// -------------------------------------------------------------------------------------------------------------------- // // Copyright (c) by respective owners including Yahoo!, Microsoft, and // individual contributors. All rights reserved. Released under a BSD // license as described in the file LICENSE. // // -------------------------------------------------------------------------------------------------------------------- using System; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.Contracts; using System.IO; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Reflection.Emit; namespace VW.Reflection { /// /// Reflection helper to find methods on visitors. /// public static class ReflectionHelper { /// /// Compiles the supplied to a callable function. /// /// The source expression to be compiled. /// A callable function. /// Can't constraint on Func (or would have to have 11 overloads) nor is it possible to constaint on delegate. public static System.Delegate CompileToFunc(this Expression sourceExpression) { // inspect T to be Func<...> var funcType = typeof(T); if (!funcType.Name.StartsWith("Func`")) throw new ArgumentException("T must be one of the System.Func<...> type."); var genericArguments = funcType.GetGenericArguments(); var returnType = genericArguments.Last(); var paramTypes = genericArguments.Take(genericArguments.Length - 1); // sign serializer so we can get access to internal members. var asmName = new AssemblyName("VowpalWabbitSerializer"); StrongNameKeyPair kp; using (var stream = typeof(ReflectionHelper).Assembly.GetManifestResourceStream("VW.vw_key.snk")) using (var memStream = new MemoryStream()) { stream.CopyTo(memStream, 1024); kp = new StrongNameKeyPair(memStream.ToArray()); } asmName.KeyPair = kp; var dynAsm = AppDomain.CurrentDomain.DefineDynamicAssembly(asmName, AssemblyBuilderAccess.RunAndSave); // Create a dynamic module and type //#if !DEBUG //var moduleBuilder = dynAsm.DefineDynamicModule("VowpalWabbitSerializerModule", asmName.Name + ".dll", true); //#else var moduleBuilder = dynAsm.DefineDynamicModule("VowpalWabbitSerializerModule"); var typeBuilder = moduleBuilder.DefineType("VowpalWabbitSerializer" + Guid.NewGuid().ToString().Replace('-', '_')); // Create our method builder for this type builder const string methodName = "Method"; var methodBuilder = typeBuilder.DefineMethod( methodName, MethodAttributes.Public | MethodAttributes.Static, returnType, paramTypes.ToArray()); // compared to Compile this looks rather ugly, but there is a feature-bug // that adds a security check to every call of the Serialize method //#if !DEBUG //var debugInfoGenerator = DebugInfoGenerator.CreatePdbGenerator(); //visit.CompileToMethod(methodBuilder, debugInfoGenerator); //#else sourceExpression.CompileToMethod(methodBuilder); //#endif var dynType = typeBuilder.CreateType(); // for debugging only // dynAsm.Save(@"my.dll"); return Delegate.CreateDelegate(typeof(T), dynType.GetMethod(methodName)); } /// /// TODO: replace me with Roslyn once it's released and just generate string code. This way the overload resolution is properly done. /// /// This is a simple heuristic for overload resolution, not the full thing. public static MethodInfo FindMethod(Type objectType, string name, params Type[] parameterTypes) { Contract.Requires(objectType != null); Contract.Requires(name != null); Contract.Requires(parameterTypes != null); // let's find the "best" match: // order by // 1. distance (0 = assignable, 1 = using generic) --> ascending // 2. # of interfaces implemented. the more the better (the more specific we are) --> descending // 3. # of open generics. the less the better (the more specific we are) --> ascending var methods = from m in objectType.GetMethods(BindingFlags.Instance | BindingFlags.Public) where m.Name == name let parameters = m.GetParameters() where parameters.Length == parameterTypes.Length let output = parameterTypes.Zip(parameters, (valueType, methodParameter) => Distance(methodParameter.ParameterType, valueType)).ToArray() where output.All(o => o != null) let distance = output.Sum(o => o.Distance) let interfacesImplemented = output.Sum(o => o.InterfacesImplemented) let generics = output.Sum(o => o.GenericTypes.Count) orderby distance, generics, interfacesImplemented descending select new { Method = m, Distance = distance, InterfacesImplemented = interfacesImplemented, GenericTypes = output.Select(o => o.GenericTypes) }; var bestCandidate = methods.FirstOrDefault(); if (bestCandidate == null) { return null; } MethodInfo method = bestCandidate.Method; //Debug.WriteLine("Method Search"); //foreach (var item in methods) //{ // Debug.WriteLine(string.Format("Distance={0} Interfaces={1} OpenGenerics={2} Method={3}", // item.Distance, // item.InterfacesImplemented, // item.GenericTypes.Count(gt => gt.Count > 0), // item.Method)); //} if (method.IsGenericMethod) { var mergedGenericTypes = bestCandidate.GenericTypes.SelectMany(d => d).ToLookup(kvp => kvp.Key, kvp => kvp.Value); // consistency check foreach (var gt in mergedGenericTypes) { var refElem = gt.First(); if (gt.Any(t => t != refElem)) { throw new NotSupportedException("Inconsistent generic argument mapping: " + string.Join(",", gt)); } } // map generic arguments to actual argument var actualTypes = method.GetGenericArguments().Select(t => mergedGenericTypes[t].First()).ToArray(); method = method.MakeGenericMethod(actualTypes); //Debug.WriteLine("\t specializing: " + method); } // Debug.WriteLine("Method: {0} for {1} {2}", method, name, string.Join(",", parameterTypes.Select(t => t.ToString()))); return method; } internal static TypeMatch Distance(Type candidate, Type valueType) { if (candidate == valueType) { return new TypeMatch(0) { InterfacesImplemented = candidate.GetInterfaces().Count() }; } if (candidate.IsAssignableFrom(valueType)) { return new TypeMatch(1) { InterfacesImplemented = candidate.GetInterfaces().Count() }; } if (candidate.IsGenericParameter && candidate.GetGenericParameterConstraints().All(c => c.IsAssignableFrom(valueType))) { return new TypeMatch(2, candidate, valueType) { InterfacesImplemented = candidate.GetInterfaces().Count() }; } if (candidate.IsGenericType) { // try to find a match that is assignable... // var genericCandidate = candidate.GetGenericTypeDefinition(); var bestMatches = from typeDistance in valueType.GetInterfaces().Select(it => new TypeDistance { Distance = 1, Type = it }) .Union(GetBaseTypes(valueType)) let type = typeDistance.Type where type.IsGenericType && type.GetGenericTypeDefinition() == genericCandidate let distances = candidate.GetGenericArguments().Zip(type.GetGenericArguments(), (a, b) => Distance(a, b)).ToList() where distances.All(d => d != null) let output = new TypeMatch(typeDistance.Distance, distances) { InterfacesImplemented = distances.Sum(d => d.InterfacesImplemented) + (candidate.IsInterface ? candidate.GetInterfaces().Count() : 0) } orderby output.Distance, output.InterfacesImplemented descending, output.GenericTypes.Count select output; return bestMatches.FirstOrDefault(); } return null; } internal static IEnumerable GetBaseTypes(Type type, int depth = 0) { if (type == typeof(object) || type == null) { yield break; } yield return new TypeDistance { Type = type, Distance = depth }; foreach (var item in GetBaseTypes(type.BaseType, depth + 1)) { yield return item; } } /// /// Gets the member info in a sort of type safe manner - it's better than using strings, but some runtime errors are still possbile. /// public static MemberInfo GetInfo(Expression> expression) { Contract.Requires(expression != null); return GetInfo(expression.Body); } /// /// Gets the member info in a sort of type safe manner - it's better than using strings, but some runtime errors are still possbile. /// public static MemberInfo GetInfo(Expression> expression) { Contract.Requires(expression != null); return GetInfo(expression.Body); } /// /// Gets the member info in a sort of type safe manner - it's better than using strings, but some runtime errors are still possbile. /// public static MemberInfo GetInfo(Expression expression) { Contract.Requires(expression != null); var binaryExpression = expression as BinaryExpression; if (binaryExpression != null) { if (binaryExpression.Method != null) { return binaryExpression.Method; } throw new NotSupportedException(); } var methodExpression = expression as MemberExpression; if (methodExpression != null) { return methodExpression.Member; } var methodCallExpression = expression as MethodCallExpression; if (methodCallExpression != null) { return methodCallExpression.Method; } var newExpression = expression as NewExpression; if (newExpression != null) { return newExpression.Constructor; } var unaryExpression = expression as UnaryExpression; if (unaryExpression != null) { if (unaryExpression.Method != null) { return unaryExpression.Method; } } throw new NotSupportedException(); } } }