Entity Framework 表达式树 lambda 访问者错误

Entity Framework expression tree lambda visitor error

我正在尝试为 EF Core 制作一个 select 解析器,将动态格式化字符串 select 子句转换为 linq/ef select 表达式。

这个想法是,例如这个 类:

public class User
    {
        [Key]
        public int Id { get; set; }
        public string FirstName { get; set; }
        public string LastName { get; set; }
        public long? AnualSalary { get; set; }
        public string Phone { get; set; }
        public IEnumerable<Document> Documents { get; set; }

        public Document PrincipalDocument { get; set; }

        public User Parent { get; set; }
    }

    public class Document
    {
        [Key]
        public int Id { get; set; }
        public string Name { get; set; }
        //public User? Owner { get; set; }
    }

如果我传递“Id,Parent[Id],Documents[Id]”,解析器将return 表示一个表达式树,例如

user=>new User(){
  Id=user.Id,
  Parent=(user.Parent!=null)?new User(){ Id= user.Parent.Id}:null,
  Documents=(user.Documents!=null)?user.Documents.Select(d=>new Document(){ Id= d.Id}).ToList():null
}

解析器在内存集合中工作正常但在 Ef (InMemory Sqlite) 中失败它抛出以下错误:

Error Message:
   System.InvalidOperationException : When called from 'VisitLambda', rewriting a node of type 'System.Linq.Expressions.ParameterExpression' must return a non-null value of the same type. Alternatively, override 'VisitLambda' and change it to not visit children of this type.

错误似乎与我的访问者实现和传递给 lambda 表达式的参数有关,但我不知道如何解决它,我已经尝试了很多来自 SO 和其他网站,但没有运气。有什么想法吗?

这是解析器的代码

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text.RegularExpressions;
using System.Threading.Tasks;

namespace Cypretex.Data.Filters.Parsers.Linq
{
    internal class SelectEntry
    {
        public string Property { get; set; }
        public IList<SelectEntry> Childs { get; set; } = new List<SelectEntry>();

        public SelectEntry AddChildProperty(SelectEntry entry)
        {
            Childs.Add(entry);
            return this;
        }
    }

    public static class LinqSelectParser
    {



        //Expression visitor instance
        private readonly static Visitor visitor = new Visitor();

        /// <summary>
        /// Parse the selection clause
        /// </summary>
        /// <param name="properties">The select properties clause</param>
        /// <param name="source"></param>
        /// <param name="suffix"></param>
        /// <typeparam name="T"></typeparam>
        /// <returns></returns>
        public static IQueryable<T> ParseSelect<T>(string properties, IQueryable<T> source, string suffix = "")
        where T : class, new()
        {
            if (String.IsNullOrEmpty(properties) || properties.Equals("*"))
            {
                return source;
            }
            List<SelectEntry> props = ParsePropertyNames(properties.Replace(" ", String.Empty));
            Expression<Func<T, T>> expression = (Expression<Func<T, T>>)Process<T, T>(props, typeof(T), typeof(T), suffix);
            return source.Select<T, T>(expression);
        }

        /// <summary>
        /// Convert the string of the properties to a SelectEntry collection
        /// </summary>
        /// <param name="properties"></param>
        /// <param name="prefix"></param>
        /// <returns></returns>
        private static List<SelectEntry> ParsePropertyNames(string properties, string prefix = "")
        {
            string pattern = @"((?<complex>[A-Za-z0-9]+)\[(?<props>[[A-Za-z0-9,]+)\]?)+|(?<simple>\w+)";
            List<SelectEntry> ret = new List<SelectEntry>();
            MatchCollection matches = Regex.Matches(properties, pattern);
            if (matches.Any())
            {
                matches.ToList().ForEach(o =>
                {
                    if (!String.IsNullOrEmpty(o.Groups["simple"].Value))
                    {
                        ret.Add(new SelectEntry()
                        {
                            Property = o.Value
                        });
                    }
                    else
                    {
                        SelectEntry entry = new SelectEntry()
                        {
                            Property = o.Groups["complex"].Value,
                            Childs = ParsePropertyNames(o.Groups["props"].Value)
                        };
                        ret.Add(entry);
                    }
                });
            }

            return ret;
        }

