Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update / add task execution hooks #1269

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions docs/reference/lifecycle-hooks/TaskExecutionHook.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
===================================
lifecycle.api.TaskExecutionHook
===================================


.. autoclass:: hamilton.lifecycle.api.TaskExecutionHook
:special-members: __init__
:members:
:inherited-members:
9 changes: 9 additions & 0 deletions docs/reference/lifecycle-hooks/TaskGroupingHook.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
===================================
lifecycle.api.TaskGroupingHook
===================================


.. autoclass:: hamilton.lifecycle.api.TaskGroupingHook
:special-members: __init__
:members:
:inherited-members:
3 changes: 2 additions & 1 deletion docs/reference/lifecycle-hooks/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ looking forward.
NodeExecutionMethod
StaticValidator
GraphConstructionHook

TaskExecutionHook
TaskGroupingHook

Available Adapters
-------------------
Expand Down
8 changes: 8 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,14 @@ def execute(
results_cache = state.DictBasedResultCache(prehydrated_results)
# Create tasks from the grouped nodes, filtering/pruning as we go
tasks = grouping.create_task_plan(grouped_nodes, final_vars, overrides, self.adapter)

if self.adapter.does_hook("post_task_group", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_task_group",
run_id=run_id,
tasks=tasks,
)

# Create a task graph and execution state
execution_state = state.ExecutionState(
tasks, results_cache, run_id
Expand Down
4 changes: 4 additions & 0 deletions hamilton/execution/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def base_execute_task(task: TaskImplementation) -> Dict[str, Any]:
nodes=task.nodes,
inputs=task.dynamic_inputs,
overrides=task.overrides,
spawning_task_id=task.spawning_task_id,
purpose=task.purpose,
)
error = None
success = True
Expand Down Expand Up @@ -139,6 +141,8 @@ def base_execute_task(task: TaskImplementation) -> Dict[str, Any]:
results=results,
success=success,
error=error,
spawning_task_id=task.spawning_task_id,
purpose=task.purpose,
)
# This selection is for GC
# We also need to get the override values
Expand Down
7 changes: 7 additions & 0 deletions hamilton/execution/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,13 @@ def update_task_state(
self.realize_parameterized_group(
completed_task.task_id, parameterization_values, input_to_parameterize
)
if completed_task.adapter.does_hook("post_task_expand", is_async=False):
completed_task.adapter.call_all_lifecycle_hooks_sync(
"post_task_expand",
run_id=completed_task.run_id,
task_id=completed_task.task_id,
parameters=parameterization_values,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, there's a slight design issue here, curious as to your thoughts. The fact that we create a dict with all the parameterization values is not part of the contract -- the idea is we could go to having a generator where they're not all decided for now. Not married to this (it's a bit baked in that it's a list now), but curious what you're planning on using this value for?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conversely, it could change to storing the whole set of values as a list regardless of whether it's a generator or not, only if this hook exists -- that could be an implementation detail we could handle later (e.g. add something that materializes it) -- this would allow us to release this now and not change the contract.

)
else:
for candidate_task in self.base_reverse_dependencies[completed_task.base_id]:
# This means its not spawned by another task, or a node spawning group itself
Expand Down
58 changes: 58 additions & 0 deletions hamilton/lifecycle/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
# To really fix this we should move everything user-facing out of base, which is a pretty sloppy name for a package anyway
# And put it where it belongs. For now we're OK with the TYPE_CHECKING hack
if TYPE_CHECKING:
from hamilton.execution.grouping import NodeGroupPurpose, TaskSpec
from hamilton.graph import FunctionGraph
else:
NodeGroupPurpose = None
TaskSpec = None

from hamilton.graph_types import HamiltonGraph, HamiltonNode
from hamilton.lifecycle.base import (
Expand All @@ -27,6 +31,8 @@
BasePostGraphExecute,
BasePostNodeExecute,
BasePostTaskExecute,
BasePostTaskExpand,
BasePostTaskGroup,
BasePreGraphExecute,
BasePreNodeExecute,
BasePreTaskExecute,
Expand Down Expand Up @@ -379,13 +385,17 @@ def pre_task_execute(
nodes: List["node.Node"],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
):
self.run_before_task_execution(
run_id=run_id,
task_id=task_id,
nodes=[HamiltonNode.from_node(n) for n in nodes],
inputs=inputs,
overrides=overrides,
spawning_task_id=spawning_task_id,
purpose=purpose,
)

def post_task_execute(
Expand All @@ -397,6 +407,8 @@ def post_task_execute(
results: Optional[Dict[str, Any]],
success: bool,
error: Exception,
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
):
self.run_after_task_execution(
run_id=run_id,
Expand All @@ -405,6 +417,8 @@ def post_task_execute(
results=results,
success=success,
error=error,
spawning_task_id=spawning_task_id,
purpose=purpose,
)

@abc.abstractmethod
Expand All @@ -416,6 +430,8 @@ def run_before_task_execution(
nodes: List[HamiltonNode],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
**future_kwargs,
):
"""Implement this to run something after task execution. Tasks are tols used to group nodes.
Expand All @@ -428,6 +444,8 @@ def run_before_task_execution(
:param inputs: Inputs to the task
:param overrides: Overrides passed to the task
:param future_kwargs: Reserved for backwards compatibility.
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
"""
pass

Expand All @@ -441,6 +459,8 @@ def run_after_task_execution(
results: Optional[Dict[str, Any]],
success: bool,
error: Exception,
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
**future_kwargs,
):
"""Implement this to run something after task execution. See note in run_before_task_execution.
Expand All @@ -452,6 +472,8 @@ def run_after_task_execution(
:param success: Whether the task was successful
:param error: The error the task threw, if any
:param future_kwargs: Reserved for backwards compatibility.
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
"""
pass

Expand Down Expand Up @@ -614,6 +636,42 @@ def validate_graph(
return self.run_to_validate_graph(graph=HamiltonGraph.from_graph(graph))


class TaskGroupingHook(BasePostTaskGroup, BasePostTaskExpand):
"""Implement this to run something after task grouping or task expansion. This will allow you to
capture information about the tasks during `Parallelize`/`Collect` blocks in dynamic DAG execution."""

@override
@final
def post_task_group(self, *, run_id: str, tasks: List[TaskSpec]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should expose the task spec here now -- what data in it do we need? can narrow the contract to give us more flexibility...

return self.run_after_task_grouping(run_id=run_id, tasks=tasks)

@override
@final
def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, Any]):
return self.run_after_task_expansion(run_id=run_id, task_id=task_id, parameters=parameters)

@abc.abstractmethod
def run_after_task_grouping(self, *, run_id: str, tasks: List[TaskSpec], **future_kwargs):
"""Hook that is called after task grouping.
:param run_id: ID of the run, unique in scope of the driver.
:param tasks: List of tasks that were grouped together.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility.
"""
pass

@abc.abstractmethod
def run_after_task_expansion(
self, *, run_id: str, task_id: str, parameters: Dict[str, Any], **future_kwargs
):
"""Hook that is called after task expansion.
:param run_id: ID of the run, unique in scope of the driver.
:param task_id: ID of the task that was expanded.
:param parameters: Parameters that were passed to the task.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility.
"""
pass


class GraphConstructionHook(BasePostGraphConstruct, abc.ABC):
"""Hook that is run after graph construction. This allows you to register/capture info on the graph.
Note that, in the case of materialization, this may be called multiple times (once when we create the graph,
Expand Down
47 changes: 47 additions & 0 deletions hamilton/lifecycle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
# python, which (our usage of) leans type-hinting trigger-happy, this will suffice.
if TYPE_CHECKING:
from hamilton import graph, node
from hamilton.execution.grouping import NodeGroupPurpose, TaskSpec
else:
NodeGroupPurpose = None
TaskSpec = None

# All of these are internal APIs. Specifically, structure required to manage a set of
# hooks/methods/validators that we will likely expand. We store them in constants (rather than, say, a more complex single object)
Expand Down Expand Up @@ -418,6 +422,8 @@ def pre_task_execute(
nodes: List["node.Node"],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
):
"""Hook that is called immediately prior to task execution. Note that this is only useful in dynamic
execution, although we reserve the right to add this back into the standard hamilton execution pattern.
Expand All @@ -427,6 +433,8 @@ def pre_task_execute(
:param nodes: Nodes that are being executed
:param inputs: Inputs to the task
:param overrides: Overrides to task execution
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
"""
pass

Expand All @@ -442,6 +450,8 @@ async def pre_task_execute(
nodes: List["node.Node"],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
):
"""Hook that is called immediately prior to task execution. Note that this is only useful in dynamic
execution, although we reserve the right to add this back into the standard hamilton execution pattern.
Expand All @@ -451,6 +461,8 @@ async def pre_task_execute(
:param nodes: Nodes that are being executed
:param inputs: Inputs to the task
:param overrides: Overrides to task execution
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
"""
pass

Expand Down Expand Up @@ -615,6 +627,31 @@ async def post_node_execute(
pass


@lifecycle.base_hook("post_task_group")
class BasePostTaskGroup(abc.ABC):
@abc.abstractmethod
def post_task_group(self, *, run_id: str, tasks: List[TaskSpec]):
"""Hook that is called immediately after a task group is created. Note that this is only useful in dynamic
execution, although we reserve the right to add this back into the standard hamilton execution pattern.

:param run_id: ID of the run, unique in scope of the driver.
:param tasks: Tasks specs that are in the group."""
pass


@lifecycle.base_hook("post_task_expand")
class BasePostTaskExpand(abc.ABC):
@abc.abstractmethod
def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, Any]):
"""Hook that is called immediately after a task is expanded into separate task. Note that this is only useful
in dynamic execution, although we reserve the right to add this back into the standard hamilton execution pattern.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this will be added back in -- the expansion piece is a specific dynamic execution notion.


:param run_id: ID of the run, unique in scope of the driver.
:param task_id: ID of the task.
:param parameters: Parameters that are being passed to each of the expanded tasks."""
pass


@lifecycle.base_hook("post_task_execute")
class BasePostTaskExecute(abc.ABC):
@abc.abstractmethod
Expand All @@ -627,6 +664,8 @@ def post_task_execute(
results: Optional[Dict[str, Any]],
success: bool,
error: Exception,
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
):
"""Hook called immediately after task execution. Note that this is only useful in dynamic
execution, although we reserve the right to add this back into the standard hamilton execution pattern.
Expand All @@ -637,6 +676,8 @@ def post_task_execute(
:param results: Results of the task
:param success: Whether or not the task executed successfully
:param error: The error that was raised, if any
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
"""
pass

Expand All @@ -653,6 +694,8 @@ async def post_task_execute(
results: Optional[Dict[str, Any]],
success: bool,
error: Exception,
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
):
"""Asynchronous Hook called immediately after task execution. Note that this is only useful in dynamic
execution, although we reserve the right to add this back into the standard hamilton execution pattern.
Expand All @@ -663,6 +706,8 @@ async def post_task_execute(
:param results: Results of the task
:param success: Whether or not the task executed successfully
:param error: The error that was raised, if any
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
"""
pass

Expand Down Expand Up @@ -737,6 +782,8 @@ def do_build_result(self, *, outputs: Any) -> Any:
BasePostGraphConstructAsync,
BasePreGraphExecute,
BasePreGraphExecuteAsync,
BasePostTaskGroup,
BasePostTaskExpand,
BasePreTaskExecute,
BasePreTaskExecuteAsync,
BasePreNodeExecute,
Expand Down
17 changes: 17 additions & 0 deletions tests/lifecycle/lifecycle_adapters_for_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

from hamilton import node
from hamilton.execution.grouping import NodeGroupPurpose, TaskSpec
from hamilton.graph import FunctionGraph
from hamilton.lifecycle.base import (
BaseDoBuildResult,
Expand All @@ -15,6 +16,8 @@
BasePostGraphExecute,
BasePostNodeExecute,
BasePostTaskExecute,
BasePostTaskExpand,
BasePostTaskGroup,
BasePreDoAnythingHook,
BasePreGraphExecute,
BasePreNodeExecute,
Expand Down Expand Up @@ -108,6 +111,11 @@ def pre_graph_execute(
pass


class TrackingPostTaskGroupHook(ExtendToTrackCalls, BasePostTaskGroup):
def post_task_group(self, run_id: str, tasks: List[TaskSpec]):
pass


class TrackingPreTaskExecuteHook(ExtendToTrackCalls, BasePreTaskExecute):
def pre_task_execute(
self,
Expand All @@ -116,6 +124,8 @@ def pre_task_execute(
nodes: List[node.Node],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
):
pass

Expand Down Expand Up @@ -150,10 +160,17 @@ def post_task_execute(
results: Optional[Dict[str, Any]],
success: bool,
error: Exception,
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
):
pass


class TrackingPostTaskExpandHook(ExtendToTrackCalls, BasePostTaskExpand):
def post_task_expand(self, run_id: str, task_id: str, parameters: Dict[str, Any]):
pass


class TrackingPostGraphExecuteHook(ExtendToTrackCalls, BasePostGraphExecute):
def post_graph_execute(
self,
Expand Down
Loading