From cd1fdc65c23ea3de631d75cf3e00db1a01d5d5ea Mon Sep 17 00:00:00 2001 From: kzander91 Date: Tue, 19 Nov 2024 08:10:51 +0100 Subject: [PATCH] Always return current ClientRegistration in `loadAuthorizedClient` This changes `InMemoryOAuth2AuthorizedClientService.loadAuthorizedClient` (and its reactive counterpart) to always return `OAuth2AuthorizedClient` instances containing the current `ClientRegistration` as obtained from the `ClientRegistrationRepository`. Before this change, the first `ClientRegistration` instance was cached, with the effect that any changes made in the `ClientRegistrationRepository` (such as a new client secret) would not have taken effect. Closes gh-15511 --- ...InMemoryOAuth2AuthorizedClientService.java | 10 ++- ...ReactiveOAuth2AuthorizedClientService.java | 29 +++++--- ...oryOAuth2AuthorizedClientServiceTests.java | 72 ++++++++++++++++--- ...iveOAuth2AuthorizedClientServiceTests.java | 65 +++++++++++++++-- 4 files changed, 150 insertions(+), 26 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java index 3041ce764f3..8ec6fdb17e4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -80,7 +80,13 @@ public T loadAuthorizedClient(String clientRe if (registration == null) { return null; } - return (T) this.authorizedClients.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName)); + OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients + .get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName)); + if (cachedAuthorizedClient == null) { + return null; + } + return (T) new OAuth2AuthorizedClient(registration, cachedAuthorizedClient.getPrincipalName(), + cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken()); } @Override diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java index 3cf977d4779..f5cc5b35e9a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,14 @@ package org.springframework.security.oauth2.client; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -import reactor.core.publisher.Mono; - import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.util.Assert; +import reactor.core.publisher.Mono; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** * An {@link OAuth2AuthorizedClientService} that stores {@link OAuth2AuthorizedClient @@ -32,11 +31,11 @@ * * @author Rob Winch * @author Vedran Pavic - * @since 5.1 * @see OAuth2AuthorizedClientService * @see OAuth2AuthorizedClient * @see ClientRegistration * @see Authentication + * @since 5.1 */ public final class InMemoryReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService { @@ -47,6 +46,7 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac /** * Constructs an {@code InMemoryReactiveOAuth2AuthorizedClientService} using the * provided parameters. + * * @param clientRegistrationRepository the repository of client registrations */ public InMemoryReactiveOAuth2AuthorizedClientService( @@ -62,8 +62,19 @@ public Mono loadAuthorizedClient(String cl Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty"); return (Mono) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName)) - .flatMap((identifier) -> Mono.justOrEmpty(this.authorizedClients.get(identifier))); + .mapNotNull((clientRegistration) -> { + OAuth2AuthorizedClientId id = new OAuth2AuthorizedClientId(clientRegistrationId, principalName); + OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients.get(id); + if (cachedAuthorizedClient == null) { + return null; + } + // @formatter:off + return new OAuth2AuthorizedClient(clientRegistration, + cachedAuthorizedClient.getPrincipalName(), + cachedAuthorizedClient.getAccessToken(), + cachedAuthorizedClient.getRefreshToken()); + // @formatter:on + }); } @Override diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java index efa546b5d0c..ef4fd16b2a9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,12 +28,9 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.assertj.core.api.Assertions.assertThatObject; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.BDDMockito.given; -import static org.mockito.Mockito.mock; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.BDDMockito.*; /** * Tests for {@link InMemoryOAuth2AuthorizedClientService}. @@ -79,9 +76,11 @@ public void constructorWhenAuthorizedClientsIsNullThenThrowIllegalArgumentExcept @Test public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() { String registrationId = this.registration3.getRegistrationId(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration3, this.principalName1, + mock(OAuth2AccessToken.class)); Map authorizedClients = Collections.singletonMap( new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1), - mock(OAuth2AuthorizedClient.class)); + authorizedClient); ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3); InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService( @@ -124,7 +123,35 @@ public void loadAuthorizedClientWhenClientRegistrationFoundAndAssociatedToPrinci this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication); OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService .loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1); - assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); + assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient); + } + + @Test + public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnAuthorizedClientWithUpdatedClientRegistration() { + ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.registration1) + .clientSecret("updated secret") + .build(); + ClientRegistrationRepository repository = mock(ClientRegistrationRepository.class); + given(repository.findByRegistrationId(this.registration1.getRegistrationId())).willReturn(this.registration1, + updatedRegistration); + + Authentication authentication = mock(Authentication.class); + given(authentication.getName()).willReturn(this.principalName1); + + InMemoryOAuth2AuthorizedClientService service = new InMemoryOAuth2AuthorizedClientService(repository); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1, + mock(OAuth2AccessToken.class)); + service.saveAuthorizedClient(authorizedClient, authentication); + + OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient(updatedRegistration, + this.principalName1, mock(OAuth2AccessToken.class)); + OAuth2AuthorizedClient firstLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(), + this.principalName1); + OAuth2AuthorizedClient secondLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(), + this.principalName1); + assertAuthorizedClientEquals(authorizedClient, firstLoadedClient); + assertAuthorizedClientEquals(authorizedClientWithUpdatedRegistration, secondLoadedClient); } @Test @@ -148,7 +175,7 @@ public void saveAuthorizedClientWhenSavedThenCanLoad() { this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication); OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService .loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2); - assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient); + assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient); } @Test @@ -180,4 +207,29 @@ public void removeAuthorizedClientWhenSavedThenRemoved() { assertThat(loadedAuthorizedClient).isNull(); } + private static void assertAuthorizedClientEquals(OAuth2AuthorizedClient expected, OAuth2AuthorizedClient actual) { + assertThat(actual).isNotNull(); + assertThat(actual.getClientRegistration().getRegistrationId()) + .isEqualTo(expected.getClientRegistration().getRegistrationId()); + assertThat(actual.getClientRegistration().getClientName()) + .isEqualTo(expected.getClientRegistration().getClientName()); + assertThat(actual.getClientRegistration().getRedirectUri()) + .isEqualTo(expected.getClientRegistration().getRedirectUri()); + assertThat(actual.getClientRegistration().getAuthorizationGrantType()) + .isEqualTo(expected.getClientRegistration().getAuthorizationGrantType()); + assertThat(actual.getClientRegistration().getClientAuthenticationMethod()) + .isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod()); + assertThat(actual.getClientRegistration().getClientId()) + .isEqualTo(expected.getClientRegistration().getClientId()); + assertThat(actual.getClientRegistration().getClientSecret()) + .isEqualTo(expected.getClientRegistration().getClientSecret()); + assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName()); + assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt()); + assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt()); + assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes()); + assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken()); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java index 71a359b5ab3..b1231b70fd1 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,14 @@ import java.time.Duration; import java.time.Instant; +import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -35,8 +37,8 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.BDDMockito.given; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.BDDMockito.*; /** * @author Rob Winch @@ -60,7 +62,7 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests { Instant.now().plus(Duration.ofDays(1))); // @formatter:off - private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId) + private final ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId) .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) @@ -153,7 +155,33 @@ public void loadAuthorizedClientWhenClientRegistrationFoundThenFound() { .saveAuthorizedClient(authorizedClient, this.principal) .then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); StepVerifier.create(saveAndLoad) - .expectNext(authorizedClient) + .assertNext(isEqualTo(authorizedClient)) + .verifyComplete(); + // @formatter:on + } + + @Test + @SuppressWarnings("unchecked") + public void loadAuthorizedClientWhenClientRegistrationChangedThenCurrentVersionFound() { + ClientRegistration changedClientRegistration = ClientRegistration + .withClientRegistration(this.clientRegistration) + .clientSecret("updated secret") + .build(); + + given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId)) + .willReturn(Mono.just(this.clientRegistration), Mono.just(changedClientRegistration)); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principalName, this.accessToken); + OAuth2AuthorizedClient authorizedClientWithChangedRegistration = new OAuth2AuthorizedClient( + changedClientRegistration, this.principalName, this.accessToken); + // @formatter:off + Flux saveAndLoadTwice = this.authorizedClientService + .saveAuthorizedClient(authorizedClient, this.principal) + .then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)) + .concatWith(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); + StepVerifier.create(saveAndLoadTwice) + .assertNext(isEqualTo(authorizedClient)) + .assertNext(isEqualTo(authorizedClientWithChangedRegistration)) .verifyComplete(); // @formatter:on } @@ -246,4 +274,31 @@ public void removeAuthorizedClientWhenClientRegistrationFoundRemovedThenNotFound // @formatter:on } + private static Consumer isEqualTo(OAuth2AuthorizedClient expected) { + return (actual) -> { + assertThat(actual).isNotNull(); + assertThat(actual.getClientRegistration().getRegistrationId()) + .isEqualTo(expected.getClientRegistration().getRegistrationId()); + assertThat(actual.getClientRegistration().getClientName()) + .isEqualTo(expected.getClientRegistration().getClientName()); + assertThat(actual.getClientRegistration().getRedirectUri()) + .isEqualTo(expected.getClientRegistration().getRedirectUri()); + assertThat(actual.getClientRegistration().getAuthorizationGrantType()) + .isEqualTo(expected.getClientRegistration().getAuthorizationGrantType()); + assertThat(actual.getClientRegistration().getClientAuthenticationMethod()) + .isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod()); + assertThat(actual.getClientRegistration().getClientId()) + .isEqualTo(expected.getClientRegistration().getClientId()); + assertThat(actual.getClientRegistration().getClientSecret()) + .isEqualTo(expected.getClientRegistration().getClientSecret()); + assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName()); + assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType()); + assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue()); + assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt()); + assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt()); + assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes()); + assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken()); + }; + } + }