From d398d48e7679e8e06bcfbc4320955bd52957b241 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 26 Jan 2024 10:04:33 +0800 Subject: [PATCH] Fix no response issue in functional test Signed-off-by: zane-neo --- .../algorithms/remote/MLSdkAsyncHttpResponseHandler.java | 8 ++++++-- .../engine/algorithms/remote/RemoteConnectorExecutor.java | 8 ++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index 6717d001e4..1d0ccb0715 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.TreeMap; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; @@ -85,6 +86,7 @@ public void onStream(Publisher stream) { subscription.request(Long.MAX_VALUE); } @Override public void onError(Throwable t) { + countDownLatch.getCountDownLatch().countDown(); log.error("Error on receiving response body from remote: {}", t instanceof NullPointerException ? "NullPointerException" : t.getMessage(), t); errorMsg.add("Error on receiving response body from remote: " + (t instanceof NullPointerException ? "NullPointerException" : t.getMessage())); if (countDownLatch.getCountDownLatch().getCount() == 0) { @@ -96,15 +98,16 @@ public void onStream(Publisher stream) { @Override public void onComplete() { - countDownLatch.getCountDownLatch().countDown(); try { String fullResponseBody = responseBody.toString(); processResponse(statusCode, fullResponseBody, parameters, tensorOutputs); + countDownLatch.getCountDownLatch().countDown(); if (countDownLatch.getCountDownLatch().getCount() == 0) { log.debug("All responses received, calling action listener to return final results."); actionListener.onResponse(reOrderTensorResponses(tensorOutputs)); } } catch (Throwable e) { + countDownLatch.getCountDownLatch().countDown(); log.error("Error on processing response from remote: {}", e instanceof NullPointerException ? "NullPointerException" : e.getMessage(), e); errorMsg.add("Error on receiving response from remote: " + (e instanceof NullPointerException ? "NullPointerException" : e.getMessage())); if (countDownLatch.getCountDownLatch().getCount() == 0) { @@ -142,7 +145,8 @@ private void processResponse(Integer statusCode, String body, Map reOrderTensorResponses(Map tensorOutputs) { List modelTensors = new ArrayList<>(); - for (Map.Entry entry : tensorOutputs.entrySet()) { + TreeMap sortedMap = new TreeMap<>(tensorOutputs); + for (Map.Entry entry : sortedMap.entrySet()) { modelTensors.add(entry.getKey(), entry.getValue()); } return modelTensors; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index c3de753020..887eb7bed5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -40,13 +40,13 @@ default void executePredict(MLInput mlInput, ActionListener acti ActionListener> tensorActionListener = ActionListener.wrap(r -> { actionListener.onResponse(new MLTaskResponse(new ModelTensorOutput(r))); }, actionListener::onFailure); - Map modelTensorsQueue = new ConcurrentHashMap<>(); + Map modelTensors = new ConcurrentHashMap<>(); if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); Tuple calculatedChunkSize = calculateChunkSize(textDocsInputDataSet); CountDownLatch countDownLatch = new CountDownLatch(calculatedChunkSize.v1()); int sequence = 0; - for (int processedDocs = 0; processedDocs < calculatedChunkSize.v1(); processedDocs = processedDocs + calculatedChunkSize.v2()) { + for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += calculatedChunkSize.v2()) { List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); preparePayloadAndInvokeRemoteModel( MLInput @@ -54,10 +54,10 @@ default void executePredict(MLInput mlInput, ActionListener acti .algorithm(FunctionName.TEXT_EMBEDDING) .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) .build(), - modelTensorsQueue, new WrappedCountDownLatch(sequence++, countDownLatch) , tensorActionListener); + modelTensors, new WrappedCountDownLatch(sequence++, countDownLatch) , tensorActionListener); } } else { - preparePayloadAndInvokeRemoteModel(mlInput, modelTensorsQueue, new WrappedCountDownLatch(0, new CountDownLatch(1)), tensorActionListener); + preparePayloadAndInvokeRemoteModel(mlInput, modelTensors, new WrappedCountDownLatch(0, new CountDownLatch(1)), tensorActionListener); } }