使用 DacFX SqlCodeAnalysisRule 检查 where 子句中的数据类型

Check datatypes in where clause using DacFX SqlCodeAnalysisRule

我正在尝试编写一个 SqlCodeAnalysisRule 规则来检查 where 子句中的比较运算符,以确保双方具有相同的数据类型以实现可搜索性。但是,我很难从 ScriptDom ColumnReferenceExpression 对象中找出数据类型实际是什么。我花了几天时间研究这个,但真的一无所获。希望有人可以提供帮助。

当您引用可以包含多个 select 语句的存储过程和函数时,这会变得更加复杂。甚至是具有多个子 select 或联合的视图。我看过这个答案 Get DataType of computed column from dacpac 但它没有解决如何从 select 内的列返回到所需的 table.

这是我现在的规则分析:

    public override IList<SqlRuleProblem> Analyze(SqlRuleExecutionContext ruleExecutionContext)
    {
        var problems = new List<SqlRuleProblem>();
        var sqlObj = ruleExecutionContext.ModelElement;

        try
        {
            if (sqlObj != null)
            {
                var children = sqlObj.GetChildren(DacQueryScopes.All);

                var fragment = ruleExecutionContext.ScriptFragment;

                var selectStatementVisitor = new SelectStatementVisitor();
                fragment.Accept(selectStatementVisitor);

                foreach (var select in selectStatementVisitor.Statements)
                {
                    var query = select.QueryExpression as QuerySpecification;
                    if(query != null && query.WhereClause != null)
                    {
                        var booleanComparisonVisitor = new BooleanComparisonVisitor();
                        query.WhereClause.Accept(booleanComparisonVisitor);

                        foreach (var comparison in booleanComparisonVisitor.Statements)
                        {
                            var refColumn = comparison.FirstExpression as ColumnReferenceExpression;
                            // HERE I WOULD LIKE TO INSPECT THE DATATYPE OF FirstExpression AND COMPARE IT TO THE DATATYPE OF the SecondExpression 
                        }
                    }
                }
            }
        }
        catch (System.Exception ex)
        {
            //TODO: PROPERLY LOG THIS ERROR
            Debug.WriteLine(ex.ToString());
            //throw;
        }

        return problems;
    }

好的,我将 post 我想出的解决方案,尽管它不是 100% 准确。它不是 100% 准确的问题是因为 DacFX api 的问题,以及它不允许查询内置函数的事实。虽然不是 100%,但事实证明这相当准确。我跳过任何我不确定的(return null),并且将不得不依靠扫描计划缓存来进行隐式转换。

[ExportCodeAnalysisRule(DataTypesOnBothSidesOfEqualityRule.RuleId,
    DataTypesOnBothSidesOfEqualityRule.RuleDisplayName,
    Description = DataTypesOnBothSidesOfEqualityRule.RuleDisplayName,
    Category = Constants.Performance,
    RuleScope = SqlRuleScope.Element)]
