Skip to content

Commit

Permalink
Updated cookie provision for localhost development
Browse files Browse the repository at this point in the history
  • Loading branch information
davenquinn committed Oct 21, 2024
1 parent d898e50 commit ea50536
Showing 1 changed file with 79 additions and 61 deletions.
140 changes: 79 additions & 61 deletions api/routes/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours
GROUP_TOKEN_LENGTH = 32
GROUP_TOKEN_SALT = b'$2b$12$yQrslvQGWDFjwmDBMURAUe' # Hardcode salt so hashes are consistent
GROUP_TOKEN_SALT = b"$2b$12$yQrslvQGWDFjwmDBMURAUe" # Hardcode salt so hashes are consistent


class Token(BaseModel):
Expand Down Expand Up @@ -59,10 +59,12 @@ class GroupTokenRequest(BaseModel):
expiration: int
group_id: int


access_token_key = "access_token"
# Coming soon
# refresh_token_key = "refresh_token"


class OAuth2AuthorizationCodeBearerWithCookie(OAuth2AuthorizationCodeBearer):
"""Tweak FastAPI's OAuth2AuthorizationCodeBearer to use a cookie instead of a header"""

Expand All @@ -78,43 +80,36 @@ async def __call__(self, request: Request) -> Optional[str]:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={
"WWW-Authenticate": "Bearer"
},
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None # pragma: nocover
return param


oauth2_scheme = OAuth2AuthorizationCodeBearerWithCookie(
authorizationUrl='/security/login',
tokenUrl="/security/callback",
auto_error=False
authorizationUrl="/security/login", tokenUrl="/security/callback", auto_error=False
)

http_bearer = HTTPBearer(auto_error=False)

router = APIRouter(
prefix="/security",
tags=["security"],
responses={
404: {
"description": "Not found"
}
},
responses={404: {"description": "Not found"}},
)


async def get_groups_from_header_token(
header_token: Annotated[HTTPAuthorizationCredentials, Depends(http_bearer)]) -> int | None:
header_token: Annotated[HTTPAuthorizationCredentials, Depends(http_bearer)]
) -> int | None:
"""Get the groups from the bearer token in the header"""

if header_token is None:
return None

token_hash = bcrypt.hashpw(header_token.credentials.encode(), GROUP_TOKEN_SALT)
token_hash_string = token_hash.decode('utf-8')
token_hash_string = token_hash.decode("utf-8")

engine = db.get_engine()
async_session = db.get_async_session(engine)
Expand All @@ -134,10 +129,7 @@ async def get_user(sub: str) -> schemas.User | None:
async_session = db.get_async_session(engine)

async with async_session() as session:
stmt = (
select(schemas.User)
.where(schemas.User.sub == sub)
)
stmt = select(schemas.User).where(schemas.User.sub == sub)

user = await session.scalar(stmt)

Expand Down Expand Up @@ -167,7 +159,9 @@ async def get_user_token_from_cookie(token: Annotated[str | None, Depends(oauth2
return None

try:
payload = jwt.decode(token, os.environ['SECRET_KEY'], algorithms=[os.environ['JWT_ENCRYPTION_ALGORITHM']])
payload = jwt.decode(
token, os.environ["SECRET_KEY"], algorithms=[os.environ["JWT_ENCRYPTION_ALGORITHM"]]
)
sub: str = payload.get("sub")
groups = payload.get("groups", [])
token_data = TokenData(sub=sub, groups=groups)
Expand All @@ -178,8 +172,8 @@ async def get_user_token_from_cookie(token: Annotated[str | None, Depends(oauth2


async def get_groups(
user_token_data: TokenData | None = Depends(get_user_token_from_cookie),
header_token: int | None = Depends(get_groups_from_header_token)
user_token_data: TokenData | None = Depends(get_user_token_from_cookie),
header_token: int | None = Depends(get_groups_from_header_token),
) -> list[int]:
"""Get the groups from both the cookies and header"""

Expand All @@ -196,7 +190,7 @@ async def get_groups(
async def has_access(groups: list[int] = Depends(get_groups)) -> bool:
"""Check if the user has access to the group"""

if 'ENVIRONMENT' in os.environ and os.environ['ENVIRONMENT'] == 'development':
if "ENVIRONMENT" in os.environ and os.environ["ENVIRONMENT"] == "development":
return True

return 1 in groups
Expand All @@ -210,10 +204,10 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({
"exp": expire
})
encoded_jwt = jwt.encode(to_encode, os.environ['SECRET_KEY'], algorithm=os.environ['JWT_ENCRYPTION_ALGORITHM'])
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode, os.environ["SECRET_KEY"], algorithm=os.environ["JWT_ENCRYPTION_ALGORITHM"]
)
return encoded_jwt


Expand All @@ -222,63 +216,70 @@ async def redirect_authorization(return_url: str = None):
"""Redirect to the authorization URL with the appropriate parameters"""

params = {
'scope': "openid profile email",
'client_id': os.environ['OAUTH_CLIENT_ID'],
'response_type': "code",
'redirect_uri': os.environ['REDIRECT_URI']
"scope": "openid profile email",
"client_id": os.environ["OAUTH_CLIENT_ID"],
"response_type": "code",
"redirect_uri": os.environ["REDIRECT_URI"],
}

if return_url is not None:
params['state'] = return_url
params["state"] = return_url

return RedirectResponse(os.environ['OAUTH_AUTHORIZATION_URL'] + "?" + urllib.parse.urlencode(params))
return RedirectResponse(
os.environ["OAUTH_AUTHORIZATION_URL"] + "?" + urllib.parse.urlencode(params)
)


@router.get("/callback")
async def redirect_callback(code: str, state: Optional[str] = None):
"""Exchange the code for a token and redirect to the state URL"""

uri = os.environ['REDIRECT_URI']
uri = os.environ["REDIRECT_URI"]
data = {
'grant_type': 'authorization_code',
'client_id': os.environ['OAUTH_CLIENT_ID'],
'client_secret': os.environ['OAUTH_CLIENT_SECRET'],
'code': code,
'redirect_uri': uri
"grant_type": "authorization_code",
"client_id": os.environ["OAUTH_CLIENT_ID"],
"client_secret": os.environ["OAUTH_CLIENT_SECRET"],
"code": code,
"redirect_uri": uri,
}

# Get the domain for the redirect URL
parsed_url = urllib.parse.urlparse(uri)
domain = parsed_url.netloc


async with aiohttp.ClientSession() as session:
async with session.post(os.environ['OAUTH_TOKEN_URL'], data=data) as token_response:
async with session.post(os.environ["OAUTH_TOKEN_URL"], data=data) as token_response:

if token_response.status != 200:
raise HTTPException(status_code=400, detail=f"Invalid code: {await token_response.text()} ")
raise HTTPException(
status_code=400, detail=f"Invalid code: {await token_response.text()} "
)

response_data = await token_response.json()

async with session.post(os.environ['OAUTH_USERINFO_URL'], data=response_data) as user_response:
async with session.post(
os.environ["OAUTH_USERINFO_URL"], data=response_data
) as user_response:

if user_response.status != 200:
raise HTTPException(status_code=400,
detail=f"Couldn't get user information: {await user_response.text()} ")
raise HTTPException(
status_code=400,
detail=f"Couldn't get user information: {await user_response.text()} ",
)

user_data = await user_response.json()

user = await get_user(user_data['sub'])
user = await get_user(user_data["sub"])

if user is None:

given_name = user_data.get('given_name') if user_data.get('given_name') else ""
family_name = user_data.get('family_name') if user_data.get('family_name') else ""
given_name = user_data.get("given_name") if user_data.get("given_name") else ""
family_name = (
user_data.get("family_name") if user_data.get("family_name") else ""
)

user = await create_user(
user_data['sub'],
f"{given_name} {family_name}",
user_data.get('email', '')
user_data["sub"], f"{given_name} {family_name}", user_data.get("email", "")
)

names = [group.name for group in user.groups]
Expand All @@ -293,40 +294,57 @@ async def redirect_callback(code: str, state: Optional[str] = None):
"sub": user.sub,
"role": role, # For PostgREST
"groups": [group.id for group in user.groups],
"group_names": names
"group_names": names,
}
)

response = RedirectResponse(state if state else "/")
redirect_domain = urllib.parse.urlparse(state).netloc

# Set a cookie for the API domain
response.set_cookie(key=access_token_key, value=f"Bearer {access_token}", httponly=True, samesite="lax",
domain=domain)
_domain = domain

# Overrides for local development
for override in ["localhost", "127.0.0.1"]:
if override in redirect_domain:
_domain = override

response.set_cookie(
access_token_key,
f"Bearer {access_token}",
domain=_domain,
httponly=True,
samesite="lax",
)

return response


@router.post("/token", response_model=AccessToken)
async def create_group_token(group_token_request: GroupTokenRequest,
user_token: TokenData = Depends(get_user_token_from_cookie)):
async def create_group_token(
group_token_request: GroupTokenRequest,
user_token: TokenData = Depends(get_user_token_from_cookie),
):
"""Get an access token for the current user"""

if group_token_request.group_id not in user_token.groups:
raise HTTPException(status_code=401,
detail=f"User cannot create tokens for group {group_token_request.group_id}")
raise HTTPException(
status_code=401,
detail=f"User cannot create tokens for group {group_token_request.group_id}",
)

engine = db.get_engine()

token = ''.join(secrets.choice(string.ascii_letters + string.digits) for i in range(GROUP_TOKEN_LENGTH))
token = "".join(
secrets.choice(string.ascii_letters + string.digits) for i in range(GROUP_TOKEN_LENGTH)
)
token_hash = bcrypt.hashpw(token.encode("utf-8"), GROUP_TOKEN_SALT)
token_hash_string = token_hash.decode('utf-8')
token_hash_string = token_hash.decode("utf-8")

await db.insert_access_token(
engine=engine,
token=token_hash_string,
group_id=group_token_request.group_id,
expiration=datetime.fromtimestamp(group_token_request.expiration)
expiration=datetime.fromtimestamp(group_token_request.expiration),
)

return AccessToken(group=group_token_request.group_id, token=token)
Expand Down

0 comments on commit ea50536

Please sign in to comment.