如何从定义为接口的 IQueryable 获取 where 子句
How to get the where clause from IQueryable defined as interface
class Program
{
static void Main(string[] args)
{
var c = new SampleClass<ClassString>();
c.ClassStrings.Add(new ClassString{ Name1 = "1", Name2 = "1"});
c.ClassStrings.Add(new ClassString{ Name1 = "2", Name2 = "2"});
var result = c.Query<ClassString>().Where(s => s.Name1.Equals("2"));
Console.WriteLine(result);
Console.ReadLine();
}
}
public class ClassString
{
public string Name1 { get; set; }
public string Name2 { get; set; }
}
public interface ISampleQ
{
IQueryable<T> Query<T>() where T: class , new();
}
public class SampleClass<X> : ISampleQ
{
public List<X> ClassStrings { get; private set; }
public SampleClass()
{
ClassStrings = new List<X>();
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the WHERE expression from here.
return new EnumerableQuery<T>((IEnumerable<T>) ClassStrings);
}
}
我调查了这个 solution1, solution2 and solution3 似乎不适用于我的问题。由于 where 子句是在外部定义的,它是 class 的接口。如何获取查询方法中的表达式?因为没有变量被传递。
目的,我想要检索并注入回目标(即作为 IQueryable 的 DBContext)。因为我们有一个像 ISampleQ 这样的通用接口。
添加了新的示例代码但场景相同:
internal class Program
{
private static void Main(string[] args)
{
var oracleDbContext = new OracleDbContext();
var result = oracleDbContext.Query<Person>().Where(person => person.Name.Equals("username"));
Console.WriteLine();
Console.ReadLine();
}
}
public interface IGenericQuery
{
IQueryable<T> Query<T>() where T : class , new();
}
public class OracleDbContext : IGenericQuery
{
public OracleDbContext()
{
//Will hold all oracle operations here. For brevity, only
//Query are exposed.
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the where predicate here. Since the where was defined outside of the
//class. I want to retrieve since the IQueryable<T> is generic to both class
//OracleDbContext and MssqlDbContext. I want to re-inject the where or add
//new expression before calling.
//
//For eg.
//oracleDbContext.Query<T>(where clause from here)
return null;
}
}
public class MssqlDbContext : IGenericQuery
{
public MssqlDbContext()
{
//Will hold all MSSQL operations here. For brevity, only
//Query are exposed.
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the where predicate here.
return null;
}
}
public class Person
{
public int Id { get; set; }
public int Name { get; set; }
}
它相当复杂...现在...Queryable.Where()
以这种方式工作:
public static IQueryable<TSource> Where<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate)
{
return source.Provider.CreateQuery<TSource>(Expression.Call(null, ...
所以 Queryable.Where
调用 source.Provider.CreateQuery()
返回一个新的 IQueryable<>
。因此,如果您希望在添加(并操作它)时能够 "see" a Where()
,则必须 "be" IQueryable<>.Provider
,并让您的 CreateQuery()
,因此您必须创建一个实现 IQueryProvider
的 class(可能还有一个实现 IQueryable<T>
的 class)。
另一种方法(更简单)是进行简单查询 "converter":一种接受 IQueryable<>
和 returns 操作的方法 IQueryable<>
:
var result = c.Query<ClassString>().Where(s => s.Name1.Equals("2")).FixMyQuery();
正如我所说,完整路线很长:
namespace Utilities
{
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data.Entity;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
public class ProxyDbContext : DbContext
{
protected static readonly MethodInfo ProxifySetsMethod = typeof(ProxyDbContext).GetMethod("ProxifySets", BindingFlags.Instance | BindingFlags.NonPublic);
protected static class ProxyDbContexSetter<TContext> where TContext : ProxyDbContext
{
public static readonly Action<TContext> Do = x => { };
static ProxyDbContexSetter()
{
var properties = typeof(TContext).GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy);
ParameterExpression context = Expression.Parameter(typeof(TContext), "context");
FieldInfo manipulatorField = typeof(ProxyDbContext).GetField("Manipulator", BindingFlags.Instance | BindingFlags.Public);
Expression manipulator = Expression.Field(context, manipulatorField);
var sets = new List<Expression>();
foreach (PropertyInfo property in properties)
{
if (property.GetMethod == null)
{
continue;
}
MethodInfo setMethod = property.SetMethod;
if (setMethod != null && !setMethod.IsPublic)
{
continue;
}
Type type = property.PropertyType;
Type entityType = GetIDbSetTypeArgument(type);
if (entityType == null)
{
continue;
}
if (!type.IsAssignableFrom(typeof(DbSet<>).MakeGenericType(entityType)))
{
continue;
}
Type dbSetType = typeof(DbSet<>).MakeGenericType(entityType);
ConstructorInfo constructor = typeof(ProxyDbSet<>)
.MakeGenericType(entityType)
.GetConstructor(new[]
{
dbSetType,
typeof(Func<bool, Expression, Expression>)
});
MemberExpression property2 = Expression.Property(context, property);
BinaryExpression assign = Expression.Assign(property2, Expression.New(constructor, Expression.Convert(property2, dbSetType), manipulator));
sets.Add(assign);
}
Expression<Action<TContext>> lambda = Expression.Lambda<Action<TContext>>(Expression.Block(sets), context);
Do = lambda.Compile();
}
// Gets the T of IDbSetlt;T>
private static Type GetIDbSetTypeArgument(Type type)
{
IEnumerable<Type> interfaces = type.IsInterface ?
new[] { type }.Concat(type.GetInterfaces()) :
type.GetInterfaces();
Type argument = (from x in interfaces
where x.IsGenericType
let gt = x.GetGenericTypeDefinition()
where gt == typeof(IDbSet<>)
select x.GetGenericArguments()[0]).SingleOrDefault();
return argument;
}
}
public readonly Func<bool, Expression, Expression> Manipulator;
/// <summary>
///
/// </summary>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
/// <param name="resetSets">True to have all the DbSet<TEntity> and IDbSet<TEntity> proxified</param>
public ProxyDbContext(Func<bool, Expression, Expression> manipulator, bool resetSets = true)
{
Manipulator = manipulator;
if (resetSets)
{
ProxifySetsMethod.MakeGenericMethod(GetType()).Invoke(this, null);
}
}
/// <summary>
///
/// </summary>
/// <param name="nameOrConnectionString"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
/// <param name="resetSets">True to have all the DbSet<TEntity> and IDbSet<TEntity> proxified</param>
public ProxyDbContext(string nameOrConnectionString, Func<bool, Expression, Expression> manipulator, bool resetSets = true)
: base(nameOrConnectionString)
{
Manipulator = manipulator;
if (resetSets)
{
ProxifySetsMethod.MakeGenericMethod(GetType()).Invoke(this, null);
}
}
protected void ProxifySets<TContext>() where TContext : ProxyDbContext
{
ProxyDbContexSetter<TContext>.Do((TContext)this);
}
public override DbSet<TEntity> Set<TEntity>()
{
return new ProxyDbSet<TEntity>(base.Set<TEntity>(), Manipulator);
}
public override DbSet Set(Type entityType)
{
DbSet set = base.Set(entityType);
ConstructorInfo constructor = typeof(ProxyDbSetNonGeneric<>)
.MakeGenericType(entityType)
.GetConstructor(new[]
{
typeof(DbSet),
typeof(Func<bool, Expression, Expression>)
});
return (DbSet)constructor.Invoke(new object[] { set, Manipulator });
}
}
/// <summary>
/// The DbSet, that is implemented as InternalDbSet<> by EF.
/// </summary>
/// <typeparam name="TEntity"></typeparam>
public class ProxyDbSetNonGeneric<TEntity> : DbSet, IQueryable<TEntity>, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>, IQueryable, IEnumerable, IDbAsyncEnumerable where TEntity : class
{
protected readonly DbSet BaseDbSet;
protected readonly IQueryable<TEntity> ProxyQueryable;
public readonly Func<bool, Expression, Expression> Manipulator;
protected readonly FieldInfo InternalSetField = typeof(DbSet).GetField("_internalSet", BindingFlags.Instance | BindingFlags.NonPublic);
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSetNonGeneric(DbSet baseDbSet, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
IQueryProvider provider = ((IQueryable)baseDbSet).Provider;
ProxyDbProvider proxyDbProvider = new ProxyDbProvider(provider, manipulator);
ProxyQueryable = proxyDbProvider.CreateQuery<TEntity>(((IQueryable)baseDbSet).Expression);
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="proxyQueryable"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSetNonGeneric(DbSet baseDbSet, ProxyQueryable<TEntity> proxyQueryable, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
ProxyQueryable = proxyQueryable;
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
public override object Add(object entity)
{
return BaseDbSet.Add(entity);
}
public override IEnumerable AddRange(IEnumerable entities)
{
return BaseDbSet.AddRange(entities);
}
public override DbQuery AsNoTracking()
{
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.AsNoTracking()), Manipulator);
}
[Obsolete]
public override DbQuery AsStreaming()
{
#pragma warning disable 618
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.AsStreaming()), Manipulator);
#pragma warning restore 618
}
public override object Attach(object entity)
{
return BaseDbSet.Attach(entity);
}
public override object Create(Type derivedEntityType)
{
return BaseDbSet.Create(derivedEntityType);
}
public override object Create()
{
return BaseDbSet.Create();
}
public override object Find(params object[] keyValues)
{
return BaseDbSet.Find(keyValues);
}
public override Task<object> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
{
return BaseDbSet.FindAsync(cancellationToken, keyValues);
}
public override Task<object> FindAsync(params object[] keyValues)
{
return BaseDbSet.FindAsync(keyValues);
}
public override DbQuery Include(string path)
{
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.Include(path)), Manipulator);
}
public override IList Local
{
get
{
return BaseDbSet.Local;
}
}
public override object Remove(object entity)
{
return BaseDbSet.Remove(entity);
}
public override IEnumerable RemoveRange(IEnumerable entities)
{
return BaseDbSet.RemoveRange(entities);
}
public override DbSqlQuery SqlQuery(string sql, params object[] parameters)
{
return BaseDbSet.SqlQuery(sql, parameters);
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return ProxyQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)ProxyQueryable).GetEnumerator();
}
Type IQueryable.ElementType
{
get { return ProxyQueryable.ElementType; }
}
Expression IQueryable.Expression
{
get { return ProxyQueryable.Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return ProxyQueryable.Provider; }
}
IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable<TEntity>)ProxyQueryable).GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable)ProxyQueryable).GetAsyncEnumerator();
}
public override string ToString()
{
return ProxyQueryable.ToString();
}
}
public class ProxyDbSet<TEntity> : DbSet<TEntity>, IQueryable<TEntity>, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>, IQueryable, IEnumerable, IDbAsyncEnumerable where TEntity : class
{
protected readonly DbSet<TEntity> BaseDbSet;
protected readonly IQueryable<TEntity> ProxyQueryable;
public readonly Func<bool, Expression, Expression> Manipulator;
protected readonly FieldInfo InternalSetField = typeof(DbSet<TEntity>).GetField("_internalSet", BindingFlags.Instance | BindingFlags.NonPublic);
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSet(DbSet<TEntity> baseDbSet, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
IQueryProvider provider = ((IQueryable)baseDbSet).Provider;
ProxyDbProvider proxyDbProvider = new ProxyDbProvider(provider, manipulator);
ProxyQueryable = proxyDbProvider.CreateQuery<TEntity>(((IQueryable)baseDbSet).Expression);
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="proxyQueryable"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSet(DbSet<TEntity> baseDbSet, ProxyQueryable<TEntity> proxyQueryable, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
ProxyQueryable = proxyQueryable;
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
public override TEntity Add(TEntity entity)
{
return BaseDbSet.Add(entity);
}
public override IEnumerable<TEntity> AddRange(IEnumerable<TEntity> entities)
{
return BaseDbSet.AddRange(entities);
}
public override DbQuery<TEntity> AsNoTracking()
{
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.AsNoTracking()), Manipulator);
}
[Obsolete]
public override DbQuery<TEntity> AsStreaming()
{
#pragma warning disable 618
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.AsStreaming()), Manipulator);
#pragma warning restore 618
}
public override TEntity Attach(TEntity entity)
{
return BaseDbSet.Attach(entity);
}
public override TDerivedEntity Create<TDerivedEntity>()
{
return BaseDbSet.Create<TDerivedEntity>();
}
public override TEntity Create()
{
return BaseDbSet.Create();
}
public override TEntity Find(params object[] keyValues)
{
return BaseDbSet.Find(keyValues);
}
public override Task<TEntity> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
{
return BaseDbSet.FindAsync(cancellationToken, keyValues);
}
public override Task<TEntity> FindAsync(params object[] keyValues)
{
return BaseDbSet.FindAsync(keyValues);
}
public override DbQuery<TEntity> Include(string path)
{
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.Include(path)), Manipulator);
}
public override ObservableCollection<TEntity> Local
{
get
{
return BaseDbSet.Local;
}
}
public override TEntity Remove(TEntity entity)
{
return BaseDbSet.Remove(entity);
}
public override IEnumerable<TEntity> RemoveRange(IEnumerable<TEntity> entities)
{
return BaseDbSet.RemoveRange(entities);
}
public override DbSqlQuery<TEntity> SqlQuery(string sql, params object[] parameters)
{
return BaseDbSet.SqlQuery(sql, parameters);
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return ProxyQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)ProxyQueryable).GetEnumerator();
}
Type IQueryable.ElementType
{
get { return ProxyQueryable.ElementType; }
}
Expression IQueryable.Expression
{
get { return ProxyQueryable.Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return ProxyQueryable.Provider; }
}
IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable<TEntity>)ProxyQueryable).GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable)ProxyQueryable).GetAsyncEnumerator();
}
public override string ToString()
{
return ProxyQueryable.ToString();
}
// Note that the operator isn't virtual! If you do:
// DbSet<Foo> foo = new ProxyDbSet<Foo>(...)
// DbSet foo2 = (DbSet)foo;
// Then you'll have a non-proxed DbSet!
public static implicit operator ProxyDbSetNonGeneric<TEntity>(ProxyDbSet<TEntity> entry)
{
return new ProxyDbSetNonGeneric<TEntity>((DbSet)entry.BaseDbSet, entry.Manipulator);
}
}
public class ProxyDbProvider : IQueryProvider, IDbAsyncQueryProvider
{
protected readonly IQueryProvider BaseQueryProvider;
public readonly Func<bool, Expression, Expression> Manipulator;
/// <summary>
///
/// </summary>
/// <param name="baseQueryProvider"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbProvider(IQueryProvider baseQueryProvider, Func<bool, Expression, Expression> manipulator)
{
BaseQueryProvider = baseQueryProvider;
Manipulator = manipulator;
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(false, expression) : expression;
IQueryable<TElement> query = BaseQueryProvider.CreateQuery<TElement>(expression2);
IQueryProvider provider = query.Provider;
ProxyDbProvider proxy = provider == BaseQueryProvider ? this : new ProxyDbProvider(provider, Manipulator);
return new ProxyQueryable<TElement>(proxy, query);
}
protected static readonly MethodInfo CreateQueryNonGenericToGenericMethod = typeof(ProxyDbProvider).GetMethod("CreateQueryNonGenericToGeneric", BindingFlags.Static | BindingFlags.NonPublic);
public IQueryable CreateQuery(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(false, expression) : expression;
IQueryable query = BaseQueryProvider.CreateQuery(expression2);
IQueryProvider provider = query.Provider;
ProxyDbProvider proxy = provider == BaseQueryProvider ? this : new ProxyDbProvider(provider, Manipulator);
Type entityType = GetIQueryableTypeArgument(query.GetType());
if (entityType == null)
{
return new ProxyQueryable(proxy, query);
}
else
{
return (IQueryable)CreateQueryNonGenericToGenericMethod.MakeGenericMethod(entityType).Invoke(null, new object[] { proxy, query });
}
}
protected static ProxyQueryable<TElement> CreateQueryNonGenericToGeneric<TElement>(ProxyDbProvider proxy, IQueryable<TElement> query)
{
return new ProxyQueryable<TElement>(proxy, query);
}
public TResult Execute<TResult>(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return BaseQueryProvider.Execute<TResult>(expression2);
}
public object Execute(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return BaseQueryProvider.Execute(expression2);
}
// Gets the T of IQueryablelt;T>
protected static Type GetIQueryableTypeArgument(Type type)
{
IEnumerable<Type> interfaces = type.IsInterface ?
new[] { type }.Concat(type.GetInterfaces()) :
type.GetInterfaces();
Type argument = (from x in interfaces
where x.IsGenericType
let gt = x.GetGenericTypeDefinition()
where gt == typeof(IQueryable<>)
select x.GetGenericArguments()[0]).FirstOrDefault();
return argument;
}
public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
var asyncQueryProvider = BaseQueryProvider as IDbAsyncQueryProvider;
if (asyncQueryProvider == null)
{
throw new NotSupportedException();
}
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return asyncQueryProvider.ExecuteAsync<TResult>(expression2, cancellationToken);
}
public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
{
var asyncQueryProvider = BaseQueryProvider as IDbAsyncQueryProvider;
if (asyncQueryProvider == null)
{
throw new NotSupportedException();
}
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return asyncQueryProvider.ExecuteAsync(expression2, cancellationToken);
}
}
public class ProxyQueryable : IOrderedQueryable, IQueryable, IEnumerable, IDbAsyncEnumerable
{
protected readonly ProxyDbProvider ProxyDbProvider;
protected readonly IQueryable BaseQueryable;
public ProxyQueryable(ProxyDbProvider proxyDbProvider, IQueryable baseQueryable)
{
ProxyDbProvider = proxyDbProvider;
BaseQueryable = baseQueryable;
}
public IEnumerator GetEnumerator()
{
return BaseQueryable.GetEnumerator();
}
public Type ElementType
{
get { return BaseQueryable.ElementType; }
}
public Expression Expression
{
get { return BaseQueryable.Expression; }
}
public IQueryProvider Provider
{
get { return ProxyDbProvider; }
}
public override string ToString()
{
return BaseQueryable.ToString();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
}
public class ProxyQueryable<TElement> : IOrderedQueryable<TElement>, IQueryable<TElement>, IEnumerable<TElement>, IDbAsyncEnumerable<TElement>, IOrderedQueryable, IQueryable, IEnumerable, IDbAsyncEnumerable
{
protected readonly ProxyDbProvider ProxyDbProvider;
protected readonly IQueryable<TElement> BaseQueryable;
public ProxyQueryable(ProxyDbProvider proxyDbProvider, IQueryable<TElement> baseQueryable)
{
ProxyDbProvider = proxyDbProvider;
BaseQueryable = baseQueryable;
}
public IEnumerator<TElement> GetEnumerator()
{
return BaseQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)BaseQueryable).GetEnumerator();
}
public Type ElementType
{
get { return BaseQueryable.ElementType; }
}
public Expression Expression
{
get { return BaseQueryable.Expression; }
}
public IQueryProvider Provider
{
get { return ProxyDbProvider; }
}
public override string ToString()
{
return BaseQueryable.ToString();
}
public IDbAsyncEnumerator<TElement> GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable<TElement>;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
}
}
Expression
s 的操纵器示例(这个将把 .Where(x => something)
转换为 .Where(x => something && something)
:
namespace My
{
using System.Linq;
using System.Linq.Expressions;
public class MyExpressionManipulator : ExpressionVisitor
{
protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (node.Method.DeclaringType == typeof(Queryable) && node.Method.Name == "Where" && node.Arguments.Count == 2)
{
// Transforms all the .Where(x => something) in
// .Where(x => something && something)
if (node.Arguments[1].NodeType == ExpressionType.Quote)
{
UnaryExpression argument1 = (UnaryExpression)node.Arguments[1]; // Expression.Quote
if (argument1.Operand.NodeType == ExpressionType.Lambda)
{
LambdaExpression argument1lambda = (LambdaExpression)argument1.Operand;
// Important: at each step you'll reevalute the
// full expression! Try to not replace twice
// the expression!
// So if you have a query like:
// var res = ctx.Where(x => true).Where(x => true).Select(x => 1)
// the first time you'll visit
// ctx.Where(x => true)
// and you'll obtain
// ctx.Where(x => true && true)
// the second time you'll visit
// ctx.Where(x => true && true).Where(x => true)
// and you want to obtain
// ctx.Where(x => true && true).Where(x => true && true)
// and not
// ctx.Where(x => (true && true) && (true && true)).Where(x => true && true)
if (argument1lambda.Body.NodeType != ExpressionType.AndAlso)
{
var arguments = new Expression[node.Arguments.Count];
node.Arguments.CopyTo(arguments, 0);
arguments[1] = Expression.Quote(Expression.Lambda(Expression.AndAlso(argument1lambda.Body, argument1lambda.Body), argument1lambda.Parameters));
MethodCallExpression node2 = Expression.Call(node.Object, node.Method, arguments);
node = node2;
}
}
}
}
return base.VisitMethodCall(node);
}
}
}
现在...如何使用它?最好的方法是不是从 DbContext
而是从 ProxyDbContext
派生上下文(在本例中为 Model1),如下所示:
public partial class Model1 : ProxyDbContext
{
public Model1()
: base("name=Model1", Manipulate)
{
}
/// <summary>
///
/// </summary>
/// <param name="executing">true: the returned Expression will be executed directly, false: the returned expression will be returned as IQueryable<>.</param>
/// <param name="expression"></param>
/// <returns></returns>
private static Expression Manipulate(bool executing, Expression expression)
{
// See the annotation about reexecuting the same visitor
// multiple times in MyExpressionManipulator().Visit .
// By executing the visitor only on executing == true,
// and simply return expression; on executing == false,
// you have the guarantee that an expression won't be
// manipulated multiple times.
// As written now, the expression will be manipulated
// multiple times.
return new MyExpressionManipulator().Visit(expression);
}
// Some tables
public virtual DbSet<Parent> Parent { get; set; }
public virtual IDbSet<Child> Child { get; set; }
然后就很透明了:
// Where Model1: class Model1 : ProxyDbContext {}
using (var ctx = new Model1())
{
// Your query
var res = ctx.Parent.Where(x => x.Id > 100);
// The query is automatically manipulated by your Manipulate method
}
另一种方法 无需从 ProxyDbContext
进行子class:
// Where Model1: class Model1 : ProxyDbContext {}
using (var ctx = new Model1())
{
Func<Expression, Expression> manipulator = new MyExpressionManipulator().Visit;
ctx.Parent = new ProxyDbSet<Parent>(ctx.Parent, manipulator);
ctx.Child = new ProxyDbSet<Child>(ctx.Child, manipulator);
// Your query
var res = ctx.Parent.Where(x => x.Id > 100);
}
ProxyDbContext<>
将您上下文中的 DbSet<>
/IDbSet<>
替换为某些 ProxyDbSet<>
。
在第二个示例中,此操作是明确完成的,但请注意,您可以创建一个方法来执行此操作,或为您的上下文创建一个工厂(一种静态方法,returns 具有各种 DbSet<>
"proxied"),或者您可以将代理放在上下文的构造函数中(因为 DbSet<>
的 "original" 初始化发生在 DbContext
的构造函数中,并且你的上下文的构造函数的主体在这之后执行),或者你可以创建多个子class你的上下文,每个都有以不同方式代理的构造函数...
请注意,第一种方法 (subclassing ProxyDbContext<>
) "fixes" Set<>
/Set
方法,否则您将不得不修复自己从 ProxyDbContext<>
.
复制这两个方法的重载代码
class Program
{
static void Main(string[] args)
{
var c = new SampleClass<ClassString>();
c.ClassStrings.Add(new ClassString{ Name1 = "1", Name2 = "1"});
c.ClassStrings.Add(new ClassString{ Name1 = "2", Name2 = "2"});
var result = c.Query<ClassString>().Where(s => s.Name1.Equals("2"));
Console.WriteLine(result);
Console.ReadLine();
}
}
public class ClassString
{
public string Name1 { get; set; }
public string Name2 { get; set; }
}
public interface ISampleQ
{
IQueryable<T> Query<T>() where T: class , new();
}
public class SampleClass<X> : ISampleQ
{
public List<X> ClassStrings { get; private set; }
public SampleClass()
{
ClassStrings = new List<X>();
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the WHERE expression from here.
return new EnumerableQuery<T>((IEnumerable<T>) ClassStrings);
}
}
我调查了这个 solution1, solution2 and solution3 似乎不适用于我的问题。由于 where 子句是在外部定义的,它是 class 的接口。如何获取查询方法中的表达式?因为没有变量被传递。
目的,我想要检索并注入回目标(即作为 IQueryable 的 DBContext)。因为我们有一个像 ISampleQ 这样的通用接口。
添加了新的示例代码但场景相同:
internal class Program
{
private static void Main(string[] args)
{
var oracleDbContext = new OracleDbContext();
var result = oracleDbContext.Query<Person>().Where(person => person.Name.Equals("username"));
Console.WriteLine();
Console.ReadLine();
}
}
public interface IGenericQuery
{
IQueryable<T> Query<T>() where T : class , new();
}
public class OracleDbContext : IGenericQuery
{
public OracleDbContext()
{
//Will hold all oracle operations here. For brevity, only
//Query are exposed.
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the where predicate here. Since the where was defined outside of the
//class. I want to retrieve since the IQueryable<T> is generic to both class
//OracleDbContext and MssqlDbContext. I want to re-inject the where or add
//new expression before calling.
//
//For eg.
//oracleDbContext.Query<T>(where clause from here)
return null;
}
}
public class MssqlDbContext : IGenericQuery
{
public MssqlDbContext()
{
//Will hold all MSSQL operations here. For brevity, only
//Query are exposed.
}
public IQueryable<T> Query<T>() where T : class, new()
{
//Get the where predicate here.
return null;
}
}
public class Person
{
public int Id { get; set; }
public int Name { get; set; }
}
它相当复杂...现在...Queryable.Where()
以这种方式工作:
public static IQueryable<TSource> Where<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate)
{
return source.Provider.CreateQuery<TSource>(Expression.Call(null, ...
所以 Queryable.Where
调用 source.Provider.CreateQuery()
返回一个新的 IQueryable<>
。因此,如果您希望在添加(并操作它)时能够 "see" a Where()
,则必须 "be" IQueryable<>.Provider
,并让您的 CreateQuery()
,因此您必须创建一个实现 IQueryProvider
的 class(可能还有一个实现 IQueryable<T>
的 class)。
另一种方法(更简单)是进行简单查询 "converter":一种接受 IQueryable<>
和 returns 操作的方法 IQueryable<>
:
var result = c.Query<ClassString>().Where(s => s.Name1.Equals("2")).FixMyQuery();
正如我所说,完整路线很长:
namespace Utilities
{
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data.Entity;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
public class ProxyDbContext : DbContext
{
protected static readonly MethodInfo ProxifySetsMethod = typeof(ProxyDbContext).GetMethod("ProxifySets", BindingFlags.Instance | BindingFlags.NonPublic);
protected static class ProxyDbContexSetter<TContext> where TContext : ProxyDbContext
{
public static readonly Action<TContext> Do = x => { };
static ProxyDbContexSetter()
{
var properties = typeof(TContext).GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy);
ParameterExpression context = Expression.Parameter(typeof(TContext), "context");
FieldInfo manipulatorField = typeof(ProxyDbContext).GetField("Manipulator", BindingFlags.Instance | BindingFlags.Public);
Expression manipulator = Expression.Field(context, manipulatorField);
var sets = new List<Expression>();
foreach (PropertyInfo property in properties)
{
if (property.GetMethod == null)
{
continue;
}
MethodInfo setMethod = property.SetMethod;
if (setMethod != null && !setMethod.IsPublic)
{
continue;
}
Type type = property.PropertyType;
Type entityType = GetIDbSetTypeArgument(type);
if (entityType == null)
{
continue;
}
if (!type.IsAssignableFrom(typeof(DbSet<>).MakeGenericType(entityType)))
{
continue;
}
Type dbSetType = typeof(DbSet<>).MakeGenericType(entityType);
ConstructorInfo constructor = typeof(ProxyDbSet<>)
.MakeGenericType(entityType)
.GetConstructor(new[]
{
dbSetType,
typeof(Func<bool, Expression, Expression>)
});
MemberExpression property2 = Expression.Property(context, property);
BinaryExpression assign = Expression.Assign(property2, Expression.New(constructor, Expression.Convert(property2, dbSetType), manipulator));
sets.Add(assign);
}
Expression<Action<TContext>> lambda = Expression.Lambda<Action<TContext>>(Expression.Block(sets), context);
Do = lambda.Compile();
}
// Gets the T of IDbSetlt;T>
private static Type GetIDbSetTypeArgument(Type type)
{
IEnumerable<Type> interfaces = type.IsInterface ?
new[] { type }.Concat(type.GetInterfaces()) :
type.GetInterfaces();
Type argument = (from x in interfaces
where x.IsGenericType
let gt = x.GetGenericTypeDefinition()
where gt == typeof(IDbSet<>)
select x.GetGenericArguments()[0]).SingleOrDefault();
return argument;
}
}
public readonly Func<bool, Expression, Expression> Manipulator;
/// <summary>
///
/// </summary>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
/// <param name="resetSets">True to have all the DbSet<TEntity> and IDbSet<TEntity> proxified</param>
public ProxyDbContext(Func<bool, Expression, Expression> manipulator, bool resetSets = true)
{
Manipulator = manipulator;
if (resetSets)
{
ProxifySetsMethod.MakeGenericMethod(GetType()).Invoke(this, null);
}
}
/// <summary>
///
/// </summary>
/// <param name="nameOrConnectionString"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
/// <param name="resetSets">True to have all the DbSet<TEntity> and IDbSet<TEntity> proxified</param>
public ProxyDbContext(string nameOrConnectionString, Func<bool, Expression, Expression> manipulator, bool resetSets = true)
: base(nameOrConnectionString)
{
Manipulator = manipulator;
if (resetSets)
{
ProxifySetsMethod.MakeGenericMethod(GetType()).Invoke(this, null);
}
}
protected void ProxifySets<TContext>() where TContext : ProxyDbContext
{
ProxyDbContexSetter<TContext>.Do((TContext)this);
}
public override DbSet<TEntity> Set<TEntity>()
{
return new ProxyDbSet<TEntity>(base.Set<TEntity>(), Manipulator);
}
public override DbSet Set(Type entityType)
{
DbSet set = base.Set(entityType);
ConstructorInfo constructor = typeof(ProxyDbSetNonGeneric<>)
.MakeGenericType(entityType)
.GetConstructor(new[]
{
typeof(DbSet),
typeof(Func<bool, Expression, Expression>)
});
return (DbSet)constructor.Invoke(new object[] { set, Manipulator });
}
}
/// <summary>
/// The DbSet, that is implemented as InternalDbSet<> by EF.
/// </summary>
/// <typeparam name="TEntity"></typeparam>
public class ProxyDbSetNonGeneric<TEntity> : DbSet, IQueryable<TEntity>, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>, IQueryable, IEnumerable, IDbAsyncEnumerable where TEntity : class
{
protected readonly DbSet BaseDbSet;
protected readonly IQueryable<TEntity> ProxyQueryable;
public readonly Func<bool, Expression, Expression> Manipulator;
protected readonly FieldInfo InternalSetField = typeof(DbSet).GetField("_internalSet", BindingFlags.Instance | BindingFlags.NonPublic);
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSetNonGeneric(DbSet baseDbSet, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
IQueryProvider provider = ((IQueryable)baseDbSet).Provider;
ProxyDbProvider proxyDbProvider = new ProxyDbProvider(provider, manipulator);
ProxyQueryable = proxyDbProvider.CreateQuery<TEntity>(((IQueryable)baseDbSet).Expression);
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="proxyQueryable"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSetNonGeneric(DbSet baseDbSet, ProxyQueryable<TEntity> proxyQueryable, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
ProxyQueryable = proxyQueryable;
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
public override object Add(object entity)
{
return BaseDbSet.Add(entity);
}
public override IEnumerable AddRange(IEnumerable entities)
{
return BaseDbSet.AddRange(entities);
}
public override DbQuery AsNoTracking()
{
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.AsNoTracking()), Manipulator);
}
[Obsolete]
public override DbQuery AsStreaming()
{
#pragma warning disable 618
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.AsStreaming()), Manipulator);
#pragma warning restore 618
}
public override object Attach(object entity)
{
return BaseDbSet.Attach(entity);
}
public override object Create(Type derivedEntityType)
{
return BaseDbSet.Create(derivedEntityType);
}
public override object Create()
{
return BaseDbSet.Create();
}
public override object Find(params object[] keyValues)
{
return BaseDbSet.Find(keyValues);
}
public override Task<object> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
{
return BaseDbSet.FindAsync(cancellationToken, keyValues);
}
public override Task<object> FindAsync(params object[] keyValues)
{
return BaseDbSet.FindAsync(keyValues);
}
public override DbQuery Include(string path)
{
return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.Include(path)), Manipulator);
}
public override IList Local
{
get
{
return BaseDbSet.Local;
}
}
public override object Remove(object entity)
{
return BaseDbSet.Remove(entity);
}
public override IEnumerable RemoveRange(IEnumerable entities)
{
return BaseDbSet.RemoveRange(entities);
}
public override DbSqlQuery SqlQuery(string sql, params object[] parameters)
{
return BaseDbSet.SqlQuery(sql, parameters);
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return ProxyQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)ProxyQueryable).GetEnumerator();
}
Type IQueryable.ElementType
{
get { return ProxyQueryable.ElementType; }
}
Expression IQueryable.Expression
{
get { return ProxyQueryable.Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return ProxyQueryable.Provider; }
}
IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable<TEntity>)ProxyQueryable).GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable)ProxyQueryable).GetAsyncEnumerator();
}
public override string ToString()
{
return ProxyQueryable.ToString();
}
}
public class ProxyDbSet<TEntity> : DbSet<TEntity>, IQueryable<TEntity>, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>, IQueryable, IEnumerable, IDbAsyncEnumerable where TEntity : class
{
protected readonly DbSet<TEntity> BaseDbSet;
protected readonly IQueryable<TEntity> ProxyQueryable;
public readonly Func<bool, Expression, Expression> Manipulator;
protected readonly FieldInfo InternalSetField = typeof(DbSet<TEntity>).GetField("_internalSet", BindingFlags.Instance | BindingFlags.NonPublic);
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSet(DbSet<TEntity> baseDbSet, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
IQueryProvider provider = ((IQueryable)baseDbSet).Provider;
ProxyDbProvider proxyDbProvider = new ProxyDbProvider(provider, manipulator);
ProxyQueryable = proxyDbProvider.CreateQuery<TEntity>(((IQueryable)baseDbSet).Expression);
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
/// <summary>
///
/// </summary>
/// <param name="baseDbSet"></param>
/// <param name="proxyQueryable"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbSet(DbSet<TEntity> baseDbSet, ProxyQueryable<TEntity> proxyQueryable, Func<bool, Expression, Expression> manipulator)
{
BaseDbSet = baseDbSet;
ProxyQueryable = proxyQueryable;
Manipulator = manipulator;
if (InternalSetField != null)
{
InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
}
}
public override TEntity Add(TEntity entity)
{
return BaseDbSet.Add(entity);
}
public override IEnumerable<TEntity> AddRange(IEnumerable<TEntity> entities)
{
return BaseDbSet.AddRange(entities);
}
public override DbQuery<TEntity> AsNoTracking()
{
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.AsNoTracking()), Manipulator);
}
[Obsolete]
public override DbQuery<TEntity> AsStreaming()
{
#pragma warning disable 618
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.AsStreaming()), Manipulator);
#pragma warning restore 618
}
public override TEntity Attach(TEntity entity)
{
return BaseDbSet.Attach(entity);
}
public override TDerivedEntity Create<TDerivedEntity>()
{
return BaseDbSet.Create<TDerivedEntity>();
}
public override TEntity Create()
{
return BaseDbSet.Create();
}
public override TEntity Find(params object[] keyValues)
{
return BaseDbSet.Find(keyValues);
}
public override Task<TEntity> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
{
return BaseDbSet.FindAsync(cancellationToken, keyValues);
}
public override Task<TEntity> FindAsync(params object[] keyValues)
{
return BaseDbSet.FindAsync(keyValues);
}
public override DbQuery<TEntity> Include(string path)
{
return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.Include(path)), Manipulator);
}
public override ObservableCollection<TEntity> Local
{
get
{
return BaseDbSet.Local;
}
}
public override TEntity Remove(TEntity entity)
{
return BaseDbSet.Remove(entity);
}
public override IEnumerable<TEntity> RemoveRange(IEnumerable<TEntity> entities)
{
return BaseDbSet.RemoveRange(entities);
}
public override DbSqlQuery<TEntity> SqlQuery(string sql, params object[] parameters)
{
return BaseDbSet.SqlQuery(sql, parameters);
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return ProxyQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)ProxyQueryable).GetEnumerator();
}
Type IQueryable.ElementType
{
get { return ProxyQueryable.ElementType; }
}
Expression IQueryable.Expression
{
get { return ProxyQueryable.Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return ProxyQueryable.Provider; }
}
IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable<TEntity>)ProxyQueryable).GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return ((IDbAsyncEnumerable)ProxyQueryable).GetAsyncEnumerator();
}
public override string ToString()
{
return ProxyQueryable.ToString();
}
// Note that the operator isn't virtual! If you do:
// DbSet<Foo> foo = new ProxyDbSet<Foo>(...)
// DbSet foo2 = (DbSet)foo;
// Then you'll have a non-proxed DbSet!
public static implicit operator ProxyDbSetNonGeneric<TEntity>(ProxyDbSet<TEntity> entry)
{
return new ProxyDbSetNonGeneric<TEntity>((DbSet)entry.BaseDbSet, entry.Manipulator);
}
}
public class ProxyDbProvider : IQueryProvider, IDbAsyncQueryProvider
{
protected readonly IQueryProvider BaseQueryProvider;
public readonly Func<bool, Expression, Expression> Manipulator;
/// <summary>
///
/// </summary>
/// <param name="baseQueryProvider"></param>
/// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
public ProxyDbProvider(IQueryProvider baseQueryProvider, Func<bool, Expression, Expression> manipulator)
{
BaseQueryProvider = baseQueryProvider;
Manipulator = manipulator;
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(false, expression) : expression;
IQueryable<TElement> query = BaseQueryProvider.CreateQuery<TElement>(expression2);
IQueryProvider provider = query.Provider;
ProxyDbProvider proxy = provider == BaseQueryProvider ? this : new ProxyDbProvider(provider, Manipulator);
return new ProxyQueryable<TElement>(proxy, query);
}
protected static readonly MethodInfo CreateQueryNonGenericToGenericMethod = typeof(ProxyDbProvider).GetMethod("CreateQueryNonGenericToGeneric", BindingFlags.Static | BindingFlags.NonPublic);
public IQueryable CreateQuery(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(false, expression) : expression;
IQueryable query = BaseQueryProvider.CreateQuery(expression2);
IQueryProvider provider = query.Provider;
ProxyDbProvider proxy = provider == BaseQueryProvider ? this : new ProxyDbProvider(provider, Manipulator);
Type entityType = GetIQueryableTypeArgument(query.GetType());
if (entityType == null)
{
return new ProxyQueryable(proxy, query);
}
else
{
return (IQueryable)CreateQueryNonGenericToGenericMethod.MakeGenericMethod(entityType).Invoke(null, new object[] { proxy, query });
}
}
protected static ProxyQueryable<TElement> CreateQueryNonGenericToGeneric<TElement>(ProxyDbProvider proxy, IQueryable<TElement> query)
{
return new ProxyQueryable<TElement>(proxy, query);
}
public TResult Execute<TResult>(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return BaseQueryProvider.Execute<TResult>(expression2);
}
public object Execute(Expression expression)
{
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return BaseQueryProvider.Execute(expression2);
}
// Gets the T of IQueryablelt;T>
protected static Type GetIQueryableTypeArgument(Type type)
{
IEnumerable<Type> interfaces = type.IsInterface ?
new[] { type }.Concat(type.GetInterfaces()) :
type.GetInterfaces();
Type argument = (from x in interfaces
where x.IsGenericType
let gt = x.GetGenericTypeDefinition()
where gt == typeof(IQueryable<>)
select x.GetGenericArguments()[0]).FirstOrDefault();
return argument;
}
public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
var asyncQueryProvider = BaseQueryProvider as IDbAsyncQueryProvider;
if (asyncQueryProvider == null)
{
throw new NotSupportedException();
}
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return asyncQueryProvider.ExecuteAsync<TResult>(expression2, cancellationToken);
}
public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
{
var asyncQueryProvider = BaseQueryProvider as IDbAsyncQueryProvider;
if (asyncQueryProvider == null)
{
throw new NotSupportedException();
}
Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
return asyncQueryProvider.ExecuteAsync(expression2, cancellationToken);
}
}
public class ProxyQueryable : IOrderedQueryable, IQueryable, IEnumerable, IDbAsyncEnumerable
{
protected readonly ProxyDbProvider ProxyDbProvider;
protected readonly IQueryable BaseQueryable;
public ProxyQueryable(ProxyDbProvider proxyDbProvider, IQueryable baseQueryable)
{
ProxyDbProvider = proxyDbProvider;
BaseQueryable = baseQueryable;
}
public IEnumerator GetEnumerator()
{
return BaseQueryable.GetEnumerator();
}
public Type ElementType
{
get { return BaseQueryable.ElementType; }
}
public Expression Expression
{
get { return BaseQueryable.Expression; }
}
public IQueryProvider Provider
{
get { return ProxyDbProvider; }
}
public override string ToString()
{
return BaseQueryable.ToString();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
}
public class ProxyQueryable<TElement> : IOrderedQueryable<TElement>, IQueryable<TElement>, IEnumerable<TElement>, IDbAsyncEnumerable<TElement>, IOrderedQueryable, IQueryable, IEnumerable, IDbAsyncEnumerable
{
protected readonly ProxyDbProvider ProxyDbProvider;
protected readonly IQueryable<TElement> BaseQueryable;
public ProxyQueryable(ProxyDbProvider proxyDbProvider, IQueryable<TElement> baseQueryable)
{
ProxyDbProvider = proxyDbProvider;
BaseQueryable = baseQueryable;
}
public IEnumerator<TElement> GetEnumerator()
{
return BaseQueryable.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)BaseQueryable).GetEnumerator();
}
public Type ElementType
{
get { return BaseQueryable.ElementType; }
}
public Expression Expression
{
get { return BaseQueryable.Expression; }
}
public IQueryProvider Provider
{
get { return ProxyDbProvider; }
}
public override string ToString()
{
return BaseQueryable.ToString();
}
public IDbAsyncEnumerator<TElement> GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable<TElement>;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable;
if (asyncEnumerator == null)
{
throw new NotSupportedException();
}
return asyncEnumerator.GetAsyncEnumerator();
}
}
}
Expression
s 的操纵器示例(这个将把 .Where(x => something)
转换为 .Where(x => something && something)
:
namespace My
{
using System.Linq;
using System.Linq.Expressions;
public class MyExpressionManipulator : ExpressionVisitor
{
protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (node.Method.DeclaringType == typeof(Queryable) && node.Method.Name == "Where" && node.Arguments.Count == 2)
{
// Transforms all the .Where(x => something) in
// .Where(x => something && something)
if (node.Arguments[1].NodeType == ExpressionType.Quote)
{
UnaryExpression argument1 = (UnaryExpression)node.Arguments[1]; // Expression.Quote
if (argument1.Operand.NodeType == ExpressionType.Lambda)
{
LambdaExpression argument1lambda = (LambdaExpression)argument1.Operand;
// Important: at each step you'll reevalute the
// full expression! Try to not replace twice
// the expression!
// So if you have a query like:
// var res = ctx.Where(x => true).Where(x => true).Select(x => 1)
// the first time you'll visit
// ctx.Where(x => true)
// and you'll obtain
// ctx.Where(x => true && true)
// the second time you'll visit
// ctx.Where(x => true && true).Where(x => true)
// and you want to obtain
// ctx.Where(x => true && true).Where(x => true && true)
// and not
// ctx.Where(x => (true && true) && (true && true)).Where(x => true && true)
if (argument1lambda.Body.NodeType != ExpressionType.AndAlso)
{
var arguments = new Expression[node.Arguments.Count];
node.Arguments.CopyTo(arguments, 0);
arguments[1] = Expression.Quote(Expression.Lambda(Expression.AndAlso(argument1lambda.Body, argument1lambda.Body), argument1lambda.Parameters));
MethodCallExpression node2 = Expression.Call(node.Object, node.Method, arguments);
node = node2;
}
}
}
}
return base.VisitMethodCall(node);
}
}
}
现在...如何使用它?最好的方法是不是从 DbContext
而是从 ProxyDbContext
派生上下文(在本例中为 Model1),如下所示:
public partial class Model1 : ProxyDbContext
{
public Model1()
: base("name=Model1", Manipulate)
{
}
/// <summary>
///
/// </summary>
/// <param name="executing">true: the returned Expression will be executed directly, false: the returned expression will be returned as IQueryable<>.</param>
/// <param name="expression"></param>
/// <returns></returns>
private static Expression Manipulate(bool executing, Expression expression)
{
// See the annotation about reexecuting the same visitor
// multiple times in MyExpressionManipulator().Visit .
// By executing the visitor only on executing == true,
// and simply return expression; on executing == false,
// you have the guarantee that an expression won't be
// manipulated multiple times.
// As written now, the expression will be manipulated
// multiple times.
return new MyExpressionManipulator().Visit(expression);
}
// Some tables
public virtual DbSet<Parent> Parent { get; set; }
public virtual IDbSet<Child> Child { get; set; }
然后就很透明了:
// Where Model1: class Model1 : ProxyDbContext {}
using (var ctx = new Model1())
{
// Your query
var res = ctx.Parent.Where(x => x.Id > 100);
// The query is automatically manipulated by your Manipulate method
}
另一种方法 无需从 ProxyDbContext
进行子class:
// Where Model1: class Model1 : ProxyDbContext {}
using (var ctx = new Model1())
{
Func<Expression, Expression> manipulator = new MyExpressionManipulator().Visit;
ctx.Parent = new ProxyDbSet<Parent>(ctx.Parent, manipulator);
ctx.Child = new ProxyDbSet<Child>(ctx.Child, manipulator);
// Your query
var res = ctx.Parent.Where(x => x.Id > 100);
}
ProxyDbContext<>
将您上下文中的 DbSet<>
/IDbSet<>
替换为某些 ProxyDbSet<>
。
在第二个示例中,此操作是明确完成的,但请注意,您可以创建一个方法来执行此操作,或为您的上下文创建一个工厂(一种静态方法,returns 具有各种 DbSet<>
"proxied"),或者您可以将代理放在上下文的构造函数中(因为 DbSet<>
的 "original" 初始化发生在 DbContext
的构造函数中,并且你的上下文的构造函数的主体在这之后执行),或者你可以创建多个子class你的上下文,每个都有以不同方式代理的构造函数...
请注意,第一种方法 (subclassing ProxyDbContext<>
) "fixes" Set<>
/Set
方法,否则您将不得不修复自己从 ProxyDbContext<>
.