From ed96e2cddf87caaa8e615dd34f333fe993869a87 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 18 Aug 2023 15:11:33 -0600 Subject: [PATCH 1/2] Ignore Unmappable Servlets Closes gh-13666 --- .../web/AbstractRequestMatcherRegistry.java | 15 +++++++++++++-- .../security/config/MockServletContext.java | 11 ++++++++--- .../web/AbstractRequestMatcherRegistryTests.java | 15 +++++++++++++-- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index 0df4204dbd1..9ea268e8c21 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -312,8 +313,8 @@ public C requestMatchers(HttpMethod method, String... patterns) { if (servletContext == null) { return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); } - Map registrations = servletContext.getServletRegistrations(); - if (registrations == null) { + Map registrations = mappableServletRegistrations(servletContext); + if (registrations.isEmpty()) { return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); } if (!hasDispatcherServlet(registrations)) { @@ -324,6 +325,16 @@ public C requestMatchers(HttpMethod method, String... patterns) { return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0])); } + private Map mappableServletRegistrations(ServletContext servletContext) { + Map mappable = new LinkedHashMap<>(); + for (Map.Entry entry : servletContext.getServletRegistrations().entrySet()) { + if (!entry.getValue().getMappings().isEmpty()) { + mappable.put(entry.getKey(), entry.getValue()); + } + } + return mappable; + } + private boolean hasDispatcherServlet(Map registrations) { if (registrations == null) { return false; diff --git a/config/src/test/java/org/springframework/security/config/MockServletContext.java b/config/src/test/java/org/springframework/security/config/MockServletContext.java index aac3515ee35..df3ca415c7d 100644 --- a/config/src/test/java/org/springframework/security/config/MockServletContext.java +++ b/config/src/test/java/org/springframework/security/config/MockServletContext.java @@ -16,8 +16,10 @@ package org.springframework.security.config; +import java.util.Arrays; import java.util.Collection; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; @@ -35,7 +37,7 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet public static MockServletContext mvc() { MockServletContext servletContext = new MockServletContext(); - servletContext.addServlet("dispatcherServlet", DispatcherServlet.class); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/"); return servletContext; } @@ -59,6 +61,8 @@ private static class MockServletRegistration implements ServletRegistration.Dyna private final Class clazz; + private final Set mappings = new LinkedHashSet<>(); + MockServletRegistration(String name, Class clazz) { this.name = name; this.clazz = clazz; @@ -91,12 +95,13 @@ public void setAsyncSupported(boolean isAsyncSupported) { @Override public Set addMapping(String... urlPatterns) { - return null; + this.mappings.addAll(Arrays.asList(urlPatterns)); + return this.mappings; } @Override public Collection getMappings() { - return null; + return this.mappings; } @Override diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java index a3125ce54e1..2ca3d8f25c8 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java @@ -211,12 +211,23 @@ public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType( public void requestMatchersWhenAmbiguousServletsThenException() { MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); - servletContext.addServlet("dispatcherServlet", DispatcherServlet.class); - servletContext.addServlet("servletTwo", Servlet.class); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/"); + servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**"); assertThatExceptionOfType(IllegalArgumentException.class) .isThrownBy(() -> this.matcherRegistry.requestMatchers("/**")); } + @Test + public void requestMatchersWhenUnmappableServletsThenSkips() { + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/"); + servletContext.addServlet("servletTwo", Servlet.class); + List requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class); + } + private void mockMvcIntrospector(boolean isPresent) { ApplicationContext context = this.matcherRegistry.getApplicationContext(); given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent); From 28f98b3351dae6dd350a89d80a044848a8ed60f1 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Sun, 20 Aug 2023 22:04:37 -0600 Subject: [PATCH 2/2] Improve Error Message Closes gh-13667 --- .../web/AbstractRequestMatcherRegistry.java | 23 ++++++++++++++++--- .../AbstractRequestMatcherRegistryTests.java | 4 +++- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index 9ea268e8c21..7ec65d6ddf8 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -320,14 +321,17 @@ public C requestMatchers(HttpMethod method, String... patterns) { if (!hasDispatcherServlet(registrations)) { return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); } - Assert.isTrue(registrations.size() == 1, - "This method cannot decide whether these patterns are Spring MVC patterns or not. If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); otherwise, please use requestMatchers(AntPathRequestMatcher)."); + if (registrations.size() > 1) { + String errorMessage = computeErrorMessage(registrations.values()); + throw new IllegalArgumentException(errorMessage); + } return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0])); } private Map mappableServletRegistrations(ServletContext servletContext) { Map mappable = new LinkedHashMap<>(); - for (Map.Entry entry : servletContext.getServletRegistrations().entrySet()) { + for (Map.Entry entry : servletContext.getServletRegistrations() + .entrySet()) { if (!entry.getValue().getMappings().isEmpty()) { mappable.put(entry.getKey(), entry.getValue()); } @@ -355,6 +359,19 @@ private boolean hasDispatcherServlet(Map return false; } + private String computeErrorMessage(Collection registrations) { + String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. " + + "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); " + + "otherwise, please use requestMatchers(AntPathRequestMatcher).\n\n" + + "This is because there is more than one mappable servlet in your servlet context: %s.\n\n" + + "For each MvcRequestMatcher, call MvcRequestMatcher#setServletPath to indicate the servlet path."; + Map> mappings = new LinkedHashMap<>(); + for (ServletRegistration registration : registrations) { + mappings.put(registration.getClassName(), registration.getMappings()); + } + return String.format(template, mappings); + } + /** *

* If the {@link HandlerMappingIntrospector} is available in the classpath, maps to an diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java index 2ca3d8f25c8..fa0a794151b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java @@ -63,12 +63,13 @@ public O postProcess(O object) { private WebApplicationContext context; @BeforeEach - public void setUp() { + public void setUp() throws Exception { this.matcherRegistry = new TestRequestMatcherRegistry(); this.context = mock(WebApplicationContext.class); given(this.context.getBean(ObjectPostProcessor.class)).willReturn(NO_OP_OBJECT_POST_PROCESSOR); given(this.context.getServletContext()).willReturn(MockServletContext.mvc()); this.matcherRegistry.setApplicationContext(this.context); + mockMvcPresentClasspath(true); } @Test @@ -219,6 +220,7 @@ public void requestMatchersWhenAmbiguousServletsThenException() { @Test public void requestMatchersWhenUnmappableServletsThenSkips() { + mockMvcIntrospector(true); MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");