diff --git a/FluentValidation.AutoValidation.Mvc/src/Filters/FluentValidationAutoValidationActionFilter.cs b/FluentValidation.AutoValidation.Mvc/src/Filters/FluentValidationAutoValidationActionFilter.cs index 438e914..2ae20dd 100644 --- a/FluentValidation.AutoValidation.Mvc/src/Filters/FluentValidationAutoValidationActionFilter.cs +++ b/FluentValidation.AutoValidation.Mvc/src/Filters/FluentValidationAutoValidationActionFilter.cs @@ -1,10 +1,12 @@ -using System.Linq; +using System; +using System.Linq; using System.Threading.Tasks; using FluentValidation; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Controllers; using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.Mvc.ModelBinding; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; @@ -22,9 +24,7 @@ public class FluentValidationAutoValidationActionFilter : IAsyncActionFilter private readonly IFluentValidationAutoValidationResultFactory fluentValidationAutoValidationResultFactory; private readonly AutoValidationMvcConfiguration autoValidationMvcConfiguration; - public FluentValidationAutoValidationActionFilter( - IFluentValidationAutoValidationResultFactory fluentValidationAutoValidationResultFactory, - IOptions autoValidationMvcConfiguration) + public FluentValidationAutoValidationActionFilter(IFluentValidationAutoValidationResultFactory fluentValidationAutoValidationResultFactory, IOptions autoValidationMvcConfiguration) { this.fluentValidationAutoValidationResultFactory = fluentValidationAutoValidationResultFactory; this.autoValidationMvcConfiguration = autoValidationMvcConfiguration.Value; @@ -32,7 +32,7 @@ public FluentValidationAutoValidationActionFilter( public async Task OnActionExecutionAsync(ActionExecutingContext actionExecutingContext, ActionExecutionDelegate next) { - if (actionExecutingContext.Controller is ControllerBase controllerBase) + if (IsValidController(actionExecutingContext.Controller)) { var endpoint = actionExecutingContext.HttpContext.GetEndpoint(); var controllerActionDescriptor = (ControllerActionDescriptor) actionExecutingContext.ActionDescriptor; @@ -108,7 +108,8 @@ public async Task OnActionExecutionAsync(ActionExecutingContext actionExecutingC if (!actionExecutingContext.ModelState.IsValid) { - var validationProblemDetails = controllerBase.ProblemDetailsFactory.CreateValidationProblemDetails(actionExecutingContext.HttpContext, actionExecutingContext.ModelState); + var problemDetailsFactory = serviceProvider.GetRequiredService(); + var validationProblemDetails = problemDetailsFactory.CreateValidationProblemDetails(actionExecutingContext.HttpContext, actionExecutingContext.ModelState); actionExecutingContext.Result = fluentValidationAutoValidationResultFactory.CreateActionResult(actionExecutingContext, validationProblemDetails); @@ -119,6 +120,21 @@ public async Task OnActionExecutionAsync(ActionExecutingContext actionExecutingC await next(); } + private bool IsValidController(object controller) + { + var controllerType = controller.GetType(); + + if (controllerType.HasCustomAttribute()) + { + return false; + } + + return controller is ControllerBase || + controllerType.HasCustomAttribute() || + controllerType.Name.EndsWith("Controller", StringComparison.OrdinalIgnoreCase) || + controllerType.InheritsFromTypeWithNameEndingIn("Controller"); + } + private bool HasValidBindingSource(BindingSource? bindingSource) { return (autoValidationMvcConfiguration.EnableBodyBindingSourceAutomaticValidation && bindingSource == BindingSource.Body) || diff --git a/FluentValidation.AutoValidation.Shared/src/Extensions/TypeExtensions.cs b/FluentValidation.AutoValidation.Shared/src/Extensions/TypeExtensions.cs index 9f90947..f8b08ca 100644 --- a/FluentValidation.AutoValidation.Shared/src/Extensions/TypeExtensions.cs +++ b/FluentValidation.AutoValidation.Shared/src/Extensions/TypeExtensions.cs @@ -25,5 +25,20 @@ public static bool HasCustomAttribute(this Type type) where TAttribu { return type.CustomAttributes.Any(attribute => attribute.AttributeType == typeof(TAttribute)); } + + public static bool InheritsFromTypeWithNameEndingIn(this Type type, string name) + { + while (type.BaseType != null) + { + type = type.BaseType; + + if (type.Name.EndsWith(name, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + } + + return false; + } } } \ No newline at end of file diff --git a/Tests/FluentValidation.AutoValidation.Tests.csproj b/Tests/FluentValidation.AutoValidation.Tests.csproj index 6249b84..2dfddc9 100644 --- a/Tests/FluentValidation.AutoValidation.Tests.csproj +++ b/Tests/FluentValidation.AutoValidation.Tests.csproj @@ -8,26 +8,26 @@ - - - + + + all runtime; build; native; contentfiles; analyzers; buildtransitive - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/Tests/src/FluentValidation.AutoValidation.Mvc/Filters/FluentValidationAutoValidationActionFilterTest.cs b/Tests/src/FluentValidation.AutoValidation.Mvc/Filters/FluentValidationAutoValidationActionFilterTest.cs index fa762ae..835d626 100644 --- a/Tests/src/FluentValidation.AutoValidation.Mvc/Filters/FluentValidationAutoValidationActionFilterTest.cs +++ b/Tests/src/FluentValidation.AutoValidation.Mvc/Filters/FluentValidationAutoValidationActionFilterTest.cs @@ -73,10 +73,11 @@ public async Task TestOnActionExecutionAsync() serviceProvider.GetService(typeof(IValidator<>).MakeGenericType(typeof(TestModel))).Returns(new TestValidator()); serviceProvider.GetService(typeof(IGlobalValidationInterceptor)).Returns(new GlobalValidationInterceptor()); + serviceProvider.GetService(typeof(ProblemDetailsFactory)).Returns(problemDetailsFactory); + problemDetailsFactory.CreateValidationProblemDetails(httpContext, modelStateDictionary).Returns(validationProblemDetails); fluentValidationAutoValidationResultFactory.CreateActionResult(actionExecutingContext, validationProblemDetails).Returns(new BadRequestObjectResult(validationProblemDetails)); httpContext.RequestServices.Returns(serviceProvider); - controller.ProblemDetailsFactory = problemDetailsFactory; actionExecutingContext.Controller.Returns(controller); actionExecutingContext.ActionDescriptor = controllerActionDescriptor; actionExecutingContext.ActionArguments.Returns(actionArguments); diff --git a/Tests/src/FluentValidation.AutoValidation.Shared/Extensions/TypeExtensionsTest.cs b/Tests/src/FluentValidation.AutoValidation.Shared/Extensions/TypeExtensionsTest.cs index a1d87d8..bfbd3d6 100644 --- a/Tests/src/FluentValidation.AutoValidation.Shared/Extensions/TypeExtensionsTest.cs +++ b/Tests/src/FluentValidation.AutoValidation.Shared/Extensions/TypeExtensionsTest.cs @@ -1,4 +1,5 @@ using System; +using Microsoft.AspNetCore.Mvc; using SharpGrip.FluentValidation.AutoValidation.Mvc.Attributes; using SharpGrip.FluentValidation.AutoValidation.Shared.Extensions; using Xunit; @@ -46,6 +47,21 @@ public void Test_HasCustomAttribute() Assert.False(typeof(TestModelRecord).HasCustomAttribute()); } + [Fact] + public void Test_InheritsFromTypeWithNameEndingIn() + { + Assert.True(typeof(TestInherits1).InheritsFromTypeWithNameEndingIn("Controller")); + Assert.True(typeof(TestInherits1).InheritsFromTypeWithNameEndingIn("controller")); + Assert.True(typeof(TestInherits2).InheritsFromTypeWithNameEndingIn("Controller")); + Assert.True(typeof(TestInherits2).InheritsFromTypeWithNameEndingIn("controller")); + Assert.False(typeof(TestInherits3).InheritsFromTypeWithNameEndingIn("Controller")); + Assert.False(typeof(TestInherits3).InheritsFromTypeWithNameEndingIn("controller")); + Assert.False(typeof(TestInherits4).InheritsFromTypeWithNameEndingIn("Controller")); + Assert.False(typeof(TestInherits4).InheritsFromTypeWithNameEndingIn("controller")); + Assert.False(typeof(TestInherits5).InheritsFromTypeWithNameEndingIn("Controller")); + Assert.False(typeof(TestInherits5).InheritsFromTypeWithNameEndingIn("controller")); + } + [AutoValidation] private class TestModelClass; @@ -53,4 +69,16 @@ private class TestModelClass; private record TestModelRecord; private enum TestModelEnum; + + private class TestInherits1 : Controller; + + private class TestInherits2 : CustomControllerBase; + + private class TestInherits3 : ControllerBase; + + private class TestInherits4 : ActionContext; + + private class TestInherits5 : object; + + private class CustomControllerBase : Controller; } \ No newline at end of file