Skip to content

Commit

Permalink
fix duplicate response in channel
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Apr 26, 2024
1 parent d12fb86 commit 5e6a7f6
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void invokeRemoteModel(
Map<String, String> parameters,
String payload,
Map<Integer, ModelTensors> tensorOutputs,
WrappedCountDownLatch countDownLatch,
ExecutionContext countDownLatch,
ActionListener<List<ModelTensors>> actionListener
) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
package org.opensearch.ml.engine.algorithms.remote;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import lombok.AllArgsConstructor;
import lombok.Data;

@Data
@AllArgsConstructor
public class WrappedCountDownLatch {
public class ExecutionContext {
// Should never be null
private int sequence;
private CountDownLatch countDownLatch;
// This is to hold any exception thrown in a split-batch request
private AtomicReference<Exception> exceptionHolder;
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void invokeRemoteModel(
Map<String, String> parameters,
String payload,
Map<Integer, ModelTensors> tensorOutputs,
WrappedCountDownLatch countDownLatch,
ExecutionContext countDownLatch,
ActionListener<List<ModelTensors>> actionListener
) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandle
@Getter
private final StringBuilder responseBody = new StringBuilder();

private final WrappedCountDownLatch countDownLatch;
private final ExecutionContext executionContext;

private final ActionListener<List<ModelTensors>> actionListener;

Expand All @@ -64,20 +64,18 @@ public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandle

private final MLGuard mlGuard;

private Exception exception;

private final static Gson GSON = StringUtils.gson;

public MLSdkAsyncHttpResponseHandler(
WrappedCountDownLatch countDownLatch,
ExecutionContext executionContext,
ActionListener<List<ModelTensors>> actionListener,
Map<String, String> parameters,
Map<Integer, ModelTensors> tensorOutputs,
Connector connector,
ScriptService scriptService,
MLGuard mlGuard
) {
this.countDownLatch = countDownLatch;
this.executionContext = executionContext;
this.actionListener = actionListener;
this.parameters = parameters;
this.tensorOutputs = tensorOutputs;
Expand Down Expand Up @@ -114,22 +112,28 @@ private void processResponse(
) {
if (Strings.isBlank(body)) {
log.error("Remote model response body is empty!");
if (exception == null)
exception = new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
if (executionContext.getExceptionHolder().get() == null)
executionContext
.getExceptionHolder()
.compareAndSet(null, new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST));
} else {
if (statusCode < HttpStatus.SC_OK || statusCode > HttpStatus.SC_MULTIPLE_CHOICES) {
log.error("Remote server returned error code: {}", statusCode);
if (exception == null)
exception = new OpenSearchStatusException(REMOTE_SERVICE_ERROR + body, RestStatus.fromCode(statusCode));
if (executionContext.getExceptionHolder().get() == null)
executionContext
.getExceptionHolder()
.compareAndSet(null, new OpenSearchStatusException(REMOTE_SERVICE_ERROR + body, RestStatus.fromCode(statusCode)));
} else {
try {
ModelTensors tensors = processOutput(body, connector, scriptService, parameters, mlGuard);
tensors.setStatusCode(statusCode);
tensorOutputs.put(countDownLatch.getSequence(), tensors);
tensorOutputs.put(executionContext.getSequence(), tensors);
} catch (Exception e) {
log.error("Failed to process response body: {}", body, e);
if (exception == null)
exception = new MLException("Fail to execute predict in aws connector", e);
if (executionContext.getExceptionHolder().get() == null)
executionContext
.getExceptionHolder()
.compareAndSet(null, new MLException("Fail to execute predict in aws connector", e));
}
}
}
Expand Down Expand Up @@ -217,16 +221,16 @@ public void onComplete() {

private void response(Map<Integer, ModelTensors> tensors) {
processResponse(statusCode, responseBody.toString(), parameters, tensorOutputs);
countDownLatch.getCountDownLatch().countDown();
executionContext.getCountDownLatch().countDown();
// when countdown's count equals to 0 means all responses are received.
if (countDownLatch.getCountDownLatch().getCount() == 0) {
if (exception != null) {
actionListener.onFailure(exception);
if (executionContext.getCountDownLatch().getCount() == 0) {
if (executionContext.getExceptionHolder().get() != null) {
actionListener.onFailure(executionContext.getExceptionHolder().get());
return;
}
reOrderTensorResponses(tensors);
} else {
log.debug("Not all responses received, left response count is: " + countDownLatch.getCountDownLatch().getCount());
log.debug("Not all responses received, left response count is: " + executionContext.getCountDownLatch().getCount());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.Client;
Expand Down Expand Up @@ -42,8 +43,9 @@ default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> acti
ActionListener<List<ModelTensors>> tensorActionListener = ActionListener.wrap(r -> {
actionListener.onResponse(new MLTaskResponse(new ModelTensorOutput(r)));
}, actionListener::onFailure);
Map<Integer, ModelTensors> modelTensors = new ConcurrentHashMap<>();
try {
Map<Integer, ModelTensors> modelTensors = new ConcurrentHashMap<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
Tuple<Integer, Integer> calculatedChunkSize = calculateChunkSize(textDocsInputDataSet);
Expand All @@ -59,15 +61,15 @@ default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> acti
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
.build(),
modelTensors,
new WrappedCountDownLatch(sequence++, countDownLatch),
new ExecutionContext(sequence++, countDownLatch, exceptionHolder),
tensorActionListener
);
}
} else {
preparePayloadAndInvokeRemoteModel(
mlInput,
modelTensors,
new WrappedCountDownLatch(0, new CountDownLatch(1)),
new ExecutionContext(0, new CountDownLatch(1), exceptionHolder),
tensorActionListener
);
}
Expand Down Expand Up @@ -131,7 +133,7 @@ default void setMlGuard(MLGuard mlGuard) {}
default void preparePayloadAndInvokeRemoteModel(
MLInput mlInput,
Map<Integer, ModelTensors> tensorOutputs,
WrappedCountDownLatch countDownLatch,
ExecutionContext countDownLatch,
ActionListener<List<ModelTensors>> actionListener
) {
Connector connector = getConnector();
Expand Down Expand Up @@ -183,7 +185,7 @@ void invokeRemoteModel(
Map<String, String> parameters,
String payload,
Map<Integer, ModelTensors> tensorOutputs,
WrappedCountDownLatch countDownLatch,
ExecutionContext countDownLatch,
ActionListener<List<ModelTensors>> actionListener
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -93,7 +94,7 @@ public void invokeRemoteModel_invalidIpAddress() {
new HashMap<>(),
"{\"input\": \"hello world\"}",
new HashMap<>(),
new WrappedCountDownLatch(0, new CountDownLatch(1)),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
actionListener
);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
Expand Down Expand Up @@ -125,7 +126,7 @@ public void invokeRemoteModel_Empty_payload() {
new HashMap<>(),
null,
new HashMap<>(),
new WrappedCountDownLatch(0, new CountDownLatch(1)),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
actionListener
);
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
Expand Down Expand Up @@ -157,7 +158,7 @@ public void invokeRemoteModel_get_request() {
new HashMap<>(),
null,
new HashMap<>(),
new WrappedCountDownLatch(0, new CountDownLatch(1)),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
actionListener
);
}
Expand Down Expand Up @@ -185,7 +186,7 @@ public void invokeRemoteModel_post_request() {
new HashMap<>(),
"hello world",
new HashMap<>(),
new WrappedCountDownLatch(0, new CountDownLatch(1)),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
actionListener
);
}
Expand Down Expand Up @@ -216,7 +217,7 @@ public void invokeRemoteModel_nullHttpClient_throwMLException() throws NoSuchFie
new HashMap<>(),
"hello world",
new HashMap<>(),
new WrappedCountDownLatch(0, new CountDownLatch(1)),
new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>()),
actionListener
);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.Before;
import org.junit.Test;
Expand All @@ -40,8 +41,7 @@
import software.amazon.awssdk.http.SdkHttpResponse;

public class MLSdkAsyncHttpResponseHandlerTest {

private final WrappedCountDownLatch countDownLatch = new WrappedCountDownLatch(0, new CountDownLatch(1));
private final ExecutionContext executionContext = new ExecutionContext(0, new CountDownLatch(1), new AtomicReference<>());
@Mock
private ActionListener<List<ModelTensors>> actionListener;
@Mock
Expand Down Expand Up @@ -95,7 +95,7 @@ public void setup() {
.actions(Arrays.asList(noProcessFunctionPredictAction))
.build();
mlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler(
countDownLatch,
executionContext,
actionListener,
parameters,
tensorOutputs,
Expand Down Expand Up @@ -155,7 +155,7 @@ public void test_OnStream_without_postProcessFunction() {
}
};
MLSdkAsyncHttpResponseHandler noProcessFunctionMlSdkAsyncHttpResponseHandler = new MLSdkAsyncHttpResponseHandler(
countDownLatch,
executionContext,
actionListener,
parameters,
tensorOutputs,
Expand All @@ -178,7 +178,7 @@ public void test_onError() {
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(captor.capture());
assert captor.getValue() instanceof OpenSearchStatusException;
assert captor.getValue().getMessage().equals("Error on communication with remote model: runtime exception");
assert captor.getValue().getMessage().equals("Error communicating with remote model: runtime exception");
}

@Test
Expand Down Expand Up @@ -244,6 +244,7 @@ public void test_onComplete_success() {

@Test
public void test_onComplete_partial_success_exceptionSecond() {
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
String response1 = "{\n"
+ " \"embedding\": [\n"
+ " 0.46484375,\n"
Expand All @@ -259,7 +260,7 @@ public void test_onComplete_partial_success_exceptionSecond() {
String response2 = "Model current status is: FAILED";
CountDownLatch count = new CountDownLatch(2);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
new WrappedCountDownLatch(0, count),
new ExecutionContext(0, count, exceptionHolder),
actionListener,
parameters,
tensorOutputs,
Expand All @@ -268,7 +269,7 @@ public void test_onComplete_partial_success_exceptionSecond() {
null
);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
new WrappedCountDownLatch(1, count),
new ExecutionContext(1, count, exceptionHolder),
actionListener,
parameters,
tensorOutputs,
Expand Down Expand Up @@ -311,6 +312,7 @@ public void test_onComplete_partial_success_exceptionSecond() {

@Test
public void test_onComplete_partial_success_exceptionFirst() {
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
String response1 = "{\n"
+ " \"embedding\": [\n"
+ " 0.46484375,\n"
Expand All @@ -326,7 +328,7 @@ public void test_onComplete_partial_success_exceptionFirst() {
String response2 = "Model current status is: FAILED";
CountDownLatch count = new CountDownLatch(2);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler1 = new MLSdkAsyncHttpResponseHandler(
new WrappedCountDownLatch(0, count),
new ExecutionContext(0, count, exceptionHolder),
actionListener,
parameters,
tensorOutputs,
Expand All @@ -335,7 +337,7 @@ public void test_onComplete_partial_success_exceptionFirst() {
null
);
MLSdkAsyncHttpResponseHandler mlSdkAsyncHttpResponseHandler2 = new MLSdkAsyncHttpResponseHandler(
new WrappedCountDownLatch(1, count),
new ExecutionContext(1, count, exceptionHolder),
actionListener,
parameters,
tensorOutputs,
Expand Down

0 comments on commit 5e6a7f6

Please sign in to comment.