Skip to content

Commit

Permalink
Add subpipeline as node node context for all nodes in a subpipeline.
Browse files Browse the repository at this point in the history
This is needed so that we can query against subpipelines with `artifact_query`.

This change makes it so that in the below example, `node_a` and `node_b` will both have a `node` context with name `parent.subpipeline`.

```
parent {
  subpipeline {
     node_a {}
     node_b {}
  }
}
```

PiperOrigin-RevId: 628120450
  • Loading branch information
kmonte authored and tfx-copybara committed Apr 25, 2024
1 parent 89beaa4 commit 69c52c8
Show file tree
Hide file tree
Showing 5 changed files with 752 additions and 3 deletions.
11 changes: 10 additions & 1 deletion tfx/dsl/compiler/node_contexts_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,16 @@ def compile_node_contexts(
constants.PIPELINE_RUN_ID_PARAMETER_NAME,
str,
)

# If this is a subpipline then set the subpipeline as node context.
if pipeline_ctx.is_subpipeline:
subpipeline_context_pb = node_contexts.contexts.add()
subpipeline_context_pb.type.name = constants.NODE_CONTEXT_TYPE_NAME
subpipeline_context_pb.name.field_value.string_value = (
compiler_utils.node_context_name(
pipeline_ctx.parent.pipeline_info.pipeline_context_name,
pipeline_ctx.pipeline_info.pipeline_context_name,
)
)
# Contexts inherited from the parent pipelines.
for i, parent_pipeline in enumerate(pipeline_ctx.parent_pipelines[::-1]):
parent_pipeline_context_pb = node_contexts.contexts.add()
Expand Down
14 changes: 12 additions & 2 deletions tfx/dsl/compiler/node_contexts_compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def test_compile_node_contexts(self):
pipeline_pb2.NodeContexts(),
)
self.assertProtoEquals(
expected_node_contexts,
node_contexts_compiler.compile_node_contexts(
compiler_context.PipelineContext(pipeline.Pipeline(_PIPELINE_NAME)),
_NODE_ID,
),
expected_node_contexts,
)

def test_compile_node_contexts_for_subpipeline(self):
Expand Down Expand Up @@ -110,6 +110,16 @@ def test_compile_node_contexts_for_subpipeline(self):
}
}
}
contexts {
type {
name: "node"
}
name {
field_value {
string_value: "test_pipeline.subpipeline"
}
}
}
contexts {
type {
name: "pipeline"
Expand Down Expand Up @@ -145,11 +155,11 @@ def test_compile_node_contexts_for_subpipeline(self):
pipeline_pb2.NodeContexts(),
)
self.assertProtoEquals(
expected_node_contexts,
node_contexts_compiler.compile_node_contexts(
subpipeline_context,
_NODE_ID,
),
expected_node_contexts,
)


Expand Down
Loading

0 comments on commit 69c52c8

Please sign in to comment.