Skip to content

Commit

Permalink
Merge pull request #1242 from rommapp/fix/simplify-query-to-validate-…
Browse files Browse the repository at this point in the history
…username-exists

fix: Simplify query that validates new username already exists
  • Loading branch information
adamantike authored Oct 14, 2024
2 parents 5b73c4e + eba2971 commit fc8cbb7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
16 changes: 16 additions & 0 deletions backend/endpoints/tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,22 @@ def test_add_user_from_unauthorized_user(
assert response.status_code == expected_status_code


def test_add_user_with_existing_username(client, access_token, admin_user):
response = client.post(
"/api/users",
params={
"username": admin_user.username,
"password": "new_user_password",
"role": Role.VIEWER.value,
},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == HTTPStatus.BAD_REQUEST

response = response.json()
assert response["detail"] == f"Username {admin_user.username} already exists"


def test_update_user(client, access_token, editor_user):
assert editor_user.role == Role.EDITOR

Expand Down
3 changes: 2 additions & 1 deletion backend/endpoints/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def add_user(request: Request, username: str, password: str, role: str) -> UserS
detail="Forbidden",
)

if username in [user.username for user in db_user_handler.get_users()]:
existing_user = db_user_handler.get_user_by_username(username)
if existing_user:
msg = f"Username {username} already exists"
log.error(msg)
raise HTTPException(
Expand Down
6 changes: 4 additions & 2 deletions backend/handler/database/users_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ def add_user(self, user: User, session: Session = None) -> User:
return session.merge(user)

@begin_session
def get_user_by_username(self, username: str, session: Session = None):
def get_user_by_username(
self, username: str, session: Session = None
) -> User | None:
return session.scalar(select(User).filter_by(username=username).limit(1))

@begin_session
def get_user(self, id: int, session: Session = None) -> User:
def get_user(self, id: int, session: Session = None) -> User | None:
return session.get(User, id)

@begin_session
Expand Down

0 comments on commit fc8cbb7

Please sign in to comment.