From ed1f9a81d0c06e85ed887e2da72c1bd8e15855ba Mon Sep 17 00:00:00 2001 From: approxit Date: Thu, 4 Jan 2024 16:29:19 +0100 Subject: [PATCH 01/20] initial --- golem/utils/asyncio.py | 15 ++- golem/utils/buffer.py | 235 +++++++++++++++++++++++++++++++++++++++++ golem/utils/counter.py | 44 ++++++++ golem/utils/queue.py | 8 +- 4 files changed, 296 insertions(+), 6 deletions(-) create mode 100644 golem/utils/buffer.py create mode 100644 golem/utils/counter.py diff --git a/golem/utils/asyncio.py b/golem/utils/asyncio.py index 78d5f2e2..0d9e8fa2 100644 --- a/golem/utils/asyncio.py +++ b/golem/utils/asyncio.py @@ -1,7 +1,7 @@ import asyncio import contextvars import logging -from typing import Optional +from typing import Optional, Sequence from golem.utils.logging import trace_id_var @@ -38,3 +38,16 @@ def _handle_task_logging(task: asyncio.Task): pass except Exception: logger.exception("Background async task encountered unhandled exception!") + + +async def cancel_and_await(task: asyncio.Task) -> None: + task.cancel() + + try: + await task + except asyncio.CancelledError: + pass + + +async def cancel_and_await_many(tasks: Sequence[asyncio.Task]) -> None: + await asyncio.gather(*[cancel_and_await(task) for task in tasks]) diff --git a/golem/utils/buffer.py b/golem/utils/buffer.py new file mode 100644 index 00000000..6e46cdbb --- /dev/null +++ b/golem/utils/buffer.py @@ -0,0 +1,235 @@ +import asyncio +import logging +from abc import abstractmethod, ABC +from collections import defaultdict +from datetime import timedelta +from typing import TypeVar, Generic, Optional, Sequence, Iterable, Callable, List, Dict, Awaitable + +from golem.utils.asyncio import create_task_with_logging, cancel_and_await_many +from golem.utils.counter import AsyncCounter +from golem.utils.logging import trace_span, get_trace_id_name + +TItem = TypeVar("TItem") + +logger = logging.getLogger(__name__) + + +class Buffer(ABC, Generic[TItem]): + @abstractmethod + async def get(self) -> TItem: + ... + + @abstractmethod + async def put(self, item: TItem, *, notify_all=True) -> None: + ... + + @abstractmethod + async def remove(self, item: TItem) -> None: + ... + + @abstractmethod + def size(self) -> int: + ... + + @abstractmethod + async def notify_all(self) -> None: + ... + + +class ComposableBuffer(Generic[TItem], Buffer[TItem]): + def __init__(self, buffer: Buffer): + self._buffer = buffer + + async def get(self) -> TItem: + return await self._buffer.get() + + async def put(self, item: TItem, *, notify_all=True) -> None: + await self._buffer.put(item, notify_all=notify_all) + + async def remove(self, item: TItem) -> None: + await self._buffer.remove(item) + + def size(self) -> int: + return self._buffer.size() + + async def notify_all(self) -> None: + await self._buffer.notify_all() + + +class SimpleBuffer(Buffer[TItem]): + def __init__(self, items: Optional[Sequence[TItem]] = None): + self._items = list(items) if items is not None else [] + + self._condition = asyncio.Condition() + + async def get(self) -> TItem: + async with self._condition: + await self._condition.wait_for(lambda: 0 < len(self._items)) + return self._items.pop(0) + + async def put(self, item: TItem, *, notify_all=True) -> None: + async with self._condition: + self._items.append(item) + + if notify_all: + self._condition.notify_all() + + async def remove(self, item: TItem) -> None: + async with self._condition: + self._items.remove(item) + + def size(self) -> int: + return len(self._items) + + async def notify_all(self) -> None: + async with self._condition: + self._condition.notify_all() + + +class ExpirableBuffer(ComposableBuffer[TItem]): + # Optimisation options: Use single expiration task that wakes up to expire the earliest item, + # then check next earliest item and sleep to it and repeat + def __init__(self, buffer: Buffer, get_expiration_func: Callable[[TItem], Optional[timedelta]]): + super().__init__(buffer) + + self._get_expiration_func = get_expiration_func + + self._lock = asyncio.Lock() + self._expiration_tasks: Dict[int, List[asyncio.Task]] = defaultdict(list) + + def _cancel_expiration_task_for_item(self, item: TItem) -> None: + item_id = id(item) + + if item_id not in self._expiration_tasks or not len(self._expiration_tasks[item_id]): + return + + self._expiration_tasks[item_id].pop(0) + + if not self._expiration_tasks[item_id]: + del self._expiration_tasks[item_id] + + async def get(self) -> Iterable[TItem]: + async with self._lock: + item = await self._buffer.get() + self._cancel_expiration_task_for_item(item) + + return item + + async def put(self, item: TItem, *, notify_all=True) -> None: + async with self._lock: + await self._buffer.put(item, notify_all=False) + expiration = self._get_expiration_func(item) + + if expiration is not None: + self._expiration_tasks[id(item)].append(asyncio.create_task(self._expire_item(expiration, item))) + + if notify_all: + await self._buffer.notify_all() + + async def remove(self, item: TItem) -> None: + async with self._lock: + await self._buffer.remove(item) + self._cancel_expiration_task_for_item(item) + + async def _expire_item(self, expiration: timedelta, item: TItem) -> None: + await asyncio.sleep(expiration.total_seconds()) + + async with self._lock: + await self._buffer.remove(item) + del self._expiration_tasks[id(item)] + + async def notify_all(self) -> None: + async with self._lock: + await self._buffer.notify_all() + + +class BackgroundFedBuffer(ComposableBuffer[TItem]): + def __init__( + self, + buffer: Buffer, + feed_func: Callable[[], Awaitable[TItem]], + min_size: int, + max_size: int, + fill_concurrency_size: int = 1, + fill_at_start=False, + ): + super().__init__(buffer) + + self._feed_func = feed_func + self._min_size = min_size + self._max_size = max_size + self._fill_concurrency_size = fill_concurrency_size + self._fill_at_start = fill_at_start + + self._is_started = False + self._worker_tasks: List[asyncio.Task] = [] + self._requests_counter = AsyncCounter() + self._lock = asyncio.Lock() + + @trace_span() + async def start(self) -> None: + if self.is_started(): + raise RuntimeError("Already started!") + + for i in range(self._fill_concurrency_size): + self._worker_tasks.append( + create_task_with_logging( + self._worker_loop(), trace_id=get_trace_id_name(self, f"worker-{i}") + ) + ) + + if self._fill_at_start: + await self._handle_item_requests() + + self._is_started = True + + @trace_span() + async def stop(self) -> None: + if not self.is_started(): + raise RuntimeError("Already stopped!") + + await cancel_and_await_many(self._worker_tasks) + self._worker_tasks.clear() + self._is_started = False + + self._requests_counter = AsyncCounter() + + def is_started(self) -> bool: + return self._is_started + + async def _worker_loop(self): + while True: + await self._requests_counter.decrement() + + item = await self._feed_func() + + await self._buffer.put(item) + + self._requests_counter.task_done() + + async def get(self) -> TItem: + async with self._lock: + if self._size_with_pending() == 0: # This supports lazy (not at start) buffer filling + logger.debug("No items to get, requesting fill") + await self._handle_item_requests() + + logger.debug("Waiting for any item to pick...") + + item = await self._buffer.get() + + # Check if we need to request any additional items + if self._size_with_pending() < self._min_size: + await self._handle_item_requests() + + return item + + def _size_with_pending(self): + return self._buffer.size() + self._requests_counter.pending_count() + + @trace_span() + async def _handle_item_requests(self) -> None: + items_to_request = self._max_size - self._size_with_pending() + + await self._requests_counter.increment(items_to_request) + + logger.debug("Requested %d items", items_to_request) diff --git a/golem/utils/counter.py b/golem/utils/counter.py new file mode 100644 index 00000000..b0fb3ed3 --- /dev/null +++ b/golem/utils/counter.py @@ -0,0 +1,44 @@ +import asyncio + + +class AsyncCounter: + def __init__(self, start_with=0): + self._counter = start_with + self._pending_count = 0 + self._condition = asyncio.Condition() + self._finished = asyncio.Event() + + self._finished.set() + + async def increment(self, value=1) -> None: + async with self._condition: + self._counter += value + + self._condition.notify_all() + self._finished.clear() + + async def decrement(self, value=1) -> None: + async with self._condition: + await self._condition.wait_for(lambda: value < self._counter) + + self._counter -= value + + async def reset(self, value=0) -> None: + self._counter = value + self._pending_count = 0 + self._finished.set() + + def task_done(self): + if self._pending_count <= 0: + raise ValueError('task_done() called too many times!') + + self._pending_count -= 1 + + if not self._pending_count: + self._finished.set() + + async def join(self) -> None: + await self._finished.wait() + + def pending_count(self) -> int: + return self._pending_count diff --git a/golem/utils/queue.py b/golem/utils/queue.py index 6c15d0d8..6cc1ece6 100644 --- a/golem/utils/queue.py +++ b/golem/utils/queue.py @@ -7,12 +7,10 @@ class ErrorReportingQueue(asyncio.Queue, Generic[QueueItem]): """Asyncio Queue that enables exceptions to be passed to consumers from the feeding code.""" - _error: Optional[BaseException] - _error_event: asyncio.Event - def __init__(self, *args, **kwargs): - self._error = None - self._error_event = asyncio.Event() + self._error: Optional[BaseException] = None + self._error_event: asyncio.Event = asyncio.Event() + super().__init__(*args, **kwargs) def get_nowait(self) -> QueueItem: From c44925a5179d36663fe953a17a1d66988a60b2fc Mon Sep 17 00:00:00 2001 From: approxit Date: Fri, 5 Jan 2024 18:04:43 +0100 Subject: [PATCH 02/20] a little better this time --- golem/utils/asyncio.py | 3 + golem/utils/buffer.py | 230 +++++++++++++++++++++++---------------- golem/utils/semaphore.py | 48 ++++++++ 3 files changed, 186 insertions(+), 95 deletions(-) create mode 100644 golem/utils/semaphore.py diff --git a/golem/utils/asyncio.py b/golem/utils/asyncio.py index 0d9e8fa2..5285a315 100644 --- a/golem/utils/asyncio.py +++ b/golem/utils/asyncio.py @@ -41,6 +41,9 @@ def _handle_task_logging(task: asyncio.Task): async def cancel_and_await(task: asyncio.Task) -> None: + if task.done(): + return + task.cancel() try: diff --git a/golem/utils/buffer.py b/golem/utils/buffer.py index 6e46cdbb..21c84a60 100644 --- a/golem/utils/buffer.py +++ b/golem/utils/buffer.py @@ -3,90 +3,127 @@ from abc import abstractmethod, ABC from collections import defaultdict from datetime import timedelta -from typing import TypeVar, Generic, Optional, Sequence, Iterable, Callable, List, Dict, Awaitable +from typing import TypeVar, Generic, Optional, Sequence, Iterable, Callable, List, Dict, Awaitable, MutableSequence -from golem.utils.asyncio import create_task_with_logging, cancel_and_await_many -from golem.utils.counter import AsyncCounter +from golem.utils.asyncio import create_task_with_logging, cancel_and_await_many, cancel_and_await from golem.utils.logging import trace_span, get_trace_id_name +from golem.utils.semaphore import SingleUseSemaphore TItem = TypeVar("TItem") +TBuffer = TypeVar("TBuffer") logger = logging.getLogger(__name__) class Buffer(ABC, Generic[TItem]): + @abstractmethod + def size(self) -> int: + ... + + @abstractmethod + async def wait_for_any_items(self) -> None: + ... + @abstractmethod async def get(self) -> TItem: ... @abstractmethod - async def put(self, item: TItem, *, notify_all=True) -> None: + async def get_all(self) -> MutableSequence[TItem]: ... @abstractmethod - async def remove(self, item: TItem) -> None: + async def put(self, item: TItem) -> None: ... @abstractmethod - def size(self) -> int: + async def put_all(self, items: Sequence[TItem]) -> None: ... @abstractmethod - async def notify_all(self) -> None: + async def remove(self, item: TItem) -> None: ... -class ComposableBuffer(Generic[TItem], Buffer[TItem]): - def __init__(self, buffer: Buffer): +class ComposableBuffer(Generic[TBuffer, TItem], Buffer[TItem]): + def __init__(self, buffer: TBuffer): self._buffer = buffer + def size(self) -> int: + return self._buffer.size() + + async def wait_for_any_items(self) -> None: + await self._buffer.wait_for_any_items() + async def get(self) -> TItem: return await self._buffer.get() - async def put(self, item: TItem, *, notify_all=True) -> None: - await self._buffer.put(item, notify_all=notify_all) + async def get_all(self) -> MutableSequence[TItem]: + return await self._buffer.get_all() - async def remove(self, item: TItem) -> None: - await self._buffer.remove(item) + async def put(self, item: TItem) -> None: + await self._buffer.put(item) - def size(self) -> int: - return self._buffer.size() + async def put_all(self, items: Sequence[TItem]) -> None: + await self._buffer.put_all(items) - async def notify_all(self) -> None: - await self._buffer.notify_all() + async def remove(self, item: TItem) -> None: + await self._buffer.remove(item) class SimpleBuffer(Buffer[TItem]): def __init__(self, items: Optional[Sequence[TItem]] = None): self._items = list(items) if items is not None else [] - self._condition = asyncio.Condition() + self._have_items = asyncio.Event() # TODO: collections of future-object waiters instead of event? + + if self.size(): + self._have_items.set() + + def size(self) -> int: + return len(self._items) + + async def wait_for_any_items(self) -> None: + while not self.size(): + await self._have_items.wait() async def get(self) -> TItem: - async with self._condition: - await self._condition.wait_for(lambda: 0 < len(self._items)) - return self._items.pop(0) + await self.wait_for_any_items() - async def put(self, item: TItem, *, notify_all=True) -> None: - async with self._condition: - self._items.append(item) + item = self._items.pop(0) - if notify_all: - self._condition.notify_all() + if not self.size(): + self._have_items.clear() - async def remove(self, item: TItem) -> None: - async with self._condition: - self._items.remove(item) + return item - def size(self) -> int: - return len(self._items) + async def get_all(self) -> MutableSequence[TItem]: + items = self._items[:] + self._items.clear() + self._have_items.clear() + return items - async def notify_all(self) -> None: - async with self._condition: - self._condition.notify_all() + async def put(self, item: TItem) -> None: + self._items.append(item) + self._have_items.set() + async def put_all(self, items: Sequence[TItem]) -> None: + self._items.clear() + self._items.extend(items[:]) -class ExpirableBuffer(ComposableBuffer[TItem]): + if self.size(): + self._have_items.set() + else: + self._have_items.clear() + + async def remove(self, item: TItem) -> None: + self._items.remove(item) + + if not self.size(): + self._have_items.clear() + + +class ExpirableBuffer(ComposableBuffer[Buffer, TItem]): # Optimisation options: Use single expiration task that wakes up to expire the earliest item, # then check next earliest item and sleep to it and repeat def __init__(self, buffer: Buffer, get_expiration_func: Callable[[TItem], Optional[timedelta]]): @@ -97,90 +134,96 @@ def __init__(self, buffer: Buffer, get_expiration_func: Callable[[TItem], Option self._lock = asyncio.Lock() self._expiration_tasks: Dict[int, List[asyncio.Task]] = defaultdict(list) - def _cancel_expiration_task_for_item(self, item: TItem) -> None: + def _add_expiration_task_for_item(self, item: TItem) -> None: + expiration = self._get_expiration_func(item) + + if expiration is None: + return + + self._expiration_tasks[id(item)].append(asyncio.create_task(self._expire_item(expiration, item))) + + async def _remove_expiration_task_for_item(self, item: TItem) -> None: item_id = id(item) if item_id not in self._expiration_tasks or not len(self._expiration_tasks[item_id]): return - self._expiration_tasks[item_id].pop(0) + expiration_task = self._expiration_tasks[item_id].pop(0) + + await cancel_and_await(expiration_task) if not self._expiration_tasks[item_id]: del self._expiration_tasks[item_id] + async def _remove_all_expiration_tasks(self) -> None: + await cancel_and_await_many(self._expiration_tasks) + self._expiration_tasks.clear() + async def get(self) -> Iterable[TItem]: async with self._lock: - item = await self._buffer.get() - self._cancel_expiration_task_for_item(item) + item = await super().get() + await self._remove_expiration_task_for_item(item) return item - async def put(self, item: TItem, *, notify_all=True) -> None: + async def get_all(self) -> MutableSequence[TItem]: async with self._lock: - await self._buffer.put(item, notify_all=False) - expiration = self._get_expiration_func(item) + items = await super().get_all() + await self._remove_all_expiration_tasks() + return items - if expiration is not None: - self._expiration_tasks[id(item)].append(asyncio.create_task(self._expire_item(expiration, item))) + async def put(self, item: TItem) -> None: + async with self._lock: + await super().put(item) + self._add_expiration_task_for_item(item) + + async def put_all(self, items: Sequence[TItem]) -> None: + async with self._lock: + await super().put_all(items) + await self._remove_all_expiration_tasks() - if notify_all: - await self._buffer.notify_all() + for item in items: + self._add_expiration_task_for_item(item) async def remove(self, item: TItem) -> None: async with self._lock: - await self._buffer.remove(item) - self._cancel_expiration_task_for_item(item) + await super().remove(item) + await self._remove_expiration_task_for_item(item) async def _expire_item(self, expiration: timedelta, item: TItem) -> None: await asyncio.sleep(expiration.total_seconds()) - async with self._lock: - await self._buffer.remove(item) - del self._expiration_tasks[id(item)] - - async def notify_all(self) -> None: - async with self._lock: - await self._buffer.notify_all() + await self.remove(item) -class BackgroundFedBuffer(ComposableBuffer[TItem]): +class BackgroundFeedBuffer(ComposableBuffer[Buffer, TItem]): def __init__( self, buffer: Buffer, feed_func: Callable[[], Awaitable[TItem]], - min_size: int, - max_size: int, - fill_concurrency_size: int = 1, - fill_at_start=False, + feed_concurrency_size=1, ): super().__init__(buffer) self._feed_func = feed_func - self._min_size = min_size - self._max_size = max_size - self._fill_concurrency_size = fill_concurrency_size - self._fill_at_start = fill_at_start + self._feed_concurrency_size = feed_concurrency_size self._is_started = False self._worker_tasks: List[asyncio.Task] = [] - self._requests_counter = AsyncCounter() - self._lock = asyncio.Lock() + self._workers_semaphore = SingleUseSemaphore() @trace_span() async def start(self) -> None: if self.is_started(): raise RuntimeError("Already started!") - for i in range(self._fill_concurrency_size): + for i in range(self._feed_concurrency_size): self._worker_tasks.append( create_task_with_logging( self._worker_loop(), trace_id=get_trace_id_name(self, f"worker-{i}") ) ) - if self._fill_at_start: - await self._handle_item_requests() - self._is_started = True @trace_span() @@ -192,44 +235,41 @@ async def stop(self) -> None: self._worker_tasks.clear() self._is_started = False - self._requests_counter = AsyncCounter() + self._workers_semaphore.reset() def is_started(self) -> bool: return self._is_started async def _worker_loop(self): while True: - await self._requests_counter.decrement() - - item = await self._feed_func() - - await self._buffer.put(item) + async with self._workers_semaphore: + item = await self._feed_func() - self._requests_counter.task_done() + await self.put(item) - async def get(self) -> TItem: - async with self._lock: - if self._size_with_pending() == 0: # This supports lazy (not at start) buffer filling - logger.debug("No items to get, requesting fill") - await self._handle_item_requests() + async def request(self, count: int) -> None: + await self._workers_semaphore.increase(count) - logger.debug("Waiting for any item to pick...") + @property + def finished(self): + return self._workers_semaphore.finished - item = await self._buffer.get() - # Check if we need to request any additional items - if self._size_with_pending() < self._min_size: - await self._handle_item_requests() +class BatchedGetAllBuffer(ComposableBuffer[BackgroundFeedBuffer, TItem]): + def __init__(self, buffer: BackgroundFeedBuffer, batch_deadline: timedelta): + super().__init__(buffer) - return item + self._batch_deadline = batch_deadline - def _size_with_pending(self): - return self._buffer.size() + self._requests_counter.pending_count() + self._lock = asyncio.Lock() - @trace_span() - async def _handle_item_requests(self) -> None: - items_to_request = self._max_size - self._size_with_pending() + async def get_all(self) -> Sequence[TItem]: + async with self._lock: + await self.wait_for_any_items() - await self._requests_counter.increment(items_to_request) + try: + await asyncio.wait_for(self._buffer.finished.wait(), self._batch_deadline.total_seconds()) + except TimeoutError: + pass - logger.debug("Requested %d items", items_to_request) + return await super().get_all() diff --git a/golem/utils/semaphore.py b/golem/utils/semaphore.py new file mode 100644 index 00000000..83483226 --- /dev/null +++ b/golem/utils/semaphore.py @@ -0,0 +1,48 @@ +import asyncio + + +class SingleUseSemaphore: + def __init__(self, value=0): + self._value = value + + self._pending = 0 + self._condition = asyncio.Condition() + + self.finished = asyncio.Event() + if not self._value: + self.finished.set() + + async def __aenter__(self): + await self.acquire() + + async def __aexit__(self, exc_type, exc, tb): + self.release() + + def locked(self) -> bool: + return not self._value + + async def acquire(self): + async with self._condition: + await self._condition.wait_for(lambda: self._value) + + self._value -= 1 + self._pending += 1 + + def release(self): + self._pending -= 1 + + if not self._pending: + self.finished.set() + + async def increase(self, value: int) -> None: + async with self._condition: + self._value += value + self.finished.clear() + self._condition.notify(value) + + def get_pending_count(self) -> int: + return self._pending + + def reset(self) -> None: + self._value = 0 + self.finished.set() From f945e99874befd3e943d29150b2595427139c911 Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 17 Jan 2024 16:48:11 +0100 Subject: [PATCH 03/20] nearly finish --- golem/managers/proposal/plugins/new_buffer.py | 55 +++ .../proposal/plugins/scoring/new_scoring.py | 34 ++ golem/utils/buffer.py | 155 +++++--- golem/utils/counter.py | 44 --- golem/utils/semaphore.py | 12 + tests/unit/utils/test_buffer.py | 347 ++++++++++++++++++ tests/unit/utils/test_semaphore.py | 174 +++++++++ 7 files changed, 723 insertions(+), 98 deletions(-) create mode 100644 golem/managers/proposal/plugins/new_buffer.py create mode 100644 golem/managers/proposal/plugins/scoring/new_scoring.py delete mode 100644 golem/utils/counter.py create mode 100644 tests/unit/utils/test_buffer.py create mode 100644 tests/unit/utils/test_semaphore.py diff --git a/golem/managers/proposal/plugins/new_buffer.py b/golem/managers/proposal/plugins/new_buffer.py new file mode 100644 index 00000000..edd2f699 --- /dev/null +++ b/golem/managers/proposal/plugins/new_buffer.py @@ -0,0 +1,55 @@ +from golem.managers import ProposalManagerPlugin +from golem.resources import Proposal +from golem.utils.buffer import BackgroundFeedBuffer, SimpleBuffer + + +class Buffer(ProposalManagerPlugin): + def __init__( + self, + min_size: int, + max_size: int, + fill_concurrency_size: int = 1, + fill_at_start=False, + ) -> None: + self._min_size = min_size + self._max_size = max_size + self._fill_concurrency_size = fill_concurrency_size + self._fill_at_start = fill_at_start + + self._buffer = BackgroundFeedBuffer( + buffer=SimpleBuffer(), + feed_func=self._call_feed_func, + feed_concurrency_size=self._fill_concurrency_size, + ) + + async def _call_feed_func(self) -> Proposal: + return await self._get_proposal() + + async def start(self) -> None: + await self._buffer.start() + + if self._fill_at_start: + self._request_items() + + async def stop(self) -> None: + await self._buffer.stop() + + def _request_items(self): + self._buffer.request(self._max_size - self._buffer.size_with_requested()) + + async def get_proposal(self) -> Proposal: + if not self._get_items_count(): + self._request_items() + + proposal = await self._get_item() + + if self._get_items_count() < self._min_size: + self._request_items() + + return proposal + + async def _get_item(self) -> Proposal: + return await self._buffer.get() + + def _get_items_count(self) -> int: + return self._buffer.size_with_requested() diff --git a/golem/managers/proposal/plugins/scoring/new_scoring.py b/golem/managers/proposal/plugins/scoring/new_scoring.py new file mode 100644 index 00000000..6cebffe8 --- /dev/null +++ b/golem/managers/proposal/plugins/scoring/new_scoring.py @@ -0,0 +1,34 @@ +import asyncio +from datetime import timedelta + +from golem.managers import ProposalScoringMixin +from golem.managers.proposal.plugins.new_buffer import Buffer as BufferPlugin +from golem.resources import Proposal +from golem.utils.buffer import Buffer + + +class ScoringBuffer(ProposalScoringMixin, BufferPlugin): + def __init__(self, update_interval: timedelta = timedelta(seconds=10), *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self._update_interval = update_interval + + self._buffer_scored: Buffer[Proposal] = Buffer() + self._lock = asyncio.Lock() + + async def _background_loop(self) -> None: + while True: + self._buffer.wait_for_any_items() + + async with self._lock: + items = await self._buffer.get_all_requested(self._update_interval) + items.extend(await self._buffer_scored.get_all()) + scored_items = await self.do_scoring(items) + await self._buffer_scored.put_all([proposal for _, proposal in scored_items]) + + async def _get_item(self) -> Proposal: + async with self._lock: + return await self._buffer_scored.get() + + def _get_items_count(self) -> int: + return super()._get_items_count() + self._buffer_scored.size() diff --git a/golem/utils/buffer.py b/golem/utils/buffer.py index 21c84a60..e84e9583 100644 --- a/golem/utils/buffer.py +++ b/golem/utils/buffer.py @@ -1,16 +1,26 @@ import asyncio import logging -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod from collections import defaultdict from datetime import timedelta -from typing import TypeVar, Generic, Optional, Sequence, Iterable, Callable, List, Dict, Awaitable, MutableSequence - -from golem.utils.asyncio import create_task_with_logging, cancel_and_await_many, cancel_and_await -from golem.utils.logging import trace_span, get_trace_id_name +from typing import ( + Awaitable, + Callable, + Dict, + Generic, + Iterable, + List, + MutableSequence, + Optional, + Sequence, + TypeVar, +) + +from golem.utils.asyncio import cancel_and_await_many, create_task_with_logging +from golem.utils.logging import get_trace_id_name, trace_span from golem.utils.semaphore import SingleUseSemaphore TItem = TypeVar("TItem") -TBuffer = TypeVar("TBuffer") logger = logging.getLogger(__name__) @@ -18,35 +28,50 @@ class Buffer(ABC, Generic[TItem]): @abstractmethod def size(self) -> int: + """Return number of items stored in buffer.""" ... @abstractmethod async def wait_for_any_items(self) -> None: + """Wait until any items are stored in buffer.""" ... @abstractmethod async def get(self) -> TItem: + """Await, remove and return left-most item stored in buffer.""" ... @abstractmethod async def get_all(self) -> MutableSequence[TItem]: + """Remove and return all items stored in buffer.""" ... @abstractmethod async def put(self, item: TItem) -> None: + """Add item to right-most position to buffer. + + Duplicates are supported. + """ ... @abstractmethod async def put_all(self, items: Sequence[TItem]) -> None: + """Replace all items stored in buffer. + + Duplicates are supported. + """ ... @abstractmethod async def remove(self, item: TItem) -> None: + """Remove first occurrence of item from buffer or raise `ValueError` if not found.""" ... -class ComposableBuffer(Generic[TBuffer, TItem], Buffer[TItem]): - def __init__(self, buffer: TBuffer): +class ComposableBuffer(Buffer[TItem]): + """Utility class for composable/stackable buffer implementations to help with calling underlying buffer.""" + + def __init__(self, buffer: Buffer[TItem]): self._buffer = buffer def size(self) -> int: @@ -72,10 +97,14 @@ async def remove(self, item: TItem) -> None: class SimpleBuffer(Buffer[TItem]): + """Most basic implementation of Buffer interface.""" + def __init__(self, items: Optional[Sequence[TItem]] = None): self._items = list(items) if items is not None else [] - self._have_items = asyncio.Event() # TODO: collections of future-object waiters instead of event? + self._have_items = ( + asyncio.Event() + ) # TODO: collections of future-object waiters instead of event if self.size(): self._have_items.set() @@ -123,16 +152,29 @@ async def remove(self, item: TItem) -> None: self._have_items.clear() -class ExpirableBuffer(ComposableBuffer[Buffer, TItem]): - # Optimisation options: Use single expiration task that wakes up to expire the earliest item, - # then check next earliest item and sleep to it and repeat - def __init__(self, buffer: Buffer, get_expiration_func: Callable[[TItem], Optional[timedelta]]): +class ExpirableBuffer(ComposableBuffer[TItem]): + """Composable that adds option to expire item after some time. + + Items that are already in provided buffer will not expire. + """ + + # TODO: Optimisation options: Use single expiration task that wakes up to expire the earliest item, + # then check next earliest item and sleep to it and repeat + + def __init__( + self, + buffer: Buffer[TItem], + get_expiration_func: Callable[[TItem], Optional[timedelta]], + on_expiration_func: Optional[Callable[[TItem], Awaitable[None]]] = None, + ): super().__init__(buffer) self._get_expiration_func = get_expiration_func + self._on_expiration_func = on_expiration_func + # Lock is used to keep items in buffer and expiration tasks in sync self._lock = asyncio.Lock() - self._expiration_tasks: Dict[int, List[asyncio.Task]] = defaultdict(list) + self._expiration_handlers: Dict[int, List[asyncio.TimerHandle]] = defaultdict(list) def _add_expiration_task_for_item(self, item: TItem) -> None: expiration = self._get_expiration_func(item) @@ -140,36 +182,42 @@ def _add_expiration_task_for_item(self, item: TItem) -> None: if expiration is None: return - self._expiration_tasks[id(item)].append(asyncio.create_task(self._expire_item(expiration, item))) + loop = asyncio.get_event_loop() + + self._expiration_handlers[id(item)].append( + loop.call_later(expiration.total_seconds(), lambda: asyncio.create_task(self._expire_item(item))) + ) - async def _remove_expiration_task_for_item(self, item: TItem) -> None: + async def _remove_expiration_handler_for_item(self, item: TItem) -> None: item_id = id(item) - if item_id not in self._expiration_tasks or not len(self._expiration_tasks[item_id]): + if item_id not in self._expiration_handlers or not len(self._expiration_handlers[item_id]): return - expiration_task = self._expiration_tasks[item_id].pop(0) + expiration_handle = self._expiration_handlers[item_id].pop(0) + expiration_handle.cancel() - await cancel_and_await(expiration_task) + if not self._expiration_handlers[item_id]: + del self._expiration_handlers[item_id] - if not self._expiration_tasks[item_id]: - del self._expiration_tasks[item_id] + async def _remove_all_expiration_handlers(self) -> None: + for handlers in self._expiration_handlers.values(): + for handler in handlers: + handler.cancel() - async def _remove_all_expiration_tasks(self) -> None: - await cancel_and_await_many(self._expiration_tasks) - self._expiration_tasks.clear() + self._expiration_handlers.clear() async def get(self) -> Iterable[TItem]: async with self._lock: item = await super().get() - await self._remove_expiration_task_for_item(item) + await self._remove_expiration_handler_for_item(item) return item async def get_all(self) -> MutableSequence[TItem]: async with self._lock: items = await super().get_all() - await self._remove_all_expiration_tasks() + await self._remove_all_expiration_handlers() return items async def put(self, item: TItem) -> None: @@ -180,7 +228,7 @@ async def put(self, item: TItem) -> None: async def put_all(self, items: Sequence[TItem]) -> None: async with self._lock: await super().put_all(items) - await self._remove_all_expiration_tasks() + await self._remove_all_expiration_handlers() for item in items: self._add_expiration_task_for_item(item) @@ -188,18 +236,25 @@ async def put_all(self, items: Sequence[TItem]) -> None: async def remove(self, item: TItem) -> None: async with self._lock: await super().remove(item) - await self._remove_expiration_task_for_item(item) - - async def _expire_item(self, expiration: timedelta, item: TItem) -> None: - await asyncio.sleep(expiration.total_seconds()) + await self._remove_expiration_handler_for_item(item) + async def _expire_item(self, item: TItem) -> None: await self.remove(item) + if self._on_expiration_func: + await self._on_expiration_func(item) + + +class BackgroundFeedBuffer(ComposableBuffer[TItem]): + """Composable that adds option to feed buffer in background task. + + Background feed will happen only if background tasks are started by calling `.start()` + and items were requested by `.request()`. + """ -class BackgroundFeedBuffer(ComposableBuffer[Buffer, TItem]): def __init__( self, - buffer: Buffer, + buffer: Buffer[TItem], feed_func: Callable[[], Awaitable[TItem]], feed_concurrency_size=1, ): @@ -248,28 +303,20 @@ async def _worker_loop(self): await self.put(item) async def request(self, count: int) -> None: + """Request given number of items to be filled in background.""" await self._workers_semaphore.increase(count) - @property - def finished(self): - return self._workers_semaphore.finished - - -class BatchedGetAllBuffer(ComposableBuffer[BackgroundFeedBuffer, TItem]): - def __init__(self, buffer: BackgroundFeedBuffer, batch_deadline: timedelta): - super().__init__(buffer) - - self._batch_deadline = batch_deadline - - self._lock = asyncio.Lock() - - async def get_all(self) -> Sequence[TItem]: - async with self._lock: - await self.wait_for_any_items() + def size_with_requested(self) -> int: + """Return sum of items stored in buffer and requested to be filled.""" + return self.size() + self._workers_semaphore.get_count_with_pending() - try: - await asyncio.wait_for(self._buffer.finished.wait(), self._batch_deadline.total_seconds()) - except TimeoutError: - pass + async def get_all_requested(self, deadline: timedelta) -> MutableSequence[TItem]: + """Await for all requested items with given deadline, then remove and return all items stored in buffer.""" + try: + await asyncio.wait_for( + self._workers_semaphore.finished.wait(), deadline.total_seconds() + ) + except asyncio.TimeoutError: + pass - return await super().get_all() + return await self.get_all() diff --git a/golem/utils/counter.py b/golem/utils/counter.py deleted file mode 100644 index b0fb3ed3..00000000 --- a/golem/utils/counter.py +++ /dev/null @@ -1,44 +0,0 @@ -import asyncio - - -class AsyncCounter: - def __init__(self, start_with=0): - self._counter = start_with - self._pending_count = 0 - self._condition = asyncio.Condition() - self._finished = asyncio.Event() - - self._finished.set() - - async def increment(self, value=1) -> None: - async with self._condition: - self._counter += value - - self._condition.notify_all() - self._finished.clear() - - async def decrement(self, value=1) -> None: - async with self._condition: - await self._condition.wait_for(lambda: value < self._counter) - - self._counter -= value - - async def reset(self, value=0) -> None: - self._counter = value - self._pending_count = 0 - self._finished.set() - - def task_done(self): - if self._pending_count <= 0: - raise ValueError('task_done() called too many times!') - - self._pending_count -= 1 - - if not self._pending_count: - self._finished.set() - - async def join(self) -> None: - await self._finished.wait() - - def pending_count(self) -> int: - return self._pending_count diff --git a/golem/utils/semaphore.py b/golem/utils/semaphore.py index 83483226..44bfd970 100644 --- a/golem/utils/semaphore.py +++ b/golem/utils/semaphore.py @@ -3,6 +3,9 @@ class SingleUseSemaphore: def __init__(self, value=0): + if value < 0: + raise ValueError("Initial value must be greater or equal to zero!") + self._value = value self._pending = 0 @@ -29,6 +32,9 @@ async def acquire(self): self._pending += 1 def release(self): + if self._pending <= 0: + raise RuntimeError("Release called too many times!") + self._pending -= 1 if not self._pending: @@ -40,6 +46,12 @@ async def increase(self, value: int) -> None: self.finished.clear() self._condition.notify(value) + def get_count(self) -> int: + return self._value + + def get_count_with_pending(self) -> int: + return self.get_count() + self.get_pending_count() + def get_pending_count(self) -> int: return self._pending diff --git a/tests/unit/utils/test_buffer.py b/tests/unit/utils/test_buffer.py new file mode 100644 index 00000000..8a10a604 --- /dev/null +++ b/tests/unit/utils/test_buffer.py @@ -0,0 +1,347 @@ +from datetime import timedelta + +import asyncio + +import pytest + +from golem.utils.buffer import SimpleBuffer, Buffer, ExpirableBuffer, BackgroundFeedBuffer + + + +@pytest.fixture +def mocked_buffer(mocker): + return mocker.Mock(spec=Buffer) + + +def test_simple_buffer_creation(): + buffer = SimpleBuffer() + assert buffer.size() == 0 + + buffer = SimpleBuffer(["a", "b", "c"]) + assert buffer.size() == 3 + + +async def test_simple_buffer_put_get(): + buffer = SimpleBuffer() + assert buffer.size() == 0 + + item_put = object() + + await buffer.put(item_put) + + assert buffer.size() == 1 + + item_get = await buffer.get() + + assert buffer.size() == 0 + + assert item_put == item_get + + +async def test_simple_buffer_put_all_get_all(): + buffer = SimpleBuffer(["a"]) + assert buffer.size() == 1 + + await buffer.put_all(["b", "c", "d"]) + assert buffer.size() == 3 + + assert await buffer.get_all() == ["b", "c", "d"] + + assert buffer.size() == 0 + + assert await buffer.get_all() == [] + + +async def test_simple_buffer_remove(): + buffer = SimpleBuffer(["a", "b", "c"]) + + await buffer.remove("b") + + assert await buffer.get_all() == ["a", "c"] + + +async def test_simple_buffer_get_waits_for_items(): + buffer = SimpleBuffer() + assert buffer.size() == 0 + + _, pending = await asyncio.wait([buffer.get()], timeout=0.1) + get_task = pending.pop() + if not get_task: + pytest.fail("Getting empty buffer somehow finished instead of blocking!") + + item_put = object() + await buffer.put(item_put) + + item_get = await asyncio.wait_for(get_task, timeout=0.1) + + assert item_get == item_put + + +async def test_simple_buffer_keeps_item_order(): + buffer = SimpleBuffer(["a", "b", "c"]) + + assert await buffer.get() == "a" + assert await buffer.get() == "b" + assert await buffer.get() == "c" + + await buffer.put("d") + await buffer.put("e") + await buffer.put("f") + + assert await buffer.get() == "d" + assert await buffer.get_all() == ["e", "f"] + + +async def test_simple_buffer_keeps_shallow_copy_of_items(): + initial_items = ["a", "b", "c"] + buffer = SimpleBuffer(initial_items) + assert buffer.size() == 3 + + initial_items.extend(["d", "e", "f"]) + assert buffer.size() == 3 + + put_items = ["g", "h", "i"] + await buffer.put_all(put_items) + assert buffer.size() == 3 + + put_items.extend(["j", "k", "l"]) + assert buffer.size() == 3 + + +async def test_expirable_buffer_is_not_expiring_initial_items(mocked_buffer, mocker): + expire_after = timedelta(seconds=0.1) + ExpirableBuffer( + mocked_buffer, + lambda i: expire_after, + ) + + await asyncio.sleep(0.2) + + mocked_buffer.remove.assert_not_called() + +async def test_expirable_buffer_is_not_expiring_items_with_none_expiration(mocked_buffer, mocker): + expiration_func = mocker.Mock(side_effect=[ + timedelta(seconds=0.1), + None, + timedelta(seconds=0.1) + ]) + buffer = ExpirableBuffer( + mocked_buffer, + expiration_func, + ) + + await buffer.put('a') + await buffer.put('b') + await buffer.put('c') + + await asyncio.sleep(0.2) + + assert mocker.call('a') in mocked_buffer.remove.mock_calls + assert mocker.call('c') in mocked_buffer.remove.mock_calls + + mocked_buffer.get.return_value = 'b' + + assert await buffer.get() == 'b' + + +async def test_expirable_buffer_can_expire_items_with_put_get(mocked_buffer, mocker): + expire_after = timedelta(seconds=0.1) + on_expire = mocker.AsyncMock() + buffer = ExpirableBuffer( + mocked_buffer, + lambda i: expire_after, + on_expire, + ) + item_put = object() + + await buffer.put(item_put) + mocked_buffer.put.assert_called_with(item_put) + + mocked_buffer.get.return_value = item_put + await buffer.get() + mocked_buffer.get.assert_called() + + mocked_buffer.remove.assert_not_called() + on_expire.assert_not_called() + + await asyncio.sleep(0.2) + + mocked_buffer.remove.assert_not_called() + on_expire.assert_not_called() + + mocked_buffer.reset_mock() + + await buffer.put(item_put) + mocked_buffer.put.assert_called_with(item_put) + + await asyncio.sleep(0.2) + + mocked_buffer.remove.assert_called_with(item_put) + on_expire.assert_called_with(item_put) + + +async def test_expirable_buffer_can_expire_items_with_put_all_get_all(mocked_buffer, mocker): + expire_after = timedelta(seconds=0.1) + on_expire = mocker.AsyncMock() + buffer = ExpirableBuffer( + mocked_buffer, + lambda i: expire_after, + on_expire, + ) + items_put_all = ['a', 'b' , 'c'] + + await buffer.put_all(items_put_all) + mocked_buffer.put_all.assert_called_with(items_put_all) + + mocked_buffer.get_all.return_value = items_put_all + await buffer.get_all() + mocked_buffer.get_all.assert_called() + + mocked_buffer.remove.assert_not_called() + on_expire.assert_not_called() + + await asyncio.sleep(0.2) + + on_expire.assert_not_called() + + mocked_buffer.reset_mock() + + await buffer.put_all(items_put_all) + mocked_buffer.put_all.assert_called_with(items_put_all) + + await asyncio.sleep(0.2) + + assert mocker.call(items_put_all[0]) in mocked_buffer.remove.mock_calls + assert mocker.call(items_put_all[1]) in mocked_buffer.remove.mock_calls + assert mocker.call(items_put_all[2]) in mocked_buffer.remove.mock_calls + + assert mocker.call(items_put_all[0]) in on_expire.mock_calls + assert mocker.call(items_put_all[1]) in on_expire.mock_calls + assert mocker.call(items_put_all[2]) in on_expire.mock_calls + + +async def test_background_feed_buffer_start_stop(mocked_buffer, mocker): + feed_func = mocker.AsyncMock() + buffer = BackgroundFeedBuffer( + mocked_buffer, + feed_func, + ) + + assert not buffer.is_started() + + await buffer.start() + + assert buffer.is_started() + + with pytest.raises(RuntimeError): + await buffer.start() + + assert buffer.is_started() + mocked_buffer.size.return_value = 0 + assert buffer.size_with_requested() == 0 + + await buffer.stop() + + assert not buffer.is_started() + + with pytest.raises(RuntimeError): + await buffer.stop() + + assert not buffer.is_started() + + feed_func.assert_not_called() + + +async def test_background_feed_buffer_request(mocked_buffer, mocker): + item = object() + feed_queue = asyncio.Queue() + feed_func = mocker.AsyncMock(wraps=feed_queue.get) + buffer = BackgroundFeedBuffer( + mocked_buffer, + feed_func, + ) + await buffer.start() + + await buffer.request(1) + + await asyncio.sleep(0.1) + + feed_func.assert_called() + mocked_buffer.size.return_value = 0 + assert buffer.size() == 0 + assert buffer.size_with_requested() == 1 + + await feed_queue.put(item) + + await asyncio.sleep(0.1) + + mocked_buffer.put.assert_called_with(item) + mocked_buffer.size.return_value = 1 + assert buffer.size() == 1 + assert buffer.size_with_requested() == 1 + + await buffer.stop() + +async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, event_loop): + timeout = timedelta(seconds=0.1) + item = object() + feed_queue = asyncio.Queue() + feed_func = mocker.AsyncMock(wraps=feed_queue.get) + buffer = BackgroundFeedBuffer( + mocked_buffer, + feed_func, + ) + await buffer.start() + + # Try to get items while buffer is empty and with no requests + mocked_buffer.get_all.return_value = [] + mocked_buffer.size.return_value = 0 + + time_before_wait = event_loop.time() + assert await buffer.get_all_requested(timeout) == [] + time_after_wait = event_loop.time() + + assert time_after_wait - time_before_wait < timeout.total_seconds(), 'get_all_requested seems to wait for the deadline instead of retuning fast' + # + + await buffer.request(1) + + await asyncio.sleep(0.1) + + feed_func.assert_called() + mocked_buffer.size.return_value = 0 + assert buffer.size() == 0 + assert buffer.size_with_requested() == 1 + + # Try to get items while buffer is empty but with pending requests + mocked_buffer.get_all.return_value = [] + mocked_buffer.size.return_value = 0 + + time_before_wait = event_loop.time() + assert await buffer.get_all_requested(timeout) == [] + time_after_wait = event_loop.time() + + assert timeout.total_seconds() <= time_after_wait - time_before_wait, 'get_all_requested seems to not wait to the deadline' + # + + await feed_queue.put(item) + + await asyncio.sleep(0.1) + + mocked_buffer.put.assert_called_with(item) + mocked_buffer.size.return_value = 1 + assert buffer.size() == 1 + assert buffer.size_with_requested() == 1 + + # Try to get items while buffer is have items and no pending requests + mocked_buffer.get_all.return_value = [item] + mocked_buffer.size.return_value = 0 + + time_before_wait = event_loop.time() + assert await buffer.get_all_requested(timeout) == [item] + time_after_wait = event_loop.time() + + assert time_after_wait - time_before_wait < timeout.total_seconds(), 'get_all_requested seems to wait for the deadline instead of retuning fast' + # + + await buffer.stop() diff --git a/tests/unit/utils/test_semaphore.py b/tests/unit/utils/test_semaphore.py new file mode 100644 index 00000000..d4c9b365 --- /dev/null +++ b/tests/unit/utils/test_semaphore.py @@ -0,0 +1,174 @@ +import asyncio + +import pytest + +from golem.utils.semaphore import SingleUseSemaphore + + +def test_creation(): + sem = SingleUseSemaphore() + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() + + sem = SingleUseSemaphore(0) + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() + + sem_count = 10 + sem = SingleUseSemaphore(sem_count) + + assert sem.get_count() == sem_count + assert sem.get_count_with_pending() == sem_count + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + with pytest.raises(ValueError): + SingleUseSemaphore(-10) + + +async def test_increase(): + sem = SingleUseSemaphore() + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() + + await sem.increase(1) + + assert sem.get_count() == 1 + assert sem.get_count_with_pending() == 1 + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + await sem.increase(10) + + assert sem.get_count() == 11 + assert sem.get_count_with_pending() == 11 + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + +async def test_reset(): + sem_value = 10 + sem = SingleUseSemaphore(sem_value) + + assert sem.get_count() == sem_value + assert sem.get_count_with_pending() == sem_value + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + sem.reset() + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() + + +async def test_acquire(): + sem_value = 1 + sem = SingleUseSemaphore(sem_value) + + assert sem.get_count() == sem_value + assert sem.get_count_with_pending() == sem_value + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + await asyncio.wait_for(sem.acquire(), 0.1) + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 1 + assert sem.get_pending_count() == 1 + assert not sem.finished.is_set() + assert sem.locked() + + _, pending = await asyncio.wait([sem.acquire()], timeout=0.1) + acquire_task = pending.pop() + if not acquire_task: + pytest.fail("Acquiring locked semaphore somehow finished instead of blocking!") + + await sem.increase(1) + + assert sem.get_count() == 1 + assert sem.get_count_with_pending() == 2 + assert sem.get_pending_count() == 1 + assert not sem.finished.is_set() + assert not sem.locked() + + await asyncio.wait_for(acquire_task, 0.1) + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 2 + assert sem.get_pending_count() == 2 + assert not sem.finished.is_set() + assert sem.locked() + + +async def test_release(): + sem_value = 1 + sem = SingleUseSemaphore(sem_value) + + assert sem.get_count() == sem_value + assert sem.get_count_with_pending() == sem_value + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + with pytest.raises(RuntimeError): + sem.release() + + await asyncio.wait_for(sem.acquire(), 0.1) + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 1 + assert sem.get_pending_count() == 1 + assert not sem.finished.is_set() + assert sem.locked() + + sem.release() + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() + + +async def test_context_manager(): + sem_value = 1 + sem = SingleUseSemaphore(sem_value) + + assert sem.get_count() == sem_value + assert sem.get_count_with_pending() == sem_value + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + async with sem: + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 1 + assert sem.get_pending_count() == 1 + assert not sem.finished.is_set() + assert sem.locked() + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() From a87f7763c89302461bcfd3fc13b096bd2d7f2972 Mon Sep 17 00:00:00 2001 From: approxit Date: Fri, 19 Jan 2024 16:37:11 +0100 Subject: [PATCH 04/20] finished? --- golem/managers/demand/refreshing.py | 2 +- golem/managers/proposal/plugins/new_buffer.py | 14 +- .../proposal/plugins/scoring/new_scoring.py | 52 ++++-- golem/utils/asyncio/__init__.py | 29 +++ golem/utils/{ => asyncio}/buffer.py | 103 +++++++---- golem/utils/{ => asyncio}/queue.py | 12 +- golem/utils/{ => asyncio}/semaphore.py | 0 golem/utils/{asyncio.py => asyncio/tasks.py} | 0 golem/utils/asyncio/waiter.py | 52 ++++++ tests/unit/utils/test_buffer.py | 168 +++++++++++++++--- .../unit/utils/test_error_reporting_queue.py | 2 +- tests/unit/utils/test_semaphore.py | 4 +- tests/unit/utils/test_waiter.py | 82 +++++++++ 13 files changed, 436 insertions(+), 84 deletions(-) create mode 100644 golem/utils/asyncio/__init__.py rename golem/utils/{ => asyncio}/buffer.py (77%) rename golem/utils/{ => asyncio}/queue.py (86%) rename golem/utils/{ => asyncio}/semaphore.py (100%) rename golem/utils/{asyncio.py => asyncio/tasks.py} (100%) create mode 100644 golem/utils/asyncio/waiter.py create mode 100644 tests/unit/utils/test_waiter.py diff --git a/golem/managers/demand/refreshing.py b/golem/managers/demand/refreshing.py index c0155086..4e72cada 100644 --- a/golem/managers/demand/refreshing.py +++ b/golem/managers/demand/refreshing.py @@ -11,8 +11,8 @@ from golem.resources import Allocation, Demand, Proposal from golem.resources.demand.demand_builder import DemandBuilder from golem.utils.asyncio import create_task_with_logging +from golem.utils.asyncio.queue import ErrorReportingQueue from golem.utils.logging import get_trace_id_name, trace_span -from golem.utils.queue import ErrorReportingQueue logger = logging.getLogger(__name__) diff --git a/golem/managers/proposal/plugins/new_buffer.py b/golem/managers/proposal/plugins/new_buffer.py index edd2f699..f6a53fb6 100644 --- a/golem/managers/proposal/plugins/new_buffer.py +++ b/golem/managers/proposal/plugins/new_buffer.py @@ -1,6 +1,6 @@ from golem.managers import ProposalManagerPlugin from golem.resources import Proposal -from golem.utils.buffer import BackgroundFeedBuffer, SimpleBuffer +from golem.utils.asyncio.buffer import BackgroundFeedBuffer, SimpleBuffer class Buffer(ProposalManagerPlugin): @@ -8,7 +8,7 @@ def __init__( self, min_size: int, max_size: int, - fill_concurrency_size: int = 1, + fill_concurrency_size=1, fill_at_start=False, ) -> None: self._min_size = min_size @@ -29,22 +29,22 @@ async def start(self) -> None: await self._buffer.start() if self._fill_at_start: - self._request_items() + await self._request_items() async def stop(self) -> None: await self._buffer.stop() - def _request_items(self): - self._buffer.request(self._max_size - self._buffer.size_with_requested()) + async def _request_items(self): + await self._buffer.request(self._max_size - self._buffer.size_with_requested()) async def get_proposal(self) -> Proposal: if not self._get_items_count(): - self._request_items() + await self._request_items() proposal = await self._get_item() if self._get_items_count() < self._min_size: - self._request_items() + await self._request_items() return proposal diff --git a/golem/managers/proposal/plugins/scoring/new_scoring.py b/golem/managers/proposal/plugins/scoring/new_scoring.py index 6cebffe8..0ddd48bb 100644 --- a/golem/managers/proposal/plugins/scoring/new_scoring.py +++ b/golem/managers/proposal/plugins/scoring/new_scoring.py @@ -1,10 +1,16 @@ import asyncio +import logging from datetime import timedelta +from typing import Optional from golem.managers import ProposalScoringMixin from golem.managers.proposal.plugins.new_buffer import Buffer as BufferPlugin from golem.resources import Proposal -from golem.utils.buffer import Buffer +from golem.utils.asyncio import create_task_with_logging +from golem.utils.asyncio.buffer import Buffer, SimpleBuffer +from golem.utils.logging import get_trace_id_name, trace_span + +logger = logging.getLogger(__name__) class ScoringBuffer(ProposalScoringMixin, BufferPlugin): @@ -13,22 +19,46 @@ def __init__(self, update_interval: timedelta = timedelta(seconds=10), *args, ** self._update_interval = update_interval - self._buffer_scored: Buffer[Proposal] = Buffer() - self._lock = asyncio.Lock() + self._buffer_scored: Buffer[Proposal] = SimpleBuffer() + self._background_loop_task: Optional[asyncio.Task] = None + + @trace_span() + async def start(self) -> None: + await super().start() + + self._background_loop_task = create_task_with_logging( + self._background_loop(), trace_id=get_trace_id_name(self, "background-loop") + ) + + @trace_span() + async def stop(self) -> None: + await super().stop() + + if self._background_loop_task is not None: + self._background_loop_task.cancel() + self._background_loop_task = None async def _background_loop(self) -> None: while True: - self._buffer.wait_for_any_items() + logger.debug("Waiting for any items to score...") + await self._buffer.wait_for_any_items() + logger.debug("Waiting for any items to score done, items are available for scoring") + + logger.debug(f"Waiting for more items up to {self._update_interval}...") + items = await self._buffer.get_all_requested(self._update_interval) + logger.debug(f"Waiting for more items done, {len(items)} new items will be scored") + + items.extend(await self._buffer_scored.get_all()) + + logger.debug(f"Scoring total {len(items)} items...") + + scored_items = await self.do_scoring(items) + await self._buffer_scored.put_all([proposal for _, proposal in scored_items]) - async with self._lock: - items = await self._buffer.get_all_requested(self._update_interval) - items.extend(await self._buffer_scored.get_all()) - scored_items = await self.do_scoring(items) - await self._buffer_scored.put_all([proposal for _, proposal in scored_items]) + logger.debug(f"Scoring total {len(items)} items done") async def _get_item(self) -> Proposal: - async with self._lock: - return await self._buffer_scored.get() + return await self._buffer_scored.get() def _get_items_count(self) -> int: return super()._get_items_count() + self._buffer_scored.size() diff --git a/golem/utils/asyncio/__init__.py b/golem/utils/asyncio/__init__.py new file mode 100644 index 00000000..bd8798b7 --- /dev/null +++ b/golem/utils/asyncio/__init__.py @@ -0,0 +1,29 @@ +from golem.utils.asyncio.buffer import ( + BackgroundFeedBuffer, + Buffer, + ComposableBuffer, + ExpirableBuffer, + SimpleBuffer, +) +from golem.utils.asyncio.queue import ErrorReportingQueue +from golem.utils.asyncio.semaphore import SingleUseSemaphore +from golem.utils.asyncio.tasks import ( + cancel_and_await, + cancel_and_await_many, + create_task_with_logging, +) +from golem.utils.asyncio.waiter import Waiter + +__all__ = ( + "BackgroundFeedBuffer", + "Buffer", + "ComposableBuffer", + "ExpirableBuffer", + "SimpleBuffer", + "ErrorReportingQueue", + "SingleUseSemaphore", + "cancel_and_await", + "cancel_and_await_many", + "create_task_with_logging", + "Waiter", +) diff --git a/golem/utils/buffer.py b/golem/utils/asyncio/buffer.py similarity index 77% rename from golem/utils/buffer.py rename to golem/utils/asyncio/buffer.py index e84e9583..52fb1f1f 100644 --- a/golem/utils/buffer.py +++ b/golem/utils/asyncio/buffer.py @@ -16,9 +16,10 @@ TypeVar, ) -from golem.utils.asyncio import cancel_and_await_many, create_task_with_logging +from golem.utils.asyncio.semaphore import SingleUseSemaphore +from golem.utils.asyncio.tasks import cancel_and_await_many, create_task_with_logging +from golem.utils.asyncio.waiter import Waiter from golem.utils.logging import get_trace_id_name, trace_span -from golem.utils.semaphore import SingleUseSemaphore TItem = TypeVar("TItem") @@ -26,25 +27,31 @@ class Buffer(ABC, Generic[TItem]): + """Interface class for object similar to `asyncio.Queue` but with more control over its items.""" + @abstractmethod def size(self) -> int: """Return number of items stored in buffer.""" - ... @abstractmethod async def wait_for_any_items(self) -> None: """Wait until any items are stored in buffer.""" - ... @abstractmethod async def get(self) -> TItem: - """Await, remove and return left-most item stored in buffer.""" - ... + """Await, remove and return left-most item stored in buffer. + + If `.set_exception()` was previously called, exception will be raised only if buffer is empty. + """ @abstractmethod async def get_all(self) -> MutableSequence[TItem]: - """Remove and return all items stored in buffer.""" - ... + """Remove and return all items stored in buffer. + + Note that this method will not await for any items if buffer is empty. + + If `.set_exception()` was previously called, exception will be raised only if buffer is empty. + """ @abstractmethod async def put(self, item: TItem) -> None: @@ -52,7 +59,6 @@ async def put(self, item: TItem) -> None: Duplicates are supported. """ - ... @abstractmethod async def put_all(self, items: Sequence[TItem]) -> None: @@ -60,12 +66,17 @@ async def put_all(self, items: Sequence[TItem]) -> None: Duplicates are supported. """ - ... @abstractmethod async def remove(self, item: TItem) -> None: """Remove first occurrence of item from buffer or raise `ValueError` if not found.""" - ... + + @abstractmethod + def set_exception(self, exc: BaseException) -> None: + """Set exception that will be raised while trying to `.get()`/`.get_all()` item from empty buffer.""" + + def reset_exception(self) -> None: + """Reset exception that was previously set by calling `.set_exception()`.""" class ComposableBuffer(Buffer[TItem]): @@ -95,65 +106,76 @@ async def put_all(self, items: Sequence[TItem]) -> None: async def remove(self, item: TItem) -> None: await self._buffer.remove(item) + def set_exception(self, exc: BaseException) -> None: + self._buffer.set_exception(exc) + + def reset_exception(self) -> None: + self._buffer.reset_exception() + class SimpleBuffer(Buffer[TItem]): """Most basic implementation of Buffer interface.""" def __init__(self, items: Optional[Sequence[TItem]] = None): self._items = list(items) if items is not None else [] + self._error: Optional[BaseException] = None - self._have_items = ( - asyncio.Event() - ) # TODO: collections of future-object waiters instead of event - - if self.size(): - self._have_items.set() + self._waiter = Waiter() def size(self) -> int: return len(self._items) + @trace_span() async def wait_for_any_items(self) -> None: - while not self.size(): - await self._have_items.wait() + await self._waiter.wait_for(lambda: bool(self.size() or self._error)) + @trace_span() async def get(self) -> TItem: - await self.wait_for_any_items() + if not self.size(): + if self._error: + raise self._error + else: + await self.wait_for_any_items() - item = self._items.pop(0) + if not self.size() and self._error: + raise self._error - if not self.size(): - self._have_items.clear() + item = self._items.pop(0) return item async def get_all(self) -> MutableSequence[TItem]: + if not self._items and self._error: + raise self._error + items = self._items[:] self._items.clear() - self._have_items.clear() + return items async def put(self, item: TItem) -> None: self._items.append(item) - self._have_items.set() + self._waiter.notify() async def put_all(self, items: Sequence[TItem]) -> None: self._items.clear() self._items.extend(items[:]) - if self.size(): - self._have_items.set() - else: - self._have_items.clear() + self._waiter.notify(len(items)) async def remove(self, item: TItem) -> None: self._items.remove(item) - if not self.size(): - self._have_items.clear() + def set_exception(self, exc: BaseException) -> None: + self._error = exc + self._waiter.notify() + + def reset_exception(self) -> None: + self._error = None class ExpirableBuffer(ComposableBuffer[TItem]): - """Composable that adds option to expire item after some time. + """Composable `Buffer` that adds option to expire item after some time. Items that are already in provided buffer will not expire. """ @@ -185,7 +207,9 @@ def _add_expiration_task_for_item(self, item: TItem) -> None: loop = asyncio.get_event_loop() self._expiration_handlers[id(item)].append( - loop.call_later(expiration.total_seconds(), lambda: asyncio.create_task(self._expire_item(item))) + loop.call_later( + expiration.total_seconds(), lambda: asyncio.create_task(self._expire_item(item)) + ) ) async def _remove_expiration_handler_for_item(self, item: TItem) -> None: @@ -238,6 +262,7 @@ async def remove(self, item: TItem) -> None: await super().remove(item) await self._remove_expiration_handler_for_item(item) + @trace_span() async def _expire_item(self, item: TItem) -> None: await self.remove(item) @@ -246,7 +271,7 @@ async def _expire_item(self, item: TItem) -> None: class BackgroundFeedBuffer(ComposableBuffer[TItem]): - """Composable that adds option to feed buffer in background task. + """Composable `Buffer` that adds option to feed buffer in background task. Background feed will happen only if background tasks are started by calling `.start()` and items were requested by `.request()`. @@ -297,15 +322,25 @@ def is_started(self) -> bool: async def _worker_loop(self): while True: + logger.debug("Waiting for item request...") + async with self._workers_semaphore: + logger.debug("Waiting for item request done") + + logger.debug("Adding new item...") + item = await self._feed_func() await self.put(item) + logger.debug("Adding new item done") + async def request(self, count: int) -> None: """Request given number of items to be filled in background.""" await self._workers_semaphore.increase(count) + logger.debug(f"Requested {count} items") + def size_with_requested(self) -> int: """Return sum of items stored in buffer and requested to be filled.""" return self.size() + self._workers_semaphore.get_count_with_pending() diff --git a/golem/utils/queue.py b/golem/utils/asyncio/queue.py similarity index 86% rename from golem/utils/queue.py rename to golem/utils/asyncio/queue.py index 6cc1ece6..98c982d5 100644 --- a/golem/utils/queue.py +++ b/golem/utils/asyncio/queue.py @@ -1,10 +1,10 @@ import asyncio from typing import Generic, Optional, TypeVar -QueueItem = TypeVar("QueueItem") +TQueueItem = TypeVar("TQueueItem") -class ErrorReportingQueue(asyncio.Queue, Generic[QueueItem]): +class ErrorReportingQueue(asyncio.Queue, Generic[TQueueItem]): """Asyncio Queue that enables exceptions to be passed to consumers from the feeding code.""" def __init__(self, *args, **kwargs): @@ -13,7 +13,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def get_nowait(self) -> QueueItem: + def get_nowait(self) -> TQueueItem: """Perform a regular `get_nowait` if there are items in the queue. Otherwise, if an exception had been signalled, raise the exception. @@ -22,7 +22,7 @@ def get_nowait(self) -> QueueItem: raise self._error return super().get_nowait() - async def get(self) -> QueueItem: + async def get(self) -> TQueueItem: """Perform a regular, waiting `get` but raise an exception if happens while waiting. If there had been items in the queue, @@ -43,10 +43,10 @@ async def get(self) -> QueueItem: assert self._error raise self._error - async def put(self, item: QueueItem): + async def put(self, item: TQueueItem): await super().put(item) - def put_nowait(self, item: QueueItem): + def put_nowait(self, item: TQueueItem): super().put_nowait(item) def set_exception(self, exc: BaseException): diff --git a/golem/utils/semaphore.py b/golem/utils/asyncio/semaphore.py similarity index 100% rename from golem/utils/semaphore.py rename to golem/utils/asyncio/semaphore.py diff --git a/golem/utils/asyncio.py b/golem/utils/asyncio/tasks.py similarity index 100% rename from golem/utils/asyncio.py rename to golem/utils/asyncio/tasks.py diff --git a/golem/utils/asyncio/waiter.py b/golem/utils/asyncio/waiter.py new file mode 100644 index 00000000..09c08326 --- /dev/null +++ b/golem/utils/asyncio/waiter.py @@ -0,0 +1,52 @@ +import asyncio +import collections +from typing import Callable + + +class Waiter: + """Class similar to `asyncio.Event` but valueless and with notify interface similar to `asyncio.Condition`.""" + + def __init__(self) -> None: + self._waiters: collections.deque[asyncio.Future] = collections.deque() + self._loop = asyncio.get_event_loop() + + async def wait_for(self, predicate: Callable[[], bool]) -> None: + """Check if predicate is true and return immediately, or await until it becomes true.""" + result = predicate() + + while not result: + await self._wait() + result = predicate() + + if not result: + # as last `._wait()` call woken us up but predicate is still false, lets give a + # chance another `.wait_for()` pending call. + self._notify_first() + + def _notify_first(self) -> None: + try: + first_waiter = self._waiters[0] + except IndexError: + return + + if not first_waiter.done(): + first_waiter.set_result(None) + + async def _wait(self): + future = self._loop.create_future() + self._waiters.append(future) + try: + await future + finally: + self._waiters.remove(future) + + def notify(self, count=1) -> None: + """Notify given amount of `.wait_for()` calls to check its predicates.""" + notified = 0 + for waiter in self._waiters: + if count <= notified: + break + + if not waiter.done(): + waiter.set_result(None) + notified += 1 diff --git a/tests/unit/utils/test_buffer.py b/tests/unit/utils/test_buffer.py index 8a10a604..b0f4cf9d 100644 --- a/tests/unit/utils/test_buffer.py +++ b/tests/unit/utils/test_buffer.py @@ -1,11 +1,9 @@ -from datetime import timedelta - import asyncio +from datetime import timedelta import pytest -from golem.utils.buffer import SimpleBuffer, Buffer, ExpirableBuffer, BackgroundFeedBuffer - +from golem.utils.asyncio.buffer import BackgroundFeedBuffer, Buffer, ExpirableBuffer, SimpleBuffer @pytest.fixture @@ -65,10 +63,11 @@ async def test_simple_buffer_get_waits_for_items(): assert buffer.size() == 0 _, pending = await asyncio.wait([buffer.get()], timeout=0.1) - get_task = pending.pop() - if not get_task: + if not pending: pytest.fail("Getting empty buffer somehow finished instead of blocking!") + get_task = pending.pop() + item_put = object() await buffer.put(item_put) @@ -76,6 +75,24 @@ async def test_simple_buffer_get_waits_for_items(): assert item_get == item_put + # concurrent wait + get_task1 = asyncio.create_task(buffer.get()) + get_task2 = asyncio.create_task(buffer.get()) + + await asyncio.sleep(0.1) + + await buffer.put(item_put) + + done, pending = await asyncio.wait([get_task1, get_task2], timeout=0.1) + if len(done) != len(pending): + pytest.fail(f"One of the tasks should not block at this point!") + + await buffer.put(item_put) + + await asyncio.sleep(0.1) + + await asyncio.wait_for(pending.pop(), timeout=0.1) + async def test_simple_buffer_keeps_item_order(): buffer = SimpleBuffer(["a", "b", "c"]) @@ -108,6 +125,107 @@ async def test_simple_buffer_keeps_shallow_copy_of_items(): assert buffer.size() == 3 +async def test_simple_buffer_exceptions(): + buffer = SimpleBuffer() + assert buffer.size() == 0 + + exc = ZeroDivisionError() + + buffer.set_exception(exc) + + # should raise when exception set and no items + with pytest.raises(ZeroDivisionError): + await buffer.get() + + with pytest.raises(ZeroDivisionError): + await buffer.get_all() + + await buffer.put("a") + + # should not raise when exception set and with items + assert await buffer.get_all() == ["a"] + + # should raise when exception set and items were cleared + with pytest.raises(ZeroDivisionError): + await buffer.get_all() + + buffer.reset_exception() + + assert await buffer.get_all() == [] + + try: + await asyncio.wait_for(buffer.get(), timeout=0.1) + except asyncio.TimeoutError: + pass + else: + pytest.fail("Getting empty buffer somehow finished instead of blocking!") + + get_task = asyncio.create_task(buffer.get()) + + await asyncio.sleep(0.1) + + buffer.set_exception(exc) + + await asyncio.sleep(0.1) + + with pytest.raises(ZeroDivisionError): + get_task.result() + + +async def test_simple_buffer_wait_for_any_items(): + buffer = SimpleBuffer() + assert buffer.size() == 0 + + # should block on empty + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) + + # should unblock on item + wait_task = asyncio.create_task(buffer.wait_for_any_items()) + + await asyncio.sleep(0.1) + + assert not wait_task.done() + + await buffer.put("a") + + await asyncio.sleep(0.1) + + assert wait_task.done() + + # should not block a long time on item + await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) + + buffer.set_exception(ZeroDivisionError()) + + # should not block a long time on item with exception + await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) + + await buffer.get() + + # should not block a long time with exception and no items + await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) + + buffer.reset_exception() + + # should block on after exception reset and no items + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) + + # should unblock on item + wait_task = asyncio.create_task(buffer.wait_for_any_items()) + + await asyncio.sleep(0.1) + + assert not wait_task.done() + + buffer.set_exception(ZeroDivisionError()) + + await asyncio.sleep(0.1) + + assert wait_task.done() + + async def test_expirable_buffer_is_not_expiring_initial_items(mocked_buffer, mocker): expire_after = timedelta(seconds=0.1) ExpirableBuffer( @@ -119,29 +237,28 @@ async def test_expirable_buffer_is_not_expiring_initial_items(mocked_buffer, moc mocked_buffer.remove.assert_not_called() + async def test_expirable_buffer_is_not_expiring_items_with_none_expiration(mocked_buffer, mocker): - expiration_func = mocker.Mock(side_effect=[ - timedelta(seconds=0.1), - None, - timedelta(seconds=0.1) - ]) + expiration_func = mocker.Mock( + side_effect=[timedelta(seconds=0.1), None, timedelta(seconds=0.1)] + ) buffer = ExpirableBuffer( mocked_buffer, expiration_func, ) - await buffer.put('a') - await buffer.put('b') - await buffer.put('c') + await buffer.put("a") + await buffer.put("b") + await buffer.put("c") await asyncio.sleep(0.2) - assert mocker.call('a') in mocked_buffer.remove.mock_calls - assert mocker.call('c') in mocked_buffer.remove.mock_calls + assert mocker.call("a") in mocked_buffer.remove.mock_calls + assert mocker.call("c") in mocked_buffer.remove.mock_calls - mocked_buffer.get.return_value = 'b' + mocked_buffer.get.return_value = "b" - assert await buffer.get() == 'b' + assert await buffer.get() == "b" async def test_expirable_buffer_can_expire_items_with_put_get(mocked_buffer, mocker): @@ -188,7 +305,7 @@ async def test_expirable_buffer_can_expire_items_with_put_all_get_all(mocked_buf lambda i: expire_after, on_expire, ) - items_put_all = ['a', 'b' , 'c'] + items_put_all = ["a", "b", "c"] await buffer.put_all(items_put_all) mocked_buffer.put_all.assert_called_with(items_put_all) @@ -282,6 +399,7 @@ async def test_background_feed_buffer_request(mocked_buffer, mocker): await buffer.stop() + async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, event_loop): timeout = timedelta(seconds=0.1) item = object() @@ -301,7 +419,9 @@ async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, e assert await buffer.get_all_requested(timeout) == [] time_after_wait = event_loop.time() - assert time_after_wait - time_before_wait < timeout.total_seconds(), 'get_all_requested seems to wait for the deadline instead of retuning fast' + assert ( + time_after_wait - time_before_wait < timeout.total_seconds() + ), "get_all_requested seems to wait for the deadline instead of retuning fast" # await buffer.request(1) @@ -321,7 +441,9 @@ async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, e assert await buffer.get_all_requested(timeout) == [] time_after_wait = event_loop.time() - assert timeout.total_seconds() <= time_after_wait - time_before_wait, 'get_all_requested seems to not wait to the deadline' + assert ( + timeout.total_seconds() <= time_after_wait - time_before_wait + ), "get_all_requested seems to not wait to the deadline" # await feed_queue.put(item) @@ -341,7 +463,9 @@ async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, e assert await buffer.get_all_requested(timeout) == [item] time_after_wait = event_loop.time() - assert time_after_wait - time_before_wait < timeout.total_seconds(), 'get_all_requested seems to wait for the deadline instead of retuning fast' + assert ( + time_after_wait - time_before_wait < timeout.total_seconds() + ), "get_all_requested seems to wait for the deadline instead of retuning fast" # await buffer.stop() diff --git a/tests/unit/utils/test_error_reporting_queue.py b/tests/unit/utils/test_error_reporting_queue.py index 2dba6c7e..8e7f0c67 100644 --- a/tests/unit/utils/test_error_reporting_queue.py +++ b/tests/unit/utils/test_error_reporting_queue.py @@ -2,7 +2,7 @@ import pytest -from golem.utils.queue import ErrorReportingQueue +from golem.utils.asyncio.queue import ErrorReportingQueue class SomeException(Exception): diff --git a/tests/unit/utils/test_semaphore.py b/tests/unit/utils/test_semaphore.py index d4c9b365..3ef09c78 100644 --- a/tests/unit/utils/test_semaphore.py +++ b/tests/unit/utils/test_semaphore.py @@ -2,10 +2,10 @@ import pytest -from golem.utils.semaphore import SingleUseSemaphore +from golem.utils.asyncio.semaphore import SingleUseSemaphore -def test_creation(): +async def test_creation(): sem = SingleUseSemaphore() assert sem.get_count() == 0 diff --git a/tests/unit/utils/test_waiter.py b/tests/unit/utils/test_waiter.py new file mode 100644 index 00000000..486da920 --- /dev/null +++ b/tests/unit/utils/test_waiter.py @@ -0,0 +1,82 @@ +import asyncio + +import pytest + +from golem.utils.asyncio import Waiter + + +async def test_waiter_single_wait_for(): + waiter = Waiter() + value = 0 + + wait_task = asyncio.create_task(waiter.wait_for(lambda: value == 1)) + + await asyncio.sleep(0.1) + + assert not wait_task.done() + + waiter.notify() + + await asyncio.sleep(0.1) + + assert not wait_task.done() + + value = 1 + waiter.notify() + + await asyncio.wait_for(wait_task, timeout=0.1) + + +async def test_waiter_multiple_wait_for(): + waiter = Waiter() + value = 0 + + wait_task1 = asyncio.create_task(waiter.wait_for(lambda: value == 1)) + wait_task2 = asyncio.create_task(waiter.wait_for(lambda: value == 1)) + wait_task3 = asyncio.create_task(waiter.wait_for(lambda: value == 1)) + + await asyncio.sleep(0.1) + + assert not wait_task1.done() + + value = 1 + waiter.notify() + + await asyncio.wait_for(wait_task1, timeout=0.1) + + done, _ = await asyncio.wait([wait_task2, wait_task3], timeout=0.1) + + if done: + pytest.fail("Somehow some tasks finished too early!") + + waiter.notify(2) + + _, pending = await asyncio.wait([wait_task2, wait_task3], timeout=0.1) + + if pending: + pytest.fail("Somehow some tasks not finished!") + + +async def test_waiter_will_renotify_when_predicate_was_false(): + waiter = Waiter() + value = 0 + + wait_task1 = asyncio.create_task(waiter.wait_for(lambda: value == 2)) + wait_task2 = asyncio.create_task(waiter.wait_for(lambda: value == 1)) + + await asyncio.sleep(0.1) + assert not wait_task2.done() and not wait_task1.done() + + value = 1 + waiter.notify() + + await asyncio.sleep(0.1) + + assert wait_task2.done(), "Task2 should be done at this point!" + assert not wait_task1.done(), "Task1 should still block at this point!" + + value = 2 + + waiter.notify() + + await asyncio.wait_for(wait_task1, timeout=0.1) From 7f8cdd4f84b9a0e01261f9aa8cead09bad56c6c6 Mon Sep 17 00:00:00 2001 From: approxit Date: Fri, 19 Jan 2024 17:44:34 +0100 Subject: [PATCH 05/20] at least finished to pr draft --- golem/managers/demand/refreshing.py | 3 +- golem/managers/proposal/plugins/buffer.py | 141 ++++++------------ golem/managers/proposal/plugins/new_buffer.py | 55 ------- .../proposal/plugins/scoring/new_scoring.py | 64 -------- .../plugins/scoring/scoring_buffer.py | 128 ++++------------ golem/utils/asyncio/__init__.py | 4 +- golem/utils/asyncio/buffer.py | 44 +++--- golem/utils/asyncio/tasks.py | 2 +- tests/unit/utils/test_buffer.py | 8 +- .../unit/utils/test_error_reporting_queue.py | 2 +- 10 files changed, 109 insertions(+), 342 deletions(-) delete mode 100644 golem/managers/proposal/plugins/new_buffer.py delete mode 100644 golem/managers/proposal/plugins/scoring/new_scoring.py diff --git a/golem/managers/demand/refreshing.py b/golem/managers/demand/refreshing.py index 4e72cada..6289a074 100644 --- a/golem/managers/demand/refreshing.py +++ b/golem/managers/demand/refreshing.py @@ -10,8 +10,7 @@ from golem.payload import defaults as payload_defaults from golem.resources import Allocation, Demand, Proposal from golem.resources.demand.demand_builder import DemandBuilder -from golem.utils.asyncio import create_task_with_logging -from golem.utils.asyncio.queue import ErrorReportingQueue +from golem.utils.asyncio import ErrorReportingQueue, create_task_with_logging from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) diff --git a/golem/managers/proposal/plugins/buffer.py b/golem/managers/proposal/plugins/buffer.py index f4ec6f6d..7236284f 100644 --- a/golem/managers/proposal/plugins/buffer.py +++ b/golem/managers/proposal/plugins/buffer.py @@ -1,12 +1,9 @@ -import asyncio import logging -from asyncio import Queue -from typing import List from golem.managers import ProposalManagerPlugin from golem.resources import Proposal -from golem.utils.asyncio import create_task_with_logging -from golem.utils.logging import get_trace_id_name, trace_span +from golem.utils.asyncio.buffer import BackgroundFillBuffer, SimpleBuffer +from golem.utils.logging import trace_span logger = logging.getLogger(__name__) @@ -16,115 +13,67 @@ def __init__( self, min_size: int, max_size: int, - fill_concurrency_size: int = 1, + fill_concurrency_size=1, fill_at_start=False, - ): + ) -> None: self._min_size = min_size self._max_size = max_size self._fill_concurrency_size = fill_concurrency_size self._fill_at_start = fill_at_start - self._get_proposal_lock = asyncio.Lock() - self._worker_tasks: List[asyncio.Task] = [] - # As there is no awaitable counter, queue with dummy values is used - self._requests_queue: Queue[int] = asyncio.Queue() - self._requests_pending_count = 0 - self._buffered: List[Proposal] = [] - self._buffered_condition = asyncio.Condition() - self._is_started = False + self._buffer = BackgroundFillBuffer( + buffer=SimpleBuffer(), + fill_func=self._call_feed_func, + fill_concurrency_size=self._fill_concurrency_size, + ) - @trace_span(show_results=True) - async def get_proposal(self) -> Proposal: - if not self.is_started(): - raise RuntimeError("Not started!") - - async with self._get_proposal_lock: - return await self._get_item() + async def _call_feed_func(self) -> Proposal: + return await self._get_proposal() @trace_span() async def start(self) -> None: - if self.is_started(): - raise RuntimeError("Already started!") - - for i in range(self._fill_concurrency_size): - self._worker_tasks.append( - create_task_with_logging( - self._worker_loop(), trace_id=get_trace_id_name(self, f"worker-{i}") - ) - ) + await self._buffer.start() if self._fill_at_start: - self._handle_item_requests() - - self._is_started = True + await self._request_items() @trace_span() async def stop(self) -> None: - if not self.is_started(): - raise RuntimeError("Already stopped!") - - for worker_task in self._worker_tasks: - worker_task.cancel() - - try: - await worker_task - except asyncio.CancelledError: - pass - - self._worker_tasks.clear() - self._is_started = False - - self._requests_queue = asyncio.Queue() - self._requests_pending_count = 0 - - def is_started(self) -> bool: - return self._is_started - - async def _worker_loop(self): - while True: - await self._wait_for_any_item_requests() - - item = await self._get_proposal() - - async with self._buffered_condition: - self._buffered.append(item) + await self._buffer.stop() - self._buffered_condition.notify_all() + async def _request_items(self): + count = self._max_size - self._buffer.size_with_requested() + await self._buffer.request(count) - self._requests_queue.task_done() - self._requests_pending_count -= 1 + logger.debug("Requested %s items", count) - @trace_span() - async def _wait_for_any_item_requests(self) -> None: - await self._requests_queue.get() - - async def _get_item(self): - async with self._buffered_condition: - if self._get_items_count() == 0: # This supports lazy (not at start) buffer filling - logger.debug("No items to get, requesting fill") - self._handle_item_requests() - - logger.debug("Waiting for any item to pick...") - - await self._buffered_condition.wait_for(lambda: 0 < len(self._buffered)) - item = self._buffered.pop() - - # Check if we need to request any additional items - if self._get_items_count() < self._min_size: - self._handle_item_requests() - - return item - - def _get_items_count(self): - return len(self._buffered) + self._requests_pending_count - - @trace_span() - def _handle_item_requests(self) -> None: - items_to_request = self._max_size - self._get_items_count() + @trace_span(show_results=True) + async def get_proposal(self) -> Proposal: + if not self._get_items_count(): + logger.debug("No items to get, requesting fill") + await self._request_items() + + proposal = await self._get_item() + + items_count = self._get_items_count() + if items_count < self._min_size: + logger.debug( + "Target items count `%s` is below min size `%d`, requesting fill", + items_count, + self._min_size, + ) + await self._request_items() + else: + logger.debug( + "Target items count `%s` is not below min size `%d`, requesting fill not needed", + items_count, + self._min_size, + ) - for i in range(items_to_request): - self._requests_queue.put_nowait(i) + return proposal - self._requests_pending_count += items_to_request + async def _get_item(self) -> Proposal: + return await self._buffer.get() - logger.debug("Requested %d items", items_to_request) + def _get_items_count(self) -> int: + return self._buffer.size_with_requested() diff --git a/golem/managers/proposal/plugins/new_buffer.py b/golem/managers/proposal/plugins/new_buffer.py deleted file mode 100644 index f6a53fb6..00000000 --- a/golem/managers/proposal/plugins/new_buffer.py +++ /dev/null @@ -1,55 +0,0 @@ -from golem.managers import ProposalManagerPlugin -from golem.resources import Proposal -from golem.utils.asyncio.buffer import BackgroundFeedBuffer, SimpleBuffer - - -class Buffer(ProposalManagerPlugin): - def __init__( - self, - min_size: int, - max_size: int, - fill_concurrency_size=1, - fill_at_start=False, - ) -> None: - self._min_size = min_size - self._max_size = max_size - self._fill_concurrency_size = fill_concurrency_size - self._fill_at_start = fill_at_start - - self._buffer = BackgroundFeedBuffer( - buffer=SimpleBuffer(), - feed_func=self._call_feed_func, - feed_concurrency_size=self._fill_concurrency_size, - ) - - async def _call_feed_func(self) -> Proposal: - return await self._get_proposal() - - async def start(self) -> None: - await self._buffer.start() - - if self._fill_at_start: - await self._request_items() - - async def stop(self) -> None: - await self._buffer.stop() - - async def _request_items(self): - await self._buffer.request(self._max_size - self._buffer.size_with_requested()) - - async def get_proposal(self) -> Proposal: - if not self._get_items_count(): - await self._request_items() - - proposal = await self._get_item() - - if self._get_items_count() < self._min_size: - await self._request_items() - - return proposal - - async def _get_item(self) -> Proposal: - return await self._buffer.get() - - def _get_items_count(self) -> int: - return self._buffer.size_with_requested() diff --git a/golem/managers/proposal/plugins/scoring/new_scoring.py b/golem/managers/proposal/plugins/scoring/new_scoring.py deleted file mode 100644 index 0ddd48bb..00000000 --- a/golem/managers/proposal/plugins/scoring/new_scoring.py +++ /dev/null @@ -1,64 +0,0 @@ -import asyncio -import logging -from datetime import timedelta -from typing import Optional - -from golem.managers import ProposalScoringMixin -from golem.managers.proposal.plugins.new_buffer import Buffer as BufferPlugin -from golem.resources import Proposal -from golem.utils.asyncio import create_task_with_logging -from golem.utils.asyncio.buffer import Buffer, SimpleBuffer -from golem.utils.logging import get_trace_id_name, trace_span - -logger = logging.getLogger(__name__) - - -class ScoringBuffer(ProposalScoringMixin, BufferPlugin): - def __init__(self, update_interval: timedelta = timedelta(seconds=10), *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - self._update_interval = update_interval - - self._buffer_scored: Buffer[Proposal] = SimpleBuffer() - self._background_loop_task: Optional[asyncio.Task] = None - - @trace_span() - async def start(self) -> None: - await super().start() - - self._background_loop_task = create_task_with_logging( - self._background_loop(), trace_id=get_trace_id_name(self, "background-loop") - ) - - @trace_span() - async def stop(self) -> None: - await super().stop() - - if self._background_loop_task is not None: - self._background_loop_task.cancel() - self._background_loop_task = None - - async def _background_loop(self) -> None: - while True: - logger.debug("Waiting for any items to score...") - await self._buffer.wait_for_any_items() - logger.debug("Waiting for any items to score done, items are available for scoring") - - logger.debug(f"Waiting for more items up to {self._update_interval}...") - items = await self._buffer.get_all_requested(self._update_interval) - logger.debug(f"Waiting for more items done, {len(items)} new items will be scored") - - items.extend(await self._buffer_scored.get_all()) - - logger.debug(f"Scoring total {len(items)} items...") - - scored_items = await self.do_scoring(items) - await self._buffer_scored.put_all([proposal for _, proposal in scored_items]) - - logger.debug(f"Scoring total {len(items)} items done") - - async def _get_item(self) -> Proposal: - return await self._buffer_scored.get() - - def _get_items_count(self) -> int: - return super()._get_items_count() + self._buffer_scored.size() diff --git a/golem/managers/proposal/plugins/scoring/scoring_buffer.py b/golem/managers/proposal/plugins/scoring/scoring_buffer.py index 850b6727..26d53f13 100644 --- a/golem/managers/proposal/plugins/scoring/scoring_buffer.py +++ b/golem/managers/proposal/plugins/scoring/scoring_buffer.py @@ -1,33 +1,26 @@ import asyncio import logging from datetime import timedelta -from typing import List, Optional +from typing import Optional -from golem.managers.proposal.plugins.buffer import Buffer -from golem.managers.proposal.plugins.scoring.mixins import ProposalScoringMixin +from golem.managers.proposal.plugins.buffer import Buffer as BufferPlugin +from golem.managers.proposal.plugins.scoring import ProposalScoringMixin from golem.resources import Proposal -from golem.utils.asyncio import create_task_with_logging +from golem.utils.asyncio import Buffer, SimpleBuffer, create_task_with_logging from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) -class ScoringBuffer(ProposalScoringMixin, Buffer): - def __init__( - self, - update_interval: timedelta = timedelta(seconds=10), - *args, - **kwargs, - ): +class ScoringBuffer(ProposalScoringMixin, BufferPlugin): + def __init__(self, update_interval: timedelta = timedelta(seconds=10), *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self._scored: List[Proposal] = [] - self._scored_condition = asyncio.Condition() - self._background_loop_task: Optional[asyncio.Task] = None - - self._items_requested_event = asyncio.Event() self._update_interval = update_interval + self._buffer_scored: Buffer[Proposal] = SimpleBuffer() + self._background_loop_task: Optional[asyncio.Task] = None + @trace_span() async def start(self) -> None: await super().start() @@ -44,86 +37,27 @@ async def stop(self) -> None: self._background_loop_task.cancel() self._background_loop_task = None - def is_started(self) -> bool: - is_started = super().is_started() - return ( - is_started - and self._background_loop_task is not None - and not self._background_loop_task.done() - ) + async def _background_loop(self) -> None: + while True: + logger.debug("Waiting for any items to score...") + await self._buffer.wait_for_any_items() + logger.debug("Waiting for any items to score done, items are available for scoring") - async def _background_loop(self): - # TODO: Rework this flow to: - # - Use composition instead of inheritance - # - Get rid of loop checking state for proper state - # - composable Buffer should expose async context that exposes its items - # and items asyncio.Condition + logger.debug("Waiting for more items up to %s...", self._update_interval) + items = await self._buffer.get_all_requested(self._update_interval) + logger.debug("Waiting for more items done, %d new items will be scored", len(items)) - while True: - logger.debug("Waiting for any requested items...") - await self._items_requested_event.wait() - logger.debug("Waiting for any requested items done") - - keep_retrying = True - while keep_retrying: # FIXME: Needs refactor too - logger.debug("Waiting up to %s for all requested items...", self._update_interval) - try: - await asyncio.wait_for( - self._requests_queue.join(), - timeout=self._update_interval.total_seconds(), - ) - except asyncio.TimeoutError: - logger.debug( - "Waiting up to %s for all requested items failed with timeout, trying to" - "update anyways...", - self._update_interval, - ) - else: - logger.debug( - "Waiting up to %s for all requested items done", self._update_interval - ) - keep_retrying = False - self._items_requested_event.clear() - - async with self._buffered_condition: - if not self._buffered: - logger.debug("Update not needed, as no items were buffered in the meantime") - continue - - items_to_score = self._buffered[:] - self._buffered.clear() - - async with self._scored_condition: - scored_proposals = await self.do_scoring(self._scored + items_to_score) - self._scored = [proposal for _, proposal in scored_proposals] - - logger.debug("Item collection updated %s", scored_proposals) - - self._scored_condition.notify_all() - - async def _get_item(self): - async with self._scored_condition: - if self._get_items_count() == 0: # This supports lazy (not at start) buffer filling - logger.debug("No items to get, requesting fill") - self._handle_item_requests() - - logger.debug("Waiting for any item to pick...") - - await self._scored_condition.wait_for(lambda: 0 < len(self._scored)) - item = self._scored.pop() - - # Check if we need to request any additional items - if self._get_items_count() < self._min_size: - self._handle_item_requests() - - return item - - def _handle_item_requests(self) -> None: - super()._handle_item_requests() - - self._items_requested_event.set() - - def _get_items_count(self): - items_count = super()._get_items_count() - - return items_count + len(self._scored) + items.extend(await self._buffer_scored.get_all()) + + logger.debug("Scoring total %d items...", len(items)) + + scored_items = await self.do_scoring(items) + await self._buffer_scored.put_all([proposal for _, proposal in scored_items]) + + logger.debug("Scoring total %d items done", len(items)) + + async def _get_item(self) -> Proposal: + return await self._buffer_scored.get() + + def _get_items_count(self) -> int: + return super()._get_items_count() + self._buffer_scored.size() diff --git a/golem/utils/asyncio/__init__.py b/golem/utils/asyncio/__init__.py index bd8798b7..e7a60662 100644 --- a/golem/utils/asyncio/__init__.py +++ b/golem/utils/asyncio/__init__.py @@ -1,5 +1,5 @@ from golem.utils.asyncio.buffer import ( - BackgroundFeedBuffer, + BackgroundFillBuffer, Buffer, ComposableBuffer, ExpirableBuffer, @@ -15,7 +15,7 @@ from golem.utils.asyncio.waiter import Waiter __all__ = ( - "BackgroundFeedBuffer", + "BackgroundFillBuffer", "Buffer", "ComposableBuffer", "ExpirableBuffer", diff --git a/golem/utils/asyncio/buffer.py b/golem/utils/asyncio/buffer.py index 52fb1f1f..887e07eb 100644 --- a/golem/utils/asyncio/buffer.py +++ b/golem/utils/asyncio/buffer.py @@ -270,23 +270,23 @@ async def _expire_item(self, item: TItem) -> None: await self._on_expiration_func(item) -class BackgroundFeedBuffer(ComposableBuffer[TItem]): - """Composable `Buffer` that adds option to feed buffer in background task. +class BackgroundFillBuffer(ComposableBuffer[TItem]): + """Composable `Buffer` that adds option to fill buffer in background tasks. - Background feed will happen only if background tasks are started by calling `.start()` + Background fill will happen only if background tasks are started by calling `.start()` and items were requested by `.request()`. """ def __init__( self, buffer: Buffer[TItem], - feed_func: Callable[[], Awaitable[TItem]], - feed_concurrency_size=1, + fill_func: Callable[[], Awaitable[TItem]], + fill_concurrency_size=1, ): super().__init__(buffer) - self._feed_func = feed_func - self._feed_concurrency_size = feed_concurrency_size + self._fill_func = fill_func + self._fill_concurrency_size = fill_concurrency_size self._is_started = False self._worker_tasks: List[asyncio.Task] = [] @@ -297,7 +297,7 @@ async def start(self) -> None: if self.is_started(): raise RuntimeError("Already started!") - for i in range(self._feed_concurrency_size): + for i in range(self._fill_concurrency_size): self._worker_tasks.append( create_task_with_logging( self._worker_loop(), trace_id=get_trace_id_name(self, f"worker-{i}") @@ -322,36 +322,40 @@ def is_started(self) -> bool: async def _worker_loop(self): while True: - logger.debug("Waiting for item request...") + logger.debug("Waiting for fill item request...") async with self._workers_semaphore: - logger.debug("Waiting for item request done") + logger.debug("Waiting for fill item request done") logger.debug("Adding new item...") - item = await self._feed_func() + item = await self._fill_func() await self.put(item) - logger.debug("Adding new item done") + logger.debug("Adding new item done with total of %d items in buffer", self.size()) async def request(self, count: int) -> None: """Request given number of items to be filled in background.""" + await self._workers_semaphore.increase(count) - logger.debug(f"Requested {count} items") + logger.debug("Requested %d items to be filled in background", count) def size_with_requested(self) -> int: - """Return sum of items stored in buffer and requested to be filled.""" + """Return sum of item count stored in buffer and requested to be filled.""" + return self.size() + self._workers_semaphore.get_count_with_pending() async def get_all_requested(self, deadline: timedelta) -> MutableSequence[TItem]: """Await for all requested items with given deadline, then remove and return all items stored in buffer.""" - try: - await asyncio.wait_for( - self._workers_semaphore.finished.wait(), deadline.total_seconds() - ) - except asyncio.TimeoutError: - pass + + if not self._workers_semaphore.finished.is_set(): + try: + await asyncio.wait_for( + self._workers_semaphore.finished.wait(), deadline.total_seconds() + ) + except asyncio.TimeoutError: + pass return await self.get_all() diff --git a/golem/utils/asyncio/tasks.py b/golem/utils/asyncio/tasks.py index 5285a315..1660be42 100644 --- a/golem/utils/asyncio/tasks.py +++ b/golem/utils/asyncio/tasks.py @@ -17,7 +17,7 @@ def create_task_with_logging(coro, *, trace_id: Optional[str] = None) -> asyncio else: task_name = task.get_name() - logger.debug(f"Task `{task_name}` created") + logger.debug("Task `%s` created", task_name) return task diff --git a/tests/unit/utils/test_buffer.py b/tests/unit/utils/test_buffer.py index b0f4cf9d..762cdd35 100644 --- a/tests/unit/utils/test_buffer.py +++ b/tests/unit/utils/test_buffer.py @@ -3,7 +3,7 @@ import pytest -from golem.utils.asyncio.buffer import BackgroundFeedBuffer, Buffer, ExpirableBuffer, SimpleBuffer +from golem.utils.asyncio.buffer import BackgroundFillBuffer, Buffer, ExpirableBuffer, SimpleBuffer @pytest.fixture @@ -339,7 +339,7 @@ async def test_expirable_buffer_can_expire_items_with_put_all_get_all(mocked_buf async def test_background_feed_buffer_start_stop(mocked_buffer, mocker): feed_func = mocker.AsyncMock() - buffer = BackgroundFeedBuffer( + buffer = BackgroundFillBuffer( mocked_buffer, feed_func, ) @@ -373,7 +373,7 @@ async def test_background_feed_buffer_request(mocked_buffer, mocker): item = object() feed_queue = asyncio.Queue() feed_func = mocker.AsyncMock(wraps=feed_queue.get) - buffer = BackgroundFeedBuffer( + buffer = BackgroundFillBuffer( mocked_buffer, feed_func, ) @@ -405,7 +405,7 @@ async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, e item = object() feed_queue = asyncio.Queue() feed_func = mocker.AsyncMock(wraps=feed_queue.get) - buffer = BackgroundFeedBuffer( + buffer = BackgroundFillBuffer( mocked_buffer, feed_func, ) diff --git a/tests/unit/utils/test_error_reporting_queue.py b/tests/unit/utils/test_error_reporting_queue.py index 8e7f0c67..318782df 100644 --- a/tests/unit/utils/test_error_reporting_queue.py +++ b/tests/unit/utils/test_error_reporting_queue.py @@ -2,7 +2,7 @@ import pytest -from golem.utils.asyncio.queue import ErrorReportingQueue +from golem.utils.asyncio import ErrorReportingQueue class SomeException(Exception): From a0e72e511bf2743af31814ee051b9aa7b5845ff0 Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 24 Jan 2024 12:52:39 +0100 Subject: [PATCH 06/20] logs and fixes --- .../task_api_draft/task_api/activity_pool.py | 3 ++- golem/event_bus/in_memory/event_bus.py | 4 +-- golem/managers/mixins.py | 4 +-- golem/managers/proposal/plugins/buffer.py | 25 ++++++++++++++----- .../plugins/scoring/scoring_buffer.py | 7 ++++-- golem/managers/work/plugins.py | 10 ++------ golem/utils/asyncio/buffer.py | 10 ++++++++ golem/utils/asyncio/queue.py | 4 ++- golem/utils/asyncio/semaphore.py | 4 +-- tests/unit/utils/test_semaphore.py | 15 ++++++++++- 10 files changed, 61 insertions(+), 25 deletions(-) diff --git a/examples/task_api_draft/task_api/activity_pool.py b/examples/task_api_draft/task_api/activity_pool.py index df814fa3..9ce2fa95 100644 --- a/examples/task_api_draft/task_api/activity_pool.py +++ b/examples/task_api_draft/task_api/activity_pool.py @@ -4,6 +4,7 @@ from golem.pipeline import InputStreamExhausted from golem.resources import Activity +from golem.utils.asyncio import cancel_and_await class ActivityPool: @@ -94,7 +95,7 @@ async def _activity_destroyed_cleanup( return await activity.wait_destroyed() - manager_task.cancel() + await cancel_and_await(manager_task) async def _get_next_idle_activity( self, activity_stream: AsyncIterator[Union[Activity, Awaitable[Activity]]] diff --git a/golem/event_bus/in_memory/event_bus.py b/golem/event_bus/in_memory/event_bus.py index 052f2959..c4294a35 100644 --- a/golem/event_bus/in_memory/event_bus.py +++ b/golem/event_bus/in_memory/event_bus.py @@ -5,7 +5,7 @@ from typing import Awaitable, Callable, DefaultDict, List, Optional, Tuple, Type from golem.event_bus.base import Event, EventBus, EventBusError, TEvent -from golem.utils.asyncio import create_task_with_logging +from golem.utils.asyncio import create_task_with_logging, cancel_and_await from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ async def stop(self): await self._event_queue.join() if self._process_event_queue_loop_task is not None: - self._process_event_queue_loop_task.cancel() + await cancel_and_await(self._process_event_queue_loop_task) self._process_event_queue_loop_task = None @trace_span(show_results=True) diff --git a/golem/managers/mixins.py b/golem/managers/mixins.py index bc1170a2..2f33ff2d 100644 --- a/golem/managers/mixins.py +++ b/golem/managers/mixins.py @@ -3,7 +3,7 @@ from typing import Generic, List, Optional, Sequence from golem.managers.base import ManagerException, TPlugin -from golem.utils.asyncio import create_task_with_logging +from golem.utils.asyncio import create_task_with_logging, cancel_and_await from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ async def stop(self) -> None: raise ManagerException("Already stopped!") if self._background_loop_task is not None: - self._background_loop_task.cancel() + await cancel_and_await(self._background_loop_task) self._background_loop_task = None def is_started(self) -> bool: diff --git a/golem/managers/proposal/plugins/buffer.py b/golem/managers/proposal/plugins/buffer.py index 7236284f..a29b5b61 100644 --- a/golem/managers/proposal/plugins/buffer.py +++ b/golem/managers/proposal/plugins/buffer.py @@ -25,11 +25,20 @@ def __init__( buffer=SimpleBuffer(), fill_func=self._call_feed_func, fill_concurrency_size=self._fill_concurrency_size, + on_added_callback=self._on_added_callback ) async def _call_feed_func(self) -> Proposal: return await self._get_proposal() + async def _on_added_callback(self): + count_current = self._buffer.size() + count_with_requested = self._buffer.size_with_requested() + pending = count_with_requested - count_current + + logger.debug("Item added, having %d items, and %d pending, target %d", count_current, pending, self._max_size) + + @trace_span() async def start(self) -> None: await self._buffer.start() @@ -42,10 +51,14 @@ async def stop(self) -> None: await self._buffer.stop() async def _request_items(self): - count = self._max_size - self._buffer.size_with_requested() - await self._buffer.request(count) + count_current = self._buffer.size() + count_with_requested = self._buffer.size_with_requested() + requested = self._max_size - count_with_requested + + logger.debug("Having %d items, and %d already requested, requesting additional %d items to match target %d", count_current, count_with_requested - count_current, requested, self._max_size) + + await self._buffer.request(requested) - logger.debug("Requested %s items", count) @trace_span(show_results=True) async def get_proposal(self) -> Proposal: @@ -58,14 +71,14 @@ async def get_proposal(self) -> Proposal: items_count = self._get_items_count() if items_count < self._min_size: logger.debug( - "Target items count `%s` is below min size `%d`, requesting fill", + "Items count is now `%s` which is below min size `%d`, requesting fill", items_count, self._min_size, ) await self._request_items() else: logger.debug( - "Target items count `%s` is not below min size `%d`, requesting fill not needed", + "Target items is now `%s` which is not below min size `%d`, requesting fill not needed", items_count, self._min_size, ) @@ -76,4 +89,4 @@ async def _get_item(self) -> Proposal: return await self._buffer.get() def _get_items_count(self) -> int: - return self._buffer.size_with_requested() + return self._buffer.size() diff --git a/golem/managers/proposal/plugins/scoring/scoring_buffer.py b/golem/managers/proposal/plugins/scoring/scoring_buffer.py index 26d53f13..cd52041d 100644 --- a/golem/managers/proposal/plugins/scoring/scoring_buffer.py +++ b/golem/managers/proposal/plugins/scoring/scoring_buffer.py @@ -6,7 +6,7 @@ from golem.managers.proposal.plugins.buffer import Buffer as BufferPlugin from golem.managers.proposal.plugins.scoring import ProposalScoringMixin from golem.resources import Proposal -from golem.utils.asyncio import Buffer, SimpleBuffer, create_task_with_logging +from golem.utils.asyncio import Buffer, SimpleBuffer, create_task_with_logging, cancel_and_await from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) @@ -34,9 +34,12 @@ async def stop(self) -> None: await super().stop() if self._background_loop_task is not None: - self._background_loop_task.cancel() + await cancel_and_await(self._background_loop_task) self._background_loop_task = None + async def _on_added_callback(self): + pass # explicit no-op + async def _background_loop(self) -> None: while True: logger.debug("Waiting for any items to score...") diff --git a/golem/managers/work/plugins.py b/golem/managers/work/plugins.py index 5612b9bb..bfeab014 100644 --- a/golem/managers/work/plugins.py +++ b/golem/managers/work/plugins.py @@ -10,6 +10,7 @@ WorkManagerPlugin, WorkResult, ) +from golem.utils.asyncio import cancel_and_await_many logger = logging.getLogger(__name__) @@ -70,14 +71,7 @@ async def wrapper(work: Work) -> WorkResult: tasks, return_when=asyncio.FIRST_COMPLETED ) - for task in tasks_pending: - task.cancel() - - for task in tasks_pending: - try: - await task - except asyncio.CancelledError: - pass + await cancel_and_await_many(tasks_pending) return tasks_done.pop().result() diff --git a/golem/utils/asyncio/buffer.py b/golem/utils/asyncio/buffer.py index 887e07eb..a81bc766 100644 --- a/golem/utils/asyncio/buffer.py +++ b/golem/utils/asyncio/buffer.py @@ -282,11 +282,13 @@ def __init__( buffer: Buffer[TItem], fill_func: Callable[[], Awaitable[TItem]], fill_concurrency_size=1, + on_added_callback: Optional[Callable[[], Awaitable[None]]] = None, ): super().__init__(buffer) self._fill_func = fill_func self._fill_concurrency_size = fill_concurrency_size + self._on_added_callback = on_added_callback self._is_started = False self._worker_tasks: List[asyncio.Task] = [] @@ -333,6 +335,9 @@ async def _worker_loop(self): await self.put(item) + if self._on_added_callback is not None: + await self._on_added_callback() + logger.debug("Adding new item done with total of %d items in buffer", self.size()) async def request(self, count: int) -> None: @@ -351,6 +356,7 @@ async def get_all_requested(self, deadline: timedelta) -> MutableSequence[TItem] """Await for all requested items with given deadline, then remove and return all items stored in buffer.""" if not self._workers_semaphore.finished.is_set(): + logger.debug("semaphore %d is not finished, waiting...", self._workers_semaphore.get_pending_count()) try: await asyncio.wait_for( self._workers_semaphore.finished.wait(), deadline.total_seconds() @@ -358,4 +364,8 @@ async def get_all_requested(self, deadline: timedelta) -> MutableSequence[TItem] except asyncio.TimeoutError: pass + logger.debug("semaphore is not finished, waiting done") + else: + logger.debug("semaphore %d is finished, not waiting", self._workers_semaphore.get_pending_count()) + return await self.get_all() diff --git a/golem/utils/asyncio/queue.py b/golem/utils/asyncio/queue.py index 98c982d5..b5b612c9 100644 --- a/golem/utils/asyncio/queue.py +++ b/golem/utils/asyncio/queue.py @@ -1,6 +1,8 @@ import asyncio from typing import Generic, Optional, TypeVar +from golem.utils.asyncio.tasks import cancel_and_await_many + TQueueItem = TypeVar("TQueueItem") @@ -35,7 +37,7 @@ async def get(self) -> TQueueItem: [error_task, get_task], return_when=asyncio.FIRST_COMPLETED ) - [t.cancel() for t in pending] + await cancel_and_await_many(pending) if get_task in done: return await get_task diff --git a/golem/utils/asyncio/semaphore.py b/golem/utils/asyncio/semaphore.py index 44bfd970..fb2bc9bd 100644 --- a/golem/utils/asyncio/semaphore.py +++ b/golem/utils/asyncio/semaphore.py @@ -12,7 +12,7 @@ def __init__(self, value=0): self._condition = asyncio.Condition() self.finished = asyncio.Event() - if not self._value: + if self.locked(): self.finished.set() async def __aenter__(self): @@ -37,7 +37,7 @@ def release(self): self._pending -= 1 - if not self._pending: + if self.locked(): self.finished.set() async def increase(self, value: int) -> None: diff --git a/tests/unit/utils/test_semaphore.py b/tests/unit/utils/test_semaphore.py index 3ef09c78..b9af0029 100644 --- a/tests/unit/utils/test_semaphore.py +++ b/tests/unit/utils/test_semaphore.py @@ -151,7 +151,7 @@ async def test_release(): async def test_context_manager(): - sem_value = 1 + sem_value = 2 sem = SingleUseSemaphore(sem_value) assert sem.get_count() == sem_value @@ -160,6 +160,19 @@ async def test_context_manager(): assert not sem.finished.is_set() assert not sem.locked() + async with sem: + assert sem.get_count() == 1 + assert sem.get_count_with_pending() == sem_value + assert sem.get_pending_count() == 1 + assert not sem.finished.is_set() + assert not sem.locked() + + assert sem.get_count() == 1 + assert sem.get_count_with_pending() == 1 + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + async with sem: assert sem.get_count() == 0 assert sem.get_count_with_pending() == 1 From a912cf207646fcecd5e364149b5978ad30d01689 Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 24 Jan 2024 15:56:05 +0100 Subject: [PATCH 07/20] finishing touches? --- examples/managers/blender/blender.py | 4 +- examples/managers/flexible_negotiation.py | 6 +-- examples/managers/mid_agreement_payments.py | 4 +- examples/managers/proposal_plugins.py | 8 ++-- examples/task_api_draft/examples/yacat.py | 2 +- .../task_api_draft/task_api/execute_tasks.py | 2 +- golem/event_bus/in_memory/event_bus.py | 2 +- golem/managers/__init__.py | 8 ++-- golem/managers/mixins.py | 2 +- golem/managers/proposal/__init__.py | 8 ++-- golem/managers/proposal/plugins/__init__.py | 8 ++-- golem/managers/proposal/plugins/buffer.py | 27 +++++++++---- .../proposal/plugins/scoring/__init__.py | 4 +- .../plugins/scoring/scoring_buffer.py | 6 +-- golem/utils/asyncio/buffer.py | 38 +++++++++---------- golem/utils/asyncio/tasks.py | 4 +- golem/utils/asyncio/waiter.py | 7 ++-- tests/unit/utils/test_buffer.py | 16 ++++---- 18 files changed, 83 insertions(+), 73 deletions(-) diff --git a/examples/managers/blender/blender.py b/examples/managers/blender/blender.py index bbb8e338..5554ecaa 100644 --- a/examples/managers/blender/blender.py +++ b/examples/managers/blender/blender.py @@ -16,7 +16,7 @@ PaymentPlatformNegotiator, PoolActivityManager, RefreshingDemandManager, - ScoringBuffer, + ScoringBufferPlugin, WorkContext, WorkResult, retry, @@ -58,7 +58,7 @@ async def run_on_golem( demand_manager.get_initial_proposal, plugins=[ NegotiatingPlugin(proposal_negotiators=negotiators), - ScoringBuffer( + ScoringBufferPlugin( min_size=3, max_size=5, fill_concurrency_size=3, proposal_scorers=scorers ), ], diff --git a/examples/managers/flexible_negotiation.py b/examples/managers/flexible_negotiation.py index c173b259..3e7bdeef 100644 --- a/examples/managers/flexible_negotiation.py +++ b/examples/managers/flexible_negotiation.py @@ -3,7 +3,7 @@ from golem.managers import ( BlacklistProviderIdPlugin, - Buffer, + BufferPlugin, DefaultAgreementManager, DefaultProposalManager, NegotiatingPlugin, @@ -48,14 +48,14 @@ async def main(): golem, demand_manager.get_initial_proposal, plugins=[ - Buffer( + BufferPlugin( min_size=10, max_size=1000, fill_concurrency_size=5, ), BlacklistProviderIdPlugin(BLACKLISTED_PROVIDERS), NegotiatingPlugin(proposal_negotiators=[PaymentPlatformNegotiator()]), - Buffer( + BufferPlugin( min_size=3, max_size=5, fill_concurrency_size=3, diff --git a/examples/managers/mid_agreement_payments.py b/examples/managers/mid_agreement_payments.py index 72f75087..7c99aaa9 100644 --- a/examples/managers/mid_agreement_payments.py +++ b/examples/managers/mid_agreement_payments.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta from golem.managers import ( - Buffer, + BufferPlugin, DefaultAgreementManager, DefaultProposalManager, MidAgreementPaymentsNegotiator, @@ -91,7 +91,7 @@ async def main(): ), ] ), - Buffer( + BufferPlugin( min_size=1, max_size=4, fill_concurrency_size=2, diff --git a/examples/managers/proposal_plugins.py b/examples/managers/proposal_plugins.py index 552781f7..53f85d22 100644 --- a/examples/managers/proposal_plugins.py +++ b/examples/managers/proposal_plugins.py @@ -5,7 +5,7 @@ from golem.managers import ( BlacklistProviderIdPlugin, - Buffer, + BufferPlugin, DefaultAgreementManager, DefaultProposalManager, LinearAverageCostPricing, @@ -19,7 +19,7 @@ RefreshingDemandManager, RejectIfCostsExceeds, RejectProposal, - ScoringBuffer, + ScoringBufferPlugin, SequentialWorkManager, WorkContext, WorkResult, @@ -78,7 +78,7 @@ async def main(): golem, demand_manager.get_initial_proposal, plugins=[ - Buffer( + BufferPlugin( min_size=10, max_size=1000, ), @@ -101,7 +101,7 @@ async def main(): else None, ) ), - ScoringBuffer( + ScoringBufferPlugin( min_size=3, max_size=5, fill_concurrency_size=3, diff --git a/examples/task_api_draft/examples/yacat.py b/examples/task_api_draft/examples/yacat.py index 4854d222..257c4636 100644 --- a/examples/task_api_draft/examples/yacat.py +++ b/examples/task_api_draft/examples/yacat.py @@ -205,7 +205,7 @@ async def main() -> None: Zip(async_queue_aiter(tasks_queue)), # type: ignore # mypy, why? Map( lambda activity, task: task.execute(activity), # type: ignore - on_exception=close_agreement_repeat_task, + on_exception=close_agreement_repeat_task, # type: ignore ), Buffer(size=MAX_WORKERS * 2), ): diff --git a/examples/task_api_draft/task_api/execute_tasks.py b/examples/task_api_draft/task_api/execute_tasks.py index 1618912e..d56f77ca 100644 --- a/examples/task_api_draft/task_api/execute_tasks.py +++ b/examples/task_api_draft/task_api/execute_tasks.py @@ -67,7 +67,7 @@ def get_chain( Zip(task_stream), Map( execute_task, # type: ignore[arg-type] - on_exception=close_agreement_repeat_task(task_stream), + on_exception=close_agreement_repeat_task(task_stream), # type: ignore[arg-type] ), Buffer(size=max_workers), ) diff --git a/golem/event_bus/in_memory/event_bus.py b/golem/event_bus/in_memory/event_bus.py index c4294a35..2b7bf85e 100644 --- a/golem/event_bus/in_memory/event_bus.py +++ b/golem/event_bus/in_memory/event_bus.py @@ -5,7 +5,7 @@ from typing import Awaitable, Callable, DefaultDict, List, Optional, Tuple, Type from golem.event_bus.base import Event, EventBus, EventBusError, TEvent -from golem.utils.asyncio import create_task_with_logging, cancel_and_await +from golem.utils.asyncio import cancel_and_await, create_task_with_logging from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) diff --git a/golem/managers/__init__.py b/golem/managers/__init__.py index 470f15da..96006298 100644 --- a/golem/managers/__init__.py +++ b/golem/managers/__init__.py @@ -23,7 +23,7 @@ from golem.managers.payment import PayAllPaymentManager from golem.managers.proposal import ( BlacklistProviderIdPlugin, - Buffer, + BufferPlugin, DefaultProposalManager, LinearAverageCostPricing, LinearCoeffsCost, @@ -37,7 +37,7 @@ ProposalScoringMixin, RandomScore, RejectIfCostsExceeds, - ScoringBuffer, + ScoringBufferPlugin, ) from golem.managers.work import ( ConcurrentWorkManager, @@ -73,7 +73,7 @@ "PayAllPaymentManager", "DefaultProposalManager", "BlacklistProviderIdPlugin", - "Buffer", + "BufferPlugin", "PaymentPlatformNegotiator", "MidAgreementPaymentsNegotiator", "NegotiatingPlugin", @@ -86,7 +86,7 @@ "LinearCoeffsCost", "PropertyValueLerpScore", "RandomScore", - "ScoringBuffer", + "ScoringBufferPlugin", "SequentialWorkManager", "ConcurrentWorkManager", "WorkManagerPluginsMixin", diff --git a/golem/managers/mixins.py b/golem/managers/mixins.py index 2f33ff2d..e6793e20 100644 --- a/golem/managers/mixins.py +++ b/golem/managers/mixins.py @@ -3,7 +3,7 @@ from typing import Generic, List, Optional, Sequence from golem.managers.base import ManagerException, TPlugin -from golem.utils.asyncio import create_task_with_logging, cancel_and_await +from golem.utils.asyncio import cancel_and_await, create_task_with_logging from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) diff --git a/golem/managers/proposal/__init__.py b/golem/managers/proposal/__init__.py index 819b4014..49918f2b 100644 --- a/golem/managers/proposal/__init__.py +++ b/golem/managers/proposal/__init__.py @@ -1,7 +1,7 @@ from golem.managers.proposal.default import DefaultProposalManager from golem.managers.proposal.plugins import ( BlacklistProviderIdPlugin, - Buffer, + BufferPlugin, LinearAverageCostPricing, LinearCoeffsCost, LinearPerCpuAverageCostPricing, @@ -14,13 +14,13 @@ ProposalScoringMixin, RandomScore, RejectIfCostsExceeds, - ScoringBuffer, + ScoringBufferPlugin, ) __all__ = ( "DefaultProposalManager", "BlacklistProviderIdPlugin", - "Buffer", + "BufferPlugin", "PaymentPlatformNegotiator", "MidAgreementPaymentsNegotiator", "NegotiatingPlugin", @@ -33,5 +33,5 @@ "LinearPerCpuCoeffsCost", "PropertyValueLerpScore", "RandomScore", - "ScoringBuffer", + "ScoringBufferPlugin", ) diff --git a/golem/managers/proposal/plugins/__init__.py b/golem/managers/proposal/plugins/__init__.py index 5cad6098..900eb32a 100644 --- a/golem/managers/proposal/plugins/__init__.py +++ b/golem/managers/proposal/plugins/__init__.py @@ -1,5 +1,5 @@ from golem.managers.proposal.plugins.blacklist import BlacklistProviderIdPlugin -from golem.managers.proposal.plugins.buffer import Buffer +from golem.managers.proposal.plugins.buffer import BufferPlugin from golem.managers.proposal.plugins.linear_coeffs import LinearCoeffsCost, LinearPerCpuCoeffsCost from golem.managers.proposal.plugins.negotiating import ( MidAgreementPaymentsNegotiator, @@ -14,12 +14,12 @@ PropertyValueLerpScore, ProposalScoringMixin, RandomScore, - ScoringBuffer, + ScoringBufferPlugin, ) __all__ = ( "BlacklistProviderIdPlugin", - "Buffer", + "BufferPlugin", "PaymentPlatformNegotiator", "MidAgreementPaymentsNegotiator", "NegotiatingPlugin", @@ -32,5 +32,5 @@ "LinearPerCpuCoeffsCost", "PropertyValueLerpScore", "RandomScore", - "ScoringBuffer", + "ScoringBufferPlugin", ) diff --git a/golem/managers/proposal/plugins/buffer.py b/golem/managers/proposal/plugins/buffer.py index a29b5b61..0c81d7dd 100644 --- a/golem/managers/proposal/plugins/buffer.py +++ b/golem/managers/proposal/plugins/buffer.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) -class Buffer(ProposalManagerPlugin): +class BufferPlugin(ProposalManagerPlugin): def __init__( self, min_size: int, @@ -21,11 +21,11 @@ def __init__( self._fill_concurrency_size = fill_concurrency_size self._fill_at_start = fill_at_start - self._buffer = BackgroundFillBuffer( + self._buffer: BackgroundFillBuffer[Proposal] = BackgroundFillBuffer( buffer=SimpleBuffer(), fill_func=self._call_feed_func, fill_concurrency_size=self._fill_concurrency_size, - on_added_callback=self._on_added_callback + on_added_callback=self._on_added_callback, ) async def _call_feed_func(self) -> Proposal: @@ -36,8 +36,12 @@ async def _on_added_callback(self): count_with_requested = self._buffer.size_with_requested() pending = count_with_requested - count_current - logger.debug("Item added, having %d items, and %d pending, target %d", count_current, pending, self._max_size) - + logger.debug( + "Item added, having %d items, and %d pending, target %d", + count_current, + pending, + self._max_size, + ) @trace_span() async def start(self) -> None: @@ -55,11 +59,17 @@ async def _request_items(self): count_with_requested = self._buffer.size_with_requested() requested = self._max_size - count_with_requested - logger.debug("Having %d items, and %d already requested, requesting additional %d items to match target %d", count_current, count_with_requested - count_current, requested, self._max_size) + logger.debug( + "Having %d items, and %d already requested, requesting additional %d items to match" + " target %d", + count_current, + count_with_requested - count_current, + requested, + self._max_size, + ) await self._buffer.request(requested) - @trace_span(show_results=True) async def get_proposal(self) -> Proposal: if not self._get_items_count(): @@ -78,7 +88,8 @@ async def get_proposal(self) -> Proposal: await self._request_items() else: logger.debug( - "Target items is now `%s` which is not below min size `%d`, requesting fill not needed", + "Target items is now `%s` which is not below min size `%d`, requesting fill not" + " needed", items_count, self._min_size, ) diff --git a/golem/managers/proposal/plugins/scoring/__init__.py b/golem/managers/proposal/plugins/scoring/__init__.py index e8209595..b29ac289 100644 --- a/golem/managers/proposal/plugins/scoring/__init__.py +++ b/golem/managers/proposal/plugins/scoring/__init__.py @@ -6,7 +6,7 @@ ) from golem.managers.proposal.plugins.scoring.property_value_lerp import PropertyValueLerpScore from golem.managers.proposal.plugins.scoring.random import RandomScore -from golem.managers.proposal.plugins.scoring.scoring_buffer import ScoringBuffer +from golem.managers.proposal.plugins.scoring.scoring_buffer import ScoringBufferPlugin __all__ = ( "MapScore", @@ -15,5 +15,5 @@ "LinearPerCpuAverageCostPricing", "PropertyValueLerpScore", "RandomScore", - "ScoringBuffer", + "ScoringBufferPlugin", ) diff --git a/golem/managers/proposal/plugins/scoring/scoring_buffer.py b/golem/managers/proposal/plugins/scoring/scoring_buffer.py index cd52041d..7e49069b 100644 --- a/golem/managers/proposal/plugins/scoring/scoring_buffer.py +++ b/golem/managers/proposal/plugins/scoring/scoring_buffer.py @@ -3,16 +3,16 @@ from datetime import timedelta from typing import Optional -from golem.managers.proposal.plugins.buffer import Buffer as BufferPlugin +from golem.managers.proposal.plugins.buffer import BufferPlugin as BufferPlugin from golem.managers.proposal.plugins.scoring import ProposalScoringMixin from golem.resources import Proposal -from golem.utils.asyncio import Buffer, SimpleBuffer, create_task_with_logging, cancel_and_await +from golem.utils.asyncio import Buffer, SimpleBuffer, cancel_and_await, create_task_with_logging from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) -class ScoringBuffer(ProposalScoringMixin, BufferPlugin): +class ScoringBufferPlugin(ProposalScoringMixin, BufferPlugin): def __init__(self, update_interval: timedelta = timedelta(seconds=10), *args, **kwargs) -> None: super().__init__(*args, **kwargs) diff --git a/golem/utils/asyncio/buffer.py b/golem/utils/asyncio/buffer.py index a81bc766..f4e61a1f 100644 --- a/golem/utils/asyncio/buffer.py +++ b/golem/utils/asyncio/buffer.py @@ -8,7 +8,6 @@ Callable, Dict, Generic, - Iterable, List, MutableSequence, Optional, @@ -27,7 +26,8 @@ class Buffer(ABC, Generic[TItem]): - """Interface class for object similar to `asyncio.Queue` but with more control over its items.""" + """Interface class for object similar to `asyncio.Queue` but with more control over its \ + items.""" @abstractmethod def size(self) -> int: @@ -41,7 +41,8 @@ async def wait_for_any_items(self) -> None: async def get(self) -> TItem: """Await, remove and return left-most item stored in buffer. - If `.set_exception()` was previously called, exception will be raised only if buffer is empty. + If `.set_exception()` was previously called, exception will be raised only if buffer + is empty. """ @abstractmethod @@ -50,7 +51,8 @@ async def get_all(self) -> MutableSequence[TItem]: Note that this method will not await for any items if buffer is empty. - If `.set_exception()` was previously called, exception will be raised only if buffer is empty. + If `.set_exception()` was previously called, exception will be raised only if buffer + is empty. """ @abstractmethod @@ -73,14 +75,16 @@ async def remove(self, item: TItem) -> None: @abstractmethod def set_exception(self, exc: BaseException) -> None: - """Set exception that will be raised while trying to `.get()`/`.get_all()` item from empty buffer.""" + """Set exception that will be raised while trying to `.get()`/`.get_all()` item from \ + empty buffer.""" def reset_exception(self) -> None: """Reset exception that was previously set by calling `.set_exception()`.""" class ComposableBuffer(Buffer[TItem]): - """Utility class for composable/stackable buffer implementations to help with calling underlying buffer.""" + """Utility class for composable/stackable buffer implementations to help with calling \ + underlying buffer.""" def __init__(self, buffer: Buffer[TItem]): self._buffer = buffer @@ -180,8 +184,8 @@ class ExpirableBuffer(ComposableBuffer[TItem]): Items that are already in provided buffer will not expire. """ - # TODO: Optimisation options: Use single expiration task that wakes up to expire the earliest item, - # then check next earliest item and sleep to it and repeat + # TODO: Optimisation options: Use single expiration task that wakes up to expire the earliest + # item, then check next earliest item and sleep to it and repeat def __init__( self, @@ -231,7 +235,7 @@ async def _remove_all_expiration_handlers(self) -> None: self._expiration_handlers.clear() - async def get(self) -> Iterable[TItem]: + async def get(self) -> TItem: async with self._lock: item = await super().get() await self._remove_expiration_handler_for_item(item) @@ -335,28 +339,26 @@ async def _worker_loop(self): await self.put(item) - if self._on_added_callback is not None: - await self._on_added_callback() + if self._on_added_callback is not None: + await self._on_added_callback() - logger.debug("Adding new item done with total of %d items in buffer", self.size()) + logger.debug("Adding new item done with total of %d items in buffer", self.size()) async def request(self, count: int) -> None: """Request given number of items to be filled in background.""" await self._workers_semaphore.increase(count) - logger.debug("Requested %d items to be filled in background", count) - def size_with_requested(self) -> int: """Return sum of item count stored in buffer and requested to be filled.""" return self.size() + self._workers_semaphore.get_count_with_pending() async def get_all_requested(self, deadline: timedelta) -> MutableSequence[TItem]: - """Await for all requested items with given deadline, then remove and return all items stored in buffer.""" + """Await for all requested items with given deadline, then remove and return all items \ + stored in buffer.""" if not self._workers_semaphore.finished.is_set(): - logger.debug("semaphore %d is not finished, waiting...", self._workers_semaphore.get_pending_count()) try: await asyncio.wait_for( self._workers_semaphore.finished.wait(), deadline.total_seconds() @@ -364,8 +366,4 @@ async def get_all_requested(self, deadline: timedelta) -> MutableSequence[TItem] except asyncio.TimeoutError: pass - logger.debug("semaphore is not finished, waiting done") - else: - logger.debug("semaphore %d is finished, not waiting", self._workers_semaphore.get_pending_count()) - return await self.get_all() diff --git a/golem/utils/asyncio/tasks.py b/golem/utils/asyncio/tasks.py index 1660be42..039d3c86 100644 --- a/golem/utils/asyncio/tasks.py +++ b/golem/utils/asyncio/tasks.py @@ -1,7 +1,7 @@ import asyncio import contextvars import logging -from typing import Optional, Sequence +from typing import Iterable, Optional from golem.utils.logging import trace_id_var @@ -52,5 +52,5 @@ async def cancel_and_await(task: asyncio.Task) -> None: pass -async def cancel_and_await_many(tasks: Sequence[asyncio.Task]) -> None: +async def cancel_and_await_many(tasks: Iterable[asyncio.Task]) -> None: await asyncio.gather(*[cancel_and_await(task) for task in tasks]) diff --git a/golem/utils/asyncio/waiter.py b/golem/utils/asyncio/waiter.py index 09c08326..a99ccbf1 100644 --- a/golem/utils/asyncio/waiter.py +++ b/golem/utils/asyncio/waiter.py @@ -1,13 +1,14 @@ import asyncio import collections -from typing import Callable +from typing import Callable, Deque class Waiter: - """Class similar to `asyncio.Event` but valueless and with notify interface similar to `asyncio.Condition`.""" + """Class similar to `asyncio.Event` but valueless and with notify interface similar to \ + `asyncio.Condition`.""" def __init__(self) -> None: - self._waiters: collections.deque[asyncio.Future] = collections.deque() + self._waiters: Deque[asyncio.Future] = collections.deque() self._loop = asyncio.get_event_loop() async def wait_for(self, predicate: Callable[[], bool]) -> None: diff --git a/tests/unit/utils/test_buffer.py b/tests/unit/utils/test_buffer.py index 762cdd35..684d848a 100644 --- a/tests/unit/utils/test_buffer.py +++ b/tests/unit/utils/test_buffer.py @@ -12,7 +12,7 @@ def mocked_buffer(mocker): def test_simple_buffer_creation(): - buffer = SimpleBuffer() + buffer: Buffer[str] = SimpleBuffer() assert buffer.size() == 0 buffer = SimpleBuffer(["a", "b", "c"]) @@ -20,7 +20,7 @@ def test_simple_buffer_creation(): async def test_simple_buffer_put_get(): - buffer = SimpleBuffer() + buffer: Buffer[object] = SimpleBuffer() assert buffer.size() == 0 item_put = object() @@ -59,7 +59,7 @@ async def test_simple_buffer_remove(): async def test_simple_buffer_get_waits_for_items(): - buffer = SimpleBuffer() + buffer: Buffer[object] = SimpleBuffer() assert buffer.size() == 0 _, pending = await asyncio.wait([buffer.get()], timeout=0.1) @@ -85,7 +85,7 @@ async def test_simple_buffer_get_waits_for_items(): done, pending = await asyncio.wait([get_task1, get_task2], timeout=0.1) if len(done) != len(pending): - pytest.fail(f"One of the tasks should not block at this point!") + pytest.fail("One of the tasks should not block at this point!") await buffer.put(item_put) @@ -126,7 +126,7 @@ async def test_simple_buffer_keeps_shallow_copy_of_items(): async def test_simple_buffer_exceptions(): - buffer = SimpleBuffer() + buffer: Buffer[str] = SimpleBuffer() assert buffer.size() == 0 exc = ZeroDivisionError() @@ -173,7 +173,7 @@ async def test_simple_buffer_exceptions(): async def test_simple_buffer_wait_for_any_items(): - buffer = SimpleBuffer() + buffer: Buffer[str] = SimpleBuffer() assert buffer.size() == 0 # should block on empty @@ -371,7 +371,7 @@ async def test_background_feed_buffer_start_stop(mocked_buffer, mocker): async def test_background_feed_buffer_request(mocked_buffer, mocker): item = object() - feed_queue = asyncio.Queue() + feed_queue: asyncio.Queue[object] = asyncio.Queue() feed_func = mocker.AsyncMock(wraps=feed_queue.get) buffer = BackgroundFillBuffer( mocked_buffer, @@ -403,7 +403,7 @@ async def test_background_feed_buffer_request(mocked_buffer, mocker): async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, event_loop): timeout = timedelta(seconds=0.1) item = object() - feed_queue = asyncio.Queue() + feed_queue: asyncio.Queue[object] = asyncio.Queue() feed_func = mocker.AsyncMock(wraps=feed_queue.get) buffer = BackgroundFillBuffer( mocked_buffer, From 5ae9a86db2e59e6f264d998b7e4ca21425134569 Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 24 Jan 2024 16:30:46 +0100 Subject: [PATCH 08/20] mypy fixes --- examples/task_api_draft/task_api/execute_tasks.py | 8 ++++---- golem/managers/proposal/plugins/buffer.py | 4 ++-- golem/utils/asyncio/buffer.py | 2 +- golem/utils/asyncio/semaphore.py | 4 ++-- golem/utils/asyncio/waiter.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/task_api_draft/task_api/execute_tasks.py b/examples/task_api_draft/task_api/execute_tasks.py index d56f77ca..656d6b24 100644 --- a/examples/task_api_draft/task_api/execute_tasks.py +++ b/examples/task_api_draft/task_api/execute_tasks.py @@ -31,10 +31,10 @@ async def random_score(proposal: Proposal) -> float: def close_agreement_repeat_task( task_stream: TaskDataStream[TaskData], -) -> Callable[[Callable, Tuple[Activity, TaskData], Exception], Awaitable[None]]: +) -> Callable[[Callable, Tuple, Exception], Awaitable[None]]: async def on_exception( - func: Callable[[Activity, TaskData], Awaitable[TaskResult]], - args: Tuple[Activity, TaskData], + func: Callable, + args: Tuple, e: Exception, ) -> None: activity, in_data = args @@ -67,7 +67,7 @@ def get_chain( Zip(task_stream), Map( execute_task, # type: ignore[arg-type] - on_exception=close_agreement_repeat_task(task_stream), # type: ignore[arg-type] + on_exception=close_agreement_repeat_task(task_stream), ), Buffer(size=max_workers), ) diff --git a/golem/managers/proposal/plugins/buffer.py b/golem/managers/proposal/plugins/buffer.py index 0c81d7dd..a88058b0 100644 --- a/golem/managers/proposal/plugins/buffer.py +++ b/golem/managers/proposal/plugins/buffer.py @@ -31,7 +31,7 @@ def __init__( async def _call_feed_func(self) -> Proposal: return await self._get_proposal() - async def _on_added_callback(self): + async def _on_added_callback(self) -> None: count_current = self._buffer.size() count_with_requested = self._buffer.size_with_requested() pending = count_with_requested - count_current @@ -54,7 +54,7 @@ async def start(self) -> None: async def stop(self) -> None: await self._buffer.stop() - async def _request_items(self): + async def _request_items(self) -> None: count_current = self._buffer.size() count_with_requested = self._buffer.size_with_requested() requested = self._max_size - count_with_requested diff --git a/golem/utils/asyncio/buffer.py b/golem/utils/asyncio/buffer.py index f4e61a1f..bcca3f53 100644 --- a/golem/utils/asyncio/buffer.py +++ b/golem/utils/asyncio/buffer.py @@ -326,7 +326,7 @@ async def stop(self) -> None: def is_started(self) -> bool: return self._is_started - async def _worker_loop(self): + async def _worker_loop(self) -> None: while True: logger.debug("Waiting for fill item request...") diff --git a/golem/utils/asyncio/semaphore.py b/golem/utils/asyncio/semaphore.py index fb2bc9bd..cc77b814 100644 --- a/golem/utils/asyncio/semaphore.py +++ b/golem/utils/asyncio/semaphore.py @@ -24,14 +24,14 @@ async def __aexit__(self, exc_type, exc, tb): def locked(self) -> bool: return not self._value - async def acquire(self): + async def acquire(self) -> None: async with self._condition: await self._condition.wait_for(lambda: self._value) self._value -= 1 self._pending += 1 - def release(self): + def release(self) -> None: if self._pending <= 0: raise RuntimeError("Release called too many times!") diff --git a/golem/utils/asyncio/waiter.py b/golem/utils/asyncio/waiter.py index a99ccbf1..10c918e5 100644 --- a/golem/utils/asyncio/waiter.py +++ b/golem/utils/asyncio/waiter.py @@ -33,7 +33,7 @@ def _notify_first(self) -> None: if not first_waiter.done(): first_waiter.set_result(None) - async def _wait(self): + async def _wait(self) -> None: future = self._loop.create_future() self._waiters.append(future) try: From 1b115d2be8b50570700a14eea6ae0dcc5839a55e Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 24 Jan 2024 16:37:05 +0100 Subject: [PATCH 09/20] mypy fixes 2 --- examples/task_api_draft/examples/yacat.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/task_api_draft/examples/yacat.py b/examples/task_api_draft/examples/yacat.py index 257c4636..1849ca75 100644 --- a/examples/task_api_draft/examples/yacat.py +++ b/examples/task_api_draft/examples/yacat.py @@ -1,7 +1,6 @@ import asyncio from collections import defaultdict from typing import ( - Any, AsyncIterator, Callable, DefaultDict, @@ -205,7 +204,7 @@ async def main() -> None: Zip(async_queue_aiter(tasks_queue)), # type: ignore # mypy, why? Map( lambda activity, task: task.execute(activity), # type: ignore - on_exception=close_agreement_repeat_task, # type: ignore + on_exception=close_agreement_repeat_task, ), Buffer(size=MAX_WORKERS * 2), ): @@ -221,9 +220,7 @@ async def main() -> None: ############################################# # NOT REALLY INTERESTING PARTS OF THE LOGIC -async def close_agreement_repeat_task( - func: Callable, args: Tuple[Activity, Any], e: Exception -) -> None: +async def close_agreement_repeat_task(func: Callable, args: Tuple, e: Exception) -> None: activity, task = args tasks_queue.put_nowait(task) print("Task failed on", activity) From 2b5d83a66a908c5bd74364c8213bbefc07341ea4 Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 24 Jan 2024 16:44:48 +0100 Subject: [PATCH 10/20] test fixes for py 3.11 --- .github/workflows/tests-unit.yml | 2 +- tests/unit/utils/test_buffer.py | 2 +- tests/unit/utils/test_semaphore.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests-unit.yml b/.github/workflows/tests-unit.yml index 1af51c31..05a34296 100644 --- a/.github/workflows/tests-unit.yml +++ b/.github/workflows/tests-unit.yml @@ -15,7 +15,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] os: - ubuntu-latest - macos-latest diff --git a/tests/unit/utils/test_buffer.py b/tests/unit/utils/test_buffer.py index 684d848a..557f5c42 100644 --- a/tests/unit/utils/test_buffer.py +++ b/tests/unit/utils/test_buffer.py @@ -62,7 +62,7 @@ async def test_simple_buffer_get_waits_for_items(): buffer: Buffer[object] = SimpleBuffer() assert buffer.size() == 0 - _, pending = await asyncio.wait([buffer.get()], timeout=0.1) + _, pending = await asyncio.wait([asyncio.create_task(buffer.get())], timeout=0.1) if not pending: pytest.fail("Getting empty buffer somehow finished instead of blocking!") diff --git a/tests/unit/utils/test_semaphore.py b/tests/unit/utils/test_semaphore.py index b9af0029..adab0022 100644 --- a/tests/unit/utils/test_semaphore.py +++ b/tests/unit/utils/test_semaphore.py @@ -98,7 +98,7 @@ async def test_acquire(): assert not sem.finished.is_set() assert sem.locked() - _, pending = await asyncio.wait([sem.acquire()], timeout=0.1) + _, pending = await asyncio.wait([asyncio.create_task(sem.acquire())], timeout=0.1) acquire_task = pending.pop() if not acquire_task: pytest.fail("Acquiring locked semaphore somehow finished instead of blocking!") From fa50207504e03429ee2315b02bf5229b34c4367e Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 24 Jan 2024 16:50:37 +0100 Subject: [PATCH 11/20] ignore backticks complain --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 18f9d79d..588c89cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,7 @@ extend-ignore = [ "D104", # No docs for public package "D107", # No docs for __init__ "D202", # We prefer whitelines after docstrings + "W604", # Backticks complain on Python 3.12+ ] [tool.mypy] From 94ebff703956eaa13ad454be32d193079250f5c6 Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 24 Jan 2024 17:01:08 +0100 Subject: [PATCH 12/20] flake8 bump --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 588c89cd..917c52da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ requires = ["poetry_core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.poetry.dependencies] -python = "^3.8.0" +python = "^3.8.1" # CLI prettytable = "^3.4.1" @@ -45,7 +45,7 @@ isort = "^5" black = "^23" mypy = "^1" -flake8 = "^5" +flake8 = "^7.0.0" flake8-docstrings = "^1" Flake8-pyproject = "^1" From 467fe04714a59dd5aaab8ff5874c71ff4747759f Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 24 Jan 2024 17:23:11 +0100 Subject: [PATCH 13/20] python 3.12 support --- golem/utils/logging.py | 9 ++++++++- pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/golem/utils/logging.py b/golem/utils/logging.py index 3f57cab7..fcdaa90f 100644 --- a/golem/utils/logging.py +++ b/golem/utils/logging.py @@ -1,6 +1,7 @@ import contextvars import inspect import logging +import sys from datetime import datetime, timezone from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union @@ -9,6 +10,12 @@ from golem.event_bus import Event +if (3, 8) <= sys.version_info: + FilterReturn = Union[bool, logging.LogRecord] +else: + FilterReturn = bool + + DEFAULT_LOGGING = { "version": 1, "disable_existing_loggers": False, @@ -241,7 +248,7 @@ async def _async_wrapper(self, func: Callable, args: Sequence, kwargs: Dict) -> class AddTraceIdFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: + def filter(self, record: logging.LogRecord) -> FilterReturn: record.traceid = trace_id_var.get() return super().filter(record) diff --git a/pyproject.toml b/pyproject.toml index 917c52da..43826a08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ isort = "^5" black = "^23" mypy = "^1" -flake8 = "^7.0.0" +flake8 = "^7" flake8-docstrings = "^1" Flake8-pyproject = "^1" From e8b0d4259d04ecf00f4093e13b2ee1c4d8f71d1b Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 24 Jan 2024 17:30:58 +0100 Subject: [PATCH 14/20] python 3.12 support --- golem/utils/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/golem/utils/logging.py b/golem/utils/logging.py index fcdaa90f..11f3295e 100644 --- a/golem/utils/logging.py +++ b/golem/utils/logging.py @@ -10,7 +10,7 @@ from golem.event_bus import Event -if (3, 8) <= sys.version_info: +if (3, 12) <= sys.version_info: FilterReturn = Union[bool, logging.LogRecord] else: FilterReturn = bool From d6eac918052cc2a4ed092a5fcc3e87ba4498514a Mon Sep 17 00:00:00 2001 From: approxit Date: Mon, 29 Jan 2024 19:07:23 +0100 Subject: [PATCH 15/20] better expire --- golem/managers/base.py | 5 +- golem/managers/demand/refreshing.py | 13 ++- golem/managers/proposal/plugins/buffer.py | 35 +++++++- .../plugins/negotiating/negotiating_plugin.py | 8 +- .../proposal/plugins/scoring/mixins.py | 7 +- .../plugins/scoring/scoring_buffer.py | 27 +++++- golem/resources/demand/demand.py | 9 +- golem/resources/proposal/proposal.py | 13 +++ golem/utils/asyncio/buffer.py | 83 ++++++++++--------- golem/utils/asyncio/semaphore.py | 20 +++++ golem/utils/asyncio/tasks.py | 15 +++- golem/utils/asyncio/waiter.py | 2 + golem/utils/typing.py | 6 +- 13 files changed, 177 insertions(+), 66 deletions(-) diff --git a/golem/managers/base.py b/golem/managers/base.py index cc5df9b6..05484744 100644 --- a/golem/managers/base.py +++ b/golem/managers/base.py @@ -15,6 +15,7 @@ Script, ) from golem.resources.activity import commands +from golem.utils.typing import MaybeAwaitable logger = logging.getLogger(__name__) @@ -184,7 +185,7 @@ class ProposalNegotiator(ABC): @abstractmethod def __call__( self, demand_data: DemandData, proposal_data: ProposalData - ) -> Union[Awaitable[Optional[RejectProposal]], Optional[RejectProposal]]: + ) -> MaybeAwaitable[Optional[RejectProposal]]: ... @@ -212,7 +213,7 @@ class ProposalScorer(ABC): @abstractmethod def __call__( self, proposals_data: Sequence[ProposalData] - ) -> Union[Awaitable[ProposalScoringResult], ProposalScoringResult]: + ) -> MaybeAwaitable[ProposalScoringResult]: ... diff --git a/golem/managers/demand/refreshing.py b/golem/managers/demand/refreshing.py index 6289a074..844eb0a9 100644 --- a/golem/managers/demand/refreshing.py +++ b/golem/managers/demand/refreshing.py @@ -1,6 +1,6 @@ import asyncio import logging -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Awaitable, Callable, List, Tuple from golem.managers.base import DemandManager @@ -67,13 +67,10 @@ async def _wait_for_demand_to_expire(self): if not self._demands: return - remaining: timedelta = ( - datetime.utcfromtimestamp( - self._demands[-1][0].data.properties["golem.srv.comp.expiration"] / 1000 - ) - - datetime.utcnow() - ) - await asyncio.sleep(remaining.seconds) + await self._demands[-1][0].get_data() + expiration_date = self._demands[-1][0].get_expiration_date() + remaining = expiration_date - datetime.now(timezone.utc) + await asyncio.sleep(remaining.total_seconds()) @trace_span() async def _create_and_subscribe_demand(self): diff --git a/golem/managers/proposal/plugins/buffer.py b/golem/managers/proposal/plugins/buffer.py index a88058b0..efb0cca7 100644 --- a/golem/managers/proposal/plugins/buffer.py +++ b/golem/managers/proposal/plugins/buffer.py @@ -1,9 +1,13 @@ import logging +from datetime import timedelta +from typing import Callable, Optional from golem.managers import ProposalManagerPlugin from golem.resources import Proposal -from golem.utils.asyncio.buffer import BackgroundFillBuffer, SimpleBuffer +from golem.utils.asyncio.buffer import BackgroundFillBuffer, Buffer, ExpirableBuffer, SimpleBuffer +from golem.utils.asyncio.tasks import resolve_maybe_awaitable from golem.utils.logging import trace_span +from golem.utils.typing import MaybeAwaitable logger = logging.getLogger(__name__) @@ -15,23 +19,38 @@ def __init__( max_size: int, fill_concurrency_size=1, fill_at_start=False, + get_expiration_func: Optional[ + Callable[[Proposal], MaybeAwaitable[Optional[timedelta]]] + ] = None, + on_expiration_func: Optional[Callable[[Proposal], MaybeAwaitable[None]]] = None, ) -> None: self._min_size = min_size self._max_size = max_size self._fill_concurrency_size = fill_concurrency_size self._fill_at_start = fill_at_start + self._get_expiration_func = get_expiration_func + self._on_expiration_func = on_expiration_func + + buffer: Buffer[Proposal] = SimpleBuffer() + + if self._get_expiration_func is not None: + buffer = ExpirableBuffer( + buffer=buffer, + get_expiration_func=self._get_expiration_func, + on_expired_func=self._on_item_expire, + ) self._buffer: BackgroundFillBuffer[Proposal] = BackgroundFillBuffer( - buffer=SimpleBuffer(), + buffer=buffer, fill_func=self._call_feed_func, fill_concurrency_size=self._fill_concurrency_size, - on_added_callback=self._on_added_callback, + on_added_func=self._on_item_added, ) async def _call_feed_func(self) -> Proposal: return await self._get_proposal() - async def _on_added_callback(self) -> None: + async def _on_item_added(self, item: Proposal) -> None: count_current = self._buffer.size() count_with_requested = self._buffer.size_with_requested() pending = count_with_requested - count_current @@ -43,6 +62,14 @@ async def _on_added_callback(self) -> None: self._max_size, ) + async def _on_item_expire(self, item: Proposal): + logger.debug("Item %r expired, rejecting proposal and requesting fill", item) + await item.reject("Proposal no longer needed") + await self._request_items() + + if self._on_expiration_func is not None: + await resolve_maybe_awaitable(self._on_expiration_func, item) + @trace_span() async def start(self) -> None: await self._buffer.start() diff --git a/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py b/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py index 15edb600..283dd993 100644 --- a/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py +++ b/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py @@ -9,6 +9,7 @@ from golem.managers import ProposalManagerPlugin, RejectProposal from golem.managers.base import ProposalNegotiator from golem.resources import DemandData, Proposal +from golem.utils.asyncio.tasks import resolve_maybe_awaitable from golem.utils.logging import trace_span logger = logging.getLogger(__name__) @@ -116,10 +117,9 @@ async def _run_negotiators( proposal_data = await offer_proposal.get_proposal_data() for negotiator in self._proposal_negotiators: - negotiator_result = negotiator(demand_data_after_negotiators, proposal_data) - - if asyncio.iscoroutine(negotiator_result): - negotiator_result = await negotiator_result + negotiator_result = await resolve_maybe_awaitable( + negotiator, demand_data_after_negotiators, proposal_data + ) if isinstance(negotiator_result, RejectProposal): raise negotiator_result diff --git a/golem/managers/proposal/plugins/scoring/mixins.py b/golem/managers/proposal/plugins/scoring/mixins.py index 8bf9f9db..b29ade39 100644 --- a/golem/managers/proposal/plugins/scoring/mixins.py +++ b/golem/managers/proposal/plugins/scoring/mixins.py @@ -1,10 +1,10 @@ -import inspect from datetime import datetime from typing import List, Optional, Sequence, Tuple, cast from golem.managers.base import ScorerWithOptionalWeight from golem.payload import PayloadSyntaxParser, Properties from golem.resources import Proposal, ProposalData +from golem.utils.asyncio.tasks import resolve_maybe_awaitable from golem.utils.logging import trace_span @@ -43,10 +43,7 @@ async def _run_scorers( else: weight = 1 - scorer_scores = scorer(proposals_data) - - if inspect.isawaitable(scorer_scores): - scorer_scores = await scorer_scores + scorer_scores = await resolve_maybe_awaitable(scorer, proposals_data) proposal_scores.append((weight, scorer_scores)) # type: ignore[arg-type] diff --git a/golem/managers/proposal/plugins/scoring/scoring_buffer.py b/golem/managers/proposal/plugins/scoring/scoring_buffer.py index 7e49069b..fe6ccf53 100644 --- a/golem/managers/proposal/plugins/scoring/scoring_buffer.py +++ b/golem/managers/proposal/plugins/scoring/scoring_buffer.py @@ -6,7 +6,13 @@ from golem.managers.proposal.plugins.buffer import BufferPlugin as BufferPlugin from golem.managers.proposal.plugins.scoring import ProposalScoringMixin from golem.resources import Proposal -from golem.utils.asyncio import Buffer, SimpleBuffer, cancel_and_await, create_task_with_logging +from golem.utils.asyncio import ( + Buffer, + ExpirableBuffer, + SimpleBuffer, + cancel_and_await, + create_task_with_logging, +) from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) @@ -14,11 +20,26 @@ class ScoringBufferPlugin(ProposalScoringMixin, BufferPlugin): def __init__(self, update_interval: timedelta = timedelta(seconds=10), *args, **kwargs) -> None: + get_expiration_func = kwargs.pop("get_expiration_func", None) + super().__init__(*args, **kwargs) self._update_interval = update_interval - self._buffer_scored: Buffer[Proposal] = SimpleBuffer() + # Postponing argument would disable expiration from BufferPlugin parent + # as we want to expire only scored items instead + self._get_expiration_func = get_expiration_func + + scored_buffer: Buffer[Proposal] = SimpleBuffer() + + if get_expiration_func is not None: + scored_buffer = ExpirableBuffer( + buffer=scored_buffer, + get_expiration_func=get_expiration_func, + on_expired_func=self._on_item_expire, + ) + + self._buffer_scored: Buffer[Proposal] = scored_buffer self._background_loop_task: Optional[asyncio.Task] = None @trace_span() @@ -37,7 +58,7 @@ async def stop(self) -> None: await cancel_and_await(self._background_loop_task) self._background_loop_task = None - async def _on_added_callback(self): + async def _on_item_added(self, item: Proposal): pass # explicit no-op async def _background_loop(self) -> None: diff --git a/golem/resources/demand/demand.py b/golem/resources/demand/demand.py index 2b9f4e27..34eef73a 100644 --- a/golem/resources/demand/demand.py +++ b/golem/resources/demand/demand.py @@ -1,5 +1,5 @@ import asyncio -from datetime import datetime +from datetime import datetime, timedelta from typing import TYPE_CHECKING, AsyncIterator, Callable, Dict, List, Optional, Union, cast from ya_market import RequestorApi @@ -15,6 +15,8 @@ if TYPE_CHECKING: from golem.node import GolemNode +DEFAULT_TTL = timedelta(hours=1) + class Demand(Resource[RequestorApi, models.Demand, _NULL, Proposal, _NULL], YagnaEventCollector): """A single demand on the Golem Network. @@ -139,3 +141,8 @@ async def get_demand_data(self) -> DemandData: ) return self._demand_data + + def get_expiration_date(self) -> datetime: + """Return expiration date to auto unsubscribe.""" + + return cast(datetime, self.data.timestamp) + DEFAULT_TTL diff --git a/golem/resources/proposal/proposal.py b/golem/resources/proposal/proposal.py index 488d0245..1f91e1f2 100644 --- a/golem/resources/proposal/proposal.py +++ b/golem/resources/proposal/proposal.py @@ -16,6 +16,8 @@ from golem.node import GolemNode from golem.resources.demand import Demand +DEFAULT_TTL = timedelta(hours=1) + class Proposal( Resource[ @@ -230,3 +232,14 @@ async def get_provider_name(self): node_info = NodeInfo.from_properties(proposal_data.properties) self._provider_node_name = node_info.name return self._provider_node_name + + def get_expiration_date(self) -> datetime: + """Return expiration date to auto unsubscribe. + + Note: As Proposal can have different expiration date than its Demand, it would be unusable + after demand expiration anyway, hence earliest from both dates is returned. + """ + demand_expiration_date = self.demand.get_expiration_date() + proposal_expiration_date = cast(datetime, self.data.timestamp) + DEFAULT_TTL + + return min(proposal_expiration_date, demand_expiration_date) diff --git a/golem/utils/asyncio/buffer.py b/golem/utils/asyncio/buffer.py index bcca3f53..ea9e2ff4 100644 --- a/golem/utils/asyncio/buffer.py +++ b/golem/utils/asyncio/buffer.py @@ -16,9 +16,14 @@ ) from golem.utils.asyncio.semaphore import SingleUseSemaphore -from golem.utils.asyncio.tasks import cancel_and_await_many, create_task_with_logging +from golem.utils.asyncio.tasks import ( + cancel_and_await_many, + create_task_with_logging, + resolve_maybe_awaitable, +) from golem.utils.asyncio.waiter import Waiter from golem.utils.logging import get_trace_id_name, trace_span +from golem.utils.typing import MaybeAwaitable TItem = TypeVar("TItem") @@ -190,20 +195,19 @@ class ExpirableBuffer(ComposableBuffer[TItem]): def __init__( self, buffer: Buffer[TItem], - get_expiration_func: Callable[[TItem], Optional[timedelta]], - on_expiration_func: Optional[Callable[[TItem], Awaitable[None]]] = None, + get_expiration_func: Callable[[TItem], MaybeAwaitable[Optional[timedelta]]], + on_expired_func: Optional[Callable[[TItem], MaybeAwaitable[None]]] = None, ): super().__init__(buffer) self._get_expiration_func = get_expiration_func - self._on_expiration_func = on_expiration_func + self._on_expired_func = on_expired_func - # Lock is used to keep items in buffer and expiration tasks in sync - self._lock = asyncio.Lock() + # TODO: Could this collection be liable to race conditions? self._expiration_handlers: Dict[int, List[asyncio.TimerHandle]] = defaultdict(list) - def _add_expiration_task_for_item(self, item: TItem) -> None: - expiration = self._get_expiration_func(item) + async def _add_expiration_task_for_item(self, item: TItem) -> None: + expiration = await resolve_maybe_awaitable(self._get_expiration_func, item) if expiration is None: return @@ -212,11 +216,14 @@ def _add_expiration_task_for_item(self, item: TItem) -> None: self._expiration_handlers[id(item)].append( loop.call_later( - expiration.total_seconds(), lambda: asyncio.create_task(self._expire_item(item)) + expiration.total_seconds(), + lambda: create_task_with_logging( + self._expire_item(item), trace_id=get_trace_id_name(self, f"item-expire-{item}") + ), ) ) - async def _remove_expiration_handler_for_item(self, item: TItem) -> None: + def _remove_expiration_handler_for_item(self, item: TItem) -> None: item_id = id(item) if item_id not in self._expiration_handlers or not len(self._expiration_handlers[item_id]): @@ -228,7 +235,7 @@ async def _remove_expiration_handler_for_item(self, item: TItem) -> None: if not self._expiration_handlers[item_id]: del self._expiration_handlers[item_id] - async def _remove_all_expiration_handlers(self) -> None: + def _remove_all_expiration_handlers(self) -> None: for handlers in self._expiration_handlers.values(): for handler in handlers: handler.cancel() @@ -236,42 +243,44 @@ async def _remove_all_expiration_handlers(self) -> None: self._expiration_handlers.clear() async def get(self) -> TItem: - async with self._lock: - item = await super().get() - await self._remove_expiration_handler_for_item(item) + item = await super().get() + + self._remove_expiration_handler_for_item(item) - return item + return item async def get_all(self) -> MutableSequence[TItem]: - async with self._lock: - items = await super().get_all() - await self._remove_all_expiration_handlers() - return items + items = await super().get_all() + + self._remove_all_expiration_handlers() + + return items async def put(self, item: TItem) -> None: - async with self._lock: - await super().put(item) - self._add_expiration_task_for_item(item) + await super().put(item) + + await self._add_expiration_task_for_item(item) async def put_all(self, items: Sequence[TItem]) -> None: - async with self._lock: - await super().put_all(items) - await self._remove_all_expiration_handlers() + await super().put_all(items) + + self._remove_all_expiration_handlers() - for item in items: - self._add_expiration_task_for_item(item) + await asyncio.gather( + *[self._add_expiration_task_for_item(item) for item in items], return_exceptions=True + ) async def remove(self, item: TItem) -> None: - async with self._lock: - await super().remove(item) - await self._remove_expiration_handler_for_item(item) + await super().remove(item) - @trace_span() + self._remove_expiration_handler_for_item(item) + + @trace_span(show_arguments=True) async def _expire_item(self, item: TItem) -> None: await self.remove(item) - if self._on_expiration_func: - await self._on_expiration_func(item) + if self._on_expired_func: + await resolve_maybe_awaitable(self._on_expired_func, item) class BackgroundFillBuffer(ComposableBuffer[TItem]): @@ -286,13 +295,13 @@ def __init__( buffer: Buffer[TItem], fill_func: Callable[[], Awaitable[TItem]], fill_concurrency_size=1, - on_added_callback: Optional[Callable[[], Awaitable[None]]] = None, + on_added_func: Optional[Callable[[TItem], Awaitable[None]]] = None, ): super().__init__(buffer) self._fill_func = fill_func self._fill_concurrency_size = fill_concurrency_size - self._on_added_callback = on_added_callback + self._on_added_func = on_added_func self._is_started = False self._worker_tasks: List[asyncio.Task] = [] @@ -339,8 +348,8 @@ async def _worker_loop(self) -> None: await self.put(item) - if self._on_added_callback is not None: - await self._on_added_callback() + if self._on_added_func is not None: + await self._on_added_func(item) logger.debug("Adding new item done with total of %d items in buffer", self.size()) diff --git a/golem/utils/asyncio/semaphore.py b/golem/utils/asyncio/semaphore.py index cc77b814..8d651bfd 100644 --- a/golem/utils/asyncio/semaphore.py +++ b/golem/utils/asyncio/semaphore.py @@ -2,6 +2,9 @@ class SingleUseSemaphore: + """Class similar to `asyncio.Semaphore` but with more limited count of `.acquire()` calls and\ + exposed counters.""" + def __init__(self, value=0): if value < 0: raise ValueError("Initial value must be greater or equal to zero!") @@ -22,9 +25,14 @@ async def __aexit__(self, exc_type, exc, tb): self.release() def locked(self) -> bool: + """Return True if there are no more "charges" left in semaphore.""" + return not self._value async def acquire(self) -> None: + """Decrease "charges" counter and increase pending count, or await until there any\ + "charges".""" + async with self._condition: await self._condition.wait_for(lambda: self._value) @@ -32,6 +40,8 @@ async def acquire(self) -> None: self._pending += 1 def release(self) -> None: + """Decrease pending count.""" + if self._pending <= 0: raise RuntimeError("Release called too many times!") @@ -41,20 +51,30 @@ def release(self) -> None: self.finished.set() async def increase(self, value: int) -> None: + """Add given "charges" amount.""" + async with self._condition: self._value += value self.finished.clear() self._condition.notify(value) def get_count(self) -> int: + """Return "charges" count.""" + return self._value def get_count_with_pending(self) -> int: + """Return sum of "charges" and pending count.""" + return self.get_count() + self.get_pending_count() def get_pending_count(self) -> int: + """Return pending count.""" + return self._pending def reset(self) -> None: + """Reset "charges" amount to zero.""" + self._value = 0 self.finished.set() diff --git a/golem/utils/asyncio/tasks.py b/golem/utils/asyncio/tasks.py index 039d3c86..a425838a 100644 --- a/golem/utils/asyncio/tasks.py +++ b/golem/utils/asyncio/tasks.py @@ -1,9 +1,13 @@ import asyncio import contextvars +import inspect import logging -from typing import Iterable, Optional +from typing import Callable, Iterable, Optional, TypeVar, cast from golem.utils.logging import trace_id_var +from golem.utils.typing import MaybeAwaitable + +T = TypeVar("T") logger = logging.getLogger(__name__) @@ -54,3 +58,12 @@ async def cancel_and_await(task: asyncio.Task) -> None: async def cancel_and_await_many(tasks: Iterable[asyncio.Task]) -> None: await asyncio.gather(*[cancel_and_await(task) for task in tasks]) + + +async def resolve_maybe_awaitable(func: Callable[..., MaybeAwaitable[T]], *args, **kwargs) -> T: + result = func(*args, **kwargs) + + if inspect.iscoroutine(result): + result = await result + + return cast(T, result) # FIXME: This cast should not be needed diff --git a/golem/utils/asyncio/waiter.py b/golem/utils/asyncio/waiter.py index 10c918e5..143bcf99 100644 --- a/golem/utils/asyncio/waiter.py +++ b/golem/utils/asyncio/waiter.py @@ -13,6 +13,7 @@ def __init__(self) -> None: async def wait_for(self, predicate: Callable[[], bool]) -> None: """Check if predicate is true and return immediately, or await until it becomes true.""" + result = predicate() while not result: @@ -43,6 +44,7 @@ async def _wait(self) -> None: def notify(self, count=1) -> None: """Notify given amount of `.wait_for()` calls to check its predicates.""" + notified = 0 for waiter in self._waiters: if count <= notified: diff --git a/golem/utils/typing.py b/golem/utils/typing.py index 76b774dc..7877ba15 100644 --- a/golem/utils/typing.py +++ b/golem/utils/typing.py @@ -1,4 +1,8 @@ -from typing import Any, Callable, Optional, Type, Union, get_args, get_origin +from typing import Any, Awaitable, Callable, Optional, Type, TypeVar, Union, get_args, get_origin + +T = TypeVar("T") + +MaybeAwaitable = Union[Awaitable[T], T] def match_type_union_aware(obj_type: Type, match_func: Callable[[Type], bool]) -> Optional[Any]: From 9938b45416bd0ed8bdb374208630378d167ac728 Mon Sep 17 00:00:00 2001 From: approxit Date: Mon, 29 Jan 2024 22:57:00 +0100 Subject: [PATCH 16/20] changed buffer interface --- golem/utils/asyncio/buffer.py | 173 ++++++++++++++++++-------------- golem/utils/asyncio/waiter.py | 5 +- tests/unit/utils/test_buffer.py | 44 ++++---- 3 files changed, 128 insertions(+), 94 deletions(-) diff --git a/golem/utils/asyncio/buffer.py b/golem/utils/asyncio/buffer.py index ea9e2ff4..78bf4e27 100644 --- a/golem/utils/asyncio/buffer.py +++ b/golem/utils/asyncio/buffer.py @@ -2,6 +2,7 @@ import logging from abc import ABC, abstractmethod from collections import defaultdict +from contextlib import asynccontextmanager from datetime import timedelta from typing import ( Awaitable, @@ -21,7 +22,6 @@ create_task_with_logging, resolve_maybe_awaitable, ) -from golem.utils.asyncio.waiter import Waiter from golem.utils.logging import get_trace_id_name, trace_span from golem.utils.typing import MaybeAwaitable @@ -34,16 +34,18 @@ class Buffer(ABC, Generic[TItem]): """Interface class for object similar to `asyncio.Queue` but with more control over its \ items.""" + condition: asyncio.Condition + @abstractmethod def size(self) -> int: """Return number of items stored in buffer.""" @abstractmethod - async def wait_for_any_items(self) -> None: + async def wait_for_any_items(self, *, lock=True) -> None: """Wait until any items are stored in buffer.""" @abstractmethod - async def get(self) -> TItem: + async def get(self, *, lock=True) -> TItem: """Await, remove and return left-most item stored in buffer. If `.set_exception()` was previously called, exception will be raised only if buffer @@ -51,7 +53,7 @@ async def get(self) -> TItem: """ @abstractmethod - async def get_all(self) -> MutableSequence[TItem]: + async def get_all(self, *, lock=True) -> MutableSequence[TItem]: """Remove and return all items stored in buffer. Note that this method will not await for any items if buffer is empty. @@ -61,25 +63,25 @@ async def get_all(self) -> MutableSequence[TItem]: """ @abstractmethod - async def put(self, item: TItem) -> None: + async def put(self, item: TItem, *, lock=True) -> None: """Add item to right-most position to buffer. Duplicates are supported. """ @abstractmethod - async def put_all(self, items: Sequence[TItem]) -> None: + async def put_all(self, items: Sequence[TItem], *, lock=True) -> None: """Replace all items stored in buffer. Duplicates are supported. """ @abstractmethod - async def remove(self, item: TItem) -> None: + async def remove(self, item: TItem, *, lock=True) -> None: """Remove first occurrence of item from buffer or raise `ValueError` if not found.""" @abstractmethod - def set_exception(self, exc: BaseException) -> None: + async def set_exception(self, exc: BaseException, *, lock=True) -> None: """Set exception that will be raised while trying to `.get()`/`.get_all()` item from \ empty buffer.""" @@ -94,29 +96,37 @@ class ComposableBuffer(Buffer[TItem]): def __init__(self, buffer: Buffer[TItem]): self._buffer = buffer + @asynccontextmanager + async def _handle_lock(self, lock: bool): + if lock: + async with self._buffer.condition: + yield + else: + yield + def size(self) -> int: return self._buffer.size() - async def wait_for_any_items(self) -> None: - await self._buffer.wait_for_any_items() + async def wait_for_any_items(self, *, lock=True) -> None: + await self._buffer.wait_for_any_items(lock=lock) - async def get(self) -> TItem: - return await self._buffer.get() + async def get(self, *, lock=True) -> TItem: + return await self._buffer.get(lock=lock) - async def get_all(self) -> MutableSequence[TItem]: - return await self._buffer.get_all() + async def get_all(self, *, lock=True) -> MutableSequence[TItem]: + return await self._buffer.get_all(lock=lock) - async def put(self, item: TItem) -> None: - await self._buffer.put(item) + async def put(self, item: TItem, *, lock=True) -> None: + await self._buffer.put(item, lock=lock) - async def put_all(self, items: Sequence[TItem]) -> None: - await self._buffer.put_all(items) + async def put_all(self, items: Sequence[TItem], *, lock=True) -> None: + await self._buffer.put_all(items, lock=lock) - async def remove(self, item: TItem) -> None: - await self._buffer.remove(item) + async def remove(self, item: TItem, *, lock=True) -> None: + await self._buffer.remove(item, lock=lock) - def set_exception(self, exc: BaseException) -> None: - self._buffer.set_exception(exc) + async def set_exception(self, exc: BaseException, *, lock=True) -> None: + await self._buffer.set_exception(exc, lock=lock) def reset_exception(self) -> None: self._buffer.reset_exception() @@ -129,55 +139,64 @@ def __init__(self, items: Optional[Sequence[TItem]] = None): self._items = list(items) if items is not None else [] self._error: Optional[BaseException] = None - self._waiter = Waiter() + self.condition = asyncio.Condition() def size(self) -> int: return len(self._items) - @trace_span() - async def wait_for_any_items(self) -> None: - await self._waiter.wait_for(lambda: bool(self.size() or self._error)) + @asynccontextmanager + async def _handle_lock(self, lock: bool): + if lock: + async with self.condition: + yield + else: + yield @trace_span() - async def get(self) -> TItem: - if not self.size(): - if self._error: - raise self._error - else: - await self.wait_for_any_items() + async def wait_for_any_items(self, lock=True) -> None: + async with self._handle_lock(lock): + await self.condition.wait_for(lambda: bool(self.size() or self._error)) - if not self.size() and self._error: - raise self._error + @trace_span() + async def get(self, *, lock=True) -> TItem: + async with self._handle_lock(lock): + await self.wait_for_any_items(lock=False) - item = self._items.pop(0) + if not self.size() and self._error: + raise self._error - return item + return self._items.pop(0) - async def get_all(self) -> MutableSequence[TItem]: - if not self._items and self._error: - raise self._error + async def get_all(self, *, lock=True) -> MutableSequence[TItem]: + async with self._handle_lock(lock): + if not self._items and self._error: + raise self._error - items = self._items[:] - self._items.clear() + items = self._items[:] + self._items.clear() - return items + return items - async def put(self, item: TItem) -> None: - self._items.append(item) - self._waiter.notify() + async def put(self, item: TItem, *, lock=True) -> None: + async with self._handle_lock(lock): + self._items.append(item) + self.condition.notify() - async def put_all(self, items: Sequence[TItem]) -> None: - self._items.clear() - self._items.extend(items[:]) + async def put_all(self, items: Sequence[TItem], *, lock=True) -> None: + async with self._handle_lock(lock): + self._items.clear() + self._items.extend(items[:]) - self._waiter.notify(len(items)) + self.condition.notify(len(items)) - async def remove(self, item: TItem) -> None: - self._items.remove(item) + async def remove(self, item: TItem, *, lock=True) -> None: + async with self._handle_lock(lock): + self._items.remove(item) - def set_exception(self, exc: BaseException) -> None: - self._error = exc - self._waiter.notify() + async def set_exception(self, exc: BaseException, *, lock=True) -> None: + async with self._handle_lock(lock): + self._error = exc + self.condition.notify() def reset_exception(self) -> None: self._error = None @@ -242,38 +261,44 @@ def _remove_all_expiration_handlers(self) -> None: self._expiration_handlers.clear() - async def get(self) -> TItem: - item = await super().get() + async def get(self, *, lock=True) -> TItem: + async with self._handle_lock(lock): + item = await super().get(lock=False) - self._remove_expiration_handler_for_item(item) + self._remove_expiration_handler_for_item(item) - return item + return item - async def get_all(self) -> MutableSequence[TItem]: - items = await super().get_all() + async def get_all(self, *, lock=True) -> MutableSequence[TItem]: + async with self._handle_lock(lock): + items = await super().get_all(lock=False) - self._remove_all_expiration_handlers() + self._remove_all_expiration_handlers() - return items + return items - async def put(self, item: TItem) -> None: - await super().put(item) + async def put(self, item: TItem, *, lock=True) -> None: + async with self._handle_lock(lock): + await super().put(item, lock=False) - await self._add_expiration_task_for_item(item) + await self._add_expiration_task_for_item(item) - async def put_all(self, items: Sequence[TItem]) -> None: - await super().put_all(items) + async def put_all(self, items: Sequence[TItem], *, lock=True) -> None: + async with self._handle_lock(lock): + await super().put_all(items, lock=False) - self._remove_all_expiration_handlers() + self._remove_all_expiration_handlers() - await asyncio.gather( - *[self._add_expiration_task_for_item(item) for item in items], return_exceptions=True - ) + await asyncio.gather( + *[self._add_expiration_task_for_item(item) for item in items], + return_exceptions=True, + ) - async def remove(self, item: TItem) -> None: - await super().remove(item) + async def remove(self, item: TItem, *, lock=True) -> None: + async with self._handle_lock(lock): + await super().remove(item, lock=False) - self._remove_expiration_handler_for_item(item) + self._remove_expiration_handler_for_item(item) @trace_span(show_arguments=True) async def _expire_item(self, item: TItem) -> None: diff --git a/golem/utils/asyncio/waiter.py b/golem/utils/asyncio/waiter.py index 143bcf99..c0b7dfc3 100644 --- a/golem/utils/asyncio/waiter.py +++ b/golem/utils/asyncio/waiter.py @@ -5,7 +5,10 @@ class Waiter: """Class similar to `asyncio.Event` but valueless and with notify interface similar to \ - `asyncio.Condition`.""" + `asyncio.Condition`. + + Note: Developed to support `golem.utils.asyncio.buffer`, but finally not used. + """ def __init__(self) -> None: self._waiters: Deque[asyncio.Future] = collections.deque() diff --git a/tests/unit/utils/test_buffer.py b/tests/unit/utils/test_buffer.py index 557f5c42..e6624311 100644 --- a/tests/unit/utils/test_buffer.py +++ b/tests/unit/utils/test_buffer.py @@ -8,7 +8,11 @@ @pytest.fixture def mocked_buffer(mocker): - return mocker.Mock(spec=Buffer) + mock = mocker.Mock(spec=Buffer) + + mock.condition = mocker.AsyncMock() + + return mock def test_simple_buffer_creation(): @@ -131,7 +135,7 @@ async def test_simple_buffer_exceptions(): exc = ZeroDivisionError() - buffer.set_exception(exc) + await buffer.set_exception(exc) # should raise when exception set and no items with pytest.raises(ZeroDivisionError): @@ -164,7 +168,7 @@ async def test_simple_buffer_exceptions(): await asyncio.sleep(0.1) - buffer.set_exception(exc) + await buffer.set_exception(exc) await asyncio.sleep(0.1) @@ -196,7 +200,7 @@ async def test_simple_buffer_wait_for_any_items(): # should not block a long time on item await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) - buffer.set_exception(ZeroDivisionError()) + await buffer.set_exception(ZeroDivisionError()) # should not block a long time on item with exception await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) @@ -219,7 +223,7 @@ async def test_simple_buffer_wait_for_any_items(): assert not wait_task.done() - buffer.set_exception(ZeroDivisionError()) + await buffer.set_exception(ZeroDivisionError()) await asyncio.sleep(0.1) @@ -253,8 +257,8 @@ async def test_expirable_buffer_is_not_expiring_items_with_none_expiration(mocke await asyncio.sleep(0.2) - assert mocker.call("a") in mocked_buffer.remove.mock_calls - assert mocker.call("c") in mocked_buffer.remove.mock_calls + assert mocker.call("a", lock=False) in mocked_buffer.remove.mock_calls + assert mocker.call("c", lock=False) in mocked_buffer.remove.mock_calls mocked_buffer.get.return_value = "b" @@ -272,7 +276,7 @@ async def test_expirable_buffer_can_expire_items_with_put_get(mocked_buffer, moc item_put = object() await buffer.put(item_put) - mocked_buffer.put.assert_called_with(item_put) + mocked_buffer.put.assert_called_with(item_put, lock=False) mocked_buffer.get.return_value = item_put await buffer.get() @@ -289,11 +293,11 @@ async def test_expirable_buffer_can_expire_items_with_put_get(mocked_buffer, moc mocked_buffer.reset_mock() await buffer.put(item_put) - mocked_buffer.put.assert_called_with(item_put) + mocked_buffer.put.assert_called_with(item_put, lock=False) await asyncio.sleep(0.2) - mocked_buffer.remove.assert_called_with(item_put) + mocked_buffer.remove.assert_called_with(item_put, lock=False) on_expire.assert_called_with(item_put) @@ -308,13 +312,15 @@ async def test_expirable_buffer_can_expire_items_with_put_all_get_all(mocked_buf items_put_all = ["a", "b", "c"] await buffer.put_all(items_put_all) - mocked_buffer.put_all.assert_called_with(items_put_all) + mocked_buffer.put_all.assert_called_with(items_put_all, lock=False) mocked_buffer.get_all.return_value = items_put_all await buffer.get_all() - mocked_buffer.get_all.assert_called() + mocked_buffer.get_all.assert_called_with(lock=False) + + with pytest.raises(AssertionError): + mocked_buffer.remove.assert_called_with(lock=False) - mocked_buffer.remove.assert_not_called() on_expire.assert_not_called() await asyncio.sleep(0.2) @@ -324,13 +330,13 @@ async def test_expirable_buffer_can_expire_items_with_put_all_get_all(mocked_buf mocked_buffer.reset_mock() await buffer.put_all(items_put_all) - mocked_buffer.put_all.assert_called_with(items_put_all) + mocked_buffer.put_all.assert_called_with(items_put_all, lock=False) await asyncio.sleep(0.2) - assert mocker.call(items_put_all[0]) in mocked_buffer.remove.mock_calls - assert mocker.call(items_put_all[1]) in mocked_buffer.remove.mock_calls - assert mocker.call(items_put_all[2]) in mocked_buffer.remove.mock_calls + assert mocker.call(items_put_all[0], lock=False) in mocked_buffer.remove.mock_calls + assert mocker.call(items_put_all[1], lock=False) in mocked_buffer.remove.mock_calls + assert mocker.call(items_put_all[2], lock=False) in mocked_buffer.remove.mock_calls assert mocker.call(items_put_all[0]) in on_expire.mock_calls assert mocker.call(items_put_all[1]) in on_expire.mock_calls @@ -392,7 +398,7 @@ async def test_background_feed_buffer_request(mocked_buffer, mocker): await asyncio.sleep(0.1) - mocked_buffer.put.assert_called_with(item) + mocked_buffer.put.assert_called_with(item, lock=True) mocked_buffer.size.return_value = 1 assert buffer.size() == 1 assert buffer.size_with_requested() == 1 @@ -450,7 +456,7 @@ async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, e await asyncio.sleep(0.1) - mocked_buffer.put.assert_called_with(item) + mocked_buffer.put.assert_called_with(item, lock=True) mocked_buffer.size.return_value = 1 assert buffer.size() == 1 assert buffer.size_with_requested() == 1 From 3fe8fc7915ad4d5b86ddfc6253049bad1257e4c5 Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 31 Jan 2024 17:30:35 +0100 Subject: [PATCH 17/20] review changes --- examples/managers/blender/blender.py | 4 +- examples/managers/flexible_negotiation.py | 6 +- examples/managers/mid_agreement_payments.py | 4 +- examples/managers/proposal_plugins.py | 8 +- .../task_api_draft/task_api/activity_pool.py | 4 +- golem/event_bus/in_memory/event_bus.py | 4 +- golem/managers/__init__.py | 8 +- golem/managers/base.py | 27 +++++-- golem/managers/mixins.py | 4 +- golem/managers/proposal/__init__.py | 8 +- golem/managers/proposal/plugins/__init__.py | 8 +- golem/managers/proposal/plugins/buffer.py | 58 +++++++------- .../plugins/negotiating/negotiating_plugin.py | 2 +- .../proposal/plugins/scoring/__init__.py | 4 +- .../proposal/plugins/scoring/mixins.py | 2 +- .../plugins/scoring/scoring_buffer.py | 79 ++++++++++++------- golem/managers/work/plugins.py | 4 +- golem/utils/asyncio/__init__.py | 8 +- golem/utils/asyncio/buffer.py | 67 ++++++++-------- golem/utils/asyncio/queue.py | 4 +- golem/utils/asyncio/semaphore.py | 1 - golem/utils/asyncio/tasks.py | 24 +++--- tests/unit/utils/test_buffer.py | 14 ++-- tests/unit/utils/test_semaphore.py | 2 +- 24 files changed, 196 insertions(+), 158 deletions(-) diff --git a/examples/managers/blender/blender.py b/examples/managers/blender/blender.py index 5554ecaa..b96bcbe7 100644 --- a/examples/managers/blender/blender.py +++ b/examples/managers/blender/blender.py @@ -15,8 +15,8 @@ PayAllPaymentManager, PaymentPlatformNegotiator, PoolActivityManager, + ProposalScoringBuffer, RefreshingDemandManager, - ScoringBufferPlugin, WorkContext, WorkResult, retry, @@ -58,7 +58,7 @@ async def run_on_golem( demand_manager.get_initial_proposal, plugins=[ NegotiatingPlugin(proposal_negotiators=negotiators), - ScoringBufferPlugin( + ProposalScoringBuffer( min_size=3, max_size=5, fill_concurrency_size=3, proposal_scorers=scorers ), ], diff --git a/examples/managers/flexible_negotiation.py b/examples/managers/flexible_negotiation.py index 3e7bdeef..05e2718d 100644 --- a/examples/managers/flexible_negotiation.py +++ b/examples/managers/flexible_negotiation.py @@ -3,13 +3,13 @@ from golem.managers import ( BlacklistProviderIdPlugin, - BufferPlugin, DefaultAgreementManager, DefaultProposalManager, NegotiatingPlugin, PayAllPaymentManager, PaymentPlatformNegotiator, PoolActivityManager, + ProposalBuffer, RefreshingDemandManager, SequentialWorkManager, WorkContext, @@ -48,14 +48,14 @@ async def main(): golem, demand_manager.get_initial_proposal, plugins=[ - BufferPlugin( + ProposalBuffer( min_size=10, max_size=1000, fill_concurrency_size=5, ), BlacklistProviderIdPlugin(BLACKLISTED_PROVIDERS), NegotiatingPlugin(proposal_negotiators=[PaymentPlatformNegotiator()]), - BufferPlugin( + ProposalBuffer( min_size=3, max_size=5, fill_concurrency_size=3, diff --git a/examples/managers/mid_agreement_payments.py b/examples/managers/mid_agreement_payments.py index 7c99aaa9..b94320a9 100644 --- a/examples/managers/mid_agreement_payments.py +++ b/examples/managers/mid_agreement_payments.py @@ -3,13 +3,13 @@ from datetime import datetime, timedelta from golem.managers import ( - BufferPlugin, DefaultAgreementManager, DefaultProposalManager, MidAgreementPaymentsNegotiator, NegotiatingPlugin, PayAllPaymentManager, PaymentPlatformNegotiator, + ProposalBuffer, RefreshingDemandManager, SequentialWorkManager, SingleUseActivityManager, @@ -91,7 +91,7 @@ async def main(): ), ] ), - BufferPlugin( + ProposalBuffer( min_size=1, max_size=4, fill_concurrency_size=2, diff --git a/examples/managers/proposal_plugins.py b/examples/managers/proposal_plugins.py index 53f85d22..8bf4ca9d 100644 --- a/examples/managers/proposal_plugins.py +++ b/examples/managers/proposal_plugins.py @@ -5,7 +5,6 @@ from golem.managers import ( BlacklistProviderIdPlugin, - BufferPlugin, DefaultAgreementManager, DefaultProposalManager, LinearAverageCostPricing, @@ -15,11 +14,12 @@ PaymentPlatformNegotiator, PoolActivityManager, PropertyValueLerpScore, + ProposalBuffer, + ProposalScoringBuffer, RandomScore, RefreshingDemandManager, RejectIfCostsExceeds, RejectProposal, - ScoringBufferPlugin, SequentialWorkManager, WorkContext, WorkResult, @@ -78,7 +78,7 @@ async def main(): golem, demand_manager.get_initial_proposal, plugins=[ - BufferPlugin( + ProposalBuffer( min_size=10, max_size=1000, ), @@ -101,7 +101,7 @@ async def main(): else None, ) ), - ScoringBufferPlugin( + ProposalScoringBuffer( min_size=3, max_size=5, fill_concurrency_size=3, diff --git a/examples/task_api_draft/task_api/activity_pool.py b/examples/task_api_draft/task_api/activity_pool.py index 9ce2fa95..f97523ce 100644 --- a/examples/task_api_draft/task_api/activity_pool.py +++ b/examples/task_api_draft/task_api/activity_pool.py @@ -4,7 +4,7 @@ from golem.pipeline import InputStreamExhausted from golem.resources import Activity -from golem.utils.asyncio import cancel_and_await +from golem.utils.asyncio import ensure_cancelled class ActivityPool: @@ -95,7 +95,7 @@ async def _activity_destroyed_cleanup( return await activity.wait_destroyed() - await cancel_and_await(manager_task) + await ensure_cancelled(manager_task) async def _get_next_idle_activity( self, activity_stream: AsyncIterator[Union[Activity, Awaitable[Activity]]] diff --git a/golem/event_bus/in_memory/event_bus.py b/golem/event_bus/in_memory/event_bus.py index 2b7bf85e..d019f312 100644 --- a/golem/event_bus/in_memory/event_bus.py +++ b/golem/event_bus/in_memory/event_bus.py @@ -5,7 +5,7 @@ from typing import Awaitable, Callable, DefaultDict, List, Optional, Tuple, Type from golem.event_bus.base import Event, EventBus, EventBusError, TEvent -from golem.utils.asyncio import cancel_and_await, create_task_with_logging +from golem.utils.asyncio import create_task_with_logging, ensure_cancelled from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ async def stop(self): await self._event_queue.join() if self._process_event_queue_loop_task is not None: - await cancel_and_await(self._process_event_queue_loop_task) + await ensure_cancelled(self._process_event_queue_loop_task) self._process_event_queue_loop_task = None @trace_span(show_results=True) diff --git a/golem/managers/__init__.py b/golem/managers/__init__.py index 96006298..1d1beb3c 100644 --- a/golem/managers/__init__.py +++ b/golem/managers/__init__.py @@ -23,7 +23,6 @@ from golem.managers.payment import PayAllPaymentManager from golem.managers.proposal import ( BlacklistProviderIdPlugin, - BufferPlugin, DefaultProposalManager, LinearAverageCostPricing, LinearCoeffsCost, @@ -34,10 +33,11 @@ NegotiatingPlugin, PaymentPlatformNegotiator, PropertyValueLerpScore, + ProposalBuffer, + ProposalScoringBuffer, ProposalScoringMixin, RandomScore, RejectIfCostsExceeds, - ScoringBufferPlugin, ) from golem.managers.work import ( ConcurrentWorkManager, @@ -73,7 +73,7 @@ "PayAllPaymentManager", "DefaultProposalManager", "BlacklistProviderIdPlugin", - "BufferPlugin", + "ProposalBuffer", "PaymentPlatformNegotiator", "MidAgreementPaymentsNegotiator", "NegotiatingPlugin", @@ -86,7 +86,7 @@ "LinearCoeffsCost", "PropertyValueLerpScore", "RandomScore", - "ScoringBufferPlugin", + "ProposalScoringBuffer", "SequentialWorkManager", "ConcurrentWorkManager", "WorkManagerPluginsMixin", diff --git a/golem/managers/base.py b/golem/managers/base.py index 05484744..1a6006d1 100644 --- a/golem/managers/base.py +++ b/golem/managers/base.py @@ -2,7 +2,19 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, runtime_checkable, +) from golem.exceptions import GolemException from golem.resources import ( @@ -181,8 +193,8 @@ class RejectProposal(ManagerPluginException): pass -class ProposalNegotiator(ABC): - @abstractmethod +@runtime_checkable +class ProposalNegotiator(Protocol): def __call__( self, demand_data: DemandData, proposal_data: ProposalData ) -> MaybeAwaitable[Optional[RejectProposal]]: @@ -209,8 +221,8 @@ async def stop(self) -> None: ProposalScoringResult = Sequence[Optional[float]] -class ProposalScorer(ABC): - @abstractmethod +@runtime_checkable +class ProposalScorer(Protocol): def __call__( self, proposals_data: Sequence[ProposalData] ) -> MaybeAwaitable[ProposalScoringResult]: @@ -220,10 +232,11 @@ def __call__( ScorerWithOptionalWeight = Union[ProposalScorer, Tuple[float, ProposalScorer]] -class WorkManagerPlugin(ABC): - @abstractmethod +@runtime_checkable +class WorkManagerPlugin(Protocol): def __call__(self, do_work: DoWorkCallable) -> DoWorkCallable: ... +# TODO: Make consistent naming on functions in arguments in whole project - callable or func PricingCallable = Callable[[ProposalData], Optional[float]] diff --git a/golem/managers/mixins.py b/golem/managers/mixins.py index e6793e20..9bcba831 100644 --- a/golem/managers/mixins.py +++ b/golem/managers/mixins.py @@ -3,7 +3,7 @@ from typing import Generic, List, Optional, Sequence from golem.managers.base import ManagerException, TPlugin -from golem.utils.asyncio import cancel_and_await, create_task_with_logging +from golem.utils.asyncio import create_task_with_logging, ensure_cancelled from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ async def stop(self) -> None: raise ManagerException("Already stopped!") if self._background_loop_task is not None: - await cancel_and_await(self._background_loop_task) + await ensure_cancelled(self._background_loop_task) self._background_loop_task = None def is_started(self) -> bool: diff --git a/golem/managers/proposal/__init__.py b/golem/managers/proposal/__init__.py index 49918f2b..fee2e36c 100644 --- a/golem/managers/proposal/__init__.py +++ b/golem/managers/proposal/__init__.py @@ -1,7 +1,6 @@ from golem.managers.proposal.default import DefaultProposalManager from golem.managers.proposal.plugins import ( BlacklistProviderIdPlugin, - BufferPlugin, LinearAverageCostPricing, LinearCoeffsCost, LinearPerCpuAverageCostPricing, @@ -11,16 +10,17 @@ NegotiatingPlugin, PaymentPlatformNegotiator, PropertyValueLerpScore, + ProposalBuffer, + ProposalScoringBuffer, ProposalScoringMixin, RandomScore, RejectIfCostsExceeds, - ScoringBufferPlugin, ) __all__ = ( "DefaultProposalManager", "BlacklistProviderIdPlugin", - "BufferPlugin", + "ProposalBuffer", "PaymentPlatformNegotiator", "MidAgreementPaymentsNegotiator", "NegotiatingPlugin", @@ -33,5 +33,5 @@ "LinearPerCpuCoeffsCost", "PropertyValueLerpScore", "RandomScore", - "ScoringBufferPlugin", + "ProposalScoringBuffer", ) diff --git a/golem/managers/proposal/plugins/__init__.py b/golem/managers/proposal/plugins/__init__.py index 900eb32a..aac10077 100644 --- a/golem/managers/proposal/plugins/__init__.py +++ b/golem/managers/proposal/plugins/__init__.py @@ -1,5 +1,5 @@ from golem.managers.proposal.plugins.blacklist import BlacklistProviderIdPlugin -from golem.managers.proposal.plugins.buffer import BufferPlugin +from golem.managers.proposal.plugins.buffer import ProposalBuffer from golem.managers.proposal.plugins.linear_coeffs import LinearCoeffsCost, LinearPerCpuCoeffsCost from golem.managers.proposal.plugins.negotiating import ( MidAgreementPaymentsNegotiator, @@ -12,14 +12,14 @@ LinearPerCpuAverageCostPricing, MapScore, PropertyValueLerpScore, + ProposalScoringBuffer, ProposalScoringMixin, RandomScore, - ScoringBufferPlugin, ) __all__ = ( "BlacklistProviderIdPlugin", - "BufferPlugin", + "ProposalBuffer", "PaymentPlatformNegotiator", "MidAgreementPaymentsNegotiator", "NegotiatingPlugin", @@ -32,5 +32,5 @@ "LinearPerCpuCoeffsCost", "PropertyValueLerpScore", "RandomScore", - "ScoringBufferPlugin", + "ProposalScoringBuffer", ) diff --git a/golem/managers/proposal/plugins/buffer.py b/golem/managers/proposal/plugins/buffer.py index efb0cca7..9f524507 100644 --- a/golem/managers/proposal/plugins/buffer.py +++ b/golem/managers/proposal/plugins/buffer.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class BufferPlugin(ProposalManagerPlugin): +class ProposalBuffer(ProposalManagerPlugin): def __init__( self, min_size: int, @@ -31,63 +31,64 @@ def __init__( self._get_expiration_func = get_expiration_func self._on_expiration_func = on_expiration_func + # TODO: Consider moving buffer composition from here to plugin level buffer: Buffer[Proposal] = SimpleBuffer() - if self._get_expiration_func is not None: + if self._get_expiration_func: buffer = ExpirableBuffer( buffer=buffer, get_expiration_func=self._get_expiration_func, - on_expired_func=self._on_item_expire, + on_expired_func=self._on_expired, ) self._buffer: BackgroundFillBuffer[Proposal] = BackgroundFillBuffer( buffer=buffer, fill_func=self._call_feed_func, fill_concurrency_size=self._fill_concurrency_size, - on_added_func=self._on_item_added, + on_added_func=self._on_added, ) async def _call_feed_func(self) -> Proposal: return await self._get_proposal() - async def _on_item_added(self, item: Proposal) -> None: + async def _on_added(self, proposal: Proposal) -> None: count_current = self._buffer.size() count_with_requested = self._buffer.size_with_requested() pending = count_with_requested - count_current logger.debug( - "Item added, having %d items, and %d pending, target %d", + "Proposal added, having %d proposals, and %d already requested, target %d", count_current, pending, self._max_size, ) - async def _on_item_expire(self, item: Proposal): - logger.debug("Item %r expired, rejecting proposal and requesting fill", item) - await item.reject("Proposal no longer needed") - await self._request_items() + async def _on_expired(self, proposal: Proposal): + logger.debug("Rejecting expired `%r` and requesting fill", proposal) + await proposal.reject("Proposal no longer needed due to its near expiration.") + await self._request_proposals() - if self._on_expiration_func is not None: - await resolve_maybe_awaitable(self._on_expiration_func, item) + if self._on_expiration_func: + await resolve_maybe_awaitable(self._on_expiration_func(proposal)) @trace_span() async def start(self) -> None: await self._buffer.start() if self._fill_at_start: - await self._request_items() + await self._request_proposals() @trace_span() async def stop(self) -> None: await self._buffer.stop() - async def _request_items(self) -> None: + async def _request_proposals(self) -> None: count_current = self._buffer.size() count_with_requested = self._buffer.size_with_requested() requested = self._max_size - count_with_requested logger.debug( - "Having %d items, and %d already requested, requesting additional %d items to match" + "Proposal count %d and %d already requested, requesting additional %d to match" " target %d", count_current, count_with_requested - count_current, @@ -99,32 +100,31 @@ async def _request_items(self) -> None: @trace_span(show_results=True) async def get_proposal(self) -> Proposal: - if not self._get_items_count(): - logger.debug("No items to get, requesting fill") - await self._request_items() + if not self._get_buffered_proposals_count(): + logger.debug("No proposals to get, requesting fill") + await self._request_proposals() - proposal = await self._get_item() + proposal = await self._get_buffered_proposal() - items_count = self._get_items_count() - if items_count < self._min_size: + proposals_count = self._get_buffered_proposals_count() + if proposals_count < self._min_size: logger.debug( - "Items count is now `%s` which is below min size `%d`, requesting fill", - items_count, + "Proposals count %d is below minimum size %d, requesting fill", + proposals_count, self._min_size, ) - await self._request_items() + await self._request_proposals() else: logger.debug( - "Target items is now `%s` which is not below min size `%d`, requesting fill not" - " needed", - items_count, + "Proposals count %d is not below minimum size %d, skipping fill", + proposals_count, self._min_size, ) return proposal - async def _get_item(self) -> Proposal: + async def _get_buffered_proposal(self) -> Proposal: return await self._buffer.get() - def _get_items_count(self) -> int: + def _get_buffered_proposals_count(self) -> int: return self._buffer.size() diff --git a/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py b/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py index 283dd993..199339dc 100644 --- a/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py +++ b/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py @@ -118,7 +118,7 @@ async def _run_negotiators( for negotiator in self._proposal_negotiators: negotiator_result = await resolve_maybe_awaitable( - negotiator, demand_data_after_negotiators, proposal_data + negotiator(demand_data_after_negotiators, proposal_data) ) if isinstance(negotiator_result, RejectProposal): diff --git a/golem/managers/proposal/plugins/scoring/__init__.py b/golem/managers/proposal/plugins/scoring/__init__.py index b29ac289..35b374f2 100644 --- a/golem/managers/proposal/plugins/scoring/__init__.py +++ b/golem/managers/proposal/plugins/scoring/__init__.py @@ -6,7 +6,7 @@ ) from golem.managers.proposal.plugins.scoring.property_value_lerp import PropertyValueLerpScore from golem.managers.proposal.plugins.scoring.random import RandomScore -from golem.managers.proposal.plugins.scoring.scoring_buffer import ScoringBufferPlugin +from golem.managers.proposal.plugins.scoring.scoring_buffer import ProposalScoringBuffer __all__ = ( "MapScore", @@ -15,5 +15,5 @@ "LinearPerCpuAverageCostPricing", "PropertyValueLerpScore", "RandomScore", - "ScoringBufferPlugin", + "ProposalScoringBuffer", ) diff --git a/golem/managers/proposal/plugins/scoring/mixins.py b/golem/managers/proposal/plugins/scoring/mixins.py index b29ade39..c314870a 100644 --- a/golem/managers/proposal/plugins/scoring/mixins.py +++ b/golem/managers/proposal/plugins/scoring/mixins.py @@ -43,7 +43,7 @@ async def _run_scorers( else: weight = 1 - scorer_scores = await resolve_maybe_awaitable(scorer, proposals_data) + scorer_scores = await resolve_maybe_awaitable(scorer(proposals_data)) proposal_scores.append((weight, scorer_scores)) # type: ignore[arg-type] diff --git a/golem/managers/proposal/plugins/scoring/scoring_buffer.py b/golem/managers/proposal/plugins/scoring/scoring_buffer.py index fe6ccf53..e5384386 100644 --- a/golem/managers/proposal/plugins/scoring/scoring_buffer.py +++ b/golem/managers/proposal/plugins/scoring/scoring_buffer.py @@ -1,33 +1,53 @@ import asyncio import logging from datetime import timedelta -from typing import Optional +from typing import Callable, Optional, Sequence -from golem.managers.proposal.plugins.buffer import BufferPlugin as BufferPlugin +from golem.managers.base import ScorerWithOptionalWeight +from golem.managers.proposal.plugins.buffer import ProposalBuffer from golem.managers.proposal.plugins.scoring import ProposalScoringMixin from golem.resources import Proposal from golem.utils.asyncio import ( Buffer, ExpirableBuffer, SimpleBuffer, - cancel_and_await, create_task_with_logging, + ensure_cancelled, ) from golem.utils.logging import get_trace_id_name, trace_span +from golem.utils.typing import MaybeAwaitable logger = logging.getLogger(__name__) -class ScoringBufferPlugin(ProposalScoringMixin, BufferPlugin): - def __init__(self, update_interval: timedelta = timedelta(seconds=10), *args, **kwargs) -> None: - get_expiration_func = kwargs.pop("get_expiration_func", None) - - super().__init__(*args, **kwargs) +class ProposalScoringBuffer(ProposalScoringMixin, ProposalBuffer): + def __init__( + self, + min_size: int, + max_size: int, + fill_concurrency_size=1, + fill_at_start=False, + get_expiration_func: Optional[ + Callable[[Proposal], MaybeAwaitable[Optional[timedelta]]] + ] = None, + on_expiration_func: Optional[Callable[[Proposal], MaybeAwaitable[None]]] = None, + scoring_debounce: timedelta = timedelta(seconds=10), + proposal_scorers: Optional[Sequence[ScorerWithOptionalWeight]] = None, + ) -> None: + super().__init__( + min_size=min_size, + max_size=max_size, + fill_concurrency_size=fill_concurrency_size, + fill_at_start=fill_at_start, + get_expiration_func=None, + on_expiration_func=on_expiration_func, + proposal_scorers=proposal_scorers, + ) - self._update_interval = update_interval + self._scoring_debounce = scoring_debounce - # Postponing argument would disable expiration from BufferPlugin parent - # as we want to expire only scored items instead + # Postponing argument would disable expiration from ProposalBuffer parent + # as we want to expire only scored proposals instead self._get_expiration_func = get_expiration_func scored_buffer: Buffer[Proposal] = SimpleBuffer() @@ -36,7 +56,7 @@ def __init__(self, update_interval: timedelta = timedelta(seconds=10), *args, ** scored_buffer = ExpirableBuffer( buffer=scored_buffer, get_expiration_func=get_expiration_func, - on_expired_func=self._on_item_expire, + on_expired_func=self._on_expired, ) self._buffer_scored: Buffer[Proposal] = scored_buffer @@ -55,33 +75,34 @@ async def stop(self) -> None: await super().stop() if self._background_loop_task is not None: - await cancel_and_await(self._background_loop_task) + await ensure_cancelled(self._background_loop_task) self._background_loop_task = None - async def _on_item_added(self, item: Proposal): + async def _on_added(self, proposal: Proposal) -> None: pass # explicit no-op async def _background_loop(self) -> None: while True: - logger.debug("Waiting for any items to score...") - await self._buffer.wait_for_any_items() - logger.debug("Waiting for any items to score done, items are available for scoring") - - logger.debug("Waiting for more items up to %s...", self._update_interval) - items = await self._buffer.get_all_requested(self._update_interval) - logger.debug("Waiting for more items done, %d new items will be scored", len(items)) + logger.debug( + "Waiting for any proposals to score with debounce of `%s`...", + self._scoring_debounce, + ) + proposals = await self._buffer.get_requested(self._scoring_debounce) + logger.debug( + "Waiting for any proposals done, %d new proposals will be scored", len(proposals) + ) - items.extend(await self._buffer_scored.get_all()) + proposals.extend(await self._buffer_scored.get_all()) - logger.debug("Scoring total %d items...", len(items)) + logger.debug("Scoring total %d proposals...", len(proposals)) - scored_items = await self.do_scoring(items) - await self._buffer_scored.put_all([proposal for _, proposal in scored_items]) + scored_proposals = await self.do_scoring(proposals) + await self._buffer_scored.put_all([proposal for _, proposal in scored_proposals]) - logger.debug("Scoring total %d items done", len(items)) + logger.debug("Scoring total %d proposals done", len(proposals)) - async def _get_item(self) -> Proposal: + async def _get_buffered_proposal(self) -> Proposal: return await self._buffer_scored.get() - def _get_items_count(self) -> int: - return super()._get_items_count() + self._buffer_scored.size() + def _get_buffered_proposals_count(self) -> int: + return super()._get_buffered_proposals_count() + self._buffer_scored.size() diff --git a/golem/managers/work/plugins.py b/golem/managers/work/plugins.py index bfeab014..78a0eb8e 100644 --- a/golem/managers/work/plugins.py +++ b/golem/managers/work/plugins.py @@ -10,7 +10,7 @@ WorkManagerPlugin, WorkResult, ) -from golem.utils.asyncio import cancel_and_await_many +from golem.utils.asyncio import ensure_cancelled_many logger = logging.getLogger(__name__) @@ -71,7 +71,7 @@ async def wrapper(work: Work) -> WorkResult: tasks, return_when=asyncio.FIRST_COMPLETED ) - await cancel_and_await_many(tasks_pending) + await ensure_cancelled_many(tasks_pending) return tasks_done.pop().result() diff --git a/golem/utils/asyncio/__init__.py b/golem/utils/asyncio/__init__.py index e7a60662..dbdef0df 100644 --- a/golem/utils/asyncio/__init__.py +++ b/golem/utils/asyncio/__init__.py @@ -8,9 +8,9 @@ from golem.utils.asyncio.queue import ErrorReportingQueue from golem.utils.asyncio.semaphore import SingleUseSemaphore from golem.utils.asyncio.tasks import ( - cancel_and_await, - cancel_and_await_many, create_task_with_logging, + ensure_cancelled, + ensure_cancelled_many, ) from golem.utils.asyncio.waiter import Waiter @@ -22,8 +22,8 @@ "SimpleBuffer", "ErrorReportingQueue", "SingleUseSemaphore", - "cancel_and_await", - "cancel_and_await_many", + "ensure_cancelled", + "ensure_cancelled_many", "create_task_with_logging", "Waiter", ) diff --git a/golem/utils/asyncio/buffer.py b/golem/utils/asyncio/buffer.py index 78bf4e27..940d5b36 100644 --- a/golem/utils/asyncio/buffer.py +++ b/golem/utils/asyncio/buffer.py @@ -18,11 +18,11 @@ from golem.utils.asyncio.semaphore import SingleUseSemaphore from golem.utils.asyncio.tasks import ( - cancel_and_await_many, create_task_with_logging, + ensure_cancelled_many, resolve_maybe_awaitable, ) -from golem.utils.logging import get_trace_id_name, trace_span +from golem.utils.logging import get_trace_id_name from golem.utils.typing import MaybeAwaitable TItem = TypeVar("TItem") @@ -34,11 +34,12 @@ class Buffer(ABC, Generic[TItem]): """Interface class for object similar to `asyncio.Queue` but with more control over its \ items.""" + # TODO: Consider hiding condition in favor of public lock related context manager method condition: asyncio.Condition @abstractmethod def size(self) -> int: - """Return number of items stored in buffer.""" + """Return the number of items stored in the buffer.""" @abstractmethod async def wait_for_any_items(self, *, lock=True) -> None: @@ -46,19 +47,20 @@ async def wait_for_any_items(self, *, lock=True) -> None: @abstractmethod async def get(self, *, lock=True) -> TItem: - """Await, remove and return left-most item stored in buffer. + """Await, remove and return left-most item stored in the buffer. - If `.set_exception()` was previously called, exception will be raised only if buffer + If `.set_exception()` was previously called, the exception will be raised only if the buffer is empty. """ @abstractmethod async def get_all(self, *, lock=True) -> MutableSequence[TItem]: - """Remove and return all items stored in buffer. + """Remove and immediately return all items stored in the buffer. - Note that this method will not await for any items if buffer is empty. + Note that this method should not await any items if the buffer is empty and should instead + return an empty iterable. - If `.set_exception()` was previously called, exception will be raised only if buffer + If `.set_exception()` was previously called, exception should be raised only if buffer is empty. """ @@ -78,20 +80,21 @@ async def put_all(self, items: Sequence[TItem], *, lock=True) -> None: @abstractmethod async def remove(self, item: TItem, *, lock=True) -> None: - """Remove first occurrence of item from buffer or raise `ValueError` if not found.""" + """Remove the first occurrence of the item from the buffer or raise `ValueError` \ + if not found.""" @abstractmethod async def set_exception(self, exc: BaseException, *, lock=True) -> None: - """Set exception that will be raised while trying to `.get()`/`.get_all()` item from \ - empty buffer.""" + """Set the exception that will be raised on `.get()`/`.get_all()` if there are no more \ + items in the buffer.""" def reset_exception(self) -> None: - """Reset exception that was previously set by calling `.set_exception()`.""" + """Reset (clear) the exception that was previously set by calling `.set_exception()`.""" class ComposableBuffer(Buffer[TItem]): """Utility class for composable/stackable buffer implementations to help with calling \ - underlying buffer.""" + the underlying buffer.""" def __init__(self, buffer: Buffer[TItem]): self._buffer = buffer @@ -133,10 +136,10 @@ def reset_exception(self) -> None: class SimpleBuffer(Buffer[TItem]): - """Most basic implementation of Buffer interface.""" + """Basic implementation of the Buffer interface.""" def __init__(self, items: Optional[Sequence[TItem]] = None): - self._items = list(items) if items is not None else [] + self._items = list(items) if items else [] self._error: Optional[BaseException] = None self.condition = asyncio.Condition() @@ -152,12 +155,10 @@ async def _handle_lock(self, lock: bool): else: yield - @trace_span() async def wait_for_any_items(self, lock=True) -> None: async with self._handle_lock(lock): await self.condition.wait_for(lambda: bool(self.size() or self._error)) - @trace_span() async def get(self, *, lock=True) -> TItem: async with self._handle_lock(lock): await self.wait_for_any_items(lock=False) @@ -226,7 +227,7 @@ def __init__( self._expiration_handlers: Dict[int, List[asyncio.TimerHandle]] = defaultdict(list) async def _add_expiration_task_for_item(self, item: TItem) -> None: - expiration = await resolve_maybe_awaitable(self._get_expiration_func, item) + expiration = await resolve_maybe_awaitable(self._get_expiration_func(item)) if expiration is None: return @@ -300,12 +301,11 @@ async def remove(self, item: TItem, *, lock=True) -> None: self._remove_expiration_handler_for_item(item) - @trace_span(show_arguments=True) async def _expire_item(self, item: TItem) -> None: await self.remove(item) if self._on_expired_func: - await resolve_maybe_awaitable(self._on_expired_func, item) + await resolve_maybe_awaitable(self._on_expired_func(item)) class BackgroundFillBuffer(ComposableBuffer[TItem]): @@ -332,7 +332,6 @@ def __init__( self._worker_tasks: List[asyncio.Task] = [] self._workers_semaphore = SingleUseSemaphore() - @trace_span() async def start(self) -> None: if self.is_started(): raise RuntimeError("Already started!") @@ -346,12 +345,11 @@ async def start(self) -> None: self._is_started = True - @trace_span() async def stop(self) -> None: if not self.is_started(): raise RuntimeError("Already stopped!") - await cancel_and_await_many(self._worker_tasks) + await ensure_cancelled_many(self._worker_tasks) self._worker_tasks.clear() self._is_started = False @@ -373,7 +371,7 @@ async def _worker_loop(self) -> None: await self.put(item) - if self._on_added_func is not None: + if self._on_added_func: await self._on_added_func(item) logger.debug("Adding new item done with total of %d items in buffer", self.size()) @@ -388,16 +386,17 @@ def size_with_requested(self) -> int: return self.size() + self._workers_semaphore.get_count_with_pending() - async def get_all_requested(self, deadline: timedelta) -> MutableSequence[TItem]: - """Await for all requested items with given deadline, then remove and return all items \ - stored in buffer.""" + async def get_requested(self, debounce: timedelta) -> MutableSequence[TItem]: + """Await for any requested items with given debounce time window, then remove and return \ + all items stored in buffer.""" - if not self._workers_semaphore.finished.is_set(): - try: - await asyncio.wait_for( - self._workers_semaphore.finished.wait(), deadline.total_seconds() - ) - except asyncio.TimeoutError: - pass + await self._buffer.wait_for_any_items() + + try: + await asyncio.wait_for( + self._workers_semaphore.finished.wait(), debounce.total_seconds() + ) + except asyncio.TimeoutError: + pass return await self.get_all() diff --git a/golem/utils/asyncio/queue.py b/golem/utils/asyncio/queue.py index b5b612c9..a2df2715 100644 --- a/golem/utils/asyncio/queue.py +++ b/golem/utils/asyncio/queue.py @@ -1,7 +1,7 @@ import asyncio from typing import Generic, Optional, TypeVar -from golem.utils.asyncio.tasks import cancel_and_await_many +from golem.utils.asyncio.tasks import ensure_cancelled_many TQueueItem = TypeVar("TQueueItem") @@ -37,7 +37,7 @@ async def get(self) -> TQueueItem: [error_task, get_task], return_when=asyncio.FIRST_COMPLETED ) - await cancel_and_await_many(pending) + await ensure_cancelled_many(pending) if get_task in done: return await get_task diff --git a/golem/utils/asyncio/semaphore.py b/golem/utils/asyncio/semaphore.py index 8d651bfd..8557d984 100644 --- a/golem/utils/asyncio/semaphore.py +++ b/golem/utils/asyncio/semaphore.py @@ -77,4 +77,3 @@ def reset(self) -> None: """Reset "charges" amount to zero.""" self._value = 0 - self.finished.set() diff --git a/golem/utils/asyncio/tasks.py b/golem/utils/asyncio/tasks.py index a425838a..06496b0a 100644 --- a/golem/utils/asyncio/tasks.py +++ b/golem/utils/asyncio/tasks.py @@ -2,7 +2,7 @@ import contextvars import inspect import logging -from typing import Callable, Iterable, Optional, TypeVar, cast +from typing import Iterable, Optional, TypeVar, cast from golem.utils.logging import trace_id_var from golem.utils.typing import MaybeAwaitable @@ -44,7 +44,9 @@ def _handle_task_logging(task: asyncio.Task): logger.exception("Background async task encountered unhandled exception!") -async def cancel_and_await(task: asyncio.Task) -> None: +async def ensure_cancelled(task: asyncio.Task) -> None: + """Cancel given task and await for its cancellation.""" + if task.done(): return @@ -56,14 +58,18 @@ async def cancel_and_await(task: asyncio.Task) -> None: pass -async def cancel_and_await_many(tasks: Iterable[asyncio.Task]) -> None: - await asyncio.gather(*[cancel_and_await(task) for task in tasks]) +async def ensure_cancelled_many(tasks: Iterable[asyncio.Task]) -> None: + """Cancel given tasks and concurrently await for their cancellation.""" + + await asyncio.gather(*[ensure_cancelled(task) for task in tasks]) -async def resolve_maybe_awaitable(func: Callable[..., MaybeAwaitable[T]], *args, **kwargs) -> T: - result = func(*args, **kwargs) +async def resolve_maybe_awaitable(value: MaybeAwaitable[T]) -> T: + """Return given value or await for it results if value is awaitable.""" - if inspect.iscoroutine(result): - result = await result + if inspect.isawaitable(value): + return await value - return cast(T, result) # FIXME: This cast should not be needed + # TODO: remove cast as inspect.isawaitable can't tell mypy that at this + # point value is T + return cast(T, value) diff --git a/tests/unit/utils/test_buffer.py b/tests/unit/utils/test_buffer.py index e6624311..0fd71174 100644 --- a/tests/unit/utils/test_buffer.py +++ b/tests/unit/utils/test_buffer.py @@ -230,7 +230,7 @@ async def test_simple_buffer_wait_for_any_items(): assert wait_task.done() -async def test_expirable_buffer_is_not_expiring_initial_items(mocked_buffer, mocker): +async def test_expirable_buffer_is_not_expiring_initial_items(mocked_buffer): expire_after = timedelta(seconds=0.1) ExpirableBuffer( mocked_buffer, @@ -422,12 +422,12 @@ async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, e mocked_buffer.size.return_value = 0 time_before_wait = event_loop.time() - assert await buffer.get_all_requested(timeout) == [] + assert await buffer.get_requested(timeout) == [] time_after_wait = event_loop.time() assert ( time_after_wait - time_before_wait < timeout.total_seconds() - ), "get_all_requested seems to wait for the deadline instead of retuning fast" + ), "get_requested seems to wait for the deadline instead of retuning fast" # await buffer.request(1) @@ -444,12 +444,12 @@ async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, e mocked_buffer.size.return_value = 0 time_before_wait = event_loop.time() - assert await buffer.get_all_requested(timeout) == [] + assert await buffer.get_requested(timeout) == [] time_after_wait = event_loop.time() assert ( timeout.total_seconds() <= time_after_wait - time_before_wait - ), "get_all_requested seems to not wait to the deadline" + ), "get_requested seems to not wait to the deadline" # await feed_queue.put(item) @@ -466,12 +466,12 @@ async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, e mocked_buffer.size.return_value = 0 time_before_wait = event_loop.time() - assert await buffer.get_all_requested(timeout) == [item] + assert await buffer.get_requested(timeout) == [item] time_after_wait = event_loop.time() assert ( time_after_wait - time_before_wait < timeout.total_seconds() - ), "get_all_requested seems to wait for the deadline instead of retuning fast" + ), "get_requested seems to wait for the deadline instead of retuning fast" # await buffer.stop() diff --git a/tests/unit/utils/test_semaphore.py b/tests/unit/utils/test_semaphore.py index adab0022..ca986127 100644 --- a/tests/unit/utils/test_semaphore.py +++ b/tests/unit/utils/test_semaphore.py @@ -76,7 +76,7 @@ async def test_reset(): assert sem.get_count() == 0 assert sem.get_count_with_pending() == 0 assert sem.get_pending_count() == 0 - assert sem.finished.is_set() + assert not sem.finished.is_set() assert sem.locked() From 9cac75e0e6cdc7c1a8ef68f63a8a3b9a2eeddb35 Mon Sep 17 00:00:00 2001 From: approxit Date: Wed, 31 Jan 2024 20:50:13 +0100 Subject: [PATCH 18/20] reformat --- golem/managers/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/golem/managers/base.py b/golem/managers/base.py index 1a6006d1..34f64b13 100644 --- a/golem/managers/base.py +++ b/golem/managers/base.py @@ -13,7 +13,8 @@ Sequence, Tuple, TypeVar, - Union, runtime_checkable, + Union, + runtime_checkable, ) from golem.exceptions import GolemException From 7bd090558cca76acb2cc603a53f7a76b9e68390b Mon Sep 17 00:00:00 2001 From: approxit Date: Thu, 1 Feb 2024 16:47:15 +0100 Subject: [PATCH 19/20] review changes --- golem/managers/proposal/plugins/buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/golem/managers/proposal/plugins/buffer.py b/golem/managers/proposal/plugins/buffer.py index 9f524507..7d3d6dfc 100644 --- a/golem/managers/proposal/plugins/buffer.py +++ b/golem/managers/proposal/plugins/buffer.py @@ -17,7 +17,7 @@ def __init__( self, min_size: int, max_size: int, - fill_concurrency_size=1, + fill_concurrency_size: int = 1, fill_at_start=False, get_expiration_func: Optional[ Callable[[Proposal], MaybeAwaitable[Optional[timedelta]]] @@ -116,7 +116,7 @@ async def get_proposal(self) -> Proposal: await self._request_proposals() else: logger.debug( - "Proposals count %d is not below minimum size %d, skipping fill", + "Proposals count %d is above minimum size %d, skipping fill", proposals_count, self._min_size, ) From cab74a54f050b9626d2ae5b6440cded9343a415a Mon Sep 17 00:00:00 2001 From: approxit Date: Fri, 2 Feb 2024 10:32:02 +0100 Subject: [PATCH 20/20] Update golem/utils/asyncio/buffer.py Co-authored-by: shadeofblue --- golem/utils/asyncio/buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/golem/utils/asyncio/buffer.py b/golem/utils/asyncio/buffer.py index 940d5b36..c437db37 100644 --- a/golem/utils/asyncio/buffer.py +++ b/golem/utils/asyncio/buffer.py @@ -43,7 +43,7 @@ def size(self) -> int: @abstractmethod async def wait_for_any_items(self, *, lock=True) -> None: - """Wait until any items are stored in buffer.""" + """Wait until any items are stored in the buffer.""" @abstractmethod async def get(self, *, lock=True) -> TItem: