Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collect based on tenant id #104

Merged
merged 17 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions fixbackend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

log = logging.getLogger(__name__)
API_PREFIX = "/api"
domain_events_stream_name = "fixbackend::domain_events"


# noinspection PyUnresolvedReferences
Expand Down Expand Up @@ -115,7 +116,6 @@ async def setup_teardown_application(app: FastAPI) -> AsyncIterator[None]:
RedisStreamPublisher(readwrite_redis, "fixbackend::cloudaccount", f"fixbackend-{cfg.instance_id}"),
)

domain_events_stream_name = "fixbackend::domain_events"
domain_event_redis_publisher = deps.add(
SN.domain_event_redis_stream_publisher,
RedisStreamPublisher(
Expand Down Expand Up @@ -167,9 +167,21 @@ async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]:
metering_repo = deps.add(SN.metering_repo, MeteringRepository(session_maker))
collect_queue = deps.add(SN.collect_queue, RedisCollectQueue(arq_redis))
db_access = deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker))
domain_event_redis_publisher = deps.add(
meln1k marked this conversation as resolved.
Show resolved Hide resolved
SN.domain_event_redis_stream_publisher,
RedisStreamPublisher(
rw_redis,
domain_events_stream_name,
"dispatching",
keep_unprocessed_messages_for=timedelta(days=7),
),
)
domain_event_sender = deps.add(SN.domain_event_sender, DomainEventSenderImpl(domain_event_redis_publisher))
deps.add(
SN.dispatching,
DispatcherService(rw_redis, cloud_accounts, next_run_repo, metering_repo, collect_queue, db_access),
DispatcherService(
rw_redis, cloud_accounts, next_run_repo, metering_repo, collect_queue, db_access, domain_event_sender
),
)

async with deps:
Expand Down
126 changes: 108 additions & 18 deletions fixbackend/dispatcher/dispatcher_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,50 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


import logging
import uuid
from datetime import timedelta, datetime
from typing import Any, Optional
from datetime import datetime, timedelta
from typing import Any, Awaitable, Optional, Literal, cast, Dict

from fixcloudutils.asyncio.periodic import Periodic
from fixcloudutils.redis.event_stream import RedisStreamListener, Json, MessageContext
from fixcloudutils.redis.event_stream import Json, MessageContext, RedisStreamListener
from fixcloudutils.service import Service
from fixcloudutils.util import utc, parse_utc_str
from fixcloudutils.util import parse_utc_str, utc
from redis.asyncio import Redis

from fixbackend.cloud_accounts.models import AwsCloudAccess, CloudAccount
from fixbackend.cloud_accounts.repository import CloudAccountRepository
from fixbackend.collect.collect_queue import CollectQueue, AccountInformation, AwsAccountInformation
from fixbackend.collect.collect_queue import AccountInformation, AwsAccountInformation, CollectQueue
from fixbackend.dispatcher.next_run_repository import NextRunRepository
from fixbackend.graph_db.service import GraphDatabaseAccessManager
from fixbackend.ids import CloudAccountId, WorkspaceId
from fixbackend.metering import MeteringRecord
from fixbackend.metering.metering_repository import MeteringRepository
from fixbackend.domain_events.events import TenantAccountsCollected
from fixbackend.domain_events.sender import DomainEventSender
from dataclasses import dataclass
import dataclasses
import json

log = logging.getLogger(__name__)


@dataclass(frozen=True)
class AccountCollectInProgress:
job_id: str
account_id: str
started_at: str
status: Literal["in_progress", "done"] = "in_progress"

def to_json_str(self) -> str:
return json.dumps(dataclasses.asdict(self))

@staticmethod
def from_json(value: str) -> "AccountCollectInProgress":
return AccountCollectInProgress(**json.loads(value))


