diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index 3958382998..dcf8aa8a01 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -60,7 +60,7 @@ public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandle private ScriptService scriptService; - private Gson gson = new GsonBuilder().disableHtmlEscaping().create(); + private final static Gson GSON = new GsonBuilder().disableHtmlEscaping().create(); public MLSdkAsyncHttpResponseHandler( WrappedCountDownLatch countDownLatch, @@ -156,7 +156,7 @@ private void reOrderTensorResponses(Map tensorOutputs) { new OpenSearchStatusException( AccessController .doPrivileged( - (PrivilegedExceptionAction) () -> gson + (PrivilegedExceptionAction) () -> GSON .toJson(tensorOutputs.get(0).getMlModelTensors().get(0).getDataAsMap()) ), RestStatus.fromCode(status) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index ce89358b47..5e928a3f96 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -238,4 +238,44 @@ public void executePredict_TextDocsInferenceInput() { executor .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); } + + @Test + public void executePredict_TextDocsInferenceInput_withStepSize() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap + .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "2"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + + MLInputDataset inputDataSet1 = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet1).build(), actionListener); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 60338e4ccd..6428ceb5e2 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -61,6 +61,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; +import java.util.function.Supplier; import org.junit.Before; import org.junit.Ignore; @@ -1121,6 +1122,35 @@ public void testRegisterModelMeta_FailedToInitIndexIfPresent() { verify(actionListener).onFailure(argumentCaptor.capture()); } + public void test_trackPredictDuration_sync() { + Supplier mockResult = () -> { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return "test"; + }; + String modelId = "test_model"; + modelManager.trackPredictDuration(modelId, mockResult); + ArgumentCaptor modelIdCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor durationCaptor = ArgumentCaptor.forClass(Double.class); + verify(modelCacheHelper).addModelInferenceDuration(modelIdCaptor.capture(), durationCaptor.capture()); + assert modelIdCaptor.getValue().equals(modelId); + assert durationCaptor.getValue() > 0; + } + + public void test_trackPredictDuration_async() { + String modelId = "test_model"; + long startTime = System.nanoTime(); + modelManager.trackPredictDuration(modelId, startTime); + ArgumentCaptor modelIdCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor durationCaptor = ArgumentCaptor.forClass(Double.class); + verify(modelCacheHelper).addModelInferenceDuration(modelIdCaptor.capture(), durationCaptor.capture()); + assert modelIdCaptor.getValue().equals(modelId); + assert durationCaptor.getValue() > 0; + } + private void setupForModelMeta() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1);