Skip to content

Commit

Permalink
Change error message when remote model return empty and chaange the b…
Browse files Browse the repository at this point in the history
…ehavior when one of the requests fails

Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Mar 9, 2024
1 parent 23b0d77 commit 9fff702
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ private void processResponse(
ModelTensors tensors;
if (Strings.isBlank(body)) {
log.error("Remote model response body is empty!");
tensors = processErrorResponse(statusCode, "null");
tensors = processErrorResponse(statusCode, "Remote model response is empty!");
} else {
if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) {
log.error("Remote server returned error code: {}", statusCode);
Expand All @@ -145,33 +145,44 @@ private void reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
log.debug("Reordered tensor outputs size is {}", sortedMap.size());
if (tensorOutputs.size() == 1) {
// batch API case
int status = tensorOutputs.get(0).getStatusCode();
ModelTensors singleTensor = tensorOutputs.get(0);
int status = singleTensor.getStatusCode();
if (status == HttpStatus.SC_OK) {
modelTensors.add(tensorOutputs.get(0));
modelTensors.add(singleTensor);
actionListener.onResponse(modelTensors);
} else {
try {
actionListener
.onFailure(
new OpenSearchStatusException(
AccessController
.doPrivileged(
(PrivilegedExceptionAction<String>) () -> GSON
.toJson(tensorOutputs.get(0).getMlModelTensors().get(0).getDataAsMap())
),
RestStatus.fromCode(status)
)
);
} catch (PrivilegedActionException e) {
actionListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.fromCode(statusCode)));
}
actionListener.onFailure(buildOpenSearchStatusException(singleTensor));
}
} else {
// non batch API.
// non batch API. This is to follow the previously code logic. Previously when making multiple requests to remote model,
// either one fails, we will return a failure response.
OpenSearchStatusException openSearchStatusException = null;
for (Map.Entry<Integer, ModelTensors> entry : sortedMap.entrySet()) {
if (entry.getValue().getStatusCode() != HttpStatus.SC_OK) {
openSearchStatusException = buildOpenSearchStatusException(entry.getValue());
break;
}
modelTensors.add(entry.getKey(), entry.getValue());
}
actionListener.onResponse(modelTensors);
if (openSearchStatusException != null) {
actionListener.onFailure(openSearchStatusException);
} else {
actionListener.onResponse(modelTensors);
}
}
}

private OpenSearchStatusException buildOpenSearchStatusException(ModelTensors modelTensors) {
try {
return new OpenSearchStatusException(
AccessController
.doPrivileged(
(PrivilegedExceptionAction<String>) () -> GSON.toJson(modelTensors.getMlModelTensors().get(0).getDataAsMap())
),
RestStatus.fromCode(modelTensors.getStatusCode())
);
} catch (PrivilegedActionException e) {
return new OpenSearchStatusException(e.getMessage(), RestStatus.fromCode(statusCode));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public void test_MLResponseSubscriber_onError() {
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue() instanceof OpenSearchStatusException;
assert captor.getValue().getMessage().equals("{\"remote_response\":\"null\"}");
assert captor.getValue().getMessage().equals("{\"remote_response\":\"Remote model response is empty!\"}");
}

@Test
Expand Down Expand Up @@ -198,7 +198,7 @@ public void test_onComplete_partial_success_exceptionSecond() {
+ " -0.0011978149\n"
+ " ]\n"
+ "}";
String response2 = "Failed to predict";
String response2 = "Model current status is: FAILED";
CountDownLatch count = new CountDownLatch(2);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
new WrappedCountDownLatch(0, count),
Expand Down Expand Up @@ -243,13 +243,10 @@ public void test_onComplete_partial_success_exceptionSecond() {
}
};
mlSdkAsyncHttpResponseHandler2.onStream(stream2);
ArgumentCaptor<List<ModelTensors>> captor = ArgumentCaptor.forClass(List.class);
verify(actionListener, times(1)).onResponse(captor.capture());
assert captor.getValue().size() == 2;
assert captor.getValue().get(0).getStatusCode() == 200;
assert captor.getValue().get(0).getMlModelTensors().get(0).getData().length == 8;
assert captor.getValue().get(1).getMlModelTensors().get(0).getDataAsMap().size() == 1;
assert captor.getValue().get(1).getMlModelTensors().get(0).getDataAsMap().get("remote_response").equals("Failed to predict");
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("{\"remote_response\":\"Model current status is: FAILED\"}");
assert captor.getValue().status().getStatus() == 500;
}

@Test
Expand All @@ -266,7 +263,7 @@ public void test_onComplete_partial_success_exceptionFirst() {
+ " -0.0011978149\n"
+ " ]\n"
+ "}";
String response2 = "Failed to predict";
String response2 = "Model current status is: FAILED";
CountDownLatch count = new CountDownLatch(2);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
new WrappedCountDownLatch(0, count),
Expand Down Expand Up @@ -312,13 +309,10 @@ public void test_onComplete_partial_success_exceptionFirst() {
}
};
mlSdkAsyncHttpResponseHandler1.onStream(stream1);
ArgumentCaptor<List<ModelTensors>> captor = ArgumentCaptor.forClass(List.class);
verify(actionListener, times(1)).onResponse(captor.capture());
assert captor.getValue().size() == 2;
assert captor.getValue().get(0).getStatusCode() == 200;
assert captor.getValue().get(0).getMlModelTensors().get(0).getData().length == 8;
assert captor.getValue().get(1).getMlModelTensors().get(0).getDataAsMap().size() == 1;
assert captor.getValue().get(1).getMlModelTensors().get(0).getDataAsMap().get("remote_response").equals("Failed to predict");
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("{\"remote_response\":\"Model current status is: FAILED\"}");
assert captor.getValue().status().getStatus() == 500;
}

@Test
Expand All @@ -336,7 +330,14 @@ public void test_onComplete_empty_response_body() {
mlSdkAsyncHttpResponseHandler.onStream(stream);
ArgumentCaptor<List<ModelTensors>> captor = ArgumentCaptor.forClass(List.class);
verify(actionListener, times(1)).onResponse(captor.capture());
assert captor.getValue().get(0).getMlModelTensors().get(0).getDataAsMap().get("remote_response").equals("null");
assert captor
.getValue()
.get(0)
.getMlModelTensors()
.get(0)
.getDataAsMap()
.get("remote_response")
.equals("Remote model response is empty!");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,10 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
if (mlInput.getAlgorithm() == FunctionName.REMOTE) {
long startTime = System.nanoTime();
ActionListener<MLTaskResponse> trackPredictDurationListener = ActionListener.wrap(output -> {
handleAsyncMLTaskComplete(mlTask);
mlModelManager.trackPredictDuration(modelId, startTime);
internalListener.onResponse(output);
}, internalListener::onFailure);
}, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId));
predictor.asyncPredict(mlInput, trackPredictDurationListener);
} else {
MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput));
Expand Down

0 comments on commit 9fff702

Please sign in to comment.