Skip to content

Commit

Permalink
Let resolver op be able to get external artifacts.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623548131
  • Loading branch information
tfx-copybara committed Apr 26, 2024
1 parent 1c42e19 commit 17681e2
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 47 deletions.
95 changes: 78 additions & 17 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for LatestPolicyModel operator."""

import collections
import enum
from typing import Dict

from absl import logging
from tfx import types
from tfx.dsl.input_resolution import resolver_op
from tfx.dsl.input_resolution.ops import ops_utils
Expand All @@ -24,6 +26,7 @@
from tfx.orchestration.portable.mlmd import event_lib
from tfx.orchestration.portable.mlmd import filter_query_builder as q
from tfx.types import artifact_utils
from tfx.types import external_artifact_utils
from tfx.utils import typing_utils

from ml_metadata.proto import metadata_store_pb2
Expand Down Expand Up @@ -344,7 +347,17 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):
input_child_artifacts = input_dict.get(
ops_utils.MODEL_BLESSSING_KEY, []
) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, [])
input_child_artifact_ids = set([a.id for a in input_child_artifacts])

input_child_artifact_ids = set()
for a in input_child_artifacts:
if a.is_external:
input_child_artifact_ids.add(
external_artifact_utils.get_id_from_external_id(
a.mlmd_artifact.external_id
)
)
else:
input_child_artifact_ids.add(a.id)

# If the ModelBlessing and ModelInfraBlessing lists are empty, then no
# child artifacts can be considered and we raise a SkipSignal. This can
Expand Down Expand Up @@ -372,8 +385,38 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):

# There could be multiple events with the same execution ID but different
# artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we
# keep the values of model_artifact_ids_by_execution_id as sets.
model_artifact_ids = sorted(set(m.id for m in models))
# keep the values of model_artifact_ids as sets.
are_models_external = [m.is_external for m in models]
if any(are_models_external) and not all(are_models_external):
raise exceptions.InvalidArgument(
'Inputs to the LastestPolicyModel are from both current pipeline and'
' external pipeline. LastestPolicyModel does not support such usage.'
)
if all(are_models_external):
pipeline_assets = set([
external_artifact_utils.get_pipeline_asset_from_external_id(
m.mlmd_artifact.external_id
)
for m in models
])
if len(pipeline_assets) != 1:
raise exceptions.InvalidArgument(
'Input models to the LastestPolicyModel are from multiple'
' pipelines. LastestPolicyModel does not support such usage.'
)

model_by_external_id = {m.mlmd_artifact.external_id: m for m in models}
deduped_models = list(model_by_external_id.values())
model_artifact_ids = sorted(
set([
external_artifact_utils.get_id_from_external_id(i)
for i in model_by_external_id.keys()
])
)
else:
model_by_id = {m.id: m for m in models}
deduped_models = list(model_by_id.values())
model_artifact_ids = sorted(set(model_by_id.keys()))

downstream_artifact_type_names_filter_query = q.to_sql_string([
ops_utils.MODEL_BLESSING_TYPE_NAME,
Expand Down Expand Up @@ -417,10 +460,13 @@ def event_filter(event):
else:
return event_lib.is_valid_output_event(event)

mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store)
mlmd_resolver = metadata_resolver.MetadataResolver(
self.context.store,
mlmd_connection_manager=self.context.mlmd_connection_manager,
)
# Populate the ModelRelations associated with each Model artifact and its
# children.
model_relations_by_model_artifact_id = collections.defaultdict(
model_relations_by_model_identifier = collections.defaultdict(
ModelRelations
)
artifact_type_by_name: Dict[str, metadata_store_pb2.ArtifactType] = {}
Expand All @@ -429,34 +475,44 @@ def event_filter(event):
# fetching downstream artifacts, because
# `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids
# as starting artifact ids.
for id_index in range(0, len(model_artifact_ids), ops_utils.BATCH_SIZE):
batch_model_artifact_ids = model_artifact_ids[
for id_index in range(0, len(deduped_models), ops_utils.BATCH_SIZE):
batch_model_artifacts = deduped_models[
id_index : id_index + ops_utils.BATCH_SIZE
]
# Set `max_num_hops` to 50, which should be enough for this use case.
batch_downstream_artifacts_and_types_by_model_ids = (
mlmd_resolver.get_downstream_artifacts_by_artifact_ids(
batch_model_artifact_ids,
batch_downstream_artifacts_and_types_by_model_identifier = (
mlmd_resolver.get_downstream_artifacts_by_artifacts(
batch_model_artifacts,
max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS,
filter_query=filter_query,
event_filter=event_filter,
)
)

logging.error(
'Guowei batch_downstream_artifacts_and_types_by_model_identifier %s',
batch_downstream_artifacts_and_types_by_model_identifier,
)
for (
model_artifact_id,
model_identifier,
artifacts_and_types,
) in batch_downstream_artifacts_and_types_by_model_ids.items():
) in batch_downstream_artifacts_and_types_by_model_identifier.items():
for downstream_artifact, artifact_type in artifacts_and_types:
artifact_type_by_name[artifact_type.name] = artifact_type
model_relations = model_relations_by_model_artifact_id[
model_artifact_id
]
model_relations.add_downstream_artifact(downstream_artifact)
model_relations_by_model_identifier[
model_identifier
].add_downstream_artifact(downstream_artifact)

logging.error(
'Guowei model_relations_by_model_identifier %s',
model_relations_by_model_identifier,
)

# Find the latest model and ModelRelations that meets the Policy.
result = {}
for model in models:
model_relations = model_relations_by_model_artifact_id[model.id]
identifier = external_artifact_utils.identifier(model)
model_relations = model_relations_by_model_identifier[identifier]
if model_relations.meets_policy(self.policy):
result[ops_utils.MODEL_KEY] = [model]
break
Expand All @@ -465,6 +521,11 @@ def event_filter(event):
f'No model found that meets the Policy {Policy(self.policy).name}'
)

logging.error(
'Guowei result %s',
result,
)

return _build_result_dictionary(
result, model_relations, self.policy, artifact_type_by_name
)
31 changes: 25 additions & 6 deletions tfx/dsl/input_resolution/resolver_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for ResolverOp and its related definitions."""