public sealed class DataTypesOnBothSidesOfEqualityRule : SqlCodeAnalysisRule
{
    public const string RuleId = Constants.RuleNameSpace + "SRP0016";
    public const string RuleDisplayName = "Data types on both sides of an equality check should be the same in the where clause. (Sargeable)";
    public const string Message = "{0} {1} data types on both sides of an equality do not match in the where clause.";
    private static StringComparer _comparer = StringComparer.InvariantCultureIgnoreCase;
    private readonly FunctionChain functionChain = new FunctionChain();
    #region built in function data types
    private static IDictionary<string, string> _functions = new Dictionary<string, string>(StringComparer.InvariantCultureIgnoreCase)
    {
        /*Date and Time Data Types and Functions (Transact-SQL)*/
        { "SYSDATETIME", "datetime2" },
        { "SYSDATETIMEOFFSET", "datetimeoffset" },
        { "SYSUTCDATETIME", "datetime2" },
        { "CURRENT_TIMESTAMP", "datetime" },
        { "GETDATE", "datetime" },
        { "GETUTCDATE", "datetime" },
        { "DATENAME", "nvarchar" },
        { "DATEPART", "int" },
        { "DAY", "int" },
        { "MONTH", "int" },
        { "YEAR", "int" },
        { "DATEFROMPARTS", "date" },
        { "DATETIME2FROMPARTS", "datetime2" },
        { "DATETIMEFROMPARTS", "datetime" },
        { "DATETIMEOFFSETFROMPARTS", "datetimeoffset" },
        { "SMALLDATETIMEFROMPARTS", "smalldatetime" },
        { "TIMEFROMPARTS", "time" },
        { "DATEDIFF", "int" },
        { "DATEDIFF_BIG", "bigint" },
        { "ISDATE", "int" },
        /* Mathematical Functions (Transact-SQL)*/
        { "ACOS", "float" },
        { "ASIN", "float" },
        { "ATAN", "float" },
        { "ATN2", "float" },
        { "COS", "float" },
        { "COT", "float" },
        { "EXP", "float" },
        { "LOG", "float" },
        { "LOG10", "float" },
        { "PI", "float" },
        { "POWER", "float" },
        { "RAND", "float" },
        //{ "ROUND", "" }, completely unable to figure out how to map these. leaving commented here to mark that in case someone else figures it out
        //{ "SIGN", "" },
        { "SIN", "float" },
        { "SQRT", "float" },
        { "SQUARE", "float" },
        { "TAN", "float" },
        /*String Functions (Transact-SQL)*/
        { "ASCII", "int" },
        { "CHAR", "char" },
        { "DIFFERENCE", "int" },
        { "FORMAT", "nvarchar" },
        { "QUOTENAME", "nvarchar" },
        { "SOUNDEX", "varchar" },
        { "SPACE", "varchar" },
        { "STR", "varchar" },
        { "STRING_ESCAPE", "nvarchar" },
        { "UNICODE", "int" },
        /* System Functions (Transact-SQL)*/
        { "HOST_ID", "char" },
        { "HOST_NAME", "nvarchar" },
        { "ISNUMERIC", "int" },
        { "NEWID", "uniqueidentifier" },
        { "NEWSEQUENTIALID", "uniqueidentifier" },
        { "ROWCOUNT_BIG", "bigint" },
        { "SESSION_CONTEXT", "sql_variant" },
        { "SESSION_ID", "nvarchar" },
        { "XACT_STATE", "smallint" },
        /*Metadata Functions (Transact-SQL)*/
        { "APP_NAME", "nvarchar" },
        { "APPLOCK_MODE", "nvarchar" },
        { "APPLOCK_TEST", "smallint" },
        { "ASSEMBLYPROPERTY", "sql_variant" },
        { "COL_LENGTH", "smallint" },
        { "COL_NAME", "nvarchar" },
        { "COLUMNPROPERTY", "int" },
        { "DATABASEPROPERTYEX", "sql_variant" },
        { "DB_ID", "int" },
        { "DB_NAME", "nvarchar" },
        { "FILE_ID", "smallint" },
        { "FILE_IDEX", "int" },
        { "FILE_NAME", "nvarchar" },
        { "FILEGROUP_ID", "int" },
        { "FILEGROUP_NAME", "nvarchar" },
        { "FILEGROUPPROPERTY", "int" },
        { "FILEPROPERTY", "int" },
        { "FULLTEXTCATALOGPROPERTY", "int" },
        { "FULLTEXTSERVICEPROPERTY", "int" },
        { "INDEX_COL", "nvarchar" },
        { "INDEXKEY_PROPERTY", "int" },
        { "INDEXPROPERTY", "int" },
        { "OBJECT_DEFINITION", "nvarchar" },
        { "OBJECT_ID", "int" },
        { "OBJECT_NAME", "sysname" },
        { "OBJECT_SCHEMA_NAME", "sysname" },
        { "OBJECTPROPERTY", "int" },
        { "OBJECTPROPERTYEX", "sql_variant" },
        { "ORIGINAL_DB_NAME", "nvarchar" },
        { "PARSENAME", "nchar" },
        { "SCHEMA_ID", "int" },
        { "SCHEMA_NAME", "sysname" },
        { "SCOPE_IDENTITY", "numeric" },
        { "SERVERPROPERTY", "sql_variant" },
        { "STATS_DATE", "datetime" },
        { "TYPE_ID", "int" },
        { "TYPE_NAME", "sysname" },
        { "TYPEPROPERTY", "int" },
        /*Security Functions (Transact-SQL)*/
        { "CERTENCODED", "varbinary" },
        { "CERTPRIVATEKEY", "varbinary" },
        { "CURRENT_USER", "sysname" },
        { "DATABASE_PRINCIPAL_ID", "int" },
        { "HAS_PERMS_BY_NAME", "int" },
        { "IS_MEMBER", "int" },
        { "IS_ROLEMEMBER", "int" },
        { "IS_SRVROLEMEMBER", "int" },
        { "ORIGINAL_LOGIN", "sysname" },
        { "PERMISSIONS", "int" },
        { "PWDCOMPARE", "int" },
        { "PWDENCRYPT", "varbinary" },
        { "SESSION_USER", "nvarchar" },
        { "SUSER_ID", "int" },
        { "SUSER_SID", "varbinary" },
        { "SUSER_SNAME", "nvarchar" },
        { "SYSTEM_USER", "nchar" },
        { "SUSER_NAME", "nvarchar" },
        { "USER_ID", "int" },
        { "USER_NAME", "nvarchar" }
    };
    #endregion

