Skip to content

Commit

Permalink
feat: feeds operations API function (#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidgamez authored Dec 6, 2024
1 parent 868662b commit dc5700c
Show file tree
Hide file tree
Showing 62 changed files with 3,111 additions and 61 deletions.
12 changes: 11 additions & 1 deletion .github/workflows/api-deployer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ on:
description: Validator endpoint
required: true
type: string
OPERATIONS_OAUTH2_CLIENT_ID_1PASSWORD:
description: Oauth client id part of the authorization for the operations API
required: true
type: string

env:
python_version: '3.11'
Expand Down Expand Up @@ -255,6 +259,11 @@ jobs:
name: feeds_gen
path: api/src/feeds_gen/

- uses: actions/download-artifact@v4
with:
name: feeds_operations_gen
path: functions-python/operations_api/src/feeds_operations_gen/

- name: Build python functions
run: |
scripts/function-python-build.sh --all
Expand Down Expand Up @@ -290,11 +299,12 @@ jobs:
env:
OP_SERVICE_ACCOUNT_TOKEN: ${{ secrets.OP_SERVICE_ACCOUNT_TOKEN }}
TRANSITLAND_API_KEY: "op://rbiv7rvkkrsdlpcrz3bmv7nmcu/TansitLand API Key/credential"
OPERATIONS_OAUTH2_CLIENT_ID: ${{ inputs.OPERATIONS_OAUTH2_CLIENT_ID_1PASSWORD }}

- name: Populate Variables
run: |
scripts/replace-variables.sh -in_file infra/backend.conf.rename_me -out_file infra/backend.conf -variables BUCKET_NAME,OBJECT_PREFIX
scripts/replace-variables.sh -in_file infra/vars.tfvars.rename_me -out_file infra/vars.tfvars -variables PROJECT_ID,REGION,ENVIRONMENT,DEPLOYER_SERVICE_ACCOUNT,FEED_API_IMAGE_VERSION,OAUTH2_CLIENT_ID,OAUTH2_CLIENT_SECRET,GLOBAL_RATE_LIMIT_REQ_PER_MINUTE,ARTIFACT_REPO_NAME,VALIDATOR_ENDPOINT,TRANSITLAND_API_KEY
scripts/replace-variables.sh -in_file infra/vars.tfvars.rename_me -out_file infra/vars.tfvars -variables PROJECT_ID,REGION,ENVIRONMENT,DEPLOYER_SERVICE_ACCOUNT,FEED_API_IMAGE_VERSION,OAUTH2_CLIENT_ID,OAUTH2_CLIENT_SECRET,GLOBAL_RATE_LIMIT_REQ_PER_MINUTE,ARTIFACT_REPO_NAME,VALIDATOR_ENDPOINT,TRANSITLAND_API_KEY,OPERATIONS_OAUTH2_CLIENT_ID
- uses: hashicorp/setup-terraform@v3
with:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/api-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
GLOBAL_RATE_LIMIT_REQ_PER_MINUTE: ${{ vars.GLOBAL_RATE_LIMIT_REQ_PER_MINUTE }}
TF_APPLY: true
VALIDATOR_ENDPOINT: https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app
OPERATIONS_OAUTH2_CLIENT_ID_1PASSWORD: "op://rbiv7rvkkrsdlpcrz3bmv7nmcu/GCP_RETOOL_OAUTH2_CREDS/username"
secrets:
GCP_MOBILITY_FEEDS_SA_KEY: ${{ secrets.DEV_GCP_MOBILITY_FEEDS_SA_KEY }}
OAUTH2_CLIENT_ID: ${{ secrets.DEV_MOBILITY_FEEDS_OAUTH2_CLIENT_ID}}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/api-prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
GLOBAL_RATE_LIMIT_REQ_PER_MINUTE: ${{ vars.GLOBAL_RATE_LIMIT_REQ_PER_MINUTE }}
TF_APPLY: true
VALIDATOR_ENDPOINT: https://gtfs-validator-web-mbzoxaljzq-ue.a.run.app
OPERATIONS_OAUTH2_CLIENT_ID_1PASSWORD: "op://rbiv7rvkkrsdlpcrz3bmv7nmcu/GCP_RETOOL_OAUTH2_CREDS/username"
secrets:
GCP_MOBILITY_FEEDS_SA_KEY: ${{ secrets.PROD_GCP_MOBILITY_FEEDS_SA_KEY }}
OAUTH2_CLIENT_ID: ${{ secrets.PROD_MOBILITY_FEEDS_OAUTH2_CLIENT_ID}}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/api-qa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
TF_APPLY: true
GLOBAL_RATE_LIMIT_REQ_PER_MINUTE: ${{ vars.GLOBAL_RATE_LIMIT_REQ_PER_MINUTE }}
VALIDATOR_ENDPOINT: https://stg-gtfs-validator-web-mbzoxaljzq-ue.a.run.app
OPERATIONS_OAUTH2_CLIENT_ID_1PASSWORD: "op://rbiv7rvkkrsdlpcrz3bmv7nmcu/GCP_RETOOL_OAUTH2_CREDS/username"
secrets:
GCP_MOBILITY_FEEDS_SA_KEY: ${{ secrets.QA_GCP_MOBILITY_FEEDS_SA_KEY }}
OAUTH2_CLIENT_ID: ${{ secrets.DEV_MOBILITY_FEEDS_OAUTH2_CLIENT_ID}}
Expand Down
13 changes: 12 additions & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ jobs:
scripts/setup-openapi-generator.sh
scripts/api-gen.sh
- name: Generate Operations API code
run: |
scripts/api-operations-gen.sh
- name: Unit tests - API
shell: bash
run: |
Expand All @@ -104,9 +108,16 @@ jobs:
path: api/src/database_gen/
overwrite: true

- name: API generated code
- name: Upload API generated code
uses: actions/upload-artifact@v4
with:
name: feeds_gen
path: api/src/feeds_gen/
overwrite: true

- name: Upload Operations API generated code
uses: actions/upload-artifact@v4
with:
name: feeds_operations_gen
path: functions-python/operations_api/src/feeds_operations_gen/
overwrite: true
17 changes: 14 additions & 3 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
LocationTranslation,
get_feeds_location_translations,
)
from utils.logger import Logger