from __future__ import annotations

import abc
from typing import Any, Generic, Literal, Mapping, Optional, Sequence, Set, Type, TypeVar, Union
from typing import Any, Generic, Literal, Mapping, Optional, Sequence, Set, Type, TypeVar, Union, cast

import attr
from tfx import types
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import json_utils
from tfx.utils import typing_utils
Expand All @@ -28,13 +30,30 @@

# Mark frozen as context instance may be used across multiple operator
# invocations.
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class Context:
"""Context for running ResolverOp."""
# MetadataStore for MLMD read access.
store: mlmd.MetadataStore
# TODO(jjong): Add more context such as current pipeline, current pipeline
# run, and current running node information.

def __init__(
self,
store=mlmd.MetadataStore,
mlmd_handle_like: Optional[mlmd_cm.HandleLike] = None,
):
self._store = store
self._mlmd_handle_like = mlmd_handle_like

@property
def store(self):
return self._store

@property
def mlmd_connection_manager(self):
if isinstance(self._mlmd_handle_like, mlmd_cm.MLMDConnectionManager):
return cast(mlmd_cm.MLMDConnectionManager, self._mlmd_handle_like)
else:
return None

# # TODO(jjong): Add more context such as current pipeline, current pipeline
# # run, and current running node information.


# Note that to use DataType as a generic type parameter (e.g.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
import collections
import dataclasses
import functools
from typing import Union, Sequence, Mapping, Tuple, List, Iterable, Callable
from typing import Callable, Iterable, List, Mapping, Sequence, Tuple, Union

from tfx import types
from tfx.dsl.components.common import resolver
from tfx.dsl.input_resolution import resolver_op
from tfx.dsl.input_resolution.ops import ops
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
from tfx.orchestration.portable.input_resolution import exceptions
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import topsort
Expand All @@ -52,8 +52,12 @@

@dataclasses.dataclass
class _Context:
mlmd_handle: metadata.Metadata
input_graph: pipeline_pb2.InputGraph
mlmd_handle_like: mlmd_cm.HandleLike

@property
def mlmd_handle(self):
return mlmd_cm.get_handle(self.mlmd_handle_like)


def _topologically_sorted_node_ids(
Expand Down Expand Up @@ -131,7 +135,12 @@ def _evaluate_op_node(
f'nodes[{node_id}] has unknown op_type {op_node.op_type}.') from e
if issubclass(op_type, resolver_op.ResolverOp):
op: resolver_op.ResolverOp = op_type.create(**kwargs)
op.set_context(resolver_op.Context(store=ctx.mlmd_handle.store))
op.set_context(
resolver_op.Context(
store=mlmd_cm.get_handle(ctx.mlmd_handle_like).store,
mlmd_handle_like=ctx.mlmd_handle_like,
)
)
return op.apply(*args)
elif issubclass(op_type, resolver.ResolverStrategy):
if len(args) != 1:
Expand Down Expand Up @@ -207,7 +216,7 @@ def new_graph_fn(data: Mapping[str, _Data]):


def build_graph_fn(
mlmd_handle: metadata.Metadata,
handle_like: mlmd_cm.HandleLike,
input_graph: pipeline_pb2.InputGraph,
) -> Tuple[_GraphFn, List[str]]:
"""Build a functional interface for the `input_graph`.
Expand All @@ -222,7 +231,7 @@ def build_graph_fn(
z = graph_fn({'x': inputs['x'], 'y': inputs['y']})
Args:
mlmd_handle: A `Metadata` instance.
handle_like: A `mlmd_cm.HandleLike` instance.
input_graph: An `pipeline_pb2.InputGraph` proto.
Returns:
Expand All @@ -235,7 +244,7 @@ def build_graph_fn(
f'result_node {input_graph.result_node} does not exist in input_graph. '
f'Valid node ids: {list(input_graph.nodes.keys())}')

context = _Context(mlmd_handle=mlmd_handle, input_graph=input_graph)
context = _Context(mlmd_handle_like=handle_like, input_graph=input_graph)

input_key_to_node_id = {}
for node_id in input_graph.nodes:
Expand Down
Loading

0 comments on commit 17681e2

Please sign in to comment.