using Microsoft.SqlServer.TransactSql.ScriptDom; using SQLLinter.Common; using SQLLinter.Common.Helpers; using SQLLinter.Core; namespace SQLLinter.Infrastructure.Rules; public class UserFunctionJoinRule : BaseRuleVisitor { public override string Text => "Запрещено использование пользовательских функций внутри запросов: {0}"; private bool _inQueryContext; // Входим в контекст запроса public override void Visit(QuerySpecification node) { _inQueryContext = true; // Проверяем FROM if (node.FromClause != null) { var tables = node.FromClause.TableReferences; // Разрешаем ровно одну табличную функцию if (!(tables.Count == 1 && tables[0] is SchemaObjectFunctionTableReference ft && IsUserFunction(ft.SchemaObject?.BaseIdentifier?.Value))) { foreach (var tr in tables) CheckTableReference(tr); } } base.Visit(node); _inQueryContext = false; } // Скалярные функции public override void Visit(FunctionCall node) { if (_inQueryContext && IsUserFunction(node.FunctionName.Value)) { AddViolationFunction(node); } base.Visit(node); } // Табличные функции public override void Visit(SchemaObjectFunctionTableReference node) { // Проверка делается в QuerySpecification - здесь ничего не делаем base.Visit(node); } // Подзапросы public override void Visit(ScalarSubquery node) { node.QueryExpression.Accept(this); base.Visit(node); } public override void Visit(QueryDerivedTable node) { node.QueryExpression.Accept(this); base.Visit(node); } public override void Visit(QualifiedJoin node) { if (node.SearchCondition != null) node.SearchCondition.Accept(this); if (node.FirstTableReference != null) CheckTableReference(node.FirstTableReference); if (node.SecondTableReference != null) CheckTableReference(node.SecondTableReference); base.Visit(node); } public override void Visit(UnqualifiedJoin node) { // Проверяем APPLY if (node.UnqualifiedJoinType == UnqualifiedJoinType.CrossApply || node.UnqualifiedJoinType == UnqualifiedJoinType.OuterApply) { if (node.SecondTableReference is SchemaObjectFunctionTableReference ft && IsUserFunction(ft.SchemaObject?.BaseIdentifier?.Value)) { AddViolation(ft, SQLHelpers.ObjectGetFullName(ft.SchemaObject)); } } // Обходим обе стороны if (node.FirstTableReference != null) CheckTableReference(node.FirstTableReference); if (node.SecondTableReference != null) CheckTableReference(node.SecondTableReference); base.Visit(node); } // ------------------------ // Вспомогательные методы // ------------------------ private void AddViolationFunction(FunctionCall func) { string ident = ""; if (func.CallTarget is MultiPartIdentifierCallTarget callTarget) { ident = SQLHelpers.ObjectGetFullName(callTarget.MultiPartIdentifier.Identifiers); if (!string.IsNullOrWhiteSpace(ident)) ident += "."; } ident += "[" + func.FunctionName.Value + "]"; AddViolation(func, ident); } private void CheckTableReference(TableReference tr) { switch (tr) { case SchemaObjectFunctionTableReference ft: if (IsUserFunction(ft.SchemaObject?.BaseIdentifier?.Value)) AddViolation(ft, SQLHelpers.ObjectGetFullName(ft.SchemaObject)); break; case QueryDerivedTable qdt: qdt.QueryExpression.Accept(this); break; case QualifiedJoin qj: if (qj.FirstTableReference != null) CheckTableReference(qj.FirstTableReference); if (qj.SecondTableReference != null) CheckTableReference(qj.SecondTableReference); if (qj.SearchCondition != null) qj.SearchCondition.Accept(this); break; case UnqualifiedJoin aj: if (aj.FirstTableReference != null) CheckTableReference(aj.FirstTableReference); if (aj.SecondTableReference != null) CheckTableReference(aj.SecondTableReference); break; } } private bool IsUserFunction(string? functionName) { return !string.IsNullOrEmpty(functionName) && !Constants.SystemFunctions.Contains(functionName); } }