Skip to content

Commit

Permalink
Add support for customization jwt decoder for ReactiveOidcIdTokenDeco…
Browse files Browse the repository at this point in the history
…derFactory, OidcIdTokenDecoderFactory

Closes spring-projectsgh-14673
  • Loading branch information
Max Batischev authored and Max Batischev committed Mar 27, 2024
1 parent e771267 commit 078218a
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
private Function<ClientRegistration, RestOperations> restOperationsFactory = (
clientRegistration) -> new RestTemplate();

private Function<ClientRegistration, JwtDecoder> jwtDecoderFactory = new DefaultJwtJwtDecoderFactory();

/**
* Returns the default {@link Converter}'s used for type conversion of claim values
* for an {@link OidcIdToken}.
Expand Down Expand Up @@ -131,7 +133,7 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
public JwtDecoder createDecoder(ClientRegistration clientRegistration) {
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> {
NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
NimbusJwtDecoder jwtDecoder = (NimbusJwtDecoder) this.jwtDecoderFactory.apply(clientRegistration);
jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
.apply(clientRegistration);
Expand All @@ -142,70 +144,6 @@ public JwtDecoder createDecoder(ClientRegistration clientRegistration) {
});
}

private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) {
JwsAlgorithm jwsAlgorithm = this.jwsAlgorithmResolver.apply(clientRegistration);
if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 6. If the ID Token is received via direct communication between the Client
// and the Token Endpoint (which it is in this flow),
// the TLS server validation MAY be used to validate the issuer in place of
// checking the token signature.
// The Client MUST validate the signature of all other ID Tokens according to
// JWS [JWS]
// using the algorithm specified in the JWT alg Header Parameter.
// The Client MUST use the keys provided by the Issuer.
//
// 7. The alg value SHOULD be the default of RS256 or the algorithm sent by
// the Client
// in the id_token_signed_response_alg parameter during Registration.

String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
if (!StringUtils.hasText(jwkSetUri)) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured the JwkSet URI.",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
return NimbusJwtDecoder.withJwkSetUri(jwkSetUri)
.jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm)
.restOperations(this.restOperationsFactory.apply(clientRegistration))
.build();
}
if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 8. If the JWT alg Header Parameter uses a MAC based algorithm such as
// HS256, HS384, or HS512,
// the octets of the UTF-8 representation of the client_secret
// corresponding to the client_id contained in the aud (audience) Claim
// are used as the key to validate the signature.
// For MAC based algorithms, the behavior is unspecified if the aud is
// multi-valued or
// if an azp value is present that is different than the aud value.
String clientSecret = clientRegistration.getClientSecret();
if (!StringUtils.hasText(clientSecret)) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured the client secret.",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
return NimbusJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
}
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

/**
* Sets the factory that provides an {@link OAuth2TokenValidator}, which is used by
* the {@link JwtDecoder}. The default composes {@link JwtTimestampValidator} and
Expand Down Expand Up @@ -259,4 +197,93 @@ public void setRestOperationsFactory(Function<ClientRegistration, RestOperations
this.restOperationsFactory = restOperationsFactory;
}

/**
* Sets the factory that provides a {@link JwtDecoder} used for building
* {@link NimbusJwtDecoder}
* @param jwtDecoderFactory the factory that provides a {@link JwtDecoder}
*
* @since 6.3
*/
public void setJwtDecoderFactory(Function<ClientRegistration, JwtDecoder> jwtDecoderFactory) {
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
this.jwtDecoderFactory = jwtDecoderFactory;
}

/**
* @author Max Batischev
* @since 6.3
*/
private class DefaultJwtJwtDecoderFactory implements Function<ClientRegistration, JwtDecoder> {

@Override
public JwtDecoder apply(ClientRegistration clientRegistration) {
JwsAlgorithm jwsAlgorithm = OidcIdTokenDecoderFactory.this.jwsAlgorithmResolver.apply(clientRegistration);
if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 6. If the ID Token is received via direct communication between the
// Client
// and the Token Endpoint (which it is in this flow),
// the TLS server validation MAY be used to validate the issuer in place
// of
// checking the token signature.
// The Client MUST validate the signature of all other ID Tokens according
// to
// JWS [JWS]
// using the algorithm specified in the JWT alg Header Parameter.
// The Client MUST use the keys provided by the Issuer.
//
// 7. The alg value SHOULD be the default of RS256 or the algorithm sent
// by
// the Client
// in the id_token_signed_response_alg parameter during Registration.

String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
if (!StringUtils.hasText(jwkSetUri)) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured the JwkSet URI.",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
return NimbusJwtDecoder.withJwkSetUri(jwkSetUri)
.jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm)
.restOperations(OidcIdTokenDecoderFactory.this.restOperationsFactory.apply(clientRegistration))
.build();
}
if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 8. If the JWT alg Header Parameter uses a MAC based algorithm such as
// HS256, HS384, or HS512,
// the octets of the UTF-8 representation of the client_secret
// corresponding to the client_id contained in the aud (audience) Claim
// are used as the key to validate the signature.
// For MAC based algorithms, the behavior is unspecified if the aud is
// multi-valued or
// if an azp value is present that is different than the aud value.
String clientSecret = clientRegistration.getClientSecret();
if (!StringUtils.hasText(clientSecret)) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured the client secret.",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
return NimbusJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
}
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod

