From d628ee1187d64e69865b9f74299cadfbd4e06473 Mon Sep 17 00:00:00 2001 From: Matthias Veit Date: Tue, 17 Oct 2023 12:47:17 +0200 Subject: [PATCH 1/3] [feat] Handle AWS marketplace fulfillment --- fixbackend/all_models.py | 1 + fixbackend/app.py | 19 +- fixbackend/auth/depedencies.py | 4 +- fixbackend/base_model.py | 19 +- fixbackend/config.py | 3 + fixbackend/dependencies.py | 23 +- fixbackend/ids.py | 1 + fixbackend/sqs.py | 10 +- fixbackend/subscription/__init__.py | 21 ++ fixbackend/subscription/aws_marketplace.py | 121 +++++++++++ fixbackend/subscription/models.py | 41 ++++ fixbackend/subscription/router.py | 67 ++++++ .../subscription/subscription_repository.py | 94 ++++++++ fixbackend/workspaces/repository.py | 201 +++++++++--------- ...-10-17T09:03:56Z_session_entity_genesis.py | 57 +++++ static/openapi-events.yaml | 4 +- tests/fixbackend/conftest.py | 57 ++++- .../subscription/aws_marketplace_test.py | 41 ++++ .../subscription_repository_test.py | 50 +++++ 19 files changed, 706 insertions(+), 128 deletions(-) create mode 100644 fixbackend/subscription/__init__.py create mode 100644 fixbackend/subscription/aws_marketplace.py create mode 100644 fixbackend/subscription/models.py create mode 100644 fixbackend/subscription/router.py create mode 100644 fixbackend/subscription/subscription_repository.py create mode 100644 migrations/versions/2023-10-17T09:03:56Z_session_entity_genesis.py create mode 100644 tests/fixbackend/subscription/aws_marketplace_test.py create mode 100644 tests/fixbackend/subscription/subscription_repository_test.py diff --git a/fixbackend/all_models.py b/fixbackend/all_models.py index a086770a..11051f86 100644 --- a/fixbackend/all_models.py +++ b/fixbackend/all_models.py @@ -25,3 +25,4 @@ from fixbackend.dispatcher.next_run_repository import NextRun # noqa from fixbackend.metering.metering_repository import MeteringRecordEntity # noqa from fixbackend.keyvalue.json_kv import JsonEntry # noqa +from fixbackend.subscription.subscription_repository import SubscriptionEntity # noqa diff --git a/fixbackend/app.py b/fixbackend/app.py index 4041bcbe..604d9aee 100644 --- a/fixbackend/app.py +++ b/fixbackend/app.py @@ -19,6 +19,7 @@ from ssl import Purpose, create_default_context from typing import Any, AsyncIterator, ClassVar, Optional, Set, Tuple, cast +import boto3 import httpx from arq import create_pool from arq.connections import RedisSettings @@ -57,6 +58,10 @@ from fixbackend.inventory.inventory_service import InventoryService from fixbackend.inventory.router import inventory_router from fixbackend.metering.metering_repository import MeteringRepository +from fixbackend.subscription.aws_marketplace import AwsMarketplaceHandler +from fixbackend.subscription.router import subscription_router +from fixbackend.subscription.subscription_repository import SubscriptionRepository +from fixbackend.workspaces.repository import WorkspaceRepositoryImpl from fixbackend.workspaces.router import workspaces_router log = logging.getLogger(__name__) @@ -67,6 +72,7 @@ def fast_api_app(cfg: Config) -> FastAPI: google = google_client(cfg) github = github_client(cfg) + boto_session = boto3.Session(cfg.aws_access_key_id, cfg.aws_secret_access_key) deps = FixDependencies() ca_cert_path = str(cfg.ca_cert) if cfg.ca_cert else None client_context = create_default_context(purpose=Purpose.SERVER_AUTH) @@ -109,13 +115,21 @@ async def setup_teardown_application(app: FastAPI) -> AsyncIterator[None]: deps.add(SN.next_run_repo, NextRunRepository(session_maker)) deps.add(SN.metering_repo, MeteringRepository(session_maker)) deps.add(SN.collect_queue, RedisCollectQueue(arq_redis)) - deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker)) + graph_db_access = deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker)) inventory_client = deps.add(SN.inventory_client, InventoryClient(cfg.inventory_url, http_client)) deps.add(SN.inventory, InventoryService(inventory_client)) deps.add( SN.cloudaccount_publisher, RedisStreamPublisher(readwrite_redis, "fixbackend::cloudaccount", f"fixbackend-{cfg.instance_id}"), ) + workspace_repo = deps.add(SN.workspace_repo, WorkspaceRepositoryImpl(session_maker, graph_db_access)) + subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker)) + deps.add( + SN.aws_marketplace_handler, + AwsMarketplaceHandler( + subscription_repo, workspace_repo, boto_session, cfg.args.aws_marketplace_metering_sqs_url + ), + ) domain_events_stream_name = "fixbackend::domain_events" domain_event_redis_publisher = deps.add( @@ -268,8 +282,9 @@ async def domain_events_swagger_ui_html(req: Request) -> HTMLResponse: api_router.include_router(cloud_accounts_callback_router(), prefix="/cloud", tags=["cloud_accounts"]) api_router.include_router(users_router(), prefix="/users", tags=["users"]) - app.include_router(api_router) + api_router.include_router(subscription_router(deps)) + app.include_router(api_router) app.mount("/static", StaticFiles(directory="static"), name="static") if cfg.static_assets: diff --git a/fixbackend/auth/depedencies.py b/fixbackend/auth/depedencies.py index e03f8f39..45570614 100644 --- a/fixbackend/auth/depedencies.py +++ b/fixbackend/auth/depedencies.py @@ -19,7 +19,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Annotated +from typing import Annotated, Optional from uuid import UUID from fastapi import Depends @@ -36,6 +36,8 @@ # the value below is a dependency itself get_current_active_verified_user = fastapi_users.current_user(active=True, verified=True) +maybe_current_active_verified_user = fastapi_users.current_user(active=True, verified=True, optional=True) AuthenticatedUser = Annotated[User, Depends(get_current_active_verified_user)] +OptionalAuthenticatedUser = Annotated[Optional[User], Depends(maybe_current_active_verified_user)] diff --git a/fixbackend/base_model.py b/fixbackend/base_model.py index 8eb2a63f..ae0184de 100644 --- a/fixbackend/base_model.py +++ b/fixbackend/base_model.py @@ -11,8 +11,12 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from datetime import datetime -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy import func, text, DefaultClause +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +from fixbackend.sqlalechemy_extensions import UTCDateTime class Base(DeclarativeBase): @@ -21,3 +25,16 @@ class Base(DeclarativeBase): All model classes should inherit from this class. """ + + +class CreatedUpdatedMixin: + """ + Mixin to always have created_at and updated_at columns in a model. + """ + + created_at: Mapped[datetime] = mapped_column(UTCDateTime, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + UTCDateTime, + server_default=func.now(), + server_onupdate=DefaultClause(text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")), + ) diff --git a/fixbackend/config.py b/fixbackend/config.py index 86b0b976..5e0f627a 100644 --- a/fixbackend/config.py +++ b/fixbackend/config.py @@ -115,6 +115,9 @@ def parse_args(argv: Optional[Sequence[str]] = None) -> Namespace: parser.add_argument("--aws-access-key-id", default=os.environ.get("AWS_ACCESS_KEY_ID", "")) parser.add_argument("--aws-secret-access-key", default=os.environ.get("AWS_SECRET_ACCESS_KEY", "")) parser.add_argument("--aws-region", default=os.environ.get("AWS_REGION", "us-east-1")) + parser.add_argument( + "--aws-marketplace-metering-sqs-url", default=os.environ.get("AWS_MARKETPLACE_METERING_SQS_URL") + ) parser.add_argument("--ca-cert", type=Path, default=os.environ.get("CA_CERT")) parser.add_argument("--host-cert", type=Path, default=os.environ.get("HOST_CERT")) parser.add_argument("--host-key", type=Path, default=os.environ.get("HOST_KEY")) diff --git a/fixbackend/dependencies.py b/fixbackend/dependencies.py index d25b9bbc..c0d7b909 100644 --- a/fixbackend/dependencies.py +++ b/fixbackend/dependencies.py @@ -11,7 +11,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Annotated +from typing import Annotated, cast from arq import ArqRedis from fastapi.params import Depends @@ -20,14 +20,12 @@ from redis.asyncio import Redis from sqlalchemy.ext.asyncio import AsyncEngine -from fixbackend.collect.collect_queue import RedisCollectQueue -from fixbackend.graph_db.service import GraphDatabaseAccessManager -from fixbackend.inventory.inventory_client import InventoryClient -from fixbackend.inventory.inventory_service import InventoryService -from fixbackend.types import AsyncSessionMaker from fixbackend.certificates.cert_store import CertificateStore from fixbackend.domain_events.sender import DomainEventSender from fixbackend.domain_events.sender_impl import DomainEventSenderImpl +from fixbackend.graph_db.service import GraphDatabaseAccessManager +from fixbackend.inventory.inventory_service import InventoryService +from fixbackend.types import AsyncSessionMaker class ServiceNames: @@ -50,6 +48,9 @@ class ServiceNames: domain_event_redis_stream_publisher = "domain_event_redis_stream_publisher" domain_event_sender = "domain_event_sender" customerio_consumer = "customerio_consumer" + aws_marketplace_handler = "aws_marketplace_handler" + workspace_repo = "workspace_repo" + subscription_repo = "subscription_repo" class FixDependencies(Dependencies): @@ -57,26 +58,18 @@ class FixDependencies(Dependencies): def arq_redis(self) -> ArqRedis: return self.service(ServiceNames.arq_redis, ArqRedis) - @property - def collect_queue(self) -> RedisCollectQueue: - return self.service(ServiceNames.collect_queue, RedisCollectQueue) - @property def async_engine(self) -> AsyncEngine: return self.service(ServiceNames.async_engine, AsyncEngine) @property def session_maker(self) -> AsyncSessionMaker: - return self.service(ServiceNames.async_engine, AsyncSessionMaker) # type: ignore + return cast(AsyncSessionMaker, self.lookup[ServiceNames.session_maker]) @property def inventory(self) -> InventoryService: return self.service(ServiceNames.inventory, InventoryService) - @property - def inventory_client(self) -> InventoryClient: - return self.service(ServiceNames.inventory, InventoryClient) - @property def readonly_redis(self) -> Redis: return self.service(ServiceNames.readonly_redis, Redis) diff --git a/fixbackend/ids.py b/fixbackend/ids.py index d467c0e6..fea3e2f6 100644 --- a/fixbackend/ids.py +++ b/fixbackend/ids.py @@ -6,3 +6,4 @@ UserId = NewType("UserId", UUID) CloudAccountId = NewType("CloudAccountId", UUID) ExternalId = NewType("ExternalId", UUID) +PaymentMethodId = NewType("PaymentMethodId", UUID) diff --git a/fixbackend/sqs.py b/fixbackend/sqs.py index 54a31baf..ef7ff2b5 100644 --- a/fixbackend/sqs.py +++ b/fixbackend/sqs.py @@ -45,6 +45,7 @@ class SQSRawListener(Service): def __init__( self, + session: boto3.Session, queue_url: str, message_processor: Callable[[Json], Awaitable[Any]], *, @@ -62,7 +63,7 @@ def __init__( :param wait_for_new_messages_to_arrive: The time to wait for new messages to arrive. :param backoff: The backoff strategy to use when processing messages. """ - self.sqs = boto3.client("sqs") + self.sqs = session.client("sqs") self.queue_url = queue_url self.message_processor = message_processor self.consider_failed_after = consider_failed_after.total_seconds() if consider_failed_after else 30 @@ -111,11 +112,12 @@ class SQSListener(SQSRawListener): def __init__( self, + session: boto3.Session, queue_url: str, message_processor: Callable[[Json, MessageContext], Awaitable[Any]], **kwargs: Any, ) -> None: - super().__init__(queue_url, self.__context_handler(message_processor), **kwargs) + super().__init__(session, queue_url, self.__context_handler(message_processor), **kwargs) @staticmethod def __context_handler(fn: Callable[[Json, MessageContext], Awaitable[Any]]) -> Callable[[Json], Awaitable[Any]]: @@ -134,9 +136,9 @@ async def handler(message: Json) -> Any: class SQSPublisher(Service): - def __init__(self, publisher_name: str, queue_url: str) -> None: + def __init__(self, session: boto3.Session, publisher_name: str, queue_url: str) -> None: self.publisher_name = publisher_name - self.sqs = boto3.client("sqs") + self.sqs = session.client("sqs") self.queue_url = queue_url async def publish(self, kind: str, message: Json) -> None: diff --git a/fixbackend/subscription/__init__.py b/fixbackend/subscription/__init__.py new file mode 100644 index 00000000..e4ca030b --- /dev/null +++ b/fixbackend/subscription/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . diff --git a/fixbackend/subscription/aws_marketplace.py b/fixbackend/subscription/aws_marketplace.py new file mode 100644 index 00000000..b8e55a4d --- /dev/null +++ b/fixbackend/subscription/aws_marketplace.py @@ -0,0 +1,121 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +import logging +from datetime import timedelta +from typing import Optional +from uuid import uuid4 + +import boto3 +from fixcloudutils.service import Service +from fixcloudutils.types import Json + +from fixbackend.auth.models import User +from fixbackend.ids import PaymentMethodId +from fixbackend.sqs import SQSRawListener +from fixbackend.subscription.models import AwsMarketplaceSubscription, SubscriptionMethod +from fixbackend.subscription.subscription_repository import ( + SubscriptionRepository, +) +from fixbackend.workspaces.repository import WorkspaceRepository + +log = logging.getLogger(__name__) + + +class AwsMarketplaceHandler(Service): + def __init__( + self, + subscription_repo: SubscriptionRepository, + workspace_repo: WorkspaceRepository, + session: boto3.Session, + sqs_queue_url: Optional[str], + ) -> None: + self.aws_marketplace_repo = subscription_repo + self.workspace_repo = workspace_repo + self.listener = ( + SQSRawListener( + session, + sqs_queue_url, + self.handle_message, + consider_failed_after=timedelta(minutes=5), + max_nr_of_messages_in_one_batch=1, + wait_for_new_messages_to_arrive=timedelta(seconds=5), + ) + if sqs_queue_url is not None + else None + ) + self.marketplace_client = session.client("meteringmarketplace") + + async def start(self) -> None: + if self.listener is not None: + await self.listener.start() + + async def stop(self) -> None: + if self.listener is not None: + await self.listener.stop() + + async def subscribed(self, user: User, token: str) -> Optional[SubscriptionMethod]: + # Get the related data from AWS. Will throw in case of an error. + customer_data = self.marketplace_client.resolve_customer(RegistrationToken=token) + product_code = customer_data["ProductCode"] + customer_identifier = customer_data["CustomerIdentifier"] + customer_aws_account_id = customer_data["CustomerAWSAccountId"] + + # get all workspaces of the user and use the first one if it is the only one + # if more than one workspace exists, the user needs to define the workspace in a later step + workspaces = await self.workspace_repo.list_workspaces(user.id) + workspace_id = workspaces[0].id if len(workspaces) == 1 else None + + # only create a new subscription if there is no existing one + if existing := await self.aws_marketplace_repo.aws_marketplace_subscription(user.id, customer_identifier): + return existing + else: + subscription = AwsMarketplaceSubscription( + id=PaymentMethodId(uuid4()), + user_id=user.id, + workspace_id=workspace_id, + customer_identifier=customer_identifier, + customer_aws_account_id=customer_aws_account_id, + product_code=product_code, + active=True, + ) + return await self.aws_marketplace_repo.create(subscription) + + async def handle_message(self, message: Json) -> None: + # See: https://docs.aws.amazon.com/marketplace/latest/userguide/saas-notification.html + action = message["action"] + # customer_identifier = message["customer-identifier"] + # free_trial = message.get("isFreeTrialTermPresent", False) + match action: + case "subscribe-success": + # allow sending metering records + pass + case "subscribe-fail": + # wait for subscribe-success + pass + case "unsubscribe-pending": + # TODO: send metering records! + pass + case "unsubscribe-success": + # the user has unsubscribed + pass + case _: + raise ValueError(f"Unknown action: {action}") diff --git a/fixbackend/subscription/models.py b/fixbackend/subscription/models.py new file mode 100644 index 00000000..9a3d2abf --- /dev/null +++ b/fixbackend/subscription/models.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import Union, Optional + +from attr import frozen + +from fixbackend.ids import PaymentMethodId, WorkspaceId, UserId + + +@frozen +class AwsMarketplaceSubscription: + id: PaymentMethodId + user_id: Optional[UserId] + workspace_id: Optional[WorkspaceId] + customer_identifier: str + customer_aws_account_id: str + product_code: str + active: bool + + +# Multiple payment methods are possible, but for now we only support AWS Marketplace +SubscriptionMethod = Union[AwsMarketplaceSubscription] diff --git a/fixbackend/subscription/router.py b/fixbackend/subscription/router.py new file mode 100644 index 00000000..5d543d1b --- /dev/null +++ b/fixbackend/subscription/router.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from functools import partial +from typing import Callable + +from fastapi import APIRouter, Form, Cookie, Response +from starlette.responses import RedirectResponse + +from fixbackend.auth.depedencies import OptionalAuthenticatedUser, AuthenticatedUser +from fixbackend.dependencies import FixDependencies, ServiceNames as SN +from fixbackend.subscription.aws_marketplace import AwsMarketplaceHandler + + +def subscription_router(deps: FixDependencies) -> APIRouter: + router = APIRouter() + market_place_handler: Callable[[], AwsMarketplaceHandler] = partial( # type: ignore + deps.service, SN.aws_marketplace_handler, AwsMarketplaceHandler + ) + + # Attention: Changing this route will break the AWS Marketplace integration! + @router.post("/cloud/callbacks/aws/marketplace") + async def aws_marketplace_fulfillment( + maybe_user: OptionalAuthenticatedUser, x_amzn_marketplace_token: str = Form(alias="x-amzn-marketplace-token") + ) -> Response: + if user := maybe_user: + # add marketplace subscription + await market_place_handler().subscribed(user, x_amzn_marketplace_token) + # load the app and show a message + return RedirectResponse("/?message=aws-marketplace-subscribed") + else: + response = RedirectResponse("/auth/login?returnUrl=/subscriptions/aws/marketplace/add") + response.set_cookie("fix-aws-marketplace-token", x_amzn_marketplace_token, secure=True, httponly=True) + return response + + @router.get("/subscriptions/aws/marketplace/add", response_model=None) + async def aws_marketplace_fulfillment_after_login( + user: AuthenticatedUser, fix_aws_marketplace_token: str = Cookie(None) + ) -> Response: + if fix_aws_marketplace_token is not None: + await market_place_handler().subscribed(user, fix_aws_marketplace_token) + # load the app and show a message + response = RedirectResponse("/?message=aws-marketplace-subscribed") + response.set_cookie("fix-aws-marketplace-token", "", expires=0) # delete the cookie + return response + else: + raise ValueError("No AWS token found!") + + return router diff --git a/fixbackend/subscription/subscription_repository.py b/fixbackend/subscription/subscription_repository.py new file mode 100644 index 00000000..31a4bd7a --- /dev/null +++ b/fixbackend/subscription/subscription_repository.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from typing import Optional +from uuid import UUID + +from fastapi_users_db_sqlalchemy.generics import GUID +from sqlalchemy import String, Boolean, select, Index +from sqlalchemy.orm import Mapped, mapped_column + +from fixbackend.base_model import Base, CreatedUpdatedMixin +from fixbackend.ids import WorkspaceId, UserId, PaymentMethodId +from fixbackend.subscription.models import AwsMarketplaceSubscription +from fixbackend.types import AsyncSessionMaker + + +class SubscriptionEntity(CreatedUpdatedMixin, Base): + __tablename__ = "subscriptions" + __table_args__ = (Index("idx_aws_customer_user", "aws_customer_identifier", "user_id"),) + + id: Mapped[UUID] = mapped_column(GUID, primary_key=True) + user_id: Mapped[Optional[UserId]] = mapped_column(GUID, nullable=True, index=True) + workspace_id: Mapped[Optional[WorkspaceId]] = mapped_column(GUID, nullable=True, index=True) + aws_customer_identifier: Mapped[str] = mapped_column(String(128), nullable=False) + aws_customer_account_id: Mapped[str] = mapped_column(String(128), nullable=True, default="") + aws_product_code: Mapped[str] = mapped_column(String(128), nullable=False) + active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + + def to_model(self) -> AwsMarketplaceSubscription: + return AwsMarketplaceSubscription( + id=PaymentMethodId(self.id), + user_id=self.user_id, + workspace_id=self.workspace_id, + customer_identifier=self.aws_customer_identifier, + customer_aws_account_id=self.aws_customer_account_id, + product_code=self.aws_product_code, + active=self.active, + ) + + @staticmethod + def from_model(subscription: AwsMarketplaceSubscription) -> SubscriptionEntity: + return SubscriptionEntity( + id=subscription.id, + user_id=subscription.user_id, + workspace_id=subscription.workspace_id, + aws_customer_identifier=subscription.customer_identifier, + aws_customer_account_id=subscription.customer_aws_account_id, + aws_product_code=subscription.product_code, + active=subscription.active, + ) + + +class SubscriptionRepository: + def __init__(self, session_maker: AsyncSessionMaker) -> None: + self.session_maker = session_maker + + async def aws_marketplace_subscription( + self, user_id: UserId, customer_identifier: str + ) -> Optional[AwsMarketplaceSubscription]: + async with self.session_maker() as session: + stmt = select(SubscriptionEntity).where( + SubscriptionEntity.aws_customer_identifier == customer_identifier + and SubscriptionEntity.user_id == user_id + ) + if result := (await session.execute(stmt)).scalar_one_or_none(): + return result.to_model() + else: + return None + + async def create(self, subscription: AwsMarketplaceSubscription) -> AwsMarketplaceSubscription: + async with self.session_maker() as session: + session.add(SubscriptionEntity.from_model(subscription)) + await session.commit() + return subscription diff --git a/fixbackend/workspaces/repository.py b/fixbackend/workspaces/repository.py index a3f2092b..cf1b22f2 100644 --- a/fixbackend/workspaces/repository.py +++ b/fixbackend/workspaces/repository.py @@ -20,15 +20,14 @@ from fastapi import Depends from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from fixbackend.auth.models import User from fixbackend.auth.models import orm as auth_orm -from fixbackend.db import AsyncSessionDependency -from fixbackend.dependencies import FixDependency +from fixbackend.dependencies import FixDependency, ServiceNames from fixbackend.graph_db.service import GraphDatabaseAccessManager from fixbackend.ids import WorkspaceId, UserId +from fixbackend.types import AsyncSessionMaker from fixbackend.workspaces.models import Workspace, WorkspaceInvite, orm @@ -85,126 +84,136 @@ async def delete_invitation(self, invitation_id: uuid.UUID) -> None: class WorkspaceRepositoryImpl(WorkspaceRepository): - def __init__(self, session: AsyncSession, graph_db_access_manager: GraphDatabaseAccessManager) -> None: - self.session = session + def __init__(self, session_maker: AsyncSessionMaker, graph_db_access_manager: GraphDatabaseAccessManager) -> None: + self.session_maker = session_maker self.graph_db_access_manager = graph_db_access_manager async def create_workspace(self, name: str, slug: str, owner: User) -> Workspace: - workspace_id = WorkspaceId(uuid.uuid4()) - organization = orm.Organization(id=workspace_id, name=name, slug=slug) - owner_relationship = orm.OrganizationOwners(user_id=owner.id) - organization.owners.append(owner_relationship) - self.session.add(organization) - # create a database access object for this organization in the same transaction - await self.graph_db_access_manager.create_database_access(workspace_id, session=self.session) - - await self.session.commit() - await self.session.refresh(organization) - statement = ( - select(orm.Organization) - .where(orm.Organization.id == organization.id) - .options(selectinload(orm.Organization.owners), selectinload(orm.Organization.members)) - ) - results = await self.session.execute(statement) - org = results.unique().scalar_one() - return org.to_model() + async with self.session_maker() as session: + workspace_id = WorkspaceId(uuid.uuid4()) + organization = orm.Organization(id=workspace_id, name=name, slug=slug) + owner_relationship = orm.OrganizationOwners(user_id=owner.id) + organization.owners.append(owner_relationship) + session.add(organization) + # create a database access object for this organization in the same transaction + await self.graph_db_access_manager.create_database_access(workspace_id, session=session) + + await session.commit() + await session.refresh(organization) + statement = ( + select(orm.Organization) + .where(orm.Organization.id == organization.id) + .options(selectinload(orm.Organization.owners), selectinload(orm.Organization.members)) + ) + results = await session.execute(statement) + org = results.unique().scalar_one() + return org.to_model() async def get_workspace(self, workspace_id: WorkspaceId) -> Optional[Workspace]: - statement = select(orm.Organization).where(orm.Organization.id == workspace_id) - results = await self.session.execute(statement) - org = results.unique().scalar_one_or_none() - return org.to_model() if org else None + async with self.session_maker() as session: + statement = select(orm.Organization).where(orm.Organization.id == workspace_id) + results = await session.execute(statement) + org = results.unique().scalar_one_or_none() + return org.to_model() if org else None async def list_workspaces(self, user_id: UserId) -> Sequence[Workspace]: - statement = ( - select(orm.Organization).join(orm.OrganizationOwners).where(orm.OrganizationOwners.user_id == user_id) - ) - results = await self.session.execute(statement) - orgs = results.unique().scalars().all() - return [org.to_model() for org in orgs] + async with self.session_maker() as session: + statement = ( + select(orm.Organization).join(orm.OrganizationOwners).where(orm.OrganizationOwners.user_id == user_id) + ) + results = await session.execute(statement) + orgs = results.unique().scalars().all() + return [org.to_model() for org in orgs] async def add_to_workspace(self, workspace_id: WorkspaceId, user_id: UserId) -> None: - existing_membership = await self.session.get(orm.OrganizationMembers, (workspace_id, user_id)) - if existing_membership is not None: - # user is already a member of the organization, do nothing - return None - - member_relationship = orm.OrganizationMembers(user_id=user_id, organization_id=workspace_id) - self.session.add(member_relationship) - try: - await self.session.commit() - except IntegrityError: - raise ValueError("Can't add user to workspace.") + async with self.session_maker() as session: + existing_membership = await session.get(orm.OrganizationMembers, (workspace_id, user_id)) + if existing_membership is not None: + # user is already a member of the organization, do nothing + return None + + member_relationship = orm.OrganizationMembers(user_id=user_id, organization_id=workspace_id) + session.add(member_relationship) + try: + await session.commit() + except IntegrityError: + raise ValueError("Can't add user to workspace.") async def remove_from_workspace(self, workspace_id: WorkspaceId, user_id: UserId) -> None: - membership = await self.session.get(orm.OrganizationMembers, (workspace_id, user_id)) - if membership is None: - raise ValueError(f"User {uuid} is not a member of workspace {workspace_id}") - await self.session.delete(membership) - await self.session.commit() + async with self.session_maker() as session: + membership = await session.get(orm.OrganizationMembers, (workspace_id, user_id)) + if membership is None: + raise ValueError(f"User {uuid} is not a member of workspace {workspace_id}") + await session.delete(membership) + await session.commit() async def create_invitation(self, workspace_id: WorkspaceId, user_id: UserId) -> WorkspaceInvite: - user = await self.session.get(auth_orm.User, user_id) - organization = await self.get_workspace(workspace_id) + async with self.session_maker() as session: + user = await session.get(auth_orm.User, user_id) + organization = await self.get_workspace(workspace_id) - if user is None or organization is None: - raise ValueError(f"User {user_id} or organization {workspace_id} does not exist.") + if user is None or organization is None: + raise ValueError(f"User {user_id} or organization {workspace_id} does not exist.") - if user.id in [owner for owner in organization.owners]: - raise ValueError(f"User {user_id} is already an owner of workspace {workspace_id}") + if user.id in [owner for owner in organization.owners]: + raise ValueError(f"User {user_id} is already an owner of workspace {workspace_id}") - if user.id in [member for member in organization.members]: - raise ValueError(f"User {user_id} is already a member of workspace {workspace_id}") + if user.id in [member for member in organization.members]: + raise ValueError(f"User {user_id} is already a member of workspace {workspace_id}") - invite = orm.OrganizationInvite( - user_id=user_id, organization_id=workspace_id, expires_at=datetime.utcnow() + timedelta(days=7) - ) - self.session.add(invite) - await self.session.commit() - await self.session.refresh(invite) - return invite.to_model() + invite = orm.OrganizationInvite( + user_id=user_id, organization_id=workspace_id, expires_at=datetime.utcnow() + timedelta(days=7) + ) + session.add(invite) + await session.commit() + await session.refresh(invite) + return invite.to_model() async def get_invitation(self, invitation_id: uuid.UUID) -> Optional[WorkspaceInvite]: - statement = ( - select(orm.OrganizationInvite) - .where(orm.OrganizationInvite.id == invitation_id) - .options(selectinload(orm.OrganizationInvite.user)) - ) - results = await self.session.execute(statement) - invite = results.unique().scalar_one_or_none() - return invite.to_model() if invite else None + async with self.session_maker() as session: + statement = ( + select(orm.OrganizationInvite) + .where(orm.OrganizationInvite.id == invitation_id) + .options(selectinload(orm.OrganizationInvite.user)) + ) + results = await session.execute(statement) + invite = results.unique().scalar_one_or_none() + return invite.to_model() if invite else None async def list_invitations(self, workspace_id: WorkspaceId) -> Sequence[WorkspaceInvite]: - statement = ( - select(orm.OrganizationInvite) - .where(orm.OrganizationInvite.organization_id == workspace_id) - .options(selectinload(orm.OrganizationInvite.user), selectinload(orm.OrganizationInvite.organization)) - ) - results = await self.session.execute(statement) - invites = results.scalars().all() - return [invite.to_model() for invite in invites] + async with self.session_maker() as session: + statement = ( + select(orm.OrganizationInvite) + .where(orm.OrganizationInvite.organization_id == workspace_id) + .options(selectinload(orm.OrganizationInvite.user), selectinload(orm.OrganizationInvite.organization)) + ) + results = await session.execute(statement) + invites = results.scalars().all() + return [invite.to_model() for invite in invites] async def accept_invitation(self, invitation_id: uuid.UUID) -> None: - invite = await self.session.get(orm.OrganizationInvite, invitation_id) - if invite is None: - raise ValueError(f"Invitation {invitation_id} does not exist.") - if invite.expires_at < datetime.utcnow(): - raise ValueError(f"Invitation {invitation_id} has expired.") - membership = orm.OrganizationMembers(user_id=invite.user_id, organization_id=invite.organization_id) - self.session.add(membership) - await self.session.delete(invite) - await self.session.commit() + async with self.session_maker() as session: + invite = await session.get(orm.OrganizationInvite, invitation_id) + if invite is None: + raise ValueError(f"Invitation {invitation_id} does not exist.") + if invite.expires_at < datetime.utcnow(): + raise ValueError(f"Invitation {invitation_id} has expired.") + membership = orm.OrganizationMembers(user_id=invite.user_id, organization_id=invite.organization_id) + session.add(membership) + await session.delete(invite) + await session.commit() async def delete_invitation(self, invitation_id: uuid.UUID) -> None: - invite = await self.session.get(orm.OrganizationInvite, invitation_id) - if invite is None: - raise ValueError(f"Invitation {invitation_id} does not exist.") - await self.session.delete(invite) - await self.session.commit() + async with self.session_maker() as session: + invite = await session.get(orm.OrganizationInvite, invitation_id) + if invite is None: + raise ValueError(f"Invitation {invitation_id} does not exist.") + await session.delete(invite) + await session.commit() -async def get_workspace_repository(session: AsyncSessionDependency, fix: FixDependency) -> WorkspaceRepository: - return WorkspaceRepositoryImpl(session, fix.graph_database_access) +async def get_workspace_repository(fix: FixDependency) -> WorkspaceRepository: + return fix.service(ServiceNames.workspace_repo, WorkspaceRepositoryImpl) WorkspaceRepositoryDependency = Annotated[WorkspaceRepository, Depends(get_workspace_repository)] diff --git a/migrations/versions/2023-10-17T09:03:56Z_session_entity_genesis.py b/migrations/versions/2023-10-17T09:03:56Z_session_entity_genesis.py new file mode 100644 index 00000000..1cbbdd52 --- /dev/null +++ b/migrations/versions/2023-10-17T09:03:56Z_session_entity_genesis.py @@ -0,0 +1,57 @@ +"""session entity genesis + +Revision ID: 6e6db9c38194 +Revises: 9b482c179740 +Create Date: 2023-10-17 09:03:56.141766+00:00 + +""" +from typing import Sequence, Union + +import fastapi_users_db_sqlalchemy +from alembic import op +import sqlalchemy as sa +from sqlalchemy import DefaultClause, text + +from fixbackend.sqlalechemy_extensions import UTCDateTime + +# revision identifiers, used by Alembic. +revision: str = "6e6db9c38194" +down_revision: Union[str, None] = "9b482c179740" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "subscriptions", + sa.Column("id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False), + sa.Column("user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True), + sa.Column("workspace_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True), + sa.Column("aws_customer_identifier", sa.String(length=128), nullable=False), + sa.Column("aws_customer_account_id", sa.String(length=128), nullable=True), + sa.Column("aws_product_code", sa.String(length=128), nullable=False), + sa.Column("active", sa.Boolean(), nullable=False), + sa.Column("created_at", UTCDateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column( + "updated_at", + UTCDateTime(timezone=True), + server_default=sa.text("now()"), + server_onupdate=DefaultClause(text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("idx_aws_customer_user", "subscriptions", ["aws_customer_identifier", "user_id"], unique=False) + op.create_index(op.f("ix_subscriptions_user_id"), "subscriptions", ["user_id"], unique=False) + op.create_index(op.f("ix_subscriptions_workspace_id"), "subscriptions", ["workspace_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_subscriptions_workspace_id"), table_name="subscriptions") + op.drop_index(op.f("ix_subscriptions_user_id"), table_name="subscriptions") + op.drop_index("idx_aws_customer_user", table_name="subscriptions") + op.drop_table("subscriptions") + # ### end Alembic commands ### diff --git a/static/openapi-events.yaml b/static/openapi-events.yaml index 9713ec2c..1f5a231d 100644 --- a/static/openapi-events.yaml +++ b/static/openapi-events.yaml @@ -49,7 +49,7 @@ components: tenant_id: type: string description: Id of the defaut workspace. - + AwsAccountDiscovered: description: "This event emitted when a new aws account has been discovered." type: object @@ -65,4 +65,4 @@ components: description: Name of the cloud. aws_account_id: type: string - description: Id of the aws account. \ No newline at end of file + description: Id of the aws account. diff --git a/tests/fixbackend/conftest.py b/tests/fixbackend/conftest.py index 13c65600..8d1cc48e 100644 --- a/tests/fixbackend/conftest.py +++ b/tests/fixbackend/conftest.py @@ -16,13 +16,15 @@ import json from argparse import Namespace from asyncio import AbstractEventLoop -from typing import AsyncIterator, Iterator, List +from typing import AsyncIterator, Iterator, List, Any, Dict +from unittest.mock import patch import pytest from alembic.command import upgrade as alembic_upgrade from alembic.config import Config as AlembicConfig from arq import ArqRedis, create_pool from arq.connections import RedisSettings +from boto3 import Session as BotoSession from fastapi import FastAPI from fixcloudutils.types import Json from httpx import AsyncClient, MockTransport, Request, Response @@ -30,9 +32,9 @@ from sqlalchemy_utils import create_database, database_exists, drop_database from fixbackend.app import fast_api_app -from fixbackend.cloud_accounts.repository import CloudAccountRepository, CloudAccountRepositoryImpl from fixbackend.auth.db import get_user_repository from fixbackend.auth.models import User +from fixbackend.cloud_accounts.repository import CloudAccountRepository, CloudAccountRepositoryImpl from fixbackend.collect.collect_queue import RedisCollectQueue from fixbackend.config import Config, get_config from fixbackend.db import get_async_session @@ -43,9 +45,11 @@ from fixbackend.inventory.inventory_client import InventoryClient from fixbackend.inventory.inventory_service import InventoryService from fixbackend.metering.metering_repository import MeteringRepository +from fixbackend.subscription.aws_marketplace import AwsMarketplaceHandler +from fixbackend.subscription.subscription_repository import SubscriptionRepository +from fixbackend.types import AsyncSessionMaker from fixbackend.workspaces.models import Workspace from fixbackend.workspaces.repository import WorkspaceRepository, WorkspaceRepositoryImpl -from fixbackend.types import AsyncSessionMaker DATABASE_URL = "mysql+aiomysql://root@127.0.0.1:3306/fixbackend-testdb" # only used to create/drop the database @@ -152,6 +156,23 @@ async def session(db_engine: AsyncEngine) -> AsyncIterator[AsyncSession]: await connection.close() +@pytest.fixture +async def boto_answers() -> Dict[str, Any]: + return {} + + +@pytest.fixture +async def boto_session(boto_answers: Dict[str, Any]) -> AsyncIterator[BotoSession]: + def mock_make_api_call(client: Any, operation_name: str, kwarg: Any) -> Any: + if result := boto_answers.get(operation_name): + return result + else: + raise Exception(f"Please provide mocked answer for boto operation {operation_name} and arguments {kwarg}") + + with patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call): + yield BotoSession() + + @pytest.fixture def async_session_maker(session: AsyncSession) -> AsyncSessionMaker: def get_session() -> AsyncSession: @@ -282,6 +303,11 @@ async def cloud_account_repository(async_session_maker: AsyncSessionMaker) -> Cl return CloudAccountRepositoryImpl(async_session_maker) +@pytest.fixture +async def subscription_repository(async_session_maker: AsyncSessionMaker) -> SubscriptionRepository: + return SubscriptionRepository(async_session_maker) + + @pytest.fixture async def metering_repository(async_session_maker: AsyncSessionMaker) -> MeteringRepository: return MeteringRepository(async_session_maker) @@ -289,9 +315,18 @@ async def metering_repository(async_session_maker: AsyncSessionMaker) -> Meterin @pytest.fixture async def workspace_repository( - session: AsyncSession, graph_database_access_manager: GraphDatabaseAccessManager + async_session_maker: AsyncSessionMaker, graph_database_access_manager: GraphDatabaseAccessManager ) -> WorkspaceRepository: - return WorkspaceRepositoryImpl(session, graph_database_access_manager) + return WorkspaceRepositoryImpl(async_session_maker, graph_database_access_manager) + + +@pytest.fixture +async def aws_marketplace_handler( + subscription_repository: SubscriptionRepository, + workspace_repository: WorkspaceRepository, + boto_session: BotoSession, +) -> AwsMarketplaceHandler: + return AwsMarketplaceHandler(subscription_repository, workspace_repository, boto_session, None) @pytest.fixture @@ -315,10 +350,18 @@ async def dispatcher( @pytest.fixture async def fix_deps( - db_engine: AsyncEngine, graph_database_access_manager: GraphDatabaseAccessManager + db_engine: AsyncEngine, + graph_database_access_manager: GraphDatabaseAccessManager, + async_session_maker: AsyncSessionMaker, + workspace_repository: WorkspaceRepository, ) -> FixDependencies: return FixDependencies( - **{ServiceNames.async_engine: db_engine, ServiceNames.graph_db_access: graph_database_access_manager} + **{ + ServiceNames.async_engine: db_engine, + ServiceNames.graph_db_access: graph_database_access_manager, + ServiceNames.session_maker: async_session_maker, + ServiceNames.workspace_repo: workspace_repository, + } ) diff --git a/tests/fixbackend/subscription/aws_marketplace_test.py b/tests/fixbackend/subscription/aws_marketplace_test.py new file mode 100644 index 00000000..21b6a1e8 --- /dev/null +++ b/tests/fixbackend/subscription/aws_marketplace_test.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import Dict, Any + +from fixbackend.auth.models import User +from fixbackend.subscription.aws_marketplace import AwsMarketplaceHandler + + +async def test_handle_subscription( + aws_marketplace_handler: AwsMarketplaceHandler, user: User, boto_answers: Dict[str, Any] +) -> None: + boto_answers["ResolveCustomer"] = {"CustomerAWSAccountId": "1", "CustomerIdentifier": "2", "ProductCode": "3"} + result = await aws_marketplace_handler.subscribed(user, "123") + assert result is not None + assert result.customer_aws_account_id == "1" + assert result.customer_identifier == "2" + assert result.product_code == "3" + assert result.user_id == user.id + assert result.workspace_id is None # user does not have any workspaces yet + # subscribe again: will not create a new subscription + result2 = await aws_marketplace_handler.subscribed(user, "123") + assert result == result2 diff --git a/tests/fixbackend/subscription/subscription_repository_test.py b/tests/fixbackend/subscription/subscription_repository_test.py new file mode 100644 index 00000000..5ce61874 --- /dev/null +++ b/tests/fixbackend/subscription/subscription_repository_test.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023. Some Engineering +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from uuid import uuid4 + +from sqlalchemy.ext.asyncio import AsyncSession + +from fixbackend.ids import PaymentMethodId, UserId, WorkspaceId +from fixbackend.subscription.models import AwsMarketplaceSubscription +from fixbackend.subscription.subscription_repository import ( + SubscriptionRepository, + SubscriptionEntity, +) + + +async def test_create_entry(subscription_repository: SubscriptionRepository, session: AsyncSession) -> None: + id = PaymentMethodId(uuid4()) + user_id = UserId(uuid4()) + entity = AwsMarketplaceSubscription( + id=id, + user_id=user_id, + workspace_id=WorkspaceId(uuid4()), + customer_identifier="123", + customer_aws_account_id="123", + product_code="123", + active=True, + ) + await subscription_repository.create(entity) + assert await subscription_repository.aws_marketplace_subscription(user_id, entity.customer_identifier) is not None + assert await subscription_repository.aws_marketplace_subscription(user_id, "124") is None + result = await session.get(SubscriptionEntity, id) + assert result is not None From a6abc2793f0ecc4536b863cb3b103d5a6e2ff885 Mon Sep 17 00:00:00 2001 From: Matthias Veit Date: Tue, 17 Oct 2023 13:17:19 +0200 Subject: [PATCH 2/3] define the default region --- fixbackend/app.py | 2 +- tests/fixbackend/conftest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fixbackend/app.py b/fixbackend/app.py index 604d9aee..8aca86c4 100644 --- a/fixbackend/app.py +++ b/fixbackend/app.py @@ -72,7 +72,7 @@ def fast_api_app(cfg: Config) -> FastAPI: google = google_client(cfg) github = github_client(cfg) - boto_session = boto3.Session(cfg.aws_access_key_id, cfg.aws_secret_access_key) + boto_session = boto3.Session(cfg.aws_access_key_id, cfg.aws_secret_access_key, region_name="us-east-1") deps = FixDependencies() ca_cert_path = str(cfg.ca_cert) if cfg.ca_cert else None client_context = create_default_context(purpose=Purpose.SERVER_AUTH) diff --git a/tests/fixbackend/conftest.py b/tests/fixbackend/conftest.py index 8d1cc48e..b1b15cbd 100644 --- a/tests/fixbackend/conftest.py +++ b/tests/fixbackend/conftest.py @@ -170,7 +170,7 @@ def mock_make_api_call(client: Any, operation_name: str, kwarg: Any) -> Any: raise Exception(f"Please provide mocked answer for boto operation {operation_name} and arguments {kwarg}") with patch("botocore.client.BaseClient._make_api_call", new=mock_make_api_call): - yield BotoSession() + yield BotoSession(region_name="us-east-1") @pytest.fixture From 08fae1e1705b3bbffc48103fcbf9bba689287acd Mon Sep 17 00:00:00 2001 From: Matthias Veit Date: Tue, 17 Oct 2023 15:02:57 +0200 Subject: [PATCH 3/3] log relevant actions --- fixbackend/subscription/aws_marketplace.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fixbackend/subscription/aws_marketplace.py b/fixbackend/subscription/aws_marketplace.py index b8e55a4d..e2a91194 100644 --- a/fixbackend/subscription/aws_marketplace.py +++ b/fixbackend/subscription/aws_marketplace.py @@ -73,8 +73,10 @@ async def stop(self) -> None: await self.listener.stop() async def subscribed(self, user: User, token: str) -> Optional[SubscriptionMethod]: + log.info(f"AWS Marketplace subscription for user {user.email} with token {token}") # Get the related data from AWS. Will throw in case of an error. customer_data = self.marketplace_client.resolve_customer(RegistrationToken=token) + log.debug(f"AWS Marketplace user {user.email} got customer data: {customer_data}") product_code = customer_data["ProductCode"] customer_identifier = customer_data["CustomerIdentifier"] customer_aws_account_id = customer_data["CustomerAWSAccountId"] @@ -86,6 +88,7 @@ async def subscribed(self, user: User, token: str) -> Optional[SubscriptionMetho # only create a new subscription if there is no existing one if existing := await self.aws_marketplace_repo.aws_marketplace_subscription(user.id, customer_identifier): + log.debug(f"AWS Marketplace user {user.email}: return existing subscription") return existing else: subscription = AwsMarketplaceSubscription( @@ -102,6 +105,7 @@ async def subscribed(self, user: User, token: str) -> Optional[SubscriptionMetho async def handle_message(self, message: Json) -> None: # See: https://docs.aws.amazon.com/marketplace/latest/userguide/saas-notification.html action = message["action"] + log.info(f"AWS Marketplace. Received message: {message}") # customer_identifier = message["customer-identifier"] # free_trial = message.get("isFreeTrialTermPresent", False) match action: