From f382b3b39840b92b42a348cc5d73854e9d1cc0a6 Mon Sep 17 00:00:00 2001 From: Darren Bolduc Date: Thu, 9 Jan 2025 15:48:06 -0500 Subject: [PATCH] refactor(auth): prepare to support quota project in UC (#656) --- src/auth/src/credentials/user_credential.rs | 71 +++++++++++++-------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/src/auth/src/credentials/user_credential.rs b/src/auth/src/credentials/user_credential.rs index b61cd1efd..6f8c190c0 100644 --- a/src/auth/src/credentials/user_credential.rs +++ b/src/auth/src/credentials/user_credential.rs @@ -26,7 +26,14 @@ use time::OffsetDateTime; const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token"; pub(crate) fn creds_from(js: serde_json::Value) -> Result { - let token_provider = UserTokenProvider::from_json(js)?; + let au = serde_json::from_value::(js) + .map_err(|e| CredentialError::new(false, e.into()))?; + let token_provider = UserTokenProvider { + client_id: au.client_id, + client_secret: au.client_secret, + refresh_token: au.refresh_token, + endpoint: OAUTH2_ENDPOINT.to_string(), + }; Ok(Credential { inner: Arc::new(UserCredential { token_provider }), @@ -52,19 +59,6 @@ impl std::fmt::Debug for UserTokenProvider { } } -impl UserTokenProvider { - fn from_json(js: serde_json::Value) -> Result { - let au: AuthorizedUser = - serde_json::from_value(js).map_err(|e| CredentialError::new(false, e.into()))?; - Ok(UserTokenProvider { - client_id: au.client_id, - client_secret: au.client_secret, - refresh_token: au.refresh_token, - endpoint: OAUTH2_ENDPOINT.to_string(), - }) - } -} - #[async_trait::async_trait] impl TokenProvider for UserTokenProvider { async fn get_token(&self) -> Result { @@ -210,7 +204,7 @@ mod test { } #[test] - fn user_token_provider_from_json_success() { + fn authorized_user_full_from_json_success() { let json = serde_json::json!({ "account": "", "client_id": "test-client-id", @@ -221,18 +215,37 @@ mod test { "quota_project_id": "test-project" }); - let expected = UserTokenProvider { + let expected = AuthorizedUser { + cred_type: "authorized_user".to_string(), + client_id: "test-client-id".to_string(), + client_secret: "test-client-secret".to_string(), + refresh_token: "test-refresh-token".to_string(), + }; + let actual = serde_json::from_value::(json).unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn authorized_user_partial_from_json_success() { + let json = serde_json::json!({ + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "refresh_token": "test-refresh-token", + "type": "authorized_user", + }); + + let expected = AuthorizedUser { + cred_type: "authorized_user".to_string(), client_id: "test-client-id".to_string(), client_secret: "test-client-secret".to_string(), refresh_token: "test-refresh-token".to_string(), - endpoint: OAUTH2_ENDPOINT.to_string(), }; - let actual = UserTokenProvider::from_json(json).unwrap(); + let actual = serde_json::from_value::(json).unwrap(); assert_eq!(actual, expected); } #[test] - fn user_token_provider_from_json_parse_fail() { + fn authorized_user_from_json_parse_fail() { let json_full = serde_json::json!({ "account": "", "client_id": "test-client-id", @@ -247,7 +260,9 @@ mod test { let mut json = json_full.clone(); // Remove a required field from the JSON json[required_field].take(); - UserTokenProvider::from_json(json).err().unwrap(); + serde_json::from_value::(json) + .err() + .unwrap(); } } @@ -455,13 +470,13 @@ mod test { let (endpoint, _server) = start(StatusCode::OK, response_body).await; println!("endpoint = {endpoint}"); - let tp = UserTokenProvider { + let token_provider = UserTokenProvider { client_id: "test-client-id".to_string(), client_secret: "test-client-secret".to_string(), refresh_token: "test-refresh-token".to_string(), endpoint: endpoint, }; - let uc = UserCredential { token_provider: tp }; + let uc = UserCredential { token_provider }; let now = OffsetDateTime::now_utc(); let token = uc.get_token().await?; assert_eq!(token.token, "test-access-token"); @@ -486,13 +501,13 @@ mod test { let (endpoint, _server) = start(StatusCode::OK, response_body).await; println!("endpoint = {endpoint}"); - let tp = UserTokenProvider { + let token_provider = UserTokenProvider { client_id: "test-client-id".to_string(), client_secret: "test-client-secret".to_string(), refresh_token: "test-refresh-token".to_string(), endpoint: endpoint, }; - let uc = UserCredential { token_provider: tp }; + let uc = UserCredential { token_provider }; let token = uc.get_token().await?; assert_eq!(token.token, "test-access-token"); assert_eq!(token.token_type, "test-token-type"); @@ -507,13 +522,13 @@ mod test { start(StatusCode::SERVICE_UNAVAILABLE, "try again".to_string()).await; println!("endpoint = {endpoint}"); - let tp = UserTokenProvider { + let token_provider = UserTokenProvider { client_id: "test-client-id".to_string(), client_secret: "test-client-secret".to_string(), refresh_token: "test-refresh-token".to_string(), endpoint: endpoint, }; - let uc = UserCredential { token_provider: tp }; + let uc = UserCredential { token_provider }; let e = uc.get_token().await.err().unwrap(); assert!(e.is_retryable()); assert!(e.source().unwrap().to_string().contains("try again")); @@ -526,13 +541,13 @@ mod test { let (endpoint, _server) = start(StatusCode::UNAUTHORIZED, "epic fail".to_string()).await; println!("endpoint = {endpoint}"); - let tp = UserTokenProvider { + let token_provider = UserTokenProvider { client_id: "test-client-id".to_string(), client_secret: "test-client-secret".to_string(), refresh_token: "test-refresh-token".to_string(), endpoint: endpoint, }; - let uc = UserCredential { token_provider: tp }; + let uc = UserCredential { token_provider }; let e = uc.get_token().await.err().unwrap(); assert!(!e.is_retryable()); assert!(e.source().unwrap().to_string().contains("epic fail"));