diff --git a/README.md b/README.md
index 7b24146..99ec4fd 100644
--- a/README.md
+++ b/README.md
@@ -35,6 +35,23 @@ services
Or use [automatic registration](https://docs.fluentvalidation.net/en/latest/di.html#automatic-registration).
+## Custom Status Code
+
+Currently, `JSM.FluentValidation.AspNet.AsyncFilter` supports the following status codes:
+
+- 400, Bad Request
+- 403, Forbidden (`ErrorCode.Forbidden`)
+- 404, Not Found (`ErrorCode.NotFound`)
+
+By default, every client error will return a 400 status code (Bad Request). If you want to customize the response, use FluentValidation's [WithErrorCode()](https://docs.fluentvalidation.net/en/latest/error-codes.html):
+
+```c#
+RuleFor(user => user)
+ .Must(user => user.Id != "321")
+ .WithMessage("Insufficient rights to access this resource")
+ .WithErrorCode(ErrorCode.Forbidden);
+```
+
## Customization
If also possible to apply the filter only to controllers that contains the [`ApiControllerAttribute`](https://docs.microsoft.com/en-us/dotnet/api/microsoft.aspnetcore.mvc.apicontrollerattribute).
diff --git a/src/ErrorResponse/ErrorCode.cs b/src/ErrorResponse/ErrorCode.cs
new file mode 100644
index 0000000..655a429
--- /dev/null
+++ b/src/ErrorResponse/ErrorCode.cs
@@ -0,0 +1,28 @@
+using System.Collections.Generic;
+
+namespace JSM.FluentValidation.AspNet.AsyncFilter.ErrorResponse
+{
+ ///
+ /// Defines what HTTP status code should be returned. Use it with the `WithErrorCode` extension method.
+ ///
+ public static class ErrorCode
+ {
+ ///
+ /// 401 HTTP status code as per RFC 2616
+ ///
+ public const string Unauthorized = "UNAUTHORIZED_ERROR";
+
+ ///
+ /// 403 HTTP status code as per RFC 2616
+ ///
+ public const string Forbidden = "FORBIDDEN_ERROR";
+
+ ///
+ /// 404 HTTP status code as per RFC 2616
+ ///
+ public const string NotFound = "NOT_FOUND_ERROR";
+
+ internal static readonly HashSet AvailableCodes = new HashSet()
+ {Unauthorized, Forbidden, NotFound};
+ }
+}
diff --git a/src/ErrorResponse/ErrorResponseFactory.cs b/src/ErrorResponse/ErrorResponseFactory.cs
new file mode 100644
index 0000000..f9d7c89
--- /dev/null
+++ b/src/ErrorResponse/ErrorResponseFactory.cs
@@ -0,0 +1,44 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.AspNetCore.Mvc.ModelBinding;
+
+namespace JSM.FluentValidation.AspNet.AsyncFilter.ErrorResponse
+{
+ internal static class ErrorResponseFactory
+ {
+ public static TraceableProblemDetails CreateErrorResponse(ModelStateDictionary modelState,
+ string traceparent)
+ {
+ if (modelState[ErrorCode.Unauthorized] is not null)
+ return new UnauthorizedResponse(
+ modelState[ErrorCode.Unauthorized]?.Errors.FirstOrDefault()?.ErrorMessage ??
+ string.Empty, traceparent);
+
+ if (modelState[ErrorCode.Forbidden] is not null)
+ return new ForbiddenResponse(
+ modelState[ErrorCode.Forbidden]?.Errors.FirstOrDefault()?.ErrorMessage ??
+ string.Empty, traceparent);
+
+ return new NotFoundResponse(
+ modelState[ErrorCode.NotFound]?.Errors.FirstOrDefault()?.ErrorMessage ??
+ string.Empty,
+ traceparent);
+ }
+
+ public static HttpStatusCode GetResponseStatusCode(ModelStateDictionary modelState)
+ {
+ if (modelState[ErrorCode.Unauthorized] is not null)
+ return HttpStatusCode.Unauthorized;
+
+ if (modelState[ErrorCode.Forbidden] is not null)
+ return HttpStatusCode.Forbidden;
+
+ return modelState[ErrorCode.NotFound] is not null
+ ? HttpStatusCode.NotFound
+ : HttpStatusCode.BadRequest;
+ }
+ }
+}
diff --git a/src/ErrorResponse/ForbiddenResponse.cs b/src/ErrorResponse/ForbiddenResponse.cs
new file mode 100644
index 0000000..b4ce659
--- /dev/null
+++ b/src/ErrorResponse/ForbiddenResponse.cs
@@ -0,0 +1,14 @@
+namespace JSM.FluentValidation.AspNet.AsyncFilter.ErrorResponse
+{
+ internal class ForbiddenResponse : TraceableProblemDetails
+ {
+ public ForbiddenResponse(string message, string traceparent)
+ {
+ Type = "https://datatracker.ietf.org/doc/html/rfc7231#section-6.5.3";
+ Title = ErrorCode.Forbidden;
+ Status = 403;
+ Detail = message;
+ TraceId = traceparent;
+ }
+ }
+}
diff --git a/src/ErrorResponse/NotFoundResponse.cs b/src/ErrorResponse/NotFoundResponse.cs
new file mode 100644
index 0000000..ab4d3ad
--- /dev/null
+++ b/src/ErrorResponse/NotFoundResponse.cs
@@ -0,0 +1,17 @@
+using System.Text.Json.Serialization;
+using Microsoft.AspNetCore.Mvc;
+
+namespace JSM.FluentValidation.AspNet.AsyncFilter.ErrorResponse
+{
+ internal class NotFoundResponse : TraceableProblemDetails
+ {
+ public NotFoundResponse(string message, string traceparent)
+ {
+ Type = "https://datatracker.ietf.org/doc/html/rfc7231#section-6.5.4";
+ Title = ErrorCode.NotFound;
+ Status = 404;
+ Detail = message;
+ TraceId = traceparent;
+ }
+ }
+}
diff --git a/src/ErrorResponse/TraceableProblemDetails.cs b/src/ErrorResponse/TraceableProblemDetails.cs
new file mode 100644
index 0000000..11ac28e
--- /dev/null
+++ b/src/ErrorResponse/TraceableProblemDetails.cs
@@ -0,0 +1,14 @@
+using System.Text.Json.Serialization;
+using Microsoft.AspNetCore.Mvc;
+
+namespace JSM.FluentValidation.AspNet.AsyncFilter.ErrorResponse
+{
+ internal abstract class TraceableProblemDetails : ProblemDetails
+ {
+ ///
+ /// A unique identifier responsible to describe the incoming request.
+ ///
+ [JsonPropertyName("traceId")]
+ public string TraceId { get; set; }
+ }
+}
diff --git a/src/ErrorResponse/UnauthorizedResponse.cs b/src/ErrorResponse/UnauthorizedResponse.cs
new file mode 100644
index 0000000..7fe59cf
--- /dev/null
+++ b/src/ErrorResponse/UnauthorizedResponse.cs
@@ -0,0 +1,14 @@
+namespace JSM.FluentValidation.AspNet.AsyncFilter.ErrorResponse
+{
+ internal class UnauthorizedResponse : TraceableProblemDetails
+ {
+ public UnauthorizedResponse(string message, string traceparent)
+ {
+ Type = "https://datatracker.ietf.org/doc/html/rfc7235#section-3.1";
+ Title = ErrorCode.Unauthorized;
+ Status = 401;
+ Detail = message;
+ TraceId = traceparent;
+ }
+ }
+}
diff --git a/src/JSM.FluentValidation.AspNet.AsyncFilter.csproj b/src/JSM.FluentValidation.AspNet.AsyncFilter.csproj
index bb97fa3..d8f72cb 100644
--- a/src/JSM.FluentValidation.AspNet.AsyncFilter.csproj
+++ b/src/JSM.FluentValidation.AspNet.AsyncFilter.csproj
@@ -17,7 +17,6 @@
-
diff --git a/src/ModelValidationAsyncActionFilter.cs b/src/ModelValidationAsyncActionFilter.cs
index 5382baf..38aa054 100644
--- a/src/ModelValidationAsyncActionFilter.cs
+++ b/src/ModelValidationAsyncActionFilter.cs
@@ -8,8 +8,13 @@
using Microsoft.Extensions.Options;
using System;
using System.Collections;
+using System.Diagnostics;
using System.Linq;
+using System.Net;
+using System.Text.Json;
using System.Threading.Tasks;
+using JSM.FluentValidation.AspNet.AsyncFilter.ErrorResponse;
+using Microsoft.AspNetCore.Http;
namespace JSM.FluentValidation.AspNet.AsyncFilter
{
@@ -47,7 +52,8 @@ public ModelValidationAsyncActionFilter(
///
/// Validates values before the controller's action is invoked (before the route is executed).
///
- public async Task OnActionExecutionAsync(ActionExecutingContext context, ActionExecutionDelegate next)
+ public async Task OnActionExecutionAsync(ActionExecutingContext context,
+ ActionExecutionDelegate next)
{
if (ShouldIgnoreFilter(context))
{
@@ -59,8 +65,29 @@ public async Task OnActionExecutionAsync(ActionExecutingContext context, ActionE
if (!context.ModelState.IsValid)
{
- _logger.LogDebug("The request has model state errors, returning an error response.");
- context.Result = _apiBehaviorOptions.InvalidModelStateResponseFactory(context);
+ _logger.LogDebug(
+ "The request has model state errors, returning an error response");
+ var responseStatusCode =
+ ErrorResponseFactory.GetResponseStatusCode(context.ModelState);
+
+ // BadRequest responses will return the default response structure, only different
+ // status codes will be customized
+ if (responseStatusCode == HttpStatusCode.BadRequest)
+ {
+ context.Result = _apiBehaviorOptions.InvalidModelStateResponseFactory(context);
+ return;
+ }
+
+ var errorResponse =
+ ErrorResponseFactory.CreateErrorResponse(context.ModelState,
+ Activity.Current?.Id ?? context.HttpContext.TraceIdentifier);
+
+ context.HttpContext.Response.StatusCode = (int) responseStatusCode;
+ context.HttpContext.Response.ContentType = "application/json";
+
+ var responseBody = JsonSerializer.Serialize(errorResponse);
+ await context.HttpContext.Response.WriteAsync(responseBody);
+
return;
}
@@ -97,7 +124,8 @@ private bool ShouldIgnoreFilter(ActionExecutingContext context)
return !hasApiControllerAttribute;
}
- private async Task ValidateEnumerableObjectsAsync(object value, ModelStateDictionary modelState)
+ private async Task ValidateEnumerableObjectsAsync(object value,
+ ModelStateDictionary modelState)
{
var underlyingType = value.GetType().GenericTypeArguments[0];
var validator = GetValidator(underlyingType);
@@ -105,14 +133,17 @@ private async Task ValidateEnumerableObjectsAsync(object value, ModelStateDictio
if (validator == null)
return;
- foreach (var item in (IEnumerable)value)
+ foreach (var item in (IEnumerable) value)
{
if (item is null)
continue;
var context = new ValidationContext