    public DataTypesOnBothSidesOfEqualityRule()
    {

        SupportedElementTypes = new[] { ModelSchema.Procedure, ModelSchema.View, ModelSchema.ScalarFunction, ModelSchema.TableValuedFunction };
    }

    public override IList<SqlRuleProblem> Analyze(SqlRuleExecutionContext ruleExecutionContext)
    {
        var problems = new List<SqlRuleProblem>();
        var sqlObj = ruleExecutionContext.ModelElement; //proc / view / function

        try
        {
            if (sqlObj != null)
            {
                var fragment = ruleExecutionContext.ScriptFragment;

                //get the combined parameters and declare variables into one searchable list
                var variablesVisitor = new VariablesVisitor();
                fragment.AcceptChildren(variablesVisitor);
                var variables = variablesVisitor.GetVariables();

                var selectStatementVisitor = new SelectStatementVisitor();
                fragment.Accept(selectStatementVisitor);
                foreach (var select in selectStatementVisitor.Statements)
                {
                    var query = select.QueryExpression as QuerySpecification;
                    if (query != null && query.WhereClause != null)
                    {
                        var booleanComparisonVisitor = new BooleanComparisonVisitor();
                        query.WhereClause.Accept(booleanComparisonVisitor);

                        foreach (var comparison in booleanComparisonVisitor.Statements)
                        {
                            var datatype1 = GetDataType(sqlObj, query, comparison.FirstExpression, variables);
                            if (string.IsNullOrEmpty(datatype1)) { continue; }

                            var datatype2 = GetDataType(sqlObj, query, comparison.SecondExpression, variables);
                            if (string.IsNullOrEmpty(datatype2)) { continue; }

                            //when checking the numeric literal I am not sure if it is a bit or tinyint.
                            if ((_comparer.Equals(datatype1, "bit") && _comparer.Equals(datatype2, "tinyint")) || (_comparer.Equals(datatype1, "tinyint") && _comparer.Equals(datatype2, "bit"))) { continue; }

                            if (!_comparer.Equals(datatype1, datatype2))
                            {
                                string msg = string.Format(Message, sqlObj.ObjectType.Name, RuleUtils.GetElementName(ruleExecutionContext, sqlObj));
                                problems.Add(new SqlRuleProblem(msg, sqlObj, comparison));
                            }
                        }
                    }
                }
            }
        }
        catch (System.Exception ex)
        {
            //TODO: PROPERLY LOG THIS ERROR
            Debug.WriteLine(ex.ToString());
            //throw;
        }

        return problems;
    }

