Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
feat: migrate POST /users endpoint to POST /api/v1/users
Browse files Browse the repository at this point in the history
  • Loading branch information
jfcalvo committed May 7, 2024
1 parent 0453420 commit 06c32d1
Show file tree
Hide file tree
Showing 8 changed files with 422 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/argilla_server/apis/v0/handlers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def create_user(
raise EntityAlreadyExistsError(name=user_create.username, type=User)

try:
user = await accounts.create_user(db, user_create)
user = await accounts.create_user(db, user_create.dict(), user_create.workspaces)
telemetry.track_user_created(user)
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
Expand Down
26 changes: 25 additions & 1 deletion src/argilla_server/apis/v1/handlers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from argilla_server import models, telemetry
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.errors import EntityAlreadyExistsError
from argilla_server.policies import UserPolicyV1, authorize
from argilla_server.schemas.v1.users import User, Users
from argilla_server.schemas.v1.users import User, UserCreate, Users
from argilla_server.schemas.v1.workspaces import Workspaces
from argilla_server.security import auth

Expand All @@ -49,6 +50,29 @@ async def list_users(
return Users(items=users)


@router.post("/users", response_model=User)
async def create_user(
*,
db: AsyncSession = Depends(get_async_db),
user_create: UserCreate,
current_user: models.User = Security(auth.get_current_user),
):
await authorize(current_user, UserPolicyV1.create)

user = await accounts.get_user_by_username(db, user_create.username)
if user is not None:
raise EntityAlreadyExistsError(name=user_create.username, type=User)

try:
user = await accounts.create_user(db, user_create.dict())

telemetry.track_user_created(user)
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))

return user


@router.get("/users/{user_id}/workspaces", response_model=Workspaces)
async def list_user_workspaces(
*,
Expand Down
36 changes: 21 additions & 15 deletions src/argilla_server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,26 @@ async def list_users_by_ids(db: AsyncSession, ids: Iterable[UUID]) -> Sequence[U
return result.scalars().all()


async def create_user(db: "AsyncSession", user_create: UserCreate) -> User:
# TODO: After removing API v0 implementation we can remove the workspaces attribute.
# With API v1 the workspaces will be created doing additional requests to other endpoints for it.
async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List[str], None] = None) -> User:
async with db.begin_nested():
user = await User.create(
db,
first_name=user_create.first_name,
last_name=user_create.last_name,
username=user_create.username,
role=user_create.role,
password_hash=hash_password(user_create.password),
first_name=user_attrs["first_name"],
last_name=user_attrs["last_name"],
username=user_attrs["username"],
role=user_attrs["role"],
password_hash=hash_password(user_attrs["password"]),
autocommit=False,
)

if user_create.workspaces:
for workspace_name in user_create.workspaces:
if workspaces is not None:
for workspace_name in workspaces:
workspace = await get_workspace_by_name(db, workspace_name)
if not workspace:
raise ValueError(f"Workspace '{workspace_name}' does not exist")

await WorkspaceUser.create(
db,
workspace_id=workspace.id,
Expand All @@ -154,15 +157,18 @@ async def create_user_with_random_password(
db,
username: str,
first_name: str,
workspaces: List[str] = None,
role: UserRole = UserRole.annotator,
workspaces: Union[List[str], None] = None,
) -> User:
password = _generate_random_password()

user_create = UserCreate(
first_name=first_name, username=username, role=role, password=password, workspaces=workspaces
)
return await create_user(db, user_create)
user_attrs = {
"first_name": first_name,
"last_name": None,
"username": username,
"role": role,
"password": _generate_random_password(),
}

return await create_user(db, user_attrs, workspaces)


async def delete_user(db: AsyncSession, user: User) -> User:
Expand Down
4 changes: 4 additions & 0 deletions src/argilla_server/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ class UserPolicyV1:
async def list(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def create(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def list_workspaces(cls, actor: User) -> bool:
return actor.is_owner
Expand Down
14 changes: 13 additions & 1 deletion src/argilla_server/schemas/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from uuid import UUID

from argilla_server.enums import UserRole
from argilla_server.pydantic_v1 import BaseModel
from argilla_server.pydantic_v1 import BaseModel, Field, constr

USER_USERNAME_REGEX = "^(?!-|_)[A-za-z0-9-_]+$"
USER_PASSWORD_MIN_LENGTH = 8
USER_PASSWORD_MAX_LENGTH = 100


class User(BaseModel):
Expand All @@ -34,5 +38,13 @@ class Config:
orm_mode = True


class UserCreate(BaseModel):
first_name: constr(min_length=1, strip_whitespace=True)
last_name: Optional[constr(min_length=1, strip_whitespace=True)]
username: str = Field(regex=USER_USERNAME_REGEX, min_length=1)
role: Optional[UserRole]
password: str = Field(min_length=USER_PASSWORD_MIN_LENGTH, max_length=USER_PASSWORD_MAX_LENGTH)


class Users(BaseModel):
items: List[User]
47 changes: 47 additions & 0 deletions tests/unit/api/v0/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,53 @@ async def test_create_user_with_non_default_role(
assert response_body["role"] == UserRole.owner.value


@pytest.mark.asyncio
async def test_create_user_with_first_name_including_leading_and_trailing_spaces(
async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
):
response = await async_client.post(
"/api/users",
headers=owner_auth_header,
json={
"first_name": " First name ",
"username": "username",
"password": "12345678",
},
)

assert response.status_code == 200

assert (await db.execute(select(func.count(User.id)))).scalar() == 2
user = (await db.execute(select(User).filter_by(username="username"))).scalar_one()

assert response.json()["first_name"] == "First name"
assert user.first_name == "First name"


@pytest.mark.asyncio
async def test_create_user_with_last_name_including_leading_and_trailing_spaces(
async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
):
response = await async_client.post(
"/api/users",
headers=owner_auth_header,
json={
"first_name": "First name",
"last_name": " Last name ",
"username": "username",
"password": "12345678",
},
)

assert response.status_code == 200

assert (await db.execute(select(func.count(User.id)))).scalar() == 2
user = (await db.execute(select(User).filter_by(username="username"))).scalar_one()

assert response.json()["last_name"] == "Last name"
assert user.last_name == "Last name"


@pytest.mark.asyncio
async def test_create_user_without_authentication(async_client: "AsyncClient", db: "AsyncSession"):
user = {"first_name": "first-name", "username": "username", "password": "12345678"}
Expand Down
Loading

0 comments on commit 06c32d1

Please sign in to comment.