From 954f5f8d54a77b8361842432cb1a646416ea5128 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Fri, 22 Sep 2023 13:46:57 -0700 Subject: [PATCH 01/13] Initial UploadModel integration Signed-off-by: Owais Kazi --- .../workflow/UploadModel/UploadModelStep.java | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java diff --git a/src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java new file mode 100644 index 000000000..7d27cf8b3 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java @@ -0,0 +1,22 @@ +package org.opensearch.flowframework.workflow.UploadModel; + +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.util.List; +import java.util.concurrent.CompletableFuture; + +public class UploadModelStep implements WorkflowStep { + + private final String NAME = "upload_model_step"; + + @Override + public CompletableFuture execute(List data) { + return null; + } + + @Override + public String getName() { + return NAME; + } +} From 8d673e76476f0374c08ab4289cf2cd7f0784fae2 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Fri, 22 Sep 2023 18:49:40 -0700 Subject: [PATCH 02/13] Implemented Register Model Step Signed-off-by: Owais Kazi --- build.gradle | 3 + .../RegisterModel/RegisterModelStep.java | 145 ++++++++++++++++++ .../workflow/UploadModel/UploadModelStep.java | 22 --- .../RegisterModel/RegisterModelTests.java | 98 ++++++++++++ 4 files changed, 246 insertions(+), 22 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java delete mode 100644 src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java diff --git a/build.gradle b/build.gradle index 6265e5bd9..a460326c9 100644 --- a/build.gradle +++ b/build.gradle @@ -132,6 +132,9 @@ dependencies { implementation 'org.junit.jupiter:junit-jupiter:5.10.0' implementation "com.google.guava:guava:32.1.3-jre" api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" + implementation "com.google.code.gson:gson:2.10.1" + compileOnly "com.google.guava:guava:32.1.2-jre" + api group: 'org.opensearch', name:'opensearch-ml-client', version: "2.10.0.0-SNAPSHOT" configurations.all { resolutionStrategy { diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java new file mode 100644 index 000000000..906acfe74 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow.RegisterModel; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.flowframework.client.MLClient; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +public class RegisterModelStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); + + private Client client; + private final String NAME = "register_model_step"; + + private static final String FUNCTION_NAME = "function_name"; + private static final String MODEL_NAME = "model_name"; + private static final String MODEL_VERSION = "model_version"; + private static final String MODEL_GROUP_ID = "model_group_id"; + private static final String MODEL_URL = "url"; + private static final String MODEL_FORMAT = "model_format"; + private static final String MODEL_CONFIG = "model_config"; + private static final String DEPLOY_MODEL = "deploy_model"; + private static final String MODEL_NODES_IDS = "model_nodes_ids"; + + public RegisterModelStep(Client client) { + this.client = client; + } + + @Override + public CompletableFuture execute(List data) { + + CompletableFuture registerModelFuture = new CompletableFuture<>(); + + FunctionName functionName = null; + String modelName = null; + String modelVersion = null; + String modelGroupId = null; + String modelUrl = null; + MLModelFormat modelFormat = null; + String modelConfig = null; + Boolean deployModel = null; + String[] modelNodesId = null; + + for (WorkflowData workflowData : data) { + Map parameters = workflowData.getParams(); + Map content = workflowData.getContent(); + logger.info("Previous step sent params: {}, content: {}", parameters, content); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case FUNCTION_NAME: + functionName = (FunctionName) content.get(FUNCTION_NAME); + break; + case MODEL_NAME: + modelName = (String) content.get(MODEL_NAME); + break; + case MODEL_VERSION: + modelVersion = (String) content.get(MODEL_VERSION); + break; + case MODEL_GROUP_ID: + modelGroupId = (String) content.get(MODEL_GROUP_ID); + break; + case MODEL_URL: + modelUrl = (String) content.get(MODEL_URL); + break; + case MODEL_FORMAT: + modelFormat = (MLModelFormat) content.get(MODEL_FORMAT); + break; + case MODEL_CONFIG: + modelConfig = (String) content.get(MODEL_CONFIG); + break; + case DEPLOY_MODEL: + deployModel = (Boolean) content.get(DEPLOY_MODEL); + break; + case MODEL_NODES_IDS: + modelNodesId = (String[]) content.get(MODEL_NODES_IDS); + default: + break; + + } + } + } + + if (Stream.of(functionName, modelName, modelVersion, modelGroupId, modelConfig, modelFormat, deployModel, modelNodesId) + .allMatch(x -> x != null)) { + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient((NodeClient) client); + // TODO: Add model Config and type cast correctly + MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + .functionName(functionName) + .modelName(modelName) + .version(modelVersion) + .modelGroupId(modelGroupId) + .url(modelUrl) + .modelFormat(modelFormat) + .deployModel(deployModel) + .modelNodeIds(modelNodesId) + .build(); + + MLRegisterModelResponse mlRegisterModelResponse = machineLearningNodeClient.register(mlInput).actionGet(); + + registerModelFuture.complete(new WorkflowData() { + @Override + public Map getContent() { + return Map.ofEntries( + Map.entry("taskId", mlRegisterModelResponse.getTaskId()), + Map.entry("status", mlRegisterModelResponse.getStatus()) + ); + } + }); + + } else { + logger.error("Failed to register model"); + registerModelFuture.completeExceptionally(new IOException("Failed to register model ")); + } + return registerModelFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java deleted file mode 100644 index 7d27cf8b3..000000000 --- a/src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.opensearch.flowframework.workflow.UploadModel; - -import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.flowframework.workflow.WorkflowStep; - -import java.util.List; -import java.util.concurrent.CompletableFuture; - -public class UploadModelStep implements WorkflowStep { - - private final String NAME = "upload_model_step"; - - @Override - public CompletableFuture execute(List data) { - return null; - } - - @Override - public String getName() { - return NAME; - } -} diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java new file mode 100644 index 000000000..80839fee6 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow.RegisterModel; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; + +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.*; + +public class RegisterModelTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + + @Mock(answer = RETURNS_DEEP_STUBS) + NodeClient client; + + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + inputData = new WorkflowData() { + @Override + public Map getContent() { + return Map.ofEntries( + Map.entry("function_name", FunctionName.KMEANS), + Map.entry("model_name", "bedrock"), + Map.entry("model_version", "1.0.0"), + Map.entry("model_group_id", "1.0"), + Map.entry("url", "url"), + Map.entry("model_format", MLModelFormat.TORCH_SCRIPT), + Map.entry("deploy_model", true), + Map.entry("model_nodes_ids", new String[] { "foo", "bar", "baz" }) + ); + } + }; + + machineLearningNodeClient = mock(MachineLearningNodeClient.class); + + } + + public void testRegisterModel() { + + FunctionName functionName = FunctionName.KMEANS; + + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + .functionName(functionName) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); + + RegisterModelStep registerModelStep = new RegisterModelStep(client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); + CompletableFuture future = registerModelStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient, times(1)).register(mlInput); + assertEquals("1", (argumentCaptor.getValue()).getTaskId()); + + assertTrue(future.isDone()); + } + +} From 01fbb1e534bd78953dd766ed1f046dbfe9144337 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Fri, 29 Sep 2023 17:26:45 -0700 Subject: [PATCH 03/13] Integrated register for remote model Signed-off-by: Owais Kazi --- build.gradle | 3 -- .../RegisterModel/RegisterModelStep.java | 48 ++++++++----------- .../RegisterModel/RegisterModelTests.java | 46 +++++++++++------- 3 files changed, 48 insertions(+), 49 deletions(-) diff --git a/build.gradle b/build.gradle index a460326c9..6265e5bd9 100644 --- a/build.gradle +++ b/build.gradle @@ -132,9 +132,6 @@ dependencies { implementation 'org.junit.jupiter:junit-jupiter:5.10.0' implementation "com.google.guava:guava:32.1.3-jre" api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" - implementation "com.google.code.gson:gson:2.10.1" - compileOnly "com.google.guava:guava:32.1.2-jre" - api group: 'org.opensearch', name:'opensearch-ml-client', version: "2.10.0.0-SNAPSHOT" configurations.all { resolutionStrategy { diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java index 906acfe74..1c44d09a5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java @@ -17,6 +17,7 @@ import org.opensearch.flowframework.workflow.WorkflowStep; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; @@ -39,11 +40,10 @@ public class RegisterModelStep implements WorkflowStep { private static final String MODEL_NAME = "model_name"; private static final String MODEL_VERSION = "model_version"; private static final String MODEL_GROUP_ID = "model_group_id"; - private static final String MODEL_URL = "url"; + private static final String DESCRIPTION = "description"; + private static final String CONNECTOR_ID = "connector_id"; private static final String MODEL_FORMAT = "model_format"; private static final String MODEL_CONFIG = "model_config"; - private static final String DEPLOY_MODEL = "deploy_model"; - private static final String MODEL_NODES_IDS = "model_nodes_ids"; public RegisterModelStep(Client client) { this.client = client; @@ -58,11 +58,10 @@ public CompletableFuture execute(List data) { String modelName = null; String modelVersion = null; String modelGroupId = null; - String modelUrl = null; + String connectorId = null; + String description = null; MLModelFormat modelFormat = null; - String modelConfig = null; - Boolean deployModel = null; - String[] modelNodesId = null; + MLModelConfig modelConfig = null; for (WorkflowData workflowData : data) { Map parameters = workflowData.getParams(); @@ -83,20 +82,18 @@ public CompletableFuture execute(List data) { case MODEL_GROUP_ID: modelGroupId = (String) content.get(MODEL_GROUP_ID); break; - case MODEL_URL: - modelUrl = (String) content.get(MODEL_URL); - break; case MODEL_FORMAT: modelFormat = (MLModelFormat) content.get(MODEL_FORMAT); break; case MODEL_CONFIG: - modelConfig = (String) content.get(MODEL_CONFIG); + modelConfig = (MLModelConfig) content.get(MODEL_CONFIG); + break; + case DESCRIPTION: + description = (String) content.get(DESCRIPTION); break; - case DEPLOY_MODEL: - deployModel = (Boolean) content.get(DEPLOY_MODEL); + case CONNECTOR_ID: + connectorId = (String) content.get(CONNECTOR_ID); break; - case MODEL_NODES_IDS: - modelNodesId = (String[]) content.get(MODEL_NODES_IDS); default: break; @@ -104,8 +101,7 @@ public CompletableFuture execute(List data) { } } - if (Stream.of(functionName, modelName, modelVersion, modelGroupId, modelConfig, modelFormat, deployModel, modelNodesId) - .allMatch(x -> x != null)) { + if (Stream.of(functionName, modelName, modelVersion, modelGroupId, description, connectorId).allMatch(x -> x != null)) { MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient((NodeClient) client); // TODO: Add model Config and type cast correctly MLRegisterModelInput mlInput = MLRegisterModelInput.builder() @@ -113,23 +109,17 @@ public CompletableFuture execute(List data) { .modelName(modelName) .version(modelVersion) .modelGroupId(modelGroupId) - .url(modelUrl) .modelFormat(modelFormat) - .deployModel(deployModel) - .modelNodeIds(modelNodesId) + .modelConfig(modelConfig) + .description(description) + .connectorId(connectorId) .build(); MLRegisterModelResponse mlRegisterModelResponse = machineLearningNodeClient.register(mlInput).actionGet(); - registerModelFuture.complete(new WorkflowData() { - @Override - public Map getContent() { - return Map.ofEntries( - Map.entry("taskId", mlRegisterModelResponse.getTaskId()), - Map.entry("status", mlRegisterModelResponse.getStatus()) - ); - } - }); + registerModelFuture.complete( + new WorkflowData(Map.of("taskId", mlRegisterModelResponse.getTaskId(), "status", mlRegisterModelResponse.getStatus())) + ); } else { logger.error("Failed to register model"); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java index 80839fee6..8929b1495 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java @@ -41,21 +41,33 @@ public class RegisterModelTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); - inputData = new WorkflowData() { - @Override - public Map getContent() { - return Map.ofEntries( - Map.entry("function_name", FunctionName.KMEANS), - Map.entry("model_name", "bedrock"), - Map.entry("model_version", "1.0.0"), - Map.entry("model_group_id", "1.0"), - Map.entry("url", "url"), - Map.entry("model_format", MLModelFormat.TORCH_SCRIPT), - Map.entry("deploy_model", true), - Map.entry("model_nodes_ids", new String[] { "foo", "bar", "baz" }) - ); - } - }; + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + inputData = new WorkflowData( + Map.of( + "function_name", + FunctionName.KMEANS, + "model_name", + "bedrock", + "model_version", + "1.0.0", + "model_group_id", + "1.0", + "model_format", + MLModelFormat.TORCH_SCRIPT, + "model_config", + config, + "description", + "description", + "connector_id", + "abcdefgh" + ) + ); machineLearningNodeClient = mock(MachineLearningNodeClient.class); @@ -80,8 +92,8 @@ public void testRegisterModel() { .url("url") .modelFormat(MLModelFormat.ONNX) .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[] { "modelNodeIds" }) + .description("description") + .connectorId("abcdefgh") .build(); RegisterModelStep registerModelStep = new RegisterModelStep(client); From a389c878fa8669485663f45f5a9b2f380a7fb0fc Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Sun, 1 Oct 2023 02:02:34 -0700 Subject: [PATCH 04/13] Integrated deploy model Signed-off-by: Owais Kazi --- .../flowframework/workflow/DeployModel.java | 39 +++++++++++ .../flowframework/workflow/GetTask.java | 56 +++++++++++++++ ...p.java => RegisterAndDeployModelStep.java} | 70 ++++++++++++++----- ...a => RegisterAndDeployModelStepTests.java} | 49 ++++++++----- .../resources/template/finaltemplate.json | 16 ++++- 5 files changed, 196 insertions(+), 34 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/DeployModel.java create mode 100644 src/main/java/org/opensearch/flowframework/workflow/GetTask.java rename src/main/java/org/opensearch/flowframework/workflow/{RegisterModel/RegisterModelStep.java => RegisterAndDeployModelStep.java} (63%) rename src/test/java/org/opensearch/flowframework/workflow/{RegisterModel/RegisterModelTests.java => RegisterAndDeployModelStepTests.java} (61%) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java new file mode 100644 index 000000000..8bb6ec232 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; + +public class DeployModel { + private static final Logger logger = LogManager.getLogger(DeployModel.class); + + public void deployModel(MachineLearningNodeClient machineLearningNodeClient, String modelId) { + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLDeployModelResponse mlDeployModelResponse) { + if (mlDeployModelResponse.getStatus() == MLTaskState.COMPLETED.name()) { + logger.info("Model deployed successfully"); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Model deployment failed"); + } + }; + machineLearningNodeClient.deploy(modelId, actionListener); + + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetTask.java b/src/main/java/org/opensearch/flowframework/workflow/GetTask.java new file mode 100644 index 000000000..44ae253e2 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/GetTask.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.transport.task.MLTaskGetResponse; + +public class GetTask { + + private static final Logger logger = LogManager.getLogger(GetTask.class); + private MachineLearningNodeClient machineLearningNodeClient; + private String taskId; + + public GetTask(MachineLearningNodeClient machineLearningNodeClient, String taskId) { + this.machineLearningNodeClient = machineLearningNodeClient; + this.taskId = taskId; + } + + public void getTask() { + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLTask mlTask) { + if (mlTask.getState() == MLTaskState.COMPLETED) { + logger.info("Model registration successful"); + MLTaskGetResponse response = MLTaskGetResponse.builder().mlTask(mlTask).build(); + logger.info("Response from task {}", response); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Model registration failed"); + } + }; + + machineLearningNodeClient.getTask(taskId, actionListener); + + } + + /*@Override + public void run() { + getTask(); + }*/ +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStep.java similarity index 63% rename from src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java rename to src/main/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStep.java index 1c44d09a5..ef2e4e59c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStep.java @@ -6,21 +6,22 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ -package org.opensearch.flowframework.workflow.RegisterModel; +package org.opensearch.flowframework.workflow; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.client.MLClient; -import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.flowframework.workflow.WorkflowStep; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.util.List; @@ -29,12 +30,15 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; -public class RegisterModelStep implements WorkflowStep { +public class RegisterAndDeployModelStep implements WorkflowStep { - private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); + private static final Logger logger = LogManager.getLogger(RegisterAndDeployModelStep.class); private Client client; - private final String NAME = "register_model_step"; + private ThreadPool threadPool; + private volatile Scheduler.Cancellable scheduledFuture; + + static final String NAME = "register_model_step"; private static final String FUNCTION_NAME = "function_name"; private static final String MODEL_NAME = "model_name"; @@ -45,7 +49,7 @@ public class RegisterModelStep implements WorkflowStep { private static final String MODEL_FORMAT = "model_format"; private static final String MODEL_CONFIG = "model_config"; - public RegisterModelStep(Client client) { + public RegisterAndDeployModelStep(Client client) { this.client = client; } @@ -54,6 +58,45 @@ public CompletableFuture execute(List data) { CompletableFuture registerModelFuture = new CompletableFuture<>(); + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient((NodeClient) client); + + ActionListener actionListener = new ActionListener<>() { + @Override + public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { + + /*ActionListener deployActionListener = new ActionListener<>() { + @Override + public void onResponse(MLDeployModelResponse mlDeployModelResponse) { + if (mlDeployModelResponse.getStatus() == MLTaskState.COMPLETED.name()) { + logger.info("Model deployment successful"); + registerModelFuture.complete( + new WorkflowData(Map.ofEntries(Map.entry("modelId", mlRegisterModelResponse.getModelId()))) + ); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Model deployment failed"); + registerModelFuture.completeExceptionally(new IOException("Model deployment failed")); + } + }; + machineLearningNodeClient.deploy(mlRegisterModelResponse.getModelId(), deployActionListener);*/ + // scheduledFuture = threadPool.scheduleWithFixedDelay(new GetTask(machineLearningNodeClient, + // mlRegisterModelResponse.getTaskId()), TimeValue.timeValueMillis(10L), ThreadPool.Names.GENERIC); + + DeployModel deployModel = new DeployModel(); + deployModel.deployModel(machineLearningNodeClient, mlRegisterModelResponse.getModelId()); + + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to register model"); + registerModelFuture.completeExceptionally(new IOException("Failed to register model ")); + } + }; + FunctionName functionName = null; String modelName = null; String modelVersion = null; @@ -102,7 +145,7 @@ public CompletableFuture execute(List data) { } if (Stream.of(functionName, modelName, modelVersion, modelGroupId, description, connectorId).allMatch(x -> x != null)) { - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient((NodeClient) client); + // TODO: Add model Config and type cast correctly MLRegisterModelInput mlInput = MLRegisterModelInput.builder() .functionName(functionName) @@ -115,16 +158,9 @@ public CompletableFuture execute(List data) { .connectorId(connectorId) .build(); - MLRegisterModelResponse mlRegisterModelResponse = machineLearningNodeClient.register(mlInput).actionGet(); - - registerModelFuture.complete( - new WorkflowData(Map.of("taskId", mlRegisterModelResponse.getTaskId(), "status", mlRegisterModelResponse.getStatus())) - ); - - } else { - logger.error("Failed to register model"); - registerModelFuture.completeExceptionally(new IOException("Failed to register model ")); + machineLearningNodeClient.register(mlInput, actionListener); } + return registerModelFuture; } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStepTests.java similarity index 61% rename from src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java rename to src/test/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStepTests.java index 8929b1495..5d3ba1972 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStepTests.java @@ -6,12 +6,14 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ -package org.opensearch.flowframework.workflow.RegisterModel; +package org.opensearch.flowframework.workflow; import org.opensearch.client.node.NodeClient; -import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -22,20 +24,19 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; +import org.mockito.*; -import static org.mockito.Answers.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.*; -public class RegisterModelTests extends OpenSearchTestCase { +public class RegisterAndDeployModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock(answer = RETURNS_DEEP_STUBS) - NodeClient client; + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private NodeClient nodeClient; - MachineLearningNodeClient machineLearningNodeClient; + private MachineLearningNodeClient machineLearningNodeClient; @Override public void setUp() throws Exception { @@ -69,11 +70,11 @@ public void setUp() throws Exception { ) ); - machineLearningNodeClient = mock(MachineLearningNodeClient.class); + nodeClient = mock(NodeClient.class); } - public void testRegisterModel() { + public void testRegisterModel() throws ExecutionException, InterruptedException { FunctionName functionName = FunctionName.KMEANS; @@ -96,15 +97,31 @@ public void testRegisterModel() { .connectorId("abcdefgh") .build(); - RegisterModelStep registerModelStep = new RegisterModelStep(client); + RegisterAndDeployModelStep registerModelStep = new RegisterAndDeployModelStep(nodeClient); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = registerModelStep.execute(List.of(inputData)); - verify(machineLearningNodeClient, times(1)).register(mlInput); - assertEquals("1", (argumentCaptor.getValue()).getTaskId()); + assertFalse(future.isDone()); + + /*try (MockedStatic mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) { + mlClientMockedStatic + .when(() -> MLClient.createMLClient(any(NodeClient.class))) + .thenReturn(machineLearningNodeClient); + + }*/ + when(spy(MLClient.createMLClient(nodeClient))).thenReturn(machineLearningNodeClient); + verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + actionListenerCaptor.getValue().onResponse(new MLRegisterModelResponse("xyz", MLTaskState.COMPLETED.name(), "abc")); + + assertTrue(future.isDone() && !future.isCompletedExceptionally()); + + Map outputData = Map.of("index-name", "demo"); + + assertTrue(future.isDone() && future.isCompletedExceptionally()); + + assertEquals(outputData, future.get().getContent()); - assertTrue(future.isDone()); } } diff --git a/src/test/resources/template/finaltemplate.json b/src/test/resources/template/finaltemplate.json index fe1a57e36..88a7425f3 100644 --- a/src/test/resources/template/finaltemplate.json +++ b/src/test/resources/template/finaltemplate.json @@ -45,12 +45,26 @@ }], "node_timeout": "10s" } + }, + { + "id": "register_and_deploy", + "type": "register_and_deploy", + "inputs": { + "name": "openAI-gpt-3.5-turbo", + "function_name": "remote", + "description": "test model", + "connector_id": "uDna54oB76l1MtYJF84U" + } } ], "edges": [{ "source": "create_index", "dest": "create_ingest_pipeline" - }] + }, + { + "source": "create_ingest_pipeline", + "dest": "register_and_deploy" + }] }, "ingest": { "user_params": { From 3bcbc6a03eacc031747bcd627ff1c3a1ece86a1a Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Tue, 3 Oct 2023 10:56:08 -0700 Subject: [PATCH 05/13] Separated Register and Deploy Steps Signed-off-by: Owais Kazi --- .../flowframework/workflow/DeployModel.java | 48 +++++++- ...yModelStep.java => RegisterModelStep.java} | 104 +++++++++--------- .../workflow/WorkflowStepFactory.java | 2 + ...Tests.java => RegisterModelStepTests.java} | 60 +++------- src/test/resources/template/demo.json | 40 +++---- .../resources/template/finaltemplate.json | 3 +- 6 files changed, 140 insertions(+), 117 deletions(-) rename src/main/java/org/opensearch/flowframework/workflow/{RegisterAndDeployModelStep.java => RegisterModelStep.java} (63%) rename src/test/java/org/opensearch/flowframework/workflow/{RegisterAndDeployModelStepTests.java => RegisterModelStepTests.java} (64%) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java index 8bb6ec232..65ba93e58 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java @@ -10,30 +10,74 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; -public class DeployModel { +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public class DeployModel implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModel.class); - public void deployModel(MachineLearningNodeClient machineLearningNodeClient, String modelId) { + private NodeClient nodeClient; + private static final String MODEL_ID = "model_id"; + static final String NAME = "deploy_model"; + + public DeployModel(Client client) { + this.nodeClient = (NodeClient) client; + } + + @Override + public CompletableFuture execute(List data) { + + CompletableFuture deployModelFuture = new CompletableFuture<>(); + + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(nodeClient); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLDeployModelResponse mlDeployModelResponse) { if (mlDeployModelResponse.getStatus() == MLTaskState.COMPLETED.name()) { logger.info("Model deployed successfully"); + deployModelFuture.complete( + new WorkflowData(Map.ofEntries(Map.entry("deploy-model-status", mlDeployModelResponse.getStatus()))) + ); } } @Override public void onFailure(Exception e) { logger.error("Model deployment failed"); + deployModelFuture.completeExceptionally(e); } }; + + String modelId = null; + + for (WorkflowData workflowData : data) { + if (workflowData != null) { + Map content = workflowData.getContent(); + + for (Map.Entry entry : content.entrySet()) { + if (entry.getKey() == MODEL_ID) { + modelId = (String) content.get(MODEL_ID); + } + + } + } + } machineLearningNodeClient.deploy(modelId, actionListener); + return deployModelFuture; + } + @Override + public String getName() { + return NAME; } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java similarity index 63% rename from src/main/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStep.java rename to src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index ef2e4e59c..7d11e2f73 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -21,27 +21,26 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.threadpool.Scheduler; -import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; -public class RegisterAndDeployModelStep implements WorkflowStep { +public class RegisterModelStep implements WorkflowStep { - private static final Logger logger = LogManager.getLogger(RegisterAndDeployModelStep.class); + private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); - private Client client; - private ThreadPool threadPool; + private NodeClient nodeClient; private volatile Scheduler.Cancellable scheduledFuture; - static final String NAME = "register_model_step"; + static final String NAME = "register_model"; private static final String FUNCTION_NAME = "function_name"; - private static final String MODEL_NAME = "model_name"; + private static final String MODEL_NAME = "name"; private static final String MODEL_VERSION = "model_version"; private static final String MODEL_GROUP_ID = "model_group_id"; private static final String DESCRIPTION = "description"; @@ -49,8 +48,8 @@ public class RegisterAndDeployModelStep implements WorkflowStep { private static final String MODEL_FORMAT = "model_format"; private static final String MODEL_CONFIG = "model_config"; - public RegisterAndDeployModelStep(Client client) { - this.client = client; + public RegisterModelStep(Client client) { + this.nodeClient = (NodeClient) client; } @Override @@ -58,7 +57,7 @@ public CompletableFuture execute(List data) { CompletableFuture registerModelFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient((NodeClient) client); + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(nodeClient); ActionListener actionListener = new ActionListener<>() { @Override @@ -85,9 +84,17 @@ public void onFailure(Exception e) { // scheduledFuture = threadPool.scheduleWithFixedDelay(new GetTask(machineLearningNodeClient, // mlRegisterModelResponse.getTaskId()), TimeValue.timeValueMillis(10L), ThreadPool.Names.GENERIC); - DeployModel deployModel = new DeployModel(); - deployModel.deployModel(machineLearningNodeClient, mlRegisterModelResponse.getModelId()); - + /*DeployModel deployModel = new DeployModel(); + deployModel.deployModel(machineLearningNodeClient, mlRegisterModelResponse.getModelId());*/ + logger.info("Model registration successful"); + registerModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry("modelId", mlRegisterModelResponse.getModelId()), + Map.entry("model-register-status", mlRegisterModelResponse.getStatus()) + ) + ) + ); } @Override @@ -107,53 +114,50 @@ public void onFailure(Exception e) { MLModelConfig modelConfig = null; for (WorkflowData workflowData : data) { - Map parameters = workflowData.getParams(); - Map content = workflowData.getContent(); - logger.info("Previous step sent params: {}, content: {}", parameters, content); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case FUNCTION_NAME: - functionName = (FunctionName) content.get(FUNCTION_NAME); - break; - case MODEL_NAME: - modelName = (String) content.get(MODEL_NAME); - break; - case MODEL_VERSION: - modelVersion = (String) content.get(MODEL_VERSION); - break; - case MODEL_GROUP_ID: - modelGroupId = (String) content.get(MODEL_GROUP_ID); - break; - case MODEL_FORMAT: - modelFormat = (MLModelFormat) content.get(MODEL_FORMAT); - break; - case MODEL_CONFIG: - modelConfig = (MLModelConfig) content.get(MODEL_CONFIG); - break; - case DESCRIPTION: - description = (String) content.get(DESCRIPTION); - break; - case CONNECTOR_ID: - connectorId = (String) content.get(CONNECTOR_ID); - break; - default: - break; + if (workflowData != null) { + Map content = workflowData.getContent(); + logger.info("Previous step sent content: {}", content); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case FUNCTION_NAME: + functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); + break; + case MODEL_NAME: + modelName = (String) content.get(MODEL_NAME); + break; + case MODEL_VERSION: + modelVersion = (String) content.get(MODEL_VERSION); + break; + case MODEL_GROUP_ID: + modelGroupId = (String) content.get(MODEL_GROUP_ID); + break; + case MODEL_FORMAT: + modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); + break; + case MODEL_CONFIG: + modelConfig = (MLModelConfig) content.get(MODEL_CONFIG); + break; + case DESCRIPTION: + description = (String) content.get(DESCRIPTION); + break; + case CONNECTOR_ID: + connectorId = (String) content.get(CONNECTOR_ID); + break; + default: + break; + } } } } - if (Stream.of(functionName, modelName, modelVersion, modelGroupId, description, connectorId).allMatch(x -> x != null)) { + if (Stream.of(functionName, modelName, description, connectorId).allMatch(x -> x != null)) { // TODO: Add model Config and type cast correctly MLRegisterModelInput mlInput = MLRegisterModelInput.builder() .functionName(functionName) .modelName(modelName) - .version(modelVersion) - .modelGroupId(modelGroupId) - .modelFormat(modelFormat) - .modelConfig(modelConfig) .description(description) .connectorId(connectorId) .build(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 73468f5f6..a9ad01e15 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -38,6 +38,8 @@ public WorkflowStepFactory(ClusterService clusterService, Client client) { private void populateMap(ClusterService clusterService, Client client) { stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); + stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client)); + stepMap.put(DeployModel.NAME, new DeployModel(client)); // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java similarity index 64% rename from src/test/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStepTests.java rename to src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java index 5d3ba1972..13bf9b30c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -8,18 +8,19 @@ */ package org.opensearch.flowframework.workflow; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.client.NoOpNodeClient; import java.util.List; import java.util.Map; @@ -30,13 +31,15 @@ import static org.mockito.Mockito.*; -public class RegisterAndDeployModelStepTests extends OpenSearchTestCase { +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class RegisterModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; @Mock(answer = Answers.RETURNS_DEEP_STUBS) private NodeClient nodeClient; - private MachineLearningNodeClient machineLearningNodeClient; + @Mock + MachineLearningNodeClient machineLearningNodeClient; @Override public void setUp() throws Exception { @@ -49,68 +52,41 @@ public void setUp() throws Exception { .embeddingDimension(100) .build(); + MockitoAnnotations.openMocks(this); + inputData = new WorkflowData( - Map.of( - "function_name", - FunctionName.KMEANS, - "model_name", - "bedrock", - "model_version", - "1.0.0", - "model_group_id", - "1.0", - "model_format", - MLModelFormat.TORCH_SCRIPT, - "model_config", - config, - "description", - "description", - "connector_id", - "abcdefgh" + Map.ofEntries( + Map.entry("function_name", "remote"), + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry("connector_id", "abcdefg") ) ); - nodeClient = mock(NodeClient.class); - + nodeClient = new NoOpNodeClient("xyz"); } public void testRegisterModel() throws ExecutionException, InterruptedException { - FunctionName functionName = FunctionName.KMEANS; - - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - MLRegisterModelInput mlInput = MLRegisterModelInput.builder() - .functionName(functionName) + .functionName(FunctionName.from("REMOTE")) .modelName("testModelName") - .version("testModelVersion") - .modelGroupId("modelGroupId") - .url("url") - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) .description("description") .connectorId("abcdefgh") .build(); - RegisterAndDeployModelStep registerModelStep = new RegisterAndDeployModelStep(nodeClient); + RegisterModelStep registerModelStep = new RegisterModelStep(nodeClient); ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = registerModelStep.execute(List.of(inputData)); - assertFalse(future.isDone()); - /*try (MockedStatic mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) { mlClientMockedStatic .when(() -> MLClient.createMLClient(any(NodeClient.class))) .thenReturn(machineLearningNodeClient); }*/ - when(spy(MLClient.createMLClient(nodeClient))).thenReturn(machineLearningNodeClient); + // when(spy(MLClient.createMLClient(nodeClient))).thenReturn(machineLearningNodeClient); verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); actionListenerCaptor.getValue().onResponse(new MLRegisterModelResponse("xyz", MLTaskState.COMPLETED.name(), "abc")); diff --git a/src/test/resources/template/demo.json b/src/test/resources/template/demo.json index e27158bff..103afb92a 100644 --- a/src/test/resources/template/demo.json +++ b/src/test/resources/template/demo.json @@ -9,37 +9,33 @@ "nodes": [ { "id": "fetch_model", - "type": "demo_delay_3" + "type": "demo_delay_3", + "inputs": { + "ingest_key": "ingest_value" + } }, { - "id": "create_ingest_pipeline", - "type": "demo_delay_3" + "id": "register_model", + "type": "register_model", + "inputs": { + "name": "openAI-gpt-3.5-turbo", + "function_name": "remote", + "description": "test model", + "connector_id": "uDna54oB76l1MtYJF84U" + } }, { - "id": "create_search_pipeline", - "type": "demo_delay_5" - }, - { - "id": "create_neural_search_index", - "type": "demo_delay_3" + "id": "deploy_model", + "type": "deploy_model", + "inputs": { + "model_id": "abc" + } } ], "edges": [ { "source": "fetch_model", - "dest": "create_ingest_pipeline" - }, - { - "source": "fetch_model", - "dest": "create_search_pipeline" - }, - { - "source": "create_ingest_pipeline", - "dest": "create_neural_search_index" - }, - { - "source": "create_search_pipeline", - "dest": "create_neural_search_index" + "dest": "deploy_model" } ] } diff --git a/src/test/resources/template/finaltemplate.json b/src/test/resources/template/finaltemplate.json index 88a7425f3..689547d8d 100644 --- a/src/test/resources/template/finaltemplate.json +++ b/src/test/resources/template/finaltemplate.json @@ -16,7 +16,8 @@ }, "user_inputs": { "index_name": "my-knn-index", - "index_settings": {} + "index_settings": {}, + }, "workflows": { "provision": { From 022da959c4beea9a84184f37f796a41bc8591298 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 4 Oct 2023 14:02:26 -0700 Subject: [PATCH 06/13] Added tests Signed-off-by: Owais Kazi --- .../flowframework/workflow/DeployModel.java | 11 +- .../workflow/RegisterModelStep.java | 25 --- .../workflow/DeployModelTests.java | 81 +++++++ .../workflow/RegisterModelStepTests.java | 39 ++-- src/test/resources/template/demo.json | 49 ++++- .../resources/template/finaltemplate.json | 205 ++++++++---------- 6 files changed, 245 insertions(+), 165 deletions(-) create mode 100644 src/test/java/org/opensearch/flowframework/workflow/DeployModelTests.java diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java index 65ba93e58..3ca978f6a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java @@ -15,7 +15,6 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import java.util.List; @@ -43,12 +42,10 @@ public CompletableFuture execute(List data) { ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLDeployModelResponse mlDeployModelResponse) { - if (mlDeployModelResponse.getStatus() == MLTaskState.COMPLETED.name()) { - logger.info("Model deployed successfully"); - deployModelFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry("deploy-model-status", mlDeployModelResponse.getStatus()))) - ); - } + logger.info("Model deployment state {}", mlDeployModelResponse.getStatus()); + deployModelFuture.complete( + new WorkflowData(Map.ofEntries(Map.entry("deploy-model-status", mlDeployModelResponse.getStatus()))) + ); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index 7d11e2f73..050e51ad9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -62,30 +62,6 @@ public CompletableFuture execute(List data) { ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { - - /*ActionListener deployActionListener = new ActionListener<>() { - @Override - public void onResponse(MLDeployModelResponse mlDeployModelResponse) { - if (mlDeployModelResponse.getStatus() == MLTaskState.COMPLETED.name()) { - logger.info("Model deployment successful"); - registerModelFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry("modelId", mlRegisterModelResponse.getModelId()))) - ); - } - } - - @Override - public void onFailure(Exception e) { - logger.error("Model deployment failed"); - registerModelFuture.completeExceptionally(new IOException("Model deployment failed")); - } - }; - machineLearningNodeClient.deploy(mlRegisterModelResponse.getModelId(), deployActionListener);*/ - // scheduledFuture = threadPool.scheduleWithFixedDelay(new GetTask(machineLearningNodeClient, - // mlRegisterModelResponse.getTaskId()), TimeValue.timeValueMillis(10L), ThreadPool.Names.GENERIC); - - /*DeployModel deployModel = new DeployModel(); - deployModel.deployModel(machineLearningNodeClient, mlRegisterModelResponse.getModelId());*/ logger.info("Model registration successful"); registerModelFuture.complete( new WorkflowData( @@ -116,7 +92,6 @@ public void onFailure(Exception e) { for (WorkflowData workflowData : data) { if (workflowData != null) { Map content = workflowData.getContent(); - logger.info("Previous step sent content: {}", content); for (Entry entry : content.entrySet()) { switch (entry.getKey()) { diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelTests.java new file mode 100644 index 000000000..548f303cc --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelTests.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.client.NoOpNodeClient; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.Mockito.*; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class DeployModelTests extends OpenSearchTestCase { + + private WorkflowData inputData = WorkflowData.EMPTY; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private NodeClient nodeClient; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + inputData = new WorkflowData(Map.ofEntries(Map.entry("model_id", "modelId"))); + + MockitoAnnotations.openMocks(this); + + nodeClient = new NoOpNodeClient("xyz"); + + } + + public void testDeployModel() { + + String taskId = "taskId"; + String status = MLTaskState.CREATED.name(); + MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; + + DeployModel deployModel = new DeployModel(nodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + + CompletableFuture future = deployModel.execute(List.of(inputData)); + + // TODO: Find a way to verify the below + // verify(machineLearningNodeClient).deploy(eq(MLRegisterModelInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java index 13bf9b30c..b1a2b2fc0 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -27,9 +27,13 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import org.mockito.*; +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; -import static org.mockito.Mockito.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class RegisterModelStepTests extends OpenSearchTestCase { @@ -38,6 +42,9 @@ public class RegisterModelStepTests extends OpenSearchTestCase { @Mock(answer = Answers.RETURNS_DEEP_STUBS) private NodeClient nodeClient; + @Mock + ActionListener registerModelActionListener; + @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -68,6 +75,9 @@ public void setUp() throws Exception { public void testRegisterModel() throws ExecutionException, InterruptedException { + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); MLRegisterModelInput mlInput = MLRegisterModelInput.builder() .functionName(FunctionName.from("REMOTE")) .modelName("testModelName") @@ -78,25 +88,20 @@ public void testRegisterModel() throws ExecutionException, InterruptedException RegisterModelStep registerModelStep = new RegisterModelStep(nodeClient); ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = registerModelStep.execute(List.of(inputData)); - /*try (MockedStatic mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) { - mlClientMockedStatic - .when(() -> MLClient.createMLClient(any(NodeClient.class))) - .thenReturn(machineLearningNodeClient); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); - }*/ - // when(spy(MLClient.createMLClient(nodeClient))).thenReturn(machineLearningNodeClient); - verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); - actionListenerCaptor.getValue().onResponse(new MLRegisterModelResponse("xyz", MLTaskState.COMPLETED.name(), "abc")); - - assertTrue(future.isDone() && !future.isCompletedExceptionally()); - - Map outputData = Map.of("index-name", "demo"); + CompletableFuture future = registerModelStep.execute(List.of(inputData)); - assertTrue(future.isDone() && future.isCompletedExceptionally()); + // TODO: Find a way to verify the below + // verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); - assertEquals(outputData, future.get().getContent()); + assertTrue(future.isCompletedExceptionally()); } diff --git a/src/test/resources/template/demo.json b/src/test/resources/template/demo.json index 103afb92a..8719bf2fe 100644 --- a/src/test/resources/template/demo.json +++ b/src/test/resources/template/demo.json @@ -9,14 +9,27 @@ "nodes": [ { "id": "fetch_model", - "type": "demo_delay_3", - "inputs": { - "ingest_key": "ingest_value" - } + "type": "demo_delay_3" + }, + { + "id": "create_index", + "type": "demo_delay_3" + }, + { + "id": "create_ingest_pipeline", + "type": "demo_delay_3" + }, + { + "id": "create_search_pipeline", + "type": "demo_delay_5" + }, + { + "id": "create_neural_search_index", + "type": "demo_delay_3" }, { "id": "register_model", - "type": "register_model", + "type": "demo_delay_3", "inputs": { "name": "openAI-gpt-3.5-turbo", "function_name": "remote", @@ -26,7 +39,7 @@ }, { "id": "deploy_model", - "type": "deploy_model", + "type": "demo_delay_3", "inputs": { "model_id": "abc" } @@ -35,6 +48,30 @@ "edges": [ { "source": "fetch_model", + "dest": "create_index" + }, + { + "source": "create_index", + "dest": "create_ingest_pipeline" + }, + { + "source": "fetch_model", + "dest": "create_search_pipeline" + }, + { + "source": "create_ingest_pipeline", + "dest": "create_neural_search_index" + }, + { + "source": "create_search_pipeline", + "dest": "create_neural_search_index" + }, + { + "source": "create_neural_search_index", + "dest": "register_model" + }, + { + "source": "register_model", "dest": "deploy_model" } ] diff --git a/src/test/resources/template/finaltemplate.json b/src/test/resources/template/finaltemplate.json index 689547d8d..a950f069f 100644 --- a/src/test/resources/template/finaltemplate.json +++ b/src/test/resources/template/finaltemplate.json @@ -1,116 +1,101 @@ { - "name": "semantic-search", - "description": "My semantic search use case", - "use_case": "SEMANTIC_SEARCH", - "operations": [ - "PROVISION", - "INGEST", - "QUERY" - ], - "version": { - "template": "1.0.0", - "compatibility": [ - "2.9.0", - "3.0.0" - ] + "name": "semantic-search", + "description": "My semantic search use case", + "use_case": "SEMANTIC_SEARCH", + "operations": [ + "PROVISION", + "INGEST", + "QUERY" + ], + "version": { + "template": "1.0.0", + "compatibility": [ + "2.9.0", + "3.0.0" + ] + }, + "user_inputs": { + "index_name": "my-knn-index", + "index_settings": {} + }, + "workflows": { + "provision": { + "nodes": [{ + "id": "create_index", + "type": "create_index", + "inputs": { + "name": "user_inputs.index_name", + "settings": "user_inputs.index_settings", + "node_timeout": "10s" + } + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline", + "inputs": { + "name": "my-ingest-pipeline", + "description": "some description", + "processors": [{ + "type": "text_embedding", + "params": { + "model_id": "my-existing-model-id", + "input_field": "text_passage", + "output_field": "text_embedding" + } + }], + "node_timeout": "10s" + } + } + ], + "edges": [{ + "source": "create_index", + "dest": "create_ingest_pipeline" + }] }, - "user_inputs": { - "index_name": "my-knn-index", - "index_settings": {}, - + "ingest": { + "user_params": { + "document": "doc" + }, + "nodes": [{ + "id": "ingest_index", + "type": "ingest_index", + "inputs": { + "index": "user_inputs.index_name", + "ingest_pipeline": "my-ingest-pipeline", + "document": "user_params.document", + "node_timeout": "10s" + } + }] }, - "workflows": { - "provision": { - "nodes": [{ - "id": "create_index", - "type": "create_index", - "inputs": { - "name": "user_inputs.index_name", - "settings": "user_inputs.index_settings", - "node_timeout": "10s" - } - }, - { - "id": "create_ingest_pipeline", - "type": "create_ingest_pipeline", - "inputs": { - "name": "my-ingest-pipeline", - "description": "some description", - "processors": [{ - "type": "text_embedding", - "params": { - "model_id": "my-existing-model-id", - "input_field": "text_passage", - "output_field": "text_embedding" - } - }], - "node_timeout": "10s" - } - }, - { - "id": "register_and_deploy", - "type": "register_and_deploy", - "inputs": { - "name": "openAI-gpt-3.5-turbo", - "function_name": "remote", - "description": "test model", - "connector_id": "uDna54oB76l1MtYJF84U" - } - } - ], - "edges": [{ - "source": "create_index", - "dest": "create_ingest_pipeline" - }, - { - "source": "create_ingest_pipeline", - "dest": "register_and_deploy" - }] - }, - "ingest": { - "user_params": { - "document": "doc" - }, - "nodes": [{ - "id": "ingest_index", - "type": "ingest_index", - "inputs": { - "index": "user_inputs.index_name", - "ingest_pipeline": "my-ingest-pipeline", - "document": "user_params.document", - "node_timeout": "10s" - } - }] - }, - "query": { - "user_params": { - "plaintext": "string" - }, - "nodes": [{ - "id": "transform_query", - "type": "transform_query", - "inputs": { - "template": "neural-search-template-1", - "plaintext": "user_params.plaintext", - "node_timeout": "10s" - } - }, - { - "id": "query_index", - "type": "query_index", - "inputs": { - "index": "user_inputs.index_name", - "query": "{{output-from-prev-step}}.query", - "search_request_processors": [], - "search_response_processors": [], - "node_timeout": "10s" - } - } - ], - "edges": [{ - "source": "transform_query", - "dest": "query_index" - }] + "query": { + "user_params": { + "plaintext": "string" + }, + "nodes": [{ + "id": "transform_query", + "type": "transform_query", + "inputs": { + "template": "neural-search-template-1", + "plaintext": "user_params.plaintext", + "node_timeout": "10s" + } + }, + { + "id": "query_index", + "type": "query_index", + "inputs": { + "index": "user_inputs.index_name", + "query": "{{output-from-prev-step}}.query", + "search_request_processors": [], + "search_response_processors": [], + "node_timeout": "10s" + } } + ], + "edges": [{ + "source": "transform_query", + "dest": "query_index" + }] } + } } From 49078d78682176a46cfb1ee086f4adb967fd5d44 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 4 Oct 2023 14:25:46 -0700 Subject: [PATCH 07/13] Added NodeClient Signed-off-by: Owais Kazi --- src/main/java/demo/Demo.java | 3 ++- src/main/java/demo/TemplateParseDemo.java | 4 ++-- .../flowframework/FlowFrameworkPlugin.java | 7 ++++++- .../{DeployModel.java => DeployModelStep.java} | 9 ++++----- .../opensearch/flowframework/workflow/GetTask.java | 2 ++ .../flowframework/workflow/RegisterModelStep.java | 7 ++----- .../flowframework/workflow/WorkflowStepFactory.java | 12 +++++++----- .../flowframework/FlowFrameworkPluginTests.java | 5 ++++- ...ployModelTests.java => DeployModelStepTests.java} | 7 ++++--- .../workflow/WorkflowProcessSorterTests.java | 4 +++- 10 files changed, 36 insertions(+), 24 deletions(-) rename src/main/java/org/opensearch/flowframework/workflow/{DeployModel.java => DeployModelStep.java} (93%) rename src/test/java/org/opensearch/flowframework/workflow/{DeployModelTests.java => DeployModelStepTests.java} (91%) diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index 910f22b14..b4022f92f 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -59,7 +59,8 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + NodeClient nodeClient = new NodeClient(null, null); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, nodeClient); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java index e9bddb749..b1c2cad73 100644 --- a/src/main/java/demo/TemplateParseDemo.java +++ b/src/main/java/demo/TemplateParseDemo.java @@ -55,8 +55,8 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + NodeClient nodeClient = new NodeClient(null, null); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, nodeClient); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 5d9692006..ebe945c5a 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -10,8 +10,10 @@ import com.google.common.collect.ImmutableList; import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; @@ -51,7 +53,10 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client); + Settings settings = environment.settings(); + // TODO: Creating NodeClient is a temporary fix until we get the NodeClient from the provision API + NodeClient nodeClient = new NodeClient(settings, threadPool); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, nodeClient); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); return ImmutableList.of(workflowStepFactory, workflowProcessSorter); diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java similarity index 93% rename from src/main/java/org/opensearch/flowframework/workflow/DeployModel.java rename to src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 3ca978f6a..3a3918bce 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.client.MLClient; @@ -21,15 +20,15 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; -public class DeployModel implements WorkflowStep { - private static final Logger logger = LogManager.getLogger(DeployModel.class); +public class DeployModelStep implements WorkflowStep { + private static final Logger logger = LogManager.getLogger(DeployModelStep.class); private NodeClient nodeClient; private static final String MODEL_ID = "model_id"; static final String NAME = "deploy_model"; - public DeployModel(Client client) { - this.nodeClient = (NodeClient) client; + public DeployModelStep(NodeClient nodeClient) { + this.nodeClient = nodeClient; } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetTask.java b/src/main/java/org/opensearch/flowframework/workflow/GetTask.java index 44ae253e2..9394b55ec 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/GetTask.java +++ b/src/main/java/org/opensearch/flowframework/workflow/GetTask.java @@ -10,12 +10,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.common.SuppressForbidden; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +@SuppressForbidden(reason = "This class is for the future work of registering local model") public class GetTask { private static final Logger logger = LogManager.getLogger(GetTask.class); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index 050e51ad9..8f59ce5c4 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.client.MLClient; @@ -20,7 +19,6 @@ import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import org.opensearch.threadpool.Scheduler; import java.io.IOException; import java.util.List; @@ -35,7 +33,6 @@ public class RegisterModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); private NodeClient nodeClient; - private volatile Scheduler.Cancellable scheduledFuture; static final String NAME = "register_model"; @@ -48,8 +45,8 @@ public class RegisterModelStep implements WorkflowStep { private static final String MODEL_FORMAT = "model_format"; private static final String MODEL_CONFIG = "model_config"; - public RegisterModelStep(Client client) { - this.nodeClient = (NodeClient) client; + public RegisterModelStep(NodeClient nodeClient) { + this.nodeClient = nodeClient; } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index a9ad01e15..8d1e54031 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -9,6 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import java.util.HashMap; @@ -31,15 +32,16 @@ public class WorkflowStepFactory { * @param clusterService The OpenSearch cluster service * @param client The OpenSearch client steps can use */ - public WorkflowStepFactory(ClusterService clusterService, Client client) { - populateMap(clusterService, client); + + public WorkflowStepFactory(ClusterService clusterService, Client client, NodeClient nodeClient) { + populateMap(clusterService, client, nodeClient); } - private void populateMap(ClusterService clusterService, Client client) { + private void populateMap(ClusterService clusterService, Client client, NodeClient nodeClient) { stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); - stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client)); - stepMap.put(DeployModel.NAME, new DeployModel(client)); + stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(nodeClient)); + stepMap.put(DeployModelStep.NAME, new DeployModelStep(nodeClient)); // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index d211e3928..47500693a 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -10,6 +10,8 @@ import org.opensearch.client.AdminClient; import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -41,7 +43,8 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { - assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size()); + Environment env = new Environment(Settings.builder().put("path.home", "dummy").build(), null); + assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, env, null, null, null, null).size()); } } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java similarity index 91% rename from src/test/java/org/opensearch/flowframework/workflow/DeployModelTests.java rename to src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 548f303cc..44f51f65c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -28,10 +28,11 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.mockito.Mockito.*; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) -public class DeployModelTests extends OpenSearchTestCase { +public class DeployModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; @@ -59,7 +60,7 @@ public void testDeployModel() { String status = MLTaskState.CREATED.name(); MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; - DeployModel deployModel = new DeployModel(nodeClient); + DeployModelStep deployModel = new DeployModelStep(nodeClient); ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index eab29121d..e8aadf51a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -10,6 +10,7 @@ import org.opensearch.client.AdminClient; import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.model.TemplateTestJsonUtil; @@ -60,10 +61,11 @@ public static void setup() { AdminClient adminClient = mock(AdminClient.class); ClusterService clusterService = mock(ClusterService.class); Client client = mock(Client.class); + NodeClient nodeClient = mock(NodeClient.class); when(client.admin()).thenReturn(adminClient); testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, nodeClient); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); } From 570173e41b31f4649f276d5ac6dbba4351ac02ed Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 4 Oct 2023 15:27:40 -0700 Subject: [PATCH 08/13] Added javadocs Signed-off-by: Owais Kazi --- .../flowframework/workflow/DeployModelStep.java | 7 +++++++ .../flowframework/workflow/GetTask.java | 15 +++++++++++---- .../flowframework/workflow/RegisterModelStep.java | 7 +++++++ .../workflow/WorkflowStepFactory.java | 1 + 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 3a3918bce..43ed8cbb1 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -20,6 +20,9 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; +/** + * Step to deploy a model + */ public class DeployModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); @@ -27,6 +30,10 @@ public class DeployModelStep implements WorkflowStep { private static final String MODEL_ID = "model_id"; static final String NAME = "deploy_model"; + /** + * Instantiate this class + * @param nodeClient client to instantiate MLClient + */ public DeployModelStep(NodeClient nodeClient) { this.nodeClient = nodeClient; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/GetTask.java b/src/main/java/org/opensearch/flowframework/workflow/GetTask.java index 9394b55ec..a3d1caa4e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/GetTask.java +++ b/src/main/java/org/opensearch/flowframework/workflow/GetTask.java @@ -17,6 +17,9 @@ import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +/** + * Step to get modelID of a registered local model + */ @SuppressForbidden(reason = "This class is for the future work of registering local model") public class GetTask { @@ -24,11 +27,19 @@ public class GetTask { private MachineLearningNodeClient machineLearningNodeClient; private String taskId; + /** + * Instantiate this class + * @param machineLearningNodeClient client to instantiate ml-commons APIs + * @param taskId taskID of the model + */ public GetTask(MachineLearningNodeClient machineLearningNodeClient, String taskId) { this.machineLearningNodeClient = machineLearningNodeClient; this.taskId = taskId; } + /** + * Invokes get task API of ml-commons + */ public void getTask() { ActionListener actionListener = new ActionListener<>() { @@ -51,8 +62,4 @@ public void onFailure(Exception e) { } - /*@Override - public void run() { - getTask(); - }*/ } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index 8f59ce5c4..74a811764 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -28,6 +28,9 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; +/** + * Step to register a remote model + */ public class RegisterModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); @@ -45,6 +48,10 @@ public class RegisterModelStep implements WorkflowStep { private static final String MODEL_FORMAT = "model_format"; private static final String MODEL_CONFIG = "model_config"; + /** + * Instantiate this class + * @param nodeClient client to instantiate MLClient + */ public RegisterModelStep(NodeClient nodeClient) { this.nodeClient = nodeClient; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 8d1e54031..d15d41341 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -31,6 +31,7 @@ public class WorkflowStepFactory { * * @param clusterService The OpenSearch cluster service * @param client The OpenSearch client steps can use + * @param nodeClient NodeClient to execute transport calls */ public WorkflowStepFactory(ClusterService clusterService, Client client, NodeClient nodeClient) { From d59087fe71e1a1000fbad7172a47c11a0d919693 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Thu, 5 Oct 2023 14:50:27 -0700 Subject: [PATCH 09/13] Addressed PR comments Signed-off-by: Owais Kazi --- .../opensearch/flowframework/workflow/DeployModelStep.java | 2 +- .../opensearch/flowframework/workflow/RegisterModelStep.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 43ed8cbb1..33594583b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -50,7 +50,7 @@ public CompletableFuture execute(List data) { public void onResponse(MLDeployModelResponse mlDeployModelResponse) { logger.info("Model deployment state {}", mlDeployModelResponse.getStatus()); deployModelFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry("deploy-model-status", mlDeployModelResponse.getStatus()))) + new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus()))) ); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index 74a811764..8faf736eb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -70,8 +70,8 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { registerModelFuture.complete( new WorkflowData( Map.ofEntries( - Map.entry("modelId", mlRegisterModelResponse.getModelId()), - Map.entry("model-register-status", mlRegisterModelResponse.getStatus()) + Map.entry("model_id", mlRegisterModelResponse.getModelId()), + Map.entry("register_model_status", mlRegisterModelResponse.getStatus()) ) ) ); From 58c3f7afdfb1ef2a90ddc8d0c9e713e41615a235 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Mon, 9 Oct 2023 12:03:11 -0700 Subject: [PATCH 10/13] Addressed PR comments Signed-off-by: Owais Kazi --- src/main/java/demo/DemoWorkflowStep.java | 2 +- .../workflow/DeployModelStep.java | 13 ++-- .../workflow/RegisterModelStep.java | 64 +++++++++---------- 3 files changed, 37 insertions(+), 42 deletions(-) diff --git a/src/main/java/demo/DemoWorkflowStep.java b/src/main/java/demo/DemoWorkflowStep.java index 037d9b6f6..267a8c8ab 100644 --- a/src/main/java/demo/DemoWorkflowStep.java +++ b/src/main/java/demo/DemoWorkflowStep.java @@ -37,7 +37,7 @@ public CompletableFuture execute(List data) { CompletableFuture.runAsync(() -> { try { Thread.sleep(this.delay); - future.complete(null); + future.complete(WorkflowData.EMPTY); } catch (InterruptedException e) { future.completeExceptionally(e); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 33594583b..82e12af1e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -64,15 +64,12 @@ public void onFailure(Exception e) { String modelId = null; for (WorkflowData workflowData : data) { - if (workflowData != null) { - Map content = workflowData.getContent(); - - for (Map.Entry entry : content.entrySet()) { - if (entry.getKey() == MODEL_ID) { - modelId = (String) content.get(MODEL_ID); - } - + Map content = workflowData.getContent(); + for (Map.Entry entry : content.entrySet()) { + if (entry.getKey() == MODEL_ID) { + modelId = (String) content.get(MODEL_ID); } + } } machineLearningNodeClient.deploy(modelId, actionListener); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index 8faf736eb..f42d52b99 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -94,39 +94,37 @@ public void onFailure(Exception e) { MLModelConfig modelConfig = null; for (WorkflowData workflowData : data) { - if (workflowData != null) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case FUNCTION_NAME: - functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); - break; - case MODEL_NAME: - modelName = (String) content.get(MODEL_NAME); - break; - case MODEL_VERSION: - modelVersion = (String) content.get(MODEL_VERSION); - break; - case MODEL_GROUP_ID: - modelGroupId = (String) content.get(MODEL_GROUP_ID); - break; - case MODEL_FORMAT: - modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); - break; - case MODEL_CONFIG: - modelConfig = (MLModelConfig) content.get(MODEL_CONFIG); - break; - case DESCRIPTION: - description = (String) content.get(DESCRIPTION); - break; - case CONNECTOR_ID: - connectorId = (String) content.get(CONNECTOR_ID); - break; - default: - break; - - } + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case FUNCTION_NAME: + functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); + break; + case MODEL_NAME: + modelName = (String) content.get(MODEL_NAME); + break; + case MODEL_VERSION: + modelVersion = (String) content.get(MODEL_VERSION); + break; + case MODEL_GROUP_ID: + modelGroupId = (String) content.get(MODEL_GROUP_ID); + break; + case MODEL_FORMAT: + modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); + break; + case MODEL_CONFIG: + modelConfig = (MLModelConfig) content.get(MODEL_CONFIG); + break; + case DESCRIPTION: + description = (String) content.get(DESCRIPTION); + break; + case CONNECTOR_ID: + connectorId = (String) content.get(CONNECTOR_ID); + break; + default: + break; + } } } From 43138782682aae0c713416fffcddfb472f328428 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Mon, 9 Oct 2023 15:03:48 -0700 Subject: [PATCH 11/13] Addressed PR comments - 2 Signed-off-by: Owais Kazi --- src/main/java/demo/Demo.java | 3 +-- src/main/java/demo/TemplateParseDemo.java | 3 +-- .../opensearch/flowframework/FlowFrameworkPlugin.java | 7 +++---- .../flowframework/workflow/WorkflowStepFactory.java | 11 +++++------ .../flowframework/FlowFrameworkPluginTests.java | 5 +---- .../workflow/WorkflowProcessSorterTests.java | 4 +--- 6 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index b4022f92f..910f22b14 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -59,8 +59,7 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - NodeClient nodeClient = new NodeClient(null, null); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, nodeClient); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java index b1c2cad73..a2d0db443 100644 --- a/src/main/java/demo/TemplateParseDemo.java +++ b/src/main/java/demo/TemplateParseDemo.java @@ -55,8 +55,7 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - NodeClient nodeClient = new NodeClient(null, null); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, nodeClient); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index ebe945c5a..fde7028e8 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -13,7 +13,6 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; @@ -34,6 +33,8 @@ */ public class FlowFrameworkPlugin extends Plugin { + private NodeClient client; + /** * Instantiate this plugin. */ @@ -53,10 +54,8 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - Settings settings = environment.settings(); // TODO: Creating NodeClient is a temporary fix until we get the NodeClient from the provision API - NodeClient nodeClient = new NodeClient(settings, threadPool); - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, nodeClient); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); return ImmutableList.of(workflowStepFactory, workflowProcessSorter); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index d15d41341..0976f2bfb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -31,18 +31,17 @@ public class WorkflowStepFactory { * * @param clusterService The OpenSearch cluster service * @param client The OpenSearch client steps can use - * @param nodeClient NodeClient to execute transport calls */ - public WorkflowStepFactory(ClusterService clusterService, Client client, NodeClient nodeClient) { - populateMap(clusterService, client, nodeClient); + public WorkflowStepFactory(ClusterService clusterService, Client client) { + populateMap(clusterService, client); } - private void populateMap(ClusterService clusterService, Client client, NodeClient nodeClient) { + private void populateMap(ClusterService clusterService, Client client) { stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); - stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(nodeClient)); - stepMap.put(DeployModelStep.NAME, new DeployModelStep(nodeClient)); + stepMap.put(RegisterModelStep.NAME, new RegisterModelStep((NodeClient) client)); + stepMap.put(DeployModelStep.NAME, new DeployModelStep((NodeClient) client)); // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index 47500693a..d211e3928 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -10,8 +10,6 @@ import org.opensearch.client.AdminClient; import org.opensearch.client.Client; -import org.opensearch.common.settings.Settings; -import org.opensearch.env.Environment; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -43,8 +41,7 @@ public void tearDown() throws Exception { public void testPlugin() throws IOException { try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) { - Environment env = new Environment(Settings.builder().put("path.home", "dummy").build(), null); - assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, env, null, null, null, null).size()); + assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size()); } } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index e8aadf51a..eab29121d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -10,7 +10,6 @@ import org.opensearch.client.AdminClient; import org.opensearch.client.Client; -import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.model.TemplateTestJsonUtil; @@ -61,11 +60,10 @@ public static void setup() { AdminClient adminClient = mock(AdminClient.class); ClusterService clusterService = mock(ClusterService.class); Client client = mock(Client.class); - NodeClient nodeClient = mock(NodeClient.class); when(client.admin()).thenReturn(adminClient); testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, nodeClient); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); } From 0d894bbabffaaac8a981d9ee2279df12a5a729ee Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Tue, 10 Oct 2023 11:45:59 -0700 Subject: [PATCH 12/13] Fixed test failure Signed-off-by: Owais Kazi --- build.gradle | 3 ++ gradle.properties | 38 +++++++++++++++++++ .../flowframework/FlowFrameworkPlugin.java | 4 -- .../flowframework/client/MLClient.java | 8 ++-- .../workflow/DeployModelStep.java | 12 +++--- .../workflow/RegisterModelStep.java | 12 +++--- .../workflow/WorkflowStepFactory.java | 5 +-- .../FlowFrameworkPluginTests.java | 12 +++++- .../workflow/DeployModelStepTests.java | 2 +- .../workflow/WorkflowProcessSorterTests.java | 1 + 10 files changed, 72 insertions(+), 25 deletions(-) create mode 100644 gradle.properties diff --git a/build.gradle b/build.gradle index 6265e5bd9..36e2a6605 100644 --- a/build.gradle +++ b/build.gradle @@ -112,6 +112,9 @@ publishing { allprojects { group = opensearch_group version = "${opensearch_build}" +} + +java { targetCompatibility = JavaVersion.VERSION_11 sourceCompatibility = JavaVersion.VERSION_11 } diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 000000000..200c21212 --- /dev/null +++ b/gradle.properties @@ -0,0 +1,38 @@ +# +# SPDX-License-Identifier: Apache-2.0 +# +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# +# Modifications Copyright OpenSearch Contributors. See +# GitHub history for details. +# + + +# Enable build caching +org.gradle.caching=true +org.gradle.warning.mode=none +org.gradle.parallel=true +# Workaround for https://github.com/diffplug/spotless/issues/834 +org.gradle.jvmargs=-Xmx3g -XX:+HeapDumpOnOutOfMemoryError -Xss2m \ + --add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ + --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED +options.forkOptions.memoryMaximumSize=2g + +# Disable duplicate project id detection +# See https://docs.gradle.org/current/userguide/upgrading_version_6.html#duplicate_project_names_may_cause_publication_to_fail +systemProp.org.gradle.dependency.duplicate.project.detection=false + +# Enforce the build to fail on deprecated gradle api usage +systemProp.org.gradle.warning.mode=fail + +# forcing to use TLS1.2 to avoid failure in vault +# see https://github.com/hashicorp/vault/issues/8750#issuecomment-631236121 +systemProp.jdk.tls.client.protocols=TLSv1.2 + +# jvm args for faster test execution by default +systemProp.tests.jvm.argline=-XX:TieredStopAtLevel=1 -XX:ReservedCodeCacheSize=64m diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index fde7028e8..5d9692006 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -10,7 +10,6 @@ import com.google.common.collect.ImmutableList; import org.opensearch.client.Client; -import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; @@ -33,8 +32,6 @@ */ public class FlowFrameworkPlugin extends Plugin { - private NodeClient client; - /** * Instantiate this plugin. */ @@ -54,7 +51,6 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - // TODO: Creating NodeClient is a temporary fix until we get the NodeClient from the provision API WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); diff --git a/src/main/java/org/opensearch/flowframework/client/MLClient.java b/src/main/java/org/opensearch/flowframework/client/MLClient.java index a1ef7d61e..977e24588 100644 --- a/src/main/java/org/opensearch/flowframework/client/MLClient.java +++ b/src/main/java/org/opensearch/flowframework/client/MLClient.java @@ -8,7 +8,7 @@ */ package org.opensearch.flowframework.client; -import org.opensearch.client.node.NodeClient; +import org.opensearch.client.Client; import org.opensearch.ml.client.MachineLearningNodeClient; /** @@ -22,12 +22,12 @@ private MLClient() {} /** * Creates machine learning client. * - * @param nodeClient node client of OpenSearch. + * @param client client of OpenSearch. * @return machine learning client from ml-commons. */ - public static MachineLearningNodeClient createMLClient(NodeClient nodeClient) { + public static MachineLearningNodeClient createMLClient(Client client) { if (INSTANCE == null) { - INSTANCE = new MachineLearningNodeClient(nodeClient); + INSTANCE = new MachineLearningNodeClient(client); } return INSTANCE; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 82e12af1e..81024785d 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -10,7 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.node.NodeClient; +import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -26,16 +26,16 @@ public class DeployModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); - private NodeClient nodeClient; + private Client client; private static final String MODEL_ID = "model_id"; static final String NAME = "deploy_model"; /** * Instantiate this class - * @param nodeClient client to instantiate MLClient + * @param client client to instantiate MLClient */ - public DeployModelStep(NodeClient nodeClient) { - this.nodeClient = nodeClient; + public DeployModelStep(Client client) { + this.client = client; } @Override @@ -43,7 +43,7 @@ public CompletableFuture execute(List data) { CompletableFuture deployModelFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(nodeClient); + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); ActionListener actionListener = new ActionListener<>() { @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index f42d52b99..b9b071a69 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -10,7 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.node.NodeClient; +import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -35,7 +35,7 @@ public class RegisterModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); - private NodeClient nodeClient; + private Client client; static final String NAME = "register_model"; @@ -50,10 +50,10 @@ public class RegisterModelStep implements WorkflowStep { /** * Instantiate this class - * @param nodeClient client to instantiate MLClient + * @param client client to instantiate MLClient */ - public RegisterModelStep(NodeClient nodeClient) { - this.nodeClient = nodeClient; + public RegisterModelStep(Client client) { + this.client = client; } @Override @@ -61,7 +61,7 @@ public CompletableFuture execute(List data) { CompletableFuture registerModelFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(nodeClient); + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); ActionListener actionListener = new ActionListener<>() { @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 0976f2bfb..fdb82ef0b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -9,7 +9,6 @@ package org.opensearch.flowframework.workflow; import org.opensearch.client.Client; -import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import java.util.HashMap; @@ -40,8 +39,8 @@ public WorkflowStepFactory(ClusterService clusterService, Client client) { private void populateMap(ClusterService clusterService, Client client) { stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); - stepMap.put(RegisterModelStep.NAME, new RegisterModelStep((NodeClient) client)); - stepMap.put(DeployModelStep.NAME, new DeployModelStep((NodeClient) client)); + stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client)); + stepMap.put(DeployModelStep.NAME, new DeployModelStep(client)); // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java index d211e3928..ea8a3b520 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginTests.java @@ -10,6 +10,8 @@ import org.opensearch.client.AdminClient; import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.client.node.NodeClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -23,13 +25,21 @@ public class FlowFrameworkPluginTests extends OpenSearchTestCase { private Client client; + private NodeClient nodeClient; + + private AdminClient adminClient; + + private ClusterAdminClient clusterAdminClient; private ThreadPool threadPool; @Override public void setUp() throws Exception { super.setUp(); client = mock(Client.class); - when(client.admin()).thenReturn(mock(AdminClient.class)); + adminClient = mock(AdminClient.class); + clusterAdminClient = mock(ClusterAdminClient.class); + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); threadPool = new TestThreadPool(FlowFrameworkPluginTests.class.getName()); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 44f51f65c..87db208c2 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -74,7 +74,7 @@ public void testDeployModel() { CompletableFuture future = deployModel.execute(List.of(inputData)); // TODO: Find a way to verify the below - // verify(machineLearningNodeClient).deploy(eq(MLRegisterModelInput.class), actionListenerCaptor.capture()); + // verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); assertTrue(future.isCompletedExceptionally()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index eab29121d..e8ada0e15 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -60,6 +60,7 @@ public static void setup() { AdminClient adminClient = mock(AdminClient.class); ClusterService clusterService = mock(ClusterService.class); Client client = mock(Client.class); + when(client.admin()).thenReturn(adminClient); testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); From e0cd1cfdedbd1322fff1a96642208f0c3a486207 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 11 Oct 2023 15:12:02 -0700 Subject: [PATCH 13/13] Addressed PR comments Signed-off-by: Owais Kazi --- .../flowframework/common/CommonValue.java | 9 +++++++++ .../workflow/DeployModelStep.java | 12 +++++------- .../workflow/RegisterModelStep.java | 18 +++++++++--------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index a8fdf2929..0bf8ae890 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -19,4 +19,13 @@ public class CommonValue { public static final String GLOBAL_CONTEXT_INDEX = ".plugins-ai-global-context"; public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json"; public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1; + public static final String MODEL_ID = "model_id"; + public static final String FUNCTION_NAME = "function_name"; + public static final String MODEL_NAME = "name"; + public static final String MODEL_VERSION = "model_version"; + public static final String MODEL_GROUP_ID = "model_group_id"; + public static final String DESCRIPTION = "description"; + public static final String CONNECTOR_ID = "connector_id"; + public static final String MODEL_FORMAT = "model_format"; + public static final String MODEL_CONFIG = "model_config"; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 81024785d..e4c9b1a14 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -20,6 +20,8 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; +import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; + /** * Step to deploy a model */ @@ -27,7 +29,6 @@ public class DeployModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); private Client client; - private static final String MODEL_ID = "model_id"; static final String NAME = "deploy_model"; /** @@ -64,12 +65,9 @@ public void onFailure(Exception e) { String modelId = null; for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - for (Map.Entry entry : content.entrySet()) { - if (entry.getKey() == MODEL_ID) { - modelId = (String) content.get(MODEL_ID); - } - + if (workflowData.getContent().containsKey(MODEL_ID)) { + modelId = (String) workflowData.getContent().get(MODEL_ID); + break; } } machineLearningNodeClient.deploy(modelId, actionListener); diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index b9b071a69..b97c56d57 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -28,6 +28,15 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION; +import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; +import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG; +import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; +import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; +import static org.opensearch.flowframework.common.CommonValue.MODEL_NAME; +import static org.opensearch.flowframework.common.CommonValue.MODEL_VERSION; + /** * Step to register a remote model */ @@ -39,15 +48,6 @@ public class RegisterModelStep implements WorkflowStep { static final String NAME = "register_model"; - private static final String FUNCTION_NAME = "function_name"; - private static final String MODEL_NAME = "name"; - private static final String MODEL_VERSION = "model_version"; - private static final String MODEL_GROUP_ID = "model_group_id"; - private static final String DESCRIPTION = "description"; - private static final String CONNECTOR_ID = "connector_id"; - private static final String MODEL_FORMAT = "model_format"; - private static final String MODEL_CONFIG = "model_config"; - /** * Instantiate this class * @param client client to instantiate MLClient