body 表达式的复杂编辑<Func<T,bool>>

Complex edit of a body Expression<Func<T,bool>>

总结: 我想知道如何从表达式的 body 中检测特定定义,然后按照我想要的方式更改它,例如

e.Entity.ListA.Union(e.ListB).Any(...)...

e.Entity != null && 
((e.Entity.ListA != null && e.Entity.ListA.Any(...)) 
|| (e.Entity.ListB != null && e.Entity.ListB.Any(...)))

仅使用我认为的 Linq 表达式技术是理想的解决方案

作为编写干净的 C# 代码的一部分,我编写了一组预定义表达式并使用 LinqKit 扩展我可以将它们组合起来,因此它将轻松扩展编写复杂表达式的活力,直到一切都好.此外,我想用它们来过滤 IQuerable 和 IEnumerable 的情况。但是,如您所知,有些情况下定义的表达式在前者或后者中不起作用,我成功地避免了很多此类问题。直到我遇到了解决方案,但我仍然觉得不太理想的情况。

我会先从展示问题开始,然后解释想要的解决方案,最后,我会分享我的尝试。

//---
public class AssignmentsEx : BaseEx
{ 


//.........

/// <summary>
/// (e.FreeRoles AND e.RoleClass.Roles) ⊆ ass.AllRoles
/// </summary>
public static Expression<Func<T, bool>> RolesInclosedBy<T>(IAssignedInstitution assignedInstitution) where T : class, IAssignedInstitution
    {
        var allStaticRoles = AppRolesStaticData.AdminRolesStr.GetAll();
        var assAllRoles = assignedInstitution.AllRoles.Select(s => s.Name).ToList();
        var hasAllRoles = allStaticRoles.All(assR => assAllRoles.Any(sR => sR == assR));

        if (hasAllRoles)
            return e => true;

// for LINQ to SQL the expression works perfectly as you know 
// the expression will be translated to an SQL code
// for IEnumerable case the nested object Roles with throw null obj ref 
// exception if the RoleClass is null (and this is a healthy case from code execution
// 
       return Expression<Func<T, bool>> whenToEntity = e => e.FreeRoles.Union(e.RoleClass.Roles).All(eR => assAllRoles.Any(assR => assR == eR.Name));
    }

//.........

}

如您所见,如果我使用此方法定义 objects 列表且 RoleClass 为 null 或 FreeRoles 为 null,它将抛出 NullException。

-- best-expected 建议我认为它会影响三个因素:

这种方式将帮助我保持方法静态并通过扩展方法修改它:例如ex.WithSplittedUnion()

而不是传统方式,即我现在使用的方式如下

public class AssignmentsEx
{

public LinqExpressionPurpose purpose{get;}

public AssignmentsEx(LinqExpressionPurpose purpose) : base(purpose)
    {
          Purpose = purpose
    }

 public Expression<Func<T, bool>> RolesInclosedBy<T>(IAssignedInstitution assignedInstitution) where T : class, IAssignedInstitution
    {
        var allStaticRoles = AppRolesStaticData.AdminRolesStr.GetAll();
        var assAllRoles = assignedInstitution.AllRoles.Select(s => s.Name).ToList();
        var hasAllRoles = allStaticRoles.All(assR => assAllRoles.Any(sR => sR == assR));

        if (hasAllRoles)
            return e => true;

        Expression<Func<T, bool>> whenToObject = e => (e.FreeRoles == null || e.FreeRoles.All(eR => assAllRoles.Any(assR => assR == eR.Name)))
        && (e.RoleClass == null || e.RoleClass.Roles == null || e.RoleClass.Roles.All(eR => assAllRoles.Any(assR => assR == eR.Name)));

        Expression<Func<T, bool>> whenToEntity = e => e.FreeRoles.Union(e.RoleClass.Roles).All(eR => assAllRoles.Any(assR => assR == eR.Name));

        return Purpose switch
        {
            LinqExpressionPurpose.ToEntity => whenToEntity,
            LinqExpressionPurpose.ToObject => whenToObject,
            _ => null,
        };
    }
}

希望解释清楚,提前致谢

在我看来,你需要的是ExpressionVisitor遍历和修改ExpressionTree。我要改变的一件事是你调用 Any 的方式。 而不是

