diff --git a/src/saturn_engine/core/__init__.py b/src/saturn_engine/core/__init__.py index 2fd1b7a0..1a8ba2e2 100644 --- a/src/saturn_engine/core/__init__.py +++ b/src/saturn_engine/core/__init__.py @@ -6,6 +6,9 @@ from .pipeline import ResourceUsed from .resource import Resource from .topic import TopicMessage +from .types import Cursor +from .types import JobId +from .types import MessageId __all__ = [ "PipelineInfo", @@ -16,4 +19,7 @@ "Resource", "ResourceUsed", "TopicMessage", + "MessageId", + "Cursor", + "JobId", ] diff --git a/src/saturn_engine/core/api.py b/src/saturn_engine/core/api.py index 3d8dddbe..49de5acb 100644 --- a/src/saturn_engine/core/api.py +++ b/src/saturn_engine/core/api.py @@ -9,8 +9,10 @@ from pydantic import BaseModel from pydantic import dataclasses -from saturn_engine.core import PipelineInfo # noqa: F401 # Reexport for public API -from saturn_engine.core import QueuePipeline +from .pipeline import PipelineInfo # noqa: F401 # Reexport for public API +from .pipeline import QueuePipeline +from .types import Cursor +from .types import JobId T = TypeVar("T") @@ -102,16 +104,16 @@ class InventoriesResponse(ListResponse[ComponentDefinition]): @dataclasses.dataclass class JobItem: - name: str + name: JobId started_at: datetime completed_at: Optional[datetime] = None - cursor: Optional[str] = None + cursor: Optional[Cursor] = None error: Optional[str] = None @dataclasses.dataclass class JobInput: - cursor: Optional[str] = None + cursor: Optional[Cursor] = None completed_at: Optional[datetime] = None error: Optional[str] = None diff --git a/src/saturn_engine/core/topic.py b/src/saturn_engine/core/topic.py index cdb57706..62e0bd43 100644 --- a/src/saturn_engine/core/topic.py +++ b/src/saturn_engine/core/topic.py @@ -4,6 +4,8 @@ import dataclasses import uuid +from .types import MessageId + @dataclasses.dataclass class TopicMessage: @@ -11,7 +13,9 @@ class TopicMessage: args: dict[str, Optional[Any]] #: Unique message id. - id: str = dataclasses.field(default_factory=lambda: str(uuid.uuid4())) + id: MessageId = dataclasses.field( + default_factory=lambda: MessageId(str(uuid.uuid4())) + ) #: Tags to attach to observability (logging, events, metrics and tracing). tags: dict[str, str] = dataclasses.field(default_factory=dict) diff --git a/src/saturn_engine/core/types.py b/src/saturn_engine/core/types.py new file mode 100644 index 00000000..f92bebb6 --- /dev/null +++ b/src/saturn_engine/core/types.py @@ -0,0 +1,5 @@ +import typing as t + +JobId = t.NewType("JobId", str) +MessageId = t.NewType("MessageId", str) +Cursor = t.NewType("Cursor", str) diff --git a/src/saturn_engine/models/job.py b/src/saturn_engine/models/job.py index 3b739282..919acd4c 100644 --- a/src/saturn_engine/models/job.py +++ b/src/saturn_engine/models/job.py @@ -10,6 +10,8 @@ from sqlalchemy.orm import relationship from sqlalchemy.sql.sqltypes import Text +from saturn_engine.core import Cursor +from saturn_engine.core import JobId from saturn_engine.core.api import JobItem from saturn_engine.utils import utcnow @@ -58,10 +60,10 @@ def __init__( def as_core_item(self) -> JobItem: return JobItem( - name=self.name, + name=JobId(self.name), completed_at=self.completed_at, started_at=self.started_at, - cursor=self.cursor, + cursor=Cursor(self.cursor) if self.cursor else None, error=self.error, ) diff --git a/src/saturn_engine/utils/tester/config/inventory_test.py b/src/saturn_engine/utils/tester/config/inventory_test.py index 5b84f14b..08fe9047 100644 --- a/src/saturn_engine/utils/tester/config/inventory_test.py +++ b/src/saturn_engine/utils/tester/config/inventory_test.py @@ -2,6 +2,7 @@ from pydantic import dataclasses +from saturn_engine.core import Cursor from saturn_engine.utils.declarative_config import BaseObject from saturn_engine.worker.inventories import Item @@ -16,7 +17,7 @@ class InventoryTestSpec: selector: InventorySelector items: list[Item] limit: Optional[int] = None - after: Optional[str] = None + after: Optional[Cursor] = None @dataclasses.dataclass diff --git a/src/saturn_engine/utils/tester/inventory_test.py b/src/saturn_engine/utils/tester/inventory_test.py index 0c4714e0..918344d2 100644 --- a/src/saturn_engine/utils/tester/inventory_test.py +++ b/src/saturn_engine/utils/tester/inventory_test.py @@ -3,6 +3,7 @@ import asyncio from saturn_engine.config import default_config_with_env +from saturn_engine.core import Cursor from saturn_engine.utils.options import asdict from saturn_engine.worker import work_factory from saturn_engine.worker.services.manager import ServicesManager @@ -17,7 +18,7 @@ def run_saturn_inventory( static_definitions: StaticDefinitions, inventory_name: str, limit: Optional[int] = None, - after: Optional[str] = None, + after: Optional[Cursor] = None, ) -> list[dict]: inventory_item = static_definitions.inventories[inventory_name] inventory = work_factory.build_inventory( diff --git a/src/saturn_engine/utils/tester/runner.py b/src/saturn_engine/utils/tester/runner.py index 177443fe..ba83ff9c 100644 --- a/src/saturn_engine/utils/tester/runner.py +++ b/src/saturn_engine/utils/tester/runner.py @@ -8,6 +8,7 @@ import click +from saturn_engine.core import Cursor from saturn_engine.utils.declarative_config import UncompiledObject from saturn_engine.utils.declarative_config import load_uncompiled_objects_from_path from saturn_engine.utils.options import fromdict @@ -154,7 +155,9 @@ def run( @click.option("--name", type=str, required=True) @click.option("--limit", type=int, required=True, default=1) @click.option("--after", type=str, required=False) -def show_inventory(topology: str, name: str, limit: int, after: Optional[str]) -> None: +def show_inventory( + topology: str, name: str, limit: int, after: Optional[Cursor] +) -> None: static_definitions = compile_static_definitions( load_uncompiled_objects_from_path(topology), ) diff --git a/src/saturn_engine/worker/inventories/batching.py b/src/saturn_engine/worker/inventories/batching.py index 25226b0d..e62b8e5d 100644 --- a/src/saturn_engine/worker/inventories/batching.py +++ b/src/saturn_engine/worker/inventories/batching.py @@ -5,6 +5,7 @@ import asyncstdlib as alib +from saturn_engine.core import Cursor from saturn_engine.core.api import ComponentDefinition from saturn_engine.worker.services import Services @@ -28,13 +29,13 @@ def __init__( self.inventory = build_inventory(options.inventory, services=services) - async def next_batch(self, after: Optional[str] = None) -> list[Item]: + async def next_batch(self, after: Optional[Cursor] = None) -> list[Item]: batch: list[Item] = await alib.list( alib.islice(self.inventory.iterate(after=after), self.batch_size) ) return batch - async def iterate(self, after: Optional[str] = None) -> AsyncIterator[Item]: + async def iterate(self, after: Optional[Cursor] = None) -> AsyncIterator[Item]: while True: batch = await self.next_batch(after) if not batch: diff --git a/src/saturn_engine/worker/inventories/chained.py b/src/saturn_engine/worker/inventories/chained.py index 5d20565d..11567137 100644 --- a/src/saturn_engine/worker/inventories/chained.py +++ b/src/saturn_engine/worker/inventories/chained.py @@ -1,5 +1,7 @@ from collections.abc import AsyncIterator +from saturn_engine.core import Cursor + from . import Inventory from .multi import MultiInventory from .multi import MultiItems @@ -7,7 +9,7 @@ class ChainedInventory(MultiInventory): async def inventories_iterator( - self, *, inventories: list[tuple[str, Inventory]], after: dict[str, str] + self, *, inventories: list[tuple[str, Inventory]], after: dict[str, Cursor] ) -> AsyncIterator[MultiItems]: start_inventory = 0 if after: diff --git a/src/saturn_engine/worker/inventories/dummy.py b/src/saturn_engine/worker/inventories/dummy.py index 8c0e1570..5664acab 100644 --- a/src/saturn_engine/worker/inventories/dummy.py +++ b/src/saturn_engine/worker/inventories/dummy.py @@ -2,6 +2,9 @@ import dataclasses +from saturn_engine.core import Cursor +from saturn_engine.core import MessageId + from . import Inventory from . import Item @@ -14,9 +17,9 @@ class Options: def __init__(self, options: Options, **kwrags: object) -> None: self.count = options.count or 1000 - async def next_batch(self, after: Optional[str] = None) -> list[Item]: + async def next_batch(self, after: Optional[Cursor] = None) -> list[Item]: n = int(after) + 1 if after is not None else 0 n_end = min(n + 100, self.count) if n_end == n: return [] - return [Item(id=str(i), args={"n": i}) for i in range(n, n_end)] + return [Item(id=MessageId(str(i)), args={"n": i}) for i in range(n, n_end)] diff --git a/src/saturn_engine/worker/inventories/joined.py b/src/saturn_engine/worker/inventories/joined.py index 73512844..cac4be16 100644 --- a/src/saturn_engine/worker/inventories/joined.py +++ b/src/saturn_engine/worker/inventories/joined.py @@ -1,5 +1,7 @@ from collections.abc import AsyncIterator +from saturn_engine.core import Cursor + from . import Inventory from .multi import MultiInventory from .multi import MultiItems @@ -7,7 +9,7 @@ class JoinedInventory(MultiInventory): async def inventories_iterator( - self, *, inventories: list[tuple[str, Inventory]], after: dict[str, str] + self, *, inventories: list[tuple[str, Inventory]], after: dict[str, Cursor] ) -> AsyncIterator[MultiItems]: name, inventory = inventories[0] last_cursor = after.pop(name, None) diff --git a/src/saturn_engine/worker/inventories/joined_sub.py b/src/saturn_engine/worker/inventories/joined_sub.py index 766804f3..abc0bdee 100644 --- a/src/saturn_engine/worker/inventories/joined_sub.py +++ b/src/saturn_engine/worker/inventories/joined_sub.py @@ -5,6 +5,8 @@ import json from collections.abc import AsyncIterator +from saturn_engine.core import Cursor +from saturn_engine.core import MessageId from saturn_engine.core.api import ComponentDefinition from saturn_engine.worker.services import Services @@ -42,7 +44,7 @@ def __init__(self, options: Options, services: Services, **kwargs: object) -> No options.sub_inventory, services=services ) - async def iterate(self, after: t.Optional[str] = None) -> AsyncIterator[Item]: + async def iterate(self, after: t.Optional[Cursor] = None) -> AsyncIterator[Item]: cursors = json.loads(after) if after else {} inventory_cursor = cursors.get(self.inventory_name) sub_inventory_cursor = cursors.get(self.sub_inventory_name) @@ -67,25 +69,29 @@ async def iterate(self, after: t.Optional[str] = None) -> AsyncIterator[Item]: } yield Item( - id=json.dumps( - { - self.inventory_name: item.id, - self.sub_inventory_name: sub_item.id, - } + id=MessageId( + json.dumps( + { + self.inventory_name: item.id, + self.sub_inventory_name: sub_item.id, + } + ) ), - cursor=json.dumps( - { - **( - {self.inventory_name: inventory_cursor} - if inventory_cursor - else {} - ), - **( - {self.sub_inventory_name: sub_item.cursor} - if sub_item.cursor - else {} - ), - } + cursor=Cursor( + json.dumps( + { + **( + {self.inventory_name: inventory_cursor} + if inventory_cursor + else {} + ), + **( + {self.sub_inventory_name: sub_item.cursor} + if sub_item.cursor + else {} + ), + } + ) ), args=args, ) diff --git a/src/saturn_engine/worker/inventories/multi.py b/src/saturn_engine/worker/inventories/multi.py index 8465a303..001820b1 100644 --- a/src/saturn_engine/worker/inventories/multi.py +++ b/src/saturn_engine/worker/inventories/multi.py @@ -7,6 +7,8 @@ import json from collections.abc import AsyncIterator +from saturn_engine.core import Cursor +from saturn_engine.core import MessageId from saturn_engine.core.api import ComponentDefinition from saturn_engine.worker.services import Services @@ -48,7 +50,7 @@ def __init__(self, options: Options, services: Services, **kwargs: object) -> No (inventory.name, build_inventory(inventory, services=services)) ) - async def iterate(self, after: Optional[str] = None) -> AsyncIterator[Item]: + async def iterate(self, after: Optional[Cursor] = None) -> AsyncIterator[Item]: cursors = json.loads(after) if after else {} async for item in self.inventories_iterator( @@ -64,12 +66,14 @@ async def iterate(self, after: Optional[str] = None) -> AsyncIterator[Item]: self.alias: args, } yield Item( - id=json.dumps(item.ids), cursor=json.dumps(item.cursors), args=args + id=MessageId(json.dumps(item.ids)), + cursor=Cursor(json.dumps(item.cursors)), + args=args, ) @abc.abstractmethod async def inventories_iterator( - self, *, inventories: list[tuple[str, Inventory]], after: dict[str, str] + self, *, inventories: list[tuple[str, Inventory]], after: dict[str, Cursor] ) -> AsyncIterator[MultiItems]: raise NotImplementedError() yield diff --git a/src/saturn_engine/worker/inventories/static.py b/src/saturn_engine/worker/inventories/static.py index e5d18c43..4515b12c 100644 --- a/src/saturn_engine/worker/inventories/static.py +++ b/src/saturn_engine/worker/inventories/static.py @@ -2,6 +2,9 @@ import dataclasses +from saturn_engine.core import Cursor +from saturn_engine.core import MessageId + from . import Inventory from . import Item @@ -14,9 +17,9 @@ class Options: def __init__(self, options: Options, **kwargs: object) -> None: self.items = options.items - async def next_batch(self, after: t.Optional[str] = None) -> list[Item]: + async def next_batch(self, after: t.Optional[Cursor] = None) -> list[Item]: begin = int(after) + 1 if after else 0 return [ - Item(id=str(i), args=args) + Item(id=MessageId(str(i)), args=args) for i, args in enumerate(self.items[begin:], start=begin) ] diff --git a/src/saturn_engine/worker/inventory.py b/src/saturn_engine/worker/inventory.py index 228637b5..8e6f122e 100644 --- a/src/saturn_engine/worker/inventory.py +++ b/src/saturn_engine/worker/inventory.py @@ -11,6 +11,8 @@ import asyncstdlib as alib +from saturn_engine.core import Cursor +from saturn_engine.core import MessageId from saturn_engine.utils.options import OptionsSchema MISSING = object() @@ -19,14 +21,16 @@ @dataclasses.dataclass class Item: args: dict[str, t.Any] - id: str = dataclasses.field(default_factory=lambda: str(uuid.uuid4())) - cursor: t.Optional[str] = MISSING # type: ignore[assignment] + id: MessageId = dataclasses.field( + default_factory=lambda: MessageId(str(uuid.uuid4())) + ) + cursor: t.Optional[Cursor] = MISSING # type: ignore[assignment] tags: dict[str, str] = dataclasses.field(default_factory=dict) metadata: dict[str, t.Any] = dataclasses.field(default_factory=dict) def __post_init__(self) -> None: if self.cursor is MISSING: - self.cursor = self.id + self.cursor = Cursor(self.id) class MaxRetriesError(Exception): @@ -52,11 +56,11 @@ class Inventory(abc.ABC, OptionsSchema): name: str @abc.abstractmethod - async def next_batch(self, after: t.Optional[str] = None) -> list[Item]: + async def next_batch(self, after: t.Optional[Cursor] = None) -> list[Item]: """Returns a batch of item with id greater than `after`.""" raise NotImplementedError() - async def iterate(self, after: t.Optional[str] = None) -> AsyncIterator[Item]: + async def iterate(self, after: t.Optional[Cursor] = None) -> AsyncIterator[Item]: """Returns an iterable that goes over the whole inventory.""" retries_count = 0 while True: @@ -87,20 +91,20 @@ class IteratorInventory(Inventory): def __init__(self, *, batch_size: t.Optional[int] = None, **kwargs: object) -> None: self.batch_size = batch_size or 10 - async def next_batch(self, after: t.Optional[str] = None) -> list[Item]: + async def next_batch(self, after: t.Optional[Cursor] = None) -> list[Item]: batch: list[Item] = await alib.list( alib.islice(self.iterate(after=after), self.batch_size) ) return batch @abc.abstractmethod - async def iterate(self, after: t.Optional[str] = None) -> AsyncIterator[Item]: + async def iterate(self, after: t.Optional[Cursor] = None) -> AsyncIterator[Item]: raise NotImplementedError() yield class BlockingInventory(Inventory, abc.ABC): - async def next_batch(self, after: t.Optional[str] = None) -> list[Item]: + async def next_batch(self, after: t.Optional[Cursor] = None) -> list[Item]: return await asyncio.get_event_loop().run_in_executor( None, self.next_batch_blocking, @@ -108,20 +112,20 @@ async def next_batch(self, after: t.Optional[str] = None) -> list[Item]: ) @abc.abstractmethod - def next_batch_blocking(self, after: t.Optional[str] = None) -> list[Item]: + def next_batch_blocking(self, after: t.Optional[Cursor] = None) -> list[Item]: raise NotImplementedError() class SubInventory(abc.ABC, OptionsSchema): @abc.abstractmethod async def next_batch( - self, source_item: Item, after: t.Optional[str] = None + self, source_item: Item, after: t.Optional[Cursor] = None ) -> list[Item]: """Returns a batch of item with id greater than `after`.""" raise NotImplementedError() async def iterate( - self, source_item: Item, after: t.Optional[str] = None + self, source_item: Item, after: t.Optional[Cursor] = None ) -> AsyncIterator[Item]: """Returns an iterable that goes over the whole inventory.""" retries_count = 0 @@ -153,7 +157,7 @@ def logger(self) -> logging.Logger: class BlockingSubInventory(SubInventory, abc.ABC): async def next_batch( - self, source_item: Item, after: t.Optional[str] = None + self, source_item: Item, after: t.Optional[Cursor] = None ) -> list[Item]: return await asyncio.get_event_loop().run_in_executor( None, @@ -164,6 +168,6 @@ async def next_batch( @abc.abstractmethod def next_batch_blocking( - self, source_item: Item, after: t.Optional[str] = None + self, source_item: Item, after: t.Optional[Cursor] = None ) -> list[Item]: raise NotImplementedError() diff --git a/src/saturn_engine/worker/job/__init__.py b/src/saturn_engine/worker/job/__init__.py index c9b916b0..8dde5a4b 100644 --- a/src/saturn_engine/worker/job/__init__.py +++ b/src/saturn_engine/worker/job/__init__.py @@ -5,6 +5,8 @@ import asyncstdlib as alib +from saturn_engine.core import Cursor +from saturn_engine.core import MessageId from saturn_engine.core import TopicMessage from saturn_engine.utils.log import getLogger from saturn_engine.utils.options import OptionsSchema @@ -16,11 +18,11 @@ class JobStore(OptionsSchema, abc.ABC): @abc.abstractmethod - async def load_cursor(self) -> Optional[str]: + async def load_cursor(self) -> Optional[Cursor]: pass @abc.abstractmethod - async def save_cursor(self, *, after: str) -> None: + async def save_cursor(self, *, after: Cursor) -> None: pass @abc.abstractmethod @@ -54,7 +56,7 @@ async def run(self) -> AsyncGenerator[TopicOutput, None]: after = item.cursor done = False message = TopicMessage( - id=str(item.id), + id=MessageId(item.id), args=item.args, tags=item.tags, metadata=item.metadata, diff --git a/src/saturn_engine/worker/job/api.py b/src/saturn_engine/worker/job/api.py index 7d1e8db2..c20a07b6 100644 --- a/src/saturn_engine/worker/job/api.py +++ b/src/saturn_engine/worker/job/api.py @@ -3,6 +3,7 @@ import aiohttp +from saturn_engine.core import Cursor from saturn_engine.core.api import JobInput from saturn_engine.core.api import JobResponse from saturn_engine.utils import urlcat @@ -29,18 +30,18 @@ def __init__( self.base_url = base_url self.logger = getLogger(__name__, self) - self.after: Optional[str] = None + self.after: Optional[Cursor] = None self.throttle_save_cursor = DelayedThrottle(self.delayed_save_cursor, delay=1) - async def load_cursor(self) -> Optional[str]: + async def load_cursor(self) -> Optional[Cursor]: async with self.http_client.get(self.job_url) as response: return fromdict(await response.json(), JobResponse).data.cursor - async def save_cursor(self, *, after: str) -> None: + async def save_cursor(self, *, after: Cursor) -> None: self.after = after await self.throttle_save_cursor(after=self.after) - async def delayed_save_cursor(self, *, after: str) -> None: + async def delayed_save_cursor(self, *, after: Cursor) -> None: try: json = asdict(JobInput(cursor=after)) async with self.http_client.put(self.job_url, json=json) as response: diff --git a/src/saturn_engine/worker/job/memory.py b/src/saturn_engine/worker/job/memory.py index 0fd80fa0..6bb31407 100644 --- a/src/saturn_engine/worker/job/memory.py +++ b/src/saturn_engine/worker/job/memory.py @@ -1,19 +1,21 @@ from typing import Any from typing import Optional +from saturn_engine.core import Cursor + from . import JobStore class MemoryJobStore(JobStore): def __init__(self, **kwargs: Any) -> None: - self.after: Optional[str] = None + self.after: Optional[Cursor] = None self.completed = False self.error: Optional[Exception] = None - async def load_cursor(self) -> Optional[str]: + async def load_cursor(self) -> Optional[Cursor]: return self.after - async def save_cursor(self, *, after: str) -> None: + async def save_cursor(self, *, after: Cursor) -> None: self.after = after async def set_completed(self) -> None: diff --git a/src/saturn_engine/worker/topics/periodic.py b/src/saturn_engine/worker/topics/periodic.py index 8757fa0e..28bba53a 100644 --- a/src/saturn_engine/worker/topics/periodic.py +++ b/src/saturn_engine/worker/topics/periodic.py @@ -5,6 +5,7 @@ from croniter import croniter +from saturn_engine.core import MessageId from saturn_engine.core import TopicMessage from . import Topic @@ -27,7 +28,7 @@ async def run(self) -> AsyncGenerator[TopicMessage, None]: wait_time = next_tick - time.time() if wait_time > 0: await asyncio.sleep(wait_time) - yield TopicMessage(id=str(next_tick), args={}) + yield TopicMessage(id=MessageId(str(next_tick)), args={}) async def publish(self, message: TopicMessage, wait: bool) -> bool: raise ValueError("Cannot publish on periodic topic") diff --git a/tests/worker/inventories/test_batching_inventory.py b/tests/worker/inventories/test_batching_inventory.py index 91f9a7c3..92e0ba9a 100644 --- a/tests/worker/inventories/test_batching_inventory.py +++ b/tests/worker/inventories/test_batching_inventory.py @@ -1,6 +1,7 @@ import asyncstdlib as alib import pytest +from saturn_engine.core import Cursor from saturn_engine.worker.inventories.batching import BatchingInventory @@ -33,7 +34,7 @@ async def test_batching_inventory() -> None: ), ] - items = await alib.list(inventory.iterate(after="4")) + items = await alib.list(inventory.iterate(after=Cursor("4"))) assert [(i.id, i.args) for i in items] == [ ("7", {"batch": [{"a": "5"}, {"a": "6"}, {"a": "7"}]}), diff --git a/tests/worker/inventories/test_blocking_inventory.py b/tests/worker/inventories/test_blocking_inventory.py index 86432335..81a82c88 100644 --- a/tests/worker/inventories/test_blocking_inventory.py +++ b/tests/worker/inventories/test_blocking_inventory.py @@ -2,6 +2,8 @@ import pytest +from saturn_engine.core import Cursor +from saturn_engine.core import MessageId from saturn_engine.worker.inventories import BlockingInventory from saturn_engine.worker.inventories import Item @@ -9,13 +11,13 @@ @pytest.mark.asyncio async def test_blocking_inventory() -> None: class BI(BlockingInventory): - def next_batch_blocking(self, after: Optional[str] = None) -> list[Item]: + def next_batch_blocking(self, after: Optional[Cursor] = None) -> list[Item]: # Don't really block here as we want tests to be fast. # This still tests almost completely that BlockingInventory works. - return [Item(id="66", args={"after": after})] + return [Item(id=MessageId("66"), args={"after": after})] batch = list(await BI.from_options(dict()).next_batch()) - assert batch[0] == Item(id="66", args={"after": None}) + assert batch[0] == Item(id=MessageId("66"), args={"after": None}) - batch = list(await BI.from_options(dict()).next_batch(after="20")) - assert batch[0] == Item(id="66", args={"after": "20"}) + batch = list(await BI.from_options(dict()).next_batch(after=Cursor("20"))) + assert batch[0] == Item(id=MessageId("66"), args={"after": "20"}) diff --git a/tests/worker/inventories/test_chained_inventory.py b/tests/worker/inventories/test_chained_inventory.py index 9ffd4f0d..fcb7645f 100644 --- a/tests/worker/inventories/test_chained_inventory.py +++ b/tests/worker/inventories/test_chained_inventory.py @@ -3,6 +3,7 @@ import asyncstdlib as alib import pytest +from saturn_engine.core import Cursor from saturn_engine.worker.inventories.chained import ChainedInventory @@ -44,11 +45,11 @@ async def test_chained_inventory() -> None: ({"c": "2"}, {"c": "2"}, {"c": {"c": "3"}}), ] - batch = await alib.list(inventory.iterate(after='{"b": "1"}')) + batch = await alib.list(inventory.iterate(after=Cursor('{"b": "1"}'))) assert [(json.loads(i.id), json.loads(i.cursor or ""), i.args) for i in batch] == [ ({"b": "2"}, {"b": "2"}, {"b": {"b": "3"}}), ({"c": "0"}, {"c": "0"}, {"c": {"c": "1"}}), ({"c": "1"}, {"c": "1"}, {"c": {"c": "2"}}), ({"c": "2"}, {"c": "2"}, {"c": {"c": "3"}}), ] - assert not await alib.list(inventory.iterate(after='{"c": "2"}')) + assert not await alib.list(inventory.iterate(after=Cursor('{"c": "2"}'))) diff --git a/tests/worker/inventories/test_inventory.py b/tests/worker/inventories/test_inventory.py index 65bfa9b9..a436d0ba 100644 --- a/tests/worker/inventories/test_inventory.py +++ b/tests/worker/inventories/test_inventory.py @@ -4,6 +4,8 @@ import pytest +from saturn_engine.core import Cursor +from saturn_engine.core import MessageId from saturn_engine.utils import utcnow from saturn_engine.worker.inventory import Inventory from saturn_engine.worker.inventory import Item @@ -20,9 +22,9 @@ def __init__( self.retry_delay = retry_delay self.retrying_count = retrying_count - async def next_batch(self, after: t.Optional[str] = None) -> list[Item]: + async def next_batch(self, after: t.Optional[Cursor] = None) -> list[Item]: if self.retried == self.retrying_count: - return [Item(id="1", args={"x": 1})] + return [Item(id=MessageId("1"), args={"x": 1})] self.retried += 1 raise RetryBatch(delay=self.retry_delay, max_retries=5) diff --git a/tests/worker/inventories/test_joined_inventory.py b/tests/worker/inventories/test_joined_inventory.py index 479ae472..d1487686 100644 --- a/tests/worker/inventories/test_joined_inventory.py +++ b/tests/worker/inventories/test_joined_inventory.py @@ -3,6 +3,7 @@ import asyncstdlib as alib import pytest +from saturn_engine.core import Cursor from saturn_engine.worker.inventories import JoinedInventory @@ -39,14 +40,14 @@ async def test_joined_inventory() -> None: ({"a": "2", "b": "2"}, {"a": "1", "b": "2"}, {"a": {"n": 3}, "b": {"c": "C"}}), ] - batch = await alib.list(inventory.iterate(after='{"a": "0", "b": "1"}')) + batch = await alib.list(inventory.iterate(after=Cursor('{"a": "0", "b": "1"}'))) assert [(json.loads(i.id), json.loads(i.cursor or ""), i.args) for i in batch] == [ ({"a": "1", "b": "2"}, {"a": "0", "b": "2"}, {"a": {"n": 2}, "b": {"c": "C"}}), ({"a": "2", "b": "0"}, {"a": "1", "b": "0"}, {"a": {"n": 3}, "b": {"c": "A"}}), ({"a": "2", "b": "1"}, {"a": "1", "b": "1"}, {"a": {"n": 3}, "b": {"c": "B"}}), ({"a": "2", "b": "2"}, {"a": "1", "b": "2"}, {"a": {"n": 3}, "b": {"c": "C"}}), ] - assert not await alib.list(inventory.iterate(after='{"a": "1", "b": "2"}')) + assert not await alib.list(inventory.iterate(after=Cursor('{"a": "1", "b": "2"}'))) @pytest.mark.asyncio diff --git a/tests/worker/inventories/test_joined_sub_inventory.py b/tests/worker/inventories/test_joined_sub_inventory.py index cc938a52..cdab4ae3 100644 --- a/tests/worker/inventories/test_joined_sub_inventory.py +++ b/tests/worker/inventories/test_joined_sub_inventory.py @@ -5,6 +5,8 @@ import asyncstdlib as alib import pytest +from saturn_engine.core import Cursor +from saturn_engine.core import MessageId from saturn_engine.worker.inventories import JoinedSubInventory from saturn_engine.worker.inventory import Item from saturn_engine.worker.inventory import SubInventory @@ -14,27 +16,27 @@ class TestSubInventory(SubInventory): __test__ = False async def next_batch( - self, source_item: Item, after: t.Optional[str] = None + self, source_item: Item, after: t.Optional[Cursor] = None ) -> list[Item]: items: list[Item] = [] if source_item.args["a"] == 0: items = [ - Item(args={"b": {"c": 7}}, id="7"), - Item(args={"b": {"c": 8}}, id="8"), - Item(args={"b": {"c": 9}}, id="9"), + Item(args={"b": {"c": 7}}, id=MessageId("7")), + Item(args={"b": {"c": 8}}, id=MessageId("8")), + Item(args={"b": {"c": 9}}, id=MessageId("9")), ] elif source_item.args["a"] == 1: items = [ - Item(args={"b": {"c": 4}}, id="4"), - Item(args={"b": {"c": 5}}, id="5"), - Item(args={"b": {"c": 6}}, id="6"), + Item(args={"b": {"c": 4}}, id=MessageId("4")), + Item(args={"b": {"c": 5}}, id=MessageId("5")), + Item(args={"b": {"c": 6}}, id=MessageId("6")), ] elif source_item.args["a"] == 2: items = [ - Item(args={"b": {"c": 1}}, id="1"), - Item(args={"b": {"c": 2}}, id="2"), - Item(args={"b": {"c": 3}}, id="3"), + Item(args={"b": {"c": 1}}, id=MessageId("1")), + Item(args={"b": {"c": 2}}, id=MessageId("2")), + Item(args={"b": {"c": 3}}, id=MessageId("3")), ] return [i for i in items if after is None or int(i.id) > int(after)] @@ -99,7 +101,7 @@ async def test_joined_sub_inventory() -> None: ), ] - batch = await alib.list(inventory.iterate(after='{"a": "0", "b": "6"}')) + batch = await alib.list(inventory.iterate(after=Cursor('{"a": "0", "b": "6"}'))) assert [(json.loads(i.id), json.loads(i.cursor or ""), i.args) for i in batch] == [ ( {"a": "2", "b": "1"}, @@ -117,7 +119,7 @@ async def test_joined_sub_inventory() -> None: {"a": {"a": 2}, "b": {"b": {"c": 3}}}, ), ] - assert not await alib.list(inventory.iterate(after='{"a": "1", "b": "3"}')) + assert not await alib.list(inventory.iterate(after=Cursor('{"a": "1", "b": "3"}'))) @pytest.mark.asyncio diff --git a/tests/worker/inventories/test_static_inventory.py b/tests/worker/inventories/test_static_inventory.py index 3eda1635..1899d781 100644 --- a/tests/worker/inventories/test_static_inventory.py +++ b/tests/worker/inventories/test_static_inventory.py @@ -1,5 +1,7 @@ import pytest +from saturn_engine.core import Cursor +from saturn_engine.core import MessageId from saturn_engine.worker.inventories import Item from saturn_engine.worker.inventories import StaticInventory @@ -9,17 +11,17 @@ async def test_static_inventory() -> None: inventory = StaticInventory.from_options({"items": [{"n": 1}, {"n": 2}]}) batch = list(await inventory.next_batch()) assert batch == [ - Item(id="0", args={"n": 1}), - Item(id="1", args={"n": 2}), + Item(id=MessageId("0"), args={"n": 1}), + Item(id=MessageId("1"), args={"n": 2}), ] inventory = StaticInventory.from_options({"items": [{}] * 9}) assert len(list(await inventory.next_batch())) == 9 - assert len(list(await inventory.next_batch(after="4"))) == 4 - assert not list(await inventory.next_batch(after="8")) + assert len(list(await inventory.next_batch(after=Cursor("4")))) == 4 + assert not list(await inventory.next_batch(after=Cursor("8"))) inventory = StaticInventory.from_options({"items": [{"a": None}]}) batch = list(await inventory.next_batch()) assert batch == [ - Item(id="0", args={"a": None}), + Item(id=MessageId("0"), args={"a": None}), ] diff --git a/tests/worker/job/test_api.py b/tests/worker/job/test_api.py index 36180a34..6fb9c1be 100644 --- a/tests/worker/job/test_api.py +++ b/tests/worker/job/test_api.py @@ -4,6 +4,7 @@ import pytest +from saturn_engine.core import Cursor from saturn_engine.worker.job.api import ApiJobStore from tests.conftest import FreezeTime from tests.utils import HttpClientMock @@ -41,11 +42,11 @@ async def test_api_jobstore( assert await job_store.load_cursor() == "10" http_client_mock.put("/api/jobs/test").return_value = {} - await job_store.save_cursor(after="20") + await job_store.save_cursor(after=Cursor("20")) http_client_mock.put("/api/jobs/test").assert_not_called() await asyncio.sleep(0.5) - await job_store.save_cursor(after="30") + await job_store.save_cursor(after=Cursor("30")) http_client_mock.put("/api/jobs/test").assert_not_called() await asyncio.sleep(0.6) @@ -58,7 +59,7 @@ async def test_api_jobstore( http_client_mock.put("/api/jobs/test").assert_not_called() http_client_mock.reset_mock() - await job_store.save_cursor(after="40") + await job_store.save_cursor(after=Cursor("40")) await job_store.set_completed() http_client_mock.put("/api/jobs/test").assert_called_once_with( json={ @@ -72,7 +73,7 @@ async def test_api_jobstore( await asyncio.sleep(1.1) http_client_mock.put("/api/jobs/test").assert_not_called() - await job_store.save_cursor(after="50") + await job_store.save_cursor(after=Cursor("50")) await job_store.set_failed(ValueError("test")) http_client_mock.put("/api/jobs/test").assert_called_once_with( json={ diff --git a/tests/worker/services/test_logger.py b/tests/worker/services/test_logger.py index ce81797b..9d0f1682 100644 --- a/tests/worker/services/test_logger.py +++ b/tests/worker/services/test_logger.py @@ -7,6 +7,7 @@ import pytest from saturn_engine.config import Config +from saturn_engine.core import MessageId from saturn_engine.core import PipelineInfo from saturn_engine.core import PipelineOutput from saturn_engine.core import PipelineResults @@ -53,7 +54,7 @@ async def test_logger_message_executed( pipeline_info = PipelineInfo.from_pipeline(fake_pipeline) xmsg = executable_maker(pipeline_info=pipeline_info) - xmsg.message.message = TopicMessage(id="m1", args={"x": 42}) + xmsg.message.message = TopicMessage(id=MessageId("m1"), args={"x": 42}) xmsg.message.update_with_resources( {FakeResource._typename(): {"name": "r1", "data": "foobar"}} ) @@ -63,7 +64,8 @@ async def test_logger_message_executed( results = PipelineResults( outputs=[ PipelineOutput( - channel="default", message=TopicMessage(id="m2", args={"foo": "bar"}) + channel="default", + message=TopicMessage(id=MessageId("m2"), args={"foo": "bar"}), ) ], resources=[ResourceUsed(type=FakeResource._typename(), release_at=10)], diff --git a/tests/worker/test_broker.py b/tests/worker/test_broker.py index 5805892b..ec157423 100644 --- a/tests/worker/test_broker.py +++ b/tests/worker/test_broker.py @@ -5,6 +5,7 @@ import pytest from saturn_engine.config import Config +from saturn_engine.core import JobId from saturn_engine.core import PipelineResults from saturn_engine.core import api from saturn_engine.core.api import ComponentDefinition @@ -89,7 +90,7 @@ async def test_broker_dummy( worker_manager_client.lock.return_value = LockResponse( items=[ QueueItem( - name="j1", + name=JobId("j1"), input=ComponentDefinition( name="dummy", type="DummyInventory", options={"count": 10000} ), diff --git a/tests/worker/topics/test_blocking_topic.py b/tests/worker/topics/test_blocking_topic.py index 47da1b29..6b9f326b 100644 --- a/tests/worker/topics/test_blocking_topic.py +++ b/tests/worker/topics/test_blocking_topic.py @@ -6,6 +6,7 @@ import asyncstdlib as alib import pytest +from saturn_engine.core import MessageId from saturn_engine.core import TopicMessage from saturn_engine.worker.topic import TopicOutput from saturn_engine.worker.topics import BlockingTopic @@ -26,14 +27,14 @@ def run_once_blocking(self) -> Optional[list[TopicOutput]]: if self.x == 0: self.x += 2 return [ - TopicMessage(id=str(1), args={}), - TopicMessage(id=str(2), args={}), + TopicMessage(id=MessageId(str(1)), args={}), + TopicMessage(id=MessageId(str(2)), args={}), ] self.x += 1 if self.x == 3: return None - return [TopicMessage(id=str(self.x), args={})] + return [TopicMessage(id=MessageId(str(self.x)), args={})] def publish_blocking(self, message: TopicMessage, wait: bool) -> bool: if message.args["block"]: @@ -47,26 +48,32 @@ def publish_blocking(self, message: TopicMessage, wait: bool) -> bool: topic = FakeTopic() assert await alib.list(topic.run()) == [ - TopicMessage(id="1", args={}), - TopicMessage(id="2", args={}), + TopicMessage(id=MessageId("1"), args={}), + TopicMessage(id=MessageId("2"), args={}), ] - assert await topic.publish(TopicMessage(id="1", args={"block": False}), wait=True) - assert await topic.publish(TopicMessage(id="2", args={"block": False}), wait=False) + assert await topic.publish( + TopicMessage(id=MessageId("1"), args={"block": False}), wait=True + ) + assert await topic.publish( + TopicMessage(id=MessageId("2"), args={"block": False}), wait=False + ) assert not await topic.publish( - TopicMessage(id="3", args={"block": True}), wait=False + TopicMessage(id=MessageId("3"), args={"block": True}), wait=False ) async with event_loop.until_idle(): publish_task1 = asyncio.create_task( - topic.publish(TopicMessage(id="4", args={"block": True}), wait=True) + topic.publish( + TopicMessage(id=MessageId("4"), args={"block": True}), wait=True + ) ) assert not await topic.publish( - TopicMessage(id="5", args={"block": False}), wait=False + TopicMessage(id=MessageId("5"), args={"block": False}), wait=False ) publish_task2 = asyncio.create_task( - topic.publish(TopicMessage(id="6", args={"block": False}), wait=True) + topic.publish(TopicMessage(id=MessageId("6"), args={"block": False}), wait=True) ) event.set() @@ -89,10 +96,10 @@ def run_once_blocking(self) -> Optional[list[TopicOutput]]: item = self.items.pop(0) if isinstance(item, Exception): raise item - return [TopicMessage(id=str(item), args={})] + return [TopicMessage(id=MessageId(str(item)), args={})] topic = FakeTopic() assert await alib.list(topic.run()) == [ - TopicMessage(id="1", args={}), - TopicMessage(id="2", args={}), + TopicMessage(id=MessageId("1"), args={}), + TopicMessage(id=MessageId("2"), args={}), ] diff --git a/tests/worker/topics/test_delayed_topic.py b/tests/worker/topics/test_delayed_topic.py index 636a2b78..67565e9e 100644 --- a/tests/worker/topics/test_delayed_topic.py +++ b/tests/worker/topics/test_delayed_topic.py @@ -10,6 +10,7 @@ import asyncstdlib as alib import pytest +from saturn_engine.core import MessageId from saturn_engine.core import TopicMessage from saturn_engine.utils import utcnow from saturn_engine.worker.topics import DelayedTopic @@ -41,8 +42,8 @@ def run_time() -> datetime: topic = await rabbitmq_topic_maker(DelayedTopic, delay=TEST_DELAY) messages = [ - TopicMessage(id="a", args={"message": "test-a"}), - TopicMessage(id="b", args={"message": "test-b"}), + TopicMessage(id=MessageId("a"), args={"message": "test-a"}), + TopicMessage(id=MessageId("b"), args={"message": "test-b"}), ] with mock.patch("saturn_engine.worker.topics.delayed.utcnow", new=publish_time): diff --git a/tests/worker/topics/test_file_topic.py b/tests/worker/topics/test_file_topic.py index 67214aa9..f78610ef 100644 --- a/tests/worker/topics/test_file_topic.py +++ b/tests/worker/topics/test_file_topic.py @@ -3,6 +3,7 @@ import asyncstdlib as alib import pytest +from saturn_engine.core import MessageId from saturn_engine.core import TopicMessage from saturn_engine.worker.topics import FileTopic @@ -11,8 +12,8 @@ async def test_file_topic(tmp_path: Path) -> None: path = tmp_path / "topic.json" messages = [ - TopicMessage(id="0", args={"n": 1}), - TopicMessage(id="1", args={"n": 2}), + TopicMessage(id=MessageId("0"), args={"n": 1}), + TopicMessage(id=MessageId("1"), args={"n": 2}), ] topic = FileTopic.from_options({"path": str(path), "mode": "w"}) diff --git a/tests/worker/topics/test_logging_topic.py b/tests/worker/topics/test_logging_topic.py index 531dbcb0..7fb62e1f 100644 --- a/tests/worker/topics/test_logging_topic.py +++ b/tests/worker/topics/test_logging_topic.py @@ -4,6 +4,7 @@ import pytest +from saturn_engine.core import MessageId from saturn_engine.core import TopicMessage from saturn_engine.core.error import ErrorMessageArgs from saturn_engine.utils.traceback_data import TracebackData @@ -19,12 +20,12 @@ def generate_exception(msg: str) -> BaseException: @pytest.mark.asyncio async def test_logging_topic(caplog: t.Any) -> None: messages = [ - TopicMessage(id="0", args={"n": 1}), - TopicMessage(id="1", args={"n": 2}), + TopicMessage(id=MessageId("0"), args={"n": 1}), + TopicMessage(id=MessageId("1"), args={"n": 2}), ] error_messages = [ TopicMessage( - id="2", + id=MessageId("2"), args={ "error": ErrorMessageArgs( type="Exception", @@ -37,7 +38,7 @@ async def test_logging_topic(caplog: t.Any) -> None: }, ), TopicMessage( - id="3", + id=MessageId("3"), args={ "error": ErrorMessageArgs( type="Exception", diff --git a/tests/worker/topics/test_periodic_topic.py b/tests/worker/topics/test_periodic_topic.py index 643afce9..7648fa1e 100644 --- a/tests/worker/topics/test_periodic_topic.py +++ b/tests/worker/topics/test_periodic_topic.py @@ -3,6 +3,7 @@ import asyncstdlib as alib import pytest +from saturn_engine.core import MessageId from saturn_engine.core import TopicMessage from saturn_engine.worker.topics import PeriodicTopic from tests.utils import TimeForwardLoop @@ -15,11 +16,11 @@ async def test_periodic_topic(event_loop: TimeForwardLoop, frozen_time: t.Any) - async with alib.scoped_iter(topic.run()) as scoped_topic_iter: items = await alib.list(alib.islice(scoped_topic_iter, 5)) assert items == [ - TopicMessage(id="1514851500.0", args={}), - TopicMessage(id="1514851800.0", args={}), - TopicMessage(id="1514852100.0", args={}), - TopicMessage(id="1514852400.0", args={}), - TopicMessage(id="1514852700.0", args={}), + TopicMessage(id=MessageId("1514851500.0"), args={}), + TopicMessage(id=MessageId("1514851800.0"), args={}), + TopicMessage(id=MessageId("1514852100.0"), args={}), + TopicMessage(id=MessageId("1514852400.0"), args={}), + TopicMessage(id=MessageId("1514852700.0"), args={}), ] await topic.close() diff --git a/tests/worker/topics/test_rabbitmq_topic.py b/tests/worker/topics/test_rabbitmq_topic.py index 9f5cb6a9..4ae772a0 100644 --- a/tests/worker/topics/test_rabbitmq_topic.py +++ b/tests/worker/topics/test_rabbitmq_topic.py @@ -9,6 +9,7 @@ from aiormq.exceptions import AMQPConnectionError from saturn_engine.config import Config +from saturn_engine.core import MessageId from saturn_engine.core import TopicMessage from saturn_engine.utils import utcnow from saturn_engine.worker.services.manager import ServicesManager @@ -32,8 +33,8 @@ async def test_rabbitmq_topic_simple( topic = await rabbitmq_topic_maker(RabbitMQTopic) messages = [ - TopicMessage(id="0", args={"n": 1}), - TopicMessage(id="1", args={"n": 2}), + TopicMessage(id=MessageId("0"), args={"n": 1}), + TopicMessage(id=MessageId("1"), args={"n": 2}), ] for message in messages: @@ -58,8 +59,8 @@ async def test_rabbitmq_topic_pickle( ) messages = [ - TopicMessage(id="0", args={"n": b"1", "time": utcnow()}), - TopicMessage(id="1", args={"n": b"2", "time": utcnow()}), + TopicMessage(id=MessageId("0"), args={"n": b"1", "time": utcnow()}), + TopicMessage(id=MessageId("1"), args={"n": b"2", "time": utcnow()}), ] for message in messages: @@ -82,7 +83,7 @@ async def test_bounded_rabbitmq_topic_max_length( topic = await rabbitmq_topic_maker(RabbitMQTopic, max_length=2, prefetch_count=2) topic.RETRY_PUBLISH_DELAY = timedelta(seconds=0.1) - message = TopicMessage(id="0", args={"n": 1}) + message = TopicMessage(id=MessageId("0"), args={"n": 1}) assert await topic.publish(message, wait=False) assert await topic.publish(message, wait=True) @@ -147,7 +148,7 @@ async def test_rabbitmq_topic_channel_closed( connection_name="proxy", ) - message = TopicMessage(id="0", args={"n": 1}) + message = TopicMessage(id=MessageId("0"), args={"n": 1}) async with alib.scoped_iter(reader.run()) as topic_iter: assert await topic.publish(message, wait=False) @@ -175,4 +176,4 @@ async def test_closed_rabbitmq_topic( topic = await rabbitmq_topic_maker(RabbitMQTopic) await topic.close() with pytest.raises(TopicClosedError): - await topic.publish(TopicMessage(id="0", args={"n": 0}), wait=True) + await topic.publish(TopicMessage(id=MessageId("0"), args={"n": 0}), wait=True)