如何: LINQ 到 SQL 转换 






4.63/5 (10投票s)
一篇关于 LINQ to SQL 转换的文章。
引言
“.NET”框架的 v3.5 版本包含大量新增和增强的技术。LINQ(语言集成查询)在我看来是 v3.5 版本中最重要的新技术。微软实现了一套库,用于将 LINQ 表达式树转换为 SQL 语句,并将其命名为 DLINQ。DLINQ 是一个非常出色的作品,但遗憾的是,它仅适用于 SQL Server 2000 和 2005。
背景
本文(以及后续文章)的目标是演示如何将 LINQ 表达式树转换为 SQL 语句,这些语句可以针对多个 RDBMS 系统执行,而不仅仅是微软的 SQL Server 产品。我至少知道有另外一套出色的文章,尤其是 WaywardWeblog 上的这篇文章,演示了如何执行这种转换。我使用了 WaywardWeblog 文章中引入的两个组件,即部分求值器和表达式树遍历器。但是,在我上次检查时,这些文章并未演示如何
- 正确且全面地翻译具有有效 SQL 翻译的二元和一元表达式。
- 翻译具有 SQL 等效项的函数调用(例如 customer.FirstName.ToUpper())。
- 实现 GroupBy。
- 实现 IQueryable方法ANY、ALL、COUNT、AVERAGE等。
- 参数化查询,而不是在 SQL 转换中嵌入常量。
- 缓存先前翻译的表达式树。
- 可能不使用 MARS。
此外,我想以最简单、最直接的方式执行转换(当然,这在一定程度上是主观品味的问题,但我希望您在阅读完这些文章后会同意我的观点)。因此,您会发现这里采用的方法与您在其他地方可能找到的方法有显著差异。
绑定器 (The Binder)
创建 LINQ to SQL 翻译器是一项艰巨的任务,在一篇文章中无法涵盖。因此,在本文中,我将只讨论我的实现中使用的一个类——Binder。这个类将说明许多有趣的概念,但仍然足够易于理解,以免让读者“淹没”其中。
Binder 是一个类,它接收一个 DbDataReader 并将该读取器中的值赋给一个给定类的已实例化对象。
我听到一些人要窒息了,所以这里有一个例子来帮助您消化这个概念。
假设我们有一个 LINQ 查询,如下所示
var customers = from customer in customers
                where customer.City == city
                select new { Name = customer.ContactName, 
                             Phone = customer.Phone };
