Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Handle AWS marketplace fulfillment #107

Merged
merged 3 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fixbackend/all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from fixbackend.dispatcher.next_run_repository import NextRun # noqa
from fixbackend.metering.metering_repository import MeteringRecordEntity # noqa
from fixbackend.keyvalue.json_kv import JsonEntry # noqa
from fixbackend.subscription.subscription_repository import SubscriptionEntity # noqa
19 changes: 17 additions & 2 deletions fixbackend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ssl import Purpose, create_default_context
from typing import Any, AsyncIterator, ClassVar, Optional, Set, Tuple, cast

import boto3
import httpx
from arq import create_pool
from arq.connections import RedisSettings
Expand Down Expand Up @@ -57,6 +58,10 @@
from fixbackend.inventory.inventory_service import InventoryService
from fixbackend.inventory.router import inventory_router
from fixbackend.metering.metering_repository import MeteringRepository
from fixbackend.subscription.aws_marketplace import AwsMarketplaceHandler
from fixbackend.subscription.router import subscription_router
from fixbackend.subscription.subscription_repository import SubscriptionRepository
from fixbackend.workspaces.repository import WorkspaceRepositoryImpl
from fixbackend.workspaces.router import workspaces_router

log = logging.getLogger(__name__)
Expand All @@ -67,6 +72,7 @@
def fast_api_app(cfg: Config) -> FastAPI:
google = google_client(cfg)
github = github_client(cfg)
boto_session = boto3.Session(cfg.aws_access_key_id, cfg.aws_secret_access_key, region_name="us-east-1")
deps = FixDependencies()
ca_cert_path = str(cfg.ca_cert) if cfg.ca_cert else None
client_context = create_default_context(purpose=Purpose.SERVER_AUTH)
Expand Down Expand Up @@ -109,13 +115,21 @@ async def setup_teardown_application(app: FastAPI) -> AsyncIterator[None]:
deps.add(SN.next_run_repo, NextRunRepository(session_maker))
deps.add(SN.metering_repo, MeteringRepository(session_maker))
deps.add(SN.collect_queue, RedisCollectQueue(arq_redis))
deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker))
graph_db_access = deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker))
inventory_client = deps.add(SN.inventory_client, InventoryClient(cfg.inventory_url, http_client))
deps.add(SN.inventory, InventoryService(inventory_client))
deps.add(
SN.cloudaccount_publisher,
RedisStreamPublisher(readwrite_redis, "fixbackend::cloudaccount", f"fixbackend-{cfg.instance_id}"),
)
workspace_repo = deps.add(SN.workspace_repo, WorkspaceRepositoryImpl(session_maker, graph_db_access))
subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker))
deps.add(
SN.aws_marketplace_handler,
AwsMarketplaceHandler(
subscription_repo, workspace_repo, boto_session, cfg.args.aws_marketplace_metering_sqs_url
),
)