        private static Expression Process<T, TReturn>(List<SelectEntry> props, Type sourceType, Type destType, string suffix = "")
            where T : class, new()
            where TReturn : class, new()
        {

            List<MemberAssignment> bindings = new List<MemberAssignment>();
            ParameterExpression parameter = Expression.Parameter(sourceType, sourceType.Name);
            foreach (SelectEntry entry in props)
            {
                bindings.AddRange(ProcessEntry(entry, parameter));
            }
            NewExpression newData = Expression.New(destType);
            MemberInitExpression initExpression = Expression.MemberInit(newData, bindings);
            Expression finalExpression = MakeLambda(initExpression, parameter);
            //Console.WriteLine(finalExpression);
            return (Expression<Func<T, TReturn>>)finalExpression;

        }


        private static IList<MemberAssignment> ProcessEntry(SelectEntry entry, ParameterExpression parameter, string suffix = "")
        {
            List<MemberAssignment> bindings = new List<MemberAssignment>();
            Type type = parameter.Type;

            //process the sub properties
            if (entry.Childs.Count > 0)
            {


                PropertyInfo propertyInfo = parameter.Type.GetProperty(entry.Property);
                MemberExpression originalMember = Expression.Property(parameter, propertyInfo);

                Type childType = propertyInfo.PropertyType;
                ParameterExpression childParameter = Expression.Parameter(childType, entry.Property);
                List<MemberAssignment> subBindings = new List<MemberAssignment>();



                var isCollection = Utils.IsEnumerable(childParameter);
                //The property is a Enumerable
                if (isCollection)
                {
                    // Get the type of the child elements
                    Type elementType = childType.GetGenericArguments()[0];
                    // Create an expression for the parameter
                    ParameterExpression elementParameter = Expression.Parameter(elementType, entry.Property + ".Element");

                    foreach (SelectEntry e in entry.Childs)
                    {
                        subBindings.AddRange(ProcessEntry(e, elementParameter));
                    }

                    // Convert the list to Queryable
                    Expression asQueryable = Utils.AsQueryable(childParameter);
                    //Expression to generate a new element of the list
                    NewExpression newElementExpression = Expression.New(elementType);
                    MemberInitExpression initElementExpression = Expression.MemberInit(newElementExpression, subBindings);
                    //Iterate over the original elements (Queryable.Select)
                    MethodCallExpression selectExpr = Expression.Call(typeof(Queryable), "Select", new[] { elementType, elementType }, asQueryable, visitor.Visit(MakeLambda(initElementExpression, elementParameter)));
                    //Convert the result to list
                    Expression toListCall = Expression.Call(typeof(Enumerable), "ToList", selectExpr.Type.GetGenericArguments(), visitor.Visit(selectExpr));
                    // Check for null original collection (avoid null pointer)
                    Expression notNullConditionExpression = Expression.NotEqual(childParameter, Expression.Constant(null, childParameter.Type));
                    Expression trueExpression = MakeLambda(Expression.Convert(toListCall, childParameter.Type), childParameter);
                    Expression falseExpression = MakeLambda(Expression.Constant(null, childParameter.Type), childParameter);

                    Expression notNullExpression = Expression.Condition(notNullConditionExpression, trueExpression, falseExpression);
                    Expression notNullLambda = MakeLambda(Expression.Invoke(notNullExpression, originalMember), childParameter);
                    Console.WriteLine(notNullLambda);

                    //Invocate the null-check expression
                    Expression invocation = Expression.Invoke(notNullLambda, originalMember);
                    // Add the invocation to the bindings on the original element
                    bindings.Add(Expression.Bind(propertyInfo, invocation));
                }
                else
                {
                    // Add the child entities to the initialization bindings of the object
                    foreach (SelectEntry e in entry.Childs)
                    {
                        subBindings.AddRange(ProcessEntry(e, childParameter));
                    }
                    // Add the lambda to the bindings of the property in the parent object
                    bindings.Add(Expression.Bind(propertyInfo, CreateNewObject(childParameter, childType, subBindings, originalMember)));
                }

            }
            else
            {
                // Add the property to the init bindings
                bindings.Add(AssignProperty(parameter.Type, entry.Property, parameter));
            }
            return bindings;
        }

        /// <summary>
        /// Create a new object for assignement on the member of the result object
        /// </summary>
        /// <param name="parameter">The child parameter</param>
        /// <param name="objectType">The type of the object</param>
        /// <param name="bindings">The bindings for the initialization</param>
        /// <param name="originalMember">The member on the original (parent) object</param>
        /// <returns></returns>
        private static Expression CreateNewObject(ParameterExpression parameter, Type objectType, List<MemberAssignment> bindings, MemberExpression originalMember)
        {
            // Create new object of type childType
            NewExpression newExpression = Expression.New(objectType);
            // Initialize the members of the object
            MemberInitExpression initExpression = Expression.MemberInit(newExpression, bindings);
            // Check for not null original property (avoid the null pointer)
            Expression notNullConditionExpression = Expression.NotEqual(parameter, Expression.Constant(null, objectType));
            Expression trueExpression = MakeLambda(initExpression, parameter);
            Expression falseExpression = MakeLambda(Expression.Constant(null, objectType), parameter);
            Expression notNullExpression = Expression.Condition(notNullConditionExpression, trueExpression, falseExpression);

            // Create the lambda
            Expression initLambdaExpression = MakeLambda(notNullExpression, parameter);
            Expression initInvocation = Expression.Invoke(initLambdaExpression, originalMember);

            // Invoke the initialization expression and the not null expression
            Expression invocation = Expression.Invoke(initInvocation, originalMember);
            return invocation;
        }


        private static MemberAssignment AssignProperty(Type type, string propertyName, Expression parameter)
        {
            PropertyInfo propertyInfo = type.GetProperty(propertyName);
            MemberExpression originalMember = Expression.Property(parameter, propertyInfo);
            return Expression.Bind(propertyInfo, originalMember);
        }


        private static Expression MakeLambda(Expression predicate, params ParameterExpression[] parameters)
        {

            List<ParameterExpression> resultParameters = new List<ParameterExpression>();
            //var resultParameterVisitor = new ParameterVisitor();
            foreach (ParameterExpression parameter in parameters)
            {

                resultParameters.Add(((ParameterExpression)visitor.Visit(parameter)) ?? parameter);
            }
            return visitor.Visit(Expression.Lambda(visitor.Visit(predicate), resultParameters));
        }



        /// <summary>
        /// Returns the enumerable collection of expressions that comprise
        /// the expression tree rooted at the specified node.
        /// </summary>
        /// <param name="node">The node.</param>
        /// <returns>
        /// The enumerable collection of expressions that comprise the expression tree.
        /// </returns>
        public static IEnumerable<Expression> Explore(Expression node)
        {
            return visitor.Explore(node);
        }
        private class Visitor : ExpressionVisitor
        {
            private readonly List<Expression> expressions = new List<Expression>();

            protected override Expression VisitBinary(BinaryExpression node)
            {
                this.expressions.Add(node);
                return base.VisitBinary(node);
            }

            protected override Expression VisitBlock(BlockExpression node)
            {
                this.expressions.Add(node);
                return base.VisitBlock(node);
            }

            protected override Expression VisitConditional(ConditionalExpression node)
            {
                this.expressions.Add(node);
                return base.VisitConditional(node);
            }

            protected override Expression VisitConstant(ConstantExpression node)
            {
                this.expressions.Add(node);
                return base.VisitConstant(node);
            }

            protected override Expression VisitDebugInfo(DebugInfoExpression node)
            {
                this.expressions.Add(node);
                return base.VisitDebugInfo(node);
            }

            protected override Expression VisitDefault(DefaultExpression node)
            {
                this.expressions.Add(node);
                return base.VisitDefault(node);
            }

            protected override Expression VisitDynamic(DynamicExpression node)
            {
                this.expressions.Add(node);
                return base.VisitDynamic(node);
            }

            protected override Expression VisitExtension(Expression node)
            {
                this.expressions.Add(node);
                return base.VisitExtension(node);
            }

            protected override Expression VisitGoto(GotoExpression node)
            {
                this.expressions.Add(node);
                return base.VisitGoto(node);
            }

            protected override Expression VisitIndex(IndexExpression node)
            {
                this.expressions.Add(node);
                return base.VisitIndex(node);
            }

            protected override Expression VisitInvocation(InvocationExpression node)
            {
                this.expressions.Add(node);
                return base.VisitInvocation(node);
            }

            protected override Expression VisitLabel(LabelExpression node)
            {
                this.expressions.Add(node);
                return base.VisitLabel(node);
            }

            protected override Expression VisitLambda<T>(Expression<T> node)
            {
                this.expressions.Add(node);
                return base.VisitLambda(node);
            }

            protected override Expression VisitListInit(ListInitExpression node)
            {
                this.expressions.Add(node);
                return base.VisitListInit(node);
            }

