Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
feat: migrate GET /api/workspaces/:workspace_id/users to GET /api/v1/…
Browse files Browse the repository at this point in the history
…workspaces/:workspace_id/users
  • Loading branch information
jfcalvo committed May 8, 2024
1 parent 1cd5abc commit 38d5fa8
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 2 deletions.
28 changes: 26 additions & 2 deletions src/argilla_server/apis/v1/handlers/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from argilla_server.database import get_async_db
from argilla_server.errors import EntityAlreadyExistsError
from argilla_server.models import User
from argilla_server.policies import WorkspacePolicyV1, authorize
from argilla_server.policies import WorkspacePolicyV1, WorkspaceUserPolicyV1, authorize
from argilla_server.schemas.v1.users import Users
from argilla_server.schemas.v1.workspaces import Workspace, WorkspaceCreate, Workspaces
from argilla_server.security import auth
from argilla_server.services.datasets import DatasetsService
Expand Down Expand Up @@ -97,7 +98,9 @@ async def delete_workspace(

@router.get("/me/workspaces", response_model=Workspaces)
async def list_workspaces_me(
*, db: AsyncSession = Depends(get_async_db), current_user: User = Security(auth.get_current_user)
*,
db: AsyncSession = Depends(get_async_db),
current_user: User = Security(auth.get_current_user),
) -> Workspaces:
await authorize(current_user, WorkspacePolicyV1.list_workspaces_me)

Expand All @@ -107,3 +110,24 @@ async def list_workspaces_me(
workspaces = await accounts.list_workspaces_by_user_id(db, current_user.id)

return Workspaces(items=workspaces)


@router.get("/workspaces/{workspace_id}/users", response_model=Users)
async def list_workspace_users(
*,
db: AsyncSession = Depends(get_async_db),
workspace_id: UUID,
current_user: User = Security(auth.get_current_user),
):
await authorize(current_user, WorkspaceUserPolicyV1.list(workspace_id))

workspace = await accounts.get_workspace_by_id(db, workspace_id)
if workspace is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Workspace with id `{workspace_id}` not found",
)

await workspace.awaitable_attrs.users

return Users(items=workspace.users)
11 changes: 11 additions & 0 deletions src/argilla_server/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ async def is_allowed(actor: User) -> bool:
return is_allowed


class WorkspaceUserPolicyV1:
@classmethod
def list(cls, workspace_id: UUID) -> PolicyAction:
async def is_allowed(actor: User) -> bool:
return actor.is_owner or (
actor.is_admin and await _exists_workspace_user_by_user_and_workspace_id(actor, workspace_id)
)

return is_allowed


class WorkspacePolicy:
@classmethod
async def list(cls, actor: User) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions src/argilla_server/schemas/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ class User(BaseModel):

class Config:
orm_mode = True


class Users(BaseModel):
items: list[User]
130 changes: 130 additions & 0 deletions tests/unit/api/v1/workspaces/test_list_workspace_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from uuid import UUID, uuid4

import pytest
from argilla_server.constants import API_KEY_HEADER_NAME
from argilla_server.enums import UserRole
from httpx import AsyncClient

from tests.factories import AdminFactory, AnnotatorFactory, UserFactory, WorkspaceFactory, WorkspaceUserFactory


@pytest.mark.asyncio
class TestListWorkspaceUsers:
def url(self, workspace_id: UUID) -> str:
return f"/api/v1/workspaces/{workspace_id}/users"

async def test_list_workspace_users(self, async_client: AsyncClient, owner_auth_header: dict):
workspace = await WorkspaceFactory.create()
users = await UserFactory.create_batch(3)
await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=users[0].id)
await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=users[1].id)
await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=users[2].id)

other_workspace = await WorkspaceFactory.create()
other_users = await UserFactory.create_batch(2)
await WorkspaceUserFactory.create(workspace_id=other_workspace.id, user_id=other_users[0].id)
await WorkspaceUserFactory.create(workspace_id=other_workspace.id, user_id=other_users[1].id)

response = await async_client.get(self.url(workspace.id), headers=owner_auth_header)

assert response.status_code == 200
assert response.json() == {
"items": [
{
"id": str(users[0].id),
"first_name": users[0].first_name,
"last_name": users[0].last_name,
"username": users[0].username,
"role": UserRole.annotator,
"api_key": users[0].api_key,
"inserted_at": users[0].inserted_at.isoformat(),
"updated_at": users[0].updated_at.isoformat(),
},
{
"id": str(users[1].id),
"first_name": users[1].first_name,
"last_name": users[1].last_name,
"username": users[1].username,
"role": UserRole.annotator,
"api_key": users[1].api_key,
"inserted_at": users[1].inserted_at.isoformat(),
"updated_at": users[1].updated_at.isoformat(),
},
{
"id": str(users[2].id),
"first_name": users[2].first_name,
"last_name": users[2].last_name,
"username": users[2].username,
"role": UserRole.annotator,
"api_key": users[2].api_key,
"inserted_at": users[2].inserted_at.isoformat(),
"updated_at": users[2].updated_at.isoformat(),
},
],
}

async def test_list_workspace_users_without_authentication(self, async_client: AsyncClient):
workspace = await WorkspaceFactory.create()

response = await async_client.get(self.url(workspace.id))

assert response.status_code == 401

async def test_list_workspace_users_as_admin(self, async_client: AsyncClient):
workspace = await WorkspaceFactory.create()
admin = await AdminFactory.create()
await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=admin.id)

response = await async_client.get(
self.url(workspace.id),
headers={API_KEY_HEADER_NAME: admin.api_key},
)

assert response.status_code == 200

async def test_list_workspace_users_as_admin_from_different_workspace(self, async_client: AsyncClient):
workspace = await WorkspaceFactory.create()
admin = await AdminFactory.create()

response = await async_client.get(
self.url(workspace.id),
headers={API_KEY_HEADER_NAME: admin.api_key},
)

assert response.status_code == 403

async def test_list_workspace_users_as_annotator(self, async_client: AsyncClient):
workspace = await WorkspaceFactory.create()
annotator = await AnnotatorFactory.create()
await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=annotator.id)

response = await async_client.get(
self.url(workspace.id),
headers={API_KEY_HEADER_NAME: annotator.api_key},
)

assert response.status_code == 403

async def test_list_workspace_with_nonexistent_workspace_id(
self, async_client: AsyncClient, owner_auth_header: dict
):
workspace_id = uuid4()

response = await async_client.get(self.url(workspace_id), headers=owner_auth_header)

assert response.status_code == 404
assert response.json() == {"detail": f"Workspace with id `{workspace_id}` not found"}

0 comments on commit 38d5fa8

Please sign in to comment.