穷举搜索/生成表达式树的每个组合

Exhaustive search / generate each combination of an expression tree

我正在使用基本的表达式树优化器来构建查询计划。解析树时,我可以决定如何 "best" 构造它,具体取决于我可以分配给每个操作的权重。

如果我有一个简单的树,有 2 个关于如何执行动作的选择,我希望能够生成树的两个变体,然后可以比较每个变体的权重以查看最有效率。

例如,下面的代码将允许我构建表达式树连接操作的两种变体:一种带有 MergeJoinExpression 和一种带有 NestedLoopJoinExpression

class Customer
{
        public int Id { get; set; }
}
class Orders
{
        public int Id { get; set; }
        public int CustomerId { get; set; }
}

class MergeJoinExpresion : JoinExpression
{
}

class NestLoopJoinExpresion : JoinExpression
{
}

class Visitor : ExpressionVisitor
{
    public List<Expression> GetPlans(Expression expr)
    {
        // ???
    }

    override VisitJoin(JoinExpression join)
    {
        // For this join, I can return the following (trite example)
        // return MergeJoinExpresion
        // return NestLoopJoinExpresion

        return base.VisitJoin(join);
    }
}

我如何构建一个方法来生成树的每个变体并将它们 return 给我?

class Program
{
        static void Main(string[] args)
        {
             var query = from c in customers
                        join o in orders on c.Id equals o.CustomerId
                        select new
                        {
                            CustomerId = c.Id,
                            OrderId = o.Id
                        };


            var plans = new Visitor().GetPlans(query);
        }
}

谁能告诉我如何修改 Visitor Class GetPlans 方法来生成这些变体?

编辑 - 类似于:

class Visitor : ExpressionVisitor
{
    private List<Expression> exprs = new List<Expression>();

    public List<Expression> GetPlans(Expression expr)
    {
        Visit(expr);    
        return exprs;
    }

    override VisitJoin(JoinExpression join)
    {
        // For this join, I can return the following (trite example)
        // return MergeJoinExpresion
        // return NestLoopJoinExpresion      
        var choices = new Expression[] { MergeJoinExpresion.Create(join), NestLoopJoinExpresion.Create(join) };

        foreach(var choice in choices)
        {
             var cloned = Cloner.Clone(choice);
             var newTree = base.VisitJoin(cloned);
             exprs.Add(newTree);
        }

        return base.VisitJoin(join);
    }
}

你肯定需要不可变树。

创建 class:

class JoinOptionsExpression: JoinExpression {
    public IEnumerable<JoinExpression> Options {get; private set;}
    private JoinOptionsExpression(){}
    public static JoinOptionsExpression Create(IEnumerable<JoinExpression> options){
        return new JoinOptionsExpression{Options = options.ToList().AsReadOnly()}; // you can improve this probably
    }
}

然后在您的 VisitJoin 方法中 return 选项,return 所有选择:

private List<Dictionary<JoinOptionsExpression,int>> selections = new List<Dictionary<JoinOptionsExpression,int>>{new Dictionary<JoinOptionsExpression,int>()};
override VisitJoin(JoinExpression join)
{
    var choices = new Expression[] { MergeJoinExpresion.Create(join), NestLoopJoinExpresion.Create(join) };
    List<Expression> exprs = new List<Expression>();
    foreach(var choice in choices)
    {
         var cloned = Cloner.Clone(choice);
         var newTree = base.VisitJoin(cloned);
         exprs.Add(newTree);
    }
    var result = JoinOptionsExpression.Create(exprs);
    // now add all choices
    if (exprs.Count > 0)
        foreach (selection in selections.ToList()) // to make sure your don't modify during enumeration, you can improve this too
        {
            selection.Add(result, 0);
            for (i=1; i<exprs.Count; i++)
            {
                var copy= new Dictionary<JoinOptionsExpression, int>(selection);
                copy[result] = i;
                selections.Add(copy);
            }
        }
    return result;
}

那么您将需要第二个访问者,它派生自框架访问者,没有其他原因,只需提取您的选项:

