Skip to content

Commit

Permalink
add multi tenacy support (#489) (#492)
Browse files Browse the repository at this point in the history
Signed-off-by: Hailong Cui <[email protected]>
(cherry picked from commit 854c64d)
  • Loading branch information
Hailong-am authored Jan 26, 2025
1 parent 735f9f4 commit b201795
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
Expand Down Expand Up @@ -165,6 +166,7 @@ public CreateAnomalyDetectorTool(Client client, String modelId, String modelType
*/
@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
final String tenantId = parameters.get(TENANT_ID_FIELD);
Map<String, String> enrichedParameters = enrichParameters(parameters);
String indexName = enrichedParameters.get("index");
if (Strings.isNullOrEmpty(indexName)) {
Expand Down Expand Up @@ -227,7 +229,9 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
.build();
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(),
null,
tenantId
);

client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> {
Expand Down
7 changes: 6 additions & 1 deletion src/main/java/org/opensearch/agent/tools/PPLTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.agent.tools;

import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -170,6 +172,7 @@ public PPLTool(
@SuppressWarnings("unchecked")
@Override
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
final String tenantId = parameters.get(TENANT_ID_FIELD);
extractFromChatParameters(parameters);
String indexName = getIndexNameFromParameters(parameters);
if (StringUtils.isBlank(indexName)) {
Expand Down Expand Up @@ -206,7 +209,9 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
.build();
ActionRequest request = new MLPredictionTaskRequest(
modelId,
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()
MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(),
null,
tenantId
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(mlTaskResponse -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput();
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/agent/tools/RAGTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.apache.commons.lang3.StringEscapeUtils.escapeJson;
import static org.opensearch.agent.tools.AbstractRetrieverTool.*;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

Expand Down Expand Up @@ -94,6 +95,8 @@ public Object parse(Object o) {
}

public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
final String tenantId = parameters.get(TENANT_ID_FIELD);

String input = null;

if (!this.validate(parameters)) {
Expand Down Expand Up @@ -145,7 +148,7 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build();
ActionRequest request = new MLPredictionTaskRequest(this.inferenceModelId, mlInput);
ActionRequest request = new MLPredictionTaskRequest(this.inferenceModelId, mlInput, null, tenantId);

client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(resp -> {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) resp.getOutput();
Expand Down

0 comments on commit b201795

Please sign in to comment.