Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
feat: migrate /users endpoint to /api/v1/users
Browse files Browse the repository at this point in the history
  • Loading branch information
jfcalvo committed May 6, 2024
1 parent 06bd520 commit 4fb7e73
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 3 deletions.
17 changes: 15 additions & 2 deletions src/argilla_server/apis/v1/handlers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Request, Security, status
Expand All @@ -20,9 +21,8 @@
from argilla_server import models, telemetry
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.models import User
from argilla_server.policies import UserPolicyV1, authorize
from argilla_server.schemas.v1.users import User
from argilla_server.schemas.v1.users import User, Users
from argilla_server.schemas.v1.workspaces import Workspaces
from argilla_server.security import auth

Expand All @@ -36,6 +36,19 @@ async def get_current_user(request: Request, current_user: models.User = Securit
return current_user


@router.get("/users", response_model=Users)
async def list_users(
*,
db: AsyncSession = Depends(get_async_db),
current_user: models.User = Security(auth.get_current_user),
):
await authorize(current_user, UserPolicyV1.list)

users = await accounts.list_users(db)

return Users(items=users)


@router.get("/users/{user_id}/workspaces", response_model=Workspaces)
async def list_user_workspaces(
*,
Expand Down
2 changes: 2 additions & 0 deletions src/argilla_server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ async def get_user_by_api_key(db: AsyncSession, api_key: str) -> Union[User, Non


async def list_users(db: "AsyncSession") -> Sequence[User]:
# TODO: After removing API v0 implementation we can remove the workspaces eager loading
# because is not used in the new API v1 endpoints.
result = await db.execute(select(User).order_by(User.inserted_at.asc()).options(selectinload(User.workspaces)))
return result.scalars().all()

Expand Down
4 changes: 4 additions & 0 deletions src/argilla_server/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ async def is_allowed(actor: User) -> bool:


class UserPolicyV1:
@classmethod
async def list(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def list_workspaces(cls, actor: User) -> bool:
return actor.is_owner
Expand Down
6 changes: 5 additions & 1 deletion src/argilla_server/schemas/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from datetime import datetime
from typing import Optional
from typing import List, Optional
from uuid import UUID

from argilla_server.enums import UserRole
Expand All @@ -32,3 +32,7 @@ class User(BaseModel):

class Config:
orm_mode = True


class Users(BaseModel):
items: List[User]
81 changes: 81 additions & 0 deletions tests/unit/api/v1/users/test_list_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from argilla_server.constants import API_KEY_HEADER_NAME
from argilla_server.enums import UserRole
from argilla_server.models import User
from httpx import AsyncClient

from tests.factories import UserFactory


@pytest.mark.asyncio
class TestListUsers:
def url(self) -> str:
return "/api/v1/users"

async def test_list_users(self, async_client: AsyncClient, owner: User, owner_auth_header: dict):
user_a, user_b = await UserFactory.create_batch(2)

response = await async_client.get(self.url(), headers=owner_auth_header)

assert response.status_code == 200
assert response.json() == {
"items": [
{
"id": str(owner.id),
"first_name": owner.first_name,
"last_name": owner.last_name,
"username": owner.username,
"role": owner.role,
"api_key": owner.api_key,
"inserted_at": owner.inserted_at.isoformat(),
"updated_at": owner.updated_at.isoformat(),
},
{
"id": str(user_a.id),
"first_name": user_a.first_name,
"last_name": user_a.last_name,
"username": user_a.username,
"role": user_a.role,
"api_key": user_a.api_key,
"inserted_at": user_a.inserted_at.isoformat(),
"updated_at": user_a.updated_at.isoformat(),
},
{
"id": str(user_b.id),
"first_name": user_b.first_name,
"last_name": user_b.last_name,
"username": user_b.username,
"role": user_b.role,
"api_key": user_b.api_key,
"inserted_at": user_b.inserted_at.isoformat(),
"updated_at": user_b.updated_at.isoformat(),
},
]
}

@pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator])
async def test_list_users_with_invalid_role(self, async_client: AsyncClient, user_role: UserRole):
user = await UserFactory.create(role=user_role)

response = await async_client.get(self.url(), headers={API_KEY_HEADER_NAME: user.api_key})

assert response.status_code == 403

async def test_list_users_without_authentication(self, async_client: AsyncClient):
response = await async_client.get(self.url())

assert response.status_code == 401

0 comments on commit 4fb7e73

Please sign in to comment.