Skip to content

Commit

Permalink
Add more UTs to increase code coverage
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Feb 20, 2024
1 parent 9aefe47 commit 689c2cb
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandle

private ScriptService scriptService;

private Gson gson = new GsonBuilder().disableHtmlEscaping().create();
private final static Gson GSON = new GsonBuilder().disableHtmlEscaping().create();

public MLSdkAsyncHttpResponseHandler(
WrappedCountDownLatch countDownLatch,
Expand Down Expand Up @@ -156,7 +156,7 @@ private void reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
new OpenSearchStatusException(
AccessController
.doPrivileged(
(PrivilegedExceptionAction<String>) () -> gson
(PrivilegedExceptionAction<String>) () -> GSON
.toJson(tensorOutputs.get(0).getMlModelTensors().get(0).getDataAsMap())
),
RestStatus.fromCode(status)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,44 @@ public void executePredict_TextDocsInferenceInput() {
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
}

@Test
public void executePredict_TextDocsInferenceInput_withStepSize() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap
.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "2");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);

MLInputDataset inputDataSet1 = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet1).build(), actionListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.function.Supplier;

import org.junit.Before;
import org.junit.Ignore;
Expand Down Expand Up @@ -1121,6 +1122,35 @@ public void testRegisterModelMeta_FailedToInitIndexIfPresent() {
verify(actionListener).onFailure(argumentCaptor.capture());
}

public void test_trackPredictDuration_sync() {
Supplier<String> mockResult = () -> {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return "test";
};
String modelId = "test_model";
modelManager.trackPredictDuration(modelId, mockResult);
ArgumentCaptor<String> modelIdCaptor = ArgumentCaptor.forClass(String.class);
ArgumentCaptor<Double> durationCaptor = ArgumentCaptor.forClass(Double.class);
verify(modelCacheHelper).addModelInferenceDuration(modelIdCaptor.capture(), durationCaptor.capture());
assert modelIdCaptor.getValue().equals(modelId);
assert durationCaptor.getValue() > 0;
}

public void test_trackPredictDuration_async() {
String modelId = "test_model";
long startTime = System.nanoTime();
modelManager.trackPredictDuration(modelId, startTime);
ArgumentCaptor<String> modelIdCaptor = ArgumentCaptor.forClass(String.class);
ArgumentCaptor<Double> durationCaptor = ArgumentCaptor.forClass(Double.class);
verify(modelCacheHelper).addModelInferenceDuration(modelIdCaptor.capture(), durationCaptor.capture());
assert modelIdCaptor.getValue().equals(modelId);
assert durationCaptor.getValue() > 0;
}

private void setupForModelMeta() {
doAnswer(invocation -> {
ActionListener<IndexResponse> listener = invocation.getArgument(1);
Expand Down

0 comments on commit 689c2cb

Please sign in to comment.