diff --git a/.github/workflows/pre-merge-checks.yml b/.github/workflows/pre-merge-checks.yml index c68b2e815d..69ebe5f813 100644 --- a/.github/workflows/pre-merge-checks.yml +++ b/.github/workflows/pre-merge-checks.yml @@ -177,7 +177,7 @@ jobs: localstack: image: ${{ (needs.check-changed-files.outputs.java_changed == 'true') && 'localstack/localstack:3.0.0' || '' }} env: - SERVICES: "lambda, apigateway, iam, ec2, sqs, s3, sts, kms, sns, ssm, events" + SERVICES: "lambda, apigateway, iam, ec2, sqs, s3, sts, kms, sns, ssm, events, logs" GATEWAY_LISTEN: 0.0.0.0:45678 LOCALSTACK_HOST: localhost:45678 TEST_AWS_ACCOUNT_ID: 123456789012 diff --git a/account-management-integration-tests/build.gradle b/account-management-integration-tests/build.gradle index d2bc727bad..6ab199efca 100644 --- a/account-management-integration-tests/build.gradle +++ b/account-management-integration-tests/build.gradle @@ -28,9 +28,7 @@ dependencies { test { useJUnitPlatform() - filter { - includeTestsMatching "*" - } + environment "AUDIT_SIGNING_KEY_ALIAS", "alias/local-audit-payload-signing-key-alias" environment "AWS_ACCESS_KEY_ID", "mock-access-key" environment "AWS_REGION", "eu-west-2" diff --git a/delivery-receipts-integration-tests/build.gradle b/delivery-receipts-integration-tests/build.gradle index 155cf6632e..ae90d9ecce 100644 --- a/delivery-receipts-integration-tests/build.gradle +++ b/delivery-receipts-integration-tests/build.gradle @@ -21,9 +21,6 @@ dependencies { test { useJUnitPlatform() - filter { - includeTestsMatching "*" - } environment "AWS_ACCESS_KEY_ID", "mock-access-key" environment "AWS_REGION", "eu-west-2" @@ -56,14 +53,15 @@ test { doLast { tasks.getByName("jacocoTestReport").sourceDirectories.from( - project(":account-management-api").sourceSets.main.java, + project(":delivery-receipts-api").sourceSets.main.java, project(":shared").sourceSets.main.java) tasks.getByName("jacocoTestReport").classDirectories.from( - project(":account-management-api").sourceSets.main.output, + project(":delivery-receipts-api").sourceSets.main.output, project(":shared").sourceSets.main.output) } dependsOn ":composeUp" + finalizedBy ":composeDown" } jacocoTestReport { diff --git a/docker-compose.yml b/docker-compose.yml index 930e98a20c..0063336132 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,7 +5,7 @@ services: image: localstack/localstack:3.1.0@sha256:a47b435f876a100115d1d9f24e19b6b302cc7acb78b91e9258122116283ba462 restart: no environment: - SERVICES: iam, ec2, sqs, s3, sts, kms, sns, ssm, cloudwatch, events + SERVICES: iam, ec2, sqs, s3, sts, kms, sns, ssm, cloudwatch, events, logs GATEWAY_LISTEN: 0.0.0.0:45678 LOCALSTACK_HOST: localhost:45678 TEST_AWS_ACCOUNT_ID: 123456789012 diff --git a/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetJarJwkHandler.java b/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetJarJwkHandler.java index 58d33ce4f9..169d0c1dae 100644 --- a/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetJarJwkHandler.java +++ b/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetJarJwkHandler.java @@ -46,7 +46,8 @@ public APIGatewayProxyResponseEvent handleRequest( } public APIGatewayProxyResponseEvent mfaResetJarJwkHandler() { - LOG.info("MFA reset JAR Signing JWK request received"); + LOG.info( + "Request for Auth reverification request JAR signature verification key received."); try { List signingKeys = new ArrayList<>(); @@ -55,7 +56,7 @@ public APIGatewayProxyResponseEvent mfaResetJarJwkHandler() { JWKSet jwkSet = new JWKSet(signingKeys); - LOG.info("Generating MFA reset JAR signing JWK successful response"); + LOG.info("Served Auth reverification request JAR signature verification key JWK set."); return generateApiGatewayProxyResponse( 200, @@ -63,8 +64,12 @@ public APIGatewayProxyResponseEvent mfaResetJarJwkHandler() { Map.of("Cache-Control", "max-age=86400"), null); } catch (Exception e) { - LOG.error("Error in MfaResetJarJwkHandler lambda", e); - return generateApiGatewayProxyResponse(500, "Error providing MFA Reset JAR JWK data"); + LOG.error( + "Failed to serve Auth reverification request JAR signature verification key.", + e); + return generateApiGatewayProxyResponse( + 500, + "Auth MFA reverification request JAR signature verification key not available."); } } } diff --git a/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetStorageTokenJwkHandler.java b/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetStorageTokenJwkHandler.java index 5154554cbf..c0a360d7a0 100644 --- a/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetStorageTokenJwkHandler.java +++ b/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetStorageTokenJwkHandler.java @@ -49,7 +49,7 @@ public APIGatewayProxyResponseEvent handleRequest( public APIGatewayProxyResponseEvent mfaResetStorageTokenJwkHandler() { try { - LOG.info("MfaResetStorageTokenJwk request received"); + LOG.info("Request for Auth MFA storage token signature verification key received."); List signingKeys = new ArrayList<>(); @@ -57,7 +57,7 @@ public APIGatewayProxyResponseEvent mfaResetStorageTokenJwkHandler() { JWKSet jwkSet = new JWKSet(signingKeys); - LOG.info("Generating MfaResetStorageTokenJwk successful response"); + LOG.info("Served Auth MFA storage token signature verification key JWK set."); return generateApiGatewayProxyResponse( 200, @@ -65,9 +65,9 @@ public APIGatewayProxyResponseEvent mfaResetStorageTokenJwkHandler() { Map.of("Cache-Control", "max-age=86400"), null); } catch (Exception e) { - LOG.error("Error in MfaResetStorageTokenJwk lambda", e); + LOG.error("Failed to serve Auth MFA storage token signature verification key.", e); return generateApiGatewayProxyResponse( - 500, "Error providing MfaResetStorageTokenJwk data"); + 500, "Auth MFA storage token signature verification key not available."); } } } diff --git a/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/services/IPVReverificationService.java b/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/services/IPVReverificationService.java index 1dcadc3095..5d2dca2c75 100644 --- a/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/services/IPVReverificationService.java +++ b/frontend-api/src/main/java/uk/gov/di/authentication/frontendapi/services/IPVReverificationService.java @@ -39,7 +39,7 @@ public class IPVReverificationService { private static final Logger LOG = LogManager.getLogger(IPVReverificationService.class); private static final JWSAlgorithm SIGNING_ALGORITHM = JWSAlgorithm.ES256; private static final String MFA_RESET_SCOPE = "reverification"; - private static final String STATE_STORAGE_PREFIX = "mfaReset:state:"; + public static final String STATE_STORAGE_PREFIX = "mfaReset:state:"; private final ConfigurationService configurationService; private final JwtService jwtService; private final NowClock nowClock; diff --git a/frontend-api/src/test/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetJarJwkHandlerTest.java b/frontend-api/src/test/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetJarJwkHandlerTest.java index 5e06b4af0d..582cf5c612 100644 --- a/frontend-api/src/test/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetJarJwkHandlerTest.java +++ b/frontend-api/src/test/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetJarJwkHandlerTest.java @@ -57,6 +57,9 @@ void shouldReturn500WhenSigningKeyIsNotPresent() { var result = handler.handleRequest(event, context); assertThat(result, hasStatus(500)); - assertThat(result, hasBody("Error providing MFA Reset JAR JWK data")); + assertThat( + result, + hasBody( + "Auth MFA reverification request JAR signature verification key not available.")); } } diff --git a/frontend-api/src/test/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetStorageTokenJwkHandlerTest.java b/frontend-api/src/test/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetStorageTokenJwkHandlerTest.java index ceea0063b9..050bdccc35 100644 --- a/frontend-api/src/test/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetStorageTokenJwkHandlerTest.java +++ b/frontend-api/src/test/java/uk/gov/di/authentication/frontendapi/lambda/MfaResetStorageTokenJwkHandlerTest.java @@ -57,7 +57,9 @@ void shouldReturn500WhenSigningKeyIsNotPresent() { var result = handler.handleRequest(event, context); assertThat(result, hasStatus(500)); - assertThat(result, hasBody("Error providing MfaResetStorageTokenJwk data")); + assertThat( + result, + hasBody("Auth MFA storage token signature verification key not available.")); } @Test diff --git a/integration-tests/build.gradle b/integration-tests/build.gradle index 1b5439ded7..741a4aba5c 100644 --- a/integration-tests/build.gradle +++ b/integration-tests/build.gradle @@ -36,14 +36,16 @@ dependencies { implementation project(":utils"), noXray testRuntimeOnly "org.junit.jupiter:junit-jupiter-engine:${dependencyVersions.junit}" + testImplementation("uk.org.webcompere:system-stubs-jupiter:2.1.3") + testImplementation("com.google.guava:guava:33.3.1-jre") + testImplementation("org.awaitility:awaitility:4.2.0") + testImplementation('org.wiremock:wiremock-jetty12:3.10.0') } test { useJUnitPlatform() exclude 'uk/gov/di/authentication/contract/**' - jacoco { - } environment "AUDIT_SIGNING_KEY_ALIAS", "alias/local-audit-payload-signing-key-alias" environment "AWS_ACCESS_KEY_ID", "mock-access-key" environment "AWS_REGION", "eu-west-2" @@ -69,6 +71,10 @@ test { environment "BULK_USER_EMAIL_INCLUDED_TERMS_AND_CONDITIONS", "1.0,1.1,1.2,1.3,1.4" environment "SEND_STORAGE_TOKEN_TO_IPV_ENABLED", "true" + testLogging { + showStandardStreams = false + } + doLast { tasks.getByName("jacocoTestReport").sourceDirectories.from( project(":frontend-api").sourceSets.main.java, diff --git a/integration-tests/src/test/java/uk/gov/di/authentication/api/AuthSigningKeyJWKSIntegrationTest.java b/integration-tests/src/test/java/uk/gov/di/authentication/api/AuthSigningKeyJWKSIntegrationTest.java new file mode 100644 index 0000000000..05f4b09d54 --- /dev/null +++ b/integration-tests/src/test/java/uk/gov/di/authentication/api/AuthSigningKeyJWKSIntegrationTest.java @@ -0,0 +1,124 @@ +package uk.gov.di.authentication.api; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.KeyType; +import com.nimbusds.jose.jwk.KeyUse; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.kms.model.GetPublicKeyRequest; +import software.amazon.awssdk.services.kms.model.GetPublicKeyResponse; +import software.amazon.awssdk.services.kms.model.KeyUsageType; +import uk.gov.di.authentication.frontendapi.lambda.MfaResetJarJwkHandler; +import uk.gov.di.authentication.shared.services.ConfigurationService; +import uk.gov.di.authentication.sharedtest.basetest.ApiGatewayHandlerIntegrationTest; +import uk.gov.di.authentication.sharedtest.extensions.KmsKeyExtension; +import uk.gov.di.authentication.sharedtest.logging.CaptureLoggingExtension; +import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; +import uk.org.webcompere.systemstubs.jupiter.SystemStub; +import uk.org.webcompere.systemstubs.jupiter.SystemStubsExtension; + +import java.net.URI; +import java.util.Map; +import java.util.Optional; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasItem; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static uk.gov.di.authentication.shared.helpers.HashHelper.hashSha256String; +import static uk.gov.di.authentication.sharedtest.logging.LogEventMatcher.withMessageContaining; +import static uk.gov.di.orchestration.sharedtest.matchers.APIGatewayProxyResponseEventMatcher.hasStatus; + +@ExtendWith(SystemStubsExtension.class) +class AuthSigningKeyJWKSIntegrationTest extends ApiGatewayHandlerIntegrationTest { + private static final Logger LOG = LogManager.getLogger(AuthSigningKeyJWKSIntegrationTest.class); + + @SystemStub private static final EnvironmentVariables environment = new EnvironmentVariables(); + + @RegisterExtension + private static final KmsKeyExtension mfaResetJarSigningKey = + new KmsKeyExtension("mfa-reset-jar-signing-key", KeyUsageType.SIGN_VERIFY); + + @RegisterExtension + private static final CaptureLoggingExtension logging = + new CaptureLoggingExtension(MfaResetJarJwkHandler.class); + + private static String expectedHashKeyArn; + + @BeforeAll + static void setupEnvironment() { + environment.set( + "IPV_REVERIFICATION_REQUESTS_SIGNING_KEY_ALIAS", mfaResetJarSigningKey.getKeyId()); + + try (KmsClient kmsClient = getKmsClient()) { + GetPublicKeyRequest getPublicKeyRequest = + GetPublicKeyRequest.builder().keyId(mfaResetJarSigningKey.getKeyId()).build(); + + GetPublicKeyResponse getPublicKeyResponse = kmsClient.getPublicKey(getPublicKeyRequest); + + expectedHashKeyArn = hashSha256String(getPublicKeyResponse.keyId()); + } catch (Exception e) { + LOG.error(e.getMessage(), e); + } + } + + private static KmsClient getKmsClient() { + return KmsClient.builder() + .endpointOverride(URI.create("http://localhost:45678")) + .credentialsProvider( + StaticCredentialsProvider.create( + AwsBasicCredentials.create("dummy", "dummy"))) + .region(Region.EU_WEST_2) + .build(); + } + + @Test + void shouldReturnJWKSetContainingTheReverificationSigningKey() { + handler = new MfaResetJarJwkHandler(new ConfigurationService()); + + var response = makeRequest(Optional.empty(), Map.of(), Map.of()); + + assertThat(response, hasStatus(200)); + + JsonObject jwk = JsonParser.parseString(response.getBody()).getAsJsonObject(); + JsonArray keys = jwk.get("keys").getAsJsonArray(); + assertEquals(1, keys.size(), "JWKS endpoint must return a single key."); + + checkPublicSigningKeyResponseMeetsADR0030(keys.get(0).getAsJsonObject()); + } + + @Test + void shouldNotAllowExceptionsToEscape() { + environment.set("IPV_REVERIFICATION_REQUESTS_SIGNING_KEY_ALIAS", "wrong-key-alias"); + + handler = new MfaResetJarJwkHandler(new ConfigurationService()); + + var response = makeRequest(Optional.empty(), Map.of(), Map.of()); + + assertThat(response, hasStatus(500)); + assertThat( + logging.events(), + hasItem( + withMessageContaining( + "Failed to serve Auth reverification request JAR signature verification key."))); + } + + private static void checkPublicSigningKeyResponseMeetsADR0030(JsonObject key) { + assertEquals(expectedHashKeyArn, key.get("kid").getAsString()); + assertEquals(KeyType.EC.getValue(), key.get("kty").getAsString()); + assertEquals(KeyUse.SIGNATURE.getValue(), key.get("use").getAsString()); + assertEquals(Curve.P_256.getName(), key.get("crv").getAsString()); + assertEquals(JWSAlgorithm.ES256.toString(), key.get("alg").getAsString()); + } +} diff --git a/integration-tests/src/test/java/uk/gov/di/authentication/api/MfaResetAuthorizeHandlerIntegrationTest.java b/integration-tests/src/test/java/uk/gov/di/authentication/api/MfaResetAuthorizeHandlerIntegrationTest.java new file mode 100644 index 0000000000..41f975b9a9 --- /dev/null +++ b/integration-tests/src/test/java/uk/gov/di/authentication/api/MfaResetAuthorizeHandlerIntegrationTest.java @@ -0,0 +1,215 @@ +package uk.gov.di.authentication.api; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWEObject; +import com.nimbusds.jose.crypto.RSADecrypter; +import com.nimbusds.jose.jwk.RSAKey; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.oauth2.sdk.id.Subject; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; +import software.amazon.awssdk.services.kms.model.KeyUsageType; +import uk.gov.di.authentication.frontendapi.entity.MfaResetRequest; +import uk.gov.di.authentication.frontendapi.lambda.MfaResetAuthorizeHandler; +import uk.gov.di.authentication.shared.helpers.ClientSubjectHelper; +import uk.gov.di.authentication.shared.helpers.SaltHelper; +import uk.gov.di.authentication.shared.serialization.Json; +import uk.gov.di.authentication.shared.services.ConfigurationService; +import uk.gov.di.authentication.shared.services.SerializationService; +import uk.gov.di.authentication.sharedtest.basetest.ApiGatewayHandlerIntegrationTest; +import uk.gov.di.authentication.sharedtest.doubles.MetricsLoggerTestDouble; +import uk.gov.di.authentication.sharedtest.extensions.CloudWatchExtension; +import uk.gov.di.authentication.sharedtest.extensions.IDReverificationStateExtension; +import uk.gov.di.authentication.sharedtest.extensions.KmsKeyExtension; +import uk.gov.di.authentication.sharedtest.extensions.RedisExtension; +import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; +import uk.org.webcompere.systemstubs.jupiter.SystemStub; +import uk.org.webcompere.systemstubs.jupiter.SystemStubsExtension; + +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.text.ParseException; +import java.util.Arrays; +import java.util.Base64; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static uk.gov.di.authentication.frontendapi.services.IPVReverificationService.STATE_STORAGE_PREFIX; +import static uk.gov.di.authentication.shared.domain.CloudwatchMetrics.MFA_RESET_HANDOFF; +import static uk.gov.di.authentication.sharedtest.matchers.APIGatewayProxyResponseEventMatcher.hasStatus; + +@ExtendWith(SystemStubsExtension.class) +class MfaResetAuthorizeHandlerIntegrationTest extends ApiGatewayHandlerIntegrationTest { + private static final String USER_EMAIL = "test@email.com"; + private static final String USER_PASSWORD = "Password123!"; + private static final String USER_PHONE_NUMBER = "+447712345432"; + private static KeyPair keyPair; + private String sessionId; + + @SystemStub static EnvironmentVariables environment = new EnvironmentVariables(); + + @RegisterExtension + private static final KmsKeyExtension mfaResetStorageTokenSigningKey = + new KmsKeyExtension("mfa-reset-storage-token-signing-key", KeyUsageType.SIGN_VERIFY); + + @RegisterExtension + private static final KmsKeyExtension ipvReverificationRequestsSigningKey = + new KmsKeyExtension("mfa-reset-jar-signing-key", KeyUsageType.SIGN_VERIFY); + + @RegisterExtension + public static final RedisExtension redisExtension = + new RedisExtension(new SerializationService(), new ConfigurationService()); + + @RegisterExtension + private static final CloudWatchExtension cloudwatchExtension = new CloudWatchExtension(); + + @RegisterExtension + private static final IDReverificationStateExtension idReverificationStateExtension = + new IDReverificationStateExtension(); + + @BeforeAll + static void setupEnvironment() { + environment.set("TXMA_AUDIT_QUEUE_URL", txmaAuditQueue.getQueueUrl()); + environment.set("IPV_AUTHORISATION_CLIENT_ID", "test-client-id"); + environment.set( + "MFA_RESET_STORAGE_TOKEN_SIGNING_KEY_ALIAS", + mfaResetStorageTokenSigningKey.getKeyId()); + environment.set( + "IPV_REVERIFICATION_REQUESTS_SIGNING_KEY_ALIAS", + ipvReverificationRequestsSigningKey.getKeyId()); + + createTestIPVEncryptionKeyPair(); + putIPVPublicKeyInEnvironmentVariableUntilIPVJWKSAvailable(); + } + + private static void createTestIPVEncryptionKeyPair() { + try { + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + keyPair = keyPairGenerator.generateKeyPair(); + } catch (NoSuchAlgorithmException e) { + fail("Unable to create RSA key pair: " + e.getMessage()); + } + } + + private static void putIPVPublicKeyInEnvironmentVariableUntilIPVJWKSAvailable() { + RSAKey rsaKey = + new RSAKey.Builder((java.security.interfaces.RSAPublicKey) keyPair.getPublic()) + .privateKey(keyPair.getPrivate()) + .keyID("key-id") + .build(); + + try { + String base64PublicKey = + Base64.getEncoder().encodeToString(rsaKey.toRSAPublicKey().getEncoded()); + + environment.set( + "IPV_PUBLIC_ENCRYPTION_KEY", + "-----BEGIN PUBLIC KEY-----\n" + + base64PublicKey + + "\n-----END PUBLIC KEY-----"); + } catch (JOSEException e) { + fail("Unable to create IPV public key for test environment: " + e.getMessage()); + } + } + + @BeforeEach + void setup() throws Json.JsonException { + ConfigurationService configurationService = ConfigurationService.getInstance(); + configurationService.setMetricsLoggerAdapter( + new MetricsLoggerTestDouble( + cloudwatchExtension.getLogGroupName(), + cloudwatchExtension.getLogStreamName())); + + handler = new MfaResetAuthorizeHandler(); + + sessionId = redis.createAuthenticatedSessionWithEmail(USER_EMAIL); + var internalCommonSubjectId = + ClientSubjectHelper.calculatePairwiseIdentifier( + new Subject().getValue(), + "test.account.gov.uk", + SaltHelper.generateNewSalt()); + redis.addInternalCommonSubjectIdToSession(sessionId, internalCommonSubjectId); + + String subjectId = "test-subject-id"; + userStore.signUp(USER_EMAIL, USER_PASSWORD, new Subject(subjectId)); + userStore.addVerifiedPhoneNumber(USER_EMAIL, USER_PHONE_NUMBER); + } + + @Test + void shouldAuthenticateMfaReset() { + idReverificationStateExtension.store("orch-redirect-url", "client-session-id"); + + var response = + makeRequest( + Optional.of(new MfaResetRequest(USER_EMAIL, "")), + constructFrontendHeaders(sessionId, sessionId), + Map.of()); + + assertThat(response, hasStatus(200)); + + checkCorrectKeysUsedViaIntegrationWithKms(response.getBody()); + checkStateIsStoredViaIntegrationWithRedis(sessionId); + checkTxmaEventPublishedViaIntegrationWithSQS(); + checkExecutionMetricsPublishedViaIntegrationWithCloudWatch(); + } + + private static void checkCorrectKeysUsedViaIntegrationWithKms(String body) { + var kmsAccessInterceptor = ConfigurationService.getKmsAccessInterceptor(); + assertTrue( + kmsAccessInterceptor.wasKeyUsedToSign( + ipvReverificationRequestsSigningKey.getKeyId())); + assertTrue( + kmsAccessInterceptor.wasKeyUsedToSign(mfaResetStorageTokenSigningKey.getKeyId())); + ObjectMapper objectMapper = new ObjectMapper(); + try { + JsonNode rootNode = objectMapper.readTree(body); + String url = rootNode.get("authorize_url").asText(); + Map params = + Arrays.stream(url.substring(1).split("&")) + .map(param -> param.split("=")) + .collect(Collectors.toMap(param -> param[0], param -> param[1])); + + String request = params.get("request"); + + JWEObject jweObject = JWEObject.parse(request); + jweObject.decrypt(new RSADecrypter(keyPair.getPrivate())); + + var payload = jweObject.getPayload().toString(); + + SignedJWT signedJWT = SignedJWT.parse(payload); + + assertNotNull(signedJWT); + } catch (JsonProcessingException e) { + fail("Body could not be parsed: " + body); + } catch (ParseException | JOSEException e) { + fail("JOSE exception processing JAR", e); + } + } + + private static void checkStateIsStoredViaIntegrationWithRedis(String sessionId) { + var state = redisExtension.getFromRedis(STATE_STORAGE_PREFIX + sessionId); + assertNotNull(state); + } + + private static void checkTxmaEventPublishedViaIntegrationWithSQS() { + assertEquals(1, txmaAuditQueue.getRawMessages().size()); + } + + private static void checkExecutionMetricsPublishedViaIntegrationWithCloudWatch() { + assertTrue(cloudwatchExtension.hasLoggedMetric(MFA_RESET_HANDOFF.getValue())); + } +} diff --git a/integration-tests/src/test/java/uk/gov/di/authentication/api/MfaResetStorageTokenJwkHandlerIntegrationTest.java b/integration-tests/src/test/java/uk/gov/di/authentication/api/MfaResetStorageTokenJwkHandlerIntegrationTest.java new file mode 100644 index 0000000000..fe33780c4f --- /dev/null +++ b/integration-tests/src/test/java/uk/gov/di/authentication/api/MfaResetStorageTokenJwkHandlerIntegrationTest.java @@ -0,0 +1,130 @@ +package uk.gov.di.authentication.api; + +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jose.jwk.KeyType; +import com.nimbusds.jose.jwk.KeyUse; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.kms.KmsClient; +import software.amazon.awssdk.services.kms.model.GetPublicKeyRequest; +import software.amazon.awssdk.services.kms.model.GetPublicKeyResponse; +import software.amazon.awssdk.services.kms.model.KeyUsageType; +import uk.gov.di.authentication.frontendapi.lambda.MfaResetStorageTokenJwkHandler; +import uk.gov.di.authentication.shared.services.ConfigurationService; +import uk.gov.di.authentication.sharedtest.basetest.ApiGatewayHandlerIntegrationTest; +import uk.gov.di.authentication.sharedtest.extensions.KmsKeyExtension; +import uk.gov.di.authentication.sharedtest.logging.CaptureLoggingExtension; +import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; +import uk.org.webcompere.systemstubs.jupiter.SystemStub; +import uk.org.webcompere.systemstubs.jupiter.SystemStubsExtension; + +import java.net.URI; +import java.util.Map; +import java.util.Optional; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasItem; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static uk.gov.di.authentication.shared.helpers.HashHelper.hashSha256String; +import static uk.gov.di.authentication.sharedtest.logging.LogEventMatcher.withMessageContaining; +import static uk.gov.di.orchestration.sharedtest.matchers.APIGatewayProxyResponseEventMatcher.hasStatus; + +@ExtendWith(SystemStubsExtension.class) +class MfaResetStorageTokenJwkHandlerIntegrationTest extends ApiGatewayHandlerIntegrationTest { + private static final Logger LOG = + LogManager.getLogger(MfaResetStorageTokenJwkHandlerIntegrationTest.class); + + @SystemStub static EnvironmentVariables environment = new EnvironmentVariables(); + + @RegisterExtension + private static final KmsKeyExtension mfaResetStorageTokenSigningKey = + new KmsKeyExtension("mfa-reset-storage-token-signing-key", KeyUsageType.SIGN_VERIFY); + + @RegisterExtension + private static final CaptureLoggingExtension logging = + new CaptureLoggingExtension(MfaResetStorageTokenJwkHandler.class); + + private static String expectedKid; + + @BeforeAll + static void setupEnvironment() { + environment.set( + "MFA_RESET_STORAGE_TOKEN_SIGNING_KEY_ALIAS", + mfaResetStorageTokenSigningKey.getKeyId()); + + try (KmsClient kmsClient = getKmsClient()) { + GetPublicKeyRequest getPublicKeyRequest = + GetPublicKeyRequest.builder() + .keyId(mfaResetStorageTokenSigningKey.getKeyId()) + .build(); + + GetPublicKeyResponse getPublicKeyResponse = kmsClient.getPublicKey(getPublicKeyRequest); + + expectedKid = hashSha256String(getPublicKeyResponse.keyId()); + + LOG.info("Retrieved kid: {}", expectedKid); + } catch (Exception e) { + LOG.error(e.getMessage(), e); + } + } + + private static KmsClient getKmsClient() { + return KmsClient.builder() + .endpointOverride(URI.create("http://localhost:45678")) + .credentialsProvider( + StaticCredentialsProvider.create( + AwsBasicCredentials.create("dummy", "dummy"))) + .region(Region.EU_WEST_2) + .build(); + } + + @Test + void shouldReturnJWKSetContainingTheStorageTokenSigningKey() { + handler = new MfaResetStorageTokenJwkHandler(new ConfigurationService()); + + var response = makeRequest(Optional.empty(), Map.of(), Map.of()); + + assertThat(response, hasStatus(200)); + + JsonObject jwk = JsonParser.parseString(response.getBody()).getAsJsonObject(); + JsonArray keys = jwk.get("keys").getAsJsonArray(); + assertEquals(1, keys.size(), "JWKS endpoint must return a single key."); + + checkPublicSigningKeyResponseMeetsADR0030(keys.get(0).getAsJsonObject()); + } + + @Test + void shouldNotAllowExceptionsToEscape() { + environment.set("MFA_RESET_STORAGE_TOKEN_SIGNING_KEY_ALIAS", "wrong-key-alias"); + + handler = new MfaResetStorageTokenJwkHandler(new ConfigurationService()); + + var response = makeRequest(Optional.empty(), Map.of(), Map.of()); + + assertThat(response, hasStatus(500)); + assertThat( + logging.events(), + hasItem( + withMessageContaining( + "Failed to serve Auth MFA storage token signature verification key."))); + } + + private static void checkPublicSigningKeyResponseMeetsADR0030(JsonObject key) { + assertEquals(expectedKid, key.get("kid").getAsString()); + assertEquals(KeyType.EC.getValue(), key.get("kty").getAsString()); + assertEquals(KeyUse.SIGNATURE.getValue(), key.get("use").getAsString()); + assertEquals(Curve.P_256.getName(), key.get("crv").getAsString()); + assertEquals(JWSAlgorithm.ES256.toString(), key.get("alg").getAsString()); + } +} diff --git a/integration-tests/src/test/java/uk/gov/di/authentication/api/ReverificationResultHandlerIntegrationTest.java b/integration-tests/src/test/java/uk/gov/di/authentication/api/ReverificationResultHandlerIntegrationTest.java new file mode 100644 index 0000000000..f55625485c --- /dev/null +++ b/integration-tests/src/test/java/uk/gov/di/authentication/api/ReverificationResultHandlerIntegrationTest.java @@ -0,0 +1,150 @@ +package uk.gov.di.authentication.api; + +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.core.WireMockConfiguration; +import com.nimbusds.oauth2.sdk.id.Subject; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; +import software.amazon.awssdk.services.kms.model.KeyUsageType; +import uk.gov.di.authentication.frontendapi.entity.ReverificationResultRequest; +import uk.gov.di.authentication.frontendapi.lambda.ReverificationResultHandler; +import uk.gov.di.authentication.shared.helpers.ClientSubjectHelper; +import uk.gov.di.authentication.shared.helpers.SaltHelper; +import uk.gov.di.authentication.shared.serialization.Json; +import uk.gov.di.authentication.shared.services.ConfigurationService; +import uk.gov.di.authentication.sharedtest.basetest.ApiGatewayHandlerIntegrationTest; +import uk.gov.di.authentication.sharedtest.extensions.KmsKeyExtension; +import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; +import uk.org.webcompere.systemstubs.jupiter.SystemStub; +import uk.org.webcompere.systemstubs.jupiter.SystemStubsExtension; + +import java.net.URI; +import java.util.Map; +import java.util.Optional; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.configureFor; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathMatching; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static uk.gov.di.authentication.sharedtest.matchers.APIGatewayProxyResponseEventMatcher.hasStatus; + +@ExtendWith(SystemStubsExtension.class) +class ReverificationResultHandlerIntegrationTest extends ApiGatewayHandlerIntegrationTest { + private static final String USER_EMAIL = "test@email.com"; + private static final String USER_PASSWORD = "Password123!"; + private static final String USER_PHONE_NUMBER = "+447712345432"; + + public static final String SUCCESSFUL_TOKEN_RESPONSE = + """ + { + "access_token": "access-token", + "token_type": "bearer", + "expires_in": 3600, + "scope": "openid" + } + """; + + private static final String SUCCESSFUL_USER_INFO_HTTP_RESPONSE_CONTENT = + """ + { + "sub": "urn:uuid:f81d4fae-7dec-11d0-a765-00a0c91e6bf6", + "success": true" + } + """; + + private String sessionId; + + @SystemStub static EnvironmentVariables environment = new EnvironmentVariables(); + + @RegisterExtension + private static final KmsKeyExtension mfaResetJarSigningKey = + new KmsKeyExtension("mfa-reset-jar-signing-key", KeyUsageType.SIGN_VERIFY); + + @BeforeAll + static void setupEnvironment() { + environment.set( + "IPV_REVERIFICATION_REQUESTS_SIGNING_KEY_ALIAS", mfaResetJarSigningKey.getKeyId()); + environment.set("IPV_AUTHORISATION_CLIENT_ID", "test-client-id"); + environment.set("IPV_AUDIENCE", "test-audience"); + environment.set("TXMA_AUDIT_QUEUE_URL", txmaAuditQueue.getQueueUrl()); + + WireMockServer wireMockServer = + new WireMockServer(WireMockConfiguration.wireMockConfig().dynamicPort()); + wireMockServer.start(); + configureFor("localhost", wireMockServer.port()); + URI ipvUri = URI.create("http://localhost:" + wireMockServer.port()); + + environment.set("IPV_BACKEND_URI", ipvUri); + } + + @BeforeEach + void setup() throws Json.JsonException { + handler = new ReverificationResultHandler(); + sessionId = redis.createAuthenticatedSessionWithEmail(USER_EMAIL); + var internalCommonSubjectId = + ClientSubjectHelper.calculatePairwiseIdentifier( + new Subject().getValue(), + "test.account.gov.uk", + SaltHelper.generateNewSalt()); + redis.addInternalCommonSubjectIdToSession(sessionId, internalCommonSubjectId); + + String subjectId = "test-subject-id"; + userStore.signUp(USER_EMAIL, USER_PASSWORD, new Subject(subjectId)); + userStore.addVerifiedPhoneNumber(USER_EMAIL, USER_PHONE_NUMBER); + } + + @Test + void shouldSuccessfullyProcessAReverificationResult() { + stubFor( + post(urlPathMatching("/token")) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody(SUCCESSFUL_TOKEN_RESPONSE))); + + stubFor( + get(urlPathMatching("/reverification")) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody(SUCCESSFUL_USER_INFO_HTTP_RESPONSE_CONTENT))); + + var response = + makeRequest( + Optional.of(new ReverificationResultRequest("code", "eamil")), + constructFrontendHeaders(sessionId, sessionId), + Map.of()); + + assertThat(response, hasStatus(200)); + checkIntegrationWithTxmaViaSQS(); + checkCorrectKeyUsedToSignRequestToIPVViaIntegrationWithKms(); + checkIntegrationWithIPV(); + } + + private static void checkIntegrationWithTxmaViaSQS() { + assertEquals(2, txmaAuditQueue.getRawMessages().size()); + } + + private static void checkCorrectKeyUsedToSignRequestToIPVViaIntegrationWithKms() { + var kmsAccessInterceptor = ConfigurationService.getKmsAccessInterceptor(); + assertTrue(kmsAccessInterceptor.wasKeyUsedToSign(mfaResetJarSigningKey.getKeyId())); + } + + private static void checkIntegrationWithIPV() { + WireMock.verify(1, postRequestedFor(urlPathMatching("/token"))); + WireMock.verify(1, getRequestedFor(urlPathMatching("/reverification"))); + } +} diff --git a/integration-tests/src/test/resources/log4j2-test.xml b/integration-tests/src/test/resources/log4j2-test.xml new file mode 100644 index 0000000000..b122109ce3 --- /dev/null +++ b/integration-tests/src/test/resources/log4j2-test.xml @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/shared-test/build.gradle b/shared-test/build.gradle index 194ccbcda2..e1c99247bb 100644 --- a/shared-test/build.gradle +++ b/shared-test/build.gradle @@ -26,6 +26,7 @@ dependencies { implementation project(":doc-checking-app-api") implementation project(":oidc-api") implementation project(":interventions-api-stub") + implementation("software.amazon.awssdk:cloudwatchlogs:${dependencyVersions.aws_sdk_v2_version}") } test { diff --git a/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/doubles/MetricsLoggerTestDouble.java b/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/doubles/MetricsLoggerTestDouble.java new file mode 100644 index 0000000000..4216e8b60c --- /dev/null +++ b/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/doubles/MetricsLoggerTestDouble.java @@ -0,0 +1,78 @@ +package uk.gov.di.authentication.sharedtest.doubles; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient; +import software.amazon.awssdk.services.cloudwatchlogs.model.CreateLogGroupRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.CreateLogStreamRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.InputLogEvent; +import software.amazon.awssdk.services.cloudwatchlogs.model.PutLogEventsRequest; +import software.amazon.cloudwatchlogs.emf.logger.MetricsLogger; +import software.amazon.cloudwatchlogs.emf.model.Unit; + +import java.net.URI; +import java.time.Instant; +import java.util.Collections; + +public class MetricsLoggerTestDouble extends MetricsLogger { + private static final Logger log = LogManager.getLogger(MetricsLoggerTestDouble.class); + private final CloudWatchLogsClient cloudWatchLogsClient; + private final String logGroupName; + private final String logStreamName; + + public MetricsLoggerTestDouble(String logGroupName, String logStreamName) { + this.logGroupName = logGroupName; + this.logStreamName = logStreamName; + this.cloudWatchLogsClient = + CloudWatchLogsClient.builder() + .endpointOverride(URI.create("http://localhost:45678")) + .region(Region.EU_WEST_2) + .credentialsProvider( + StaticCredentialsProvider.create( + AwsBasicCredentials.create("dummy", "dummy"))) + .build(); + cloudWatchLogsClient.createLogGroup( + CreateLogGroupRequest.builder().logGroupName(logGroupName).build()); + cloudWatchLogsClient.createLogStream( + CreateLogStreamRequest.builder() + .logGroupName(logGroupName) + .logStreamName(logStreamName) + .build()); + } + + @Override + public MetricsLogger setNamespace(String namespace) { + return null; + } + + @Override + public MetricsLogger putMetric(String metricName, double value, Unit unit) { + String message = + String.format( + "{\"metricName\":\"%s\", \"value\": %f, \"timestamp\": %d}", + metricName, value, Instant.now().toEpochMilli()); + + PutLogEventsRequest request = + PutLogEventsRequest.builder() + .logGroupName(logGroupName) + .logStreamName(logStreamName) + .logEvents( + Collections.singletonList( + InputLogEvent.builder() + .message(message) + .timestamp(Instant.now().toEpochMilli()) + .build())) + .build(); + + cloudWatchLogsClient.putLogEvents(request); + return null; + } + + @Override + public void flush() { + log.info("NO-OP: flush"); + } +} diff --git a/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/CloudWatchExtension.java b/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/CloudWatchExtension.java new file mode 100644 index 0000000000..ac90cb4e3f --- /dev/null +++ b/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/CloudWatchExtension.java @@ -0,0 +1,72 @@ +package uk.gov.di.authentication.sharedtest.extensions; + +import org.junit.jupiter.api.extension.AfterAllCallback; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetLogEventsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetLogEventsResponse; + +import java.net.URI; + +public class CloudWatchExtension implements BeforeAllCallback, AfterAllCallback { + private CloudWatchLogsClient cloudWatchLogsClient; + + public String getLogGroupName() { + return logGroupName; + } + + public String getLogStreamName() { + return logStreamName; + } + + private String logGroupName; + private String logStreamName; + + @Override + public void beforeAll(ExtensionContext context) throws Exception { + logGroupName = context.getTestClass().map(Class::getSimpleName).orElse("unknown"); + logStreamName = context.getTestClass().map(Class::getSimpleName).orElse("unknown"); + cloudWatchLogsClient = + CloudWatchLogsClient.builder() + .endpointOverride(URI.create("http://localhost:45678")) + .region(Region.EU_WEST_2) + .credentialsProvider( + StaticCredentialsProvider.create( + AwsBasicCredentials.create("some-key", "some-secret"))) + .build(); + } + + @Override + public void afterAll(ExtensionContext context) { + if (cloudWatchLogsClient != null) { + cloudWatchLogsClient.close(); + } + } + + /** + * Localstack does not support CloudWatch metrics generation from CloudWatch logs so we have to + * check the logs for a matching metric log message. This is sufficient as we do not need to + * test internal AWS functionality just that our lambda initiates the metric creation process + * correctly. + * + * @param metricName name of the metric to be created from the log message + * @return boolean indicating whether the log message containing the metric was found + */ + public boolean hasLoggedMetric(String metricName) { + GetLogEventsRequest request = + GetLogEventsRequest.builder() + .logGroupName(this.logGroupName) + .logStreamName(this.logStreamName) + .build(); + + GetLogEventsResponse response = cloudWatchLogsClient.getLogEvents(request); + + return response.hasEvents() + && response.events().stream() + .anyMatch(event -> event.message().contains(metricName)); + } +} diff --git a/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/KmsKeyExtension.java b/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/KmsKeyExtension.java index 029986f1c3..aefd1507f8 100644 --- a/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/KmsKeyExtension.java +++ b/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/KmsKeyExtension.java @@ -26,6 +26,8 @@ public class KmsKeyExtension extends BaseAwsResourceExtension implements BeforeA private String keyAlias; private final KeyUsageType keyUsageType; + private String keyId; + public KmsKeyExtension(String keyAliasSuffix) { this(keyAliasSuffix, SIGN_VERIFY); } @@ -69,6 +71,8 @@ protected void createTokenSigningKey(String keyAlias) { .build(); var keyResponse = kms.createKey(keyRequest); + keyId = keyResponse.keyMetadata().keyId(); + CreateAliasRequest aliasRequest = CreateAliasRequest.builder() .aliasName(keyAlias) @@ -87,6 +91,8 @@ protected void createEncryptionKey(String keyAlias) { var keyResponse = kms.createKey(keyRequest); + keyId = keyResponse.keyMetadata().keyId(); + CreateAliasRequest aliasRequest = CreateAliasRequest.builder() .aliasName(keyAlias) @@ -109,4 +115,8 @@ protected boolean keyExists(String keyAlias) { public String getKeyAlias() { return keyAlias; } + + public String getKeyId() { + return keyId; + } } diff --git a/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/TokenSigningExtension.java b/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/TokenSigningExtension.java index 024f3c662f..842d8b9335 100644 --- a/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/TokenSigningExtension.java +++ b/shared-test/src/main/java/uk/gov/di/authentication/sharedtest/extensions/TokenSigningExtension.java @@ -11,6 +11,7 @@ import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.kms.model.SignRequest; import software.amazon.awssdk.services.kms.model.SigningAlgorithmSpec; +import uk.gov.di.authentication.shared.services.ConfigurationService; import uk.gov.di.authentication.shared.services.KmsConnectionService; import java.nio.ByteBuffer; @@ -33,7 +34,11 @@ public TokenSigningExtension(String keyAliasSuffix) { public void beforeAll(ExtensionContext context) { super.beforeAll(context); kmsConnectionService = - new KmsConnectionService(Optional.of(LOCALSTACK_ENDPOINT), REGION, getKeyAlias()); + new KmsConnectionService( + Optional.of(LOCALSTACK_ENDPOINT), + REGION, + getKeyAlias(), + ConfigurationService.getKmsAccessInterceptor()); } public SignedJWT signJwt(JWTClaimsSet claimsSet) { diff --git a/shared/build.gradle b/shared/build.gradle index 55865035d7..3cb66ce525 100644 --- a/shared/build.gradle +++ b/shared/build.gradle @@ -26,6 +26,7 @@ dependencies { configurations.cloudwatch, configurations.gson, configurations.apache + implementation("software.amazon.awssdk:cloudwatchlogs:${dependencyVersions.aws_sdk_v2_version}") testImplementation configurations.tests, configurations.lambda_tests, diff --git a/shared/src/main/java/uk/gov/di/authentication/shared/interceptors/KmsAccessInterceptor.java b/shared/src/main/java/uk/gov/di/authentication/shared/interceptors/KmsAccessInterceptor.java new file mode 100644 index 0000000000..3198e92f4d --- /dev/null +++ b/shared/src/main/java/uk/gov/di/authentication/shared/interceptors/KmsAccessInterceptor.java @@ -0,0 +1,27 @@ +package uk.gov.di.authentication.shared.interceptors; + +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.services.kms.model.SignRequest; + +import java.util.HashSet; + +public class KmsAccessInterceptor implements ExecutionInterceptor { + private HashSet signingKeysUsed = new HashSet<>(); + + public boolean wasKeyUsedToSign(String keyId) { + return signingKeysUsed.contains(keyId); + } + + @Override + public void beforeExecution( + Context.BeforeExecution context, ExecutionAttributes executionAttributes) { + String operation = executionAttributes.getAttribute(SdkExecutionAttribute.OPERATION_NAME); + if (operation != null && (operation.equals("Sign"))) { + SignRequest signRequest = (SignRequest) context.request(); + signingKeysUsed.add(signRequest.keyId()); + } + } +} diff --git a/shared/src/main/java/uk/gov/di/authentication/shared/services/CloudwatchMetricsService.java b/shared/src/main/java/uk/gov/di/authentication/shared/services/CloudwatchMetricsService.java index a588a35d63..e267668211 100644 --- a/shared/src/main/java/uk/gov/di/authentication/shared/services/CloudwatchMetricsService.java +++ b/shared/src/main/java/uk/gov/di/authentication/shared/services/CloudwatchMetricsService.java @@ -1,6 +1,5 @@ package uk.gov.di.authentication.shared.services; -import software.amazon.cloudwatchlogs.emf.logger.MetricsLogger; import software.amazon.cloudwatchlogs.emf.model.DimensionSet; import software.amazon.cloudwatchlogs.emf.model.Unit; import uk.gov.di.authentication.shared.entity.AuthSessionItem; @@ -38,7 +37,7 @@ public void putEmbeddedValue(String name, double value, Map dime segmentedFunctionCall( "Metrics::EMF", () -> { - var metrics = new MetricsLogger(); + var metrics = configurationService.getMetricsLogger(); var dimensionsSet = new DimensionSet(); String namespace = "Authentication"; diff --git a/shared/src/main/java/uk/gov/di/authentication/shared/services/ConfigurationService.java b/shared/src/main/java/uk/gov/di/authentication/shared/services/ConfigurationService.java index 58847cc0a5..77bf78a681 100644 --- a/shared/src/main/java/uk/gov/di/authentication/shared/services/ConfigurationService.java +++ b/shared/src/main/java/uk/gov/di/authentication/shared/services/ConfigurationService.java @@ -10,10 +10,12 @@ import software.amazon.awssdk.services.ssm.model.GetParametersRequest; import software.amazon.awssdk.services.ssm.model.Parameter; import software.amazon.awssdk.services.ssm.model.ParameterNotFoundException; +import software.amazon.cloudwatchlogs.emf.logger.MetricsLogger; import uk.gov.di.authentication.shared.configuration.AuditPublisherConfiguration; import uk.gov.di.authentication.shared.configuration.BaseLambdaConfiguration; import uk.gov.di.authentication.shared.entity.DeliveryReceiptsNotificationType; import uk.gov.di.authentication.shared.exceptions.MissingEnvVariableException; +import uk.gov.di.authentication.shared.interceptors.KmsAccessInterceptor; import java.net.URI; import java.time.Clock; @@ -33,6 +35,12 @@ public class ConfigurationService implements BaseLambdaConfiguration, AuditPubli public static final String FEATURE_SWITCH_ON = "true"; private static ConfigurationService configurationService; + public static KmsAccessInterceptor getKmsAccessInterceptor() { + return kmsAccessInterceptor; + } + + private static final KmsAccessInterceptor kmsAccessInterceptor = new KmsAccessInterceptor(); + public static ConfigurationService getInstance() { if (configurationService == null) { configurationService = new ConfigurationService(); @@ -46,6 +54,7 @@ public static ConfigurationService getInstance() { private String notifyCallbackBearerToken; protected SystemService systemService; + private MetricsLogger metricsLoggerAdapter; public ConfigurationService() {} @@ -686,4 +695,20 @@ public URI getIPVAuthorisationCallbackURI() { public String getIPVAuthorisationClientId() { return System.getenv().getOrDefault("IPV_AUTHORISATION_CLIENT_ID", ""); } + + public MetricsLogger getMetricsLogger() { + return getLocalstackEndpointUri() + .map( + l -> { + LOG.info("Localstack endpoint URI is present: {}", l); + return metricsLoggerAdapter != null + ? metricsLoggerAdapter + : new MetricsLogger(); + }) + .orElseGet(MetricsLogger::new); + } + + public void setMetricsLoggerAdapter(MetricsLogger metricsLoggerAdapter) { + this.metricsLoggerAdapter = metricsLoggerAdapter; + } } diff --git a/shared/src/main/java/uk/gov/di/authentication/shared/services/KmsConnectionService.java b/shared/src/main/java/uk/gov/di/authentication/shared/services/KmsConnectionService.java index a8f43ea072..61d9836c56 100644 --- a/shared/src/main/java/uk/gov/di/authentication/shared/services/KmsConnectionService.java +++ b/shared/src/main/java/uk/gov/di/authentication/shared/services/KmsConnectionService.java @@ -3,12 +3,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.kms.KmsClient; import software.amazon.awssdk.services.kms.model.GetPublicKeyRequest; import software.amazon.awssdk.services.kms.model.GetPublicKeyResponse; import software.amazon.awssdk.services.kms.model.SignRequest; import software.amazon.awssdk.services.kms.model.SignResponse; +import uk.gov.di.authentication.shared.interceptors.KmsAccessInterceptor; import java.net.URI; import java.util.Optional; @@ -22,11 +24,15 @@ public KmsConnectionService(ConfigurationService configurationService) { this( configurationService.getLocalstackEndpointUri(), configurationService.getAwsRegion(), - configurationService.getTokenSigningKeyAlias()); + configurationService.getTokenSigningKeyAlias(), + ConfigurationService.getKmsAccessInterceptor()); } public KmsConnectionService( - Optional localstackEndpointUri, String awsRegion, String tokenSigningKeyId) { + Optional localstackEndpointUri, + String awsRegion, + String tokenSigningKeyId, + KmsAccessInterceptor kmsAccessInterceptor) { if (localstackEndpointUri.isPresent()) { LOG.info("Localstack endpoint URI is present: " + localstackEndpointUri.get()); this.kmsClient = @@ -34,6 +40,10 @@ public KmsConnectionService( .endpointOverride(URI.create(localstackEndpointUri.get())) .credentialsProvider(DefaultCredentialsProvider.create()) .region(Region.of(awsRegion)) + .overrideConfiguration( + ClientOverrideConfiguration.builder() + .addExecutionInterceptor(kmsAccessInterceptor) + .build()) .build(); } else { this.kmsClient = diff --git a/shared/src/main/java/uk/gov/di/authentication/shared/services/TokenService.java b/shared/src/main/java/uk/gov/di/authentication/shared/services/TokenService.java index 865c805951..602a8602de 100644 --- a/shared/src/main/java/uk/gov/di/authentication/shared/services/TokenService.java +++ b/shared/src/main/java/uk/gov/di/authentication/shared/services/TokenService.java @@ -428,7 +428,7 @@ private SignedJWT generateSignedJWT( .message( SdkBytes.fromByteArray( message.getBytes(StandardCharsets.UTF_8))) - .keyId(signingKeyId) + .keyId(signingKey) .signingAlgorithm(signingAlgorithm) .build(); SignResponse signResult = kmsConnectionService.sign(signRequest);