Skip to content

Commit

Permalink
fix(rust): set default project/space/user
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianbenavides committed Aug 2, 2024
1 parent c87c66c commit 50f3fe0
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl ProjectsSqlxDatabase {

// Check if any of the emails are in the user table
let q = format!(
"SELECT EXISTS(SELECT 1 FROM \"user\" WHERE LOWER(email) IN ({}))",
r#"SELECT EXISTS(SELECT 1 FROM "user" WHERE LOWER(email) IN ({}))"#,
non_admin_emails
.iter()
.map(|e| format!("'{}'", e))
Expand Down Expand Up @@ -112,16 +112,36 @@ impl ProjectsSqlxDatabase {
}

// set the project as the default one
let query1 = query("UPDATE project SET is_default = $1 WHERE project_id = $2")
query("UPDATE project SET is_default = $1 WHERE project_id = $2")
.bind(true)
.bind(project_id);
query1.execute(&mut *transaction).await.void()?;
.bind(project_id)
.execute(&mut *transaction)
.await
.void()?;

// set all the others as non-default
let query2 = query("UPDATE project SET is_default = $1 WHERE project_id <> $2")
query("UPDATE project SET is_default = $1 WHERE project_id <> $2")
.bind(false)
.bind(project_id);
query2.execute(&mut *transaction).await.void()?;
.bind(project_id)
.execute(&mut *transaction)
.await
.void()?;

// set the associated space as default
query("UPDATE space SET is_default = $1 WHERE space_id = (SELECT space_id FROM project WHERE project_id = $2)")
.bind(true)
.bind(project_id)
.execute(&mut *transaction)
.await
.void()?;

// set all the others as non-default
query("UPDATE space SET is_default = $1 WHERE space_id <> (SELECT space_id FROM project WHERE project_id = $2)")
.bind(false)
.bind(project_id)
.execute(&mut *transaction)
.await
.void()?;

Ok(())
}
Expand All @@ -144,15 +164,16 @@ impl ProjectsRepository for ProjectsSqlxDatabase {
// projects with the same name that belong to other spaces.
project_name = &project.id;
} else {
// Is there a default project already?
let query1 = query("SELECT project_id FROM project WHERE is_default = $1").bind(true);
let project_id: Option<String> = query1
.fetch_optional(&mut *transaction)
.await
.into_core()?
.map(|row| row.get(0));
// The project is set as the default one if no other default project exists already
is_default = project_id.is_none() || project_id == Some(project.id.clone());
// Set to default if there is no default project
let default_project_id: Option<String> =
query("SELECT project_id FROM project WHERE is_default = $1")
.bind(true)
.fetch_optional(&mut *transaction)
.await
.into_core()?
.map(|row| row.get(0));
is_default =
default_project_id.is_none() || default_project_id.as_ref() == Some(&project.id);
}

let query2 = query(
Expand All @@ -177,6 +198,10 @@ impl ProjectsRepository for ProjectsSqlxDatabase {
.bind(project.operation_id.as_ref());
query2.execute(&mut *transaction).await.void()?;

if is_default {
self.set_as_default(&project.id, &mut transaction).await?;
}

// remove any existing users related to that project if any
let query3 = query("DELETE FROM user_project WHERE project_id = $1").bind(&project.id);
query3.execute(&mut *transaction).await.void()?;
Expand Down Expand Up @@ -664,10 +689,9 @@ mod test {

// retrieve them as a list or by name
let result = repository.get_projects().await?;
assert_eq!(
result,
vec![project1.clone(), project2.clone(), project3.clone()]
);
for project in vec![project1.clone(), project2.clone(), project3.clone()] {
assert!(result.contains(&project));
}

let result = repository.get_project_by_name("name2").await?;
assert_eq!(result, Some(project2.clone()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,17 @@ impl SpacesRepository for SpacesSqlxDatabase {
async fn store_space(&self, space: &Space) -> Result<()> {
let mut transaction = self.database.begin().await.into_core()?;

// Is there a default space already?
let query1 = query("SELECT space_id FROM project WHERE is_default = $1").bind(true);
let space_id: Option<String> = query1
.fetch_optional(&mut *transaction)
.await
.into_core()?
.map(|row| row.get(0));
// The space is set as the default one if no other default space exists already
let is_default = space_id.is_none() || space_id == Some(space.id.clone());
// Set to default if there is no default space
let is_default = {
let default_space_id: Option<String> =
query("SELECT space_id FROM space WHERE is_default = $1")
.bind(true)
.fetch_optional(&mut *transaction)
.await
.into_core()?
.map(|row| row.get(0));
default_space_id.is_none() || default_space_id.as_ref() == Some(&space.id)
};

let query2 = query(
r#"
Expand All @@ -78,6 +80,10 @@ impl SpacesRepository for SpacesSqlxDatabase {
.bind(is_default);
query2.execute(&mut *transaction).await.void()?;

if is_default {
self.set_as_default(&space.id, &mut transaction).await?;
}

// remove any existing users related to that space if any
let query3 = query("DELETE FROM user_space WHERE space_id = $1").bind(&space.id);
query3.execute(&mut *transaction).await.void()?;
Expand Down Expand Up @@ -363,7 +369,9 @@ mod test {

// retrieve them as a vector or by name
let result = repository.get_spaces().await?;
assert_eq!(result, vec![space1.clone(), space2.clone()]);
for space in vec![space1.clone(), space2.clone()] {
assert!(result.contains(&space));
}

let result = repository.get_space_by_name("name1").await?;
assert_eq!(result, Some(space1.clone()));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use itertools::Itertools;
use sqlx::*;

use crate::cloud::email_address::EmailAddress;
Expand Down Expand Up @@ -25,34 +26,74 @@ impl UsersSqlxDatabase {
pub async fn create() -> Result<Self> {
Ok(Self::new(SqlxDatabase::in_memory("users").await?))
}

async fn set_as_default(
&self,
email: &EmailAddress,
transaction: &mut AnyConnection,
) -> Result<()> {
// set the user as the default one
query(r#"UPDATE "user" SET is_default = $1 WHERE LOWER(email) = LOWER($2)"#)
.bind(true)
.bind(email)
.execute(&mut *transaction)
.await
.void()?;

// set all the others as non-default
query(r#"UPDATE "user" SET is_default = $1 WHERE LOWER(email) <> LOWER($2)"#)
.bind(false)
.bind(email)
.execute(&mut *transaction)
.await
.void()?;

Ok(())
}
}

#[async_trait]
impl UsersRepository for UsersSqlxDatabase {
async fn store_user(&self, user: &UserInfo) -> Result<()> {
let mut transaction = self.database.begin().await.into_core()?;

let query1 = query_scalar(
r#"SELECT EXISTS(SELECT email FROM "user" WHERE is_default = $1 AND email = LOWER($2))"#,
)
.bind(true)
.bind(&user.email);
let is_already_default: Boolean = query1.fetch_one(&mut *transaction).await.into_core()?;
let is_already_default = is_already_default.to_bool();
// Set to default if there is no default user
let is_default = {
let default_user_email: Option<String> =
query(r#"SELECT email FROM "user" WHERE is_default = $1"#)
.bind(true)
.fetch_optional(&mut *transaction)
.await
.into_core()?
.map(|row| row.get(0));
let default_user_email: Option<EmailAddress> =
default_user_email.map(|e| EmailAddress::new_unsafe(&e));
default_user_email.is_none() || default_user_email.as_ref() == Some(&user.email)
};

// Get user if it exists, using the lowercased email
let query1 =
query(r#"SELECT email FROM "user" WHERE LOWER(email) = LOWER($1)"#).bind(&user.email);
let existing_user: Option<String> = query1
.fetch_optional(&mut *transaction)
.await
.into_core()?
.map(|row| row.get(0));
let email = existing_user.unwrap_or(user.email.to_string());

let query2 = query(r#"
INSERT INTO "user" (email, sub, nickname, name, picture, updated_at, email_verified, is_default)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (email)
DO UPDATE SET sub = $2, nickname = $3, name = $4, picture = $5, updated_at = $6, email_verified = $7, is_default = $8"#)
.bind(&user.email)
.bind(&email)
.bind(&user.sub)
.bind(&user.nickname)
.bind(&user.name)
.bind(&user.picture)
.bind(&user.updated_at)
.bind(user.email_verified)
.bind(is_already_default);
.bind(is_default);
query2.execute(&mut *transaction).await.void()?;

transaction.commit().await.void()
Expand All @@ -68,14 +109,19 @@ impl UsersRepository for UsersSqlxDatabase {
}

async fn set_default_user(&self, email: &EmailAddress) -> Result<()> {
let query = query(r#"UPDATE "user" SET is_default = $1 WHERE email = LOWER($2)"#)
.bind(true)
.bind(email);
query.execute(&*self.database.pool).await.void()
let mut transaction = self.database.begin().await.into_core()?;
self.set_as_default(email, &mut transaction).await?;
transaction.commit().await.void()
}

async fn get_user(&self, email: &EmailAddress) -> Result<Option<UserInfo>> {
let query = query_as(r#"SELECT email, sub, nickname, name, picture, updated_at, email_verified, is_default FROM "user" WHERE email = LOWER($1)"#).bind(email);
let query = query_as(
r#"SELECT email, sub, nickname, name,
picture, updated_at, email_verified, is_default
FROM "user"
WHERE LOWER(email) = LOWER($1)"#,
)
.bind(email);
let row: Option<UserRow> = query
.fetch_optional(&*self.database.pool)
.await
Expand All @@ -92,8 +138,39 @@ impl UsersRepository for UsersSqlxDatabase {
}

async fn delete_user(&self, email: &EmailAddress) -> Result<()> {
let query1 = query(r#"DELETE FROM "user" WHERE email = LOWER($1)"#).bind(email);
query1.execute(&*self.database.pool).await.void()
let mut transaction = self.database.begin().await.into_core()?;

// Check if the space is the default one
let q = query_scalar(
r#"SELECT EXISTS(SELECT 1 FROM "user" WHERE LOWER(email) = LOWER($1) AND is_default = $2)"#,
)
.bind(email.to_string())
.bind(true);
let is_default: Boolean = q.fetch_one(&mut *transaction).await.into_core()?;
let is_default = is_default.to_bool();

let query1 = query(r#"DELETE FROM "user" WHERE LOWER(email) = LOWER($1)"#).bind(email);
query1.execute(&mut *transaction).await.void()?;

// Set another space as default if the deleted one was the default
if is_default {
let user_emails: Vec<String> = query_scalar(r#"SELECT email FROM "user""#)
.fetch_all(&mut *transaction)
.await
.into_core()?;
let user_emails: Vec<_> = user_emails
.into_iter()
.map(|e| EmailAddress::new_unsafe(&e))
.unique()
.collect();
for email in user_emails {
if self.set_as_default(&email, &mut transaction).await.is_ok() {
break;
}
}
}

transaction.commit().await.void()
}
}

Expand Down Expand Up @@ -141,6 +218,7 @@ mod test {

let my_email_address: EmailAddress = "[email protected]".try_into().unwrap();
let your_email_address: EmailAddress = "[email protected]".try_into().unwrap();
let your_email_address_capitalized: EmailAddress = "[email protected]".try_into().unwrap();

// create and store 2 users
let user1 = UserInfo {
Expand All @@ -153,6 +231,15 @@ mod test {
email_verified: false,
};
let user2 = UserInfo {
sub: "ub".into(),
nickname: "You".to_string(),
name: "You".to_string(),
picture: "You".to_string(),
updated_at: "today".to_string(),
email: your_email_address_capitalized.clone(),
email_verified: false,
};
let user3 = UserInfo {
sub: "sub".into(),
nickname: "you".to_string(),
name: "you".to_string(),
Expand All @@ -163,25 +250,43 @@ mod test {
};

repository.store_user(&user1).await?;
// The first stored user is the default one
let result = repository.get_default_user().await?;
assert_eq!(result, Some(user1.clone()));

repository.store_user(&user2).await?;
// The second stored space is not the default one
let result = repository.get_default_user().await?;
assert_eq!(result, Some(user1.clone()));

repository.store_user(&user3).await?;
// The third space replaces the second one, as the email is equivalent
let result = repository.get_user(&your_email_address).await?;
assert_eq!(result, Some(user3.clone()));
let result = repository.get_user(&your_email_address_capitalized).await?;
assert_eq!(result, Some(user3.clone()));

// retrieve them as a vector or by name
let result = repository.get_users().await?;
assert_eq!(result, vec![user1.clone(), user2.clone()]);
assert_eq!(result, vec![user1.clone(), user3.clone()]);

let result = repository.get_user(&my_email_address).await?;
assert_eq!(result, Some(user1.clone()));

// a user can be set created as the default user
repository.set_default_user(&my_email_address).await?;
// a user can be set as the default user
repository.set_default_user(&your_email_address).await?;
let result = repository.get_default_user().await?;
assert_eq!(result, Some(user1.clone()));
assert_eq!(result, Some(user3.clone()));

// a user can be deleted
repository.delete_user(&your_email_address).await?;
let result = repository.get_user(&your_email_address).await?;
assert_eq!(result, None);

// if the default user is deleted, the next one becomes the default
let result = repository.get_default_user().await?;
assert_eq!(result, Some(user1.clone()));

let result = repository.get_users().await?;
assert_eq!(result, vec![user1.clone()]);
Ok(())
Expand Down
6 changes: 0 additions & 6 deletions implementations/rust/ockam/ockam_api/src/cli_state/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@ impl CliState {
#[instrument(skip_all, fields(user = %user))]
pub async fn store_user(&self, user: &UserInfo) -> Result<()> {
let repository = self.users_repository();
let is_first_user = repository.get_users().await?.is_empty();
repository.store_user(user).await?;

// if this is the first user we store we mark it as the default user
if is_first_user {
self.set_default_user(&user.email).await?
}
Ok(())
}

Expand Down
Loading

0 comments on commit 50f3fe0

Please sign in to comment.