diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 183541c4f..2df834e2e 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -15,17 +15,17 @@ repos:
# Run the formatter.
- id: ruff-format
# args: [ --diff ] # Use for previewing changes
-- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.6.0
- hooks:
- - id: trailing-whitespace
- # ensures files are either empty or end with a blank line
- - id: end-of-file-fixer
- # sorts requirements
- - id: requirements-txt-fixer
- # valid python file
- - id: check-ast
-- repo: https://github.com/pycqa/flake8
- rev: 7.1.1
- hooks:
- - id: flake8
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.6.0
+ hooks:
+ - id: trailing-whitespace
+ # ensures files are either empty or end with a blank line
+ - id: end-of-file-fixer
+ # sorts requirements
+ - id: requirements-txt-fixer
+ # valid python file
+ - id: check-ast
+- repo: https://github.com/pycqa/flake8
+ rev: 7.1.1
+ hooks:
+ - id: flake8
diff --git a/hamilton/ad_hoc_utils.py b/hamilton/ad_hoc_utils.py
index e5a9591f5..c0bd317f0 100644
--- a/hamilton/ad_hoc_utils.py
+++ b/hamilton/ad_hoc_utils.py
@@ -1,5 +1,7 @@
"""A suite of tools for ad-hoc use"""
+from __future__ import annotations
+
import atexit
import importlib.util
import linecache
@@ -9,7 +11,7 @@
import types
import uuid
from types import ModuleType
-from typing import Callable, Optional
+from typing import Callable
def _copy_func(f):
@@ -64,7 +66,7 @@ def create_temporary_module(*functions: Callable, module_name: str = None) -> Mo
return module
-def module_from_source(source: str, module_name: Optional[str] = None) -> ModuleType:
+def module_from_source(source: str, module_name: str | None = None) -> ModuleType:
"""Create a temporary module from source code."""
module_name = module_name if module_name else _generate_unique_temp_module_name()
module_object = ModuleType(module_name)
diff --git a/hamilton/async_driver.py b/hamilton/async_driver.py
index f724773ba..9a4e80860 100644
--- a/hamilton/async_driver.py
+++ b/hamilton/async_driver.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import asyncio
import inspect
import logging
@@ -5,18 +7,21 @@
import time
import typing
import uuid
-from types import ModuleType
-from typing import Any, Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Any
import hamilton.lifecycle.base as lifecycle_base
from hamilton import base, driver, graph, lifecycle, node, telemetry
from hamilton.execution.graph_functions import create_error_message
-from hamilton.io.materialization import ExtractorFactory, MaterializerFactory
+
+if TYPE_CHECKING:
+ from types import ModuleType
+
+ from hamilton.io.materialization import ExtractorFactory, MaterializerFactory
logger = logging.getLogger(__name__)
-async def await_dict_of_tasks(task_dict: Dict[str, typing.Awaitable]) -> Dict[str, Any]:
+async def await_dict_of_tasks(task_dict: dict[str, typing.Awaitable]) -> dict[str, Any]:
"""Util to await a dictionary of tasks as asyncio.gather is kind of garbage"""
keys = sorted(task_dict.keys())
coroutines = [task_dict[key] for key in keys]
@@ -42,7 +47,7 @@ class AsyncGraphAdapter(lifecycle_base.BaseDoNodeExecute, lifecycle.ResultBuilde
def __init__(
self,
result_builder: base.ResultMixin = None,
- async_lifecycle_adapters: Optional[lifecycle_base.LifecycleAdapterSet] = None,
+ async_lifecycle_adapters: lifecycle_base.LifecycleAdapterSet | None = None,
):
"""Creates an AsyncGraphAdapter class. Note this will *only* work with the AsyncDriver class.
@@ -52,7 +57,7 @@ def __init__(
2. This does *not* work with decorators when the async function is being decorated. That is\
because that function is called directly within the decorator, so we cannot await it.
"""
- super(AsyncGraphAdapter, self).__init__()
+ super().__init__()
self.adapter = (
async_lifecycle_adapters
if async_lifecycle_adapters is not None
@@ -66,8 +71,8 @@ def do_node_execute(
*,
run_id: str,
node_: node.Node,
- kwargs: typing.Dict[str, typing.Any],
- task_id: Optional[str] = None,
+ kwargs: dict[str, typing.Any],
+ task_id: str | None = None,
) -> typing.Any:
"""Executes a node. Note this doesn't actually execute it -- rather, it returns a task.
This does *not* use async def, as we want it to be awaited on later -- this await is done
@@ -159,8 +164,8 @@ def build_result(self, **outputs: Any) -> Any:
def separate_sync_from_async(
- adapters: typing.List[lifecycle.LifecycleAdapter],
-) -> Tuple[typing.List[lifecycle.LifecycleAdapter], typing.List[lifecycle.LifecycleAdapter]]:
+ adapters: list[lifecycle.LifecycleAdapter],
+) -> tuple[list[lifecycle.LifecycleAdapter], list[lifecycle.LifecycleAdapter]]:
"""Separates the sync and async adapters from a list of adapters.
Note this only works with hooks -- we'll be dealing with methods later.
@@ -196,8 +201,8 @@ def __init__(
self,
config,
*modules,
- result_builder: Optional[base.ResultMixin] = None,
- adapters: typing.List[lifecycle.LifecycleAdapter] = None,
+ result_builder: base.ResultMixin | None = None,
+ adapters: list[lifecycle.LifecycleAdapter] = None,
):
"""Instantiates an asynchronous driver.
@@ -229,7 +234,7 @@ def __init__(
)
# it will be defaulted by the graph adapter
result_builder = result_builders[0] if len(result_builders) == 1 else None
- super(AsyncDriver, self).__init__(
+ super().__init__(
config,
*modules,
adapter=[
@@ -246,7 +251,7 @@ def __init__(
)
self.initialized = False
- async def ainit(self) -> "AsyncDriver":
+ async def ainit(self) -> AsyncDriver:
"""Initializes the driver when using async. This only exists for backwards compatibility.
In Hamilton 2.0, we will be using an asynchronous constructor.
See https://dev.to/akarshan/asynchronous-python-magic-how-to-create-awaitable-constructors-with-asyncmixin-18j5.
@@ -267,12 +272,12 @@ async def ainit(self) -> "AsyncDriver":
async def raw_execute(
self,
- final_vars: typing.List[str],
- overrides: Dict[str, Any] = None,
+ final_vars: list[str],
+ overrides: dict[str, Any] = None,
display_graph: bool = False, # don't care
- inputs: Dict[str, Any] = None,
+ inputs: dict[str, Any] = None,
_fn_graph: graph.FunctionGraph = None,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Executes the graph, returning a dictionary of strings (node keys) to final results.
:param final_vars: Variables to execute (+ upstream)
@@ -332,10 +337,10 @@ async def raw_execute(
async def execute(
self,
- final_vars: typing.List[str],
- overrides: Dict[str, Any] = None,
+ final_vars: list[str],
+ overrides: dict[str, Any] = None,
display_graph: bool = False,
- inputs: Dict[str, Any] = None,
+ inputs: dict[str, Any] = None,
) -> Any:
"""Executes computation.
@@ -386,9 +391,9 @@ async def make_coroutine():
def capture_constructor_telemetry(
self,
- error: Optional[str],
- modules: Tuple[ModuleType],
- config: Dict[str, Any],
+ error: str | None,
+ modules: tuple[ModuleType],
+ config: dict[str, Any],
adapter: base.HamiltonGraphAdapter,
):
"""Ensures we capture constructor telemetry the right way in an async context.
@@ -407,7 +412,7 @@ def capture_constructor_telemetry(
if loop.is_running():
loop.run_in_executor(
None,
- super(AsyncDriver, self).capture_constructor_telemetry,
+ super().capture_constructor_telemetry,
error,
modules,
config,
@@ -450,22 +455,20 @@ class Builder(driver.Builder):
"""
def __init__(self):
- super(Builder, self).__init__()
+ super().__init__()
def _not_supported(self, method_name: str, additional_message: str = ""):
raise ValueError(
f"Builder().{method_name}() is not supported for the async driver. {additional_message}"
)
- def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> "Builder":
+ def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> Builder:
self._not_supported("enable_dynamic_execution")
- def with_materializers(
- self, *materializers: typing.Union[ExtractorFactory, MaterializerFactory]
- ) -> "Builder":
+ def with_materializers(self, *materializers: ExtractorFactory | MaterializerFactory) -> Builder:
self._not_supported("with_materializers")
- def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder":
+ def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> Builder:
self._not_supported(
"with_adapter",
"Use with_adapters instead to pass in the tracker (or other async hooks/methods)",
diff --git a/hamilton/base.py b/hamilton/base.py
index 136c0b3fa..e3bcf3956 100644
--- a/hamilton/base.py
+++ b/hamilton/base.py
@@ -3,10 +3,12 @@
It cannot import hamilton.graph, or hamilton.driver.
"""
+from __future__ import annotations
+
import abc
import collections
import logging
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
+from typing import TYPE_CHECKING, Any, Dict
import numpy as np
import pandas as pd
@@ -17,7 +19,8 @@
try:
from . import htypes, node
except ImportError:
- import node
+ if TYPE_CHECKING:
+ import node
logger = logging.getLogger(__name__)
@@ -59,14 +62,14 @@ class DictResult(ResultMixin):
"""
@staticmethod
- def build_result(**outputs: Dict[str, Any]) -> Dict:
+ def build_result(**outputs: dict[str, Any]) -> dict:
"""This function builds a simple dict of output -> computed values."""
return outputs
- def input_types(self) -> Optional[List[Type[Type]]]:
+ def input_types(self) -> list[type[type]] | None:
return [Any]
- def output_type(self) -> Type:
+ def output_type(self) -> type:
return Dict[str, Any]
@@ -91,8 +94,8 @@ class PandasDataFrameResult(ResultMixin):
@staticmethod
def pandas_index_types(
- outputs: Dict[str, Any],
- ) -> Tuple[Dict[str, List[str]], Dict[str, List[str]], Dict[str, List[str]]]:
+ outputs: dict[str, Any],
+ ) -> tuple[dict[str, list[str]], dict[str, list[str]], dict[str, list[str]]]:
"""This function creates three dictionaries according to whether there is an index type or not.
The three dicts we create are:
@@ -107,7 +110,7 @@ def pandas_index_types(
time_indexes = collections.defaultdict(list)
no_indexes = collections.defaultdict(list)
- def index_key_name(pd_object: Union[pd.DataFrame, pd.Series]) -> str:
+ def index_key_name(pd_object: pd.DataFrame | pd.Series) -> str:
"""Creates a string helping identify the index and it's type.
Useful for disambiguating time related indexes."""
return f"{pd_object.index.__class__.__name__}:::{pd_object.index.dtype}"
@@ -143,9 +146,9 @@ def get_parent_time_index_type():
@staticmethod
def check_pandas_index_types_match(
- all_index_types: Dict[str, List[str]],
- time_indexes: Dict[str, List[str]],
- no_indexes: Dict[str, List[str]],
+ all_index_types: dict[str, list[str]],
+ time_indexes: dict[str, list[str]],
+ no_indexes: dict[str, list[str]],
) -> bool:
"""Checks that pandas index types match.
@@ -195,7 +198,7 @@ def check_pandas_index_types_match(
return types_match
@staticmethod
- def build_result(**outputs: Dict[str, Any]) -> pd.DataFrame:
+ def build_result(**outputs: dict[str, Any]) -> pd.DataFrame:
"""Builds a Pandas DataFrame from the outputs.
This function will check the index types of the outputs, and log warnings if they don't match.
@@ -227,7 +230,7 @@ def build_result(**outputs: Dict[str, Any]) -> pd.DataFrame:
return pd.DataFrame(outputs) # this does an implicit outer join based on index.
@staticmethod
- def build_dataframe_with_dataframes(outputs: Dict[str, Any]) -> pd.DataFrame:
+ def build_dataframe_with_dataframes(outputs: dict[str, Any]) -> pd.DataFrame:
"""Builds a dataframe from the outputs in an "outer join" manner based on index.
The behavior of pd.Dataframe(outputs) is that it will do an outer join based on indexes of the Series passed in.
@@ -277,12 +280,12 @@ def get_output_name(output_name: str, column_name: str) -> str:
return pd.DataFrame(flattened_outputs)
- def input_types(self) -> List[Type[Type]]:
+ def input_types(self) -> list[type[type]]:
"""Currently this just shoves anything into a dataframe. We should probably
tighten this up."""
return [Any]
- def output_type(self) -> Type:
+ def output_type(self) -> type:
return pd.DataFrame
@@ -307,7 +310,7 @@ class StrictIndexTypePandasDataFrameResult(PandasDataFrameResult):
"""
@staticmethod
- def build_result(**outputs: Dict[str, Any]) -> pd.DataFrame:
+ def build_result(**outputs: dict[str, Any]) -> pd.DataFrame:
# TODO check inputs are pd.Series, arrays, or scalars -- else error
output_index_type_tuple = PandasDataFrameResult.pandas_index_types(outputs)
indexes_match = PandasDataFrameResult.check_pandas_index_types_match(
@@ -340,7 +343,7 @@ class NumpyMatrixResult(ResultMixin):
"""
@staticmethod
- def build_result(**outputs: Dict[str, Any]) -> np.matrix:
+ def build_result(**outputs: dict[str, Any]) -> np.matrix:
"""Builds a numpy matrix from the passed in, inputs.
Note: this does not check that the inputs are all numpy arrays/array like things.
@@ -380,11 +383,11 @@ def build_result(**outputs: Dict[str, Any]) -> np.matrix:
# Create the matrix with columns as rows and then transpose
return np.asmatrix(list_of_columns).T
- def input_types(self) -> List[Type[Type]]:
+ def input_types(self) -> list[type[type]]:
"""Currently returns anything as numpy types are relatively new and"""
return [Any] # Typing
- def output_type(self) -> Type:
+ def output_type(self) -> type:
return pd.DataFrame
@@ -405,14 +408,14 @@ class SimplePythonDataFrameGraphAdapter(HamiltonGraphAdapter, PandasDataFrameRes
"""
@staticmethod
- def check_input_type(node_type: Type, input_value: Any) -> bool:
+ def check_input_type(node_type: type, input_value: Any) -> bool:
return htypes.check_input_type(node_type, input_value)
@staticmethod
- def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool:
+ def check_node_type_equivalence(node_type: type, input_type: type) -> bool:
return node_type == input_type
- def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any:
+ def execute_node(self, node: node.Node, kwargs: dict[str, Any]) -> Any:
return node.callable(**kwargs)
@@ -436,11 +439,11 @@ def __init__(self, result_builder: ResultMixin = None):
result_builder = DictResult()
self.result_builder = result_builder
- def build_result(self, **outputs: Dict[str, Any]) -> Any:
+ def build_result(self, **outputs: dict[str, Any]) -> Any:
"""Delegates to the result builder function supplied."""
return self.result_builder.build_result(**outputs)
- def output_type(self) -> Type:
+ def output_type(self) -> type:
return self.result_builder.output_type()
diff --git a/hamilton/driver.py b/hamilton/driver.py
index a2ba558d5..be0846451 100644
--- a/hamilton/driver.py
+++ b/hamilton/driver.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import abc
import functools
import importlib
@@ -13,8 +15,7 @@
import uuid
from collections.abc import Sequence # typing.Sequence is deprecated in >=3.9
from datetime import datetime
-from types import ModuleType
-from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Collection, Dict
import pandas as pd
@@ -25,6 +26,9 @@
from hamilton.io.materialization import ExtractorFactory, MaterializerFactory
from hamilton.lifecycle import base as lifecycle_base
+if TYPE_CHECKING:
+ from types import ModuleType
+
SLACK_ERROR_MESSAGE = (
"-------------------------------------------------------------------\n"
"Oh no an error! Need help with Hamilton?\n"
@@ -92,11 +96,11 @@ class GraphExecutor(abc.ABC):
def execute(
self,
fg: graph.FunctionGraph,
- final_vars: List[Union[str, Callable, Variable]],
- overrides: Dict[str, Any],
- inputs: Dict[str, Any],
+ final_vars: list[str | Callable | Variable],
+ overrides: dict[str, Any],
+ inputs: dict[str, Any],
run_id: str,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Executes a graph in a blocking function.
:param fg: Graph to execute
@@ -110,7 +114,7 @@ def execute(
pass
@abc.abstractmethod
- def validate(self, nodes_to_execute: List[node.Node]):
+ def validate(self, nodes_to_execute: list[node.Node]):
"""Validates whether the executor can execute the given graph.
Some executors allow API constructs that others do not support
(such as Parallelizable[]/Collect[])
@@ -124,14 +128,14 @@ def validate(self, nodes_to_execute: List[node.Node]):
class DefaultGraphExecutor(GraphExecutor):
DEFAULT_TASK_NAME = "root" # Not task-based, so we just assign a default name for a task
- def __init__(self, adapter: Optional[lifecycle_base.LifecycleAdapterSet] = None):
+ def __init__(self, adapter: lifecycle_base.LifecycleAdapterSet | None = None):
"""Constructor for the default graph executor.
:param adapter: Adapter to use for execution (optional).
"""
self.adapter = adapter
- def validate(self, nodes_to_execute: List[node.Node]):
+ def validate(self, nodes_to_execute: list[node.Node]):
"""The default graph executor cannot handle parallelizable[]/collect[] nodes.
:param nodes_to_execute:
@@ -148,11 +152,11 @@ def validate(self, nodes_to_execute: List[node.Node]):
def execute(
self,
fg: graph.FunctionGraph,
- final_vars: List[str],
- overrides: Dict[str, Any],
- inputs: Dict[str, Any],
+ final_vars: list[str],
+ overrides: dict[str, Any],
+ inputs: dict[str, Any],
run_id: str,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Basic executor for a function graph. Does no task-based execution, just does a DFS
and executes the graph in order, in memory."""
memoized_computation = dict() # memoized storage
@@ -169,7 +173,7 @@ def execute(
class TaskBasedGraphExecutor(GraphExecutor):
- def validate(self, nodes_to_execute: List[node.Node]):
+ def validate(self, nodes_to_execute: list[node.Node]):
"""Currently this can run every valid graph"""
pass
@@ -193,11 +197,11 @@ def __init__(
def execute(
self,
fg: graph.FunctionGraph,
- final_vars: List[str],
- overrides: Dict[str, Any],
- inputs: Dict[str, Any],
+ final_vars: list[str],
+ overrides: dict[str, Any],
+ inputs: dict[str, Any],
run_id: str,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Executes a graph, task by task. This blocks until completion.
This does the following:
@@ -291,13 +295,12 @@ def __setstate__(self, state):
@staticmethod
def normalize_adapter_input(
- adapter: Optional[
- Union[
- lifecycle_base.LifecycleAdapter,
- List[lifecycle_base.LifecycleAdapter],
- lifecycle_base.LifecycleAdapterSet,
- ]
- ],
+ adapter: None
+ | (
+ lifecycle_base.LifecycleAdapter
+ | list[lifecycle_base.LifecycleAdapter]
+ | lifecycle_base.LifecycleAdapterSet
+ ),
use_legacy_adapter: bool = True,
) -> lifecycle_base.LifecycleAdapterSet:
"""Normalizes the adapter argument in the driver to a list of adapters. Adds back the legacy adapter if needed.
@@ -385,12 +388,11 @@ def _perform_graph_validations(
def __init__(
self,
- config: Dict[str, Any],
+ config: dict[str, Any],
*modules: ModuleType,
- adapter: Optional[
- Union[lifecycle_base.LifecycleAdapter, List[lifecycle_base.LifecycleAdapter]]
- ] = None,
- _materializers: typing.Sequence[Union[ExtractorFactory, MaterializerFactory]] = None,
+ adapter: None
+ | (lifecycle_base.LifecycleAdapter | list[lifecycle_base.LifecycleAdapter]) = None,
+ _materializers: typing.Sequence[ExtractorFactory | MaterializerFactory] = None,
_graph_executor: GraphExecutor = None,
_use_legacy_adapter: bool = True,
):
@@ -455,9 +457,9 @@ def _repr_mimebundle_(self, include=None, exclude=None, **kwargs):
def capture_constructor_telemetry(
self,
- error: Optional[str],
- modules: Tuple[ModuleType],
- config: Dict[str, Any],
+ error: str | None,
+ modules: tuple[ModuleType],
+ config: dict[str, Any],
adapter: lifecycle_base.LifecycleAdapterSet,
):
"""Captures constructor telemetry. Notes:
@@ -497,13 +499,13 @@ def capture_constructor_telemetry(
@staticmethod
def validate_inputs(
fn_graph: graph.FunctionGraph,
- adapter: Union[
- lifecycle_base.LifecycleAdapter,
- List[lifecycle_base.LifecycleAdapter],
- lifecycle_base.LifecycleAdapterSet,
- ],
+ adapter: (
+ lifecycle_base.LifecycleAdapter
+ | list[lifecycle_base.LifecycleAdapter]
+ | lifecycle_base.LifecycleAdapterSet
+ ),
user_nodes: Collection[node.Node],
- inputs: typing.Optional[Dict[str, Any]] = None,
+ inputs: dict[str, Any] | None = None,
nodes_set: Collection[node.Node] = None,
):
"""Validates that inputs meet our expectations. This means that:
@@ -558,10 +560,10 @@ def validate_inputs(
def execute(
self,
- final_vars: List[Union[str, Callable, Variable]],
- overrides: Dict[str, Any] = None,
+ final_vars: list[str | Callable | Variable],
+ overrides: dict[str, Any] = None,
display_graph: bool = False,
- inputs: Dict[str, Any] = None,
+ inputs: dict[str, Any] = None,
) -> Any:
"""Executes computation.
@@ -600,7 +602,7 @@ def execute(
error, _final_vars, inputs, overrides, run_successful, duration
)
- def _create_final_vars(self, final_vars: List[Union[str, Callable, Variable]]) -> List[str]:
+ def _create_final_vars(self, final_vars: list[str | Callable | Variable]) -> list[str]:
"""Creates the final variables list - converting functions names as required.
:param final_vars:
@@ -612,10 +614,10 @@ def _create_final_vars(self, final_vars: List[Union[str, Callable, Variable]]) -
def capture_execute_telemetry(
self,
- error: Optional[str],
- final_vars: List[str],
- inputs: Dict[str, Any],
- overrides: Dict[str, Any],
+ error: str | None,
+ final_vars: list[str],
+ inputs: dict[str, Any],
+ overrides: dict[str, Any],
run_successful: bool,
duration: float,
):
@@ -651,12 +653,12 @@ def capture_execute_telemetry(
def raw_execute(
self,
- final_vars: List[str],
- overrides: Dict[str, Any] = None,
+ final_vars: list[str],
+ overrides: dict[str, Any] = None,
display_graph: bool = False,
- inputs: Dict[str, Any] = None,
+ inputs: dict[str, Any] = None,
_fn_graph: graph.FunctionGraph = None,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Raw execute function that does the meat of execute.
Don't use this entry point for execution directly. Always go through `.execute()`.
@@ -727,8 +729,8 @@ def raw_execute(
@capture_function_usage
def list_available_variables(
- self, *, tag_filter: Dict[str, Union[Optional[str], List[str]]] = None
- ) -> List[Variable]:
+ self, *, tag_filter: dict[str, str | None | list[str]] = None
+ ) -> list[Variable]:
"""Returns available variables, i.e. outputs.
These variables correspond 1:1 with nodes in the DAG, and contain the following information:
@@ -791,7 +793,7 @@ def display_all_functions(
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
- ) -> Optional["graphviz.Digraph"]: # noqa F821
+ ) -> graphviz.Digraph | None: # noqa F821
"""Displays the graph of all functions loaded!
:param output_file_path: the full URI of path + file name to save the dot file to.
@@ -836,12 +838,12 @@ def display_all_functions(
def _visualize_execution_helper(
fn_graph: graph.FunctionGraph,
adapter: lifecycle_base.LifecycleAdapterSet,
- final_vars: List[str],
+ final_vars: list[str],
output_file_path: str,
render_kwargs: dict,
- inputs: Dict[str, Any] = None,
+ inputs: dict[str, Any] = None,
graphviz_kwargs: dict = None,
- overrides: Dict[str, Any] = None,
+ overrides: dict[str, Any] = None,
show_legend: bool = True,
orient: str = "LR",
hide_inputs: bool = False,
@@ -908,12 +910,12 @@ def _visualize_execution_helper(
@capture_function_usage
def visualize_execution(
self,
- final_vars: List[Union[str, Callable, Variable]],
+ final_vars: list[str | Callable | Variable],
output_file_path: str = None,
render_kwargs: dict = None,
- inputs: Dict[str, Any] = None,
+ inputs: dict[str, Any] = None,
graphviz_kwargs: dict = None,
- overrides: Dict[str, Any] = None,
+ overrides: dict[str, Any] = None,
show_legend: bool = True,
orient: str = "LR",
hide_inputs: bool = False,
@@ -922,7 +924,7 @@ def visualize_execution(
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
- ) -> Optional["graphviz.Digraph"]: # noqa F821
+ ) -> graphviz.Digraph | None: # noqa F821
"""Visualizes Execution.
Note: overrides are not handled at this time.
@@ -980,9 +982,9 @@ def visualize_execution(
@capture_function_usage
def export_execution(
self,
- final_vars: List[str],
- inputs: Dict[str, Any] = None,
- overrides: Dict[str, Any] = None,
+ final_vars: list[str],
+ inputs: dict[str, Any] = None,
+ overrides: dict[str, Any] = None,
) -> str:
"""Method to create JSON representation of the Graph.
@@ -1002,7 +1004,7 @@ def export_execution(
@capture_function_usage
def has_cycles(
self,
- final_vars: List[Union[str, Callable, Variable]],
+ final_vars: list[str | Callable | Variable],
_fn_graph: graph.FunctionGraph = None,
) -> bool:
"""Checks that the created graph does not have cycles.
@@ -1018,7 +1020,7 @@ def has_cycles(
return self.graph.has_cycles(nodes, user_nodes)
@capture_function_usage
- def what_is_downstream_of(self, *node_names: str) -> List[Variable]:
+ def what_is_downstream_of(self, *node_names: str) -> list[Variable]:
"""Tells you what is downstream of this function(s), i.e. node(s).
:param node_names: names of function(s) that are starting points for traversing the graph.
@@ -1042,7 +1044,7 @@ def display_downstream_of(
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
- ) -> Optional["graphviz.Digraph"]: # noqa F821
+ ) -> graphviz.Digraph | None: # noqa F821
"""Creates a visualization of the DAG starting from the passed in function name(s).
Note: for any "node" visualized, we will also add its parents to the visualization as well, so
@@ -1110,7 +1112,7 @@ def display_upstream_of(
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
- ) -> Optional["graphviz.Digraph"]: # noqa F821
+ ) -> graphviz.Digraph | None: # noqa F821
"""Creates a visualization of the DAG going backwards from the passed in function name(s).
Note: for any "node" visualized, we will also add its parents to the visualization as well, so
@@ -1160,7 +1162,7 @@ def display_upstream_of(
logger.warning(f"Unable to import {e}", exc_info=True)
@capture_function_usage
- def what_is_upstream_of(self, *node_names: str) -> List[Variable]:
+ def what_is_upstream_of(self, *node_names: str) -> list[Variable]:
"""Tells you what is upstream of this function(s), i.e. node(s).
:param node_names: names of function(s) that are starting points for traversing the graph backwards.
@@ -1173,7 +1175,7 @@ def what_is_upstream_of(self, *node_names: str) -> List[Variable]:
@capture_function_usage
def what_is_the_path_between(
self, upstream_node_name: str, downstream_node_name: str
- ) -> List[Variable]:
+ ) -> list[Variable]:
"""Tells you what nodes are on the path between two nodes.
Note: this is inclusive of the two nodes, and returns an unsorted list of nodes.
@@ -1199,7 +1201,7 @@ def what_is_the_path_between(
def _get_nodes_between(
self, upstream_node_name: str, downstream_node_name: str
- ) -> Set[node.Node]:
+ ) -> set[node.Node]:
"""Gets the nodes representing the path between two nodes, inclusive of the two nodes.
Assumes that the nodes exist in the graph.
@@ -1219,7 +1221,7 @@ def visualize_path_between(
self,
upstream_node_name: str,
downstream_node_name: str,
- output_file_path: Optional[str] = None,
+ output_file_path: str | None = None,
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
strict_path_visualization: bool = False,
@@ -1230,7 +1232,7 @@ def visualize_path_between(
show_schema: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
- ) -> Optional["graphviz.Digraph"]: # noqa F821
+ ) -> graphviz.Digraph | None: # noqa F821
"""Visualizes the path between two nodes.
This is useful for debugging and understanding the path between two nodes.
@@ -1321,8 +1323,8 @@ def visualize_path_between(
logger.warning(f"Unable to import {e}", exc_info=True)
def _process_materializers(
- self, materializers: typing.Sequence[Union[MaterializerFactory, ExtractorFactory]]
- ) -> Tuple[List[MaterializerFactory], List[ExtractorFactory]]:
+ self, materializers: typing.Sequence[MaterializerFactory | ExtractorFactory]
+ ) -> tuple[list[MaterializerFactory], list[ExtractorFactory]]:
"""Processes materializers, splitting them into materializers and extractors.
Note that this also sanitizes the variable names in the materializer dependencies,
so one can pass in functions instead of strings.
@@ -1342,13 +1344,11 @@ def _process_materializers(
@capture_function_usage
def materialize(
self,
- *materializers: Union[
- materialization.MaterializerFactory, materialization.ExtractorFactory
- ],
- additional_vars: List[Union[str, Callable, Variable]] = None,
- overrides: Dict[str, Any] = None,
- inputs: Dict[str, Any] = None,
- ) -> Tuple[Any, Dict[str, Any]]:
+ *materializers: (materialization.MaterializerFactory | materialization.ExtractorFactory),
+ additional_vars: list[str | Callable | Variable] = None,
+ overrides: dict[str, Any] = None,
+ inputs: dict[str, Any] = None,
+ ) -> tuple[Any, dict[str, Any]]:
"""Executes and materializes with ad-hoc materializers (`to`) and extractors (`from_`).This does the following:
1. Creates a new graph, appending the desired materialization nodes and prepending the desired extraction nodes
@@ -1574,13 +1574,13 @@ def materialize(
@capture_function_usage
def visualize_materialization(
self,
- *materializers: Union[MaterializerFactory, ExtractorFactory],
+ *materializers: MaterializerFactory | ExtractorFactory,
output_file_path: str = None,
render_kwargs: dict = None,
- additional_vars: List[Union[str, Callable, Variable]] = None,
- inputs: Dict[str, Any] = None,
+ additional_vars: list[str | Callable | Variable] = None,
+ inputs: dict[str, Any] = None,
graphviz_kwargs: dict = None,
- overrides: Dict[str, Any] = None,
+ overrides: dict[str, Any] = None,
show_legend: bool = True,
orient: str = "LR",
hide_inputs: bool = False,
@@ -1589,7 +1589,7 @@ def visualize_materialization(
custom_style_function: Callable = None,
bypass_validation: bool = False,
keep_dot: bool = False,
- ) -> Optional["graphviz.Digraph"]: # noqa F821
+ ) -> graphviz.Digraph | None: # noqa F821
"""Visualizes materialization. This helps give you a sense of how materialization
will impact the DAG.
@@ -1643,9 +1643,9 @@ def visualize_materialization(
def validate_execution(
self,
- final_vars: List[Union[str, Callable, Variable]],
- overrides: Dict[str, Any] = None,
- inputs: Dict[str, Any] = None,
+ final_vars: list[str | Callable | Variable],
+ overrides: dict[str, Any] = None,
+ inputs: dict[str, Any] = None,
):
"""Validates execution of the graph. One can call this to validate execution, independently of actually executing.
Note this has no return -- it will raise a ValueError if there is an issue.
@@ -1662,9 +1662,9 @@ def validate_execution(
def validate_materialization(
self,
*materializers: materialization.MaterializerFactory,
- additional_vars: List[Union[str, Callable, Variable]] = None,
- overrides: Dict[str, Any] = None,
- inputs: Dict[str, Any] = None,
+ additional_vars: list[str | Callable | Variable] = None,
+ overrides: dict[str, Any] = None,
+ inputs: dict[str, Any] = None,
):
"""Validates materialization of the graph. Effectively .materialize() with a dry-run.
Note this has no return -- it will raise a ValueError if there is an issue.
@@ -1710,7 +1710,7 @@ def __init__(self):
self.legacy_graph_adapter = None
# Standard execution fields
- self.adapters: List[lifecycle_base.LifecycleAdapter] = []
+ self.adapters: list[lifecycle_base.LifecycleAdapter] = []
# Dynamic execution fields
self.execution_manager = None
@@ -1730,7 +1730,7 @@ def _require_field_set(self, field: str, message: str, unset_value: Any = None):
if getattr(self, field) == unset_value:
raise ValueError(message)
- def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> "Builder":
+ def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> Builder:
"""Enables the Parallelizable[] type, which in turn enables:
1. Grouped execution into tasks
2. Parallel execution
@@ -1744,7 +1744,7 @@ def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) ->
self.v2_executor = True
return self
- def with_config(self, config: Dict[str, Any]) -> "Builder":
+ def with_config(self, config: dict[str, Any]) -> Builder:
"""Adds the specified configuration to the config.
This can be called multilple times -- later calls will take precedence.
@@ -1754,7 +1754,7 @@ def with_config(self, config: Dict[str, Any]) -> "Builder":
self.config.update(config)
return self
- def with_modules(self, *modules: ModuleType) -> "Builder":
+ def with_modules(self, *modules: ModuleType) -> Builder:
"""Adds the specified modules to the modules list.
This can be called multiple times -- later calls will take precedence.
@@ -1764,7 +1764,7 @@ def with_modules(self, *modules: ModuleType) -> "Builder":
self.modules.extend(modules)
return self
- def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder":
+ def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> Builder:
"""Sets the adapter to use.
:param adapter: Adapter to use.
@@ -1774,7 +1774,7 @@ def with_adapter(self, adapter: base.HamiltonGraphAdapter) -> "Builder":
self.legacy_graph_adapter = adapter
return self
- def with_adapters(self, *adapters: lifecycle_base.LifecycleAdapter) -> "Builder":
+ def with_adapters(self, *adapters: lifecycle_base.LifecycleAdapter) -> Builder:
"""Sets the adapter to use.
:param adapter: Adapter to use.
@@ -1783,9 +1783,7 @@ def with_adapters(self, *adapters: lifecycle_base.LifecycleAdapter) -> "Builder"
self.adapters.extend(adapters)
return self
- def with_materializers(
- self, *materializers: Union[ExtractorFactory, MaterializerFactory]
- ) -> "Builder":
+ def with_materializers(self, *materializers: ExtractorFactory | MaterializerFactory) -> Builder:
"""Add materializer nodes to the `Driver`
The generated nodes can be referenced by name in `.execute()`
@@ -1807,7 +1805,7 @@ def with_materializers(
self.materializers.extend(materializers)
return self
- def with_execution_manager(self, execution_manager: executors.ExecutionManager) -> "Builder":
+ def with_execution_manager(self, execution_manager: executors.ExecutionManager) -> Builder:
"""Sets the execution manager to use. Note that this cannot be used if local_executor
or remote_executor are also set
@@ -1824,7 +1822,7 @@ def with_execution_manager(self, execution_manager: executors.ExecutionManager)
self.execution_manager = execution_manager
return self
- def with_remote_executor(self, remote_executor: executors.TaskExecutor) -> "Builder":
+ def with_remote_executor(self, remote_executor: executors.TaskExecutor) -> Builder:
"""Sets the execution manager to use. Note that this cannot be used if local_executor
or remote_executor are also set
@@ -1840,7 +1838,7 @@ def with_remote_executor(self, remote_executor: executors.TaskExecutor) -> "Buil
self.remote_executor = remote_executor
return self
- def with_local_executor(self, local_executor: executors.TaskExecutor) -> "Builder":
+ def with_local_executor(self, local_executor: executors.TaskExecutor) -> Builder:
"""Sets the execution manager to use. Note that this cannot be used if local_executor
or remote_executor are also set
@@ -1856,7 +1854,7 @@ def with_local_executor(self, local_executor: executors.TaskExecutor) -> "Builde
self.local_executor = local_executor
return self
- def with_grouping_strategy(self, grouping_strategy: grouping.GroupingStrategy) -> "Builder":
+ def with_grouping_strategy(self, grouping_strategy: grouping.GroupingStrategy) -> Builder:
"""Sets a node grouper, which tells the driver how to group nodes into tasks for execution.
:param node_grouper: Node grouper to use.
@@ -1907,7 +1905,7 @@ def build(self) -> Driver:
_use_legacy_adapter=False,
)
- def copy(self) -> "Builder":
+ def copy(self) -> Builder:
"""Creates a copy of the current state of this Builder.
NOTE. The copied Builder currently holds reference of Builder attributes
diff --git a/hamilton/graph.py b/hamilton/graph.py
index c4a18b655..b3143bebd 100644
--- a/hamilton/graph.py
+++ b/hamilton/graph.py
@@ -6,14 +6,15 @@
Note: one should largely consider the code in this module to be "private".
"""
+from __future__ import annotations
+
import inspect
import logging
import os.path
import pathlib
import uuid
from enum import Enum
-from types import ModuleType
-from typing import Any, Callable, Collection, Dict, FrozenSet, List, Optional, Set, Tuple, Type
+from typing import TYPE_CHECKING, Any, Callable, Collection
import hamilton.lifecycle.base as lifecycle_base
from hamilton import graph_types, node
@@ -22,7 +23,11 @@
from hamilton.function_modifiers.metadata import schema
from hamilton.graph_utils import find_functions
from hamilton.htypes import get_type_as_string, types_match
-from hamilton.node import Node
+
+if TYPE_CHECKING:
+ from types import ModuleType
+
+ from hamilton.node import Node
logger = logging.getLogger(__name__)
@@ -39,9 +44,9 @@ class VisualizationNodeModifiers(Enum):
def add_dependency(
func_node: node.Node,
func_name: str,
- nodes: Dict[str, node.Node],
+ nodes: dict[str, node.Node],
param_name: str,
- param_type: Type,
+ param_type: type,
adapter: lifecycle_base.LifecycleAdapterSet,
):
"""Adds dependencies to the node objects.
@@ -114,7 +119,7 @@ def add_dependency(
def update_dependencies(
- nodes: Dict[str, node.Node],
+ nodes: dict[str, node.Node],
adapter: lifecycle_base.LifecycleAdapterSet,
reset_dependencies: bool = True,
):
@@ -143,10 +148,10 @@ def update_dependencies(
def create_function_graph(
*modules: ModuleType,
- config: Dict[str, Any],
+ config: dict[str, Any],
adapter: lifecycle_base.LifecycleAdapterSet = None,
- fg: Optional["FunctionGraph"] = None,
-) -> Dict[str, node.Node]:
+ fg: FunctionGraph | None = None,
+) -> dict[str, node.Node]:
"""Creates a graph of all available functions & their dependencies.
:param modules: A set of modules over which one wants to compute the function graph
:param config: Dictionary that we will inspect to get values from in building the function graph.
@@ -197,10 +202,10 @@ def _check_keyword_args_only(func: Callable) -> bool:
def create_graphviz_graph(
- nodes: Set[node.Node],
+ nodes: set[node.Node],
comment: str,
graphviz_kwargs: dict,
- node_modifiers: Dict[str, Set[VisualizationNodeModifiers]],
+ node_modifiers: dict[str, set[VisualizationNodeModifiers]],
strictly_display_only_nodes_passed_in: bool,
show_legend: bool = True,
orient: str = "LR",
@@ -209,7 +214,7 @@ def create_graphviz_graph(
display_fields: bool = True,
custom_style_function: Callable = None,
config: dict = None,
-) -> "graphviz.Digraph": # noqa: F821
+) -> graphviz.Digraph: # noqa: F821
"""Helper function to create a graphviz graph.
:param nodes: The set of computational nodes
@@ -255,8 +260,8 @@ def create_graphviz_graph(
def _get_node_label(
n: node.Node,
- name: Optional[str] = None,
- type_string: Optional[str] = None,
+ name: str | None = None,
+ type_string: str | None = None,
) -> str:
"""Get a graphviz HTML-like node label. It uses the DAG node
name and type but values can be overridden. Overriding is currently
@@ -270,7 +275,7 @@ def _get_node_label(
return f"<{name}
{type_string}>"
- def _get_input_label(input_nodes: FrozenSet[node.Node]) -> str:
+ def _get_input_label(input_nodes: frozenset[node.Node]) -> str:
"""Get a graphviz HTML-like node label formatted aspyer a table.
Each row is a different input node with one column containing
the name and the other the type.
@@ -301,7 +306,7 @@ def _get_node_type(n: node.Node) -> str:
else:
return "function"
- def _get_node_style(node_type: str) -> Dict[str, str]:
+ def _get_node_style(node_type: str) -> dict[str, str]:
"""Get the style of a node type.
Graphviz needs values to be strings.
"""
@@ -338,7 +343,7 @@ def _get_node_style(node_type: str) -> Dict[str, str]:
return node_style
- def _get_function_modifier_style(modifier: str) -> Dict[str, str]:
+ def _get_function_modifier_style(modifier: str) -> dict[str, str]:
"""Get the style of a modifier. The dictionary returned
is used to overwrite values of the base node style.
Graphviz needs values to be strings.
@@ -364,7 +369,7 @@ def _get_function_modifier_style(modifier: str) -> Dict[str, str]:
return modifier_style
- def _get_edge_style(from_type: str, to_type: str) -> Dict:
+ def _get_edge_style(from_type: str, to_type: str) -> dict:
"""
Graphviz needs values to be strings.
@@ -384,7 +389,7 @@ def _get_edge_style(from_type: str, to_type: str) -> Dict:
return edge_style
def _get_legend(
- node_types: Set[str], extra_legend_nodes: Dict[Tuple[str, str], Dict[str, str]]
+ node_types: set[str], extra_legend_nodes: dict[tuple[str, str], dict[str, str]]
):
"""Create a visualization legend as a graphviz subgraph. The legend includes the
node types and modifiers presente in the visualization.
@@ -547,7 +552,7 @@ def _get_legend(
seen_node_types.add("cluster")
seen_node_types.add("field")
- def _create_equal_length_cols(schema_tag: str) -> List[str]:
+ def _create_equal_length_cols(schema_tag: str) -> list[str]:
cols = schema_tag.split(",")
for i in range(len(cols)):
@@ -653,8 +658,8 @@ def _insert_space_after_colon(col: str) -> str:
def create_networkx_graph(
- nodes: Set[node.Node], user_nodes: Set[node.Node], name: str
-) -> "networkx.DiGraph": # noqa: F821
+ nodes: set[node.Node], user_nodes: set[node.Node], name: str
+) -> networkx.DiGraph: # noqa: F821
"""Helper function to create a networkx graph.
:param nodes: The set of computational nodes
@@ -684,8 +689,8 @@ class FunctionGraph:
def __init__(
self,
- nodes: Dict[str, Node],
- config: Dict[str, Any],
+ nodes: dict[str, Node],
+ config: dict[str, Any],
adapter: lifecycle_base.LifecycleAdapterSet = None,
):
"""Initializes a function graph from specified nodes. See note on `from_modules` if you
@@ -705,7 +710,7 @@ def __init__(
@staticmethod
def from_modules(
*modules: ModuleType,
- config: Dict[str, Any],
+ config: dict[str, Any],
adapter: lifecycle_base.LifecycleAdapterSet = None,
):
"""Initializes a function graph from the specified modules. Note that this was the old
@@ -723,7 +728,7 @@ def from_modules(
nodes = create_function_graph(*modules, config=config, adapter=adapter)
return FunctionGraph(nodes, config, adapter)
- def with_nodes(self, nodes: Dict[str, Node]) -> "FunctionGraph":
+ def with_nodes(self, nodes: dict[str, Node]) -> FunctionGraph:
"""Creates a new function graph with the additional specified nodes.
Note that if there is a duplication in the node definitions,
it will error out.
@@ -743,10 +748,10 @@ def config(self):
return self._config
@property
- def decorator_counter(self) -> Dict[str, int]:
+ def decorator_counter(self) -> dict[str, int]:
return fm_base.DECORATOR_COUNTER
- def get_nodes(self) -> List[node.Node]:
+ def get_nodes(self) -> list[node.Node]:
return list(self.nodes.values())
def display_all(
@@ -761,7 +766,7 @@ def display_all(
display_fields: bool = True,
custom_style_function: Callable = None,
keep_dot: bool = False,
- ) -> Optional["graphviz.Digraph"]: # noqa F821
+ ) -> graphviz.Digraph | None: # noqa F821
"""Displays & saves a dot file of the entire DAG structure constructed.
:param output_file_path: the place to save the files.
@@ -807,7 +812,7 @@ def display_all(
keep_dot=keep_dot,
)
- def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool:
+ def has_cycles(self, nodes: set[node.Node], user_nodes: set[node.Node]) -> bool:
"""Checks that the graph created does not contain cycles.
:param nodes: the set of nodes that need to be computed.
@@ -817,7 +822,7 @@ def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool:
cycles = self.get_cycles(nodes, user_nodes)
return True if cycles else False
- def get_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> List[List[str]]:
+ def get_cycles(self, nodes: set[node.Node], user_nodes: set[node.Node]) -> list[list[str]]:
"""Returns cycles found in the graph.
:param nodes: the set of nodes that need to be computed.
@@ -838,11 +843,11 @@ def get_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> List[
@staticmethod
def display(
- nodes: Set[node.Node],
- output_file_path: Optional[str] = None,
+ nodes: set[node.Node],
+ output_file_path: str | None = None,
render_kwargs: dict = None,
graphviz_kwargs: dict = None,
- node_modifiers: Dict[str, Set[VisualizationNodeModifiers]] = None,
+ node_modifiers: dict[str, set[VisualizationNodeModifiers]] = None,
strictly_display_only_passed_in_nodes: bool = False,
show_legend: bool = True,
orient: str = "LR",
@@ -852,7 +857,7 @@ def display(
custom_style_function: Callable = None,
config: dict = None,
keep_dot: bool = False,
- ) -> Optional["graphviz.Digraph"]: # noqa F821
+ ) -> graphviz.Digraph | None: # noqa F821
"""Function to display the graph represented by the passed in nodes.
The output file format is determined through the following steps, each one overwriting the previous one:
@@ -932,7 +937,7 @@ def display(
pathlib.Path(output_file_path).write_bytes(dot.pipe(**kwargs))
return dot
- def get_impacted_nodes(self, var_changes: List[str]) -> Set[node.Node]:
+ def get_impacted_nodes(self, var_changes: list[str]) -> set[node.Node]:
"""DEPRECATED - use `get_downstream_nodes` instead."""
logger.warning(
"FunctionGraph.get_impacted_nodes is deprecated. "
@@ -941,7 +946,7 @@ def get_impacted_nodes(self, var_changes: List[str]) -> Set[node.Node]:
)
return self.get_downstream_nodes(var_changes)
- def get_downstream_nodes(self, var_changes: List[str]) -> Set[node.Node]:
+ def get_downstream_nodes(self, var_changes: list[str]) -> set[node.Node]:
"""Given our function graph, and a list of nodes that are changed,
returns the subgraph that they will impact.
@@ -955,10 +960,10 @@ def get_downstream_nodes(self, var_changes: List[str]) -> Set[node.Node]:
def get_upstream_nodes(
self,
- final_vars: List[str],
- runtime_inputs: Dict[str, Any] = None,
- runtime_overrides: Dict[str, Any] = None,
- ) -> Tuple[Set[node.Node], Set[node.Node]]:
+ final_vars: list[str],
+ runtime_inputs: dict[str, Any] = None,
+ runtime_overrides: dict[str, Any] = None,
+ ) -> tuple[set[node.Node], set[node.Node]]:
"""Given our function graph, and a list of desired output variables, returns the subgraph
required to compute them.
@@ -972,7 +977,7 @@ def get_upstream_nodes(
:return: a tuple of sets: - set of all nodes. - subset of nodes that human input is required for.
"""
- def next_nodes_function(n: node.Node) -> List[node.Node]:
+ def next_nodes_function(n: node.Node) -> list[node.Node]:
deps = []
if runtime_overrides is not None and n.name in runtime_overrides:
return deps
@@ -996,7 +1001,7 @@ def next_nodes_function(n: node.Node) -> List[node.Node]:
next_nodes_function, starting_nodes=final_vars, runtime_inputs=runtime_inputs
)
- def nodes_between(self, start: str, end: str) -> Set[node.Node]:
+ def nodes_between(self, start: str, end: str) -> set[node.Node]:
"""Given our function graph, and a list of desired output variables, returns the subgraph
required to compute them. Note that this returns an empty set if no path exists.
@@ -1018,8 +1023,8 @@ def nodes_between(self, start: str, end: str) -> Set[node.Node]:
def directional_dfs_traverse(
self,
next_nodes_fn: Callable[[node.Node], Collection[node.Node]],
- starting_nodes: List[str],
- runtime_inputs: Dict[str, Any] = None,
+ starting_nodes: list[str],
+ runtime_inputs: dict[str, Any] = None,
):
"""Traverses the DAG directionally using a DFS.
@@ -1061,11 +1066,11 @@ def dfs_traverse(node: node.Node):
def execute(
self,
nodes: Collection[node.Node] = None,
- computed: Dict[str, Any] = None,
- overrides: Dict[str, Any] = None,
- inputs: Dict[str, Any] = None,
+ computed: dict[str, Any] = None,
+ overrides: dict[str, Any] = None,
+ inputs: dict[str, Any] = None,
run_id: str = None,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Executes the DAG, given potential inputs/previously computed components.
:param nodes: Nodes to compute
diff --git a/hamilton/graph_types.py b/hamilton/graph_types.py
index 978f8afed..c55d792d4 100644
--- a/hamilton/graph_types.py
+++ b/hamilton/graph_types.py
@@ -1,5 +1,7 @@
"""Module for external-facing graph constructs. These help the user navigate/manage the graph as needed."""
+from __future__ import annotations
+
import ast
import functools
import hashlib
@@ -62,7 +64,7 @@ def _remove_docs_and_comments(source: str) -> str:
return ast.unparse(parsed)
-def hash_source_code(source: typing.Union[str, typing.Callable], strip: bool = False) -> str:
+def hash_source_code(source: str | typing.Callable, strip: bool = False) -> str:
"""Hashes the source code of a function (str).
The `strip` parameter requires Python 3.9
@@ -93,13 +95,13 @@ class HamiltonNode:
Furthermore, we can always add attributes and maintain backwards compatibility."""
name: str
- type: typing.Type
- tags: typing.Dict[str, typing.Union[str, typing.List[str]]]
+ type: type
+ tags: dict[str, str | list[str]]
is_external_input: bool
- originating_functions: typing.Optional[typing.Tuple[typing.Callable, ...]]
- documentation: typing.Optional[str]
- required_dependencies: typing.Set[str]
- optional_dependencies: typing.Set[str]
+ originating_functions: tuple[typing.Callable, ...] | None
+ documentation: str | None
+ required_dependencies: set[str]
+ optional_dependencies: set[str]
def as_dict(self):
"""Create a dictionary representation of the Node that is JSON serializable"""
@@ -119,7 +121,7 @@ def as_dict(self):
}
@staticmethod
- def from_node(n: node.Node) -> "HamiltonNode":
+ def from_node(n: node.Node) -> HamiltonNode:
"""Creates a HamiltonNode from a Node (Hamilton's internal representation).
:param n: Node to create the Variable from.
@@ -145,7 +147,7 @@ def from_node(n: node.Node) -> "HamiltonNode":
)
@functools.cached_property
- def version(self) -> typing.Optional[str]:
+ def version(self) -> str | None:
"""Generate a hash of the node originating function source code.
Note that this will be `None` if the node is an external input/has no
@@ -184,11 +186,11 @@ class HamiltonGraph:
Note that you do not construct this class directly -- instead, you will get this at various points in the API.
"""
- nodes: typing.List[HamiltonNode]
+ nodes: list[HamiltonNode]
# store the original graph for internal use
@staticmethod
- def from_graph(fn_graph: "graph.FunctionGraph") -> "HamiltonGraph":
+ def from_graph(fn_graph: graph.FunctionGraph) -> HamiltonGraph:
"""Creates a HamiltonGraph from a FunctionGraph (Hamilton's internal representation).
:param fn_graph: FunctionGraph to convert
@@ -209,7 +211,7 @@ def version(self) -> str:
return hashlib.sha256(str(sorted_node_versions).encode()).hexdigest()
@functools.cached_property
- def __nodes_lookup(self) -> typing.Dict[str, HamiltonNode]:
+ def __nodes_lookup(self) -> dict[str, HamiltonNode]:
"""Cache the mapping {node_name: node} for faster `__getitem__`"""
return {n.name: n for n in self.nodes}
@@ -221,8 +223,6 @@ def __getitem__(self, key: str) -> HamiltonNode:
"""
return self.__nodes_lookup[key]
- def filter_nodes(
- self, filter: typing.Callable[[HamiltonNode], bool]
- ) -> typing.List[HamiltonNode]:
+ def filter_nodes(self, filter: typing.Callable[[HamiltonNode], bool]) -> list[HamiltonNode]:
"""Return Hamilton nodes matching the filter criteria"""
return [n for n in self.nodes if filter(n) is True]
diff --git a/hamilton/graph_utils.py b/hamilton/graph_utils.py
index 8c3695b4d..32219ece9 100644
--- a/hamilton/graph_utils.py
+++ b/hamilton/graph_utils.py
@@ -1,13 +1,17 @@
+from __future__ import annotations
+
import inspect
-from types import ModuleType
-from typing import Callable, List, Tuple
+from typing import TYPE_CHECKING, Callable
+
+if TYPE_CHECKING:
+ from types import ModuleType
def is_submodule(child: ModuleType, parent: ModuleType):
return parent.__name__ in child.__name__
-def find_functions(function_module: ModuleType) -> List[Tuple[str, Callable]]:
+def find_functions(function_module: ModuleType) -> list[tuple[str, Callable]]:
"""Function to determine the set of functions we want to build a graph from.
This iterates through the function module and grabs all function definitions.
diff --git a/hamilton/htypes.py b/hamilton/htypes.py
index 195bbe077..e3d76d277 100644
--- a/hamilton/htypes.py
+++ b/hamilton/htypes.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
import inspect
import sys
import typing
from abc import ABC
-from typing import Any, Generator, Optional, Tuple, Type, TypeVar, Union
+from typing import Any, Generator, TypeVar, Union
import typing_inspect
@@ -15,7 +17,7 @@
BASE_ARGS_FOR_GENERICS = (typing.T,)
-def _safe_subclass(candidate_type: Type, base_type: Type) -> bool:
+def _safe_subclass(candidate_type: type, base_type: type) -> bool:
"""Safely checks subclass, returning False if python's subclass does not work.
This is *not* a true subclass check, and will not tell you whether hamilton
considers the types to be equivalent. Rather, it is used to short-circuit further
@@ -36,7 +38,7 @@ def _safe_subclass(candidate_type: Type, base_type: Type) -> bool:
return False
-def custom_subclass_check(requested_type: Type, param_type: Type):
+def custom_subclass_check(requested_type: type, param_type: type):
"""This is a custom check around generics & classes. It probably misses a few edge cases.
We will likely need to revisit this in the future (perhaps integrate with graphadapter?)
@@ -92,7 +94,7 @@ def custom_subclass_check(requested_type: Type, param_type: Type):
return False
-def get_type_as_string(type_: Type) -> Optional[str]:
+def get_type_as_string(type_: type) -> str | None:
"""Get a string representation of a type.
The logic supports the evolution of the type system between 3.8 and 3.10.
@@ -113,7 +115,7 @@ def get_type_as_string(type_: Type) -> Optional[str]:
return type_string
-def types_match(param_type: Type[Type], required_node_type: Any) -> bool:
+def types_match(param_type: type[type], required_node_type: Any) -> bool:
"""Checks that we have "types" that "match".
Matching can be loose here -- and depends on the adapter being used as to what is
@@ -165,26 +167,21 @@ def types_match(param_type: Type[Type], required_node_type: Any) -> bool:
# Before 3.9 we use typing_extensions
import typing_extensions
- column = typing_extensions.Annotated
+ _get_origin = typing_extensions.get_origin
+ _get_args = typing_extensions.get_args
+ column = typing_extensions.Annotated
else:
ANNOTATE_ALLOWED = True
- from typing import Annotated, Type
-
- column = Annotated
-
-if _version_tuple < (3, 9, 0):
- import typing_extensions
-
- _get_origin = typing_extensions.get_origin
- _get_args = typing_extensions.get_args
-else:
+ from typing import Annotated
from typing import get_args as _get_args
from typing import get_origin as _get_origin
+ column = Annotated
+
-def _is_annotated_type(type_: Type[Type]) -> bool:
+def _is_annotated_type(type_: type[type]) -> bool:
"""Utility function to tell if a type is Annotated"""
return _get_origin(type_) == column
@@ -204,7 +201,7 @@ class InvalidTypeException(Exception):
)
-def _is_valid_series_type(candidate_type: Type[Type]) -> bool:
+def _is_valid_series_type(candidate_type: type[type]) -> bool:
"""Tells if something is a valid series type, using the registry we have.
:param candidate_type: Type to check
@@ -218,7 +215,7 @@ def _is_valid_series_type(candidate_type: Type[Type]) -> bool:
return False
-def validate_type_annotation(annotation: Type[Type]):
+def validate_type_annotation(annotation: type[type]):
"""Validates a type annotation for a hamilton function.
If it is not an Annotated type, it will be fine.
If it is the Annotated type, it will check that
@@ -253,7 +250,7 @@ def validate_type_annotation(annotation: Type[Type]):
)
-def get_type_information(some_type: Any) -> Tuple[Type[Type], list]:
+def get_type_information(some_type: Any) -> tuple[type[type], list]:
"""Gets the type information for a given type.
If it is an annotated type, it will return the original type and the annotation.
@@ -283,7 +280,7 @@ class Parallelizable(typing.Generator[U, None, None], ABC):
pass
-def is_parallelizable_type(type_: Type) -> bool:
+def is_parallelizable_type(type_: type) -> bool:
return issubclass(type_, Parallelizable)
@@ -291,7 +288,7 @@ class Collect(Generator[V, None, None], ABC):
pass
-def check_input_type(node_type: Type, input_value: Any) -> bool:
+def check_input_type(node_type: type, input_value: Any) -> bool:
"""Checks an input value against the declare input type. This is a utility function to be
used for checking types against values. Note we are looser here than in custom_subclass_check,
as runtime-typing is less specific.
diff --git a/hamilton/models.py b/hamilton/models.py
index 224c12836..f62bc8f12 100644
--- a/hamilton/models.py
+++ b/hamilton/models.py
@@ -1,7 +1,10 @@
+from __future__ import annotations
+
import abc
-from typing import Any, Dict, List
+from typing import TYPE_CHECKING, Any
-import pandas as pd
+if TYPE_CHECKING:
+ import pandas as pd
class DynamicTransformBase(abc.ABC):
@@ -19,7 +22,7 @@ def __init__(self, config_parameters: Any, name: str):
self._name = name
@abc.abstractmethod
- def get_dependents(self) -> List[str]:
+ def get_dependents(self) -> list[str]:
"""Gets the names/types of the inputs to this transform.
:return: A list of columns on which this model depends.
"""
@@ -34,7 +37,7 @@ def compute(self, **inputs: Any) -> Any:
pass
@property
- def config_parameters(self) -> Dict[str, Any]:
+ def config_parameters(self) -> dict[str, Any]:
"""Accessor for configuration parameters"""
return self._config_parameters
diff --git a/hamilton/node.py b/hamilton/node.py
index 2c6e5c73f..e895852f3 100644
--- a/hamilton/node.py
+++ b/hamilton/node.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
import inspect
import sys
import typing
from enum import Enum
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
+from typing import Any, Callable
import typing_inspect
@@ -41,20 +43,20 @@ def from_parameter(param: inspect.Parameter):
return DependencyType.OPTIONAL
-class Node(object):
+class Node:
"""Object representing a node of computation."""
def __init__(
self,
name: str,
- typ: Type,
+ typ: type,
doc_string: str = "",
callabl: Callable = None,
node_source: NodeType = NodeType.STANDARD,
- input_types: Dict[str, Union[Type, Tuple[Type, DependencyType]]] = None,
- tags: Dict[str, Any] = None,
- namespace: Tuple[str, ...] = (),
- originating_functions: Optional[Tuple[Callable, ...]] = None,
+ input_types: dict[str, type | tuple[type, DependencyType]] = None,
+ tags: dict[str, Any] = None,
+ namespace: tuple[str, ...] = (),
+ originating_functions: tuple[Callable, ...] | None = None,
):
"""Constructor for our Node object.
@@ -128,7 +130,7 @@ def collect_dependency(self) -> str:
return key
@property
- def namespace(self) -> Tuple[str, ...]:
+ def namespace(self) -> tuple[str, ...]:
return self._namespace
@property
@@ -136,7 +138,7 @@ def documentation(self) -> str:
return self._doc
@property
- def input_types(self) -> Dict[Any, Tuple[Any, DependencyType]]:
+ def input_types(self) -> dict[Any, tuple[Any, DependencyType]]:
return self._input_types
def requires(self, dependency: str) -> bool:
@@ -177,19 +179,19 @@ def node_role(self):
return self._node_source
@property
- def dependencies(self) -> List["Node"]:
+ def dependencies(self) -> list[Node]:
return self._dependencies
@property
- def depended_on_by(self) -> List["Node"]:
+ def depended_on_by(self) -> list[Node]:
return self._depended_on_by
@property
- def tags(self) -> Dict[str, str]:
+ def tags(self) -> dict[str, str]:
return self._tags
@property
- def originating_functions(self) -> Optional[Tuple[Callable, ...]]:
+ def originating_functions(self) -> tuple[Callable, ...] | None:
"""Gives all functions from which this node was created. None if the data
is not available (it is user-defined, or we have not added it yet). Note that this can be
multiple in the case of subdags (the subdag function + the other function). In that case,
@@ -222,7 +224,7 @@ def __hash__(self):
def __repr__(self):
return f"<{self.name} {self._tags}>"
- def __eq__(self, other: "Node"):
+ def __eq__(self, other: Node):
"""Want to deeply compare nodes in a custom way.
Current user is just unit tests. But you never know :)
@@ -241,7 +243,7 @@ def __eq__(self, other: "Node"):
and self.node_role == other.node_role
)
- def __ne__(self, other: "Node"):
+ def __ne__(self, other: Node):
return not self.__eq__(other)
def __call__(self, *args, **kwargs):
@@ -249,7 +251,7 @@ def __call__(self, *args, **kwargs):
return self.callable(*args, **kwargs)
@staticmethod
- def from_fn(fn: Callable, name: str = None) -> "Node":
+ def from_fn(fn: Callable, name: str = None) -> Node:
"""Generates a node from a function. Optionally overrides the name.
Note that currently, the `originating_function` is externally passed in -- this
@@ -292,7 +294,7 @@ def from_fn(fn: Callable, name: str = None) -> "Node":
node_source=node_source,
)
- def copy_with(self, include_refs: bool = True, **overrides) -> "Node":
+ def copy_with(self, include_refs: bool = True, **overrides) -> Node:
"""Copies a node with the specified overrides for the constructor arguments.
Utility function for creating a node -- useful for modifying it.
@@ -317,7 +319,7 @@ def copy_with(self, include_refs: bool = True, **overrides) -> "Node":
out._depended_on_by = self._depended_on_by
return out
- def copy(self, include_refs: bool = True) -> "Node":
+ def copy(self, include_refs: bool = True) -> Node:
"""Copies a node, not modifying anything (except for the references
/dependencies if specified).
@@ -330,8 +332,8 @@ def copy(self, include_refs: bool = True) -> "Node":
return self.copy_with(include_refs)
def reassign_inputs(
- self, input_names: Dict[str, Any] = None, input_values: Dict[str, Any] = None
- ) -> "Node":
+ self, input_names: dict[str, Any] = None, input_values: dict[str, Any] = None
+ ) -> Node:
"""Reassigns the input names of a node. Useful for applying
a node to a separate input if needed. Note that things can get a
little strange if you have multiple inputs with the same name, so
@@ -357,8 +359,8 @@ def new_callable(**kwargs) -> Any:
return out
def transform_output(
- self, __transform: Callable[[Dict[str, Any], Any], Any], __output_type: Type[Any]
- ) -> "Node":
+ self, __transform: Callable[[dict[str, Any], Any], Any], __output_type: type[Any]
+ ) -> Node:
"""Applies a transformation on the output of the node, returning a new node.
Also modifies the type.
@@ -375,7 +377,7 @@ def new_callable(**kwargs) -> Any:
def matches_query(
- tags: Dict[str, Union[str, List[str]]], query_dict: Dict[str, Optional[Union[str, List[str]]]]
+ tags: dict[str, str | list[str]], query_dict: dict[str, str | list[str] | None]
) -> bool:
"""Check whether a set of node tags matches the query based on tags.
diff --git a/hamilton/plugins/h_mlflow.py b/hamilton/plugins/h_mlflow.py
index 2d77815a0..0583613e2 100644
--- a/hamilton/plugins/h_mlflow.py
+++ b/hamilton/plugins/h_mlflow.py
@@ -1,14 +1,18 @@
+from __future__ import annotations
+
import logging
import pickle
import warnings
-from typing import Any, Dict, List, Optional, Type, Union
+from typing import TYPE_CHECKING, Any
import mlflow
import mlflow.data
-from hamilton import graph_types
from hamilton.lifecycle import GraphConstructionHook, GraphExecutionHook, NodeExecutionHook
+if TYPE_CHECKING:
+ from hamilton import graph_types
+
# silence odd ongoing MLFlow issue that spams warnings
# GitHub Issue https://github.com/mlflow/mlflow/issues/8605
warnings.filterwarnings("ignore", category=UserWarning)
@@ -33,7 +37,7 @@
logger = logging.getLogger(__name__)
-def get_path_from_metadata(metadata: dict) -> Union[str, None]:
+def get_path_from_metadata(metadata: dict) -> str | None:
"""Retrieve the `path` attribute from DataSaver output metadata"""
path = None
if "path" in metadata:
@@ -57,16 +61,16 @@ class MLFlowTracker(
def __init__(
self,
- tracking_uri: Optional[str] = None,
- registry_uri: Optional[str] = None,
- artifact_location: Optional[str] = None,
+ tracking_uri: str | None = None,
+ registry_uri: str | None = None,
+ artifact_location: str | None = None,
experiment_name: str = "Hamilton",
- experiment_tags: Optional[dict] = None,
- experiment_description: Optional[str] = None,
- run_id: Optional[str] = None,
- run_name: Optional[str] = None,
- run_tags: Optional[dict] = None,
- run_description: Optional[str] = None,
+ experiment_tags: dict | None = None,
+ experiment_description: str | None = None,
+ run_id: str | None = None,
+ run_name: str | None = None,
+ run_tags: dict | None = None,
+ run_description: str | None = None,
log_system_metrics: bool = False,
):
"""Configure the MLFlow client and experiment for the lifetime of the tracker
@@ -127,8 +131,8 @@ def run_before_graph_execution(
self,
*,
run_id: str,
- final_vars: List[str],
- inputs: Dict[str, Any],
+ final_vars: list[str],
+ inputs: dict[str, Any],
graph: graph_types.HamiltonGraph,
**kwargs,
):
@@ -175,7 +179,7 @@ def run_after_node_execution(
self,
*,
node_name: str,
- node_return_type: Type,
+ node_return_type: type,
node_tags: dict,
node_kwargs: dict,
result: Any,
diff --git a/hamilton/plugins/h_ray.py b/hamilton/plugins/h_ray.py
index 0f0692803..8af915fda 100644
--- a/hamilton/plugins/h_ray.py
+++ b/hamilton/plugins/h_ray.py
@@ -1,17 +1,22 @@
+from __future__ import annotations
+
import functools
import json
import logging
import typing
+from typing import TYPE_CHECKING
import ray
from ray import workflow
from hamilton import base, htypes, node
from hamilton.execution import executors
-from hamilton.execution.executors import TaskFuture
-from hamilton.execution.grouping import TaskImplementation
from hamilton.function_modifiers.metadata import RAY_REMOTE_TAG_NAMESPACE
+if TYPE_CHECKING:
+ from hamilton.execution.executors import TaskFuture
+ from hamilton.execution.grouping import TaskImplementation
+
logger = logging.getLogger(__name__)
@@ -31,7 +36,7 @@ def new_fn(*args, **kwargs):
return fn
-def parse_ray_remote_options_from_tags(tags: typing.Dict[str, str]) -> typing.Dict[str, typing.Any]:
+def parse_ray_remote_options_from_tags(tags: dict[str, str]) -> dict[str, typing.Any]:
"""DRY helper to parse ray.remote(**options) from Hamilton Tags
Tags are added to nodes via the @ray_remote_options decorator
@@ -101,17 +106,17 @@ def __init__(self, result_builder: base.ResultMixin):
)
@staticmethod
- def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool:
+ def check_input_type(node_type: type, input_value: typing.Any) -> bool:
# NOTE: the type of a raylet is unknown until they are computed
if isinstance(input_value, ray._raylet.ObjectRef):
return True
return htypes.check_input_type(node_type, input_value)
@staticmethod
- def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool:
+ def check_node_type_equivalence(node_type: type, input_type: type) -> bool:
return node_type == input_type
- def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any:
+ def execute_node(self, node: node.Node, kwargs: dict[str, typing.Any]) -> typing.Any:
"""Function that is called as we walk the graph to determine how to execute a hamilton function.
:param node: the node from the graph.
@@ -121,7 +126,7 @@ def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) ->
ray_options = parse_ray_remote_options_from_tags(node.tags)
return ray.remote(raify(node.callable)).options(**ray_options).remote(**kwargs)
- def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any:
+ def build_result(self, **outputs: dict[str, typing.Any]) -> typing.Any:
"""Builds the result and brings it back to this running process.
:param outputs: the dictionary of key -> Union[ray object reference | value]
@@ -195,17 +200,17 @@ def __init__(self, result_builder: base.ResultMixin, workflow_id: str):
)
@staticmethod
- def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool:
+ def check_input_type(node_type: type, input_value: typing.Any) -> bool:
# NOTE: the type of a raylet is unknown until they are computed
if isinstance(input_value, ray._raylet.ObjectRef):
return True
return htypes.check_input_type(node_type, input_value)
@staticmethod
- def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool:
+ def check_node_type_equivalence(node_type: type, input_type: type) -> bool:
return node_type == input_type
- def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any:
+ def execute_node(self, node: node.Node, kwargs: dict[str, typing.Any]) -> typing.Any:
"""Function that is called as we walk the graph to determine how to execute a hamilton function.
:param node: the node from the graph.
@@ -215,7 +220,7 @@ def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) ->
ray_options = parse_ray_remote_options_from_tags(node.tags)
return ray.remote(raify(node.callable)).options(**ray_options).bind(**kwargs)
- def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any:
+ def build_result(self, **outputs: dict[str, typing.Any]) -> typing.Any:
"""Builds the result and brings it back to this running process.
:param outputs: the dictionary of key -> Union[ray object reference | value]
@@ -240,7 +245,7 @@ class RayTaskExecutor(executors.TaskExecutor):
def __init__(
self,
num_cpus: int = None,
- ray_init_config: typing.Dict[str, typing.Any] = None,
+ ray_init_config: dict[str, typing.Any] = None,
skip_init: bool = False,
):
"""Creates a ray task executor. Note this will likely take in more parameters. This is
diff --git a/hamilton/registry.py b/hamilton/registry.py
index a3d24aec7..27f0c4c74 100644
--- a/hamilton/registry.py
+++ b/hamilton/registry.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import collections
import configparser
import functools
@@ -5,7 +7,7 @@
import logging
import os
import pathlib
-from typing import Any, Dict, Literal, Optional, Tuple, Type, get_args
+from typing import Any, Literal, get_args
logger = logging.getLogger(__name__)
@@ -33,14 +35,14 @@
"huggingface",
"mlflow",
]
-HAMILTON_EXTENSIONS: Tuple[ExtensionName, ...] = get_args(ExtensionName)
+HAMILTON_EXTENSIONS: tuple[ExtensionName, ...] = get_args(ExtensionName)
HAMILTON_AUTOLOAD_ENV = "HAMILTON_AUTOLOAD_EXTENSIONS"
# NOTE the variable DEFAULT_CONFIG_LOCAITON is redundant with `hamilton.telemetry`
# but this `registry` module must avoid circular imports
DEFAULT_CONFIG_LOCATION = pathlib.Path("~/.hamilton.conf").expanduser()
# This is a dictionary of extension name -> dict with dataframe and column types.
-DF_TYPE_AND_COLUMN_TYPES: Dict[str, Dict[str, Type]] = {}
+DF_TYPE_AND_COLUMN_TYPES: dict[str, dict[str, type]] = {}
COLUMN_TYPE = "column_type"
DATAFRAME_TYPE = "dataframe_type"
@@ -148,7 +150,7 @@ def config_disable_autoload():
config.write(f)
-def register_types(extension_name: str, dataframe_type: Type, column_type: Optional[Type]):
+def register_types(extension_name: str, dataframe_type: type, column_type: type | None):
"""Registers the dataframe and column types for the extension. Note that column types are optional
as some extensions may not have a column type (E.G. spark). In this case, this is not included
@@ -189,7 +191,7 @@ def fill_with_scalar(df: Any, column_name: str, scalar_value: Any) -> Any:
raise NotImplementedError()
-def get_column_type_from_df_type(dataframe_type: Type) -> Type:
+def get_column_type_from_df_type(dataframe_type: type) -> type:
"""Function to cycle through the registered extensions and return the column type for the dataframe type.
:param dataframe_type: the dataframe type to find the column type for.
@@ -221,7 +223,7 @@ def register_adapter(adapter: Any):
SAVER_REGISTRY[adapter.name()].append(adapter)
-def get_registered_dataframe_types() -> Dict[str, Type]:
+def get_registered_dataframe_types() -> dict[str, type]:
"""Returns a dictionary of extension name -> dataframe type.
:return: the dictionary.
@@ -232,7 +234,7 @@ def get_registered_dataframe_types() -> Dict[str, Type]:
}
-def get_registered_column_types() -> Dict[str, Type]:
+def get_registered_column_types() -> dict[str, type]:
"""Returns a dictionary of extension name -> column type.
:return: the dictionary.
diff --git a/pyproject.toml b/pyproject.toml
index 5434a7ce5..89cc77711 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -206,7 +206,7 @@ extend-select = [
"TCH", # Move type-only imports to a type-checking block.
"TID", # Helps you write tidier imports.
# "TRY", # Prevent exception handling anti-patterns
-# "UP", # pyupgrade
+ "UP", # pyupgrade
"W", # pycodestyle warnings
]
extend-ignore = [
diff --git a/ui/sdk/.pre-commit-config.yaml b/ui/sdk/.pre-commit-config.yaml
index 75a5287c3..da5107f3d 100644
--- a/ui/sdk/.pre-commit-config.yaml
+++ b/ui/sdk/.pre-commit-config.yaml
@@ -16,6 +16,11 @@ repos:
hooks:
- id: black
args: [--line-length=100]
+- repo: https://github.com/asottile/pyupgrade
+ rev: v3.17.0
+ hooks:
+ - id: pyupgrade
+ args: [--py38-plus]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
diff --git a/ui/sdk/src/hamilton_sdk/adapters.py b/ui/sdk/src/hamilton_sdk/adapters.py
index 995143e64..b553786a3 100644
--- a/ui/sdk/src/hamilton_sdk/adapters.py
+++ b/ui/sdk/src/hamilton_sdk/adapters.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import datetime
import hashlib
import logging
@@ -6,7 +8,7 @@
import traceback
from datetime import timezone
from types import ModuleType
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable
from hamilton_sdk import driver
from hamilton_sdk.api import clients, constants
@@ -22,7 +24,7 @@
logger = logging.getLogger(__name__)
-def get_node_name(node_: node.Node, task_id: Optional[str]) -> str:
+def get_node_name(node_: node.Node, task_id: str | None) -> str:
if task_id is not None:
return f"{task_id}-{node_.name}"
return node_.name
@@ -43,14 +45,14 @@ def __init__(
project_id: int,
username: str,
dag_name: str,
- tags: Dict[str, str] = None,
+ tags: dict[str, str] = None,
client_factory: Callable[
- [str, str, str, Union[str, bool]], clients.HamiltonClient
+ [str, str, str, str | bool], clients.HamiltonClient
] = clients.BasicSynchronousHamiltonClient,
api_key: str = None,
hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL),
hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL),
- verify: Union[str, bool] = True,
+ verify: str | bool = True,
):
"""This hooks into Hamilton execution to track DAG runs in Hamilton UI.
@@ -104,7 +106,7 @@ def __init__(
self.seed = None
def post_graph_construct(
- self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any]
+ self, graph: h_graph.FunctionGraph, modules: list[ModuleType], config: dict[str, Any]
):
"""Registers the DAG to get an ID."""
if self.seed is None:
@@ -138,9 +140,9 @@ def pre_graph_execute(
self,
run_id: str,
graph: h_graph.FunctionGraph,
- final_vars: List[str],
- inputs: Dict[str, Any],
- overrides: Dict[str, Any],
+ final_vars: list[str],
+ inputs: dict[str, Any],
+ overrides: dict[str, Any],
):
"""Creates a DAG run."""
logger.debug("pre_graph_execute %s", run_id)
@@ -167,7 +169,7 @@ def pre_graph_execute(
return dw_run_id
def pre_node_execute(
- self, run_id: str, node_: node.Node, kwargs: Dict[str, Any], task_id: Optional[str] = None
+ self, run_id: str, node_: node.Node, kwargs: dict[str, Any], task_id: str | None = None
):
"""Captures start of node execution."""
logger.debug("pre_node_execute %s %s", run_id, task_id)
@@ -200,8 +202,8 @@ def pre_node_execute(
def get_hash(self, block_value: int):
"""Creates a deterministic hash."""
- full_salt = "%s.%s%s" % (self.seed, "DAGWORKS", ".")
- hash_str = "%s%s" % (full_salt, str(block_value))
+ full_salt = "{}.{}{}".format(self.seed, "DAGWORKS", ".")
+ hash_str = f"{full_salt}{str(block_value)}"
hash_str = hash_str.encode("ascii")
return int(hashlib.sha1(hash_str).hexdigest()[:15], 16)
@@ -246,11 +248,11 @@ def post_node_execute(
self,
run_id: str,
node_: node.Node,
- kwargs: Dict[str, Any],
+ kwargs: dict[str, Any],
success: bool,
- error: Optional[Exception],
- result: Optional[Any],
- task_id: Optional[str] = None,
+ error: Exception | None,
+ result: Any | None,
+ task_id: str | None = None,
):
"""Captures end of node execution."""
logger.debug("post_node_execute %s %s", run_id, task_id)
@@ -339,8 +341,8 @@ def post_graph_execute(
run_id: str,
graph: h_graph.FunctionGraph,
success: bool,
- error: Optional[Exception],
- results: Optional[Dict[str, Any]],
+ error: Exception | None,
+ results: dict[str, Any] | None,
):
"""Captures end of DAG execution."""
logger.debug("post_graph_execute %s", run_id)
@@ -387,14 +389,14 @@ def __init__(
project_id: int,
username: str,
dag_name: str,
- tags: Dict[str, str] = None,
+ tags: dict[str, str] = None,
client_factory: Callable[
- [str, str, str, Union[str, bool]], clients.BasicAsynchronousHamiltonClient
+ [str, str, str, str | bool], clients.BasicAsynchronousHamiltonClient
] = clients.BasicAsynchronousHamiltonClient,
api_key: str = os.environ.get("HAMILTON_API_KEY", ""),
hamilton_api_url=os.environ.get("HAMILTON_API_URL", constants.HAMILTON_API_URL),
hamilton_ui_url=os.environ.get("HAMILTON_UI_URL", constants.HAMILTON_UI_URL),
- verify: Union[str, bool] = True,
+ verify: str | bool = True,
):
self.project_id = project_id
self.api_key = api_key
@@ -441,7 +443,7 @@ async def ainit(self):
return self
async def post_graph_construct(
- self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any]
+ self, graph: h_graph.FunctionGraph, modules: list[ModuleType], config: dict[str, Any]
):
logger.debug("post_graph_construct")
fg_id = id(graph)
@@ -472,9 +474,9 @@ async def pre_graph_execute(
self,
run_id: str,
graph: h_graph.FunctionGraph,
- final_vars: List[str],
- inputs: Dict[str, Any],
- overrides: Dict[str, Any],
+ final_vars: list[str],
+ inputs: dict[str, Any],
+ overrides: dict[str, Any],
):
logger.debug("pre_graph_execute %s", run_id)
fg_id = id(graph)
@@ -496,7 +498,7 @@ async def pre_graph_execute(
self.task_runs[run_id] = {}
async def pre_node_execute(
- self, run_id: str, node_: node.Node, kwargs: Dict[str, Any], task_id: Optional[str] = None
+ self, run_id: str, node_: node.Node, kwargs: dict[str, Any], task_id: str | None = None
):
logger.debug("pre_node_execute %s", run_id)
tracking_state = self.tracking_states[run_id]
@@ -529,9 +531,9 @@ async def post_node_execute(
run_id: str,
node_: node.Node,
success: bool,
- error: Optional[Exception],
+ error: Exception | None,
result: Any,
- task_id: Optional[str] = None,
+ task_id: str | None = None,
**future_kwargs,
):
logger.debug("post_node_execute %s", run_id)
@@ -616,8 +618,8 @@ async def post_graph_execute(
run_id: str,
graph: h_graph.FunctionGraph,
success: bool,
- error: Optional[Exception],
- results: Optional[Dict[str, Any]],
+ error: Exception | None,
+ results: dict[str, Any] | None,
):
logger.debug("post_graph_execute %s", run_id)
dw_run_id = self.dw_run_ids[run_id]
diff --git a/ui/sdk/src/hamilton_sdk/api/clients.py b/ui/sdk/src/hamilton_sdk/api/clients.py
index 87e687027..499cb313f 100644
--- a/ui/sdk/src/hamilton_sdk/api/clients.py
+++ b/ui/sdk/src/hamilton_sdk/api/clients.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import abc
import asyncio
import datetime
@@ -8,7 +10,7 @@
import threading
import time
from collections import defaultdict
-from typing import Any, Callable, Dict, List, Union
+from typing import Any, Callable
from urllib.parse import urlencode
import aiohttp
@@ -85,12 +87,12 @@ def register_dag_template_if_not_exists(
project_id: int,
dag_hash: str,
code_hash: str,
- nodes: List[dict],
- code_artifacts: List[dict],
+ nodes: list[dict],
+ code_artifacts: list[dict],
name: str,
config: dict,
- tags: Dict[str, Any],
- code: List[dict],
+ tags: dict[str, Any],
+ code: list[dict],
vcs_info: GitInfo, # TODO -- separate this out so we can support more code version types -- just pass it directly to the client
) -> int:
"""Registers a project version with the Hamilton BE API.
@@ -113,9 +115,9 @@ def register_dag_template_if_not_exists(
def create_and_start_dag_run(
self,
dag_template_id: int,
- tags: Dict[str, str],
- inputs: Dict[str, Any],
- outputs: List[str],
+ tags: dict[str, str],
+ inputs: dict[str, Any],
+ outputs: list[str],
) -> int:
"""Logs a DAG run to the Hamilton BE API.
@@ -133,9 +135,9 @@ def create_and_start_dag_run(
def update_tasks(
self,
dag_run_id: int,
- attributes: List[dict],
- task_updates: List[dict],
- in_samples: List[bool] = None,
+ attributes: list[dict],
+ task_updates: list[dict],
+ in_samples: list[bool] = None,
):
"""Updates the tasks + attributes in a DAG run. Does not change the DAG run's status.
@@ -168,7 +170,7 @@ def __init__(
username: str,
h_api_url: str,
base_path: str = "/api/v1",
- verify: Union[str, bool] = True,
+ verify: str | bool = True,
):
"""Initializes a Hamilton API client
@@ -251,7 +253,7 @@ def flush(self, batch):
"""Flush the batch (send it to the backend or process it)."""
logger.debug(f"Flushing batch: {len(batch)}") # Replace with actual processing logic
# group by dag_run_id -- just incase someone does something weird?
- dag_run_ids = set([item["dag_run_id"] for item in batch])
+ dag_run_ids = {item["dag_run_id"] for item in batch}
for dag_run_id in dag_run_ids:
attributes_list, task_updates_list = create_batch(batch, dag_run_id)
response = requests.put(
@@ -287,7 +289,7 @@ def stop(self):
except queue.Empty:
break
- def _common_headers(self) -> Dict[str, Any]:
+ def _common_headers(self) -> dict[str, Any]:
"""Yields the common headers for all requests.
@return: a dictionary of headers.
@@ -313,7 +315,7 @@ def register_code_version_if_not_exists(
project_id: int,
code_hash: str,
vcs_info: GitInfo,
- slurp_code: Callable[[], Dict[str, str]],
+ slurp_code: Callable[[], dict[str, str]],
) -> int:
logger.debug(f"Checking if code version {code_hash} exists for project {project_id}")
response = requests.get(
@@ -391,12 +393,12 @@ def register_dag_template_if_not_exists(
project_id: int,
dag_hash: str,
code_hash: str,
- nodes: List[dict],
- code_artifacts: List[dict],
+ nodes: list[dict],
+ code_artifacts: list[dict],
name: str,
config: dict,
- tags: Dict[str, Any],
- code: List[dict],
+ tags: dict[str, Any],
+ code: list[dict],
vcs_info: GitInfo,
) -> int:
logger.debug(
@@ -456,7 +458,7 @@ def register_dag_template_if_not_exists(
raise
def create_and_start_dag_run(
- self, dag_template_id: int, tags: Dict[str, str], inputs: Dict[str, Any], outputs: List[str]
+ self, dag_template_id: int, tags: dict[str, str], inputs: dict[str, Any], outputs: list[str]
) -> int:
logger.debug(f"Creating DAG run for project version {dag_template_id}")
response = requests.post(
@@ -487,9 +489,9 @@ def create_and_start_dag_run(
def update_tasks(
self,
dag_run_id: int,
- attributes: List[dict],
- task_updates: List[dict],
- in_samples: List[bool] = None,
+ attributes: list[dict],
+ task_updates: list[dict],
+ in_samples: list[bool] = None,
):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
@@ -530,7 +532,7 @@ def __init__(
username: str,
h_api_url: str,
base_path: str = "/api/v1",
- verify: Union[str, bool] = True,
+ verify: str | bool = True,
):
"""Initializes an async Hamilton API client
@@ -562,7 +564,7 @@ async def flush(self, batch):
"""Flush the batch (send it to the backend or process it)."""
logger.debug(f"Flushing batch: {len(batch)}") # Replace with actual processing logic
# group by dag_run_id -- just incase someone does something weird?
- dag_run_ids = set([item["dag_run_id"] for item in batch])
+ dag_run_ids = {item["dag_run_id"] for item in batch}
for dag_run_id in dag_run_ids:
attributes_list, task_updates_list = create_batch(batch, dag_run_id)
async with aiohttp.ClientSession() as session:
@@ -611,7 +613,7 @@ async def worker(self):
batch = []
last_flush_time = time.time()
- def _common_headers(self) -> Dict[str, Any]:
+ def _common_headers(self) -> dict[str, Any]:
"""Yields the common headers for all requests.
@return: a dictionary of headers.
@@ -640,7 +642,7 @@ async def register_code_version_if_not_exists(
project_id: int,
code_hash: str,
vcs_info: GitInfo,
- slurp_code: Callable[[], Dict[str, str]],
+ slurp_code: Callable[[], dict[str, str]],
) -> int:
logger.debug(f"Checking if code version {code_hash} exists for project {project_id}")
async with aiohttp.ClientSession() as session:
@@ -721,12 +723,12 @@ async def register_dag_template_if_not_exists(
project_id: int,
dag_hash: str,
code_hash: str,
- nodes: List[dict],
- code_artifacts: List[dict],
+ nodes: list[dict],
+ code_artifacts: list[dict],
name: str,
config: dict,
- tags: Dict[str, Any],
- code: List[dict],
+ tags: dict[str, Any],
+ code: list[dict],
vcs_info: GitInfo,
) -> int:
logger.debug(
@@ -793,7 +795,7 @@ async def register_dag_template_if_not_exists(
raise
async def create_and_start_dag_run(
- self, dag_template_id: int, tags: Dict[str, str], inputs: Dict[str, Any], outputs: List[str]
+ self, dag_template_id: int, tags: dict[str, str], inputs: dict[str, Any], outputs: list[str]
) -> int:
logger.debug(f"Creating DAG run for project version {dag_template_id}")
async with aiohttp.ClientSession() as session:
@@ -825,9 +827,9 @@ async def create_and_start_dag_run(
async def update_tasks(
self,
dag_run_id: int,
- attributes: List[dict],
- task_updates: List[dict],
- in_samples: List[bool] = None,
+ attributes: list[dict],
+ task_updates: list[dict],
+ in_samples: list[bool] = None,
):
logger.debug(
f"Updating tasks for DAG run {dag_run_id} with {len(attributes)} "