domain_events_stream_name = "fixbackend::domain_events"
domain_event_redis_publisher = deps.add(
Expand Down Expand Up @@ -268,8 +282,9 @@ async def domain_events_swagger_ui_html(req: Request) -> HTMLResponse:

api_router.include_router(cloud_accounts_callback_router(), prefix="/cloud", tags=["cloud_accounts"])
api_router.include_router(users_router(), prefix="/users", tags=["users"])
app.include_router(api_router)
api_router.include_router(subscription_router(deps))

app.include_router(api_router)
app.mount("/static", StaticFiles(directory="static"), name="static")

if cfg.static_assets:
Expand Down
4 changes: 3 additions & 1 deletion fixbackend/auth/depedencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Annotated
from typing import Annotated, Optional
from uuid import UUID

from fastapi import Depends
Expand All @@ -36,6 +36,8 @@

# the value below is a dependency itself
get_current_active_verified_user = fastapi_users.current_user(active=True, verified=True)
maybe_current_active_verified_user = fastapi_users.current_user(active=True, verified=True, optional=True)


AuthenticatedUser = Annotated[User, Depends(get_current_active_verified_user)]
OptionalAuthenticatedUser = Annotated[Optional[User], Depends(maybe_current_active_verified_user)]
19 changes: 18 additions & 1 deletion fixbackend/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from datetime import datetime

from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import func, text, DefaultClause
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

from fixbackend.sqlalechemy_extensions import UTCDateTime


class Base(DeclarativeBase):
Expand All @@ -21,3 +25,16 @@ class Base(DeclarativeBase):

All model classes should inherit from this class.
"""


class CreatedUpdatedMixin:
"""
Mixin to always have created_at and updated_at columns in a model.
"""

created_at: Mapped[datetime] = mapped_column(UTCDateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
UTCDateTime,
server_default=func.now(),
server_onupdate=DefaultClause(text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")),
)
3 changes: 3 additions & 0 deletions fixbackend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def parse_args(argv: Optional[Sequence[str]] = None) -> Namespace:
parser.add_argument("--aws-access-key-id", default=os.environ.get("AWS_ACCESS_KEY_ID", ""))
parser.add_argument("--aws-secret-access-key", default=os.environ.get("AWS_SECRET_ACCESS_KEY", ""))
parser.add_argument("--aws-region", default=os.environ.get("AWS_REGION", "us-east-1"))
parser.add_argument(
"--aws-marketplace-metering-sqs-url", default=os.environ.get("AWS_MARKETPLACE_METERING_SQS_URL")
)
parser.add_argument("--ca-cert", type=Path, default=os.environ.get("CA_CERT"))
parser.add_argument("--host-cert", type=Path, default=os.environ.get("HOST_CERT"))
parser.add_argument("--host-key", type=Path, default=os.environ.get("HOST_KEY"))
Expand Down
23 changes: 8 additions & 15 deletions fixbackend/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Annotated
from typing import Annotated, cast

from arq import ArqRedis
from fastapi.params import Depends
Expand All @@ -20,14 +20,12 @@
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncEngine

from fixbackend.collect.collect_queue import RedisCollectQueue
from fixbackend.graph_db.service import GraphDatabaseAccessManager
from fixbackend.inventory.inventory_client import InventoryClient
from fixbackend.inventory.inventory_service import InventoryService
from fixbackend.types import AsyncSessionMaker
from fixbackend.certificates.cert_store import CertificateStore
from fixbackend.domain_events.sender import DomainEventSender
from fixbackend.domain_events.sender_impl import DomainEventSenderImpl
from fixbackend.graph_db.service import GraphDatabaseAccessManager
from fixbackend.inventory.inventory_service import InventoryService
from fixbackend.types import AsyncSessionMaker


class ServiceNames:
Expand All @@ -50,33 +48,28 @@ class ServiceNames:
domain_event_redis_stream_publisher = "domain_event_redis_stream_publisher"
domain_event_sender = "domain_event_sender"
customerio_consumer = "customerio_consumer"
aws_marketplace_handler = "aws_marketplace_handler"
workspace_repo = "workspace_repo"
subscription_repo = "subscription_repo"


class FixDependencies(Dependencies):
@property
def arq_redis(self) -> ArqRedis:
return self.service(ServiceNames.arq_redis, ArqRedis)

@property
def collect_queue(self) -> RedisCollectQueue:
return self.service(ServiceNames.collect_queue, RedisCollectQueue)

@property
def async_engine(self) -> AsyncEngine:
return self.service(ServiceNames.async_engine, AsyncEngine)

@property
def session_maker(self) -> AsyncSessionMaker:
return self.service(ServiceNames.async_engine, AsyncSessionMaker) # type: ignore
return cast(AsyncSessionMaker, self.lookup[ServiceNames.session_maker])

@property
def inventory(self) -> InventoryService:
return self.service(ServiceNames.inventory, InventoryService)

@property
def inventory_client(self) -> InventoryClient:
return self.service(ServiceNames.inventory, InventoryClient)

@property
def readonly_redis(self) -> Redis:
return self.service(ServiceNames.readonly_redis, Redis)
Expand Down
1 change: 1 addition & 0 deletions fixbackend/ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
UserId = NewType("UserId", UUID)
CloudAccountId = NewType("CloudAccountId", UUID)
ExternalId = NewType("ExternalId", UUID)
PaymentMethodId = NewType("PaymentMethodId", UUID)
10 changes: 6 additions & 4 deletions fixbackend/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class SQSRawListener(Service):

def __init__(
self,
session: boto3.Session,
queue_url: str,
message_processor: Callable[[Json], Awaitable[Any]],
*,
Expand All @@ -62,7 +63,7 @@ def __init__(
:param wait_for_new_messages_to_arrive: The time to wait for new messages to arrive.
:param backoff: The backoff strategy to use when processing messages.
"""
self.sqs = boto3.client("sqs")
self.sqs = session.client("sqs")
self.queue_url = queue_url
self.message_processor = message_processor
self.consider_failed_after = consider_failed_after.total_seconds() if consider_failed_after else 30
Expand Down Expand Up @@ -111,11 +112,12 @@ class SQSListener(SQSRawListener):

def __init__(
self,
session: boto3.Session,
queue_url: str,
message_processor: Callable[[Json, MessageContext], Awaitable[Any]],
**kwargs: Any,
) -> None:
super().__init__(queue_url, self.__context_handler(message_processor), **kwargs)
super().__init__(session, queue_url, self.__context_handler(message_processor), **kwargs)

@staticmethod
def __context_handler(fn: Callable[[Json, MessageContext], Awaitable[Any]]) -> Callable[[Json], Awaitable[Any]]:
Expand All @@ -134,9 +136,9 @@ async def handler(message: Json) -> Any:


class SQSPublisher(Service):
def __init__(self, publisher_name: str, queue_url: str) -> None:
def __init__(self, session: boto3.Session, publisher_name: str, queue_url: str) -> None:
self.publisher_name = publisher_name
self.sqs = boto3.client("sqs")
self.sqs = session.client("sqs")
self.queue_url = queue_url

async def publish(self, kind: str, message: Json) -> None:
Expand Down
21 changes: 21 additions & 0 deletions fixbackend/subscription/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2023. Some Engineering
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
125 changes: 125 additions & 0 deletions fixbackend/subscription/aws_marketplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2023. Some Engineering
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
from datetime import timedelta
from typing import Optional
from uuid import uuid4

import boto3
from fixcloudutils.service import Service
from fixcloudutils.types import Json

from fixbackend.auth.models import User
from fixbackend.ids import PaymentMethodId
from fixbackend.sqs import SQSRawListener
from fixbackend.subscription.models import AwsMarketplaceSubscription, SubscriptionMethod
from fixbackend.subscription.subscription_repository import (
SubscriptionRepository,
)
from fixbackend.workspaces.repository import WorkspaceRepository

log = logging.getLogger(__name__)


class AwsMarketplaceHandler(Service):
def __init__(
self,
subscription_repo: SubscriptionRepository,
workspace_repo: WorkspaceRepository,
session: boto3.Session,
sqs_queue_url: Optional[str],
) -> None:
self.aws_marketplace_repo = subscription_repo
self.workspace_repo = workspace_repo
self.listener = (
SQSRawListener(
session,
sqs_queue_url,
self.handle_message,
consider_failed_after=timedelta(minutes=5),
max_nr_of_messages_in_one_batch=1,
wait_for_new_messages_to_arrive=timedelta(seconds=5),
)
if sqs_queue_url is not None
else None
)
self.marketplace_client = session.client("meteringmarketplace")

async def start(self) -> None:
if self.listener is not None:
await self.listener.start()

async def stop(self) -> None:
if self.listener is not None:
await self.listener.stop()

async def subscribed(self, user: User, token: str) -> Optional[SubscriptionMethod]:
log.info(f"AWS Marketplace subscription for user {user.email} with token {token}")
# Get the related data from AWS. Will throw in case of an error.
customer_data = self.marketplace_client.resolve_customer(RegistrationToken=token)
log.debug(f"AWS Marketplace user {user.email} got customer data: {customer_data}")
product_code = customer_data["ProductCode"]
customer_identifier = customer_data["CustomerIdentifier"]
customer_aws_account_id = customer_data["CustomerAWSAccountId"]

# get all workspaces of the user and use the first one if it is the only one
# if more than one workspace exists, the user needs to define the workspace in a later step
workspaces = await self.workspace_repo.list_workspaces(user.id)
workspace_id = workspaces[0].id if len(workspaces) == 1 else None

# only create a new subscription if there is no existing one
if existing := await self.aws_marketplace_repo.aws_marketplace_subscription(user.id, customer_identifier):
log.debug(f"AWS Marketplace user {user.email}: return existing subscription")
return existing
else:
subscription = AwsMarketplaceSubscription(
id=PaymentMethodId(uuid4()),
user_id=user.id,
workspace_id=workspace_id,
customer_identifier=customer_identifier,
customer_aws_account_id=customer_aws_account_id,
product_code=product_code,
active=True,
)
return await self.aws_marketplace_repo.create(subscription)

async def handle_message(self, message: Json) -> None:
# See: https://docs.aws.amazon.com/marketplace/latest/userguide/saas-notification.html
action = message["action"]
log.info(f"AWS Marketplace. Received message: {message}")
# customer_identifier = message["customer-identifier"]
# free_trial = message.get("isFreeTrialTermPresent", False)
match action:
case "subscribe-success":
# allow sending metering records
pass
case "subscribe-fail":
# wait for subscribe-success
pass
case "unsubscribe-pending":
# TODO: send metering records!
pass
case "unsubscribe-success":
# the user has unsubscribed
pass
case _:
raise ValueError(f"Unknown action: {action}")
Loading