Skip to content

Commit

Permalink
Merge pull request #63 from Flared/isra/tweak
Browse files Browse the repository at this point in the history
Runner tweak
  • Loading branch information
isra17 authored Nov 12, 2021
2 parents 8c8c8ea + 340b4a2 commit 797a0b7
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 12 deletions.
4 changes: 3 additions & 1 deletion src/saturn_engine/worker/executors/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def bootstrap_pipeline(message: PipelineMessage) -> PipelineResult:

# Ensure result is an iterator.
results: Iterator
if isinstance(execute_result, Iterable):
if execute_result is None:
results = iter([])
elif isinstance(execute_result, Iterable):
results = iter(execute_result)
elif not isinstance(execute_result, Iterator):
if isinstance(execute_result, (TopicMessage, PipelineOutput, ResourceUsed)):
Expand Down
19 changes: 12 additions & 7 deletions src/saturn_engine/worker/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@

from .broker import Broker


async def async_main() -> None:
loop = asyncio.get_running_loop()
broker = Broker()
for signame in ["SIGINT", "SIGTERM"]:
loop.add_signal_handler(getattr(signal, signame), broker.stop)
await broker.run()


if __name__ == "__main__":
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger(__name__)
broker = Broker()
loop = asyncio.get_event_loop()

for signame in ["SIGINT", "SIGTERM"]:
loop.add_signal_handler(getattr(signal, signame), broker.stop)

loop.run_until_complete(broker.run())
loop = asyncio.get_event_loop()
loop.run_until_complete(async_main())
if tasks := asyncio.all_tasks(loop):
logger.error("Leftover tasks: %s", tasks)
5 changes: 3 additions & 2 deletions src/saturn_engine/worker/topics/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@dataclasses.dataclass
class MemoryOptions:
name: str
buffer_size: int = 100


class MemoryTopic(Topic):
Expand All @@ -25,7 +26,7 @@ def __init__(self, options: MemoryOptions, **kwargs: Any):
self.options = options

async def run(self) -> AsyncGenerator[AsyncContextManager[TopicMessage], None]:
queue = get_queue(self.options.name)
queue = get_queue(self.options.name, maxsize=self.options.buffer_size)
while True:
message = await queue.get()
yield self.message_context(message, queue=queue)
Expand All @@ -40,7 +41,7 @@ async def message_context(
queue.task_done()

async def publish(self, message: TopicMessage, wait: bool) -> bool:
queue = get_queue(self.options.name)
queue = get_queue(self.options.name, maxsize=self.options.buffer_size)
if wait:
await queue.put(message)
else:
Expand Down
6 changes: 4 additions & 2 deletions src/saturn_engine/worker/work_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def build_inventory(
options = {"name": inventory_item.name} | inventory_item.options

inventory_input = klass.from_options(options, context=context)
inventory_output = MemoryTopic(MemoryTopic.Options(name=queue_item.name))
inventory_output = MemoryTopic(
MemoryTopic.Options(name=queue_item.name, buffer_size=1)
)
store = context.services.job_store.for_queue(queue_item)
job = Job(inventory=inventory_input, publisher=inventory_output, store=store)
input_topic = MemoryTopic(MemoryTopic.Options(name=queue_item.name))
input_topic = MemoryTopic(MemoryTopic.Options(name=queue_item.name, buffer_size=1))
return (asyncio.create_task(job.run()), input_topic)

0 comments on commit 797a0b7

Please sign in to comment.