// --------------------------------------------------------------------------------------------------------------------
//
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD
// license as described in the file LICENSE.
//
// --------------------------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.Contracts;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Reflection.Emit;
namespace VW.Reflection
{
///
/// Reflection helper to find methods on visitors.
///
public static class ReflectionHelper
{
///
/// Compiles the supplied to a callable function.
///
/// The source expression to be compiled.
/// A callable function.
/// Can't constraint on Func (or would have to have 11 overloads) nor is it possible to constaint on delegate.
public static System.Delegate CompileToFunc(this Expression sourceExpression)
{
// inspect T to be Func<...>
var funcType = typeof(T);
if (!funcType.Name.StartsWith("Func`"))
throw new ArgumentException("T must be one of the System.Func<...> type.");
var genericArguments = funcType.GetGenericArguments();
var returnType = genericArguments.Last();
var paramTypes = genericArguments.Take(genericArguments.Length - 1);
// sign serializer so we can get access to internal members.
var asmName = new AssemblyName("VowpalWabbitSerializer");
StrongNameKeyPair kp;
using (var stream = typeof(ReflectionHelper).Assembly.GetManifestResourceStream("VW.vw_key.snk"))
using (var memStream = new MemoryStream())
{
stream.CopyTo(memStream, 1024);
kp = new StrongNameKeyPair(memStream.ToArray());
}
asmName.KeyPair = kp;
var dynAsm = AppDomain.CurrentDomain.DefineDynamicAssembly(asmName, AssemblyBuilderAccess.RunAndSave);
// Create a dynamic module and type
//#if !DEBUG
//var moduleBuilder = dynAsm.DefineDynamicModule("VowpalWabbitSerializerModule", asmName.Name + ".dll", true);
//#else
var moduleBuilder = dynAsm.DefineDynamicModule("VowpalWabbitSerializerModule");
var typeBuilder = moduleBuilder.DefineType("VowpalWabbitSerializer" + Guid.NewGuid().ToString().Replace('-', '_'));
// Create our method builder for this type builder
const string methodName = "Method";
var methodBuilder = typeBuilder.DefineMethod(
methodName,
MethodAttributes.Public | MethodAttributes.Static,
returnType,
paramTypes.ToArray());
// compared to Compile this looks rather ugly, but there is a feature-bug
// that adds a security check to every call of the Serialize method
//#if !DEBUG
//var debugInfoGenerator = DebugInfoGenerator.CreatePdbGenerator();
//visit.CompileToMethod(methodBuilder, debugInfoGenerator);
//#else
sourceExpression.CompileToMethod(methodBuilder);
//#endif
var dynType = typeBuilder.CreateType();
// for debugging only
// dynAsm.Save(@"my.dll");
return Delegate.CreateDelegate(typeof(T), dynType.GetMethod(methodName));
}
///
/// TODO: replace me with Roslyn once it's released and just generate string code. This way the overload resolution is properly done.
///
/// This is a simple heuristic for overload resolution, not the full thing.
public static MethodInfo FindMethod(Type objectType, string name, params Type[] parameterTypes)
{
Contract.Requires(objectType != null);
Contract.Requires(name != null);
Contract.Requires(parameterTypes != null);
// let's find the "best" match:
// order by
// 1. distance (0 = assignable, 1 = using generic) --> ascending
// 2. # of interfaces implemented. the more the better (the more specific we are) --> descending
// 3. # of open generics. the less the better (the more specific we are) --> ascending
var methods = from m in objectType.GetMethods(BindingFlags.Instance | BindingFlags.Public)
where m.Name == name
let parameters = m.GetParameters()
where parameters.Length == parameterTypes.Length
let output = parameterTypes.Zip(parameters, (valueType, methodParameter) => Distance(methodParameter.ParameterType, valueType)).ToArray()
where output.All(o => o != null)
let distance = output.Sum(o => o.Distance)
let interfacesImplemented = output.Sum(o => o.InterfacesImplemented)
let generics = output.Sum(o => o.GenericTypes.Count)
orderby
distance,
generics,
interfacesImplemented descending
select new
{
Method = m,
Distance = distance,
InterfacesImplemented = interfacesImplemented,
GenericTypes = output.Select(o => o.GenericTypes)
};
var bestCandidate = methods.FirstOrDefault();
if (bestCandidate == null)
{
return null;
}
MethodInfo method = bestCandidate.Method;
//Debug.WriteLine("Method Search");
//foreach (var item in methods)
//{
// Debug.WriteLine(string.Format("Distance={0} Interfaces={1} OpenGenerics={2} Method={3}",
// item.Distance,
// item.InterfacesImplemented,
// item.GenericTypes.Count(gt => gt.Count > 0),
// item.Method));
//}
if (method.IsGenericMethod)
{
var mergedGenericTypes = bestCandidate.GenericTypes.SelectMany(d => d).ToLookup(kvp => kvp.Key, kvp => kvp.Value);
// consistency check
foreach (var gt in mergedGenericTypes)
{
var refElem = gt.First();
if (gt.Any(t => t != refElem))
{
throw new NotSupportedException("Inconsistent generic argument mapping: " + string.Join(",", gt));
}
}
// map generic arguments to actual argument
var actualTypes = method.GetGenericArguments().Select(t => mergedGenericTypes[t].First()).ToArray();
method = method.MakeGenericMethod(actualTypes);
//Debug.WriteLine("\t specializing: " + method);
}
// Debug.WriteLine("Method: {0} for {1} {2}", method, name, string.Join(",", parameterTypes.Select(t => t.ToString())));
return method;
}
internal static TypeMatch Distance(Type candidate, Type valueType)
{
if (candidate == valueType)
{
return new TypeMatch(0)
{
InterfacesImplemented = candidate.GetInterfaces().Count()
};
}
if (candidate.IsAssignableFrom(valueType))
{
return new TypeMatch(1)
{
InterfacesImplemented = candidate.GetInterfaces().Count()
};
}
if (candidate.IsGenericParameter && candidate.GetGenericParameterConstraints().All(c => c.IsAssignableFrom(valueType)))
{
return new TypeMatch(2, candidate, valueType)
{
InterfacesImplemented = candidate.GetInterfaces().Count()
};
}
if (candidate.IsGenericType)
{
// try to find a match that is assignable...
//
var genericCandidate = candidate.GetGenericTypeDefinition();
var bestMatches =
from typeDistance in valueType.GetInterfaces().Select(it => new TypeDistance { Distance = 1, Type = it })
.Union(GetBaseTypes(valueType))
let type = typeDistance.Type
where type.IsGenericType && type.GetGenericTypeDefinition() == genericCandidate
let distances = candidate.GetGenericArguments().Zip(type.GetGenericArguments(), (a, b) => Distance(a, b)).ToList()
where distances.All(d => d != null)
let output = new TypeMatch(typeDistance.Distance, distances)
{
InterfacesImplemented = distances.Sum(d => d.InterfacesImplemented)
+ (candidate.IsInterface ? candidate.GetInterfaces().Count() : 0)
}
orderby output.Distance, output.InterfacesImplemented descending, output.GenericTypes.Count
select output;
return bestMatches.FirstOrDefault();
}
return null;
}
internal static IEnumerable GetBaseTypes(Type type, int depth = 0)
{
if (type == typeof(object) || type == null)
{
yield break;
}
yield return new TypeDistance { Type = type, Distance = depth };
foreach (var item in GetBaseTypes(type.BaseType, depth + 1))
{
yield return item;
}
}
///
/// Gets the member info in a sort of type safe manner - it's better than using strings, but some runtime errors are still possbile.
///
public static MemberInfo GetInfo(Expression> expression)
{
Contract.Requires(expression != null);
return GetInfo(expression.Body);
}
///
/// Gets the member info in a sort of type safe manner - it's better than using strings, but some runtime errors are still possbile.
///
public static MemberInfo GetInfo(Expression> expression)
{
Contract.Requires(expression != null);
return GetInfo(expression.Body);
}
///
/// Gets the member info in a sort of type safe manner - it's better than using strings, but some runtime errors are still possbile.
///
public static MemberInfo GetInfo(Expression expression)
{
Contract.Requires(expression != null);
var binaryExpression = expression as BinaryExpression;
if (binaryExpression != null)
{
if (binaryExpression.Method != null)
{
return binaryExpression.Method;
}
throw new NotSupportedException();
}
var methodExpression = expression as MemberExpression;
if (methodExpression != null)
{
return methodExpression.Member;
}
var methodCallExpression = expression as MethodCallExpression;
if (methodCallExpression != null)
{
return methodCallExpression.Method;
}
var newExpression = expression as NewExpression;
if (newExpression != null)
{
return newExpression.Constructor;
}
var unaryExpression = expression as UnaryExpression;
if (unaryExpression != null)
{
if (unaryExpression.Method != null)
{
return unaryExpression.Method;
}
}
throw new NotSupportedException();
}
}
}