Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Sync the latest changes from the argilla repository #11

Merged
merged 3 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion argilla
Submodule argilla updated 103 files
178 changes: 129 additions & 49 deletions pdm.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ dependencies = [
"python-multipart ~= 0.0.5",
"python-jose[cryptography] >= 3.2,< 3.4",
"passlib[bcrypt] ~= 1.7.4",
# OAuth2 integration
"oauthlib ~= 3.2.0",
"social-auth-core ~= 4.5.0",
# Info status
"psutil >= 5.8, <5.10",
# Telemetry
Expand Down
2 changes: 1 addition & 1 deletion src/argilla_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argilla_server.app import app
from argilla_server._app import app # noqa
1 change: 0 additions & 1 deletion src/argilla_server/app.py → src/argilla_server/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,3 @@ async def log_default_user_warning_if_present():


app = create_server_app()
app._get_db_wrapper = _get_db_wrapper
38 changes: 38 additions & 0 deletions src/argilla_server/apis/v0/handlers/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.errors import UnauthorizedError
from argilla_server.schemas.v0.authentication import Token, UserPasswordRequestForm
from argilla_server.security.authentication.jwt import JWT
from argilla_server.security.authentication.userinfo import UserInfo

router = APIRouter(tags=["Authentication"])


@router.post("/security/token", response_model=Token)
async def create_access_token(
db: AsyncSession = Depends(get_async_db), form: UserPasswordRequestForm = Depends()
) -> Token:
user = await accounts.authenticate_user(db, form.username, form.password)
if not user:
raise UnauthorizedError()

token = JWT.create(UserInfo(username=user.username, name=user.first_name, role=user.role, identity=str(user.id)))

return Token(access_token=token)
5 changes: 3 additions & 2 deletions src/argilla_server/apis/v0/handlers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from argilla_server.errors import EntityAlreadyExistsError, EntityNotFoundError
from argilla_server.policies import UserPolicy, authorize
from argilla_server.pydantic_v1 import parse_obj_as
from argilla_server.schemas.v0.users import User, UserCreate
from argilla_server.security import auth
from argilla_server.security.model import User, UserCreate

router = APIRouter(tags=["users"])

Expand All @@ -53,7 +53,7 @@ async def whoami(

"""

await telemetry.track_login(request, username=current_user.username)
await telemetry.track_login(request, current_user)

user = User.from_orm(current_user)
# TODO: The current client checks if a user can work on a specific workspace
Expand Down Expand Up @@ -96,6 +96,7 @@ async def create_user(

try:
user = await accounts.create_user(db, user_create)
telemetry.track_user_created(user)
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))

Expand Down
3 changes: 2 additions & 1 deletion src/argilla_server/apis/v0/handlers/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from argilla_server.errors import EntityAlreadyExistsError, EntityNotFoundError
from argilla_server.policies import WorkspacePolicy, WorkspaceUserPolicy, authorize
from argilla_server.pydantic_v1 import parse_obj_as
from argilla_server.schemas.v0.users import User
from argilla_server.schemas.v0.workspaces import Workspace, WorkspaceCreate, WorkspaceUserCreate
from argilla_server.security import auth
from argilla_server.security.model import User, Workspace, WorkspaceCreate, WorkspaceUserCreate

router = APIRouter(tags=["workspaces"])

Expand Down
1 change: 0 additions & 1 deletion src/argilla_server/apis/v1/handlers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ async def _get_field(db: "AsyncSession", field_id: UUID) -> "Field":
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Field with id `{field_id}` not found",
)

return field


Expand Down
102 changes: 102 additions & 0 deletions src/argilla_server/apis/v1/handlers/oauth2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server import telemetry
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.enums import UserRole
from argilla_server.errors.future import AuthenticationError
from argilla_server.models import User
from argilla_server.schemas.v1.oauth2 import Provider, Providers, Token
from argilla_server.security.authentication.jwt import JWT
from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.security.settings import settings

router = APIRouter(prefix="/oauth2", tags=["Authentication"])


_USER_ROLE_ON_CREATION = UserRole.annotator


@router.get("/providers", response_model=Providers)
def list_providers(_request: Request) -> Providers:
items = [Provider(name=provider_name) for provider_name in settings.oauth.providers]

return Providers(items=items)


@router.get("/providers/{provider}/authentication")
def get_authentication(request: Request, provider: str) -> RedirectResponse:
_check_oauth_enabled_or_raise()

provider = _get_provider_by_name_or_raise(provider)
return provider.authorization_redirect(request)


@router.get("/providers/{provider}/access-token", response_model=Token)
async def get_access_token(
request: Request,
provider: str,
db: AsyncSession = Depends(get_async_db),
) -> Token:
_check_oauth_enabled_or_raise()

try:
provider = _get_provider_by_name_or_raise(provider)
user_info = UserInfo(await provider.get_user_data(request))

user_info.use_claims(provider.claims)
username = user_info.username

user = await accounts.get_user_by_username(db, username)
if user is None:
user = await accounts.create_user_with_random_password(
db,
username=username,
first_name=user_info.name,
role=_USER_ROLE_ON_CREATION,
workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces],
)
telemetry.track_user_created(user, is_oauth=True)
elif not _is_user_created_by_oauth_provider(user):
# User should sign in using username/password workflow
raise AuthenticationError("Could not authenticate user")

return Token(access_token=JWT.create(user_info))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except AuthenticationError as e:
raise HTTPException(status_code=401, detail=str(e))


def _check_oauth_enabled_or_raise() -> None:
if not settings.oauth.enabled:
raise HTTPException(status_code=404, detail="OAuth2 is not enabled")


def _get_provider_by_name_or_raise(provider_name: str) -> OAuth2ClientProvider:
if not provider_name in settings.oauth.providers:
raise HTTPException(status_code=404, detail=f"Provider '{provider_name}' not found")
return settings.oauth.providers[provider_name]


def _is_user_created_by_oauth_provider(user: User) -> bool:
# TODO: We must link the created user with the provider, and base this check on that.
# For now, we just validate the user role on creation.
return user.role == _USER_ROLE_ON_CREATION
7 changes: 2 additions & 5 deletions src/argilla_server/cli/database/users/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
from argilla_server.database import AsyncSessionLocal
from argilla_server.models import User, UserRole
from argilla_server.pydantic_v1 import constr
from argilla_server.security.model import (
USER_PASSWORD_MIN_LENGTH,
UserCreate,
WorkspaceCreate,
)
from argilla_server.schemas.v0.users import USER_PASSWORD_MIN_LENGTH, UserCreate
from argilla_server.schemas.v0.workspaces import WorkspaceCreate

from .utils import get_or_new_workspace

Expand Down
9 changes: 6 additions & 3 deletions src/argilla_server/cli/database/users/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from typing import TYPE_CHECKING, List, Optional

import typer
Expand All @@ -20,8 +21,8 @@
from argilla_server.database import AsyncSessionLocal
from argilla_server.models import User, UserRole
from argilla_server.pydantic_v1 import BaseModel, Field, constr
from argilla_server.security.auth_provider.db.settings import settings
from argilla_server.security.model import USER_USERNAME_REGEX, WORKSPACE_NAME_REGEX
from argilla_server.schemas.v0.users import USER_USERNAME_REGEX
from argilla_server.schemas.v0.workspaces import WORKSPACE_NAME_REGEX

from .utils import get_or_new_workspace

Expand Down Expand Up @@ -107,7 +108,9 @@ def _user_workspace_names(self, user: dict) -> List[str]:

def migrate():
"""Migrate users defined in YAML file to database."""
asyncio.run(UsersMigrator(settings.users_db_file).migrate())

users_db_file: str = os.getenv("ARGILLA_LOCAL_AUTH_USERS_DB_FILE", ".users.yml")
asyncio.run(UsersMigrator(users_db_file).migrate())


if __name__ == "__main__":
Expand Down
25 changes: 23 additions & 2 deletions src/argilla_server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import secrets
from typing import TYPE_CHECKING, List, Union
from uuid import UUID

from passlib.context import CryptContext
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, selectinload

from argilla_server.enums import UserRole
from argilla_server.models import User, Workspace, WorkspaceUser
from argilla_server.security.model import UserCreate, WorkspaceCreate, WorkspaceUserCreate
from argilla_server.schemas.v0.users import UserCreate
from argilla_server.schemas.v0.workspaces import WorkspaceCreate, WorkspaceUserCreate

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -141,6 +143,21 @@ async def create_user(db: "AsyncSession", user_create: UserCreate) -> User:
return user


async def create_user_with_random_password(
db,
username: str,
first_name: str,
workspaces: List[str] = None,
role: UserRole = UserRole.annotator,
) -> User:
password = _generate_random_password()

user_create = UserCreate(
first_name=first_name, username=username, role=role, password=password, workspaces=workspaces
)
return await create_user(db, user_create)


async def delete_user(db: "AsyncSession", user: User) -> User:
return await user.delete(db)

Expand All @@ -162,3 +179,7 @@ def hash_password(password: str) -> str:

def verify_password(password: str, password_hash: str) -> bool:
return _CRYPT_CONTEXT.verify(password, password_hash)


def _generate_random_password() -> str:
return secrets.token_urlsafe()
2 changes: 1 addition & 1 deletion src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
VectorSettings,
)
from argilla_server.models.suggestions import SuggestionCreateWithRecordId
from argilla_server.schemas.v0.users import User
from argilla_server.schemas.v1.datasets import (
DatasetCreate,
)
Expand Down Expand Up @@ -79,7 +80,6 @@
)
from argilla_server.schemas.v1.vectors import Vector as VectorSchema
from argilla_server.search_engine import SearchEngine
from argilla_server.security.model import User

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down
2 changes: 1 addition & 1 deletion src/argilla_server/errors/future/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_errors import NotFoundError
from .base_errors import * # noqa
8 changes: 8 additions & 0 deletions src/argilla_server/errors/future/base_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["NotFoundError", "AuthenticationError"]


class NotFoundError(Exception):
"""Custom Argilla not found error. Use it for situations where an Argilla domain entity has not be found on the system."""

pass


class AuthenticationError(Exception):
"""Custom Argilla unauthorized error. Use it for situations where an request is not authorized to perform an action."""

pass
4 changes: 4 additions & 0 deletions src/argilla_server/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fastapi import APIRouter, HTTPException, Request

from argilla_server.apis.v0.handlers import (
authentication,
datasets,
info,
metrics,
Expand All @@ -36,6 +37,7 @@
from argilla_server.apis.v1.handlers import datasets as datasets_v1
from argilla_server.apis.v1.handlers import fields as fields_v1
from argilla_server.apis.v1.handlers import metadata_properties as metadata_properties_v1
from argilla_server.apis.v1.handlers import oauth2 as oauth2_v1
from argilla_server.apis.v1.handlers import questions as questions_v1
from argilla_server.apis.v1.handlers import records as records_v1
from argilla_server.apis.v1.handlers import responses as responses_v1
Expand All @@ -51,6 +53,7 @@
dependencies = []

for router in [
authentication.router,
users.router,
workspaces.router,
datasets.router,
Expand All @@ -76,6 +79,7 @@
api_router.include_router(users_v1.router, prefix="/v1")
api_router.include_router(vectors_settings_v1.router, prefix="/v1")
api_router.include_router(workspaces_v1.router, prefix="/v1")
api_router.include_router(oauth2_v1.router, prefix="/v1")


@api_router.route("/{_:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"], include_in_schema=False)
Expand Down
Loading
Loading