Skip to content

Commit

Permalink
Add jwksRetainOnErrorDuration
Browse files Browse the repository at this point in the history
  • Loading branch information
luneo7 committed Nov 15, 2024
1 parent 0d578de commit 6dd3373
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ protected HttpsJwks initializeHttpsJwks(String location)
new InetSocketAddress(authContextInfo.getHttpProxyHost(), authContextInfo.getHttpProxyPort())));
}
theHttpsJwks.setSimpleHttpGet(httpGet);
theHttpsJwks.setRetainCacheOnErrorDuration(authContextInfo.getJwksRetainCacheOnErrorDuration() * 60L);
return theHttpsJwks;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class JWTAuthContextInfo {
private String decryptionKeyContent;
private Integer jwksRefreshInterval = 60;
private int forcedJwksRefreshInterval = 30;
private int jwksRetainCacheOnErrorDuration = 0;
private String tokenHeader = "Authorization";
private String tokenCookie;
private boolean alwaysCheckAuthorization;
Expand Down Expand Up @@ -121,6 +122,7 @@ public JWTAuthContextInfo(JWTAuthContextInfo orig) {
this.decryptionKeyContent = orig.getDecryptionKeyContent();
this.jwksRefreshInterval = orig.getJwksRefreshInterval();
this.forcedJwksRefreshInterval = orig.getForcedJwksRefreshInterval();
this.jwksRetainCacheOnErrorDuration = orig.getJwksRetainCacheOnErrorDuration();
this.tokenHeader = orig.getTokenHeader();
this.tokenCookie = orig.getTokenCookie();
this.alwaysCheckAuthorization = orig.isAlwaysCheckAuthorization();
Expand Down Expand Up @@ -283,6 +285,14 @@ public void setForcedJwksRefreshInterval(int forcedJwksRefreshInterval) {
this.forcedJwksRefreshInterval = forcedJwksRefreshInterval;
}

public int getJwksRetainCacheOnErrorDuration() {
return jwksRetainCacheOnErrorDuration;
}

public void setJwksRetainCacheOnErrorDuration(int jwksRetainCacheOnErrorDuration) {
this.jwksRetainCacheOnErrorDuration = jwksRetainCacheOnErrorDuration;
}

public String getTokenHeader() {
return tokenHeader;
}
Expand Down Expand Up @@ -436,6 +446,7 @@ public String toString() {
", decryptionKeyLocation='" + decryptionKeyLocation + '\'' +
", decryptionKeyContent='" + decryptionKeyContent + '\'' +
", jwksRefreshInterval=" + jwksRefreshInterval +
", jwksRetainCacheOnErrorDuration=" + jwksRetainCacheOnErrorDuration +
", tokenHeader='" + tokenHeader + '\'' +
", tokenCookie='" + tokenCookie + '\'' +
", alwaysCheckAuthorization=" + alwaysCheckAuthorization +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ private static JWTAuthContextInfoProvider create(String key,
provider.mpJwtVerifyTokenAge = Optional.empty();
provider.jwksRefreshInterval = 60;
provider.forcedJwksRefreshInterval = 30;
provider.jwksRetainCacheOnErrorDuration = 0;
provider.signatureAlgorithm = Optional.of(SignatureAlgorithm.RS256);
provider.keyEncryptionAlgorithm = Optional.empty();
provider.mpJwtDecryptKeyAlgorithm = new HashSet<>(Arrays.asList(KeyEncryptionAlgorithm.RSA_OAEP,
Expand Down Expand Up @@ -465,6 +466,15 @@ private static JWTAuthContextInfoProvider create(String key,
@ConfigProperty(name = "smallrye.jwt.jwks.forced-refresh-interval", defaultValue = "30")
private int forcedJwksRefreshInterval;

/**
* JWK cache retain on error duration in minutes which sets the length of time, before trying again, to keep using the cache
* when an error occurs making the request to the JWKS URI or parsing the response.
* It will be ignored unless the 'mp.jwt.verify.publickey.location' property points to the HTTP or HTTPS URL based JWK set.
*/
@Inject
@ConfigProperty(name = "smallrye.jwt.jwks.retain-cache-on-error-duration", defaultValue = "0")
private int jwksRetainCacheOnErrorDuration;

/**
* Supported JSON Web Algorithm asymmetric or symmetric signature algorithm.
*
Expand Down Expand Up @@ -836,6 +846,7 @@ Optional<JWTAuthContextInfo> getOptionalContextInfo() {
contextInfo.setTokenAge(mpJwtVerifyTokenAge.orElse(null));
contextInfo.setJwksRefreshInterval(jwksRefreshInterval);
contextInfo.setForcedJwksRefreshInterval(forcedJwksRefreshInterval);
contextInfo.setJwksRetainCacheOnErrorDuration(jwksRetainCacheOnErrorDuration);
Set<SignatureAlgorithm> resolvedAlgorithm = mpJwtPublicKeyAlgorithm;
if (signatureAlgorithm.isPresent()) {
if (signatureAlgorithm.get().getAlgorithm().startsWith("HS")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.net.Proxy;
Expand All @@ -30,12 +34,18 @@
import java.security.interfaces.RSAPublicKey;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;

import org.jose4j.base64url.Base64Url;
import org.jose4j.http.Get;
import org.jose4j.http.SimpleResponse;
import org.jose4j.json.internal.json_simple.JSONObject;
import org.jose4j.jwk.HttpsJwks;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.OctetSequenceJsonWebKey;
Expand Down Expand Up @@ -68,6 +78,8 @@ class KeyLocationResolverTest {
Get mockedGet;
@Mock
UrlStreamResolver urlResolver;
@Mock
SimpleResponse simpleResponse;

RSAPublicKey rsaKey;
SecretKey secretKey;
Expand Down Expand Up @@ -180,6 +192,46 @@ protected Get getHttpGet() {
assertNull(keyLocationResolver.key);
}

@Test
void keepsRsaKeyFromHttpsJwksWhenErrorDuringRefresh() throws Exception {
long cacheDuration = 1L;
int jwksRetainCacheOnErrorDuration = 10;
JWTAuthContextInfo contextInfo = new JWTAuthContextInfo("https://github.com/my_key.jwks", "issuer");
contextInfo.setJwksRetainCacheOnErrorDuration(jwksRetainCacheOnErrorDuration);

HttpsJwks spiedHttpsJwks = Mockito.spy(new HttpsJwks(contextInfo.getPublicKeyLocation()));
spiedHttpsJwks.setDefaultCacheDuration(cacheDuration);
when(simpleResponse.getBody()).thenReturn(generateJWK(rsaKey));
when(mockedGet.get(contextInfo.getPublicKeyLocation())).thenReturn(simpleResponse);

KeyLocationResolver keyLocationResolver = new KeyLocationResolver(contextInfo) {
protected HttpsJwks getHttpsJwks(String loc) {
return spiedHttpsJwks;
}

protected Get getHttpGet() {
return mockedGet;
}
};

Mockito.verify(spiedHttpsJwks).setRetainCacheOnErrorDuration(jwksRetainCacheOnErrorDuration * 60L);
Mockito.verify(spiedHttpsJwks).setSimpleHttpGet(mockedGet);

when(signature.getHeaders()).thenReturn(headers);
when(headers.getStringHeaderValue(JsonWebKey.KEY_ID_PARAMETER)).thenReturn("1");
when(headers.getStringHeaderValue(JsonWebKey.ALGORITHM_PARAMETER)).thenReturn("RS256");

assertEquals(rsaKey, keyLocationResolver.resolveKey(signature, emptyList()));

doThrow(RuntimeException.class).when(mockedGet).get(contextInfo.getPublicKeyLocation());

TimeUnit.SECONDS.sleep(cacheDuration);

assertEquals(rsaKey, keyLocationResolver.resolveKey(signature, emptyList()));

verify(mockedGet, atLeastOnce()).get(contextInfo.getPublicKeyLocation());
}

@Test
void loadRsaKeyFromHttpJwks() throws Exception {
JWTAuthContextInfo contextInfo = new JWTAuthContextInfo("http://github.com/my_key.jwks", "issuer");
Expand Down Expand Up @@ -330,7 +382,7 @@ void loadHttpsPemCrt() throws Exception {
contextInfo.setJwksRefreshInterval(10);

Mockito.doThrow(new JoseException("")).when(mockedHttpsJwks).refresh();
Mockito.doReturn(ResourceUtils.getAsClasspathResource("publicCrt.pem"))
doReturn(ResourceUtils.getAsClasspathResource("publicCrt.pem"))
.when(urlResolver).resolve(Mockito.any());
KeyLocationResolver keyLocationResolver = new KeyLocationResolver(contextInfo) {
protected HttpsJwks initializeHttpsJwks(String loc) {
Expand Down Expand Up @@ -380,4 +432,18 @@ void loadJWKOnClassPath() throws Exception {
assertEquals(keyLocationResolver.key,
keyLocationResolver.getJsonWebKey("key1", null).getKey());
}

private String generateJWK(RSAPublicKey publicKey) {
Map<String, Object> key = new HashMap<>();

key.put("alg", "RS256");
key.put("use", "sig");
key.put("kty", publicKey.getAlgorithm());
key.put("kid", "1");
key.put("n", Base64Url.encode(publicKey.getModulus().toByteArray()));
key.put("e", Base64Url.encode(publicKey.getPublicExponent().toByteArray()));

return JSONObject.toJSONString(Collections.singletonMap("keys",
Collections.singletonList(key)));
}
}

0 comments on commit 6dd3373

Please sign in to comment.