    private string GetDataType(TSqlObject sqlObj, QuerySpecification query, ScalarExpression expression, IList<VariableView> variables)
    {
        if (expression is ColumnReferenceExpression)
        {
            return GetColumnDataType(sqlObj, query, expression);
        }
        else if (expression is StringLiteral)
        {
            var stringLiteral = expression as StringLiteral;
            if (stringLiteral.IsNational)
            {
                return "nvarchar";
            }
            else
            {
                return "varchar";
            }
        }
        else if (expression is IntegerLiteral)
        {
            long val = long.Parse(((IntegerLiteral)expression).Value);

            if (val >= 0 && val <= 255) // to bit or not to bit? NFC.
            {
                return "tinyint";
            }
            else if (val >= -32768 && val <= 32768)
            {
                return "smallint";
            }
            else if (val >= -2147483648 && val <= 2147483648)
            {
                return "int";
            }
            else if (val >= -9223372036854775808 && val <= 9223372036854775807)
            {
                return "bigint";
            }
            //technically this may not be accurate. as sql sever will interpret literal ints as different types 
            //depending upon how large they are. smallint, tinyint, etc... Unless I mimic their same value behavior.
            return "int";
        }
        else if (expression is CastCall)
        {
            var castCall = expression as CastCall;

            return castCall.DataType.Name.Identifiers.First().Value;
        }
        else if (expression is VariableReference)
        {
            var variable = variables.FirstOrDefault(v => _comparer.Equals(v.Name, ((VariableReference)expression).Name));
            if (variable != null)
            {
                return variable.DataType;
            }
        }
        else if (expression is FunctionCall)
        {
            //TIM C: sigh, this does not work for all functions. the api does not allow for me to look up built in functions. nor does it allow me to get the 
            //data types of parameters, so I am not able to type ALL functions like DATEADD, the parameter could be a column, string literal, variable, function etc...
            var function = expression as FunctionCall;
            if (_functions.ContainsKey(function.FunctionName.Value))
            {
                return _functions[function.FunctionName.Value];
            }
        }
        else if (expression is BinaryExpression)
        {
            var binaryExpression = expression as BinaryExpression;

            var datatype1 = GetDataType(sqlObj, query, binaryExpression.FirstExpression, variables);
            if (datatype1 != null) { return datatype1; }
            return GetDataType(sqlObj, query, binaryExpression.SecondExpression, variables);
        }
        else if (expression is ScalarSubquery)
        {
            var scalarQuery = ((ScalarSubquery)expression).QueryExpression as QuerySpecification;
            var selectElement = scalarQuery.SelectElements.First();

            return GetDataType(sqlObj, scalarQuery, ((SelectScalarExpression)selectElement).Expression, variables);
        }
        else
        {
            Debug.WriteLine("Unknown expression");
        }


        return null;
    }

