Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Register and Deploy Model Step for remote model #52

Merged
merged 13 commits into from
Oct 12, 2023
3 changes: 3 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ publishing {
allprojects {
group = opensearch_group
version = "${opensearch_build}"
}

java {
targetCompatibility = JavaVersion.VERSION_11
sourceCompatibility = JavaVersion.VERSION_11
}
Expand Down
38 changes: 38 additions & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#


# Enable build caching
org.gradle.caching=true
org.gradle.warning.mode=none
org.gradle.parallel=true
# Workaround for https://github.com/diffplug/spotless/issues/834
org.gradle.jvmargs=-Xmx3g -XX:+HeapDumpOnOutOfMemoryError -Xss2m \
--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED
options.forkOptions.memoryMaximumSize=2g

# Disable duplicate project id detection
# See https://docs.gradle.org/current/userguide/upgrading_version_6.html#duplicate_project_names_may_cause_publication_to_fail
systemProp.org.gradle.dependency.duplicate.project.detection=false

# Enforce the build to fail on deprecated gradle api usage
systemProp.org.gradle.warning.mode=fail

# forcing to use TLS1.2 to avoid failure in vault
# see https://github.com/hashicorp/vault/issues/8750#issuecomment-631236121
systemProp.jdk.tls.client.protocols=TLSv1.2

# jvm args for faster test execution by default
systemProp.tests.jvm.argline=-XX:TieredStopAtLevel=1 -XX:ReservedCodeCacheSize=64m
2 changes: 1 addition & 1 deletion src/main/java/demo/DemoWorkflowStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
CompletableFuture.runAsync(() -> {
try {
Thread.sleep(this.delay);
future.complete(null);
future.complete(WorkflowData.EMPTY);
} catch (InterruptedException e) {
future.completeExceptionally(e);
}
Expand Down
1 change: 0 additions & 1 deletion src/main/java/demo/TemplateParseDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ public static void main(String[] args) throws IOException {
}
ClusterService clusterService = new ClusterService(null, null, null);
Client client = new NodeClient(null, null);

WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client);
ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
*/
package org.opensearch.flowframework.client;

import org.opensearch.client.node.NodeClient;
import org.opensearch.client.Client;
import org.opensearch.ml.client.MachineLearningNodeClient;

/**
Expand All @@ -22,12 +22,12 @@ private MLClient() {}
/**
* Creates machine learning client.
*
* @param nodeClient node client of OpenSearch.
* @param client client of OpenSearch.
* @return machine learning client from ml-commons.
*/
public static MachineLearningNodeClient createMLClient(NodeClient nodeClient) {
public static MachineLearningNodeClient createMLClient(Client client) {
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
if (INSTANCE == null) {
INSTANCE = new MachineLearningNodeClient(nodeClient);
INSTANCE = new MachineLearningNodeClient(client);
}
return INSTANCE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,13 @@ public class CommonValue {
public static final String GLOBAL_CONTEXT_INDEX = ".plugins-ai-global-context";
public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json";
public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1;
public static final String MODEL_ID = "model_id";
public static final String FUNCTION_NAME = "function_name";
public static final String MODEL_NAME = "name";
public static final String MODEL_VERSION = "model_version";
public static final String MODEL_GROUP_ID = "model_group_id";
public static final String DESCRIPTION = "description";
public static final String CONNECTOR_ID = "connector_id";
public static final String MODEL_FORMAT = "model_format";
public static final String MODEL_CONFIG = "model_config";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;

/**
* Step to deploy a model
*/
public class DeployModelStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModelStep.class);

private Client client;
static final String NAME = "deploy_model";

/**
* Instantiate this class
* @param client client to instantiate MLClient
*/
public DeployModelStep(Client client) {
this.client = client;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> deployModelFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client);

ActionListener<MLDeployModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
logger.info("Model deployment state {}", mlDeployModelResponse.getStatus());
deployModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())))

Check warning on line 54 in src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java#L52-L54

Added lines #L52 - L54 were not covered by tests
);
}

Check warning on line 56 in src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java#L56

Added line #L56 was not covered by tests

@Override
public void onFailure(Exception e) {
logger.error("Model deployment failed");
deployModelFuture.completeExceptionally(e);
}
};

String modelId = null;

for (WorkflowData workflowData : data) {
if (workflowData.getContent().containsKey(MODEL_ID)) {
modelId = (String) workflowData.getContent().get(MODEL_ID);
break;
}
}

Check warning on line 72 in src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java#L72

Added line #L72 was not covered by tests
machineLearningNodeClient.deploy(modelId, actionListener);
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
return deployModelFuture;
}

@Override
public String getName() {
return NAME;

Check warning on line 79 in src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java#L79

Added line #L79 was not covered by tests
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;

/**
* Step to get modelID of a registered local model
*/
@SuppressForbidden(reason = "This class is for the future work of registering local model")
public class GetTask {

private static final Logger logger = LogManager.getLogger(GetTask.class);

Check warning on line 26 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L26

Added line #L26 was not covered by tests
private MachineLearningNodeClient machineLearningNodeClient;
private String taskId;

/**
* Instantiate this class
* @param machineLearningNodeClient client to instantiate ml-commons APIs
* @param taskId taskID of the model
*/
public GetTask(MachineLearningNodeClient machineLearningNodeClient, String taskId) {
this.machineLearningNodeClient = machineLearningNodeClient;
this.taskId = taskId;
}

Check warning on line 38 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L35-L38

Added lines #L35 - L38 were not covered by tests

/**
* Invokes get task API of ml-commons
*/
public void getTask() {

ActionListener<MLTask> actionListener = new ActionListener<>() {

Check warning on line 45 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L45

Added line #L45 was not covered by tests
@Override
public void onResponse(MLTask mlTask) {
if (mlTask.getState() == MLTaskState.COMPLETED) {
logger.info("Model registration successful");
MLTaskGetResponse response = MLTaskGetResponse.builder().mlTask(mlTask).build();
logger.info("Response from task {}", response);

Check warning on line 51 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L49-L51

Added lines #L49 - L51 were not covered by tests
}
}

Check warning on line 53 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L53

Added line #L53 was not covered by tests

@Override
public void onFailure(Exception e) {
logger.error("Model registration failed");
}

Check warning on line 58 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L57-L58

Added lines #L57 - L58 were not covered by tests
};

machineLearningNodeClient.getTask(taskId, actionListener);

Check warning on line 61 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L61

Added line #L61 was not covered by tests

}

Check warning on line 63 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L63

Added line #L63 was not covered by tests

}
Loading
Loading