diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index c4047dd5336..72dbee5365e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import jakarta.servlet.DispatcherType; import jakarta.servlet.ServletContext; import jakarta.servlet.ServletRegistration; +import jakarta.servlet.http.HttpServletRequest; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; @@ -203,11 +204,30 @@ public C requestMatchers(HttpMethod method, String... patterns) { if (!hasDispatcherServlet(registrations)) { return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); } - if (registrations.size() > 1) { - String errorMessage = computeErrorMessage(registrations.values()); - throw new IllegalArgumentException(errorMessage); + ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations); + if (dispatcherServlet != null) { + if (registrations.size() == 1) { + return requestMatchers(createMvcMatchers(method, patterns).toArray(RequestMatcher[]::new)); + } + List matchers = new ArrayList<>(); + for (String pattern : patterns) { + AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null); + MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0); + matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext)); + } + return requestMatchers(matchers.toArray(new RequestMatcher[0])); } - return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0])); + dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations); + if (dispatcherServlet != null) { + String mapping = dispatcherServlet.getMappings().iterator().next(); + List matchers = createMvcMatchers(method, patterns); + for (MvcRequestMatcher matcher : matchers) { + matcher.setServletPath(mapping.substring(0, mapping.length() - 2)); + } + return requestMatchers(matchers.toArray(new RequestMatcher[0])); + } + String errorMessage = computeErrorMessage(registrations.values()); + throw new IllegalArgumentException(errorMessage); } private Map mappableServletRegistrations(ServletContext servletContext) { @@ -225,22 +245,66 @@ private boolean hasDispatcherServlet(Map if (registrations == null) { return false; } - Class dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet", - null); for (ServletRegistration registration : registrations.values()) { - try { - Class clazz = Class.forName(registration.getClassName()); - if (dispatcherServlet.isAssignableFrom(clazz)) { - return true; - } - } - catch (ClassNotFoundException ex) { - return false; + if (isDispatcherServlet(registration)) { + return true; } } return false; } + private ServletRegistration requireOneRootDispatcherServlet( + Map registrations) { + ServletRegistration rootDispatcherServlet = null; + for (ServletRegistration registration : registrations.values()) { + if (!isDispatcherServlet(registration)) { + continue; + } + if (registration.getMappings().size() > 1) { + return null; + } + if (!"/".equals(registration.getMappings().iterator().next())) { + return null; + } + rootDispatcherServlet = registration; + } + return rootDispatcherServlet; + } + + private ServletRegistration requireOnlyPathMappedDispatcherServlet( + Map registrations) { + ServletRegistration pathDispatcherServlet = null; + for (ServletRegistration registration : registrations.values()) { + if (!isDispatcherServlet(registration)) { + return null; + } + if (registration.getMappings().size() > 1) { + return null; + } + String mapping = registration.getMappings().iterator().next(); + if (!mapping.startsWith("/") || !mapping.endsWith("/*")) { + return null; + } + if (pathDispatcherServlet != null) { + return null; + } + pathDispatcherServlet = registration; + } + return pathDispatcherServlet; + } + + private boolean isDispatcherServlet(ServletRegistration registration) { + Class dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet", + null); + try { + Class clazz = Class.forName(registration.getClassName()); + return dispatcherServlet.isAssignableFrom(clazz); + } + catch (ClassNotFoundException ex) { + return false; + } + } + private String computeErrorMessage(Collection registrations) { String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. " + "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); " @@ -380,4 +444,55 @@ static List regexMatchers(String... regexPatterns) { } + static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher { + + private final AntPathRequestMatcher ant; + + private final MvcRequestMatcher mvc; + + private final ServletContext servletContext; + + DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc, + ServletContext servletContext) { + this.ant = ant; + this.mvc = mvc; + this.servletContext = servletContext; + } + + @Override + public boolean matches(HttpServletRequest request) { + String name = request.getHttpServletMapping().getServletName(); + ServletRegistration registration = this.servletContext.getServletRegistration(name); + Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context"); + if (isDispatcherServlet(registration)) { + return this.mvc.matches(request); + } + return this.ant.matches(request); + } + + @Override + public MatchResult matcher(HttpServletRequest request) { + String name = request.getHttpServletMapping().getServletName(); + ServletRegistration registration = this.servletContext.getServletRegistration(name); + Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context"); + if (isDispatcherServlet(registration)) { + return this.mvc.matcher(request); + } + return this.ant.matcher(request); + } + + private boolean isDispatcherServlet(ServletRegistration registration) { + Class dispatcherServlet = ClassUtils + .resolveClassName("org.springframework.web.servlet.DispatcherServlet", null); + try { + Class clazz = Class.forName(registration.getClassName()); + return dispatcherServlet.isAssignableFrom(clazz); + } + catch (ClassNotFoundException ex) { + return false; + } + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/MockServletContext.java b/config/src/test/java/org/springframework/security/config/MockServletContext.java index 67b7c396e74..d819d4c7989 100644 --- a/config/src/test/java/org/springframework/security/config/MockServletContext.java +++ b/config/src/test/java/org/springframework/security/config/MockServletContext.java @@ -55,6 +55,11 @@ public ServletRegistration.Dynamic addServlet(@NonNull String servletName, Class return this.registrations; } + @Override + public ServletRegistration getServletRegistration(String servletName) { + return this.registrations.get(servletName); + } + private static class MockServletRegistration implements ServletRegistration.Dynamic { private final String name; diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/TestMockHttpServletMappings.java b/config/src/test/java/org/springframework/security/config/TestMockHttpServletMappings.java similarity index 79% rename from config/src/test/java/org/springframework/security/config/annotation/web/configurers/TestMockHttpServletMappings.java rename to config/src/test/java/org/springframework/security/config/TestMockHttpServletMappings.java index 8ab2d42aede..3f1f7f797bb 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/TestMockHttpServletMappings.java +++ b/config/src/test/java/org/springframework/security/config/TestMockHttpServletMappings.java @@ -14,32 +14,32 @@ * limitations under the License. */ -package org.springframework.security.config.annotation.web.configurers; +package org.springframework.security.config; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.MappingMatch; import org.springframework.mock.web.MockHttpServletMapping; -final class TestMockHttpServletMappings { +public final class TestMockHttpServletMappings { private TestMockHttpServletMappings() { } - static MockHttpServletMapping extension(HttpServletRequest request, String extension) { + public static MockHttpServletMapping extension(HttpServletRequest request, String extension) { String uri = request.getRequestURI(); String matchValue = uri.substring(0, uri.lastIndexOf(extension)); return new MockHttpServletMapping(matchValue, "*" + extension, "extension", MappingMatch.EXTENSION); } - static MockHttpServletMapping path(HttpServletRequest request, String path) { + public static MockHttpServletMapping path(HttpServletRequest request, String path) { String uri = request.getRequestURI(); String matchValue = uri.substring(path.length()); return new MockHttpServletMapping(matchValue, path + "/*", "path", MappingMatch.PATH); } - static MockHttpServletMapping defaultMapping() { + public static MockHttpServletMapping defaultMapping() { return new MockHttpServletMapping("", "/", "default", MappingMatch.DEFAULT); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java index 92e99b1fda5..ddcddcd93c8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,8 +26,11 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.http.HttpMethod; +import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.config.MockServletContext; +import org.springframework.security.config.TestMockHttpServletMappings; import org.springframework.security.config.annotation.ObjectPostProcessor; +import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry.DispatcherServletDelegatingRequestMatcher; import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher; @@ -40,6 +43,9 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; /** * Tests for {@link AbstractRequestMatcherRegistry}. @@ -159,6 +165,8 @@ public void requestMatchersWhenMvcPresentInClassPathAndMvcIntrospectorBeanNotAva public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() { MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("servletOne", Servlet.class).addMapping("/one"); + servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two"); List requestMatchers = this.matcherRegistry.requestMatchers("/**"); assertThat(requestMatchers).isNotEmpty(); assertThat(requestMatchers).hasSize(1); @@ -170,7 +178,26 @@ public void requestMatchersWhenAmbiguousServletsThenException() { MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/"); - servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**"); + servletContext.addServlet("servletTwo", DispatcherServlet.class).addMapping("/servlet/*"); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.matcherRegistry.requestMatchers("/**")); + } + + @Test + public void requestMatchersWhenMultipleDispatcherServletMappingsThenException() { + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/", "/mvc/*"); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.matcherRegistry.requestMatchers("/**")); + } + + @Test + public void requestMatchersWhenPathDispatcherServletAndOtherServletsThenException() { + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*"); + servletContext.addServlet("default", Servlet.class).addMapping("/"); assertThatExceptionOfType(IllegalArgumentException.class) .isThrownBy(() -> this.matcherRegistry.requestMatchers("/**")); } @@ -187,6 +214,67 @@ public void requestMatchersWhenUnmappableServletsThenSkips() { assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class); } + @Test + public void requestMatchersWhenOnlyDispatcherServletThenAllows() { + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*"); + List requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class); + } + + @Test + public void requestMatchersWhenImplicitServletsThenAllows() { + mockMvcIntrospector(true); + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("defaultServlet", Servlet.class); + servletContext.addServlet("jspServlet", Servlet.class).addMapping("*.jsp", "*.jspx"); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/"); + List requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class); + } + + @Test + public void requestMatchersWhenPathBasedNonDispatcherServletThenAllows() { + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("path", Servlet.class).addMapping("/services/*"); + servletContext.addServlet("default", DispatcherServlet.class).addMapping("/"); + List requestMatchers = this.matcherRegistry.requestMatchers("/services/*"); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint"); + request.setHttpServletMapping(TestMockHttpServletMappings.defaultMapping()); + assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue(); + request.setHttpServletMapping(TestMockHttpServletMappings.path(request, "/services")); + request.setServletPath("/services"); + request.setPathInfo("/endpoint"); + assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue(); + } + + @Test + public void matchesWhenDispatcherServletThenMvc() { + MockServletContext servletContext = new MockServletContext(); + servletContext.addServlet("default", DispatcherServlet.class).addMapping("/"); + servletContext.addServlet("path", Servlet.class).addMapping("/services/*"); + MvcRequestMatcher mvc = mock(MvcRequestMatcher.class); + AntPathRequestMatcher ant = mock(AntPathRequestMatcher.class); + DispatcherServletDelegatingRequestMatcher requestMatcher = new DispatcherServletDelegatingRequestMatcher(ant, + mvc, servletContext); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint"); + request.setHttpServletMapping(TestMockHttpServletMappings.defaultMapping()); + assertThat(requestMatcher.matches(request)).isFalse(); + verify(mvc).matches(request); + verifyNoInteractions(ant); + request.setHttpServletMapping(TestMockHttpServletMappings.path(request, "/services")); + assertThat(requestMatcher.matches(request)).isFalse(); + verify(ant).matches(request); + verifyNoMoreInteractions(mvc); + } + private void mockMvcIntrospector(boolean isPresent) { ApplicationContext context = this.matcherRegistry.getApplicationContext(); given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeHttpRequestsConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeHttpRequestsConfigurerTests.java index cb8837f9782..19d5de6f24b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeHttpRequestsConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeHttpRequestsConfigurerTests.java @@ -36,6 +36,7 @@ import org.springframework.security.authorization.AuthorizationEventPublisher; import org.springframework.security.authorization.AuthorizationManager; import org.springframework.security.config.MockServletContext; +import org.springframework.security.config.TestMockHttpServletMappings; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry; import org.springframework.security.config.annotation.web.builders.HttpSecurity; diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ServletPatternRequestMatcherTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ServletPatternRequestMatcherTests.java index 5c287a6936a..98a371bbc7b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ServletPatternRequestMatcherTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/ServletPatternRequestMatcherTests.java @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Test; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.config.TestMockHttpServletMappings; import static org.assertj.core.api.Assertions.assertThat;