From 85c3d0ab13212b6df710207024b7240c71d2ebef Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Wed, 28 Feb 2024 17:40:06 -0600 Subject: [PATCH] Add reactive support for OAuth 2.0 Token Exchange Grant Issue gh-5199 --- ...eactiveOAuth2AuthorizedClientProvider.java | 166 +++++ ...ctiveTokenExchangeTokenResponseClient.java | 83 +++ ...veOAuth2AuthorizedClientProviderTests.java | 389 +++++++++++ ...TokenExchangeTokenResponseClientTests.java | 654 ++++++++++++++++++ 4 files changed, 1292 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..43e0607d2ea --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java @@ -0,0 +1,166 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveTokenExchangeTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.util.Assert; + +/** + * An implementation of an {@link ReactiveOAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#TOKEN_EXCHANGE token-exchange} grant. + * + * @author Steve Riesenberg + * @since 6.3 + * @see ReactiveOAuth2AuthorizedClientProvider + * @see WebClientReactiveTokenExchangeTokenResponseClient + */ +public final class TokenExchangeReactiveOAuth2AuthorizedClientProvider + implements ReactiveOAuth2AuthorizedClientProvider { + + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = new WebClientReactiveTokenExchangeTokenResponseClient(); + + private Function> subjectTokenResolver = this::resolveSubjectToken; + + private Function> actorTokenResolver = (context) -> Mono.empty(); + + private Duration clockSkew = Duration.ofSeconds(60); + + private Clock clock = Clock.systemUTC(); + + /** + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns an empty {@code Mono} if authorization (or + * re-authorization) is not supported, e.g. the client's + * {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} is + * not {@link AuthorizationGrantType#TOKEN_EXCHANGE token-exchange} OR the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if + * authorization is not supported + */ + @Override + public Mono authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + ClientRegistration clientRegistration = context.getClientRegistration(); + if (!AuthorizationGrantType.TOKEN_EXCHANGE.equals(clientRegistration.getAuthorizationGrantType())) { + return Mono.empty(); + } + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); + if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { + // If client is already authorized but access token is NOT expired than no + // need for re-authorization + return Mono.empty(); + } + + return this.subjectTokenResolver.apply(context) + .flatMap((subjectToken) -> this.actorTokenResolver.apply(context) + .map((actorToken) -> new TokenExchangeGrantRequest(clientRegistration, subjectToken, actorToken)) + .defaultIfEmpty(new TokenExchangeGrantRequest(clientRegistration, subjectToken, null))) + .flatMap(this.accessTokenResponseClient::getTokenResponse) + .onErrorMap(OAuth2AuthorizationException.class, + (ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex)) + .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + tokenResponse.getAccessToken())); + } + + private Mono resolveSubjectToken(OAuth2AuthorizationContext context) { + // @formatter:off + return Mono.just(context) + .map((ctx) -> ctx.getPrincipal().getPrincipal()) + .filter((principal) -> principal instanceof OAuth2Token) + .cast(OAuth2Token.class); + // @formatter:on + } + + private boolean hasTokenExpired(OAuth2Token token) { + return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew)); + } + + /** + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code token-exchange} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code token-exchange} grant + */ + public void setAccessTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } + + /** + * Sets the resolver used for resolving the {@link OAuth2Token subject token}. + * @param subjectTokenResolver the resolver used for resolving the {@link OAuth2Token + * subject token} + */ + public void setSubjectTokenResolver(Function> subjectTokenResolver) { + Assert.notNull(subjectTokenResolver, "subjectTokenResolver cannot be null"); + this.subjectTokenResolver = subjectTokenResolver; + } + + /** + * Sets the resolver used for resolving the {@link OAuth2Token actor token}. + * @param actorTokenResolver the resolver used for resolving the {@link OAuth2Token + * actor token} + */ + public void setActorTokenResolver(Function> actorTokenResolver) { + Assert.notNull(actorTokenResolver, "actorTokenResolver cannot be null"); + this.actorTokenResolver = actorTokenResolver; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. + * + *

+ * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. + * @param clockSkew the maximum acceptable clock skew + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } + + /** + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. + * @param clock the clock + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock cannot be null"); + this.clock = clock; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java new file mode 100644 index 00000000000..abc9ad751b8 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.Set; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * The default implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} for + * the {@link AuthorizationGrantType#TOKEN_EXCHANGE token-exchange} grant. This + * implementation uses {@link WebClient} when requesting an access token credential at the + * Authorization Server's Token Endpoint. + * + * @author Steve Riesenberg + * @since 6.3 + * @see ReactiveOAuth2AccessTokenResponseClient + * @see TokenExchangeGrantRequest + * @see OAuth2AccessToken + * @see Section + * 2.1 Request + * @see Section + * 2.2 Response + */ +public final class WebClientReactiveTokenExchangeTokenResponseClient + extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + @Override + ClientRegistration clientRegistration(TokenExchangeGrantRequest grantRequest) { + return grantRequest.getClientRegistration(); + } + + @Override + Set scopes(TokenExchangeGrantRequest grantRequest) { + return grantRequest.getClientRegistration().getScopes(); + } + + @Override + BodyInserters.FormInserter populateTokenRequestBody(TokenExchangeGrantRequest grantRequest, + BodyInserters.FormInserter body) { + super.populateTokenRequestBody(grantRequest, body); + body.with(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE); + OAuth2Token subjectToken = grantRequest.getSubjectToken(); + body.with(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue()); + body.with(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken)); + OAuth2Token actorToken = grantRequest.getActorToken(); + if (actorToken != null) { + body.with(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue()); + body.with(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken)); + } + return body; + } + + private static String tokenType(OAuth2Token token) { + return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..2b7250911f7 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,389 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.function.Function; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link TokenExchangeReactiveOAuth2AuthorizedClientProvider}. + * + * @author Steve Riesenberg + */ +public class TokenExchangeReactiveOAuth2AuthorizedClientProviderTests { + + private TokenExchangeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + + private ClientRegistration clientRegistration; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + private Authentication principal; + + @BeforeEach + public void setUp() { + this.authorizedClientProvider = new TokenExchangeReactiveOAuth2AuthorizedClientProvider(); + this.accessTokenResponseClient = mock(ReactiveOAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + // @formatter:off + this.clientRegistration = ClientRegistration.withRegistrationId("token-exchange") + .clientId("client-id") + .clientSecret("client-secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .tokenUri("https://example.com/oauth2/token") + .build(); + // @formatter:on + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = TestOAuth2AccessTokens.noScopes(); + this.principal = new TestingAuthenticationToken(this.subjectToken, this.subjectToken); + } + + @Test + public void setAccessTokenResponseClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .withMessage("accessTokenResponseClient cannot be null"); + // @formatter:on + } + + @Test + public void setSubjectTokenResolverWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setSubjectTokenResolver(null)) + .withMessage("subjectTokenResolver cannot be null"); + // @formatter:on + } + + @Test + public void setActorTokenResolverWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setActorTokenResolver(null)) + .withMessage("actorTokenResolver cannot be null"); + // @formatter:on + } + + @Test + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on + } + + @Test + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on + } + + @Test + public void setClockWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) + .withMessage("context cannot be null"); + // @formatter:on + } + + @Test + public void authorizeWhenNotTokenExchangeThenUnableToAuthorize() { + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenNotExpiredThenNotReauthorized() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write")); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenInvalidRequestThenThrowClientAuthorizationException() { + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))).willReturn( + Mono.error(new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)))); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + + // @formatter:off + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST)) + .withMessageContaining("[invalid_request]"); + // @formatter:on + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() { + Instant now = Instant.now(); + Instant issuedAt = now.minus(Duration.ofMinutes(60)); + Instant expiresAt = now.minus(Duration.ofMinutes(30)); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", + issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) + .block(); + assertThat(reauthorizedClient).isNotNull(); + assertThat(reauthorizedClient).isNotEqualTo(authorizedClient); + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiryThenReauthorized() { + Instant now = Instant.now(); + Instant issuedAt = now.minus(Duration.ofMinutes(60)); + Instant expiresAt = now.plus(Duration.ofMinutes(1)); + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken); + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client + this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) + .block(); + assertThat(reauthorizedClient).isNotNull(); + assertThat(reauthorizedClient).isNotEqualTo(authorizedClient); + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenDoesNotResolveThenUnableToAuthorize() { + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(new TestingAuthenticationToken("user", "password")) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThenAuthorized() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() { + Function> subjectTokenResolver = mock(Function.class); + given(subjectTokenResolver.apply(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(this.subjectToken)); + this.authorizedClientProvider.setSubjectTokenResolver(subjectTokenResolver); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password"); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(subjectTokenResolver).apply(authorizationContext); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenCustomActorTokenResolverSetThenCalled() { + Function> actorTokenResolver = mock(Function.class); + given(actorTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(Mono.just(this.actorToken)); + this.authorizedClientProvider.setActorTokenResolver(actorTokenResolver); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(actorTokenResolver).apply(authorizationContext); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isEqualTo(this.actorToken); + } + + @Test + public void authorizeWhenClockSetThenCalled() { + Clock clock = mock(Clock.class); + given(clock.instant()).willReturn(Instant.now()); + this.authorizedClientProvider.setClock(clock); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + verify(clock).instant(); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java new file mode 100644 index 00000000000..c9bbfcf5178 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java @@ -0,0 +1,654 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Collections; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.reactive.function.client.WebClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link WebClientReactiveTokenExchangeTokenResponseClient}. + * + * @author Steve Riesenberg + */ +public class WebClientReactiveTokenExchangeTokenResponseClientTests { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + private WebClientReactiveTokenExchangeTokenResponseClient tokenResponseClient; + + private ClientRegistration.Builder clientRegistration; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + private MockWebServer server; + + @BeforeEach + public void setUp() throws IOException { + this.tokenResponseClient = new WebClientReactiveTokenExchangeTokenResponseClient(); + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off + this.clientRegistration = TestClientRegistrations.clientCredentials() + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .tokenUri(tokenUri) + .scope("read", "write"); + // @formatter:on + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = null; + } + + @AfterEach + public void cleanUp() throws IOException { + this.server.shutdown(); + } + + @Test + public void setWebClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setWebClient(null)) + .withMessage("webClient cannot be null"); + // @formatter:on + } + + @Test + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setBodyExtractorWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setBodyExtractor(null)) + .withMessage("bodyExtractor cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .withMessage("grantRequest cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.subjectToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, JWT_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.actorToken = TestOAuth2AccessTokens.noScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.ACTOR_TOKEN, this.actorToken.getTokenValue()), + param(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.actorToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.ACTOR_TOKEN, this.actorToken.getTokenValue()), + param(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, JWT_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response") + .havingRootCause().withMessage("Unsupported token_type: not-bearer"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).isEmpty(); + } + + @Test + public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(301)); + TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessage("[invalid_token_response] Empty OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR)) + .withMessage("[server_error] A server error occurred"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT)) + .withMessage("[invalid_grant] Invalid grant"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(new ClientAuthenticationMethod("basic")) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest).block()) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest).block()) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.addHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.setHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + Converter> parametersConverter = mock(Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + Converter> parametersConverter = mock(Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.addParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")), + param("custom-parameter-name", "custom-parameter-value") + ); + // @formatter:on + } + + @Test + public void getTokenResponseWhenBodyExtractorSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = mock( + BodyExtractor.class); + OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(bodyExtractor.extract(any(ReactiveHttpInputMessage.class), any(BodyExtractor.Context.class))) + .willReturn(Mono.just(response)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.setBodyExtractor(bodyExtractor); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(bodyExtractor).extract(any(ReactiveHttpInputMessage.class), any(BodyExtractor.Context.class)); + } + + @Test + public void getTokenResponseWhenWebClientSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + WebClient customClient = mock(WebClient.class); + given(customClient.post()).willReturn(WebClient.builder().build().post()); + this.tokenResponseClient.setWebClient(customClient); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(customClient).post(); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); + } + + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + +}