Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement new auth flow #138

Merged
merged 5 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion infrastructure/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ resource "aws_iam_role_policy_attachment" "lambda_basic_execution" {
policy_arn = "arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole"
}


# Function to build the JWKS URL
locals {
build_jwks_url = "${format("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", local.aws_region, var.userpool_id)}"
}
resource "aws_lambda_function" "workflows_api_handler" {
function_name = "${var.prefix}_workflows_api_handler"
role = aws_iam_role.lambda_execution_role.arn
Expand All @@ -213,6 +216,9 @@ resource "aws_lambda_function" "workflows_api_handler" {
RASTER_URL = var.raster_url
STAC_URL = var.stac_url
MWAA_ENV = "${var.prefix}-mwaa"
COGNITO_DOMAIN = var.cognito_domain
CLIENT_ID = var.client_id
JWKS_URL = local.build_jwks_url
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions infrastructure/terraform.tfvars.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ vector_security_group="${VECTOR_SECURITY_GROUP}"
vector_vpc="${VECTOR_VPC:-null}"
workflow_root_path="${WORKFLOW_ROOT_PATH}"
cloudfront_id="${CLOUDFRONT_ID}"
cognito_domain="${VEDA_COGNITO_DOMAIN}"
client_id="${VEDA_CLIENT_ID}"
userpool_id="${VEDA_USERPOOL_ID}"
11 changes: 11 additions & 0 deletions infrastructure/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,14 @@ variable "cloudfront_id" {
type = string
}

variable "cognito_domain" {
type = string
}

variable "client_id" {
type = string
}

variable "userpool_id" {
type = string
}
1 change: 1 addition & 0 deletions workflows_api/runtime/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ orjson>=3.6.8
psycopg[binary,pool]>=3.0.15
pydantic_ssm_settings>=0.2.0
pydantic>=1.9.0
pyjwt>=2.8.0
pypgstac==0.7.4
python-multipart==0.0.5
requests>=2.27.1
Expand Down
101 changes: 45 additions & 56 deletions workflows_api/runtime/src/auth.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import logging

import requests
import jwt
import src.config as config
from authlib.jose import JsonWebKey, JsonWebToken, JWTClaims, KeySet, errors
from cachetools import TTLCache, cached

from fastapi import Depends, HTTPException, security
from fastapi import Depends, HTTPException, Security, security, status
from typing import Annotated, Any, Dict


logger = logging.getLogger(__name__)

token_scheme = security.HTTPBearer()
settings = config.Settings()

oauth2_scheme = security.OAuth2AuthorizationCodeBearer(
authorizationUrl=settings.cognito_authorization_url,
tokenUrl=settings.cognito_token_url,
refreshUrl=settings.cognito_token_url,
)

jwks_client = jwt.PyJWKClient(settings.jwks_url)


def get_settings() -> config.Settings:
Expand All @@ -18,56 +26,37 @@ def get_settings() -> config.Settings:
return main.settings


def get_jwks_url(settings: config.Settings = Depends(get_settings)) -> str:
import boto3
import json
client = boto3.client("secretsmanager")
response = client.get_secret_value(SecretId=settings.workflows_client_secret_id)
secrets = json.loads(response["SecretString"])

return f"https://cognito-idp.{secrets['aws_region']}.amazonaws.com/{secrets['userpool_id']}/.well-known/jwks.json"


@cached(TTLCache(maxsize=1, ttl=3600))
def get_jwks(jwks_url: str = Depends(get_jwks_url)) -> KeySet:
with requests.get(jwks_url) as response:
response.raise_for_status()
return JsonWebKey.import_key_set(response.json())


def decode_token(
token: security.HTTPAuthorizationCredentials = Depends(token_scheme),
jwks: KeySet = Depends(get_jwks),
) -> JWTClaims:
"""
Validate & decode JWT
"""
def validated_token(
token_str: Annotated[str, Security(oauth2_scheme)],
required_scopes: security.SecurityScopes,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include type hint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Parse & validate token
try:
claims = JsonWebToken(["RS256"]).decode(
s=token.credentials,
key=jwks,
claims_options={
# # Example of validating audience to match expected value
# "aud": {"essential": True, "values": [APP_CLIENT_ID]}
},
token = jwt.decode(
token_str,
jwks_client.get_signing_key_from_jwt(token_str).key,
algorithms=["RS256"],
)

if "client_id" in claims:
# Insert Cognito's `client_id` into `aud` claim if `aud` claim is unset
claims.setdefault("aud", claims["client_id"])

claims.validate()
return claims
except errors.JoseError: #
logger.exception("Unable to decode token")
raise HTTPException(status_code=403, detail="Bad auth token")


def get_username(claims: security.HTTPBasicCredentials = Depends(decode_token)):
return claims["sub"]


def get_token(
token: security.HTTPAuthorizationCredentials = Depends(token_scheme),
):
return token.credentials
except jwt.exceptions.InvalidTokenError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
) from e

# Validate scopes (if required)
for scope in required_scopes.scopes:
if scope not in token["scope"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not enough permissions",
headers={
"WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"'
},
)

return token


def get_username(token: Annotated[Dict[Any, Any], Depends(validated_token)]):
botanical marked this conversation as resolved.
Show resolved Hide resolved
return token["username"]
24 changes: 22 additions & 2 deletions workflows_api/runtime/src/config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
from pydantic import BaseSettings, Field, constr
from pydantic import AnyHttpUrl, BaseSettings, Field, constr
from typing import Optional

AwsArn = constr(regex=r"^arn:aws:iam::\d{12}:role/.+")
AwsStepArn = constr(regex=r"^arn:aws:states:.+:\d{12}:stateMachine:.+")


class Settings(BaseSettings):
cognito_domain: AnyHttpUrl = Field(
description="The base url of the Cognito domain for authorization and token urls"
)
client_id: str = Field(description="The Cognito APP client ID")
data_access_role_arn: AwsArn = Field( # type: ignore
description="ARN of AWS Role used to validate access to S3 data"
)
jwks_url: Optional[AnyHttpUrl] = Field(
description="URL of JWKS, e.g. https://cognito-idp.{region}.amazonaws.com/{userpool_id}/.well-known/jwks.json"
)

workflows_client_secret_id: str = Field(description="The Cognito APP Secret that contains cognito creds")
workflows_client_secret_id: str = Field(
description="The Cognito APP Secret that contains cognito creds"
)
stage: str = Field(description="API stage")
workflow_root_path: str = Field(description="Root path of API")
ingest_url: str = Field(description="URL of ingest API")
raster_url: str = Field(description="URL of raster API")
stac_url: str = Field(description="URL of STAC API")
mwaa_env: str = Field(description="MWAA URL")

@property
def cognito_authorization_url(self) -> AnyHttpUrl:
"""Cognito user pool authorization url"""
return f"{self.cognito_domain}/oauth2/authorize"

@property
def cognito_token_url(self) -> AnyHttpUrl:
"""Cognito user pool token and refresh url"""
return f"{self.cognito_domain}/oauth2/token"

class Config:
env_file = ".env"
30 changes: 23 additions & 7 deletions workflows_api/runtime/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
root_path=settings.workflow_root_path,
openapi_url="/openapi.json",
docs_url="/docs",
swagger_ui_init_oauth={
"appName": "Cognito",
"clientId": settings.client_id,
"usePkceWithAuthorizationCodeGrant": True,
},
router = APIRouter(route_class=LoggerRouteHandler)
)

Expand All @@ -42,7 +47,7 @@
@workflows_app.post(
"/dataset/validate",
tags=["Dataset"],
dependencies=[Depends(auth.get_username)],
dependencies=[Depends(auth.validated_token)],
)
def validate_dataset(dataset: schemas.COGDataset):
# for all sample files in dataset, test access using raster /validate endpoint
Expand All @@ -69,7 +74,7 @@ def validate_dataset(dataset: schemas.COGDataset):
"/dataset/publish", tags=["Dataset"], dependencies=[Depends(auth.get_username)]
)
async def publish_dataset(
token=Depends(auth.get_token),
token=Depends(auth.oauth2_scheme),
dataset: Union[schemas.ZarrDataset, schemas.COGDataset] = Body(
..., discriminator="data_type"
),
Expand Down Expand Up @@ -102,10 +107,12 @@ async def publish_dataset(
response_model=schemas.WorkflowExecutionResponse,
tags=["Workflow-Executions"],
status_code=201,
dependencies=[Depends(auth.get_username)],
dependencies=[Depends(auth.validated_token)],
)
async def start_discovery_workflow_execution(
input: Union[schemas.S3Input, schemas.CmrInput]=Body(..., discriminator="discovery"),
input: Union[schemas.S3Input, schemas.CmrInput] = Body(
..., discriminator="discovery"
),
) -> schemas.WorkflowExecutionResponse:
"""
Triggers the ingestion workflow
Expand All @@ -117,7 +124,7 @@ async def start_discovery_workflow_execution(
"/discovery-executions/{workflow_execution_id}",
response_model=Union[schemas.ExecutionResponse, schemas.WorkflowExecutionResponse],
tags=["Workflow-Executions"],
dependencies=[Depends(auth.get_username)],
dependencies=[Depends(auth.validated_token)],
)
async def get_discovery_workflow_execution_status(
workflow_execution_id: str,
Expand All @@ -131,7 +138,7 @@ async def get_discovery_workflow_execution_status(
@workflows_app.get(
"/list-workflows",
tags=["Workflow-Executions"],
dependencies=[Depends(auth.get_username)],
dependencies=[Depends(auth.validated_token)],
)
async def get_workflow_list() -> (
Union[schemas.ExecutionResponse, schemas.WorkflowExecutionResponse]
Expand All @@ -145,7 +152,7 @@ async def get_workflow_list() -> (
@workflows_app.post(
"/cli-input",
tags=["Admin"],
dependencies=[Depends(auth.get_username)],
dependencies=[Depends(auth.validated_token)],
)
async def send_cli_command(cli_command: str):
return airflow_helpers.send_cli_command(cli_command)
Expand Down Expand Up @@ -178,6 +185,14 @@ async def add_correlation_id(request: Request, call_next):
logger.info("Request completed")
return response

@workflows_app.get("/auth/me", tags=["Auth"])
def who_am_i(claims=Depends(auth.validated_token)):
"""
Return claims for the provided JWT
"""
print(f"\n CLAIMS {claims}")
return claims

# exception handling
@workflows_app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
Expand All @@ -190,3 +205,4 @@ async def general_exception_handler(request, err):
metrics.add_metric(name="UnhandledExceptions", unit=MetricUnit.Count, value=1)
logger.exception(f"Unhandled exception: {err}")
return JSONResponse(status_code=500, content={"detail": "Internal Server Error"})

3 changes: 2 additions & 1 deletion workflows_api/runtime/src/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def collection_exists(collection_id: str) -> bool:
f"{response.status_code} response code from STAC API"
)


def time_density_is_valid(is_periodic: bool, time_density: Union[str, None]):
"""
Ensures that the time_density is valid based on the value of is_periodic
Expand All @@ -101,4 +102,4 @@ def time_density_is_valid(is_periodic: bool, time_density: Union[str, None]):

# Literal[str, None] doesn't quite work for null field inputs from a dict()
if time_density and time_density not in ["day", "month", "year"]:
raise ValueError("If set, time_density must be one of 'day, 'month' or 'year'")
raise ValueError("If set, time_density must be one of 'day, 'month' or 'year'")