diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandler.java index 3cc5754cac1..fc3d4fecf57 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandler.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import reactor.core.publisher.Mono; +import org.springframework.core.convert.converter.Converter; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; @@ -35,6 +36,7 @@ import org.springframework.security.web.server.authentication.logout.RedirectServerLogoutSuccessHandler; import org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler; import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; @@ -57,6 +59,8 @@ public class OidcClientInitiatedServerLogoutSuccessHandler implements ServerLogo private String postLogoutRedirectUri; + private Converter> redirectUriResolver = new DefaultRedirectUriResolver(); + /** * Constructs an {@link OidcClientInitiatedServerLogoutSuccessHandler} with the * provided parameters @@ -72,22 +76,11 @@ public OidcClientInitiatedServerLogoutSuccessHandler( @Override public Mono onLogoutSuccess(WebFilterExchange exchange, Authentication authentication) { + RedirectUriParameters redirectUriParameters = new RedirectUriParameters(); + redirectUriParameters.setAuthentication(authentication); + redirectUriParameters.setServerWebExchange(exchange.getExchange()); // @formatter:off - return Mono.just(authentication) - .filter(OAuth2AuthenticationToken.class::isInstance) - .filter((token) -> authentication.getPrincipal() instanceof OidcUser) - .map(OAuth2AuthenticationToken.class::cast) - .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId) - .flatMap(this.clientRegistrationRepository::findByRegistrationId) - .flatMap((clientRegistration) -> { - URI endSessionEndpoint = endSessionEndpoint(clientRegistration); - if (endSessionEndpoint == null) { - return Mono.empty(); - } - String idToken = idToken(authentication); - String postLogoutRedirectUri = postLogoutRedirectUri(exchange.getExchange().getRequest(), clientRegistration); - return Mono.just(endpointUri(endSessionEndpoint, idToken, postLogoutRedirectUri)); - }) + return this.redirectUriResolver.convert(redirectUriParameters) .switchIfEmpty( this.serverLogoutSuccessHandler.onLogoutSuccess(exchange, authentication).then(Mono.empty()) ) @@ -189,4 +182,90 @@ public void setLogoutSuccessUrl(URI logoutSuccessUrl) { this.serverLogoutSuccessHandler.setLogoutSuccessUrl(logoutSuccessUrl); } + /** + * Set the {@link Converter} that converts {@link RedirectUriParameters} to redirect + * URI + * @param redirectUriResolver {@link Converter} + * @since 6.4 + */ + public void setRedirectUriResolver(Converter> redirectUriResolver) { + Assert.notNull(redirectUriResolver, "redirectUriResolver cannot be null"); + this.redirectUriResolver = redirectUriResolver; + } + + /** + * Parameters, required for redirect URI resolving. + * + * @author Max Batischev + * @since 6.4 + */ + public static class RedirectUriParameters { + + private ServerWebExchange serverWebExchange; + + private Authentication authentication; + + private ClientRegistration clientRegistration; + + public ServerWebExchange getServerWebExchange() { + return this.serverWebExchange; + } + + public Authentication getAuthentication() { + return this.authentication; + } + + public ClientRegistration getClientRegistration() { + return this.clientRegistration; + } + + public void setClientRegistration(ClientRegistration clientRegistration) { + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + this.clientRegistration = clientRegistration; + } + + public void setServerWebExchange(ServerWebExchange serverWebExchange) { + Assert.notNull(serverWebExchange, "serverWebExchange cannot be null"); + this.serverWebExchange = serverWebExchange; + } + + public void setAuthentication(Authentication authentication) { + Assert.notNull(authentication, "authentication cannot be null"); + this.authentication = authentication; + } + + } + + /** + * Default {@link Converter} for redirect uri resolving. + * + * @since 6.4 + */ + private final class DefaultRedirectUriResolver implements Converter> { + + @Override + public Mono convert(RedirectUriParameters redirectUriParameters) { + // @formatter:off + return Mono.just(redirectUriParameters.authentication) + .filter(OAuth2AuthenticationToken.class::isInstance) + .filter((token) -> redirectUriParameters.authentication.getPrincipal() instanceof OidcUser) + .map(OAuth2AuthenticationToken.class::cast) + .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId) + .flatMap( + OidcClientInitiatedServerLogoutSuccessHandler.this.clientRegistrationRepository::findByRegistrationId) + .flatMap((clientRegistration) -> { + URI endSessionEndpoint = endSessionEndpoint(clientRegistration); + if (endSessionEndpoint == null) { + return Mono.empty(); + } + String idToken = idToken(redirectUriParameters.authentication); + String postLogoutRedirectUri = postLogoutRedirectUri( + redirectUriParameters.serverWebExchange.getRequest(), clientRegistration); + return Mono.just(endpointUri(endSessionEndpoint, idToken, postLogoutRedirectUri)); + }); + // @formatter:on + } + + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java index 300a815caf4..591ef091dae 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/server/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import java.io.IOException; import java.net.URI; import java.util.Collections; +import java.util.Objects; import jakarta.servlet.ServletException; import org.junit.jupiter.api.BeforeEach; @@ -199,6 +200,25 @@ public void setPostLogoutRedirectUriTemplateWhenGivenNullThenThrowsException() { assertThatIllegalArgumentException().isThrownBy(() -> this.handler.setPostLogoutRedirectUri((String) null)); } + @Test + public void logoutWhenCustomRedirectUriResolverSetThenRedirects() { + OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(TestOidcUsers.create(), + AuthorityUtils.NO_AUTHORITIES, this.registration.getRegistrationId()); + WebFilterExchange filterExchange = new WebFilterExchange(this.exchange, this.chain); + given(this.exchange.getRequest()) + .willReturn(MockServerHttpRequest.get("/").queryParam("location", "https://test.com").build()); + // @formatter:off + this.handler.setRedirectUriResolver((params) -> Mono.just( + Objects.requireNonNull(params.getServerWebExchange() + .getRequest() + .getQueryParams() + .getFirst("location")))); + // @formatter:on + this.handler.onLogoutSuccess(filterExchange, token).block(); + + assertThat(redirectedUrl(this.exchange)).isEqualTo("https://test.com"); + } + private String redirectedUrl(ServerWebExchange exchange) { return exchange.getResponse().getHeaders().getFirst("Location"); }