            protected override Expression VisitLoop(LoopExpression node)
            {
                this.expressions.Add(node);
                return base.VisitLoop(node);
            }

            protected override Expression VisitMember(MemberExpression node)
            {
                this.expressions.Add(node);
                return base.VisitMember(node);
            }

            protected override Expression VisitMemberInit(MemberInitExpression node)
            {
                this.expressions.Add(node);
                return base.VisitMemberInit(node);
            }

            protected override Expression VisitMethodCall(MethodCallExpression node)
            {
                this.expressions.Add(node);
                return base.VisitMethodCall(node);
            }

            protected override Expression VisitNew(NewExpression node)
            {
                this.expressions.Add(node);
                return base.VisitNew(node);
            }

            protected override Expression VisitNewArray(NewArrayExpression node)
            {
                this.expressions.Add(node);
                return base.VisitNewArray(node);
            }

            protected override Expression VisitParameter(ParameterExpression node)
            {
                this.expressions.Add(node);
                return base.VisitParameter(node);
            }

            protected override Expression VisitRuntimeVariables(RuntimeVariablesExpression node)
            {
                this.expressions.Add(node);
                return base.VisitRuntimeVariables(node);
            }

            protected override Expression VisitSwitch(SwitchExpression node)
            {
                this.expressions.Add(node);
                return base.VisitSwitch(node);
            }

            protected override Expression VisitTry(TryExpression node)
            {
                this.expressions.Add(node);
                return base.VisitTry(node);
            }

            protected override Expression VisitTypeBinary(TypeBinaryExpression node)
            {
                this.expressions.Add(node);
                return base.VisitTypeBinary(node);
            }

            protected override Expression VisitUnary(UnaryExpression node)
            {
                this.expressions.Add(node);
                return base.VisitUnary(node);
            }

            public IEnumerable<Expression> Explore(Expression node)
            {
                this.expressions.Clear();
                this.Visit(node);
                return expressions.ToArray();
            }
        }

    }
}

public class Utils {
 public static readonly MethodInfo AsQueryableMethod = QueryableType.GetRuntimeMethods().FirstOrDefault(
        method => method.Name == "AsQueryable" && method.IsStatic);
/// <summary>
        /// Cast IEnumerable to IQueryable.
        /// </summary>
        /// <param name="prop">IEnumerable expression</param>
        /// <returns>IQueryable expression.</returns>
        public static Expression AsQueryable(Expression prop)
        {
            return Expression.Call(
                        AsQueryableMethod.MakeGenericMethod(prop.Type.GenericTypeArguments.Single()),
                        prop);
        }

        public static bool IsEnumerable(Expression prop)
        {
            return prop.Type.GetTypeInfo().ImplementedInterfaces.FirstOrDefault(x => x.Name == "IEnumerable") != null;
        }
}

经过几个小时的调试并试图找出问题所在,我已经修复了它。正如 所建议的那样,我手动编写了表达式并比较了结果,我注意到问题在于我为同一个 lambda 调用了两次表达式:

// Check for null original collection (avoid null pointer)
                    Expression notNullConditionExpression = Expression.NotEqual(childParameter, Expression.Constant(null, childParameter.Type));
                    Expression trueExpression = MakeLambda(Expression.Convert(toListCall, childParameter.Type), childParameter);
                    Expression falseExpression = MakeLambda(Expression.Constant(null, childParameter.Type), childParameter);

                    Expression notNullExpression = Expression.Condition(notNullConditionExpression, trueExpression, falseExpression);
                    Expression notNullLambda = MakeLambda(Expression.Invoke(notNullExpression, originalMember), childParameter);

                    //Invocate the null-check expression
                    Expression invocation = Expression.Invoke(notNullLambda, originalMember);

我把它改成:

 Expression notNullConditionExpression = Expression.NotEqual(parameter, Expression.Constant(null, objectType));
            Expression trueExpression = initExpression;
            Expression falseExpression = Expression.Constant(null, objectType);
            Expression notNullExpression = Expression.Condition(notNullConditionExpression, trueExpression, falseExpression);

            // Create the lambda
            Expression initLambdaExpression = MakeLambda(notNullExpression, parameter);

            // Invoke the initialization expression and the not null expression
            Expression invocation = Expression.Invoke(initLambdaExpression, originalMember);

第一个调用已被删除,现在它正在运行,我得到的结果与手动或使用 lambda 表达式编写表达式相同。