T = TypeVar("T", bound="BasicFeed")

Expand All @@ -59,11 +60,17 @@ class FeedsApiImpl(BaseFeedsApi):

APIFeedType = Union[BasicFeed, GtfsFeed, GtfsRTFeed]

def __init__(self) -> None:
self.logger = Logger("FeedsApiImpl").get_logger()

def get_feed(
self,
id: str,
) -> BasicFeed:
"""Get the specified feed from the Mobility Database."""
is_email_restricted = is_user_email_restricted()
self.logger.info(f"User email is restricted: {is_email_restricted}")

feed = (
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
.filter(Database().get_query_model(Feed))
Expand All @@ -72,7 +79,7 @@ def get_feed(
or_(
Feed.operational_status == None, # noqa: E711
Feed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
not is_email_restricted, # Allow all feeds to be returned if the user is not restricted
)
)
.first()
Expand All @@ -91,6 +98,8 @@ def get_feeds(
producer_url: str,
) -> List[BasicFeed]:
"""Get some (or all) feeds from the Mobility Database."""
is_email_restricted = is_user_email_restricted()
self.logger.info(f"User email is restricted: {is_email_restricted}")
feed_filter = FeedFilter(
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
)
Expand All @@ -100,7 +109,7 @@ def get_feeds(
or_(
Feed.operational_status == None, # noqa: E711
Feed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
not is_email_restricted, # Allow all feeds to be returned if the user is not restricted
)
)
# Results are sorted by provider
Expand Down Expand Up @@ -239,6 +248,8 @@ def get_gtfs_feeds(
subquery, dataset_latitudes, dataset_longitudes, bounding_filter_method
).subquery()

is_email_restricted = is_user_email_restricted()
self.logger.info(f"User email is restricted: {is_email_restricted}")
feed_query = (
Database()
.get_session()
Expand All @@ -248,7 +259,7 @@ def get_gtfs_feeds(
or_(
Gtfsfeed.operational_status == None, # noqa: E711
Gtfsfeed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
not is_email_restricted, # Allow all feeds to be returned if the user is not restricted
)
)
.options(
Expand Down
2 changes: 1 addition & 1 deletion api/src/feeds/impl/search_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status) ->
or_(
t_feedsearch.c.operational_status == None, # noqa: E711
t_feedsearch.c.operational_status != "wip",
is_user_email_restricted(),
not is_user_email_restricted(),
)
)
if feed_id:
Expand Down
13 changes: 8 additions & 5 deletions api/src/middleware/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def _extract_from_headers(self, headers: dict, scope: Scope) -> None:
def __repr__(self) -> str:
# Omitting sensitive data like email and jwt assertion
safe_properties = dict(
user_id=self.user_id, client_user_agent=self.client_user_agent, client_host=self.client_host
user_id=self.user_id,
client_user_agent=self.client_user_agent,
client_host=self.client_host,
email=self.user_email,
)
return f"request-context={safe_properties})"

Expand All @@ -108,8 +111,8 @@ def is_user_email_restricted() -> bool:
Check if an email's domain is restricted (e.g., for WIP visibility).
"""
request_context = get_request_context()
if not isinstance(request_context, RequestContext):
return True # Default to restricted
email = get_request_context().user_email
unrestricted_domains = ["@mobilitydata.org"]
if not request_context:
return True
email = request_context["user_email"]
unrestricted_domains = ["mobilitydata.org"]
return not email or not any(email.endswith(f"@{domain}") for domain in unrestricted_domains)
44 changes: 1 addition & 43 deletions api/tests/unittest/middleware/test_request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from starlette.datastructures import Headers

from middleware.request_context import RequestContext, get_request_context, _request_context, is_user_email_restricted
from middleware.request_context import RequestContext, get_request_context, _request_context


class TestRequestContext(unittest.TestCase):
Expand Down Expand Up @@ -54,45 +54,3 @@ def test_get_request_context(self):
request_context = RequestContext(MagicMock())
_request_context.set(request_context)
self.assertEqual(request_context, get_request_context())

def test_is_user_email_restricted(self):
self.assertTrue(is_user_email_restricted())
scope_instance = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": "GET",
"headers": [
(b"host", b"localhost"),
(b"x-forwarded-proto", b"https"),
(b"x-forwarded-for", b"client, proxy1"),
(b"server", b"server"),
(b"user-agent", b"user-agent"),
(b"x-goog-iap-jwt-assertion", b"jwt"),
(b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"),
(b"x-goog-authenticated-user-id", b"user_id"),
(b"x-goog-authenticated-user-email", b"email"),
],
"path": "/",
"raw_path": b"/",
"query_string": b"",
"client": ("127.0.0.1", 32767),
"server": ("127.0.0.1", 80),
}
request_context = RequestContext(scope=scope_instance)
_request_context.set(request_context)
self.assertTrue(is_user_email_restricted())
scope_instance["headers"] = [
(b"host", b"localhost"),
(b"x-forwarded-proto", b"https"),
(b"x-forwarded-for", b"client, proxy1"),
(b"server", b"server"),
(b"user-agent", b"user-agent"),
(b"x-goog-iap-jwt-assertion", b"jwt"),
(b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"),
(b"x-goog-authenticated-user-id", b"user_id"),
(b"x-goog-authenticated-user-email", b"[email protected]"),
]
request_context = RequestContext(scope=scope_instance)
_request_context.set(request_context)
self.assertTrue(is_user_email_restricted())
Loading

0 comments on commit dc5700c

Please sign in to comment.