Skip to content

Commit

Permalink
Merge pull request #5732 from govuk-one-login/ATO-1113-max-age-backch…
Browse files Browse the repository at this point in the history
…annel-logout-if-different-user

ATO-1113: Send backchannel logouts if different user post `max_age`
  • Loading branch information
Ryan-Andrews99 authored Jan 16, 2025
2 parents 4f773ea + 21ca9a1 commit 1a6318d
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import uk.gov.di.orchestration.shared.domain.AccountInterventionsAuditableEvent;
import uk.gov.di.orchestration.shared.domain.LogoutAuditableEvent;
import uk.gov.di.orchestration.shared.entity.AuthenticationUserInfo;
import uk.gov.di.orchestration.shared.entity.BackChannelLogoutMessage;
import uk.gov.di.orchestration.shared.entity.ClientSession;
import uk.gov.di.orchestration.shared.entity.ClientType;
import uk.gov.di.orchestration.shared.entity.CredentialTrustLevel;
Expand Down Expand Up @@ -60,7 +61,9 @@
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
Expand All @@ -70,6 +73,7 @@

import static com.nimbusds.jose.JWSAlgorithm.ES256;
import static java.util.Collections.singletonList;
import static org.awaitility.Awaitility.await;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.endsWith;
Expand Down Expand Up @@ -115,6 +119,10 @@ public class AuthenticationCallbackHandlerIntegrationTest extends ApiGatewayHand
@RegisterExtension
public static final OrchSessionExtension orchSessionExtension = new OrchSessionExtension();

@RegisterExtension
public static final SqsQueueExtension backChannelLogoutQueueExtension =
new SqsQueueExtension("back-channel-logout-");

protected static ConfigurationService configurationService;

private static final String CLIENT_ID = "test-client-id";
Expand Down Expand Up @@ -782,6 +790,9 @@ class MaxAgeSessionHandling {
"3eee3869-abf1-41c1-bdb5-c25f68d0a54d",
"aef54391-95d8-4d3b-ac30-cbe1e3e2f0d4");

private static final List<String> PREVIOUS_CLIENTS_FOR_CLIENT_SESSION =
List.of("client-id-1", "client-id-2", "client-id-3");

@Test
void
updatesOrchSessionAndSharedSessionWhenPreviousCommonSubjectIdMatchesAuthUserInfoResponse()
Expand Down Expand Up @@ -814,6 +825,9 @@ class MaxAgeSessionHandling {
new Subject(INTERNAL_COMMON_SUBJECT_ID), Long.MAX_VALUE, false);
setupMaxAgeSession();
setupPreviousSessions(DIFFERENT_INTERNAL_COMMON_SUBJECT_ID);
setupPreviousClientsAndPreviousClientSessions();
// Sending back channel logouts requires a user store entry
userStore.signUp(TEST_EMAIL_ADDRESS, "");

var response =
makeRequest(
Expand All @@ -822,12 +836,33 @@ class MaxAgeSessionHandling {
Optional.of(buildSessionCookie(SESSION_ID, CLIENT_SESSION_ID))),
constructQueryStringParameters());

assertUserInfoStoredAndRedirectedToRp(response);
assertThat(response, hasStatus(302));

URI redirectLocationHeader =
URI.create(response.getHeaders().get(ResponseHeaders.LOCATION));
assertEquals(
REDIRECT_URI.getAuthority() + REDIRECT_URI.getPath(),
redirectLocationHeader.getAuthority() + redirectLocationHeader.getPath());

assertThat(redirectLocationHeader.getQuery(), containsString(RP_STATE.getValue()));

assertThat(redirectLocationHeader.getQuery(), containsString("code"));

assertTxmaAuditEventsReceived(
txmaAuditQueue,
List.of(
OrchestrationAuditableEvent.AUTH_CALLBACK_RESPONSE_RECEIVED,
OrchestrationAuditableEvent.AUTH_SUCCESSFUL_TOKEN_RESPONSE_RECEIVED,
OrchestrationAuditableEvent.AUTH_SUCCESSFUL_USERINFO_RESPONSE_RECEIVED,
AccountInterventionsAuditableEvent.AIS_RESPONSE_RECEIVED,
OidcAuditableEvent.AUTHENTICATION_COMPLETE,
OidcAuditableEvent.AUTH_CODE_ISSUED));

var sharedSession = redis.getSession(SESSION_ID);
var orchSession = orchSessionExtension.getSession(SESSION_ID).get();
assertEquals(List.of(), sharedSession.getClientSessions());
assertNull(orchSession.getPreviousSessionId());
assertBackChannelLogoutsSent(PREVIOUS_CLIENTS_FOR_CLIENT_SESSION);
}

private void setupMaxAgeSession() throws Json.JsonException {
Expand All @@ -843,7 +878,7 @@ private void setupMaxAgeSession() throws Json.JsonException {

private void setupPreviousSessions(String internalCommonSubjectId)
throws Json.JsonException {
var session = new Session(PREVIOUS_SESSION_ID);
var session = new Session(PREVIOUS_SESSION_ID).setEmailAddress(TEST_EMAIL_ADDRESS);
PREVIOUS_CLIENT_SESSIONS.forEach(session::addClientSession);
redis.addSession(session);
redis.addStateToRedis(
Expand All @@ -855,6 +890,17 @@ private void setupPreviousSessions(String internalCommonSubjectId)
new OrchSessionItem(PREVIOUS_SESSION_ID)
.withInternalCommonSubjectId(internalCommonSubjectId));
}

private void setupPreviousClientsAndPreviousClientSessions() throws Json.JsonException {
PREVIOUS_CLIENTS_FOR_CLIENT_SESSION.forEach(
AuthenticationCallbackHandlerIntegrationTest.this::setupClientRegWithClientId);
for (String clientSessionId : PREVIOUS_CLIENT_SESSIONS) {
setupClientSessionWithId(
clientSessionId,
PREVIOUS_CLIENTS_FOR_CLIENT_SESSION.get(
PREVIOUS_CLIENT_SESSIONS.indexOf(clientSessionId)));
}
}
}

private void assertRedirectToSuspendedPage(APIGatewayProxyResponseEvent response) {
Expand Down Expand Up @@ -986,6 +1032,24 @@ private void assertSessionIsDeleted() {
assertTrue(orchSession.isEmpty());
}

private void setupClientRegWithClientId(String clientId) {
clientStore.registerClient(
clientId,
"test-client",
singletonList(REDIRECT_URI.toString()),
singletonList("[email protected]"),
singletonList("openid"),
null,
singletonList("http://localhost/post-redirect-logout"),
"http://example.com",
String.valueOf(ServiceType.MANDATORY),
"https://test.com",
"pairwise",
ClientType.APP,
ES256.getName(),
false);
}

private void setupClientReg(boolean identityVerificationSupported) {
clientStore.registerClient(
CLIENT_ID,
Expand Down Expand Up @@ -1179,9 +1243,15 @@ public boolean isAccountInterventionServiceActionEnabled() {
public boolean abortOnAccountInterventionsErrorResponse() {
return this.abortOnAisErrorResponse;
}

@Override
public String getBackChannelLogoutQueueURI() {
return backChannelLogoutQueueExtension.getQueueUrl();
}
}

private void setUpClientSession() throws Json.JsonException {
public void setupClientSessionWithId(String clientSessionId, String clientId)
throws Json.JsonException {
String vtrStr1 =
LevelOfConfidence.MEDIUM_LEVEL.getValue()
+ "."
Expand All @@ -1196,7 +1266,7 @@ private void setUpClientSession() throws Json.JsonException {

var authRequestBuilder =
new AuthenticationRequest.Builder(
ResponseType.CODE, SCOPE, new ClientID(CLIENT_ID), REDIRECT_URI)
ResponseType.CODE, SCOPE, new ClientID(clientId), REDIRECT_URI)
.state(RP_STATE)
.nonce(new Nonce())
.customParameter("vtr", jsonArrayOf(vtrStr1, vtrStr2));
Expand All @@ -1207,7 +1277,11 @@ ResponseType.CODE, SCOPE, new ClientID(CLIENT_ID), REDIRECT_URI)
vtrList,
CLIENT_NAME);

redis.createClientSession(CLIENT_SESSION_ID, clientSession);
redis.createClientSession(clientSessionId, clientSession);
}

private void setUpClientSession() throws Json.JsonException {
setupClientSessionWithId(CLIENT_SESSION_ID, CLIENT_ID);
}

private AuthorizationRequest validateQueryRequestToIPVAndReturnAuthRequest(
Expand Down Expand Up @@ -1265,4 +1339,26 @@ private void validateClaimsInJar(SignedJWT signedJWT, boolean reproveIdentity)
assertThat(signedJWT.getJWTClaimsSet().getClaim("scope"), equalTo("openid"));
assertThat(signedJWT.getHeader().getAlgorithm(), equalTo(ES256));
}

private void assertBackChannelLogoutsSent(List<String> clientIds) {
await().atMost(Duration.of(1, ChronoUnit.SECONDS))
.untilAsserted(
() ->
assertThat(
backChannelLogoutQueueExtension
.getApproximateMessageCount(),
equalTo(clientIds.size())));

var events = backChannelLogoutQueueExtension.getMessages(BackChannelLogoutMessage.class);
assertTrue(
clientIds.stream()
.allMatch(
clientId ->
events.stream()
.anyMatch(
backChannelLogoutMessage ->
backChannelLogoutMessage
.getClientId()
.equals(clientId))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,7 @@ private void handleMaxAgeSession(
} else {
LOG.info(
"Previous OrchSession InternalCommonSubjectId does not match Auth UserInfo response");
// TODO: ATO-1101: Send backchannel logouts + audit events

logoutService.handleMaxAgeLogout(previousSharedSession.get());
}
currentOrchSession.setPreviousSessionId(null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1024,12 +1024,13 @@ void itDoesNotAssignClientSessionsIfItCannotFindThePreviousSharedSession()
}

@Test
void itDoesNotAttatchThePreviousClientSessionsIfTheInternalCommonSubjectIdsDoNotMatch()
throws UnsuccessfulCredentialResponseException {
void
itSendsBackChannelLogoutNotificationForThePreviousSessionIfTheInternalCommonSubjectIdsDoNotMatch()
throws UnsuccessfulCredentialResponseException {
var orchSession = withMaxAgeOrchSession(INTERNAL_COMMON_SUBJECT_ID);
var sharedSession = withMaxAgeSharedSession();
withPreviousOrchSessionDueToMaxAge();
withPreviousSharedSessionDueToMaxAge();
var previousSharedSession = withPreviousSharedSessionDueToMaxAge();

when(tokenService.sendTokenRequest(any())).thenReturn(SUCCESSFUL_TOKEN_RESPONSE);
when(tokenService.sendUserInfoDataRequest(any(HTTPRequest.class)))
Expand All @@ -1055,7 +1056,8 @@ void itDoesNotAttatchThePreviousClientSessionsIfTheInternalCommonSubjectIdsDoNot
.updateSession(argThat(s -> s.getPreviousSessionId() == null));
verify(sessionService, times(2))
.storeOrUpdateSession(argThat(s -> s.getClientSessions().equals(List.of())));
// TODO: ATO-1101: Send backchannel logouts + audit events

verify(logoutService, times(1)).handleMaxAgeLogout(previousSharedSession);
}

private void withPreviousOrchSessionDueToMaxAge() {
Expand All @@ -1067,11 +1069,12 @@ private void withPreviousOrchSessionDueToMaxAge() {
INTERNAL_COMMON_SUBJECT_ID)));
}

private void withPreviousSharedSessionDueToMaxAge() {
private Session withPreviousSharedSessionDueToMaxAge() {
var previousSharedSession = new Session(PREVIOUS_SESSION_ID);
PREVIOUS_CLIENT_SESSIONS.forEach(previousSharedSession::addClientSession);
when(sessionService.getSession(PREVIOUS_SESSION_ID))
.thenReturn(Optional.of(previousSharedSession));
return previousSharedSession;
}

private void withNoPreviousSharedSession() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ public APIGatewayProxyResponseEvent handleAccountInterventionLogout(
Optional.empty());
}

public void handleMaxAgeLogout(Session previousSession) {
destroySessions(previousSession);
cloudwatchMetricsService.incrementLogout(Optional.empty());
}

private void sendAuditEvent(
TxmaAuditUser auditUser,
LogoutReason logoutReason,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -542,6 +543,35 @@ void successfullyLogsOutAndGeneratesRedirectResponseForeReauthenticationFailure(
is(equalTo(REAUTH_FAILURE_URI.toString())));
}

@Test
void handlesAMaxAgeSessionExpiry() {
var clientSessionId1 = IdGenerator.generate();
var clientSessionId2 = IdGenerator.generate();
var clientId1 = CLIENT_ID + "1";
var clientId2 = CLIENT_ID + "2";
var prevousSession =
new Session(SESSION_ID)
.setEmailAddress(EMAIL)
.addClientSession(clientSessionId1)
.addClientSession(clientSessionId2);
setUpClientSession(clientSessionId1, clientId1);
setUpClientSession(clientSessionId2, clientId2);

logoutService.handleMaxAgeLogout(prevousSession);

verify(clientSessionService, times(1)).deleteStoredClientSession(clientSessionId1);
verify(clientSessionService, times(1)).deleteStoredClientSession(clientSessionId2);
verify(sessionService).deleteStoredSession(session.getSessionId());
verify(orchSessionService).deleteSession(SESSION_ID);
verify(backChannelLogoutService)
.sendLogoutMessage(
argThat(withClientId(clientId1)), eq(EMAIL), eq(INTERNAL_SECTOR_URI));
verify(backChannelLogoutService)
.sendLogoutMessage(
argThat(withClientId(clientId2)), eq(EMAIL), eq(INTERNAL_SECTOR_URI));
verify(cloudwatchMetricsService).incrementLogout(Optional.empty());
}

private void setupAdditionalClientSessions() {
setUpClientSession("client-session-id-2", "client-id-2");
setUpClientSession("client-session-id-3", "client-id-3");
Expand Down
1 change: 1 addition & 0 deletions template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,7 @@ Resources:
authEnvironment,
]
ACCOUNT_INTERVENTIONS_ERROR_METRIC_NAME: "AISException"
BACK_CHANNEL_LOGOUT_QUEUE_URI: !GetAtt BackChannelLogoutQueue.QueueUrl
CREDENTIAL_STORE_URI: !Sub
- https://credential-store.${ServiceDomain}
- ServiceDomain:
Expand Down

0 comments on commit 1a6318d

Please sign in to comment.