创建一个 iQueryable,它对实体的单元测试 linq 不区分大小写 - NSubstitute

Create an iQueryable which is case insensitive for unit testing linq to entities - NSubstitute

我正在尝试对实体进行 Linq 单元测试,我想搜索不同大小写的同一个词和 return 同一个词。

当前情况是我正在尝试对搜索小写和大写单词进行单元测试,例如 EG "Hi" 和 "hi"。

Linq to entities using entity framework 目前支持此功能,我可以在 where 子句中搜索这两个术语,它会为我完成工作。

问题: 我正在尝试制作一个行为相同的模拟查询:

 public class SimpleWord
    {
        public string Text;
    }

    [Test]
    public void someTest()
    {
        //arrange
        var lowerWords = new[] { "hi" };
        var upperWords = new[] { "Hi" };

        var wordsList = new List<SimpleWord> {new SimpleWord { Text = "hi" } };
        IDbSet<SimpleWord> wordsDbSet = Substitute.For<DbSet<SimpleWord>, IDbSet<SimpleWord>>();

        //set up the mock dbSet
        var dataAsList = wordsList.ToList();
        var queryable = dataAsList.AsQueryable();
        wordsDbSet.Provider.Returns(queryable.Provider);
        wordsDbSet.Expression.Returns(queryable.Expression);
        wordsDbSet.ElementType.Returns(queryable.ElementType);
        wordsDbSet.GetEnumerator().Returns(queryable.GetEnumerator());

        //act
        var resultLower = wordsDbSet.Where(wrd => lowerWords.Contains(wrd.Text)).ToList();
        var resultHigher = wordsDbSet.Where(wrd => upperWords.Contains(wrd.Text)).ToList();
        //assert
        Assert.That(resultHigher.Count, Is.EqualTo(1), "did not find upper case");
        Assert.That(resultLower.Count, Is.EqualTo(1), "did not find lower case");
    }

问题: 当我对它进行任何 .Where() 搜索时调用搜索时,如何使 wordsDbSet 不区分大小写。

我不想更改法案:

 var resultHigher = wordsDbSet.Where(wrd => 
                    upperWords.Contains(wrd.Text, StringComparer.OrdinalIgnoreCase)).ToList();

我要找的答案是更改排列:

wordsDbSet.When(contains.IsCalled).Return(contains.OrdinalIgnoreCasing)

感谢观看!

好的...可行但很长(不是很复杂...只是很长)。主要问题是实现 IQueryable<>IQueryProvider 很痛苦,几乎没有解释它是如何工作的(你可以复制一些你可以在互联网上找到的代码,但是很少解释为什么和它是如何工作的)。

