Skip to content

Commit

Permalink
Rebase code
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Apr 11, 2024
1 parent e446f49 commit e138d83
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ public AwsConnectorExecutor(Connector connector) {
this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
}


@SuppressWarnings("removal")
@Override
public void invokeRemoteModel(
Expand All @@ -87,7 +86,15 @@ public void invokeRemoteModel(
.request(signRequest(request))
.requestContentPublisher(new SimpleHttpContentPublisher(request))
.responseHandler(
new MLSdkAsyncHttpResponseHandler(countDownLatch, actionListener, parameters, tensorOutputs, connector, scriptService)
new MLSdkAsyncHttpResponseHandler(
countDownLatch,
actionListener,
parameters,
tensorOutputs,
connector,
scriptService,
mlGuard
)
)
.build();
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
Expand Down Expand Up @@ -185,11 +186,15 @@ public static ModelTensors processOutput(
String modelResponse,
Connector connector,
ScriptService scriptService,
Map<String, String> parameters
Map<String, String> parameters,
MLGuard mlGuard
) throws IOException {
if (modelResponse == null) {
throw new IllegalArgumentException("model response is null");
}
if (mlGuard != null && mlGuard.validate(modelResponse, MLGuard.Type.OUTPUT)) {
throw new IllegalArgumentException("guardrails triggered for LLM output");
}
List<ModelTensor> modelTensors = new ArrayList<>();
Optional<ConnectorAction> predictAction = connector.findPredictAction();
if (predictAction.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,15 @@ public void invokeRemoteModel(
.request(request)
.requestContentPublisher(new SimpleHttpContentPublisher(request))
.responseHandler(
new MLSdkAsyncHttpResponseHandler(countDownLatch, actionListener, parameters, tensorOutputs, connector, scriptService)
new MLSdkAsyncHttpResponseHandler(
countDownLatch,
actionListener,
parameters,
tensorOutputs,
connector,
scriptService,
mlGuard
)
)
.build();
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.script.ScriptService;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
Expand All @@ -49,34 +50,38 @@ public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandle
@Getter
private final StringBuilder responseBody = new StringBuilder();

private WrappedCountDownLatch countDownLatch;
private final WrappedCountDownLatch countDownLatch;

private ActionListener<List<ModelTensors>> actionListener;
private final ActionListener<List<ModelTensors>> actionListener;

private Map<String, String> parameters;
private final Map<String, String> parameters;

private Map<Integer, ModelTensors> tensorOutputs;
private final Map<Integer, ModelTensors> tensorOutputs;

private Connector connector;
private final Connector connector;

private ScriptService scriptService;
private final ScriptService scriptService;

private final static Gson GSON = new GsonBuilder().disableHtmlEscaping().create();
private final MLGuard mlGuard;

private final static Gson GSON = StringUtils.gson;

public MLSdkAsyncHttpResponseHandler(
WrappedCountDownLatch countDownLatch,
ActionListener<List<ModelTensors>> actionListener,
Map<String, String> parameters,
Map<Integer, ModelTensors> tensorOutputs,
Connector connector,
ScriptService scriptService
ScriptService scriptService,
MLGuard mlGuard
) {
this.countDownLatch = countDownLatch;
this.actionListener = actionListener;
this.parameters = parameters;
this.tensorOutputs = tensorOutputs;
this.connector = connector;
this.scriptService = scriptService;
this.mlGuard = mlGuard;
}

@Override
Expand Down Expand Up @@ -128,7 +133,7 @@ private void processResponse(
actionListener.onFailure(new OpenSearchStatusException(REMOTE_SERVICE_ERROR + body, RestStatus.fromCode(statusCode)));
} else {
try {
ModelTensors tensors = processOutput(body, connector, scriptService, parameters);
ModelTensors tensors = processOutput(body, connector, scriptService, parameters, mlGuard);
tensors.setStatusCode(statusCode);
tensorOutputs.put(countDownLatch.getSequence(), tensors);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc()
public void processOutput_NullResponse() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model response is null");
ConnectorUtils.processOutput(null, null, null, null);
ConnectorUtils.processOutput(null, null, null, null, null);
}

@Test
Expand All @@ -192,7 +192,7 @@ public void processOutput_NoPostprocessFunction_jsonResponse() throws IOExceptio
.build();
String modelResponse =
"{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of());
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null);
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName());
Assert.assertEquals(4, tensors.getMlModelTensors().get(0).getDataAsMap().size());
Expand Down Expand Up @@ -224,7 +224,7 @@ public void processOutput_PostprocessFunction() throws IOException {
.build();
String modelResponse =
"{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of());
ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of(), null);
Assert.assertEquals(1, tensors.getMlModelTensors().size());
Assert.assertEquals("sentence_embedding", tensors.getMlModelTensors().get(0).getName());
Assert.assertNull(tensors.getMlModelTensors().get(0).getDataAsMap());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ public void setup() {
parameters,
tensorOutputs,
connector,
scriptService
scriptService,
null
);
responseSubscriber = mlSdkAsyncHttpResponseHandler.new MLResponseSubscriber();
}
Expand Down Expand Up @@ -159,7 +160,8 @@ public void test_OnStream_without_postProcessFunction() {
parameters,
tensorOutputs,
noProcessFunctionConnector,
scriptService
scriptService,
null
);
noProcessFunctionMlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
noProcessFunctionMlSdkAsyncHttpResponseHandler.onStream(stream);
Expand Down Expand Up @@ -262,15 +264,17 @@ public void test_onComplete_partial_success_exceptionSecond() {
parameters,
tensorOutputs,
connector,
scriptService
scriptService,
null
);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
new WrappedCountDownLatch(1, count),
actionListener,
parameters,
tensorOutputs,
connector,
scriptService
scriptService,
null
);
SdkHttpFullResponse sdkHttpResponse1 = mock(SdkHttpFullResponse.class);
when(sdkHttpResponse1.statusCode()).thenReturn(200);
Expand Down Expand Up @@ -327,15 +331,17 @@ public void test_onComplete_partial_success_exceptionFirst() {
parameters,
tensorOutputs,
connector,
scriptService
scriptService,
null
);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
new WrappedCountDownLatch(1, count),
actionListener,
parameters,
tensorOutputs,
connector,
scriptService
scriptService,
null
);

SdkHttpFullResponse sdkHttpResponse2 = mock(SdkHttpFullResponse.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ private void runPredict(
}
// Once prediction complete, reduce ML_EXECUTING_TASK_COUNT and update task state
handleAsyncMLTaskComplete(mlTask);
listener.onResponse(new MLTaskResponse(output));
internalListener.onResponse(new MLTaskResponse(output));
}
return;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
Expand Down

0 comments on commit e138d83

Please sign in to comment.