Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Deprecated Usages of RemoteJWKSet #16296

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,12 @@

package org.springframework.security.oauth2.jwt;

import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator;
import com.nimbusds.jose.jwk.source.URLBasedJWKSetSource;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
Expand All @@ -26,6 +32,7 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
Expand All @@ -35,11 +42,8 @@

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.RemoteKeySourceException;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
Expand Down Expand Up @@ -80,6 +84,7 @@
* @author Josh Cummings
* @author Joe Grandja
* @author Mykyta Bezverkhyi
* @author Daeho Kwon
* @since 5.2
*/
public final class NimbusJwtDecoder implements JwtDecoder {
Expand Down Expand Up @@ -165,7 +170,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
.build();
// @formatter:on
}
catch (RemoteKeySourceException ex) {
catch (KeySourceException ex) {
this.logger.trace("Failed to retrieve JWK set", ex);
if (ex.getCause() instanceof ParseException) {
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
Expand Down Expand Up @@ -377,11 +382,12 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
}

JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) {
if (this.cache == null) {
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever);
URLBasedJWKSetSource urlBasedJWKSetSource = new URLBasedJWKSetSource(toURL(jwkSetUri), jwkSetRetriever);
if(this.cache == null) {
return new SpringURLBasedJWKSource(urlBasedJWKSetSource);
}
JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache);
SpringJWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
return new SpringURLBasedJWKSource<>(urlBasedJWKSetSource, jwkSetCache);
}

JWTProcessor<SecurityContext> processor() {
Expand Down Expand Up @@ -414,7 +420,80 @@ private static URL toURL(String url) {
}
}

private static final class SpringJWKSetCache implements JWKSetCache {
private static final class SpringURLBasedJWKSource<C extends SecurityContext> implements JWKSource<C> {

private final URLBasedJWKSetSource urlBasedJWKSetSource;

private final SpringJWKSetCache jwkSetCache;

private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource) {
this(urlBasedJWKSetSource, null);
}

private SpringURLBasedJWKSource(URLBasedJWKSetSource urlBasedJWKSetSource, SpringJWKSetCache jwkSetCache) {
this.urlBasedJWKSetSource = urlBasedJWKSetSource;
this.jwkSetCache = jwkSetCache;
}

@Override
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
if (this.jwkSetCache != null) {
JWKSet jwkSet = this.jwkSetCache.get();
if (this.jwkSetCache.requiresRefresh() || jwkSet == null) {
synchronized (this) {
jwkSet = fetchJWKSet();
this.jwkSetCache.put(jwkSet);
}
}
List<JWK> matches = jwkSelector.select(jwkSet);
if(!matches.isEmpty()) {
return matches;
}
String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
if (soughtKeyID == null) {
return Collections.emptyList();
}
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
return Collections.emptyList();
}
synchronized (this) {
if(jwkSet == this.jwkSetCache.get()) {
jwkSet = fetchJWKSet();
this.jwkSetCache.put(jwkSet);
} else {
jwkSet = this.jwkSetCache.get();
}
}
if(jwkSet == null) {
return Collections.emptyList();
}
return jwkSelector.select(jwkSet);
}
return jwkSelector.select(fetchJWKSet());
}

private JWKSet fetchJWKSet() throws KeySourceException {
return this.urlBasedJWKSetSource.getJWKSet(JWKSetCacheRefreshEvaluator.noRefresh(),
System.currentTimeMillis(), null);
}

private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
Set<String> keyIDs = jwkMatcher.getKeyIDs();

if (keyIDs == null || keyIDs.isEmpty()) {
return null;
}

for (String id: keyIDs) {
if (id != null) {
return id;
}
}
return null;
}
}

private static final class SpringJWKSetCache {

private final String jwkSetUri;

Expand All @@ -440,20 +519,16 @@ private void updateJwkSetFromCache() {
}
}

// Note: Only called from inside a synchronized block in RemoteJWKSet.
@Override
// Note: Only called from inside a synchronized block in SpringURLBasedJWKSource.
public void put(JWKSet jwkSet) {
this.jwkSet = jwkSet;
this.cache.put(this.jwkSetUri, jwkSet.toString(false));
}

@Override
public JWKSet get() {
return (!requiresRefresh()) ? this.jwkSet : null;

}

@Override
public boolean requiresRefresh() {
return this.cache.get(this.jwkSetUri) == null;
}
Expand Down