Skip to content

Commit

Permalink
--only-account-ids
Browse files Browse the repository at this point in the history
  • Loading branch information
squeaky-pl committed Nov 14, 2024
1 parent d8c2160 commit 5eebb94
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions bin/recategorize-messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,32 @@


def get_namespace_query(
entities: list, *, only_account_id: int | None, only_types: set[AccountType]
entities: list,
*,
only_account_ids: Iterable[int] | None,
only_types: set[AccountType],
) -> Query:
discriminators = {account_type + "account" for account_type in only_types}
namespace_query = (
Query(entities)
.join(Namespace.account)
.filter(Account.discriminator.in_(discriminators))
)
if only_account_id:
if only_account_ids is not None:
namespace_query = namespace_query.filter(
Namespace.account_id == only_account_id
Namespace.account_id.in_(only_account_ids)
)

return namespace_query


def get_total_namespace_count(
*, only_account_id: int | None, only_types: set[AccountType] = ALL_ACCOUNT_TYPES
*,
only_account_ids: Iterable[int] | None,
only_types: set[AccountType] = ALL_ACCOUNT_TYPES,
) -> int:
namespace_query = get_namespace_query(
[Namespace], only_account_id=only_account_id, only_types=only_types
[Namespace], only_account_ids=only_account_ids, only_types=only_types
)

with global_session_scope() as session:
Expand All @@ -49,15 +54,15 @@ def get_total_namespace_count(

def yield_account_id_and_message_ids(
*,
only_account_id: int | None,
only_account_ids: Iterable[int] | None,
date_start: datetime.date | None,
date_end: datetime.date | None,
only_inbox: bool,
only_types: set[AccountType] = ALL_ACCOUNT_TYPES,
) -> Iterable[int, list[int]]:
namespace_query = get_namespace_query(
[Namespace.account_id, Namespace.id],
only_account_id=only_account_id,
only_account_ids=only_account_ids,
only_types=only_types,
)

Expand All @@ -84,16 +89,28 @@ def yield_account_id_and_message_ids(
yield account_id, message_ids


def split_integers_separated_by_common(
ctx, param, comma_separated_value
) -> list[int] | None:
if comma_separated_value is not None:
return [int(value) for value in comma_separated_value.split(",")]


@click.command()
@click.option("--date-start", type=click.DateTime(formats=["%Y-%m-%d"]), default=None)
@click.option("--date-end", type=click.DateTime(formats=["%Y-%m-%d"]), default=None)
@click.option("--only-account-id", type=int, default=None)
@click.option(
"--only-account-ids",
type=str,
default=None,
callback=split_integers_separated_by_common,
)
@click.option("--only-inbox", is_flag=True, default=False)
@click.option("--only-types", default=",".join(ALL_ACCOUNT_TYPES))
@click.option("--only-categories", default=None)
@click.option("--dry-run/--no-dry-run", default=True)
def main(
only_account_id: int | None,
only_account_ids: list[int] | None,
only_inbox: bool,
only_types: str,
only_categories: str | None,
Expand All @@ -102,11 +119,11 @@ def main(
dry_run: bool,
) -> None:
print(
f"Settings: {only_account_id=}, {only_inbox=}, {date_start=}, {date_end=}, {dry_run=}\n"
f"Settings: {only_account_ids=}, {only_inbox=}, {only_categories=}, {date_start=}, {date_end=}, {dry_run=}\n"
)

total_namespace_count = get_total_namespace_count(
only_account_id=only_account_id, only_types=set(only_types.split(","))
only_account_ids=only_account_ids, only_types=set(only_types.split(","))
)
print(f"{total_namespace_count=}\n")

Expand All @@ -115,7 +132,7 @@ def session_factory():

for progress, (account_id, message_ids) in enumerate(
yield_account_id_and_message_ids(
only_account_id=only_account_id,
only_account_ids=only_account_ids,
date_start=date_start,
date_end=date_end,
only_inbox=only_inbox,
Expand Down

0 comments on commit 5eebb94

Please sign in to comment.