Skip to content

Commit

Permalink
fix(backend): pass correct ParentDagID to iterator DAG
Browse files Browse the repository at this point in the history
- Passthrough ParentDagID rather than DriverExecutionID to iterator such
  that iteration item correctly detects dependentTasks.
- Remove depends from iterator DAG as it is already handled by
  root-level task
- Update Iterator template names/nomenclature for clarity
- Update tests accordingly

Signed-off-by: Giulio Frasca <[email protected]>
  • Loading branch information
gmfrasca committed Oct 3, 2024
1 parent 02e27f0 commit e74b9ce
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 104 deletions.
21 changes: 9 additions & 12 deletions backend/src/v2/compiler/argocompiler/dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen
if err != nil {
return err
}

dag.DAG.Tasks = append(dag.DAG.Tasks, tasks...)
}
_, err = c.addTemplate(dag, name)
Expand Down Expand Up @@ -274,10 +273,9 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
// Set up Loop Control Template
loopDriverArgoName := name + "-loop-driver"
loopDriverInputs := dagDriverInputs{
component: componentSpecPlaceholder,
parentDagID: parentDagID,
task: taskJson, // TODO(Bobgy): avoid duplicating task JSON twice in the template.
iterationIndex: "0",
component: componentSpecPlaceholder,
parentDagID: parentDagID,
task: taskJson, // TODO(Bobgy): avoid duplicating task JSON twice in the template.
}
loopDriver, loopDriverOutputs, err := c.dagDriverTask(loopDriverArgoName, loopDriverInputs)
if err != nil {
Expand All @@ -300,12 +298,12 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
Tasks: iteratorTasks,
},
}
parallellism_limit := int64(task.GetIteratorPolicy().GetParallelismLimit())
if parallellism_limit > 0 {
loopTmpl.Parallelism = &parallellism_limit
parallelismLimit := int64(task.GetIteratorPolicy().GetParallelismLimit())
if parallelismLimit > 0 {
loopTmpl.Parallelism = &parallelismLimit
}

loopTmplName, err := c.addTemplate(loopTmpl, componentName+"-loop-"+name)
loopTmplName, err := c.addTemplate(loopTmpl, fmt.Sprintf("%s-loop-iterator", componentName))
if err != nil {
return nil, err
}
Expand All @@ -325,7 +323,7 @@ func (c *workflowCompiler) iteratorTask(name string, task *pipelinespec.Pipeline
Parameters: []wfapi.Parameter{
{
Name: paramParentDagID,
Value: wfapi.AnyStringPtr(loopDriverOutputs.executionID),
Value: wfapi.AnyStringPtr(parentDagID),
},
},
},
Expand Down Expand Up @@ -357,7 +355,6 @@ func (c *workflowCompiler) iterationItemTask(name string, task *pipelinespec.Pip
if err != nil {
return nil, err
}
//driver.Depends = depends(task.GetDependentTasks()) # TODO(gfrasca): Handled already by root task

iterationCount := intstr.FromString(driverOutputs.iterationCount)
iterationTasks, err := c.task(
Expand All @@ -382,7 +379,7 @@ func (c *workflowCompiler) iterationItemTask(name string, task *pipelinespec.Pip
Tasks: iterationTasks,
},
}
iterationsTmplName, err := c.addTemplate(iterationsTmpl, componentName+"-"+name)
iterationsTmplName, err := c.addTemplate(iterationsTmpl, componentName+"-iteration")
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit e74b9ce

Please sign in to comment.