e.Entity != null && 
((e.Entity.ListA != null && e.Entity.ListA.Any(...)) 
|| (e.Entity.ListB != null && e.Entity.ListB.Any(...)))

我会选择

(
    e.Entity != null && e.Entity.ListA != null && e.Entity.ListB != null
        ? e.Entity.ListA.Union(e.Entity.ListB)
        : e.Entity != null && e.Entity.ListA != null
            ? e.Entity.ListA
            : e.Entity.ListB != null
                ? e.Entity.ListB
                : new Entity[0]
).Any(...)

我发现构造起来更容易ExpressionTree,结果也一样。

示例代码:

public class OptionalCallFix : ExpressionVisitor
{
    private readonly List<Expression> _conditionalExpressions = new List<Expression>();
    private readonly Type _contextType;
    private readonly Type _entityType;

    private OptionalCallFix(Type contextType, Type entityType)
    {
        this._contextType = contextType;
        this._entityType = entityType;
    }

    protected override Expression VisitMethodCall(MethodCallExpression node)
    {
        // Replace Queryable.Union(left, right) call with:
        //     left == null && right == null ? new Entity[0] : (left == null ? right : (right == null ? left : Queryable.Union(left, right)))
        if (node.Method.DeclaringType == typeof(Queryable) && node.Method.Name == nameof(Queryable.Union))
        {
            Expression left = this.Visit(node.Arguments[0]);
            Expression right = this.Visit(node.Arguments[1]);

            // left == null
            Expression leftIsNull = Expression.Equal(left, Expression.Constant(null, left.Type));

            // right == null
            Expression rightIsNull = Expression.Equal(right, Expression.Constant(null, right.Type));

            // new Entity[0].AsQueryable()
            Expression emptyArray = Expression.Call
            (
                typeof(Queryable),
                nameof(Queryable.AsQueryable),
                new [] { this._entityType },
                Expression.NewArrayInit(this._entityType, new Expression[0])
            );

            // left == null && right == null ? new Entity[0] : (left == null ? right : (right == null ? left : Queryable.Union(left, right)))
            return Expression.Condition
            (
                Expression.AndAlso(leftIsNull, rightIsNull),
                emptyArray,
                Expression.Condition
                (
                    leftIsNull,
                    right,
                    Expression.Condition
                    (
                        rightIsNull,
                        left,
                        Expression.Call
                        (
                            typeof(Queryable), 
                            nameof(Queryable.Union), 
                            new [] { this._entityType }, 
                            left, 
                            Expression.Convert(right, typeof(IEnumerable<>).MakeGenericType(this._entityType))
                        )
                    )
                )
            );
        }

        return base.VisitMethodCall(node);
    }

    protected override Expression VisitMember(MemberExpression node)
    {
        Expression expression = this.Visit(node.Expression);

        // Check if expression should be fixed
        if (this._conditionalExpressions.Contains(expression))
        {
            // replace e.XXX with e == null ? null : e.XXX
            ConditionalExpression condition = Expression.Condition
            (
                Expression.Equal(expression, Expression.Constant(null, expression.Type)),
                Expression.Constant(null, node.Type),
                Expression.MakeMemberAccess(expression, node.Member)
            );

            // Add fixed expression to the _conditionalExpressions list
            this._conditionalExpressions.Add(condition);

            return condition;
        }

        return base.VisitMember(node);
    }

    protected override Expression VisitParameter(ParameterExpression node)
    {
        if (node.Type == this._contextType)
        {
            // Add ParameterExpression to the _conditionalExpressions list
            // It is used in VisitMember method to check if expression should be fixed this way
            this._conditionalExpressions.Add(node);
        }

        return base.VisitParameter(node);
    }

    public static IQueryable<TEntity> Fix<TContext, TEntity>(TContext context, in Expression<Func<TContext, IQueryable<TEntity>>> method)
    {
        return ((Expression<Func<TContext, IQueryable<TEntity>>>)new OptionalCallFix(typeof(TContext), typeof(TEntity)).Visit(method)).Compile().Invoke(context);
    }
}

你可以这样称呼它:

OptionalCallFix.Fix(context, ctx => ctx.Entity.ListA.Union(ctx.ListB));