diff --git a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs index aec8dcbe33..c5eb1d4d36 100644 --- a/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ModuleClassBuilder.cs @@ -275,24 +275,23 @@ private static void BuildParameter(ParameterBuilder builder, System.Reflection.P if (builder.TypeReader == null) { - builder.TypeReader = service.GetDefaultTypeReader(paramType) - ?? service.GetTypeReaders(paramType)?.FirstOrDefault().Value; + builder.TypeReader = service.GetTypeReaders(paramType, false)?.FirstOrDefault().Value + ?? service.GetDefaultTypeReader(paramType); } } internal static TypeReader GetTypeReader(CommandService service, Type paramType, Type typeReaderType, IServiceProvider services) { - var readers = service.GetTypeReaders(paramType); - TypeReader reader = null; + var readers = service.GetTypeReaders(paramType, true); if (readers != null) - { - if (readers.TryGetValue(typeReaderType, out reader)) - return reader; - } + foreach (var kvp in readers) + if (kvp.Key == typeReaderType) + return kvp.Value; //We dont have a cached type reader, create one - reader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, services); - service.AddTypeReader(paramType, reader, false); + TypeReader reader = ReflectionUtils.CreateObject(typeReaderType.GetTypeInfo(), service, services); + reader.IsOverride = true; + service.AddTypeReader(paramType, reader); return reader; } diff --git a/src/Discord.Net.Commands/Builders/ParameterBuilder.cs b/src/Discord.Net.Commands/Builders/ParameterBuilder.cs index 4ad5bfac08..f2a3ee70ce 100644 --- a/src/Discord.Net.Commands/Builders/ParameterBuilder.cs +++ b/src/Discord.Net.Commands/Builders/ParameterBuilder.cs @@ -60,7 +60,7 @@ private TypeReader GetReader(Type type) if (type.GetTypeInfo().GetCustomAttribute() != null) { IsRemainder = true; - var reader = commands.GetTypeReaders(type)?.FirstOrDefault().Value; + var reader = commands.GetTypeReaders(type, false)?.FirstOrDefault().Value; if (reader == null) { Type readerType; @@ -80,8 +80,7 @@ private TypeReader GetReader(Type type) return reader; } - - var readers = commands.GetTypeReaders(type); + var readers = commands.GetTypeReaders(type, false); if (readers != null) return readers.FirstOrDefault().Value; else diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index d5c060fe43..4ca26c82f5 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -48,6 +48,7 @@ public class CommandService : IDisposable private readonly SemaphoreSlim _moduleLock; private readonly ConcurrentDictionary _typedModuleDefs; private readonly ConcurrentDictionary> _typeReaders; + private readonly ConcurrentDictionary> _userEntityTypeReaders; private readonly ConcurrentDictionary _defaultTypeReaders; private readonly ImmutableList<(Type EntityType, Type TypeReaderType)> _entityTypeReaders; private readonly HashSet _moduleDefs; @@ -77,6 +78,15 @@ public class CommandService : IDisposable /// public ILookup TypeReaders => _typeReaders.SelectMany(x => x.Value.Select(y => new { y.Key, y.Value })).ToLookup(x => x.Key, x => x.Value); + /// + /// Represents all entity type reader s loaded within . + /// + /// + /// A ; the key is the object type to be read by the , + /// while the element is the type of the generic definition. + /// + public ILookup EntityTypeReaders => _userEntityTypeReaders.SelectMany(x => x.Value.Select(y => new { x.Key, TypeReaderType = y })).ToLookup(x => x.Key, y => y.TypeReaderType); + /// /// Initializes a new class. /// @@ -109,6 +119,7 @@ public CommandService(CommandServiceConfig config) _moduleDefs = new HashSet(); _map = new CommandMap(this); _typeReaders = new ConcurrentDictionary>(); + _userEntityTypeReaders = new ConcurrentDictionary>(); _defaultTypeReaders = new ConcurrentDictionary(); foreach (var type in PrimitiveParsers.SupportedTypes) @@ -329,8 +340,6 @@ private bool RemoveModuleInternal(ModuleInfo module) /// type. /// If is a , a nullable will /// also be added. - /// If a default exists for , a warning will be logged - /// and the default will be replaced. /// /// The object type to be read by the . /// An instance of the to be added. @@ -341,17 +350,49 @@ public void AddTypeReader(TypeReader reader) /// type. /// If is a , a nullable for the /// value type will also be added. - /// If a default exists for , a warning will be logged and - /// the default will be replaced. /// /// A instance for the type to be read. /// An instance of the to be added. public void AddTypeReader(Type type, TypeReader reader) { - if (_defaultTypeReaders.ContainsKey(type)) - _ = _cmdLogger.WarningAsync($"The default TypeReader for {type.FullName} was replaced by {reader.GetType().FullName}." + - "To suppress this message, use AddTypeReader(reader, true)."); - AddTypeReader(type, reader, true); + var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary()); + readers[reader.GetType()] = reader; + + if (type.GetTypeInfo().IsValueType) + AddNullableTypeReader(type, reader); + } + /// + /// Adds a custom entity to this for the supplied + /// object type. + /// + /// + /// The following example adds a custom entity reader to this . + /// + /// + /// The object type to be read by the . + /// A generic type definition (with one open argument) of the . + public void AddEntityTypeReader(Type typeReaderGenericType) + => AddEntityTypeReader(typeof(T), typeReaderGenericType); + /// + /// Adds a custom entity to this for the supplied + /// object type. + /// + /// A instance for the type to be read. + /// A generic type definition (with one open argument) of the . + public void AddEntityTypeReader(Type type, Type typeReaderGenericType) + { + if (!typeReaderGenericType.IsGenericTypeDefinition) + throw new ArgumentException("TypeReader type must be a generic type definition.", nameof(typeReaderGenericType)); + Type[] genericArgs = typeReaderGenericType.GetGenericArguments(); + if (genericArgs.Length != 1) + throw new ArgumentException("TypeReader type must accept one and only one open generic argument.", nameof(typeReaderGenericType)); + if (!genericArgs[0].IsGenericParameter) + throw new ArgumentException("TypeReader type must accept one and only one open generic argument.", nameof(typeReaderGenericType)); + if (!genericArgs[0].GenericParameterAttributes.HasFlag(GenericParameterAttributes.ReferenceTypeConstraint)) + throw new ArgumentException("TypeReader generic argument must have a reference type constraint.", nameof(typeReaderGenericType)); + var readers = _userEntityTypeReaders.GetOrAdd(type, x => new ConcurrentQueue()); + readers.Enqueue(typeReaderGenericType); } /// /// Adds a custom to this for the supplied object @@ -359,14 +400,20 @@ public void AddTypeReader(Type type, TypeReader reader) /// If is a , a nullable will /// also be added. /// + /// + /// The following example adds a custom entity reader to this . + /// + /// /// The object type to be read by the . /// An instance of the to be added. /// /// Defines whether the should replace the default one for /// if it exists. /// + [Obsolete("This method is deprecated. Use the method without the replaceDefault argument.")] public void AddTypeReader(TypeReader reader, bool replaceDefault) - => AddTypeReader(typeof(T), reader, replaceDefault); + => AddTypeReader(typeof(T), reader); /// /// Adds a custom to this for the supplied object /// type. @@ -379,27 +426,10 @@ public void AddTypeReader(TypeReader reader, bool replaceDefault) /// Defines whether the should replace the default one for if /// it exists. /// + [Obsolete("This method is deprecated. Use the method without the replaceDefault argument.")] public void AddTypeReader(Type type, TypeReader reader, bool replaceDefault) - { - if (replaceDefault && HasDefaultTypeReader(type)) - { - _defaultTypeReaders.AddOrUpdate(type, reader, (k, v) => reader); - if (type.GetTypeInfo().IsValueType) - { - var nullableType = typeof(Nullable<>).MakeGenericType(type); - var nullableReader = NullableTypeReader.Create(type, reader); - _defaultTypeReaders.AddOrUpdate(nullableType, nullableReader, (k, v) => nullableReader); - } - } - else - { - var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary()); - readers[reader.GetType()] = reader; + => AddTypeReader(type, reader); - if (type.GetTypeInfo().IsValueType) - AddNullableTypeReader(type, reader); - } - } internal bool HasDefaultTypeReader(Type type) { if (_defaultTypeReaders.ContainsKey(type)) @@ -408,7 +438,7 @@ internal bool HasDefaultTypeReader(Type type) var typeInfo = type.GetTypeInfo(); if (typeInfo.IsEnum) return true; - return _entityTypeReaders.Any(x => type == x.EntityType || typeInfo.ImplementedInterfaces.Contains(x.TypeReaderType)); + return _entityTypeReaders.Any(x => type == x.EntityType || typeInfo.ImplementedInterfaces.Contains(x.EntityType)); } internal void AddNullableTypeReader(Type valueType, TypeReader valueTypeReader) { @@ -416,10 +446,39 @@ internal void AddNullableTypeReader(Type valueType, TypeReader valueTypeReader) var nullableReader = NullableTypeReader.Create(valueType, valueTypeReader); readers[nullableReader.GetType()] = nullableReader; } - internal IDictionary GetTypeReaders(Type type) + internal IEnumerable> GetTypeReaders(Type type, bool includeOverride) { if (_typeReaders.TryGetValue(type, out var definedTypeReaders)) - return definedTypeReaders; + return includeOverride ? definedTypeReaders : definedTypeReaders.Where(x => !x.Value.IsOverride); + + var assignableEntityReaders = _userEntityTypeReaders.Where(x => x.Key.IsAssignableFrom(type)); + + int assignableTo = -1; + KeyValuePair>? entityReaders = null; + foreach (var entityReader in assignableEntityReaders) + { + int assignables = assignableEntityReaders.Sum(x => !x.Equals(entityReader) && x.Key.IsAssignableFrom(entityReader.Key) ? 1 : 0); + if (assignableTo == -1) + { + // First time + assignableTo = assignables; + entityReaders = entityReader; + } + // Try to get the most specific type reader, i.e. IMessageChannel is assignable to IChannel, but not the inverse + else if (assignables > assignableTo) + { + assignableTo = assignables; + entityReaders = entityReader; + } + } + + if (entityReaders != null) + { + var entityTypeReaderType = entityReaders.Value.Value.First(); + TypeReader reader = Activator.CreateInstance(entityTypeReaderType.MakeGenericType(type)) as TypeReader; + AddTypeReader(type, reader); + return GetTypeReaders(type, false); + } return null; } internal TypeReader GetDefaultTypeReader(Type type) @@ -511,7 +570,7 @@ public async Task ExecuteAsync(ICommandContext context, string input, I await _commandExecutedEvent.InvokeAsync(Optional.Create(), context, searchResult).ConfigureAwait(false); return searchResult; } - + var commands = searchResult.Commands; var preconditionResults = new Dictionary(); diff --git a/src/Discord.Net.Commands/Readers/NamedArgumentTypeReader.cs b/src/Discord.Net.Commands/Readers/NamedArgumentTypeReader.cs index 0adf610463..b584d29d57 100644 --- a/src/Discord.Net.Commands/Readers/NamedArgumentTypeReader.cs +++ b/src/Discord.Net.Commands/Readers/NamedArgumentTypeReader.cs @@ -136,8 +136,8 @@ async Task ReadArgumentAsync(PropertyInfo prop, string arg) var overridden = prop.GetCustomAttribute(); var reader = (overridden != null) ? ModuleClassBuilder.GetTypeReader(_commands, elemType, overridden.TypeReader, services) - : (_commands.GetDefaultTypeReader(elemType) - ?? _commands.GetTypeReaders(elemType).FirstOrDefault().Value); + : (_commands.GetTypeReaders(elemType, false)?.FirstOrDefault().Value + ?? _commands.GetDefaultTypeReader(elemType)); if (reader != null) { diff --git a/src/Discord.Net.Commands/Readers/TypeReader.cs b/src/Discord.Net.Commands/Readers/TypeReader.cs index af780993dc..a071d9b53e 100644 --- a/src/Discord.Net.Commands/Readers/TypeReader.cs +++ b/src/Discord.Net.Commands/Readers/TypeReader.cs @@ -8,6 +8,7 @@ namespace Discord.Commands /// public abstract class TypeReader { + internal bool IsOverride { get; set; } = false; /// /// Attempts to parse the into the desired type. /// diff --git a/src/Discord.Net.Examples/Commands/CommandService.Examples.cs b/src/Discord.Net.Examples/Commands/CommandService.Examples.cs new file mode 100644 index 0000000000..ca656aaf97 --- /dev/null +++ b/src/Discord.Net.Examples/Commands/CommandService.Examples.cs @@ -0,0 +1,55 @@ +using Discord.Commands; +using JetBrains.Annotations; +using System; +using System.Threading.Tasks; + +namespace Discord.Net.Examples.Commands +{ + [PublicAPI] + internal class CommandServiceExamples + { + #region AddEntityTypeReader + + public void AddCustomUserEntityReader(CommandService commandService) + { + commandService.AddEntityTypeReader(typeof(MyUserTypeReader<>)); + } + + public class MyUserTypeReader : TypeReader + where T : class, IUser + { + public override async Task ReadAsync(ICommandContext context, string input, IServiceProvider services) + { + if (ulong.TryParse(input, out var id)) + return ((await context.Client.GetUserAsync(id)) is T user) + ? TypeReaderResult.FromSuccess(user) + : TypeReaderResult.FromError(CommandError.ObjectNotFound, "User not found."); + return TypeReaderResult.FromError(CommandError.ParseFailed, "Couldn't parse input to ulong."); + } + } + + #endregion + + #region AddEntityTypeReader2 + + public void AddCustomChannelEntityReader(CommandService commandService) + { + commandService.AddEntityTypeReader(typeof(MyUserTypeReader<>)); + } + + public class MyChannelTypeReader : TypeReader + where T : class, IChannel + { + public override async Task ReadAsync(ICommandContext context, string input, IServiceProvider services) + { + if (ulong.TryParse(input, out var id)) + return ((await context.Client.GetChannelAsync(id)) is T channel) + ? TypeReaderResult.FromSuccess(channel) + : TypeReaderResult.FromError(CommandError.ObjectNotFound, "Channel not found."); + return TypeReaderResult.FromError(CommandError.ParseFailed, "Couldn't parse input to ulong."); + } + } + + #endregion + } +} diff --git a/src/Discord.Net.Examples/Discord.Net.Examples.csproj b/src/Discord.Net.Examples/Discord.Net.Examples.csproj index ec02534280..1f459f2e6c 100644 --- a/src/Discord.Net.Examples/Discord.Net.Examples.csproj +++ b/src/Discord.Net.Examples/Discord.Net.Examples.csproj @@ -13,6 +13,7 @@ +