Skip to content

Commit

Permalink
Refactor AbstractWebClientReactiveOAuth2AccessTokenResponseClient
Browse files Browse the repository at this point in the history
  • Loading branch information
sjohnr committed Apr 29, 2024
1 parent d8fe5e7 commit d7ca009
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 260 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

package org.springframework.security.oauth2.client.endpoint;

import java.util.Collections;
import java.util.Set;

import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
Expand All @@ -36,7 +33,6 @@
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec;

Expand All @@ -54,6 +50,7 @@
*
* @param <T> type of grant request
* @author Phil Clay
* @author Steve Riesenberg
* @since 5.3
* @see <a href="https://tools.ietf.org/html/rfc6749#section-3.2">RFC-6749 Token
* Endpoint</a>
Expand All @@ -72,7 +69,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T

private Converter<T, HttpHeaders> headersConverter = new DefaultOAuth2TokenRequestHeadersConverter<>();

private Converter<T, MultiValueMap<String, String>> parametersConverter = this::populateTokenRequestParameters;
private Converter<T, MultiValueMap<String, String>> parametersConverter = this::createParameters;

private BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors
.oauth2AccessTokenResponse();
Expand All @@ -86,18 +83,11 @@ public Mono<OAuth2AccessTokenResponse> getTokenResponse(T grantRequest) {
// @formatter:off
return Mono.defer(() -> this.requestEntityConverter.convert(grantRequest)
.exchange()
.flatMap((response) -> readTokenResponse(grantRequest, response))
.flatMap((response) -> response.body(this.bodyExtractor))
);
// @formatter:on
}

/**
* Returns the {@link ClientRegistration} for the given {@code grantRequest}.
* @param grantRequest the grant request
* @return the {@link ClientRegistration} for the given {@code grantRequest}.
*/
abstract ClientRegistration clientRegistration(T grantRequest);

private RequestHeadersSpec<?> validatingPopulateRequest(T grantRequest) {
validateClientAuthenticationMethod(grantRequest);
return populateRequest(grantRequest);
Expand All @@ -117,128 +107,41 @@ private void validateClientAuthenticationMethod(T grantRequest) {
}

private RequestHeadersSpec<?> populateRequest(T grantRequest) {
MultiValueMap<String, String> parameters = this.parametersConverter.convert(grantRequest);
return this.webClient.post()
.uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri())
.uri(grantRequest.getClientRegistration().getProviderDetails().getTokenUri())
.headers((headers) -> {
HttpHeaders headersToAdd = getHeadersConverter().convert(grantRequest);
HttpHeaders headersToAdd = this.headersConverter.convert(grantRequest);
if (headersToAdd != null) {
headers.addAll(headersToAdd);
}
})
.body(createTokenRequestBody(grantRequest));
.body(BodyInserters.fromFormData(parameters));
}

/**
* Populates default parameters for the token request.
* @param grantRequest the grant request
* @return the parameters populated for the token request.
* Returns a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body.
* @param grantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
private MultiValueMap<String, String> populateTokenRequestParameters(T grantRequest) {
MultiValueMap<String, String> createParameters(T grantRequest) {
ClientRegistration clientRegistration = grantRequest.getClientRegistration();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
return parameters;
}

/**
* Combine the results of {@code parametersConverter} and
* {@link #populateTokenRequestBody}.
*
* <p>
* This method pre-populates the body with some standard properties, and then
* delegates to
* {@link #populateTokenRequestBody(AbstractOAuth2AuthorizationGrantRequest, BodyInserters.FormInserter)}
* for subclasses to further populate the body before returning.
* </p>
* @param grantRequest the grant request
* @return the body for the token request.
*/
private BodyInserters.FormInserter<String> createTokenRequestBody(T grantRequest) {
MultiValueMap<String, String> parameters = getParametersConverter().convert(grantRequest);
return populateTokenRequestBody(grantRequest, BodyInserters.fromFormData(parameters));
}

