Skip to content

Commit

Permalink
feat(core): EmbeddedFlow task
Browse files Browse the repository at this point in the history
Adds an EmbeddedFlow that allow to embed subflow tasks into a parent tasks.

Fixes #6518
  • Loading branch information
loicmathieu committed Jan 8, 2025
1 parent c5e23d4 commit 7b9b9bd
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 28 deletions.
50 changes: 31 additions & 19 deletions core/src/main/java/io/kestra/core/runners/ExecutableUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,28 +112,11 @@ public static <T extends Task & ExecutableTask<?>> Optional<SubflowExecution<?>>
}
}

String tenantId = currentExecution.getTenantId();
String subflowNamespace = runContext.render(currentTask.subflowId().namespace());
String subflowId = runContext.render(currentTask.subflowId().flowId());
Optional<Integer> subflowRevision = currentTask.subflowId().revision();

io.kestra.core.models.flows.Flow flow = flowExecutorInterface.findByIdFromTask(
currentExecution.getTenantId(),
subflowNamespace,
subflowId,
subflowRevision,
currentExecution.getTenantId(),
currentFlow.getNamespace(),
currentFlow.getId()
)
.orElseThrow(() -> new IllegalStateException("Unable to find flow '" + subflowNamespace + "'.'" + subflowId + "' with revision '" + subflowRevision.orElse(0) + "'"));

if (flow.isDisabled()) {
throw new IllegalStateException("Cannot execute a flow which is disabled");
}

if (flow instanceof FlowWithException fwe) {
throw new IllegalStateException("Cannot execute an invalid flow: " + fwe.getException());
}
Flow flow = getSubflow(tenantId, subflowNamespace, subflowId, subflowRevision, flowExecutorInterface, currentFlow);

List<Label> newLabels = inheritLabels ? new ArrayList<>(currentExecution.getLabels()) : new ArrayList<>(systemLabels(currentExecution));
if (labels != null) {
Expand Down Expand Up @@ -176,6 +159,35 @@ public static <T extends Task & ExecutableTask<?>> Optional<SubflowExecution<?>>
.build());
}

public static Flow getSubflow(String tenantId,
String subflowNamespace,
String subflowId,
Optional<Integer> subflowRevision,
FlowExecutorInterface flowExecutorInterface,
Flow currentFlow) {

Flow flow = flowExecutorInterface.findByIdFromTask(
tenantId,
subflowNamespace,
subflowId,
subflowRevision,
tenantId,
currentFlow.getNamespace(),
currentFlow.getId()
)
.orElseThrow(() -> new IllegalStateException("Unable to find flow '" + subflowNamespace + "'.'" + subflowId + "' with revision '" + subflowRevision.orElse(0) + "'"));

if (flow.isDisabled()) {
throw new IllegalStateException("Cannot execute a flow which is disabled");
}

if (flow instanceof FlowWithException fwe) {
throw new IllegalStateException("Cannot execute an invalid flow: " + fwe.getException());
}

return flow;
}

