Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added open generic support for decorator func #81

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions src/Scrutor/ServiceCollectionExtensions.Decoration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ public static bool TryDecorate(this IServiceCollection services, Type serviceTyp

if (serviceType.IsOpenGeneric() && decoratorType.IsOpenGeneric())
{
return services.TryDecorateOpenGeneric(serviceType, decoratorType);
var openTypeTryDecorator = OpenTypeTryDecorator(services, serviceType, decoratorType);
return services.TryDecorateOpenGeneric(serviceType, openTypeTryDecorator);
}

return services.TryDecorateDescriptors(serviceType, x => x.Decorate(decoratorType));
Expand Down Expand Up @@ -170,6 +171,11 @@ public static IServiceCollection Decorate(this IServiceCollection services, Type
Preconditions.NotNull(serviceType, nameof(serviceType));
Preconditions.NotNull(decorator, nameof(decorator));

if (serviceType.IsOpenGeneric())
{
return services.DecorateOpenGeneric(serviceType, decorator);
}

return services.DecorateDescriptors(serviceType, x => x.Decorate(decorator));
}

Expand Down Expand Up @@ -230,7 +236,8 @@ public static bool TryDecorate(this IServiceCollection services, Type serviceTyp

private static IServiceCollection DecorateOpenGeneric(this IServiceCollection services, Type serviceType, Type decoratorType)
{
if (services.TryDecorateOpenGeneric(serviceType, decoratorType))
var openTypeTryDecorator = OpenTypeTryDecorator(services, serviceType, decoratorType);
if (services.TryDecorateOpenGeneric(serviceType, openTypeTryDecorator))
{
return services;
}
Expand All @@ -243,16 +250,19 @@ private static bool IsSameGenericType(Type t1, Type t2)
return t1.IsGenericType && t2.IsGenericType && t1.GetGenericTypeDefinition() == t2.GetGenericTypeDefinition();
}

private static bool TryDecorateOpenGeneric(this IServiceCollection services, Type serviceType, Type decoratorType)
private static IServiceCollection DecorateOpenGeneric(this IServiceCollection services, Type serviceType, Func<object, IServiceProvider, object> decorator)
{
bool TryDecorate(Type[] typeArguments)
var openTypeTryDecorator = OpenTypeTryDecorator(services, serviceType, decorator);
if (services.TryDecorateOpenGeneric(serviceType, openTypeTryDecorator))
{
var closedServiceType = serviceType.MakeGenericType(typeArguments);
var closedDecoratorType = decoratorType.MakeGenericType(typeArguments);

return services.TryDecorateDescriptors(closedServiceType, x => x.Decorate(closedDecoratorType));
return services;
}

throw new MissingTypeRegistrationException(serviceType);
}

private static bool TryDecorateOpenGeneric(this IServiceCollection services, Type serviceType, Func<Type[], bool> openTypeTryDecorator)
{
var arguments = services
.Where(descriptor => IsSameGenericType(descriptor.ServiceType, serviceType))
.Select(descriptor => descriptor.ServiceType.GenericTypeArguments)
Expand All @@ -263,7 +273,27 @@ bool TryDecorate(Type[] typeArguments)
return false;
}

return arguments.Aggregate(true, (result, args) => result && TryDecorate(args));
return arguments.Aggregate(true, (result, args) => result && openTypeTryDecorator(args));
}

private static Func<Type[], bool> OpenTypeTryDecorator(IServiceCollection services, Type serviceType, Type decoratorType)
{
return typeArguments =>
{
var closedServiceType = serviceType.MakeGenericType(typeArguments);
var closedDecoratorType = decoratorType.MakeGenericType(typeArguments);

return services.TryDecorateDescriptors(closedServiceType, x => x.Decorate(closedDecoratorType));
};
}

private static Func<Type[], bool> OpenTypeTryDecorator(IServiceCollection services, Type serviceType, Func<object, IServiceProvider, object> decorator)
{
return typeArguments =>
{
var closedServiceType = serviceType.MakeGenericType(typeArguments);
return services.TryDecorateDescriptors(closedServiceType, x => x.Decorate(decorator));
};
}

private static IServiceCollection DecorateDescriptors(this IServiceCollection services, Type serviceType, Func<ServiceDescriptor, ServiceDescriptor> decorator)
Expand Down
38 changes: 38 additions & 0 deletions test/Scrutor.Tests/OpenGenericDecorationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ public void CanDecorateOpenGenericTypeBasedOnInterface()
Assert.IsType<MyQueryHandler>(loggingDecorator.Inner);
}

[Fact]
public void CanDecorateOpenGenericTypeBasedOnInterfaceByDecoratorFunc()
{
var provider = ConfigureProvider(services =>
{
services.AddSingleton<IQueryHandler<MyQuery, MyResult>, MySpecialQueryHandler>();
services.Decorate(typeof(IQueryHandler<,>), (handlerObj, serviceProvider) =>
{
if (handlerObj is ISpecialInterface specialInterface)
{
specialInterface.InitSomeField();
}

return handlerObj;
});
});

var instance = provider.GetRequiredService<IQueryHandler<MyQuery, MyResult>>();
var myQueryHandler = Assert.IsType<MySpecialQueryHandler>(instance);
Assert.True(myQueryHandler.GetSomeField());
}

[Fact]
public void DecoratingNonRegisteredOpenGenericServiceThrows()
{
Expand Down Expand Up @@ -79,6 +101,22 @@ public void DecoratingOpenGenericTypeBasedOnGrandparentInterfaceDoesNotDecorateP
}
}

public interface ISpecialInterface
{
void InitSomeField();
}

public class MySpecialQueryHandler : QueryHandler<MyQuery, MyResult>, ISpecialInterface
{
private bool _someField = false;
public void InitSomeField()
{
_someField = true;
}

public bool GetSomeField() => _someField;
}

public class MyQuery { }

public class MyResult { }
Expand Down