Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Feb 20, 2024
1 parent 7e548c6 commit 6679dbe
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@

import java.io.IOException;
import java.net.URI;
import java.net.UnknownHostException;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
Expand Down Expand Up @@ -266,18 +264,13 @@ public static SdkHttpFullRequest buildSdkRequest(
Map<String, String> parameters,
String payload,
SdkHttpMethod method
) throws UnknownHostException {
) throws Exception {
String endpoint = connector.getPredictEndpoint(parameters);
Pattern pattern = Pattern.compile("(?:(\\w+)://)?([-a-zA-Z0-9+&@#%?=~_|!,.;]*)(?::(\\w+))?");
Matcher matcher = pattern.matcher(endpoint);
if (matcher.find()) {
String protocol = matcher.group(1);
String host = matcher.group(2);
String port = matcher.group(3);
MLHttpClientFactory.validate(protocol, host, port);
} else {
throw new IllegalArgumentException("Invalid endpoint: " + endpoint);
}
URL url = new URL(endpoint);
String protocol = url.getProtocol();
String host = url.getHost();
int port = url.getPort();
MLHttpClientFactory.validate(protocol, host, port);
String charset = parameters.getOrDefault("charset", "UTF-8");
RequestBody requestBody;
if (payload != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ default void executePredict(MLInput mlInput, ActionListener<MLTaskResponse> acti
}
}

/**
* Calculate the chunk size.
* @param textDocsInputDataSet
* @return Tuple of chunk size and step size.
*/
private Tuple<Integer, Integer> calculateChunkSize(TextDocsInputDataSet textDocsInputDataSet) {
int textDocsLength = textDocsInputDataSet.getDocs().size();
Map<String, String> parameters = getConnector().getParameters();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
import java.security.PrivilegedExceptionAction;
import java.util.Arrays;
import java.util.Locale;
import java.util.Optional;

import org.apache.commons.lang3.math.NumberUtils;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
Expand All @@ -35,24 +32,19 @@ public static SdkAsyncHttpClient getAsyncHttpClient() {
}
}

public static void validate(String protocol, String host, String port) throws UnknownHostException {
public static void validate(String protocol, String host, int port) throws UnknownHostException {
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equals(protocol)) {
log.error("Remote inference protocol is not http or https: " + protocol);
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
}
String portStr = Optional.ofNullable(port).orElseGet(() -> {
if (port == -1) {
if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) {
return "80";
port = 80;
} else {
return "443";
port = 443;
}
});
if (!NumberUtils.isDigits(portStr)) {
log.error("Remote inference port is not a valid number: " + portStr);
throw new IllegalArgumentException("Port is not a valid number: " + portStr);
}
int portNum = Integer.parseInt(portStr);
if (portNum < 0 || portNum > 65536) {
if (port < 0 || port > 65536) {
log.error("Remote inference port out of range: " + port);
throw new IllegalArgumentException("Port out of range: " + port);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,34 @@ public void invokeRemoteModel_get_request() {
);
}

@Test
public void invokeRemoteModel_post_request() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://openai.com/mock")
.requestBody("hello world")
.build();
Connector connector = HttpConnector
.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
executor
.invokeRemoteModel(
createMLInput(),
new HashMap<>(),
"hello world",
new HashMap<>(),
new WrappedCountDownLatch(0, new CountDownLatch(1)),
actionListener
);
}

private MLInput createMLInput() {
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
return MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.engine.algorithms.text_embedding;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.HUGGINGFACE_TRANSFORMERS;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS;
import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE;
Expand All @@ -29,6 +30,7 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
Expand Down Expand Up @@ -342,6 +344,14 @@ public void predict_BeforeInitingModel() {
textEmbeddingDenseModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), model);
}

@Test
public void test_async_inference() {
exceptionRule.expect(IllegalStateException.class);
exceptionRule.expectMessage("Method is not implemented");
textEmbeddingDenseModel.asyncPredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), mock(
ActionListener.class));
}

@After
public void tearDown() {
FileUtils.deleteFileQuietly(mlCachePath);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;

import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.Map;

Expand Down Expand Up @@ -36,7 +35,7 @@ public void test_getSdkAsyncHttpClient_success() {
}

@Test
public void test_validateIp_validIp_noException() throws UnknownHostException {
public void test_validateIp_validIp_noException() throws Exception {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
Expand All @@ -56,7 +55,7 @@ public void test_validateIp_validIp_noException() throws UnknownHostException {
}

@Test
public void test_validateIp_rarePrivateIp_throwException() throws UnknownHostException {
public void test_validateIp_rarePrivateIp_throwException() throws Exception {
try {
ConnectorAction predictAction = ConnectorAction
.builder()
Expand Down Expand Up @@ -179,7 +178,7 @@ public void test_validateIp_rarePrivateIp_throwException() throws UnknownHostExc
}

@Test
public void test_validateSchemaAndPort_success() throws UnknownHostException {
public void test_validateSchemaAndPort_success() throws Exception {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
Expand All @@ -199,7 +198,7 @@ public void test_validateSchemaAndPort_success() throws UnknownHostException {
}

@Test
public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws UnknownHostException {
public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception {
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("Protocol is not http or https: ftp");
ConnectorAction predictAction = ConnectorAction
Expand All @@ -221,7 +220,7 @@ public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws
}

@Test
public void test_validateSchemaAndPort_portNotInRange_throwException() throws UnknownHostException {
public void test_validateSchemaAndPort_portNotInRange_throwException() throws Exception {
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("Port out of range: 65537");
ConnectorAction predictAction = ConnectorAction
Expand All @@ -242,7 +241,7 @@ public void test_validateSchemaAndPort_portNotInRange_throwException() throws Un
}

@Test
public void test_validateSchemaAndPort_portNotANumber_throwException() throws UnknownHostException {
public void test_validateSchemaAndPort_portNotANumber_throwException() throws Exception {
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("Port is not a valid number: abc");
ConnectorAction predictAction = ConnectorAction
Expand Down

0 comments on commit 6679dbe

Please sign in to comment.