private static List<Label> systemLabels(Execution execution) {
return Streams.of(execution.getLabels())
.filter(label -> label.key().startsWith(Label.SYSTEM_PREFIX))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.kestra.core.services.ConditionService;
import io.kestra.core.utils.ListUtils;
import io.kestra.plugin.core.condition.*;
import io.kestra.plugin.core.flow.ChildFlowInterface;
import io.micronaut.core.annotation.Nullable;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
Expand Down Expand Up @@ -157,11 +158,9 @@ protected boolean isFlowTaskChild(FlowWithSource parent, FlowWithSource child) {
return parent
.allTasksWithChilds()
.stream()
.filter(t -> t instanceof ExecutableTask)
.map(t -> (ExecutableTask<?>) t)
.anyMatch(t ->
t.subflowId() != null && t.subflowId().namespace().equals(child.getNamespace()) && t.subflowId().flowId().equals(child.getId())
);
.filter(t -> t instanceof ChildFlowInterface)
.map(t -> (ChildFlowInterface) t)
.anyMatch(t -> Objects.equals(t.getFlowId(), child.getId()) && Objects.equals(t.getNamespace(), child.getNamespace()));
} catch (Exception e) {
log.warn("Failed to detect flow task on namespace:'{}', flowId:'{}'", parent.getNamespace(), parent.getId(), e);
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.kestra.core.models.tasks.Task;
import io.kestra.core.utils.ListUtils;
import io.kestra.core.validations.FlowValidation;
import io.kestra.plugin.core.flow.ChildFlowInterface;
import io.micronaut.core.annotation.AnnotationValue;
import io.micronaut.core.annotation.Introspected;
import io.micronaut.core.annotation.NonNull;
Expand Down Expand Up @@ -60,9 +61,9 @@ public boolean isValid(
}

value.allTasksWithChilds()
.stream().filter(task -> task instanceof ExecutableTask<?> executableTask
&& value.getId().equals(executableTask.subflowId().flowId())
&& value.getNamespace().equals(executableTask.subflowId().namespace()))
.stream().filter(task -> task instanceof ChildFlowInterface childFlow
&& value.getId().equals(childFlow.getFlowId())
&& value.getNamespace().equals(childFlow.getNamespace()))
.forEach(task -> violations.add("Recursive call to flow [" + value.getNamespace() + "." + value.getId() + "]"));

// input unique name
Expand Down
251 changes: 251 additions & 0 deletions core/src/main/java/io/kestra/plugin/core/flow/EmbeddedSubflow.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
package io.kestra.plugin.core.flow;

import com.fasterxml.jackson.annotation.JsonIgnore;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.executions.NextTaskRun;
import io.kestra.core.models.executions.TaskRun;
import io.kestra.core.models.flows.*;
import io.kestra.core.models.hierarchies.AbstractGraph;
import io.kestra.core.models.hierarchies.GraphCluster;
import io.kestra.core.models.hierarchies.RelationType;
import io.kestra.core.models.tasks.FlowableTask;
import io.kestra.core.models.tasks.ResolvedTask;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.runners.*;
import io.kestra.core.utils.GraphUtils;
import io.micronaut.context.ApplicationContext;
import io.micronaut.context.event.StartupEvent;
import io.micronaut.runtime.event.annotation.EventListener;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.*;
import lombok.experimental.SuperBuilder;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@SuperBuilder
@ToString
@EqualsAndHashCode
@Getter
@NoArgsConstructor
@Schema(
title = "Embeds subflow tasks into this flow."
)
@Plugin(
examples = {
@Example(
title = "Embeds subflow tasks.",
full = true,
code = """
id: parent_flow
namespace: company.team
tasks:
- id: embed_subflow
type: io.kestra.plugin.core.flow.EmbeddedSubflow
namespace: company.team
flowId: subflow
"""
)
}
)
public class EmbeddedSubflow extends Task implements FlowableTask<EmbeddedSubflow.Output>, ChildFlowInterface {
static final String PLUGIN_FLOW_OUTPUTS_ENABLED = "outputs.enabled";

// FIXME no other choice for now as getErrors() and allChildTasks() has no context
// maybe refacto getErrors() and allChildTasks() to take tenantId as parameter
@Schema(
title = "The tenantId of the subflow to be embedded."
)
@PluginProperty
private String tenantId;

@NotEmpty
@Schema(
title = "The namespace of the subflow to be embedded."
)
@PluginProperty
private String namespace;

@NotNull
@Schema(
title = "The identifier of the subflow to be embedded."
)
@PluginProperty
private String flowId;

@Schema(
title = "The revision of the subflow to be embedded.",
description = "By default, the last, i.e. the most recent, revision of the subflow is embedded."
)
@PluginProperty
@Min(value = 1)
private Integer revision;

// TODO list:
// - inputs ?
// - unique taskId via value ?
// - dedicated run context ?

@Override
@JsonIgnore
public List<Task> getErrors() {
Flow subflow = fetchSubflow();

return subflow.getErrors();
}

@Override
public AbstractGraph tasksTree(Execution execution, TaskRun taskRun, List<String> parentValues) throws IllegalVariableEvaluationException {
Flow subflow = fetchSubflow();

GraphCluster subGraph = new GraphCluster(this, taskRun, parentValues, RelationType.SEQUENTIAL);

GraphUtils.sequential(
subGraph,
subflow.getTasks(),
subflow.getErrors(),
taskRun,
execution
);

return subGraph;
}

@Override
public List<Task> allChildTasks() {
Flow subflow = fetchSubflow();

return Stream
.concat(
subflow.getTasks() != null ? subflow.getTasks().stream() : Stream.empty(),
subflow.getErrors() != null ? subflow.getErrors().stream() : Stream.empty()
)
.toList();
}

@Override
public List<ResolvedTask> childTasks(RunContext runContext, TaskRun parentTaskRun) throws IllegalVariableEvaluationException {
Flow subflow = fetchSubflow(runContext);

return FlowableUtils.resolveTasks(subflow.getTasks(), parentTaskRun);
}

@Override
public List<NextTaskRun> resolveNexts(RunContext runContext, Execution execution, TaskRun parentTaskRun) throws IllegalVariableEvaluationException {
return FlowableUtils.resolveSequentialNexts(
execution,
this.childTasks(runContext, parentTaskRun),
FlowableUtils.resolveTasks(this.getErrors(), parentTaskRun),
parentTaskRun
);
}

@Override
public Output outputs(RunContext runContext) throws Exception {
final Output.OutputBuilder builder = Output.builder();
Flow subflow = fetchSubflow(runContext);

final Optional<Map<String, Object>> subflowOutputs = Optional
.ofNullable(subflow.getOutputs())
.map(outputs -> outputs
.stream()
.collect(Collectors.toMap(
io.kestra.core.models.flows.Output::getId,
io.kestra.core.models.flows.Output::getValue)
)
);

if (subflowOutputs.isPresent()) {
Map<String, Object> outputs = runContext.render(subflowOutputs.get());
FlowInputOutput flowInputOutput = ((DefaultRunContext)runContext).getApplicationContext().getBean(FlowInputOutput.class); // this is hacking
if (subflow.getOutputs() != null && flowInputOutput != null) {
// to be able to use FILE Input, we need the execution info, so we create a fake execution with what's needed here
RunContext.FlowInfo flowInfo = runContext.flowInfo();
String executionId = (String) ((Map<String, Object>) runContext.getVariables().get("execution")).get("id");
Execution fake = Execution.builder()
.id(executionId)
.tenantId(flowInfo.tenantId())
.namespace(flowInfo.namespace())
.flowId(flowInfo.id())
.build();
outputs = flowInputOutput.typedOutputs(subflow, fake, outputs);
}
builder.outputs(outputs);
}
return builder.build();
}

// This method should only be used when getSubflow(RunContext) cannot be used.
private Flow fetchSubflow() {
ApplicationContext applicationContext = ContextHelper.context();
FlowExecutorInterface flowExecutor = applicationContext.getBean(FlowExecutorInterface.class);
FlowWithSource subflow = flowExecutor.findById(tenantId, namespace, flowId, Optional.ofNullable(revision)).orElseThrow(() -> new IllegalArgumentException("Unable to find flow " + namespace + "." + flowId));

if (subflow.isDisabled()) {
throw new IllegalStateException("Cannot execute a flow which is disabled");
}

if (subflow instanceof FlowWithException fwe) {
throw new IllegalStateException("Cannot execute an invalid flow: " + fwe.getException());
}

return subflow;
}

// This method is preferred as getSubflow() as it checks current flow and subflow and allowed namespaces
private Flow fetchSubflow(RunContext runContext) {
// we check that the task tenant is the current tenant to avoid accessing flows from another tenant
if (!Objects.equals(tenantId, runContext.flowInfo().tenantId())) {
throw new IllegalArgumentException("Cannot embeds a flow from a different tenant");
}

ApplicationContext applicationContext = ContextHelper.context();
FlowExecutorInterface flowExecutor = applicationContext.getBean(FlowExecutorInterface.class);
RunContext.FlowInfo flowInfo = runContext.flowInfo();

FlowWithSource flow = flowExecutor.findById(flowInfo.tenantId(), flowInfo.namespace(), flowInfo.id(), Optional.of(flowInfo.revision()))
.orElseThrow(() -> new IllegalArgumentException("Unable to find flow " + flowInfo.namespace() + "." + flowInfo.id()));
return ExecutableUtils.getSubflow(tenantId, namespace, flowId, Optional.ofNullable(revision), flowExecutor, flow);
}

/**
* Ugly hack to provide the ApplicationContext on {{@link #allChildTasks }} &amp; {{@link #tasksTree }}
* We need to inject a way to fetch embedded subflows ...
*/
@Singleton
static class ContextHelper {
@Inject
private ApplicationContext applicationContext;

private static ApplicationContext context;

static ApplicationContext context() {
return ContextHelper.context;
}

@EventListener
void onStartup(final StartupEvent event) {
ContextHelper.context = this.applicationContext;
}
}

@Builder
@Getter
public static class Output implements io.kestra.core.models.tasks.Output {
@Schema(
title = "The extracted outputs from the embedded subflow."
)
private final Map<String, Object> outputs;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.kestra.plugin.core.flow;

import io.kestra.core.junit.annotations.ExecuteFlow;
import io.kestra.core.junit.annotations.KestraTest;
import io.kestra.core.junit.annotations.LoadFlows;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.flows.State;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;

@KestraTest(startRunner = true)
class EmbeddedSubflowTest {
@Test
@LoadFlows("flows/valids/minimal.yaml")
@ExecuteFlow("flows/valids/embedded-flow.yaml")
void shouldEmbedTasks(Execution execution) throws Exception {
assertThat(execution.getState().getCurrent(), is(State.Type.SUCCESS));
assertThat(execution.getTaskRunList(), hasSize(2));
assertThat(execution.findTaskRunsByTaskId("embeddedFlow"), notNullValue());
assertThat(execution.findTaskRunsByTaskId("date"), notNullValue());
}

@Test
@LoadFlows({"flows/valids/minimal.yaml", "flows/valids/embedded-flow.yaml"})
@ExecuteFlow("flows/valids/embedded-parent.yaml")
void shouldEmbedTasksRecursively(Execution execution) throws Exception {
assertThat(execution.getState().getCurrent(), is(State.Type.SUCCESS));
assertThat(execution.getTaskRunList(), hasSize(3));
assertThat(execution.findTaskRunsByTaskId("embeddedParent"), notNullValue());
assertThat(execution.findTaskRunsByTaskId("embeddedFlow"), notNullValue());
assertThat(execution.findTaskRunsByTaskId("date"), notNullValue());
}
}
Loading

0 comments on commit 7b9b9bd

Please sign in to comment.