    private string GetColumnDataType(TSqlObject sqlObj, QuerySpecification query, ScalarExpression expression)
    {
        var column = expression as ColumnReferenceExpression;
        var tables = new List<NamedTableReference>();

        var namedTableVisitor = new NamedTableReferenceVisitor();
        query.FromClause.AcceptChildren(namedTableVisitor);

        if (column.MultiPartIdentifier.Identifiers.Count == 2)
        {
            tables.AddRange(namedTableVisitor.Statements.Where(x => x.Alias?.Value == column.MultiPartIdentifier.Identifiers[0].Value));
        }
        else
        {
            //they did NOT use a two part name, so logic dictates that this column SHOULD only appear once in the list of tables, but we will have to search all of the tables.
            tables.AddRange(namedTableVisitor.Statements);
        }

        var referencedTables = sqlObj.GetReferenced().Where(x => x.ObjectType == Table.TypeClass && tables.Any(t => x.Name.CompareTo(t.SchemaObject.Identifiers)));

        foreach (var referencedTable in referencedTables)
        {
            string fullColumnName = referencedTable.Name.ToString() + ".[" + column.MultiPartIdentifier.Identifiers.Last().Value + "]";

            var retColumn = referencedTable.GetReferencedRelationshipInstances(Table.Columns).FirstOrDefault(p => _comparer.Equals(p.ObjectName.ToString(), fullColumnName));

            if (retColumn != null)
            {
                var dataType = retColumn.Object.GetReferenced(Column.DataType)
                            .FirstOrDefault();

                return dataType.Name.Parts.First();
            }
        }

        return null;
    }
}

访问者变量:

internal enum VariableViewType
{
    Parameter,
    Declare
}
internal class VariableView
{
    public string Name { get; set; }
    public string DataType { get; set; }
    public VariableViewType Type { get; set; }

    public override string ToString()
    {
        if (!string.IsNullOrEmpty(this.Name))
        {
            return $"{this.Type} {this.Name} {this.DataType}";
        }
        return base.ToString();
    }
}

internal class VariablesVisitor : TSqlFragmentVisitor
{
    public IList<DeclareVariableStatement> DeclareVariables { get; private set; }
    public IList<ProcedureParameter> ProcedureParameters { get; private set; }

    public VariablesVisitor()
    {
        DeclareVariables = new List<DeclareVariableStatement>();
        ProcedureParameters = new List<ProcedureParameter>();
    }
    public override void ExplicitVisit(DeclareVariableStatement node)
    {
        DeclareVariables.Add(node);
    }
    public override void ExplicitVisit(ProcedureParameter node)
    {
        ProcedureParameters.Add(node);
    }

    public IList<VariableView> GetVariables()
    {
        var ret = new List<VariableView>();
        foreach (var statements in DeclareVariables)
        {
            foreach (var declare in statements.Declarations)
            {
                ret.Add(new VariableView()
                {
                    Name = declare.VariableName.Value,
                    DataType = declare.DataType.Name.Identifiers.Last().Value,
                    Type = VariableViewType.Declare
                });

            }
        }
        foreach (var parameter in ProcedureParameters)
        {
            ret.Add(new VariableView()
            {
                Name = parameter.VariableName.Value,
                DataType = parameter.DataType.Name.Identifiers.Last().Value,
                Type = VariableViewType.Parameter
            });
        }
        return ret;
    }
}

其他访客:

internal class SelectStatementVisitor : TSqlFragmentVisitor
{
    public IList<SelectStatement> Statements { get; private set; }
    public SelectStatementVisitor()
    {
        Statements = new List<SelectStatement>();
    }
    public override void ExplicitVisit(SelectStatement node)
    {
        Statements.Add(node);
    }
}
internal class BooleanComparisonVisitor : TSqlFragmentVisitor
{
    public IList<BooleanComparisonExpression> Statements { get; private set; }


    public BooleanComparisonVisitor()
    {
        Statements = new List<BooleanComparisonExpression>();
    }

    public override void ExplicitVisit(BooleanComparisonExpression node)
    {
        Statements.Add(node);
    }
}

名为table的访客:

internal class NamedTableReferenceVisitor : TSqlFragmentVisitor
{
    public IList<NamedTableReference> Statements { get; private set; }


    public NamedTableReferenceVisitor()
    {
        Statements = new List<NamedTableReference>();
    }

    public override void ExplicitVisit(NamedTableReference node)
    {
        Statements.Add(node);
    }
}