From 04c4cb6b88765dfcc2746c99b1352f1b4feb9ce3 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Wed, 10 Apr 2024 11:05:01 -0700 Subject: [PATCH] Let resolver op be able to get external artifacts. PiperOrigin-RevId: 623548131 --- .../ops/latest_policy_model_op.py | 76 +++++++++++++++++-- tfx/dsl/input_resolution/resolver_op.py | 8 ++ .../input_resolution/input_graph_resolver.py | 20 +++-- .../input_resolution/node_inputs_resolver.py | 15 ++-- tfx/types/external_artifact_utils.py | 35 +++++++++ 5 files changed, 134 insertions(+), 20 deletions(-) create mode 100644 tfx/types/external_artifact_utils.py diff --git a/tfx/dsl/input_resolution/ops/latest_policy_model_op.py b/tfx/dsl/input_resolution/ops/latest_policy_model_op.py index 1492744c255..7953c0fc112 100644 --- a/tfx/dsl/input_resolution/ops/latest_policy_model_op.py +++ b/tfx/dsl/input_resolution/ops/latest_policy_model_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for LatestPolicyModel operator.""" + import collections import enum from typing import Dict @@ -23,6 +24,7 @@ from tfx.orchestration.portable.mlmd import event_lib from tfx.orchestration.portable.mlmd import filter_query_builder as q from tfx.types import artifact_utils +from tfx.types import external_artifact_utils from tfx.utils import typing_utils from ml_metadata.proto import metadata_store_pb2 @@ -324,7 +326,17 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap): input_child_artifacts = input_dict.get( ops_utils.MODEL_BLESSSING_KEY, [] ) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, []) - input_child_artifact_ids = set([a.id for a in input_child_artifacts]) + + input_child_artifact_ids = set() + for a in input_child_artifacts: + if a.is_external: + input_child_artifact_ids.add( + external_artifact_utils.get_id_from_external_id( + a.mlmd_artifact.external_id + ) + ) + else: + input_child_artifact_ids.add(a.id) # If the ModelBlessing and ModelInfraBlessing lists are empty, then no # child artifacts can be considered and we raise a SkipSignal. This can @@ -353,7 +365,52 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap): # There could be multiple events with the same execution ID but different # artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we # keep the values of model_artifact_ids_by_execution_id as sets. - model_artifact_ids = sorted(set(m.id for m in models)) + are_model_external = [m.is_external for m in models] + if any(are_model_external) and not all(are_model_external): + raise exceptions.InvalidArgument( + 'Inputs to the LastestPolicyModel are from both current pipeline and' + ' external pipeline. LastestPolicyModel does not support such usage.' + ) + + if not all(are_model_external): + model_artifact_ids = sorted(set(m.id for m in models)) + store = self.context.store + else: + # If the input models are from external pipeline, try to get a MLMD + # `store` which connects to the external MLMD instance. + model_external_ids = sorted( + set([m.mlmd_artifact.external_id for m in models]) + ) + model_artifact_ids = sorted( + set([ + external_artifact_utils.get_id_from_external_id(i) + for i in model_external_ids + ]) + ) + + pipeline_assets = [ + external_artifact_utils.get_pipeline_asset_from_external_id(i) + for i in model_external_ids + ] + pipeline_assets = set([a.SerializeToString() for a in pipeline_assets]) + if len(pipeline_assets) > 1: + raise exceptions.InvalidArgument( + 'Input models to the LastestPolicyModel are from multiple' + ' pipelines. LastestPolicyModel does not support such usage.' + ) + + external_connection_config = ( + external_artifact_utils.get_external_connection_config( + model_external_ids[0] + ) + ) + if not self.context.mlmd_manager: + raise ValueError('Not able to connect to external MLMD instance.') + store = self.context.mlmd_manager.get_mlmd_handle( + external_connection_config + ).store + + mlmd_resolver = metadata_resolver.MetadataResolver(store) downstream_artifact_type_names_filter_query = q.to_sql_string([ ops_utils.MODEL_BLESSING_TYPE_NAME, @@ -397,9 +454,7 @@ def event_filter(event): else: return event_lib.is_valid_output_event(event) - mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store) downstream_artifacts_by_model_ids = {} - # Split `model_artifact_ids` into batches with batch size = 100 while # fetching downstream artifacts, because # `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids @@ -420,12 +475,12 @@ def event_filter(event): downstream_artifacts_by_model_ids.update( batch_downstream_artifacts_by_model_ids ) + # Populate the ModelRelations associated with each Model artifact and its # children. model_relations_by_model_artifact_id = collections.defaultdict( ModelRelations ) - type_ids = set() for ( model_artifact_id, @@ -455,7 +510,13 @@ def event_filter(event): # Find the latest model and ModelRelations that meets the Policy. result = {} for model in models: - model_relations = model_relations_by_model_artifact_id[model.id] + if model.is_external: + model_id = external_artifact_utils.get_id_from_external_id( + model.mlmd_artifact.external_id + ) + else: + model_id = model.id + model_relations = model_relations_by_model_artifact_id[model_id] if model_relations.meets_policy(self.policy): result[ops_utils.MODEL_KEY] = [model] break @@ -463,7 +524,8 @@ def event_filter(event): return self._raise_skip_signal_or_return_empty_dict( f'No model found that meets the Policy {Policy(self.policy).name}' ) - artifact_types = self.context.store.get_artifact_types_by_id(type_ids) + + artifact_types = store.get_artifact_types_by_id(type_ids) artifact_type_by_name = {t.name: t for t in artifact_types} return _build_result_dictionary( result, model_relations, self.policy, artifact_type_by_name diff --git a/tfx/dsl/input_resolution/resolver_op.py b/tfx/dsl/input_resolution/resolver_op.py index 8594d93b6db..8c4793a719e 100644 --- a/tfx/dsl/input_resolution/resolver_op.py +++ b/tfx/dsl/input_resolution/resolver_op.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for ResolverOp and its related definitions.""" + from __future__ import annotations import abc @@ -19,6 +20,7 @@ import attr from tfx import types +from tfx.orchestration import mlmd_connection_manager as mlmd_cm from tfx.proto.orchestration import pipeline_pb2 from tfx.utils import json_utils from tfx.utils import typing_utils @@ -31,8 +33,14 @@ @attr.s(auto_attribs=True, frozen=True, kw_only=True) class Context: """Context for running ResolverOp.""" + # TODO(b/302730333) We could remove store and only use mlmd_manager. Keeping + # this for now to keep it backward compatible with other resolver ops. # MetadataStore for MLMD read access. store: mlmd.MetadataStore + + # An MLMDConnectionManager instance. It can manage multiple MLMD connections. + mlmd_manager: Optional[mlmd_cm.MLMDConnectionManager] = None + # TODO(jjong): Add more context such as current pipeline, current pipeline # run, and current running node information. diff --git a/tfx/orchestration/portable/input_resolution/input_graph_resolver.py b/tfx/orchestration/portable/input_resolution/input_graph_resolver.py index 5c6e04a9a94..cd516f37309 100644 --- a/tfx/orchestration/portable/input_resolution/input_graph_resolver.py +++ b/tfx/orchestration/portable/input_resolution/input_graph_resolver.py @@ -29,7 +29,7 @@ import collections import dataclasses import functools -from typing import Union, Sequence, Mapping, Tuple, List, Iterable, Callable +from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple, Union, cast from tfx import types from tfx.dsl.components.common import resolver @@ -37,6 +37,7 @@ from tfx.dsl.input_resolution.ops import ops from tfx.orchestration import data_types_utils from tfx.orchestration import metadata +from tfx.orchestration import mlmd_connection_manager as mlmd_cm from tfx.orchestration.portable.input_resolution import exceptions from tfx.proto.orchestration import pipeline_pb2 from tfx.utils import topsort @@ -54,6 +55,7 @@ class _Context: mlmd_handle: metadata.Metadata input_graph: pipeline_pb2.InputGraph + mlmd_manager: Optional[mlmd_cm.MLMDConnectionManager] = None def _topologically_sorted_node_ids( @@ -131,7 +133,11 @@ def _evaluate_op_node( f'nodes[{node_id}] has unknown op_type {op_node.op_type}.') from e if issubclass(op_type, resolver_op.ResolverOp): op: resolver_op.ResolverOp = op_type.create(**kwargs) - op.set_context(resolver_op.Context(store=ctx.mlmd_handle.store)) + op.set_context( + resolver_op.Context( + store=ctx.mlmd_handle.store, mlmd_manager=ctx.mlmd_manager + ) + ) return op.apply(*args) elif issubclass(op_type, resolver.ResolverStrategy): if len(args) != 1: @@ -207,7 +213,7 @@ def new_graph_fn(data: Mapping[str, _Data]): def build_graph_fn( - mlmd_handle: metadata.Metadata, + handle_like: mlmd_cm.HandleLike, input_graph: pipeline_pb2.InputGraph, ) -> Tuple[_GraphFn, List[str]]: """Build a functional interface for the `input_graph`. @@ -222,7 +228,7 @@ def build_graph_fn( z = graph_fn({'x': inputs['x'], 'y': inputs['y']}) Args: - mlmd_handle: A `Metadata` instance. + handle_like: A `mlmd_cm.HandleLike` instance. input_graph: An `pipeline_pb2.InputGraph` proto. Returns: @@ -235,7 +241,11 @@ def build_graph_fn( f'result_node {input_graph.result_node} does not exist in input_graph. ' f'Valid node ids: {list(input_graph.nodes.keys())}') - context = _Context(mlmd_handle=mlmd_handle, input_graph=input_graph) + context = _Context( + mlmd_handle=mlmd_cm.get_handle(handle_like), input_graph=input_graph + ) + if isinstance(handle_like, mlmd_cm.MLMDConnectionManager): + context.mlmd_manager = cast(mlmd_cm.MLMDConnectionManager, handle_like) input_key_to_node_id = {} for node_id in input_graph.nodes: diff --git a/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py b/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py index cad7d29c250..fee73bda28d 100644 --- a/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py +++ b/tfx/orchestration/portable/input_resolution/node_inputs_resolver.py @@ -341,7 +341,7 @@ def _join_artifacts( def _resolve_input_graph_ref( - mlmd_handle: metadata.Metadata, + handle_like: mlmd_cm.HandleLike, node_inputs: pipeline_pb2.NodeInputs, input_key: str, resolved: Dict[str, List[_Entry]], @@ -352,12 +352,12 @@ def _resolve_input_graph_ref( (i.e. `InputGraphRef` with the same `graph_id`). Args: - mlmd_handle: A `Metadata` instance. + handle_like: A `mlmd_cm.HandleLike` instance. node_inputs: A `NodeInputs` proto. input_key: A target input key whose corresponding `InputSpec` has an - `InputGraphRef`. + `InputGraphRef`. resolved: A dict that contains the already resolved inputs, and to which the - resolved result would be written from this function. + resolved result would be written from this function. """ graph_id = node_inputs.inputs[input_key].input_graph_ref.graph_id input_graph = node_inputs.input_graphs[graph_id] @@ -372,7 +372,8 @@ def _resolve_input_graph_ref( } graph_fn, graph_input_keys = input_graph_resolver.build_graph_fn( - mlmd_handle, node_inputs.input_graphs[graph_id]) + handle_like, node_inputs.input_graphs[graph_id] + ) for partition, input_dict in _join_artifacts(resolved, graph_input_keys): result = graph_fn(input_dict) if graph_output_type == _DataType.ARTIFACT_LIST: @@ -514,9 +515,7 @@ def resolve( (partition_utils.NO_PARTITION, _filter_live(artifacts)) ] elif input_spec.input_graph_ref.graph_id: - _resolve_input_graph_ref( - mlmd_cm.get_handle(handle_like), node_inputs, input_key, - resolved) + _resolve_input_graph_ref(handle_like, node_inputs, input_key, resolved) elif input_spec.mixed_inputs.input_keys: _resolve_mixed_inputs(node_inputs, input_key, resolved) elif input_spec.HasField('static_inputs'): diff --git a/tfx/types/external_artifact_utils.py b/tfx/types/external_artifact_utils.py new file mode 100644 index 00000000000..ba0b87c5db1 --- /dev/null +++ b/tfx/types/external_artifact_utils.py @@ -0,0 +1,35 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Third party version of external_artifact_utils.py.""" + + +def get_artifact_id_from_external_id(external_id: str): + del external_id + + +def get_pipeline_asset_from_external_id( + external_id: str, +): + del external_id + + +def get_external_connection_config( + external_id: str, +): + del external_id + + +def identifier(artifact): + del artifact