class DispatcherService(Service):
def __init__(
self,
Expand All @@ -44,6 +65,7 @@ def __init__(
metering_repo: MeteringRepository,
collect_queue: CollectQueue,
access_manager: GraphDatabaseAccessManager,
domain_event_sender: DomainEventSender,
) -> None:
self.cloud_account_repo = cloud_account_repo
self.next_run_repo = next_run_repo
Expand All @@ -69,6 +91,8 @@ def __init__(
consider_failed_after=timedelta(minutes=5),
batch_size=1,
)
self.readwrite_redis = readwrite_redis
self.domaim_event_sender = domain_event_sender
aquamatthias marked this conversation as resolved.
Show resolved Hide resolved

async def start(self) -> Any:
await self.collect_result_listener.start()
Expand All @@ -85,7 +109,7 @@ async def process_cloud_account_changed_message(self, message: Json, context: Me
case "cloud_account_created":
await self.cloud_account_created(CloudAccountId(message["cloud_account_id"]))
case "cloud_account_deleted":
await self.cloud_account_deleted(CloudAccountId(message["cloud_account_id"]))
pass # we don't care about deleted accounts since the scheduling is done via the tenant id
case _:
log.error(f"Don't know how to handle messages of kind {context.kind}")
meln1k marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -96,18 +120,68 @@ async def process_collect_done_message(self, message: Json, context: MessageCont
case _:
log.info(f"Collect messages: will ignore messages of kine {context.kind}")

def _collect_progress_hash_key(self, workspace_id: WorkspaceId) -> str:
return f"dispatching:collect_jobs_in_progress:{workspace_id}"
meln1k marked this conversation as resolved.
Show resolved Hide resolved

async def maybe_send_domain_event(self, tenant_id: WorkspaceId, completed_job_id: str) -> None:
meln1k marked this conversation as resolved.
Show resolved Hide resolved
redis_set_key = self._collect_progress_hash_key(tenant_id)

async def get_redis_hash() -> Dict[bytes, bytes]:
result = self.readwrite_redis.hgetall(redis_set_key)
if isinstance(result, Awaitable):
result = await result
return cast(Dict[bytes, bytes], result)

def parse_collect_state(hash: Dict[bytes, bytes]) -> Dict[str, AccountCollectInProgress]:
return {k.decode(): AccountCollectInProgress.from_json(v.decode()) for k, v in hash.items()}

async def mark_job_as_done(progress: AccountCollectInProgress) -> AccountCollectInProgress:
progress = dataclasses.replace(progress, status="done")
result = self.readwrite_redis.hset(redis_set_key, completed_job_id, progress.to_json_str())
if isinstance(result, Awaitable):
await result
return progress

def all_jobs_finished(collect_state: Dict[str, AccountCollectInProgress]) -> bool:
return all(job.status == "done" for job in collect_state.values())

async def send_domain_event(collect_state: Dict[str, AccountCollectInProgress]) -> None:
collected_accounts = [CloudAccountId(uuid.UUID(job.account_id)) for job in collect_state.values()]
await self.domaim_event_sender.publish(TenantAccountsCollected(tenant_id, collected_accounts))

# fetch the redis hash
hash = await get_redis_hash()
if not hash:
log.error(f"Could not find any job context for tenant id {tenant_id}")
return
# parse it to dataclass
tenant_collect_state = parse_collect_state(hash)
if not (progress := tenant_collect_state.get(completed_job_id)):
log.error(f"Could not find job context for job id {completed_job_id}")
return
# mark the job as done
progress = await mark_job_as_done(progress)
tenant_collect_state[completed_job_id] = progress
# check if we can send the domain event
if not all_jobs_finished(tenant_collect_state):
return

# all jobs are finished, send domain event and delete the hash
await send_domain_event(tenant_collect_state)
await self.readwrite_redis.delete(redis_set_key)

async def collect_job_finished(self, message: Json) -> None:
job_id = message["job_id"]
task_id = message["task_id"]
workspace_id = message["tenant_id"]
workspace_id = WorkspaceId(uuid.UUID(message["tenant_id"]))
account_info = message["account_info"]
messages = message["messages"]
started_at = parse_utc_str(message["started_at"])
duration = message["duration"]
records = [
MeteringRecord(
id=uuid.uuid4(),
workspace_id=WorkspaceId(uuid.UUID(workspace_id)),
workspace_id=workspace_id,
cloud=account_details["cloud"],
account_id=account_id,
account_name=account_details["name"],
Expand All @@ -122,20 +196,20 @@ async def collect_job_finished(self, message: Json) -> None:
for account_id, account_details in account_info.items()
]
await self.metering_repo.add(records)
await self.maybe_send_domain_event(
workspace_id,
job_id,
)

async def cloud_account_created(self, cid: CloudAccountId) -> None:
if account := await self.cloud_account_repo.get(cid):
await self.trigger_collect(account)
# store an entry in the next_run table
next_run_at = await self.compute_next_run(account.workspace_id)
await self.next_run_repo.create(cid, next_run_at)
await self.next_run_repo.create(account.workspace_id, next_run_at)
meln1k marked this conversation as resolved.
Show resolved Hide resolved
else:
log.error("Received a message, that a cloud account is created, but it does not exist in the database")

async def cloud_account_deleted(self, cid: CloudAccountId) -> None:
# delete the entry from the scheduler table
await self.next_run_repo.delete(cid)

async def compute_next_run(self, tenant: WorkspaceId, last_run: Optional[datetime] = None) -> datetime:
now = utc()
delta = timedelta(hours=1) # TODO: compute delta dependent on the tenant.
Expand All @@ -149,6 +223,20 @@ async def compute_next_run(self, tenant: WorkspaceId, last_run: Optional[datetim
log.info(f"Next run for tenant: {tenant} is {result}")
return result

async def _add_collect_in_progress_account(
self, workspace_id: WorkspaceId, job_id: str, account_id: CloudAccountId
) -> None:
value = AccountCollectInProgress(job_id, str(account_id), utc().isoformat()).to_json_str()
meln1k marked this conversation as resolved.
Show resolved Hide resolved
result = self.readwrite_redis.hset(name=self._collect_progress_hash_key(workspace_id), key=job_id, value=value)
if isinstance(result, Awaitable):
await result
# cleanup after 24 hours just to be sure
result = self.readwrite_redis.expire(
name=self._collect_progress_hash_key(workspace_id), time=timedelta(hours=24)
)
if isinstance(result, Awaitable):
await result

async def trigger_collect(self, account: CloudAccount) -> None:
def account_information() -> Optional[AccountInformation]:
match account.access:
Expand All @@ -170,15 +258,17 @@ def account_information() -> Optional[AccountInformation]:
log.info(
f"Trigger collect for tenant: {account.workspace_id} and account: {account.id} with job_id: {job_id}"
)
await self._add_collect_in_progress_account(account.workspace_id, job_id, account.id)
await self.collect_queue.enqueue(db, ai, job_id=job_id)

async def schedule_next_runs(self) -> None:
now = utc()
async for cid, at in self.next_run_repo.older_than(now):
if account := await self.cloud_account_repo.get(cid):
await self.trigger_collect(account)
next_run_at = await self.compute_next_run(account.workspace_id, at)
await self.next_run_repo.update_next_run_at(cid, next_run_at)
async for workspace_id, at in self.next_run_repo.older_than(now):
if accounts := await self.cloud_account_repo.list_by_workspace_id(workspace_id):
for account in accounts:
await self.trigger_collect(account)
next_run_at = await self.compute_next_run(workspace_id, at)
await self.next_run_repo.update_next_run_at(workspace_id, next_run_at)
else:
log.error("Received a message, that a cloud account is created, but it does not exist in the database")
continue
29 changes: 19 additions & 10 deletions fixbackend/dispatcher/next_run_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,50 @@
from sqlalchemy.orm import Mapped, mapped_column

from fixbackend.base_model import Base
from fixbackend.ids import CloudAccountId
from fixbackend.ids import CloudAccountId, WorkspaceId
from fixbackend.sqlalechemy_extensions import UTCDateTime
from fixbackend.types import AsyncSessionMaker


# deprecated, do not use
class NextRun(Base):
__tablename__ = "next_run"

cloud_account_id: Mapped[CloudAccountId] = mapped_column(GUID, primary_key=True)
at: Mapped[datetime] = mapped_column(UTCDateTime, nullable=False, index=True)
meln1k marked this conversation as resolved.
Show resolved Hide resolved


class NextTenantRun(Base):
__tablename__ = "next_tenant_run"

tenant_id: Mapped[WorkspaceId] = mapped_column(GUID, primary_key=True)
at: Mapped[datetime] = mapped_column(UTCDateTime, nullable=False, index=True)


class NextRunRepository:
def __init__(self, session_maker: AsyncSessionMaker) -> None:
self.session_maker = session_maker

async def create(self, cid: CloudAccountId, next_run: datetime) -> None:
async def create(self, workspace_id: WorkspaceId, next_run: datetime) -> None:
async with self.session_maker() as session:
session.add(NextRun(cloud_account_id=cid, at=next_run))
next_tenant_run = NextTenantRun(tenant_id=workspace_id, at=next_run)
session.add(next_tenant_run)
await session.commit()

async def update_next_run_at(self, cid: CloudAccountId, next_run: datetime) -> None:
async def update_next_run_at(self, workspace_id: WorkspaceId, next_run: datetime) -> None:
async with self.session_maker() as session:
if nxt := await session.get(NextRun, cid):
if nxt := await session.get(NextTenantRun, workspace_id):
nxt.at = next_run
await session.commit()

async def delete(self, cid: CloudAccountId) -> None:
async def delete(self, workspace_id: WorkspaceId) -> None:
async with self.session_maker() as session:
results = await session.execute(select(NextRun).where(NextRun.cloud_account_id == cid))
results = await session.execute(select(NextTenantRun).where(NextTenantRun.tenant_id == workspace_id))
if run := results.unique().scalar():
await session.delete(run)
await session.commit()

async def older_than(self, at: datetime) -> AsyncIterator[Tuple[CloudAccountId, datetime]]:
async def older_than(self, at: datetime) -> AsyncIterator[Tuple[WorkspaceId, datetime]]:
async with self.session_maker() as session:
async for (entry,) in await session.stream(select(NextRun).where(NextRun.at < at)):
yield entry.cloud_account_id, entry.at
async for (entry,) in await session.stream(select(NextTenantRun).where(NextTenantRun.at < at)):
yield entry.tenant_id, entry.at
17 changes: 16 additions & 1 deletion fixbackend/domain_events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.


from typing import ClassVar
from typing import ClassVar, List
from attrs import frozen
from abc import ABC, abstractmethod
from fixbackend.ids import UserId, WorkspaceId, CloudAccountId
Expand Down Expand Up @@ -75,3 +75,18 @@ def to_json(self) -> Json:
@staticmethod
def from_json(json: Json) -> "AwsAccountDiscovered":
return converter.structure(json, AwsAccountDiscovered)


@frozen
class TenantAccountsCollected(Event):
kind: ClassVar[str] = "tenant_accounts_collected"

tenant_id: WorkspaceId
cloud_account_ids: List[CloudAccountId]

def to_json(self) -> Json:
return converter.unstructure(self) # type: ignore

@staticmethod
def from_json(json: Json) -> "TenantAccountsCollected":
return converter.structure(json, TenantAccountsCollected)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""add the next_tenant_run table

Revision ID: cd9ef4e05fdd
Revises: 9b482c179740
Create Date: 2023-10-13 11:23:36.653345+00:00

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from fastapi_users_db_sqlalchemy.generics import GUID
from fixbackend.sqlalechemy_extensions import UTCDateTime

# revision identifiers, used by Alembic.
revision: str = "cd9ef4e05fdd"
down_revision: Union[str, None] = "9b482c179740"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"next_tenant_run",
sa.Column("tenant_id", GUID, nullable=False),
sa.Column("at", UTCDateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("tenant_id"),
)
op.create_index(op.f("ix_next_tenant_run_at"), "next_tenant_run", ["at"], unique=False)
# ### end Alembic commands ###
Loading