Skip to content

Commit

Permalink
Add test_work_manager_plugins_manager_mixin_ok
Browse files Browse the repository at this point in the history
  • Loading branch information
lucekdudek committed Jul 27, 2023
1 parent 2f2a9f0 commit 46e8767
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 13 deletions.
28 changes: 27 additions & 1 deletion golem/managers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,22 @@
RandomScore,
ScoredAheadOfTimeAgreementManager,
)
from golem.managers.base import Manager, ManagerScorePlugin, RejectProposal, WorkContext, WorkResult
from golem.managers.base import (
ActivityManager,
AgreementManager,
DemandManager,
DoWorkCallable,
Manager,
ManagerScorePlugin,
NegotiationManager,
NetworkManager,
PaymentManager,
RejectProposal,
Work,
WorkContext,
WorkManager,
WorkResult,
)
from golem.managers.demand import AutoDemandManager
from golem.managers.mixins import BackgroundLoopMixin, WeightProposalScoringPluginsMixin
from golem.managers.negotiation import (
Expand All @@ -19,6 +34,7 @@
from golem.managers.payment import PayAllPaymentManager
from golem.managers.work import (
SequentialWorkManager,
WorkManagerPluginsMixin,
redundancy_cancel_others_on_first_done,
retry,
work_plugin,
Expand All @@ -33,11 +49,20 @@
"PropertyValueLerpScore",
"RandomScore",
"ScoredAheadOfTimeAgreementManager",
"DoWorkCallable",
"Manager",
"ManagerScorePlugin",
"RejectProposal",
"Work",
"WorkManager",
"WorkContext",
"WorkResult",
"NetworkManager",
"PaymentManager",
"DemandManager",
"NegotiationManager",
"AgreementManager",
"ActivityManager",
"AutoDemandManager",
"BackgroundLoopMixin",
"WeightProposalScoringPluginsMixin",
Expand All @@ -48,6 +73,7 @@
"SingleNetworkManager",
"PayAllPaymentManager",
"SequentialWorkManager",
"WorkManagerPluginsMixin",
"redundancy_cancel_others_on_first_done",
"retry",
"work_plugin",
Expand Down
9 changes: 1 addition & 8 deletions golem/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,7 @@ class WorkResult:

WORK_PLUGIN_FIELD_NAME = "_work_plugins"


class Work(ABC):
_work_plugins: Optional[List["WorkManagerPlugin"]]

@abstractmethod
def __call__(self, context: WorkContext) -> Awaitable[Optional[WorkResult]]:
...

Work = Callable[[WorkContext], Awaitable[Optional[WorkResult]]]

DoWorkCallable = Callable[[Work], Awaitable[WorkResult]]

Expand Down
7 changes: 4 additions & 3 deletions golem/managers/work/plugins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from functools import wraps
from typing import List

from golem.managers.base import (
WORK_PLUGIN_FIELD_NAME,
Expand All @@ -16,9 +17,9 @@
def work_plugin(plugin: WorkManagerPlugin):
def _work_plugin(work: Work):
if not hasattr(work, WORK_PLUGIN_FIELD_NAME):
work._work_plugins = []
setattr(work, WORK_PLUGIN_FIELD_NAME, [])

work._work_plugins.append(plugin) # type: ignore [union-attr]
getattr(work, WORK_PLUGIN_FIELD_NAME).append(plugin)

return work

Expand Down Expand Up @@ -63,7 +64,7 @@ def redundancy_cancel_others_on_first_done(size: int):
def _redundancy(do_work: DoWorkCallable):
@wraps(do_work)
async def wrapper(work: Work) -> WorkResult:
tasks = [asyncio.ensure_future(do_work(work)) for _ in range(size)]
tasks: List[asyncio.Task] = [asyncio.ensure_future(do_work(work)) for _ in range(size)]

tasks_done, tasks_pending = await asyncio.wait(
tasks, return_when=asyncio.FIRST_COMPLETED
Expand Down
54 changes: 54 additions & 0 deletions tests/unit/test_managers_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@

from golem.managers import (
BackgroundLoopMixin,
DoWorkCallable,
LinearAverageCostPricing,
Manager,
ManagerScorePlugin,
MapScore,
PropertyValueLerpScore,
WeightProposalScoringPluginsMixin,
Work,
WorkContext,
WorkManager,
WorkManagerPluginsMixin,
WorkResult,
)
from golem.payload import defaults

Expand Down Expand Up @@ -138,3 +144,51 @@ async def test_weight_proposal_scoring_plugins_mixin_ok(
manager = FooBarWeightProposalScoringPluginsManager(plugins=given_plugins)
received_proposals = await manager.do_scoring(given_proposals)
assert expected_weights == [weight for weight, _ in received_proposals]


class FooBarWorkManagerPluginsManager(WorkManagerPluginsMixin, WorkManager):
def __init__(self, do_work: DoWorkCallable, *args, **kwargs):
self._do_work = do_work
super().__init__(*args, **kwargs)

async def do_work(self, work: Work) -> WorkResult:
return await self._do_work_with_plugins(self._do_work, work)


@pytest.mark.parametrize(
"expected_work_result, expected_called_count",
(
("ZERO", None),
("ONE", 1),
("TWO", 2),
("TEN", 10),
),
)
async def test_work_manager_plugins_manager_mixin_ok(
expected_work_result: str, expected_called_count: Optional[int]
):
async def _do_work_func(work: Work) -> WorkResult:
work_result = await work(AsyncMock())
if not isinstance(work_result, WorkResult):
work_result = WorkResult(result=work_result)
return work_result

async def _work(context: WorkContext) -> Optional[WorkResult]:
return WorkResult(result=expected_work_result)

def _plugin(do_work: DoWorkCallable) -> DoWorkCallable:
async def wrapper(work: Work) -> WorkResult:
work_result = await do_work(work)
work_result.extras["called_count"] = work_result.extras.get("called_count", 0) + 1
return work_result

return wrapper

work_plugins = [_plugin for _ in range(expected_called_count or 0)]

manager = FooBarWorkManagerPluginsManager(do_work=_do_work_func, plugins=work_plugins)

result = await manager.do_work(_work)

assert result.result == expected_work_result
assert result.extras.get("called_count") == expected_called_count
2 changes: 1 addition & 1 deletion tests/unit/utils/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def test_buffer_get_item_will_trigger_fill_on_below_min_size(create_buffer
assert fill_callback.await_count == 10

done, _ = await asyncio.wait(
[asyncio.ensure_future(buffer.get_item()) for _ in range(6)], timeout=0.05
[asyncio.create_task(buffer.get_item()) for _ in range(6)], timeout=0.05
)

assert [d.result() for d in done] == fill_callback.mock_calls[1:7]
Expand Down

0 comments on commit 46e8767

Please sign in to comment.