diff --git a/.github/workflows/tests-unit.yml b/.github/workflows/tests-unit.yml index 1af51c31..05a34296 100644 --- a/.github/workflows/tests-unit.yml +++ b/.github/workflows/tests-unit.yml @@ -15,7 +15,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] os: - ubuntu-latest - macos-latest diff --git a/examples/managers/blender/blender.py b/examples/managers/blender/blender.py index bbb8e338..b96bcbe7 100644 --- a/examples/managers/blender/blender.py +++ b/examples/managers/blender/blender.py @@ -15,8 +15,8 @@ PayAllPaymentManager, PaymentPlatformNegotiator, PoolActivityManager, + ProposalScoringBuffer, RefreshingDemandManager, - ScoringBuffer, WorkContext, WorkResult, retry, @@ -58,7 +58,7 @@ async def run_on_golem( demand_manager.get_initial_proposal, plugins=[ NegotiatingPlugin(proposal_negotiators=negotiators), - ScoringBuffer( + ProposalScoringBuffer( min_size=3, max_size=5, fill_concurrency_size=3, proposal_scorers=scorers ), ], diff --git a/examples/managers/flexible_negotiation.py b/examples/managers/flexible_negotiation.py index c173b259..05e2718d 100644 --- a/examples/managers/flexible_negotiation.py +++ b/examples/managers/flexible_negotiation.py @@ -3,13 +3,13 @@ from golem.managers import ( BlacklistProviderIdPlugin, - Buffer, DefaultAgreementManager, DefaultProposalManager, NegotiatingPlugin, PayAllPaymentManager, PaymentPlatformNegotiator, PoolActivityManager, + ProposalBuffer, RefreshingDemandManager, SequentialWorkManager, WorkContext, @@ -48,14 +48,14 @@ async def main(): golem, demand_manager.get_initial_proposal, plugins=[ - Buffer( + ProposalBuffer( min_size=10, max_size=1000, fill_concurrency_size=5, ), BlacklistProviderIdPlugin(BLACKLISTED_PROVIDERS), NegotiatingPlugin(proposal_negotiators=[PaymentPlatformNegotiator()]), - Buffer( + ProposalBuffer( min_size=3, max_size=5, fill_concurrency_size=3, diff --git a/examples/managers/mid_agreement_payments.py b/examples/managers/mid_agreement_payments.py index 72f75087..b94320a9 100644 --- a/examples/managers/mid_agreement_payments.py +++ b/examples/managers/mid_agreement_payments.py @@ -3,13 +3,13 @@ from datetime import datetime, timedelta from golem.managers import ( - Buffer, DefaultAgreementManager, DefaultProposalManager, MidAgreementPaymentsNegotiator, NegotiatingPlugin, PayAllPaymentManager, PaymentPlatformNegotiator, + ProposalBuffer, RefreshingDemandManager, SequentialWorkManager, SingleUseActivityManager, @@ -91,7 +91,7 @@ async def main(): ), ] ), - Buffer( + ProposalBuffer( min_size=1, max_size=4, fill_concurrency_size=2, diff --git a/examples/managers/proposal_plugins.py b/examples/managers/proposal_plugins.py index 552781f7..8bf4ca9d 100644 --- a/examples/managers/proposal_plugins.py +++ b/examples/managers/proposal_plugins.py @@ -5,7 +5,6 @@ from golem.managers import ( BlacklistProviderIdPlugin, - Buffer, DefaultAgreementManager, DefaultProposalManager, LinearAverageCostPricing, @@ -15,11 +14,12 @@ PaymentPlatformNegotiator, PoolActivityManager, PropertyValueLerpScore, + ProposalBuffer, + ProposalScoringBuffer, RandomScore, RefreshingDemandManager, RejectIfCostsExceeds, RejectProposal, - ScoringBuffer, SequentialWorkManager, WorkContext, WorkResult, @@ -78,7 +78,7 @@ async def main(): golem, demand_manager.get_initial_proposal, plugins=[ - Buffer( + ProposalBuffer( min_size=10, max_size=1000, ), @@ -101,7 +101,7 @@ async def main(): else None, ) ), - ScoringBuffer( + ProposalScoringBuffer( min_size=3, max_size=5, fill_concurrency_size=3, diff --git a/examples/task_api_draft/examples/yacat.py b/examples/task_api_draft/examples/yacat.py index 4854d222..1849ca75 100644 --- a/examples/task_api_draft/examples/yacat.py +++ b/examples/task_api_draft/examples/yacat.py @@ -1,7 +1,6 @@ import asyncio from collections import defaultdict from typing import ( - Any, AsyncIterator, Callable, DefaultDict, @@ -221,9 +220,7 @@ async def main() -> None: ############################################# # NOT REALLY INTERESTING PARTS OF THE LOGIC -async def close_agreement_repeat_task( - func: Callable, args: Tuple[Activity, Any], e: Exception -) -> None: +async def close_agreement_repeat_task(func: Callable, args: Tuple, e: Exception) -> None: activity, task = args tasks_queue.put_nowait(task) print("Task failed on", activity) diff --git a/examples/task_api_draft/task_api/activity_pool.py b/examples/task_api_draft/task_api/activity_pool.py index df814fa3..f97523ce 100644 --- a/examples/task_api_draft/task_api/activity_pool.py +++ b/examples/task_api_draft/task_api/activity_pool.py @@ -4,6 +4,7 @@ from golem.pipeline import InputStreamExhausted from golem.resources import Activity +from golem.utils.asyncio import ensure_cancelled class ActivityPool: @@ -94,7 +95,7 @@ async def _activity_destroyed_cleanup( return await activity.wait_destroyed() - manager_task.cancel() + await ensure_cancelled(manager_task) async def _get_next_idle_activity( self, activity_stream: AsyncIterator[Union[Activity, Awaitable[Activity]]] diff --git a/examples/task_api_draft/task_api/execute_tasks.py b/examples/task_api_draft/task_api/execute_tasks.py index 1618912e..656d6b24 100644 --- a/examples/task_api_draft/task_api/execute_tasks.py +++ b/examples/task_api_draft/task_api/execute_tasks.py @@ -31,10 +31,10 @@ async def random_score(proposal: Proposal) -> float: def close_agreement_repeat_task( task_stream: TaskDataStream[TaskData], -) -> Callable[[Callable, Tuple[Activity, TaskData], Exception], Awaitable[None]]: +) -> Callable[[Callable, Tuple, Exception], Awaitable[None]]: async def on_exception( - func: Callable[[Activity, TaskData], Awaitable[TaskResult]], - args: Tuple[Activity, TaskData], + func: Callable, + args: Tuple, e: Exception, ) -> None: activity, in_data = args diff --git a/golem/event_bus/in_memory/event_bus.py b/golem/event_bus/in_memory/event_bus.py index 052f2959..d019f312 100644 --- a/golem/event_bus/in_memory/event_bus.py +++ b/golem/event_bus/in_memory/event_bus.py @@ -5,7 +5,7 @@ from typing import Awaitable, Callable, DefaultDict, List, Optional, Tuple, Type from golem.event_bus.base import Event, EventBus, EventBusError, TEvent -from golem.utils.asyncio import create_task_with_logging +from golem.utils.asyncio import create_task_with_logging, ensure_cancelled from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ async def stop(self): await self._event_queue.join() if self._process_event_queue_loop_task is not None: - self._process_event_queue_loop_task.cancel() + await ensure_cancelled(self._process_event_queue_loop_task) self._process_event_queue_loop_task = None @trace_span(show_results=True) diff --git a/golem/managers/__init__.py b/golem/managers/__init__.py index 470f15da..1d1beb3c 100644 --- a/golem/managers/__init__.py +++ b/golem/managers/__init__.py @@ -23,7 +23,6 @@ from golem.managers.payment import PayAllPaymentManager from golem.managers.proposal import ( BlacklistProviderIdPlugin, - Buffer, DefaultProposalManager, LinearAverageCostPricing, LinearCoeffsCost, @@ -34,10 +33,11 @@ NegotiatingPlugin, PaymentPlatformNegotiator, PropertyValueLerpScore, + ProposalBuffer, + ProposalScoringBuffer, ProposalScoringMixin, RandomScore, RejectIfCostsExceeds, - ScoringBuffer, ) from golem.managers.work import ( ConcurrentWorkManager, @@ -73,7 +73,7 @@ "PayAllPaymentManager", "DefaultProposalManager", "BlacklistProviderIdPlugin", - "Buffer", + "ProposalBuffer", "PaymentPlatformNegotiator", "MidAgreementPaymentsNegotiator", "NegotiatingPlugin", @@ -86,7 +86,7 @@ "LinearCoeffsCost", "PropertyValueLerpScore", "RandomScore", - "ScoringBuffer", + "ProposalScoringBuffer", "SequentialWorkManager", "ConcurrentWorkManager", "WorkManagerPluginsMixin", diff --git a/golem/managers/base.py b/golem/managers/base.py index cc5df9b6..34f64b13 100644 --- a/golem/managers/base.py +++ b/golem/managers/base.py @@ -2,7 +2,20 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, + runtime_checkable, +) from golem.exceptions import GolemException from golem.resources import ( @@ -15,6 +28,7 @@ Script, ) from golem.resources.activity import commands +from golem.utils.typing import MaybeAwaitable logger = logging.getLogger(__name__) @@ -180,11 +194,11 @@ class RejectProposal(ManagerPluginException): pass -class ProposalNegotiator(ABC): - @abstractmethod +@runtime_checkable +class ProposalNegotiator(Protocol): def __call__( self, demand_data: DemandData, proposal_data: ProposalData - ) -> Union[Awaitable[Optional[RejectProposal]], Optional[RejectProposal]]: + ) -> MaybeAwaitable[Optional[RejectProposal]]: ... @@ -208,21 +222,22 @@ async def stop(self) -> None: ProposalScoringResult = Sequence[Optional[float]] -class ProposalScorer(ABC): - @abstractmethod +@runtime_checkable +class ProposalScorer(Protocol): def __call__( self, proposals_data: Sequence[ProposalData] - ) -> Union[Awaitable[ProposalScoringResult], ProposalScoringResult]: + ) -> MaybeAwaitable[ProposalScoringResult]: ... ScorerWithOptionalWeight = Union[ProposalScorer, Tuple[float, ProposalScorer]] -class WorkManagerPlugin(ABC): - @abstractmethod +@runtime_checkable +class WorkManagerPlugin(Protocol): def __call__(self, do_work: DoWorkCallable) -> DoWorkCallable: ... +# TODO: Make consistent naming on functions in arguments in whole project - callable or func PricingCallable = Callable[[ProposalData], Optional[float]] diff --git a/golem/managers/demand/refreshing.py b/golem/managers/demand/refreshing.py index c0155086..844eb0a9 100644 --- a/golem/managers/demand/refreshing.py +++ b/golem/managers/demand/refreshing.py @@ -1,6 +1,6 @@ import asyncio import logging -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Awaitable, Callable, List, Tuple from golem.managers.base import DemandManager @@ -10,9 +10,8 @@ from golem.payload import defaults as payload_defaults from golem.resources import Allocation, Demand, Proposal from golem.resources.demand.demand_builder import DemandBuilder -from golem.utils.asyncio import create_task_with_logging +from golem.utils.asyncio import ErrorReportingQueue, create_task_with_logging from golem.utils.logging import get_trace_id_name, trace_span -from golem.utils.queue import ErrorReportingQueue logger = logging.getLogger(__name__) @@ -68,13 +67,10 @@ async def _wait_for_demand_to_expire(self): if not self._demands: return - remaining: timedelta = ( - datetime.utcfromtimestamp( - self._demands[-1][0].data.properties["golem.srv.comp.expiration"] / 1000 - ) - - datetime.utcnow() - ) - await asyncio.sleep(remaining.seconds) + await self._demands[-1][0].get_data() + expiration_date = self._demands[-1][0].get_expiration_date() + remaining = expiration_date - datetime.now(timezone.utc) + await asyncio.sleep(remaining.total_seconds()) @trace_span() async def _create_and_subscribe_demand(self): diff --git a/golem/managers/mixins.py b/golem/managers/mixins.py index bc1170a2..9bcba831 100644 --- a/golem/managers/mixins.py +++ b/golem/managers/mixins.py @@ -3,7 +3,7 @@ from typing import Generic, List, Optional, Sequence from golem.managers.base import ManagerException, TPlugin -from golem.utils.asyncio import create_task_with_logging +from golem.utils.asyncio import create_task_with_logging, ensure_cancelled from golem.utils.logging import get_trace_id_name, trace_span logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ async def stop(self) -> None: raise ManagerException("Already stopped!") if self._background_loop_task is not None: - self._background_loop_task.cancel() + await ensure_cancelled(self._background_loop_task) self._background_loop_task = None def is_started(self) -> bool: diff --git a/golem/managers/proposal/__init__.py b/golem/managers/proposal/__init__.py index 819b4014..fee2e36c 100644 --- a/golem/managers/proposal/__init__.py +++ b/golem/managers/proposal/__init__.py @@ -1,7 +1,6 @@ from golem.managers.proposal.default import DefaultProposalManager from golem.managers.proposal.plugins import ( BlacklistProviderIdPlugin, - Buffer, LinearAverageCostPricing, LinearCoeffsCost, LinearPerCpuAverageCostPricing, @@ -11,16 +10,17 @@ NegotiatingPlugin, PaymentPlatformNegotiator, PropertyValueLerpScore, + ProposalBuffer, + ProposalScoringBuffer, ProposalScoringMixin, RandomScore, RejectIfCostsExceeds, - ScoringBuffer, ) __all__ = ( "DefaultProposalManager", "BlacklistProviderIdPlugin", - "Buffer", + "ProposalBuffer", "PaymentPlatformNegotiator", "MidAgreementPaymentsNegotiator", "NegotiatingPlugin", @@ -33,5 +33,5 @@ "LinearPerCpuCoeffsCost", "PropertyValueLerpScore", "RandomScore", - "ScoringBuffer", + "ProposalScoringBuffer", ) diff --git a/golem/managers/proposal/plugins/__init__.py b/golem/managers/proposal/plugins/__init__.py index 5cad6098..aac10077 100644 --- a/golem/managers/proposal/plugins/__init__.py +++ b/golem/managers/proposal/plugins/__init__.py @@ -1,5 +1,5 @@ from golem.managers.proposal.plugins.blacklist import BlacklistProviderIdPlugin -from golem.managers.proposal.plugins.buffer import Buffer +from golem.managers.proposal.plugins.buffer import ProposalBuffer from golem.managers.proposal.plugins.linear_coeffs import LinearCoeffsCost, LinearPerCpuCoeffsCost from golem.managers.proposal.plugins.negotiating import ( MidAgreementPaymentsNegotiator, @@ -12,14 +12,14 @@ LinearPerCpuAverageCostPricing, MapScore, PropertyValueLerpScore, + ProposalScoringBuffer, ProposalScoringMixin, RandomScore, - ScoringBuffer, ) __all__ = ( "BlacklistProviderIdPlugin", - "Buffer", + "ProposalBuffer", "PaymentPlatformNegotiator", "MidAgreementPaymentsNegotiator", "NegotiatingPlugin", @@ -32,5 +32,5 @@ "LinearPerCpuCoeffsCost", "PropertyValueLerpScore", "RandomScore", - "ScoringBuffer", + "ProposalScoringBuffer", ) diff --git a/golem/managers/proposal/plugins/buffer.py b/golem/managers/proposal/plugins/buffer.py index f4ec6f6d..7d3d6dfc 100644 --- a/golem/managers/proposal/plugins/buffer.py +++ b/golem/managers/proposal/plugins/buffer.py @@ -1,130 +1,130 @@ -import asyncio import logging -from asyncio import Queue -from typing import List +from datetime import timedelta +from typing import Callable, Optional from golem.managers import ProposalManagerPlugin from golem.resources import Proposal -from golem.utils.asyncio import create_task_with_logging -from golem.utils.logging import get_trace_id_name, trace_span +from golem.utils.asyncio.buffer import BackgroundFillBuffer, Buffer, ExpirableBuffer, SimpleBuffer +from golem.utils.asyncio.tasks import resolve_maybe_awaitable +from golem.utils.logging import trace_span +from golem.utils.typing import MaybeAwaitable logger = logging.getLogger(__name__) -class Buffer(ProposalManagerPlugin): +class ProposalBuffer(ProposalManagerPlugin): def __init__( self, min_size: int, max_size: int, fill_concurrency_size: int = 1, fill_at_start=False, - ): + get_expiration_func: Optional[ + Callable[[Proposal], MaybeAwaitable[Optional[timedelta]]] + ] = None, + on_expiration_func: Optional[Callable[[Proposal], MaybeAwaitable[None]]] = None, + ) -> None: self._min_size = min_size self._max_size = max_size self._fill_concurrency_size = fill_concurrency_size self._fill_at_start = fill_at_start + self._get_expiration_func = get_expiration_func + self._on_expiration_func = on_expiration_func - self._get_proposal_lock = asyncio.Lock() - self._worker_tasks: List[asyncio.Task] = [] - # As there is no awaitable counter, queue with dummy values is used - self._requests_queue: Queue[int] = asyncio.Queue() - self._requests_pending_count = 0 - self._buffered: List[Proposal] = [] - self._buffered_condition = asyncio.Condition() - self._is_started = False + # TODO: Consider moving buffer composition from here to plugin level + buffer: Buffer[Proposal] = SimpleBuffer() - @trace_span(show_results=True) - async def get_proposal(self) -> Proposal: - if not self.is_started(): - raise RuntimeError("Not started!") - - async with self._get_proposal_lock: - return await self._get_item() - - @trace_span() - async def start(self) -> None: - if self.is_started(): - raise RuntimeError("Already started!") - - for i in range(self._fill_concurrency_size): - self._worker_tasks.append( - create_task_with_logging( - self._worker_loop(), trace_id=get_trace_id_name(self, f"worker-{i}") - ) + if self._get_expiration_func: + buffer = ExpirableBuffer( + buffer=buffer, + get_expiration_func=self._get_expiration_func, + on_expired_func=self._on_expired, ) - if self._fill_at_start: - self._handle_item_requests() - - self._is_started = True - - @trace_span() - async def stop(self) -> None: - if not self.is_started(): - raise RuntimeError("Already stopped!") - - for worker_task in self._worker_tasks: - worker_task.cancel() - - try: - await worker_task - except asyncio.CancelledError: - pass + self._buffer: BackgroundFillBuffer[Proposal] = BackgroundFillBuffer( + buffer=buffer, + fill_func=self._call_feed_func, + fill_concurrency_size=self._fill_concurrency_size, + on_added_func=self._on_added, + ) - self._worker_tasks.clear() - self._is_started = False + async def _call_feed_func(self) -> Proposal: + return await self._get_proposal() - self._requests_queue = asyncio.Queue() - self._requests_pending_count = 0 + async def _on_added(self, proposal: Proposal) -> None: + count_current = self._buffer.size() + count_with_requested = self._buffer.size_with_requested() + pending = count_with_requested - count_current - def is_started(self) -> bool: - return self._is_started + logger.debug( + "Proposal added, having %d proposals, and %d already requested, target %d", + count_current, + pending, + self._max_size, + ) - async def _worker_loop(self): - while True: - await self._wait_for_any_item_requests() + async def _on_expired(self, proposal: Proposal): + logger.debug("Rejecting expired `%r` and requesting fill", proposal) + await proposal.reject("Proposal no longer needed due to its near expiration.") + await self._request_proposals() - item = await self._get_proposal() - - async with self._buffered_condition: - self._buffered.append(item) - - self._buffered_condition.notify_all() - - self._requests_queue.task_done() - self._requests_pending_count -= 1 + if self._on_expiration_func: + await resolve_maybe_awaitable(self._on_expiration_func(proposal)) @trace_span() - async def _wait_for_any_item_requests(self) -> None: - await self._requests_queue.get() - - async def _get_item(self): - async with self._buffered_condition: - if self._get_items_count() == 0: # This supports lazy (not at start) buffer filling - logger.debug("No items to get, requesting fill") - self._handle_item_requests() + async def start(self) -> None: + await self._buffer.start() - logger.debug("Waiting for any item to pick...") + if self._fill_at_start: + await self._request_proposals() - await self._buffered_condition.wait_for(lambda: 0 < len(self._buffered)) - item = self._buffered.pop() + @trace_span() + async def stop(self) -> None: + await self._buffer.stop() - # Check if we need to request any additional items - if self._get_items_count() < self._min_size: - self._handle_item_requests() + async def _request_proposals(self) -> None: + count_current = self._buffer.size() + count_with_requested = self._buffer.size_with_requested() + requested = self._max_size - count_with_requested - return item + logger.debug( + "Proposal count %d and %d already requested, requesting additional %d to match" + " target %d", + count_current, + count_with_requested - count_current, + requested, + self._max_size, + ) - def _get_items_count(self): - return len(self._buffered) + self._requests_pending_count + await self._buffer.request(requested) - @trace_span() - def _handle_item_requests(self) -> None: - items_to_request = self._max_size - self._get_items_count() + @trace_span(show_results=True) + async def get_proposal(self) -> Proposal: + if not self._get_buffered_proposals_count(): + logger.debug("No proposals to get, requesting fill") + await self._request_proposals() + + proposal = await self._get_buffered_proposal() + + proposals_count = self._get_buffered_proposals_count() + if proposals_count < self._min_size: + logger.debug( + "Proposals count %d is below minimum size %d, requesting fill", + proposals_count, + self._min_size, + ) + await self._request_proposals() + else: + logger.debug( + "Proposals count %d is above minimum size %d, skipping fill", + proposals_count, + self._min_size, + ) - for i in range(items_to_request): - self._requests_queue.put_nowait(i) + return proposal - self._requests_pending_count += items_to_request + async def _get_buffered_proposal(self) -> Proposal: + return await self._buffer.get() - logger.debug("Requested %d items", items_to_request) + def _get_buffered_proposals_count(self) -> int: + return self._buffer.size() diff --git a/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py b/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py index 15edb600..199339dc 100644 --- a/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py +++ b/golem/managers/proposal/plugins/negotiating/negotiating_plugin.py @@ -9,6 +9,7 @@ from golem.managers import ProposalManagerPlugin, RejectProposal from golem.managers.base import ProposalNegotiator from golem.resources import DemandData, Proposal +from golem.utils.asyncio.tasks import resolve_maybe_awaitable from golem.utils.logging import trace_span logger = logging.getLogger(__name__) @@ -116,10 +117,9 @@ async def _run_negotiators( proposal_data = await offer_proposal.get_proposal_data() for negotiator in self._proposal_negotiators: - negotiator_result = negotiator(demand_data_after_negotiators, proposal_data) - - if asyncio.iscoroutine(negotiator_result): - negotiator_result = await negotiator_result + negotiator_result = await resolve_maybe_awaitable( + negotiator(demand_data_after_negotiators, proposal_data) + ) if isinstance(negotiator_result, RejectProposal): raise negotiator_result diff --git a/golem/managers/proposal/plugins/scoring/__init__.py b/golem/managers/proposal/plugins/scoring/__init__.py index e8209595..35b374f2 100644 --- a/golem/managers/proposal/plugins/scoring/__init__.py +++ b/golem/managers/proposal/plugins/scoring/__init__.py @@ -6,7 +6,7 @@ ) from golem.managers.proposal.plugins.scoring.property_value_lerp import PropertyValueLerpScore from golem.managers.proposal.plugins.scoring.random import RandomScore -from golem.managers.proposal.plugins.scoring.scoring_buffer import ScoringBuffer +from golem.managers.proposal.plugins.scoring.scoring_buffer import ProposalScoringBuffer __all__ = ( "MapScore", @@ -15,5 +15,5 @@ "LinearPerCpuAverageCostPricing", "PropertyValueLerpScore", "RandomScore", - "ScoringBuffer", + "ProposalScoringBuffer", ) diff --git a/golem/managers/proposal/plugins/scoring/mixins.py b/golem/managers/proposal/plugins/scoring/mixins.py index 8bf9f9db..c314870a 100644 --- a/golem/managers/proposal/plugins/scoring/mixins.py +++ b/golem/managers/proposal/plugins/scoring/mixins.py @@ -1,10 +1,10 @@ -import inspect from datetime import datetime from typing import List, Optional, Sequence, Tuple, cast from golem.managers.base import ScorerWithOptionalWeight from golem.payload import PayloadSyntaxParser, Properties from golem.resources import Proposal, ProposalData +from golem.utils.asyncio.tasks import resolve_maybe_awaitable from golem.utils.logging import trace_span @@ -43,10 +43,7 @@ async def _run_scorers( else: weight = 1 - scorer_scores = scorer(proposals_data) - - if inspect.isawaitable(scorer_scores): - scorer_scores = await scorer_scores + scorer_scores = await resolve_maybe_awaitable(scorer(proposals_data)) proposal_scores.append((weight, scorer_scores)) # type: ignore[arg-type] diff --git a/golem/managers/proposal/plugins/scoring/scoring_buffer.py b/golem/managers/proposal/plugins/scoring/scoring_buffer.py index 850b6727..e5384386 100644 --- a/golem/managers/proposal/plugins/scoring/scoring_buffer.py +++ b/golem/managers/proposal/plugins/scoring/scoring_buffer.py @@ -1,32 +1,66 @@ import asyncio import logging from datetime import timedelta -from typing import List, Optional +from typing import Callable, Optional, Sequence -from golem.managers.proposal.plugins.buffer import Buffer -from golem.managers.proposal.plugins.scoring.mixins import ProposalScoringMixin +from golem.managers.base import ScorerWithOptionalWeight +from golem.managers.proposal.plugins.buffer import ProposalBuffer +from golem.managers.proposal.plugins.scoring import ProposalScoringMixin from golem.resources import Proposal -from golem.utils.asyncio import create_task_with_logging +from golem.utils.asyncio import ( + Buffer, + ExpirableBuffer, + SimpleBuffer, + create_task_with_logging, + ensure_cancelled, +) from golem.utils.logging import get_trace_id_name, trace_span +from golem.utils.typing import MaybeAwaitable logger = logging.getLogger(__name__) -class ScoringBuffer(ProposalScoringMixin, Buffer): +class ProposalScoringBuffer(ProposalScoringMixin, ProposalBuffer): def __init__( self, - update_interval: timedelta = timedelta(seconds=10), - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - - self._scored: List[Proposal] = [] - self._scored_condition = asyncio.Condition() - self._background_loop_task: Optional[asyncio.Task] = None + min_size: int, + max_size: int, + fill_concurrency_size=1, + fill_at_start=False, + get_expiration_func: Optional[ + Callable[[Proposal], MaybeAwaitable[Optional[timedelta]]] + ] = None, + on_expiration_func: Optional[Callable[[Proposal], MaybeAwaitable[None]]] = None, + scoring_debounce: timedelta = timedelta(seconds=10), + proposal_scorers: Optional[Sequence[ScorerWithOptionalWeight]] = None, + ) -> None: + super().__init__( + min_size=min_size, + max_size=max_size, + fill_concurrency_size=fill_concurrency_size, + fill_at_start=fill_at_start, + get_expiration_func=None, + on_expiration_func=on_expiration_func, + proposal_scorers=proposal_scorers, + ) + + self._scoring_debounce = scoring_debounce + + # Postponing argument would disable expiration from ProposalBuffer parent + # as we want to expire only scored proposals instead + self._get_expiration_func = get_expiration_func - self._items_requested_event = asyncio.Event() - self._update_interval = update_interval + scored_buffer: Buffer[Proposal] = SimpleBuffer() + + if get_expiration_func is not None: + scored_buffer = ExpirableBuffer( + buffer=scored_buffer, + get_expiration_func=get_expiration_func, + on_expired_func=self._on_expired, + ) + + self._buffer_scored: Buffer[Proposal] = scored_buffer + self._background_loop_task: Optional[asyncio.Task] = None @trace_span() async def start(self) -> None: @@ -41,89 +75,34 @@ async def stop(self) -> None: await super().stop() if self._background_loop_task is not None: - self._background_loop_task.cancel() + await ensure_cancelled(self._background_loop_task) self._background_loop_task = None - def is_started(self) -> bool: - is_started = super().is_started() - return ( - is_started - and self._background_loop_task is not None - and not self._background_loop_task.done() - ) - - async def _background_loop(self): - # TODO: Rework this flow to: - # - Use composition instead of inheritance - # - Get rid of loop checking state for proper state - # - composable Buffer should expose async context that exposes its items - # and items asyncio.Condition + async def _on_added(self, proposal: Proposal) -> None: + pass # explicit no-op + async def _background_loop(self) -> None: while True: - logger.debug("Waiting for any requested items...") - await self._items_requested_event.wait() - logger.debug("Waiting for any requested items done") - - keep_retrying = True - while keep_retrying: # FIXME: Needs refactor too - logger.debug("Waiting up to %s for all requested items...", self._update_interval) - try: - await asyncio.wait_for( - self._requests_queue.join(), - timeout=self._update_interval.total_seconds(), - ) - except asyncio.TimeoutError: - logger.debug( - "Waiting up to %s for all requested items failed with timeout, trying to" - "update anyways...", - self._update_interval, - ) - else: - logger.debug( - "Waiting up to %s for all requested items done", self._update_interval - ) - keep_retrying = False - self._items_requested_event.clear() - - async with self._buffered_condition: - if not self._buffered: - logger.debug("Update not needed, as no items were buffered in the meantime") - continue - - items_to_score = self._buffered[:] - self._buffered.clear() - - async with self._scored_condition: - scored_proposals = await self.do_scoring(self._scored + items_to_score) - self._scored = [proposal for _, proposal in scored_proposals] - - logger.debug("Item collection updated %s", scored_proposals) - - self._scored_condition.notify_all() - - async def _get_item(self): - async with self._scored_condition: - if self._get_items_count() == 0: # This supports lazy (not at start) buffer filling - logger.debug("No items to get, requesting fill") - self._handle_item_requests() - - logger.debug("Waiting for any item to pick...") - - await self._scored_condition.wait_for(lambda: 0 < len(self._scored)) - item = self._scored.pop() - - # Check if we need to request any additional items - if self._get_items_count() < self._min_size: - self._handle_item_requests() - - return item - - def _handle_item_requests(self) -> None: - super()._handle_item_requests() - - self._items_requested_event.set() - - def _get_items_count(self): - items_count = super()._get_items_count() - - return items_count + len(self._scored) + logger.debug( + "Waiting for any proposals to score with debounce of `%s`...", + self._scoring_debounce, + ) + proposals = await self._buffer.get_requested(self._scoring_debounce) + logger.debug( + "Waiting for any proposals done, %d new proposals will be scored", len(proposals) + ) + + proposals.extend(await self._buffer_scored.get_all()) + + logger.debug("Scoring total %d proposals...", len(proposals)) + + scored_proposals = await self.do_scoring(proposals) + await self._buffer_scored.put_all([proposal for _, proposal in scored_proposals]) + + logger.debug("Scoring total %d proposals done", len(proposals)) + + async def _get_buffered_proposal(self) -> Proposal: + return await self._buffer_scored.get() + + def _get_buffered_proposals_count(self) -> int: + return super()._get_buffered_proposals_count() + self._buffer_scored.size() diff --git a/golem/managers/work/plugins.py b/golem/managers/work/plugins.py index 5612b9bb..78a0eb8e 100644 --- a/golem/managers/work/plugins.py +++ b/golem/managers/work/plugins.py @@ -10,6 +10,7 @@ WorkManagerPlugin, WorkResult, ) +from golem.utils.asyncio import ensure_cancelled_many logger = logging.getLogger(__name__) @@ -70,14 +71,7 @@ async def wrapper(work: Work) -> WorkResult: tasks, return_when=asyncio.FIRST_COMPLETED ) - for task in tasks_pending: - task.cancel() - - for task in tasks_pending: - try: - await task - except asyncio.CancelledError: - pass + await ensure_cancelled_many(tasks_pending) return tasks_done.pop().result() diff --git a/golem/resources/demand/demand.py b/golem/resources/demand/demand.py index 2b9f4e27..34eef73a 100644 --- a/golem/resources/demand/demand.py +++ b/golem/resources/demand/demand.py @@ -1,5 +1,5 @@ import asyncio -from datetime import datetime +from datetime import datetime, timedelta from typing import TYPE_CHECKING, AsyncIterator, Callable, Dict, List, Optional, Union, cast from ya_market import RequestorApi @@ -15,6 +15,8 @@ if TYPE_CHECKING: from golem.node import GolemNode +DEFAULT_TTL = timedelta(hours=1) + class Demand(Resource[RequestorApi, models.Demand, _NULL, Proposal, _NULL], YagnaEventCollector): """A single demand on the Golem Network. @@ -139,3 +141,8 @@ async def get_demand_data(self) -> DemandData: ) return self._demand_data + + def get_expiration_date(self) -> datetime: + """Return expiration date to auto unsubscribe.""" + + return cast(datetime, self.data.timestamp) + DEFAULT_TTL diff --git a/golem/resources/proposal/proposal.py b/golem/resources/proposal/proposal.py index 488d0245..1f91e1f2 100644 --- a/golem/resources/proposal/proposal.py +++ b/golem/resources/proposal/proposal.py @@ -16,6 +16,8 @@ from golem.node import GolemNode from golem.resources.demand import Demand +DEFAULT_TTL = timedelta(hours=1) + class Proposal( Resource[ @@ -230,3 +232,14 @@ async def get_provider_name(self): node_info = NodeInfo.from_properties(proposal_data.properties) self._provider_node_name = node_info.name return self._provider_node_name + + def get_expiration_date(self) -> datetime: + """Return expiration date to auto unsubscribe. + + Note: As Proposal can have different expiration date than its Demand, it would be unusable + after demand expiration anyway, hence earliest from both dates is returned. + """ + demand_expiration_date = self.demand.get_expiration_date() + proposal_expiration_date = cast(datetime, self.data.timestamp) + DEFAULT_TTL + + return min(proposal_expiration_date, demand_expiration_date) diff --git a/golem/utils/asyncio.py b/golem/utils/asyncio.py deleted file mode 100644 index 78d5f2e2..00000000 --- a/golem/utils/asyncio.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -import contextvars -import logging -from typing import Optional - -from golem.utils.logging import trace_id_var - -logger = logging.getLogger(__name__) - - -def create_task_with_logging(coro, *, trace_id: Optional[str] = None) -> asyncio.Task: - context = contextvars.copy_context() - task = context.run(_create_task_with_logging, coro, trace_id=trace_id) - - if trace_id is not None: - task_name = trace_id - else: - task_name = task.get_name() - - logger.debug(f"Task `{task_name}` created") - - return task - - -def _create_task_with_logging(coro, *, trace_id: Optional[str] = None) -> asyncio.Task: - if trace_id is not None: - trace_id_var.set(trace_id) - - task = asyncio.create_task(coro) - task.add_done_callback(_handle_task_logging) - return task - - -def _handle_task_logging(task: asyncio.Task): - try: - return task.result() - except asyncio.CancelledError: - pass - except Exception: - logger.exception("Background async task encountered unhandled exception!") diff --git a/golem/utils/asyncio/__init__.py b/golem/utils/asyncio/__init__.py new file mode 100644 index 00000000..dbdef0df --- /dev/null +++ b/golem/utils/asyncio/__init__.py @@ -0,0 +1,29 @@ +from golem.utils.asyncio.buffer import ( + BackgroundFillBuffer, + Buffer, + ComposableBuffer, + ExpirableBuffer, + SimpleBuffer, +) +from golem.utils.asyncio.queue import ErrorReportingQueue +from golem.utils.asyncio.semaphore import SingleUseSemaphore +from golem.utils.asyncio.tasks import ( + create_task_with_logging, + ensure_cancelled, + ensure_cancelled_many, +) +from golem.utils.asyncio.waiter import Waiter + +__all__ = ( + "BackgroundFillBuffer", + "Buffer", + "ComposableBuffer", + "ExpirableBuffer", + "SimpleBuffer", + "ErrorReportingQueue", + "SingleUseSemaphore", + "ensure_cancelled", + "ensure_cancelled_many", + "create_task_with_logging", + "Waiter", +) diff --git a/golem/utils/asyncio/buffer.py b/golem/utils/asyncio/buffer.py new file mode 100644 index 00000000..c437db37 --- /dev/null +++ b/golem/utils/asyncio/buffer.py @@ -0,0 +1,402 @@ +import asyncio +import logging +from abc import ABC, abstractmethod +from collections import defaultdict +from contextlib import asynccontextmanager +from datetime import timedelta +from typing import ( + Awaitable, + Callable, + Dict, + Generic, + List, + MutableSequence, + Optional, + Sequence, + TypeVar, +) + +from golem.utils.asyncio.semaphore import SingleUseSemaphore +from golem.utils.asyncio.tasks import ( + create_task_with_logging, + ensure_cancelled_many, + resolve_maybe_awaitable, +) +from golem.utils.logging import get_trace_id_name +from golem.utils.typing import MaybeAwaitable + +TItem = TypeVar("TItem") + +logger = logging.getLogger(__name__) + + +class Buffer(ABC, Generic[TItem]): + """Interface class for object similar to `asyncio.Queue` but with more control over its \ + items.""" + + # TODO: Consider hiding condition in favor of public lock related context manager method + condition: asyncio.Condition + + @abstractmethod + def size(self) -> int: + """Return the number of items stored in the buffer.""" + + @abstractmethod + async def wait_for_any_items(self, *, lock=True) -> None: + """Wait until any items are stored in the buffer.""" + + @abstractmethod + async def get(self, *, lock=True) -> TItem: + """Await, remove and return left-most item stored in the buffer. + + If `.set_exception()` was previously called, the exception will be raised only if the buffer + is empty. + """ + + @abstractmethod + async def get_all(self, *, lock=True) -> MutableSequence[TItem]: + """Remove and immediately return all items stored in the buffer. + + Note that this method should not await any items if the buffer is empty and should instead + return an empty iterable. + + If `.set_exception()` was previously called, exception should be raised only if buffer + is empty. + """ + + @abstractmethod + async def put(self, item: TItem, *, lock=True) -> None: + """Add item to right-most position to buffer. + + Duplicates are supported. + """ + + @abstractmethod + async def put_all(self, items: Sequence[TItem], *, lock=True) -> None: + """Replace all items stored in buffer. + + Duplicates are supported. + """ + + @abstractmethod + async def remove(self, item: TItem, *, lock=True) -> None: + """Remove the first occurrence of the item from the buffer or raise `ValueError` \ + if not found.""" + + @abstractmethod + async def set_exception(self, exc: BaseException, *, lock=True) -> None: + """Set the exception that will be raised on `.get()`/`.get_all()` if there are no more \ + items in the buffer.""" + + def reset_exception(self) -> None: + """Reset (clear) the exception that was previously set by calling `.set_exception()`.""" + + +class ComposableBuffer(Buffer[TItem]): + """Utility class for composable/stackable buffer implementations to help with calling \ + the underlying buffer.""" + + def __init__(self, buffer: Buffer[TItem]): + self._buffer = buffer + + @asynccontextmanager + async def _handle_lock(self, lock: bool): + if lock: + async with self._buffer.condition: + yield + else: + yield + + def size(self) -> int: + return self._buffer.size() + + async def wait_for_any_items(self, *, lock=True) -> None: + await self._buffer.wait_for_any_items(lock=lock) + + async def get(self, *, lock=True) -> TItem: + return await self._buffer.get(lock=lock) + + async def get_all(self, *, lock=True) -> MutableSequence[TItem]: + return await self._buffer.get_all(lock=lock) + + async def put(self, item: TItem, *, lock=True) -> None: + await self._buffer.put(item, lock=lock) + + async def put_all(self, items: Sequence[TItem], *, lock=True) -> None: + await self._buffer.put_all(items, lock=lock) + + async def remove(self, item: TItem, *, lock=True) -> None: + await self._buffer.remove(item, lock=lock) + + async def set_exception(self, exc: BaseException, *, lock=True) -> None: + await self._buffer.set_exception(exc, lock=lock) + + def reset_exception(self) -> None: + self._buffer.reset_exception() + + +class SimpleBuffer(Buffer[TItem]): + """Basic implementation of the Buffer interface.""" + + def __init__(self, items: Optional[Sequence[TItem]] = None): + self._items = list(items) if items else [] + self._error: Optional[BaseException] = None + + self.condition = asyncio.Condition() + + def size(self) -> int: + return len(self._items) + + @asynccontextmanager + async def _handle_lock(self, lock: bool): + if lock: + async with self.condition: + yield + else: + yield + + async def wait_for_any_items(self, lock=True) -> None: + async with self._handle_lock(lock): + await self.condition.wait_for(lambda: bool(self.size() or self._error)) + + async def get(self, *, lock=True) -> TItem: + async with self._handle_lock(lock): + await self.wait_for_any_items(lock=False) + + if not self.size() and self._error: + raise self._error + + return self._items.pop(0) + + async def get_all(self, *, lock=True) -> MutableSequence[TItem]: + async with self._handle_lock(lock): + if not self._items and self._error: + raise self._error + + items = self._items[:] + self._items.clear() + + return items + + async def put(self, item: TItem, *, lock=True) -> None: + async with self._handle_lock(lock): + self._items.append(item) + self.condition.notify() + + async def put_all(self, items: Sequence[TItem], *, lock=True) -> None: + async with self._handle_lock(lock): + self._items.clear() + self._items.extend(items[:]) + + self.condition.notify(len(items)) + + async def remove(self, item: TItem, *, lock=True) -> None: + async with self._handle_lock(lock): + self._items.remove(item) + + async def set_exception(self, exc: BaseException, *, lock=True) -> None: + async with self._handle_lock(lock): + self._error = exc + self.condition.notify() + + def reset_exception(self) -> None: + self._error = None + + +class ExpirableBuffer(ComposableBuffer[TItem]): + """Composable `Buffer` that adds option to expire item after some time. + + Items that are already in provided buffer will not expire. + """ + + # TODO: Optimisation options: Use single expiration task that wakes up to expire the earliest + # item, then check next earliest item and sleep to it and repeat + + def __init__( + self, + buffer: Buffer[TItem], + get_expiration_func: Callable[[TItem], MaybeAwaitable[Optional[timedelta]]], + on_expired_func: Optional[Callable[[TItem], MaybeAwaitable[None]]] = None, + ): + super().__init__(buffer) + + self._get_expiration_func = get_expiration_func + self._on_expired_func = on_expired_func + + # TODO: Could this collection be liable to race conditions? + self._expiration_handlers: Dict[int, List[asyncio.TimerHandle]] = defaultdict(list) + + async def _add_expiration_task_for_item(self, item: TItem) -> None: + expiration = await resolve_maybe_awaitable(self._get_expiration_func(item)) + + if expiration is None: + return + + loop = asyncio.get_event_loop() + + self._expiration_handlers[id(item)].append( + loop.call_later( + expiration.total_seconds(), + lambda: create_task_with_logging( + self._expire_item(item), trace_id=get_trace_id_name(self, f"item-expire-{item}") + ), + ) + ) + + def _remove_expiration_handler_for_item(self, item: TItem) -> None: + item_id = id(item) + + if item_id not in self._expiration_handlers or not len(self._expiration_handlers[item_id]): + return + + expiration_handle = self._expiration_handlers[item_id].pop(0) + expiration_handle.cancel() + + if not self._expiration_handlers[item_id]: + del self._expiration_handlers[item_id] + + def _remove_all_expiration_handlers(self) -> None: + for handlers in self._expiration_handlers.values(): + for handler in handlers: + handler.cancel() + + self._expiration_handlers.clear() + + async def get(self, *, lock=True) -> TItem: + async with self._handle_lock(lock): + item = await super().get(lock=False) + + self._remove_expiration_handler_for_item(item) + + return item + + async def get_all(self, *, lock=True) -> MutableSequence[TItem]: + async with self._handle_lock(lock): + items = await super().get_all(lock=False) + + self._remove_all_expiration_handlers() + + return items + + async def put(self, item: TItem, *, lock=True) -> None: + async with self._handle_lock(lock): + await super().put(item, lock=False) + + await self._add_expiration_task_for_item(item) + + async def put_all(self, items: Sequence[TItem], *, lock=True) -> None: + async with self._handle_lock(lock): + await super().put_all(items, lock=False) + + self._remove_all_expiration_handlers() + + await asyncio.gather( + *[self._add_expiration_task_for_item(item) for item in items], + return_exceptions=True, + ) + + async def remove(self, item: TItem, *, lock=True) -> None: + async with self._handle_lock(lock): + await super().remove(item, lock=False) + + self._remove_expiration_handler_for_item(item) + + async def _expire_item(self, item: TItem) -> None: + await self.remove(item) + + if self._on_expired_func: + await resolve_maybe_awaitable(self._on_expired_func(item)) + + +class BackgroundFillBuffer(ComposableBuffer[TItem]): + """Composable `Buffer` that adds option to fill buffer in background tasks. + + Background fill will happen only if background tasks are started by calling `.start()` + and items were requested by `.request()`. + """ + + def __init__( + self, + buffer: Buffer[TItem], + fill_func: Callable[[], Awaitable[TItem]], + fill_concurrency_size=1, + on_added_func: Optional[Callable[[TItem], Awaitable[None]]] = None, + ): + super().__init__(buffer) + + self._fill_func = fill_func + self._fill_concurrency_size = fill_concurrency_size + self._on_added_func = on_added_func + + self._is_started = False + self._worker_tasks: List[asyncio.Task] = [] + self._workers_semaphore = SingleUseSemaphore() + + async def start(self) -> None: + if self.is_started(): + raise RuntimeError("Already started!") + + for i in range(self._fill_concurrency_size): + self._worker_tasks.append( + create_task_with_logging( + self._worker_loop(), trace_id=get_trace_id_name(self, f"worker-{i}") + ) + ) + + self._is_started = True + + async def stop(self) -> None: + if not self.is_started(): + raise RuntimeError("Already stopped!") + + await ensure_cancelled_many(self._worker_tasks) + self._worker_tasks.clear() + self._is_started = False + + self._workers_semaphore.reset() + + def is_started(self) -> bool: + return self._is_started + + async def _worker_loop(self) -> None: + while True: + logger.debug("Waiting for fill item request...") + + async with self._workers_semaphore: + logger.debug("Waiting for fill item request done") + + logger.debug("Adding new item...") + + item = await self._fill_func() + + await self.put(item) + + if self._on_added_func: + await self._on_added_func(item) + + logger.debug("Adding new item done with total of %d items in buffer", self.size()) + + async def request(self, count: int) -> None: + """Request given number of items to be filled in background.""" + + await self._workers_semaphore.increase(count) + + def size_with_requested(self) -> int: + """Return sum of item count stored in buffer and requested to be filled.""" + + return self.size() + self._workers_semaphore.get_count_with_pending() + + async def get_requested(self, debounce: timedelta) -> MutableSequence[TItem]: + """Await for any requested items with given debounce time window, then remove and return \ + all items stored in buffer.""" + + await self._buffer.wait_for_any_items() + + try: + await asyncio.wait_for( + self._workers_semaphore.finished.wait(), debounce.total_seconds() + ) + except asyncio.TimeoutError: + pass + + return await self.get_all() diff --git a/golem/utils/queue.py b/golem/utils/asyncio/queue.py similarity index 76% rename from golem/utils/queue.py rename to golem/utils/asyncio/queue.py index 6c15d0d8..a2df2715 100644 --- a/golem/utils/queue.py +++ b/golem/utils/asyncio/queue.py @@ -1,21 +1,21 @@ import asyncio from typing import Generic, Optional, TypeVar -QueueItem = TypeVar("QueueItem") +from golem.utils.asyncio.tasks import ensure_cancelled_many +TQueueItem = TypeVar("TQueueItem") -class ErrorReportingQueue(asyncio.Queue, Generic[QueueItem]): - """Asyncio Queue that enables exceptions to be passed to consumers from the feeding code.""" - _error: Optional[BaseException] - _error_event: asyncio.Event +class ErrorReportingQueue(asyncio.Queue, Generic[TQueueItem]): + """Asyncio Queue that enables exceptions to be passed to consumers from the feeding code.""" def __init__(self, *args, **kwargs): - self._error = None - self._error_event = asyncio.Event() + self._error: Optional[BaseException] = None + self._error_event: asyncio.Event = asyncio.Event() + super().__init__(*args, **kwargs) - def get_nowait(self) -> QueueItem: + def get_nowait(self) -> TQueueItem: """Perform a regular `get_nowait` if there are items in the queue. Otherwise, if an exception had been signalled, raise the exception. @@ -24,7 +24,7 @@ def get_nowait(self) -> QueueItem: raise self._error return super().get_nowait() - async def get(self) -> QueueItem: + async def get(self) -> TQueueItem: """Perform a regular, waiting `get` but raise an exception if happens while waiting. If there had been items in the queue, @@ -37,7 +37,7 @@ async def get(self) -> QueueItem: [error_task, get_task], return_when=asyncio.FIRST_COMPLETED ) - [t.cancel() for t in pending] + await ensure_cancelled_many(pending) if get_task in done: return await get_task @@ -45,10 +45,10 @@ async def get(self) -> QueueItem: assert self._error raise self._error - async def put(self, item: QueueItem): + async def put(self, item: TQueueItem): await super().put(item) - def put_nowait(self, item: QueueItem): + def put_nowait(self, item: TQueueItem): super().put_nowait(item) def set_exception(self, exc: BaseException): diff --git a/golem/utils/asyncio/semaphore.py b/golem/utils/asyncio/semaphore.py new file mode 100644 index 00000000..8557d984 --- /dev/null +++ b/golem/utils/asyncio/semaphore.py @@ -0,0 +1,79 @@ +import asyncio + + +class SingleUseSemaphore: + """Class similar to `asyncio.Semaphore` but with more limited count of `.acquire()` calls and\ + exposed counters.""" + + def __init__(self, value=0): + if value < 0: + raise ValueError("Initial value must be greater or equal to zero!") + + self._value = value + + self._pending = 0 + self._condition = asyncio.Condition() + + self.finished = asyncio.Event() + if self.locked(): + self.finished.set() + + async def __aenter__(self): + await self.acquire() + + async def __aexit__(self, exc_type, exc, tb): + self.release() + + def locked(self) -> bool: + """Return True if there are no more "charges" left in semaphore.""" + + return not self._value + + async def acquire(self) -> None: + """Decrease "charges" counter and increase pending count, or await until there any\ + "charges".""" + + async with self._condition: + await self._condition.wait_for(lambda: self._value) + + self._value -= 1 + self._pending += 1 + + def release(self) -> None: + """Decrease pending count.""" + + if self._pending <= 0: + raise RuntimeError("Release called too many times!") + + self._pending -= 1 + + if self.locked(): + self.finished.set() + + async def increase(self, value: int) -> None: + """Add given "charges" amount.""" + + async with self._condition: + self._value += value + self.finished.clear() + self._condition.notify(value) + + def get_count(self) -> int: + """Return "charges" count.""" + + return self._value + + def get_count_with_pending(self) -> int: + """Return sum of "charges" and pending count.""" + + return self.get_count() + self.get_pending_count() + + def get_pending_count(self) -> int: + """Return pending count.""" + + return self._pending + + def reset(self) -> None: + """Reset "charges" amount to zero.""" + + self._value = 0 diff --git a/golem/utils/asyncio/tasks.py b/golem/utils/asyncio/tasks.py new file mode 100644 index 00000000..06496b0a --- /dev/null +++ b/golem/utils/asyncio/tasks.py @@ -0,0 +1,75 @@ +import asyncio +import contextvars +import inspect +import logging +from typing import Iterable, Optional, TypeVar, cast + +from golem.utils.logging import trace_id_var +from golem.utils.typing import MaybeAwaitable + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + + +def create_task_with_logging(coro, *, trace_id: Optional[str] = None) -> asyncio.Task: + context = contextvars.copy_context() + task = context.run(_create_task_with_logging, coro, trace_id=trace_id) + + if trace_id is not None: + task_name = trace_id + else: + task_name = task.get_name() + + logger.debug("Task `%s` created", task_name) + + return task + + +def _create_task_with_logging(coro, *, trace_id: Optional[str] = None) -> asyncio.Task: + if trace_id is not None: + trace_id_var.set(trace_id) + + task = asyncio.create_task(coro) + task.add_done_callback(_handle_task_logging) + return task + + +def _handle_task_logging(task: asyncio.Task): + try: + return task.result() + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Background async task encountered unhandled exception!") + + +async def ensure_cancelled(task: asyncio.Task) -> None: + """Cancel given task and await for its cancellation.""" + + if task.done(): + return + + task.cancel() + + try: + await task + except asyncio.CancelledError: + pass + + +async def ensure_cancelled_many(tasks: Iterable[asyncio.Task]) -> None: + """Cancel given tasks and concurrently await for their cancellation.""" + + await asyncio.gather(*[ensure_cancelled(task) for task in tasks]) + + +async def resolve_maybe_awaitable(value: MaybeAwaitable[T]) -> T: + """Return given value or await for it results if value is awaitable.""" + + if inspect.isawaitable(value): + return await value + + # TODO: remove cast as inspect.isawaitable can't tell mypy that at this + # point value is T + return cast(T, value) diff --git a/golem/utils/asyncio/waiter.py b/golem/utils/asyncio/waiter.py new file mode 100644 index 00000000..c0b7dfc3 --- /dev/null +++ b/golem/utils/asyncio/waiter.py @@ -0,0 +1,58 @@ +import asyncio +import collections +from typing import Callable, Deque + + +class Waiter: + """Class similar to `asyncio.Event` but valueless and with notify interface similar to \ + `asyncio.Condition`. + + Note: Developed to support `golem.utils.asyncio.buffer`, but finally not used. + """ + + def __init__(self) -> None: + self._waiters: Deque[asyncio.Future] = collections.deque() + self._loop = asyncio.get_event_loop() + + async def wait_for(self, predicate: Callable[[], bool]) -> None: + """Check if predicate is true and return immediately, or await until it becomes true.""" + + result = predicate() + + while not result: + await self._wait() + result = predicate() + + if not result: + # as last `._wait()` call woken us up but predicate is still false, lets give a + # chance another `.wait_for()` pending call. + self._notify_first() + + def _notify_first(self) -> None: + try: + first_waiter = self._waiters[0] + except IndexError: + return + + if not first_waiter.done(): + first_waiter.set_result(None) + + async def _wait(self) -> None: + future = self._loop.create_future() + self._waiters.append(future) + try: + await future + finally: + self._waiters.remove(future) + + def notify(self, count=1) -> None: + """Notify given amount of `.wait_for()` calls to check its predicates.""" + + notified = 0 + for waiter in self._waiters: + if count <= notified: + break + + if not waiter.done(): + waiter.set_result(None) + notified += 1 diff --git a/golem/utils/logging.py b/golem/utils/logging.py index 3f57cab7..11f3295e 100644 --- a/golem/utils/logging.py +++ b/golem/utils/logging.py @@ -1,6 +1,7 @@ import contextvars import inspect import logging +import sys from datetime import datetime, timezone from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union @@ -9,6 +10,12 @@ from golem.event_bus import Event +if (3, 12) <= sys.version_info: + FilterReturn = Union[bool, logging.LogRecord] +else: + FilterReturn = bool + + DEFAULT_LOGGING = { "version": 1, "disable_existing_loggers": False, @@ -241,7 +248,7 @@ async def _async_wrapper(self, func: Callable, args: Sequence, kwargs: Dict) -> class AddTraceIdFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: + def filter(self, record: logging.LogRecord) -> FilterReturn: record.traceid = trace_id_var.get() return super().filter(record) diff --git a/golem/utils/typing.py b/golem/utils/typing.py index 76b774dc..7877ba15 100644 --- a/golem/utils/typing.py +++ b/golem/utils/typing.py @@ -1,4 +1,8 @@ -from typing import Any, Callable, Optional, Type, Union, get_args, get_origin +from typing import Any, Awaitable, Callable, Optional, Type, TypeVar, Union, get_args, get_origin + +T = TypeVar("T") + +MaybeAwaitable = Union[Awaitable[T], T] def match_type_union_aware(obj_type: Type, match_func: Callable[[Type], bool]) -> Optional[Any]: diff --git a/pyproject.toml b/pyproject.toml index 18f9d79d..43826a08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ requires = ["poetry_core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.poetry.dependencies] -python = "^3.8.0" +python = "^3.8.1" # CLI prettytable = "^3.4.1" @@ -45,7 +45,7 @@ isort = "^5" black = "^23" mypy = "^1" -flake8 = "^5" +flake8 = "^7" flake8-docstrings = "^1" Flake8-pyproject = "^1" @@ -122,6 +122,7 @@ extend-ignore = [ "D104", # No docs for public package "D107", # No docs for __init__ "D202", # We prefer whitelines after docstrings + "W604", # Backticks complain on Python 3.12+ ] [tool.mypy] diff --git a/tests/unit/utils/test_buffer.py b/tests/unit/utils/test_buffer.py new file mode 100644 index 00000000..0fd71174 --- /dev/null +++ b/tests/unit/utils/test_buffer.py @@ -0,0 +1,477 @@ +import asyncio +from datetime import timedelta + +import pytest + +from golem.utils.asyncio.buffer import BackgroundFillBuffer, Buffer, ExpirableBuffer, SimpleBuffer + + +@pytest.fixture +def mocked_buffer(mocker): + mock = mocker.Mock(spec=Buffer) + + mock.condition = mocker.AsyncMock() + + return mock + + +def test_simple_buffer_creation(): + buffer: Buffer[str] = SimpleBuffer() + assert buffer.size() == 0 + + buffer = SimpleBuffer(["a", "b", "c"]) + assert buffer.size() == 3 + + +async def test_simple_buffer_put_get(): + buffer: Buffer[object] = SimpleBuffer() + assert buffer.size() == 0 + + item_put = object() + + await buffer.put(item_put) + + assert buffer.size() == 1 + + item_get = await buffer.get() + + assert buffer.size() == 0 + + assert item_put == item_get + + +async def test_simple_buffer_put_all_get_all(): + buffer = SimpleBuffer(["a"]) + assert buffer.size() == 1 + + await buffer.put_all(["b", "c", "d"]) + assert buffer.size() == 3 + + assert await buffer.get_all() == ["b", "c", "d"] + + assert buffer.size() == 0 + + assert await buffer.get_all() == [] + + +async def test_simple_buffer_remove(): + buffer = SimpleBuffer(["a", "b", "c"]) + + await buffer.remove("b") + + assert await buffer.get_all() == ["a", "c"] + + +async def test_simple_buffer_get_waits_for_items(): + buffer: Buffer[object] = SimpleBuffer() + assert buffer.size() == 0 + + _, pending = await asyncio.wait([asyncio.create_task(buffer.get())], timeout=0.1) + if not pending: + pytest.fail("Getting empty buffer somehow finished instead of blocking!") + + get_task = pending.pop() + + item_put = object() + await buffer.put(item_put) + + item_get = await asyncio.wait_for(get_task, timeout=0.1) + + assert item_get == item_put + + # concurrent wait + get_task1 = asyncio.create_task(buffer.get()) + get_task2 = asyncio.create_task(buffer.get()) + + await asyncio.sleep(0.1) + + await buffer.put(item_put) + + done, pending = await asyncio.wait([get_task1, get_task2], timeout=0.1) + if len(done) != len(pending): + pytest.fail("One of the tasks should not block at this point!") + + await buffer.put(item_put) + + await asyncio.sleep(0.1) + + await asyncio.wait_for(pending.pop(), timeout=0.1) + + +async def test_simple_buffer_keeps_item_order(): + buffer = SimpleBuffer(["a", "b", "c"]) + + assert await buffer.get() == "a" + assert await buffer.get() == "b" + assert await buffer.get() == "c" + + await buffer.put("d") + await buffer.put("e") + await buffer.put("f") + + assert await buffer.get() == "d" + assert await buffer.get_all() == ["e", "f"] + + +async def test_simple_buffer_keeps_shallow_copy_of_items(): + initial_items = ["a", "b", "c"] + buffer = SimpleBuffer(initial_items) + assert buffer.size() == 3 + + initial_items.extend(["d", "e", "f"]) + assert buffer.size() == 3 + + put_items = ["g", "h", "i"] + await buffer.put_all(put_items) + assert buffer.size() == 3 + + put_items.extend(["j", "k", "l"]) + assert buffer.size() == 3 + + +async def test_simple_buffer_exceptions(): + buffer: Buffer[str] = SimpleBuffer() + assert buffer.size() == 0 + + exc = ZeroDivisionError() + + await buffer.set_exception(exc) + + # should raise when exception set and no items + with pytest.raises(ZeroDivisionError): + await buffer.get() + + with pytest.raises(ZeroDivisionError): + await buffer.get_all() + + await buffer.put("a") + + # should not raise when exception set and with items + assert await buffer.get_all() == ["a"] + + # should raise when exception set and items were cleared + with pytest.raises(ZeroDivisionError): + await buffer.get_all() + + buffer.reset_exception() + + assert await buffer.get_all() == [] + + try: + await asyncio.wait_for(buffer.get(), timeout=0.1) + except asyncio.TimeoutError: + pass + else: + pytest.fail("Getting empty buffer somehow finished instead of blocking!") + + get_task = asyncio.create_task(buffer.get()) + + await asyncio.sleep(0.1) + + await buffer.set_exception(exc) + + await asyncio.sleep(0.1) + + with pytest.raises(ZeroDivisionError): + get_task.result() + + +async def test_simple_buffer_wait_for_any_items(): + buffer: Buffer[str] = SimpleBuffer() + assert buffer.size() == 0 + + # should block on empty + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) + + # should unblock on item + wait_task = asyncio.create_task(buffer.wait_for_any_items()) + + await asyncio.sleep(0.1) + + assert not wait_task.done() + + await buffer.put("a") + + await asyncio.sleep(0.1) + + assert wait_task.done() + + # should not block a long time on item + await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) + + await buffer.set_exception(ZeroDivisionError()) + + # should not block a long time on item with exception + await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) + + await buffer.get() + + # should not block a long time with exception and no items + await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) + + buffer.reset_exception() + + # should block on after exception reset and no items + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(buffer.wait_for_any_items(), timeout=0.1) + + # should unblock on item + wait_task = asyncio.create_task(buffer.wait_for_any_items()) + + await asyncio.sleep(0.1) + + assert not wait_task.done() + + await buffer.set_exception(ZeroDivisionError()) + + await asyncio.sleep(0.1) + + assert wait_task.done() + + +async def test_expirable_buffer_is_not_expiring_initial_items(mocked_buffer): + expire_after = timedelta(seconds=0.1) + ExpirableBuffer( + mocked_buffer, + lambda i: expire_after, + ) + + await asyncio.sleep(0.2) + + mocked_buffer.remove.assert_not_called() + + +async def test_expirable_buffer_is_not_expiring_items_with_none_expiration(mocked_buffer, mocker): + expiration_func = mocker.Mock( + side_effect=[timedelta(seconds=0.1), None, timedelta(seconds=0.1)] + ) + buffer = ExpirableBuffer( + mocked_buffer, + expiration_func, + ) + + await buffer.put("a") + await buffer.put("b") + await buffer.put("c") + + await asyncio.sleep(0.2) + + assert mocker.call("a", lock=False) in mocked_buffer.remove.mock_calls + assert mocker.call("c", lock=False) in mocked_buffer.remove.mock_calls + + mocked_buffer.get.return_value = "b" + + assert await buffer.get() == "b" + + +async def test_expirable_buffer_can_expire_items_with_put_get(mocked_buffer, mocker): + expire_after = timedelta(seconds=0.1) + on_expire = mocker.AsyncMock() + buffer = ExpirableBuffer( + mocked_buffer, + lambda i: expire_after, + on_expire, + ) + item_put = object() + + await buffer.put(item_put) + mocked_buffer.put.assert_called_with(item_put, lock=False) + + mocked_buffer.get.return_value = item_put + await buffer.get() + mocked_buffer.get.assert_called() + + mocked_buffer.remove.assert_not_called() + on_expire.assert_not_called() + + await asyncio.sleep(0.2) + + mocked_buffer.remove.assert_not_called() + on_expire.assert_not_called() + + mocked_buffer.reset_mock() + + await buffer.put(item_put) + mocked_buffer.put.assert_called_with(item_put, lock=False) + + await asyncio.sleep(0.2) + + mocked_buffer.remove.assert_called_with(item_put, lock=False) + on_expire.assert_called_with(item_put) + + +async def test_expirable_buffer_can_expire_items_with_put_all_get_all(mocked_buffer, mocker): + expire_after = timedelta(seconds=0.1) + on_expire = mocker.AsyncMock() + buffer = ExpirableBuffer( + mocked_buffer, + lambda i: expire_after, + on_expire, + ) + items_put_all = ["a", "b", "c"] + + await buffer.put_all(items_put_all) + mocked_buffer.put_all.assert_called_with(items_put_all, lock=False) + + mocked_buffer.get_all.return_value = items_put_all + await buffer.get_all() + mocked_buffer.get_all.assert_called_with(lock=False) + + with pytest.raises(AssertionError): + mocked_buffer.remove.assert_called_with(lock=False) + + on_expire.assert_not_called() + + await asyncio.sleep(0.2) + + on_expire.assert_not_called() + + mocked_buffer.reset_mock() + + await buffer.put_all(items_put_all) + mocked_buffer.put_all.assert_called_with(items_put_all, lock=False) + + await asyncio.sleep(0.2) + + assert mocker.call(items_put_all[0], lock=False) in mocked_buffer.remove.mock_calls + assert mocker.call(items_put_all[1], lock=False) in mocked_buffer.remove.mock_calls + assert mocker.call(items_put_all[2], lock=False) in mocked_buffer.remove.mock_calls + + assert mocker.call(items_put_all[0]) in on_expire.mock_calls + assert mocker.call(items_put_all[1]) in on_expire.mock_calls + assert mocker.call(items_put_all[2]) in on_expire.mock_calls + + +async def test_background_feed_buffer_start_stop(mocked_buffer, mocker): + feed_func = mocker.AsyncMock() + buffer = BackgroundFillBuffer( + mocked_buffer, + feed_func, + ) + + assert not buffer.is_started() + + await buffer.start() + + assert buffer.is_started() + + with pytest.raises(RuntimeError): + await buffer.start() + + assert buffer.is_started() + mocked_buffer.size.return_value = 0 + assert buffer.size_with_requested() == 0 + + await buffer.stop() + + assert not buffer.is_started() + + with pytest.raises(RuntimeError): + await buffer.stop() + + assert not buffer.is_started() + + feed_func.assert_not_called() + + +async def test_background_feed_buffer_request(mocked_buffer, mocker): + item = object() + feed_queue: asyncio.Queue[object] = asyncio.Queue() + feed_func = mocker.AsyncMock(wraps=feed_queue.get) + buffer = BackgroundFillBuffer( + mocked_buffer, + feed_func, + ) + await buffer.start() + + await buffer.request(1) + + await asyncio.sleep(0.1) + + feed_func.assert_called() + mocked_buffer.size.return_value = 0 + assert buffer.size() == 0 + assert buffer.size_with_requested() == 1 + + await feed_queue.put(item) + + await asyncio.sleep(0.1) + + mocked_buffer.put.assert_called_with(item, lock=True) + mocked_buffer.size.return_value = 1 + assert buffer.size() == 1 + assert buffer.size_with_requested() == 1 + + await buffer.stop() + + +async def test_background_feed_buffer_get_all_requested(mocked_buffer, mocker, event_loop): + timeout = timedelta(seconds=0.1) + item = object() + feed_queue: asyncio.Queue[object] = asyncio.Queue() + feed_func = mocker.AsyncMock(wraps=feed_queue.get) + buffer = BackgroundFillBuffer( + mocked_buffer, + feed_func, + ) + await buffer.start() + + # Try to get items while buffer is empty and with no requests + mocked_buffer.get_all.return_value = [] + mocked_buffer.size.return_value = 0 + + time_before_wait = event_loop.time() + assert await buffer.get_requested(timeout) == [] + time_after_wait = event_loop.time() + + assert ( + time_after_wait - time_before_wait < timeout.total_seconds() + ), "get_requested seems to wait for the deadline instead of retuning fast" + # + + await buffer.request(1) + + await asyncio.sleep(0.1) + + feed_func.assert_called() + mocked_buffer.size.return_value = 0 + assert buffer.size() == 0 + assert buffer.size_with_requested() == 1 + + # Try to get items while buffer is empty but with pending requests + mocked_buffer.get_all.return_value = [] + mocked_buffer.size.return_value = 0 + + time_before_wait = event_loop.time() + assert await buffer.get_requested(timeout) == [] + time_after_wait = event_loop.time() + + assert ( + timeout.total_seconds() <= time_after_wait - time_before_wait + ), "get_requested seems to not wait to the deadline" + # + + await feed_queue.put(item) + + await asyncio.sleep(0.1) + + mocked_buffer.put.assert_called_with(item, lock=True) + mocked_buffer.size.return_value = 1 + assert buffer.size() == 1 + assert buffer.size_with_requested() == 1 + + # Try to get items while buffer is have items and no pending requests + mocked_buffer.get_all.return_value = [item] + mocked_buffer.size.return_value = 0 + + time_before_wait = event_loop.time() + assert await buffer.get_requested(timeout) == [item] + time_after_wait = event_loop.time() + + assert ( + time_after_wait - time_before_wait < timeout.total_seconds() + ), "get_requested seems to wait for the deadline instead of retuning fast" + # + + await buffer.stop() diff --git a/tests/unit/utils/test_error_reporting_queue.py b/tests/unit/utils/test_error_reporting_queue.py index 2dba6c7e..318782df 100644 --- a/tests/unit/utils/test_error_reporting_queue.py +++ b/tests/unit/utils/test_error_reporting_queue.py @@ -2,7 +2,7 @@ import pytest -from golem.utils.queue import ErrorReportingQueue +from golem.utils.asyncio import ErrorReportingQueue class SomeException(Exception): diff --git a/tests/unit/utils/test_semaphore.py b/tests/unit/utils/test_semaphore.py new file mode 100644 index 00000000..ca986127 --- /dev/null +++ b/tests/unit/utils/test_semaphore.py @@ -0,0 +1,187 @@ +import asyncio + +import pytest + +from golem.utils.asyncio.semaphore import SingleUseSemaphore + + +async def test_creation(): + sem = SingleUseSemaphore() + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() + + sem = SingleUseSemaphore(0) + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() + + sem_count = 10 + sem = SingleUseSemaphore(sem_count) + + assert sem.get_count() == sem_count + assert sem.get_count_with_pending() == sem_count + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + with pytest.raises(ValueError): + SingleUseSemaphore(-10) + + +async def test_increase(): + sem = SingleUseSemaphore() + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() + + await sem.increase(1) + + assert sem.get_count() == 1 + assert sem.get_count_with_pending() == 1 + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + await sem.increase(10) + + assert sem.get_count() == 11 + assert sem.get_count_with_pending() == 11 + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + +async def test_reset(): + sem_value = 10 + sem = SingleUseSemaphore(sem_value) + + assert sem.get_count() == sem_value + assert sem.get_count_with_pending() == sem_value + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + sem.reset() + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert sem.locked() + + +async def test_acquire(): + sem_value = 1 + sem = SingleUseSemaphore(sem_value) + + assert sem.get_count() == sem_value + assert sem.get_count_with_pending() == sem_value + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + await asyncio.wait_for(sem.acquire(), 0.1) + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 1 + assert sem.get_pending_count() == 1 + assert not sem.finished.is_set() + assert sem.locked() + + _, pending = await asyncio.wait([asyncio.create_task(sem.acquire())], timeout=0.1) + acquire_task = pending.pop() + if not acquire_task: + pytest.fail("Acquiring locked semaphore somehow finished instead of blocking!") + + await sem.increase(1) + + assert sem.get_count() == 1 + assert sem.get_count_with_pending() == 2 + assert sem.get_pending_count() == 1 + assert not sem.finished.is_set() + assert not sem.locked() + + await asyncio.wait_for(acquire_task, 0.1) + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 2 + assert sem.get_pending_count() == 2 + assert not sem.finished.is_set() + assert sem.locked() + + +async def test_release(): + sem_value = 1 + sem = SingleUseSemaphore(sem_value) + + assert sem.get_count() == sem_value + assert sem.get_count_with_pending() == sem_value + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + with pytest.raises(RuntimeError): + sem.release() + + await asyncio.wait_for(sem.acquire(), 0.1) + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 1 + assert sem.get_pending_count() == 1 + assert not sem.finished.is_set() + assert sem.locked() + + sem.release() + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() + + +async def test_context_manager(): + sem_value = 2 + sem = SingleUseSemaphore(sem_value) + + assert sem.get_count() == sem_value + assert sem.get_count_with_pending() == sem_value + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + async with sem: + assert sem.get_count() == 1 + assert sem.get_count_with_pending() == sem_value + assert sem.get_pending_count() == 1 + assert not sem.finished.is_set() + assert not sem.locked() + + assert sem.get_count() == 1 + assert sem.get_count_with_pending() == 1 + assert sem.get_pending_count() == 0 + assert not sem.finished.is_set() + assert not sem.locked() + + async with sem: + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 1 + assert sem.get_pending_count() == 1 + assert not sem.finished.is_set() + assert sem.locked() + + assert sem.get_count() == 0 + assert sem.get_count_with_pending() == 0 + assert sem.get_pending_count() == 0 + assert sem.finished.is_set() + assert sem.locked() diff --git a/tests/unit/utils/test_waiter.py b/tests/unit/utils/test_waiter.py new file mode 100644 index 00000000..486da920 --- /dev/null +++ b/tests/unit/utils/test_waiter.py @@ -0,0 +1,82 @@ +import asyncio + +import pytest + +from golem.utils.asyncio import Waiter + + +async def test_waiter_single_wait_for(): + waiter = Waiter() + value = 0 + + wait_task = asyncio.create_task(waiter.wait_for(lambda: value == 1)) + + await asyncio.sleep(0.1) + + assert not wait_task.done() + + waiter.notify() + + await asyncio.sleep(0.1) + + assert not wait_task.done() + + value = 1 + waiter.notify() + + await asyncio.wait_for(wait_task, timeout=0.1) + + +async def test_waiter_multiple_wait_for(): + waiter = Waiter() + value = 0 + + wait_task1 = asyncio.create_task(waiter.wait_for(lambda: value == 1)) + wait_task2 = asyncio.create_task(waiter.wait_for(lambda: value == 1)) + wait_task3 = asyncio.create_task(waiter.wait_for(lambda: value == 1)) + + await asyncio.sleep(0.1) + + assert not wait_task1.done() + + value = 1 + waiter.notify() + + await asyncio.wait_for(wait_task1, timeout=0.1) + + done, _ = await asyncio.wait([wait_task2, wait_task3], timeout=0.1) + + if done: + pytest.fail("Somehow some tasks finished too early!") + + waiter.notify(2) + + _, pending = await asyncio.wait([wait_task2, wait_task3], timeout=0.1) + + if pending: + pytest.fail("Somehow some tasks not finished!") + + +async def test_waiter_will_renotify_when_predicate_was_false(): + waiter = Waiter() + value = 0 + + wait_task1 = asyncio.create_task(waiter.wait_for(lambda: value == 2)) + wait_task2 = asyncio.create_task(waiter.wait_for(lambda: value == 1)) + + await asyncio.sleep(0.1) + assert not wait_task2.done() and not wait_task1.done() + + value = 1 + waiter.notify() + + await asyncio.sleep(0.1) + + assert wait_task2.done(), "Task2 should be done at this point!" + assert not wait_task1.done(), "Task1 should still block at this point!" + + value = 2 + + waiter.notify() + + await asyncio.wait_for(wait_task1, timeout=0.1)