class OptionsExtractor:ExpressionVisitor
{
    public IEnumerable<Expression> Extract(Expression expression, List<Dictionary<JoinOptionsExpression,int>> selections)
    {
        foreach(var selection in selections)
        {
            currentSelections = selection;
            yield return Visit(expression);
        }
    }
    private Dictionary<JoinOptionsExpression,int> currentSelections;
    override Expression Visit(Expression node)
    {
        var opts = node as JoinOptionsExpression;
        if (opts != null)
            return base.Visit(opts.Options.ElementAt(currentSelections[opts]);
        else
            return base.Visit(node);
    }
}

无论如何,详尽的搜索会很快在您面前爆炸,我想您知道这一点。 免责声明:我只是在这个编辑器中输入它,它甚至可能无法编译,但你应该能够理解。

因此,首先我们将创建一个访问者,它将帮助我们从 Expression:

中提取 JoinExpression 个对象的列表
internal class FindJoinsVisitor : ExpressionVisitor
{
    private List<JoinExpression> expressions = new List<JoinExpression>();
    protected override Expression VisitJoin(JoinExpression join)
    {
        expressions.Add(join);
        return base.VisitJoin(join);
    }
    public IEnumerable<JoinExpression> JoinExpressions
    {
        get
        {
            return expressions;
        }
    }
}
public static IEnumerable<JoinExpression> FindJoins(
    this Expression expression)
{
    var visitor = new FindJoinsVisitor();
    visitor.Visit(expression);
    return visitor.JoinExpressions;
}

接下来我们将使用下面的方法,取自this blog post,得到一个序列序列的笛卡尔积:

static IEnumerable<IEnumerable<T>> CartesianProduct<T>(
    this IEnumerable<IEnumerable<T>> sequences) 
{ 
    IEnumerable<IEnumerable<T>> emptyProduct = new[] { Enumerable.Empty<T>() }; 
    return sequences.Aggregate( 
        emptyProduct, 
        (accumulator, sequence) => 
            from accseq in accumulator 
            from item in sequence 
            select accseq.Concat(new[] {item})); 
}

接下来我们将创建一个访问者,它接受一系列表达式对,并将该对中第一个表达式的所有实例替换为第二个表达式:

internal class ReplaceVisitor : ExpressionVisitor
{
    private readonly Dictionary<Expression, Expression> lookup;
    public ReplaceVisitor(Dictionary<Expression, Expression> pairsToReplace)
    {
        lookup = pairsToReplace;
    }
    public override Expression Visit(Expression node)
    {
        if(lookup.ContainsKey(node))
            return base.Visit(lookup[node]);
        else
            return base.Visit(node);
    }
}

public static Expression ReplaceAll(this Expression expression,
    Dictionary<Expression, Expression> pairsToReplace)
{
    return new ReplaceVisitor(pairsToReplace).Visit(expression);
}

public static Expression ReplaceAll(this Expression expression,
    IEnumerable<Tuple<Expression, Expression>> pairsToReplace)
{
    var lookup = pairsToReplace.ToDictionary(pair => pair.Item1, pair => pair.Item2);
    return new ReplaceVisitor(lookup).Visit(expression);
}

最后我们通过在我们的表达式中找到所有的连接表达式将所有东西放在一起,将它们投影到一个对序列中,其中 JoinExpression 是对中的第一个项目,第二个是每个可能的重置价值。从那里我们可以取它的笛卡尔积以获得成对的表达式替换的所有组合。最后,我们可以将替换的每个组合投影到表达式中,该表达式实际上替换了原始表达式中的所有这些对:

public static IEnumerable<Expression> AllJoinCombinations(Expression expression)
{
    var combinations = expression.FindJoins()
        .Select(join => new Tuple<Expression, Expression>[]
        {
            Tuple.Create<Expression, Expression>(join, new NestLoopJoinExpresion(join)), 
            Tuple.Create<Expression, Expression>(join, new MergeJoinExpresion(join)),
        })
        .CartesianProduct();

    return combinations.Select(combination => expression.ReplaceAll(combination));
}