Skip to content

Commit

Permalink
Address comments and add modelTensor status code
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Feb 23, 2024
1 parent 59f16f0 commit 0d21732
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ public ModelTensors(List<ModelTensor> mlModelTensors) {
this.mlModelTensors = mlModelTensors;
}

@Builder
public ModelTensors(List<ModelTensor> mlModelTensors, Integer statusCode) {
this.mlModelTensors = mlModelTensors;
this.statusCode = statusCode;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
Expand Down Expand Up @@ -215,12 +216,13 @@ public static ModelTensors processOutput(
Object filteredResponse = JsonPath.parse(response).read(parameters.get(RESPONSE_FILTER_FIELD));
connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor);
}
return ModelTensors.builder().mlModelTensors(modelTensors).build();
return ModelTensors.builder().statusCode(RestStatus.OK.getStatus()).mlModelTensors(modelTensors).build();
}

public static ModelTensors processErrorResponse(String errorResponse) {
return ModelTensors
.builder()
.statusCode(RestStatus.INTERNAL_SERVER_ERROR.getStatus())
.mlModelTensors(List.of(ModelTensor.builder().dataAsMap(Map.of("remote_response", errorResponse)).build()))
.build();
}
Expand Down Expand Up @@ -278,7 +280,7 @@ public static SdkHttpFullRequest buildSdkRequest(
} else {
requestBody = RequestBody.empty();
}
if (SdkHttpMethod.POST == method && "0".equals(requestBody.optionalContentLength().get().toString())) {
if (SdkHttpMethod.POST == method && 0 == requestBody.optionalContentLength().get()) {
log.error("Content length is 0. Aborting request to remote model");
throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ public static SdkAsyncHttpClient getAsyncHttpClient() {
}

public static void validate(String protocol, String host, int port) throws UnknownHostException {
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equals(protocol)) {
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
log.error("Remote inference protocol is not http or https: " + protocol);
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
}
// When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
if (port == -1) {
if (protocol == null || "http".equals(protocol.toLowerCase(Locale.getDefault()))) {
port = 80;
Expand Down

0 comments on commit 0d21732

Please sign in to comment.