Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logout feature #79

Merged
merged 5 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion backend/app/auth/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from typing import Annotated

from fastapi import Depends
Expand All @@ -7,7 +8,13 @@
from app.database import get_async_session

from .models import User
from .service import ALGORITHM, SECRET_KEY, get_user_by_username, oauth2_scheme
from .service import (
ALGORITHM,
SECRET_KEY,
get_user_by_username,
oauth2_scheme,
validate_token,
)

from .exceptions import credentials_exception

Expand All @@ -20,9 +27,16 @@ async def get_current_user(
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str | None = payload.get("sub")

if not (await validate_token(token, session)):
raise credentials_exception

if username is None:
raise credentials_exception

date_expir: int | None = payload.get("exp")
if date_expir is None or date_expir < datetime.datetime.now().timestamp():
raise credentials_exception

except JWTError as err:
raise credentials_exception from err

Expand Down
6 changes: 6 additions & 0 deletions backend/app/auth/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
headers={"WWW-Authenticate": "Bearer"},
)

expired_credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Expired credentials",
headers={"WWW-Authenticate": "Bearer"},
)

username_taken_exception = HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE, detail="Username already taken."
)
7 changes: 7 additions & 0 deletions backend/app/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ class Role(Base):
timestamp = Column(DateTime, nullable=False)

users = relationship("User", back_populates="role")


class LoggedInTokens(Base):
__tablename__ = "logged_in_tokens"

token = Column(String, primary_key=True, nullable=False)
expiration = Column(DateTime, nullable=False)
14 changes: 10 additions & 4 deletions backend/app/auth/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from app.database import get_async_session

from .schemas import Token, User
from .service import authenticate_user, create_access_token, create_user
from .service import authenticate_user, create_access_token, create_user, remove_token
from .exceptions import credentials_exception

from .service import oauth2_scheme

router = APIRouter(tags=["auth"])


Expand All @@ -23,13 +25,17 @@ async def login_for_access_token(
if not user:
raise credentials_exception

access_token = create_access_token(data={"sub": user.username})
access_token = await create_access_token(session, data={"sub": user.username})
return {"access_token": access_token, "token_type": "bearer"}


@router.post("/logout")
def logout() -> None:
...
async def logout(
token: Annotated[str, Depends(oauth2_scheme)],
session: Annotated[AsyncSession, Depends(get_async_session)],
) -> dict:
await remove_token(token, session)
return {"message": "Logged out"}


@router.post("/register")
Expand Down
45 changes: 40 additions & 5 deletions backend/app/auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
from fastapi.security import OAuth2PasswordBearer
from jose import jwt
from passlib.context import CryptContext
from sqlalchemy import and_, select
from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession

from ..database import get_async_session

from .constants import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM
from .models import User
from .models import LoggedInTokens, User
from .exceptions import username_taken_exception

import hashlib

SECRET_KEY = os.getenv("JWT_SECRET_KEY", "SECRET")

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
Expand Down Expand Up @@ -58,18 +62,41 @@ async def authenticate_user(
)


def create_access_token(
async def create_access_token(
session: AsyncSession,
data: dict,
expires_delta: timedelta | None = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)

to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

token = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

encrypted_token = hashlib.sha256(token.encode()).hexdigest()
session.add(LoggedInTokens(token=encrypted_token, expiration=expire))
await session.commit()

return token


async def remove_token(token: str, session: AsyncSession) -> None:
await session.execute(delete(LoggedInTokens).where(LoggedInTokens.token == token))
await session.commit()


async def validate_token(token: str, session: AsyncSession) -> bool:
async with session:
encrypted_token = hashlib.sha256(token.encode()).hexdigest()

valid_token = await session.execute(
select(LoggedInTokens).filter(LoggedInTokens.token == encrypted_token)
)
return valid_token.scalar_one_or_none() is not None


async def add_user(user: User, session: AsyncSession) -> None:
Expand All @@ -95,3 +122,11 @@ async def get_user_by_username(username: str, session: AsyncSession) -> User | N
async with session:
users = await session.execute(select(User).filter(User.username == username))
return users.scalar_one_or_none()


async def delete_inactive_tokens() -> None:
async for session in get_async_session():
await session.execute(
delete(LoggedInTokens).where(LoggedInTokens.expiration < datetime.utcnow())
)
await session.commit()
3 changes: 3 additions & 0 deletions backend/app/constants.py
amatanasovska marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MAX_DB_CONNECTION_ATTEMPTS = 10

INVALID_TOKENS_CLEAN_UP_INTERVAL_SECONDS = 60 * 30
21 changes: 21 additions & 0 deletions backend/app/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger

from .auth.service import delete_inactive_tokens

from .constants import INVALID_TOKENS_CLEAN_UP_INTERVAL_SECONDS


def configure_auth_jobs(scheduler: AsyncIOScheduler) -> AsyncIOScheduler:
scheduler.add_job(
delete_inactive_tokens,
trigger=IntervalTrigger(seconds=INVALID_TOKENS_CLEAN_UP_INTERVAL_SECONDS),
)


def schedule_jobs() -> AsyncIOScheduler:
scheduler = AsyncIOScheduler()
configure_auth_jobs(scheduler)
scheduler.start()

return scheduler
13 changes: 9 additions & 4 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import uvicorn
import time
import sys

from .auth.router import router as auth_router
from .file_transfer.constants import FILE_PATH
from .file_transfer.router import router as file_router

from app.settings import APISettings

MAX_CONNECTION_ATTEMPTS = 10
from .constants import MAX_DB_CONNECTION_ATTEMPTS

from .jobs import schedule_jobs


@asynccontextmanager
Expand All @@ -24,21 +26,24 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
for i in range(10):
try:
await initialize_database()
print("Database initialized...")
break
except Exception as ex:
print(ex)
time.sleep(1)
if i == MAX_CONNECTION_ATTEMPTS - 1:
if i == MAX_DB_CONNECTION_ATTEMPTS - 1:
sys.exit()

Path.mkdir(Path(FILE_PATH), exist_ok=True)

scheduler = schedule_jobs()

yield

print("Application is shutting down")
files = Path.rglob(Path(FILE_PATH), "*")
for f in files:
Path.unlink(f)
scheduler.shutdown(wait=False)


def make_app() -> FastAPI:
Expand Down
Loading
Loading