-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add session routes to create new sessions and authentication helper
- Loading branch information
Showing
4 changed files
with
73 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from fastapi import HTTPException, status, Security | ||
from fastapi.security import APIKeyHeader | ||
from db.models import Session | ||
|
||
api_key_header = APIKeyHeader(name='X-Session-ID', auto_error=False) | ||
|
||
|
||
async def get_current_session(session_id: str = Security(api_key_header)) -> Session: | ||
try: | ||
session = Session.get(Session.public_id == session_id, Session.is_valid is True) | ||
return session | ||
except Session.DoesNotExist: | ||
raise HTTPException( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
detail="Invalid session or session expired", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from fastapi import APIRouter, Depends | ||
from pydantic import BaseModel | ||
|
||
from http_api.auth import get_current_session | ||
from db.models import Session | ||
|
||
router = APIRouter() | ||
|
||
|
||
class SessionResponse(BaseModel): | ||
public_id: str | ||
|
||
|
||
@router.post('/session', response_model=SessionResponse) | ||
async def create_session() -> SessionResponse: | ||
session = Session() | ||
session.save() | ||
return SessionResponse(public_id=session.public_id) | ||
|
||
|
||
@router.get('/session/me', dependencies=[Depends(get_current_session)]) | ||
async def get_current_session(session: Session = Depends(get_current_session)): | ||
return { | ||
'message': f"hello there your session is '{session.public_id}', it was started at {session.created_at}." | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from db.models import Session | ||
|
||
|
||
def test_user_can_be_granted_a_session_by_visiting_auth_url(api_client): | ||
response = api_client.post('/session') | ||
data = response.json() | ||
|
||
assert response.status_code == 200 | ||
assert 'public_id' in data | ||
assert Session.select().filter(Session.public_id == data['public_id']).exists() | ||
|
||
|
||
def test_protected_route_with_valid_session(api_client): | ||
valid_session = Session() | ||
valid_session.save() | ||
|
||
headers = {'X-Session-ID': valid_session.public_id} | ||
response = api_client.get('/session/me', headers=headers) | ||
data = response.json() | ||
|
||
assert response.status_code == 200 | ||
assert valid_session.public_id in data['message'] | ||
|
||
|
||
def test_protected_route_with_invalid_session(api_client): | ||
headers = {'X-Session-ID': 'something-invalid'} | ||
response = api_client.get('/session/me', headers=headers) | ||
|
||
assert response.status_code == 401 |