diff --git a/backend/alembic/versions/548cdd92ed41_setup_roles.py b/backend/alembic/versions/548cdd92ed41_setup_roles.py new file mode 100644 index 0000000..31f51bc --- /dev/null +++ b/backend/alembic/versions/548cdd92ed41_setup_roles.py @@ -0,0 +1,30 @@ +"""Setup Roles + +Revision ID: 548cdd92ed41 +Revises: 0f18e18f9ae9 +Create Date: 2024-06-16 14:36:40.496965 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '548cdd92ed41' +down_revision: Union[str, None] = '0f18e18f9ae9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('user', 'quota') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('user', sa.Column('quota', sa.INTEGER(), autoincrement=False, nullable=True)) + # ### end Alembic commands ### diff --git a/backend/app/auth/models.py b/backend/app/auth/models.py index 14229a4..9b31722 100644 --- a/backend/app/auth/models.py +++ b/backend/app/auth/models.py @@ -16,7 +16,6 @@ class User(Base): username = Column(String, nullable=False, unique=True) password = Column(String, nullable=False) avatar = Column(String, nullable=True) - quota = Column(Integer, nullable=True) timestamp = Column(DateTime(timezone=True), nullable=True) role_id = Column(UUID(as_uuid=True), ForeignKey("role.id"), nullable=True) @@ -29,8 +28,14 @@ class User(Base): shared_files = relationship("Share", back_populates="user", lazy="selectin") code_2fa = Column(String, nullable=True) - def has_remaining_quota(self: Self) -> bool: - return bool(self.quota != 0) + def has_remaining_files_quota(self: Self) -> bool: + return self.role.quota_files is None or len(self.files) < self.role.quota_files + + def has_remaining_size_quota(self: Self, size: int) -> bool: + return self.role.quota_size is None or self.get_used_space() + size < self.role.quota_size + + def get_used_space(self: Self) -> int: + return sum(file.size for file in self.files) class Role(Base): diff --git a/backend/app/auth/router.py b/backend/app/auth/router.py index da30651..ad616d5 100644 --- a/backend/app/auth/router.py +++ b/backend/app/auth/router.py @@ -32,7 +32,7 @@ async def simple_test() -> str: @router.post("/create-test-user", response_model=str) async def test(session: Annotated[AsyncSession, Depends(get_async_session)]) -> str: - await create_user("a", "a", 30, session) + await create_user("a", "a", session) return "Created a test user" @@ -81,7 +81,7 @@ async def logout( async def register( user_schema: User, session: Annotated[AsyncSession, Depends(get_async_session)] ) -> RequestStatus: - user = await create_user(user_schema.username, user_schema.password, 30, session) + user = await create_user(user_schema.username, user_schema.password, session) return RequestStatus(message=f"User {user.username} registered successfully") @@ -121,6 +121,8 @@ async def fetch_user_data( ) -> UserMetadata: return UserMetadata( username=str(current_user.username), - quota=int(current_user.quota), + role=str(current_user.role.name), + files_quota=int(current_user.role.quota_files), + size_quota=int(current_user.role.quota_size), is_2fa_enabled=(current_user.code_2fa is not None), ) diff --git a/backend/app/auth/schemas.py b/backend/app/auth/schemas.py index 1b15e58..1746edf 100644 --- a/backend/app/auth/schemas.py +++ b/backend/app/auth/schemas.py @@ -25,5 +25,7 @@ class Code2FA(BaseModel): class UserMetadata(BaseModel): username: str - quota: int + role: str + files_quota: int + size_quota: int is_2fa_enabled: bool diff --git a/backend/app/auth/service.py b/backend/app/auth/service.py index 4b219c5..ebd4d8d 100644 --- a/backend/app/auth/service.py +++ b/backend/app/auth/service.py @@ -10,11 +10,12 @@ from passlib.context import CryptContext from sqlalchemy import ColumnElement, delete, select, update from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload from ..database import get_async_session from .constants import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM from .exceptions import USERNAME_TAKEN_EXCEPTION -from .models import LoggedInTokens, User +from .models import LoggedInTokens, Role, User SECRET_KEY = os.getenv("JWT_SECRET_KEY", "SECRET") @@ -22,11 +23,36 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") -async def create_user( - username: str, plain_password: str, quota: int, session: AsyncSession -) -> User: +async def get_user_role(session: AsyncSession) -> Role: + role = await session.execute(select(Role).where(Role.name == "user")) + + scalar_role = role.scalar_one_or_none() + + if scalar_role is None: + raise ValueError("Role 'user' not found") + + return scalar_role + + +async def get_admin_role(session: AsyncSession) -> Role: + role = await session.execute(select(Role).where(Role.name == "admin")) + + scalar_role = role.scalar_one_or_none() + + if scalar_role is None: + raise ValueError("Role 'admin' not found") + + return scalar_role + + +async def create_user(username: str, plain_password: str, session: AsyncSession) -> User: password = pwd_context.hash(plain_password) - user = User(username=username, password=password, quota=quota) + + user_role = ( + await get_admin_role(session) if username == "admin" else await get_user_role(session) + ) + + user = User(username=username, password=password, role_id=user_role.id) existing_user = await get_user_by_username(username, session) @@ -137,7 +163,9 @@ async def get_user_by_id(user_id: UUID, session: AsyncSession) -> User | None: async def get_user_by_username(username: str, session: AsyncSession) -> User | None: async with session: - users = await session.execute(select(User).filter(User.username == username)) + users = await session.execute( + select(User).options(joinedload(User.role)).filter(User.username == username) + ) return users.scalar_one_or_none() diff --git a/backend/app/database.py b/backend/app/database.py index d4a87a5..b3bb8a2 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -1,6 +1,8 @@ import os from collections.abc import AsyncGenerator +from datetime import datetime +from sqlalchemy import select from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -11,6 +13,8 @@ from alembic import command from alembic.config import Config +from .auth.models import Role + SQLALCHEMY_DATABASE_URL = os.getenv( "DATABASE_URL", "postgresql+asyncpg://synthra:synthra@database:5432/synthra", @@ -29,7 +33,7 @@ def get_engine(cls: type["DatabaseEngine"]) -> AsyncEngine: class AsyncSessionMaker: - _instance: None | async_sessionmaker[AsyncSession] = None + _instance: async_sessionmaker[AsyncSession] | None = None @classmethod def get_sessionmaker(cls: type["AsyncSessionMaker"]) -> async_sessionmaker[AsyncSession]: @@ -51,3 +55,22 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: def run_migrations() -> None: alembic_cfg = Config("alembic.ini") command.upgrade(alembic_cfg, "head") + + +async def initialize_database() -> None: + async with AsyncSessionMaker.get_sessionmaker()() as session: + await initialize_roles(session) + + +async def initialize_roles(session: AsyncSession) -> None: + existing_roles = await session.execute(select(Role).limit(1)) + if existing_roles.scalars().first() is not None: + return + + roles = [ + Role(name="admin", quota_size=1000000000, quota_files=50, timestamp=datetime.now()), + Role(name="user", quota_size=100000000, quota_files=10, timestamp=datetime.now()), + ] + + session.add_all(roles) + await session.commit() diff --git a/backend/app/files/service.py b/backend/app/files/service.py index cc32fec..cbe1fc7 100644 --- a/backend/app/files/service.py +++ b/backend/app/files/service.py @@ -9,7 +9,7 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from fastapi import Depends, UploadFile -from sqlalchemy import delete, select, update +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from ..auth.dependencies import get_current_user @@ -27,7 +27,10 @@ async def upload_file( is_shared: bool = False, password: str | None = None, ) -> str: - if not current_user.has_remaining_quota(): + if not current_user.has_remaining_files_quota(): + raise QUOTA_EXCEPTION + + if not current_user.has_remaining_size_quota(file.size or 0): raise QUOTA_EXCEPTION file_path = f"{uuid.uuid4()}{file.filename}" @@ -56,11 +59,6 @@ async def upload_file( ) session.add(file_db) - update_statement = ( - update(User).where(User.id == current_user.id).values(quota=User.quota - 1) - ) - - await session.execute(update_statement) return str(file_path) diff --git a/backend/app/main.py b/backend/app/main.py index 8bc984e..dce63fc 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -7,7 +7,7 @@ from starlette.middleware.cors import CORSMiddleware from .auth.router import router as auth_router -from .database import run_migrations +from .database import initialize_database, run_migrations from .files.constants import FILE_PATH from .files.router import router as file_router from .jobs import schedule_jobs @@ -19,30 +19,40 @@ @asynccontextmanager async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]: - print("Starting...") + print("Starting...", flush=True) - print("Creating file storage directory...") - Path.mkdir(Path(FILE_PATH), exist_ok=True) + print("Creating file storage directory...", flush=True) + Path(FILE_PATH).mkdir(parents=True, exist_ok=True) + print("File storage directory created", flush=True) - print("Running migrations...") + print("Running migrations...", flush=True) run_migrations() + print("Migrations complete", flush=True) - print("Scheduling jobs...") + print("Initializing database...", flush=True) + await initialize_database() + print("Database initialized", flush=True) + + print("Scheduling jobs...", flush=True) scheduler = schedule_jobs() + print("Jobs scheduled", flush=True) - print("Server started") + print("Server started", flush=True) yield - print("Shutting down...") + print("Shutting down...", flush=True) + + print("Stopping jobs...", flush=True) scheduler.shutdown(wait=False) + print("Jobs stopped", flush=True) - print("Server stopped") + print("Server stopped", flush=True) def make_app() -> FastAPI: # set debug=True to enable verbose logging - app = FastAPI(lifespan=lifespan, debug=True) + app = FastAPI(lifespan=lifespan) # URL Normalizer Middleware app.add_middleware(SlashNormalizerMiddleware) diff --git a/docker-compose.yaml b/docker-compose.yaml index 53239a7..5ba7965 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -58,3 +58,5 @@ services: environment: PGADMIN_DEFAULT_EMAIL: test@test.com PGADMIN_DEFAULT_PASSWORD: test + volumes: + - ./pgadmin_data:/var/lib/pgadmin diff --git a/frontend/src/lib/types/UserMetadata.ts b/frontend/src/lib/types/UserMetadata.ts index 1406a36..976a71c 100644 --- a/frontend/src/lib/types/UserMetadata.ts +++ b/frontend/src/lib/types/UserMetadata.ts @@ -1,5 +1,7 @@ export type UserMetadata = { username: string; - quota: number; + files_quota: number; + role: string; + size_quota: number; is_2fa_enabled: boolean; }; diff --git a/frontend/src/routes/user/account/+page.svelte b/frontend/src/routes/user/account/+page.svelte index d56194b..dbb007f 100644 --- a/frontend/src/routes/user/account/+page.svelte +++ b/frontend/src/routes/user/account/+page.svelte @@ -1,19 +1,20 @@ -Hi, {username}! -
- + Account + Username: {user?.username} + Role: {user?.role.toUpperCase()} + +
+ + 2FA {#if user?.is_2fa_enabled} @@ -86,5 +94,16 @@ {/if} +
+ + ShareX + +
+ + Quotas + {numberOfFiles} / {user?.files_quota ?? 0} Files + {storageUsed} / {user?.size_quota ?? 0} B ({((storageUsed ?? 0) / 1000 / 1000).toFixed(2)} MB) +
diff --git a/frontend/src/routes/user/home/+page.svelte b/frontend/src/routes/user/home/+page.svelte index 75c6892..5a6d511 100644 --- a/frontend/src/routes/user/home/+page.svelte +++ b/frontend/src/routes/user/home/+page.svelte @@ -47,6 +47,11 @@ return; } + if (error.response?.status === 403) { + alert('Your session has expired or you have reached your quota.'); + return; + } + alert('An error occurred while uploading the file.'); } };