我写的是一个 IQueryable<> 包装器 "wraps" 一个 IQueryable<> 对象(就像 AsQueryable() 和 "on the fly" 返回的对象替换所有表达式树都传递了一些 string 方法(加上 Enumerable.Contains<string>)和相应的重载,这些方法接受 StringComparison/StringComparer。像这样使用它:

var arr = new[] { "foo " };
var query = new[] { "Foo", "Bar", "bar" }
    .AsQueryable()
    .AsStringComparison(StringComparison.CurrentCultureIgnoreCase);

// query is a IQueryable<>

var res = query
    .Where(x => string.Compare(x, "foo") < 0)
    .Where(x => x.CompareTo("foo") < 0)
    .Where(x => string.Compare(x, 0, "foo", 0, 3) < 0)
    .Where(x => x.Contains("foo"))
    .Where(x => string.Equals(x, "foo"))
    .Where(x => x.Equals("foo"))
    .Where(x => arr.Contains(x))
    .Where(x => x == "foo")
    .Where(x => x != "foo")
;

(这是我要替换的所有方法的列表)

和实施:

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

public static class StringComparisonQueryableWrapper
{
    public static IQueryable<T> AsStringComparison<T>(this IQueryable<T> query, StringComparison comparisonType)
    {
        return new StringComparisonQueryableWrapper<T>(query, comparisonType);
    }
}

public class StringComparisonQueryableWrapper<T> : IQueryable<T>, IQueryable, IQueryProvider
{
    private readonly IQueryable<T> baseQuery;
    public readonly StringComparison ComparisonType;

    public StringComparisonQueryableWrapper(IQueryable<T> baseQuery, StringComparison comparisonType)
    {
        this.baseQuery = baseQuery;
        this.ComparisonType = comparisonType;
    }

    Expression IQueryable.Expression => baseQuery.Expression;

    Type IQueryable.ElementType => baseQuery.ElementType;

    IQueryProvider IQueryable.Provider => this;

    IQueryable IQueryProvider.CreateQuery(Expression expression)
    {
        Type type = expression.Type;
        var iqueryableT = type.GetInterfaces().Where(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IQueryable<>)).Single();
        Type type2 = iqueryableT.GetGenericArguments()[0];

        var thisType = typeof(StringComparisonQueryableWrapper<>).MakeGenericType(typeof(T));
        var createQueryMethod = thisType.GetMethods(BindingFlags.Instance | BindingFlags.NonPublic).Where(x => x.Name == "System.Linq.IQueryProvider.CreateQuery" && x.IsGenericMethod).Single().MakeGenericMethod(type2);
        var queryable = (IQueryable)createQueryMethod.Invoke(this, new object[] { expression });
        return queryable;
    }

    IQueryable<TElement> IQueryProvider.CreateQuery<TElement>(Expression expression)
    {
        var expression2 = TransformExpression(expression);
        var query = baseQuery.Provider.CreateQuery<TElement>(expression2);
        return new StringComparisonQueryableWrapper<TElement>(query, ComparisonType);
    }

    object IQueryProvider.Execute(Expression expression)
    {
        var expression2 = TransformExpression(expression);
        return baseQuery.Provider.Execute(expression2);
    }

    TResult IQueryProvider.Execute<TResult>(Expression expression)
    {
        var expression2 = TransformExpression(expression);
        return baseQuery.Provider.Execute<TResult>(expression2);
    }

    IEnumerator<T> IEnumerable<T>.GetEnumerator()
    {
        return baseQuery.GetEnumerator();
    }

    IEnumerator IEnumerable.GetEnumerator()
    {
        return baseQuery.GetEnumerator();
    }

    private Expression TransformExpression(Expression expression)
    {
        Expression expression2 = new StringComparisonExpressionTranformer(ComparisonType).Visit(expression);
        return expression2;
    }

    private class StringComparisonExpressionTranformer : ExpressionVisitor
    {
        private readonly StringComparison comparisonType;

        private static readonly IReadOnlyDictionary<MethodInfo, Func<MethodCallExpression, StringComparison, Expression>> transformers;
        private static readonly IReadOnlyDictionary<MethodInfo, Func<BinaryExpression, StringComparison, Expression>> transformers2;

        // 
        private static readonly IReadOnlyDictionary<StringComparison, StringComparer> comparisonToComparer = new Dictionary<StringComparison, System.StringComparer>
        {
            { StringComparison.CurrentCulture, StringComparer.CurrentCulture },
            { StringComparison.CurrentCultureIgnoreCase, StringComparer.CurrentCultureIgnoreCase },
            { StringComparison.InvariantCulture, StringComparer.InvariantCulture },
            { StringComparison.InvariantCultureIgnoreCase, StringComparer.InvariantCultureIgnoreCase },
            { StringComparison.Ordinal, StringComparer.Ordinal },
            { StringComparison.OrdinalIgnoreCase, StringComparer.OrdinalIgnoreCase }
        };

        static StringComparisonExpressionTranformer()
        {
            var transformers = new Dictionary<MethodInfo, Func<MethodCallExpression, StringComparison, Expression>>();

            {
                // string.Compare("foo", "bar")
                var method = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
                transformers.Add(method, Compare);
            }

            {
                // string.Compare("foo", 0, "bar", 0, 3)
                var method = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(int), typeof(string), typeof(int), typeof(int) }, null);
                transformers.Add(method, CompareIndexLength);
            }

            {
                // "foo".CompareTo("bar")
                var method = typeof(string).GetMethod(nameof(string.CompareTo), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
                transformers.Add(method, CompareTo);
            }

            {
                // "foo".Contains("bar")
                var method = typeof(string).GetMethod(nameof(string.Contains), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
                transformers.Add(method, Contains);
            }

            {
                // string.Equals("foo", "bar")
                var method = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
                transformers.Add(method, EqualsStatic);
            }

            {
                // "foo".Equals("bar")
                var method = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
                transformers.Add(method, EqualsInstance);
            }

            {
                // Enumerable.Contains<TSource>(source, "foo")
                var method = (from x in typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
                              where x.Name == nameof(Enumerable.Contains)
                              let args = x.GetGenericArguments()
                              where args.Length == 1
                              let pars = x.GetParameters()
                              where pars.Length == 2 &&
                                  pars[0].ParameterType == typeof(IEnumerable<>).MakeGenericType(args[0]) &&
                                  pars[1].ParameterType == args[0]
                              select x).Single();

                // Enumerable.Contains<string>(source, "foo")
                var method2 = method.MakeGenericMethod(typeof(string));

                transformers.Add(method2, EnumerableContains);
            }

            // TODO: all the various Array.Find*, Array.IndexOf

            StringComparisonExpressionTranformer.transformers = transformers;

            var transformers2 = new Dictionary<MethodInfo, Func<BinaryExpression, StringComparison, Expression>>();

            {
                // ==
                var method = typeof(string).GetMethod("op_Equality", BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
                transformers2.Add(method, OpEquality);
            }

            {
                // !=
                var method = typeof(string).GetMethod("op_Inequality", BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
                transformers2.Add(method, OpInequality);
            }

            StringComparisonExpressionTranformer.transformers2 = transformers2;
        }

        public StringComparisonExpressionTranformer(StringComparison comparisonType)
        {
            this.comparisonType = comparisonType;
        }

        // methods
        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            Func<MethodCallExpression, StringComparison, Expression> transformer;

            if (transformers.TryGetValue(node.Method, out transformer))
            {
                Expression node2 = transformer(node, comparisonType);
                return Visit(node2);
            }

            return base.VisitMethodCall(node);
        }

        // operators
        protected override Expression VisitBinary(BinaryExpression node)
        {
            Func<BinaryExpression, StringComparison, Expression> transformer;

            if (node.Method != null && transformers2.TryGetValue(node.Method, out transformer))
            {
                Expression node2 = transformer(node, comparisonType);
                return Visit(node2);
            }

            return base.VisitBinary(node);
        }

        private static readonly MethodInfo StringEqualsStatic = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string), typeof(StringComparison) }, null);
        private static readonly MethodInfo StringEqualsInstance = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string), typeof(StringComparison) }, null);

        private static readonly MethodInfo StringCompareStatic = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string), typeof(StringComparison) }, null);
        private static readonly MethodInfo StringCompareIndexLengthStatic = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(int), typeof(string), typeof(int), typeof(int), typeof(StringComparison) }, null);

        private static readonly MethodInfo StringIndexOfInstance = typeof(string).GetMethod(nameof(string.IndexOf), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string), typeof(StringComparison) }, null);

        private static readonly MethodInfo EnumerableContainsStatic = (from x in typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
                                                                       where x.Name == nameof(Enumerable.Contains)
                                                                       let args = x.GetGenericArguments()
                                                                       where args.Length == 1
                                                                       let pars = x.GetParameters()
                                                                       where pars.Length == 3 &&
                                                                           pars[0].ParameterType == typeof(IEnumerable<>).MakeGenericType(args[0]) &&
                                                                           pars[1].ParameterType == args[0] &&
                                                                           pars[2].ParameterType == typeof(IEqualityComparer<>).MakeGenericType(args[0])
                                                                       select x).Single().MakeGenericMethod(typeof(string));

        private static Expression Compare(MethodCallExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(StringCompareStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparisonType));
        }

        private static Expression CompareIndexLength(MethodCallExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(StringCompareIndexLengthStatic, exp.Arguments[0], exp.Arguments[1], exp.Arguments[2], exp.Arguments[3], exp.Arguments[4], Expression.Constant(comparisonType));
        }

        private static Expression CompareTo(MethodCallExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(StringCompareStatic, exp.Object, exp.Arguments[0], Expression.Constant(comparisonType));
        }

        private static Expression Contains(MethodCallExpression exp, StringComparison comparisonType)
        {
            // No "".Contains(, StringComparison). Translate to "".IndexOf(, StringComparison) != -1
            return Expression.NotEqual(Expression.Call(exp.Object, StringIndexOfInstance, exp.Arguments[0], Expression.Constant(comparisonType)), Expression.Constant(-1));
        }

        private static Expression EqualsStatic(MethodCallExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(StringEqualsStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparisonType));
        }

        private static Expression EqualsInstance(MethodCallExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(exp.Object, StringEqualsInstance, exp.Arguments[0], Expression.Constant(comparisonType));
        }

        private static Expression EnumerableContains(MethodCallExpression exp, StringComparison comparisonType)
        {
            StringComparer comparer = comparisonToComparer[comparisonType];
            return Expression.Call(EnumerableContainsStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparer));
        }


        private static Expression OpEquality(BinaryExpression exp, StringComparison comparisonType)
        {
            return Expression.Call(StringEqualsStatic, exp.Left, exp.Right, Expression.Constant(comparisonType));
        }

        private static Expression OpInequality(BinaryExpression exp, StringComparison comparisonType)
        {
            return Expression.Not(Expression.Call(StringEqualsStatic, exp.Left, exp.Right, Expression.Constant(comparisonType)));
        }
    }
}

如果您看一下,有一个非常简单的代理 IQueryable/IQueryProvider 实现 (StringComparisonQueryableWrapper<T>),它使用 ExpressionVisitor (StringComparisonExpressionTranformer ) 来查找和替换某些特定的 MethodCallExpression(调用方法)和 BinaryExpression==!= 运算符)为 MethodCallExpression,这些方法使用 StringComparison/StringComparer。缺少 Array.IndexOfEnumerable.SequenceEquals...

的替代品