Skip to content

Commit

Permalink
Add more UTs for throw exception cases
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Mar 13, 2024
1 parent feee4f1 commit 16df256
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ default MLOutput predict(MLInput mlInput) {
}

default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
throw new IllegalStateException("Method is not implemented");
actionListener.onFailure(new IllegalStateException("Method is not implemented"));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,36 @@ default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> acti
actionListener.onResponse(new MLTaskResponse(new ModelTensorOutput(r)));
}, actionListener::onFailure);
Map<Integer, ModelTensors> modelTensors = new ConcurrentHashMap<>();
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
Tuple<Integer, Integer> calculatedChunkSize = calculateChunkSize(textDocsInputDataSet);
CountDownLatch countDownLatch = new CountDownLatch(calculatedChunkSize.v1());
int sequence = 0;
for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += calculatedChunkSize.v2()) {
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
try {
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
Tuple<Integer, Integer> calculatedChunkSize = calculateChunkSize(textDocsInputDataSet);
CountDownLatch countDownLatch = new CountDownLatch(calculatedChunkSize.v1());
int sequence = 0;
for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += calculatedChunkSize
.v2()) {
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
preparePayloadAndInvokeRemoteModel(
MLInput
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
.build(),
modelTensors,
new WrappedCountDownLatch(sequence++, countDownLatch),
tensorActionListener
);
}
} else {
preparePayloadAndInvokeRemoteModel(
MLInput
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
.build(),
mlInput,
modelTensors,
new WrappedCountDownLatch(sequence++, countDownLatch),
new WrappedCountDownLatch(0, new CountDownLatch(1)),
tensorActionListener
);
}
} else {
preparePayloadAndInvokeRemoteModel(
mlInput,
modelTensors,
new WrappedCountDownLatch(0, new CountDownLatch(1)),
tensorActionListener
);
} catch (Exception e) {
actionListener.onFailure(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() {
}

@Test
public void executePredict_RemoteInferenceInput_nullHttpClient_throwMLException() throws NoSuchFieldException, IllegalAccessException {
public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException() throws NoSuchFieldException, IllegalAccessException {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
Expand Down Expand Up @@ -319,4 +319,42 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwMLException(
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assert exceptionCaptor.getValue() instanceof NullPointerException;
}

@Test
public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArgumentException() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap
.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "-1");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build();
executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assert exceptionCaptor.getValue() instanceof IllegalArgumentException;
}
}

0 comments on commit 16df256

Please sign in to comment.