using Microsoft.SqlServer.TransactSql.ScriptDom; using SQLLinter.Common; using SQLLinter.Core.Interfaces; using SQLLinter.Infrastructure.Configuration.Overrides; using SQLLinter.Infrastructure.Interfaces; using SQLLinter.Infrastructure.Rules; using SQLLinter.Infrastructure.Rules.RuleExceptions; using SQLLinter.Infrastructure.Rules.RuleViolations; using System.Data; namespace SQLLinter.Infrastructure.Parser; public class SqlRuleVisitor : IRuleVisitor { private readonly IFragmentBuilder _fragmentBuilder; private readonly IReporter _reporter; private readonly ISqlStreamReaderBuilder _sqlStreamReaderBuilder; private readonly IPluginHandler _pluginHandler; private readonly OverrideFinder _overrideFinder = new OverrideFinder(); public SqlRuleVisitor(IPluginHandler pluginHandler, IFragmentBuilder fragmentBuilder, IReporter reporter) : this(pluginHandler, fragmentBuilder, reporter, new SqlStreamReaderBuilder()) { } public SqlRuleVisitor(IPluginHandler pluginHandler, IFragmentBuilder fragmentBuilder, IReporter reporter, ISqlStreamReaderBuilder sqlStreamReaderBuilder) { this._fragmentBuilder = fragmentBuilder; this._reporter = reporter; this._pluginHandler = pluginHandler; this._sqlStreamReaderBuilder = sqlStreamReaderBuilder; } public void VisitRules(string sqlPath, IEnumerable ignoredRules, Stream sqlFileStream, bool generateDetails) { var overrides = _overrideFinder.GetOverrideList(sqlFileStream); var overrideArray = overrides as IOverride[] ?? overrides.ToArray(); var sqlFragment = _fragmentBuilder.GetFragment(sqlPath, GetSqlTextReader(sqlFileStream), out var errors, overrideArray); if (sqlFragment == null) return; Dictionary? parentMap = generateDetails ? ParentMapBuilder.Build(sqlFragment) : null; var ruleExceptions = ignoredRules as IRuleException[] ?? ignoredRules.ToArray(); if (errors.Any()) { HandleParserErrors(sqlPath, errors, ruleExceptions); } var rules = _pluginHandler.Rules; foreach (var rule in rules) { rule.SetParents(parentMap); VisitFragment(sqlFragment, rule, overrideArray, sqlPath); } } private void VisitFragment(TSqlFragment sqlFragment, IRule rule, IEnumerable overrides, string filePath) { var violations = rule.Analyze(sqlFragment).ToList(); if (!VisitorIsBlackListedForDynamicSql(rule)) { var dynamicSqlVisitor = new DynamicSQLParser(DynamicSqlCallback); sqlFragment?.Accept(dynamicSqlVisitor); } void DynamicSqlCallback(string dynamicSQL, int DynamicSqlStartLine, int DynamicSqlStartColumn) { rule.DynamicSqlStartLine = DynamicSqlStartLine; rule.DynamicSqlStartColumn = DynamicSqlStartColumn; var dynamicSqlStream = ParsingUtility.GenerateStreamFromString(dynamicSQL); var dynamicFragment = _fragmentBuilder.GetFragment(filePath, GetSqlTextReader(dynamicSqlStream), out var errors, overrides); if (dynamicFragment != null) { violations.AddRange(rule.Analyze(dynamicFragment)); } } violations.ForEach(t => _reporter.ReportViolation(filePath, t.Line, t.Column, rule.Severity, t.RuleName, t.Template, t.Snippet, t.Params)); } private static bool VisitorIsBlackListedForDynamicSql(IRule visitor) { return new List { "SetAnsiNullsRule", "SetNoCountRule", "SetQuotedIdentifierRule", "SetTransactionIsolationLevelRule", "UnicodeStringRule" }.Any(x => visitor.GetType().ToString().Contains(x)); } private StreamReader GetSqlTextReader(Stream sqlFileStream) { return _sqlStreamReaderBuilder.CreateReader(sqlFileStream); } private void HandleParserErrors(string sqlPath, IEnumerable errors, IEnumerable ignoredRules) { var updatedExitCode = false; var ruleExceptions = ignoredRules as IRuleException[] ?? ignoredRules.ToArray(); foreach (var error in errors) { var globalRulesOnLine = ruleExceptions.OfType().Where( x => error.Line >= x.StartLine && error.Line <= x.EndLine); if (!globalRulesOnLine.Any()) { _reporter.ReportViolation(new RuleViolation() { FileName = sqlPath, RuleName = "invalid-syntax", Text = error.Message, Line = error.Line, Column = error.Column, Severity = RuleViolationSeverity.Critical }); if (updatedExitCode) { continue; } updatedExitCode = true; Environment.ExitCode = 1; } } } }