From 64189a4a3cc5429467e5e922a868a98d56fd3924 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Sat, 25 Nov 2023 15:15:54 +0800 Subject: [PATCH] format code Signed-off-by: zane-neo --- .../opensearch/ml/cluster/MLSyncUpCron.java | 24 +-- .../ml/plugin/MachineLearningPlugin.java | 1 - .../ml/rest/MyRestPPLQueryAction.java | 196 ++++++++---------- 3 files changed, 104 insertions(+), 117 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 801b7f0961..3a5ea83347 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -244,20 +244,20 @@ void refreshModelState(Map> modelWorkerNodes, Map planningWorkNodes = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODES_FIELD) - ? (List) sourceAsMap.get(MLModel.PLANNING_WORKER_NODES_FIELD) - : new ArrayList<>(); + ? (List) sourceAsMap.get(MLModel.PLANNING_WORKER_NODES_FIELD) + : new ArrayList<>(); if (deployToAllNodes) { DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(functionName); planningWorkerNodeCount = eligibleNodes.length; @@ -312,8 +312,8 @@ private MLModelState getNewModelState( if (currentWorkerNodeCount == 0 && state != MLModelState.DEPLOY_FAILED && !(state == MLModelState.DEPLOYING - && lastUpdateTime != null - && lastUpdateTime + DEPLOY_MODEL_TASK_GRACE_TIME_IN_MS > Instant.now().toEpochMilli())) { + && lastUpdateTime != null + && lastUpdateTime + DEPLOY_MODEL_TASK_GRACE_TIME_IN_MS > Instant.now().toEpochMilli())) { // If model not deployed to any node and no node is deploying the model, then set model state as DEPLOY_FAILED return MLModelState.DEPLOY_FAILED; } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 9325e9a095..473fbb42fe 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -144,7 +144,6 @@ import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; -//import org.opensearch.ml.rest.MyRestPPLQueryAction; import org.opensearch.ml.rest.MyRestPPLQueryAction; import org.opensearch.ml.rest.RestMLCreateConnectorAction; import org.opensearch.ml.rest.RestMLDeleteConnectorAction; diff --git a/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java b/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java index 543190e0e8..ced731edd4 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java @@ -5,6 +5,19 @@ package org.opensearch.ml.rest; +import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; +import static org.opensearch.core.rest.RestStatus.OK; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchSecurityException; @@ -24,118 +37,93 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +public class MyRestPPLQueryAction extends BaseRestHandler { + public static final String QUERY_API_ENDPOINT = "_ml/_ppl"; + public static final String EXPLAIN_API_ENDPOINT = "_ml/_ppl/_explain"; + public static final String LEGACY_QUERY_API_ENDPOINT = "_ml/_opendistro/_ppl"; + public static final String LEGACY_EXPLAIN_API_ENDPOINT = "_ml/_opendistro/_ppl/_explain"; -import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; -import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; -import static org.opensearch.core.rest.RestStatus.OK; + private static final Logger LOG = LogManager.getLogger(); -public class MyRestPPLQueryAction extends BaseRestHandler { - public static final String QUERY_API_ENDPOINT = "_ml/_ppl"; - public static final String EXPLAIN_API_ENDPOINT = "_ml/_ppl/_explain"; - public static final String LEGACY_QUERY_API_ENDPOINT = "_ml/_opendistro/_ppl"; - public static final String LEGACY_EXPLAIN_API_ENDPOINT = "_ml/_opendistro/_ppl/_explain"; - - private static final Logger LOG = LogManager.getLogger(); - - /** Constructor of RestPPLQueryAction. */ - public MyRestPPLQueryAction() { - super(); - } - - @Override - public List routes() { - return List.of(new Route(RestRequest.Method.POST, QUERY_API_ENDPOINT), new Route(RestRequest.Method.POST, EXPLAIN_API_ENDPOINT)); - } - - @Override - public String getName() { - return "ml_ppl_query_action"; - } - - @Override - protected Set responseParams() { - Set responseParams = new HashSet<>(super.responseParams()); - responseParams.addAll(Arrays.asList("format", "sanitize")); - return responseParams; - } - - @Override - protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nodeClient) { - TransportPPLQueryRequest transportPPLQueryRequest = - new TransportPPLQueryRequest(PPLQueryRequestFactory.getPPLRequest(request)); - LOG.info("request classloader: " + transportPPLQueryRequest.getClass().getClassLoader()); - LOG.info("response classloader:" + TransportPPLQueryResponse.class.getClassLoader()); - - return channel -> - nodeClient.execute( - PPLQueryAction.INSTANCE, - transportPPLQueryRequest, - getPPLTransportActionListener( - new ActionListener<>() { - @Override - public void onResponse(TransportPPLQueryResponse response) { - sendResponse(channel, OK, response.getResult()); - } - - @Override - public void onFailure(Exception e) { - if (e instanceof IllegalAccessException) { - LOG.error("Error happened during query handling", e); - reportError(channel, e, BAD_REQUEST); - } else if (transportPPLQueryRequest.isExplainRequest()) { - LOG.error("Error happened during explain", e); - sendResponse( - channel, - INTERNAL_SERVER_ERROR, - "Failed to explain the query due to error: " + e.getMessage()); - } else if (e instanceof OpenSearchSecurityException) { - OpenSearchSecurityException exception = (OpenSearchSecurityException) e; - reportError(channel, exception, exception.status()); - } else { - LOG.error("Error happened during query handling", e); - reportError(channel, e, INTERNAL_SERVER_ERROR); - } - } - })); - } + /** Constructor of RestPPLQueryAction. */ + public MyRestPPLQueryAction() { + super(); + } - private void sendResponse(RestChannel channel, RestStatus status, String content) { - channel.sendResponse(new BytesRestResponse(status, "application/json; charset=UTF-8", content)); - } + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.POST, QUERY_API_ENDPOINT), new Route(RestRequest.Method.POST, EXPLAIN_API_ENDPOINT)); + } - private void reportError(final RestChannel channel, final Exception e, final RestStatus status) { - channel.sendResponse(new BytesRestResponse(status, e.getMessage())); - } + @Override + public String getName() { + return "ml_ppl_query_action"; + } - private ActionListener getPPLTransportActionListener(ActionListener listener) { - return ActionListener.wrap(r -> { - listener.onResponse(fromActionResponse(r)); - }, listener::onFailure); - } + @Override + protected Set responseParams() { + Set responseParams = new HashSet<>(super.responseParams()); + responseParams.addAll(Arrays.asList("format", "sanitize")); + return responseParams; + } - private static TransportPPLQueryResponse fromActionResponse(ActionResponse actionResponse) { - if (actionResponse instanceof TransportPPLQueryResponse) { - return (TransportPPLQueryResponse) actionResponse; + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nodeClient) { + TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest(PPLQueryRequestFactory.getPPLRequest(request)); + LOG.info("request classloader: " + transportPPLQueryRequest.getClass().getClassLoader()); + LOG.info("response classloader:" + TransportPPLQueryResponse.class.getClassLoader()); + + return channel -> nodeClient + .execute(PPLQueryAction.INSTANCE, transportPPLQueryRequest, getPPLTransportActionListener(new ActionListener<>() { + @Override + public void onResponse(TransportPPLQueryResponse response) { + sendResponse(channel, OK, response.getResult()); + } + + @Override + public void onFailure(Exception e) { + if (e instanceof IllegalAccessException) { + LOG.error("Error happened during query handling", e); + reportError(channel, e, BAD_REQUEST); + } else if (transportPPLQueryRequest.isExplainRequest()) { + LOG.error("Error happened during explain", e); + sendResponse(channel, INTERNAL_SERVER_ERROR, "Failed to explain the query due to error: " + e.getMessage()); + } else if (e instanceof OpenSearchSecurityException) { + OpenSearchSecurityException exception = (OpenSearchSecurityException) e; + reportError(channel, exception, exception.status()); + } else { + LOG.error("Error happened during query handling", e); + reportError(channel, e, INTERNAL_SERVER_ERROR); + } + } + })); + } + + private void sendResponse(RestChannel channel, RestStatus status, String content) { + channel.sendResponse(new BytesRestResponse(status, "application/json; charset=UTF-8", content)); } - try ( - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { - actionResponse.writeTo(osso); - try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new TransportPPLQueryResponse(input); - } - } catch (IOException e) { - throw new UncheckedIOException("failed to parse ActionResponse into TransportPPLQueryResponse", e); + private void reportError(final RestChannel channel, final Exception e, final RestStatus status) { + channel.sendResponse(new BytesRestResponse(status, e.getMessage())); } - } + private ActionListener getPPLTransportActionListener(ActionListener listener) { + return ActionListener.wrap(r -> { listener.onResponse(fromActionResponse(r)); }, listener::onFailure); + } + + private static TransportPPLQueryResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof TransportPPLQueryResponse) { + return (TransportPPLQueryResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new TransportPPLQueryResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into TransportPPLQueryResponse", e); + } + + } }