From 0d21732353d5355be4676e6fed57bf0001652303 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 23 Feb 2024 10:55:00 +0800 Subject: [PATCH] Address comments and add modelTensor status code Signed-off-by: zane-neo --- .../org/opensearch/ml/common/output/model/ModelTensors.java | 6 ++++++ .../ml/engine/algorithms/remote/ConnectorUtils.java | 6 ++++-- .../ml/engine/httpclient/MLHttpClientFactory.java | 3 ++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java index 03b0ce5fca..9629a7d5b6 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java @@ -35,6 +35,12 @@ public ModelTensors(List mlModelTensors) { this.mlModelTensors = mlModelTensors; } + @Builder + public ModelTensors(List mlModelTensors, Integer statusCode) { + this.mlModelTensors = mlModelTensors; + this.statusCode = statusCode; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index c09f00786c..75c9ca30a1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.MLPostProcessFunction; @@ -215,12 +216,13 @@ public static ModelTensors processOutput( Object filteredResponse = JsonPath.parse(response).read(parameters.get(RESPONSE_FILTER_FIELD)); connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor); } - return ModelTensors.builder().mlModelTensors(modelTensors).build(); + return ModelTensors.builder().statusCode(RestStatus.OK.getStatus()).mlModelTensors(modelTensors).build(); } public static ModelTensors processErrorResponse(String errorResponse) { return ModelTensors .builder() + .statusCode(RestStatus.INTERNAL_SERVER_ERROR.getStatus()) .mlModelTensors(List.of(ModelTensor.builder().dataAsMap(Map.of("remote_response", errorResponse)).build())) .build(); } @@ -278,7 +280,7 @@ public static SdkHttpFullRequest buildSdkRequest( } else { requestBody = RequestBody.empty(); } - if (SdkHttpMethod.POST == method && "0".equals(requestBody.optionalContentLength().get().toString())) { + if (SdkHttpMethod.POST == method && 0 == requestBody.optionalContentLength().get()) { log.error("Content length is 0. Aborting request to remote model"); throw new IllegalArgumentException("Content length is 0. Aborting request to remote model"); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java index 097c0f30cd..2b754307c2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java @@ -33,10 +33,11 @@ public static SdkAsyncHttpClient getAsyncHttpClient() { } public static void validate(String protocol, String host, int port) throws UnknownHostException { - if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equals(protocol)) { + if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) { log.error("Remote inference protocol is not http or https: " + protocol); throw new IllegalArgumentException("Protocol is not http or https: " + protocol); } + // When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol. if (port == -1) { if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) { port = 80;