创建一个 iQueryable,它对实体的单元测试 linq 不区分大小写 - NSubstitute
Create an iQueryable which is case insensitive for unit testing linq to entities - NSubstitute
我正在尝试对实体进行 Linq 单元测试,我想搜索不同大小写的同一个词和 return 同一个词。
当前情况是我正在尝试对搜索小写和大写单词进行单元测试,例如 EG "Hi" 和 "hi"。
Linq to entities using entity framework 目前支持此功能,我可以在 where 子句中搜索这两个术语,它会为我完成工作。
问题:
我正在尝试制作一个行为相同的模拟查询:
public class SimpleWord
{
public string Text;
}
[Test]
public void someTest()
{
//arrange
var lowerWords = new[] { "hi" };
var upperWords = new[] { "Hi" };
var wordsList = new List<SimpleWord> {new SimpleWord { Text = "hi" } };
IDbSet<SimpleWord> wordsDbSet = Substitute.For<DbSet<SimpleWord>, IDbSet<SimpleWord>>();
//set up the mock dbSet
var dataAsList = wordsList.ToList();
var queryable = dataAsList.AsQueryable();
wordsDbSet.Provider.Returns(queryable.Provider);
wordsDbSet.Expression.Returns(queryable.Expression);
wordsDbSet.ElementType.Returns(queryable.ElementType);
wordsDbSet.GetEnumerator().Returns(queryable.GetEnumerator());
//act
var resultLower = wordsDbSet.Where(wrd => lowerWords.Contains(wrd.Text)).ToList();
var resultHigher = wordsDbSet.Where(wrd => upperWords.Contains(wrd.Text)).ToList();
//assert
Assert.That(resultHigher.Count, Is.EqualTo(1), "did not find upper case");
Assert.That(resultLower.Count, Is.EqualTo(1), "did not find lower case");
}
问题:
当我对它进行任何 .Where() 搜索时调用搜索时,如何使 wordsDbSet 不区分大小写。
我不想更改法案:
var resultHigher = wordsDbSet.Where(wrd =>
upperWords.Contains(wrd.Text, StringComparer.OrdinalIgnoreCase)).ToList();
我要找的答案是更改排列:
wordsDbSet.When(contains.IsCalled).Return(contains.OrdinalIgnoreCasing)
感谢观看!
好的...可行但很长(不是很复杂...只是很长)。主要问题是实现 IQueryable<>
和 IQueryProvider
很痛苦,几乎没有解释它是如何工作的(你可以复制一些你可以在互联网上找到的代码,但是很少解释为什么和它是如何工作的)。
我写的是一个 IQueryable<>
包装器 "wraps" 一个 IQueryable<>
对象(就像 AsQueryable()
和 "on the fly" 返回的对象替换所有表达式树都传递了一些 string
方法(加上 Enumerable.Contains<string>
)和相应的重载,这些方法接受 StringComparison
/StringComparer
。像这样使用它:
var arr = new[] { "foo " };
var query = new[] { "Foo", "Bar", "bar" }
.AsQueryable()
.AsStringComparison(StringComparison.CurrentCultureIgnoreCase);
// query is a IQueryable<>
var res = query
.Where(x => string.Compare(x, "foo") < 0)
.Where(x => x.CompareTo("foo") < 0)
.Where(x => string.Compare(x, 0, "foo", 0, 3) < 0)
.Where(x => x.Contains("foo"))
.Where(x => string.Equals(x, "foo"))
.Where(x => x.Equals("foo"))
.Where(x => arr.Contains(x))
.Where(x => x == "foo")
.Where(x => x != "foo")
;
(这是我要替换的所有方法的列表)
和实施:
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
public static class StringComparisonQueryableWrapper
{
public static IQueryable<T> AsStringComparison<T>(this IQueryable<T> query, StringComparison comparisonType)
{
return new StringComparisonQueryableWrapper<T>(query, comparisonType);
}
}
public class StringComparisonQueryableWrapper<T> : IQueryable<T>, IQueryable, IQueryProvider
{
private readonly IQueryable<T> baseQuery;
public readonly StringComparison ComparisonType;
public StringComparisonQueryableWrapper(IQueryable<T> baseQuery, StringComparison comparisonType)
{
this.baseQuery = baseQuery;
this.ComparisonType = comparisonType;
}
Expression IQueryable.Expression => baseQuery.Expression;
Type IQueryable.ElementType => baseQuery.ElementType;
IQueryProvider IQueryable.Provider => this;
IQueryable IQueryProvider.CreateQuery(Expression expression)
{
Type type = expression.Type;
var iqueryableT = type.GetInterfaces().Where(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IQueryable<>)).Single();
Type type2 = iqueryableT.GetGenericArguments()[0];
var thisType = typeof(StringComparisonQueryableWrapper<>).MakeGenericType(typeof(T));
var createQueryMethod = thisType.GetMethods(BindingFlags.Instance | BindingFlags.NonPublic).Where(x => x.Name == "System.Linq.IQueryProvider.CreateQuery" && x.IsGenericMethod).Single().MakeGenericMethod(type2);
var queryable = (IQueryable)createQueryMethod.Invoke(this, new object[] { expression });
return queryable;
}
IQueryable<TElement> IQueryProvider.CreateQuery<TElement>(Expression expression)
{
var expression2 = TransformExpression(expression);
var query = baseQuery.Provider.CreateQuery<TElement>(expression2);
return new StringComparisonQueryableWrapper<TElement>(query, ComparisonType);
}
object IQueryProvider.Execute(Expression expression)
{
var expression2 = TransformExpression(expression);
return baseQuery.Provider.Execute(expression2);
}
TResult IQueryProvider.Execute<TResult>(Expression expression)
{
var expression2 = TransformExpression(expression);
return baseQuery.Provider.Execute<TResult>(expression2);
}
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
return baseQuery.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return baseQuery.GetEnumerator();
}
private Expression TransformExpression(Expression expression)
{
Expression expression2 = new StringComparisonExpressionTranformer(ComparisonType).Visit(expression);
return expression2;
}
private class StringComparisonExpressionTranformer : ExpressionVisitor
{
private readonly StringComparison comparisonType;
private static readonly IReadOnlyDictionary<MethodInfo, Func<MethodCallExpression, StringComparison, Expression>> transformers;
private static readonly IReadOnlyDictionary<MethodInfo, Func<BinaryExpression, StringComparison, Expression>> transformers2;
//
private static readonly IReadOnlyDictionary<StringComparison, StringComparer> comparisonToComparer = new Dictionary<StringComparison, System.StringComparer>
{
{ StringComparison.CurrentCulture, StringComparer.CurrentCulture },
{ StringComparison.CurrentCultureIgnoreCase, StringComparer.CurrentCultureIgnoreCase },
{ StringComparison.InvariantCulture, StringComparer.InvariantCulture },
{ StringComparison.InvariantCultureIgnoreCase, StringComparer.InvariantCultureIgnoreCase },
{ StringComparison.Ordinal, StringComparer.Ordinal },
{ StringComparison.OrdinalIgnoreCase, StringComparer.OrdinalIgnoreCase }
};
static StringComparisonExpressionTranformer()
{
var transformers = new Dictionary<MethodInfo, Func<MethodCallExpression, StringComparison, Expression>>();
{
// string.Compare("foo", "bar")
var method = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers.Add(method, Compare);
}
{
// string.Compare("foo", 0, "bar", 0, 3)
var method = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(int), typeof(string), typeof(int), typeof(int) }, null);
transformers.Add(method, CompareIndexLength);
}
{
// "foo".CompareTo("bar")
var method = typeof(string).GetMethod(nameof(string.CompareTo), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
transformers.Add(method, CompareTo);
}
{
// "foo".Contains("bar")
var method = typeof(string).GetMethod(nameof(string.Contains), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
transformers.Add(method, Contains);
}
{
// string.Equals("foo", "bar")
var method = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers.Add(method, EqualsStatic);
}
{
// "foo".Equals("bar")
var method = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
transformers.Add(method, EqualsInstance);
}
{
// Enumerable.Contains<TSource>(source, "foo")
var method = (from x in typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
where x.Name == nameof(Enumerable.Contains)
let args = x.GetGenericArguments()
where args.Length == 1
let pars = x.GetParameters()
where pars.Length == 2 &&
pars[0].ParameterType == typeof(IEnumerable<>).MakeGenericType(args[0]) &&
pars[1].ParameterType == args[0]
select x).Single();
// Enumerable.Contains<string>(source, "foo")
var method2 = method.MakeGenericMethod(typeof(string));
transformers.Add(method2, EnumerableContains);
}
// TODO: all the various Array.Find*, Array.IndexOf
StringComparisonExpressionTranformer.transformers = transformers;
var transformers2 = new Dictionary<MethodInfo, Func<BinaryExpression, StringComparison, Expression>>();
{
// ==
var method = typeof(string).GetMethod("op_Equality", BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers2.Add(method, OpEquality);
}
{
// !=
var method = typeof(string).GetMethod("op_Inequality", BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers2.Add(method, OpInequality);
}
StringComparisonExpressionTranformer.transformers2 = transformers2;
}
public StringComparisonExpressionTranformer(StringComparison comparisonType)
{
this.comparisonType = comparisonType;
}
// methods
protected override Expression VisitMethodCall(MethodCallExpression node)
{
Func<MethodCallExpression, StringComparison, Expression> transformer;
if (transformers.TryGetValue(node.Method, out transformer))
{
Expression node2 = transformer(node, comparisonType);
return Visit(node2);
}
return base.VisitMethodCall(node);
}
// operators
protected override Expression VisitBinary(BinaryExpression node)
{
Func<BinaryExpression, StringComparison, Expression> transformer;
if (node.Method != null && transformers2.TryGetValue(node.Method, out transformer))
{
Expression node2 = transformer(node, comparisonType);
return Visit(node2);
}
return base.VisitBinary(node);
}
private static readonly MethodInfo StringEqualsStatic = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo StringEqualsInstance = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo StringCompareStatic = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo StringCompareIndexLengthStatic = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(int), typeof(string), typeof(int), typeof(int), typeof(StringComparison) }, null);
private static readonly MethodInfo StringIndexOfInstance = typeof(string).GetMethod(nameof(string.IndexOf), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo EnumerableContainsStatic = (from x in typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
where x.Name == nameof(Enumerable.Contains)
let args = x.GetGenericArguments()
where args.Length == 1
let pars = x.GetParameters()
where pars.Length == 3 &&
pars[0].ParameterType == typeof(IEnumerable<>).MakeGenericType(args[0]) &&
pars[1].ParameterType == args[0] &&
pars[2].ParameterType == typeof(IEqualityComparer<>).MakeGenericType(args[0])
select x).Single().MakeGenericMethod(typeof(string));
private static Expression Compare(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringCompareStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparisonType));
}
private static Expression CompareIndexLength(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringCompareIndexLengthStatic, exp.Arguments[0], exp.Arguments[1], exp.Arguments[2], exp.Arguments[3], exp.Arguments[4], Expression.Constant(comparisonType));
}
private static Expression CompareTo(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringCompareStatic, exp.Object, exp.Arguments[0], Expression.Constant(comparisonType));
}
private static Expression Contains(MethodCallExpression exp, StringComparison comparisonType)
{
// No "".Contains(, StringComparison). Translate to "".IndexOf(, StringComparison) != -1
return Expression.NotEqual(Expression.Call(exp.Object, StringIndexOfInstance, exp.Arguments[0], Expression.Constant(comparisonType)), Expression.Constant(-1));
}
private static Expression EqualsStatic(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringEqualsStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparisonType));
}
private static Expression EqualsInstance(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(exp.Object, StringEqualsInstance, exp.Arguments[0], Expression.Constant(comparisonType));
}
private static Expression EnumerableContains(MethodCallExpression exp, StringComparison comparisonType)
{
StringComparer comparer = comparisonToComparer[comparisonType];
return Expression.Call(EnumerableContainsStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparer));
}
private static Expression OpEquality(BinaryExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringEqualsStatic, exp.Left, exp.Right, Expression.Constant(comparisonType));
}
private static Expression OpInequality(BinaryExpression exp, StringComparison comparisonType)
{
return Expression.Not(Expression.Call(StringEqualsStatic, exp.Left, exp.Right, Expression.Constant(comparisonType)));
}
}
}
如果您看一下,有一个非常简单的代理 IQueryable
/IQueryProvider
实现 (StringComparisonQueryableWrapper<T>
),它使用 ExpressionVisitor
(StringComparisonExpressionTranformer
) 来查找和替换某些特定的 MethodCallExpression
(调用方法)和 BinaryExpression
(==
和 !=
运算符)为 MethodCallExpression
,这些方法使用 StringComparison
/StringComparer
。缺少 Array.IndexOf
、Enumerable.SequenceEquals
...
的替代品
我正在尝试对实体进行 Linq 单元测试,我想搜索不同大小写的同一个词和 return 同一个词。
当前情况是我正在尝试对搜索小写和大写单词进行单元测试,例如 EG "Hi" 和 "hi"。
Linq to entities using entity framework 目前支持此功能,我可以在 where 子句中搜索这两个术语,它会为我完成工作。
问题: 我正在尝试制作一个行为相同的模拟查询:
public class SimpleWord
{
public string Text;
}
[Test]
public void someTest()
{
//arrange
var lowerWords = new[] { "hi" };
var upperWords = new[] { "Hi" };
var wordsList = new List<SimpleWord> {new SimpleWord { Text = "hi" } };
IDbSet<SimpleWord> wordsDbSet = Substitute.For<DbSet<SimpleWord>, IDbSet<SimpleWord>>();
//set up the mock dbSet
var dataAsList = wordsList.ToList();
var queryable = dataAsList.AsQueryable();
wordsDbSet.Provider.Returns(queryable.Provider);
wordsDbSet.Expression.Returns(queryable.Expression);
wordsDbSet.ElementType.Returns(queryable.ElementType);
wordsDbSet.GetEnumerator().Returns(queryable.GetEnumerator());
//act
var resultLower = wordsDbSet.Where(wrd => lowerWords.Contains(wrd.Text)).ToList();
var resultHigher = wordsDbSet.Where(wrd => upperWords.Contains(wrd.Text)).ToList();
//assert
Assert.That(resultHigher.Count, Is.EqualTo(1), "did not find upper case");
Assert.That(resultLower.Count, Is.EqualTo(1), "did not find lower case");
}
问题: 当我对它进行任何 .Where() 搜索时调用搜索时,如何使 wordsDbSet 不区分大小写。
我不想更改法案:
var resultHigher = wordsDbSet.Where(wrd =>
upperWords.Contains(wrd.Text, StringComparer.OrdinalIgnoreCase)).ToList();
我要找的答案是更改排列:
wordsDbSet.When(contains.IsCalled).Return(contains.OrdinalIgnoreCasing)
感谢观看!
好的...可行但很长(不是很复杂...只是很长)。主要问题是实现 IQueryable<>
和 IQueryProvider
很痛苦,几乎没有解释它是如何工作的(你可以复制一些你可以在互联网上找到的代码,但是很少解释为什么和它是如何工作的)。
我写的是一个 IQueryable<>
包装器 "wraps" 一个 IQueryable<>
对象(就像 AsQueryable()
和 "on the fly" 返回的对象替换所有表达式树都传递了一些 string
方法(加上 Enumerable.Contains<string>
)和相应的重载,这些方法接受 StringComparison
/StringComparer
。像这样使用它:
var arr = new[] { "foo " };
var query = new[] { "Foo", "Bar", "bar" }
.AsQueryable()
.AsStringComparison(StringComparison.CurrentCultureIgnoreCase);
// query is a IQueryable<>
var res = query
.Where(x => string.Compare(x, "foo") < 0)
.Where(x => x.CompareTo("foo") < 0)
.Where(x => string.Compare(x, 0, "foo", 0, 3) < 0)
.Where(x => x.Contains("foo"))
.Where(x => string.Equals(x, "foo"))
.Where(x => x.Equals("foo"))
.Where(x => arr.Contains(x))
.Where(x => x == "foo")
.Where(x => x != "foo")
;
(这是我要替换的所有方法的列表)
和实施:
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
public static class StringComparisonQueryableWrapper
{
public static IQueryable<T> AsStringComparison<T>(this IQueryable<T> query, StringComparison comparisonType)
{
return new StringComparisonQueryableWrapper<T>(query, comparisonType);
}
}
public class StringComparisonQueryableWrapper<T> : IQueryable<T>, IQueryable, IQueryProvider
{
private readonly IQueryable<T> baseQuery;
public readonly StringComparison ComparisonType;
public StringComparisonQueryableWrapper(IQueryable<T> baseQuery, StringComparison comparisonType)
{
this.baseQuery = baseQuery;
this.ComparisonType = comparisonType;
}
Expression IQueryable.Expression => baseQuery.Expression;
Type IQueryable.ElementType => baseQuery.ElementType;
IQueryProvider IQueryable.Provider => this;
IQueryable IQueryProvider.CreateQuery(Expression expression)
{
Type type = expression.Type;
var iqueryableT = type.GetInterfaces().Where(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IQueryable<>)).Single();
Type type2 = iqueryableT.GetGenericArguments()[0];
var thisType = typeof(StringComparisonQueryableWrapper<>).MakeGenericType(typeof(T));
var createQueryMethod = thisType.GetMethods(BindingFlags.Instance | BindingFlags.NonPublic).Where(x => x.Name == "System.Linq.IQueryProvider.CreateQuery" && x.IsGenericMethod).Single().MakeGenericMethod(type2);
var queryable = (IQueryable)createQueryMethod.Invoke(this, new object[] { expression });
return queryable;
}
IQueryable<TElement> IQueryProvider.CreateQuery<TElement>(Expression expression)
{
var expression2 = TransformExpression(expression);
var query = baseQuery.Provider.CreateQuery<TElement>(expression2);
return new StringComparisonQueryableWrapper<TElement>(query, ComparisonType);
}
object IQueryProvider.Execute(Expression expression)
{
var expression2 = TransformExpression(expression);
return baseQuery.Provider.Execute(expression2);
}
TResult IQueryProvider.Execute<TResult>(Expression expression)
{
var expression2 = TransformExpression(expression);
return baseQuery.Provider.Execute<TResult>(expression2);
}
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
return baseQuery.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return baseQuery.GetEnumerator();
}
private Expression TransformExpression(Expression expression)
{
Expression expression2 = new StringComparisonExpressionTranformer(ComparisonType).Visit(expression);
return expression2;
}
private class StringComparisonExpressionTranformer : ExpressionVisitor
{
private readonly StringComparison comparisonType;
private static readonly IReadOnlyDictionary<MethodInfo, Func<MethodCallExpression, StringComparison, Expression>> transformers;
private static readonly IReadOnlyDictionary<MethodInfo, Func<BinaryExpression, StringComparison, Expression>> transformers2;
//
private static readonly IReadOnlyDictionary<StringComparison, StringComparer> comparisonToComparer = new Dictionary<StringComparison, System.StringComparer>
{
{ StringComparison.CurrentCulture, StringComparer.CurrentCulture },
{ StringComparison.CurrentCultureIgnoreCase, StringComparer.CurrentCultureIgnoreCase },
{ StringComparison.InvariantCulture, StringComparer.InvariantCulture },
{ StringComparison.InvariantCultureIgnoreCase, StringComparer.InvariantCultureIgnoreCase },
{ StringComparison.Ordinal, StringComparer.Ordinal },
{ StringComparison.OrdinalIgnoreCase, StringComparer.OrdinalIgnoreCase }
};
static StringComparisonExpressionTranformer()
{
var transformers = new Dictionary<MethodInfo, Func<MethodCallExpression, StringComparison, Expression>>();
{
// string.Compare("foo", "bar")
var method = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers.Add(method, Compare);
}
{
// string.Compare("foo", 0, "bar", 0, 3)
var method = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(int), typeof(string), typeof(int), typeof(int) }, null);
transformers.Add(method, CompareIndexLength);
}
{
// "foo".CompareTo("bar")
var method = typeof(string).GetMethod(nameof(string.CompareTo), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
transformers.Add(method, CompareTo);
}
{
// "foo".Contains("bar")
var method = typeof(string).GetMethod(nameof(string.Contains), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
transformers.Add(method, Contains);
}
{
// string.Equals("foo", "bar")
var method = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers.Add(method, EqualsStatic);
}
{
// "foo".Equals("bar")
var method = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string) }, null);
transformers.Add(method, EqualsInstance);
}
{
// Enumerable.Contains<TSource>(source, "foo")
var method = (from x in typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
where x.Name == nameof(Enumerable.Contains)
let args = x.GetGenericArguments()
where args.Length == 1
let pars = x.GetParameters()
where pars.Length == 2 &&
pars[0].ParameterType == typeof(IEnumerable<>).MakeGenericType(args[0]) &&
pars[1].ParameterType == args[0]
select x).Single();
// Enumerable.Contains<string>(source, "foo")
var method2 = method.MakeGenericMethod(typeof(string));
transformers.Add(method2, EnumerableContains);
}
// TODO: all the various Array.Find*, Array.IndexOf
StringComparisonExpressionTranformer.transformers = transformers;
var transformers2 = new Dictionary<MethodInfo, Func<BinaryExpression, StringComparison, Expression>>();
{
// ==
var method = typeof(string).GetMethod("op_Equality", BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers2.Add(method, OpEquality);
}
{
// !=
var method = typeof(string).GetMethod("op_Inequality", BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string) }, null);
transformers2.Add(method, OpInequality);
}
StringComparisonExpressionTranformer.transformers2 = transformers2;
}
public StringComparisonExpressionTranformer(StringComparison comparisonType)
{
this.comparisonType = comparisonType;
}
// methods
protected override Expression VisitMethodCall(MethodCallExpression node)
{
Func<MethodCallExpression, StringComparison, Expression> transformer;
if (transformers.TryGetValue(node.Method, out transformer))
{
Expression node2 = transformer(node, comparisonType);
return Visit(node2);
}
return base.VisitMethodCall(node);
}
// operators
protected override Expression VisitBinary(BinaryExpression node)
{
Func<BinaryExpression, StringComparison, Expression> transformer;
if (node.Method != null && transformers2.TryGetValue(node.Method, out transformer))
{
Expression node2 = transformer(node, comparisonType);
return Visit(node2);
}
return base.VisitBinary(node);
}
private static readonly MethodInfo StringEqualsStatic = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo StringEqualsInstance = typeof(string).GetMethod(nameof(string.Equals), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo StringCompareStatic = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo StringCompareIndexLengthStatic = typeof(string).GetMethod(nameof(string.Compare), BindingFlags.Static | BindingFlags.Public, null, new[] { typeof(string), typeof(int), typeof(string), typeof(int), typeof(int), typeof(StringComparison) }, null);
private static readonly MethodInfo StringIndexOfInstance = typeof(string).GetMethod(nameof(string.IndexOf), BindingFlags.Instance | BindingFlags.Public, null, new[] { typeof(string), typeof(StringComparison) }, null);
private static readonly MethodInfo EnumerableContainsStatic = (from x in typeof(Enumerable).GetMethods(BindingFlags.Static | BindingFlags.Public)
where x.Name == nameof(Enumerable.Contains)
let args = x.GetGenericArguments()
where args.Length == 1
let pars = x.GetParameters()
where pars.Length == 3 &&
pars[0].ParameterType == typeof(IEnumerable<>).MakeGenericType(args[0]) &&
pars[1].ParameterType == args[0] &&
pars[2].ParameterType == typeof(IEqualityComparer<>).MakeGenericType(args[0])
select x).Single().MakeGenericMethod(typeof(string));
private static Expression Compare(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringCompareStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparisonType));
}
private static Expression CompareIndexLength(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringCompareIndexLengthStatic, exp.Arguments[0], exp.Arguments[1], exp.Arguments[2], exp.Arguments[3], exp.Arguments[4], Expression.Constant(comparisonType));
}
private static Expression CompareTo(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringCompareStatic, exp.Object, exp.Arguments[0], Expression.Constant(comparisonType));
}
private static Expression Contains(MethodCallExpression exp, StringComparison comparisonType)
{
// No "".Contains(, StringComparison). Translate to "".IndexOf(, StringComparison) != -1
return Expression.NotEqual(Expression.Call(exp.Object, StringIndexOfInstance, exp.Arguments[0], Expression.Constant(comparisonType)), Expression.Constant(-1));
}
private static Expression EqualsStatic(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringEqualsStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparisonType));
}
private static Expression EqualsInstance(MethodCallExpression exp, StringComparison comparisonType)
{
return Expression.Call(exp.Object, StringEqualsInstance, exp.Arguments[0], Expression.Constant(comparisonType));
}
private static Expression EnumerableContains(MethodCallExpression exp, StringComparison comparisonType)
{
StringComparer comparer = comparisonToComparer[comparisonType];
return Expression.Call(EnumerableContainsStatic, exp.Arguments[0], exp.Arguments[1], Expression.Constant(comparer));
}
private static Expression OpEquality(BinaryExpression exp, StringComparison comparisonType)
{
return Expression.Call(StringEqualsStatic, exp.Left, exp.Right, Expression.Constant(comparisonType));
}
private static Expression OpInequality(BinaryExpression exp, StringComparison comparisonType)
{
return Expression.Not(Expression.Call(StringEqualsStatic, exp.Left, exp.Right, Expression.Constant(comparisonType)));
}
}
}
如果您看一下,有一个非常简单的代理 IQueryable
/IQueryProvider
实现 (StringComparisonQueryableWrapper<T>
),它使用 ExpressionVisitor
(StringComparisonExpressionTranformer
) 来查找和替换某些特定的 MethodCallExpression
(调用方法)和 BinaryExpression
(==
和 !=
运算符)为 MethodCallExpression
,这些方法使用 StringComparison
/StringComparer
。缺少 Array.IndexOf
、Enumerable.SequenceEquals
...