diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index 468afafad4e..789b0042ce2 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -16,9 +16,13 @@ package org.springframework.security.config.annotation.web.configurers.saml2; +import java.util.ArrayList; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import jakarta.servlet.http.HttpServletRequest; + import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.security.authentication.AuthenticationManager; @@ -33,6 +37,7 @@ import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrations; import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository; import org.springframework.security.saml2.provider.service.web.OpenSamlAuthenticationTokenConverter; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository; @@ -50,6 +55,7 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.NegatedRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; +import org.springframework.security.web.util.matcher.ParameterRequestMatcher; import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.security.web.util.matcher.RequestMatchers; @@ -111,7 +117,13 @@ public final class Saml2LoginConfigurer> private String loginPage; - private String authenticationRequestUri = Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI; + private String authenticationRequestUri = "/saml2/authenticate"; + + private String[] authenticationRequestParams = { "registrationId={registrationId}" }; + + private RequestMatcher authenticationRequestMatcher = RequestMatchers.anyOf( + new AntPathRequestMatcher(Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI), + new AntPathQueryRequestMatcher(this.authenticationRequestUri, this.authenticationRequestParams)); private Saml2AuthenticationRequestResolver authenticationRequestResolver; @@ -196,11 +208,31 @@ public Saml2LoginConfigurer authenticationRequestResolver( * Request * @return the {@link Saml2LoginConfigurer} for further configuration * @since 6.0 + * @deprecated Use {@link #authenticationRequestUriQuery} instead */ public Saml2LoginConfigurer authenticationRequestUri(String authenticationRequestUri) { - Assert.state(authenticationRequestUri.contains("{registrationId}"), - "authenticationRequestUri must contain {registrationId} path variable"); - this.authenticationRequestUri = authenticationRequestUri; + return authenticationRequestUriQuery(authenticationRequestUri); + } + + /** + * Customize the URL that the SAML Authentication Request will be sent to. This method + * also supports query parameters like so:
+	 * 	authenticationRequestUriQuery("/saml/authenticate?registrationId={registrationId}")
+	 * 
{@link RelyingPartyRegistrations} + * @param authenticationRequestUriQuery the URI and query to use for the SAML 2.0 + * Authentication Request + * @return the {@link Saml2LoginConfigurer} for further configuration + * @since 6.0 + */ + public Saml2LoginConfigurer authenticationRequestUriQuery(String authenticationRequestUriQuery) { + Assert.state(authenticationRequestUriQuery.contains("{registrationId}"), + "authenticationRequestUri must contain {registrationId} path variable or query value"); + String[] parts = authenticationRequestUriQuery.split("[?&]"); + this.authenticationRequestUri = parts[0]; + this.authenticationRequestParams = new String[parts.length - 1]; + System.arraycopy(parts, 1, this.authenticationRequestParams, 0, parts.length - 1); + this.authenticationRequestMatcher = new AntPathQueryRequestMatcher(this.authenticationRequestUri, + this.authenticationRequestParams); return this; } @@ -255,7 +287,7 @@ public void init(B http) throws Exception { } else { Map providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri, - this.relyingPartyRegistrationRepository); + this.authenticationRequestParams, this.relyingPartyRegistrationRepository); boolean singleProvider = providerUrlMap.size() == 1; if (singleProvider) { // Setup auto-redirect to provider login page @@ -336,8 +368,7 @@ private Saml2AuthenticationRequestResolver getAuthenticationRequestResolver(B ht } OpenSaml4AuthenticationRequestResolver openSaml4AuthenticationRequestResolver = new OpenSaml4AuthenticationRequestResolver( relyingPartyRegistrationRepository(http)); - openSaml4AuthenticationRequestResolver - .setRequestMatcher(new AntPathRequestMatcher(this.authenticationRequestUri)); + openSaml4AuthenticationRequestResolver.setRequestMatcher(this.authenticationRequestMatcher); return openSaml4AuthenticationRequestResolver; } @@ -382,20 +413,28 @@ private void initDefaultLoginFilter(B http) { return; } loginPageGeneratingFilter.setSaml2LoginEnabled(true); - loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName( - this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository)); + loginPageGeneratingFilter + .setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap(this.authenticationRequestUri, + this.authenticationRequestParams, this.relyingPartyRegistrationRepository)); loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage()); loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl()); } @SuppressWarnings("unchecked") - private Map getIdentityProviderUrlMap(String authRequestPrefixUrl, + private Map getIdentityProviderUrlMap(String authRequestPrefixUrl, String[] authRequestQueryParams, RelyingPartyRegistrationRepository idpRepo) { Map idps = new LinkedHashMap<>(); if (idpRepo instanceof Iterable) { Iterable repo = (Iterable) idpRepo; - repo.forEach((p) -> idps.put(authRequestPrefixUrl.replace("{registrationId}", p.getRegistrationId()), - p.getRegistrationId())); + StringBuilder authRequestQuery = new StringBuilder("?"); + for (String authRequestQueryParam : authRequestQueryParams) { + authRequestQuery.append(authRequestQueryParam + "&"); + } + authRequestQuery.deleteCharAt(authRequestQuery.length() - 1); + String authenticationRequestUriQuery = authRequestPrefixUrl + authRequestQuery; + repo.forEach( + (p) -> idps.put(authenticationRequestUriQuery.replace("{registrationId}", p.getRegistrationId()), + p.getRegistrationId())); } return idps; } @@ -437,4 +476,35 @@ private void setSharedObject(B http, Class clazz, C object) { } } + static class AntPathQueryRequestMatcher implements RequestMatcher { + + private final RequestMatcher matcher; + + AntPathQueryRequestMatcher(String path, String... params) { + List matchers = new ArrayList<>(); + matchers.add(new AntPathRequestMatcher(path)); + for (String param : params) { + String[] parts = param.split("="); + if (parts.length == 1) { + matchers.add(new ParameterRequestMatcher(parts[0])); + } + else { + matchers.add(new ParameterRequestMatcher(parts[0], parts[1])); + } + } + this.matcher = new AndRequestMatcher(matchers); + } + + @Override + public boolean matches(HttpServletRequest request) { + return matcher(request).isMatch(); + } + + @Override + public MatchResult matcher(HttpServletRequest request) { + return this.matcher.matcher(request); + } + + } + } diff --git a/config/src/main/kotlin/org/springframework/security/config/annotation/web/Saml2Dsl.kt b/config/src/main/kotlin/org/springframework/security/config/annotation/web/Saml2Dsl.kt index 810bf54447a..e8f52dd44ac 100644 --- a/config/src/main/kotlin/org/springframework/security/config/annotation/web/Saml2Dsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/annotation/web/Saml2Dsl.kt @@ -48,6 +48,7 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand class Saml2Dsl { var relyingPartyRegistrationRepository: RelyingPartyRegistrationRepository? = null var loginPage: String? = null + var authenticationRequestUriQuery: String? = null var authenticationSuccessHandler: AuthenticationSuccessHandler? = null var authenticationFailureHandler: AuthenticationFailureHandler? = null var failureUrl: String? = null @@ -88,6 +89,9 @@ class Saml2Dsl { defaultSuccessUrlOption?.also { saml2Login.defaultSuccessUrl(defaultSuccessUrlOption!!.first, defaultSuccessUrlOption!!.second) } + authenticationRequestUriQuery?.also { + saml2Login.authenticationRequestUriQuery(authenticationRequestUriQuery) + } authenticationSuccessHandler?.also { saml2Login.successHandler(authenticationSuccessHandler) } authenticationFailureHandler?.also { saml2Login.failureHandler(authenticationFailureHandler) } authenticationManager?.also { saml2Login.authenticationManager(authenticationManager) } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index 461f030e9f2..c637c70a409 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -101,6 +101,7 @@ import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.startsWith; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; @@ -113,6 +114,7 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @@ -343,6 +345,19 @@ public void authenticationRequestWhenCustomAuthenticationRequestUriRepositoryThe any(HttpServletRequest.class), any(HttpServletResponse.class)); } + @Test + public void authenticationRequestWhenCustomAuthenticationRequestPathRepositoryThenUses() throws Exception { + this.spring.register(CustomAuthenticationRequestUriQuery.class).autowire(); + MockHttpServletRequestBuilder request = get("/custom/auth/sso"); + this.mvc.perform(request) + .andExpect(status().isFound()) + .andExpect(redirectedUrl("http://localhost/custom/auth/sso?entityId=registration-id")); + request.queryParam("entityId", registration.getRegistrationId()); + MvcResult result = this.mvc.perform(request).andExpect(status().isFound()).andReturn(); + String redirectedUrl = result.getResponse().getRedirectedUrl(); + assertThat(redirectedUrl).startsWith(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation()); + } + @Test public void saml2LoginWhenLoginProcessingUrlWithoutRegistrationIdAndDefaultAuthenticationConverterThenAutowires() throws Exception { @@ -390,7 +405,7 @@ public void getFaviconWhenDefaultConfigurationThenDoesNotSaveAuthnRequest() thro .andExpect(redirectedUrl("http://localhost/login")); this.mvc.perform(get("/").accept(MediaType.TEXT_HTML)) .andExpect(status().isFound()) - .andExpect(redirectedUrl("http://localhost/saml2/authenticate/registration-id")); + .andExpect(header().string("Location", startsWith("http://localhost/saml2/authenticate"))); } @Test @@ -669,6 +684,23 @@ SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { } + @Configuration + @EnableWebSecurity + @Import(Saml2LoginConfigBeans.class) + static class CustomAuthenticationRequestUriQuery { + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeHttpRequests((authz) -> authz.anyRequest().authenticated()) + .saml2Login((saml2) -> saml2.authenticationRequestUriQuery("/custom/auth/sso?entityId={registrationId}")); + // @formatter:on + return http.build(); + } + + } + @Configuration @EnableWebSecurity @Import(Saml2LoginConfigBeans.class) diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/Saml2DslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/Saml2DslTests.kt index cf7d716e43d..40b88fbc18c 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/Saml2DslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/Saml2DslTests.kt @@ -43,11 +43,13 @@ import org.springframework.security.saml2.provider.service.registration.TestRely import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter import org.springframework.security.web.SecurityFilterChain import org.springframework.test.web.servlet.MockMvc +import org.springframework.test.web.servlet.MvcResult import org.springframework.test.web.servlet.get import org.springframework.test.web.servlet.request.MockMvcRequestBuilders +import org.springframework.test.web.servlet.result.MockMvcResultMatchers import java.security.cert.Certificate import java.security.cert.CertificateFactory -import java.util.Base64 +import java.util.* /** * Tests for [Saml2Dsl] @@ -136,6 +138,23 @@ class Saml2DslTests { verify(exactly = 1) { Saml2LoginCustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()) } } + @Test + @Throws(Exception::class) + fun authenticationRequestWhenCustomAuthenticationRequestPathRepositoryThenUses() { + this.spring.register(CustomAuthenticationRequestUriQuery::class.java).autowire() + val registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build(); + val request = MockMvcRequestBuilders.get("/custom/auth/sso") + this.mockMvc.perform(request) + .andExpect(MockMvcResultMatchers.status().isFound()) + .andExpect(MockMvcResultMatchers.redirectedUrl("http://localhost/custom/auth/sso?entityId=simplesamlphp")) + request.queryParam("entityId", registration.registrationId) + val result: MvcResult = + this.mockMvc.perform(request).andExpect(MockMvcResultMatchers.status().isFound()).andReturn() + val redirectedUrl = result.response.redirectedUrl + Assertions.assertThat(redirectedUrl) + .startsWith(registration.assertingPartyDetails.singleSignOnServiceLocation) + } + @Configuration @EnableWebSecurity open class Saml2LoginCustomAuthenticationManagerConfig { @@ -162,4 +181,26 @@ class Saml2DslTests { return repository } } + + @Configuration + @EnableWebSecurity + open class CustomAuthenticationRequestUriQuery { + @Bean + open fun securityFilterChain(http: HttpSecurity): SecurityFilterChain { + http { + authorizeHttpRequests { + authorize(anyRequest, authenticated) + } + saml2Login { + authenticationRequestUriQuery = "/custom/auth/sso?entityId={registrationId}" + } + } + return http.build() + } + + @Bean + open fun relyingPartyRegistrationRepository(): RelyingPartyRegistrationRepository? { + return InMemoryRelyingPartyRegistrationRepository(TestRelyingPartyRegistrations.relyingPartyRegistration().build()) + } + } } diff --git a/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc b/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc index 4e0ec21d32e..6e4e49d891f 100644 --- a/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc +++ b/docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc @@ -4,7 +4,7 @@ As stated earlier, Spring Security's SAML 2.0 support produces a `` to commence authentication with the asserting party. Spring Security achieves this in part by registering the `Saml2WebSsoAuthenticationRequestFilter` in the filter chain. -This filter by default responds to endpoint `+/saml2/authenticate/{registrationId}+`. +This filter by default responds to the endpoints `+/saml2/authenticate/{registrationId}+` and `+/saml2/authenticate?registrationId={registrationId}+`. For example, if you were deployed to `https://rp.example.com` and you gave your registration an ID of `okta`, you could navigate to: @@ -12,6 +12,42 @@ For example, if you were deployed to `https://rp.example.com` and you gave your and the result would be a redirect that included a `SAMLRequest` parameter containing the signed, deflated, and encoded ``. +== Configuring the `` Endpoint + +To configure the endpoint differently from the default, you can set the value in `saml2Login`: + +[tabs] +====== +Java:: ++ +[source,java,role="primary"] +---- +@Bean +SecurityFilterChain filterChain(HttpSecurity http) { + http + .saml2Login((saml2) -> saml2 + .authenticationRequestUriQuery("/custom/auth/sso?peerEntityID={registrationId}") + ); + return new CustomSaml2AuthenticationRequestRepository(); +} +---- + +Kotlin:: ++ +[source,kotlin,role="secondary"] +---- +@Bean +fun filterChain(http: HttpSecurity): SecurityFilterChain { + http { + saml2Login { + authenticationRequestUriQuery = "/custom/auth/sso?peerEntityID={registrationId}" + } + } + return CustomSaml2AuthenticationRequestRepository() +} +---- +====== + [[servlet-saml2login-store-authn-request]] == Changing How the `` Gets Stored diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java index 39cd8bafb21..85ed7ae877a 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java @@ -17,6 +17,8 @@ package org.springframework.security.saml2.provider.service.web.authentication; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.function.BiConsumer; @@ -50,8 +52,11 @@ import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; +import org.springframework.security.web.util.matcher.AndRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.ParameterRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatchers; import org.springframework.util.Assert; /** @@ -75,8 +80,9 @@ class OpenSamlAuthenticationRequestResolver { private final NameIDPolicyBuilder nameIdPolicyBuilder; - private RequestMatcher requestMatcher = new AntPathRequestMatcher( - Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI); + private RequestMatcher requestMatcher = RequestMatchers.anyOf( + new AntPathRequestMatcher(Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI), + new AntPathQueryRequestMatcher("/saml2/authenticate", "registrationId={registrationId}")); private Converter relayStateResolver = (request) -> UUID.randomUUID().toString(); @@ -199,4 +205,35 @@ private String serialize(AuthnRequest authnRequest) { } } + private static final class AntPathQueryRequestMatcher implements RequestMatcher { + + private final RequestMatcher matcher; + + AntPathQueryRequestMatcher(String path, String... params) { + List matchers = new ArrayList<>(); + matchers.add(new AntPathRequestMatcher(path)); + for (String param : params) { + String[] parts = param.split("="); + if (parts.length == 1) { + matchers.add(new ParameterRequestMatcher(parts[0])); + } + else { + matchers.add(new ParameterRequestMatcher(parts[0], parts[1])); + } + } + this.matcher = new AndRequestMatcher(matchers); + } + + @Override + public boolean matches(HttpServletRequest request) { + return matcher(request).isMatch(); + } + + @Override + public MatchResult matcher(HttpServletRequest request) { + return this.matcher.matcher(request); + } + + } + }