using Microsoft.SqlServer.TransactSql.ScriptDom; using System.Text.RegularExpressions; namespace SQLLinter.Common.Helpers; public static class FixHelpers { public class FindViolatingNodeVisitor : TSqlFragmentVisitor where T : TSqlFragment { private readonly Func Where; public List Nodes = new List(); public FindViolatingNodeVisitor(Func where = null) { Where = where; } public override void Visit(TSqlFragment node) { if (node is T val && (Where == null || Where(val))) { Nodes.Add(val); } base.Visit(node); } } public static (TReturn, TFind) FindViolatingNode(List fileLines, IRuleViolation ruleViolation, Func getFragment) where TFind : TSqlFragment where TReturn : TSqlFragment { TFind val = FindNodes(fileLines).FirstOrDefault(delegate (TFind x) { TReturn val2 = getFragment(x); return val2?.StartLine == ruleViolation.Line && val2?.StartColumn == ruleViolation.Column; }); return (getFragment(val), val); } public static List FindNodes(List fileLines, Func where = null) where T : TSqlFragment { using StringReader input = new StringReader(string.Join("\n", fileLines)); IList errors; TSqlFragment tSqlFragment = new TSql150Parser(initialQuotedIdentifiers: true, SqlEngineType.All).Parse(input, out errors); if (errors != null && errors.Any()) { throw new Exception("Parsing failed. " + string.Join(". ", errors.Select((ParseError x) => x.Message))); } FindViolatingNodeVisitor findViolatingNodeVisitor = new FindViolatingNodeVisitor(where); tSqlFragment.Accept(findViolatingNodeVisitor); return findViolatingNodeVisitor.Nodes; } public static List FindNodes(TSqlFragment statement, Func where = null) where T : TSqlFragment { FindViolatingNodeVisitor findViolatingNodeVisitor = new FindViolatingNodeVisitor(where); statement.Accept(findViolatingNodeVisitor); return findViolatingNodeVisitor.Nodes; } public static T FindViolatingNode(List fileLines, IRuleViolation ruleViolation) where T : TSqlFragment { return FindViolatingNode(fileLines, ruleViolation, (T x) => x).Item1; } public static string GetIndent(List fileLines, IRuleViolation ruleViolation) { return GetIndent(fileLines[ruleViolation.Line - 1]); } public static string GetIndent(List fileLines, TSqlStatement statement) { return GetIndent(fileLines[statement.StartLine - 1]); } public static string GetString(TSqlFragment fragment) { return string.Join(string.Empty, from x in fragment.ScriptTokenStream.Where((TSqlParserToken x, int i) => i >= fragment.FirstTokenIndex && i <= fragment.LastTokenIndex) select x.Text); } private static string GetIndent(string ifLine) { Match match = new Regex("^\\s+").Match(ifLine); string result = string.Empty; if (match.Success) { result = match.Value; } return result; } }