private Function<ClientRegistration, WebClient> webClientFactory = (clientRegistration) -> WebClient.create();

private Function<ClientRegistration, ReactiveJwtDecoder> reactiveJwtDecoderFactory = new DefaultReactiveJwtDecoderFactory();

/**
* Returns the default {@link Converter}'s used for type conversion of claim values
* for an {@link OidcIdToken}.
Expand Down Expand Up @@ -130,7 +132,8 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> {
NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
NimbusReactiveJwtDecoder jwtDecoder = (NimbusReactiveJwtDecoder) this.reactiveJwtDecoderFactory
.apply(clientRegistration);
jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
.apply(clientRegistration);
Expand All @@ -141,72 +144,6 @@ public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
});
}

private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistration) {
JwsAlgorithm jwsAlgorithm = this.jwsAlgorithmResolver.apply(clientRegistration);
if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 6. If the ID Token is received via direct communication between the Client
// and the Token Endpoint (which it is in this flow),
// the TLS server validation MAY be used to validate the issuer in place of
// checking the token signature.
// The Client MUST validate the signature of all other ID Tokens according to
// JWS [JWS]
// using the algorithm specified in the JWT alg Header Parameter.
// The Client MUST use the keys provided by the Issuer.
//
// 7. The alg value SHOULD be the default of RS256 or the algorithm sent by
// the Client
// in the id_token_signed_response_alg parameter during Registration.
String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
if (!StringUtils.hasText(jwkSetUri)) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured the JwkSet URI.",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
return NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri)
.jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm)
.webClient(this.webClientFactory.apply(clientRegistration))
.build();
}
if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 8. If the JWT alg Header Parameter uses a MAC based algorithm such as
// HS256, HS384, or HS512,
// the octets of the UTF-8 representation of the client_secret
// corresponding to the client_id contained in the aud (audience) Claim
// are used as the key to validate the signature.
// For MAC based algorithms, the behavior is unspecified if the aud is
// multi-valued or
// if an azp value is present that is different than the aud value.

String clientSecret = clientRegistration.getClientSecret();
if (!StringUtils.hasText(clientSecret)) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured the client secret.",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
return NimbusReactiveJwtDecoder.withSecretKey(secretKeySpec)
.macAlgorithm((MacAlgorithm) jwsAlgorithm)
.build();
}
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

/**
* Sets the factory that provides an {@link OAuth2TokenValidator}, which is used by
* the {@link ReactiveJwtDecoder}. The default composes {@link JwtTimestampValidator}
Expand Down Expand Up @@ -261,4 +198,98 @@ public void setWebClientFactory(Function<ClientRegistration, WebClient> webClien
this.webClientFactory = webClientFactory;
}

/**
* Sets the factory that provides a {@link ReactiveJwtDecoder} used for building
* {@link NimbusReactiveJwtDecoder}
* @param reactiveJwtDecoderFactory the factory that provides a
* {@link ReactiveJwtDecoder}
*
* @since 6.3
*/
public void setReactiveJwtDecoderFactory(
Function<ClientRegistration, ReactiveJwtDecoder> reactiveJwtDecoderFactory) {
Assert.notNull(reactiveJwtDecoderFactory, "reactiveJwtDecoderFactory cannot be null");
this.reactiveJwtDecoderFactory = reactiveJwtDecoderFactory;
}

/**
* @author Max Batischev
* @since 6.3
*/
private class DefaultReactiveJwtDecoderFactory implements Function<ClientRegistration, ReactiveJwtDecoder> {

@Override
public ReactiveJwtDecoder apply(ClientRegistration clientRegistration) {
JwsAlgorithm jwsAlgorithm = ReactiveOidcIdTokenDecoderFactory.this.jwsAlgorithmResolver
.apply(clientRegistration);
if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 6. If the ID Token is received via direct communication between the
// Client
// and the Token Endpoint (which it is in this flow),
// the TLS server validation MAY be used to validate the issuer in place
// of
// checking the token signature.
// The Client MUST validate the signature of all other ID Tokens according
// to
// JWS [JWS]
// using the algorithm specified in the JWT alg Header Parameter.
// The Client MUST use the keys provided by the Issuer.
//
// 7. The alg value SHOULD be the default of RS256 or the algorithm sent
// by
// the Client
// in the id_token_signed_response_alg parameter during Registration.
String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
if (!StringUtils.hasText(jwkSetUri)) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured the JwkSet URI.",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
return NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri)
.jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm)
.webClient(ReactiveOidcIdTokenDecoderFactory.this.webClientFactory.apply(clientRegistration))
.build();
}
if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 8. If the JWT alg Header Parameter uses a MAC based algorithm such as
// HS256, HS384, or HS512,
// the octets of the UTF-8 representation of the client_secret
// corresponding to the client_id contained in the aud (audience) Claim
// are used as the key to validate the signature.
// For MAC based algorithms, the behavior is unspecified if the aud is
// multi-valued or
// if an azp value is present that is different than the aud value.

String clientSecret = clientRegistration.getClientSecret();
if (!StringUtils.hasText(clientSecret)) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured the client secret.",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
return NimbusReactiveJwtDecoder.withSecretKey(secretKeySpec)
.macAlgorithm((MacAlgorithm) jwsAlgorithm)
.build();
}
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '"
+ clientRegistration.getRegistrationId()
+ "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'",
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

}

}
Loading

0 comments on commit 078218a

Please sign in to comment.