From 6679dbee15580d5d1b63968c77e769da2d83a707 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 20 Feb 2024 18:09:04 +0800 Subject: [PATCH] Address comments Signed-off-by: zane-neo --- .../algorithms/remote/ConnectorUtils.java | 21 +++++--------- .../remote/RemoteConnectorExecutor.java | 5 ++++ .../httpclient/MLHttpClientFactory.java | 18 ++++-------- .../remote/HttpJsonConnectorExecutorTest.java | 28 +++++++++++++++++++ .../TextEmbeddingDenseModelTest.java | 10 +++++++ .../httpclient/MLHttpClientFactoryTests.java | 13 ++++----- 6 files changed, 61 insertions(+), 34 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 5a5b4551c9..c09f00786c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -16,7 +16,7 @@ import java.io.IOException; import java.net.URI; -import java.net.UnknownHostException; +import java.net.URL; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.HashMap; @@ -24,8 +24,6 @@ import java.util.Map; import java.util.Optional; import java.util.function.Function; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; @@ -266,18 +264,13 @@ public static SdkHttpFullRequest buildSdkRequest( Map parameters, String payload, SdkHttpMethod method - ) throws UnknownHostException { + ) throws Exception { String endpoint = connector.getPredictEndpoint(parameters); - Pattern pattern = Pattern.compile("(?:(\\w+)://)?([-a-zA-Z0-9+&@#%?=~_|!,.;]*)(?::(\\w+))?"); - Matcher matcher = pattern.matcher(endpoint); - if (matcher.find()) { - String protocol = matcher.group(1); - String host = matcher.group(2); - String port = matcher.group(3); - MLHttpClientFactory.validate(protocol, host, port); - } else { - throw new IllegalArgumentException("Invalid endpoint: " + endpoint); - } + URL url = new URL(endpoint); + String protocol = url.getProtocol(); + String host = url.getHost(); + int port = url.getPort(); + MLHttpClientFactory.validate(protocol, host, port); String charset = parameters.getOrDefault("charset", "UTF-8"); RequestBody requestBody; if (payload != null) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 69504e0ae2..ab4d01bae4 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -70,6 +70,11 @@ default void executePredict(MLInput mlInput, ActionListener acti } } + /** + * Calculate the chunk size. + * @param textDocsInputDataSet + * @return Tuple of chunk size and step size. + */ private Tuple calculateChunkSize(TextDocsInputDataSet textDocsInputDataSet) { int textDocsLength = textDocsInputDataSet.getDocs().size(); Map parameters = getConnector().getParameters(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java index c674100434..097c0f30cd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java @@ -13,9 +13,6 @@ import java.security.PrivilegedExceptionAction; import java.util.Arrays; import java.util.Locale; -import java.util.Optional; - -import org.apache.commons.lang3.math.NumberUtils; import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; @@ -35,24 +32,19 @@ public static SdkAsyncHttpClient getAsyncHttpClient() { } } - public static void validate(String protocol, String host, String port) throws UnknownHostException { + public static void validate(String protocol, String host, int port) throws UnknownHostException { if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equals(protocol)) { log.error("Remote inference protocol is not http or https: " + protocol); throw new IllegalArgumentException("Protocol is not http or https: " + protocol); } - String portStr = Optional.ofNullable(port).orElseGet(() -> { + if (port == -1) { if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) { - return "80"; + port = 80; } else { - return "443"; + port = 443; } - }); - if (!NumberUtils.isDigits(portStr)) { - log.error("Remote inference port is not a valid number: " + portStr); - throw new IllegalArgumentException("Port is not a valid number: " + portStr); } - int portNum = Integer.parseInt(portStr); - if (portNum < 0 || portNum > 65536) { + if (port < 0 || port > 65536) { log.error("Remote inference port out of range: " + port); throw new IllegalArgumentException("Port out of range: " + port); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 5b52c48ddc..8545acd541 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -159,6 +159,34 @@ public void invokeRemoteModel_get_request() { ); } + @Test + public void invokeRemoteModel_post_request() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://openai.com/mock") + .requestBody("hello world") + .build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + executor + .invokeRemoteModel( + createMLInput(), + new HashMap<>(), + "hello world", + new HashMap<>(), + new WrappedCountDownLatch(0, new CountDownLatch(1)), + actionListener + ); + } + private MLInput createMLInput() { MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); return MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java index 7c7e2be4b9..07a9bc2e09 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.text_embedding; import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.HUGGINGFACE_TRANSFORMERS; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS; import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE; @@ -29,6 +30,7 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.ResourceNotFoundException; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; @@ -342,6 +344,14 @@ public void predict_BeforeInitingModel() { textEmbeddingDenseModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), model); } + @Test + public void test_async_inference() { + exceptionRule.expect(IllegalStateException.class); + exceptionRule.expectMessage("Method is not implemented"); + textEmbeddingDenseModel.asyncPredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), mock( + ActionListener.class)); + } + @After public void tearDown() { FileUtils.deleteFileQuietly(mlCachePath); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java index cd3ba4322b..1e656eb769 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java @@ -8,7 +8,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; -import java.net.UnknownHostException; import java.util.Arrays; import java.util.Map; @@ -36,7 +35,7 @@ public void test_getSdkAsyncHttpClient_success() { } @Test - public void test_validateIp_validIp_noException() throws UnknownHostException { + public void test_validateIp_validIp_noException() throws Exception { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) @@ -56,7 +55,7 @@ public void test_validateIp_validIp_noException() throws UnknownHostException { } @Test - public void test_validateIp_rarePrivateIp_throwException() throws UnknownHostException { + public void test_validateIp_rarePrivateIp_throwException() throws Exception { try { ConnectorAction predictAction = ConnectorAction .builder() @@ -179,7 +178,7 @@ public void test_validateIp_rarePrivateIp_throwException() throws UnknownHostExc } @Test - public void test_validateSchemaAndPort_success() throws UnknownHostException { + public void test_validateSchemaAndPort_success() throws Exception { ConnectorAction predictAction = ConnectorAction .builder() .actionType(ConnectorAction.ActionType.PREDICT) @@ -199,7 +198,7 @@ public void test_validateSchemaAndPort_success() throws UnknownHostException { } @Test - public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws UnknownHostException { + public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Protocol is not http or https: ftp"); ConnectorAction predictAction = ConnectorAction @@ -221,7 +220,7 @@ public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws } @Test - public void test_validateSchemaAndPort_portNotInRange_throwException() throws UnknownHostException { + public void test_validateSchemaAndPort_portNotInRange_throwException() throws Exception { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Port out of range: 65537"); ConnectorAction predictAction = ConnectorAction @@ -242,7 +241,7 @@ public void test_validateSchemaAndPort_portNotInRange_throwException() throws Un } @Test - public void test_validateSchemaAndPort_portNotANumber_throwException() throws UnknownHostException { + public void test_validateSchemaAndPort_portNotANumber_throwException() throws Exception { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Port is not a valid number: abc"); ConnectorAction predictAction = ConnectorAction