diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 2713ee96b2..934472922f 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,11 @@ package org.springframework.security.oauth2.jwt; +import com.nimbusds.jose.KeySourceException; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKSelector; +import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator; +import com.nimbusds.jose.jwk.source.URLBasedJWKSetSource; import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; @@ -26,6 +31,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; @@ -35,11 +41,8 @@ import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; -import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.jwk.JWKSet; -import com.nimbusds.jose.jwk.source.JWKSetCache; import com.nimbusds.jose.jwk.source.JWKSource; -import com.nimbusds.jose.jwk.source.RemoteJWKSet; import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; @@ -80,6 +83,7 @@ * @author Josh Cummings * @author Joe Grandja * @author Mykyta Bezverkhyi + * @author Daeho Kwon * @since 5.2 */ public final class NimbusJwtDecoder implements JwtDecoder { @@ -165,7 +169,7 @@ private Jwt createJwt(String token, JWT parsedJwt) { .build(); // @formatter:on } - catch (RemoteKeySourceException ex) { + catch (KeySourceException ex) { this.logger.trace("Failed to retrieve JWK set", ex); if (ex.getCause() instanceof ParseException) { throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex); @@ -377,11 +381,12 @@ JWSKeySelector jwsKeySelector(JWKSource jwkSou } JWKSource jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) { - if (this.cache == null) { - return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever); + URLBasedJWKSetSource urlBasedJWKSetSource = new URLBasedJWKSetSource(toURL(jwkSetUri), jwkSetRetriever); + if(this.cache == null) { + return new SpringURLBasedJWKSource(urlBasedJWKSetSource); } - JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache); - return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache); + SpringJWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache); + return new SpringURLBasedJWKSource<>(urlBasedJWKSetSource, jwkSetCache); } JWTProcessor processor() { @@ -414,7 +419,49 @@ private static URL toURL(String url) { } } - private static final class SpringJWKSetCache implements JWKSetCache { + private static final class SpringURLBasedJWKSource implements JWKSource { + + private final URLBasedJWKSetSource urlBasedJWKSetSource; + + private final SpringJWKSetCache jwkSetCache; + + private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource) { + this(urlBasedJWKSetSource, null); + } + + private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource, SpringJWKSetCache jwkSetCache) { + this.urlBasedJWKSetSource = urlBasedJWKSetSource; + this.jwkSetCache = jwkSetCache; + } + + @Override + public List get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException { + if (this.jwkSetCache != null) { + synchronized (this) { + JWKSet jwkSet = this.jwkSetCache.get(); + if (this.jwkSetCache.requiresRefresh() || jwkSet == null) { + jwkSet = fetchJWKSet(context); + this.jwkSetCache.put(jwkSet); + } + List jwks = jwkSelector.select(jwkSet); + if(!jwks.isEmpty()) { + return jwks; + } + jwkSet = fetchJWKSet(context); + this.jwkSetCache.put(jwkSet); + return jwkSelector.select(jwkSet); + } + } + return jwkSelector.select(fetchJWKSet(context)); + } + + private JWKSet fetchJWKSet(SecurityContext context) throws KeySourceException { + return this.urlBasedJWKSetSource.getJWKSet(JWKSetCacheRefreshEvaluator.noRefresh(), + System.currentTimeMillis(), context); + } + } + + private static final class SpringJWKSetCache { private final String jwkSetUri; @@ -440,20 +487,16 @@ private void updateJwkSetFromCache() { } } - // Note: Only called from inside a synchronized block in RemoteJWKSet. - @Override + // Note: Only called from inside a synchronized block in SpringURLBasedJWKSource. public void put(JWKSet jwkSet) { this.jwkSet = jwkSet; this.cache.put(this.jwkSetUri, jwkSet.toString(false)); } - @Override public JWKSet get() { return (!requiresRefresh()) ? this.jwkSet : null; - } - @Override public boolean requiresRefresh() { return this.cache.get(this.jwkSetUri) == null; }