diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java index 93c1c69fb3..5ee1e4d7a1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java @@ -38,7 +38,7 @@ default MLOutput predict(MLInput mlInput) { } default void asyncPredict(MLInput mlInput, ActionListener actionListener) { - throw new IllegalStateException("Method is not implemented"); + actionListener.onFailure(new IllegalStateException("Method is not implemented")); } /** diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index ab4d01bae4..52096b1790 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -42,31 +42,36 @@ default void executePredict(MLInput mlInput, ActionListener acti actionListener.onResponse(new MLTaskResponse(new ModelTensorOutput(r))); }, actionListener::onFailure); Map modelTensors = new ConcurrentHashMap<>(); - if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { - TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - Tuple calculatedChunkSize = calculateChunkSize(textDocsInputDataSet); - CountDownLatch countDownLatch = new CountDownLatch(calculatedChunkSize.v1()); - int sequence = 0; - for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += calculatedChunkSize.v2()) { - List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); + try { + if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { + TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); + Tuple calculatedChunkSize = calculateChunkSize(textDocsInputDataSet); + CountDownLatch countDownLatch = new CountDownLatch(calculatedChunkSize.v1()); + int sequence = 0; + for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += calculatedChunkSize + .v2()) { + List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); + preparePayloadAndInvokeRemoteModel( + MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) + .build(), + modelTensors, + new WrappedCountDownLatch(sequence++, countDownLatch), + tensorActionListener + ); + } + } else { preparePayloadAndInvokeRemoteModel( - MLInput - .builder() - .algorithm(FunctionName.TEXT_EMBEDDING) - .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) - .build(), + mlInput, modelTensors, - new WrappedCountDownLatch(sequence++, countDownLatch), + new WrappedCountDownLatch(0, new CountDownLatch(1)), tensorActionListener ); } - } else { - preparePayloadAndInvokeRemoteModel( - mlInput, - modelTensors, - new WrappedCountDownLatch(0, new CountDownLatch(1)), - tensorActionListener - ); + } catch (Exception e) { + actionListener.onFailure(e); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 43f18c53f6..94b4a7304a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -281,7 +281,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() { } @Test - public void executePredict_RemoteInferenceInput_nullHttpClient_throwMLException() throws NoSuchFieldException, IllegalAccessException { + public void executePredict_RemoteInferenceInput_nullHttpClient_throwNPException() throws NoSuchFieldException, IllegalAccessException { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) @@ -319,4 +319,42 @@ public void executePredict_RemoteInferenceInput_nullHttpClient_throwMLException( Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); assert exceptionCaptor.getValue() instanceof NullPointerException; } + + @Test + public void executePredict_RemoteInferenceInput_negativeStepSize_throwIllegalArgumentException() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap + .of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker", "input_docs_processed_step_size", "-1"); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input1", "input2", "input3")).build(); + executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), actionListener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); + assert exceptionCaptor.getValue() instanceof IllegalArgumentException; + } }