-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update / add task execution hooks #1269
base: main
Are you sure you want to change the base?
Changes from all commits
65ccf5f
87c33c0
68c6eb6
1680fe8
aff7fdf
47d9553
18e6874
c7c4993
7d42e6d
b70e9b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
=================================== | ||
lifecycle.api.TaskExecutionHook | ||
=================================== | ||
|
||
|
||
.. autoclass:: hamilton.lifecycle.api.TaskExecutionHook | ||
:special-members: __init__ | ||
:members: | ||
:inherited-members: |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
=================================== | ||
lifecycle.api.TaskGroupingHook | ||
=================================== | ||
|
||
|
||
.. autoclass:: hamilton.lifecycle.api.TaskGroupingHook | ||
:special-members: __init__ | ||
:members: | ||
:inherited-members: |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,11 @@ | |
# To really fix this we should move everything user-facing out of base, which is a pretty sloppy name for a package anyway | ||
# And put it where it belongs. For now we're OK with the TYPE_CHECKING hack | ||
if TYPE_CHECKING: | ||
from hamilton.execution.grouping import NodeGroupPurpose, TaskSpec | ||
from hamilton.graph import FunctionGraph | ||
else: | ||
NodeGroupPurpose = None | ||
TaskSpec = None | ||
|
||
from hamilton.graph_types import HamiltonGraph, HamiltonNode | ||
from hamilton.lifecycle.base import ( | ||
|
@@ -27,6 +31,8 @@ | |
BasePostGraphExecute, | ||
BasePostNodeExecute, | ||
BasePostTaskExecute, | ||
BasePostTaskExpand, | ||
BasePostTaskGroup, | ||
BasePreGraphExecute, | ||
BasePreNodeExecute, | ||
BasePreTaskExecute, | ||
|
@@ -379,13 +385,17 @@ def pre_task_execute( | |
nodes: List["node.Node"], | ||
inputs: Dict[str, Any], | ||
overrides: Dict[str, Any], | ||
spawning_task_id: Optional[str], | ||
purpose: NodeGroupPurpose, | ||
): | ||
self.run_before_task_execution( | ||
run_id=run_id, | ||
task_id=task_id, | ||
nodes=[HamiltonNode.from_node(n) for n in nodes], | ||
inputs=inputs, | ||
overrides=overrides, | ||
spawning_task_id=spawning_task_id, | ||
purpose=purpose, | ||
) | ||
|
||
def post_task_execute( | ||
|
@@ -397,6 +407,8 @@ def post_task_execute( | |
results: Optional[Dict[str, Any]], | ||
success: bool, | ||
error: Exception, | ||
spawning_task_id: Optional[str], | ||
purpose: NodeGroupPurpose, | ||
): | ||
self.run_after_task_execution( | ||
run_id=run_id, | ||
|
@@ -405,6 +417,8 @@ def post_task_execute( | |
results=results, | ||
success=success, | ||
error=error, | ||
spawning_task_id=spawning_task_id, | ||
purpose=purpose, | ||
) | ||
|
||
@abc.abstractmethod | ||
|
@@ -416,6 +430,8 @@ def run_before_task_execution( | |
nodes: List[HamiltonNode], | ||
inputs: Dict[str, Any], | ||
overrides: Dict[str, Any], | ||
spawning_task_id: Optional[str], | ||
purpose: NodeGroupPurpose, | ||
**future_kwargs, | ||
): | ||
"""Implement this to run something after task execution. Tasks are tols used to group nodes. | ||
|
@@ -428,6 +444,8 @@ def run_before_task_execution( | |
:param inputs: Inputs to the task | ||
:param overrides: Overrides passed to the task | ||
:param future_kwargs: Reserved for backwards compatibility. | ||
:param spawning_task_id: ID of the task that spawned this task | ||
:param purpose: Purpose of the current task group | ||
""" | ||
pass | ||
|
||
|
@@ -441,6 +459,8 @@ def run_after_task_execution( | |
results: Optional[Dict[str, Any]], | ||
success: bool, | ||
error: Exception, | ||
spawning_task_id: Optional[str], | ||
purpose: NodeGroupPurpose, | ||
**future_kwargs, | ||
): | ||
"""Implement this to run something after task execution. See note in run_before_task_execution. | ||
|
@@ -452,6 +472,8 @@ def run_after_task_execution( | |
:param success: Whether the task was successful | ||
:param error: The error the task threw, if any | ||
:param future_kwargs: Reserved for backwards compatibility. | ||
:param spawning_task_id: ID of the task that spawned this task | ||
:param purpose: Purpose of the current task group | ||
""" | ||
pass | ||
|
||
|
@@ -614,6 +636,42 @@ def validate_graph( | |
return self.run_to_validate_graph(graph=HamiltonGraph.from_graph(graph)) | ||
|
||
|
||
class TaskGroupingHook(BasePostTaskGroup, BasePostTaskExpand): | ||
"""Implement this to run something after task grouping or task expansion. This will allow you to | ||
capture information about the tasks during `Parallelize`/`Collect` blocks in dynamic DAG execution.""" | ||
|
||
@override | ||
@final | ||
def post_task_group(self, *, run_id: str, tasks: List[TaskSpec]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure we should expose the task spec here now -- what data in it do we need? can narrow the contract to give us more flexibility... |
||
return self.run_after_task_grouping(run_id=run_id, tasks=tasks) | ||
|
||
@override | ||
@final | ||
def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, Any]): | ||
return self.run_after_task_expansion(run_id=run_id, task_id=task_id, parameters=parameters) | ||
|
||
@abc.abstractmethod | ||
def run_after_task_grouping(self, *, run_id: str, tasks: List[TaskSpec], **future_kwargs): | ||
"""Hook that is called after task grouping. | ||
:param run_id: ID of the run, unique in scope of the driver. | ||
:param tasks: List of tasks that were grouped together. | ||
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility. | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def run_after_task_expansion( | ||
self, *, run_id: str, task_id: str, parameters: Dict[str, Any], **future_kwargs | ||
): | ||
"""Hook that is called after task expansion. | ||
:param run_id: ID of the run, unique in scope of the driver. | ||
:param task_id: ID of the task that was expanded. | ||
:param parameters: Parameters that were passed to the task. | ||
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility. | ||
""" | ||
pass | ||
|
||
|
||
class GraphConstructionHook(BasePostGraphConstruct, abc.ABC): | ||
"""Hook that is run after graph construction. This allows you to register/capture info on the graph. | ||
Note that, in the case of materialization, this may be called multiple times (once when we create the graph, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,10 @@ | |
# python, which (our usage of) leans type-hinting trigger-happy, this will suffice. | ||
if TYPE_CHECKING: | ||
from hamilton import graph, node | ||
from hamilton.execution.grouping import NodeGroupPurpose, TaskSpec | ||
else: | ||
NodeGroupPurpose = None | ||
TaskSpec = None | ||
|
||
# All of these are internal APIs. Specifically, structure required to manage a set of | ||
# hooks/methods/validators that we will likely expand. We store them in constants (rather than, say, a more complex single object) | ||
|
@@ -418,6 +422,8 @@ def pre_task_execute( | |
nodes: List["node.Node"], | ||
inputs: Dict[str, Any], | ||
overrides: Dict[str, Any], | ||
spawning_task_id: Optional[str], | ||
purpose: NodeGroupPurpose, | ||
): | ||
"""Hook that is called immediately prior to task execution. Note that this is only useful in dynamic | ||
execution, although we reserve the right to add this back into the standard hamilton execution pattern. | ||
|
@@ -427,6 +433,8 @@ def pre_task_execute( | |
:param nodes: Nodes that are being executed | ||
:param inputs: Inputs to the task | ||
:param overrides: Overrides to task execution | ||
:param spawning_task_id: ID of the task that spawned this task | ||
:param purpose: Purpose of the current task group | ||
""" | ||
pass | ||
|
||
|
@@ -442,6 +450,8 @@ async def pre_task_execute( | |
nodes: List["node.Node"], | ||
inputs: Dict[str, Any], | ||
overrides: Dict[str, Any], | ||
spawning_task_id: Optional[str], | ||
purpose: NodeGroupPurpose, | ||
): | ||
"""Hook that is called immediately prior to task execution. Note that this is only useful in dynamic | ||
execution, although we reserve the right to add this back into the standard hamilton execution pattern. | ||
|
@@ -451,6 +461,8 @@ async def pre_task_execute( | |
:param nodes: Nodes that are being executed | ||
:param inputs: Inputs to the task | ||
:param overrides: Overrides to task execution | ||
:param spawning_task_id: ID of the task that spawned this task | ||
:param purpose: Purpose of the current task group | ||
""" | ||
pass | ||
|
||
|
@@ -615,6 +627,31 @@ async def post_node_execute( | |
pass | ||
|
||
|
||
@lifecycle.base_hook("post_task_group") | ||
class BasePostTaskGroup(abc.ABC): | ||
@abc.abstractmethod | ||
def post_task_group(self, *, run_id: str, tasks: List[TaskSpec]): | ||
"""Hook that is called immediately after a task group is created. Note that this is only useful in dynamic | ||
execution, although we reserve the right to add this back into the standard hamilton execution pattern. | ||
|
||
:param run_id: ID of the run, unique in scope of the driver. | ||
:param tasks: Tasks specs that are in the group.""" | ||
pass | ||
|
||
|
||
@lifecycle.base_hook("post_task_expand") | ||
class BasePostTaskExpand(abc.ABC): | ||
@abc.abstractmethod | ||
def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, Any]): | ||
"""Hook that is called immediately after a task is expanded into separate task. Note that this is only useful | ||
in dynamic execution, although we reserve the right to add this back into the standard hamilton execution pattern. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this will be added back in -- the expansion piece is a specific dynamic execution notion. |
||
|
||
:param run_id: ID of the run, unique in scope of the driver. | ||
:param task_id: ID of the task. | ||
:param parameters: Parameters that are being passed to each of the expanded tasks.""" | ||
pass | ||
|
||
|
||
@lifecycle.base_hook("post_task_execute") | ||
class BasePostTaskExecute(abc.ABC): | ||
@abc.abstractmethod | ||
|
@@ -627,6 +664,8 @@ def post_task_execute( | |
results: Optional[Dict[str, Any]], | ||
success: bool, | ||
error: Exception, | ||
spawning_task_id: Optional[str], | ||
purpose: NodeGroupPurpose, | ||
): | ||
"""Hook called immediately after task execution. Note that this is only useful in dynamic | ||
execution, although we reserve the right to add this back into the standard hamilton execution pattern. | ||
|
@@ -637,6 +676,8 @@ def post_task_execute( | |
:param results: Results of the task | ||
:param success: Whether or not the task executed successfully | ||
:param error: The error that was raised, if any | ||
:param spawning_task_id: ID of the task that spawned this task | ||
:param purpose: Purpose of the current task group | ||
""" | ||
pass | ||
|
||
|
@@ -653,6 +694,8 @@ async def post_task_execute( | |
results: Optional[Dict[str, Any]], | ||
success: bool, | ||
error: Exception, | ||
spawning_task_id: Optional[str], | ||
purpose: NodeGroupPurpose, | ||
): | ||
"""Asynchronous Hook called immediately after task execution. Note that this is only useful in dynamic | ||
execution, although we reserve the right to add this back into the standard hamilton execution pattern. | ||
|
@@ -663,6 +706,8 @@ async def post_task_execute( | |
:param results: Results of the task | ||
:param success: Whether or not the task executed successfully | ||
:param error: The error that was raised, if any | ||
:param spawning_task_id: ID of the task that spawned this task | ||
:param purpose: Purpose of the current task group | ||
""" | ||
pass | ||
|
||
|
@@ -737,6 +782,8 @@ def do_build_result(self, *, outputs: Any) -> Any: | |
BasePostGraphConstructAsync, | ||
BasePreGraphExecute, | ||
BasePreGraphExecuteAsync, | ||
BasePostTaskGroup, | ||
BasePostTaskExpand, | ||
BasePreTaskExecute, | ||
BasePreTaskExecuteAsync, | ||
BasePreNodeExecute, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, there's a slight design issue here, curious as to your thoughts. The fact that we create a dict with all the parameterization values is not part of the contract -- the idea is we could go to having a generator where they're not all decided for now. Not married to this (it's a bit baked in that it's a list now), but curious what you're planning on using this value for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conversely, it could change to storing the whole set of values as a list regardless of whether it's a generator or not, only if this hook exists -- that could be an implementation detail we could handle later (e.g. add something that materializes it) -- this would allow us to release this now and not change the contract.