Skip to content

Commit

Permalink
refactor(auth): prepare to support quota project in UC (#656)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbolduc authored Jan 9, 2025
1 parent f134976 commit f382b3b
Showing 1 changed file with 43 additions and 28 deletions.
71 changes: 43 additions & 28 deletions src/auth/src/credentials/user_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@ use time::OffsetDateTime;
const OAUTH2_ENDPOINT: &str = "https://oauth2.googleapis.com/token";

pub(crate) fn creds_from(js: serde_json::Value) -> Result<Credential> {
let token_provider = UserTokenProvider::from_json(js)?;
let au = serde_json::from_value::<AuthorizedUser>(js)
.map_err(|e| CredentialError::new(false, e.into()))?;
let token_provider = UserTokenProvider {
client_id: au.client_id,
client_secret: au.client_secret,
refresh_token: au.refresh_token,
endpoint: OAUTH2_ENDPOINT.to_string(),
};

Ok(Credential {
inner: Arc::new(UserCredential { token_provider }),
Expand All @@ -52,19 +59,6 @@ impl std::fmt::Debug for UserTokenProvider {
}
}

impl UserTokenProvider {
fn from_json(js: serde_json::Value) -> Result<Self> {
let au: AuthorizedUser =
serde_json::from_value(js).map_err(|e| CredentialError::new(false, e.into()))?;
Ok(UserTokenProvider {
client_id: au.client_id,
client_secret: au.client_secret,
refresh_token: au.refresh_token,
endpoint: OAUTH2_ENDPOINT.to_string(),
})
}
}

#[async_trait::async_trait]
impl TokenProvider for UserTokenProvider {
async fn get_token(&self) -> Result<Token> {
Expand Down Expand Up @@ -210,7 +204,7 @@ mod test {
}

#[test]
fn user_token_provider_from_json_success() {
fn authorized_user_full_from_json_success() {
let json = serde_json::json!({
"account": "",
"client_id": "test-client-id",
Expand All @@ -221,18 +215,37 @@ mod test {
"quota_project_id": "test-project"
});

let expected = UserTokenProvider {
let expected = AuthorizedUser {
cred_type: "authorized_user".to_string(),
client_id: "test-client-id".to_string(),
client_secret: "test-client-secret".to_string(),
refresh_token: "test-refresh-token".to_string(),
};
let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
assert_eq!(actual, expected);
}

#[test]
fn authorized_user_partial_from_json_success() {
let json = serde_json::json!({
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"refresh_token": "test-refresh-token",
"type": "authorized_user",
});

let expected = AuthorizedUser {
cred_type: "authorized_user".to_string(),
client_id: "test-client-id".to_string(),
client_secret: "test-client-secret".to_string(),
refresh_token: "test-refresh-token".to_string(),
endpoint: OAUTH2_ENDPOINT.to_string(),
};
let actual = UserTokenProvider::from_json(json).unwrap();
let actual = serde_json::from_value::<AuthorizedUser>(json).unwrap();
assert_eq!(actual, expected);
}

#[test]
fn user_token_provider_from_json_parse_fail() {
fn authorized_user_from_json_parse_fail() {
let json_full = serde_json::json!({
"account": "",
"client_id": "test-client-id",
Expand All @@ -247,7 +260,9 @@ mod test {
let mut json = json_full.clone();
// Remove a required field from the JSON
json[required_field].take();
UserTokenProvider::from_json(json).err().unwrap();
serde_json::from_value::<AuthorizedUser>(json)
.err()
.unwrap();
}
}

Expand Down Expand Up @@ -455,13 +470,13 @@ mod test {
let (endpoint, _server) = start(StatusCode::OK, response_body).await;
println!("endpoint = {endpoint}");

let tp = UserTokenProvider {
let token_provider = UserTokenProvider {
client_id: "test-client-id".to_string(),
client_secret: "test-client-secret".to_string(),
refresh_token: "test-refresh-token".to_string(),
endpoint: endpoint,
};
let uc = UserCredential { token_provider: tp };
let uc = UserCredential { token_provider };
let now = OffsetDateTime::now_utc();
let token = uc.get_token().await?;
assert_eq!(token.token, "test-access-token");
Expand All @@ -486,13 +501,13 @@ mod test {
let (endpoint, _server) = start(StatusCode::OK, response_body).await;
println!("endpoint = {endpoint}");

let tp = UserTokenProvider {
let token_provider = UserTokenProvider {
client_id: "test-client-id".to_string(),
client_secret: "test-client-secret".to_string(),
refresh_token: "test-refresh-token".to_string(),
endpoint: endpoint,
};
let uc = UserCredential { token_provider: tp };
let uc = UserCredential { token_provider };
let token = uc.get_token().await?;
assert_eq!(token.token, "test-access-token");
assert_eq!(token.token_type, "test-token-type");
Expand All @@ -507,13 +522,13 @@ mod test {
start(StatusCode::SERVICE_UNAVAILABLE, "try again".to_string()).await;
println!("endpoint = {endpoint}");

let tp = UserTokenProvider {
let token_provider = UserTokenProvider {
client_id: "test-client-id".to_string(),
client_secret: "test-client-secret".to_string(),
refresh_token: "test-refresh-token".to_string(),
endpoint: endpoint,
};
let uc = UserCredential { token_provider: tp };
let uc = UserCredential { token_provider };
let e = uc.get_token().await.err().unwrap();
assert!(e.is_retryable());
assert!(e.source().unwrap().to_string().contains("try again"));
Expand All @@ -526,13 +541,13 @@ mod test {
let (endpoint, _server) = start(StatusCode::UNAUTHORIZED, "epic fail".to_string()).await;
println!("endpoint = {endpoint}");

let tp = UserTokenProvider {
let token_provider = UserTokenProvider {
client_id: "test-client-id".to_string(),
client_secret: "test-client-secret".to_string(),
refresh_token: "test-refresh-token".to_string(),
endpoint: endpoint,
};
let uc = UserCredential { token_provider: tp };
let uc = UserCredential { token_provider };
let e = uc.get_token().await.err().unwrap();
assert!(!e.is_retryable());
assert!(e.source().unwrap().to_string().contains("epic fail"));
Expand Down

0 comments on commit f382b3b

Please sign in to comment.