这将转换为以下 SQL 语句
SELECT t0.ContactName, t0.Phone
FROM dbo.Customers AS t0
WHERE (t0.City = @p0)
然后,我们需要创建一个命令,相应地填充参数集合(即,在这种情况下,为 city 参数提供一个值),然后执行该命令并检索一个包含两个字段的 DbDataReader:ContactName 和 Phone。
然后,Binder 将负责创建一个有两个属性(Name 和 Phone)的匿名类型,我们将分别从 DbDataReader 中检索到的值 ContactName 和 Phone 赋给它们。
血腥的细节
上面的 LINQ 查询将产生以下表达式
.Where(customer => (customer.City = value(LinqTest.NorthwindLinq+<>c__DisplayClass1).city))
.Select(customer => new <>f__AnonymousType0`2(Name = customer.ContactName, 
                                              Phone = customer.Phone))
为了 Binder 的目的,我们只对以下 Lambda 表达式感兴趣
customer => new <>f__AnonymousType0`2(Name = customer.ContactName,       
                                      Phone = customer.Phone)
这本质上意味着
- 给定一个名为 customer、类型为Customer的参数
- 创建一个类型为 <>f__AnonymousType0`2的实例
- 同时,将 customer.ContactName的值赋给Name,将customer.Phone的值赋给Phone
- 您就完成了。
customer => new <>f__AnonymousType0`2(Name = customer.ContactName, 
                                      Phone = customer.Phone)
customer => new <>f__AnonymousType0`2(Name = customer.ContactName, 
                                      Phone = customer.Phone)
customer => new <>f__AnonymousType0`2(Name = customer.ContactName, 
                                      Phone = customer.Phone)
听起来很简单,但像往常一样,细节决定成败。
您还记得上面提到的,为了 Binder 的目的,我们得到了一个包含两个字段的 DbDataReader:ContactName 和 Phone,也就是说,我们没有一个名为 customer、类型为 Customer 且带有 ContactName 和 Phone 两个属性的参数。那么,该怎么办?
有人告诉我,两点之间最短的路径是直线,那么为什么不修改上面的 Lambda 表达式,使其从名为 reader、类型为 DbDataReader 的参数中获取值呢?
换句话说,我们想把这个变成
customer => new <>f__AnonymousType0`2(Name = customer.ContactName, 
                                      Phone = customer.Phone)
变成这样
reader => new <>f__AnonymousType0`2(Name = reader.GetString(0), 
                                    Phone = reader.GetString(1))
如果您在问自己是否已经完成了,那么答案是“否”。
我们有三个问题(至少有这么多)
- 我们如何知道 reader.GetString(0)获得的是ContactName?
- 我们如何知道应该调用 reader.GetString(0)而不是reader.GetInt16(0)或其他reader.Getxxx方法?
- 如果我们调用 reader.Getxxx并且值为null,会发生什么?(答案:您会遇到错误。)
事实证明,第三个问题是最容易解决的。我们想要一个如下所示的 Lambda 表达式
reader => new <>f__AnonymousType0`2(Name = IIF(Not(reader.IsDBNull(0)), 
                                                   reader.GetString(0), Convert(null)), 
                                    Phone = IIF(Not(reader.IsDBNull(1)), 
                                                    reader.GetString(1), Convert(null)))
这虽然说起来有点长,但最终我们只是在说
IF NOT reader.IsDBNull(0)) Then
     Name = reader.GetString(0)
ELSE
         Name = NULL
END IF
同样,对于 phone 也是如此。
不幸的是,为了解决前两个问题,我们必须深入研究代码,这将在下一篇文章中讨论。对于那些等不及的读者,这里是 Binder 类的完整列表
private class Binder : ExpressionVisitor {
    private readonly LambdaExpression selector = null;
    private readonly LambdaExpression binderLambda = null;
    private readonly Delegate binderMethod = null;
    private readonly Dictionary<string,> columnPositions = new Dictionary<string,>();
    private readonly ParameterExpression reader = 
            Expression.Parameter(typeof(DbDataReader), "reader");
    private static readonly MethodInfo getBoolean = 
            typeof(DbDataReader).GetMethod("GetBoolean");
    private static readonly MethodInfo getByte = 
            typeof(DbDataReader).GetMethod("GetByte");
    private static readonly MethodInfo getChar = 
            typeof(DbDataReader).GetMethod("GetChar");
    private static readonly MethodInfo getDateTime = 
            typeof(DbDataReader).GetMethod("GetDateTime");
    private static readonly MethodInfo getDecimal = 
            typeof(DbDataReader).GetMethod("GetDecimal");
    private static readonly MethodInfo getDouble = 
            typeof(DbDataReader).GetMethod("GetDouble");
    private static readonly MethodInfo getGUID = 
            typeof(DbDataReader).GetMethod("GetGuid");
    private static readonly MethodInfo getInt16 = 
            typeof(DbDataReader).GetMethod("GetInt16");
    private static readonly MethodInfo getInt32 = 
            typeof(DbDataReader).GetMethod("GetInt32");
    private static readonly MethodInfo getInt64 = 
            typeof(DbDataReader).GetMethod("GetInt64");
    private static readonly MethodInfo getString = 
            typeof(DbDataReader).GetMethod("GetString");
    private static readonly MethodInfo getValue = 
            typeof(DbDataReader).GetMethod("GetValue");
    public Delegate BinderMethod {
        get {
            return binderMethod;
        }
    }
    public Binder(LambdaExpression selector) {
        this.selector = selector;
        if (selector.Body.NodeType != ExpressionType.Parameter) {
            binderLambda = Expression.Lambda(((LambdaExpression)this.Visit(selector)).Body,
                                          reader);
        }
        else {
            binderLambda = GetBindingLambda(selector);
        }
        binderMethod = binderLambda.Compile();
    }
    protected override Expression VisitMethodCall(MethodCallExpression m) {
        switch (m.Method.Name) {
            case "Count":
            case "Average":
            case "Max":
            case "Min":
            case "Sum":
                break;
            default:
                return base.VisitMethodCall(m);
        }
        Debug.Assert(m.Arguments.Count > 0);
        Debug.Assert(m.Arguments[0].NodeType == ExpressionType.MemberAccess);
        if (GetAccessedType(m.Arguments[0] as MemberExpression) != 
                                           selector.Parameters[0].Type) {
            return m;
        }
        int columnPosition = GetColumnPosition(m.ToString());
        return GetColumnReader(m, columnPosition);
    }
    protected override Expression VisitMemberAccess(MemberExpression m) {
        Debug.Assert(selector.Parameters.Count == 1);
        if (GetAccessedType(m) != selector.Parameters[0].Type) {
            return m;
        }
        int columnPosition = GetColumnPosition(m);
        return GetColumnReader(m, columnPosition);
    }
    private Expression GetColumnReader(Expression m, int columnPosition) {
        var column = Expression.Constant(columnPosition, typeof(int));
        var callExpression = GetCallMethod(m, column);
        var isDbNull = Expression.Call(reader,
                                       typeof(DbDataReader).GetMethod("IsDBNull"),
                                       column);
        var conditionalExpression =
            Expression.Condition(Expression.Not(isDbNull),
                                 callExpression,
                                 Expression.Convert(Expression.Constant(null),
                                                     callExpression.Type));
        return conditionalExpression;
    }
    private static Type GetAccessedType(MemberExpression m) {
        if (m.Expression.NodeType == ExpressionType.MemberAccess) {
            return GetAccessedType((MemberExpression)m.Expression);
        }
        return m.Expression.Type;
    }
    private Expression GetCallMethod(Expression m, ConstantExpression column) {
        MethodInfo getMethod = GetGetMethod(m);
        var callMethod = Expression.Call(reader, getMethod, column);
        if (getMethod.ReturnType == m.Type) {
            return callMethod;
        }
        return Expression.Convert(callMethod, m.Type);
    }
    private int GetColumnPosition(MemberExpression m) {
        return GetColumnPosition(m.Member.Name);
    }
    private int GetColumnPosition(string columnName) {
        int columnPosition = 0;
        if (columnPositions.ContainsKey(columnName)) {
            columnPosition = columnPositions[columnName];
            return columnPosition;
        }
        columnPosition = columnPositions.Count();
        columnPositions.Add(columnName, columnPosition);
        return columnPosition;
    }
    private static MethodInfo GetGetMethod(Expression m) {
        Type memberType = GetMemberType(m);
        MethodInfo getMethod = null;
        switch (Type.GetTypeCode(memberType)) {
            case TypeCode.Boolean:
                getMethod = getBoolean;
                break;
            case TypeCode.Byte:
                getMethod = getByte;
                break;
            case TypeCode.Char:
                getMethod = getChar;
                break;
            case TypeCode.DateTime:
                getMethod = getDateTime;
                break;
            case TypeCode.Decimal:
                getMethod = getDecimal;
                break;
            case TypeCode.Double:
                getMethod = getDouble;
                break;
            case TypeCode.Int16:
                getMethod = getInt16;
                break;
            case TypeCode.Int32:
                getMethod = getInt32;
                break;
            case TypeCode.Int64:
                getMethod = getInt64;
                break;
            case TypeCode.String:
                getMethod = getString;
                break;
            case TypeCode.Object:
                getMethod = getValue;
                break;
            default:
                if (m.Type == typeof(Guid)) {
                    getMethod = getGUID;
                }
                else {
                    getMethod = getValue;
                }
                break;
        }
        return getMethod;
    }
    private static Type GetMemberType(Expression m) {
        Type memberType = null;
        if (m.Type.Name == "Nullable`1") {
            memberType = m.Type.GetGenericArguments()[0];
        }
        else {
            memberType = m.Type;
        }
        return memberType;
    }
    private LambdaExpression GetBindingLambda(LambdaExpression selector) {
        var instanceType = selector.Body.Type;
        // this is a hack
        var properties = (from property in instanceType.GetProperties()
                          where property.PropertyType.IsValueType ||
                                property.PropertyType == typeof(string)
                          orderby property.Name
                          select instanceType.GetField("_" + property.Name,
                                                       BindingFlags.Instance |
                                                       BindingFlags.NonPublic))
                          .ToArray();
        var bindings = new MemberBinding[properties.Length];
        for (int i = 0; i < properties.Length; i++) {
            var callMethod = GetColumnReader(
                                Expression.MakeMemberAccess(
                                    Expression.Parameter(instanceType, "param"),
                                    properties[i]),
                                i);
            bindings[i] = Expression.Bind(properties[i], callMethod);
        }
        return Expression.Lambda(Expression.MemberInit(Expression.New(instanceType),
                                 bindings),
                                 reader);
    }
}
注意
我的网络连接慢如糖浆,所以我稍后会发布完整的 LINQ to SQL IQueryable Provider 的项目。

