From 46e876765b6a79ef9cb9135788646e60e06722ac Mon Sep 17 00:00:00 2001 From: Lucjan Dudek Date: Thu, 27 Jul 2023 11:55:51 +0200 Subject: [PATCH] Add test_work_manager_plugins_manager_mixin_ok --- golem/managers/__init__.py | 28 +++++++++++++++- golem/managers/base.py | 9 +---- golem/managers/work/plugins.py | 7 ++-- tests/unit/test_managers_mixins.py | 54 ++++++++++++++++++++++++++++++ tests/unit/utils/test_buffer.py | 2 +- 5 files changed, 87 insertions(+), 13 deletions(-) diff --git a/golem/managers/__init__.py b/golem/managers/__init__.py index f9ba9ba3..bd84cee2 100644 --- a/golem/managers/__init__.py +++ b/golem/managers/__init__.py @@ -6,7 +6,22 @@ RandomScore, ScoredAheadOfTimeAgreementManager, ) -from golem.managers.base import Manager, ManagerScorePlugin, RejectProposal, WorkContext, WorkResult +from golem.managers.base import ( + ActivityManager, + AgreementManager, + DemandManager, + DoWorkCallable, + Manager, + ManagerScorePlugin, + NegotiationManager, + NetworkManager, + PaymentManager, + RejectProposal, + Work, + WorkContext, + WorkManager, + WorkResult, +) from golem.managers.demand import AutoDemandManager from golem.managers.mixins import BackgroundLoopMixin, WeightProposalScoringPluginsMixin from golem.managers.negotiation import ( @@ -19,6 +34,7 @@ from golem.managers.payment import PayAllPaymentManager from golem.managers.work import ( SequentialWorkManager, + WorkManagerPluginsMixin, redundancy_cancel_others_on_first_done, retry, work_plugin, @@ -33,11 +49,20 @@ "PropertyValueLerpScore", "RandomScore", "ScoredAheadOfTimeAgreementManager", + "DoWorkCallable", "Manager", "ManagerScorePlugin", "RejectProposal", + "Work", + "WorkManager", "WorkContext", "WorkResult", + "NetworkManager", + "PaymentManager", + "DemandManager", + "NegotiationManager", + "AgreementManager", + "ActivityManager", "AutoDemandManager", "BackgroundLoopMixin", "WeightProposalScoringPluginsMixin", @@ -48,6 +73,7 @@ "SingleNetworkManager", "PayAllPaymentManager", "SequentialWorkManager", + "WorkManagerPluginsMixin", "redundancy_cancel_others_on_first_done", "retry", "work_plugin", diff --git a/golem/managers/base.py b/golem/managers/base.py index b8c8da4b..e5598cac 100644 --- a/golem/managers/base.py +++ b/golem/managers/base.py @@ -88,14 +88,7 @@ class WorkResult: WORK_PLUGIN_FIELD_NAME = "_work_plugins" - -class Work(ABC): - _work_plugins: Optional[List["WorkManagerPlugin"]] - - @abstractmethod - def __call__(self, context: WorkContext) -> Awaitable[Optional[WorkResult]]: - ... - +Work = Callable[[WorkContext], Awaitable[Optional[WorkResult]]] DoWorkCallable = Callable[[Work], Awaitable[WorkResult]] diff --git a/golem/managers/work/plugins.py b/golem/managers/work/plugins.py index 70658c20..5612b9bb 100644 --- a/golem/managers/work/plugins.py +++ b/golem/managers/work/plugins.py @@ -1,6 +1,7 @@ import asyncio import logging from functools import wraps +from typing import List from golem.managers.base import ( WORK_PLUGIN_FIELD_NAME, @@ -16,9 +17,9 @@ def work_plugin(plugin: WorkManagerPlugin): def _work_plugin(work: Work): if not hasattr(work, WORK_PLUGIN_FIELD_NAME): - work._work_plugins = [] + setattr(work, WORK_PLUGIN_FIELD_NAME, []) - work._work_plugins.append(plugin) # type: ignore [union-attr] + getattr(work, WORK_PLUGIN_FIELD_NAME).append(plugin) return work @@ -63,7 +64,7 @@ def redundancy_cancel_others_on_first_done(size: int): def _redundancy(do_work: DoWorkCallable): @wraps(do_work) async def wrapper(work: Work) -> WorkResult: - tasks = [asyncio.ensure_future(do_work(work)) for _ in range(size)] + tasks: List[asyncio.Task] = [asyncio.ensure_future(do_work(work)) for _ in range(size)] tasks_done, tasks_pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED diff --git a/tests/unit/test_managers_mixins.py b/tests/unit/test_managers_mixins.py index bf1552af..73e05c23 100644 --- a/tests/unit/test_managers_mixins.py +++ b/tests/unit/test_managers_mixins.py @@ -8,12 +8,18 @@ from golem.managers import ( BackgroundLoopMixin, + DoWorkCallable, LinearAverageCostPricing, Manager, ManagerScorePlugin, MapScore, PropertyValueLerpScore, WeightProposalScoringPluginsMixin, + Work, + WorkContext, + WorkManager, + WorkManagerPluginsMixin, + WorkResult, ) from golem.payload import defaults @@ -138,3 +144,51 @@ async def test_weight_proposal_scoring_plugins_mixin_ok( manager = FooBarWeightProposalScoringPluginsManager(plugins=given_plugins) received_proposals = await manager.do_scoring(given_proposals) assert expected_weights == [weight for weight, _ in received_proposals] + + +class FooBarWorkManagerPluginsManager(WorkManagerPluginsMixin, WorkManager): + def __init__(self, do_work: DoWorkCallable, *args, **kwargs): + self._do_work = do_work + super().__init__(*args, **kwargs) + + async def do_work(self, work: Work) -> WorkResult: + return await self._do_work_with_plugins(self._do_work, work) + + +@pytest.mark.parametrize( + "expected_work_result, expected_called_count", + ( + ("ZERO", None), + ("ONE", 1), + ("TWO", 2), + ("TEN", 10), + ), +) +async def test_work_manager_plugins_manager_mixin_ok( + expected_work_result: str, expected_called_count: Optional[int] +): + async def _do_work_func(work: Work) -> WorkResult: + work_result = await work(AsyncMock()) + if not isinstance(work_result, WorkResult): + work_result = WorkResult(result=work_result) + return work_result + + async def _work(context: WorkContext) -> Optional[WorkResult]: + return WorkResult(result=expected_work_result) + + def _plugin(do_work: DoWorkCallable) -> DoWorkCallable: + async def wrapper(work: Work) -> WorkResult: + work_result = await do_work(work) + work_result.extras["called_count"] = work_result.extras.get("called_count", 0) + 1 + return work_result + + return wrapper + + work_plugins = [_plugin for _ in range(expected_called_count or 0)] + + manager = FooBarWorkManagerPluginsManager(do_work=_do_work_func, plugins=work_plugins) + + result = await manager.do_work(_work) + + assert result.result == expected_work_result + assert result.extras.get("called_count") == expected_called_count diff --git a/tests/unit/utils/test_buffer.py b/tests/unit/utils/test_buffer.py index b770c42a..38697942 100644 --- a/tests/unit/utils/test_buffer.py +++ b/tests/unit/utils/test_buffer.py @@ -118,7 +118,7 @@ async def test_buffer_get_item_will_trigger_fill_on_below_min_size(create_buffer assert fill_callback.await_count == 10 done, _ = await asyncio.wait( - [asyncio.ensure_future(buffer.get_item()) for _ in range(6)], timeout=0.05 + [asyncio.create_task(buffer.get_item()) for _ in range(6)], timeout=0.05 ) assert [d.result() for d in done] == fill_callback.mock_calls[1:7]