Skip to content

Commit

Permalink
Always return current ClientRegistration in loadAuthorizedClient
Browse files Browse the repository at this point in the history
This changes `InMemoryOAuth2AuthorizedClientService.loadAuthorizedClient`
(and its reactive counterpart) to always return `OAuth2AuthorizedClient`
instances containing the current `ClientRegistration` as obtained from
the `ClientRegistrationRepository`.

Before this change, the first `ClientRegistration` instance was cached,
with the effect that any changes made in the `ClientRegistrationRepository`
(such as a new client secret) would not have taken effect.

Closes gh-15511
  • Loading branch information
kzander91 committed Nov 19, 2024
1 parent 30c9860 commit cd1fdc6
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -80,7 +80,13 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRe
if (registration == null) {
return null;
}
return (T) this.authorizedClients.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients
.get(new OAuth2AuthorizedClientId(clientRegistrationId, principalName));
if (cachedAuthorizedClient == null) {
return null;
}
return (T) new OAuth2AuthorizedClient(registration, cachedAuthorizedClient.getPrincipalName(),
cachedAuthorizedClient.getAccessToken(), cachedAuthorizedClient.getRefreshToken());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,27 +16,26 @@

package org.springframework.security.oauth2.client;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import reactor.core.publisher.Mono;

import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.util.Assert;
import reactor.core.publisher.Mono;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
* An {@link OAuth2AuthorizedClientService} that stores {@link OAuth2AuthorizedClient
* Authorized Client(s)} in-memory.
*
* @author Rob Winch
* @author Vedran Pavic
* @since 5.1
* @see OAuth2AuthorizedClientService
* @see OAuth2AuthorizedClient
* @see ClientRegistration
* @see Authentication
* @since 5.1
*/
public final class InMemoryReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService {

Expand All @@ -47,6 +46,7 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
/**
* Constructs an {@code InMemoryReactiveOAuth2AuthorizedClientService} using the
* provided parameters.
*
* @param clientRegistrationRepository the repository of client registrations
*/
public InMemoryReactiveOAuth2AuthorizedClientService(
Expand All @@ -62,8 +62,19 @@ public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(String cl
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty");
return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName))
.flatMap((identifier) -> Mono.justOrEmpty(this.authorizedClients.get(identifier)));
.mapNotNull((clientRegistration) -> {
OAuth2AuthorizedClientId id = new OAuth2AuthorizedClientId(clientRegistrationId, principalName);
OAuth2AuthorizedClient cachedAuthorizedClient = this.authorizedClients.get(id);
if (cachedAuthorizedClient == null) {
return null;
}
// @formatter:off
return new OAuth2AuthorizedClient(clientRegistration,
cachedAuthorizedClient.getPrincipalName(),
cachedAuthorizedClient.getAccessToken(),
cachedAuthorizedClient.getRefreshToken());
// @formatter:on
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,12 +28,9 @@
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.OAuth2AccessToken;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatObject;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.BDDMockito.*;

/**
* Tests for {@link InMemoryOAuth2AuthorizedClientService}.
Expand Down Expand Up @@ -79,9 +76,11 @@ public void constructorWhenAuthorizedClientsIsNullThenThrowIllegalArgumentExcept
@Test
public void constructorWhenAuthorizedClientsProvidedThenUseProvidedAuthorizedClients() {
String registrationId = this.registration3.getRegistrationId();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration3, this.principalName1,
mock(OAuth2AccessToken.class));
Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
new OAuth2AuthorizedClientId(this.registration3.getRegistrationId(), this.principalName1),
mock(OAuth2AuthorizedClient.class));
authorizedClient);
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);
InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
Expand Down Expand Up @@ -124,7 +123,35 @@ public void loadAuthorizedClientWhenClientRegistrationFoundAndAssociatedToPrinci
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
.loadAuthorizedClient(this.registration1.getRegistrationId(), this.principalName1);
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
}

@Test
public void loadAuthorizedClientWhenClientRegistrationIsUpdatedThenReturnAuthorizedClientWithUpdatedClientRegistration() {
ClientRegistration updatedRegistration = ClientRegistration.withClientRegistration(this.registration1)
.clientSecret("updated secret")
.build();
ClientRegistrationRepository repository = mock(ClientRegistrationRepository.class);
given(repository.findByRegistrationId(this.registration1.getRegistrationId())).willReturn(this.registration1,
updatedRegistration);

Authentication authentication = mock(Authentication.class);
given(authentication.getName()).willReturn(this.principalName1);

InMemoryOAuth2AuthorizedClientService service = new InMemoryOAuth2AuthorizedClientService(repository);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration1, this.principalName1,
mock(OAuth2AccessToken.class));
service.saveAuthorizedClient(authorizedClient, authentication);

OAuth2AuthorizedClient authorizedClientWithUpdatedRegistration = new OAuth2AuthorizedClient(updatedRegistration,
this.principalName1, mock(OAuth2AccessToken.class));
OAuth2AuthorizedClient firstLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
this.principalName1);
OAuth2AuthorizedClient secondLoadedClient = service.loadAuthorizedClient(this.registration1.getRegistrationId(),
this.principalName1);
assertAuthorizedClientEquals(authorizedClient, firstLoadedClient);
assertAuthorizedClientEquals(authorizedClientWithUpdatedRegistration, secondLoadedClient);
}

@Test
Expand All @@ -148,7 +175,7 @@ public void saveAuthorizedClientWhenSavedThenCanLoad() {
this.authorizedClientService.saveAuthorizedClient(authorizedClient, authentication);
OAuth2AuthorizedClient loadedAuthorizedClient = this.authorizedClientService
.loadAuthorizedClient(this.registration3.getRegistrationId(), this.principalName2);
assertThat(loadedAuthorizedClient).isEqualTo(authorizedClient);
assertAuthorizedClientEquals(authorizedClient, loadedAuthorizedClient);
}

@Test
Expand Down Expand Up @@ -180,4 +207,29 @@ public void removeAuthorizedClientWhenSavedThenRemoved() {
assertThat(loadedAuthorizedClient).isNull();
}

private static void assertAuthorizedClientEquals(OAuth2AuthorizedClient expected, OAuth2AuthorizedClient actual) {
assertThat(actual).isNotNull();
assertThat(actual.getClientRegistration().getRegistrationId())
.isEqualTo(expected.getClientRegistration().getRegistrationId());
assertThat(actual.getClientRegistration().getClientName())
.isEqualTo(expected.getClientRegistration().getClientName());
assertThat(actual.getClientRegistration().getRedirectUri())
.isEqualTo(expected.getClientRegistration().getRedirectUri());
assertThat(actual.getClientRegistration().getAuthorizationGrantType())
.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
assertThat(actual.getClientRegistration().getClientId())
.isEqualTo(expected.getClientRegistration().getClientId());
assertThat(actual.getClientRegistration().getClientSecret())
.isEqualTo(expected.getClientRegistration().getClientSecret());
assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken());
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,12 +18,14 @@

import java.time.Duration;
import java.time.Instant;
import java.util.function.Consumer;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

Expand All @@ -35,8 +37,8 @@
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;

import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given;
import static org.assertj.core.api.Assertions.*;
import static org.mockito.BDDMockito.*;

/**
* @author Rob Winch
Expand All @@ -60,7 +62,7 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
Instant.now().plus(Duration.ofDays(1)));

// @formatter:off
private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId)
private final ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId)
.redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
Expand Down Expand Up @@ -153,7 +155,33 @@ public void loadAuthorizedClientWhenClientRegistrationFoundThenFound() {
.saveAuthorizedClient(authorizedClient, this.principal)
.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
StepVerifier.create(saveAndLoad)
.expectNext(authorizedClient)
.assertNext(isEqualTo(authorizedClient))
.verifyComplete();
// @formatter:on
}

@Test
@SuppressWarnings("unchecked")
public void loadAuthorizedClientWhenClientRegistrationChangedThenCurrentVersionFound() {
ClientRegistration changedClientRegistration = ClientRegistration
.withClientRegistration(this.clientRegistration)
.clientSecret("updated secret")
.build();

given(this.clientRegistrationRepository.findByRegistrationId(this.clientRegistrationId))
.willReturn(Mono.just(this.clientRegistration), Mono.just(changedClientRegistration));
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principalName, this.accessToken);
OAuth2AuthorizedClient authorizedClientWithChangedRegistration = new OAuth2AuthorizedClient(
changedClientRegistration, this.principalName, this.accessToken);
// @formatter:off
Flux<OAuth2AuthorizedClient> saveAndLoadTwice = this.authorizedClientService
.saveAuthorizedClient(authorizedClient, this.principal)
.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName))
.concatWith(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
StepVerifier.create(saveAndLoadTwice)
.assertNext(isEqualTo(authorizedClient))
.assertNext(isEqualTo(authorizedClientWithChangedRegistration))
.verifyComplete();
// @formatter:on
}
Expand Down Expand Up @@ -246,4 +274,31 @@ public void removeAuthorizedClientWhenClientRegistrationFoundRemovedThenNotFound
// @formatter:on
}

private static Consumer<OAuth2AuthorizedClient> isEqualTo(OAuth2AuthorizedClient expected) {
return (actual) -> {
assertThat(actual).isNotNull();
assertThat(actual.getClientRegistration().getRegistrationId())
.isEqualTo(expected.getClientRegistration().getRegistrationId());
assertThat(actual.getClientRegistration().getClientName())
.isEqualTo(expected.getClientRegistration().getClientName());
assertThat(actual.getClientRegistration().getRedirectUri())
.isEqualTo(expected.getClientRegistration().getRedirectUri());
assertThat(actual.getClientRegistration().getAuthorizationGrantType())
.isEqualTo(expected.getClientRegistration().getAuthorizationGrantType());
assertThat(actual.getClientRegistration().getClientAuthenticationMethod())
.isEqualTo(expected.getClientRegistration().getClientAuthenticationMethod());
assertThat(actual.getClientRegistration().getClientId())
.isEqualTo(expected.getClientRegistration().getClientId());
assertThat(actual.getClientRegistration().getClientSecret())
.isEqualTo(expected.getClientRegistration().getClientSecret());
assertThat(actual.getPrincipalName()).isEqualTo(expected.getPrincipalName());
assertThat(actual.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
assertThat(actual.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
assertThat(actual.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
assertThat(actual.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
assertThat(actual.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
assertThat(actual.getRefreshToken()).isEqualTo(expected.getRefreshToken());
};
}

}

0 comments on commit cd1fdc6

Please sign in to comment.