From 9fff7025a9d031360b18e017239468269d125db2 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Sat, 9 Mar 2024 15:46:15 +0800 Subject: [PATCH] Change error message when remote model return empty and chaange the behavior when one of the requests fails Signed-off-by: zane-neo --- .../remote/MLSdkAsyncHttpResponseHandler.java | 51 +++++++++++-------- .../MLSdkAsyncHttpResponseHandlerTest.java | 37 +++++++------- .../ml/task/MLPredictTaskRunner.java | 3 +- 3 files changed, 52 insertions(+), 39 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index 73398ccf2f..4743a26101 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -121,7 +121,7 @@ private void processResponse( ModelTensors tensors; if (Strings.isBlank(body)) { log.error("Remote model response body is empty!"); - tensors = processErrorResponse(statusCode, "null"); + tensors = processErrorResponse(statusCode, "Remote model response is empty!"); } else { if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) { log.error("Remote server returned error code: {}", statusCode); @@ -145,33 +145,44 @@ private void reOrderTensorResponses(Map tensorOutputs) { log.debug("Reordered tensor outputs size is {}", sortedMap.size()); if (tensorOutputs.size() == 1) { // batch API case - int status = tensorOutputs.get(0).getStatusCode(); + ModelTensors singleTensor = tensorOutputs.get(0); + int status = singleTensor.getStatusCode(); if (status == HttpStatus.SC_OK) { - modelTensors.add(tensorOutputs.get(0)); + modelTensors.add(singleTensor); actionListener.onResponse(modelTensors); } else { - try { - actionListener - .onFailure( - new OpenSearchStatusException( - AccessController - .doPrivileged( - (PrivilegedExceptionAction) () -> GSON - .toJson(tensorOutputs.get(0).getMlModelTensors().get(0).getDataAsMap()) - ), - RestStatus.fromCode(status) - ) - ); - } catch (PrivilegedActionException e) { - actionListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.fromCode(statusCode))); - } + actionListener.onFailure(buildOpenSearchStatusException(singleTensor)); } } else { - // non batch API. + // non batch API. This is to follow the previously code logic. Previously when making multiple requests to remote model, + // either one fails, we will return a failure response. + OpenSearchStatusException openSearchStatusException = null; for (Map.Entry entry : sortedMap.entrySet()) { + if (entry.getValue().getStatusCode() != HttpStatus.SC_OK) { + openSearchStatusException = buildOpenSearchStatusException(entry.getValue()); + break; + } modelTensors.add(entry.getKey(), entry.getValue()); } - actionListener.onResponse(modelTensors); + if (openSearchStatusException != null) { + actionListener.onFailure(openSearchStatusException); + } else { + actionListener.onResponse(modelTensors); + } + } + } + + private OpenSearchStatusException buildOpenSearchStatusException(ModelTensors modelTensors) { + try { + return new OpenSearchStatusException( + AccessController + .doPrivileged( + (PrivilegedExceptionAction) () -> GSON.toJson(modelTensors.getMlModelTensors().get(0).getDataAsMap()) + ), + RestStatus.fromCode(modelTensors.getStatusCode()) + ); + } catch (PrivilegedActionException e) { + return new OpenSearchStatusException(e.getMessage(), RestStatus.fromCode(statusCode)); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java index 42d3cdc736..5635a8b4fb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java @@ -150,7 +150,7 @@ public void test_MLResponseSubscriber_onError() { ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof OpenSearchStatusException; - assert captor.getValue().getMessage().equals("{\"remote_response\":\"null\"}"); + assert captor.getValue().getMessage().equals("{\"remote_response\":\"Remote model response is empty!\"}"); } @Test @@ -198,7 +198,7 @@ public void test_onComplete_partial_success_exceptionSecond() { + " -0.0011978149\n" + " ]\n" + "}"; - String response2 = "Failed to predict"; + String response2 = "Model current status is: FAILED"; CountDownLatch count = new CountDownLatch(2); MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler( new WrappedCountDownLatch(0, count), @@ -243,13 +243,10 @@ public void test_onComplete_partial_success_exceptionSecond() { } }; mlSdkAsyncHttpResponseHandler2.onStream(stream2); - ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); - verify(actionListener, times(1)).onResponse(captor.capture()); - assert captor.getValue().size() == 2; - assert captor.getValue().get(0).getStatusCode() == 200; - assert captor.getValue().get(0).getMlModelTensors().get(0).getData().length == 8; - assert captor.getValue().get(1).getMlModelTensors().get(0).getDataAsMap().size() == 1; - assert captor.getValue().get(1).getMlModelTensors().get(0).getDataAsMap().get("remote_response").equals("Failed to predict"); + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue().getMessage().equals("{\"remote_response\":\"Model current status is: FAILED\"}"); + assert captor.getValue().status().getStatus() == 500; } @Test @@ -266,7 +263,7 @@ public void test_onComplete_partial_success_exceptionFirst() { + " -0.0011978149\n" + " ]\n" + "}"; - String response2 = "Failed to predict"; + String response2 = "Model current status is: FAILED"; CountDownLatch count = new CountDownLatch(2); MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler( new WrappedCountDownLatch(0, count), @@ -312,13 +309,10 @@ public void test_onComplete_partial_success_exceptionFirst() { } }; mlSdkAsyncHttpResponseHandler1.onStream(stream1); - ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); - verify(actionListener, times(1)).onResponse(captor.capture()); - assert captor.getValue().size() == 2; - assert captor.getValue().get(0).getStatusCode() == 200; - assert captor.getValue().get(0).getMlModelTensors().get(0).getData().length == 8; - assert captor.getValue().get(1).getMlModelTensors().get(0).getDataAsMap().size() == 1; - assert captor.getValue().get(1).getMlModelTensors().get(0).getDataAsMap().get("remote_response").equals("Failed to predict"); + ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener, times(1)).onFailure(captor.capture()); + assert captor.getValue().getMessage().equals("{\"remote_response\":\"Model current status is: FAILED\"}"); + assert captor.getValue().status().getStatus() == 500; } @Test @@ -336,7 +330,14 @@ public void test_onComplete_empty_response_body() { mlSdkAsyncHttpResponseHandler.onStream(stream); ArgumentCaptor> captor = ArgumentCaptor.forClass(List.class); verify(actionListener, times(1)).onResponse(captor.capture()); - assert captor.getValue().get(0).getMlModelTensors().get(0).getDataAsMap().get("remote_response").equals("null"); + assert captor + .getValue() + .get(0) + .getMlModelTensors() + .get(0) + .getDataAsMap() + .get("remote_response") + .equals("Remote model response is empty!"); } @Test diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 82acfd6637..09bb37ddaa 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -222,9 +222,10 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe if (mlInput.getAlgorithm() == FunctionName.REMOTE) { long startTime = System.nanoTime(); ActionListener trackPredictDurationListener = ActionListener.wrap(output -> { + handleAsyncMLTaskComplete(mlTask); mlModelManager.trackPredictDuration(modelId, startTime); internalListener.onResponse(output); - }, internalListener::onFailure); + }, e -> handlePredictFailure(mlTask, internalListener, e, false, modelId)); predictor.asyncPredict(mlInput, trackPredictDurationListener); } else { MLOutput output = mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput));