Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
Merge branch 'feat/move-api-v0-functionality-to-v1' into feat/migrate…
Browse files Browse the repository at this point in the history
…-create-workspace-to-api-v1
  • Loading branch information
jfcalvo committed May 13, 2024
2 parents a29021b + 2ab1d19 commit a57ae46
Show file tree
Hide file tree
Showing 10 changed files with 639 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ These are the section headers that we use:

- Added `POST /api/v1/token` endpoint to generate a new API token for a user. ([#138](https://github.com/argilla-io/argilla-server/pull/138))
- Added `GET /api/v1/me` endpoint to get the current user information. ([#140](https://github.com/argilla-io/argilla-server/pull/140))
- Added `GET /api/v1/users` endpoint to get a list of all users. ([#142](https://github.com/argilla-io/argilla-server/pull/142))
- Added `POST /api/v1/users` endpoint to create a new user. ([#146](https://github.com/argilla-io/argilla-server/pull/146))
- Added `DELETE /api/v1/users` endpoint to delete a user. ([#148](https://github.com/argilla-io/argilla-server/pull/148))
- Added `POST /api/v1/workspaces` endpoint to create a new workspace. ([#150](https://github.com/argilla-io/argilla-server/pull/150))

## [Unreleased]()
Expand Down
11 changes: 6 additions & 5 deletions src/argilla_server/apis/v0/handlers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.errors import EntityAlreadyExistsError, EntityNotFoundError
from argilla_server.errors.future import NotUniqueError
from argilla_server.policies import UserPolicy, authorize
from argilla_server.pydantic_v1 import parse_obj_as
from argilla_server.schemas.v0.users import User, UserCreate
Expand Down Expand Up @@ -90,17 +91,17 @@ async def create_user(
):
await authorize(current_user, UserPolicy.create)

user = await accounts.get_user_by_username(db, user_create.username)
if user is not None:
raise EntityAlreadyExistsError(name=user_create.username, type=User)

try:
user = await accounts.create_user(db, user_create)
user = await accounts.create_user(db, user_create.dict(), user_create.workspaces)

telemetry.track_user_created(user)
except NotUniqueError:
raise EntityAlreadyExistsError(name=user_create.username, type=User)
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))

await user.awaitable_attrs.workspaces

return User.from_orm(user)


Expand Down
60 changes: 58 additions & 2 deletions src/argilla_server/apis/v1/handlers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Request, Security, status
Expand All @@ -20,9 +21,9 @@
from argilla_server import models, telemetry
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.models import User
from argilla_server.errors.future import NotUniqueError
from argilla_server.policies import UserPolicyV1, authorize
from argilla_server.schemas.v1.users import User
from argilla_server.schemas.v1.users import User, UserCreate, Users
from argilla_server.schemas.v1.workspaces import Workspaces
from argilla_server.security import auth

Expand All @@ -36,6 +37,61 @@ async def get_current_user(request: Request, current_user: models.User = Securit
return current_user


@router.get("/users", response_model=Users)
async def list_users(
*,
db: AsyncSession = Depends(get_async_db),
current_user: models.User = Security(auth.get_current_user),
):
await authorize(current_user, UserPolicyV1.list)

users = await accounts.list_users(db)

return Users(items=users)


@router.post("/users", status_code=status.HTTP_201_CREATED, response_model=User)
async def create_user(
*,
db: AsyncSession = Depends(get_async_db),
user_create: UserCreate,
current_user: models.User = Security(auth.get_current_user),
):
await authorize(current_user, UserPolicyV1.create)

try:
user = await accounts.create_user(db, user_create.dict())

telemetry.track_user_created(user)
except NotUniqueError as e:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e))
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))

return user


@router.delete("/users/{user_id}", response_model=User)
async def delete_user(
*,
db: AsyncSession = Depends(get_async_db),
user_id: UUID,
current_user: models.User = Security(auth.get_current_user),
):
user = await accounts.get_user_by_id(db, user_id)
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User with id `{user_id}` not found",
)

await authorize(current_user, UserPolicyV1.delete)

await accounts.delete_user(db, user)

return user


@router.get("/users/{user_id}/workspaces", response_model=Workspaces)
async def list_user_workspaces(
*,
Expand Down
41 changes: 26 additions & 15 deletions src/argilla_server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ async def get_user_by_api_key(db: AsyncSession, api_key: str) -> Union[User, Non


async def list_users(db: "AsyncSession") -> Sequence[User]:
# TODO: After removing API v0 implementation we can remove the workspaces eager loading
# because is not used in the new API v1 endpoints.
result = await db.execute(select(User).order_by(User.inserted_at.asc()).options(selectinload(User.workspaces)))
return result.scalars().all()

Expand All @@ -123,23 +125,29 @@ async def list_users_by_ids(db: AsyncSession, ids: Iterable[UUID]) -> Sequence[U
return result.scalars().all()


async def create_user(db: "AsyncSession", user_create: UserCreate) -> User:
# TODO: After removing API v0 implementation we can remove the workspaces attribute.
# With API v1 the workspaces will be created doing additional requests to other endpoints for it.
async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List[str], None] = None) -> User:
if (await get_user_by_username(db, user_attrs["username"])) is not None:
raise NotUniqueError(f"Username `{user_attrs['username']}` is not unique")

async with db.begin_nested():
user = await User.create(
db,
first_name=user_create.first_name,
last_name=user_create.last_name,
username=user_create.username,
role=user_create.role,
password_hash=hash_password(user_create.password),
first_name=user_attrs["first_name"],
last_name=user_attrs["last_name"],
username=user_attrs["username"],
role=user_attrs["role"],
password_hash=hash_password(user_attrs["password"]),
autocommit=False,
)

if user_create.workspaces:
for workspace_name in user_create.workspaces:
if workspaces is not None:
for workspace_name in workspaces:
workspace = await get_workspace_by_name(db, workspace_name)
if not workspace:
raise ValueError(f"Workspace '{workspace_name}' does not exist")

await WorkspaceUser.create(
db,
workspace_id=workspace.id,
Expand All @@ -156,15 +164,18 @@ async def create_user_with_random_password(
db,
username: str,
first_name: str,
workspaces: List[str] = None,
role: UserRole = UserRole.annotator,
workspaces: Union[List[str], None] = None,
) -> User:
password = _generate_random_password()

user_create = UserCreate(
first_name=first_name, username=username, role=role, password=password, workspaces=workspaces
)
return await create_user(db, user_create)
user_attrs = {
"first_name": first_name,
"last_name": None,
"username": username,
"role": role,
"password": _generate_random_password(),
}

return await create_user(db, user_attrs, workspaces)


async def delete_user(db: AsyncSession, user: User) -> User:
Expand Down
12 changes: 12 additions & 0 deletions src/argilla_server/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ async def is_allowed(actor: User) -> bool:


class UserPolicyV1:
@classmethod
async def list(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def create(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def delete(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def list_workspaces(cls, actor: User) -> bool:
return actor.is_owner
Expand Down
20 changes: 18 additions & 2 deletions src/argilla_server/schemas/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
# limitations under the License.

from datetime import datetime
from typing import Optional
from typing import List, Optional
from uuid import UUID

from argilla_server.enums import UserRole
from argilla_server.pydantic_v1 import BaseModel
from argilla_server.pydantic_v1 import BaseModel, Field, constr

USER_USERNAME_REGEX = "^(?!-|_)[A-za-z0-9-_]+$"
USER_PASSWORD_MIN_LENGTH = 8
USER_PASSWORD_MAX_LENGTH = 100


class User(BaseModel):
Expand All @@ -32,3 +36,15 @@ class User(BaseModel):

class Config:
orm_mode = True


class UserCreate(BaseModel):
first_name: constr(min_length=1, strip_whitespace=True)
last_name: Optional[constr(min_length=1, strip_whitespace=True)]
username: str = Field(regex=USER_USERNAME_REGEX, min_length=1)
role: Optional[UserRole]
password: str = Field(min_length=USER_PASSWORD_MIN_LENGTH, max_length=USER_PASSWORD_MAX_LENGTH)


class Users(BaseModel):
items: List[User]
47 changes: 47 additions & 0 deletions tests/unit/api/v0/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,53 @@ async def test_create_user_with_non_default_role(
assert response_body["role"] == UserRole.owner.value


@pytest.mark.asyncio
async def test_create_user_with_first_name_including_leading_and_trailing_spaces(
async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
):
response = await async_client.post(
"/api/users",
headers=owner_auth_header,
json={
"first_name": " First name ",
"username": "username",
"password": "12345678",
},
)

assert response.status_code == 200

assert (await db.execute(select(func.count(User.id)))).scalar() == 2
user = (await db.execute(select(User).filter_by(username="username"))).scalar_one()

assert response.json()["first_name"] == "First name"
assert user.first_name == "First name"


@pytest.mark.asyncio
async def test_create_user_with_last_name_including_leading_and_trailing_spaces(
async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
):
response = await async_client.post(
"/api/users",
headers=owner_auth_header,
json={
"first_name": "First name",
"last_name": " Last name ",
"username": "username",
"password": "12345678",
},
)

assert response.status_code == 200

assert (await db.execute(select(func.count(User.id)))).scalar() == 2
user = (await db.execute(select(User).filter_by(username="username"))).scalar_one()

assert response.json()["last_name"] == "Last name"
assert user.last_name == "Last name"


@pytest.mark.asyncio
async def test_create_user_without_authentication(async_client: "AsyncClient", db: "AsyncSession"):
user = {"first_name": "first-name", "username": "username", "password": "12345678"}
Expand Down
Loading

0 comments on commit a57ae46

Please sign in to comment.