diff --git a/src/Orleans.CodeGenerator/CodeGenerator.cs b/src/Orleans.CodeGenerator/CodeGenerator.cs index ff2f0847b2e..4149f241c0c 100644 --- a/src/Orleans.CodeGenerator/CodeGenerator.cs +++ b/src/Orleans.CodeGenerator/CodeGenerator.cs @@ -23,6 +23,7 @@ public class CodeGeneratorOptions public List ImmutableAttributes { get; } = new() { "Orleans.ImmutableAttribute" }; public List ConstructorAttributes { get; } = new() { "Orleans.OrleansConstructorAttribute", "Microsoft.Extensions.DependencyInjection.ActivatorUtilitiesConstructorAttribute" }; public GenerateFieldIds GenerateFieldIds { get; set; } + public bool GenerateCompatibilityInvokers { get; set; } } public class CodeGenerator @@ -643,7 +644,7 @@ internal InvokableMethodProxyBase GetProxyBase(INamedTypeSymbol interfaceType) return result; } - internal ProxyInterfaceDescription GetInvokableInterfaceDescription(INamedTypeSymbol proxyBaseType, INamedTypeSymbol interfaceType) + private ProxyInterfaceDescription GetInvokableInterfaceDescription(INamedTypeSymbol proxyBaseType, INamedTypeSymbol interfaceType) { var originalInterface = interfaceType.OriginalDefinition; if (MetadataModel.InvokableInterfaces.TryGetValue(originalInterface, out var description)) @@ -668,62 +669,56 @@ internal ProxyInterfaceDescription GetInvokableInterfaceDescription(INamedTypeSy return description; } - internal InvokableMethodDescription GetInvokableMethod(InvokableMethodId invokableId) + internal ProxyMethodDescription GetProxyMethodDescription(INamedTypeSymbol interfaceType, IMethodSymbol method, bool hasCollision) { - if (!_invokableMethodDescriptions.TryGetValue(invokableId, out var result)) - { - result = _invokableMethodDescriptions[invokableId] = InvokableMethodDescription.Create(invokableId); - } - - return result; - } + var originalMethod = method.OriginalDefinition; + var proxyBaseInfo = GetProxyBase(interfaceType); + var invokableId = new InvokableMethodId(proxyBaseInfo, originalMethod); + var interfaceDescription = GetInvokableInterfaceDescription(invokableId.ProxyBase.ProxyBaseType, interfaceType); - internal GeneratedInvokableDescription GetGeneratedInvokable(InvokableMethodId invokableId) - { // Get or generate an invokable for the original method definition. - if (MetadataModel.GeneratedInvokables.TryGetValue(invokableId, out var result)) + if (!MetadataModel.GeneratedInvokables.TryGetValue(invokableId, out var generatedInvokable)) { - return result; - } + if (!_invokableMethodDescriptions.TryGetValue(invokableId, out var methodDescription)) + { + methodDescription = _invokableMethodDescriptions[invokableId] = InvokableMethodDescription.Create(invokableId); + } - var methodDescription = GetInvokableMethod(invokableId); - result = MetadataModel.GeneratedInvokables[invokableId] = InvokableGenerator.Generate(methodDescription); + generatedInvokable = MetadataModel.GeneratedInvokables[invokableId] = InvokableGenerator.Generate(methodDescription); - if (Compilation.GetTypeByMetadataName(result.MetadataName) == null) - { - // Emit the generated code on-demand. - AddMember(result.GeneratedNamespace, result.ClassDeclarationSyntax); + if (Compilation.GetTypeByMetadataName(generatedInvokable.MetadataName) == null) + { + // Emit the generated code on-demand. + AddMember(generatedInvokable.GeneratedNamespace, generatedInvokable.ClassDeclarationSyntax); - // Ensure the type will have a serializer generated for it. - MetadataModel.SerializableTypes.Add(result); + // Ensure the type will have a serializer generated for it. + MetadataModel.SerializableTypes.Add(generatedInvokable); - foreach (var alias in result.CompoundTypeAliases) - { - MetadataModel.CompoundTypeAliases.Add(alias, result.OpenTypeSyntax); + foreach (var alias in generatedInvokable.CompoundTypeAliases) + { + MetadataModel.CompoundTypeAliases.Add(alias, generatedInvokable.OpenTypeSyntax); + } } } - return result; - } - - internal ProxyMethodDescription GetProxyMethodDescription(INamedTypeSymbol interfaceType, IMethodSymbol method, bool hasCollision) - { - var invokableId = GetInvokableMethodId(interfaceType, method); - return GetInterfaceMethodDescription(interfaceType, method, invokableId, hasCollision); - } + var proxyMethodDescription = ProxyMethodDescription.Create(interfaceDescription, generatedInvokable, method, hasCollision); - internal ProxyMethodDescription GetInterfaceMethodDescription(INamedTypeSymbol interfaceType, IMethodSymbol method, InvokableMethodId invokableId, bool hasCollision) - { - var interfaceDescription = GetInvokableInterfaceDescription(invokableId.ProxyBase.ProxyBaseType, interfaceType); - var generatedInvokable = GetGeneratedInvokable(invokableId); - return ProxyMethodDescription.Create(interfaceDescription, generatedInvokable, method, hasCollision); - } + // For backwards compatibility, generate invokers for the specific implementation types as well, where they differ. + if (Options.GenerateCompatibilityInvokers && !SymbolEqualityComparer.Default.Equals(method.OriginalDefinition.ContainingType, interfaceType)) + { + var compatInvokableId = new InvokableMethodId(proxyBaseInfo, method); + var compatMethodDescription = InvokableMethodDescription.Create(compatInvokableId, interfaceType); + var compatInvokable = InvokableGenerator.Generate(compatMethodDescription); + AddMember(compatInvokable.GeneratedNamespace, compatInvokable.ClassDeclarationSyntax); + var alias = + InvokableGenerator.GetCompoundTypeAliasComponents( + compatInvokableId, + interfaceType, + compatMethodDescription.GeneratedMethodId); + MetadataModel.CompoundTypeAliases.Add(alias, compatInvokable.OpenTypeSyntax); + } - internal InvokableMethodId GetInvokableMethodId(INamedTypeSymbol interfaceType, IMethodSymbol method) - { - var originalMethod = method.OriginalDefinition; - var proxyBaseInfo = GetProxyBase(interfaceType); - return new InvokableMethodId(proxyBaseInfo, originalMethod); + return proxyMethodDescription; } } } diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index 6c1ab912009..cbf97e5ceb0 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -34,7 +34,7 @@ public GeneratedInvokableDescription Generate(InvokableMethodDescription invokab var fields = GetFieldDeclarations(invokableMethodInfo, fieldDescriptions); var (ctor, ctorArgs) = GenerateConstructor(generatedClassName, invokableMethodInfo, baseClassType); var accessibility = GetAccessibility(method); - var compoundTypeAliases = GetCompoundTypeAliasAttributeArguments(invokableMethodInfo); + var compoundTypeAliases = GetCompoundTypeAliasAttributeArguments(invokableMethodInfo, invokableMethodInfo.Key); List serializationHooks = new(); if (baseClassType.GetAttributes(LibraryTypes.SerializationCallbacksAttribute, out var hookAttributes)) @@ -208,38 +208,35 @@ internal AttributeSyntax GetCompoundTypeAliasAttribute(CompoundTypeAliasComponen return Attribute(LibraryTypes.CompoundTypeAliasAttribute.ToNameSyntax()).AddArgumentListArguments(args); } - internal static List GetCompoundTypeAliasAttributeArguments(InvokableMethodDescription methodDescription) + internal static List GetCompoundTypeAliasAttributeArguments(InvokableMethodDescription methodDescription, InvokableMethodId invokableId) { var result = new List(2); - var proxyBaseComponents = methodDescription.Key.ProxyBase.CompositeAliasComponents; + var containingInterface = methodDescription.ContainingInterface; if (methodDescription.HasAlias) { - var alias = new CompoundTypeAliasComponent[1 + proxyBaseComponents.Length + 2]; - alias[0] = new("inv"); - for (var i = 0; i < proxyBaseComponents.Length; i++) - { - alias[i + 1] = proxyBaseComponents[i]; - } - - alias[1 + proxyBaseComponents.Length] = new(methodDescription.ContainingInterface); - alias[1 + proxyBaseComponents.Length + 1] = new(methodDescription.MethodId); - result.Add(alias); + result.Add(GetCompoundTypeAliasComponents(invokableId, containingInterface, methodDescription.MethodId)); } - { - var alias = new CompoundTypeAliasComponent[1 + proxyBaseComponents.Length + 2]; - alias[0] = new("inv"); - for (var i = 0; i < proxyBaseComponents.Length; i++) - { - alias[i + 1] = proxyBaseComponents[i]; - } + result.Add(GetCompoundTypeAliasComponents(invokableId, containingInterface, methodDescription.GeneratedMethodId)); + return result; + } - alias[1 + proxyBaseComponents.Length] = new(methodDescription.ContainingInterface); - alias[1 + proxyBaseComponents.Length + 1] = new(methodDescription.GeneratedMethodId); - result.Add(alias); + public static CompoundTypeAliasComponent[] GetCompoundTypeAliasComponents( + InvokableMethodId invokableId, + INamedTypeSymbol containingInterface, + string methodId) + { + var proxyBaseComponents = invokableId.ProxyBase.CompositeAliasComponents; + var alias = new CompoundTypeAliasComponent[1 + proxyBaseComponents.Length + 2]; + alias[0] = new("inv"); + for (var i = 0; i < proxyBaseComponents.Length; i++) + { + alias[i + 1] = proxyBaseComponents[i]; } - return result; + alias[1 + proxyBaseComponents.Length] = new(containingInterface); + alias[1 + proxyBaseComponents.Length + 1] = new(methodId); + return alias; } private INamedTypeSymbol GetBaseClassType(InvokableMethodDescription method) @@ -582,7 +579,7 @@ public static string GetSimpleClassName(InvokableMethodDescription method) var genericArity = method.AllTypeParameters.Count; var typeArgs = genericArity > 0 ? "_" + genericArity : string.Empty; var proxyKey = method.ProxyBase.Key.GeneratedClassNameComponent; - return $"Invokable_{method.Method.ContainingType.Name}_{proxyKey}_{method.GeneratedMethodId}{typeArgs}"; + return $"Invokable_{method.ContainingInterface.Name}_{proxyKey}_{method.GeneratedMethodId}{typeArgs}"; } private MemberDeclarationSyntax[] GetFieldDeclarations( diff --git a/src/Orleans.CodeGenerator/Model/GeneratedInvokableDescription.cs b/src/Orleans.CodeGenerator/Model/GeneratedInvokableDescription.cs index 23db9b9351d..66aa2ace16a 100644 --- a/src/Orleans.CodeGenerator/Model/GeneratedInvokableDescription.cs +++ b/src/Orleans.CodeGenerator/Model/GeneratedInvokableDescription.cs @@ -79,7 +79,7 @@ public GeneratedInvokableDescription( public bool IsExceptionType => false; public List ActivatorConstructorParameters { get; } public bool HasActivatorConstructor => UseActivator; - public List CompoundTypeAliases {get;} + public List CompoundTypeAliases { get; } public ClassDeclarationSyntax ClassDeclarationSyntax { get; } public string ReturnValueInitializerMethod { get; } diff --git a/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs b/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs index f2649fd2d4b..0cdcf514b81 100644 --- a/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs +++ b/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs @@ -14,11 +14,12 @@ namespace Orleans.CodeGenerator /// internal sealed class InvokableMethodDescription : IEquatable { - public static InvokableMethodDescription Create(InvokableMethodId method) => new(method); + public static InvokableMethodDescription Create(InvokableMethodId method, INamedTypeSymbol containingType = null) => new(method, containingType); - private InvokableMethodDescription(InvokableMethodId invokableId) + private InvokableMethodDescription(InvokableMethodId invokableId, INamedTypeSymbol containingType) { Key = invokableId; + ContainingInterface = containingType ?? invokableId.Method.ContainingType; GeneratedMethodId = CodeGenerator.CreateHashedMethodId(Method); MethodId = CodeGenerator.GetId(Method)?.ToString(CultureInfo.InvariantCulture) ?? CodeGenerator.GetAlias(Method) ?? GeneratedMethodId; @@ -98,7 +99,7 @@ private InvokableMethodDescription(InvokableMethodId invokableId) MethodTypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); var names = new HashSet(StringComparer.Ordinal); - foreach (var typeParameter in Method.ContainingType.GetAllTypeParameters()) + foreach (var typeParameter in ContainingInterface.GetAllTypeParameters()) { var tpName = GetTypeParameterName(names, typeParameter); AllTypeParameters.Add((tpName, typeParameter)); @@ -203,7 +204,7 @@ static bool TryGetNamedArgument(ImmutableArray /// Gets the interface which this type is contained in. /// - public INamedTypeSymbol ContainingInterface => Method.ContainingType; + public INamedTypeSymbol ContainingInterface { get; } public bool Equals(InvokableMethodDescription other) => Key.Equals(other.Key); public override bool Equals(object obj) => obj is InvokableMethodDescription imd && Equals(imd); diff --git a/src/Orleans.CodeGenerator/Model/InvokableMethodId.cs b/src/Orleans.CodeGenerator/Model/InvokableMethodId.cs index 82c8c9d2680..e2e4fe4cbea 100644 --- a/src/Orleans.CodeGenerator/Model/InvokableMethodId.cs +++ b/src/Orleans.CodeGenerator/Model/InvokableMethodId.cs @@ -10,11 +10,6 @@ namespace Orleans.CodeGenerator { public InvokableMethodId(InvokableMethodProxyBase proxyBaseInfo, IMethodSymbol method) { - if (!SymbolEqualityComparer.Default.Equals(method, method.OriginalDefinition)) - { - throw new ArgumentException("Method must be an original definition", nameof(method)); - } - ProxyBase = proxyBaseInfo; Method = method; } @@ -29,11 +24,9 @@ public InvokableMethodId(InvokableMethodProxyBase proxyBaseInfo, IMethodSymbol m /// public IMethodSymbol Method { get; } - public bool Equals(InvokableMethodId other) - { - return ProxyBase.Equals(other.ProxyBase) - && SymbolEqualityComparer.Default.Equals(Method, other.Method); - } + public bool Equals(InvokableMethodId other) => + ProxyBase.Equals(other.ProxyBase) + && SymbolEqualityComparer.Default.Equals(Method, other.Method); public override bool Equals(object obj) => obj is InvokableMethodId imd && Equals(imd); public override int GetHashCode() => ProxyBase.GetHashCode() * 17 ^ SymbolEqualityComparer.Default.GetHashCode(Method); diff --git a/src/Orleans.CodeGenerator/Model/ProxyMethodDescription.cs b/src/Orleans.CodeGenerator/Model/ProxyMethodDescription.cs index 81f325e7c08..a2e45e50885 100644 --- a/src/Orleans.CodeGenerator/Model/ProxyMethodDescription.cs +++ b/src/Orleans.CodeGenerator/Model/ProxyMethodDescription.cs @@ -156,8 +156,6 @@ public ConstructedGeneratedInvokableDescription(GeneratedInvokableDescription in public bool IsExceptionType => _invokableDescription.IsExceptionType; public List ActivatorConstructorParameters => _invokableDescription.ActivatorConstructorParameters; public bool HasActivatorConstructor => UseActivator; - public List CompoundTypeAliases => _invokableDescription.CompoundTypeAliases; - public ClassDeclarationSyntax ClassDeclarationSyntax => _invokableDescription.ClassDeclarationSyntax; public string ReturnValueInitializerMethod => _invokableDescription.ReturnValueInitializerMethod; public InvokableMethodDescription MethodDescription => _invokableDescription.MethodDescription; diff --git a/src/Orleans.CodeGenerator/OrleansSourceGenerator.cs b/src/Orleans.CodeGenerator/OrleansSourceGenerator.cs index 21668c07bdc..fe561191d1e 100644 --- a/src/Orleans.CodeGenerator/OrleansSourceGenerator.cs +++ b/src/Orleans.CodeGenerator/OrleansSourceGenerator.cs @@ -57,7 +57,15 @@ public void Execute(GeneratorExecutionContext context) if (context.AnalyzerConfigOptions.GlobalOptions.TryGetValue("build_property.orleans_generatefieldids", out var generateFieldIds) && generateFieldIds is { Length: > 0 }) { if (Enum.TryParse(generateFieldIds, out GenerateFieldIds fieldIdOption)) + { options.GenerateFieldIds = fieldIdOption; + } + } + + if (context.AnalyzerConfigOptions.GlobalOptions.TryGetValue("build_property.orleansgeneratecompatibilityinvokers", out var generateCompatInvokersValue) + && bool.TryParse(generateCompatInvokersValue, out var genCompatInvokers)) + { + options.GenerateCompatibilityInvokers = genCompatInvokers; } var codeGenerator = new CodeGenerator(context.Compilation, options); diff --git a/src/Orleans.CodeGenerator/build/Microsoft.Orleans.CodeGenerator.props b/src/Orleans.CodeGenerator/build/Microsoft.Orleans.CodeGenerator.props index 81a3f826d90..e72ffc3cec2 100644 --- a/src/Orleans.CodeGenerator/build/Microsoft.Orleans.CodeGenerator.props +++ b/src/Orleans.CodeGenerator/build/Microsoft.Orleans.CodeGenerator.props @@ -9,6 +9,7 @@ +