/**
* Populates the body of the token request.
*
* <p>
* By default, populates properties that are common to all grant types. Subclasses can
* extend this method to populate grant type specific properties.
* </p>
* @param grantRequest the grant request
* @param body the body to populate
* @return the populated body
*/
BodyInserters.FormInserter<String> populateTokenRequestBody(T grantRequest,
BodyInserters.FormInserter<String> body) {
ClientRegistration clientRegistration = clientRegistration(grantRequest);
parameters.set(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue());
if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC
.equals(clientRegistration.getClientAuthenticationMethod())) {
body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
parameters.set(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) {
body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
parameters.set(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
Set<String> scopes = scopes(grantRequest);
if (!CollectionUtils.isEmpty(scopes)) {
body.with(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(scopes, " "));
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
parameters.set(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
return body;
}

/**
* Returns the scopes to include as a property in the token request.
* @param grantRequest the grant request
* @return the scopes to include as a property in the token request.
*/
abstract Set<String> scopes(T grantRequest);

/**
* Returns the scopes to include in the response if the authorization server returned
* no scopes in the response.
*
* <p>
* As per <a href="https://tools.ietf.org/html/rfc6749#section-5.1">RFC-6749 Section
* 5.1 Successful Access Token Response</a>, if AccessTokenResponse.scope is empty,
* then default to the scope originally requested by the client in the Token Request.
* </p>
* @param grantRequest the grant request
* @return the scopes to include in the response if the authorization server returned
* no scopes.
*/
Set<String> defaultScopes(T grantRequest) {
return Collections.emptySet();
}

/**
* Reads the token response from the response body.
* @param grantRequest the request for which the response was received.
* @param response the client response from which to read
* @return the token response from the response body.
*/
private Mono<OAuth2AccessTokenResponse> readTokenResponse(T grantRequest, ClientResponse response) {
return response.body(this.bodyExtractor)
.map((tokenResponse) -> populateTokenResponse(grantRequest, tokenResponse));
}

/**
* Populates the given {@link OAuth2AccessTokenResponse} with additional details from
* the grant request.
* @param grantRequest the request for which the response was received.
* @param tokenResponse the original token response
* @return a token response optionally populated with additional details from the
* request.
*/
OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessTokenResponse tokenResponse) {
if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) {
Set<String> defaultScopes = defaultScopes(grantRequest);
// @formatter:off
tokenResponse = OAuth2AccessTokenResponse
.withResponse(tokenResponse)
.scopes(defaultScopes)
.build();
// @formatter:on
}
return tokenResponse;
return parameters;
}

/**
Expand All @@ -247,22 +150,11 @@ OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessToke
* @param webClient the {@link WebClient} used when requesting the Access Token
* Response
*/
public void setWebClient(WebClient webClient) {
public final void setWebClient(WebClient webClient) {
Assert.notNull(webClient, "webClient cannot be null");
this.webClient = webClient;
}

/**
* Returns the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
* used in the OAuth 2.0 Access Token Request headers.
* @return the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link HttpHeaders}
*/
final Converter<T, HttpHeaders> getHeadersConverter() {
return this.headersConverter;
}

/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
Expand Down Expand Up @@ -305,17 +197,6 @@ public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter
this.requestEntityConverter = this::populateRequest;
}

/**
* Returns the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
* used in the OAuth 2.0 Access Token Request body.
* @return the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} to {@link MultiValueMap}
*/
final Converter<T, MultiValueMap<String, String>> getParametersConverter() {
return this.parametersConverter;
}

/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,11 @@

package org.springframework.security.oauth2.client.endpoint;

import java.util.Collections;
import java.util.Set;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.util.MultiValueMap;

/**
* An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} that
Expand Down Expand Up @@ -56,32 +51,21 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClient
extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> {

@Override
ClientRegistration clientRegistration(OAuth2AuthorizationCodeGrantRequest grantRequest) {
return grantRequest.getClientRegistration();
}

@Override
Set<String> scopes(OAuth2AuthorizationCodeGrantRequest grantRequest) {
return Collections.emptySet();
}

@Override
BodyInserters.FormInserter<String> populateTokenRequestBody(OAuth2AuthorizationCodeGrantRequest grantRequest,
BodyInserters.FormInserter<String> body) {
super.populateTokenRequestBody(grantRequest, body);
MultiValueMap<String, String> createParameters(OAuth2AuthorizationCodeGrantRequest grantRequest) {
OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange();
OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse();
body.with(OAuth2ParameterNames.CODE, authorizationResponse.getCode());
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
parameters.remove(OAuth2ParameterNames.SCOPE);
parameters.set(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode());
String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri();
if (redirectUri != null) {
body.with(OAuth2ParameterNames.REDIRECT_URI, redirectUri);
parameters.set(OAuth2ParameterNames.REDIRECT_URI, redirectUri);
}
String codeVerifier = authorizationExchange.getAuthorizationRequest()
.getAttribute(PkceParameterNames.CODE_VERIFIER);
if (codeVerifier != null) {
body.with(PkceParameterNames.CODE_VERIFIER, codeVerifier);
parameters.set(PkceParameterNames.CODE_VERIFIER, codeVerifier);
}
return body;
return parameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

package org.springframework.security.oauth2.client.endpoint;

import java.util.Set;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;

/**
Expand All @@ -44,14 +41,4 @@
public class WebClientReactiveClientCredentialsTokenResponseClient
extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {

@Override
ClientRegistration clientRegistration(OAuth2ClientCredentialsGrantRequest grantRequest) {
return grantRequest.getClientRegistration();
}

@Override
Set<String> scopes(OAuth2ClientCredentialsGrantRequest grantRequest) {
return grantRequest.getClientRegistration().getScopes();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@

package org.springframework.security.oauth2.client.endpoint;

import java.util.Set;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.client.WebClient;

/**
Expand All @@ -45,20 +42,10 @@ public final class WebClientReactiveJwtBearerTokenResponseClient
extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> {

@Override
ClientRegistration clientRegistration(JwtBearerGrantRequest grantRequest) {
return grantRequest.getClientRegistration();
}

@Override
Set<String> scopes(JwtBearerGrantRequest grantRequest) {
return grantRequest.getClientRegistration().getScopes();
}

@Override
BodyInserters.FormInserter<String> populateTokenRequestBody(JwtBearerGrantRequest grantRequest,
BodyInserters.FormInserter<String> body) {
return super.populateTokenRequestBody(grantRequest, body).with(OAuth2ParameterNames.ASSERTION,
grantRequest.getJwt().getTokenValue());
MultiValueMap<String, String> createParameters(JwtBearerGrantRequest grantRequest) {
MultiValueMap<String, String> parameters = super.createParameters(grantRequest);
parameters.set(OAuth2ParameterNames.ASSERTION, grantRequest.getJwt().getTokenValue());
return parameters;
}

}
Loading

0 comments on commit d7ca009

Please sign in to comment.