diff --git a/common/build.gradle b/common/build.gradle index 79077317a1..3857cf4ee3 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -9,6 +9,7 @@ plugins { id 'com.github.johnrengelman.shadow' id 'jacoco' id "io.freefair.lombok" + id 'com.diffplug.spotless' version '6.25.0' id 'maven-publish' id 'signing' } @@ -67,6 +68,15 @@ jacocoTestCoverageVerification { } check.dependsOn jacocoTestCoverageVerification +spotless { + java { + removeUnusedImports() + importOrder 'java', 'javax', 'org', 'com' + + eclipse().configFile rootProject.file('.eclipseformat.xml') + } +} + shadowJar { destinationDirectory = file("${project.buildDir}/distributions") archiveClassifier.set(null) diff --git a/common/src/main/java/org/opensearch/ml/common/AccessMode.java b/common/src/main/java/org/opensearch/ml/common/AccessMode.java index 6b8e31e2fd..d4195206d5 100644 --- a/common/src/main/java/org/opensearch/ml/common/AccessMode.java +++ b/common/src/main/java/org/opensearch/ml/common/AccessMode.java @@ -7,11 +7,11 @@ package org.opensearch.ml.common; -import lombok.Getter; - import java.util.HashMap; import java.util.Map; +import lombok.Getter; + public enum AccessMode { PUBLIC("public"), PRIVATE("private"), diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 39da1edd23..933f00b5ad 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -5,14 +5,6 @@ package org.opensearch.ml.common; -import com.google.common.collect.ImmutableSet; -import org.opensearch.Version; -import org.opensearch.ml.common.agent.MLAgent; -import org.opensearch.ml.common.connector.AbstractConnector; -import org.opensearch.ml.common.controller.MLController; - -import java.util.Set; - import static org.opensearch.ml.common.MLConfig.CONFIG_TYPE_FIELD; import static org.opensearch.ml.common.MLConfig.LAST_UPDATED_TIME_FIELD; import static org.opensearch.ml.common.MLConfig.ML_CONFIGURATION_FIELD; @@ -40,506 +32,549 @@ import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.POOLING_MODE_FIELD; +import java.util.Set; + +import org.opensearch.Version; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.connector.AbstractConnector; +import org.opensearch.ml.common.controller.MLController; + +import com.google.common.collect.ImmutableSet; + public class CommonValue { - public static Integer NO_SCHEMA_VERSION = 0; - public static final String REMOTE_SERVICE_ERROR = "Error from remote service: "; - public static final String USER = "user"; - public static final String META = "_meta"; - public static final String SCHEMA_VERSION_FIELD = "schema_version"; - public static final String UNDEPLOYED = "undeployed"; - public static final String NOT_FOUND = "not_found"; + public static Integer NO_SCHEMA_VERSION = 0; + public static final String REMOTE_SERVICE_ERROR = "Error from remote service: "; + public static final String USER = "user"; + public static final String META = "_meta"; + public static final String SCHEMA_VERSION_FIELD = "schema_version"; + public static final String UNDEPLOYED = "undeployed"; + public static final String NOT_FOUND = "not_found"; - public static final String MASTER_KEY = "master_key"; - public static final String CREATE_TIME_FIELD = "create_time"; - public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + public static final String MASTER_KEY = "master_key"; + public static final String CREATE_TIME_FIELD = "create_time"; + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; - public static final String BOX_TYPE_KEY = "box_type"; - // hot node - public static String HOT_BOX_TYPE = "hot"; - // warm node - public static String WARM_BOX_TYPE = "warm"; - public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group"; - public static final String ML_MODEL_INDEX = ".plugins-ml-model"; - public static final String ML_TASK_INDEX = ".plugins-ml-task"; - public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11; - public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; - public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3; - public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; - public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 3; - public static final String ML_CONTROLLER_INDEX = ".plugins-ml-controller"; - public static final Integer ML_CONTROLLER_INDEX_SCHEMA_VERSION = 1; - public static final String ML_MAP_RESPONSE_KEY = "response"; - public static final String ML_AGENT_INDEX = ".plugins-ml-agent"; - public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 2; - public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; - public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1; - public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; - public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words"; - public static final Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); - public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1; - public static final String USER_FIELD_MAPPING = " \"" - + CommonValue.USER - + "\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " }\n"; - public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" - + " \"_meta\": {\n" - + " \"schema_version\": " + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + "\n" - + " },\n" - + " \"properties\": {\n" - + " \"" + MLModelGroup.MODEL_GROUP_NAME_FIELD + "\": {\n" - + " \"type\": \"text\",\n" - + " \"fields\": {\n" - + " \"keyword\": {\n" - + " \"type\": \"keyword\",\n" - + " \"ignore_above\": 256\n" - + " }\n" - + " }\n" - + " },\n" - + " \"" + MLModelGroup.DESCRIPTION_FIELD + "\": {\n" - + " \"type\": \"text\"\n" - + " },\n" - + " \"" + MLModelGroup.LATEST_VERSION_FIELD + "\": {\n" - + " \"type\": \"integer\"\n" - + " },\n" - + " \"" + MLModelGroup.MODEL_GROUP_ID_FIELD + "\": {\n" - + " \"type\": \"keyword\"\n" - + " },\n" - + " \"" + MLModelGroup.BACKEND_ROLES_FIELD + "\": {\n" - + " \"type\": \"text\",\n" - + " \"fields\": {\n" - + " \"keyword\": {\n" - + " \"type\": \"keyword\",\n" - + " \"ignore_above\": 256\n" - + " }\n" - + " }\n" - + " },\n" - + " \"" + MLModelGroup.ACCESS + "\": {\n" - + " \"type\": \"keyword\"\n" - + " },\n" - + " \"" + MLModelGroup.OWNER + "\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " },\n" - + " \"" + MLModelGroup.CREATED_TIME_FIELD + "\": {\n" - + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" + MLModelGroup.LAST_UPDATED_TIME_FIELD + "\": {\n" - + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; + public static final String BOX_TYPE_KEY = "box_type"; + // hot node + public static String HOT_BOX_TYPE = "hot"; + // warm node + public static String WARM_BOX_TYPE = "warm"; + public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group"; + public static final String ML_MODEL_INDEX = ".plugins-ml-model"; + public static final String ML_TASK_INDEX = ".plugins-ml-task"; + public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 11; + public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; + public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; + public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3; + public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; + public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 3; + public static final String ML_CONTROLLER_INDEX = ".plugins-ml-controller"; + public static final Integer ML_CONTROLLER_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MAP_RESPONSE_KEY = "response"; + public static final String ML_AGENT_INDEX = ".plugins-ml-agent"; + public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 2; + public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; + public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; + public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words"; + public static final Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); + public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1; + public static final String USER_FIELD_MAPPING = " \"" + + CommonValue.USER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " }\n"; + public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + MLModelGroup.MODEL_GROUP_NAME_FIELD + + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.DESCRIPTION_FIELD + + "\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"" + + MLModelGroup.LATEST_VERSION_FIELD + + "\": {\n" + + " \"type\": \"integer\"\n" + + " },\n" + + " \"" + + MLModelGroup.MODEL_GROUP_ID_FIELD + + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + + MLModelGroup.BACKEND_ROLES_FIELD + + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.ACCESS + + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + + MLModelGroup.OWNER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.CREATED_TIME_FIELD + + "\": {\n" + + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModelGroup.LAST_UPDATED_TIME_FIELD + + "\": {\n" + + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; - public static final String ML_CONNECTOR_INDEX_FIELDS = " \"properties\": {\n" - + " \"" - + AbstractConnector.NAME_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + AbstractConnector.VERSION_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + AbstractConnector.DESCRIPTION_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + AbstractConnector.PROTOCOL_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + AbstractConnector.PARAMETERS_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + AbstractConnector.CREDENTIAL_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + AbstractConnector.CLIENT_CONFIG_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + AbstractConnector.ACTIONS_FIELD - + "\" : {\"type\": \"flat_object\"}\n"; + public static final String ML_CONNECTOR_INDEX_FIELDS = " \"properties\": {\n" + + " \"" + + AbstractConnector.NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + AbstractConnector.VERSION_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + AbstractConnector.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + AbstractConnector.PROTOCOL_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + AbstractConnector.PARAMETERS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + AbstractConnector.CREDENTIAL_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + AbstractConnector.CLIENT_CONFIG_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + AbstractConnector.ACTIONS_FIELD + + "\" : {\"type\": \"flat_object\"}\n"; - public static final String ML_MODEL_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_MODEL_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLModel.ALGORITHM_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_NAME_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + MLModel.OLD_MODEL_VERSION_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.MODEL_VERSION_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_GROUP_ID_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_CONTENT_FIELD - + "\" : {\"type\": \"binary\"},\n" - + " \"" - + MLModel.CHUNK_NUMBER_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.TOTAL_CHUNKS_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.MODEL_ID_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.DESCRIPTION_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + MLModel.MODEL_FORMAT_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_STATE_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.PLANNING_WORKER_NODES_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.DEPLOY_TO_ALL_NODES_FIELD - + "\": {\"type\": \"boolean\"},\n" - + " \"" - + MLModel.IS_HIDDEN_FIELD - + "\": {\"type\": \"boolean\"},\n" - + " \"" - + MLModel.MODEL_CONFIG_FIELD - + "\" : {\"properties\":{\"" - + MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" - + EMBEDDING_DIMENSION_FIELD + "\":{\"type\":\"integer\"},\"" - + FRAMEWORK_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" - + POOLING_MODE_FIELD + "\":{\"type\":\"keyword\"},\"" - + NORMALIZE_RESULT_FIELD + "\":{\"type\":\"boolean\"},\"" - + MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\"" - + ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n" - + " \"" - + MLModel.DEPLOY_SETTING_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLModel.IS_ENABLED_FIELD - + "\" : {\"type\": \"boolean\"},\n" - + " \"" - + MLModel.IS_CONTROLLER_ENABLED_FIELD - + "\" : {\"type\": \"boolean\"},\n" - + " \"" - + MLModel.RATE_LIMITER_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_REGISTERED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_DEPLOYED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_UNDEPLOYED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.INTERFACE_FIELD - + "\": {\"type\": \"flat_object\"},\n" - + " \"" - + MLModel.GUARDRAILS_FIELD - + "\" : {\n" - + " \"properties\": {\n" - + " \"input_guardrail\": {\n" - + " \"properties\": {\n" - + " \"regex\": {\n" - + " \"type\": \"text\"\n" - + " },\n" - + " \"stop_words\": {\n" - + " \"properties\": {\n" - + " \"index_name\": {\n" - + " \"type\": \"text\"\n" - + " },\n" - + " \"source_fields\": {\n" - + " \"type\": \"text\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " },\n" - + " \"output_guardrail\": {\n" - + " \"properties\": {\n" - + " \"regex\": {\n" - + " \"type\": \"text\"\n" - + " },\n" - + " \"stop_words\": {\n" - + " \"properties\": {\n" - + " \"index_name\": {\n" - + " \"type\": \"text\"\n" - + " },\n" - + " \"source_fields\": {\n" - + " \"type\": \"text\"\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " }\n" - + " },\n" - + " \"" - + MLModel.CONNECTOR_FIELD - + "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n}," - + USER_FIELD_MAPPING - + " }\n" - + "}"; + public static final String ML_MODEL_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_MODEL_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLModel.ALGORITHM_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + MLModel.OLD_MODEL_VERSION_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.MODEL_VERSION_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_GROUP_ID_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_FIELD + + "\" : {\"type\": \"binary\"},\n" + + " \"" + + MLModel.CHUNK_NUMBER_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.TOTAL_CHUNKS_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.MODEL_ID_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + MLModel.MODEL_FORMAT_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_STATE_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.PLANNING_WORKER_NODES_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.DEPLOY_TO_ALL_NODES_FIELD + + "\": {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.IS_HIDDEN_FIELD + + "\": {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.MODEL_CONFIG_FIELD + + "\" : {\"properties\":{\"" + + MODEL_TYPE_FIELD + + "\":{\"type\":\"keyword\"},\"" + + EMBEDDING_DIMENSION_FIELD + + "\":{\"type\":\"integer\"},\"" + + FRAMEWORK_TYPE_FIELD + + "\":{\"type\":\"keyword\"},\"" + + POOLING_MODE_FIELD + + "\":{\"type\":\"keyword\"},\"" + + NORMALIZE_RESULT_FIELD + + "\":{\"type\":\"boolean\"},\"" + + MODEL_MAX_LENGTH_FIELD + + "\":{\"type\":\"integer\"},\"" + + ALL_CONFIG_FIELD + + "\":{\"type\":\"text\"}}},\n" + + " \"" + + MLModel.DEPLOY_SETTING_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLModel.IS_ENABLED_FIELD + + "\" : {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.IS_CONTROLLER_ENABLED_FIELD + + "\" : {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.RATE_LIMITER_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_REGISTERED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_DEPLOYED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_UNDEPLOYED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.INTERFACE_FIELD + + "\": {\"type\": \"flat_object\"},\n" + + " \"" + + MLModel.GUARDRAILS_FIELD + + "\" : {\n" + + " \"properties\": {\n" + + " \"input_guardrail\": {\n" + + " \"properties\": {\n" + + " \"regex\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"stop_words\": {\n" + + " \"properties\": {\n" + + " \"index_name\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"source_fields\": {\n" + + " \"type\": \"text\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"output_guardrail\": {\n" + + " \"properties\": {\n" + + " \"regex\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"stop_words\": {\n" + + " \"properties\": {\n" + + " \"index_name\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"source_fields\": {\n" + + " \"type\": \"text\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModel.CONNECTOR_FIELD + + "\": {" + + ML_CONNECTOR_INDEX_FIELDS + + " }\n}," + + USER_FIELD_MAPPING + + " }\n" + + "}"; - public static final String ML_TASK_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_TASK_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLTask.MODEL_ID_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.TASK_TYPE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.FUNCTION_NAME_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.STATE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.INPUT_TYPE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.PROGRESS_FIELD - + "\": {\"type\": \"float\"},\n" - + " \"" - + MLTask.OUTPUT_INDEX_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.WORKER_NODE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.CREATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLTask.LAST_UPDATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLTask.ERROR_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + MLTask.IS_ASYNC_TASK_FIELD - + "\" : {\"type\" : \"boolean\"}, \n" - + USER_FIELD_MAPPING - + " }\n" - + "}"; + public static final String ML_TASK_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_TASK_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLTask.MODEL_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.TASK_TYPE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.FUNCTION_NAME_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.STATE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.INPUT_TYPE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.PROGRESS_FIELD + + "\": {\"type\": \"float\"},\n" + + " \"" + + MLTask.OUTPUT_INDEX_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.WORKER_NODE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLTask.LAST_UPDATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLTask.ERROR_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + MLTask.IS_ASYNC_TASK_FIELD + + "\" : {\"type\" : \"boolean\"}, \n" + + USER_FIELD_MAPPING + + " }\n" + + "}"; - public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_CONNECTOR_SCHEMA_VERSION - + "},\n" - + ML_CONNECTOR_INDEX_FIELDS + ",\n" - + " \"" - + MLModelGroup.BACKEND_ROLES_FIELD - + "\": {\n" - + " \"type\": \"text\",\n" - + " \"fields\": {\n" - + " \"keyword\": {\n" - + " \"type\": \"keyword\",\n" - + " \"ignore_above\": 256\n" - + " }\n" - + " }\n" - + " },\n" - + " \"" - + MLModelGroup.ACCESS - + "\": {\n" - + " \"type\": \"keyword\"\n" - + " },\n" - + " \"" - + MLModelGroup.OWNER - + "\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " },\n" - + " \"" - + AbstractConnector.CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + AbstractConnector.LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; + public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_CONNECTOR_SCHEMA_VERSION + + "},\n" + + ML_CONNECTOR_INDEX_FIELDS + + ",\n" + + " \"" + + MLModelGroup.BACKEND_ROLES_FIELD + + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.ACCESS + + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + + MLModelGroup.OWNER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " },\n" + + " \"" + + AbstractConnector.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + AbstractConnector.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; - public static final String ML_CONFIG_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_CONFIG_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MASTER_KEY - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + CONFIG_TYPE_FIELD - + "\" : {\"type\":\"keyword\"},\n" - + " \"" - + ML_CONFIGURATION_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + CREATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; + public static final String ML_CONFIG_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_CONFIG_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MASTER_KEY + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + CONFIG_TYPE_FIELD + + "\" : {\"type\":\"keyword\"},\n" + + " \"" + + ML_CONFIGURATION_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; - public static final String ML_CONTROLLER_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_CONTROLLER_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLController.USER_RATE_LIMITER - + "\" : {\"type\": \"flat_object\"}\n" - + " }\n" - + "}"; + public static final String ML_CONTROLLER_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_CONTROLLER_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLController.USER_RATE_LIMITER + + "\" : {\"type\": \"flat_object\"}\n" + + " }\n" + + "}"; - public static final String ML_AGENT_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_AGENT_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLAgent.AGENT_NAME_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + MLAgent.AGENT_TYPE_FIELD - + "\" : {\"type\":\"keyword\"},\n" - + " \"" - + MLAgent.DESCRIPTION_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + MLAgent.LLM_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLAgent.TOOLS_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLAgent.PARAMETERS_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLAgent.MEMORY_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLAgent.IS_HIDDEN_FIELD - + "\": {\"type\": \"boolean\"},\n" - + " \"" - + MLAgent.CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLAgent.LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; + public static final String ML_AGENT_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_AGENT_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLAgent.AGENT_NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + MLAgent.AGENT_TYPE_FIELD + + "\" : {\"type\":\"keyword\"},\n" + + " \"" + + MLAgent.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + MLAgent.LLM_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.TOOLS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.PARAMETERS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.MEMORY_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.IS_HIDDEN_FIELD + + "\": {\"type\": \"boolean\"},\n" + + " \"" + + MLAgent.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLAgent.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; - public static final String ML_MEMORY_META_INDEX_MAPPING = "{\n" - + " \"_meta\": {\n" - + " \"schema_version\": " + META_INDEX_SCHEMA_VERSION + "\n" - + " },\n" - + " \"properties\": {\n" - + " \"" - + META_NAME_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + META_CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + META_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + USER_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + APPLICATION_TYPE_FIELD - + "\": {\"type\": \"keyword\"}\n" - + " }\n" - + "}"; + public static final String ML_MEMORY_META_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + + META_INDEX_SCHEMA_VERSION + + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + META_NAME_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + META_CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + META_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + USER_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + APPLICATION_TYPE_FIELD + + "\": {\"type\": \"keyword\"}\n" + + " }\n" + + "}"; - public static final String ML_MEMORY_MESSAGE_INDEX_MAPPING = "{\n" - + " \"_meta\": {\n" - + " \"schema_version\": " + INTERACTIONS_INDEX_SCHEMA_VERSION + "\n" - + " },\n" - + " \"properties\": {\n" - + " \"" - + INTERACTIONS_CONVERSATION_ID_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + INTERACTIONS_CREATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + INTERACTIONS_INPUT_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + INTERACTIONS_PROMPT_TEMPLATE_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + INTERACTIONS_RESPONSE_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + INTERACTIONS_ORIGIN_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + INTERACTIONS_ADDITIONAL_INFO_FIELD - + "\": {\"type\": \"flat_object\"},\n" - + " \"" - + PARENT_INTERACTIONS_ID_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + INTERACTIONS_TRACE_NUMBER_FIELD - + "\": {\"type\": \"long\"}\n" - + " }\n" - + "}"; - // Calculate Versions independently of OpenSearch core version - public static final Version VERSION_2_11_0 = Version.fromString("2.11.0"); - public static final Version VERSION_2_12_0 = Version.fromString("2.12.0"); - public static final Version VERSION_2_13_0 = Version.fromString("2.13.0"); - public static final Version VERSION_2_14_0 = Version.fromString("2.14.0"); - public static final Version VERSION_2_16_0 = Version.fromString("2.16.0"); - public static final Version VERSION_2_17_0 = Version.fromString("2.17.0"); + public static final String ML_MEMORY_MESSAGE_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + + INTERACTIONS_INDEX_SCHEMA_VERSION + + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + INTERACTIONS_CONVERSATION_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + INTERACTIONS_INPUT_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_PROMPT_TEMPLATE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_RESPONSE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_ORIGIN_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_ADDITIONAL_INFO_FIELD + + "\": {\"type\": \"flat_object\"},\n" + + " \"" + + PARENT_INTERACTIONS_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_TRACE_NUMBER_FIELD + + "\": {\"type\": \"long\"}\n" + + " }\n" + + "}"; + // Calculate Versions independently of OpenSearch core version + public static final Version VERSION_2_11_0 = Version.fromString("2.11.0"); + public static final Version VERSION_2_12_0 = Version.fromString("2.12.0"); + public static final Version VERSION_2_13_0 = Version.fromString("2.13.0"); + public static final Version VERSION_2_14_0 = Version.fromString("2.14.0"); + public static final Version VERSION_2_16_0 = Version.fromString("2.16.0"); + public static final Version VERSION_2_17_0 = Version.fromString("2.17.0"); } diff --git a/common/src/main/java/org/opensearch/ml/common/Configuration.java b/common/src/main/java/org/opensearch/ml/common/Configuration.java index fa5a1bfe22..5154b1e1db 100644 --- a/common/src/main/java/org/opensearch/ml/common/Configuration.java +++ b/common/src/main/java/org/opensearch/ml/common/Configuration.java @@ -5,10 +5,10 @@ package org.opensearch.ml.common; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -16,9 +16,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; @Getter @EqualsAndHashCode @@ -30,9 +31,7 @@ public class Configuration implements ToXContentObject, Writeable { private String agentId; @Builder(toBuilder = true) - public Configuration( - String agentId - ) { + public Configuration(String agentId) { this.agentId = agentId; } @@ -76,8 +75,6 @@ public static Configuration parse(XContentParser parser) throws IOException { break; } } - return Configuration.builder() - .agentId(agentId) - .build(); + return Configuration.builder().agentId(agentId).build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index cf308f1d8d..96df6baa55 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -41,13 +41,9 @@ public static FunctionName from(String value) { } } - private static final HashSet DL_MODELS = new HashSet<>(Set.of( - TEXT_EMBEDDING, - TEXT_SIMILARITY, - SPARSE_ENCODING, - SPARSE_TOKENIZE, - QUESTION_ANSWERING - )); + private static final HashSet DL_MODELS = new HashSet<>( + Set.of(TEXT_EMBEDDING, TEXT_SIMILARITY, SPARSE_ENCODING, SPARSE_TOKENIZE, QUESTION_ANSWERING) + ); /** * Check if model is deep learning model. diff --git a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java index 75046e9bfd..1709923c48 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java +++ b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java @@ -5,7 +5,15 @@ package org.opensearch.ml.common; -import lombok.extern.log4j.Log4j2; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.ml.common.annotation.Connector; import org.opensearch.ml.common.annotation.ExecuteInput; @@ -20,14 +28,7 @@ import org.opensearch.ml.common.output.MLOutputType; import org.reflections.Reflections; -import java.io.IOException; -import java.lang.reflect.Constructor; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; +import lombok.extern.log4j.Log4j2; @Log4j2 @SuppressWarnings("removal") @@ -93,7 +94,7 @@ private static void loadMLAlgoParameterClassMapping() { if (mlAlgoParameter != null) { FunctionName[] algorithms = mlAlgoParameter.algorithms(); if (algorithms != null && algorithms.length > 0) { - for(FunctionName name : algorithms){ + for (FunctionName name : algorithms) { parameterClassMap.put(name, clazz); } } @@ -157,7 +158,7 @@ private static void loadExecuteInputClassMapping() { if (executeInput != null) { FunctionName[] algorithms = executeInput.algorithms(); if (algorithms != null && algorithms.length > 0) { - for(FunctionName name : algorithms){ + for (FunctionName name : algorithms) { executeInputClassMap.put(name, clazz); } } @@ -176,7 +177,7 @@ private static void loadExecuteOutputClassMapping() { if (executeOutput != null) { FunctionName[] algorithms = executeOutput.algorithms(); if (algorithms != null && algorithms.length > 0) { - for(FunctionName name : algorithms){ + for (FunctionName name : algorithms) { executeOutputClassMap.put(name, clazz); } } @@ -192,7 +193,7 @@ private static void loadMLInputClassMapping() { if (mlInput != null) { FunctionName[] algorithms = mlInput.functionNames(); if (algorithms != null && algorithms.length > 0) { - for(FunctionName name : algorithms){ + for (FunctionName name : algorithms) { mlInputClassMap.put(name, clazz); } } @@ -242,7 +243,7 @@ private static S init(Map> map, T type, I i } catch (Exception e) { Throwable cause = e.getCause(); if (cause instanceof MLException || cause instanceof IllegalArgumentException) { - throw (RuntimeException)cause; + throw (RuntimeException) cause; } else { log.error("Failed to init instance for type " + type, e); return null; @@ -254,19 +255,16 @@ public static boolean canInitMLInput(FunctionName functionName) { return mlInputClassMap.containsKey(functionName); } - public static S initConnector(String name, Object[] initArgs, - Class... constructorParameterTypes) { + public static S initConnector(String name, Object[] initArgs, Class... constructorParameterTypes) { return init(connectorClassMap, name, initArgs, constructorParameterTypes); } @SuppressWarnings("unchecked") - public static , S> S initMLInput(T type, Object[] initArgs, - Class... constructorParameterTypes) { + public static , S> S initMLInput(T type, Object[] initArgs, Class... constructorParameterTypes) { return init(mlInputClassMap, type, initArgs, constructorParameterTypes); } - private static S init(Map> map, T type, - Object[] initArgs, Class... constructorParameterTypes) { + private static S init(Map> map, T type, Object[] initArgs, Class... constructorParameterTypes) { Class clazz = map.get(type); if (clazz == null) { throw new IllegalArgumentException("Can't find class for type " + type); @@ -277,7 +275,7 @@ private static S init(Map> map, T type, } catch (Exception e) { Throwable cause = e.getCause(); if (cause instanceof MLException) { - throw (MLException)cause; + throw (MLException) cause; } else if (cause instanceof IllegalArgumentException) { throw (IllegalArgumentException) cause; } else { diff --git a/common/src/main/java/org/opensearch/ml/common/MLConfig.java b/common/src/main/java/org/opensearch/ml/common/MLConfig.java index bbcbb4aee1..ccdeed5df3 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/MLConfig.java @@ -5,10 +5,11 @@ package org.opensearch.ml.common; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -16,10 +17,10 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.time.Instant; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; @Getter @EqualsAndHashCode @@ -38,7 +39,6 @@ public class MLConfig implements ToXContentObject, Writeable { public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; - @Setter private String type; @@ -47,12 +47,7 @@ public class MLConfig implements ToXContentObject, Writeable { private Instant lastUpdateTime; @Builder(toBuilder = true) - public MLConfig( - String type, - Configuration configuration, - Instant createTime, - Instant lastUpdateTime - ) { + public MLConfig(String type, Configuration configuration, Instant createTime, Instant lastUpdateTime) { this.type = type; this.configuration = configuration; this.createTime = createTime; @@ -145,11 +140,12 @@ public static MLConfig parse(XContentParser parser) throws IOException { break; } } - return MLConfig.builder() - .type(configType == null ? type : configType) - .configuration(mlConfiguration == null ? configuration : mlConfiguration) - .createTime(createTime) - .lastUpdateTime(lastUpdatedTime == null ? lastUpdateTime : lastUpdatedTime) - .build(); + return MLConfig + .builder() + .type(configType == null ? type : configType) + .configuration(mlConfiguration == null ? configuration : mlConfiguration) + .createTime(createTime) + .lastUpdateTime(lastUpdatedTime == null ? lastUpdateTime : lastUpdatedTime) + .build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 363fa4bb7d..a742f542c5 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -5,42 +5,42 @@ package org.opensearch.ml.common; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.USER; +import static org.opensearch.ml.common.connector.Connector.createConnector; +import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.Guardrails; import org.opensearch.ml.common.model.MLDeploySetting; import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.QuestionAnsweringModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; - -import java.io.IOException; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Set; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.USER; -import static org.opensearch.ml.common.connector.Connector.createConnector; -import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Getter public class MLModel implements ToXContentObject { @@ -174,39 +174,42 @@ public class MLModel implements ToXContentObject { private Map modelInterface; @Builder(toBuilder = true) - public MLModel(String name, - String modelGroupId, - FunctionName algorithm, - String version, - String content, - User user, - String description, - MLModelFormat modelFormat, - MLModelState modelState, - Long modelContentSizeInBytes, - String modelContentHash, - Boolean isEnabled, - Boolean isControllerEnabled, - MLRateLimiter rateLimiter, - MLModelConfig modelConfig, - MLDeploySetting deploySetting, - Instant createdTime, - Instant lastUpdateTime, - Instant lastRegisteredTime, - Instant lastDeployedTime, - Instant lastUndeployedTime, - Integer autoRedeployRetryTimes, - String modelId, Integer chunkNumber, - Integer totalChunks, - Integer planningWorkerNodeCount, - Integer currentWorkerNodeCount, - String[] planningWorkerNodes, - boolean deployToAllNodes, - Boolean isHidden, - Connector connector, - String connectorId, - Guardrails guardrails, - Map modelInterface) { + public MLModel( + String name, + String modelGroupId, + FunctionName algorithm, + String version, + String content, + User user, + String description, + MLModelFormat modelFormat, + MLModelState modelState, + Long modelContentSizeInBytes, + String modelContentHash, + Boolean isEnabled, + Boolean isControllerEnabled, + MLRateLimiter rateLimiter, + MLModelConfig modelConfig, + MLDeploySetting deploySetting, + Instant createdTime, + Instant lastUpdateTime, + Instant lastRegisteredTime, + Instant lastDeployedTime, + Instant lastUndeployedTime, + Integer autoRedeployRetryTimes, + String modelId, + Integer chunkNumber, + Integer totalChunks, + Integer planningWorkerNodeCount, + Integer currentWorkerNodeCount, + String[] planningWorkerNodes, + boolean deployToAllNodes, + Boolean isHidden, + Connector connector, + String connectorId, + Guardrails guardrails, + Map modelInterface + ) { this.name = name; this.modelGroupId = modelGroupId; this.algorithm = algorithm; @@ -679,42 +682,43 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws break; } } - return MLModel.builder() - .name(name) - .modelGroupId(modelGroupId) - .algorithm(algorithm) - .version(version == null ? oldVersion + "" : version) - .content(content == null ? oldContent : content) - .user(user) - .description(description) - .modelFormat(modelFormat) - .modelState(modelState) - .modelContentSizeInBytes(modelContentSizeInBytes) - .modelContentHash(modelContentHash) - .modelConfig(modelConfig) - .deploySetting(deploySetting) - .isEnabled(isEnabled) - .isControllerEnabled(isControllerEnabled) - .rateLimiter(rateLimiter) - .createdTime(createdTime) - .lastUpdateTime(lastUpdateTime) - .lastRegisteredTime(lastRegisteredTime == null ? lastUploadedTime : lastRegisteredTime) - .lastDeployedTime(lastDeployedTime == null ? lastLoadedTime : lastDeployedTime) - .lastUndeployedTime(lastUndeployedTime == null ? lastUnloadedTime : lastUndeployedTime) - .modelId(modelId) - .autoRedeployRetryTimes(autoRedeployRetryTimes) - .chunkNumber(chunkNumber) - .totalChunks(totalChunks) - .planningWorkerNodeCount(planningWorkerNodeCount) - .currentWorkerNodeCount(currentWorkerNodeCount) - .planningWorkerNodes(planningWorkerNodes.toArray(new String[0])) - .deployToAllNodes(deployToAllNodes) - .isHidden(isHidden) - .connector(connector) - .connectorId(connectorId) - .guardrails(guardrails) - .modelInterface(modelInterface) - .build(); + return MLModel + .builder() + .name(name) + .modelGroupId(modelGroupId) + .algorithm(algorithm) + .version(version == null ? oldVersion + "" : version) + .content(content == null ? oldContent : content) + .user(user) + .description(description) + .modelFormat(modelFormat) + .modelState(modelState) + .modelContentSizeInBytes(modelContentSizeInBytes) + .modelContentHash(modelContentHash) + .modelConfig(modelConfig) + .deploySetting(deploySetting) + .isEnabled(isEnabled) + .isControllerEnabled(isControllerEnabled) + .rateLimiter(rateLimiter) + .createdTime(createdTime) + .lastUpdateTime(lastUpdateTime) + .lastRegisteredTime(lastRegisteredTime == null ? lastUploadedTime : lastRegisteredTime) + .lastDeployedTime(lastDeployedTime == null ? lastLoadedTime : lastDeployedTime) + .lastUndeployedTime(lastUndeployedTime == null ? lastUnloadedTime : lastUndeployedTime) + .modelId(modelId) + .autoRedeployRetryTimes(autoRedeployRetryTimes) + .chunkNumber(chunkNumber) + .totalChunks(totalChunks) + .planningWorkerNodeCount(planningWorkerNodeCount) + .currentWorkerNodeCount(currentWorkerNodeCount) + .planningWorkerNodes(planningWorkerNodes.toArray(new String[0])) + .deployToAllNodes(deployToAllNodes) + .isHidden(isHidden) + .connector(connector) + .connectorId(connectorId) + .guardrails(guardrails) + .modelInterface(modelInterface) + .build(); } public static MLModel fromStream(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java index 8a8cf5ff04..91b21131d4 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java @@ -11,9 +11,8 @@ import java.time.Instant; import java.util.ArrayList; import java.util.List; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import java.util.Objects; + import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -21,23 +20,23 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.util.Objects; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Getter public class MLModelGroup implements ToXContentObject { - public static final String MODEL_GROUP_NAME_FIELD = "name"; //name of the model group - public static final String DESCRIPTION_FIELD = "description"; //description of the model group - public static final String LATEST_VERSION_FIELD = "latest_version"; //latest model version added to the model group - public static final String BACKEND_ROLES_FIELD = "backend_roles"; //back_end roles as specified by the owner/admin - public static final String OWNER = "owner"; //user who creates/owns the model group - - public static final String ACCESS = "access"; //assigned to public, private, or null when model group created - public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //unique ID assigned to each model group - public static final String CREATED_TIME_FIELD = "created_time"; //model group created time stamp - public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; //updated whenever a new model version is created + public static final String MODEL_GROUP_NAME_FIELD = "name"; // name of the model group + public static final String DESCRIPTION_FIELD = "description"; // description of the model group + public static final String LATEST_VERSION_FIELD = "latest_version"; // latest model version added to the model group + public static final String BACKEND_ROLES_FIELD = "backend_roles"; // back_end roles as specified by the owner/admin + public static final String OWNER = "owner"; // user who creates/owns the model group + public static final String ACCESS = "access"; // assigned to public, private, or null when model group created + public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // unique ID assigned to each model group + public static final String CREATED_TIME_FIELD = "created_time"; // model group created time stamp + public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; // updated whenever a new model version is created @Setter private String name; @@ -53,13 +52,18 @@ public class MLModelGroup implements ToXContentObject { private Instant createdTime; private Instant lastUpdatedTime; - @Builder(toBuilder = true) - public MLModelGroup(String name, String description, int latestVersion, - List backendRoles, User owner, String access, - String modelGroupId, - Instant createdTime, - Instant lastUpdatedTime) { + public MLModelGroup( + String name, + String description, + int latestVersion, + List backendRoles, + User owner, + String access, + String modelGroupId, + Instant createdTime, + Instant lastUpdatedTime + ) { this.name = Objects.requireNonNull(name, "model group name must not be null"); this.description = description; this.latestVersion = latestVersion; @@ -71,8 +75,7 @@ public MLModelGroup(String name, String description, int latestVersion, this.lastUpdatedTime = lastUpdatedTime; } - - public MLModelGroup(StreamInput input) throws IOException{ + public MLModelGroup(StreamInput input) throws IOException { name = input.readString(); description = input.readOptionalString(); latestVersion = input.readInt(); @@ -195,20 +198,20 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { break; } } - return MLModelGroup.builder() - .name(name) - .description(description) - .backendRoles(backendRoles) - .latestVersion(latestVersion) - .owner(owner) - .access(access) - .modelGroupId(modelGroupId) - .createdTime(createdTime) - .lastUpdatedTime(lastUpdateTime) - .build(); + return MLModelGroup + .builder() + .name(name) + .description(description) + .backendRoles(backendRoles) + .latestVersion(latestVersion) + .owner(owner) + .access(access) + .modelGroupId(modelGroupId) + .createdTime(createdTime) + .lastUpdatedTime(lastUpdateTime) + .build(); } - public static MLModelGroup fromStream(StreamInput in) throws IOException { MLModelGroup mlModel = new MLModelGroup(in); return mlModel; diff --git a/common/src/main/java/org/opensearch/ml/common/MLTask.java b/common/src/main/java/org/opensearch/ml/common/MLTask.java index 229bba5771..a810fa5159 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTask.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTask.java @@ -5,27 +5,28 @@ package org.opensearch.ml.common; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.USER; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.dataset.MLInputDataType; -import java.io.IOException; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.USER; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; @Getter @EqualsAndHashCode @@ -279,21 +280,22 @@ public static MLTask parse(XContentParser parser) throws IOException { break; } } - return MLTask.builder() - .taskId(taskId) - .modelId(modelId) - .taskType(taskType) - .functionName(functionName) - .state(state) - .inputType(inputType) - .progress(progress) - .outputIndex(outputIndex) - .workerNodes(workerNodes) - .createTime(createTime) - .lastUpdateTime(lastUpdateTime) - .error(error) - .user(user) - .async(async) - .build(); + return MLTask + .builder() + .taskId(taskId) + .modelId(modelId) + .taskType(taskType) + .functionName(functionName) + .state(state) + .inputType(inputType) + .progress(progress) + .outputIndex(outputIndex) + .workerNodes(workerNodes) + .createTime(createTime) + .lastUpdateTime(lastUpdateTime) + .error(error) + .user(user) + .async(async) + .build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/ToolMetadata.java b/common/src/main/java/org/opensearch/ml/common/ToolMetadata.java index fa9c29ead5..19011bf407 100644 --- a/common/src/main/java/org/opensearch/ml/common/ToolMetadata.java +++ b/common/src/main/java/org/opensearch/ml/common/ToolMetadata.java @@ -4,8 +4,10 @@ */ package org.opensearch.ml.common; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -14,10 +16,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - +import lombok.Builder; +import lombok.Getter; public class ToolMetadata implements ToXContentObject, Writeable { @@ -26,7 +26,6 @@ public class ToolMetadata implements ToXContentObject, Writeable { public static final String TOOL_TYPE_FIELD = "type"; public static final String TOOL_VERSION_FIELD = "version"; - @Getter private String name; @Getter @@ -103,16 +102,11 @@ public static ToolMetadata parse(XContentParser parser) throws IOException { break; } } - return ToolMetadata.builder() - .name(name) - .description(description) - .type(type) - .version(version) - .build(); + return ToolMetadata.builder().name(name).description(description).type(type).version(version).build(); } public static ToolMetadata fromStream(StreamInput in) throws IOException { ToolMetadata toolMetadata = new ToolMetadata(in); return toolMetadata; } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java index 561fe81d5f..6a91e24644 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/LLMSpec.java @@ -5,20 +5,21 @@ package org.opensearch.ml.common.agent; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + +import java.io.IOException; +import java.util.Map; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; @EqualsAndHashCode @Getter @@ -29,7 +30,6 @@ public class LLMSpec implements ToXContentObject { private String modelId; private Map parameters; - @Builder(toBuilder = true) public LLMSpec(String modelId, Map parameters) { if (modelId == null) { @@ -39,7 +39,7 @@ public LLMSpec(String modelId, Map parameters) { this.parameters = parameters; } - public LLMSpec(StreamInput input) throws IOException{ + public LLMSpec(StreamInput input) throws IOException { modelId = input.readString(); if (input.readBoolean()) { parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString); @@ -90,10 +90,7 @@ public static LLMSpec parse(XContentParser parser) throws IOException { break; } } - return LLMSpec.builder() - .modelId(modelId) - .parameters(parameters) - .build(); + return LLMSpec.builder().modelId(modelId).parameters(parameters).build(); } public static LLMSpec fromStream(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index a7a67d2e00..dd7872c91b 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -5,19 +5,8 @@ package org.opensearch.ml.common.agent; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import org.opensearch.Version; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.CommonValue; -import org.opensearch.ml.common.MLAgentType; -import org.opensearch.ml.common.MLModel; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import java.io.IOException; import java.time.Instant; @@ -29,8 +18,20 @@ import java.util.Optional; import java.util.Set; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; +import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.MLAgentType; +import org.opensearch.ml.common.MLModel; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; @EqualsAndHashCode @Getter @@ -64,17 +65,19 @@ public class MLAgent implements ToXContentObject, Writeable { private Boolean isHidden; @Builder(toBuilder = true) - public MLAgent(String name, - String type, - String description, - LLMSpec llm, - List tools, - Map parameters, - MLMemorySpec memory, - Instant createdTime, - Instant lastUpdateTime, - String appType, - Boolean isHidden) { + public MLAgent( + String name, + String type, + String description, + LLMSpec llm, + List tools, + Map parameters, + MLMemorySpec memory, + Instant createdTime, + Instant lastUpdateTime, + String appType, + Boolean isHidden + ) { this.name = name; this.type = type; this.description = description; @@ -124,7 +127,7 @@ private void validateMLAgentType(String agentType) { } } - public MLAgent(StreamInput input) throws IOException{ + public MLAgent(StreamInput input) throws IOException { Version streamInputVersion = input.getVersion(); name = input.readString(); type = input.readString(); @@ -135,7 +138,7 @@ public MLAgent(StreamInput input) throws IOException{ if (input.readBoolean()) { tools = new ArrayList<>(); int size = input.readInt(); - for (int i=0; i parameters; private boolean includeOutputInAgentResponse; - @Builder(toBuilder = true) - public MLToolSpec(String type, - String name, - String description, - Map parameters, - boolean includeOutputInAgentResponse) { + public MLToolSpec(String type, String name, String description, Map parameters, boolean includeOutputInAgentResponse) { if (type == null) { throw new IllegalArgumentException("tool type is null"); } @@ -52,7 +48,7 @@ public MLToolSpec(String type, this.includeOutputInAgentResponse = includeOutputInAgentResponse; } - public MLToolSpec(StreamInput input) throws IOException{ + public MLToolSpec(StreamInput input) throws IOException { type = input.readString(); name = input.readOptionalString(); description = input.readOptionalString(); @@ -128,13 +124,14 @@ public static MLToolSpec parse(XContentParser parser) throws IOException { break; } } - return MLToolSpec.builder() - .type(type) - .name(name) - .description(description) - .parameters(parameters) - .includeOutputInAgentResponse(includeOutputInAgentResponse) - .build(); + return MLToolSpec + .builder() + .type(type) + .name(name) + .description(description) + .parameters(parameters) + .includeOutputInAgentResponse(includeOutputInAgentResponse) + .build(); } public static MLToolSpec fromStream(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteInput.java b/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteInput.java index 34a879aaae..a9874a286a 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteInput.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteInput.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.FunctionName; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.FunctionName; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface ExecuteInput { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteOutput.java b/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteOutput.java index e5a858f42d..42b2e1c1d0 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteOutput.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.FunctionName; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.FunctionName; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface ExecuteOutput { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/InputDataSet.java b/common/src/main/java/org/opensearch/ml/common/annotation/InputDataSet.java index 847e00ac36..93965886ff 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/InputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/InputDataSet.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.dataset.MLInputDataType; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.dataset.MLInputDataType; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface InputDataSet { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoOutput.java b/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoOutput.java index df0afd7673..d24064be71 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoOutput.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.output.MLOutputType; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.output.MLOutputType; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface MLAlgoOutput { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoParameter.java b/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoParameter.java index 18136a78f6..eff313fc8a 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoParameter.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoParameter.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.FunctionName; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.FunctionName; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface MLAlgoParameter { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java b/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java index b8100473b0..31f520b181 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.FunctionName; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.FunctionName; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface MLInput { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index 90837425c4..aac9a1acad 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -5,8 +5,16 @@ package org.opensearch.ml.common.connector; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY; +import static org.opensearch.ml.common.utils.StringUtils.isJson; + +import java.io.IOException; +import java.time.Instant; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + import org.apache.commons.text.StringSubstitutor; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; @@ -16,15 +24,8 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.utils.StringUtils; -import java.io.IOException; -import java.time.Instant; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY; -import static org.opensearch.ml.common.utils.StringUtils.isJson; +import lombok.Getter; +import lombok.Setter; @Getter public abstract class AbstractConnector implements Connector { @@ -45,7 +46,6 @@ public abstract class AbstractConnector implements Connector { public static final String ACCESS_FIELD = "access"; public static final String CLIENT_CONFIG_FIELD = "client_config"; - protected String name; protected String description; protected String version; @@ -105,7 +105,7 @@ public void parseResponse(T response, List modelTensors, boolea } return; } - if (response instanceof String && isJson((String)response)) { + if (response instanceof String && isJson((String) response)) { Map data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY); modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build()); } else { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java index 4052b45874..fb5badf2ea 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java @@ -5,22 +5,23 @@ package org.opensearch.ml.common.connector; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import lombok.extern.log4j.Log4j2; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.AccessMode; +import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Optional; -import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.AccessMode; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; @Log4j2 @NoArgsConstructor @@ -29,12 +30,32 @@ public class AwsConnector extends HttpConnector { @Builder(builderMethodName = "awsConnectorBuilder") - public AwsConnector(String name, String description, String version, String protocol, - Map parameters, Map credential, List actions, - List backendRoles, AccessMode accessMode, User owner, - ConnectorClientConfig connectorClientConfig) { - super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, - owner, connectorClientConfig); + public AwsConnector( + String name, + String description, + String version, + String protocol, + Map parameters, + Map credential, + List actions, + List backendRoles, + AccessMode accessMode, + User owner, + ConnectorClientConfig connectorClientConfig + ) { + super( + name, + description, + version, + protocol, + parameters, + credential, + actions, + backendRoles, + accessMode, + owner, + connectorClientConfig + ); validate(); } @@ -43,7 +64,6 @@ public AwsConnector(String protocol, XContentParser parser) throws IOException { validate(); } - public AwsConnector(StreamInput input) throws IOException { super(input); validate(); @@ -53,17 +73,19 @@ private void validate() { if (credential == null || !credential.containsKey(ACCESS_KEY_FIELD) || !credential.containsKey(SECRET_KEY_FIELD)) { throw new IllegalArgumentException("Missing credential"); } - if ((credential == null || !credential.containsKey(SERVICE_NAME_FIELD)) && (parameters == null || !parameters.containsKey(SERVICE_NAME_FIELD))) { + if ((credential == null || !credential.containsKey(SERVICE_NAME_FIELD)) + && (parameters == null || !parameters.containsKey(SERVICE_NAME_FIELD))) { throw new IllegalArgumentException("Missing service name"); } - if ((credential == null || !credential.containsKey(REGION_FIELD)) && (parameters == null || !parameters.containsKey(REGION_FIELD))) { + if ((credential == null || !credential.containsKey(REGION_FIELD)) + && (parameters == null || !parameters.containsKey(REGION_FIELD))) { throw new IllegalArgumentException("Missing region"); } } @Override public Connector cloneConnector() { - try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()){ + try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) { this.writeTo(bytesStreamOutput); StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); return new AwsConnector(streamInput); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 155f531948..c808f6628c 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.connector; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; import java.security.AccessController; @@ -16,13 +18,14 @@ import java.util.function.Function; import java.util.regex.Matcher; import java.util.regex.Pattern; + import org.apache.commons.text.StringSubstitutor; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; @@ -32,24 +35,27 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.gson; - /** * Connector defines how to connect to a remote service. */ public interface Connector extends ToXContentObject, Writeable { String getName(); + String getProtocol(); + User getOwner(); + void setOwner(User user); AccessMode getAccess(); + void setAccess(AccessMode access); + List getBackendRoles(); void setBackendRoles(List backendRoles); + Map getParameters(); List getActions(); @@ -63,6 +69,7 @@ public interface Connector extends ToXContentObject, Writeable { T createPayload(String action, Map parameters); void decrypt(String action, Function function); + void encrypt(Function function); Connector cloneConnector(); @@ -95,7 +102,8 @@ default void validatePayload(String payload) { static Connector fromStream(StreamInput in) throws IOException { try { String connectorProtocol = in.readString(); - return MLCommonsClassLoader.initConnector(connectorProtocol, new Object[]{connectorProtocol, in}, String.class, StreamInput.class); + return MLCommonsClassLoader + .initConnector(connectorProtocol, new Object[] { connectorProtocol, in }, String.class, StreamInput.class); } catch (IllegalArgumentException illegalArgumentException) { throw illegalArgumentException; } @@ -119,25 +127,30 @@ static Connector createConnector(XContentParser parser) throws IOException { } catch (PrivilegedActionException e) { throw new IllegalArgumentException("wrong connector"); } - String connectorProtocol = (String)connectorMap.get("protocol"); + String connectorProtocol = (String) connectorMap.get("protocol"); return createConnector(jsonStr, connectorProtocol); } private static Connector createConnector(String jsonStr, String connectorProtocol) throws IOException { - try (XContentParser connectorParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr)) { + try ( + XContentParser connectorParser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr) + ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, connectorParser.nextToken(), connectorParser); if (connectorProtocol == null) { throw new IllegalArgumentException("connector protocol is null"); } - return MLCommonsClassLoader.initConnector(connectorProtocol, new Object[]{connectorProtocol, connectorParser}, String.class, XContentParser.class); + return MLCommonsClassLoader + .initConnector(connectorProtocol, new Object[] { connectorProtocol, connectorParser }, String.class, XContentParser.class); } catch (Exception ex) { if (ex instanceof IllegalArgumentException) { throw ex; } return null; - } + } } default void validateConnectorURL(List urlRegexes) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 9be290d126..93fb5cca57 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -5,16 +5,7 @@ package org.opensearch.ml.common.connector; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.FunctionName; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.util.HashSet; @@ -22,7 +13,16 @@ import java.util.Map; import java.util.Set; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; @Getter @EqualsAndHashCode @@ -173,15 +173,16 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { break; } } - return ConnectorAction.builder() - .actionType(actionType) - .method(method) - .url(url) - .headers(headers) - .requestBody(requestBody) - .preProcessFunction(preProcessFunction) - .postProcessFunction(postProcessFunction) - .build(); + return ConnectorAction + .builder() + .actionType(actionType) + .method(method) + .url(url) + .headers(headers) + .requestBody(requestBody) + .preProcessFunction(preProcessFunction) + .postProcessFunction(postProcessFunction) + .build(); } public enum ActionType { @@ -197,10 +198,7 @@ public static ActionType from(String value) { } } - private static final HashSet MODEL_SUPPORT_ACTIONS = new HashSet<>(Set.of( - PREDICT, - BATCH_PREDICT - )); + private static final HashSet MODEL_SUPPORT_ACTIONS = new HashSet<>(Set.of(PREDICT, BATCH_PREDICT)); public static boolean isValidActionInModelPrediction(ActionType actionType) { return MODEL_SUPPORT_ACTIONS.contains(actionType); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java index 2621b1adb9..4d617ce896 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java @@ -5,9 +5,12 @@ package org.opensearch.ml.common.connector; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; +import java.util.Objects; + import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -16,12 +19,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.List; -import java.util.Locale; -import java.util.Objects; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; @Getter @EqualsAndHashCode @@ -75,7 +75,7 @@ public ConnectorClientConfig(StreamInput input) throws IOException { this.maxConnections = input.readOptionalInt(); this.connectionTimeout = input.readOptionalInt(); this.readTimeout = input.readOptionalInt(); - if(streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)) { + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)) { this.retryBackoffMillis = input.readOptionalInt(); this.retryTimeoutSeconds = input.readOptionalInt(); this.maxRetryTimes = input.readOptionalInt(); @@ -101,7 +101,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalInt(maxConnections); out.writeOptionalInt(connectionTimeout); out.writeOptionalInt(readTimeout); - if(streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)){ + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)) { out.writeOptionalInt(retryBackoffMillis); out.writeOptionalInt(retryTimeoutSeconds); out.writeOptionalInt(maxRetryTimes); @@ -187,14 +187,15 @@ public static ConnectorClientConfig parse(XContentParser parser) throws IOExcept break; } } - return ConnectorClientConfig.builder() - .maxConnections(maxConnections) - .connectionTimeout(connectionTimeout) - .readTimeout(readTimeout) - .retryBackoffMillis(retryBackoffMillis) - .retryTimeoutSeconds(retryTimeoutSeconds) - .maxRetryTimes(maxRetryTimes) - .retryBackoffPolicy(retryBackoffPolicy) - .build(); + return ConnectorClientConfig + .builder() + .maxConnections(maxConnections) + .connectionTimeout(connectionTimeout) + .readTimeout(readTimeout) + .retryBackoffMillis(retryBackoffMillis) + .retryTimeoutSeconds(retryTimeoutSeconds) + .maxRetryTimes(maxRetryTimes) + .retryBackoffPolicy(retryBackoffPolicy) + .build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java index 50412ce09a..408e4ea7c4 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java @@ -7,7 +7,6 @@ import java.util.Arrays; import java.util.List; -import java.util.Set; public class ConnectorProtocols { @@ -18,10 +17,14 @@ public class ConnectorProtocols { public static void validateProtocol(String protocol) { if (protocol == null) { - throw new IllegalArgumentException("Connector protocol is null. Please use one of " + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))); + throw new IllegalArgumentException( + "Connector protocol is null. Please use one of " + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0])) + ); } if (!VALID_PROTOCOLS.contains(protocol)) { - throw new IllegalArgumentException("Unsupported connector protocol. Please use one of " + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))); + throw new IllegalArgumentException( + "Unsupported connector protocol. Please use one of " + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0])) + ); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 7ac05a842d..56134d1c43 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -5,18 +5,11 @@ package org.opensearch.ml.common.connector; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import lombok.extern.log4j.Log4j2; -import org.apache.commons.text.StringSubstitutor; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.AccessMode; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; +import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; +import static org.opensearch.ml.common.utils.StringUtils.isJson; import java.io.IOException; import java.time.Instant; @@ -29,14 +22,21 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; -import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; -import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; -import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; -import static org.opensearch.ml.common.utils.StringUtils.isJson; +import org.apache.commons.text.StringSubstitutor; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; + @Log4j2 @NoArgsConstructor @EqualsAndHashCode @@ -48,13 +48,22 @@ public class HttpConnector extends AbstractConnector { public static final String SERVICE_NAME_FIELD = "service_name"; public static final String REGION_FIELD = "region"; - //TODO: add RequestConfig like request time out, + // TODO: add RequestConfig like request time out, @Builder - public HttpConnector(String name, String description, String version, String protocol, - Map parameters, Map credential, List actions, - List backendRoles, AccessMode accessMode, User owner, - ConnectorClientConfig connectorClientConfig) { + public HttpConnector( + String name, + String description, + String version, + String protocol, + Map parameters, + Map credential, + List actions, + List backendRoles, + AccessMode accessMode, + User owner, + ConnectorClientConfig connectorClientConfig + ) { validateProtocol(protocol); this.name = name; this.description = description; @@ -308,7 +317,7 @@ public void update(MLCreateConnectorInput updateContent, Function T createPayload(String action, Map parameters) { + public T createPayload(String action, Map parameters) { Optional connectorAction = findAction(action); if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) { String payload = connectorAction.get().getRequestBody(); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 4fb3f75412..5ba465b15a 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -5,16 +5,16 @@ package org.opensearch.ml.common.connector; -import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction; -import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction; -import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction; -import org.opensearch.ml.common.output.model.ModelTensor; - import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Function; +import org.opensearch.ml.common.connector.functions.postprocess.BedrockEmbeddingPostProcessFunction; +import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction; +import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction; +import org.opensearch.ml.common.output.model.ModelTensor; + public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; @@ -57,4 +57,4 @@ public static Function> get(String postProcessFunction public static boolean contains(String postProcessFunction) { return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction); } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 27c8e6dd93..3a5a3427a8 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common.connector; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; @@ -13,10 +17,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - public class MLPreProcessFunction { private static final Map> PRE_PROCESS_FUNCTIONS = new HashMap<>(); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java index eb55253c01..82823187e8 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunction.java @@ -5,12 +5,12 @@ package org.opensearch.ml.common.connector.functions.postprocess; -import org.opensearch.ml.common.output.model.MLResultDataType; -import org.opensearch.ml.common.output.model.ModelTensor; - import java.util.ArrayList; import java.util.List; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + public class BedrockEmbeddingPostProcessFunction extends ConnectorPostProcessFunction> { @Override @@ -21,7 +21,7 @@ public void validate(Object input) { List outerList = (List) input; - if (!outerList.isEmpty() && !(((List)input).get(0) instanceof Number)) { + if (!outerList.isEmpty() && !(((List) input).get(0) instanceof Number)) { throw new IllegalArgumentException("The embedding should be a non-empty List containing Float values."); } } @@ -29,14 +29,16 @@ public void validate(Object input) { @Override public List process(List embedding) { List modelTensors = new ArrayList<>(); - modelTensors.add( + modelTensors + .add( ModelTensor - .builder() - .name("sentence_embedding") - .dataType(MLResultDataType.FLOAT32) - .shape(new long[]{embedding.size()}) - .data(embedding.toArray(new Number[0])) - .build()); + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[] { embedding.size() }) + .data(embedding.toArray(new Number[0])) + .build() + ); return modelTensors; } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java index 216fcc9d0a..cf93202366 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunction.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.connector.functions.postprocess; -import org.opensearch.ml.common.output.model.MLResultDataType; -import org.opensearch.ml.common.output.model.ModelTensor; - import java.util.ArrayList; import java.util.List; import java.util.Map; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + public class CohereRerankPostProcessFunction extends ConnectorPostProcessFunction>> { @Override @@ -44,12 +44,16 @@ public List process(List> rerankResults) { } for (int i = 0; i < scores.length; i++) { - modelTensors.add(ModelTensor.builder() - .name("similarity") - .shape(new long[]{1}) - .data(new Number[]{scores[i]}) - .dataType(MLResultDataType.FLOAT32) - .build()); + modelTensors + .add( + ModelTensor + .builder() + .name("similarity") + .shape(new long[] { 1 }) + .data(new Number[] { scores[i] }) + .dataType(MLResultDataType.FLOAT32) + .build() + ); } } return modelTensors; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java index 9cb81099c4..a5374a42bb 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/ConnectorPostProcessFunction.java @@ -5,11 +5,11 @@ package org.opensearch.ml.common.connector.functions.postprocess; -import org.opensearch.ml.common.output.model.ModelTensor; - import java.util.List; import java.util.function.Function; +import org.opensearch.ml.common.output.model.ModelTensor; + public abstract class ConnectorPostProcessFunction implements Function> { @Override @@ -18,7 +18,7 @@ public List apply(Object input) { throw new IllegalArgumentException("Can't run post process function as model output is null"); } validate(input); - return process((T)input); + return process((T) input); } public abstract void validate(Object input); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java index 6e6d373302..e3142b8368 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunction.java @@ -5,12 +5,12 @@ package org.opensearch.ml.common.connector.functions.postprocess; -import org.opensearch.ml.common.output.model.MLResultDataType; -import org.opensearch.ml.common.output.model.ModelTensor; - import java.util.ArrayList; import java.util.List; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + public class EmbeddingPostProcessFunction extends ConnectorPostProcessFunction>> { @Override @@ -36,15 +36,19 @@ public void validate(Object input) { @Override public List process(List> embeddings) { List modelTensors = new ArrayList<>(); - embeddings.forEach(embedding -> modelTensors.add( - ModelTensor - .builder() - .name("sentence_embedding") - .dataType(MLResultDataType.FLOAT32) - .shape(new long[]{embedding.size()}) - .data(embedding.toArray(new Number[0])) - .build() - )); + embeddings + .forEach( + embedding -> modelTensors + .add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[] { embedding.size() }) + .data(embedding.toArray(new Number[0])) + .build() + ) + ); return modelTensors; } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java index 7ca22c3cdc..b6a95be042 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java @@ -5,14 +5,13 @@ package org.opensearch.ml.common.connector.functions.preprocess; -import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.MLInput; +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; import java.util.Map; -import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; - +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java index 0b66be089d..7ef9845b94 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java @@ -5,14 +5,13 @@ package org.opensearch.ml.common.connector.functions.preprocess; -import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.MLInput; +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; import java.util.Map; -import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; - +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; public class CohereEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java index c975f7f329..823cfbf2a1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunction.java @@ -5,14 +5,13 @@ package org.opensearch.ml.common.connector.functions.preprocess; -import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.MLInput; +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; import java.util.Map; -import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; - +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; public class CohereRerankPreProcessFunction extends ConnectorPreProcessFunction { @@ -30,11 +29,11 @@ public void validate(MLInput mlInput) { @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextSimilarityInputDataSet inputData = (TextSimilarityInputDataSet) mlInput.getInputDataset(); - Map processedResult = Map.of("parameters", Map.of( - "query", inputData.getQueryText(), - "documents", inputData.getTextDocs(), - "top_n", inputData.getTextDocs().size() - )); + Map processedResult = Map + .of( + "parameters", + Map.of("query", inputData.getQueryText(), "documents", inputData.getTextDocs(), "top_n", inputData.getTextDocs().size()) + ); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index c4c88532ef..387ac27467 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -5,7 +5,13 @@ package org.opensearch.ml.common.connector.functions.preprocess; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod; + +import java.util.Collections; +import java.util.Locale; +import java.util.Map; +import java.util.function.Function; + import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -14,13 +20,7 @@ import org.opensearch.script.ScriptType; import org.opensearch.script.TemplateScript; -import java.util.Collections; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.function.Function; - -import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod; +import lombok.extern.log4j.Log4j2; /** * This abstract class represents a pre-processing function for a connector. @@ -50,7 +50,7 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) { throw new IllegalArgumentException("Preprocess function input can't be null"); } if (returnDirectlyForRemoteInferenceInput && mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { - return (RemoteInferenceInputDataSet)mlInput.getInputDataset(); + return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); } else { validate(mlInput); return process(mlInput); @@ -70,8 +70,18 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) { */ public void validateTextDocsInput(MLInput mlInput) { if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) { - log.error(String.format(Locale.ROOT, "This pre_process_function can only support TextDocsInputDataSet, actual input type is: %s", mlInput.getInputDataset().getClass().getName())); - throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet which including a list of string with key 'text_docs'"); + log + .error( + String + .format( + Locale.ROOT, + "This pre_process_function can only support TextDocsInputDataSet, actual input type is: %s", + mlInput.getInputDataset().getClass().getName() + ) + ); + throw new IllegalArgumentException( + "This pre_process_function can only support TextDocsInputDataSet which including a list of string with key 'text_docs'" + ); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java index fac2b5bc94..88bd5be1b5 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunction.java @@ -5,21 +5,22 @@ package org.opensearch.ml.common.connector.functions.preprocess; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.experimental.FieldDefaults; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.util.Map; + import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.script.ScriptService; -import java.io.IOException; -import java.util.Map; - -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; -import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; -import static org.opensearch.ml.common.utils.StringUtils.gson; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.experimental.FieldDefaults; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class DefaultPreProcessFunction extends ConnectorPreProcessFunction { @@ -37,8 +38,7 @@ public DefaultPreProcessFunction(ScriptService scriptService, String preProcessF } @Override - public void validate(MLInput mlInput) { - } + public void validate(MLInput mlInput) {} @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index 008b1efe58..d7acc1a70b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -7,15 +7,15 @@ package org.opensearch.ml.common.connector.functions.preprocess; -import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.MLInput; +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; /** * This class provides a pre-processing function for multi-modal input data. @@ -53,7 +53,10 @@ public RemoteInferenceInputDataSet process(MLInput mlInput) { if (inputData.getDocs().size() > 1) { parametersMap.put("inputImage", inputData.getDocs().get(1)); } - return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))).build(); - + return RemoteInferenceInputDataSet + .builder() + .parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))) + .build(); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java index 83f7ebd74d..41a9f3fcfd 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java @@ -5,14 +5,13 @@ package org.opensearch.ml.common.connector.functions.preprocess; -import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.MLInput; +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; import java.util.Map; -import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; - +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; public class OpenAIEmbeddingPreProcessFunction extends ConnectorPreProcessFunction { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java index 882c1409f6..3921bb6818 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunction.java @@ -5,19 +5,20 @@ package org.opensearch.ml.common.connector.functions.preprocess; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.experimental.FieldDefaults; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.MLInput; -import org.opensearch.script.ScriptService; +import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; import java.util.HashMap; import java.util.Map; -import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; -import static org.opensearch.ml.common.utils.StringUtils.gson; -import static org.opensearch.ml.common.utils.StringUtils.isJson; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.script.ScriptService; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.experimental.FieldDefaults; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class RemoteInferencePreProcessFunction extends ConnectorPreProcessFunction { @@ -47,8 +48,8 @@ public void validate(MLInput mlInput) { public RemoteInferenceInputDataSet process(MLInput mlInput) { Map inputParams = new HashMap<>(); Map parameters = ((RemoteInferenceInputDataSet) mlInput.getInputDataset()).getParameters(); - if (params.containsKey(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT) && - Boolean.parseBoolean(params.get(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT))) { + if (params.containsKey(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT) + && Boolean.parseBoolean(params.get(CONVERT_REMOTE_INFERENCE_PARAM_TO_OBJECT))) { for (String key : parameters.keySet()) { if (isJson(parameters.get(key))) { inputParams.put(key, gson.fromJson(parameters.get(key), Object.class)); diff --git a/common/src/main/java/org/opensearch/ml/common/controller/MLController.java b/common/src/main/java/org/opensearch/ml/common/controller/MLController.java index f356cad477..a8f091af9e 100644 --- a/common/src/main/java/org/opensearch/ml/common/controller/MLController.java +++ b/common/src/main/java/org/opensearch/ml/common/controller/MLController.java @@ -5,9 +5,13 @@ package org.opensearch.ml.common.controller; -import lombok.Builder; -import lombok.Data; -import lombok.Getter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamInput; @@ -19,14 +23,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.HashMap; -import java.util.Iterator; -import java.util.Map; -import java.util.Objects; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; +import lombok.Builder; +import lombok.Data; +import lombok.Getter; @Data public class MLController implements ToXContentObject, Writeable { @@ -63,8 +62,9 @@ public static MLController parse(XContentParser parser) throws IOException { Map userRateLimiterStringMap = getParameterMap(parser.map()); userRateLimiterStringMap.forEach((user, rateLimiterString) -> { try { - XContentParser rateLimiterParser = XContentType.JSON.xContent().createParser( - NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, rateLimiterString); + XContentParser rateLimiterParser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, rateLimiterString); rateLimiterParser.nextToken(); MLRateLimiter rateLimiter = MLRateLimiter.parse(rateLimiterParser); if (!rateLimiter.isEmpty()) { @@ -96,8 +96,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); if (userRateLimiter != null) { out.writeBoolean(true); - out.writeMap(userRateLimiter, StreamOutput::writeString, - (streamOutput, rateLimiter) -> rateLimiter.writeTo(streamOutput)); + out.writeMap(userRateLimiter, StreamOutput::writeString, (streamOutput, rateLimiter) -> rateLimiter.writeTo(streamOutput)); } else { out.writeBoolean(false); } @@ -121,8 +120,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par * @return True if a deployment is required, false otherwise. */ public boolean isDeployRequiredAfterUpdate(MLController updateContent) { - if (updateContent != null && updateContent.getUserRateLimiter() != null - && !updateContent.getUserRateLimiter().isEmpty()) { + if (updateContent != null && updateContent.getUserRateLimiter() != null && !updateContent.getUserRateLimiter().isEmpty()) { Map updateUserRateLimiter = updateContent.getUserRateLimiter(); for (Map.Entry entry : updateUserRateLimiter.entrySet()) { String newUser = entry.getKey(); diff --git a/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java b/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java index 5b5bf9d713..182281e9e5 100644 --- a/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java +++ b/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java @@ -5,23 +5,24 @@ package org.opensearch.ml.common.controller; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + import org.opensearch.OpenSearchParseException; -import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.Objects; -import java.util.concurrent.TimeUnit; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Setter @Getter @@ -137,10 +138,8 @@ public static boolean updateValidityPreCheck(MLRateLimiter rateLimiter, MLRateLi } else if (updateContent.isEmpty()) { return false; } else - return (!Objects.equals(updateContent.getLimit(), rateLimiter.getLimit()) - && updateContent.getLimit() != null) - || (!Objects.equals(updateContent.getUnit(), rateLimiter.getUnit()) - && updateContent.getUnit() != null); + return (!Objects.equals(updateContent.getLimit(), rateLimiter.getLimit()) && updateContent.getLimit() != null) + || (!Objects.equals(updateContent.getUnit(), rateLimiter.getUnit()) && updateContent.getUnit() != null); } /** @@ -156,8 +155,8 @@ public static boolean isDeployRequiredAfterUpdate(MLRateLimiter rateLimiter, MLR return false; } else { return updateContent.isValid() - || (rateLimiter.getUnit() != null && updateContent.getLimit() != null) - || (rateLimiter.getLimit() != null && updateContent.getUnit() != null); + || (rateLimiter.getUnit() != null && updateContent.getLimit() != null) + || (rateLimiter.getLimit() != null && updateContent.getUnit() != null); } } diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 783e9c815a..57f540d27b 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -65,9 +65,9 @@ public class ActionConstants { /** path for create conversation */ public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH; /** path for get conversations */ - public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH; + public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH; /** path for update conversations */ - public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{memory_id}"; + public final static String UPDATE_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/{memory_id}"; /** path for create interaction */ public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{memory_id}/messages"; /** path for get interactions */ @@ -92,4 +92,4 @@ public class ActionConstants { /** default username for reporting security errors if no or malformed username */ public final static String DEFAULT_USERNAME_FOR_ERRORS = "BAD_USER"; -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java index 9cc3b49bc4..21ed608654 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java @@ -34,8 +34,6 @@ import lombok.AllArgsConstructor; import lombok.Getter; -import static org.opensearch.ml.common.CommonValue.VERSION_2_17_0; - /** * Class for holding conversational metadata */ @@ -76,7 +74,7 @@ public static ConversationMeta fromMap(String id, Map docFields) Instant updated = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_UPDATED_TIME_FIELD)); String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD); String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD); - Map additionalInfos = (Map)docFields.get(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD); + Map additionalInfos = (Map) docFields.get(ConversationalIndexConstants.META_ADDITIONAL_INFO_FIELD); return new ConversationMeta(id, created, updated, name, user, additionalInfos); } @@ -109,7 +107,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInstant(updatedTime); out.writeString(name); out.writeOptionalString(user); - if(out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) { + if (out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_ADDITIONAL_INFO)) { if (additionalInfos == null) { out.writeBoolean(false); } else { @@ -121,11 +119,16 @@ public void writeTo(StreamOutput out) throws IOException { @Override public String toString() { - return "{id=" + id - + ", name=" + name - + ", created=" + createdTime.toString() - + ", updated=" + updatedTime.toString() - + ", user=" + user + return "{id=" + + id + + ", name=" + + name + + ", created=" + + createdTime.toString() + + ", updated=" + + updatedTime.toString() + + ", user=" + + user + "}"; } @@ -136,7 +139,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para builder.field(ConversationalIndexConstants.META_CREATED_TIME_FIELD, this.createdTime); builder.field(ConversationalIndexConstants.META_UPDATED_TIME_FIELD, this.updatedTime); builder.field(ConversationalIndexConstants.META_NAME_FIELD, this.name); - if(this.user != null) { + if (this.user != null) { builder.field(ConversationalIndexConstants.USER_FIELD, this.user); } if (this.additionalInfos != null) { @@ -148,15 +151,15 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para @Override public boolean equals(Object other) { - if(!(other instanceof ConversationMeta)) { + if (!(other instanceof ConversationMeta)) { return false; } ConversationMeta otherConversation = (ConversationMeta) other; - return Objects.equals(this.id, otherConversation.id) && - Objects.equals(this.user, otherConversation.user) && - Objects.equals(this.createdTime, otherConversation.createdTime) && - Objects.equals(this.updatedTime, otherConversation.updatedTime) && - Objects.equals(this.name, otherConversation.name); + return Objects.equals(this.id, otherConversation.id) + && Objects.equals(this.user, otherConversation.user) + && Objects.equals(this.createdTime, otherConversation.createdTime) + && Objects.equals(this.updatedTime, otherConversation.updatedTime) + && Objects.equals(this.name, otherConversation.name); } - + } diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java index b542864726..ac639babb2 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java @@ -43,7 +43,9 @@ public class ConversationalIndexConstants { /** Mappings for the conversational metadata index */ public final static String META_MAPPING = "{\n" + " \"_meta\": {\n" - + " \"schema_version\": " + META_INDEX_SCHEMA_VERSION + "\n" + + " \"schema_version\": " + + META_INDEX_SCHEMA_VERSION + + "\n" + " },\n" + " \"properties\": {\n" + " \"" @@ -92,7 +94,9 @@ public class ConversationalIndexConstants { /** Mappings for the interactions index */ public final static String INTERACTIONS_MAPPINGS = "{\n" + " \"_meta\": {\n" - + " \"schema_version\": " + INTERACTIONS_INDEX_SCHEMA_VERSION + "\n" + + " \"schema_version\": " + + INTERACTIONS_INDEX_SCHEMA_VERSION + + "\n" + " },\n" + " \"properties\": {\n" + " \"" @@ -129,5 +133,7 @@ public class ConversationalIndexConstants { public static final Setting ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting .boolSetting("plugins.ml_commons.memory_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); - public static final String ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE = "The Conversation Memory feature is not enabled. To enable, please update the setting " + ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(); + public static final String ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE = + "The Conversation Memory feature is not enabled. To enable, please update the setting " + + ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(); } diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 8e06569672..93afbb52a3 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -62,7 +62,16 @@ public class Interaction implements Writeable, ToXContentObject { private Integer traceNum; @Builder(toBuilder = true) - public Interaction(String id, Instant createTime, String conversationId, String input, String promptTemplate, String response, String origin, Map additionalInfo) { + public Interaction( + String id, + Instant createTime, + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + Map additionalInfo + ) { this.id = id; this.createTime = createTime; this.conversationId = conversationId; @@ -83,15 +92,27 @@ public Interaction(String id, Instant createTime, String conversationId, String */ public static Interaction fromMap(String id, Map fields) { Instant createTime = Instant.parse((String) fields.get(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD)); - String conversationId = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD); - String input = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD); + String conversationId = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD); + String input = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD); String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD); - String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD); - String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD); - Map additionalInfo = (Map) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD); + String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD); + String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD); + Map additionalInfo = (Map) fields + .get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD); String parentInteractionId = (String) fields.getOrDefault(ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, null); Integer traceNum = (Integer) fields.getOrDefault(ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, null); - return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo, parentInteractionId, traceNum); + return new Interaction( + id, + createTime, + conversationId, + input, + promptTemplate, + response, + origin, + additionalInfo, + parentInteractionId, + traceNum + ); } /** @@ -124,10 +145,20 @@ public static Interaction fromStream(StreamInput in) throws IOException { } String parentInteractionId = in.readOptionalString(); Integer traceNum = in.readOptionalInt(); - return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo, parentInteractionId, traceNum); + return new Interaction( + id, + createTime, + conversationId, + input, + promptTemplate, + response, + origin, + additionalInfo, + parentInteractionId, + traceNum + ); } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(id); @@ -157,7 +188,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); - if(additionalInfo != null) { + if (additionalInfo != null) { builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); } if (parentInteractionId != null) { @@ -172,21 +203,19 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para @Override public boolean equals(Object other) { - return ( - other instanceof Interaction && - ((Interaction) other).id.equals(this.id) && - ((Interaction) other).conversationId.equals(this.conversationId) && - ((Interaction) other).createTime.equals(this.createTime) && - ((Interaction) other).input.equals(this.input) && - ((Interaction) other).promptTemplate.equals(this.promptTemplate) && - ((Interaction) other).response.equals(this.response) && - ((Interaction) other).origin.equals(this.origin) && - ( (((Interaction) other).additionalInfo == null && this.additionalInfo == null) || - ((Interaction) other).additionalInfo.equals(this.additionalInfo)) && - ( (((Interaction) other).parentInteractionId == null && this.parentInteractionId == null) || - ((Interaction) other).parentInteractionId.equals(this.parentInteractionId)) && - ( (((Interaction) other).traceNum == null && this.traceNum == null) || - ((Interaction) other).traceNum.equals(this.traceNum)) + return (other instanceof Interaction + && ((Interaction) other).id.equals(this.id) + && ((Interaction) other).conversationId.equals(this.conversationId) + && ((Interaction) other).createTime.equals(this.createTime) + && ((Interaction) other).input.equals(this.input) + && ((Interaction) other).promptTemplate.equals(this.promptTemplate) + && ((Interaction) other).response.equals(this.response) + && ((Interaction) other).origin.equals(this.origin) + && ((((Interaction) other).additionalInfo == null && this.additionalInfo == null) + || ((Interaction) other).additionalInfo.equals(this.additionalInfo)) + && ((((Interaction) other).parentInteractionId == null && this.parentInteractionId == null) + || ((Interaction) other).parentInteractionId.equals(this.parentInteractionId)) + && ((((Interaction) other).traceNum == null && this.traceNum == null) || ((Interaction) other).traceNum.equals(this.traceNum)) ); } @@ -194,17 +223,27 @@ public boolean equals(Object other) { @Override public String toString() { return "Interaction{" - + "id=" + id - + ",cid=" + conversationId - + ",create_time=" + createTime - + ",origin=" + origin - + ",input=" + input - + ",promt_template=" + promptTemplate - + ",response=" + response - + ",additional_info=" + additionalInfo - + ",parentInteractionId=" + parentInteractionId - + ",traceNum=" + traceNum + + "id=" + + id + + ",cid=" + + conversationId + + ",create_time=" + + createTime + + ",origin=" + + origin + + ",input=" + + input + + ",promt_template=" + + promptTemplate + + ",response=" + + response + + ",additional_info=" + + additionalInfo + + ",parentInteractionId=" + + parentInteractionId + + ",traceNum=" + + traceNum + "}"; } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnMeta.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnMeta.java index cfbb6484cb..4ca89ae2fa 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnMeta.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnMeta.java @@ -5,12 +5,17 @@ package org.opensearch.ml.common.dataframe; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.util.Locale; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import lombok.AccessLevel; import lombok.Builder; @@ -18,11 +23,6 @@ import lombok.RequiredArgsConstructor; import lombok.ToString; import lombok.experimental.FieldDefaults; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnType.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnType.java index 28fe550cfe..1b15f3c7bf 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnType.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnType.java @@ -16,31 +16,31 @@ public enum ColumnType { NULL; public static ColumnType from(Object object) { - if(object instanceof Short) { + if (object instanceof Short) { return SHORT; } - if(object instanceof Integer) { + if (object instanceof Integer) { return INTEGER; } - if(object instanceof Long) { + if (object instanceof Long) { return LONG; } - if(object instanceof String) { + if (object instanceof String) { return STRING; } - if(object instanceof Double) { + if (object instanceof Double) { return DOUBLE; } - if(object instanceof Boolean) { + if (object instanceof Boolean) { return BOOLEAN; } - if(object instanceof Float) { + if (object instanceof Float) { return FLOAT; } diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValue.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValue.java index 3a804d3f5e..03aa7d6acc 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValue.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValue.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.dataframe; +import java.io.IOException; +import java.util.Objects; + import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Objects; - public interface ColumnValue extends Writeable, ToXContentObject { ColumnType columnType(); diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueBuilder.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueBuilder.java index 6f91b11764..098699be63 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueBuilder.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueBuilder.java @@ -18,36 +18,36 @@ public class ColumnValueBuilder { * @return ColumnValue */ public ColumnValue build(Object object) { - if(Objects.isNull(object)) { + if (Objects.isNull(object)) { return new NullValue(); } - if(object instanceof Short) { - return new ShortValue((Short)object); + if (object instanceof Short) { + return new ShortValue((Short) object); } - if(object instanceof Integer) { - return new IntValue((Integer)object); + if (object instanceof Integer) { + return new IntValue((Integer) object); } - if(object instanceof Long) { - return new LongValue((Long)object); + if (object instanceof Long) { + return new LongValue((Long) object); } - if(object instanceof String) { - return new StringValue((String)object); + if (object instanceof String) { + return new StringValue((String) object); } - if(object instanceof Double) { - return new DoubleValue((Double)object); + if (object instanceof Double) { + return new DoubleValue((Double) object); } - if(object instanceof Boolean) { - return new BooleanValue((Boolean)object); + if (object instanceof Boolean) { + return new BooleanValue((Boolean) object); } - if(object instanceof Float) { - return new FloatValue((Float)object); + if (object instanceof Float) { + return new FloatValue((Float) object); } throw new IllegalArgumentException("unsupported type:" + object.getClass().getName()); diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueReader.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueReader.java index c94cae3cf8..e132b362a5 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueReader.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueReader.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.dataframe; import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.Writeable; @@ -13,7 +14,7 @@ public class ColumnValueReader implements Writeable.Reader { @Override public ColumnValue read(StreamInput in) throws IOException { ColumnType columnType = in.readEnum(ColumnType.class); - switch (columnType){ + switch (columnType) { case SHORT: return new ShortValue(in.readShort()); case INTEGER: diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameBuilder.java b/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameBuilder.java index c225b742e2..fac9e2ce24 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameBuilder.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameBuilder.java @@ -24,7 +24,7 @@ public class DataFrameBuilder { * @return empty data frame */ public DataFrame emptyDataFrame(final ColumnMeta[] columnMetas) { - if(columnMetas == null || columnMetas.length == 0) { + if (columnMetas == null || columnMetas.length == 0) { throw new IllegalArgumentException("columnMetas array is null or empty"); } return new DefaultDataFrame(columnMetas); @@ -37,7 +37,7 @@ public DataFrame emptyDataFrame(final ColumnMeta[] columnMetas) { * @return data frame */ public DataFrame load(final List> input) { - if(input == null || input.isEmpty()) { + if (input == null || input.isEmpty()) { throw new IllegalArgumentException("input is null or empty"); } @@ -45,11 +45,8 @@ public DataFrame load(final List> input) { ColumnMeta[] columnMetas = new ColumnMeta[element.size()]; int index = 0; - for(Map.Entry entry : element.entrySet()) { - ColumnMeta columnMeta = ColumnMeta.builder() - .name(entry.getKey()) - .columnType(ColumnType.from(entry.getValue())) - .build(); + for (Map.Entry entry : element.entrySet()) { + ColumnMeta columnMeta = ColumnMeta.builder().name(entry.getKey()).columnType(ColumnType.from(entry.getValue())).build(); columnMetas[index++] = columnMeta; } @@ -63,36 +60,36 @@ public DataFrame load(final List> input) { * @param input input list of map objects * @return data frame */ - public DataFrame load(final ColumnMeta[] columnMetas, final List> input){ - if(columnMetas == null || columnMetas.length == 0) { + public DataFrame load(final ColumnMeta[] columnMetas, final List> input) { + if (columnMetas == null || columnMetas.length == 0) { throw new IllegalArgumentException("columnMetas array is null or empty"); } - if(input == null || input.isEmpty()) { + if (input == null || input.isEmpty()) { throw new IllegalArgumentException("input data list is null or empty"); } int columnSize = columnMetas.length; Map columnsMap = new HashMap<>(); - for(int i = 0; i < columnSize; i++) { + for (int i = 0; i < columnSize; i++) { columnsMap.put(columnMetas[i].getName(), i); } List rows = input.stream().map(item -> { Row row = new Row(columnSize); - if(item.size() != columnSize) { + if (item.size() != columnSize) { throw new IllegalArgumentException("input item map size is different in the map"); } - for(Map.Entry entry : item.entrySet()) { - if(!columnsMap.containsKey(entry.getKey())) { + for (Map.Entry entry : item.entrySet()) { + if (!columnsMap.containsKey(entry.getKey())) { throw new IllegalArgumentException("field of input item doesn't exist in columns, filed:" + entry.getKey()); } String columnName = entry.getKey(); int index = columnsMap.get(columnName); ColumnType columnType = columnMetas[index].getColumnType(); ColumnValue value = ColumnValueBuilder.build(entry.getValue()); - if(columnType != value.columnType()) { + if (columnType != value.columnType()) { throw new IllegalArgumentException("the same field has different data type"); } row.setValue(index, value); diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java b/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java index 27dd667de6..e7d67fcca6 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.dataframe; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -14,31 +16,29 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; - -import lombok.AccessLevel; -import lombok.ToString; -import lombok.experimental.FieldDefaults; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.AccessLevel; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class DefaultDataFrame extends AbstractDataFrame{ +public class DefaultDataFrame extends AbstractDataFrame { private static final String COLUMN_META_FIELD = "column_metas"; private static final String ROWS_FIELD = "rows"; List rows; ColumnMeta[] columnMetas; - public DefaultDataFrame(final ColumnMeta[] columnMetas){ + public DefaultDataFrame(final ColumnMeta[] columnMetas) { super(DataFrameType.DEFAULT); this.columnMetas = columnMetas; this.rows = new ArrayList<>(); } - public DefaultDataFrame(final ColumnMeta[] columnMetas, final List rows){ + public DefaultDataFrame(final ColumnMeta[] columnMetas, final List rows) { super(DataFrameType.DEFAULT); this.columnMetas = columnMetas; this.rows = rows; @@ -52,12 +52,12 @@ public DefaultDataFrame(StreamInput streamInput) throws IOException { @Override public void appendRow(final Object[] values) { - if(values == null) { + if (values == null) { throw new IllegalArgumentException("input values can't be null"); } Row row = new Row(values.length); - for(int i = 0; i < values.length; i++) { + for (int i = 0; i < values.length; i++) { row.setValue(i, ColumnValueBuilder.build(values[i])); } @@ -66,20 +66,25 @@ public void appendRow(final Object[] values) { @Override public void appendRow(final Row row) { - if(row == null) { + if (row == null) { throw new IllegalArgumentException("input row can't be null"); } - if(row.size() != columnMetas.length) { - final String message = String.format("the size is different between input row:%d " + - "and column size in dataframe:%d", row.size(), columnMetas.length); + if (row.size() != columnMetas.length) { + final String message = String + .format("the size is different between input row:%d " + "and column size in dataframe:%d", row.size(), columnMetas.length); throw new IllegalArgumentException(message); } - for(int i = 0; i < columnMetas.length; i++) { - if(columnMetas[i].getColumnType() != row.getValue(i).columnType()) { - final String message = String.format("the column type is different in column meta:%s and input row:%s for index: %d", - columnMetas[i].getColumnType(), row.getValue(i).columnType(), i); + for (int i = 0; i < columnMetas.length; i++) { + if (columnMetas[i].getColumnType() != row.getValue(i).columnType()) { + final String message = String + .format( + "the column type is different in column meta:%s and input row:%s for index: %d", + columnMetas[i].getColumnType(), + row.getValue(i).columnType(), + i + ); throw new IllegalArgumentException(message); } } @@ -103,33 +108,33 @@ public ColumnMeta[] columnMetas() { @Override public DataFrame remove(int columnIndex) { - if(columnIndex < 0 || columnIndex >= columnMetas.length) { + if (columnIndex < 0 || columnIndex >= columnMetas.length) { throw new IllegalArgumentException("columnIndex can't be negative or bigger than columns length:" + columnMetas.length); } ColumnMeta[] newColumnMetas = new ColumnMeta[columnMetas.length - 1]; int index = 0; - for(int i = 0; i < columnMetas.length && i != columnIndex; i++) { + for (int i = 0; i < columnMetas.length && i != columnIndex; i++) { newColumnMetas[index++] = columnMetas[i]; } - return new DefaultDataFrame(newColumnMetas, rows.stream().map(row-> row.remove(columnIndex)).collect(Collectors.toList())); + return new DefaultDataFrame(newColumnMetas, rows.stream().map(row -> row.remove(columnIndex)).collect(Collectors.toList())); } @Override public DataFrame select(int[] columns) { - if(columns == null || columns.length == 0) { + if (columns == null || columns.length == 0) { throw new IllegalArgumentException("columns can't be null or empty"); } ColumnMeta[] newColumnMetas = new ColumnMeta[columns.length]; int index = 0; - for(int col : columns) { - if(col < 0 || col >= columnMetas.length) { + for (int col : columns) { + if (col < 0 || col >= columnMetas.length) { throw new IllegalArgumentException("columnIndex can't be negative or bigger than columns length"); } newColumnMetas[index++] = columnMetas[col]; } - return new DefaultDataFrame(newColumnMetas, rows.stream().map(row-> row.select(columns)).collect(Collectors.toList())); + return new DefaultDataFrame(newColumnMetas, rows.stream().map(row -> row.select(columns)).collect(Collectors.toList())); } @Override @@ -155,7 +160,6 @@ public Iterator iterator() { return rows.iterator(); } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -200,13 +204,13 @@ public XContentBuilder toXContent(XContentBuilder builder) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startArray(COLUMN_META_FIELD); - for(ColumnMeta columnMeta : columnMetas) { + for (ColumnMeta columnMeta : columnMetas) { columnMeta.toXContent(builder, params); } builder.endArray(); builder.startArray(ROWS_FIELD); - for(Row row : rows) { + for (Row row : rows) { row.toXContent(builder, params); } builder.endArray(); diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/FloatValue.java b/common/src/main/java/org/opensearch/ml/common/dataframe/FloatValue.java index 98727f537b..7cf0be0543 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/FloatValue.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/FloatValue.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.dataframe; +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamOutput; + import lombok.AccessLevel; import lombok.RequiredArgsConstructor; import lombok.ToString; import lombok.experimental.FieldDefaults; -import org.opensearch.core.common.io.stream.StreamOutput; - -import java.io.IOException; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @RequiredArgsConstructor diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/LongValue.java b/common/src/main/java/org/opensearch/ml/common/dataframe/LongValue.java index 7d40265732..d3c7d70c8e 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/LongValue.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/LongValue.java @@ -5,14 +5,15 @@ package org.opensearch.ml.common.dataframe; +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamOutput; + import lombok.AccessLevel; import lombok.RequiredArgsConstructor; import lombok.ToString; import lombok.experimental.FieldDefaults; -import java.io.IOException; -import org.opensearch.core.common.io.stream.StreamOutput; - @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @RequiredArgsConstructor @ToString diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/Row.java b/common/src/main/java/org/opensearch/ml/common/dataframe/Row.java index 8727c416e3..5d38b4bc21 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/Row.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/Row.java @@ -5,9 +5,14 @@ package org.opensearch.ml.common.dataframe; -import lombok.AccessLevel; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -16,13 +21,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.AccessLevel; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString @@ -43,14 +44,14 @@ public Row(ColumnValue[] values) { } void setValue(int index, ColumnValue value) { - if(index < 0 || index > size() - 1) { + if (index < 0 || index > size() - 1) { throw new IllegalArgumentException("index is out of scope, index:" + index + "; row size:" + size()); } this.values[index] = value; } public ColumnValue getValue(int index) { - if(index < 0 || index > size() - 1) { + if (index < 0 || index > size() - 1) { throw new IllegalArgumentException("index is out of scope, index:" + index + "; row size:" + size()); } return this.values[index]; @@ -71,7 +72,7 @@ public void writeTo(StreamOutput out) throws IOException { } Row remove(int removedIndex) { - if(removedIndex < 0 || removedIndex >= values.length) { + if (removedIndex < 0 || removedIndex >= values.length) { throw new IllegalArgumentException("removed index can't be negative or bigger than row's values length:" + values.length); } ColumnValue[] newValues = new ColumnValue[Math.max(values.length - 1, 0)]; @@ -86,7 +87,7 @@ Row remove(int removedIndex) { Row select(int[] columns) { ColumnValue[] newValues = new ColumnValue[columns.length]; int index = 0; - for(int col: columns) { + for (int col : columns) { newValues[index++] = values[col]; } @@ -109,7 +110,9 @@ public static Row parse(XContentParser parser) throws IOException { if (parser.nextToken() != XContentParser.Token.END_OBJECT) { String columnTypeField = parser.currentName(); if (!"column_type".equals(columnTypeField)) { - throw new IllegalArgumentException("wrong column type, expect column_type field but got " + columnTypeField); + throw new IllegalArgumentException( + "wrong column type, expect column_type field but got " + columnTypeField + ); } parser.nextToken(); String columnType = parser.text(); @@ -182,26 +185,28 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; Row other = (Row) o; if (this.size() != other.size()) { return false; } - for (int i = 0; i< this.size(); i++) { - if(!this.getValue(i).equals(other.getValue(i))) { + for (int i = 0; i < this.size(); i++) { + if (!this.getValue(i).equals(other.getValue(i))) { return false; } } return true; } - public boolean equals(Row other) { + public boolean equals(Row other) { if (this.size() != other.size()) { return false; } - for (int i = 0; i< this.size(); i++) { - if(!this.getValue(i).equals(other.getValue(i))) { + for (int i = 0; i < this.size(); i++) { + if (!this.getValue(i).equals(other.getValue(i))) { return false; } } diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ShortValue.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ShortValue.java index 77de5aecf4..c08f6e1ac9 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ShortValue.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ShortValue.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.dataframe; +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamOutput; + import lombok.AccessLevel; import lombok.RequiredArgsConstructor; import lombok.ToString; import lombok.experimental.FieldDefaults; -import org.opensearch.core.common.io.stream.StreamOutput; - -import java.io.IOException; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @RequiredArgsConstructor diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java b/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java index a535144354..ccb5e84014 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java @@ -11,14 +11,14 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameType; +import org.opensearch.ml.common.dataframe.DefaultDataFrame; import lombok.AccessLevel; import lombok.Builder; import lombok.Getter; import lombok.NonNull; import lombok.experimental.FieldDefaults; -import org.opensearch.ml.common.dataframe.DataFrameType; -import org.opensearch.ml.common.dataframe.DefaultDataFrame; /** * DataFrame based input data. Client directly passes the data frame to ML plugin with this. diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java index 2c3514530f..0d7374d4c1 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java @@ -10,12 +10,12 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.ml.common.MLCommonsClassLoader; import lombok.AccessLevel; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.experimental.FieldDefaults; -import org.opensearch.ml.common.MLCommonsClassLoader; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDataSet.java index 204d7df149..d0ae938d98 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDataSet.java @@ -4,32 +4,33 @@ */ package org.opensearch.ml.common.dataset; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.experimental.FieldDefaults; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.annotation.InputDataSet; -import java.io.IOException; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @InputDataSet(MLInputDataType.QUESTION_ANSWERING) public class QuestionAnsweringInputDataSet extends MLInputDataset { - String question; + String question; - String context; + String context; @Builder(toBuilder = true) public QuestionAnsweringInputDataSet(String question, String context) { super(MLInputDataType.QUESTION_ANSWERING); - if(question == null) { + if (question == null) { throw new IllegalArgumentException("Question is not provided"); } - if(context == null) { + if (context == null) { throw new IllegalArgumentException("Context is not provided"); } this.question = question; diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java b/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java index 636384adbc..6d737887a1 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java @@ -9,11 +9,11 @@ import java.util.Collections; import java.util.List; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.annotation.InputDataSet; @@ -60,7 +60,9 @@ public SearchQueryInputDataset(@NonNull List indices, @NonNull SearchSou public SearchQueryInputDataset(StreamInput streaminput) throws IOException { super(MLInputDataType.SEARCH_QUERY); String searchString = streaminput.readString(); - XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, searchString); + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, searchString); this.searchSourceBuilder = SearchSourceBuilder.fromXContent(parser); this.indices = streaminput.readStringList(); } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java index 34cc561ace..adc290f5ad 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java @@ -5,10 +5,11 @@ package org.opensearch.ml.common.dataset; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.experimental.FieldDefaults; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -16,15 +17,15 @@ import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.output.model.ModelResultFilter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @InputDataSet(MLInputDataType.TEXT_DOCS) -public class TextDocsInputDataSet extends MLInputDataset{ +public class TextDocsInputDataSet extends MLInputDataset { private ModelResultFilter resultFilter; @@ -49,7 +50,7 @@ public TextDocsInputDataSet(StreamInput streamInput) throws IOException { if (version.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MULTI_MODAL)) { docs = new ArrayList<>(); int size = streamInput.readInt(); - for (int i=0; i textDocs; - String queryText; + List textDocs; + + String queryText; @Builder(toBuilder = true) public TextSimilarityInputDataSet(String queryText, List textDocs) { super(MLInputDataType.TEXT_SIMILARITY); Objects.requireNonNull(textDocs); Objects.requireNonNull(queryText); - if(textDocs.isEmpty()) { + if (textDocs.isEmpty()) { throw new IllegalArgumentException("No text documents were provided"); } this.textDocs = textDocs; @@ -57,7 +57,7 @@ public TextSimilarityInputDataSet(StreamInput in) throws IOException { this.queryText = in.readString(); int size = in.readInt(); this.textDocs = new ArrayList(); - for(int i = 0; i < size; i++) { + for (int i = 0; i < size; i++) { String context = in.readString(); this.textDocs.add(context); } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java index 3023d5c3fc..6a70beb5a1 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSet.java @@ -7,18 +7,20 @@ import java.io.IOException; import java.util.Map; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; + import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.CommonValue; -import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.annotation.InputDataSet; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.MLInputDataset; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; + @Getter @InputDataSet(MLInputDataType.REMOTE) public class RemoteInferenceInputDataSet extends MLInputDataset { @@ -43,7 +45,7 @@ public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException { super(MLInputDataType.REMOTE); Version streamInputVersion = streamInput.getVersion(); if (streamInput.readBoolean()) { - parameters = streamInput.readMap(s -> s.readString(), s-> s.readString()); + parameters = streamInput.readMap(s -> s.readString(), s -> s.readString()); } if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CLIENT_CONFIG)) { if (streamInput.readBoolean()) { @@ -58,7 +60,7 @@ public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException { public void writeTo(StreamOutput streamOutput) throws IOException { super.writeTo(streamOutput); Version streamOutputVersion = streamOutput.getVersion(); - if (parameters != null) { + if (parameters != null) { streamOutput.writeBoolean(true); streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); } else { diff --git a/common/src/main/java/org/opensearch/ml/common/exception/ExecuteException.java b/common/src/main/java/org/opensearch/ml/common/exception/ExecuteException.java index 756ec4f319..a8d4322e8c 100644 --- a/common/src/main/java/org/opensearch/ml/common/exception/ExecuteException.java +++ b/common/src/main/java/org/opensearch/ml/common/exception/ExecuteException.java @@ -1,7 +1,15 @@ package org.opensearch.ml.common.exception; -public class ExecuteException extends MLException{ - public ExecuteException(String msg) { super(msg); } - public ExecuteException(Throwable cause) { super(cause); } - public ExecuteException(String msg, Throwable cause) { super(msg, cause); } +public class ExecuteException extends MLException { + public ExecuteException(String msg) { + super(msg); + } + + public ExecuteException(Throwable cause) { + super(cause); + } + + public ExecuteException(String msg, Throwable cause) { + super(msg, cause); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/exception/MLLimitExceededException.java b/common/src/main/java/org/opensearch/ml/common/exception/MLLimitExceededException.java index b5a529ad6d..476c0e0c8a 100644 --- a/common/src/main/java/org/opensearch/ml/common/exception/MLLimitExceededException.java +++ b/common/src/main/java/org/opensearch/ml/common/exception/MLLimitExceededException.java @@ -9,7 +9,7 @@ * This exception is thrown when a some limit is exceeded. * Won't count this exception in stats. */ -public class MLLimitExceededException extends MLException{ +public class MLLimitExceededException extends MLException { /** * Constructor with error message. diff --git a/common/src/main/java/org/opensearch/ml/common/input/InputHelper.java b/common/src/main/java/org/opensearch/ml/common/input/InputHelper.java index 067c74e4b2..fb2bbb8a7b 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/InputHelper.java +++ b/common/src/main/java/org/opensearch/ml/common/input/InputHelper.java @@ -5,15 +5,6 @@ package org.opensearch.ml.common.input; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; -import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; -import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; - -import java.util.Locale; -import java.util.Map; - import static org.opensearch.ml.common.FunctionName.BATCH_RCF; import static org.opensearch.ml.common.FunctionName.FIT_RCF; import static org.opensearch.ml.common.FunctionName.KMEANS; @@ -34,6 +25,15 @@ import static org.opensearch.ml.common.input.Constants.KM_DISTANCE_TYPE; import static org.opensearch.ml.common.input.Constants.KM_ITERATIONS; +import java.util.Locale; +import java.util.Map; + +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; +import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; +import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; + public class InputHelper { public static String getAction(Map arguments) { return (String) arguments.get(ACTION); @@ -42,22 +42,19 @@ public static String getAction(Map arguments) { public static FunctionName getFunctionName(Map arguments) { String algo = (String) arguments.get(ALGORITHM); if (algo == null) { - throw new IllegalArgumentException("The parameter algorithm is required."); + throw new IllegalArgumentException("The parameter algorithm is required."); } switch (algo.toLowerCase(Locale.ROOT)) { case Constants.KMEANS: return KMEANS; case Constants.RCF: - return arguments.get(AD_TIME_FIELD) == null ? - BATCH_RCF : FIT_RCF; + return arguments.get(AD_TIME_FIELD) == null ? BATCH_RCF : FIT_RCF; default: - throw new IllegalArgumentException( - String.format("unsupported algorithm: %s.", algo)); + throw new IllegalArgumentException(String.format("unsupported algorithm: %s.", algo)); } } - public static MLAlgoParams convertArgumentToMLParameter(Map arguments, - FunctionName func) { + public static MLAlgoParams convertArgumentToMLParameter(Map arguments, FunctionName func) { switch (func) { case KMEANS: return buildKMeansParameters(arguments); @@ -66,45 +63,46 @@ public static MLAlgoParams convertArgumentToMLParameter(Map argu case FIT_RCF: return buildFitRCFParameters(arguments); default: - throw new IllegalArgumentException( - String.format("unsupported algorithm: %s.", func)); + throw new IllegalArgumentException(String.format("unsupported algorithm: %s.", func)); } } private static MLAlgoParams buildKMeansParameters(Map arguments) { - return KMeansParams.builder() - .centroids((Integer) arguments.get(KM_CENTROIDS)) - .iterations((Integer) arguments.get(KM_ITERATIONS)) - .distanceType(arguments.containsKey(KM_DISTANCE_TYPE) - ? KMeansParams.DistanceType.valueOf(( - (String) arguments.get(KM_DISTANCE_TYPE)).toUpperCase(Locale.ROOT)) - : null) - .build(); + return KMeansParams + .builder() + .centroids((Integer) arguments.get(KM_CENTROIDS)) + .iterations((Integer) arguments.get(KM_ITERATIONS)) + .distanceType( + arguments.containsKey(KM_DISTANCE_TYPE) + ? KMeansParams.DistanceType.valueOf(((String) arguments.get(KM_DISTANCE_TYPE)).toUpperCase(Locale.ROOT)) + : null + ) + .build(); } private static MLAlgoParams buildBatchRCFParameters(Map arguments) { - return BatchRCFParams.builder() - .numberOfTrees((Integer) arguments.get(AD_NUMBER_OF_TREES)) - .sampleSize((Integer) arguments.get(AD_SAMPLE_SIZE)) - .outputAfter((Integer) arguments.get(AD_OUTPUT_AFTER)) - .trainingDataSize((Integer) arguments.get(AD_TRAINING_DATA_SIZE)) - .anomalyScoreThreshold((Double) arguments.get(AD_ANOMALY_SCORE_THRESHOLD)) - .build(); + return BatchRCFParams + .builder() + .numberOfTrees((Integer) arguments.get(AD_NUMBER_OF_TREES)) + .sampleSize((Integer) arguments.get(AD_SAMPLE_SIZE)) + .outputAfter((Integer) arguments.get(AD_OUTPUT_AFTER)) + .trainingDataSize((Integer) arguments.get(AD_TRAINING_DATA_SIZE)) + .anomalyScoreThreshold((Double) arguments.get(AD_ANOMALY_SCORE_THRESHOLD)) + .build(); } private static MLAlgoParams buildFitRCFParameters(Map arguments) { - return FitRCFParams.builder() - .numberOfTrees((Integer) arguments.get(AD_NUMBER_OF_TREES)) - .shingleSize((Integer) arguments.get(AD_SHINGLE_SIZE)) - .sampleSize((Integer) arguments.get(AD_SAMPLE_SIZE)) - .outputAfter((Integer) arguments.get(AD_OUTPUT_AFTER)) - .timeDecay((Double) arguments.get(AD_TIME_DECAY)) - .anomalyRate((Double) arguments.get(AD_ANOMALY_RATE)) - .timeField((String) arguments.get(AD_TIME_FIELD)) - .dateFormat(arguments.containsKey(AD_DATE_FORMAT) - ? ((String) arguments.get(AD_DATE_FORMAT)) - : "yyyy-MM-dd HH:mm:ss") - .timeZone((String) arguments.get(AD_TIME_ZONE)) - .build(); + return FitRCFParams + .builder() + .numberOfTrees((Integer) arguments.get(AD_NUMBER_OF_TREES)) + .shingleSize((Integer) arguments.get(AD_SHINGLE_SIZE)) + .sampleSize((Integer) arguments.get(AD_SAMPLE_SIZE)) + .outputAfter((Integer) arguments.get(AD_OUTPUT_AFTER)) + .timeDecay((Double) arguments.get(AD_TIME_DECAY)) + .anomalyRate((Double) arguments.get(AD_ANOMALY_RATE)) + .timeField((String) arguments.get(AD_TIME_FIELD)) + .dateFormat(arguments.containsKey(AD_DATE_FORMAT) ? ((String) arguments.get(AD_DATE_FORMAT)) : "yyyy-MM-dd HH:mm:ss") + .timeZone((String) arguments.get(AD_TIME_ZONE)) + .build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index 4bf166f9b6..2faa3a599f 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -5,38 +5,38 @@ package org.opensearch.ml.common.input; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.input.remote.RemoteInferenceMLInput.ACTION_TYPE_FIELD; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.connector.ConnectorAction.ActionType; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLCommonsClassLoader; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DefaultDataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; -import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.search.builder.SearchSourceBuilder; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.input.remote.RemoteInferenceMLInput.ACTION_TYPE_FIELD; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; /** * ML input data: algorithm name, parameters and input data set. @@ -89,8 +89,14 @@ public MLInput(FunctionName algorithm, MLAlgoParams parameters, MLInputDataset i this.inputDataset = inputDataset; } - public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder, - List sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) { + public MLInput( + FunctionName algorithm, + MLAlgoParams parameters, + SearchSourceBuilder searchSourceBuilder, + List sourceIndices, + DataFrame dataFrame, + MLInputDataset inputDataset + ) { validate(algorithm); this.algorithm = algorithm; this.parameters = parameters; @@ -146,12 +152,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (inputDataset != null) { switch (inputDataset.getInputDataType()) { case SEARCH_QUERY: - builder.field(INPUT_INDEX_FIELD, ((SearchQueryInputDataset)inputDataset).getIndices().toArray(new String[0])); - builder.field(INPUT_QUERY_FIELD, ((SearchQueryInputDataset)inputDataset).getSearchSourceBuilder()); + builder.field(INPUT_INDEX_FIELD, ((SearchQueryInputDataset) inputDataset).getIndices().toArray(new String[0])); + builder.field(INPUT_QUERY_FIELD, ((SearchQueryInputDataset) inputDataset).getSearchSourceBuilder()); break; case DATA_FRAME: builder.startObject(INPUT_DATA_FIELD); - ((DataFrameInputDataset)inputDataset).getDataFrame().toXContent(builder, EMPTY_PARAMS); + ((DataFrameInputDataset) inputDataset).getDataFrame().toXContent(builder, EMPTY_PARAMS); builder.endObject(); break; case TEXT_DOCS: @@ -181,7 +187,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(QUERY_TEXT_FIELD, queryText); if (documents != null && !documents.isEmpty()) { builder.startArray(TEXT_DOCS_FIELD); - for(String d : documents) { + for (String d : documents) { builder.value(d); } builder.endArray(); @@ -212,7 +218,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static MLInput parse(XContentParser parser, String inputAlgoName, ActionType actionType) throws IOException { MLInput mlInput = parse(parser, inputAlgoName); if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { - RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet)mlInput.getInputDataset(); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) mlInput.getInputDataset(); if (remoteInferenceInputDataSet.getActionType() == null) { remoteInferenceInputDataSet.setActionType(actionType); } @@ -225,7 +231,8 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws FunctionName algorithm = FunctionName.from(algorithmName); if (MLCommonsClassLoader.canInitMLInput(algorithm)) { - MLInput mlInput = MLCommonsClassLoader.initMLInput(algorithm, new Object[]{parser, algorithm}, XContentParser.class, FunctionName.class); + MLInput mlInput = MLCommonsClassLoader + .initMLInput(algorithm, new Object[] { parser, algorithm }, XContentParser.class, FunctionName.class); mlInput.setAlgorithm(algorithm); return mlInput; } @@ -305,7 +312,9 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws } } MLInputDataset inputDataSet = null; - if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.SPARSE_TOKENIZE) { + if (algorithm == FunctionName.TEXT_EMBEDDING + || algorithm == FunctionName.SPARSE_ENCODING + || algorithm == FunctionName.SPARSE_TOKENIZE) { ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions); inputDataSet = new TextDocsInputDataSet(textDocs, filter); } else if (algorithm == FunctionName.TEXT_SIMILARITY) { diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java index 3aa3ac382b..e39ceffe8c 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java @@ -5,9 +5,11 @@ package org.opensearch.ml.common.input.execute.agent; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Map; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentParser; @@ -17,18 +19,17 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; -import java.io.IOException; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; -@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.AGENT}) +@org.opensearch.ml.common.annotation.MLInput(functionNames = { FunctionName.AGENT }) public class AgentMLInput extends MLInput { public static final String AGENT_ID_FIELD = "agent_id"; public static final String PARAMETERS_FIELD = "parameters"; - @Getter @Setter + @Getter + @Setter private String agentId; @Builder(builderMethodName = "AgentMLInputBuilder") diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInput.java index 6383bf6646..823763aa9d 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInput.java @@ -5,22 +5,25 @@ package org.opensearch.ml.common.input.execute.anomalylocalization; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Optional; +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.QueryBuilder; -import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.input.Input; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregatorFactories; @@ -28,13 +31,10 @@ import lombok.AllArgsConstructor; import lombok.Data; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; - /** * Information about aggregate, time, etc to localize. */ -@ExecuteInput(algorithms={FunctionName.ANOMALY_LOCALIZATION}) +@ExecuteInput(algorithms = { FunctionName.ANOMALY_LOCALIZATION }) @Data @AllArgsConstructor public class AnomalyLocalizationInput implements Input { @@ -50,9 +50,9 @@ public class AnomalyLocalizationInput implements Input { public static final String FIELD_ANOMALY_START_TIME = "anomaly_start_time"; public static final String FIELD_FILTER_QUERY = "filter_query"; public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY_ENTRY = new NamedXContentRegistry.Entry( - Input.class, - new ParseField(FunctionName.ANOMALY_LOCALIZATION.name()), - parser -> parse(parser) + Input.class, + new ParseField(FunctionName.ANOMALY_LOCALIZATION.name()), + parser -> parse(parser) ); public static AnomalyLocalizationInput parse(XContentParser parser) throws IOException { @@ -124,9 +124,18 @@ public static AnomalyLocalizationInput parse(XContentParser parser) throws IOExc break; } } - return new AnomalyLocalizationInput(indexName, attributeFieldNames, aggregations, timeFieldName, startTime, endTime, - minTimeInterval, numOutputs, - anomalyStartTime, filterQuery); + return new AnomalyLocalizationInput( + indexName, + attributeFieldNames, + aggregations, + timeFieldName, + startTime, + endTime, + minTimeInterval, + numOutputs, + anomalyStartTime, + filterQuery + ); } private final String indexName; // name pattern of the data index diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInput.java index 3de3cee60d..8a8713cd47 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInput.java @@ -5,11 +5,15 @@ package org.opensearch.ml.common.input.execute.metricscorrelation; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,20 +21,17 @@ import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.input.Input; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; -@ExecuteInput(algorithms={FunctionName.METRICS_CORRELATION}) +@ExecuteInput(algorithms = { FunctionName.METRICS_CORRELATION }) @Data public class MetricsCorrelationInput implements Input { public static final String PARSE_FIELD_NAME = FunctionName.METRICS_CORRELATION.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - Input.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + Input.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String METRICS_FIELD = "metrics"; diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInput.java index a4d08fb69f..91c046e99b 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInput.java @@ -5,32 +5,33 @@ package org.opensearch.ml.common.input.execute.samplecalculator; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.input.Input; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; -@ExecuteInput(algorithms={FunctionName.LOCAL_SAMPLE_CALCULATOR}) +@ExecuteInput(algorithms = { FunctionName.LOCAL_SAMPLE_CALCULATOR }) @Data public class LocalSampleCalculatorInput implements Input { public static final String PARSE_FIELD_NAME = FunctionName.LOCAL_SAMPLE_CALCULATOR.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - Input.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + Input.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String OPERATION_FIELD = "operation"; @@ -87,7 +88,7 @@ public LocalSampleCalculatorInput(StreamInput in) throws IOException { this.operation = in.readString(); int size = in.readInt(); this.inputData = new ArrayList<>(); - for (int i = 0; i docs = ds.getTextDocs(); String queryText = ds.getQueryText(); builder.field(QUERY_TEXT_FIELD, queryText); if (docs != null && !docs.isEmpty()) { builder.startArray(TEXT_DOCS_FIELD); - for(String d : docs) { + for (String d : docs) { builder.value(d); } builder.endArray(); @@ -97,18 +96,18 @@ public TextSimilarityMLInput(XContentParser parser, FunctionName functionName) t docs.add(context); } break; - case QUERY_TEXT_FIELD: + case QUERY_TEXT_FIELD: queryText = parser.text(); break; default: parser.skipChildren(); break; } - } - if(docs.isEmpty()) { + } + if (docs.isEmpty()) { throw new IllegalArgumentException("No text documents were provided"); } - if(queryText == null) { + if (queryText == null) { throw new IllegalArgumentException("No query text was provided"); } inputDataset = new TextSimilarityInputDataSet(queryText, docs); diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParams.java index 87bd3d0d36..8363e30982 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParams.java @@ -5,11 +5,14 @@ package org.opensearch.ml.common.input.parameter.ad; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,19 +20,17 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.AD_LIBSVM}) +@MLAlgoParameter(algorithms = { FunctionName.AD_LIBSVM }) public class AnomalyDetectionLibSVMParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.AD_LIBSVM.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String KERNEL_FIELD = "kernel"; @@ -47,9 +48,16 @@ public class AnomalyDetectionLibSVMParams implements MLAlgoParams { private Double epsilon; private Integer degree; - @Builder(toBuilder = true) - public AnomalyDetectionLibSVMParams(ADKernelType kernelType, Double gamma, Double nu, Double cost, Double coeff, Double epsilon, Integer degree) { + public AnomalyDetectionLibSVMParams( + ADKernelType kernelType, + Double gamma, + Double nu, + Double cost, + Double coeff, + Double epsilon, + Integer degree + ) { this.kernelType = kernelType; this.gamma = gamma; this.nu = nu; @@ -176,7 +184,7 @@ public enum ADKernelType { SIGMOID; public static ADKernelType from(String value) { - try{ + try { return ADKernelType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong AD kernel type"); diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParams.java index 73fff86f94..39d3684a92 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParams.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.input.parameter.clustering; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -18,32 +20,31 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.KMEANS}) +@MLAlgoParameter(algorithms = { FunctionName.KMEANS }) public class KMeansParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.KMEANS.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String CENTROIDS_FIELD = "centroids"; public static final String ITERATIONS_FIELD = "iterations"; public static final String DISTANCE_TYPE_FIELD = "distance_type"; - //The number of centroids to use. + // The number of centroids to use. private Integer centroids; - //The maximum number of iterations + // The maximum number of iterations private Integer iterations; - //The distance function. + // The distance function. private DistanceType distanceType; - //TODO: expose number of thread and seed? + // TODO: expose number of thread and seed? @Builder(toBuilder = true) public KMeansParams(Integer centroids, Integer iterations, DistanceType distanceType) { diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParams.java index b23461428c..3956514b47 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParams.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.input.parameter.clustering; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,18 +19,17 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.RCF_SUMMARIZE}) +@MLAlgoParameter(algorithms = { FunctionName.RCF_SUMMARIZE }) public class RCFSummarizeParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.RCF_SUMMARIZE.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String MAX_K_FIELD = "max_k"; @@ -37,7 +38,7 @@ public class RCFSummarizeParams implements MLAlgoParams { public static final String PHASE1_REASSIGN_FIELD = "phase1_reassign"; public static final String PARALLEL__FIELD = "parallel"; - // The max of K allowed + // The max of K allowed private Integer maxK; // The initial K used private Integer initialK; diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParams.java index 3c284a51a6..fc23c13933 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParams.java @@ -5,30 +5,31 @@ package org.opensearch.ml.common.input.parameter.rcf; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.BATCH_RCF}) +@MLAlgoParameter(algorithms = { FunctionName.BATCH_RCF }) public class BatchRCFParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.BATCH_RCF.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String NUMBER_OF_TREES = "number_of_trees"; @@ -45,12 +46,14 @@ public class BatchRCFParams implements MLAlgoParams { private Double anomalyScoreThreshold; @Builder - public BatchRCFParams(Integer numberOfTrees, - Integer shingleSize, - Integer sampleSize, - Integer outputAfter, - Integer trainingDataSize, - Double anomalyScoreThreshold) { + public BatchRCFParams( + Integer numberOfTrees, + Integer shingleSize, + Integer sampleSize, + Integer outputAfter, + Integer trainingDataSize, + Double anomalyScoreThreshold + ) { this.numberOfTrees = numberOfTrees; this.shingleSize = shingleSize; this.sampleSize = sampleSize; @@ -115,8 +118,7 @@ public static BatchRCFParams parse(XContentParser parser) throws IOException { break; } } - return new BatchRCFParams(numberOfTrees, shingleSize, sampleSize, outputAfter, - trainingDataSize, anomalyScoreThreshold); + return new BatchRCFParams(numberOfTrees, shingleSize, sampleSize, outputAfter, trainingDataSize, anomalyScoreThreshold); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParams.java index d55fd57735..59ae8f7037 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParams.java @@ -5,30 +5,31 @@ package org.opensearch.ml.common.input.parameter.rcf; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.FIT_RCF}) +@MLAlgoParameter(algorithms = { FunctionName.FIT_RCF }) public class FitRCFParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.FIT_RCF.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String NUMBER_OF_TREES = "number_of_trees"; @@ -51,15 +52,17 @@ public class FitRCFParams implements MLAlgoParams { private String timeZone; @Builder - public FitRCFParams(Integer numberOfTrees, - Integer shingleSize, - Integer sampleSize, - Integer outputAfter, - Double timeDecay, - Double anomalyRate, - String timeField, - String dateFormat, - String timeZone) { + public FitRCFParams( + Integer numberOfTrees, + Integer shingleSize, + Integer sampleSize, + Integer outputAfter, + Double timeDecay, + Double anomalyRate, + String timeField, + String dateFormat, + String timeZone + ) { this.numberOfTrees = numberOfTrees; this.shingleSize = shingleSize; this.sampleSize = sampleSize; @@ -145,8 +148,17 @@ public static FitRCFParams parse(XContentParser parser) throws IOException { break; } } - return new FitRCFParams(numberOfTrees, shingleSize, sampleSize, outputAfter, - timeDecay, anomalyRate, timeField, dateFormat, timeZone); + return new FitRCFParams( + numberOfTrees, + shingleSize, + sampleSize, + outputAfter, + timeDecay, + anomalyRate, + timeField, + dateFormat, + timeZone + ); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java index 9e9cb7f129..9ea9d88959 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java @@ -5,11 +5,14 @@ package org.opensearch.ml.common.input.parameter.regression; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,20 +20,18 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.LINEAR_REGRESSION}) +@MLAlgoParameter(algorithms = { FunctionName.LINEAR_REGRESSION }) public class LinearRegressionParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.LINEAR_REGRESSION.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String OBJECTIVE_FIELD = "objective"; @@ -64,7 +65,22 @@ public class LinearRegressionParams implements MLAlgoParams { private String target; @Builder(toBuilder = true) - public LinearRegressionParams(ObjectiveType objectiveType, OptimizerType optimizerType, Double learningRate, MomentumType momentumType, Double momentumFactor, Double epsilon, Double beta1, Double beta2, Double decayRate, Integer epochs, Integer batchSize, Integer loggingInterval, Long seed, String target) { + public LinearRegressionParams( + ObjectiveType objectiveType, + OptimizerType optimizerType, + Double learningRate, + MomentumType momentumType, + Double momentumFactor, + Double epsilon, + Double beta1, + Double beta2, + Double decayRate, + Integer epochs, + Integer batchSize, + Integer loggingInterval, + Long seed, + String target + ) { this.objectiveType = objectiveType; this.optimizerType = optimizerType; this.learningRate = learningRate; @@ -173,7 +189,22 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { break; } } - return new LinearRegressionParams(objective, optimizerType, learningRate, momentumType, momentumFactor, epsilon, beta1, beta2,decayRate, epochs, batchSize, loggingInterval, seed, target); + return new LinearRegressionParams( + objective, + optimizerType, + learningRate, + momentumType, + momentumFactor, + epsilon, + beta1, + beta2, + decayRate, + epochs, + batchSize, + loggingInterval, + seed, + target + ); } @Override @@ -272,8 +303,9 @@ public enum ObjectiveType { SQUARED_LOSS, ABSOLUTE_LOSS, HUBER; + public static ObjectiveType from(String value) { - try{ + try { return ObjectiveType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong objective type"); @@ -286,7 +318,7 @@ public enum MomentumType { NESTEROV; public static MomentumType from(String value) { - try{ + try { return MomentumType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong momentum type"); @@ -304,7 +336,7 @@ public enum OptimizerType { RMS_PROP; public static OptimizerType from(String value) { - try{ + try { return OptimizerType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong optimizer type"); diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java index 3340050ff5..d4238c1111 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java @@ -5,11 +5,14 @@ package org.opensearch.ml.common.input.parameter.regression; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,20 +20,18 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.LOGISTIC_REGRESSION}) +@MLAlgoParameter(algorithms = { FunctionName.LOGISTIC_REGRESSION }) public class LogisticRegressionParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.LOGISTIC_REGRESSION.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String OBJECTIVE_FIELD = "objective"; @@ -188,7 +189,22 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { break; } } - return new LogisticRegressionParams(objective, optimizerType, momentumType, learningRate, epsilon, momentumFactor, beta1, beta2, decayRate, epochs, batchSize, loggingInterval, seed, target); + return new LogisticRegressionParams( + objective, + optimizerType, + momentumType, + learningRate, + epsilon, + momentumFactor, + beta1, + beta2, + decayRate, + epochs, + batchSize, + loggingInterval, + seed, + target + ); } @Override @@ -286,8 +302,9 @@ public int getVersion() { public enum ObjectiveType { HINGE, LOGMULTICLASS; + public static ObjectiveType from(String value) { - try{ + try { return ObjectiveType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong objective type"); @@ -300,7 +317,7 @@ public enum MomentumType { NESTEROV; public static MomentumType from(String value) { - try{ + try { return MomentumType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong momentum type"); @@ -318,7 +335,7 @@ public enum OptimizerType { RMS_PROP; public static OptimizerType from(String value) { - try{ + try { return OptimizerType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong optimizer type"); diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParams.java index 2544a748f5..7fc8c8be38 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParams.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.input.parameter.sample; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,18 +19,17 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.SAMPLE_ALGO}) +@MLAlgoParameter(algorithms = { FunctionName.SAMPLE_ALGO }) public class SampleAlgoParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.SAMPLE_ALGO.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String SAMPLE_PARAM_FIELD = "sample_param"; diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java index be7c139efa..f73b83e106 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java @@ -33,7 +33,7 @@ * `query_prefix` and `passage_prefix` configuration parameters. */ @Data -@MLAlgoParameter(algorithms={FunctionName.TEXT_EMBEDDING}) +@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING }) public class AsymmetricTextEmbeddingParameters implements MLAlgoParams { public enum EmbeddingContentType { diff --git a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java index cd45cb19cb..f30d845179 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.input.remote; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Map; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentParser; @@ -14,12 +19,7 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; -import java.io.IOException; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE}) +@org.opensearch.ml.common.annotation.MLInput(functionNames = { FunctionName.REMOTE }) public class RemoteInferenceMLInput extends MLInput { public static final String PARAMETERS_FIELD = "parameters"; public static final String ACTION_TYPE_FIELD = "action_type"; diff --git a/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java b/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java index 598121b8ed..3c3e828021 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java @@ -5,14 +5,14 @@ package org.opensearch.ml.common.model; +import java.io.IOException; +import java.util.Map; + import org.opensearch.client.Client; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; -import java.io.IOException; -import java.util.Map; - public abstract class Guardrail implements ToXContentObject { public abstract void writeTo(StreamOutput out) throws IOException; diff --git a/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java b/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java index db7558b7cc..b6f0017878 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java +++ b/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java @@ -5,20 +5,21 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Map; +import java.util.Set; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.Map; -import java.util.Set; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; @EqualsAndHashCode @Getter @@ -126,11 +127,12 @@ public static Guardrails parse(XContentParser parser) throws IOException { throw new IllegalArgumentException("The type of guardrails is required, can not be null."); } - return Guardrails.builder() - .type(type) - .inputGuardrail(createGuardrail(type, inputGuardrailMap)) - .outputGuardrail(createGuardrail(type, outputGuardrailMap)) - .build(); + return Guardrails + .builder() + .type(type) + .inputGuardrail(createGuardrail(type, inputGuardrailMap)) + .outputGuardrail(createGuardrail(type, outputGuardrailMap)) + .build(); } private static Boolean validateType(String type) { diff --git a/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java index 0f142bde3b..8da4e4db97 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java @@ -5,25 +5,10 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.NonNull; -import lombok.extern.log4j.Log4j2; -import org.opensearch.action.LatchedActionListener; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.search.builder.SearchSourceBuilder; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.stopWordsIndices; +import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; import java.security.AccessController; @@ -39,10 +24,26 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.stopWordsIndices; -import static org.opensearch.ml.common.utils.StringUtils.gson; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.builder.SearchSourceBuilder; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.log4j.Log4j2; @Log4j2 @EqualsAndHashCode @@ -63,6 +64,7 @@ public LocalRegexGuardrail(List stopWords, String[] regex) { this.stopWords = stopWords; this.regex = regex; } + public LocalRegexGuardrail(@NonNull Map params) { List words = (List) params.get(STOP_WORDS_FIELD); stopWords = new ArrayList<>(); @@ -81,7 +83,7 @@ public LocalRegexGuardrail(StreamInput input) throws IOException { if (input.readBoolean()) { stopWords = new ArrayList<>(); int size = input.readInt(); - for (int i=0; i queryBodyMap = Map - .of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap))); + Map queryBodyMap = Map.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap))); CountDownLatch latch = new CountDownLatch(1); ThreadContext.StoredContext context = null; try { queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap)); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); + XContentParser queryParser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); searchSourceBuilder.parseXContent(queryParser); - searchSourceBuilder.size(1); //Only need 1 doc returned, if hit. + searchSourceBuilder.size(1); // Only need 1 doc returned, if hit. searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName); if (isStopWordsSystemIndex(indexName)) { context = client.threadPool().getThreadContext().stashContext(); diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java index 23c81d9ead..16147dbafa 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java @@ -5,22 +5,23 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + import org.opensearch.Version; -import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.CommonValue; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Setter @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java index 3aa8d060b1..7bf146f36c 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java @@ -5,12 +5,13 @@ package org.opensearch.ml.common.model; -import lombok.Getter; -import lombok.extern.log4j.Log4j2; +import java.util.Map; + import org.opensearch.client.Client; import org.opensearch.core.xcontent.NamedXContentRegistry; -import java.util.Map; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; @Log4j2 @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/MLModelConfig.java index 2fb07b6d8e..67f13e4f62 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLModelConfig.java @@ -5,14 +5,15 @@ package org.opensearch.ml.common.model; -import lombok.Getter; -import lombok.Setter; +import java.io.IOException; + import org.opensearch.core.common.io.stream.NamedWriteable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; -import java.io.IOException; +import lombok.Getter; +import lombok.Setter; @Setter @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java b/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java index cfd06be1f0..3bb84c6bbd 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java @@ -37,4 +37,4 @@ public static MLModelState from(String value) { throw new IllegalArgumentException("Wrong model state"); } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java index e1c9203cae..b690bd0342 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java @@ -5,18 +5,18 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Setter @Getter @@ -29,7 +29,7 @@ public MetricsCorrelationModelConfig(String modelType, String allConfig) { super(modelType, allConfig); } - public MetricsCorrelationModelConfig(StreamInput in) throws IOException{ + public MetricsCorrelationModelConfig(StreamInput in) throws IOException { super(in); } diff --git a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java index d64050a8a3..edf78601ed 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java @@ -5,11 +5,21 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.NonNull; -import lombok.extern.log4j.Log4j2; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + import org.opensearch.action.ActionRequest; import org.opensearch.action.LatchedActionListener; import org.opensearch.client.Client; @@ -29,20 +39,11 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedExceptionAction; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.gson; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.log4j.Log4j2; @Log4j2 @EqualsAndHashCode @@ -65,8 +66,13 @@ public ModelGuardrail(String modelId, String responseFilter, String responseAcce this.responseFilter = responseFilter; this.responseAccept = responseAccept; } + public ModelGuardrail(@NonNull Map params) { - this((String) params.get(MODEL_ID_FIELD), (String) params.get(RESPONSE_FILTER_FIELD), (String) params.get(RESPONSE_VALIDATION_REGEX_FIELD)); + this( + (String) params.get(MODEL_ID_FIELD), + (String) params.get(RESPONSE_FILTER_FIELD), + (String) params.get(RESPONSE_VALIDATION_REGEX_FIELD) + ); } public ModelGuardrail(StreamInput input) throws IOException { @@ -97,14 +103,14 @@ public Boolean validate(String in, Map parameters) { AtomicBoolean isAccepted = new AtomicBoolean(true); ActionListener internalListener = ActionListener.wrap(predictionResponse -> { ModelTensorOutput output = (ModelTensorOutput) predictionResponse.getOutput(); - ModelTensor tensor = output.getMlModelOutputs().get(0).getMlModelTensors().get(0); + ModelTensor tensor = output.getMlModelOutputs().get(0).getMlModelTensors().get(0); String guardrailResponse = AccessController - .doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(tensor.getDataAsMap().get("response"))); + .doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(tensor.getDataAsMap().get("response"))); log.info("Guardrail response: {}", guardrailResponse); if (!validateAcceptRegex(guardrailResponse)) { isAccepted.set(false); } - }, e -> {log.error("[ModelGuardrail] Failed to get prediction response.", e);}); + }, e -> { log.error("[ModelGuardrail] Failed to get prediction response.", e); }); ActionListener actionListener = wrapActionListener(internalListener, res -> { MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res); return predictionResponse; @@ -117,19 +123,14 @@ public Boolean validate(String in, Map parameters) { } log.info("Guardrail resFilter: {}", responseFilter); ActionRequest request = new MLPredictionTaskRequest( - modelId, - RemoteInferenceMLInput - .builder() - .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(guardrailModelParams).build()) - .build() + modelId, + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(guardrailModelParams).build()) + .build() ); - client - .execute( - MLPredictionTaskAction.INSTANCE, - request, - new LatchedActionListener(actionListener, latch) - ); + client.execute(MLPredictionTaskAction.INSTANCE, request, new LatchedActionListener(actionListener, latch)); try { latch.await(5, SECONDS); } catch (InterruptedException e) { @@ -187,16 +188,12 @@ public static ModelGuardrail parse(XContentParser parser) throws IOException { break; } } - return ModelGuardrail.builder() - .modelId(modelId) - .responseFilter(responseFilter) - .responseAccept(responseAccept) - .build(); + return ModelGuardrail.builder().modelId(modelId).responseFilter(responseFilter).responseAccept(responseAccept).build(); } private ActionListener wrapActionListener( - final ActionListener listener, - final Function recreate + final ActionListener listener, + final Function recreate ) { ActionListener actionListener = ActionListener.wrap(r -> { listener.onResponse(recreate.apply(r)); diff --git a/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java index 7b01f847a2..503f20249f 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfig.java @@ -5,9 +5,11 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -16,19 +18,18 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Setter @Getter public class QuestionAnsweringModelConfig extends MLModelConfig { public static final String PARSE_FIELD_NAME = FunctionName.QUESTION_ANSWERING.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - QuestionAnsweringModelConfig.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + QuestionAnsweringModelConfig.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String FRAMEWORK_TYPE_FIELD = "framework_type"; public static final String NORMALIZE_RESULT_FIELD = "normalize_result"; @@ -39,7 +40,13 @@ public class QuestionAnsweringModelConfig extends MLModelConfig { private final Integer modelMaxLength; @Builder(toBuilder = true) - public QuestionAnsweringModelConfig(String modelType, FrameworkType frameworkType, String allConfig, boolean normalizeResult, Integer modelMaxLength) { + public QuestionAnsweringModelConfig( + String modelType, + FrameworkType frameworkType, + String allConfig, + boolean normalizeResult, + Integer modelMaxLength + ) { super(modelType, allConfig); if (frameworkType == null) { throw new IllegalArgumentException("framework type is null"); @@ -90,7 +97,7 @@ public String getWriteableName() { return PARSE_FIELD_NAME; } - public QuestionAnsweringModelConfig(StreamInput in) throws IOException{ + public QuestionAnsweringModelConfig(StreamInput in) throws IOException { super(in); frameworkType = in.readEnum(FrameworkType.class); normalizeResult = in.readBoolean(); @@ -126,6 +133,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + public enum FrameworkType { HUGGINGFACE_TRANSFORMERS, SENTENCE_TRANSFORMERS, diff --git a/common/src/main/java/org/opensearch/ml/common/model/StopWords.java b/common/src/main/java/org/opensearch/ml/common/model/StopWords.java index 648f465891..a66c70f7b0 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/StopWords.java +++ b/common/src/main/java/org/opensearch/ml/common/model/StopWords.java @@ -5,21 +5,22 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.NonNull; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.List; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; @EqualsAndHashCode @Getter @@ -86,9 +87,6 @@ public static StopWords parse(XContentParser parser) throws IOException { break; } } - return StopWords.builder() - .index(index) - .sourceFields(sourceFields) - .build(); + return StopWords.builder().index(index).sourceFields(sourceFields).build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java index b1c249da44..9576f19036 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java @@ -5,31 +5,32 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Setter @Getter public class TextEmbeddingModelConfig extends MLModelConfig { public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - TextEmbeddingModelConfig.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + TextEmbeddingModelConfig.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension"; @@ -48,14 +49,30 @@ public class TextEmbeddingModelConfig extends MLModelConfig { private final String queryPrefix; private final String passagePrefix; - public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig, - PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) { + public TextEmbeddingModelConfig( + String modelType, + Integer embeddingDimension, + FrameworkType frameworkType, + String allConfig, + PoolingMode poolingMode, + boolean normalizeResult, + Integer modelMaxLength + ) { this(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, null, null); } @Builder(toBuilder = true) - public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig, - PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength, String queryPrefix, String passagePrefix) { + public TextEmbeddingModelConfig( + String modelType, + Integer embeddingDimension, + FrameworkType frameworkType, + String allConfig, + PoolingMode poolingMode, + boolean normalizeResult, + Integer modelMaxLength, + String queryPrefix, + String passagePrefix + ) { super(modelType, allConfig); if (embeddingDimension == null) { throw new IllegalArgumentException("embedding dimension is null"); @@ -121,7 +138,17 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc break; } } - return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength, queryPrefix, passagePrefix); + return new TextEmbeddingModelConfig( + modelType, + embeddingDimension, + frameworkType, + allConfig, + poolingMode, + normalizeResult, + modelMaxLength, + queryPrefix, + passagePrefix + ); } @Override @@ -129,7 +156,7 @@ public String getWriteableName() { return PARSE_FIELD_NAME; } - public TextEmbeddingModelConfig(StreamInput in) throws IOException{ + public TextEmbeddingModelConfig(StreamInput in) throws IOException { super(in); embeddingDimension = in.readInt(); frameworkType = in.readEnum(FrameworkType.class); @@ -208,6 +235,7 @@ public enum PoolingMode { public String getName() { return name; } + PoolingMode(String name) { this.name = name; } @@ -220,6 +248,7 @@ public static PoolingMode from(String value) { } } } + public enum FrameworkType { HUGGINGFACE_TRANSFORMERS, SENTENCE_TRANSFORMERS, diff --git a/common/src/main/java/org/opensearch/ml/common/output/MLOutput.java b/common/src/main/java/org/opensearch/ml/common/output/MLOutput.java index 83fbfe1cc1..d967059892 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/MLOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/MLOutput.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.output; -import lombok.NonNull; -import lombok.RequiredArgsConstructor; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.MLCommonsClassLoader; -import java.io.IOException; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; /** * ML output data. Must specify output type and diff --git a/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java b/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java index 28b5b07821..5675dab409 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java @@ -5,10 +5,8 @@ package org.opensearch.ml.common.output; -import lombok.Builder; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.ToString; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; @@ -17,10 +15,13 @@ import org.opensearch.ml.common.dataframe.DataFrameType; import org.opensearch.ml.common.dataframe.DefaultDataFrame; -import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.ToString; @Data -@EqualsAndHashCode(callSuper=false) +@EqualsAndHashCode(callSuper = false) @MLAlgoOutput(MLOutputType.PREDICTION) public class MLPredictionOutput extends MLOutput { diff --git a/common/src/main/java/org/opensearch/ml/common/output/MLTrainingOutput.java b/common/src/main/java/org/opensearch/ml/common/output/MLTrainingOutput.java index c69bb9ca74..c6bb98c73f 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/MLTrainingOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/MLTrainingOutput.java @@ -5,15 +5,16 @@ package org.opensearch.ml.common.output; -import lombok.Builder; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.annotation.MLAlgoOutput; -import java.io.IOException; +import lombok.Builder; +import lombok.Getter; @Getter @MLAlgoOutput(MLOutputType.TRAINING) @@ -32,7 +33,7 @@ public MLTrainingOutput(String modelId, String taskId, String status) { super(OUTPUT_TYPE); this.modelId = modelId; this.taskId = taskId; - this.status= status; + this.status = status; } public MLTrainingOutput(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutput.java b/common/src/main/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutput.java index dde28400b7..490ea101f1 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutput.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.output.execute.anomalylocalization; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -15,25 +17,23 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import lombok.SneakyThrows; -import lombok.ToString; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.annotation.ExecuteOutput; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.ExecuteOutput; import org.opensearch.ml.common.output.Output; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.SneakyThrows; +import lombok.ToString; /** * Output of localized results. */ -@ExecuteOutput(algorithms={FunctionName.ANOMALY_LOCALIZATION}) +@ExecuteOutput(algorithms = { FunctionName.ANOMALY_LOCALIZATION }) @Data @NoArgsConstructor public class AnomalyLocalizationOutput implements Output { @@ -102,8 +102,8 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par */ @Data @NoArgsConstructor - @ToString(exclude = {"base", "counter", "completed"}) - @EqualsAndHashCode(exclude = {"base", "counter", "completed"}) + @ToString(exclude = { "base", "counter", "completed" }) + @EqualsAndHashCode(exclude = { "base", "counter", "completed" }) public static class Bucket implements Output { public static final String FIELD_START_TIME = "start_time"; @@ -113,7 +113,7 @@ public static class Bucket implements Output { private long startTime; // start time of the bucket private long endTime; // end time of the bucket - private double overallAggValue; // overall value of the bucket + private double overallAggValue; // overall value of the bucket private List entities = null; // localized entities of the bucket private Optional base = Optional.empty(); @@ -134,7 +134,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeLong(startTime); out.writeLong(endTime); out.writeDouble(overallAggValue); - if (entities == null) { + if (entities == null) { out.writeBoolean(false); } else { out.writeBoolean(true); diff --git a/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensor.java b/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensor.java index a8dc54481b..2783787cc4 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensor.java +++ b/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensor.java @@ -5,15 +5,16 @@ package org.opensearch.ml.common.output.execute.metrics_correlation; -import lombok.Builder; -import lombok.Data; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Builder; +import lombok.Data; @Data public class MCorrModelTensor implements Writeable, ToXContentObject { diff --git a/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensors.java index d26a9e8b0e..5ebcc41248 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensors.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.output.execute.metrics_correlation; -import lombok.Builder; -import lombok.Getter; -import lombok.extern.log4j.Log4j2; -import org.opensearch.core.common.bytes.BytesReference; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -18,10 +20,9 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.output.model.ModelResultFilter; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; +import lombok.Builder; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; @Log4j2 @Getter @@ -48,7 +49,7 @@ public MCorrModelTensors(StreamInput in) throws IOException { if (in.readBoolean()) { mCorrModelTensors = new ArrayList<>(); int size = in.readInt(); - for (int i=0; i targetResponse = resultFilter.getTargetResponse(); List targetResponsePositions = resultFilter.getTargetResponsePositions(); if ((targetResponse == null || targetResponse.size() == 0) - && (targetResponsePositions == null || targetResponsePositions.size() == 0)) { - mCorrModelTensors.forEach(output -> filter(output, returnNumber)); + && (targetResponsePositions == null || targetResponsePositions.size() == 0)) { + mCorrModelTensors.forEach(output -> filter(output, returnNumber)); return; } List targetOutput = new ArrayList<>(); if (mCorrModelTensors != null) { - for (int i = 0 ; i(); int size = in.readInt(); - for (int i=0; i targetResponsePositions; @Builder - public ModelResultFilter(boolean returnBytes, - boolean returnNumber, - List targetResponse, - List targetResponsePositions + public ModelResultFilter( + boolean returnBytes, + boolean returnNumber, + List targetResponse, + List targetResponsePositions ) { this.returnBytes = returnBytes; this.returnNumber = returnNumber; @@ -65,7 +67,7 @@ public ModelResultFilter(StreamInput streamInput) throws IOException { if (streamInput.readBoolean()) { int size = streamInput.readInt(); targetResponsePositions = new ArrayList<>(); - for (int i=0;i dataAsMap;// whole result in Map @Builder - public ModelTensor(String name, Number[] data, long[] shape, MLResultDataType dataType, ByteBuffer byteBuffer, String result, Map dataAsMap) { + public ModelTensor( + String name, + Number[] data, + long[] shape, + MLResultDataType dataType, + ByteBuffer byteBuffer, + String result, + Map dataAsMap + ) { if (data != null && (dataType == null || dataType == MLResultDataType.UNKNOWN)) { throw new IllegalArgumentException("data type is null"); } @@ -179,14 +189,7 @@ public static ModelTensor parser(XContentParser parser) throws IOException { data[i] = (Number) dataList.get(i); } } - return ModelTensor.builder() - .name(name) - .shape(shape) - .dataType(dataType) - .data(data) - .result(result) - .dataAsMap(dataAsMap) - .build(); + return ModelTensor.builder().name(name).shape(shape).dataType(dataType).data(data).result(result).dataAsMap(dataAsMap).build(); } public ModelTensor(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java index 664bd3510f..32f3318718 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java @@ -5,9 +5,10 @@ package org.opensearch.ml.common.output.model; -import lombok.Builder; -import lombok.Data; -import lombok.EqualsAndHashCode; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; @@ -15,12 +16,12 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLOutputType; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; @Data -@EqualsAndHashCode(callSuper=false) +@EqualsAndHashCode(callSuper = false) @MLAlgoOutput(MLOutputType.MODEL_TENSOR) public class ModelTensorOutput extends MLOutput { private static final MLOutputType OUTPUT_TYPE = MLOutputType.MODEL_TENSOR; @@ -34,13 +35,12 @@ public ModelTensorOutput(List mlModelOutputs) { this.mlModelOutputs = mlModelOutputs; } - public ModelTensorOutput(StreamInput in) throws IOException { super(OUTPUT_TYPE); if (in.readBoolean()) { mlModelOutputs = new ArrayList<>(); int size = in.readInt(); - for (int i=0; i(); int size = in.readInt(); - for (int i=0; i targetResponse = resultFilter.getTargetResponse(); List targetResponsePositions = resultFilter.getTargetResponsePositions(); if ((targetResponse == null || targetResponse.size() == 0) - && (targetResponsePositions == null || targetResponsePositions.size() == 0)) { - mlModelTensors.forEach(output -> filter(output, returnBytes, returnNumber)); + && (targetResponsePositions == null || targetResponsePositions.size() == 0)) { + mlModelTensors.forEach(output -> filter(output, returnBytes, returnNumber)); return; } List targetOutput = new ArrayList<>(); if (mlModelTensors != null) { - for (int i = 0 ; i { public static final MLAgentDeleteAction INSTANCE = new MLAgentDeleteAction(); public static final String NAME = "cluster:admin/opensearch/ml/agents/delete"; - private MLAgentDeleteAction() { super(NAME, DeleteResponse::new);} + private MLAgentDeleteAction() { + super(NAME, DeleteResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java index ddc568fc60..9786dc8b3b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.agent; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLAgentDeleteRequest extends ActionRequest { @Getter @@ -54,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLAgentDeleteRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLAgentDeleteRequest) { - return (MLAgentDeleteRequest)actionRequest; + return (MLAgentDeleteRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLAgentDeleteRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java index 2a61035ce8..b30fea6ba4 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetAction.java @@ -11,6 +11,8 @@ public class MLAgentGetAction extends ActionType { public static final MLAgentGetAction INSTANCE = new MLAgentGetAction(); public static final String NAME = "cluster:admin/opensearch/ml/agents/get"; - private MLAgentGetAction() { super(NAME, MLAgentGetResponse::new);} + private MLAgentGetAction() { + super(NAME, MLAgentGetResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java index ea65a768df..d6923ac280 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.agent; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; @Getter public class MLAgentGetRequest extends ActionRequest { @@ -65,8 +66,7 @@ public static MLAgentGetRequest fromActionRequest(ActionRequest actionRequest) { return (MLAgentGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLAgentGetRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java index 593e314b31..4749cb02b2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponse.java @@ -5,8 +5,11 @@ package org.opensearch.ml.common.transport.agent; -import lombok.Builder; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -16,10 +19,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.agent.MLAgent; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; @Getter public class MLAgentGetResponse extends ActionResponse implements ToXContentObject { @@ -36,7 +37,7 @@ public MLAgentGetResponse(StreamInput in) throws IOException { } @Override - public void writeTo(StreamOutput out) throws IOException{ + public void writeTo(StreamOutput out) throws IOException { mlAgent.writeTo(out); } @@ -50,8 +51,7 @@ public static MLAgentGetResponse fromActionResponse(ActionResponse actionRespons return (MLAgentGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLAgentGetResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java index 4add7827d5..c73f2150aa 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.agent; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -18,12 +20,11 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.agent.MLAgent; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -63,8 +64,7 @@ public static MLRegisterAgentRequest fromActionRequest(ActionRequest actionReque return (MLRegisterAgentRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterAgentRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java index 7f8b633cbe..bea3be0c81 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.agent; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -14,10 +18,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Getter; @Getter public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject { @@ -31,7 +32,7 @@ public MLRegisterAgentResponse(StreamInput in) throws IOException { } public MLRegisterAgentResponse(String agentId) { - this.agentId= agentId; + this.agentId = agentId; } @Override @@ -52,8 +53,7 @@ public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionRe return (MLRegisterAgentResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterAgentResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetAction.java index 6287559c03..ec93cc4578 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetAction.java @@ -11,6 +11,8 @@ public class MLConfigGetAction extends ActionType { public static final MLConfigGetAction INSTANCE = new MLConfigGetAction(); public static final String NAME = "cluster:admin/opensearch/ml/config/get"; - private MLConfigGetAction() { super(NAME, MLConfigGetResponse::new);} + private MLConfigGetAction() { + super(NAME, MLConfigGetResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java index 0542c9480b..bfc1c156db 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.config; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; @Getter public class MLConfigGetRequest extends ActionRequest { @@ -58,8 +59,7 @@ public static MLConfigGetRequest fromActionRequest(ActionRequest actionRequest) return (MLConfigGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLConfigGetRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetResponse.java index 1fc353e54f..33e95cc474 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/config/MLConfigGetResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.config; -import lombok.Builder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -15,10 +19,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLConfig; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; public class MLConfigGetResponse extends ActionResponse implements ToXContentObject { MLConfig mlConfig; @@ -34,7 +35,7 @@ public MLConfigGetResponse(StreamInput in) throws IOException { } @Override - public void writeTo(StreamOutput out) throws IOException{ + public void writeTo(StreamOutput out) throws IOException { mlConfig.writeTo(out); } @@ -48,8 +49,7 @@ public static MLConfigGetResponse fromActionResponse(ActionResponse actionRespon return (MLConfigGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLConfigGetResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteAction.java index dc0c1044f1..b3db0ce9cf 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteAction.java @@ -12,5 +12,7 @@ public class MLConnectorDeleteAction extends ActionType { public static final MLConnectorDeleteAction INSTANCE = new MLConnectorDeleteAction(); public static final String NAME = "cluster:admin/opensearch/ml/connectors/delete"; - private MLConnectorDeleteAction() { super(NAME, DeleteResponse::new);} + private MLConnectorDeleteAction() { + super(NAME, DeleteResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java index 9da5be98aa..a1e3a6391e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLConnectorDeleteRequest extends ActionRequest { @Getter @@ -54,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLConnectorDeleteRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLConnectorDeleteRequest) { - return (MLConnectorDeleteRequest)actionRequest; + return (MLConnectorDeleteRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLConnectorDeleteRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetAction.java index da29dd86fe..6695e2ada1 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetAction.java @@ -11,6 +11,8 @@ public class MLConnectorGetAction extends ActionType { public static final MLConnectorGetAction INSTANCE = new MLConnectorGetAction(); public static final String NAME = "cluster:admin/opensearch/ml/connectors/get"; - private MLConnectorGetAction() { super(NAME, MLConnectorGetResponse::new);} + private MLConnectorGetAction() { + super(NAME, MLConnectorGetResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java index 118a70ccde..53c6c9c497 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; @Getter public class MLConnectorGetRequest extends ActionRequest { @@ -62,8 +63,7 @@ public static MLConnectorGetRequest fromActionRequest(ActionRequest actionReques return (MLConnectorGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLConnectorGetRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java index 492566d20a..dbd7c9b42c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -16,10 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.connector.Connector; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; public class MLConnectorGetResponse extends ActionResponse implements ToXContentObject { Connector mlConnector; @@ -35,7 +36,7 @@ public MLConnectorGetResponse(StreamInput in) throws IOException { } @Override - public void writeTo(StreamOutput out) throws IOException{ + public void writeTo(StreamOutput out) throws IOException { mlConnector.writeTo(out); } @@ -49,8 +50,7 @@ public static MLConnectorGetResponse fromActionResponse(ActionResponse actionRes return (MLConnectorGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLConnectorGetResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 007b65e286..697f27494f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -5,8 +5,15 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -21,14 +28,8 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorClientConfig; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; +import lombok.Builder; +import lombok.Data; @Data public class MLCreateConnectorInput implements ToXContentObject, Writeable { @@ -65,21 +66,21 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable { private boolean updateConnector; private ConnectorClientConfig connectorClientConfig; - @Builder(toBuilder = true) - public MLCreateConnectorInput(String name, - String description, - String version, - String protocol, - Map parameters, - Map credential, - List actions, - List backendRoles, - Boolean addAllBackendRoles, - AccessMode access, - boolean dryRun, - boolean updateConnector, - ConnectorClientConfig connectorClientConfig + public MLCreateConnectorInput( + String name, + String description, + String version, + String protocol, + Map parameters, + Map credential, + List actions, + List backendRoles, + Boolean addAllBackendRoles, + AccessMode access, + boolean dryRun, + boolean updateConnector, + ConnectorClientConfig connectorClientConfig ) { if (!dryRun && !updateConnector) { @@ -182,8 +183,21 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update break; } } - return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, - backendRoles, addAllBackendRoles, access, dryRun, updateConnector, connectorClientConfig); + return new MLCreateConnectorInput( + name, + description, + version, + protocol, + parameters, + credential, + actions, + backendRoles, + addAllBackendRoles, + access, + dryRun, + updateConnector, + connectorClientConfig + ); } @Override @@ -289,7 +303,7 @@ public MLCreateConnectorInput(StreamInput input) throws IOException { parameters = input.readMap(s -> s.readString(), s -> s.readString()); } if (input.readBoolean()) { - credential = input.readMap(s -> s.readString(), s-> s.readString()); + credential = input.readMap(s -> s.readString(), s -> s.readString()); } if (input.readBoolean()) { actions = new ArrayList<>(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java index 107d5001b8..e227c30478 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; @Getter public class MLCreateConnectorRequest extends ActionRequest { @@ -56,8 +57,7 @@ public static MLCreateConnectorRequest fromActionRequest(ActionRequest actionReq return (MLCreateConnectorRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLCreateConnectorRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java index 68ce877baa..08b1631853 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -14,10 +18,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Getter; @Getter public class MLCreateConnectorResponse extends ActionResponse implements ToXContentObject { @@ -53,8 +54,7 @@ public static MLCreateConnectorResponse fromActionResponse(ActionResponse action return (MLCreateConnectorResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLCreateConnectorResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java index ab7ffa9c9f..9b24115455 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.connector; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -19,12 +21,11 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskRequest; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(level = AccessLevel.PRIVATE) @@ -70,14 +71,12 @@ public ActionRequestValidationException validate() { return exception; } - public static MLExecuteConnectorRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLExecuteConnectorRequest) { return (MLExecuteConnectorRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLExecuteConnectorRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java index 9fa10a39c6..8609af2134 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java @@ -12,5 +12,7 @@ public class MLUpdateConnectorAction extends ActionType { public static final MLUpdateConnectorAction INSTANCE = new MLUpdateConnectorAction(); public static final String NAME = "cluster:admin/opensearch/ml/connectors/update"; - private MLUpdateConnectorAction() { super(NAME, UpdateResponse::new);} + private MLUpdateConnectorAction() { + super(NAME, UpdateResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java index 089180cdc5..8a365140de 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -15,12 +20,8 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentParser; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; @Getter public class MLUpdateConnectorRequest extends ActionRequest { @@ -72,8 +73,7 @@ public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionReq return (MLUpdateConnectorRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUpdateConnectorRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequest.java index 8fdb8bc564..9acf6393c2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.controller; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLControllerDeleteRequest extends ActionRequest { @Getter @@ -57,8 +58,7 @@ public static MLControllerDeleteRequest fromActionRequest(ActionRequest actionRe return (MLControllerDeleteRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLControllerDeleteRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetRequest.java index 86754c1732..c3e9fa648f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.controller; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -67,8 +68,7 @@ public static MLControllerGetRequest fromActionRequest(ActionRequest actionReque return (MLControllerGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLControllerGetRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponse.java index 7c07e91a1f..86cd878aba 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponse.java @@ -5,8 +5,11 @@ package org.opensearch.ml.common.transport.controller; -import lombok.Builder; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -17,10 +20,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.controller.MLController; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; public class MLControllerGetResponse extends ActionResponse implements ToXContentObject { @@ -52,8 +53,7 @@ public static MLControllerGetResponse fromActionResponse(ActionResponse actionRe return (MLControllerGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLControllerGetResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequest.java index efea44da24..a179a2ffcd 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequest.java @@ -4,11 +4,13 @@ */ package org.opensearch.ml.common.transport.controller; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.controller.MLController; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -59,8 +60,7 @@ public static MLCreateControllerRequest fromActionRequest(ActionRequest actionRe if (actionRequest instanceof MLCreateControllerRequest) { return (MLCreateControllerRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLCreateControllerRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponse.java index 592caf5d6b..29157ab65c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponse.java @@ -4,8 +4,11 @@ */ package org.opensearch.ml.common.transport.controller; -import lombok.Builder; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -14,10 +17,8 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; @Getter public class MLCreateControllerResponse extends ActionResponse implements ToXContentObject { @@ -60,8 +61,7 @@ public static MLCreateControllerResponse fromActionResponse(ActionResponse actio if (actionResponse instanceof MLCreateControllerResponse) { return (MLCreateControllerResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLCreateControllerResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeRequest.java index ceb5ba92c6..092b206e54 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeRequest.java @@ -6,11 +6,13 @@ package org.opensearch.ml.common.transport.controller; import java.io.IOException; -import lombok.Getter; + import org.opensearch.action.support.nodes.BaseNodeRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import lombok.Getter; + public class MLDeployControllerNodeRequest extends BaseNodeRequest { @Getter private MLDeployControllerNodesRequest deployControllerNodesRequest; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponse.java index 7b038a4a21..1f5a05cbd8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponse.java @@ -5,8 +5,9 @@ package org.opensearch.ml.common.transport.controller; -import lombok.Getter; -import lombok.extern.log4j.Log4j2; +import java.io.IOException; +import java.util.Map; + import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; @@ -14,8 +15,8 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Map; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; @Getter @Log4j2 diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequest.java index 1a70c53a90..1b42a62a8e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequest.java @@ -5,12 +5,14 @@ package org.opensearch.ml.common.transport.controller; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; + +import lombok.Getter; public class MLDeployControllerNodesRequest extends BaseNodesRequest { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponse.java index 50d60d1801..3dd9ea27c1 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponse.java @@ -5,6 +5,9 @@ package org.opensearch.ml.common.transport.controller; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; @@ -15,19 +18,17 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; - -public class MLDeployControllerNodesResponse extends BaseNodesResponse - implements ToXContentObject { +public class MLDeployControllerNodesResponse extends BaseNodesResponse implements ToXContentObject { public MLDeployControllerNodesResponse(StreamInput in) throws IOException { - super(new ClusterName(in), in.readList(MLDeployControllerNodeResponse::readStats), - in.readList(FailedNodeException::new)); + super(new ClusterName(in), in.readList(MLDeployControllerNodeResponse::readStats), in.readList(FailedNodeException::new)); } - public MLDeployControllerNodesResponse(ClusterName clusterName, List nodes, - List failures) { + public MLDeployControllerNodesResponse( + ClusterName clusterName, + List nodes, + List failures + ) { super(clusterName, nodes, failures); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeRequest.java index bea9c709ba..4c998e3608 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeRequest.java @@ -6,11 +6,13 @@ package org.opensearch.ml.common.transport.controller; import java.io.IOException; -import lombok.Getter; + import org.opensearch.action.support.nodes.BaseNodeRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import lombok.Getter; + public class MLUndeployControllerNodeRequest extends BaseNodeRequest { @Getter private MLUndeployControllerNodesRequest undeployControllerNodesRequest; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponse.java index 7438871caf..232fbe3e07 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponse.java @@ -5,8 +5,9 @@ package org.opensearch.ml.common.transport.controller; -import lombok.Getter; -import lombok.extern.log4j.Log4j2; +import java.io.IOException; +import java.util.Map; + import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; @@ -14,8 +15,8 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Map; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; @Getter @Log4j2 diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequest.java index af9785dcee..973883c392 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequest.java @@ -5,12 +5,14 @@ package org.opensearch.ml.common.transport.controller; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; + +import lombok.Getter; public class MLUndeployControllerNodesRequest extends BaseNodesRequest { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponse.java index 11996955c9..75586cfdd6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponse.java @@ -5,6 +5,9 @@ package org.opensearch.ml.common.transport.controller; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; @@ -15,19 +18,17 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; - -public class MLUndeployControllerNodesResponse extends BaseNodesResponse - implements ToXContentObject { +public class MLUndeployControllerNodesResponse extends BaseNodesResponse implements ToXContentObject { public MLUndeployControllerNodesResponse(StreamInput in) throws IOException { - super(new ClusterName(in), in.readList(MLUndeployControllerNodeResponse::readStats), - in.readList(FailedNodeException::new)); + super(new ClusterName(in), in.readList(MLUndeployControllerNodeResponse::readStats), in.readList(FailedNodeException::new)); } - public MLUndeployControllerNodesResponse(ClusterName clusterName, List nodes, - List failures) { + public MLUndeployControllerNodesResponse( + ClusterName clusterName, + List nodes, + List failures + ) { super(clusterName, nodes, failures); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequest.java index 5a067a1411..68ea3c57ad 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequest.java @@ -4,11 +4,13 @@ */ package org.opensearch.ml.common.transport.controller; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.controller.MLController; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -60,8 +61,7 @@ public static MLUpdateControllerRequest fromActionRequest(ActionRequest actionRe if (actionRequest instanceof MLUpdateControllerRequest) { return (MLUpdateControllerRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUpdateControllerRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInput.java index ae30dd45f5..d8ae7ab829 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInput.java @@ -6,13 +6,15 @@ package org.opensearch.ml.common.transport.deploy; import java.io.IOException; -import lombok.Builder; -import lombok.Data; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.ml.common.MLTask; +import lombok.Builder; +import lombok.Data; + @Data public class MLDeployModelInput implements Writeable { private String modelId; @@ -34,7 +36,15 @@ public MLDeployModelInput(StreamInput in) throws IOException { } @Builder - public MLDeployModelInput(String modelId, String taskId, String modelContentHash, Integer nodeCount, String coordinatingNodeId, Boolean isDeployToAllNodes, MLTask mlTask) { + public MLDeployModelInput( + String modelId, + String taskId, + String modelContentHash, + Integer nodeCount, + String coordinatingNodeId, + Boolean isDeployToAllNodes, + MLTask mlTask + ) { this.modelId = modelId; this.taskId = taskId; this.modelContentHash = modelContentHash; @@ -44,8 +54,7 @@ public MLDeployModelInput(String modelId, String taskId, String modelContentHash this.mlTask = mlTask; } - public MLDeployModelInput() { - } + public MLDeployModelInput() {} @Override public void writeTo(StreamOutput out) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java index ffa30072e0..9dc192bafa 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java @@ -5,12 +5,13 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodeRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; +import lombok.Getter; public class MLDeployModelNodeRequest extends BaseNodeRequest { @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java index 685fc43cf7..587332397d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java @@ -5,7 +5,9 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.extern.log4j.Log4j2; +import java.io.IOException; +import java.util.Map; + import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; @@ -13,8 +15,8 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Map; +import lombok.extern.log4j.Log4j2; + @Log4j2 public class MLDeployModelNodeResponse extends BaseNodeResponse implements ToXContentFragment { @@ -27,6 +29,7 @@ public MLDeployModelNodeResponse(DiscoveryNode node, Map modelDe super(node); this.modelDeployStatus = modelDeployStatus; } + /** * Constructor * diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequest.java index e2c8043b04..5f5c347dac 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequest.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; +import lombok.Getter; public class MLDeployModelNodesRequest extends BaseNodesRequest { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponse.java index c27abebfaf..be8d5cc1ed 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponse.java @@ -5,6 +5,9 @@ package org.opensearch.ml.common.transport.deploy; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; @@ -14,9 +17,6 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; - public class MLDeployModelNodesResponse extends BaseNodesResponse implements ToXContentObject { /** diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java index 2b8b2f51c1..c6aa04bc6a 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java @@ -5,11 +5,16 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -19,15 +24,11 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.MLTaskRequest; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.action.ValidateActions.addValidationError; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -43,7 +44,13 @@ public class MLDeployModelRequest extends MLTaskRequest { private final boolean isUserInitiatedDeployRequest; @Builder - public MLDeployModelRequest(String modelId, String[] modelNodeIds, boolean async, boolean dispatchTask, boolean isUserInitiatedDeployRequest) { + public MLDeployModelRequest( + String modelId, + String[] modelNodeIds, + boolean async, + boolean dispatchTask, + boolean isUserInitiatedDeployRequest + ) { super(dispatchTask); this.modelId = modelId; this.modelNodeIds = modelNodeIds; @@ -113,8 +120,7 @@ public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest return (MLDeployModelRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLDeployModelRequest(input); @@ -125,4 +131,4 @@ public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java index ca35af68f0..eefcee7de5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -16,10 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLTaskType; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Getter; @Getter public class MLDeployModelResponse extends ActionResponse implements ToXContentObject { @@ -41,7 +42,7 @@ public MLDeployModelResponse(StreamInput in) throws IOException { public MLDeployModelResponse(String taskId, MLTaskType mlTaskType, String status) { this.taskId = taskId; this.taskType = mlTaskType; - this.status= status; + this.status = status; } @Override @@ -68,8 +69,7 @@ public static MLDeployModelResponse fromActionResponse(ActionResponse actionResp return (MLDeployModelResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLDeployModelResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java index e772b78d2d..d998ea71de 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java @@ -5,29 +5,30 @@ package org.opensearch.ml.common.transport.execute; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.NonNull; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.transport.MLTaskRequest; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -64,7 +65,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public ActionRequestValidationException validate() { ActionRequestValidationException exception = null; - if(this.input == null) { + if (this.input == null) { exception = addValidationError("ML input can't be null", exception); } else { if (this.input.getFunctionName() == null) { @@ -75,14 +76,12 @@ public ActionRequestValidationException validate() { return exception; } - public static MLExecuteTaskRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLExecuteTaskRequest) { return (MLExecuteTaskRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLExecuteTaskRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponse.java index dafaf03281..e49116a8e6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponse.java @@ -5,10 +5,11 @@ package org.opensearch.ml.common.transport.execute; -import lombok.Builder; -import lombok.Getter; -import lombok.NonNull; -import lombok.ToString; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -20,10 +21,10 @@ import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.output.Output; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.ToString; @Getter @ToString @@ -62,8 +63,7 @@ public static MLExecuteTaskResponse fromActionResponse(ActionResponse actionResp return (MLExecuteTaskResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLExecuteTaskResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardInput.java b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardInput.java index 624cec3c7d..902603cf03 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardInput.java @@ -5,9 +5,8 @@ package org.opensearch.ml.common.transport.forward; -import lombok.Builder; -import lombok.Data; -import lombok.extern.log4j.Log4j2; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,7 +14,9 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; -import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.extern.log4j.Log4j2; @Data @Log4j2 @@ -32,9 +33,17 @@ public class MLForwardInput implements Writeable { private MLRegisterModelInput registerModelInput; @Builder(toBuilder = true) - public MLForwardInput(String taskId, String modelId, String workerNodeId, MLForwardRequestType requestType, - MLTask mlTask, MLInput modelInput, - String error, String[] workerNodes, MLRegisterModelInput registerModelInput) { + public MLForwardInput( + String taskId, + String modelId, + String workerNodeId, + MLForwardRequestType requestType, + MLTask mlTask, + MLInput modelInput, + String error, + String[] workerNodes, + MLRegisterModelInput registerModelInput + ) { this.taskId = taskId; this.modelId = modelId; this.workerNodeId = workerNodeId; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardRequest.java index 7d2949fd3a..c029e81bd2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardRequest.java @@ -5,12 +5,13 @@ package org.opensearch.ml.common.transport.forward; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -18,12 +19,12 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -64,8 +65,7 @@ public static MLForwardRequest fromActionRequest(ActionRequest actionRequest) { return (MLForwardRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLForwardRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardResponse.java index ff51103671..f873c8a4b9 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardResponse.java @@ -5,9 +5,11 @@ package org.opensearch.ml.common.transport.forward; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -18,10 +20,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.output.MLOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; @Getter @ToString @@ -36,7 +37,6 @@ public MLForwardResponse(String status, MLOutput mlOutput) { this.mlOutput = mlOutput; } - public MLForwardResponse(StreamInput in) throws IOException { super(in); status = in.readOptionalString(); @@ -70,8 +70,7 @@ public static MLForwardResponse fromActionResponse(ActionResponse actionResponse return (MLForwardResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLForwardResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteAction.java index 6886fc57d6..8374eb5f9f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteAction.java @@ -12,5 +12,7 @@ public class MLModelDeleteAction extends ActionType { public static final MLModelDeleteAction INSTANCE = new MLModelDeleteAction(); public static final String NAME = "cluster:admin/opensearch/ml/models/delete"; - private MLModelDeleteAction() { super(NAME, DeleteResponse::new);} + private MLModelDeleteAction() { + super(NAME, DeleteResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java index a42cf1d071..4c57c5912c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.model; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLModelDeleteRequest extends ActionRequest { @Getter @@ -54,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLModelDeleteRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLModelDeleteRequest) { - return (MLModelDeleteRequest)actionRequest; + return (MLModelDeleteRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLModelDeleteRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java index 37e3831404..dd47e8cdee 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java @@ -11,5 +11,7 @@ public class MLModelGetAction extends ActionType { public static final MLModelGetAction INSTANCE = new MLModelGetAction(); public static final String NAME = "cluster:admin/opensearch/ml/models/get"; - private MLModelGetAction() { super(NAME, MLModelGetResponse::new);} + private MLModelGetAction() { + super(NAME, MLModelGetResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java index 7cad570f1d..8a6e4c2c1e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.model; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -71,11 +72,10 @@ public ActionRequestValidationException validate() { public static MLModelGetRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLModelGetRequest) { - return (MLModelGetRequest)actionRequest; + return (MLModelGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLModelGetRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java index b9a1040474..ec91a4ea43 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java @@ -5,9 +5,11 @@ package org.opensearch.ml.common.transport.model; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -18,10 +20,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLModel; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; @Getter @ToString @@ -34,14 +35,13 @@ public MLModelGetResponse(MLModel mlModel) { this.mlModel = mlModel; } - public MLModelGetResponse(StreamInput in) throws IOException { super(in); mlModel = mlModel.fromStream(in); } @Override - public void writeTo(StreamOutput out) throws IOException{ + public void writeTo(StreamOutput out) throws IOException { mlModel.writeTo(out); } @@ -55,8 +55,7 @@ public static MLModelGetResponse fromActionResponse(ActionResponse actionRespons return (MLModelGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLModelGetResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index 03047cf692..81dae3a903 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -5,9 +5,14 @@ package org.opensearch.ml.common.transport.model; -import lombok.Data; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; +import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; + +import java.io.IOException; +import java.time.Instant; +import java.util.Map; + import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -17,22 +22,17 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.Guardrails; import org.opensearch.ml.common.model.MLDeploySetting; import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; -import java.io.IOException; -import java.time.Instant; -import java.util.HashMap; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; -import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; +import lombok.Builder; +import lombok.Data; +import lombok.Getter; @Data public class MLUpdateModelInput implements ToXContentObject, Writeable { @@ -74,10 +74,23 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { private Map modelInterface; @Builder(toBuilder = true) - public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, - Boolean isEnabled, MLRateLimiter rateLimiter, MLModelConfig modelConfig, MLDeploySetting deploySetting, - Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime, - Guardrails guardrails, Map modelInterface) { + public MLUpdateModelInput( + String modelId, + String description, + String version, + String name, + String modelGroupId, + Boolean isEnabled, + MLRateLimiter rateLimiter, + MLModelConfig modelConfig, + MLDeploySetting deploySetting, + Connector updatedConnector, + String connectorId, + MLCreateConnectorInput connector, + Instant lastUpdateTime, + Guardrails guardrails, + Map modelInterface + ) { this.modelId = modelId; this.description = description; this.version = version; @@ -160,7 +173,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (deploySetting != null) { builder.field(DEPLOY_SETTING_FIELD, deploySetting); } - // Notice that we serialize the updatedConnector to the connector field, in order to be compatible with original internal connector field format. + // Notice that we serialize the updatedConnector to the connector field, in order to be compatible with original internal connector + // field format. if (updatedConnector != null) { builder.field(CONNECTOR_FIELD, updatedConnector); } @@ -301,8 +315,22 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException } // Model ID can only be set through RestRequest. Model version can only be set // automatically. - return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, isEnabled, rateLimiter, - modelConfig, deploySetting, updatedConnector, connectorId, connector, lastUpdateTime, guardrails, - modelInterface); + return new MLUpdateModelInput( + modelId, + description, + version, + name, + modelGroupId, + isEnabled, + rateLimiter, + modelConfig, + deploySetting, + updatedConnector, + connectorId, + connector, + lastUpdateTime, + guardrails, + modelInterface + ); } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java index b589f71ed4..61524689f7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.model; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,18 +19,17 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString public class MLUpdateModelRequest extends ActionRequest { - + MLUpdateModelInput updateModelInput; @Builder @@ -57,13 +58,12 @@ public void writeTo(StreamOutput out) throws IOException { this.updateModelInput.writeTo(out); } - public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest){ + public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLUpdateModelRequest) { return (MLUpdateModelRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput in = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUpdateModelRequest(in); @@ -72,4 +72,4 @@ public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest throw new UncheckedIOException("Failed to parse ActionRequest into MLUpdateModelRequest", e); } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteAction.java index 7acd877c3a..434ace5a63 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteAction.java @@ -12,5 +12,7 @@ public class MLModelGroupDeleteAction extends ActionType { public static final MLModelGroupDeleteAction INSTANCE = new MLModelGroupDeleteAction(); public static final String NAME = "cluster:admin/opensearch/ml/model_groups/delete"; - private MLModelGroupDeleteAction() { super(NAME, DeleteResponse::new);} + private MLModelGroupDeleteAction() { + super(NAME, DeleteResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java index 8c5326ab8d..86a1d093ee 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLModelGroupDeleteRequest extends ActionRequest { @Getter @@ -54,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLModelGroupDeleteRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLModelGroupDeleteRequest) { - return (MLModelGroupDeleteRequest)actionRequest; + return (MLModelGroupDeleteRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLModelGroupDeleteRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java index 2a8177eda5..15bdb1cee1 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetAction.java @@ -11,5 +11,7 @@ public class MLModelGroupGetAction extends ActionType { public static final MLModelGroupGetAction INSTANCE = new MLModelGroupGetAction(); public static final String NAME = "cluster:admin/opensearch/ml/model_groups/get"; - private MLModelGroupGetAction() { super(NAME, MLModelGroupGetResponse::new);} -} \ No newline at end of file + private MLModelGroupGetAction() { + super(NAME, MLModelGroupGetResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java index 265b16b9d1..a3a3d0fa57 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -60,11 +61,10 @@ public ActionRequestValidationException validate() { public static MLModelGroupGetRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLModelGroupGetRequest) { - return (MLModelGroupGetRequest)actionRequest; + return (MLModelGroupGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLModelGroupGetRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponse.java index 90775e09c4..45581b3c12 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponse.java @@ -5,9 +5,11 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -17,10 +19,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLModelGroup; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; @Getter @ToString @@ -33,14 +34,13 @@ public MLModelGroupGetResponse(MLModelGroup mlModelGroup) { this.mlModelGroup = mlModelGroup; } - public MLModelGroupGetResponse(StreamInput in) throws IOException { super(in); mlModelGroup = mlModelGroup.fromStream(in); } @Override - public void writeTo(StreamOutput out) throws IOException{ + public void writeTo(StreamOutput out) throws IOException { mlModelGroup.writeTo(out); } @@ -54,8 +54,7 @@ public static MLModelGroupGetResponse fromActionResponse(ActionResponse actionRe return (MLModelGroupGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLModelGroupGetResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java index c686d4bef5..8f4162f11f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java @@ -5,8 +5,14 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Objects; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,22 +21,17 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; -import java.util.Objects; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -public class MLRegisterModelGroupInput implements ToXContentObject, Writeable{ +public class MLRegisterModelGroupInput implements ToXContentObject, Writeable { - public static final String NAME_FIELD = "name"; //mandatory - public static final String DESCRIPTION_FIELD = "description"; //optional - public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional - public static final String MODEL_ACCESS_MODE = "access_mode"; //optional - public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional + public static final String NAME_FIELD = "name"; // mandatory + public static final String DESCRIPTION_FIELD = "description"; // optional + public static final String BACKEND_ROLES_FIELD = "backend_roles"; // optional + public static final String MODEL_ACCESS_MODE = "access_mode"; // optional + public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; // optional private String name; private String description; @@ -39,7 +40,13 @@ public class MLRegisterModelGroupInput implements ToXContentObject, Writeable{ private Boolean isAddAllBackendRoles; @Builder(toBuilder = true) - public MLRegisterModelGroupInput(String name, String description, List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { + public MLRegisterModelGroupInput( + String name, + String description, + List backendRoles, + AccessMode modelAccessMode, + Boolean isAddAllBackendRoles + ) { this.name = Objects.requireNonNull(name, "model group name must not be null"); this.description = description; this.backendRoles = backendRoles; @@ -47,7 +54,7 @@ public MLRegisterModelGroupInput(String name, String description, List b this.isAddAllBackendRoles = isAddAllBackendRoles; } - public MLRegisterModelGroupInput(StreamInput in) throws IOException{ + public MLRegisterModelGroupInput(StreamInput in) throws IOException { this.name = in.readString(); this.description = in.readOptionalString(); this.backendRoles = in.readOptionalStringList(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java index d3394191e0..4ecfa46b4b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -62,8 +63,7 @@ public static MLRegisterModelGroupRequest fromActionRequest(ActionRequest action return (MLRegisterModelGroupRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelGroupRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java index 01c63d18de..83aace89f6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -14,10 +18,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Getter; @Getter public class MLRegisterModelGroupResponse extends ActionResponse implements ToXContentObject { @@ -37,7 +38,7 @@ public MLRegisterModelGroupResponse(StreamInput in) throws IOException { public MLRegisterModelGroupResponse(String modelGroupId, String status) { this.modelGroupId = modelGroupId; - this.status= status; + this.status = status; } @Override @@ -60,8 +61,7 @@ public static MLRegisterModelGroupResponse fromActionResponse(ActionResponse act return (MLRegisterModelGroupResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelGroupResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java index 22e612a5b1..3dd92082c8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,23 +20,18 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data public class MLUpdateModelGroupInput implements ToXContentObject, Writeable { - public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //mandatory - public static final String NAME_FIELD = "name"; //optional - public static final String DESCRIPTION_FIELD = "description"; //optional - public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional - public static final String MODEL_ACCESS_MODE = "access_mode"; //optional - public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; //optional - + public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // mandatory + public static final String NAME_FIELD = "name"; // optional + public static final String DESCRIPTION_FIELD = "description"; // optional + public static final String BACKEND_ROLES_FIELD = "backend_roles"; // optional + public static final String MODEL_ACCESS_MODE = "access_mode"; // optional + public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; // optional private String modelGroupID; private String name; @@ -41,7 +41,14 @@ public class MLUpdateModelGroupInput implements ToXContentObject, Writeable { private Boolean isAddAllBackendRoles; @Builder(toBuilder = true) - public MLUpdateModelGroupInput(String modelGroupID, String name, String description, List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { + public MLUpdateModelGroupInput( + String modelGroupID, + String name, + String description, + List backendRoles, + AccessMode modelAccessMode, + Boolean isAddAllBackendRoles + ) { this.modelGroupID = modelGroupID; this.name = name; this.description = description; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java index aecb62a8d2..e3f103dcf3 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -62,8 +63,7 @@ public static MLUpdateModelGroupRequest fromActionRequest(ActionRequest actionRe return (MLUpdateModelGroupRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUpdateModelGroupRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponse.java index 23bec3b0aa..fbe5795c4f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponse.java @@ -5,14 +5,15 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Getter; @Getter public class MLUpdateModelGroupResponse extends ActionResponse implements ToXContentObject { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/package-info.java b/common/src/main/java/org/opensearch/ml/common/transport/package-info.java index 77111bf8f4..d01f4f9512 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/package-info.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/package-info.java @@ -3,4 +3,4 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.transport; \ No newline at end of file +package org.opensearch.ml.common.transport; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java index 8355f1485c..ffb45a50b2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java @@ -11,12 +11,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.commons.authuser.User; @@ -27,6 +22,13 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskRequest; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + @Getter @FieldDefaults(level = AccessLevel.PRIVATE) @ToString @@ -87,14 +89,12 @@ public ActionRequestValidationException validate() { return exception; } - public static MLPredictionTaskRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLPredictionTaskRequest) { return (MLPredictionTaskRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLPredictionTaskRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index 9eb5ba6b4f..2db003f005 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -5,8 +5,18 @@ package org.opensearch.ml.common.transport.register; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; +import static org.opensearch.ml.common.connector.Connector.createConnector; +import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -19,28 +29,17 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.Guardrails; -import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLDeploySetting; -import org.opensearch.ml.common.controller.MLRateLimiter; +import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.QuestionAnsweringModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; -import static org.opensearch.ml.common.connector.Connector.createConnector; -import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; - +import lombok.Builder; +import lombok.Data; /** * ML input data: algirithm name, parameters and input data set. @@ -105,29 +104,31 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private Map modelInterface; @Builder(toBuilder = true) - public MLRegisterModelInput(FunctionName functionName, - String modelName, - String modelGroupId, - String version, - String description, - Boolean isEnabled, - MLRateLimiter rateLimiter, - String url, - String hashValue, - MLModelFormat modelFormat, - MLModelConfig modelConfig, - MLDeploySetting deploySetting, - boolean deployModel, - String[] modelNodeIds, - Connector connector, - String connectorId, - List backendRoles, - Boolean addAllBackendRoles, - AccessMode accessMode, - Boolean doesVersionCreateModelGroup, - Boolean isHidden, - Guardrails guardrails, - Map modelInterface) { + public MLRegisterModelInput( + FunctionName functionName, + String modelName, + String modelGroupId, + String version, + String description, + Boolean isEnabled, + MLRateLimiter rateLimiter, + String url, + String hashValue, + MLModelFormat modelFormat, + MLModelConfig modelConfig, + MLDeploySetting deploySetting, + boolean deployModel, + String[] modelNodeIds, + Connector connector, + String connectorId, + List backendRoles, + Boolean addAllBackendRoles, + AccessMode accessMode, + Boolean doesVersionCreateModelGroup, + Boolean isHidden, + Guardrails guardrails, + Map modelInterface + ) { this.functionName = Objects.requireNonNullElse(functionName, FunctionName.TEXT_EMBEDDING); if (modelName == null) { throw new IllegalArgumentException("model name is null"); @@ -136,11 +137,13 @@ public MLRegisterModelInput(FunctionName functionName, if (modelFormat == null) { throw new IllegalArgumentException("model format is null"); } - if (url != null && modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE - && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model - // configuration. Currently, we only support one - // type of sparse model, which is pretrained, and - // it doesn't necessitate a model configuration. + if (url != null + && modelConfig == null + && functionName != FunctionName.SPARSE_TOKENIZE + && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model + // configuration. Currently, we only support one + // type of sparse model, which is pretrained, and + // it doesn't necessitate a model configuration. throw new IllegalArgumentException("model config is null"); } } @@ -378,8 +381,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static MLRegisterModelInput parse(XContentParser parser, String modelName, String version, - boolean deployModel) throws IOException { + public static MLRegisterModelInput parse(XContentParser parser, String modelName, String version, boolean deployModel) + throws IOException { FunctionName functionName = null; String modelGroupId = null; Boolean isEnabled = null; @@ -481,10 +484,31 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName break; } } - return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, isEnabled, - rateLimiter, url, hashValue, modelFormat, modelConfig, deploySetting, deployModel, modelNodeIds.toArray(new String[0]), - connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, - isHidden, guardrails, modelInterface); + return new MLRegisterModelInput( + functionName, + modelName, + modelGroupId, + version, + description, + isEnabled, + rateLimiter, + url, + hashValue, + modelFormat, + modelConfig, + deploySetting, + deployModel, + modelNodeIds.toArray(new String[0]), + connector, + connectorId, + backendRoles, + addAllBackendRoles, + accessMode, + doesVersionCreateModelGroup, + isHidden, + guardrails, + modelInterface + ); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -598,9 +622,30 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo break; } } - return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, isEnabled, rateLimiter, - url, hashValue, modelFormat, modelConfig, deploySetting, deployModel, modelNodeIds.toArray(new String[0]), - connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, - isHidden, guardrails, modelInterface); + return new MLRegisterModelInput( + functionName, + name, + modelGroupId, + version, + description, + isEnabled, + rateLimiter, + url, + hashValue, + modelFormat, + modelConfig, + deploySetting, + deployModel, + modelNodeIds.toArray(new String[0]), + connector, + connectorId, + backendRoles, + addAllBackendRoles, + accessMode, + doesVersionCreateModelGroup, + isHidden, + guardrails, + modelInterface + ); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java index b57b65c524..adff46812f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.register; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -62,8 +63,7 @@ public static MLRegisterModelRequest fromActionRequest(ActionRequest actionReque return (MLRegisterModelRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java index 18c64c6c5f..2714ddef3e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.register; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -14,12 +18,8 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.transport.MLTaskResponse; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Getter; @Getter public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject { @@ -40,12 +40,12 @@ public MLRegisterModelResponse(StreamInput in) throws IOException { public MLRegisterModelResponse(String taskId, String status) { this.taskId = taskId; - this.status= status; + this.status = status; } public MLRegisterModelResponse(String taskId, String status, String modelId) { this.taskId = taskId; - this.status= status; + this.status = status; this.modelId = modelId; } @@ -73,8 +73,7 @@ public static MLRegisterModelResponse fromActionResponse(ActionResponse actionRe return (MLRegisterModelResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java index 7ad34321b8..017b97761b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java @@ -5,16 +5,16 @@ package org.opensearch.ml.common.transport.sync; -import lombok.Builder; -import lombok.Data; -import org.opensearch.Version; +import java.io.IOException; +import java.util.Map; +import java.util.Set; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import java.io.IOException; -import java.util.Map; -import java.util.Set; +import lombok.Builder; +import lombok.Data; @Data public class MLSyncUpInput implements Writeable { @@ -37,14 +37,16 @@ public class MLSyncUpInput implements Writeable { private Map deployToAllNodes; @Builder - public MLSyncUpInput(boolean getDeployedModels, - Map addedWorkerNodes, - Map removedWorkerNodes, - Map> modelRoutingTable, - Map> runningDeployModelTasks, - Map deployToAllNodes, - boolean clearRoutingTable, - boolean syncRunningDeployModelTasks) { + public MLSyncUpInput( + boolean getDeployedModels, + Map addedWorkerNodes, + Map removedWorkerNodes, + Map> modelRoutingTable, + Map> runningDeployModelTasks, + Map deployToAllNodes, + boolean clearRoutingTable, + boolean syncRunningDeployModelTasks + ) { this.getDeployedModels = getDeployedModels; this.addedWorkerNodes = addedWorkerNodes; this.removedWorkerNodes = removedWorkerNodes; @@ -55,7 +57,7 @@ public MLSyncUpInput(boolean getDeployedModels, this.syncRunningDeployModelTasks = syncRunningDeployModelTasks; } - public MLSyncUpInput(){} + public MLSyncUpInput() {} public MLSyncUpInput(StreamInput in) throws IOException { this.getDeployedModels = in.readBoolean(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java index a7a468502a..d7042d4ad4 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java @@ -5,12 +5,13 @@ package org.opensearch.ml.common.transport.sync; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodeRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; +import lombok.Getter; public class MLSyncUpNodeRequest extends BaseNodeRequest { @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java index 74893ec91e..8f8d960b38 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java @@ -5,8 +5,8 @@ package org.opensearch.ml.common.transport.sync; -import lombok.Getter; -import lombok.extern.log4j.Log4j2; +import java.io.IOException; + import org.opensearch.Version; import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; @@ -14,11 +14,12 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.model.MLDeploySetting; -import java.io.IOException; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; @Log4j2 @Getter -public class MLSyncUpNodeResponse extends BaseNodeResponse { +public class MLSyncUpNodeResponse extends BaseNodeResponse { private String modelStatus; private String[] deployedModelIds; @@ -26,8 +27,14 @@ public class MLSyncUpNodeResponse extends BaseNodeResponse { private String[] runningDeployModelTaskIds; // deploy model task ids which is running private String[] expiredModelIds; - public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] deployedModelIds, String[] runningDeployModelIds, - String[] runningDeployModelTaskIds, String[] expiredModelIds) { + public MLSyncUpNodeResponse( + DiscoveryNode node, + String modelStatus, + String[] deployedModelIds, + String[] runningDeployModelIds, + String[] runningDeployModelTaskIds, + String[] expiredModelIds + ) { super(node); this.modelStatus = modelStatus; this.deployedModelIds = deployedModelIds; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesRequest.java index 56ec920f5f..d66af5d8f7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesRequest.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.transport.sync; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; +import lombok.Getter; public class MLSyncUpNodesRequest extends BaseNodesRequest { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponse.java index dee614685c..ecfd42f464 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponse.java @@ -5,15 +5,15 @@ package org.opensearch.ml.common.transport.sync; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.util.List; - public class MLSyncUpNodesResponse extends BaseNodesResponse { public MLSyncUpNodesResponse(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponse.java index 6c4f4ed82f..1353edcb83 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponse.java @@ -5,7 +5,8 @@ package org.opensearch.ml.common.transport.sync; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -13,7 +14,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Getter; @Getter public class MLSyncUpResponse extends ActionResponse implements ToXContentObject { @@ -27,7 +28,7 @@ public MLSyncUpResponse(StreamInput in) throws IOException { } public MLSyncUpResponse(String status) { - this.status= status; + this.status = status; } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteAction.java index 7b00b6509a..5aed868589 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteAction.java @@ -12,5 +12,7 @@ public class MLTaskDeleteAction extends ActionType { public static final MLTaskDeleteAction INSTANCE = new MLTaskDeleteAction(); public static final String NAME = "cluster:admin/opensearch/ml/tasks/delete"; - private MLTaskDeleteAction() { super(NAME, DeleteResponse::new);} + private MLTaskDeleteAction() { + super(NAME, DeleteResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteRequest.java index a7782a60ea..b109c52b42 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.task; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLTaskDeleteRequest extends ActionRequest { @Getter @@ -54,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLTaskDeleteRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLTaskDeleteRequest) { - return (MLTaskDeleteRequest)actionRequest; + return (MLTaskDeleteRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLTaskDeleteRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetAction.java index 4aaa143a1f..2d76df4dc7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetAction.java @@ -11,5 +11,7 @@ public class MLTaskGetAction extends ActionType { public static final MLTaskGetAction INSTANCE = new MLTaskGetAction(); public static final String NAME = "cluster:admin/opensearch/ml/tasks/get"; - private MLTaskGetAction() { super(NAME, MLTaskGetResponse::new);} + private MLTaskGetAction() { + super(NAME, MLTaskGetResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java index 06145adef7..3feb5c661d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java @@ -5,9 +5,13 @@ package org.opensearch.ml.common.transport.task; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; -import lombok.Builder; -import lombok.Getter; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -15,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLTaskGetRequest extends ActionRequest { @Getter @@ -55,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLTaskGetRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLTaskGetRequest) { - return (MLTaskGetRequest)actionRequest; + return (MLTaskGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLTaskGetRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetResponse.java index cc4d51192a..071ab82682 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetResponse.java @@ -5,8 +5,11 @@ package org.opensearch.ml.common.transport.task; -import lombok.Builder; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -17,10 +20,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLTask; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; @Getter public class MLTaskGetResponse extends ActionResponse implements ToXContentObject { @@ -37,7 +38,7 @@ public MLTaskGetResponse(StreamInput in) throws IOException { } @Override - public void writeTo(StreamOutput out) throws IOException{ + public void writeTo(StreamOutput out) throws IOException { mlTask.writeTo(out); } @@ -51,8 +52,7 @@ public static MLTaskGetResponse fromActionResponse(ActionResponse actionResponse return (MLTaskGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLTaskGetResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskSearchAction.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskSearchAction.java index 13b38aa687..cd2636e991 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskSearchAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskSearchAction.java @@ -2,7 +2,6 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ml.common.transport.model.MLModelSearchAction; public class MLTaskSearchAction extends ActionType { // External Action which used for public facing RestAPIs. diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLGetToolAction.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLGetToolAction.java index 468d53d34a..bd61708de8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLGetToolAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLGetToolAction.java @@ -13,4 +13,4 @@ public class MLGetToolAction extends ActionType { public MLGetToolAction() { super(NAME, MLToolGetResponse::new); } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLListToolsAction.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLListToolsAction.java index 3ec6b4c99e..58e9f300cf 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLListToolsAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLListToolsAction.java @@ -13,4 +13,4 @@ public class MLListToolsAction extends ActionType { public MLListToolsAction() { super(NAME, MLToolsListResponse::new); } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetRequest.java index e89e506fe3..d64d591f81 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetRequest.java @@ -4,11 +4,14 @@ */ package org.opensearch.ml.common.transport.tools; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,13 +20,11 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.ToolMetadata; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.List; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -66,11 +67,10 @@ public ActionRequestValidationException validate() { public static MLToolGetRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLToolGetRequest) { - return (MLToolGetRequest)actionRequest; + return (MLToolGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLToolGetRequest(input); @@ -80,5 +80,4 @@ public static MLToolGetRequest fromActionRequest(ActionRequest actionRequest) { } } - -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetResponse.java index d4623039c8..183a5a584c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolGetResponse.java @@ -5,9 +5,11 @@ package org.opensearch.ml.common.transport.tools; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -18,10 +20,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.ToolMetadata; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; @Getter @ToString @@ -54,8 +55,7 @@ public static MLToolGetResponse fromActionResponse(ActionResponse actionResponse return (MLToolGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLToolGetResponse(input); @@ -64,4 +64,4 @@ public static MLToolGetResponse fromActionResponse(ActionResponse actionResponse throw new UncheckedIOException("failed to parse ActionResponse into MLToolGetResponse", e); } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListRequest.java index 49575aaac2..90d94f17a4 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListRequest.java @@ -5,11 +5,11 @@ package org.opensearch.ml.common.transport.tools; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; @@ -19,11 +19,11 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.ToolMetadata; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.List; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -55,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLToolsListRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLToolsListRequest) { - return (MLToolsListRequest)actionRequest; + return (MLToolsListRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLToolsListRequest(input); @@ -69,4 +68,4 @@ public static MLToolsListRequest fromActionRequest(ActionRequest actionRequest) } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListResponse.java index 6f2f3cad00..87e203e7ae 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/tools/MLToolsListResponse.java @@ -4,9 +4,12 @@ */ package org.opensearch.ml.common.transport.tools; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -17,11 +20,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.ToolMetadata; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.List; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; @Getter @ToString @@ -33,6 +34,7 @@ public class MLToolsListResponse extends ActionResponse implements ToXContentObj public MLToolsListResponse(List toolMetadata) { this.toolMetadataList = toolMetadata; } + public MLToolsListResponse(StreamInput in) throws IOException { super(in); this.toolMetadataList = in.readList(ToolMetadata::new); @@ -51,7 +53,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Pa xContentBuilder.field(ToolMetadata.TOOL_NAME_FIELD, toolMetadata.getName()); xContentBuilder.field(ToolMetadata.TOOL_DESCRIPTION_FIELD, toolMetadata.getDescription()); xContentBuilder.field(ToolMetadata.TOOL_TYPE_FIELD, toolMetadata.getType()); - xContentBuilder.field(ToolMetadata.TOOL_VERSION_FIELD, toolMetadata.getVersion() != null ? toolMetadata.getVersion() : "undefined"); + xContentBuilder + .field(ToolMetadata.TOOL_VERSION_FIELD, toolMetadata.getVersion() != null ? toolMetadata.getVersion() : "undefined"); xContentBuilder.endObject(); } xContentBuilder.endArray(); @@ -63,15 +66,13 @@ public static MLToolsListResponse fromActionResponse(ActionResponse actionRespon return (MLToolsListResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLToolsListResponse(input); } - } - catch (IOException e) { + } catch (IOException e) { throw new UncheckedIOException("failed to parse ActionResponse into MLToolsListResponse", e); } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java index e03f1e8dda..45012f09eb 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java @@ -5,11 +5,14 @@ package org.opensearch.ml.common.transport.training; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Objects; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -19,13 +22,11 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskRequest; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Objects; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -79,8 +80,7 @@ public static MLTrainingTaskRequest fromActionRequest(ActionRequest actionReques return (MLTrainingTaskRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLTrainingTaskRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInput.java index d0e399f291..c08dffc336 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInput.java @@ -5,8 +5,12 @@ package org.opensearch.ml.common.transport.undeploy; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,11 +19,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data public class MLUndeployModelInput implements ToXContentObject, Writeable { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeRequest.java index 58a20af248..cff8d7bd90 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeRequest.java @@ -6,11 +6,13 @@ package org.opensearch.ml.common.transport.undeploy; import java.io.IOException; -import lombok.Getter; + import org.opensearch.action.support.nodes.BaseNodeRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import lombok.Getter; + public class MLUndeployModelNodeRequest extends BaseNodeRequest { @Getter private MLUndeployModelNodesRequest mlUndeployModelNodesRequest; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java index 99a7f39882..b4f32d2302 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java @@ -5,7 +5,9 @@ package org.opensearch.ml.common.transport.undeploy; -import lombok.Getter; +import java.io.IOException; +import java.util.Map; + import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; @@ -14,8 +16,7 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Map; +import lombok.Getter; @Getter public class MLUndeployModelNodeResponse extends BaseNodeResponse implements ToXContentFragment { @@ -26,9 +27,10 @@ public class MLUndeployModelNodeResponse extends BaseNodeResponse implements ToX // This is to record before undeploy the model, which nodes are working nodes. private Map modelWorkerNodeBeforeRemoval; - public MLUndeployModelNodeResponse(DiscoveryNode node, - Map modelUndeployStatus, - Map modelWorkerNodeBeforeRemoval + public MLUndeployModelNodeResponse( + DiscoveryNode node, + Map modelUndeployStatus, + Map modelWorkerNodeBeforeRemoval ) { super(node); this.modelUndeployStatus = modelUndeployStatus; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java index cea0d484fe..48b2bf7c5c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.transport.undeploy; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; +import lombok.Getter; public class MLUndeployModelNodesRequest extends BaseNodesRequest { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponse.java index 3728f4dd8e..22976eebf5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponse.java @@ -5,6 +5,9 @@ package org.opensearch.ml.common.transport.undeploy; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; @@ -15,16 +18,17 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; - public class MLUndeployModelNodesResponse extends BaseNodesResponse implements ToXContentObject { public MLUndeployModelNodesResponse(StreamInput in) throws IOException { super(new ClusterName(in), in.readList(MLUndeployModelNodeResponse::readStats), in.readList(FailedNodeException::new)); } - public MLUndeployModelNodesResponse(ClusterName clusterName, List nodes, List failures) { + public MLUndeployModelNodesResponse( + ClusterName clusterName, + List nodes, + List failures + ) { super(clusterName, nodes, failures); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java index c586698025..32fdfced27 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java @@ -13,11 +13,7 @@ import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.List; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -27,6 +23,12 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.MLTaskRequest; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString @@ -107,8 +109,7 @@ public static MLUndeployModelsRequest fromActionRequest(ActionRequest actionRequ return (MLUndeployModelsRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUndeployModelsRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java index 71fc7ef38b..ed3f503f68 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.undeploy; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -14,10 +18,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Getter; @Getter public class MLUndeployModelsResponse extends ActionResponse implements ToXContentObject { @@ -60,8 +61,7 @@ public static MLUndeployModelsResponse fromActionResponse(ActionResponse actionR return (MLUndeployModelsResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUndeployModelsResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java index 8cccab63d6..eccac151f0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheAction.java @@ -11,5 +11,7 @@ public class MLUpdateModelCacheAction extends ActionType { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponse.java index 6e26174eb6..00dffb0d98 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponse.java @@ -5,6 +5,9 @@ package org.opensearch.ml.common.transport.update_cache; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; @@ -15,16 +18,17 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; - public class MLUpdateModelCacheNodesResponse extends BaseNodesResponse implements ToXContentObject { public MLUpdateModelCacheNodesResponse(StreamInput in) throws IOException { super(new ClusterName(in), in.readList(MLUpdateModelCacheNodeResponse::readStats), in.readList(FailedNodeException::new)); } - public MLUpdateModelCacheNodesResponse(ClusterName clusterName, List nodes, List failures) { + public MLUpdateModelCacheNodesResponse( + ClusterName clusterName, + List nodes, + List failures + ) { super(clusterName, nodes, failures); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java index 3ee8b66805..3fdd8bb09e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java @@ -15,4 +15,4 @@ private MLRegisterModelMetaAction() { super(NAME, MLRegisterModelMetaResponse::new); } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index d56120aa5c..ffde87f594 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -5,8 +5,16 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; +import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; + import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -15,28 +23,20 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLDeploySetting; import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.QuestionAnsweringModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.MLModel.allowedInterfaceFieldKeys; -import static org.opensearch.ml.common.utils.StringUtils.filteredParameterMap; +import lombok.Builder; +import lombok.Data; @Data public class MLRegisterModelMetaInput implements ToXContentObject, Writeable { @@ -86,13 +86,28 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable { private Map modelInterface; @Builder(toBuilder = true) - public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, - String description, Boolean isEnabled, MLRateLimiter rateLimiter, MLModelFormat modelFormat, - MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, - MLModelConfig modelConfig, MLDeploySetting deploySetting, Integer totalChunks, List backendRoles, - AccessMode accessMode, - Boolean isAddAllBackendRoles, - Boolean doesVersionCreateModelGroup, Boolean isHidden, Map modelInterface) { + public MLRegisterModelMetaInput( + String name, + FunctionName functionName, + String modelGroupId, + String version, + String description, + Boolean isEnabled, + MLRateLimiter rateLimiter, + MLModelFormat modelFormat, + MLModelState modelState, + Long modelContentSizeInBytes, + String modelContentHashValue, + MLModelConfig modelConfig, + MLDeploySetting deploySetting, + Integer totalChunks, + List backendRoles, + AccessMode accessMode, + Boolean isAddAllBackendRoles, + Boolean doesVersionCreateModelGroup, + Boolean isHidden, + Map modelInterface + ) { if (name == null) { throw new IllegalArgumentException("model name is null"); } @@ -107,11 +122,32 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m if (modelContentHashValue == null) { throw new IllegalArgumentException("model content hash value is null"); } - if (modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE - && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model - // configuration. Currently, we only support one type - // of sparse model, which is pretrained, and it - // doesn't necessitate a model configuration. + if (modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The + // tokenize + // model + // doesn't + // require + // a + // model + // configuration. + // Currently, + // we + // only + // support + // one + // type + // of + // sparse + // model, + // which + // is + // pretrained, + // and it + // doesn't + // necessitate + // a + // model + // configuration. throw new IllegalArgumentException("model config is null"); } if (totalChunks == null) { @@ -415,10 +451,28 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc break; } } - return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, isEnabled, - rateLimiter, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, - deploySetting, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup, - isHidden, modelInterface); + return new MLRegisterModelMetaInput( + name, + functionName, + modelGroupId, + version, + description, + isEnabled, + rateLimiter, + modelFormat, + modelState, + modelContentSizeInBytes, + modelContentHashValue, + modelConfig, + deploySetting, + totalChunks, + backendRoles, + accessMode, + isAddAllBackendRoles, + doesVersionCreateModelGroup, + isHidden, + modelInterface + ); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java index dbfc9283fc..19558cfb60 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -62,8 +63,7 @@ public static MLRegisterModelMetaRequest fromActionRequest(ActionRequest actionR return (MLRegisterModelMetaRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelMetaRequest(input); @@ -73,4 +73,4 @@ public static MLRegisterModelMetaRequest fromActionRequest(ActionRequest actionR } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java index 4ac6220e55..62c3ef1b7f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java @@ -5,7 +5,8 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -13,7 +14,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Getter; public class MLRegisterModelMetaResponse extends ActionResponse implements ToXContentObject { @@ -33,7 +34,7 @@ public MLRegisterModelMetaResponse(StreamInput in) throws IOException { public MLRegisterModelMetaResponse(String modelId, String status) { this.modelId = modelId; - this.status= status; + this.status = status; } @Override @@ -50,4 +51,4 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.endObject(); return builder; } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkAction.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkAction.java index e6337f1347..1658bb6483 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkAction.java @@ -5,7 +5,6 @@ package org.opensearch.ml.common.transport.upload_chunk; - import org.opensearch.action.ActionType; public class MLUploadModelChunkAction extends ActionType { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInput.java index 256c4b1fe4..8f1392895f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInput.java @@ -5,8 +5,10 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,9 +17,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data public class MLUploadModelChunkInput implements ToXContentObject, Writeable { @@ -37,7 +38,6 @@ public MLUploadModelChunkInput(String modelId, Integer chunkNumber, byte[] conte this.chunkNumber = chunkNumber; } - public MLUploadModelChunkInput(StreamInput in) throws IOException { this.modelId = in.readString(); this.chunkNumber = in.readInt(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequest.java index 253d13c1ed..5edea364aa 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequest.java @@ -5,25 +5,25 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.InputStreamStreamInput; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -63,8 +63,7 @@ public static MLUploadModelChunkRequest fromActionRequest(ActionRequest actionRe return (MLUploadModelChunkRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUploadModelChunkRequest(input); @@ -74,4 +73,4 @@ public static MLUploadModelChunkRequest fromActionRequest(ActionRequest actionRe } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java index b6a065a1be..de5b1603a2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java @@ -5,7 +5,8 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -13,20 +14,20 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Getter; public class MLUploadModelChunkResponse extends ActionResponse implements ToXContentObject { public static final String STATUS_FIELD = "status"; @Getter private String status; - public MLUploadModelChunkResponse (StreamInput in) throws IOException { + public MLUploadModelChunkResponse(StreamInput in) throws IOException { super(in); this.status = in.readString(); } - public MLUploadModelChunkResponse (String status) { - this.status= status; + public MLUploadModelChunkResponse(String status) { + this.status = status; } @Override @@ -42,4 +43,3 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par return builder; } } - diff --git a/common/src/main/java/org/opensearch/ml/common/utils/IndexUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/IndexUtils.java index beccb6fdb4..32c8cd6e25 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/IndexUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/IndexUtils.java @@ -5,9 +5,10 @@ package org.opensearch.ml.common.utils; -import lombok.extern.log4j.Log4j2; import java.util.Map; +import lombok.extern.log4j.Log4j2; + @Log4j2 public class IndexUtils { diff --git a/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java index c6a4886ba3..5c5cc5fd99 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java @@ -5,593 +5,598 @@ package org.opensearch.ml.common.utils; -import lombok.extern.log4j.Log4j2; -import org.opensearch.ml.common.MLModel; +import java.util.Map; + import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; -import java.util.Map; +import lombok.extern.log4j.Log4j2; @Log4j2 public class ModelInterfaceUtils { - private static final String GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"parameters\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"inputs\": {\n" + - " \"type\": \"string\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"inputs\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"parameters\"\n" + - " ]\n" + - "}"; + private static final String GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inputs\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inputs\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; - private static final String GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"parameters\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"texts\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"string\"\n" + - " }\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"texts\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"parameters\"\n" + - " ]\n" + - "}"; + private static final String GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"texts\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"texts\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; - private static final String TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"parameters\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"inputText\": {\n" + - " \"type\": \"string\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"inputText\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"parameters\"\n" + - " ]\n" + - "}"; + private static final String TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inputText\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inputText\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; - private static final String TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"parameters\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"inputText\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"inputImage\": {\n" + - " \"type\": \"string\"\n" + - " }\n" + - " }\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"parameters\"\n" + - " ]\n" + - "}"; + private static final String TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inputText\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"inputImage\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; - private static final String AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_INPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"parameters\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"Text\": {\n" + - " \"type\": \"string\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"Text\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"parameters\"\n" + - " ]\n" + - "}"; + private static final String AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Text\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"Text\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; - private static final String AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_INPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"parameters\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"bytes\": {\n" + - " \"type\": \"string\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"bytes\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"parameters\"\n" + - " ]\n" + - "}"; + private static final String AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_INPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"bytes\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"bytes\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"parameters\"\n" + + " ]\n" + + "}"; - private static final String GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"inference_results\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"output\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"name\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"dataAsMap\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"response\": {\n" + - " \"type\": \"string\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"response\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"name\",\n" + - " \"dataAsMap\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"status_code\": {\n" + - " \"type\": \"integer\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"output\",\n" + - " \"status_code\"\n" + - " ]\n" + - " }\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"inference_results\"\n" + - " ]\n" + - "}"; + private static final String GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"response\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"response\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"name\",\n" + + " \"dataAsMap\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"output\",\n" + + " \"status_code\"\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inference_results\"\n" + + " ]\n" + + "}"; - private static final String BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE_OUTPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"inference_results\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"output\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"name\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"dataAsMap\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"type\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"completion\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"stop_reason\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"stop\": {\n" + - " \"type\": \"string\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"type\",\n" + - " \"completion\",\n" + - " \"stop_reason\",\n" + - " \"stop\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"name\",\n" + - " \"dataAsMap\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"status_code\": {\n" + - " \"type\": \"integer\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"output\",\n" + - " \"status_code\"\n" + - " ]\n" + - " }\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"inference_results\"\n" + - " ]\n" + - "}"; + private static final String BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"type\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"completion\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"stop_reason\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"stop\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"type\",\n" + + " \"completion\",\n" + + " \"stop_reason\",\n" + + " \"stop\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"name\",\n" + + " \"dataAsMap\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"output\",\n" + + " \"status_code\"\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inference_results\"\n" + + " ]\n" + + "}"; - private static final String GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"inference_results\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"output\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"name\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"data_type\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"shape\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"integer\"\n" + - " }\n" + - " },\n" + - " \"data\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"number\"\n" + - " }\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"name\",\n" + - " \"data_type\",\n" + - " \"shape\",\n" + - " \"data\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"status_code\": {\n" + - " \"type\": \"integer\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"output\",\n" + - " \"status_code\"\n" + - " ]\n" + - " }\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"inference_results\"\n" + - " ]\n" + - "}"; + private static final String GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"data_type\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"shape\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"data\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"name\",\n" + + " \"data_type\",\n" + + " \"shape\",\n" + + " \"data\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"output\",\n" + + " \"status_code\"\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inference_results\"\n" + + " ]\n" + + "}"; - private static final String AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_OUTPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"inference_results\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"output\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"name\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"dataAsMap\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"response\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"Languages\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"LanguageCode\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"Score\": {\n" + - " \"type\": \"number\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"LanguageCode\",\n" + - " \"Score\"\n" + - " ]\n" + - " }\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"Languages\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"response\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"name\",\n" + - " \"dataAsMap\"\n" + - " ]\n" + - " }\n" + - " },\n" + - " \"status_code\": {\n" + - " \"type\": \"integer\"\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"output\",\n" + - " \"status_code\"\n" + - " ]\n" + - " }\n" + - " }\n" + - " },\n" + - " \"required\": [\n" + - " \"inference_results\"\n" + - " ]\n" + - "}"; + private static final String AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"response\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Languages\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"LanguageCode\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"Score\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"LanguageCode\",\n" + + " \"Score\"\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"Languages\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"response\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"name\",\n" + + " \"dataAsMap\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"integer\"\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"output\",\n" + + " \"status_code\"\n" + + " ]\n" + + " }\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"inference_results\"\n" + + " ]\n" + + "}"; - private static final String AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_OUTPUT = "{\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"inference_results\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"output\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"name\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"dataAsMap\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"Blocks\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"BlockType\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"Geometry\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"BoundingBox\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"Height\": {\n" + - " \"type\": \"number\"\n" + - " },\n" + - " \"Left\": {\n" + - " \"type\": \"number\"\n" + - " },\n" + - " \"Top\": {\n" + - " \"type\": \"number\"\n" + - " },\n" + - " \"Width\": {\n" + - " \"type\": \"number\"\n" + - " }\n" + - " }\n" + - " },\n" + - " \"Polygon\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"X\": {\n" + - " \"type\": \"number\"\n" + - " },\n" + - " \"Y\": {\n" + - " \"type\": \"number\"\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " },\n" + - " \"Id\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"Relationships\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"Ids\": {\n" + - " \"type\": \"array\",\n" + - " \"items\": {\n" + - " \"type\": \"string\"\n" + - " }\n" + - " },\n" + - " \"Type\": {\n" + - " \"type\": \"string\"\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " },\n" + - " \"DetectDocumentTextModelVersion\": {\n" + - " \"type\": \"string\"\n" + - " },\n" + - " \"DocumentMetadata\": {\n" + - " \"type\": \"object\",\n" + - " \"properties\": {\n" + - " \"Pages\": {\n" + - " \"type\": \"number\"\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " },\n" + - " \"status_code\": {\n" + - " \"type\": \"number\"\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - " }\n" + - "}"; + private static final String AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_OUTPUT = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"inference_results\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"output\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"dataAsMap\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Blocks\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"BlockType\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"Geometry\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"BoundingBox\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Height\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"Left\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"Top\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"Width\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " },\n" + + " \"Polygon\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"X\": {\n" + + " \"type\": \"number\"\n" + + " },\n" + + " \"Y\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"Id\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"Relationships\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Ids\": {\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " },\n" + + " \"Type\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"DetectDocumentTextModelVersion\": {\n" + + " \"type\": \"string\"\n" + + " },\n" + + " \"DocumentMetadata\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"Pages\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " },\n" + + " \"status_code\": {\n" + + " \"type\": \"number\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + public static final Map BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE = Map + .of("input", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, "output", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT); - public static final Map BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE = Map.of( - "input", - GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, - "output", - GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT - ); + public static final Map BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE = Map + .of("input", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, "output", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT); - public static final Map BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE = Map.of( - "input", - GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, - "output", - GENERAL_CONVERSATIONAL_MODEL_INTERFACE_OUTPUT - ); + public static final Map BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE = Map + .of("input", GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, "output", BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE_OUTPUT); - public static final Map BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE = Map.of( - "input", - GENERAL_CONVERSATIONAL_MODEL_INTERFACE_INPUT, - "output", - BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE_OUTPUT - ); + public static final Map BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE = Map + .of("input", GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); - public static final Map BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE = Map.of( - "input", - GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, - "output", - GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT - ); + public static final Map BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE = Map + .of("input", GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); - public static final Map BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE = Map.of( - "input", - GENERAL_EMBEDDING_MODEL_INTERFACE_INPUT, - "output", - GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT - ); + public static final Map BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE = Map + .of("input", TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); - public static final Map BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE = Map.of( - "input", - TITAN_TEXT_EMBEDDING_MODEL_INTERFACE_INPUT, - "output", - GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT - ); + public static final Map BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE = Map + .of("input", TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT, "output", GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT); - public static final Map BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE = Map.of( - "input", - TITAN_MULTI_MODAL_EMBEDDING_MODEL_INTERFACE_INPUT, - "output", - GENERAL_EMBEDDING_MODEL_INTERFACE_OUTPUT - ); - - public static final Map AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE = Map.of( + public static final Map AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE = Map + .of( "input", AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_INPUT, "output", AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE_OUTPUT - ); + ); - public static final Map AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE = Map.of( + public static final Map AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE = Map + .of( "input", AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_INPUT, "output", AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE_OUTPUT - ); + ); private static Map createPresetModelInterfaceByConnector(Connector connector) { if (connector.getParameters() != null) { switch ((connector.getParameters().get("service_name") != null) ? connector.getParameters().get("service_name") : "null") { case "bedrock": - log.debug("Detected Amazon Bedrock model"); - switch ((connector.getParameters().get("model") != null) ? connector.getParameters().get("model") : "null") { - case "ai21.j2-mid-v1": - log.debug("Creating preset model interface for Amazon Bedrock model: {}", connector.getParameters().get("model")); - return BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE; - case "anthropic.claude-3-sonnet-20240229-v1:0": - log.debug("Creating preset model interface for Amazon Bedrock model: {}", connector.getParameters().get("model")); - return BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE; - case "anthropic.claude-v2": - log.debug("Creating preset model interface for Amazon Bedrock model: {}", connector.getParameters().get("model")); - return BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE; - case "cohere.embed-english-v3": - log.debug("Creating preset model interface for Amazon Bedrock model: {}", connector.getParameters().get("model")); - return BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE; - case "cohere.embed-multilingual-v3": - log.debug("Creating preset model interface for Amazon Bedrock model: {}", connector.getParameters().get("model")); - return BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE; - case "amazon.titan-embed-text-v1": - log.debug("Creating preset model interface for Amazon Bedrock model: {}", connector.getParameters().get("model")); - return BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; - case "amazon.titan-embed-image-v1": - log.debug("Creating preset model interface for Amazon Bedrock model: {}", connector.getParameters().get("model")); - return BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE; - default: - return null; - } + log.debug("Detected Amazon Bedrock model"); + switch ((connector.getParameters().get("model") != null) ? connector.getParameters().get("model") : "null") { + case "ai21.j2-mid-v1": + log + .debug( + "Creating preset model interface for Amazon Bedrock model: {}", + connector.getParameters().get("model") + ); + return BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE; + case "anthropic.claude-3-sonnet-20240229-v1:0": + log + .debug( + "Creating preset model interface for Amazon Bedrock model: {}", + connector.getParameters().get("model") + ); + return BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE; + case "anthropic.claude-v2": + log + .debug( + "Creating preset model interface for Amazon Bedrock model: {}", + connector.getParameters().get("model") + ); + return BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE; + case "cohere.embed-english-v3": + log + .debug( + "Creating preset model interface for Amazon Bedrock model: {}", + connector.getParameters().get("model") + ); + return BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE; + case "cohere.embed-multilingual-v3": + log + .debug( + "Creating preset model interface for Amazon Bedrock model: {}", + connector.getParameters().get("model") + ); + return BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE; + case "amazon.titan-embed-text-v1": + log + .debug( + "Creating preset model interface for Amazon Bedrock model: {}", + connector.getParameters().get("model") + ); + return BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; + case "amazon.titan-embed-image-v1": + log + .debug( + "Creating preset model interface for Amazon Bedrock model: {}", + connector.getParameters().get("model") + ); + return BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE; + default: + return null; + } case "comprehend": log.debug("Detected Amazon Comprehend model"); - switch ((connector.getParameters().get("api_name") != null) ? connector.getParameters().get("api_name") : "null"){ + switch ((connector.getParameters().get("api_name") != null) ? connector.getParameters().get("api_name") : "null") { // Single case for switch-case statement due to there is one more API in blueprint for Amazon Comprehend Model // Not set here because there is more than one input/output schema for the DetectEntities API // TODO: Add default model interface for Amazon Comprehend DetectEntities APIs case "DetectDominantLanguage": - log.debug("Creating preset model interface for Amazon Comprehend API: {}", connector.getParameters().get("api_name")); + log + .debug( + "Creating preset model interface for Amazon Comprehend API: {}", + connector.getParameters().get("api_name") + ); return AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE; default: return null; @@ -612,7 +617,10 @@ private static Map createPresetModelInterfaceByConnector(Connect * @param registerModelInput the register model input * @param connector the connector */ - public static void updateRegisterModelInputModelInterfaceFieldsByConnector(MLRegisterModelInput registerModelInput, Connector connector) { + public static void updateRegisterModelInputModelInterfaceFieldsByConnector( + MLRegisterModelInput registerModelInput, + Connector connector + ) { Map presetModelInterface = createPresetModelInterfaceByConnector(connector); if (presetModelInterface != null) { registerModelInput.setModelInterface(presetModelInterface); diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 9fab197a8c..e71636e01b 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -5,51 +5,45 @@ package org.opensearch.ml.common.utils; -import com.google.gson.Gson; -import com.google.gson.JsonElement; -import com.google.gson.JsonParser; -import com.google.gson.JsonSyntaxException; -import lombok.extern.log4j.Log4j2; -import org.apache.commons.lang3.BooleanUtils; -import org.json.JSONArray; -import org.json.JSONException; -import org.json.JSONObject; -import org.opensearch.OpenSearchParseException; - -import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; -import java.util.stream.Collectors; -import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; -import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; +import org.apache.commons.lang3.BooleanUtils; +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; +import com.google.gson.JsonSyntaxException; + +import lombok.extern.log4j.Log4j2; @Log4j2 public class StringUtils { - public static final String DEFAULT_ESCAPE_FUNCTION = "\n String escape(def input) { \n" + - " if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n" + - " if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n" + - " if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n" + - " if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n" + - " if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n" + - " if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n" + - " if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n" + - " return input;" + - "\n }\n"; + public static final String DEFAULT_ESCAPE_FUNCTION = "\n String escape(def input) { \n" + + " if (input.contains(\"\\\\\")) {\n input = input.replace(\"\\\\\", \"\\\\\\\\\");\n }\n" + + " if (input.contains(\"\\\"\")) {\n input = input.replace(\"\\\"\", \"\\\\\\\"\");\n }\n" + + " if (input.contains('\r')) {\n input = input = input.replace('\r', '\\\\r');\n }\n" + + " if (input.contains(\"\\\\t\")) {\n input = input.replace(\"\\\\t\", \"\\\\\\\\\\\\t\");\n }\n" + + " if (input.contains('\n')) {\n input = input.replace('\n', '\\\\n');\n }\n" + + " if (input.contains('\b')) {\n input = input.replace('\b', '\\\\b');\n }\n" + + " if (input.contains('\f')) {\n input = input.replace('\f', '\\\\f');\n }\n" + + " return input;" + + "\n }\n"; public static final Gson gson; @@ -75,10 +69,10 @@ public static boolean isJson(String json) { if (!isValidJsonString(json)) { return false; } - //This is to cover such edge case "[]\"" + // This is to cover such edge case "[]\"" gson.fromJson(json, Object.class); return true; - } catch(JsonSyntaxException ex) { + } catch (JsonSyntaxException ex) { return false; } } @@ -114,7 +108,7 @@ public static Map filteredParameterMap(Map parameterO try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { if (value instanceof String) { - parameters.put(key, (String)value); + parameters.put(key, (String) value); } else { parameters.put(key, gson.toJson(value)); } @@ -135,7 +129,7 @@ public static Map getParameterMap(Map parameterObjs) try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { if (value instanceof String) { - parameters.put(key, (String)value); + parameters.put(key, (String) value); } else { parameters.put(key, gson.toJson(value)); } @@ -218,7 +212,7 @@ public static boolean patternExist(String input, String patternString) { } public static boolean isEscapeUsed(String input) { - return patternExist(input,"(? inputData = new ArrayList<>(); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); Input mcorrInput = new MetricsCorrelationInput(inputData); BytesStreamOutput bytesStreamOutputMCorrInput = new BytesStreamOutput(); mcorrInput.writeTo(bytesStreamOutputMCorrInput); StreamInput streamInputForMcorrInput = bytesStreamOutputMCorrInput.bytes().streamInput(); - MetricsCorrelationInput mcorrStreamInput = MLCommonsClassLoader.initExecuteInputInstance(FunctionName.METRICS_CORRELATION, streamInputForMcorrInput, StreamInput.class); + MetricsCorrelationInput mcorrStreamInput = MLCommonsClassLoader + .initExecuteInputInstance(FunctionName.METRICS_CORRELATION, streamInputForMcorrInput, StreamInput.class); assertArrayEquals(((MetricsCorrelationInput) mcorrInput).getInputData().toArray(), mcorrStreamInput.getInputData().toArray()); } @@ -130,11 +128,12 @@ public void testClassLoader_ExecuteMCorrInput() throws IOException { @Test public void testClassLoader_ExecuteOutputMCorr() throws IOException { List outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); List mlModelTensors = Arrays.asList(mCorrModelTensor); MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); outputs.add(modelTensors); @@ -142,7 +141,8 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException { BytesStreamOutput bytesStreamOutputMcorrOutput = new BytesStreamOutput(); output.writeTo(bytesStreamOutputMcorrOutput); StreamInput streamInputForOutput = bytesStreamOutputMcorrOutput.bytes().streamInput(); - MetricsCorrelationOutput mcorrOutput = MLCommonsClassLoader.initExecuteOutputInstance(FunctionName.METRICS_CORRELATION, streamInputForOutput, StreamInput.class); + MetricsCorrelationOutput mcorrOutput = MLCommonsClassLoader + .initExecuteOutputInstance(FunctionName.METRICS_CORRELATION, streamInputForOutput, StreamInput.class); assertEquals(1, mcorrOutput.getModelOutput().size()); MCorrModelTensors testmodelTensors = mcorrOutput.getModelOutput().get(0); @@ -150,22 +150,29 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException { MCorrModelTensor testmodelTensor = testmodelTensors.getMCorrModelTensors().get(0); float[] events = testmodelTensor.getEvent_pattern(); long[] metrics = testmodelTensor.getSuspected_metrics(); - assertArrayEquals(new float[]{1.0f, 2.0f, 3.0f}, events, 0.001f); - assertArrayEquals(new long[]{1, 2}, metrics); + assertArrayEquals(new float[] { 1.0f, 2.0f, 3.0f }, events, 0.001f); + assertArrayEquals(new long[] { 1, 2 }, metrics); } private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws IOException { assertTrue(MLCommonsClassLoader.canInitMLInput(functionName)); - String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + String jsonStr = + "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); - TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(functionName, new Object[]{parser, functionName}, XContentParser.class, FunctionName.class); + TextDocsMLInput mlInput = MLCommonsClassLoader + .initMLInput(functionName, new Object[] { parser, functionName }, XContentParser.class, FunctionName.class); assertNotNull(mlInput); assertEquals(functionName, mlInput.getFunctionName()); - assertEquals(2, ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs().size()); + assertEquals(2, ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().size()); } @Test @@ -181,7 +188,7 @@ public void testConnectorInitializationException() { String initParam1 = "parameter1"; // Initialize the first connector type - MLCommonsClassLoader.initConnector("Connector", new Object[]{initParam1}, String.class); + MLCommonsClassLoader.initConnector("Connector", new Object[] { initParam1 }, String.class); } public enum TestEnum { diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java index 71f7f46cf2..c1abe07297 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -19,10 +23,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; - public class MLModelGroupTest { @Rule @@ -47,31 +47,41 @@ public void toXContent_Empty() throws IOException { @Test public void toXContent() throws IOException { - MLModelGroup modelGroup = MLModelGroup.builder() - .name("test") - .description("this is test group") - .latestVersion(1) - .backendRoles(Arrays.asList("role1", "role2")) - .owner(new User()) - .access(AccessMode.PUBLIC.name()) - .build(); + MLModelGroup modelGroup = MLModelGroup + .builder() + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + - "\"access\":\"PUBLIC\"}", content); + Assert + .assertEquals( + "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"access\":\"PUBLIC\"}", + content + ); } @Test public void parse() throws IOException { - String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + - "\"access\":\"PUBLIC\"}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"access\":\"PUBLIC\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLModelGroup modelGroup = MLModelGroup.parse(parser); Assert.assertEquals("test", modelGroup.getName()); @@ -85,8 +95,13 @@ public void parse() throws IOException { @Test public void parse_Empty() throws IOException { String jsonStr = "{\"name\":\"test\"}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLModelGroup modelGroup = MLModelGroup.parse(parser); Assert.assertEquals("test", modelGroup.getName()); @@ -97,14 +112,15 @@ public void parse_Empty() throws IOException { @Test public void writeTo() throws IOException { - MLModelGroup originalModelGroup = MLModelGroup.builder() - .name("test") - .description("this is test group") - .latestVersion(1) - .backendRoles(Arrays.asList("role1", "role2")) - .owner(new User()) - .access(AccessMode.PUBLIC.name()) - .build(); + MLModelGroup originalModelGroup = MLModelGroup + .builder() + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); BytesStreamOutput output = new BytesStreamOutput(); originalModelGroup.writeTo(output); diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelTests.java b/common/src/test/java/org/opensearch/ml/common/MLModelTests.java index b493ea29e9..77a3dd6a25 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelTests.java @@ -5,62 +5,65 @@ package org.opensearch.ml.common; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import java.io.IOException; -import java.time.Instant; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class MLModelTests { MLModel mlModel; TextEmbeddingModelConfig config; Function function; + @Before public void setUp() { FunctionName algorithm = FunctionName.KMEANS; - User user = new User(); - config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); + User user = new User(); + config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); Instant now = Instant.now(); - mlModel = MLModel.builder() - .name("some model") - .algorithm(algorithm) - .version("1.0.0") - .content("some content") - .user(user) - .description("test description") - .modelFormat(MLModelFormat.ONNX) - .modelState(MLModelState.DEPLOYED) - .modelContentSizeInBytes(10_000_000l) - .modelContentHash("test_hash") - .modelConfig(config) - .createdTime(now) - .lastRegisteredTime(now) - .lastDeployedTime(now) - .lastUndeployedTime(now) - .modelId("model_id") - .chunkNumber(1) - .totalChunks(10) - .isHidden(false) - .build(); + mlModel = MLModel + .builder() + .name("some model") + .algorithm(algorithm) + .version("1.0.0") + .content("some content") + .user(user) + .description("test description") + .modelFormat(MLModelFormat.ONNX) + .modelState(MLModelState.DEPLOYED) + .modelContentSizeInBytes(10_000_000l) + .modelContentHash("test_hash") + .modelConfig(config) + .createdTime(now) + .lastRegisteredTime(now) + .lastDeployedTime(now) + .lastUndeployedTime(now) + .modelId("model_id") + .chunkNumber(1) + .totalChunks(10) + .isHidden(false) + .build(); function = parser -> { try { return MLModel.parse(parser, algorithm.name()); @@ -72,11 +75,21 @@ public void setUp() { @Test public void toXContent() throws IOException { - MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("model_name").version("1.0.0").content("test_content").isHidden(true).build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.KMEANS) + .name("model_name") + .version("1.0.0") + .content("test_content") + .isHidden(true) + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"model_version\":\"1.0.0\",\"model_content\":\"test_content\",\"is_hidden\":true}", mlModelContent); + assertEquals( + "{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"model_version\":\"1.0.0\",\"model_content\":\"test_content\",\"is_hidden\":true}", + mlModelContent + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java b/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java index 6c050fb978..2ffdc32679 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java @@ -55,13 +55,14 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlTask.toXContent(builder, ToXContent.EMPTY_PARAMS); String taskContent = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals( - "{\"task_id\":\"dummy taskId\",\"model_id\":\"test_model_id\",\"task_type\":\"PREDICTION\"," - + "\"function_name\":\"KMEANS\",\"state\":\"RUNNING\",\"input_type\":\"DATA_FRAME\",\"progress\":0.0," - + "\"output_index\":\"test_index\",\"worker_node\":[\"node1\"],\"create_time\":1641599940000," - + "\"last_update_time\":1641600000000,\"error\":\"test_error\",\"is_async\":false}", - taskContent - ); + Assert + .assertEquals( + "{\"task_id\":\"dummy taskId\",\"model_id\":\"test_model_id\",\"task_type\":\"PREDICTION\"," + + "\"function_name\":\"KMEANS\",\"state\":\"RUNNING\",\"input_type\":\"DATA_FRAME\",\"progress\":0.0," + + "\"output_index\":\"test_index\",\"worker_node\":[\"node1\"],\"create_time\":1641599940000," + + "\"last_update_time\":1641600000000,\"error\":\"test_error\",\"is_async\":false}", + taskContent + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java index 0710af6dcc..acb05c9b66 100644 --- a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java +++ b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java @@ -9,6 +9,7 @@ import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import java.io.IOException; + import org.junit.Assert; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -25,66 +26,76 @@ public class RemoteModelTests { @Test public void toXContent_ConnectorId() throws IOException { - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connectorId("test_connector_id") - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connectorId("test_connector_id") + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"" + - ",\"model_version\":\"1.0.0\",\"description\":\"test model\"," + - "\"connector_id\":\"test_connector_id\"}", mlModelContent); + assertEquals( + "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"" + + ",\"model_version\":\"1.0.0\",\"description\":\"test model\"," + + "\"connector_id\":\"test_connector_id\"}", + mlModelContent + ); } @Test public void toXContent_InternalConnector() throws IOException { Connector connector = HttpConnectorTest.createHttpConnector(); - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connector(connector) - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connector(connector) + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\"," + - "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\"," + - "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + - "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + - "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + - "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}", - mlModelContent); + assertEquals( + "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\"," + + "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\"," + + "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}", + mlModelContent + ); } @Test public void parse_ConnectorId() throws IOException { - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connectorId("test_connector_id") - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connectorId("test_connector_id") + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String jsonStr = TestHelper.xContentBuilderToString(builder); - XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); parser.nextToken(); MLModel parsedModel = MLModel.parse(parser, FunctionName.REMOTE.name()); Assert.assertNull(parsedModel.getConnector()); @@ -94,49 +105,53 @@ public void parse_ConnectorId() throws IOException { @Test public void parse_InternalConnector() throws IOException { Connector connector = HttpConnectorTest.createHttpConnector(); - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connector(connector) - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connector(connector) + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String jsonStr = TestHelper.xContentBuilderToString(builder); - XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); parser.nextToken(); MLModel parsedModel = MLModel.parse(parser, FunctionName.REMOTE.name()); Assert.assertEquals(mlModel.getConnector(), parsedModel.getConnector()); } - @Test public void readInputStream_ConnectorId() throws IOException { - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connectorId("test_connector_id") - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connectorId("test_connector_id") + .build(); readInputStream(mlModel); } @Test public void readInputStream_InternalConnector() throws IOException { Connector connector = HttpConnectorTest.createHttpConnector(); - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connector(connector) - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connector(connector) + .build(); readInputStream(mlModel); } diff --git a/common/src/test/java/org/opensearch/ml/common/TestHelper.java b/common/src/test/java/org/opensearch/ml/common/TestHelper.java index 81810aef92..9dffad66a1 100644 --- a/common/src/test/java/org/opensearch/ml/common/TestHelper.java +++ b/common/src/test/java/org/opensearch/ml/common/TestHelper.java @@ -5,26 +5,27 @@ package org.opensearch.ml.common; -import org.opensearch.core.common.bytes.BytesReference; +import java.io.IOException; +import java.util.function.Function; + import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.function.Function; - public class TestHelper { public static void testParse(ToXContentObject obj, Function function) throws IOException { testParse(obj, function, false); } - public static void testParse(ToXContentObject obj, Function function, boolean wrapWithObject) throws IOException { + public static void testParse(ToXContentObject obj, Function function, boolean wrapWithObject) + throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder(); if (wrapWithObject) { builder.startObject(); @@ -37,8 +38,11 @@ public static void testParse(ToXContentObject obj, Function void testParseFromString(ToXContentObject obj, String jsonStr, Function function) throws IOException { - XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + public static void testParseFromString(ToXContentObject obj, String jsonStr, Function function) + throws IOException { + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); parser.nextToken(); T parsedObj = function.apply(parser); obj.equals(parsedObj); diff --git a/common/src/test/java/org/opensearch/ml/common/ToolMetadataTests.java b/common/src/test/java/org/opensearch/ml/common/ToolMetadataTests.java index 02234757b3..8d71fa1fc3 100644 --- a/common/src/test/java/org/opensearch/ml/common/ToolMetadataTests.java +++ b/common/src/test/java/org/opensearch/ml/common/ToolMetadataTests.java @@ -4,6 +4,12 @@ */ package org.opensearch.ml.common; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -15,12 +21,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class ToolMetadataTests { ToolMetadata toolMetadata; @@ -28,12 +28,13 @@ public class ToolMetadataTests { @Before public void setUp() { - toolMetadata = ToolMetadata.builder() - .name("MathTool") - .description("Use this tool to calculate any math problem.") - .type("MathTool") - .version("test") - .build(); + toolMetadata = ToolMetadata + .builder() + .name("MathTool") + .description("Use this tool to calculate any math problem.") + .type("MathTool") + .version("test") + .build(); function = parser -> { try { @@ -49,7 +50,10 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); toolMetadata.toXContent(builder, EMPTY_PARAMS); String toolMetadataString = TestHelper.xContentBuilderToString(builder); - assertEquals(toolMetadataString, "{\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"test\"}"); + assertEquals( + toolMetadataString, + "{\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"test\"}" + ); } @Test @@ -66,13 +70,13 @@ public void parse() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); toolMetadata.toXContent(builder, EMPTY_PARAMS); String toolMetadataString = TestHelper.xContentBuilderToString(builder); - XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, toolMetadataString); + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, toolMetadataString); parser.nextToken(); toolMetadata.equals(function.apply(parser)); } - @Test public void readInputStream_Success() throws IOException { readInputStream(toolMetadata); diff --git a/common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java index 0964efb7d1..7ebde1d183 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/LLMSpecTest.java @@ -1,5 +1,11 @@ package org.opensearch.ml.common.agent; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -12,15 +18,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.Map; - -import static org.junit.Assert.*; - public class LLMSpecTest { @Rule @@ -69,8 +68,13 @@ public void toXContent() throws IOException { @Test public void parse() throws IOException { String jsonStr = "{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); LLMSpec spec = LLMSpec.parse(parser); @@ -88,4 +92,4 @@ public void fromStream() throws IOException { Assert.assertEquals(spec.getModelId(), spec1.getModelId()); Assert.assertEquals(spec.getParameters(), spec1.getParameters()); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index 8b1a96e07b..b83758fc23 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.agent; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; + import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -23,14 +31,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.time.Instant; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.*; - public class MLAgentTest { @Rule @@ -41,7 +41,19 @@ public void constructor_NullName() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Agent name can't be null"); - MLAgent agent = new MLAgent(null, MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + null, + MLAgentType.CONVERSATIONAL.name(), + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); } @Test @@ -49,7 +61,19 @@ public void constructor_NullType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Agent type can't be null"); - MLAgent agent = new MLAgent("test_agent", null, "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test_agent", + null, + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); } @Test @@ -57,7 +81,19 @@ public void constructor_NullLLMSpec() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("We need model information for the conversational agent type"); - MLAgent agent = new MLAgent("test_agent", MLAgentType.CONVERSATIONAL.name(), "test", null, List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.CONVERSATIONAL.name(), + "test", + null, + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); } @Test @@ -65,12 +101,36 @@ public void constructor_DuplicateTool() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Duplicate tool defined: test_tool_name"); MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false); - MLAgent agent = new MLAgent("test_name", MLAgentType.CONVERSATIONAL.name(), "test_description", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(mlToolSpec, mlToolSpec), null, null, Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test_name", + MLAgentType.CONVERSATIONAL.name(), + "test_description", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(mlToolSpec, mlToolSpec), + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); } @Test public void writeTo() throws IOException { - MLAgent agent = new MLAgent("test", "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test", + "CONVERSATIONAL", + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + Map.of("test", "test"), + new MLMemorySpec("test", "123", 0), + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = new MLAgent(output.bytes().streamInput()); @@ -85,7 +145,19 @@ public void writeTo() throws IOException { @Test public void writeTo_NullLLM() throws IOException { - MLAgent agent = new MLAgent("test", "FLOW", "test", null, List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test", + null, + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + Map.of("test", "test"), + new MLMemorySpec("test", "123", 0), + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = new MLAgent(output.bytes().streamInput()); @@ -95,7 +167,19 @@ public void writeTo_NullLLM() throws IOException { @Test public void writeTo_NullTools() throws IOException { - MLAgent agent = new MLAgent("test", "FLOW", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(), + Map.of("test", "test"), + new MLMemorySpec("test", "123", 0), + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = new MLAgent(output.bytes().streamInput()); @@ -105,7 +189,19 @@ public void writeTo_NullTools() throws IOException { @Test public void writeTo_NullParameters() throws IOException { - MLAgent agent = new MLAgent("test", MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test", + MLAgentType.CONVERSATIONAL.name(), + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + null, + new MLMemorySpec("test", "123", 0), + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = new MLAgent(output.bytes().streamInput()); @@ -115,7 +211,19 @@ public void writeTo_NullParameters() throws IOException { @Test public void writeTo_NullMemory() throws IOException { - MLAgent agent = new MLAgent("test", "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), null, Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test", + "CONVERSATIONAL", + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + Map.of("test", "test"), + null, + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = new MLAgent(output.bytes().streamInput()); @@ -125,20 +233,39 @@ public void writeTo_NullMemory() throws IOException { @Test public void toXContent() throws IOException { - MLAgent agent = new MLAgent("test", "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test", + "CONVERSATIONAL", + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false)), + Map.of("test", "test"), + new MLMemorySpec("test", "123", 0), + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); agent.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - String expectedStr = "{\"name\":\"test\",\"type\":\"CONVERSATIONAL\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\",\"is_hidden\":false}"; + String expectedStr = + "{\"name\":\"test\",\"type\":\"CONVERSATIONAL\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\",\"is_hidden\":false}"; Assert.assertEquals(content, expectedStr); } @Test public void parse() throws IOException { - String jsonStr = "{\"name\":\"test\",\"type\":\"CONVERSATIONAL\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\",\"is_hidden\":false}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + String jsonStr = + "{\"name\":\"test\",\"type\":\"CONVERSATIONAL\",\"description\":\"test\",\"llm\":{\"model_id\":\"test_model\",\"parameters\":{\"test_key\":\"test_value\"}},\"tools\":[{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}],\"parameters\":{\"test\":\"test\"},\"memory\":{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"},\"created_time\":0,\"last_updated_time\":0,\"app_type\":\"test\",\"is_hidden\":false}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLAgent agent = MLAgent.parse(parser); @@ -162,7 +289,19 @@ public void parse() throws IOException { @Test public void fromStream() throws IOException { - MLAgent agent = new MLAgent("test", MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test", + MLAgentType.CONVERSATIONAL.name(), + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + Map.of("test", "test"), + new MLMemorySpec("test", "123", 0), + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); BytesStreamOutput output = new BytesStreamOutput(); agent.writeTo(output); MLAgent agent1 = MLAgent.fromStream(output.bytes().streamInput()); @@ -186,7 +325,19 @@ public void constructor_InvalidAgentType() { @Test public void constructor_NonConversationalNoLLM() { try { - MLAgent agent = new MLAgent("test_name", MLAgentType.FLOW.name(), "test_description", null, null, null, null, Instant.EPOCH, Instant.EPOCH, "test", false); + MLAgent agent = new MLAgent( + "test_name", + MLAgentType.FLOW.name(), + "test_description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); assertNotNull(agent); // Ensuring object creation was successful without throwing an exception } catch (IllegalArgumentException e) { fail("Should not throw an exception for non-conversational types without LLM"); @@ -218,8 +369,13 @@ public void writeTo_ReadFrom_HiddenFlag_VersionCompatibility() throws IOExceptio @Test public void parse_MissingFields() throws IOException { String jsonStr = "{\"name\":\"test\",\"type\":\"FLOW\"}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLAgent agent = MLAgent.parse(parser); diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java index 2d028985e0..78e9d62ea1 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLMemorySpecTest.java @@ -1,5 +1,10 @@ package org.opensearch.ml.common.agent; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.Collections; + import org.junit.Assert; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,12 +17,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.Map; - -import static org.junit.Assert.*; - public class MLMemorySpecTest { @Test @@ -45,8 +44,13 @@ public void toXContent() throws IOException { @Test public void parse() throws IOException { String jsonStr = "{\"type\":\"test\",\"window_size\":0,\"session_id\":\"123\"}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLMemorySpec spec = MLMemorySpec.parse(parser); @@ -66,4 +70,4 @@ public void fromStream() throws IOException { Assert.assertEquals(spec.getSessionId(), spec1.getSessionId()); Assert.assertEquals(spec.getWindowSize(), spec1.getWindowSize()); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java index d831611035..3d4d9a2ce5 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java @@ -1,5 +1,11 @@ package org.opensearch.ml.common.agent; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + import org.junit.Assert; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,12 +18,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.Map; - -import static org.junit.Assert.*; - public class MLToolSpecTest { @Test @@ -41,14 +41,24 @@ public void toXContent() throws IOException { spec.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}", content); + Assert + .assertEquals( + "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}", + content + ); } @Test public void parse() throws IOException { - String jsonStr = "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + String jsonStr = + "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLToolSpec spec = MLToolSpec.parse(parser); @@ -72,4 +82,4 @@ public void fromStream() throws IOException { Assert.assertEquals(spec.getDescription(), spec1.getDescription()); Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java index 36a964cef1..a60d3ac1cf 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java @@ -5,6 +5,21 @@ package org.opensearch.ml.common.connector; +import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; +import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; +import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; +import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.function.Function; + import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -20,21 +35,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.Locale; -import java.util.Map; -import java.util.function.Function; - -import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; -import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; -import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD; -import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; -import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; -import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; - public class AwsConnectorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -44,8 +44,8 @@ public class AwsConnectorTest { @Before public void setUp() { - encryptFunction = s -> "encrypted: "+s.toLowerCase(Locale.ROOT); - decryptFunction = s -> "decrypted: "+s.toUpperCase(Locale.ROOT); + encryptFunction = s -> "encrypted: " + s.toLowerCase(Locale.ROOT); + decryptFunction = s -> "decrypted: " + s.toUpperCase(Locale.ROOT); } @Test @@ -107,7 +107,12 @@ public void constructor_NoPredictAction() { credential.put(REGION_FIELD, "test_region"); Map parameters = new HashMap<>(); parameters.put(SERVICE_NAME_FIELD, "test_service"); - AwsConnector connector = AwsConnector.awsConnectorBuilder().protocol(ConnectorProtocols.AWS_SIGV4).credential(credential).parameters(parameters).build(); + AwsConnector connector = AwsConnector + .awsConnectorBuilder() + .protocol(ConnectorProtocols.AWS_SIGV4) + .credential(credential) + .parameters(parameters) + .build(); Assert.assertNotNull(connector); connector.encrypt(encryptFunction); @@ -126,8 +131,13 @@ public void constructor_Parser() throws IOException { awsConnector.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = TestHelper.xContentBuilderToString(builder); - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); AwsConnector connector = new AwsConnector(awsConnector.getProtocol(), parser); @@ -210,19 +220,28 @@ private AwsConnector createAwsConnector(Map parameters, Map ConnectorClientConfig.parse(parser)); @@ -147,13 +159,12 @@ public void testDefaultValues() { public void testDefaultValuesInitByNewInstance() { ConnectorClientConfig config = new ConnectorClientConfig(); - Assert.assertEquals(Integer.valueOf(30),config.getMaxConnections()); - Assert.assertEquals(Integer.valueOf(30000),config.getConnectionTimeout()); - Assert.assertEquals(Integer.valueOf(30000),config.getReadTimeout()); - Assert.assertEquals(Integer.valueOf(200),config.getRetryBackoffMillis()); - Assert.assertEquals(Integer.valueOf(30),config.getRetryTimeoutSeconds()); - Assert.assertEquals(Integer.valueOf(0),config.getMaxRetryTimes()); + Assert.assertEquals(Integer.valueOf(30), config.getMaxConnections()); + Assert.assertEquals(Integer.valueOf(30000), config.getConnectionTimeout()); + Assert.assertEquals(Integer.valueOf(30000), config.getReadTimeout()); + Assert.assertEquals(Integer.valueOf(200), config.getRetryBackoffMillis()); + Assert.assertEquals(Integer.valueOf(30), config.getRetryTimeoutSeconds()); + Assert.assertEquals(Integer.valueOf(0), config.getMaxRetryTimes()); Assert.assertEquals(RetryBackoffPolicy.CONSTANT, config.getRetryBackoffPolicy()); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorTest.java index 93285ec1a6..81c4a73866 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorTest.java @@ -1,5 +1,11 @@ package org.opensearch.ml.common.connector; +import static org.opensearch.ml.common.connector.HttpConnectorTest.createHttpConnector; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -14,12 +20,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; - -import static org.opensearch.ml.common.connector.HttpConnectorTest.createHttpConnector; - public class ConnectorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -50,8 +50,13 @@ public void createConnector_Parser() throws IOException { connector.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = TestHelper.xContentBuilderToString(builder); - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); Connector connector2 = Connector.createConnector(parser); @@ -63,19 +68,30 @@ public void validateConnectorURL_Invalid() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Connector URL is not matching the trusted connector endpoint regex"); HttpConnector connector = createHttpConnector(); - connector.validateConnectorURL(Arrays.asList("^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", - "^https://api\\.openai\\.com/.*$", - "^https://api\\.cohere\\.ai/.*$", - "^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$" - )); + connector + .validateConnectorURL( + Arrays + .asList( + "^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://api\\.openai\\.com/.*$", + "^https://api\\.cohere\\.ai/.*$", + "^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$" + ) + ); } @Test public void validateConnectorURL() { HttpConnector connector = createHttpConnector(); - connector.validateConnectorURL(Arrays.asList("^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", - "^https://api\\.openai\\.com/.*$", - "^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$", - "^" + connector.getActions().get(0).getUrl())); + connector + .validateConnectorURL( + Arrays + .asList( + "^https://runtime\\.sagemaker\\..*[a-z0-9-]\\.amazonaws\\.com/.*$", + "^https://api\\.openai\\.com/.*$", + "^https://bedrock-agent-runtime\\\\..*[a-z0-9-]\\\\.amazonaws\\\\.com/.*$", + "^" + connector.getActions().get(0).getUrl() + ) + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index c25f9653c3..0115ac1376 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -5,6 +5,18 @@ package org.opensearch.ml.common.connector; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Function; + import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -22,18 +34,6 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.function.Function; - -import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; - public class HttpConnectorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -41,22 +41,22 @@ public class HttpConnectorTest { Function encryptFunction; Function decryptFunction; - String TEST_CONNECTOR_JSON_STRING = "{\"name\":\"test_connector_name\",\"version\":\"1\"," + - "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + - "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + - "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + String TEST_CONNECTOR_JSON_STRING = "{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; @Before public void setUp() { - encryptFunction = s -> "encrypted: "+s.toLowerCase(Locale.ROOT); - decryptFunction = s -> "decrypted: "+s.toUpperCase(Locale.ROOT); + encryptFunction = s -> "encrypted: " + s.toLowerCase(Locale.ROOT); + decryptFunction = s -> "decrypted: " + s.toUpperCase(Locale.ROOT); } @Test @@ -88,12 +88,16 @@ public void toXContent() throws IOException { Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content); } - @Test public void constructor_Parser() throws IOException { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, TEST_CONNECTOR_JSON_STRING); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + TEST_CONNECTOR_JSON_STRING + ); parser.nextToken(); HttpConnector connector = new HttpConnector("http", parser); @@ -287,7 +291,15 @@ public static HttpConnector createHttpConnectorWithRequestBody(String requestBod String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; - ConnectorAction action = new ConnectorAction(actionType, method, url, headers, requestBody, preProcessFunction, postProcessFunction); + ConnectorAction action = new ConnectorAction( + actionType, + method, + url, + headers, + requestBody, + preProcessFunction, + postProcessFunction + ); Map parameters = new HashMap<>(); parameters.put("input", "test input value"); @@ -297,18 +309,19 @@ public static HttpConnector createHttpConnectorWithRequestBody(String requestBod ConnectorClientConfig httpClientConfig = new ConnectorClientConfig(30, 30000, 30000, 10, 10, -1, RetryBackoffPolicy.CONSTANT); - HttpConnector connector = HttpConnector.builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(parameters) - .credential(credential) - .actions(Arrays.asList(action)) - .backendRoles(Arrays.asList("role1", "role2")) - .accessMode(AccessMode.PUBLIC) - .connectorClientConfig(httpClientConfig) - .build(); + HttpConnector connector = HttpConnector + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(action)) + .backendRoles(Arrays.asList("role1", "role2")) + .accessMode(AccessMode.PUBLIC) + .connectorClientConfig(httpClientConfig) + .build(); return connector; } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java index 1c60ee8b16..944ea82eae 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java @@ -5,19 +5,19 @@ package org.opensearch.ml.common.connector; -import org.junit.Assert; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_EMBEDDING; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import static org.opensearch.ml.common.connector.MLPostProcessFunction.BEDROCK_EMBEDDING; -import static org.opensearch.ml.common.connector.MLPostProcessFunction.COHERE_EMBEDDING; -import static org.opensearch.ml.common.connector.MLPostProcessFunction.DEFAULT_EMBEDDING; -import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; public class MLPostProcessFunctionTest { diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPreProcessFunctionTest.java index b3784c1c1c..dfcd2a41a8 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPreProcessFunctionTest.java @@ -5,11 +5,11 @@ package org.opensearch.ml.common.connector; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; + import org.junit.Assert; import org.junit.Test; -import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; - public class MLPreProcessFunctionTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java index 224e807031..5a455e0e4b 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/BedrockEmbeddingPostProcessFunctionTest.java @@ -5,17 +5,17 @@ package org.opensearch.ml.common.connector.functions.postprocess; +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; +import java.util.List; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.ml.common.output.model.ModelTensor; -import java.util.Arrays; -import java.util.List; - -import static org.junit.Assert.assertEquals; - public class BedrockEmbeddingPostProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java index 5e8cfd4319..45f5696912 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/CohereRerankPostProcessFunctionTest.java @@ -5,17 +5,17 @@ package org.opensearch.ml.common.connector.functions.postprocess; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.opensearch.ml.common.output.model.ModelTensor; +import static org.junit.Assert.assertEquals; import java.util.Arrays; import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.output.model.ModelTensor; public class CohereRerankPostProcessFunctionTest { @Rule @@ -51,10 +51,12 @@ public void process_WrongInput_NotCorrectMap() { @Test public void process_CorrectInput() { - List> rerankResults = List.of( + List> rerankResults = List + .of( Map.of("index", 2, "relevance_score", 0.5), Map.of("index", 1, "relevance_score", 0.4), - Map.of("index", 0, "relevance_score", 0.3)); + Map.of("index", 0, "relevance_score", 0.3) + ); List result = function.apply(rerankResults); assertEquals(3, result.size()); assertEquals(1, result.get(0).getData().length); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java index 01240759ca..b2abf33216 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/postprocess/EmbeddingPostProcessFunctionTest.java @@ -5,17 +5,17 @@ package org.opensearch.ml.common.connector.functions.postprocess; +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; +import java.util.List; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.ml.common.output.model.ModelTensor; -import java.util.Arrays; -import java.util.List; - -import static org.junit.Assert.assertEquals; - public class EmbeddingPostProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java index eb50befdf9..851d7eaab7 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.connector.functions.preprocess; +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; +import java.util.Map; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -15,11 +20,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import java.util.Arrays; -import java.util.Map; - -import static org.junit.Assert.assertEquals; - public class BedrockEmbeddingPreProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java index f739796ae8..72b208d859 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunctionTest.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common.connector.functions.preprocess; +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -15,10 +19,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; - public class CohereEmbeddingPreProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java index d8a6f4d311..d41b5c60f1 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/CohereRerankPreProcessFunctionTest.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common.connector.functions.preprocess; +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -15,10 +19,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; - public class CohereRerankPreProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java index 2e344cbd0f..0eaef7f6a5 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java @@ -5,6 +5,12 @@ package org.opensearch.ml.common.connector.functions.preprocess; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.Arrays; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -18,12 +24,6 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.script.ScriptService; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; - public class DefaultPreProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -58,8 +58,7 @@ public void process_NullInput() { public void process_CorrectInput_WrongProcessedResult() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Preprocess function output is null"); - when(scriptService.compile(any(), any())) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(null)); + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory(null)); MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); functionWithConvertToJsonString.apply(mlInput); } @@ -68,8 +67,7 @@ public void process_CorrectInput_WrongProcessedResult() { public void process_CorrectInput_WrongProcessedResult_WithoutConvertToJsonString() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Preprocess function output is null"); - when(scriptService.compile(any(), any())) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(null)); + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory(null)); MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); functionWithoutConvertToJsonString.apply(mlInput); } @@ -77,8 +75,7 @@ public void process_CorrectInput_WrongProcessedResult_WithoutConvertToJsonString @Test public void process_CorrectInput() { String preprocessResult = "{\"parameters\": { \"input\": \"test doc1\" } }"; - when(scriptService.compile(any(), any())) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult)); + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult)); MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); RemoteInferenceInputDataSet dataSet = functionWithConvertToJsonString.apply(mlInput); assertEquals(1, dataSet.getParameters().size()); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java index 4bc4c4cd8f..6ea8da20f9 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.connector.functions.preprocess; +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -15,13 +22,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertEquals; - public class MultiModalConnectorPreProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -41,7 +41,10 @@ public void setUp() { function = new MultiModalConnectorPreProcessFunction(); textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); - remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build(); + remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("inputText", "value1", "inputImage", "value2")) + .build(); textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java index e4a08ed550..001c602320 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunctionTest.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common.connector.functions.preprocess; +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -15,10 +19,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; - public class OpenAIEmbeddingPreProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java index e50ad2441b..ebee78ef25 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/RemoteInferencePreProcessFunctionTest.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.connector.functions.preprocess; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -18,14 +26,6 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.script.ScriptService; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; - public class RemoteInferencePreProcessFunctionTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -70,8 +70,7 @@ public void process_WrongInput() { public void process_CorrectInput_WrongProcessedResult() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Preprocess function output is null"); - when(scriptService.compile(any(), any())) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(null)); + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory(null)); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); function.apply(mlInput); } @@ -79,8 +78,7 @@ public void process_CorrectInput_WrongProcessedResult() { @Test public void process_CorrectInput() { String preprocessResult = "{\"parameters\": { \"input\": \"test doc1\" } }"; - when(scriptService.compile(any(), any())) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult)); + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult)); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); RemoteInferenceInputDataSet dataSet = function.apply(mlInput); assertEquals(1, dataSet.getParameters().size()); diff --git a/common/src/test/java/org/opensearch/ml/common/controller/MLControllerTest.java b/common/src/test/java/org/opensearch/ml/common/controller/MLControllerTest.java index da5d6415ff..c2abecb0ed 100644 --- a/common/src/test/java/org/opensearch/ml/common/controller/MLControllerTest.java +++ b/common/src/test/java/org/opensearch/ml/common/controller/MLControllerTest.java @@ -22,11 +22,11 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -40,21 +40,17 @@ public class MLControllerTest { private MLController controllerNull; - private final String expectedInputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":" + - "{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}"; + private final String expectedInputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":" + + "{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}"; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @Before public void setUp() throws Exception { - rateLimiter = MLRateLimiter.builder() - .limit("1") - .unit(TimeUnit.MILLISECONDS) - .build(); + rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); - controllerNull = MLController.builder() - .modelId("testModelId").build(); + controllerNull = MLController.builder().modelId("testModelId").build(); controller = MLControllerGenerator("testUser", rateLimiter); @@ -64,17 +60,17 @@ public void setUp() throws Exception { public void readInputStreamSuccess() throws IOException { readInputStream(controller, parsedInput -> { assertEquals("testModelId", parsedInput.getModelId()); - assertEquals(controller.getUserRateLimiter().get("testUser").getLimit(), - parsedInput.getUserRateLimiter().get("testUser").getLimit()); + assertEquals( + controller.getUserRateLimiter().get("testUser").getLimit(), + parsedInput.getUserRateLimiter().get("testUser").getLimit() + ); }); } @Test public void readInputStreamSuccessWithNullFields() throws IOException { controller.setUserRateLimiter(null); - readInputStream(controller, parsedInput -> { - assertNull(parsedInput.getUserRateLimiter()); - }); + readInputStream(controller, parsedInput -> { assertNull(parsedInput.getUserRateLimiter()); }); } @Test @@ -97,14 +93,15 @@ public void testToXContentWithNullMLRateLimiterInUserRateLimiter() throws Except // parseWithNullMLRateLimiterInUserRateLimiterFieldWithException test // below. final String expectedOutputStrWithNullField = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":{\"testUser\":null}}"; - MLController controllerWithTestUserAndEmptyRateLimiter = MLController.builder() - .modelId("testModelId") - .userRateLimiter(new HashMap<>() { - { - put("testUser", null); - } - }) - .build(); + MLController controllerWithTestUserAndEmptyRateLimiter = MLController + .builder() + .modelId("testModelId") + .userRateLimiter(new HashMap<>() { + { + put("testUser", null); + } + }) + .build(); String jsonStr = serializationWithToXContent(controllerWithTestUserAndEmptyRateLimiter); assertEquals(expectedOutputStrWithNullField, jsonStr); } @@ -150,10 +147,8 @@ public void parseWithNullUserRateLimiterFieldWithNoException() throws Exception // Notice that this won't throw an IllegalStateException, which is pretty // different from usual public void parseWithTestUserAndEmptyRateLimiterFieldWithNoException() throws Exception { - final String expectedInputStrWithEmptyField = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":" + - "{\"testUser\":{}}}"; - final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":" + - "{}}"; + final String expectedInputStrWithEmptyField = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":" + "{\"testUser\":{}}}"; + final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":" + "{}}"; testParseFromJsonString(expectedInputStrWithEmptyField, parsedInput -> { try { assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); @@ -166,8 +161,8 @@ public void parseWithTestUserAndEmptyRateLimiterFieldWithNoException() throws Ex @Test public void parseWithNullField() throws Exception { exceptionRule.expect(IllegalStateException.class); - final String expectedInputStrWithNullField = "{\"model_id\":null,\"user_rate_limiter\":" + - "{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}"; + final String expectedInputStrWithNullField = "{\"model_id\":null,\"user_rate_limiter\":" + + "{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}"; testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { try { @@ -180,9 +175,9 @@ public void parseWithNullField() throws Exception { @Test public void parseWithIllegalField() throws Exception { - final String expectedInputStrWithIllegalField = "{\"model_id\":\"testModelId\",\"illegal_field\":\"This field need to be skipped.\",\"user_rate_limiter\":" - + - "{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}"; + final String expectedInputStrWithIllegalField = + "{\"model_id\":\"testModelId\",\"illegal_field\":\"This field need to be skipped.\",\"user_rate_limiter\":" + + "{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { @@ -213,9 +208,9 @@ public void parseWithNullMLRateLimiterInUserRateLimiterFieldWithException() thro @Test public void parseWithIllegalRateLimiterFieldWithException() throws Exception { exceptionRule.expect(RuntimeException.class); - final String expectedInputStrWithIllegalField = "{\"model_id\":\"testModelId\",\"illegal_field\":\"This field need to be skipped.\",\"user_rate_limiter\":" - + - "{\"testUser\":\"Some illegal content that MLRateLimiter parser cannot parse.\"}}"; + final String expectedInputStrWithIllegalField = + "{\"model_id\":\"testModelId\",\"illegal_field\":\"This field need to be skipped.\",\"user_rate_limiter\":" + + "{\"testUser\":\"Some illegal content that MLRateLimiter parser cannot parse.\"}}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { @@ -231,8 +226,7 @@ public void testUserRateLimiterUpdate() { MLRateLimiter rateLimiterWithNumber = MLRateLimiter.builder().limit("1").build(); MLController controllerWithEmptyUserRateLimiter = MLControllerGenerator(); - MLController controllerWithTestUserAndRateLimiterWithNumber = MLControllerGenerator("testUser", - rateLimiterWithNumber); + MLController controllerWithTestUserAndRateLimiterWithNumber = MLControllerGenerator("testUser", rateLimiterWithNumber); MLController controllerWithNewUserAndEmptyRateLimiter = MLControllerGenerator("newUser"); controllerWithEmptyUserRateLimiter.update(controllerNull); @@ -242,16 +236,12 @@ public void testUserRateLimiterUpdate() { assertTrue(controllerWithEmptyUserRateLimiter.getUserRateLimiter().isEmpty()); controllerWithEmptyUserRateLimiter.update(controllerWithTestUserAndRateLimiterWithNumber); - assertEquals("1", controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("testUser") - .getLimit()); - assertNull(controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("testUser") - .getUnit()); + assertEquals("1", controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("testUser").getLimit()); + assertNull(controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("testUser").getUnit()); controllerWithEmptyUserRateLimiter.update(controller); - assertEquals("1", controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("testUser") - .getLimit()); - assertEquals(TimeUnit.MILLISECONDS, controllerWithEmptyUserRateLimiter.getUserRateLimiter() - .get("testUser").getUnit()); + assertEquals("1", controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("testUser").getLimit()); + assertEquals(TimeUnit.MILLISECONDS, controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("testUser").getUnit()); controllerWithEmptyUserRateLimiter.update(controllerWithNewUserAndEmptyRateLimiter); assertTrue(controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("newUser").isEmpty()); @@ -262,38 +252,37 @@ public void testUserRateLimiterIsUpdatable() { MLRateLimiter rateLimiterWithNumber = MLRateLimiter.builder().limit("1").build(); MLController controllerWithEmptyUserRateLimiter = MLControllerGenerator(); - MLController controllerWithTestUserAndRateLimiterWithNumber = MLControllerGenerator("testUser", - rateLimiterWithNumber); - MLController controllerWithNewUserAndRateLimiterWithNumber = MLControllerGenerator("newUser", - rateLimiterWithNumber); + MLController controllerWithTestUserAndRateLimiterWithNumber = MLControllerGenerator("testUser", rateLimiterWithNumber); + MLController controllerWithNewUserAndRateLimiterWithNumber = MLControllerGenerator("newUser", rateLimiterWithNumber); MLController controllerWithNewUserAndEmptyRateLimiter = MLControllerGenerator("newUser"); MLController controllerWithNewUserAndRateLimiter = MLControllerGenerator("newUser", rateLimiter); assertFalse(controllerWithEmptyUserRateLimiter.isDeployRequiredAfterUpdate(null)); assertFalse(controllerWithEmptyUserRateLimiter.isDeployRequiredAfterUpdate(controllerNull)); - assertFalse(controllerWithEmptyUserRateLimiter - .isDeployRequiredAfterUpdate(controllerWithEmptyUserRateLimiter)); - assertFalse(controllerWithEmptyUserRateLimiter - .isDeployRequiredAfterUpdate(controllerWithNewUserAndEmptyRateLimiter)); - - assertFalse(controllerWithEmptyUserRateLimiter - .isDeployRequiredAfterUpdate(controllerWithTestUserAndRateLimiterWithNumber)); - assertFalse(controllerWithTestUserAndRateLimiterWithNumber - .isDeployRequiredAfterUpdate(controllerWithTestUserAndRateLimiterWithNumber)); + assertFalse(controllerWithEmptyUserRateLimiter.isDeployRequiredAfterUpdate(controllerWithEmptyUserRateLimiter)); + assertFalse(controllerWithEmptyUserRateLimiter.isDeployRequiredAfterUpdate(controllerWithNewUserAndEmptyRateLimiter)); + + assertFalse(controllerWithEmptyUserRateLimiter.isDeployRequiredAfterUpdate(controllerWithTestUserAndRateLimiterWithNumber)); + assertFalse( + controllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(controllerWithTestUserAndRateLimiterWithNumber) + ); assertTrue(controllerWithEmptyUserRateLimiter.isDeployRequiredAfterUpdate(controller)); assertTrue(controllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(controller)); - assertFalse(controllerWithTestUserAndRateLimiterWithNumber - .isDeployRequiredAfterUpdate(controllerWithNewUserAndRateLimiterWithNumber)); - assertTrue(controllerWithTestUserAndRateLimiterWithNumber - .isDeployRequiredAfterUpdate(controllerWithNewUserAndRateLimiter)); + assertFalse( + controllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(controllerWithNewUserAndRateLimiterWithNumber) + ); + assertTrue(controllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(controllerWithNewUserAndRateLimiter)); } private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, - expectedInputStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLController parsedInput = MLController.parse(parser); verify.accept(parsedInput); @@ -315,48 +304,40 @@ private String serializationWithToXContent(MLController input) throws IOExceptio } private MLController MLControllerGenerator(String user, MLRateLimiter rateLimiter) { - return MLController.builder() - .modelId("testModelId") - .userRateLimiter(new HashMap<>() { - { - put(user, rateLimiter); - } - }) - .build(); + return MLController.builder().modelId("testModelId").userRateLimiter(new HashMap<>() { + { + put(user, rateLimiter); + } + }).build(); } private MLController MLControllerGenerator(String user) { - return MLController.builder() - .modelId("testModelId") - .userRateLimiter(new HashMap<>() { - { - put(user, MLRateLimiter.builder().build()); - } - }) - .build(); + return MLController.builder().modelId("testModelId").userRateLimiter(new HashMap<>() { + { + put(user, MLRateLimiter.builder().build()); + } + }).build(); } private MLController MLControllerGenerator() { - return MLController.builder() - .modelId("testModelId") - .userRateLimiter(new HashMap<>()) - .build(); + return MLController.builder().modelId("testModelId").userRateLimiter(new HashMap<>()).build(); } @Ignore @Test public void testRateLimiterRemove() { - MLController controllerWithTestUserAndEmptyRateLimiter = MLController.builder() - .modelId("testModelId") - .userRateLimiter(new HashMap<>() { - { - put("testUser", MLRateLimiter.builder().build()); - } - }) - .build(); + MLController controllerWithTestUserAndEmptyRateLimiter = MLController + .builder() + .modelId("testModelId") + .userRateLimiter(new HashMap<>() { + { + put("testUser", MLRateLimiter.builder().build()); + } + }) + .build(); controller.update(controllerWithTestUserAndEmptyRateLimiter); assertNull(controller.getUserRateLimiter().get("testUser")); diff --git a/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java b/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java index d4e4d6b967..b47631cd8b 100644 --- a/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java +++ b/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java @@ -22,11 +22,11 @@ import org.junit.rules.ExpectedException; import org.opensearch.OpenSearchParseException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -50,18 +50,11 @@ public class MLRateLimiterTest { @Before public void setUp() throws Exception { - rateLimiter = MLRateLimiter.builder() - .limit("1") - .unit(TimeUnit.MILLISECONDS) - .build(); + rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); - rateLimiterWithNumber = MLRateLimiter.builder() - .limit("1") - .build(); + rateLimiterWithNumber = MLRateLimiter.builder().limit("1").build(); - rateLimiterWithUnit = MLRateLimiter.builder() - .unit(TimeUnit.MILLISECONDS) - .build(); + rateLimiterWithUnit = MLRateLimiter.builder().unit(TimeUnit.MILLISECONDS).build(); rateLimiterNull = MLRateLimiter.builder().build(); @@ -77,9 +70,7 @@ public void readInputStreamSuccess() throws IOException { @Test public void readInputStreamSuccessWithNullFields() throws IOException { - readInputStream(rateLimiterWithNumber, parsedInput -> { - assertNull(parsedInput.getUnit()); - }); + readInputStream(rateLimiterWithNumber, parsedInput -> { assertNull(parsedInput.getUnit()); }); } @Test @@ -134,8 +125,8 @@ public void parseWithNullField() throws Exception { @Test public void parseWithIllegalField() throws Exception { - final String expectedInputStrWithIllegalField = "{\"limit\":\"1\",\"unit\":" + - "\"MILLISECONDS\",\"illegal_field\":\"This field need to be skipped.\"}"; + final String expectedInputStrWithIllegalField = "{\"limit\":\"1\",\"unit\":" + + "\"MILLISECONDS\",\"illegal_field\":\"This field need to be skipped.\"}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { @@ -205,13 +196,9 @@ public void testRateLimiterIsUpdatable() { @Test public void testRateLimiterIsDeployRequiredAfterUpdate() { - MLRateLimiter rateLimiterWithNumber2 = MLRateLimiter.builder() - .limit("2") - .build(); + MLRateLimiter rateLimiterWithNumber2 = MLRateLimiter.builder().limit("2").build(); - MLRateLimiter rateLimiterWithUnit2 = MLRateLimiter.builder() - .unit(TimeUnit.NANOSECONDS) - .build(); + MLRateLimiter rateLimiterWithUnit2 = MLRateLimiter.builder().unit(TimeUnit.NANOSECONDS).build(); assertTrue(MLRateLimiter.isDeployRequiredAfterUpdate(rateLimiter, rateLimiterWithNumber2)); @@ -224,10 +211,13 @@ public void testRateLimiterIsDeployRequiredAfterUpdate() { } private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, - expectedInputStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLRateLimiter parsedInput = MLRateLimiter.parse(parser); verify.accept(parsedInput); diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java index 2b4e628d05..aaa52ffcff 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/ConversationMetaTests.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.conversation; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -15,14 +22,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchHit; -import java.io.IOException; -import java.time.Instant; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class ConversationMetaTests { ConversationMeta conversationMeta; @@ -58,16 +57,16 @@ public void test_fromSearchHit() throws IOException { @Test public void test_fromMap() { Map params = Map - .of( - ConversationalIndexConstants.META_CREATED_TIME_FIELD, - time.toString(), - ConversationalIndexConstants.META_UPDATED_TIME_FIELD, - time.toString(), - ConversationalIndexConstants.META_NAME_FIELD, - "meta name", - ConversationalIndexConstants.USER_FIELD, - "admin" - ); + .of( + ConversationalIndexConstants.META_CREATED_TIME_FIELD, + time.toString(), + ConversationalIndexConstants.META_UPDATED_TIME_FIELD, + time.toString(), + ConversationalIndexConstants.META_NAME_FIELD, + "meta name", + ConversationalIndexConstants.USER_FIELD, + "admin" + ); ConversationMeta conversationMeta = ConversationMeta.fromMap("test-conversation-meta", params); assertEquals(conversationMeta.getId(), "test-conversation-meta"); assertEquals(conversationMeta.getName(), "meta name"); @@ -88,22 +87,49 @@ public void test_fromStream() throws IOException { @Test public void test_ToXContent() throws IOException { - ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null); + ConversationMeta conversationMeta = new ConversationMeta( + "test_id", + Instant.ofEpochMilli(123), + Instant.ofEpochMilli(123), + "test meta", + "admin", + null + ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); conversationMeta.toXContent(builder, EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - assertEquals(content, "{\"memory_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\"}"); + assertEquals( + content, + "{\"memory_id\":\"test_id\",\"create_time\":\"1970-01-01T00:00:00.123Z\",\"updated_time\":\"1970-01-01T00:00:00.123Z\",\"name\":\"test meta\",\"user\":\"admin\"}" + ); } @Test public void test_toString() { - ConversationMeta conversationMeta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null); - assertEquals("{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin}", conversationMeta.toString()); + ConversationMeta conversationMeta = new ConversationMeta( + "test_id", + Instant.ofEpochMilli(123), + Instant.ofEpochMilli(123), + "test meta", + "admin", + null + ); + assertEquals( + "{id=test_id, name=test meta, created=1970-01-01T00:00:00.123Z, updated=1970-01-01T00:00:00.123Z, user=admin}", + conversationMeta.toString() + ); } @Test public void test_equal() { - ConversationMeta meta = new ConversationMeta("test_id", Instant.ofEpochMilli(123), Instant.ofEpochMilli(123), "test meta", "admin", null); + ConversationMeta meta = new ConversationMeta( + "test_id", + Instant.ofEpochMilli(123), + Instant.ofEpochMilli(123), + "test meta", + "admin", + null + ); assertEquals(meta.equals(conversationMeta), false); } } diff --git a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java index 128d9449ea..9ef58dd394 100644 --- a/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java +++ b/common/src/test/java/org/opensearch/ml/common/conversation/InteractionTests.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.conversation; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -15,14 +23,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchHit; -import java.io.IOException; -import java.time.Instant; -import java.util.Collections; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class InteractionTests { Interaction interaction; @@ -31,43 +31,44 @@ public class InteractionTests { @Before public void setUp() { time = Instant.ofEpochMilli(123); - interaction = Interaction.builder() - .id("test-interaction-id") - .createTime(time) - .conversationId("conversation-id") - .input("sample inputs") - .promptTemplate("some prompt template") - .response("sample responses") - .origin("amazon bedrock") - .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) - .parentInteractionId("parent id") - .traceNum(1) - .build(); + interaction = Interaction + .builder() + .id("test-interaction-id") + .createTime(time) + .conversationId("conversation-id") + .input("sample inputs") + .promptTemplate("some prompt template") + .response("sample responses") + .origin("amazon bedrock") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .parentInteractionId("parent id") + .traceNum(1) + .build(); } @Test public void test_fromMap() { Map params = Map - .of( - ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, - time.toString(), - ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, - "conversation-id", - ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, - "sample inputs", - ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, - "some prompt template", - ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, - "sample responses", - ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, - "amazon bedrock", - ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, - Collections.singletonMap("suggestion", "new suggestion"), - ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, - "parent id", - ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, - 1 - ); + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + time.toString(), + ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, + "conversation-id", + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "sample inputs", + ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, + "some prompt template", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "sample responses", + ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, + "amazon bedrock", + ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, + Collections.singletonMap("suggestion", "new suggestion"), + ConversationalIndexConstants.PARENT_INTERACTIONS_ID_FIELD, + "parent id", + ConversationalIndexConstants.INTERACTIONS_TRACE_NUMBER_FIELD, + 1 + ); Interaction interaction = Interaction.fromMap("test-interaction-id", params); assertEquals(interaction.getId(), "test-interaction-id"); assertEquals(interaction.getCreateTime(), time); @@ -117,74 +118,88 @@ public void test_fromStream() throws IOException { @Test public void test_ToXContent() throws IOException { - Interaction interaction = Interaction.builder() - .conversationId("conversation id") - .origin("amazon bedrock") - .parentInteractionId("parant id") - .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) - .traceNum(1) - .build(); + Interaction interaction = Interaction + .builder() + .conversationId("conversation id") + .origin("amazon bedrock") + .parentInteractionId("parant id") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .traceNum(1) + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); interaction.toXContent(builder, EMPTY_PARAMS); String interactionContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", interactionContent); + assertEquals( + "{\"memory_id\":\"conversation id\",\"message_id\":null,\"create_time\":null,\"input\":null,\"prompt_template\":null,\"response\":null,\"origin\":\"amazon bedrock\",\"additional_info\":{\"suggestion\":\"new suggestion\"},\"parent_message_id\":\"parant id\",\"trace_number\":1}", + interactionContent + ); } @Test public void test_not_equal() { - Interaction interaction1 = Interaction.builder() - .id("id") - .conversationId("conversation id") - .origin("amazon bedrock") - .parentInteractionId("parent id") - .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) - .traceNum(1) - .build(); + Interaction interaction1 = Interaction + .builder() + .id("id") + .conversationId("conversation id") + .origin("amazon bedrock") + .parentInteractionId("parent id") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .traceNum(1) + .build(); assertEquals(interaction.equals(interaction1), false); } @Test public void test_Equal() { - Interaction interaction1 = Interaction.builder() - .id("test-interaction-id") - .createTime(time) - .conversationId("conversation-id") - .input("sample inputs") - .promptTemplate("some prompt template") - .response("sample responses") - .origin("amazon bedrock") - .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) - .parentInteractionId("parent id") - .traceNum(1) - .build(); + Interaction interaction1 = Interaction + .builder() + .id("test-interaction-id") + .createTime(time) + .conversationId("conversation-id") + .input("sample inputs") + .promptTemplate("some prompt template") + .response("sample responses") + .origin("amazon bedrock") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .parentInteractionId("parent id") + .traceNum(1) + .build(); assertEquals(interaction.equals(interaction1), true); } @Test public void test_toString() { - Interaction interaction1 = Interaction.builder() - .id("id") - .conversationId("conversation id") - .origin("amazon bedrock") - .parentInteractionId("parent id") - .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) - .traceNum(1) - .build(); - assertEquals("Interaction{id=id,cid=conversation id,create_time=null,origin=amazon bedrock,input=null,promt_template=null,response=null,additional_info={suggestion=new suggestion},parentInteractionId=parent id,traceNum=1}", interaction1.toString()); + Interaction interaction1 = Interaction + .builder() + .id("id") + .conversationId("conversation id") + .origin("amazon bedrock") + .parentInteractionId("parent id") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .traceNum(1) + .build(); + assertEquals( + "Interaction{id=id,cid=conversation id,create_time=null,origin=amazon bedrock,input=null,promt_template=null,response=null,additional_info={suggestion=new suggestion},parentInteractionId=parent id,traceNum=1}", + interaction1.toString() + ); } @Test public void test_ParentInteraction() { - Interaction parentInteraction = Interaction.builder() - .id("test-interaction-id") - .createTime(time) - .conversationId("conversation-id") - .input("sample inputs") - .promptTemplate("some prompt template") - .response("sample responses") - .origin("amazon bedrock") - .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) - .build(); - assertEquals("Interaction{id=test-interaction-id,cid=conversation-id,create_time=1970-01-01T00:00:00.123Z,origin=amazon bedrock,input=sample inputs,promt_template=some prompt template,response=sample responses,additional_info={suggestion=new suggestion},parentInteractionId=null,traceNum=null}", parentInteraction.toString()); + Interaction parentInteraction = Interaction + .builder() + .id("test-interaction-id") + .createTime(time) + .conversationId("conversation-id") + .input("sample inputs") + .promptTemplate("some prompt template") + .response("sample responses") + .origin("amazon bedrock") + .additionalInfo(Collections.singletonMap("suggestion", "new suggestion")) + .build(); + assertEquals( + "Interaction{id=test-interaction-id,cid=conversation-id,create_time=1970-01-01T00:00:00.123Z,origin=amazon bedrock,input=sample inputs,promt_template=some prompt template,response=sample responses,additional_info={suggestion=new suggestion},parentInteractionId=null,traceNum=null}", + parentInteraction.toString() + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/BooleanValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/BooleanValueTest.java index d33564a1bc..b3b8d455bb 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/BooleanValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/BooleanValueTest.java @@ -5,16 +5,16 @@ package org.opensearch.ml.common.dataframe; -import org.junit.Test; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.XContentBuilder; - -import java.io.IOException; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import java.io.IOException; + +import org.junit.Test; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; + public class BooleanValueTest { @Test public void booleanValue() { diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnMetaTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnMetaTest.java index 0655ab9b3e..22dae7fa9f 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnMetaTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnMetaTest.java @@ -5,6 +5,12 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -13,12 +19,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class ColumnMetaTest { ColumnMeta columnMeta; diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnTypeTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnTypeTest.java index 905fe31159..3c0a5cfa38 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnTypeTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnTypeTest.java @@ -5,12 +5,12 @@ package org.opensearch.ml.common.dataframe; +import java.math.BigDecimal; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import java.math.BigDecimal; - public class ColumnTypeTest { @Rule diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueBuilderTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueBuilderTest.java index 4b783a3755..a9c137f459 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueBuilderTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueBuilderTest.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.dataframe; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; +import static org.junit.Assert.assertEquals; import java.math.BigDecimal; -import static org.junit.Assert.assertEquals; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; public class ColumnValueBuilderTest { @@ -45,12 +45,12 @@ public void build() { assertEquals(2.1f, value.floatValue(), 1e-5); assertEquals(2.1d, value.doubleValue(), 1e-5); - value = ColumnValueBuilder.build((short)2); + value = ColumnValueBuilder.build((short) 2); assertEquals(ColumnType.SHORT, value.columnType()); assertEquals(2, value.shortValue()); assertEquals(2.0d, value.doubleValue(), 1e-5); - value = ColumnValueBuilder.build((long)2); + value = ColumnValueBuilder.build((long) 2); assertEquals(ColumnType.LONG, value.columnType()); assertEquals(2, value.longValue()); assertEquals(2.0d, value.doubleValue(), 1e-5); @@ -63,4 +63,4 @@ public void build_IllegalType() { Object obj = new BigDecimal("0"); ColumnValueBuilder.build(obj); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueReaderTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueReaderTest.java index 07287da537..baac846c6f 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueReaderTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueReaderTest.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; + import java.io.IOException; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import static org.junit.Assert.assertEquals; - public class ColumnValueReaderTest { ColumnValueReader reader = new ColumnValueReader(); @@ -86,7 +86,7 @@ public void read_FloatValue() throws IOException { @Test public void read_ShortValue() throws IOException { - ColumnValue value = new ShortValue((short)2); + ColumnValue value = new ShortValue((short) 2); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); value.writeTo(bytesStreamOutput); value = reader.read(bytesStreamOutput.bytes().streamInput()); @@ -96,7 +96,7 @@ public void read_ShortValue() throws IOException { @Test public void read_LongValue() throws IOException { - ColumnValue value = new LongValue((long)2); + ColumnValue value = new LongValue((long) 2); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); value.writeTo(bytesStreamOutput); value = reader.read(bytesStreamOutput.bytes().streamInput()); diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueTest.java index 3e8661d275..cc60a85608 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueTest.java @@ -5,12 +5,12 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertTrue; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import static org.junit.Assert.assertTrue; - public class ColumnValueTest { @Rule diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/DataFrameBuilderTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/DataFrameBuilderTest.java index 98f4282254..d9d20e7b4d 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/DataFrameBuilderTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/DataFrameBuilderTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; + import java.io.IOException; import java.util.Collections; import java.util.HashMap; @@ -16,8 +18,6 @@ import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import static org.junit.Assert.assertEquals; - public class DataFrameBuilderTest { @Rule @@ -25,10 +25,7 @@ public class DataFrameBuilderTest { @Test public void emptyDataFrame_Success() { - ColumnMeta[] columnMetas = new ColumnMeta[]{ColumnMeta.builder() - .name("k1") - .columnType(ColumnType.DOUBLE) - .build()}; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() }; DataFrame dataFrame = DataFrameBuilder.emptyDataFrame(columnMetas); assertEquals(0, dataFrame.size()); } @@ -68,9 +65,7 @@ public void load_Exception_NullInputMapList() { public void load_Success_ColumnMetasAndInputMapList() { Map map = new HashMap<>(); map.put("k1", 2.3D); - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() }; DataFrame dataFrame = DataFrameBuilder.load(columnMetas, Collections.singletonList(map)); assertEquals(1, dataFrame.size()); } @@ -91,17 +86,13 @@ public void load_Exception_NullColumnMetas() { @Test(expected = IllegalArgumentException.class) public void load_Exception_ColumnMetasAndEmptyInputMapList() { - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() }; DataFrameBuilder.load(columnMetas, Collections.emptyList()); } @Test(expected = IllegalArgumentException.class) public void load_Exception_ColumnMetasAndNullInputMapList() { - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() }; DataFrameBuilder.load(columnMetas, null); } @@ -112,10 +103,9 @@ public void load_Exception_DifferentColumnsInColumnMetasAndInputMapList() { Map map = new HashMap<>(); map.put("k1", 2.3D); - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build(), - ColumnMeta.builder().name("k2").columnType(ColumnType.DOUBLE).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { + ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build(), + ColumnMeta.builder().name("k2").columnType(ColumnType.DOUBLE).build() }; DataFrameBuilder.load(columnMetas, Collections.singletonList(map)); } @@ -126,9 +116,7 @@ public void load_Exception_DifferentTypesForSameField() { Map map = new HashMap<>(); map.put("k1", 2.3D); - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.INTEGER).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.INTEGER).build() }; DataFrameBuilder.load(columnMetas, Collections.singletonList(map)); } @@ -139,9 +127,7 @@ public void load_Exception_DifferentFields() { Map map = new HashMap<>(); map.put("k2", 2.3D); - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.INTEGER).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.INTEGER).build() }; DataFrameBuilder.load(columnMetas, Collections.singletonList(map)); } @@ -158,4 +144,4 @@ public void load_Success_StreamInput() throws IOException { dataFrame = DataFrameBuilder.load(bytesStreamOutput.bytes().streamInput()); assertEquals(1, dataFrame.size()); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/DefaultDataFrameTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/DefaultDataFrameTest.java index eb1460f24e..da34540bc9 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/DefaultDataFrameTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/DefaultDataFrameTest.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -15,16 +19,12 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - public class DefaultDataFrameTest { DefaultDataFrame defaultDataFrame; @@ -36,23 +36,11 @@ public class DefaultDataFrameTest { @Before public void setUp() { ColumnMeta[] columnMetas = new ColumnMeta[4]; - columnMetas[0] = ColumnMeta.builder() - .name("c1") - .columnType(ColumnType.STRING) - .build(); - columnMetas[1] = ColumnMeta.builder() - .name("c2") - .columnType(ColumnType.INTEGER) - .build(); - columnMetas[2] = ColumnMeta.builder() - .name("c3") - .columnType(ColumnType.DOUBLE) - .build(); - - columnMetas[3] = ColumnMeta.builder() - .name("c4") - .columnType(ColumnType.BOOLEAN) - .build(); + columnMetas[0] = ColumnMeta.builder().name("c1").columnType(ColumnType.STRING).build(); + columnMetas[1] = ColumnMeta.builder().name("c2").columnType(ColumnType.INTEGER).build(); + columnMetas[2] = ColumnMeta.builder().name("c3").columnType(ColumnType.DOUBLE).build(); + + columnMetas[3] = ColumnMeta.builder().name("c4").columnType(ColumnType.BOOLEAN).build(); Row row = new Row(4); row.setValue(0, new StringValue("string")); @@ -156,8 +144,7 @@ public void appendRow_Exception_DifferentColumns() { @Test public void appendRow_Exception_DifferentColumnTypes() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("the column type is different in column meta:BOOLEAN and input row:DOUBLE for " + - "index: 3"); + exceptionRule.expectMessage("the column type is different in column meta:BOOLEAN and input row:DOUBLE for " + "index: 3"); Row row = new Row(4); row.setValue(0, new StringValue("string2")); row.setValue(1, new IntValue(2)); @@ -173,21 +160,21 @@ public void columnMetas_Success() { } @Test - public void remove_Exception_InputColumnIndexBiggerThanColumensLength(){ + public void remove_Exception_InputColumnIndexBiggerThanColumensLength() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("columnIndex can't be negative or bigger than columns length:4"); defaultDataFrame.remove(4); } @Test - public void remove_Exception_InputColumnIndexNegtiveColumensLength(){ + public void remove_Exception_InputColumnIndexNegtiveColumensLength() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("columnIndex can't be negative or bigger than columns length:4"); defaultDataFrame.remove(-1); } @Test - public void remove_EmptyColumnMeta(){ + public void remove_EmptyColumnMeta() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("columnIndex can't be negative or bigger than columns length:0"); DefaultDataFrame dataFrame = new DefaultDataFrame(new ColumnMeta[0]); @@ -196,31 +183,31 @@ public void remove_EmptyColumnMeta(){ } @Test - public void remove_Success(){ + public void remove_Success() { DataFrame dataFrame = defaultDataFrame.remove(3); assertEquals(3, dataFrame.columnMetas().length); assertEquals(3, dataFrame.getRow(0).size()); } @Test - public void select_Success(){ - DataFrame dataFrame = defaultDataFrame.select(new int[]{1, 3}); + public void select_Success() { + DataFrame dataFrame = defaultDataFrame.select(new int[] { 1, 3 }); assertEquals(2, dataFrame.columnMetas().length); assertEquals(2, dataFrame.getRow(0).size()); } @Test - public void select_Exception_EmptyInputColumns(){ + public void select_Exception_EmptyInputColumns() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("columns can't be null or empty"); defaultDataFrame.select(new int[0]); } @Test - public void select_Exception_InvalidColumn(){ + public void select_Exception_InvalidColumn() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("columnIndex can't be negative or bigger than columns length"); - defaultDataFrame.select(new int[]{5}); + defaultDataFrame.select(new int[] { 5 }); } @Test @@ -232,44 +219,47 @@ public void testToXContent() throws IOException { assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"column_metas\":[" + - "{\"name\":\"c1\",\"column_type\":\"STRING\"}," + - "{\"name\":\"c2\",\"column_type\":\"INTEGER\"}," + - "{\"name\":\"c3\",\"column_type\":\"DOUBLE\"}," + - "{\"name\":\"c4\",\"column_type\":\"BOOLEAN\"}]," + - "\"rows\":[" + - "{\"values\":[" + - "{\"column_type\":\"STRING\",\"value\":\"string\"}," + - "{\"column_type\":\"INTEGER\",\"value\":1}," + - "{\"column_type\":\"DOUBLE\",\"value\":2.0}," + - "{\"column_type\":\"BOOLEAN\",\"value\":true}]}]}", jsonStr); + assertEquals( + "{\"column_metas\":[" + + "{\"name\":\"c1\",\"column_type\":\"STRING\"}," + + "{\"name\":\"c2\",\"column_type\":\"INTEGER\"}," + + "{\"name\":\"c3\",\"column_type\":\"DOUBLE\"}," + + "{\"name\":\"c4\",\"column_type\":\"BOOLEAN\"}]," + + "\"rows\":[" + + "{\"values\":[" + + "{\"column_type\":\"STRING\",\"value\":\"string\"}," + + "{\"column_type\":\"INTEGER\",\"value\":1}," + + "{\"column_type\":\"DOUBLE\",\"value\":2.0}," + + "{\"column_type\":\"BOOLEAN\",\"value\":true}]}]}", + jsonStr + ); } @Test public void testParse_EmptyDataFrame() throws IOException { - ColumnMeta[] columnMetas = new ColumnMeta[] {new ColumnMeta("test_int", ColumnType.INTEGER)}; + ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test_int", ColumnType.INTEGER) }; DefaultDataFrame dataFrame = new DefaultDataFrame(columnMetas); TestHelper.testParse(dataFrame, function, true); } @Test public void testParse_DataFrame() throws IOException { - ColumnMeta[] columnMetas = new ColumnMeta[] {new ColumnMeta("test_int", ColumnType.INTEGER)}; + ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test_int", ColumnType.INTEGER) }; DefaultDataFrame dataFrame = new DefaultDataFrame(columnMetas); - dataFrame.appendRow(new Integer[]{1}); - dataFrame.appendRow(new Integer[]{2}); + dataFrame.appendRow(new Integer[] { 1 }); + dataFrame.appendRow(new Integer[] { 2 }); TestHelper.testParse(dataFrame, function, true); } @Test public void testParse_WrongExtraField() throws IOException { - ColumnMeta[] columnMetas = new ColumnMeta[] {new ColumnMeta("test_int", ColumnType.INTEGER)}; + ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test_int", ColumnType.INTEGER) }; DefaultDataFrame dataFrame = new DefaultDataFrame(columnMetas); - dataFrame.appendRow(new Integer[]{1}); - dataFrame.appendRow(new Integer[]{2}); - String jsonStr = "{\"wrong_field\":{\"test\":\"abc\"},\"column_metas\":[{\"name\":\"test_int\",\"column_type\":" + - "\"INTEGER\"}],\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":1}]},{\"values\":" + - "[{\"column_type\":\"INTEGER\",\"value\":2}]}]}"; + dataFrame.appendRow(new Integer[] { 1 }); + dataFrame.appendRow(new Integer[] { 2 }); + String jsonStr = "{\"wrong_field\":{\"test\":\"abc\"},\"column_metas\":[{\"name\":\"test_int\",\"column_type\":" + + "\"INTEGER\"}],\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":1}]},{\"values\":" + + "[{\"column_type\":\"INTEGER\",\"value\":2}]}]}"; TestHelper.testParseFromString(dataFrame, jsonStr, function); } } diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/DoubleValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/DoubleValueTest.java index 84e79a1561..319d0051c9 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/DoubleValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/DoubleValueTest.java @@ -5,6 +5,9 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + import java.io.IOException; import org.junit.Test; @@ -12,9 +15,6 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class DoubleValueTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/FloatValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/FloatValueTest.java index ce30022bd2..6afbf01c0a 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/FloatValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/FloatValueTest.java @@ -5,14 +5,14 @@ package org.opensearch.ml.common.dataframe; -import org.junit.Test; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.XContentBuilder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import java.io.IOException; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import org.junit.Test; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; public class FloatValueTest { diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/IntValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/IntValueTest.java index 1c307ab1b9..1ab0ddfbb9 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/IntValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/IntValueTest.java @@ -5,14 +5,14 @@ package org.opensearch.ml.common.dataframe; -import org.junit.Test; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.XContentBuilder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import java.io.IOException; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import org.junit.Test; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; public class IntValueTest { diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/LongValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/LongValueTest.java index 0266c24fc5..0e1e9d2668 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/LongValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/LongValueTest.java @@ -5,20 +5,20 @@ package org.opensearch.ml.common.dataframe; -import org.junit.Test; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.XContentBuilder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import java.io.IOException; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import org.junit.Test; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; public class LongValueTest { @Test public void longValue() { - LongValue longValue = new LongValue((long)2); + LongValue longValue = new LongValue((long) 2); assertEquals(ColumnType.LONG, longValue.columnType()); assertEquals(2L, longValue.getValue()); assertEquals(2.0d, longValue.doubleValue(), 1e-5); @@ -26,7 +26,7 @@ public void longValue() { @Test public void testToXContent() throws IOException { - LongValue longValue = new LongValue((long)2); + LongValue longValue = new LongValue((long) 2); XContentBuilder builder = XContentFactory.jsonBuilder(); longValue.toXContent(builder); diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/NullValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/NullValueTest.java index a4c42424f4..2ea26a4a28 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/NullValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/NullValueTest.java @@ -5,17 +5,17 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; - public class NullValueTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/RowTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/RowTest.java index 0fb081d105..a42ca77a73 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/RowTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/RowTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.ml.common.TestHelper.testParse; +import static org.opensearch.ml.common.TestHelper.testParseFromString; + import java.io.IOException; import java.util.Iterator; import java.util.function.Function; @@ -18,13 +25,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.opensearch.ml.common.TestHelper.testParse; -import static org.opensearch.ml.common.TestHelper.testParseFromString; - public class RowTest { Row row; @@ -38,7 +38,7 @@ public void setup() { row = new Row(1); row.setValue(0, ColumnValueBuilder.build(0)); - function = parser -> { + function = parser -> { try { return Row.parse(parser); } catch (IOException e) { @@ -124,7 +124,7 @@ public void select() { row = new Row(2); row.setValue(0, ColumnValueBuilder.build(0)); row.setValue(1, ColumnValueBuilder.build(false)); - row = row.select(new int[]{1}); + row = row.select(new int[] { 1 }); assertEquals(1, row.size()); assertFalse(row.getValue(0).booleanValue()); } @@ -142,45 +142,70 @@ public void testToXContent() throws IOException { @Test public void testParse_NullValue() throws IOException { - ColumnValue[] values = new ColumnValue[] {new NullValue()}; + ColumnValue[] values = new ColumnValue[] { new NullValue() }; Row row = new Row(values); testParse(row, function); } @Test public void testParse_NullValue_AtLast() throws IOException { - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new StringValue("test"), new BooleanValue(true), new NullValue()}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new StringValue("test"), + new BooleanValue(true), + new NullValue() }; Row row = new Row(values); testParse(row, function); } @Test public void testParse_NullValue_AtFirst() throws IOException { - ColumnValue[] values = new ColumnValue[] {new NullValue(), new IntValue(1), new DoubleValue(2.0), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new NullValue(), + new IntValue(1), + new DoubleValue(2.0), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); testParse(row, function); } @Test public void testParse_NullValue_AtMiddle() throws IOException { - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); testParse(row, function); } @Test public void testParse_ExtraWrongValueField() throws IOException { - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); - String jsonStr = "{\"values\":[{\"column_type\":\"INTEGER\",\"value\":1},{\"column_type\":\"DOUBLE\",\"value\":2}," + - "{\"column_type\":\"NULL\"},{\"column_type\":\"STRING\",\"value\":\"test\"},{\"column_type\":\"BOOLEAN\"," + - "\"value\":true},{\"column_type\":\"WRONG\",\"value\":true}],\"wrong_filed\":{\"test\":\"abc\"}}"; + String jsonStr = "{\"values\":[{\"column_type\":\"INTEGER\",\"value\":1},{\"column_type\":\"DOUBLE\",\"value\":2}," + + "{\"column_type\":\"NULL\"},{\"column_type\":\"STRING\",\"value\":\"test\"},{\"column_type\":\"BOOLEAN\"," + + "\"value\":true},{\"column_type\":\"WRONG\",\"value\":true}],\"wrong_filed\":{\"test\":\"abc\"}}"; testParseFromString(row, jsonStr, function); } @Test public void testParse_EmptyValueField() throws IOException { - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); String jsonStr = "{\"values\":[{}]}"; testParseFromString(row, jsonStr, function); @@ -190,7 +215,12 @@ public void testParse_EmptyValueField() throws IOException { public void testParse_WrongColumnTypeField() throws IOException { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("wrong column type, expect column_type field but got column_type_wrong"); - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); String jsonStr = "{\"values\":[{\"column_type_wrong\":\"INTEGER\",\"value\":1}]}"; testParseFromString(row, jsonStr, function); @@ -200,7 +230,12 @@ public void testParse_WrongColumnTypeField() throws IOException { public void testParse_WrongValueField() throws IOException { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("wrong column value, expect value field but got value_wrong"); - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); String jsonStr = "{\"values\":[{\"column_type\":\"INTEGER\",\"value_wrong\":1}]}"; testParseFromString(row, jsonStr, function); @@ -214,34 +249,69 @@ public void testEquals_EmptyValues() { Row row2 = new Row(values2); assertTrue(row1.equals(row2)); - ColumnValue[] values3 = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values3 = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row3 = new Row(values3); assertFalse(row1.equals(row3)); } @Test public void testEquals_AllValuesMatch() { - ColumnValue[] values1 = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values1 = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row1 = new Row(values1); - ColumnValue[] values2 = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values2 = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row2 = new Row(values2); assertTrue(row1.equals(row2)); } @Test public void testEquals_SomeValueNotMatch() { - ColumnValue[] values1 = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values1 = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row1 = new Row(values1); - ColumnValue[] values2 = new ColumnValue[] {new IntValue(2), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values2 = new ColumnValue[] { + new IntValue(2), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row2 = new Row(values2); assertFalse(row1.equals(row2)); } @Test public void testEquals_SomeTypeNotMatch() { - ColumnValue[] values1 = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values1 = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row1 = new Row(values1); - ColumnValue[] values2 = new ColumnValue[] {new DoubleValue(1.0), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values2 = new ColumnValue[] { + new DoubleValue(1.0), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row2 = new Row(values2); assertFalse(row1.equals(row2)); } diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ShortValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ShortValueTest.java index 2ef5d4f411..5d21b05cb2 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ShortValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ShortValueTest.java @@ -5,28 +5,28 @@ package org.opensearch.ml.common.dataframe; -import org.junit.Test; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.XContentBuilder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import java.io.IOException; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import org.junit.Test; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; public class ShortValueTest { @Test public void shortValue() { - ShortValue shortValue = new ShortValue((short)2); + ShortValue shortValue = new ShortValue((short) 2); assertEquals(ColumnType.SHORT, shortValue.columnType()); - assertEquals((short)2, shortValue.getValue()); + assertEquals((short) 2, shortValue.getValue()); assertEquals(2.0d, shortValue.doubleValue(), 1e-5); } @Test public void testToXContent() throws IOException { - ShortValue shortValue = new ShortValue((short)2); + ShortValue shortValue = new ShortValue((short) 2); XContentBuilder builder = XContentFactory.jsonBuilder(); shortValue.toXContent(builder); diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/StringValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/StringValueTest.java index 0f4df02a47..c0810f4f38 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/StringValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/StringValueTest.java @@ -5,16 +5,16 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class StringValueTest { @Test public void stringValue() { diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java index df50348ff5..b949208472 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java @@ -1,5 +1,12 @@ package org.opensearch.ml.common.dataset; +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -8,77 +15,71 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; - -import java.io.IOException; -import java.util.function.Function; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; -import static org.junit.Assert.assertEquals; -import static org.opensearch.ml.common.TestHelper.contentObjectToString; -import static org.opensearch.ml.common.TestHelper.testParseFromString; - public class AsymmetricTextEmbeddingParametersTest { - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); - AsymmetricTextEmbeddingParameters params; - private Function function = parser -> { - try { - return (AsymmetricTextEmbeddingParameters) AsymmetricTextEmbeddingParameters.parse(parser); - } catch (IOException e) { - throw new RuntimeException("failed to parse AsymmetricTextEmbeddingParameters", e); - } - }; + AsymmetricTextEmbeddingParameters params; + private Function function = parser -> { + try { + return (AsymmetricTextEmbeddingParameters) AsymmetricTextEmbeddingParameters.parse(parser); + } catch (IOException e) { + throw new RuntimeException("failed to parse AsymmetricTextEmbeddingParameters", e); + } + }; - @Before - public void setUp() { - params = AsymmetricTextEmbeddingParameters.builder() - .embeddingContentType(EmbeddingContentType.QUERY) - .build(); - } + @Before + public void setUp() { + params = AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.QUERY).build(); + } - @Test - public void parse_AsymmetricTextEmbeddingParameters() throws IOException { - TestHelper.testParse(params, function); - } + @Test + public void parse_AsymmetricTextEmbeddingParameters() throws IOException { + TestHelper.testParse(params, function); + } - @Test - public void parse_AsymmetricTextEmbeddingParameters_Passage() throws IOException { - String paramsStr = contentObjectToString(params); - testParseFromString(params, paramsStr.replace("QUERY", "PASSAGE"), function); - } + @Test + public void parse_AsymmetricTextEmbeddingParameters_Passage() throws IOException { + String paramsStr = contentObjectToString(params); + testParseFromString(params, paramsStr.replace("QUERY", "PASSAGE"), function); + } - @Test - public void parse_AsymmetricTextEmbeddingParameters_Invalid() throws IOException { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("No enum constant org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU"); - String paramsStr = contentObjectToString(params); - testParseFromString(params, paramsStr.replace("QUERY","fu"), function); - } + @Test + public void parse_AsymmetricTextEmbeddingParameters_Invalid() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule + .expectMessage( + "No enum constant org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType.FU" + ); + String paramsStr = contentObjectToString(params); + testParseFromString(params, paramsStr.replace("QUERY", "fu"), function); + } - @Test - public void parse_EmptyAsymmetricTextEmbeddingParameters() throws IOException { - TestHelper.testParse(AsymmetricTextEmbeddingParameters.builder().build(), function); - } + @Test + public void parse_EmptyAsymmetricTextEmbeddingParameters() throws IOException { + TestHelper.testParse(AsymmetricTextEmbeddingParameters.builder().build(), function); + } - @Test - public void readInputStream_Success() throws IOException { - readInputStream(params); - } + @Test + public void readInputStream_Success() throws IOException { + readInputStream(params); + } - @Test - public void readInputStream_Success_EmptyParams() throws IOException { - readInputStream(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()); - } + @Test + public void readInputStream_Success_EmptyParams() throws IOException { + readInputStream(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()); + } - private void readInputStream(AsymmetricTextEmbeddingParameters params) throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - params.writeTo(bytesStreamOutput); + private void readInputStream(AsymmetricTextEmbeddingParameters params) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + params.writeTo(bytesStreamOutput); - StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); - AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput); - assertEquals(params, parsedParams); - } + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput); + assertEquals(params, parsedParams); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java index d7c6294d20..eb1d4c33fe 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.dataset; +import static org.junit.Assert.assertEquals; + import java.io.IOException; import java.util.Collections; import java.util.HashMap; @@ -13,19 +15,20 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.ml.common.dataframe.DataFrameBuilder; -import static org.junit.Assert.assertEquals; - public class DataFrameInputDatasetTest { @Test public void writeTo_Success() throws IOException { - DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder() - .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }}))) + DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset + .builder() + .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + }))) .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); dataFrameInputDataset.writeTo(bytesStreamOutput); assertEquals(21, bytesStreamOutput.size()); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDatasetTest.java index f332f18db5..c8c123de8f 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDatasetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/QuestionAnsweringInputDatasetTest.java @@ -4,6 +4,10 @@ */ package org.opensearch.ml.common.dataset; +import static org.junit.Assert.assertThrows; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesReference; @@ -12,13 +16,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.util.List; - -import static org.junit.Assert.assertThrows; - public class QuestionAnsweringInputDatasetTest { - + @Test public void testStreaming() throws IOException { String question = "What color is apple"; @@ -36,15 +35,16 @@ public void testStreaming() throws IOException { @Test public void noContext_ThenFail() { String question = "What color is apple"; - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, - () -> QuestionAnsweringInputDataSet.builder().question(question).build()); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> QuestionAnsweringInputDataSet.builder().question(question).build() + ); assert (e.getMessage().equals("Context is not provided")); } @Test public void noQuestion_ThenFail() { String context = "I like Apples. They are red"; - assertThrows(IllegalArgumentException.class, - () -> QuestionAnsweringInputDataSet.builder().context(context).build()); + assertThrows(IllegalArgumentException.class, () -> QuestionAnsweringInputDataSet.builder().context(context).build()); } } diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java index d1a3af0ce2..3502ea8723 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.dataset; +import static org.junit.Assert.assertEquals; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -17,8 +19,6 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.search.builder.SearchSourceBuilder; -import static org.junit.Assert.assertEquals; - public class SearchQueryInputDatasetTest { @Rule @@ -26,7 +26,8 @@ public class SearchQueryInputDatasetTest { @Test public void writeTo_Success() throws IOException { - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder() + SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset + .builder() .indices(Arrays.asList("index1")) .searchSourceBuilder(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1)) .build(); @@ -45,9 +46,6 @@ public void writeTo_Success() throws IOException { public void init_EmptyIndices() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("indices can't be empty"); - SearchQueryInputDataset.builder() - .indices(new ArrayList<>()) - .searchSourceBuilder(new SearchSourceBuilder().size(1)) - .build(); + SearchQueryInputDataset.builder().indices(new ArrayList<>()).searchSourceBuilder(new SearchSourceBuilder().size(1)).build(); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/TextDocsInputDataSetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/TextDocsInputDataSetTest.java index 89f629e7c2..811a1243e7 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/TextDocsInputDataSetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/TextDocsInputDataSetTest.java @@ -5,17 +5,17 @@ package org.opensearch.ml.common.dataset; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Arrays; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; -import java.io.IOException; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; - public class TextDocsInputDataSetTest { @Rule diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java index 74969d537f..3a3695abe7 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/TextSimilarityInputDatasetTest.java @@ -31,7 +31,7 @@ import org.opensearch.core.common.io.stream.StreamOutput; public class TextSimilarityInputDatasetTest { - + @Test public void testStreaming() throws IOException { List docs = List.of("That is a happy dog", "it's summer"); @@ -50,8 +50,10 @@ public void testStreaming() throws IOException { public void noPairs_ThenFail() { List docs = List.of(); String queryText = "today is sunny"; - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, - () -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build()); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build() + ); assert (e.getMessage().equals("No text documents were provided")); } @@ -59,7 +61,6 @@ public void noPairs_ThenFail() { public void noQuery_ThenFail() { List docs = List.of("That is a happy dog", "it's summer"); String queryText = null; - assertThrows(NullPointerException.class, - () -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build()); + assertThrows(NullPointerException.class, () -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build()); } } diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java index 22a549a1d1..17ea68ce04 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java @@ -5,6 +5,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; + import org.junit.Assert; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -52,7 +53,11 @@ public void writeTo_withActionType() throws IOException { parameters.put("key1", "test value1"); parameters.put("key2", "test value2"); ActionType actionType = ActionType.from("predict"); - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).actionType(actionType).build(); + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(parameters) + .actionType(actionType) + .build(); BytesStreamOutput output = new BytesStreamOutput(); inputDataSet.writeTo(output); diff --git a/common/src/test/java/org/opensearch/ml/common/exception/MLExceptionTest.java b/common/src/test/java/org/opensearch/ml/common/exception/MLExceptionTest.java index 5b872c0f36..21ce33f4ed 100644 --- a/common/src/test/java/org/opensearch/ml/common/exception/MLExceptionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/exception/MLExceptionTest.java @@ -5,12 +5,12 @@ package org.opensearch.ml.common.exception; -import org.junit.Test; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import org.junit.Test; + public class MLExceptionTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/exception/MLLimitExceededExceptionTest.java b/common/src/test/java/org/opensearch/ml/common/exception/MLLimitExceededExceptionTest.java index a85e07e02d..ed3b6bb180 100644 --- a/common/src/test/java/org/opensearch/ml/common/exception/MLLimitExceededExceptionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/exception/MLLimitExceededExceptionTest.java @@ -5,11 +5,10 @@ package org.opensearch.ml.common.exception; -import org.junit.Test; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; + +import org.junit.Test; public class MLLimitExceededExceptionTest { diff --git a/common/src/test/java/org/opensearch/ml/common/exception/MLResourceNotFoundExceptionTest.java b/common/src/test/java/org/opensearch/ml/common/exception/MLResourceNotFoundExceptionTest.java index 9409859e71..7e1c41d297 100644 --- a/common/src/test/java/org/opensearch/ml/common/exception/MLResourceNotFoundExceptionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/exception/MLResourceNotFoundExceptionTest.java @@ -5,11 +5,11 @@ package org.opensearch.ml.common.exception; -import org.junit.Test; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import org.junit.Test; + public class MLResourceNotFoundExceptionTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/exception/MLValidationExceptionTest.java b/common/src/test/java/org/opensearch/ml/common/exception/MLValidationExceptionTest.java index f9e3a04fdc..eee1038a52 100644 --- a/common/src/test/java/org/opensearch/ml/common/exception/MLValidationExceptionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/exception/MLValidationExceptionTest.java @@ -5,12 +5,11 @@ package org.opensearch.ml.common.exception; -import org.junit.Test; - import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import org.junit.Test; + public class MLValidationExceptionTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java index a63f815232..a92941fa7b 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java @@ -5,21 +5,35 @@ package org.opensearch.ml.common.input; -import lombok.NonNull; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.dataframe.ColumnMeta; import org.opensearch.ml.common.dataframe.ColumnType; import org.opensearch.ml.common.dataframe.ColumnValue; @@ -27,8 +41,6 @@ import org.opensearch.ml.common.dataframe.DefaultDataFrame; import org.opensearch.ml.common.dataframe.DoubleValue; import org.opensearch.ml.common.dataframe.Row; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; @@ -41,18 +53,7 @@ import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; +import lombok.NonNull; public class MLInputTest { @@ -79,11 +80,12 @@ public void setUp() throws Exception { rows.add(new Row(new ColumnValue[] { new DoubleValue(2.0) })); rows.add(new Row(new ColumnValue[] { new DoubleValue(3.0) })); DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows); - input = MLInput.builder() - .algorithm(algorithm) - .parameters(LinearRegressionParams.builder().learningRate(0.1).build()) - .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) - .build(); + input = MLInput + .builder() + .algorithm(algorithm) + .parameters(LinearRegressionParams.builder().learningRate(0.1).build()) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) + .build(); } @Test @@ -96,11 +98,13 @@ public void constructor_NullAlgorithm() { @Test public void parse_LinearRegression() throws IOException { String indexName = "index1"; - SearchQueryInputDataset inputDataset = SearchQueryInputDataset.builder() - .indices(Arrays.asList(indexName)) - .searchSourceBuilder(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1)) - .build(); - String expectedInputStr = "{\"algorithm\":\"LINEAR_REGRESSION\",\"input_index\":[\"index1\"],\"input_query\":{\"size\":1,\"query\":{\"match_all\":{\"boost\":1.0}}}}"; + SearchQueryInputDataset inputDataset = SearchQueryInputDataset + .builder() + .indices(Arrays.asList(indexName)) + .searchSourceBuilder(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1)) + .build(); + String expectedInputStr = + "{\"algorithm\":\"LINEAR_REGRESSION\",\"input_index\":[\"index1\"],\"input_query\":{\"size\":1,\"query\":{\"match_all\":{\"boost\":1.0}}}}"; testParse(FunctionName.LINEAR_REGRESSION, inputDataset, expectedInputStr, parsedInput -> { assertNotNull(parsedInput.getInputDataset()); assertEquals(1, ((SearchQueryInputDataset) parsedInput.getInputDataset()).getIndices().size()); @@ -109,15 +113,20 @@ public void parse_LinearRegression() throws IOException { @NonNull DataFrame dataFrame = new DefaultDataFrame( - new ColumnMeta[] { ColumnMeta.builder().name("value").columnType(ColumnType.FLOAT).build() }); + new ColumnMeta[] { ColumnMeta.builder().name("value").columnType(ColumnType.FLOAT).build() } + ); dataFrame.appendRow(new Float[] { 1.0f }); DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder().dataFrame(dataFrame).build(); - expectedInputStr = "{\"algorithm\":\"LINEAR_REGRESSION\",\"input_data\":{\"column_metas\":[{\"name\":\"value\",\"column_type\":\"FLOAT\"}],\"rows\":[{\"values\":[{\"column_type\":\"FLOAT\",\"value\":1.0}]}]}}"; + expectedInputStr = + "{\"algorithm\":\"LINEAR_REGRESSION\",\"input_data\":{\"column_metas\":[{\"name\":\"value\",\"column_type\":\"FLOAT\"}],\"rows\":[{\"values\":[{\"column_type\":\"FLOAT\",\"value\":1.0}]}]}}"; testParse(FunctionName.LINEAR_REGRESSION, dataFrameInputDataset, expectedInputStr, parsedInput -> { assertNotNull(parsedInput.getInputDataset()); assertEquals(1, ((DataFrameInputDataset) parsedInput.getInputDataset()).getDataFrame().size()); - assertEquals(1.0f, ((DataFrameInputDataset) parsedInput.getInputDataset()).getDataFrame().getRow(0) - .getValue(0).floatValue(), 1e-5); + assertEquals( + 1.0f, + ((DataFrameInputDataset) parsedInput.getInputDataset()).getDataFrame().getRow(0).getValue(0).floatValue(), + 1e-5 + ); }); } @@ -125,13 +134,15 @@ private void parse_NLPModel(FunctionName functionName) throws IOException { String sentence = "test sentence"; String column = "column1"; Integer position = 1; - ModelResultFilter resultFilter = ModelResultFilter.builder() - .targetResponse(Arrays.asList(column)) - .targetResponsePositions(Arrays.asList(position)) - .build(); + ModelResultFilter resultFilter = ModelResultFilter + .builder() + .targetResponse(Arrays.asList(column)) + .targetResponsePositions(Arrays.asList(position)) + .build(); TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).resultFilter(resultFilter).build(); - String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}"; + String expectedInputStr = + "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}"; expectedInputStr = expectedInputStr.replace("functionName", functionName.toString()); testParse(functionName, inputDataset, expectedInputStr, parsedInput -> { assertNotNull(parsedInput.getInputDataset()); @@ -163,7 +174,6 @@ private void parse_NLPModel_NullResultFilter(FunctionName functionName) throws I }); } - @Test public void parse_NLPRelated_NullResultFilter() throws IOException { parse_NLPModel_NullResultFilter(FunctionName.TEXT_EMBEDDING); @@ -174,12 +184,14 @@ public void parse_NLPRelated_NullResultFilter() throws IOException { @Test public void parse_Remote_Model() throws IOException { Map parameters = Map.of("TransformJobName", "new name"); - RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder() - .parameters(parameters) - .actionType(ConnectorAction.ActionType.PREDICT) - .build(); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(parameters) + .actionType(ConnectorAction.ActionType.PREDICT) + .build(); - String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"PREDICT\"}"; + String expectedInputStr = + "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"PREDICT\"}"; testParse(FunctionName.REMOTE, remoteInferenceInputDataSet, expectedInputStr, parsedInput -> { assertNotNull(parsedInput.getInputDataset()); @@ -191,21 +203,30 @@ public void parse_Remote_Model() throws IOException { @Test public void parse_Remote_Model_With_ActionType() throws IOException { Map parameters = Map.of("TransformJobName", "new name"); - RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder() - .parameters(parameters) - .actionType(ConnectorAction.ActionType.BATCH_PREDICT) - .build(); - - String expectedInputStr = "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH_PREDICT\"}"; - - testParseWithActionType(FunctionName.REMOTE, remoteInferenceInputDataSet, ConnectorAction.ActionType.BATCH_PREDICT, expectedInputStr, parsedInput -> { - assertNotNull(parsedInput.getInputDataset()); - RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset(); - assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, parsedInputDataSet.getActionType()); - }); + RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(parameters) + .actionType(ConnectorAction.ActionType.BATCH_PREDICT) + .build(); + + String expectedInputStr = + "{\"algorithm\":\"REMOTE\",\"parameters\":{\"TransformJobName\":\"new name\"},\"action_type\":\"BATCH_PREDICT\"}"; + + testParseWithActionType( + FunctionName.REMOTE, + remoteInferenceInputDataSet, + ConnectorAction.ActionType.BATCH_PREDICT, + expectedInputStr, + parsedInput -> { + assertNotNull(parsedInput.getInputDataset()); + RemoteInferenceInputDataSet parsedInputDataSet = (RemoteInferenceInputDataSet) parsedInput.getInputDataset(); + assertEquals(ConnectorAction.ActionType.BATCH_PREDICT, parsedInputDataSet.getActionType()); + } + ); } - private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer verify) throws IOException { + private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer verify) + throws IOException { MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build(); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); input.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -213,9 +234,13 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri String jsonStr = builder.toString(); assertEquals(expectedInputStr, jsonStr); - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, algorithm.name()); assertEquals(input.getFunctionName(), parsedInput.getFunctionName()); @@ -223,7 +248,13 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri verify.accept(parsedInput); } - private void testParseWithActionType(FunctionName algorithm, MLInputDataset inputDataset, ConnectorAction.ActionType actionType, String expectedInputStr, Consumer verify) throws IOException { + private void testParseWithActionType( + FunctionName algorithm, + MLInputDataset inputDataset, + ConnectorAction.ActionType actionType, + String expectedInputStr, + Consumer verify + ) throws IOException { MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build(); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); input.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -231,9 +262,13 @@ private void testParseWithActionType(FunctionName algorithm, MLInputDataset inpu String jsonStr = builder.toString(); assertEquals(expectedInputStr, jsonStr); - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, algorithm.name(), actionType); assertEquals(input.getFunctionName(), parsedInput.getFunctionName()); @@ -274,7 +309,8 @@ public void testParse_TextSimilarity() throws IOException { MLInputDataset dataset = TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build(); input = new TextSimilarityMLInput(FunctionName.TEXT_SIMILARITY, dataset); MLInput inp = new MLInput(FunctionName.TEXT_SIMILARITY, null, dataset); - String expected = "{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[\"That is a happy dog\",\"it's summer\"]}"; + String expected = + "{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[\"That is a happy dog\",\"it's summer\"]}"; XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); inp.toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java b/common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java index 36235adffe..3dcbbe89c2 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java @@ -5,6 +5,17 @@ package org.opensearch.ml.common.input.execute.agent; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + import org.junit.Test; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -13,17 +24,6 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - public class AgentMLInputTests { @Test @@ -77,21 +77,19 @@ public void testConstructorWithXContentParser() throws IOException { // Simulate parser behavior for START_OBJECT token when(parser.currentToken()).thenReturn(XContentParser.Token.START_OBJECT); - when(parser.nextToken()).thenReturn(XContentParser.Token.FIELD_NAME) - .thenReturn(XContentParser.Token.VALUE_STRING) - .thenReturn(XContentParser.Token.FIELD_NAME) // For PARAMETERS_FIELD - .thenReturn(XContentParser.Token.START_OBJECT) // Start of PARAMETERS_FIELD map - .thenReturn(XContentParser.Token.FIELD_NAME) // Key in PARAMETERS_FIELD map - .thenReturn(XContentParser.Token.VALUE_STRING) // Value in PARAMETERS_FIELD map - .thenReturn(XContentParser.Token.END_OBJECT) // End of PARAMETERS_FIELD map - .thenReturn(XContentParser.Token.END_OBJECT); // End of the main object + when(parser.nextToken()) + .thenReturn(XContentParser.Token.FIELD_NAME) + .thenReturn(XContentParser.Token.VALUE_STRING) + .thenReturn(XContentParser.Token.FIELD_NAME) // For PARAMETERS_FIELD + .thenReturn(XContentParser.Token.START_OBJECT) // Start of PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.FIELD_NAME) // Key in PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.VALUE_STRING) // Value in PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.END_OBJECT) // End of PARAMETERS_FIELD map + .thenReturn(XContentParser.Token.END_OBJECT); // End of the main object // Simulate parser behavior for agent_id - when(parser.currentName()).thenReturn("agent_id") - .thenReturn("parameters") - .thenReturn("paramKey"); - when(parser.text()).thenReturn("testAgentId") - .thenReturn("paramValue"); + when(parser.currentName()).thenReturn("agent_id").thenReturn("parameters").thenReturn("paramKey"); + when(parser.text()).thenReturn("testAgentId").thenReturn("paramValue"); // Simulate parser behavior for parameters Map paramMap = new HashMap<>(); diff --git a/common/src/test/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInputTests.java b/common/src/test/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInputTests.java index bf1a809501..34682155da 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInputTests.java @@ -5,18 +5,20 @@ package org.opensearch.ml.common.input.execute.anomalylocalization; +import static org.junit.Assert.assertEquals; + import java.util.Arrays; import java.util.Collections; import java.util.Optional; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -24,21 +26,33 @@ import org.opensearch.search.SearchModule; import org.opensearch.search.aggregations.AggregationBuilders; -import static org.junit.Assert.assertEquals; - public class AnomalyLocalizationInputTests { @Test public void testXContentFullObject() throws Exception { - AnomalyLocalizationInput input = new AnomalyLocalizationInput("indexName", Arrays.asList("attribute"), Arrays.asList(AggregationBuilders.max("max").field("field"), - AggregationBuilders.min("min").field("field")), "@timestamp", 0L, 10L, 1L, 2, Optional.of(3L), - Optional.of(QueryBuilders.matchAllQuery())); + AnomalyLocalizationInput input = new AnomalyLocalizationInput( + "indexName", + Arrays.asList("attribute"), + Arrays.asList(AggregationBuilders.max("max").field("field"), AggregationBuilders.min("min").field("field")), + "@timestamp", + 0L, + 10L, + 1L, + 2, + Optional.of(3L), + Optional.of(QueryBuilders.matchAllQuery()) + ); XContentBuilder builder = XContentFactory.jsonBuilder(); builder = input.toXContent(builder, null); String json = builder.toString(); - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, json); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + json + ); parser.nextToken(); AnomalyLocalizationInput newInput = AnomalyLocalizationInput.parse(parser); @@ -47,14 +61,29 @@ public void testXContentFullObject() throws Exception { @Test public void testXContentMissingAnomalyStartFilter() throws Exception { - AnomalyLocalizationInput input = new AnomalyLocalizationInput("indexName", Arrays.asList("attribute"), Arrays.asList(AggregationBuilders.max("max").field("field")), - "@timestamp", 0L, 10L, 1L, 2, Optional.empty(), Optional.empty()); + AnomalyLocalizationInput input = new AnomalyLocalizationInput( + "indexName", + Arrays.asList("attribute"), + Arrays.asList(AggregationBuilders.max("max").field("field")), + "@timestamp", + 0L, + 10L, + 1L, + 2, + Optional.empty(), + Optional.empty() + ); XContentBuilder builder = XContentFactory.jsonBuilder(); builder = input.toXContent(builder, null); String json = builder.toString(); - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, json); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + json + ); parser.nextToken(); AnomalyLocalizationInput newInput = AnomalyLocalizationInput.parse(parser); @@ -63,14 +92,25 @@ public void testXContentMissingAnomalyStartFilter() throws Exception { @Test public void testWriteable() throws Exception { - AnomalyLocalizationInput input = new AnomalyLocalizationInput("indexName", Arrays.asList("attribute"), Arrays.asList(AggregationBuilders.max("max").field("field"), - AggregationBuilders.min("min").field("field")), "@timestamp", 0L, 10L, 1L, 2, Optional.of(3L), - Optional.of(QueryBuilders.matchAllQuery())); + AnomalyLocalizationInput input = new AnomalyLocalizationInput( + "indexName", + Arrays.asList("attribute"), + Arrays.asList(AggregationBuilders.max("max").field("field"), AggregationBuilders.min("min").field("field")), + "@timestamp", + 0L, + 10L, + 1L, + 2, + Optional.of(3L), + Optional.of(QueryBuilders.matchAllQuery()) + ); BytesStreamOutput out = new BytesStreamOutput(); input.writeTo(out); - StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), - new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables())); + StreamInput in = new NamedWriteableAwareStreamInput( + out.bytes().streamInput(), + new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()) + ); AnomalyLocalizationInput newInput = new AnomalyLocalizationInput(in); assertEquals(input, newInput); diff --git a/common/src/test/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInputTest.java index 832e07b7e7..c9ad55fc22 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInputTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.input.execute.metricscorrelation; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,13 +21,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; - public class MetricsCorrelationInputTest { MetricsCorrelationInput input; @@ -39,9 +39,9 @@ public class MetricsCorrelationInputTest { @Before public void setUp() { List inputData = new ArrayList<>(); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); input = MetricsCorrelationInput.builder().inputData(inputData).build(); } @@ -57,9 +57,9 @@ public void constructor_variableLengthInput() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("All the input metrics sizes should be same"); List inputData = new ArrayList<>(); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); MetricsCorrelationInput.builder().inputData(inputData).build(); } diff --git a/common/src/test/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInputTest.java index fec5c99ced..19535c1d7a 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInputTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.input.execute.samplecalculator; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,13 +21,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; - public class LocalSampleCalculatorInputTest { LocalSampleCalculatorInput input; @@ -42,10 +42,7 @@ public void setUp() { inputData.add(1.0); inputData.add(2.0); inputData.add(3.0); - input = LocalSampleCalculatorInput.builder() - .operation("sum") - .inputData(inputData) - .build(); + input = LocalSampleCalculatorInput.builder().operation("sum").inputData(inputData).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInputTest.java index dd91f4023f..4a1134be2a 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInputTest.java @@ -4,6 +4,9 @@ */ package org.opensearch.ml.common.input.nlp; +import java.io.IOException; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -25,14 +28,8 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.List; - -import static org.junit.Assert.assertThrows; - public class QuestionAnsweringMLInputTest { - + MLInput input; private final FunctionName algorithm = FunctionName.QUESTION_ANSWERING; @@ -50,9 +47,13 @@ public void testXContent_IsInternallyConsistent() throws IOException { XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); input.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); @@ -69,15 +70,23 @@ public void testXContent_String() throws IOException { XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); input.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assert (jsonStr.equals("{\"algorithm\":\"QUESTION_ANSWERING\",\"question\":\"What color is apple\",\"context\":\"I like Apples. They are red\"}")); + assert (jsonStr + .equals( + "{\"algorithm\":\"QUESTION_ANSWERING\",\"question\":\"What color is apple\",\"context\":\"I like Apples. They are red\"}" + )); } @Test public void testParseJson() throws IOException { - String json = "{\"algorithm\":\"QUESTION_ANSWERING\",\"question\":\"What color is apple\",\"context\":\"I like Apples. They are red\"}"; - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, json); + String json = + "{\"algorithm\":\"QUESTION_ANSWERING\",\"question\":\"What color is apple\",\"context\":\"I like Apples. They are red\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + json + ); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java index 397769146c..7068027bb9 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java @@ -1,7 +1,16 @@ package org.opensearch.ml.common.input.nlp; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; import java.util.stream.Collectors; import java.util.stream.Stream; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -17,19 +26,10 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; - public class TextDocsMLInputTest { MLInput input; @@ -41,10 +41,14 @@ public class TextDocsMLInputTest { @Before public void setUp() throws Exception { - ModelResultFilter resultFilter = ModelResultFilter.builder().returnBytes(true).returnNumber(true) - .targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build(); - MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2")) - .resultFilter(resultFilter).build(); + ModelResultFilter resultFilter = ModelResultFilter + .builder() + .returnBytes(true) + .returnNumber(true) + .targetResponse(Arrays.asList("field1")) + .targetResponsePositions(Arrays.asList(2)) + .build(); + MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2")).resultFilter(resultFilter).build(); input = new TextDocsMLInput(algorithm, inputDataset); } @@ -58,27 +62,40 @@ public void parseTextDocsMLInput() throws IOException { @Test public void parseTextDocsMLInput_OldWay() throws IOException { - String jsonStr = "{\"text_docs\": [ \"doc1\", \"doc2\", null ],\"return_number\": true, \"return_bytes\": true,\"target_response\": [ \"field1\" ], \"target_response_positions\": [2]}"; + String jsonStr = + "{\"text_docs\": [ \"doc1\", \"doc2\", null ],\"return_number\": true, \"return_bytes\": true,\"target_response\": [ \"field1\" ], \"target_response_positions\": [2]}"; parseMLInput(jsonStr, 3); } @Test public void parseTextDocsMLInput_NewWay() throws IOException { - String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; + String jsonStr = + "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; parseMLInput(jsonStr, 2); } @Test public void parseTextDocsMLInput_WithParameters() throws IOException { - String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}, \"parameters\" : {\"content_type\": \"passage\"}}"; + String jsonStr = + "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}, \"parameters\" : {\"content_type\": \"passage\"}}"; parseMLInput(jsonStr, 2); } private void parseMLInput(String jsonStr, int docSize) throws IOException { - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(Stream.concat(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents().stream(), Stream.of( - AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY)).collect(Collectors.toList())), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry( + Stream + .concat( + new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents().stream(), + Stream.of(AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY) + ) + .collect(Collectors.toList()) + ), + null, + jsonStr + ); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java index 296b939f5f..ec6b860b41 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInputTest.java @@ -45,7 +45,7 @@ import org.opensearch.search.SearchModule; public class TextSimilarityMLInputTest { - + MLInput input; private final FunctionName algorithm = FunctionName.TEXT_SIMILARITY; @@ -63,9 +63,13 @@ public void testXContent_IsInternallyConsistent() throws IOException { XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); input.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); @@ -84,15 +88,23 @@ public void testXContent_String() throws IOException { XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); input.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assert (jsonStr.equals("{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[\"That is a happy dog\",\"it's summer\"]}")); + assert (jsonStr + .equals( + "{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[\"That is a happy dog\",\"it's summer\"]}" + )); } @Test public void testParseJson() throws IOException { - String json = "{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[\"That is a happy dog\",\"it's summer\"]}"; - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, json); + String json = + "{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[\"That is a happy dog\",\"it's summer\"]}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + json + ); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); @@ -109,13 +121,19 @@ public void testParseJson() throws IOException { @Test public void testParseJson_NoPairs_ThenFail() throws IOException { String json = "{\"algorithm\":\"TEXT_SIMILARITY\",\"query_text\":\"today is sunny\",\"text_docs\":[]}"; - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, json); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + json + ); parser.nextToken(); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, - () -> MLInput.parse(parser, input.getFunctionName().name())); + IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> MLInput.parse(parser, input.getFunctionName().name()) + ); assert (e.getMessage().equals("No text documents were provided")); } diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParamsTest.java index 544f43ec1e..0e87cf62d4 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParamsTest.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.input.parameter.ad; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,11 +17,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; - public class AnomalyDetectionLibSVMParamsTest { AnomalyDetectionLibSVMParams params; @@ -30,15 +30,16 @@ public class AnomalyDetectionLibSVMParamsTest { @Before public void setUp() { - params = AnomalyDetectionLibSVMParams.builder() - .kernelType(AnomalyDetectionLibSVMParams.ADKernelType.POLY) - .gamma(1.0) - .nu(0.5) - .cost(1.0) - .coeff(0.1) - .epsilon(0.2) - .degree(2) - .build(); + params = AnomalyDetectionLibSVMParams + .builder() + .kernelType(AnomalyDetectionLibSVMParams.ADKernelType.POLY) + .gamma(1.0) + .nu(0.5) + .cost(1.0) + .coeff(0.1) + .epsilon(0.2) + .degree(2) + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParamsTest.java index b4cd2f0c81..a9886a0d3f 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParamsTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.input.parameter.clustering; +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,13 +21,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.ml.common.TestHelper.contentObjectToString; -import static org.opensearch.ml.common.TestHelper.testParseFromString; - public class KMeansParamsTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -28,7 +28,7 @@ public class KMeansParamsTest { KMeansParams params; private Function function = parser -> { try { - return (KMeansParams)KMeansParams.parse(parser); + return (KMeansParams) KMeansParams.parse(parser); } catch (IOException e) { throw new RuntimeException("failed to parse KMeansParams", e); } @@ -36,11 +36,7 @@ public class KMeansParamsTest { @Before public void setUp() { - params = KMeansParams.builder() - .centroids(2) - .iterations(10) - .distanceType(KMeansParams.DistanceType.COSINE) - .build(); + params = KMeansParams.builder().centroids(2).iterations(10).distanceType(KMeansParams.DistanceType.COSINE).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParamsTest.java index 0a29ebfffa..e6cbba3722 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParamsTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.input.parameter.clustering; +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,13 +21,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.ml.common.TestHelper.contentObjectToString; -import static org.opensearch.ml.common.TestHelper.testParseFromString; - public class RCFSummarizeParamsTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -28,7 +28,7 @@ public class RCFSummarizeParamsTest { RCFSummarizeParams params; private Function function = parser -> { try { - return (RCFSummarizeParams)RCFSummarizeParams.parse(parser); + return (RCFSummarizeParams) RCFSummarizeParams.parse(parser); } catch (IOException e) { throw new RuntimeException("failed to parse RCFSummarizeParams", e); } @@ -36,11 +36,7 @@ public class RCFSummarizeParamsTest { @Before public void setUp() { - params = RCFSummarizeParams.builder() - .maxK(2) - .initialK(10) - .distanceType(RCFSummarizeParams.DistanceType.L1) - .build(); + params = RCFSummarizeParams.builder().maxK(2).initialK(10).distanceType(RCFSummarizeParams.DistanceType.L1).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParamsTest.java index 7137763146..e9e2737174 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParamsTest.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.input.parameter.rcf; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,11 +17,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; - public class BatchRCFParamsTest { BatchRCFParams params; @@ -30,13 +30,7 @@ public class BatchRCFParamsTest { @Before public void setUp() { - params = BatchRCFParams.builder() - .numberOfTrees(10) - .shingleSize(8) - .sampleSize(256) - .outputAfter(32) - .trainingDataSize(200) - .build(); + params = BatchRCFParams.builder().numberOfTrees(10).shingleSize(8).sampleSize(256).outputAfter(32).trainingDataSize(200).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParamsTest.java index e87973a98a..e2e7050a72 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParamsTest.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.input.parameter.rcf; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,11 +17,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; - public class FitRCFParamsTest { FitRCFParams params; @@ -30,17 +30,18 @@ public class FitRCFParamsTest { @Before public void setUp() { - params = FitRCFParams.builder() - .numberOfTrees(10) - .shingleSize(8) - .sampleSize(256) - .outputAfter(32) - .timeDecay(0.001) - .anomalyRate(0.005) - .timeField("timestamp") - .dateFormat("yyyy-mm-dd") - .timeZone("UTC") - .build(); + params = FitRCFParams + .builder() + .numberOfTrees(10) + .shingleSize(8) + .sampleSize(256) + .outputAfter(32) + .timeDecay(0.001) + .anomalyRate(0.005) + .timeField("timestamp") + .dateFormat("yyyy-mm-dd") + .timeZone("UTC") + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParamsTest.java index be71883e92..dc9d92d0cb 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParamsTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.input.parameter.regression; +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,13 +21,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.ml.common.TestHelper.contentObjectToString; -import static org.opensearch.ml.common.TestHelper.testParseFromString; - public class LinearRegressionParamsTest { @Rule @@ -39,21 +39,21 @@ public class LinearRegressionParamsTest { @Before public void setUp() { params = LinearRegressionParams - .builder() - .objectiveType(LinearRegressionParams.ObjectiveType.ABSOLUTE_LOSS) - .optimizerType(LinearRegressionParams.OptimizerType.ADAM) - .learningRate(0.1) - .momentumType(LinearRegressionParams.MomentumType.NESTEROV) - .momentumFactor(0.2) - .epsilon(0.3) - .beta1(0.4) - .beta2(0.5) - .decayRate(0.6) - .epochs(1) - .batchSize(2) - .seed(3L) - .target("test_target") - .build(); + .builder() + .objectiveType(LinearRegressionParams.ObjectiveType.ABSOLUTE_LOSS) + .optimizerType(LinearRegressionParams.OptimizerType.ADAM) + .learningRate(0.1) + .momentumType(LinearRegressionParams.MomentumType.NESTEROV) + .momentumFactor(0.2) + .epsilon(0.3) + .beta1(0.4) + .beta2(0.5) + .decayRate(0.6) + .epochs(1) + .batchSize(2) + .seed(3L) + .target("test_target") + .build(); } @Test @@ -69,21 +69,21 @@ public void readInputStream_Success() throws IOException { @Test public void parse_PassIntValueToDoubleField() throws IOException { LinearRegressionParams params = LinearRegressionParams - .builder() - .objectiveType(LinearRegressionParams.ObjectiveType.ABSOLUTE_LOSS) - .optimizerType(LinearRegressionParams.OptimizerType.ADAM) - .learningRate(0.1) - .momentumType(LinearRegressionParams.MomentumType.NESTEROV) - .momentumFactor(0.2) - .epsilon(3.0) - .beta1(0.4) - .beta2(0.5) - .decayRate(0.6) - .epochs(1) - .batchSize(2) - .seed(3L) - .target("test_target") - .build(); + .builder() + .objectiveType(LinearRegressionParams.ObjectiveType.ABSOLUTE_LOSS) + .optimizerType(LinearRegressionParams.OptimizerType.ADAM) + .learningRate(0.1) + .momentumType(LinearRegressionParams.MomentumType.NESTEROV) + .momentumFactor(0.2) + .epsilon(3.0) + .beta1(0.4) + .beta2(0.5) + .decayRate(0.6) + .epochs(1) + .batchSize(2) + .seed(3L) + .target("test_target") + .build(); String paramsStr = contentObjectToString(params); testParseFromString(params, paramsStr.replace("\"epsilon\":3.0,", "\"epsilon\":3,"), function); } diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java index ab45c9e41e..12b5a7f5cb 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.input.parameter.regression; +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; +import static org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams.PARSE_FIELD_NAME; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,14 +22,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.ml.common.TestHelper.contentObjectToString; -import static org.opensearch.ml.common.TestHelper.testParseFromString; -import static org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams.PARSE_FIELD_NAME; - public class LogisticRegressionParamsTest { @Rule @@ -40,21 +40,21 @@ public class LogisticRegressionParamsTest { @Before public void setUp() { logisticRegressionParams = LogisticRegressionParams - .builder() - .objectiveType(LogisticRegressionParams.ObjectiveType.LOGMULTICLASS) - .optimizerType(LogisticRegressionParams.OptimizerType.ADA_GRAD) - .learningRate(0.1) - .momentumType(LogisticRegressionParams.MomentumType.STANDARD) - .momentumFactor(0.2) - .epsilon(0.3) - .beta1(0.4) - .beta2(0.5) - .decayRate(0.6) - .epochs(1) - .batchSize(2) - .seed(3L) - .target("test_target") - .build(); + .builder() + .objectiveType(LogisticRegressionParams.ObjectiveType.LOGMULTICLASS) + .optimizerType(LogisticRegressionParams.OptimizerType.ADA_GRAD) + .learningRate(0.1) + .momentumType(LogisticRegressionParams.MomentumType.STANDARD) + .momentumFactor(0.2) + .epsilon(0.3) + .beta1(0.4) + .beta2(0.5) + .decayRate(0.6) + .epochs(1) + .batchSize(2) + .seed(3L) + .target("test_target") + .build(); } @Test @@ -114,8 +114,12 @@ public void parse_EmptyLogisticRegressionParams() throws IOException { @Test public void parse_LogisticRegressionParams_WrongExtraField() throws IOException { - TestHelper.testParseFromString(logisticRegressionParams, "{\"objective\":\"LOGMULTICLASS\",\"learning_rate\":0.1,\"wrong_field\":1.0}", function); + TestHelper + .testParseFromString( + logisticRegressionParams, + "{\"objective\":\"LOGMULTICLASS\",\"learning_rate\":0.1,\"wrong_field\":1.0}", + function + ); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParamsTest.java index 2ad3fcca39..9dbc858e36 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParamsTest.java @@ -5,21 +5,21 @@ package org.opensearch.ml.common.input.parameter.sample; +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - public class SampleAlgoParamsTest { SampleAlgoParams params; private Function function = parser -> { try { - return (SampleAlgoParams)SampleAlgoParams.parse(parser); + return (SampleAlgoParams) SampleAlgoParams.parse(parser); } catch (IOException e) { throw new RuntimeException("failed to parse SampleAlgoParams", e); } @@ -27,9 +27,7 @@ public class SampleAlgoParamsTest { @Before public void setUp() { - params = SampleAlgoParams.builder() - .sampleParam(2) - .build(); + params = SampleAlgoParams.builder().sampleParam(2).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java index 759bf154b3..3ff8f82b18 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java @@ -1,5 +1,8 @@ package org.opensearch.ml.common.input.remote; +import java.io.IOException; +import java.util.Collections; + import org.junit.Assert; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,9 +15,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; - public class RemoteInferenceMLInputTest { @Test @@ -22,7 +22,7 @@ public void constructor_parser() throws IOException { RemoteInferenceMLInput input = createRemoteInferenceMLInput(); Assert.assertNotNull(input.getInputDataset()); Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType()); - RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) input.getInputDataset(); Assert.assertEquals(1, inputDataSet.getParameters().size()); Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt")); } @@ -36,7 +36,7 @@ public void constructor_stream() throws IOException { RemoteInferenceMLInput input = new RemoteInferenceMLInput(output.bytes().streamInput()); Assert.assertNotNull(input.getInputDataset()); Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType()); - RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) input.getInputDataset(); Assert.assertEquals(1, inputDataSet.getParameters().size()); Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt")); Assert.assertEquals("BATCH_PREDICT", inputDataSet.getActionType().toString()); @@ -44,8 +44,13 @@ public void constructor_stream() throws IOException { private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException { String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" }, \"action_type\": \"batch_predict\" }"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); RemoteInferenceMLInput input = new RemoteInferenceMLInput(parser, FunctionName.REMOTE); return input; diff --git a/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java b/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java index a1b589d07c..b4429ebdf8 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common.model; +import java.io.IOException; +import java.util.Collections; +import java.util.List; + import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -18,10 +22,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.List; - public class GuardrailsTests { StopWords stopWords; String[] regex; @@ -55,19 +55,27 @@ public void toXContent() throws IOException { guardrails.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"type\":\"local_regex\"," + - "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + - "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}", - content); + Assert + .assertEquals( + "{\"type\":\"local_regex\"," + + "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + + "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}", + content + ); } @Test public void parse() throws IOException { - String jsonStr = "{\"type\":\"local_regex\"," + - "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + - "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + String jsonStr = "{\"type\":\"local_regex\"," + + "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + + "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); Guardrails guardrails = Guardrails.parse(parser); @@ -75,4 +83,4 @@ public void parse() throws IOException { Assert.assertEquals(guardrails.getInputGuardrail(), inputLocalRegexGuardrail); Assert.assertEquals(guardrails.getOutputGuardrail(), outputLocalRegexGuardrail); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java index 6c0cdfb1ef..6ad9c6d8e5 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java @@ -5,6 +5,16 @@ package org.opensearch.ml.common.model; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; + import org.apache.lucene.search.TotalHits; import org.junit.Assert; import org.junit.Before; @@ -36,16 +46,6 @@ import org.opensearch.search.suggest.Suggest; import org.opensearch.threadpool.ThreadPool; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.regex.Pattern; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; - public class LocalRegexGuardrailTests { NamedXContentRegistry xContentRegistry; @Mock @@ -114,14 +114,24 @@ public void toXContent() throws IOException { localRegexGuardrail.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"(.|\\n)*stop words(.|\\n)*\"]}", content); + Assert + .assertEquals( + "{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"(.|\\n)*stop words(.|\\n)*\"]}", + content + ); } @Test public void parse() throws IOException { - String jsonStr = "{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"(.|\\n)*stop words(.|\\n)*\"]}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + String jsonStr = + "{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"(.|\\n)*stop words(.|\\n)*\"]}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); LocalRegexGuardrail localRegexGuardrail = LocalRegexGuardrail.parse(parser); @@ -217,22 +227,22 @@ private SearchResponse createSearchResponse(int size) throws IOException { hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content)); } return new SearchResponse( - new InternalSearchResponse( - new SearchHits(hits, new TotalHits(size, TotalHits.Relation.EQUAL_TO), 1.0f), - InternalAggregations.EMPTY, - new Suggest(Collections.emptyList()), - new SearchProfileShardResults(Collections.emptyMap()), - false, - false, - 1 - ), - "", - 5, - 5, - 0, - 100, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY + new InternalSearchResponse( + new SearchHits(hits, new TotalHits(size, TotalHits.Relation.EQUAL_TO), 1.0f), + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY ); } @@ -289,4 +299,4 @@ public SearchResponse get(long timeout, TimeUnit unit) { } }; } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java index 72b5e883b5..970b4df91b 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java @@ -5,7 +5,14 @@ package org.opensearch.ml.common.model; -import com.fasterxml.jackson.core.JsonParseException; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Collections; +import java.util.function.Consumer; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -22,13 +29,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.function.Consumer; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import com.fasterxml.jackson.core.JsonParseException; public class MLDeployingSettingTests { @@ -43,9 +44,7 @@ public class MLDeployingSettingTests { @Before public void setUp() throws Exception { - deploySetting = MLDeploySetting.builder() - .isAutoDeployEnabled(true) - .build(); + deploySetting = MLDeploySetting.builder().isAutoDeployEnabled(true).build(); deploySettingNull = MLDeploySetting.builder().build(); @@ -53,9 +52,7 @@ public void setUp() throws Exception { @Test public void readInputStreamSuccess() throws IOException { - readInputStream(deploySetting, parsedInput -> { - assertTrue(parsedInput.getIsAutoDeployEnabled()); - }); + readInputStream(deploySetting, parsedInput -> { assertTrue(parsedInput.getIsAutoDeployEnabled()); }); } @Test @@ -74,9 +71,7 @@ public void testToXContentIncomplete() throws Exception { @Test public void parseSuccess() throws Exception { - testParseFromJsonString(expectedInputStr, parsedInput -> { - assertTrue(parsedInput.getIsAutoDeployEnabled()); - }); + testParseFromJsonString(expectedInputStr, parsedInput -> { assertTrue(parsedInput.getIsAutoDeployEnabled()); }); } @Test @@ -109,8 +104,9 @@ public void parseWithIllegalArgumentInteger() throws Exception { @Test public void parseWithIllegalField() throws Exception { - final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl_hours\":0," + - "\"illegal_field\":\"This field need to be skipped.\"}"; + final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + + "\"model_ttl_hours\":0," + + "\"illegal_field\":\"This field need to be skipped.\"}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { @@ -122,10 +118,13 @@ public void parseWithIllegalField() throws Exception { } private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, - expectedInputStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLDeploySetting parsedInput = MLDeploySetting.parse(parser); verify.accept(parsedInput); diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java index b2a29ba7c7..a99e66be35 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java @@ -5,6 +5,12 @@ package org.opensearch.ml.common.model; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.List; +import java.util.regex.Pattern; + import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -17,12 +23,6 @@ import org.opensearch.search.SearchModule; import org.opensearch.threadpool.ThreadPool; -import java.util.Collections; -import java.util.List; -import java.util.regex.Pattern; - -import static org.mockito.Mockito.when; - public class MLGuardTests { NamedXContentRegistry xContentRegistry; @@ -80,4 +80,4 @@ public void validateInitializedStopWordsEmpty() { Boolean res = mlGuard.validate(input, MLGuard.Type.INPUT, Collections.emptyMap()); Assert.assertTrue(res); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLModelFormatTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLModelFormatTests.java index 8bdf0564e2..ee14189592 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLModelFormatTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLModelFormatTests.java @@ -5,12 +5,12 @@ package org.opensearch.ml.common.model; +import static org.junit.Assert.assertEquals; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import static org.junit.Assert.assertEquals; - public class MLModelFormatTests { @Rule diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLModelStateTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLModelStateTests.java index c4f8e7e51f..713f793ddf 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLModelStateTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLModelStateTests.java @@ -5,12 +5,11 @@ package org.opensearch.ml.common.model; +import static org.junit.Assert.assertEquals; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.opensearch.ml.common.CommonValue; - -import static org.junit.Assert.assertEquals; public class MLModelStateTests { diff --git a/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java index 4700039939..c115c9d1d7 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java @@ -5,6 +5,12 @@ package org.opensearch.ml.common.model; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -16,12 +22,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class MetricsCorrelationModelConfigTests { MetricsCorrelationModelConfig config; @@ -31,10 +31,11 @@ public class MetricsCorrelationModelConfigTests { @Before public void setUp() { - config = MetricsCorrelationModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .build(); + config = MetricsCorrelationModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); function = parser -> { try { return MetricsCorrelationModelConfig.parse(parser); @@ -49,20 +50,23 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); config.toXContent(builder, EMPTY_PARAMS); String configContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); + assertEquals( + "{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", + configContent + ); } @Test public void nullFields_ModelType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model type is null"); - config = MetricsCorrelationModelConfig.builder() - .build(); + config = MetricsCorrelationModelConfig.builder().build(); } @Test public void parse() throws IOException { - String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + String content = + "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; TestHelper.testParseFromString(config, content, function); } diff --git a/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java index ebbbd0b9a4..e59cf7075f 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/ModelGuardrailTests.java @@ -1,5 +1,13 @@ package org.opensearch.ml.common.model; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.regex.Pattern; + import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -8,7 +16,6 @@ import org.opensearch.client.Client; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; @@ -17,16 +24,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.regex.Pattern; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.when; - public class ModelGuardrailTests { NamedXContentRegistry xContentRegistry; @Mock @@ -83,14 +80,23 @@ public void toXContent() throws IOException { modelGuardrail.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_validation_regex\":\"^accept$\"}", content); + Assert + .assertEquals( + "{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_validation_regex\":\"^accept$\"}", + content + ); } @Test public void parse() throws IOException { String jsonStr = "{\"model_id\":\"test_model_id\",\"response_filter\":\"$.test\",\"response_validation_regex\":\"^accept$\"}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); ModelGuardrail modelGuardrail1 = ModelGuardrail.parse(parser); @@ -98,4 +104,4 @@ public void parse() throws IOException { Assert.assertEquals(modelGuardrail1.getResponseFilter(), modelGuardrail.getResponseFilter()); Assert.assertEquals(modelGuardrail1.getResponseAccept(), modelGuardrail.getResponseAccept()); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java index 5136c187b7..d3c8d8cb8a 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/QuestionAnsweringModelConfigTests.java @@ -5,6 +5,12 @@ package org.opensearch.ml.common.model; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -16,12 +22,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class QuestionAnsweringModelConfigTests { QuestionAnsweringModelConfig config; @@ -31,12 +31,13 @@ public class QuestionAnsweringModelConfigTests { @Before public void setUp() { - config = QuestionAnsweringModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .normalizeResult(false) - .frameworkType(QuestionAnsweringModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .build(); + config = QuestionAnsweringModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .normalizeResult(false) + .frameworkType(QuestionAnsweringModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .build(); function = parser -> { try { return QuestionAnsweringModelConfig.parse(parser); @@ -51,29 +52,30 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); config.toXContent(builder, EMPTY_PARAMS); String configContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"model_type\":\"testModelType\",\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); + assertEquals( + "{\"model_type\":\"testModelType\",\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", + configContent + ); } @Test public void nullFields_ModelType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model type is null"); - config = QuestionAnsweringModelConfig.builder() - .build(); + config = QuestionAnsweringModelConfig.builder().build(); } @Test public void nullFields_FrameworkType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("framework type is null"); - config = QuestionAnsweringModelConfig.builder() - .modelType("testModelType") - .build(); + config = QuestionAnsweringModelConfig.builder().modelType("testModelType").build(); } @Test public void parse() throws IOException { - String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"normalize_result\":false,\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + String content = + "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"normalize_result\":false,\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; TestHelper.testParseFromString(config, content, function); } diff --git a/common/src/test/java/org/opensearch/ml/common/model/StopWordsTests.java b/common/src/test/java/org/opensearch/ml/common/model/StopWordsTests.java index 19764bb736..62e970ae56 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/StopWordsTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/StopWordsTests.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common.model; +import java.io.IOException; +import java.util.Collections; +import java.util.List; + import org.junit.Assert; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -17,10 +21,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.List; - public class StopWordsTests { @Test @@ -47,12 +47,17 @@ public void toXContent() throws IOException { @Test public void parse() throws IOException { String jsonStr = "{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); StopWords stopWords = StopWords.parse(parser); Assert.assertEquals(stopWords.getIndex(), "test_index"); Assert.assertArrayEquals(stopWords.getSourceFields(), List.of("test_field").toArray(new String[0])); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java index 9bc97f7c9f..4c687dcb92 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java @@ -5,23 +5,23 @@ package org.opensearch.ml.common.model; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class TextEmbeddingModelConfigTests { TextEmbeddingModelConfig config; @@ -31,14 +31,15 @@ public class TextEmbeddingModelConfigTests { @Before public void setUp() { - config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .passagePrefix("passage: ") - .queryPrefix("query: ") - .build(); + config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .passagePrefix("passage: ") + .queryPrefix("query: ") + .build(); function = parser -> { try { return TextEmbeddingModelConfig.parse(parser); @@ -53,39 +54,37 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); config.toXContent(builder, EMPTY_PARAMS); String configContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}", configContent); + assertEquals( + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}", + configContent + ); } @Test public void nullFields_ModelType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model type is null"); - config = TextEmbeddingModelConfig.builder() - .build(); + config = TextEmbeddingModelConfig.builder().build(); } - @Test public void nullFields_EmbeddingDimension() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("embedding dimension is null"); - config = TextEmbeddingModelConfig.builder().modelType("testModelType") - .build(); + config = TextEmbeddingModelConfig.builder().modelType("testModelType").build(); } @Test public void nullFields_FrameworkType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("framework type is null"); - config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .embeddingDimension(100) - .build(); + config = TextEmbeddingModelConfig.builder().modelType("testModelType").embeddingDimension(100).build(); } @Test public void parse() throws IOException { - String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}"; + String content = + "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"query_prefix\":\"query: \",\"passage_prefix\":\"passage: \"}"; TestHelper.testParseFromString(config, content, function); } diff --git a/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java index 9ba54314e2..9df2d489ce 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java @@ -5,11 +5,17 @@ package org.opensearch.ml.common.output; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.dataframe.ColumnMeta; @@ -20,37 +26,32 @@ import org.opensearch.ml.common.dataframe.IntValue; import org.opensearch.ml.common.dataframe.Row; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.Assert.assertEquals; - public class MLPredictionOutputTest { MLPredictionOutput output; + @Before public void setUp() { - ColumnMeta[] columnMetas = new ColumnMeta[]{new ColumnMeta("test", ColumnType.INTEGER)}; + ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test", ColumnType.INTEGER) }; List rows = new ArrayList<>(); - rows.add(new Row(new ColumnValue[]{new IntValue(1)})); - rows.add(new Row(new ColumnValue[]{new IntValue(2)})); + rows.add(new Row(new ColumnValue[] { new IntValue(1) })); + rows.add(new Row(new ColumnValue[] { new IntValue(2) })); DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows); - output = MLPredictionOutput.builder() - .taskId("test_task_id") - .status("test_status") - .predictionResult(dataFrame) - .build(); + output = MLPredictionOutput.builder().taskId("test_task_id").status("test_status").predictionResult(dataFrame).build(); } + @Test public void toXContent() throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder(); output.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assertEquals("{\"task_id\":\"test_task_id\",\"status\":\"test_status\",\"prediction_result\":" + - "{\"column_metas\":[{\"name\":\"test\",\"column_type\":\"INTEGER\"}],\"rows\":[{\"values\":" + - "[{\"column_type\":\"INTEGER\",\"value\":1}]},{\"values\":[{\"column_type\":\"INTEGER\"," + - "\"value\":2}]}]}}", jsonStr); + assertEquals( + "{\"task_id\":\"test_task_id\",\"status\":\"test_status\",\"prediction_result\":" + + "{\"column_metas\":[{\"name\":\"test\",\"column_type\":\"INTEGER\"}],\"rows\":[{\"values\":" + + "[{\"column_type\":\"INTEGER\",\"value\":1}]},{\"values\":[{\"column_type\":\"INTEGER\"," + + "\"value\":2}]}]}}", + jsonStr + ); } @Test @@ -88,7 +89,7 @@ private void readInputStream(MLPredictionOutput output) throws IOException { assertEquals(output.predictionResult, parsedOutput.getPredictionResult()); } else { assertEquals(output.predictionResult.size(), parsedOutput.getPredictionResult().size()); - for (int i = 0 ;i outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); List mlModelTensors = Arrays.asList(mCorrModelTensor); MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); outputs.add(modelTensors); @@ -48,8 +49,8 @@ public void readInputStream_Success() throws IOException { MCorrModelTensor modelTensor = modelTensors.getMCorrModelTensors().get(0); float[] events = modelTensor.getEvent_pattern(); long[] metrics = modelTensor.getSuspected_metrics(); - assertArrayEquals(new float[]{1.0f, 2.0f, 3.0f}, events, 0.001f); - assertArrayEquals(new long[]{1, 2}, metrics); + assertArrayEquals(new float[] { 1.0f, 2.0f, 3.0f }, events, 0.001f); + assertArrayEquals(new long[] { 1, 2 }, metrics); }); } @@ -57,9 +58,7 @@ public void readInputStream_Success() throws IOException { @Test public void readInputStream_NullField() throws IOException { MetricsCorrelationOutput modelTensorOutput = MetricsCorrelationOutput.builder().build(); - readInputStream(modelTensorOutput, parsedTensorOutput -> { - assertNull(parsedTensorOutput.getModelOutput()); - }); + readInputStream(modelTensorOutput, parsedTensorOutput -> { assertNull(parsedTensorOutput.getModelOutput()); }); } private void readInputStream(MetricsCorrelationOutput input, Consumer verify) throws IOException { @@ -67,8 +66,8 @@ private void readInputStream(MetricsCorrelationOutput input, Consumer { - assertArrayEquals(resultFilter.getTargetResponse().toArray(new String[0]), parsedFilter.getTargetResponse().toArray(new String[0])); + assertArrayEquals( + resultFilter.getTargetResponse().toArray(new String[0]), + parsedFilter.getTargetResponse().toArray(new String[0]) + ); assertFalse(parsedFilter.returnBytes); assertFalse(parsedFilter.returnNumber); }); @@ -48,7 +51,6 @@ public void readInputStream_NullFields() throws IOException { }); } - private void readInputStream(ModelResultFilter input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java index d9f4c2c968..67690ed2bf 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java @@ -1,10 +1,8 @@ package org.opensearch.ml.common.output.model; -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.ml.common.output.MLOutputType; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import java.io.IOException; import java.nio.ByteBuffer; @@ -13,9 +11,11 @@ import java.util.List; import java.util.function.Consumer; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.ml.common.output.MLOutputType; public class ModelTensorOutputTest { @@ -24,11 +24,16 @@ public class ModelTensorOutputTest { @Before public void setUp() throws Exception { - value = new Float[]{1.0f, 2.0f, 3.0f}; + value = new Float[] { 1.0f, 2.0f, 3.0f }; List outputs = new ArrayList<>(); - ModelTensor tensor = ModelTensor.builder().data(value) - .name("test").shape(new long[]{1, 3}).dataType(MLResultDataType.FLOAT32) - .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})).build(); + ModelTensor tensor = ModelTensor + .builder() + .data(value) + .name("test") + .shape(new long[] { 1, 3 }) + .dataType(MLResultDataType.FLOAT32) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); List mlModelTensors = Arrays.asList(tensor); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(mlModelTensors).build(); outputs.add(modelTensors); diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java index 68904cb390..d41ba82fe4 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java @@ -5,24 +5,24 @@ package org.opensearch.ml.common.output.model; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class ModelTensorTest { @Rule @@ -35,15 +35,16 @@ public void setUp() { Map dataMap = new HashMap<>(); dataMap.put("key1", "test value1"); dataMap.put("key2", "test value2"); - modelTensor = ModelTensor.builder() - .name("model_tensor") - .data(new Number[]{1, 2, 3}) - .shape(new long[]{1, 2, 3,}) - .dataType(MLResultDataType.INT32) - .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) - .result("test result") - .dataAsMap(dataMap) - .build(); + modelTensor = ModelTensor + .builder() + .name("model_tensor") + .data(new Number[] { 1, 2, 3 }) + .shape(new long[] { 1, 2, 3, }) + .dataType(MLResultDataType.INT32) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .result("test result") + .dataAsMap(dataMap) + .build(); } @Test @@ -61,13 +62,16 @@ public void test_ModelTensorSuccess() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); modelTensor.toXContent(builder, EMPTY_PARAMS); String modelTensorContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"model_tensor\"," + - "\"data_type\":\"INT32\"," + - "\"shape\":[1,2,3]," + - "\"data\":[1,2,3]," + - "\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}," + - "\"result\":\"test result\"," + - "\"dataAsMap\":{\"key1\":\"test value1\",\"key2\":\"test value2\"}}", modelTensorContent); + assertEquals( + "{\"name\":\"model_tensor\"," + + "\"data_type\":\"INT32\"," + + "\"shape\":[1,2,3]," + + "\"data\":[1,2,3]," + + "\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}," + + "\"result\":\"test result\"," + + "\"dataAsMap\":{\"key1\":\"test value1\",\"key2\":\"test value2\"}}", + modelTensorContent + ); } @Test @@ -95,26 +99,27 @@ public void test_UnknownDataType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("data type is null"); - ModelTensor.builder() - .name("null_data") - .data(new Number[]{1, 2, 3}) - .shape(null) - .dataType(MLResultDataType.UNKNOWN) - .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) - .build(); + ModelTensor + .builder() + .name("null_data") + .data(new Number[] { 1, 2, 3 }) + .shape(null) + .dataType(MLResultDataType.UNKNOWN) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); } @Test public void test_NullDataType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("data type is null"); - ModelTensor.builder() - .name("null_data") - .data(new Number[]{1, 2, 3}) - .shape(null) - .dataType(null) - .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) - .build(); + ModelTensor + .builder() + .name("null_data") + .data(new Number[] { 1, 2, 3 }) + .shape(null) + .dataType(null) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java index f8e3fee984..a4f7dc51b1 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java @@ -5,23 +5,23 @@ package org.opensearch.ml.common.output.model; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class ModelTensorsTest { @Rule @@ -34,18 +34,20 @@ public void setUp() { String sentence = "test sentence"; String column = "model_tensor"; Integer position = 1; - modelResultFilter = ModelResultFilter.builder() - .targetResponse(Arrays.asList(column)) - .targetResponsePositions(Arrays.asList(position)) - .build(); - - ModelTensor modelTensor = ModelTensor.builder() - .name("model_tensor") - .data(new Number[]{1, 2, 3}) - .shape(new long[]{1, 2, 3,}) - .dataType(MLResultDataType.INT32) - .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) - .build(); + modelResultFilter = ModelResultFilter + .builder() + .targetResponse(Arrays.asList(column)) + .targetResponsePositions(Arrays.asList(position)) + .build(); + + ModelTensor modelTensor = ModelTensor + .builder() + .name("model_tensor") + .data(new Number[] { 1, 2, 3 }) + .shape(new long[] { 1, 2, 3, }) + .dataType(MLResultDataType.INT32) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); } @@ -55,7 +57,10 @@ public void test_ModelTensortoXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); modelTensors.toXContent(builder, EMPTY_PARAMS); String modelTensorContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"output\":[{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}", modelTensorContent); + assertEquals( + "{\"output\":[{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}", + modelTensorContent + ); } @Test @@ -80,14 +85,15 @@ public void test_StreamInAndOut_NullValue() throws IOException { @Test public void test_Filter() { - ModelTensor modelTensorFiltered = ModelTensor.builder() - .name("model_tensor") - .shape(new long[]{1, 2, 3,}) - .dataType(MLResultDataType.INT32) - .build(); + ModelTensor modelTensorFiltered = ModelTensor + .builder() + .name("model_tensor") + .shape(new long[] { 1, 2, 3, }) + .dataType(MLResultDataType.INT32) + .build(); modelTensors.filter(modelResultFilter); assertEquals(modelTensors.getMlModelTensors().size(), 1); - //assertEquals(modelTensors.getMlModelTensors().get(0), modelTensorFiltered); + // assertEquals(modelTensors.getMlModelTensors().get(0), modelTensorFiltered); } @Test @@ -112,7 +118,6 @@ public void test_ToAndFromBytes() throws IOException { assertEquals(bytes.length, bytesStreamOutput.bytes().toBytesRef().bytes.length); ModelTensors tensors = ModelTensors.fromBytes(bytes); - //assertEquals(modelTensors.getMlModelTensors(), tensors.getMlModelTensors()); + // assertEquals(modelTensors.getMlModelTensors(), tensors.getMlModelTensors()); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/output/sample/SampleAlgoOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/sample/SampleAlgoOutputTest.java index 1fd5a51176..f831cf3361 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/sample/SampleAlgoOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/sample/SampleAlgoOutputTest.java @@ -5,26 +5,25 @@ package org.opensearch.ml.common.output.sample; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.output.MLOutputType; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; - public class SampleAlgoOutputTest { SampleAlgoOutput output; + @Before public void setUp() { - output = SampleAlgoOutput.builder() - .sampleResult(1.0) - .build(); + output = SampleAlgoOutput.builder().sampleResult(1.0).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java index 7cc9e66793..1a7036e67b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteActionTest.java @@ -4,11 +4,11 @@ */ package org.opensearch.ml.common.transport.agent; -import org.junit.Test; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import org.junit.Test; + public class MLAgentDeleteActionTest { @Test public void testMLAgentDeleteActionInstance() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java index 070b5a6e33..19baef8494 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentDeleteRequestTest.java @@ -4,18 +4,18 @@ */ package org.opensearch.ml.common.transport.agent; +import static org.junit.Assert.assertEquals; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Test; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.action.ValidateActions.addValidationError; - public class MLAgentDeleteRequestTest { String agentId; @@ -23,7 +23,7 @@ public class MLAgentDeleteRequestTest { public void constructor_AgentId() { agentId = "test-abc"; MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); - assertEquals(mLAgentDeleteRequest.agentId,agentId); + assertEquals(mLAgentDeleteRequest.agentId, agentId); } @Test @@ -52,10 +52,10 @@ public void validate_Success() { public void validate_Failure() { agentId = null; MLAgentDeleteRequest mLAgentDeleteRequest = new MLAgentDeleteRequest(agentId); - assertEquals(null,mLAgentDeleteRequest.agentId); + assertEquals(null, mLAgentDeleteRequest.agentId); ActionRequestValidationException exception = addValidationError("ML agent id can't be null", null); - mLAgentDeleteRequest.validate().equals(exception) ; + mLAgentDeleteRequest.validate().equals(exception); } @Test @@ -65,6 +65,7 @@ public void fromActionRequest_Success() throws IOException { assertEquals(mLAgentDeleteRequest.fromActionRequest(mLAgentDeleteRequest), mLAgentDeleteRequest); } + @Test public void fromActionRequest_Success_fromActionRequest() throws IOException { agentId = "test-opq"; @@ -75,6 +76,7 @@ public void fromActionRequest_Success_fromActionRequest() throws IOException { public ActionRequestValidationException validate() { return null; } + @Override public void writeTo(StreamOutput out) throws IOException { mLAgentDeleteRequest.writeTo(out); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java index cba838fb02..eaf64b05ca 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetActionTest.java @@ -5,10 +5,11 @@ package org.opensearch.ml.common.transport.agent; -import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import org.junit.Test; + public class MLAgentGetActionTest { @Test @@ -17,5 +18,4 @@ public void testMLAgentGetActionInstance() { assertEquals("cluster:admin/opensearch/ml/agents/get", MLAgentGetAction.NAME); } - } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java index c32fdebb5b..e8d545d980 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetRequestTest.java @@ -4,18 +4,18 @@ */ package org.opensearch.ml.common.transport.agent; +import static org.junit.Assert.assertEquals; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Test; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.action.ValidateActions.addValidationError; - public class MLAgentGetRequestTest { String agentId; @@ -23,8 +23,8 @@ public class MLAgentGetRequestTest { public void constructor_AgentId() { agentId = "test-abc"; MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId, true); - assertEquals(mLAgentGetRequest.getAgentId(),agentId); - assertEquals(mLAgentGetRequest.isUserInitiatedGetRequest(),true); + assertEquals(mLAgentGetRequest.getAgentId(), agentId); + assertEquals(mLAgentGetRequest.isUserInitiatedGetRequest(), true); } @Test @@ -54,13 +54,14 @@ public void validate_Success() { public void validate_Failure() { agentId = null; MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId, true); - assertEquals(null,mLAgentGetRequest.agentId); + assertEquals(null, mLAgentGetRequest.agentId); ActionRequestValidationException exception = addValidationError("ML agent id can't be null", null); - mLAgentGetRequest.validate().equals(exception) ; + mLAgentGetRequest.validate().equals(exception); } + @Test - public void fromActionRequest_Success() throws IOException { + public void fromActionRequest_Success() throws IOException { agentId = "test-lmn"; MLAgentGetRequest mLAgentGetRequest = new MLAgentGetRequest(agentId, true); assertEquals(mLAgentGetRequest.fromActionRequest(mLAgentGetRequest), mLAgentGetRequest); @@ -76,6 +77,7 @@ public void fromActionRequest_Success_fromActionRequest() throws IOException { public ActionRequestValidationException validate() { return null; } + @Override public void writeTo(StreamOutput out) throws IOException { mLAgentGetRequest.writeTo(out); @@ -103,5 +105,3 @@ public void writeTo(StreamOutput out) throws IOException { mLAgentGetRequest.fromActionRequest(actionRequest); } } - - diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java index 34ef3f332b..cad3794134 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -4,6 +4,18 @@ */ package org.opensearch.ml.common.transport.agent; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.*; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -18,28 +30,19 @@ import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; -import java.io.*; -import java.time.Instant; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class MLAgentGetResponseTest { MLAgent mlAgent; @Before public void setUp() { - mlAgent = MLAgent.builder() - .name("test_agent") - .appType("test_app") - .type(MLAgentType.FLOW.name()) - .tools(Arrays.asList(MLToolSpec.builder().type("CatIndexTool").build())) - .build(); + mlAgent = MLAgent + .builder() + .name("test_agent") + .appType("test_app") + .type(MLAgentType.FLOW.name()) + .tools(Arrays.asList(MLToolSpec.builder().type("CatIndexTool").build())) + .build(); } @Test @@ -60,20 +63,29 @@ public void writeTo(StreamOutput out) throws IOException { @Test public void mLAgentGetResponse_Builder() throws IOException { - MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() - .mlAgent(mlAgent) - .build(); + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); assertEquals(mlAgentGetResponse.mlAgent, mlAgent); } + @Test public void writeTo() throws IOException { - //create ml agent using MLAgent and mlAgentGetResponse - mlAgent = new MLAgent("test", MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test", false); - MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() - .mlAgent(mlAgent) - .build(); - //use write out for both agents + // create ml agent using MLAgent and mlAgentGetResponse + mlAgent = new MLAgent( + "test", + MLAgentType.CONVERSATIONAL.name(), + "test", + new LLMSpec("test_model", Map.of("test_key", "test_value")), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + Map.of("test", "test"), + new MLMemorySpec("test", "123", 0), + Instant.EPOCH, + Instant.EPOCH, + "test", + false + ); + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); + // use write out for both agents BytesStreamOutput output = new BytesStreamOutput(); mlAgent.writeTo(output); mlAgentGetResponse.writeTo(output); @@ -90,9 +102,7 @@ public void writeTo() throws IOException { @Test public void toXContent() throws IOException { mlAgent = new MLAgent("mock", MLAgentType.FLOW.name(), "test", null, null, null, null, null, null, "test", false); - MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() - .mlAgent(mlAgent) - .build(); + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); XContentBuilder builder = XContentFactory.jsonBuilder(); ToXContent.Params params = EMPTY_PARAMS; XContentBuilder getResponseXContentBuilder = mlAgentGetResponse.toXContent(builder, params); @@ -101,17 +111,14 @@ public void toXContent() throws IOException { @Test public void fromActionResponse_Success() throws IOException { - MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() - .mlAgent(mlAgent) - .build(); + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); assertEquals(mlAgentGetResponse.fromActionResponse(mlAgentGetResponse), mlAgentGetResponse); - } + } + @Test public void fromActionResponse_Success_fromActionResponse() throws IOException { - MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() - .mlAgent(mlAgent) - .build(); + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); ActionResponse actionResponse = new ActionResponse() { @Override @@ -125,9 +132,7 @@ public void writeTo(StreamOutput out) throws IOException { @Test(expected = UncheckedIOException.class) public void fromActionResponse_IOException() { - MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder() - .mlAgent(mlAgent) - .build(); + MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput out) throws IOException { @@ -136,4 +141,4 @@ public void writeTo(StreamOutput out) throws IOException { }; mlAgentGetResponse.fromActionResponse(actionResponse); } - } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java index aa790d0ccd..8c73c8be09 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentActionTest.java @@ -5,11 +5,11 @@ package org.opensearch.ml.common.transport.agent; -import org.junit.Test; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import org.junit.Test; + public class MLRegisterAgentActionTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java index ee446db82f..7b1b5a8694 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java @@ -5,6 +5,12 @@ package org.opensearch.ml.common.transport.agent; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -16,12 +22,6 @@ import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Arrays; - -import static org.junit.Assert.*; - public class MLRegisterAgentRequestTest { MLAgent mlAgent; @@ -30,12 +30,13 @@ public class MLRegisterAgentRequestTest { @Before public void setUp() { - mlAgent = MLAgent.builder() - .name("test_agent") - .appType("test_app") - .type("flow") - .tools(Arrays.asList(MLToolSpec.builder().type("CatIndexTool").build())) - .build(); + mlAgent = MLAgent + .builder() + .name("test_agent") + .appType("test_app") + .type("flow") + .tools(Arrays.asList(MLToolSpec.builder().type("CatIndexTool").build())) + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java index 9997eb0ad6..2961b9ae82 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.transport.agent; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -17,11 +22,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.*; - public class MLRegisterAgentResponseTest { String agentId; @Rule diff --git a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetActionTest.java b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetActionTest.java index 935b4f0db8..98431e6cb2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetActionTest.java @@ -5,10 +5,11 @@ package org.opensearch.ml.common.transport.config; -import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import org.junit.Test; + public class MLConfigGetActionTest { @Test @@ -17,5 +18,4 @@ public void testMLAgentGetActionInstance() { assertEquals("cluster:admin/opensearch/ml/config/get", MLConfigGetAction.NAME); } - } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java index 7c86587816..ea16005d14 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetRequestTest.java @@ -4,18 +4,18 @@ */ package org.opensearch.ml.common.transport.config; +import static org.junit.Assert.assertEquals; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Test; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.action.ValidateActions.addValidationError; - public class MLConfigGetRequestTest { String configId; @@ -23,7 +23,7 @@ public class MLConfigGetRequestTest { public void constructor_configId() { configId = "test-abc"; MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); - assertEquals(mlConfigGetRequest.getConfigId(),configId); + assertEquals(mlConfigGetRequest.getConfigId(), configId); } @Test @@ -52,13 +52,14 @@ public void validate_Success() { public void validate_Failure() { configId = null; MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); - assertEquals(null,mlConfigGetRequest.configId); + assertEquals(null, mlConfigGetRequest.configId); ActionRequestValidationException exception = addValidationError("ML config id can't be null", null); - mlConfigGetRequest.validate().equals(exception) ; + mlConfigGetRequest.validate().equals(exception); } + @Test - public void fromActionRequest_Success() throws IOException { + public void fromActionRequest_Success() throws IOException { configId = "test-lmn"; MLConfigGetRequest mlConfigGetRequest = new MLConfigGetRequest(configId); assertEquals(mlConfigGetRequest.fromActionRequest(mlConfigGetRequest), mlConfigGetRequest); @@ -74,6 +75,7 @@ public void fromActionRequest_Success_fromActionRequest() throws IOException { public ActionRequestValidationException validate() { return null; } + @Override public void writeTo(StreamOutput out) throws IOException { mlConfigGetRequest.writeTo(out); @@ -101,5 +103,3 @@ public void writeTo(StreamOutput out) throws IOException { mlConfigGetRequest.fromActionRequest(actionRequest); } } - - diff --git a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java index ea370f979a..b187b4a8c8 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/config/MLConfigGetResponseTest.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.transport.config; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.time.Instant; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -16,18 +24,6 @@ import org.opensearch.ml.common.Configuration; import org.opensearch.ml.common.MLConfig; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.time.Instant; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class MLConfigGetResponseTest { MLConfig mlConfig; @@ -35,10 +31,7 @@ public class MLConfigGetResponseTest { @Before public void setUp() { Configuration configuration = Configuration.builder().agentId("agent_id").build(); - mlConfig = MLConfig.builder() - .type("olly_agent") - .configuration(configuration) - .build(); + mlConfig = MLConfig.builder().type("olly_agent").configuration(configuration).build(); } @Test @@ -59,20 +52,17 @@ public void writeTo(StreamOutput out) throws IOException { @Test public void MLConfigGetResponse_Builder() throws IOException { - MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() - .mlConfig(mlConfig) - .build(); + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder().mlConfig(mlConfig).build(); assertEquals(mlConfigGetResponse.mlConfig, mlConfig); } + @Test public void writeTo() throws IOException { - //create ml agent using mlConfig and mlConfigGetResponse - mlConfig = new MLConfig("olly_agent",new Configuration("agent_id"), Instant.EPOCH, Instant.EPOCH); - MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() - .mlConfig(mlConfig) - .build(); - //use write out for both agents + // create ml agent using mlConfig and mlConfigGetResponse + mlConfig = new MLConfig("olly_agent", new Configuration("agent_id"), Instant.EPOCH, Instant.EPOCH); + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder().mlConfig(mlConfig).build(); + // use write out for both agents BytesStreamOutput output = new BytesStreamOutput(); mlConfig.writeTo(output); mlConfigGetResponse.writeTo(output); @@ -87,9 +77,7 @@ public void writeTo() throws IOException { @Test public void toXContent() throws IOException { mlConfig = new MLConfig(null, null, null, null); - MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() - .mlConfig(mlConfig) - .build(); + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder().mlConfig(mlConfig).build(); XContentBuilder builder = XContentFactory.jsonBuilder(); ToXContent.Params params = EMPTY_PARAMS; XContentBuilder getResponseXContentBuilder = mlConfigGetResponse.toXContent(builder, params); @@ -98,17 +86,14 @@ public void toXContent() throws IOException { @Test public void fromActionResponse_Success() throws IOException { - MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() - .mlConfig(mlConfig) - .build(); + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder().mlConfig(mlConfig).build(); assertEquals(mlConfigGetResponse.fromActionResponse(mlConfigGetResponse), mlConfigGetResponse); - } + } + @Test public void fromActionResponse_Success_fromActionResponse() throws IOException { - MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() - .mlConfig(mlConfig) - .build(); + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder().mlConfig(mlConfig).build(); ActionResponse actionResponse = new ActionResponse() { @Override @@ -122,9 +107,7 @@ public void writeTo(StreamOutput out) throws IOException { @Test(expected = UncheckedIOException.class) public void fromActionResponse_IOException() { - MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder() - .mlConfig(mlConfig) - .build(); + MLConfigGetResponse mlConfigGetResponse = MLConfigGetResponse.builder().mlConfig(mlConfig).build(); ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput out) throws IOException { @@ -133,4 +116,4 @@ public void writeTo(StreamOutput out) throws IOException { }; mlConfigGetResponse.fromActionResponse(actionResponse); } - } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java index 28e755a32b..b6ad6b054f 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java @@ -12,6 +12,7 @@ import java.io.IOException; import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -29,8 +30,7 @@ public void setUp() { @Test public void writeToSuccess() throws IOException { - MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() - .connectorId(connectorId).build(); + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId(connectorId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlConnectorDeleteRequest.writeTo(bytesStreamOutput); MLConnectorDeleteRequest parsedRequest = new MLConnectorDeleteRequest(bytesStreamOutput.bytes().streamInput()); @@ -46,16 +46,14 @@ public void validWithNullConnectorIdException() { @Test public void validateSuccess() { - MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() - .connectorId(connectorId).build(); + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId(connectorId).build(); ActionRequestValidationException actionRequestValidationException = mlConnectorDeleteRequest.validate(); assertNull(actionRequestValidationException); } @Test public void fromActionRequestSuccess() { - MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() - .connectorId(connectorId).build(); + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId(connectorId).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -90,9 +88,9 @@ public void writeTo(StreamOutput out) throws IOException { @Test public void fromActionRequestWithConnectorDeleteRequestSuccess() { - MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() - .connectorId(connectorId).build(); - MLConnectorDeleteRequest mlConnectorDeleteRequestFromActionRequest = MLConnectorDeleteRequest.fromActionRequest(mlConnectorDeleteRequest); + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId(connectorId).build(); + MLConnectorDeleteRequest mlConnectorDeleteRequestFromActionRequest = MLConnectorDeleteRequest + .fromActionRequest(mlConnectorDeleteRequest); assertSame(mlConnectorDeleteRequest, mlConnectorDeleteRequestFromActionRequest); assertEquals(mlConnectorDeleteRequest.getConnectorId(), mlConnectorDeleteRequestFromActionRequest.getConnectorId()); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java index 0b663f4cc9..83ebe778d7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java @@ -3,9 +3,13 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.ml.common.transport.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + import java.io.IOException; import java.io.UncheckedIOException; @@ -16,11 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLConnectorGetRequestTests { private String connectorId; @@ -95,4 +94,3 @@ public void validateSuccess() { assertNull(actionRequestValidationException); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java index ce492b0862..70cbd64bc2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java @@ -13,11 +13,12 @@ import java.io.IOException; import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; -import org.opensearch.core.action.ActionResponse; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -57,18 +58,21 @@ public void toXContentTest() throws IOException { mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"name\":\"test_connector_name\",\"version\":\"1\"," + - "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + - "\"client_config\":{\"max_connection\":30," + - "\"connection_timeout\":30000,\"read_timeout\":30000," + - "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}", jsonStr); + assertEquals( + "{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30," + + "\"connection_timeout\":30000,\"read_timeout\":30000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}", + jsonStr + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index 8eb885ade2..28e597e186 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -5,6 +5,19 @@ package org.opensearch.ml.common.transport.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -13,9 +26,9 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -29,40 +42,27 @@ import org.opensearch.ml.common.connector.RetryBackoffPolicy; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; - public class MLCreateConnectorInputTests { private MLCreateConnectorInput mlCreateConnectorInput; private MLCreateConnectorInput mlCreateDryRunConnectorInput; @Rule public final ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"name\":\"test_connector_name\"," + - "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + - "\"access_mode\":\"PUBLIC\",\"client_config\":{\"max_connection\":20," + - "\"connection_timeout\":10000,\"read_timeout\":10000," + - "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + private final String expectedInputStr = "{\"name\":\"test_connector_name\"," + + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + + "\"access_mode\":\"PUBLIC\",\"client_config\":{\"max_connection\":20," + + "\"connection_timeout\":10000,\"read_timeout\":10000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; @Before - public void setUp(){ + public void setUp() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; String method = "POST"; String url = "https://test.com"; @@ -71,80 +71,90 @@ public void setUp(){ String mlCreateConnectorRequestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; - ConnectorAction action = new ConnectorAction(actionType, method, url, headers, mlCreateConnectorRequestBody, preProcessFunction, postProcessFunction); + ConnectorAction action = new ConnectorAction( + actionType, + method, + url, + headers, + mlCreateConnectorRequestBody, + preProcessFunction, + postProcessFunction + ); ConnectorClientConfig connectorClientConfig = new ConnectorClientConfig(20, 10000, 10000, 10, 10, -1, RetryBackoffPolicy.CONSTANT); - mlCreateConnectorInput = MLCreateConnectorInput.builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of(action)) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .connectorClientConfig(connectorClientConfig) - .build(); - - mlCreateDryRunConnectorInput = MLCreateConnectorInput.builder() - .dryRun(true) - .build(); + mlCreateConnectorInput = MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of(action)) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .connectorClientConfig(connectorClientConfig) + .build(); + + mlCreateDryRunConnectorInput = MLCreateConnectorInput.builder().dryRun(true).build(); } @Test public void constructorMLCreateConnectorInput_NullName() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Connector name is null"); - MLCreateConnectorInput.builder() - .name(null) - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + MLCreateConnectorInput + .builder() + .name(null) + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); } @Test public void constructorMLCreateConnectorInput_NullVersion() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Connector version is null"); - MLCreateConnectorInput.builder() - .name("test_connector_name") - .description("this is a test connector") - .version(null) - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version(null) + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); } @Test public void constructorMLCreateConnectorInput_NullProtocol() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Connector protocol is null"); - MLCreateConnectorInput.builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol(null) - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol(null) + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); } @Test @@ -177,16 +187,16 @@ public void testParse() throws Exception { @Test public void testParse_ArrayParameter() throws Exception { - String expectedInputStr = "{\"name\":\"test_connector_name\"," + - "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":[\"test input value\"]},\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + - "\"access_mode\":\"PUBLIC\"}"; + String expectedInputStr = "{\"name\":\"test_connector_name\"," + + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":[\"test input value\"]},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + + "\"access_mode\":\"PUBLIC\"}"; testParseFromJsonString(expectedInputStr, parsedInput -> { assertEquals("test_connector_name", parsedInput.getName()); assertEquals(1, parsedInput.getParameters().size()); @@ -210,11 +220,12 @@ public void readInputStream_Success() throws IOException { @Test public void readInputStream_SuccessWithNullFields() throws IOException { - MLCreateConnectorInput mlCreateMinimalConnectorInput = MLCreateConnectorInput.builder() - .name("test_connector_name") - .version("1") - .protocol("http") - .build(); + MLCreateConnectorInput mlCreateMinimalConnectorInput = MLCreateConnectorInput + .builder() + .name("test_connector_name") + .version("1") + .protocol("http") + .build(); readInputStream(mlCreateMinimalConnectorInput, parsedInput -> { assertEquals(mlCreateMinimalConnectorInput.getName(), parsedInput.getName()); assertNull(parsedInput.getActions()); @@ -225,18 +236,19 @@ public void readInputStream_SuccessWithNullFields() throws IOException { @Test public void testBuilder_NullActions_ShouldNotThrowException() { // Actions can be null for a connector without any specific actions defined. - MLCreateConnectorInput input = MLCreateConnectorInput.builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(null) // Setting actions to null - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + MLCreateConnectorInput input = MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(null) // Setting actions to null + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); assertNull(input.getActions()); } @@ -274,9 +286,10 @@ public void testWriteToVersionCompatibility() throws IOException { @Test public void testDryRunConnectorInput_IgnoreValidation() { - MLCreateConnectorInput input = MLCreateConnectorInput.builder() - .dryRun(true) // Set dryRun to true - .build(); + MLCreateConnectorInput input = MLCreateConnectorInput + .builder() + .dryRun(true) // Set dryRun to true + .build(); // No exception for missing mandatory fields when dryRun is true assertTrue(input.isDryRun()); @@ -285,17 +298,25 @@ public void testDryRunConnectorInput_IgnoreValidation() { // Helper method to create XContentParser from a JSON string private XContentParser createParser(String jsonString) throws IOException { - XContentParser parser = XContentType.JSON.xContent().createParser( + XContentParser parser = XContentType.JSON + .xContent() + .createParser( new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, - jsonString); + jsonString + ); parser.nextToken(); // Move to the first token return parser; } private void testParseFromJsonString(String expectedInputString, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputString); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputString + ); parser.nextToken(); MLCreateConnectorInput parsedInput = MLCreateConnectorInput.parse(parser); verify.accept(parsedInput); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java index f2c9aa7737..719e427684 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java @@ -16,6 +16,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -33,7 +34,7 @@ public class MLCreateConnectorRequestTests { private MLCreateConnectorRequest mlCreateConnectorRequest; @Before - public void setUp(){ + public void setUp() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; String method = "POST"; String url = "https://test.com"; @@ -42,34 +43,58 @@ public void setUp(){ String mlCreateConnectorRequestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; - ConnectorAction action = new ConnectorAction(actionType, method, url, headers, mlCreateConnectorRequestBody, preProcessFunction, postProcessFunction); - - mlCreateConnectorInput = MLCreateConnectorInput.builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of(action)) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + ConnectorAction action = new ConnectorAction( + actionType, + method, + url, + headers, + mlCreateConnectorRequestBody, + preProcessFunction, + postProcessFunction + ); + + mlCreateConnectorInput = MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of(action)) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); mlCreateConnectorRequest = MLCreateConnectorRequest.builder().mlCreateConnectorInput(mlCreateConnectorInput).build(); } @Test - public void writeToSuccess() throws IOException { + public void writeToSuccess() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); mlCreateConnectorRequest.writeTo(output); MLCreateConnectorRequest parsedRequest = new MLCreateConnectorRequest(output.bytes().streamInput()); assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getName(), parsedRequest.getMlCreateConnectorInput().getName()); - assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getAccess(), parsedRequest.getMlCreateConnectorInput().getAccess()); - assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getProtocol(), parsedRequest.getMlCreateConnectorInput().getProtocol()); - assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getBackendRoles(), parsedRequest.getMlCreateConnectorInput().getBackendRoles()); - assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getActions(), parsedRequest.getMlCreateConnectorInput().getActions()); - assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getParameters(), parsedRequest.getMlCreateConnectorInput().getParameters()); + assertEquals( + mlCreateConnectorRequest.getMlCreateConnectorInput().getAccess(), + parsedRequest.getMlCreateConnectorInput().getAccess() + ); + assertEquals( + mlCreateConnectorRequest.getMlCreateConnectorInput().getProtocol(), + parsedRequest.getMlCreateConnectorInput().getProtocol() + ); + assertEquals( + mlCreateConnectorRequest.getMlCreateConnectorInput().getBackendRoles(), + parsedRequest.getMlCreateConnectorInput().getBackendRoles() + ); + assertEquals( + mlCreateConnectorRequest.getMlCreateConnectorInput().getActions(), + parsedRequest.getMlCreateConnectorInput().getActions() + ); + assertEquals( + mlCreateConnectorRequest.getMlCreateConnectorInput().getParameters(), + parsedRequest.getMlCreateConnectorInput().getParameters() + ); } @Test @@ -79,8 +104,7 @@ public void validateSuccess() { @Test public void validateWithNullMLCreateConnectorInputException() { - MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() - .build(); + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder().build(); ActionRequestValidationException exception = mlCreateConnectorRequest.validate(); assertEquals("Validation Failed: 1: ML Connector input can't be null;", exception.getMessage()); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java index 7995e47f8f..31a71e9f23 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java @@ -5,7 +5,13 @@ package org.opensearch.ml.common.transport.connector; -import org.junit.Assert; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; @@ -15,13 +21,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; - public class MLCreateConnectorResponseTests { @Test @@ -43,7 +42,6 @@ public void readFromStream() throws IOException { assertEquals("testConnectorId", response2.getConnectorId()); } - @Test public void fromActionResponseWithMLCreateConnectorResponseSuccess() { MLCreateConnectorResponse response = new MLCreateConnectorResponse("testConnectorId"); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java index f95b236259..06ba1d0d21 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequestTests.java @@ -5,6 +5,16 @@ package org.opensearch.ml.common.transport.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -17,23 +27,13 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; - public class MLExecuteConnectorRequestTests { private MLExecuteConnectorRequest mlExecuteConnectorRequest; private MLInput mlInput; private String connectorId; @Before - public void setUp(){ + public void setUp() { MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("input", "hello")).build(); connectorId = "test_connector"; mlInput = RemoteInferenceMLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.CONNECTOR).build(); @@ -41,14 +41,17 @@ public void setUp(){ } @Test - public void writeToSuccess() throws IOException { + public void writeToSuccess() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); mlExecuteConnectorRequest.writeTo(output); MLExecuteConnectorRequest parsedRequest = new MLExecuteConnectorRequest(output.bytes().streamInput()); assertEquals(mlExecuteConnectorRequest.getConnectorId(), parsedRequest.getConnectorId()); assertEquals(mlExecuteConnectorRequest.getMlInput().getAlgorithm(), parsedRequest.getMlInput().getAlgorithm()); - assertEquals(mlExecuteConnectorRequest.getMlInput().getInputDataset().getInputDataType(), parsedRequest.getMlInput().getInputDataset().getInputDataType()); - assertEquals("hello", ((RemoteInferenceInputDataSet)parsedRequest.getMlInput().getInputDataset()).getParameters().get("input")); + assertEquals( + mlExecuteConnectorRequest.getMlInput().getInputDataset().getInputDataType(), + parsedRequest.getMlInput().getInputDataset().getInputDataType() + ); + assertEquals("hello", ((RemoteInferenceInputDataSet) parsedRequest.getMlInput().getInputDataset()).getParameters().get("input")); } @Test @@ -64,16 +67,14 @@ public void testConstructor() { @Test public void validateWithNullMLInputException() { - MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder() - .build(); + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder().build(); ActionRequestValidationException exception = executeConnectorRequest.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); } @Test public void validateWithNullMLInputDataSetException() { - MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder().mlInput(new MLInput()) - .build(); + MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.builder().mlInput(new MLInput()).build(); ActionRequestValidationException exception = executeConnectorRequest.validate(); assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage()); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java index 44e970f95c..49b013cdf2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java @@ -5,6 +5,16 @@ package org.opensearch.ml.common.transport.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.mockito.MockitoAnnotations; @@ -18,16 +28,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Collections; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; - public class MLUpdateConnectorRequestTests { private String connectorId; private MLCreateConnectorInput updateContent; @@ -38,10 +38,7 @@ public void setUp() { MockitoAnnotations.openMocks(this); this.connectorId = "test-connector_id"; this.updateContent = MLCreateConnectorInput.builder().description("new description").updateConnector(true).build(); - mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() - .connectorId(connectorId) - .updateContent(updateContent) - .build(); + mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(updateContent).build(); } @Test @@ -63,14 +60,22 @@ public void validate_Exception_NullConnectorId() { MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build(); Exception exception = updateConnectorRequest.validate(); - assertEquals("Validation Failed: 1: ML connector id can't be null;2: Update connector content can't be null;", exception.getMessage()); + assertEquals( + "Validation Failed: 1: ML connector id can't be null;2: Update connector content can't be null;", + exception.getMessage() + ); } @Test public void parse_success() throws IOException { String jsonStr = "{\"version\":\"new version\",\"description\":\"new description\"}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId); assertEquals(updateConnectorRequest.getConnectorId(), connectorId); @@ -81,7 +86,8 @@ public void parse_success() throws IOException { @Test public void fromActionRequest_Success() { - MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() + MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest + .builder() .connectorId(connectorId) .updateContent(updateContent) .build(); @@ -90,7 +96,8 @@ public void fromActionRequest_Success() { @Test public void fromActionRequest_Success_fromActionRequest() { - MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() + MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest + .builder() .connectorId(connectorId) .updateContent(updateContent) .build(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequestTest.java index a9ff1a6361..df37c88437 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequestTest.java @@ -31,16 +31,14 @@ public void setUp() { modelId = "testModelId"; - request = MLControllerDeleteRequest.builder() - .modelId(modelId).build(); + request = MLControllerDeleteRequest.builder().modelId(modelId).build(); } @Test public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); - MLControllerDeleteRequest parsedRequest = new MLControllerDeleteRequest( - bytesStreamOutput.bytes().streamInput()); + MLControllerDeleteRequest parsedRequest = new MLControllerDeleteRequest(bytesStreamOutput.bytes().streamInput()); assertEquals(parsedRequest.getModelId(), modelId); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponseTest.java index 6d29106842..5c5cac92e2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponseTest.java @@ -10,7 +10,6 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; import java.io.IOException; import java.io.UncheckedIOException; @@ -37,18 +36,12 @@ public class MLControllerGetResponseTest { @Before public void setUp() { - MLRateLimiter rateLimiter = MLRateLimiter.builder() - .limit("1") - .unit(TimeUnit.MILLISECONDS) - .build(); - controller = MLController.builder() - .modelId("testModelId") - .userRateLimiter(new HashMap<>() { - { - put("testUser", rateLimiter); - } - }) - .build(); + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); + controller = MLController.builder().modelId("testModelId").userRateLimiter(new HashMap<>() { + { + put("testUser", rateLimiter); + } + }).build(); response = MLControllerGetResponse.builder().controller(controller).build(); } @@ -56,14 +49,17 @@ public void setUp() { public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); response.writeTo(bytesStreamOutput); - MLControllerGetResponse parsedResponse = new MLControllerGetResponse( - bytesStreamOutput.bytes().streamInput()); + MLControllerGetResponse parsedResponse = new MLControllerGetResponse(bytesStreamOutput.bytes().streamInput()); assertNotEquals(response.getController(), parsedResponse.getController()); assertEquals(response.getController().getModelId(), parsedResponse.getController().getModelId()); - assertEquals(response.getController().getUserRateLimiter().get("testUser").getLimit(), - parsedResponse.getController().getUserRateLimiter().get("testUser").getLimit()); - assertEquals(response.getController().getUserRateLimiter().get("testUser").getUnit(), - parsedResponse.getController().getUserRateLimiter().get("testUser").getUnit()); + assertEquals( + response.getController().getUserRateLimiter().get("testUser").getLimit(), + parsedResponse.getController().getUserRateLimiter().get("testUser").getLimit() + ); + assertEquals( + response.getController().getUserRateLimiter().get("testUser").getUnit(), + parsedResponse.getController().getUserRateLimiter().get("testUser").getUnit() + ); } @Test @@ -73,14 +69,14 @@ public void toXContentTest() throws IOException { assertNotNull(builder); String jsonStr = builder.toString(); assertEquals( - "{\"model_id\":\"testModelId\",\"user_rate_limiter\":{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}", - jsonStr); + "{\"model_id\":\"testModelId\",\"user_rate_limiter\":{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}", + jsonStr + ); } @Test public void fromActionResponseWithMLControllerGetResponseSuccess() { - MLControllerGetResponse responseFromActionResponse = MLControllerGetResponse - .fromActionResponse(response); + MLControllerGetResponse responseFromActionResponse = MLControllerGetResponse.fromActionResponse(response); assertSame(response, responseFromActionResponse); assertEquals(response.getController(), responseFromActionResponse.getController()); } @@ -93,8 +89,7 @@ public void writeTo(StreamOutput out) throws IOException { response.writeTo(out); } }; - MLControllerGetResponse responseFromActionResponse = MLControllerGetResponse - .fromActionResponse(actionResponse); + MLControllerGetResponse responseFromActionResponse = MLControllerGetResponse.fromActionResponse(actionResponse); assertNotSame(response, responseFromActionResponse); assertNotEquals(response.getController(), responseFromActionResponse.getController()); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequestTest.java index a8e19b58a7..51d285c669 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequestTest.java @@ -26,103 +26,90 @@ import org.opensearch.ml.common.controller.MLRateLimiter; public class MLCreateControllerRequestTest { - private MLController controllerInput; - - private MLCreateControllerRequest request; - - @Before - public void setUp() throws Exception { - - MLRateLimiter rateLimiter = MLRateLimiter.builder() - .limit("1") - .unit(TimeUnit.MILLISECONDS) - .build(); - controllerInput = MLController.builder() - .modelId("testModelId") - .userRateLimiter(new HashMap<>() { - { - put("testUser", rateLimiter); - } - }) - .build(); - request = MLCreateControllerRequest.builder() - .controllerInput(controllerInput) - .build(); - } - - @Test - public void writeToSuccess() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - request.writeTo(bytesStreamOutput); - MLCreateControllerRequest parsedRequest = new MLCreateControllerRequest( - bytesStreamOutput.bytes().streamInput()); - assertEquals("testModelId", parsedRequest.getControllerInput().getModelId()); - assertTrue(parsedRequest.getControllerInput().getUserRateLimiter().containsKey("testUser")); - assertEquals("1", parsedRequest.getControllerInput().getUserRateLimiter().get("testUser") - .getLimit()); - assertEquals(TimeUnit.MILLISECONDS, - parsedRequest.getControllerInput().getUserRateLimiter().get("testUser").getUnit()); - } - - @Test - public void validateSuccess() { - assertNull(request.validate()); - } - - @Test - public void validateWithNullMLControllerInputException() { - MLCreateControllerRequest request = MLCreateControllerRequest.builder().build(); - ActionRequestValidationException exception = request.validate(); - assertEquals("Validation Failed: 1: Model controller input can't be null;", exception.getMessage()); - } - - @Test - public void validateWithNullMLModelID() { - controllerInput.setModelId(null); - MLCreateControllerRequest request = MLCreateControllerRequest.builder() - .controllerInput(controllerInput) - .build(); - - assertNull(request.validate()); - assertNull(request.getControllerInput().getModelId()); - } - - @Test - public void fromActionRequestWithMLCreateControllerRequestSuccess() { - assertSame(MLCreateControllerRequest.fromActionRequest(request), request); - } - - @Test - public void fromActionRequestWithNonMLCreateControllerRequestSuccess() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - request.writeTo(out); - } - }; - MLCreateControllerRequest result = MLCreateControllerRequest.fromActionRequest(actionRequest); - assertNotSame(result, request); - assertEquals(request.getControllerInput().getModelId(), result.getControllerInput().getModelId()); - } - - @Test(expected = UncheckedIOException.class) - public void fromActionRequestIOException() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); - } - }; - MLCreateControllerRequest.fromActionRequest(actionRequest); - } + private MLController controllerInput; + + private MLCreateControllerRequest request; + + @Before + public void setUp() throws Exception { + + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); + controllerInput = MLController.builder().modelId("testModelId").userRateLimiter(new HashMap<>() { + { + put("testUser", rateLimiter); + } + }).build(); + request = MLCreateControllerRequest.builder().controllerInput(controllerInput).build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLCreateControllerRequest parsedRequest = new MLCreateControllerRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals("testModelId", parsedRequest.getControllerInput().getModelId()); + assertTrue(parsedRequest.getControllerInput().getUserRateLimiter().containsKey("testUser")); + assertEquals("1", parsedRequest.getControllerInput().getUserRateLimiter().get("testUser").getLimit()); + assertEquals(TimeUnit.MILLISECONDS, parsedRequest.getControllerInput().getUserRateLimiter().get("testUser").getUnit()); + } + + @Test + public void validateSuccess() { + assertNull(request.validate()); + } + + @Test + public void validateWithNullMLControllerInputException() { + MLCreateControllerRequest request = MLCreateControllerRequest.builder().build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model controller input can't be null;", exception.getMessage()); + } + + @Test + public void validateWithNullMLModelID() { + controllerInput.setModelId(null); + MLCreateControllerRequest request = MLCreateControllerRequest.builder().controllerInput(controllerInput).build(); + + assertNull(request.validate()); + assertNull(request.getControllerInput().getModelId()); + } + + @Test + public void fromActionRequestWithMLCreateControllerRequestSuccess() { + assertSame(MLCreateControllerRequest.fromActionRequest(request), request); + } + + @Test + public void fromActionRequestWithNonMLCreateControllerRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLCreateControllerRequest result = MLCreateControllerRequest.fromActionRequest(actionRequest); + assertNotSame(result, request); + assertEquals(request.getControllerInput().getModelId(), result.getControllerInput().getModelId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLCreateControllerRequest.fromActionRequest(actionRequest); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponseTest.java index 6c11a667d1..831cdae25b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponseTest.java @@ -36,8 +36,7 @@ public void setup() { public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); response.writeTo(bytesStreamOutput); - MLCreateControllerResponse newResponse = new MLCreateControllerResponse( - bytesStreamOutput.bytes().streamInput()); + MLCreateControllerResponse newResponse = new MLCreateControllerResponse(bytesStreamOutput.bytes().streamInput()); assertEquals(response.getModelId(), newResponse.getModelId()); assertEquals(response.getStatus(), newResponse.getStatus()); } @@ -68,8 +67,7 @@ public void writeTo(StreamOutput out) throws IOException { response.writeTo(out); } }; - MLCreateControllerResponse responseFromActionResponse = MLCreateControllerResponse - .fromActionResponse(actionResponse); + MLCreateControllerResponse responseFromActionResponse = MLCreateControllerResponse.fromActionResponse(actionResponse); assertNotSame(response, responseFromActionResponse); assertEquals(response.getModelId(), responseFromActionResponse.getModelId()); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponseTest.java index e0d817f2d7..362638840c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponseTest.java @@ -39,12 +39,13 @@ public class MLDeployControllerNodeResponseTest { @Before public void setUp() throws Exception { localNode = new DiscoveryNode( - "foo0", - "foo0", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT); + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); } @Test @@ -71,8 +72,7 @@ public void testReadProfile() throws IOException { MLDeployControllerNodeResponse response = new MLDeployControllerNodeResponse(localNode, new HashMap<>()); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLDeployControllerNodeResponse newResponse = MLDeployControllerNodeResponse - .readStats(output.bytes().streamInput()); + MLDeployControllerNodeResponse newResponse = MLDeployControllerNodeResponse.readStats(output.bytes().streamInput()); assertNotEquals(newResponse, response); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequestTest.java index cb07cdbbc0..a975860c0f 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequestTest.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common.transport.controller; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -14,10 +18,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; - // This test combined MLDeployControllerNodesRequestTest and MLDeployControllerNodeRequestTest together. @RunWith(MockitoJUnitRunner.class) public class MLDeployControllerNodesRequestTest { @@ -40,9 +40,11 @@ public void setUp() throws Exception { DiscoveryNode[] discoveryNodeIds = { localNode1, localNode2 }; deployControllerNodeRequestWithStringNodeIds = new MLDeployControllerNodeRequest( - new MLDeployControllerNodesRequest(stringNodeIds, modelId)); + new MLDeployControllerNodesRequest(stringNodeIds, modelId) + ); deployControllerNodeRequestWithDiscoveryNodeIds = new MLDeployControllerNodeRequest( - new MLDeployControllerNodesRequest(discoveryNodeIds, modelId)); + new MLDeployControllerNodesRequest(discoveryNodeIds, modelId) + ); } @@ -50,15 +52,13 @@ public void setUp() throws Exception { public void testConstructorSerialization1() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); deployControllerNodeRequestWithStringNodeIds.writeTo(output); - assertEquals("testModelId", - deployControllerNodeRequestWithStringNodeIds.getDeployControllerNodesRequest().getModelId()); + assertEquals("testModelId", deployControllerNodeRequestWithStringNodeIds.getDeployControllerNodesRequest().getModelId()); } @Test public void testConstructorSerialization2() { - assertEquals(2, deployControllerNodeRequestWithDiscoveryNodeIds.getDeployControllerNodesRequest() - .concreteNodes().length); + assertEquals(2, deployControllerNodeRequestWithDiscoveryNodeIds.getDeployControllerNodesRequest().concreteNodes().length); } @@ -70,8 +70,10 @@ public void testConstructorFromInputStream() throws IOException { StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLDeployControllerNodeRequest parsedNodeRequest = new MLDeployControllerNodeRequest(streamInput); - assertEquals(deployControllerNodeRequestWithStringNodeIds.getDeployControllerNodesRequest().getModelId(), - parsedNodeRequest.getDeployControllerNodesRequest().getModelId()); + assertEquals( + deployControllerNodeRequestWithStringNodeIds.getDeployControllerNodesRequest().getModelId(), + parsedNodeRequest.getDeployControllerNodesRequest().getModelId() + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponseTest.java index bcc3f0c38a..aa081ecb42 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponseTest.java @@ -12,7 +12,6 @@ import java.net.InetAddress; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,27 +43,28 @@ public class MLDeployControllerNodesResponseTest { public void setUp() throws Exception { clusterName = new ClusterName("clusterName"); node1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT); + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); node2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT); + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); } @Test public void testSerializationDeserialization1() throws IOException { List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); - MLDeployControllerNodesResponse response = new MLDeployControllerNodesResponse(clusterName, responseList, - failuresList); + MLDeployControllerNodesResponse response = new MLDeployControllerNodesResponse(clusterName, responseList, failuresList); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLDeployControllerNodesResponse newResponse = new MLDeployControllerNodesResponse(output.bytes().streamInput()); @@ -86,9 +86,7 @@ public void testToXContent() throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assertEquals( - "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", - jsonStr); + assertEquals("{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", jsonStr); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponseTest.java index e1df438393..eea39ae6c6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponseTest.java @@ -25,10 +25,7 @@ import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; @RunWith(MockitoJUnitRunner.class) public class MLUndeployControllerNodeResponseTest { @@ -42,23 +39,22 @@ public class MLUndeployControllerNodeResponseTest { @Before public void setUp() throws Exception { localNode = new DiscoveryNode( - "foo0", - "foo0", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT); + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); } @Test public void testSerializationDeserialization() throws IOException { Map undeployControllerStatus = Map.of("modelName:version", "response"); - MLUndeployControllerNodeResponse response = new MLUndeployControllerNodeResponse(localNode, - undeployControllerStatus); + MLUndeployControllerNodeResponse response = new MLUndeployControllerNodeResponse(localNode, undeployControllerStatus); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLUndeployControllerNodeResponse newResponse = new MLUndeployControllerNodeResponse( - output.bytes().streamInput()); + MLUndeployControllerNodeResponse newResponse = new MLUndeployControllerNodeResponse(output.bytes().streamInput()); assertEquals(newResponse.getNode().getId(), response.getNode().getId()); } @@ -67,8 +63,7 @@ public void testSerializationDeserializationNullModelUpdateModelCacheStatus() th MLUndeployControllerNodeResponse response = new MLUndeployControllerNodeResponse(localNode, null); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLUndeployControllerNodeResponse newResponse = new MLUndeployControllerNodeResponse( - output.bytes().streamInput()); + MLUndeployControllerNodeResponse newResponse = new MLUndeployControllerNodeResponse(output.bytes().streamInput()); assertEquals(newResponse.getNode().getId(), response.getNode().getId()); } @@ -77,8 +72,7 @@ public void testReadProfile() throws IOException { MLUndeployControllerNodeResponse response = new MLUndeployControllerNodeResponse(localNode, new HashMap<>()); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLUndeployControllerNodeResponse newResponse = MLUndeployControllerNodeResponse - .readStats(output.bytes().streamInput()); + MLUndeployControllerNodeResponse newResponse = MLUndeployControllerNodeResponse.readStats(output.bytes().streamInput()); assertNotEquals(newResponse, response); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequestTest.java index 20d6d66fb5..58787ab6a7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequestTest.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.transport.controller; import static org.junit.Assert.assertEquals; + import java.io.IOException; import org.junit.Before; @@ -39,9 +40,11 @@ public void setUp() throws Exception { DiscoveryNode[] discoveryNodeIds = { localNode1, localNode2 }; undeployControllerNodeRequestWithStringNodeIds = new MLUndeployControllerNodeRequest( - new MLUndeployControllerNodesRequest(stringNodeIds, modelId)); + new MLUndeployControllerNodesRequest(stringNodeIds, modelId) + ); undeployControllerNodeRequestWithDiscoveryNodeIds = new MLUndeployControllerNodeRequest( - new MLUndeployControllerNodesRequest(discoveryNodeIds, modelId)); + new MLUndeployControllerNodesRequest(discoveryNodeIds, modelId) + ); } @@ -49,15 +52,13 @@ public void setUp() throws Exception { public void testConstructorSerialization1() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); undeployControllerNodeRequestWithStringNodeIds.writeTo(output); - assertEquals("testModelId", - undeployControllerNodeRequestWithStringNodeIds.getUndeployControllerNodesRequest().getModelId()); + assertEquals("testModelId", undeployControllerNodeRequestWithStringNodeIds.getUndeployControllerNodesRequest().getModelId()); } @Test public void testConstructorSerialization2() { - assertEquals(2, undeployControllerNodeRequestWithDiscoveryNodeIds.getUndeployControllerNodesRequest() - .concreteNodes().length); + assertEquals(2, undeployControllerNodeRequestWithDiscoveryNodeIds.getUndeployControllerNodesRequest().concreteNodes().length); } @@ -69,8 +70,10 @@ public void testConstructorFromInputStream() throws IOException { StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLUndeployControllerNodeRequest parsedNodeRequest = new MLUndeployControllerNodeRequest(streamInput); - assertEquals(undeployControllerNodeRequestWithStringNodeIds.getUndeployControllerNodesRequest().getModelId(), - parsedNodeRequest.getUndeployControllerNodesRequest().getModelId()); + assertEquals( + undeployControllerNodeRequestWithStringNodeIds.getUndeployControllerNodesRequest().getModelId(), + parsedNodeRequest.getUndeployControllerNodesRequest().getModelId() + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponseTest.java index c374b85312..213250bc6a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponseTest.java @@ -12,7 +12,6 @@ import java.net.InetAddress; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,31 +43,31 @@ public class MLUndeployControllerNodesResponseTest { public void setUp() throws Exception { clusterName = new ClusterName("clusterName"); node1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT); + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); node2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT); + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); } @Test public void testSerializationDeserialization1() throws IOException { List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); - MLUndeployControllerNodesResponse response = new MLUndeployControllerNodesResponse(clusterName, responseList, - failuresList); + MLUndeployControllerNodesResponse response = new MLUndeployControllerNodesResponse(clusterName, responseList, failuresList); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLUndeployControllerNodesResponse newResponse = new MLUndeployControllerNodesResponse( - output.bytes().streamInput()); + MLUndeployControllerNodesResponse newResponse = new MLUndeployControllerNodesResponse(output.bytes().streamInput()); assertEquals(newResponse.getNodes().size(), response.getNodes().size()); } @@ -83,14 +82,11 @@ public void testToXContent() throws IOException { nodes.add(new MLUndeployControllerNodeResponse(node2, undeployControllerStatus2)); List failures = new ArrayList<>(); - MLUndeployControllerNodesResponse response = new MLUndeployControllerNodesResponse(clusterName, nodes, - failures); + MLUndeployControllerNodesResponse response = new MLUndeployControllerNodesResponse(clusterName, nodes, failures); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assertEquals( - "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", - jsonStr); + assertEquals("{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", jsonStr); } @Test @@ -98,8 +94,7 @@ public void testNullUpdateModelCacheStatusToXContent() throws IOException { List nodes = new ArrayList<>(); nodes.add(new MLUndeployControllerNodeResponse(node1, null)); List failures = new ArrayList<>(); - MLUndeployControllerNodesResponse response = new MLUndeployControllerNodesResponse(clusterName, nodes, - failures); + MLUndeployControllerNodesResponse response = new MLUndeployControllerNodesResponse(clusterName, nodes, failures); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequestTest.java index 73a40047d2..32e7a7d078 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequestTest.java @@ -26,104 +26,90 @@ import org.opensearch.ml.common.controller.MLRateLimiter; public class MLUpdateControllerRequestTest { - private MLController updateControllerInput; - - private MLUpdateControllerRequest request; - - @Before - public void setUp() throws Exception { - - MLRateLimiter rateLimiter = MLRateLimiter.builder() - .limit("1") - .unit(TimeUnit.MILLISECONDS) - .build(); - updateControllerInput = MLController.builder() - .modelId("testModelId") - .userRateLimiter(new HashMap<>() { - { - put("testUser", rateLimiter); - } - }) - .build(); - request = MLUpdateControllerRequest.builder() - .updateControllerInput(updateControllerInput) - .build(); - } - - @Test - public void writeToSuccess() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - request.writeTo(bytesStreamOutput); - MLUpdateControllerRequest parsedRequest = new MLUpdateControllerRequest( - bytesStreamOutput.bytes().streamInput()); - assertEquals("testModelId", parsedRequest.getUpdateControllerInput().getModelId()); - assertTrue(parsedRequest.getUpdateControllerInput().getUserRateLimiter().containsKey("testUser")); - assertEquals("1", parsedRequest.getUpdateControllerInput().getUserRateLimiter().get("testUser") - .getLimit()); - assertEquals(TimeUnit.MILLISECONDS, parsedRequest.getUpdateControllerInput().getUserRateLimiter() - .get("testUser").getUnit()); - } - - @Test - public void validateSuccess() { - assertNull(request.validate()); - } - - @Test - public void validateWithNullMLControllerInputException() { - MLUpdateControllerRequest request = MLUpdateControllerRequest.builder().build(); - ActionRequestValidationException exception = request.validate(); - assertEquals("Validation Failed: 1: Update model controller input can't be null;", exception.getMessage()); - } - - @Test - public void validateWithNullMLModelID() { - updateControllerInput.setModelId(null); - MLUpdateControllerRequest request = MLUpdateControllerRequest.builder() - .updateControllerInput(updateControllerInput) - .build(); - - assertNull(request.validate()); - assertNull(request.getUpdateControllerInput().getModelId()); - } - - @Test - public void fromActionRequestWithMLUpdateControllerRequestSuccess() { - assertSame(MLUpdateControllerRequest.fromActionRequest(request), request); - } - - @Test - public void fromActionRequestWithNonMLUpdateControllerRequestSuccess() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - request.writeTo(out); - } - }; - MLUpdateControllerRequest result = MLUpdateControllerRequest.fromActionRequest(actionRequest); - assertNotSame(result, request); - assertEquals(request.getUpdateControllerInput().getModelId(), - result.getUpdateControllerInput().getModelId()); - } - - @Test(expected = UncheckedIOException.class) - public void fromActionRequestIOException() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); - } - }; - MLUpdateControllerRequest.fromActionRequest(actionRequest); - } + private MLController updateControllerInput; + + private MLUpdateControllerRequest request; + + @Before + public void setUp() throws Exception { + + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); + updateControllerInput = MLController.builder().modelId("testModelId").userRateLimiter(new HashMap<>() { + { + put("testUser", rateLimiter); + } + }).build(); + request = MLUpdateControllerRequest.builder().updateControllerInput(updateControllerInput).build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLUpdateControllerRequest parsedRequest = new MLUpdateControllerRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals("testModelId", parsedRequest.getUpdateControllerInput().getModelId()); + assertTrue(parsedRequest.getUpdateControllerInput().getUserRateLimiter().containsKey("testUser")); + assertEquals("1", parsedRequest.getUpdateControllerInput().getUserRateLimiter().get("testUser").getLimit()); + assertEquals(TimeUnit.MILLISECONDS, parsedRequest.getUpdateControllerInput().getUserRateLimiter().get("testUser").getUnit()); + } + + @Test + public void validateSuccess() { + assertNull(request.validate()); + } + + @Test + public void validateWithNullMLControllerInputException() { + MLUpdateControllerRequest request = MLUpdateControllerRequest.builder().build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Update model controller input can't be null;", exception.getMessage()); + } + + @Test + public void validateWithNullMLModelID() { + updateControllerInput.setModelId(null); + MLUpdateControllerRequest request = MLUpdateControllerRequest.builder().updateControllerInput(updateControllerInput).build(); + + assertNull(request.validate()); + assertNull(request.getUpdateControllerInput().getModelId()); + } + + @Test + public void fromActionRequestWithMLUpdateControllerRequestSuccess() { + assertSame(MLUpdateControllerRequest.fromActionRequest(request), request); + } + + @Test + public void fromActionRequestWithNonMLUpdateControllerRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLUpdateControllerRequest result = MLUpdateControllerRequest.fromActionRequest(actionRequest); + assertNotSame(result, request); + assertEquals(request.getUpdateControllerInput().getModelId(), result.getUpdateControllerInput().getModelId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLUpdateControllerRequest.fromActionRequest(actionRequest); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInputTest.java index e3f1583f13..8d93213a4a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInputTest.java @@ -1,5 +1,12 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -12,16 +19,6 @@ import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.dataset.MLInputDataType; -import java.io.IOException; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Arrays; - -import static org.junit.Assert.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.verify; - @RunWith(MockitoJUnitRunner.class) public class MLDeployModelInputTest { @@ -31,29 +28,31 @@ public class MLDeployModelInputTest { @Before public void setUp() throws Exception { Instant time = Instant.now(); - mlTask = MLTask.builder() - .taskId("mlTaskTaskId") - .modelId("mlTaskModelId") - .taskType(MLTaskType.PREDICTION) - .functionName(FunctionName.LINEAR_REGRESSION) - .state(MLTaskState.RUNNING) - .inputType(MLInputDataType.DATA_FRAME) - .workerNodes(Arrays.asList("node1")) - .progress(0.0f) - .outputIndex("test_index") - .error("test_error") - .createTime(time.minus(1, ChronoUnit.MINUTES)) - .lastUpdateTime(time) - .build(); + mlTask = MLTask + .builder() + .taskId("mlTaskTaskId") + .modelId("mlTaskModelId") + .taskType(MLTaskType.PREDICTION) + .functionName(FunctionName.LINEAR_REGRESSION) + .state(MLTaskState.RUNNING) + .inputType(MLInputDataType.DATA_FRAME) + .workerNodes(Arrays.asList("node1")) + .progress(0.0f) + .outputIndex("test_index") + .error("test_error") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .lastUpdateTime(time) + .build(); - mlDeployModelInput = mlDeployModelInput.builder() - .modelId("testModelId") - .taskId("testTaskId") - .modelContentHash("modelContentHash") - .nodeCount(3) - .coordinatingNodeId("coordinatingNodeId") - .mlTask(mlTask) - .build(); + mlDeployModelInput = mlDeployModelInput + .builder() + .modelId("testModelId") + .taskId("testTaskId") + .modelContentHash("modelContentHash") + .nodeCount(3) + .coordinatingNodeId("coordinatingNodeId") + .mlTask(mlTask) + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponseTest.java index cce0c463be..5d4136b06c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponseTest.java @@ -1,5 +1,15 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -10,16 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.transport.TransportAddress; -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLDeployModelNodeResponseTest { @@ -29,12 +29,12 @@ public class MLDeployModelNodeResponseTest { @Before public void setUp() throws Exception { localNode = new DiscoveryNode( - "foo0", - "foo0", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequestTest.java index 938543f230..eb510d2a11 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequestTest.java @@ -1,5 +1,15 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.*; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -16,16 +26,6 @@ import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.dataset.MLInputDataType; -import java.io.IOException; -import java.net.InetAddress; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Collections; - -import static org.junit.Assert.*; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLDeployModelNodesRequestTest { @@ -39,54 +39,63 @@ public class MLDeployModelNodesRequestTest { @Before public void setUp() throws Exception { localNode1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); localNode2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); localNode3 = new DiscoveryNode( - "foo3", - "foo3", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo3", + "foo3", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); Instant time = Instant.now(); - mlTask = MLTask.builder() - .taskId("mlTaskTaskId") - .modelId("mlTaskModelId") - .taskType(MLTaskType.PREDICTION) - .functionName(FunctionName.LINEAR_REGRESSION) - .state(MLTaskState.RUNNING) - .inputType(MLInputDataType.DATA_FRAME) - .workerNodes(Arrays.asList("node1")) - .progress(0.0f) - .outputIndex("test_index") - .error("test_error") - .createTime(time.minus(1, ChronoUnit.MINUTES)) - .lastUpdateTime(time) - .build(); + mlTask = MLTask + .builder() + .taskId("mlTaskTaskId") + .modelId("mlTaskModelId") + .taskType(MLTaskType.PREDICTION) + .functionName(FunctionName.LINEAR_REGRESSION) + .state(MLTaskState.RUNNING) + .inputType(MLInputDataType.DATA_FRAME) + .workerNodes(Arrays.asList("node1")) + .progress(0.0f) + .outputIndex("test_index") + .error("test_error") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .lastUpdateTime(time) + .build(); } @Test public void testConstructorSerialization1() throws IOException { - String [] nodeIds = {"id1", "id2", "id3"}; - MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask); + String[] nodeIds = { "id1", "id2", "id3" }; + MLDeployModelInput deployModelInput = new MLDeployModelInput( + "modelId", + "taskId", + "modelContentHash", + 3, + "coordinatingNodeId", + true, + mlTask + ); MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest( - new MLDeployModelNodesRequest(nodeIds, deployModelInput) + new MLDeployModelNodesRequest(nodeIds, deployModelInput) ); BytesStreamOutput output = new BytesStreamOutput(); @@ -95,54 +104,103 @@ public void testConstructorSerialization1() throws IOException { assertNotNull(MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput()); assertEquals("modelId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId()); assertEquals("taskId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getTaskId()); - assertEquals("modelContentHash", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash()); + assertEquals( + "modelContentHash", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash() + ); assertEquals(3, MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getNodeCount().intValue()); - assertEquals("coordinatingNodeId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId()); - assertEquals(mlTask.getTaskId(), MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId()); + assertEquals( + "coordinatingNodeId", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId() + ); + assertEquals( + mlTask.getTaskId(), + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId() + ); } @Test public void testConstructorSerialization2() throws IOException { - DiscoveryNode [] nodeIds = {localNode1, localNode2, localNode3}; - MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask); + DiscoveryNode[] nodeIds = { localNode1, localNode2, localNode3 }; + MLDeployModelInput deployModelInput = new MLDeployModelInput( + "modelId", + "taskId", + "modelContentHash", + 3, + "coordinatingNodeId", + true, + mlTask + ); MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest( - new MLDeployModelNodesRequest(nodeIds, deployModelInput) + new MLDeployModelNodesRequest(nodeIds, deployModelInput) ); assertNotNull(MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput()); assertEquals("modelId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId()); assertEquals("taskId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getTaskId()); - assertEquals("modelContentHash", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash()); + assertEquals( + "modelContentHash", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash() + ); assertEquals(3, MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getNodeCount().intValue()); - assertEquals("coordinatingNodeId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId()); - assertEquals(mlTask.getTaskId(), MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId()); + assertEquals( + "coordinatingNodeId", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId() + ); + assertEquals( + mlTask.getTaskId(), + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId() + ); } @Test public void testConstructorSerialization3() throws IOException { MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest( - new MLDeployModelNodesRequest(localNode1, localNode2, localNode3) + new MLDeployModelNodesRequest(localNode1, localNode2, localNode3) ); MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setModelId("modelIdSetDuringTest"); MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setTaskId("taskIdSetDuringTest"); - MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setModelContentHash("modelContentHashSetDuringTest"); + MLDeployModelNodeRequest + .getMLDeployModelNodesRequest() + .getMlDeployModelInput() + .setModelContentHash("modelContentHashSetDuringTest"); MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setNodeCount(2); - MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setCoordinatingNodeId("coordinatingNodeIdSetDuringTest"); + MLDeployModelNodeRequest + .getMLDeployModelNodesRequest() + .getMlDeployModelInput() + .setCoordinatingNodeId("coordinatingNodeIdSetDuringTest"); MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setMlTask(mlTask); assertNotNull(MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput()); assertEquals("modelIdSetDuringTest", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId()); assertEquals("taskIdSetDuringTest", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getTaskId()); - assertEquals("modelContentHashSetDuringTest", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash()); + assertEquals( + "modelContentHashSetDuringTest", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash() + ); assertEquals(2, MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getNodeCount().intValue()); - assertEquals("coordinatingNodeIdSetDuringTest", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId()); - assertEquals(mlTask.getTaskId(), MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId()); + assertEquals( + "coordinatingNodeIdSetDuringTest", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId() + ); + assertEquals( + mlTask.getTaskId(), + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId() + ); } @Test public void testConstructorFromInputStream() throws IOException { - String [] nodeIds = {"id1", "id2", "id3"}; - MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask); + String[] nodeIds = { "id1", "id2", "id3" }; + MLDeployModelInput deployModelInput = new MLDeployModelInput( + "modelId", + "taskId", + "modelContentHash", + 3, + "coordinatingNodeId", + true, + mlTask + ); MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest( - new MLDeployModelNodesRequest(nodeIds, deployModelInput) + new MLDeployModelNodesRequest(nodeIds, deployModelInput) ); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); MLDeployModelNodeRequest.writeTo(bytesStreamOutput); @@ -150,8 +208,10 @@ public void testConstructorFromInputStream() throws IOException { MLDeployModelNodeRequest parsedNodeRequest = new MLDeployModelNodeRequest(streamInput); assertNotNull(parsedNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput()); - assertEquals(MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId(), - parsedNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId()); + assertEquals( + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId(), + parsedNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId() + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponseTest.java index f3a146aa06..e05a18a242 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponseTest.java @@ -1,5 +1,12 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.*; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -10,18 +17,11 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.InetAddress; -import java.util.*; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLDeployModelNodesResponseTest { @@ -33,7 +33,6 @@ public void setUp() throws Exception { clusterName = new ClusterName("clusterName"); } - @Test public void testSerializationDeserialization() throws IOException { List responseList = new ArrayList<>(); @@ -49,24 +48,24 @@ public void testSerializationDeserialization() throws IOException { public void testToXContent() throws IOException { List nodes = new ArrayList<>(); DiscoveryNode node1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); Map modelToDeployStatus1 = new HashMap<>(); modelToDeployStatus1.put("modelName:version1", "response"); nodes.add(new MLDeployModelNodeResponse(node1, modelToDeployStatus1)); DiscoveryNode node2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); Map modelToDeployStatus2 = new HashMap<>(); modelToDeployStatus2.put("modelName:version2", "response"); @@ -78,8 +77,8 @@ public void testToXContent() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); assertEquals( - "{\"foo1\":{\"stats\":{\"modelName:version1\":\"response\"}},\"foo2\":{\"stats\":{\"modelName:version2\":\"response\"}}}", - jsonStr + "{\"foo1\":{\"stats\":{\"modelName:version1\":\"response\"}},\"foo2\":{\"stats\":{\"modelName:version2\":\"response\"}}}", + jsonStr ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequestTest.java index e2945dc212..b75193d970 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequestTest.java @@ -1,44 +1,43 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.function.Consumer; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.*; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Collections; -import java.util.function.Consumer; - -import static org.junit.Assert.*; - public class MLDeployModelRequestTest { private MLDeployModelRequest mlDeployModelRequest; @Before public void setUp() throws Exception { - mlDeployModelRequest = mlDeployModelRequest.builder(). - modelId("modelId"). - modelNodeIds(new String[]{"modelNodeIds"}). - async(true). - dispatchTask(true). - build(); + mlDeployModelRequest = mlDeployModelRequest + .builder() + .modelId("modelId") + .modelNodeIds(new String[] { "modelNodeIds" }) + .async(true) + .dispatchTask(true) + .build(); } @Test public void testValidateWithBuilder() { - MLDeployModelRequest request = mlDeployModelRequest.builder(). - modelId("modelId"). - build(); + MLDeployModelRequest request = mlDeployModelRequest.builder().modelId("modelId").build(); assertNull(request.validate()); } @@ -50,12 +49,13 @@ public void testValidateWithoutBuilder() { @Test public void validate_Exception_WithNullModelId() { - MLDeployModelRequest request = mlDeployModelRequest.builder(). - modelId(null). - modelNodeIds(new String[]{"modelNodeIds"}). - async(true). - dispatchTask(true). - build(); + MLDeployModelRequest request = mlDeployModelRequest + .builder() + .modelId(null) + .modelNodeIds(new String[] { "modelNodeIds" }) + .async(true) + .dispatchTask(true) + .build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML model id can't be null;", exception.getMessage()); } @@ -69,7 +69,7 @@ public void writeTo_Success() throws IOException { request = new MLDeployModelRequest(bytesStreamOutput.bytes().streamInput()); assertEquals("modelId", request.getModelId()); - assertArrayEquals(new String[]{"modelNodeIds"}, request.getModelNodeIds()); + assertArrayEquals(new String[] { "modelNodeIds" }, request.getModelNodeIds()); assertTrue(request.isAsync()); assertTrue(request.isDispatchTask()); } @@ -92,9 +92,7 @@ public void writeTo(StreamOutput out) throws IOException { @Test public void fromActionRequest_Success_WithMLDeployModelRequest() { - MLDeployModelRequest request = mlDeployModelRequest.builder(). - modelId("modelId"). - build(); + MLDeployModelRequest request = mlDeployModelRequest.builder().modelId("modelId").build(); assertSame(mlDeployModelRequest.fromActionRequest(request), request); } @@ -124,27 +122,33 @@ public void testParse() throws Exception { String expectedInputStr = "{\"node_ids\":[\"modelNodeIds\"]}"; parseFromJsonString(modelId, expectedInputStr, parsedInput -> { assertEquals("modelId", parsedInput.getModelId()); - assertArrayEquals(new String [] {"modelNodeIds"}, parsedInput.getModelNodeIds()); + assertArrayEquals(new String[] { "modelNodeIds" }, parsedInput.getModelNodeIds()); assertFalse(parsedInput.isAsync()); - assertTrue(parsedInput.isDispatchTask());} - ); + assertTrue(parsedInput.isDispatchTask()); + }); } @Test public void testParseWithInvalidField() throws Exception { String modelId = "modelId"; - String withInvalidFieldInputStr = "{\"void\":\"void\", \"dispatchTask\":\"false\", \"async\":\"true\", \"node_ids\":[\"modelNodeIds\"]}"; + String withInvalidFieldInputStr = + "{\"void\":\"void\", \"dispatchTask\":\"false\", \"async\":\"true\", \"node_ids\":[\"modelNodeIds\"]}"; parseFromJsonString(modelId, withInvalidFieldInputStr, parsedInput -> { assertEquals("modelId", parsedInput.getModelId()); - assertArrayEquals(new String [] {"modelNodeIds"}, parsedInput.getModelNodeIds()); + assertArrayEquals(new String[] { "modelNodeIds" }, parsedInput.getModelNodeIds()); assertFalse(parsedInput.isAsync()); - assertTrue(parsedInput.isDispatchTask());} - ); + assertTrue(parsedInput.isDispatchTask()); + }); } private void parseFromJsonString(String modelId, String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLDeployModelRequest parsedInput = mlDeployModelRequest.parse(parser, modelId); verify.accept(parsedInput); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java index 5b4c2f2cd3..10ad2d46cf 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java @@ -1,5 +1,10 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -8,11 +13,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLTaskType; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class MLDeployModelResponseTest { private String taskId; @@ -50,7 +50,6 @@ public void testToXContent() throws IOException { assertNotNull(builder); String jsonStr = builder.toString(); // Verify the results - assertEquals("{\"task_id\":\"test_id\"," + "\"task_type\":\"DEPLOY_MODEL\"," + - "\"status\":\"test\"}", jsonStr); + assertEquals("{\"task_id\":\"test_id\"," + "\"task_type\":\"DEPLOY_MODEL\"," + "\"status\":\"test\"}", jsonStr); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequestTest.java index 3c5faa1559..8458c090e4 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequestTest.java @@ -1,5 +1,11 @@ package org.opensearch.ml.common.transport.execute; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -7,26 +13,8 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataframe.ColumnType; -import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.dataframe.DataFrameBuilder; -import org.opensearch.ml.common.dataset.DataFrameInputDataset; -import org.opensearch.ml.common.dataset.MLInputDataType; -import org.opensearch.ml.common.dataset.MLInputDataset; -import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.Input; -import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; - -import static org.junit.Assert.*; public class MLExecuteTaskRequestTest { private Input exInput; @@ -37,9 +25,9 @@ public class MLExecuteTaskRequestTest { @Before public void setUp() { inputData = new ArrayList<>(); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); exInput = MetricsCorrelationInput.builder().inputData(inputData).build(); } @@ -47,10 +35,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLExecuteTaskRequest request = MLExecuteTaskRequest.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .input(exInput) - .build(); + MLExecuteTaskRequest request = MLExecuteTaskRequest.builder().functionName(FunctionName.METRICS_CORRELATION).input(exInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLExecuteTaskRequest(bytesStreamOutput.bytes().streamInput()); @@ -62,10 +47,7 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { - MLExecuteTaskRequest request = MLExecuteTaskRequest.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .input(exInput) - .build(); + MLExecuteTaskRequest request = MLExecuteTaskRequest.builder().functionName(FunctionName.METRICS_CORRELATION).input(exInput).build(); assertNull(request.validate()); } @@ -74,17 +56,14 @@ public void validate_Success() { public void validate_Exception_NullFunctionNane() { exceptionRule.expect(NullPointerException.class); exceptionRule.expectMessage("functionName is marked non-null but is null"); - MLExecuteTaskRequest request = MLExecuteTaskRequest.builder() - .build(); + MLExecuteTaskRequest request = MLExecuteTaskRequest.builder().build(); request.validate(); } @Test public void validate_Exception_NullMLInput() { - MLExecuteTaskRequest request = MLExecuteTaskRequest.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .build(); + MLExecuteTaskRequest request = MLExecuteTaskRequest.builder().functionName(FunctionName.METRICS_CORRELATION).build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponseTest.java index e7aa0aa774..ec085d68e3 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponseTest.java @@ -1,5 +1,12 @@ package org.opensearch.ml.common.transport.execute; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; @@ -10,88 +17,90 @@ import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors; import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import static org.junit.Assert.*; - public class MLExecuteTaskResponseTest { @Test public void writeTo_Success() throws IOException { List outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); List mlModelTensors = Arrays.asList(mCorrModelTensor); MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); outputs.add(modelTensors); MetricsCorrelationOutput output = MetricsCorrelationOutput.builder().modelOutput(outputs).build(); - MLExecuteTaskResponse response = MLExecuteTaskResponse.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .output(output) - .build(); + MLExecuteTaskResponse response = MLExecuteTaskResponse + .builder() + .functionName(FunctionName.METRICS_CORRELATION) + .output(output) + .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); response.writeTo(bytesStreamOutput); response = new MLExecuteTaskResponse(bytesStreamOutput.bytes().streamInput()); - MetricsCorrelationOutput mcorrOutputTest = (MetricsCorrelationOutput)response.getOutput(); + MetricsCorrelationOutput mcorrOutputTest = (MetricsCorrelationOutput) response.getOutput(); assertEquals(1, mcorrOutputTest.getModelOutput().size()); MCorrModelTensors testmodelTensors = mcorrOutputTest.getModelOutput().get(0); assertEquals(1, testmodelTensors.getMCorrModelTensors().size()); MCorrModelTensor testmodelTensor = testmodelTensors.getMCorrModelTensors().get(0); float[] events = testmodelTensor.getEvent_pattern(); long[] metrics = testmodelTensor.getSuspected_metrics(); - assertArrayEquals(new float[]{1.0f, 2.0f, 3.0f}, events, 0.001f); - assertArrayEquals(new long[]{1, 2}, metrics); + assertArrayEquals(new float[] { 1.0f, 2.0f, 3.0f }, events, 0.001f); + assertArrayEquals(new long[] { 1, 2 }, metrics); } @Test public void fromActionResponse_WithMLPredictionTaskResponse() { List outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); List mlModelTensors = Arrays.asList(mCorrModelTensor); MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); outputs.add(modelTensors); MetricsCorrelationOutput output = MetricsCorrelationOutput.builder().modelOutput(outputs).build(); - MLExecuteTaskResponse response = MLExecuteTaskResponse.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .output(output) - .build(); + MLExecuteTaskResponse response = MLExecuteTaskResponse + .builder() + .functionName(FunctionName.METRICS_CORRELATION) + .output(output) + .build(); assertSame(response, MLExecuteTaskResponse.fromActionResponse(response)); } @Test public void toXContentTest() throws IOException { List outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); List mlModelTensors = Arrays.asList(mCorrModelTensor); MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); outputs.add(modelTensors); MetricsCorrelationOutput output = MetricsCorrelationOutput.builder().modelOutput(outputs).build(); - MLExecuteTaskResponse response = MLExecuteTaskResponse.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .output(output) - .build(); + MLExecuteTaskResponse response = MLExecuteTaskResponse + .builder() + .functionName(FunctionName.METRICS_CORRELATION) + .output(output) + .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"inference_results\":[{" + - "\"event_window\":[4.0,5.0,6.0]," + - "\"event_pattern\":[1.0,2.0,3.0]," + - "\"suspected_metrics\":[1,2]}]}", jsonStr); + assertEquals( + "{\"inference_results\":[{" + + "\"event_window\":[4.0,5.0,6.0]," + + "\"event_pattern\":[1.0,2.0,3.0]," + + "\"suspected_metrics\":[1,2]}]}", + jsonStr + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java index d9b3fc77c4..7720f478e7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java @@ -1,5 +1,15 @@ package org.opensearch.ml.common.transport.forward; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.function.Consumer; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -21,79 +31,74 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; -import java.io.IOException; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.function.Consumer; - -import static org.junit.Assert.*; - - @RunWith(MockitoJUnitRunner.class) public class MLForwardInputTest { private MLForwardInput forwardInput; private final FunctionName functionName = FunctionName.KMEANS; - @Before public void setUp() throws Exception { Instant time = Instant.now(); - MLTask mlTask = MLTask.builder() - .taskId("mlTaskTaskId") - .modelId("mlTaskModelId") - .taskType(MLTaskType.PREDICTION) - .functionName(functionName) - .state(MLTaskState.RUNNING) - .inputType(MLInputDataType.DATA_FRAME) - .workerNodes(Arrays.asList("mlTaskNode1")) - .progress(0.0f) - .outputIndex("test_index") - .error("test_error") - .createTime(time.minus(1, ChronoUnit.MINUTES)) - .lastUpdateTime(time) - .build(); - - DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }})); - MLInput modelInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(KMeansParams.builder().centroids(1).build()) - .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) - .build(); - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder() - .functionName(functionName) - .modelName("testModelName") - .version("testModelVersion") - .modelGroupId("mockModelGroupId") - .url("url") - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds"}) - .build(); - - forwardInput = MLForwardInput.builder() - .taskId("forwardInputTaskId") - .modelId("forwardInputModelId") - .workerNodeId("forwardInputWorkerNodeId") - .requestType(MLForwardRequestType.DEPLOY_MODEL_DONE) - .mlTask(mlTask) - .modelInput(modelInput) - .error("forwardInputError") - .workerNodes(new String [] {"forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3"}) - .registerModelInput(registerModelInput) - .build(); + MLTask mlTask = MLTask + .builder() + .taskId("mlTaskTaskId") + .modelId("mlTaskModelId") + .taskType(MLTaskType.PREDICTION) + .functionName(functionName) + .state(MLTaskState.RUNNING) + .inputType(MLInputDataType.DATA_FRAME) + .workerNodes(Arrays.asList("mlTaskNode1")) + .progress(0.0f) + .outputIndex("test_index") + .error("test_error") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .lastUpdateTime(time) + .build(); + + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + })); + MLInput modelInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(KMeansParams.builder().centroids(1).build()) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) + .build(); + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + MLRegisterModelInput registerModelInput = MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("mockModelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); + + forwardInput = MLForwardInput + .builder() + .taskId("forwardInputTaskId") + .modelId("forwardInputModelId") + .workerNodeId("forwardInputWorkerNodeId") + .requestType(MLForwardRequestType.DEPLOY_MODEL_DONE) + .mlTask(mlTask) + .modelInput(modelInput) + .error("forwardInputError") + .workerNodes(new String[] { "forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3" }) + .registerModelInput(registerModelInput) + .build(); } @Test @@ -104,7 +109,6 @@ public void readInputStream_Success() throws IOException { }); } - @Test public void readInputStream_SuccessWithNullFields() throws IOException { forwardInput.setMlTask(null); @@ -117,7 +121,6 @@ public void readInputStream_SuccessWithNullFields() throws IOException { }); } - private void readInputStream(MLForwardInput input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java index b0eabfcb83..65815c921a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java @@ -1,5 +1,15 @@ package org.opensearch.ml.common.transport.forward; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -23,16 +33,6 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; - -import static org.junit.Assert.*; - @RunWith(MockitoJUnitRunner.class) public class MLForwardRequestTest { @@ -42,70 +42,74 @@ public class MLForwardRequestTest { private MLRegisterModelInput registerModelInput; private final FunctionName functionName = FunctionName.KMEANS; - @Before public void setUp() throws Exception { Instant time = Instant.now(); - mlTask = MLTask.builder() - .taskId("mlTaskTaskId") - .modelId("mlTaskModelId") - .taskType(MLTaskType.PREDICTION) - .functionName(functionName) - .state(MLTaskState.RUNNING) - .inputType(MLInputDataType.DATA_FRAME) - .workerNodes(Arrays.asList("mlTaskNode1")) - .progress(0.0f) - .outputIndex("test_index") - .error("test_error") - .createTime(time.minus(1, ChronoUnit.MINUTES)) - .lastUpdateTime(time) - .build(); - - DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }})); - modelInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(KMeansParams.builder().centroids(1).build()) - .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) - .build(); - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - registerModelInput = MLRegisterModelInput.builder() - .functionName(functionName) - .modelName("testModelName") - .version("testModelVersion") - .modelGroupId("modelGroupId") - .url("url") - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); - - forwardInput = MLForwardInput.builder() - .taskId("forwardInputTaskId") - .modelId("forwardInputModelId") - .workerNodeId("forwardInputWorkerNodeId") - .requestType(MLForwardRequestType.DEPLOY_MODEL_DONE) - .mlTask(mlTask) - .modelInput(modelInput) - .error("forwardInputError") - .workerNodes(new String [] {"forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3"}) - .registerModelInput(registerModelInput) - .build(); + mlTask = MLTask + .builder() + .taskId("mlTaskTaskId") + .modelId("mlTaskModelId") + .taskType(MLTaskType.PREDICTION) + .functionName(functionName) + .state(MLTaskState.RUNNING) + .inputType(MLInputDataType.DATA_FRAME) + .workerNodes(Arrays.asList("mlTaskNode1")) + .progress(0.0f) + .outputIndex("test_index") + .error("test_error") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .lastUpdateTime(time) + .build(); + + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + })); + modelInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(KMeansParams.builder().centroids(1).build()) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) + .build(); + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + registerModelInput = MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); + + forwardInput = MLForwardInput + .builder() + .taskId("forwardInputTaskId") + .modelId("forwardInputModelId") + .workerNodeId("forwardInputWorkerNodeId") + .requestType(MLForwardRequestType.DEPLOY_MODEL_DONE) + .mlTask(mlTask) + .modelInput(modelInput) + .error("forwardInputError") + .workerNodes(new String[] { "forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3" }) + .registerModelInput(registerModelInput) + .build(); } @Test public void writeTo_Success() throws IOException { - MLForwardRequest request = MLForwardRequest.builder() - .forwardInput(forwardInput) - .build(); + MLForwardRequest request = MLForwardRequest.builder().forwardInput(forwardInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLForwardRequest(bytesStreamOutput.bytes().streamInput()); @@ -114,7 +118,10 @@ public void writeTo_Success() throws IOException { assertEquals("forwardInputWorkerNodeId", request.getForwardInput().getWorkerNodeId()); assertEquals(MLForwardRequestType.DEPLOY_MODEL_DONE, request.getForwardInput().getRequestType()); assertEquals("forwardInputError", request.getForwardInput().getError()); - assertArrayEquals(new String [] {"forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3"}, request.getForwardInput().getWorkerNodes()); + assertArrayEquals( + new String[] { "forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3" }, + request.getForwardInput().getWorkerNodes() + ); assertEquals(mlTask.getTaskId(), request.getForwardInput().getMlTask().getTaskId()); assertEquals(modelInput.getAlgorithm().toString(), request.getForwardInput().getModelInput().getAlgorithm().toString()); assertEquals(registerModelInput.getModelName(), request.getForwardInput().getRegisterModelInput().getModelName()); @@ -122,17 +129,14 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { - MLForwardRequest request = MLForwardRequest.builder() - .forwardInput(forwardInput) - .build(); + MLForwardRequest request = MLForwardRequest.builder().forwardInput(forwardInput).build(); assertNull(request.validate()); } @Test public void validate_Exception_NullMLInput() { - MLForwardRequest request = MLForwardRequest.builder() - .build(); + MLForwardRequest request = MLForwardRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); } @@ -141,9 +145,7 @@ public void validate_Exception_NullMLInput() { // MLForwardInput check its parameters when created, so exception is not thrown here public void validate_Exception_NullMLModelName() { forwardInput.setTaskId(null); - MLForwardRequest request = MLForwardRequest.builder() - .forwardInput(forwardInput) - .build(); + MLForwardRequest request = MLForwardRequest.builder().forwardInput(forwardInput).build(); assertNull(request.validate()); assertNull(request.getForwardInput().getTaskId()); @@ -151,19 +153,14 @@ public void validate_Exception_NullMLModelName() { @Test public void fromActionRequest_Success_WithMLForwardRequest() { - MLForwardRequest request = MLForwardRequest.builder() - .forwardInput(forwardInput) - .build(); + MLForwardRequest request = MLForwardRequest.builder().forwardInput(forwardInput).build(); assertSame(MLForwardRequest.fromActionRequest(request), request); } - @Test public void fromActionRequest_Success_WithNonMLForwardRequest() { - MLForwardRequest request = MLForwardRequest.builder() - .forwardInput(forwardInput) - .build(); + MLForwardRequest request = MLForwardRequest.builder().forwardInput(forwardInput).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -179,8 +176,14 @@ public void writeTo(StreamOutput out) throws IOException { assertNotSame(result, request); assertEquals(request.getForwardInput().getTaskId(), result.getForwardInput().getTaskId()); assertEquals(request.getForwardInput().getMlTask().getTaskId(), result.getForwardInput().getMlTask().getTaskId()); - assertEquals(request.getForwardInput().getModelInput().getAlgorithm().toString(), result.getForwardInput().getModelInput().getAlgorithm().toString()); - assertEquals(request.getForwardInput().getRegisterModelInput().getModelName(), result.getForwardInput().getRegisterModelInput().getModelName()); + assertEquals( + request.getForwardInput().getModelInput().getAlgorithm().toString(), + result.getForwardInput().getModelInput().getAlgorithm().toString() + ); + assertEquals( + request.getForwardInput().getRegisterModelInput().getModelName(), + result.getForwardInput().getRegisterModelInput().getModelName() + ); } @Test(expected = UncheckedIOException.class) diff --git a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardResponseTest.java index 60992ae0d0..9748a65b8e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardResponseTest.java @@ -1,26 +1,26 @@ package org.opensearch.ml.common.transport.forward; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.HashMap; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; -import org.opensearch.core.action.ActionResponse; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.output.MLPredictionOutput; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Collections; -import java.util.HashMap; - -import static org.junit.Assert.*; - @RunWith(MockitoJUnitRunner.class) public class MLForwardResponseTest { @@ -30,14 +30,12 @@ public class MLForwardResponseTest { @Before public void setUp() throws Exception { status = "test"; - DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }})); - predictionOutput = MLPredictionOutput.builder() - .status("Success") - .predictionResult(dataFrame) - .taskId("taskId") - .build(); + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + })); + predictionOutput = MLPredictionOutput.builder().status("Success").predictionResult(dataFrame).taskId("taskId").build(); } @Test @@ -63,7 +61,10 @@ public void testToXContent() throws IOException { assertNotNull(builder); String jsonStr = builder.toString(); // Verify the results - assertEquals("{\"result\":{\"task_id\":\"taskId\",\"status\":\"Success\",\"prediction_result\":{\"column_metas\":[{\"name\":\"key1\",\"column_type\":\"DOUBLE\"}],\"rows\":[{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":2.0}]}]}}}", jsonStr); + assertEquals( + "{\"result\":{\"task_id\":\"taskId\",\"status\":\"Success\",\"prediction_result\":{\"column_metas\":[{\"name\":\"key1\",\"column_type\":\"DOUBLE\"}],\"rows\":[{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":2.0}]}]}}}", + jsonStr + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java index 533b96ecdf..109eafde95 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.transport.model; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -12,14 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLModelDeleteRequestTest { private String modelId; @@ -30,8 +30,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() - .modelId(modelId).build(); + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlModelDeleteRequest.writeTo(bytesStreamOutput); MLModelDeleteRequest parsedModel = new MLModelDeleteRequest(bytesStreamOutput.bytes().streamInput()); @@ -40,8 +39,7 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { - MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() - .modelId(modelId).build(); + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); ActionRequestValidationException actionRequestValidationException = mlModelDeleteRequest.validate(); assertNull(actionRequestValidationException); } @@ -56,8 +54,7 @@ public void validate_Exception_NullModelId() { @Test public void fromActionRequest_Success() { - MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() - .modelId(modelId).build(); + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -90,11 +87,9 @@ public void writeTo(StreamOutput out) throws IOException { MLModelDeleteRequest.fromActionRequest(actionRequest); } - @Test public void fromActionRequestWithModelDeleteRequest_Success() { - MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() - .modelId(modelId).build(); + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); MLModelDeleteRequest mlModelDeleteRequestFromActionRequest = MLModelDeleteRequest.fromActionRequest(mlModelDeleteRequest); assertSame(mlModelDeleteRequest, mlModelDeleteRequestFromActionRequest); assertEquals(mlModelDeleteRequest.getModelId(), mlModelDeleteRequestFromActionRequest.getModelId()); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java index 97f784d868..4a16bf9347 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.transport.model; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -12,14 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLModelGetRequestTest { private String modelId; @@ -30,8 +30,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder() - .modelId(modelId).build(); + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlModelGetRequest.writeTo(bytesStreamOutput); MLModelGetRequest parsedModel = new MLModelGetRequest(bytesStreamOutput.bytes().streamInput()); @@ -48,17 +47,16 @@ public void validate_Exception_NullModelId() { @Test public void fromActionRequest_Success() { - MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder() - .modelId(modelId).build(); + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { - return null; + return null; } @Override public void writeTo(StreamOutput out) throws IOException { - mlModelGetRequest.writeTo(out); + mlModelGetRequest.writeTo(out); } }; MLModelGetRequest result = MLModelGetRequest.fromActionRequest(actionRequest); @@ -71,12 +69,12 @@ public void fromActionRequest_IOException() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { - return null; + return null; } @Override public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); + throw new IOException("test"); } }; MLModelGetRequest.fromActionRequest(actionRequest); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java index 4574cb72d5..cbd4cd6133 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java @@ -13,12 +13,13 @@ import java.io.IOException; import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; -import org.opensearch.core.action.ActionResponse; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -32,14 +33,15 @@ public class MLModelGetResponseTest { @Before public void setUp() { - mlModel = MLModel.builder() - .name("model") - .algorithm(FunctionName.KMEANS) - .version("1.0.0") - .content("content") - .user(new User()) - .modelState(MLModelState.TRAINED) - .build(); + mlModel = MLModel + .builder() + .name("model") + .algorithm(FunctionName.KMEANS) + .version("1.0.0") + .content("content") + .user(new User()) + .modelState(MLModelState.TRAINED) + .build(); } @Test @@ -63,11 +65,14 @@ public void toXContentTest() throws IOException { mlModelGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"name\":\"model\"," + - "\"algorithm\":\"KMEANS\"," + - "\"model_version\":\"1.0.0\"," + - "\"model_content\":\"content\"," + - "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null},\"model_state\":\"TRAINED\"}", jsonStr); + assertEquals( + "{\"name\":\"model\"," + + "\"algorithm\":\"KMEANS\"," + + "\"model_version\":\"1.0.0\"," + + "\"model_content\":\"content\"," + + "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null},\"model_state\":\"TRAINED\"}", + jsonStr + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index d656ae5134..9014f0ec49 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -22,11 +22,11 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -35,116 +35,104 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.controller.MLRateLimiter; -import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; -import org.opensearch.search.SearchModule; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.search.SearchModule; public class MLUpdateModelInputTest { private MLUpdateModelInput updateModelInput; - private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" - + - "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + - "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" - + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"updated_connector\":" + - "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" - + - "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" - + - "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" - + - "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" - + - "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1}"; - - private final String expectedOutputStrForUpdateRequestDoc = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" - + - "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + - "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" - + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" + - "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" - + - "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" - + - "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" - + - "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" - + - "\"test-connector_id\",\"last_updated_time\":1}"; + private final String expectedInputStr = + "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"updated_connector\":" + + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1}"; + + private final String expectedOutputStrForUpdateRequestDoc = + "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector\":" + + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "\"test-connector_id\",\"last_updated_time\":1}"; private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":" - + - "\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + - "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" - + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":" + - "\"test-connector_id\"}"; + + "\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":" + + "\"test-connector_id\"}"; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @Before public void setUp() throws Exception { - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); Connector updatedConnector = HttpConnector - .builder() - .name("test") - .protocol("http") - .version("1") - .credential(Map.of("api_key", "credential_value")) - .parameters(Map.of("param1", "value1")) - .actions( - Arrays - .asList( - ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("https://api.openai.com/v1/chat/completions") - .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) - .requestBody( - "{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") - .build())) - .build(); + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") + .build() + ) + ) + .build(); MLCreateConnectorInput updateContent = MLCreateConnectorInput - .builder() - .updateConnector(true) - .version("1") - .description("updated description") - .build(); - - MLRateLimiter rateLimiter = MLRateLimiter.builder() - .limit("1") - .unit(TimeUnit.MILLISECONDS) - .build(); - - updateModelInput = MLUpdateModelInput.builder() - .modelId("test-model_id") - .modelGroupId("modelGroupId") - .version("2") - .name("name") - .description("description") - .isEnabled(false) - .rateLimiter(rateLimiter) - .modelConfig(config) - .updatedConnector(updatedConnector) - .connectorId("test-connector_id") - .connector(updateContent) - .lastUpdateTime(Instant.ofEpochMilli(1)) - .build(); + .builder() + .updateConnector(true) + .version("1") + .description("updated description") + .build(); + + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); + + updateModelInput = MLUpdateModelInput + .builder() + .modelId("test-model_id") + .modelGroupId("modelGroupId") + .version("2") + .name("name") + .description("description") + .isEnabled(false) + .rateLimiter(rateLimiter) + .modelConfig(config) + .updatedConnector(updatedConnector) + .connectorId("test-connector_id") + .connector(updateContent) + .lastUpdateTime(Instant.ofEpochMilli(1)) + .build(); } @Test @@ -158,9 +146,7 @@ public void readInputStreamSuccess() throws IOException { @Test public void readInputStreamSuccessWithNullFields() throws IOException { updateModelInput.setModelConfig(null); - readInputStream(updateModelInput, parsedInput -> { - assertNull(parsedInput.getModelConfig()); - }); + readInputStream(updateModelInput, parsedInput -> { assertNull(parsedInput.getModelConfig()); }); } @Test @@ -172,29 +158,25 @@ public void testToXContent() throws Exception { @Test public void testToXContentIncomplete() throws Exception { String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}"; - updateModelInput = MLUpdateModelInput.builder() - .modelId("test-model_id").build(); + updateModelInput = MLUpdateModelInput.builder().modelId("test-model_id").build(); String jsonStr = serializationWithToXContent(updateModelInput); assertEquals(expectedIncompleteInputStr, jsonStr); } @Test public void parseSuccess() throws Exception { - testParseFromJsonString(expectedInputStr, parsedInput -> { - assertEquals("name", parsedInput.getName()); - }); + testParseFromJsonString(expectedInputStr, parsedInput -> { assertEquals("name", parsedInput.getName()); }); } @Test public void parseWithNullFieldWithoutModel() throws Exception { exceptionRule.expect(IllegalStateException.class); - String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":" - + - "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + - "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" - + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + String expectedInputStrWithNullField = + "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":" + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { try { assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); @@ -206,22 +188,17 @@ public void parseWithNullFieldWithoutModel() throws Exception { @Test public void parseWithIllegalFieldWithoutModel() throws Exception { - String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" - + - "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + - "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" - + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"updated_connector\":" + - "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" - + - "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" - + - "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" - + - "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" - + - "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1,\"illegal_field\":\"This field need to be skipped.\"}"; + String expectedInputStrWithIllegalField = + "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"updated_connector\":" + + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1,\"illegal_field\":\"This field need to be skipped.\"}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); @@ -232,8 +209,13 @@ public void parseWithIllegalFieldWithoutModel() throws Exception { } private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLUpdateModelInput parsedInput = MLUpdateModelInput.parse(parser); verify.accept(parsedInput); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java index ef0298df27..184ab097d2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java @@ -5,54 +5,50 @@ package org.opensearch.ml.common.transport.model; -import org.junit.Before; -import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; import org.junit.Test; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.transport.model.MLUpdateModelInput; -import org.opensearch.ml.common.transport.model.MLUpdateModelRequest; - -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; public class MLUpdateModelRequestTest { private MLUpdateModelRequest updateModelRequest; @Before - public void setUp(){ + public void setUp() { MockitoAnnotations.openMocks(this); - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - - MLUpdateModelInput updateModelInput = MLUpdateModelInput.builder() - .modelId("test-model_id") - .modelGroupId("modelGroupId") - .name("name") - .description("description") - .modelConfig(config) - .build(); - - updateModelRequest = MLUpdateModelRequest.builder() - .updateModelInput(updateModelInput) - .build(); + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLUpdateModelInput updateModelInput = MLUpdateModelInput + .builder() + .modelId("test-model_id") + .modelGroupId("modelGroupId") + .name("name") + .description("description") + .modelConfig(config) + .build(); + + updateModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateModelInput).build(); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java index 627985813e..a33e169668 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java @@ -1,5 +1,13 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -7,14 +15,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLModelGroupDeleteRequestTest { private String modelGroupId; @@ -25,16 +25,14 @@ public class MLModelGroupDeleteRequestTest { public void setUp() { modelGroupId = "testGroupId"; - request = MLModelGroupDeleteRequest.builder() - .modelGroupId(modelGroupId).build(); + request = MLModelGroupDeleteRequest.builder().modelGroupId(modelGroupId).build(); } @Test public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); - MLModelGroupDeleteRequest parsedRequest = new MLModelGroupDeleteRequest( - bytesStreamOutput.bytes().streamInput()); + MLModelGroupDeleteRequest parsedRequest = new MLModelGroupDeleteRequest(bytesStreamOutput.bytes().streamInput()); assertEquals(parsedRequest.getModelGroupId(), modelGroupId); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java index 5b8000bdb9..7a463f28bc 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetRequestTest.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -12,14 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLModelGroupGetRequestTest { private String modelGroupId; @@ -30,8 +30,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder() - .modelGroupId(modelGroupId).build(); + MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().modelGroupId(modelGroupId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlModelGroupGetRequest.writeTo(bytesStreamOutput); MLModelGroupGetRequest parsedRequest = new MLModelGroupGetRequest(bytesStreamOutput.bytes().streamInput()); @@ -48,17 +47,16 @@ public void validate_Exception_NullmodelGroupId() { @Test public void fromActionRequest_Success() { - MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder() - .modelGroupId(modelGroupId).build(); + MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().modelGroupId(modelGroupId).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { - return null; + return null; } @Override public void writeTo(StreamOutput out) throws IOException { - mlModelGroupGetRequest.writeTo(out); + mlModelGroupGetRequest.writeTo(out); } }; MLModelGroupGetRequest result = MLModelGroupGetRequest.fromActionRequest(actionRequest); @@ -71,12 +69,12 @@ public void fromActionRequest_IOException() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { - return null; + return null; } @Override public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); + throw new IOException("test"); } }; MLModelGroupGetRequest.fromActionRequest(actionRequest); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponseTest.java index da10789062..4806ab0d40 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupGetResponseTest.java @@ -5,6 +5,15 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -16,27 +25,19 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLModelGroup; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; - public class MLModelGroupGetResponseTest { MLModelGroup mlModelGroup; @Before public void setUp() { - mlModelGroup = MLModelGroup.builder() - .name("modelGroup1") - .latestVersion(1) - .description("This is an example model group") - .access("public") - .build(); + mlModelGroup = MLModelGroup + .builder() + .name("modelGroup1") + .latestVersion(1) + .description("This is an example model group") + .access("public") + .build(); } @Test @@ -58,17 +59,20 @@ public void toXContentTest() throws IOException { mlModelGroupGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"name\":\"modelGroup1\"," + - "\"latest_version\":1," + - "\"description\":\"This is an example model group\"," + - "\"access\":\"public\"}", - jsonStr); + assertEquals( + "{\"name\":\"modelGroup1\"," + + "\"latest_version\":1," + + "\"description\":\"This is an example model group\"," + + "\"access\":\"public\"}", + jsonStr + ); } @Test public void fromActionResponseWithMLModelGroupGetResponseSuccess() { MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build(); - MLModelGroupGetResponse mlModelGroupGetResponseFromActionResponse = MLModelGroupGetResponse.fromActionResponse(mlModelGroupGetResponse); + MLModelGroupGetResponse mlModelGroupGetResponseFromActionResponse = MLModelGroupGetResponse + .fromActionResponse(mlModelGroupGetResponse); assertSame(mlModelGroupGetResponse, mlModelGroupGetResponseFromActionResponse); assertEquals(mlModelGroupGetResponse.mlModelGroup, mlModelGroupGetResponseFromActionResponse.mlModelGroup); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java index a9a4969533..68cd72836e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java @@ -4,6 +4,7 @@ import java.io.IOException; import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -17,13 +18,14 @@ public class MLRegisterModelGroupInputTest { @Before public void setUp() throws Exception { - mlRegisterModelGroupInput = mlRegisterModelGroupInput.builder() - .name("name") - .description("description") - .backendRoles(Arrays.asList("IT")) - .modelAccessMode(AccessMode.RESTRICTED) - .isAddAllBackendRoles(true) - .build(); + mlRegisterModelGroupInput = mlRegisterModelGroupInput + .builder() + .name("name") + .description("description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java index 34be768c63..f675f9f321 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java @@ -1,5 +1,14 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -8,15 +17,6 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.AccessMode; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLRegisterModelGroupRequestTest { private MLRegisterModelGroupInput mlRegisterModelGroupInput; @@ -24,19 +24,18 @@ public class MLRegisterModelGroupRequestTest { private MLRegisterModelGroupRequest request; @Before - public void setUp(){ - - mlRegisterModelGroupInput = MLRegisterModelGroupInput.builder() - .name("name") - .description("description") - .backendRoles(List.of("IT")) - .modelAccessMode(AccessMode.RESTRICTED) - .isAddAllBackendRoles(true) - .build(); - - request = MLRegisterModelGroupRequest.builder() - .registerModelGroupInput(mlRegisterModelGroupInput) - .build(); + public void setUp() { + + mlRegisterModelGroupInput = MLRegisterModelGroupInput + .builder() + .name("name") + .description("description") + .backendRoles(List.of("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); + + request = MLRegisterModelGroupRequest.builder().registerModelGroupInput(mlRegisterModelGroupInput).build(); } @Test @@ -46,9 +45,18 @@ public void writeToSuccess() throws IOException { MLRegisterModelGroupRequest parsedRequest = new MLRegisterModelGroupRequest(bytesStreamOutput.bytes().streamInput()); assertEquals(request.getRegisterModelGroupInput().getName(), parsedRequest.getRegisterModelGroupInput().getName()); assertEquals(request.getRegisterModelGroupInput().getDescription(), parsedRequest.getRegisterModelGroupInput().getDescription()); - assertEquals(request.getRegisterModelGroupInput().getBackendRoles().get(0), parsedRequest.getRegisterModelGroupInput().getBackendRoles().get(0)); - assertEquals(request.getRegisterModelGroupInput().getModelAccessMode(), parsedRequest.getRegisterModelGroupInput().getModelAccessMode()); - assertEquals(request.getRegisterModelGroupInput().getIsAddAllBackendRoles() ,parsedRequest.getRegisterModelGroupInput().getIsAddAllBackendRoles()); + assertEquals( + request.getRegisterModelGroupInput().getBackendRoles().get(0), + parsedRequest.getRegisterModelGroupInput().getBackendRoles().get(0) + ); + assertEquals( + request.getRegisterModelGroupInput().getModelAccessMode(), + parsedRequest.getRegisterModelGroupInput().getModelAccessMode() + ); + assertEquals( + request.getRegisterModelGroupInput().getIsAddAllBackendRoles(), + parsedRequest.getRegisterModelGroupInput().getIsAddAllBackendRoles() + ); } @Test @@ -67,9 +75,10 @@ public void validateNullMLRegisterModelGroupInputException() { // MLRegisterModelGroupInput check its parameters when created, so exception is not thrown here public void validateNullMLModelNameException() { mlRegisterModelGroupInput.setName(null); - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .registerModelGroupInput(mlRegisterModelGroupInput) - .build(); + MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest + .builder() + .registerModelGroupInput(mlRegisterModelGroupInput) + .build(); assertNull(request.validate()); assertNull(request.getRegisterModelGroupInput().getName()); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java index 528a0099b5..ae33d13469 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java @@ -5,6 +5,15 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -14,15 +23,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class MLRegisterModelGroupResponseTest { MLRegisterModelGroupResponse response; @@ -32,7 +32,6 @@ public void setup() { response = new MLRegisterModelGroupResponse("testModelGroupId", "Status"); } - @Test public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java index 9dc0fc559c..96d6b36a45 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java @@ -4,6 +4,7 @@ import java.io.IOException; import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -17,14 +18,15 @@ public class MLUpdateModelGroupInputTest { @Before public void setUp() throws Exception { - mlUpdateModelGroupInput = mlUpdateModelGroupInput.builder() - .modelGroupID("modelGroupId") - .name("name") - .description("description") - .backendRoles(Arrays.asList("IT")) - .modelAccessMode(AccessMode.RESTRICTED) - .isAddAllBackendRoles(true) - .build(); + mlUpdateModelGroupInput = mlUpdateModelGroupInput + .builder() + .modelGroupID("modelGroupId") + .name("name") + .description("description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java index 104aacadde..d823b77b16 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java @@ -8,6 +8,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -21,24 +22,23 @@ public class MLUpdateModelGroupRequestTest { private MLUpdateModelGroupInput mlUpdateModelGroupInput; @Before - public void setUp(){ - - mlUpdateModelGroupInput = mlUpdateModelGroupInput.builder() - .modelGroupID("modelGroupId") - .name("name") - .description("description") - .backendRoles(Arrays.asList("IT")) - .modelAccessMode(AccessMode.RESTRICTED) - .isAddAllBackendRoles(true) - .build(); + public void setUp() { + + mlUpdateModelGroupInput = mlUpdateModelGroupInput + .builder() + .modelGroupID("modelGroupId") + .name("name") + .description("description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); } @Test public void writeToSuccess() throws IOException { - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .updateModelGroupInput(mlUpdateModelGroupInput) - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(mlUpdateModelGroupInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLUpdateModelGroupRequest(bytesStreamOutput.bytes().streamInput()); @@ -52,17 +52,14 @@ public void writeToSuccess() throws IOException { @Test public void validateSuccess() { - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .updateModelGroupInput(mlUpdateModelGroupInput) - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(mlUpdateModelGroupInput).build(); assertNull(request.validate()); } @Test public void validateWithNullMLUpdateModelGroupInputException() { - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: Update Model group input can't be null;", exception.getMessage()); } @@ -71,28 +68,21 @@ public void validateWithNullMLUpdateModelGroupInputException() { // MLUpdateModelGroupInput check its parameters when created, so exception is not thrown here public void validateWithNullMLModelNameException() { mlUpdateModelGroupInput.setName(null); - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .updateModelGroupInput(mlUpdateModelGroupInput) - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(mlUpdateModelGroupInput).build(); assertNull(request.validate()); assertNull(request.getUpdateModelGroupInput().getName()); } - @Test public void fromActionRequestWithMLUpdateModelGroupRequestSuccess() { - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .updateModelGroupInput(mlUpdateModelGroupInput) - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(mlUpdateModelGroupInput).build(); assertSame(MLUpdateModelGroupRequest.fromActionRequest(request), request); } @Test public void fromActionRequestWithNonMLUpdateModelGroupRequestSuccess() { - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .updateModelGroupInput(mlUpdateModelGroupInput) - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(mlUpdateModelGroupInput).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponseTest.java index 2c1305a73e..f42a1bf4d1 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponseTest.java @@ -5,6 +5,12 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,12 +18,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class MLUpdateModelGroupResponseTest { MLUpdateModelGroupResponse mlUpdateModelGroupResponse; @@ -27,7 +27,6 @@ public void setup() { mlUpdateModelGroupResponse = new MLUpdateModelGroupResponse("Status"); } - @Test public void writeTo_Success() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java index b9cbe7d700..468395212a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java @@ -5,12 +5,16 @@ package org.opensearch.ml.common.transport.prediction; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + import java.io.IOException; import java.io.UncheckedIOException; import java.util.Collections; import java.util.HashMap; -import lombok.NonNull; import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -19,6 +23,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataframe.ColumnType; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; @@ -26,15 +31,11 @@ import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; import org.opensearch.search.builder.SearchSourceBuilder; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; +import lombok.NonNull; public class MLPredictionTaskRequestTest { @@ -42,29 +43,29 @@ public class MLPredictionTaskRequestTest { @Before public void setUp() { - DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }})); - mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(KMeansParams.builder().centroids(1).build()) - .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) - .build(); + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + })); + mlInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(KMeansParams.builder().centroids(1).build()) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) + .build(); } @Test public void writeTo_Success() throws IOException { User user = User.parse("admin|role-1|all_access"); - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .user(user) - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).user(user).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLPredictionTaskRequest(bytesStreamOutput.bytes().streamInput()); assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm()); - KMeansParams params = (KMeansParams)request.getMlInput().getParameters(); + KMeansParams params = (KMeansParams) request.getMlInput().getParameters(); assertEquals(1, params.getCentroids().intValue()); MLInputDataset inputDataset = request.getMlInput().getInputDataset(); assertEquals(MLInputDataType.DATA_FRAME, inputDataset.getInputDataType()); @@ -85,10 +86,7 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { User user = User.parse("admin|role-1|all_access"); - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .user(user) - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).user(user).build(); assertNull(request.validate()); } @@ -96,8 +94,7 @@ public void validate_Success() { @Test public void validate_Exception_NullMLInput() { mlInput.setAlgorithm(null); - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); @@ -106,21 +103,16 @@ public void validate_Exception_NullMLInput() { @Test public void validate_Exception_NullInputDataset() { mlInput.setInputDataset(null); - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage()); } - @Test public void fromActionRequest_Success_WithMLPredictionTaskRequest() { - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build(); assertSame(MLPredictionTaskRequest.fromActionRequest(request), request); } @@ -131,21 +123,23 @@ public void fromActionRequest_Success_WithNonMLPredictionTaskRequest_DataFrameIn @Test public void fromActionRequest_Success_WithNonMLPredictionTaskRequest_SearchQueryInput() { - @NonNull SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + @NonNull + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(new MatchAllQueryBuilder()); - mlInput.setInputDataset(SearchQueryInputDataset.builder() - .indices(Collections.singletonList("test_index")) - .searchSourceBuilder(searchSourceBuilder) - .build()); + mlInput + .setInputDataset( + SearchQueryInputDataset + .builder() + .indices(Collections.singletonList("test_index")) + .searchSourceBuilder(searchSourceBuilder) + .build() + ); fromActionRequest_Success_WithNonMLPredictionTaskRequest(mlInput); } private void fromActionRequest_Success_WithNonMLPredictionTaskRequest(MLInput mlInput) { User user = User.parse("admin|role-1|all_access"); - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .user(user) - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).user(user).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskResponseTest.java index 4f6ddac596..cb161f15f9 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskResponseTest.java @@ -5,45 +5,46 @@ package org.opensearch.ml.common.transport.prediction; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.HashMap; + import org.junit.Test; -import org.opensearch.core.action.ActionResponse; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.transport.MLTaskResponse; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Collections; -import java.util.HashMap; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; - public class MLPredictionTaskResponseTest { @Test public void writeTo_Success() throws IOException { - MLPredictionOutput output = MLPredictionOutput.builder() - .taskId("taskId") - .status("Success") - .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ + MLPredictionOutput output = MLPredictionOutput + .builder() + .taskId("taskId") + .status("Success") + .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { put("key1", 2.0D); - }}))) - .build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + } + }))) + .build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); response.writeTo(bytesStreamOutput); response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput()); - MLPredictionOutput mlPredictionOutput = (MLPredictionOutput)response.getOutput(); + MLPredictionOutput mlPredictionOutput = (MLPredictionOutput) response.getOutput(); assertEquals("taskId", mlPredictionOutput.getTaskId()); assertEquals("Success", mlPredictionOutput.getStatus()); assertEquals(1, mlPredictionOutput.getPredictionResult().size()); @@ -51,31 +52,33 @@ public void writeTo_Success() throws IOException { @Test public void fromActionResponse_WithMLPredictionTaskResponse() { - MLPredictionOutput output = MLPredictionOutput.builder() - .taskId("taskId") - .status("Success") - .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ + MLPredictionOutput output = MLPredictionOutput + .builder() + .taskId("taskId") + .status("Success") + .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { put("key1", 2.0D); - }}))) - .build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + } + }))) + .build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); assertSame(response, MLTaskResponse.fromActionResponse(response)); } @Test public void fromActionResponse_WithNonMLPredictionTaskResponse() { - MLPredictionOutput output = MLPredictionOutput.builder() - .taskId("taskId") - .status("Success") - .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ + MLPredictionOutput output = MLPredictionOutput + .builder() + .taskId("taskId") + .status("Success") + .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { put("key1", 2.0D); - }}))) - .build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + } + }))) + .build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput out) throws IOException { @@ -101,31 +104,34 @@ public void writeTo(StreamOutput out) throws IOException { } }; - MLTaskResponse.fromActionResponse(actionResponse); } @Test public void toXContentTest() throws IOException { - MLPredictionOutput output = MLPredictionOutput.builder() - .taskId("b5009b99-268f-476d-a676-379a30f82457") - .status("Success") - .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ + MLPredictionOutput output = MLPredictionOutput + .builder() + .taskId("b5009b99-268f-476d-a676-379a30f82457") + .status("Success") + .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { put("ClusterID", 0); - }}))) - .build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) + } + }))) .build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"task_id\":\"b5009b99-268f-476d-a676-379a30f82457\"," + - "\"status\":\"Success\"," + - "\"prediction_result\":{" + - "\"column_metas\":[{\"name\":\"ClusterID\",\"column_type\":\"INTEGER\"}]," + - "\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":0}]}]}}", jsonStr); + assertEquals( + "{\"task_id\":\"b5009b99-268f-476d-a676-379a30f82457\"," + + "\"status\":\"Success\"," + + "\"prediction_result\":{" + + "\"column_metas\":[{\"name\":\"ClusterID\",\"column_type\":\"INTEGER\"}]," + + "\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":0}]}]}}", + jsonStr + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index b4d5c23495..8caed811b5 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -1,5 +1,14 @@ package org.opensearch.ml.common.transport.register; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import java.io.IOException; +import java.util.Collections; +import java.util.function.Consumer; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -8,11 +17,11 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -26,15 +35,6 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.function.Consumer; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; - @RunWith(MockitoJUnitRunner.class) public class MLRegisterModelInputTest { @@ -44,23 +44,23 @@ public class MLRegisterModelInputTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," + - "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"description\":\"test description\"," + - "\"url\":\"url\",\"model_content_hash_value\":\"hash_value_test\",\"model_format\":\"ONNX\"," + - "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100," + - "\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\"," + - "\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]," + - "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + - "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + - "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + - "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}},\"is_hidden\":false}"; + private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," + + "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"description\":\"test description\"," + + "\"url\":\"url\",\"model_content_hash_value\":\"hash_value_test\",\"model_format\":\"ONNX\"," + + "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100," + + "\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\"," + + "\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]," + + "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}},\"is_hidden\":false}"; private final FunctionName functionName = FunctionName.LINEAR_REGRESSION; private final String modelName = "modelName"; private final String version = "version"; @@ -70,29 +70,31 @@ public class MLRegisterModelInputTest { @Before public void setUp() throws Exception { - config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); + config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); HttpConnector connector = HttpConnectorTest.createHttpConnector(); - input = MLRegisterModelInput.builder() - .functionName(functionName) - .modelName(modelName) - .version(version) - .modelGroupId(modelGroupId) - .url(url) - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .isHidden(false) - .description("test description") - .hashValue("hash_value_test") - .connector(connector) - .build(); + input = MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName(modelName) + .version(version) + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .isHidden(false) + .description("test description") + .hashValue("hash_value_test") + .connector(connector) + .build(); } @Test @@ -106,52 +108,51 @@ public void constructor_NullModel() { public void constructor_NullModelName() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model name is null"); - MLRegisterModelInput.builder() - .functionName(functionName) - .modelGroupId(modelGroupId) - .modelName(null) - .build(); + MLRegisterModelInput.builder().functionName(functionName).modelGroupId(modelGroupId).modelName(null).build(); } @Test public void constructor_NullModelFormat() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model format is null"); - MLRegisterModelInput.builder() - .functionName(functionName) - .modelName(modelName) - .version(version) - .modelGroupId(modelGroupId) - .modelFormat(null) - .url(url) - .build(); + MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName(modelName) + .version(version) + .modelGroupId(modelGroupId) + .modelFormat(null) + .url(url) + .build(); } @Test public void constructor_NullModelConfig() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model config is null"); - MLRegisterModelInput.builder() - .functionName(functionName) - .modelName(modelName) - .version(version) - .modelGroupId(modelGroupId) - .modelFormat(MLModelFormat.ONNX) - .modelConfig(null) - .url(url) - .build(); + MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName(modelName) + .version(version) + .modelGroupId(modelGroupId) + .modelFormat(MLModelFormat.ONNX) + .modelConfig(null) + .url(url) + .build(); } @Test public void constructor_SuccessWithMinimalSetup() { - MLRegisterModelInput input = MLRegisterModelInput.builder() - .modelName(modelName) - .version(version) - .modelGroupId(modelGroupId) - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .url(url) - .build(); + MLRegisterModelInput input = MLRegisterModelInput + .builder() + .modelName(modelName) + .version(version) + .modelGroupId(modelGroupId) + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .url(url) + .build(); // MLRegisterModelInput.functionName is set to FunctionName.TEXT_EMBEDDING if not explicitly passed, with no exception thrown assertEquals(FunctionName.TEXT_EMBEDDING, input.getFunctionName()); // MLRegisterModelInput.deployModel is set to false if not explicitly passed, with no exception thrown @@ -171,20 +172,20 @@ public void testToXContent() throws Exception { @Test public void testToXContent_Incomplete() throws Exception { - String expectedIncompleteInputStr = "{\"function_name\":\"LINEAR_REGRESSION\"," + - "\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\"," + - "\"description\":\"test description\",\"model_content_hash_value\":\"hash_value_test\"," + - "\"deploy_model\":true,\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + - "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"}," + - "\"credential\":{\"key\":\"test_key_value\"},\"actions\":[{\"action_type\":\"PREDICT\"," + - "\"method\":\"POST\",\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + - "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + - "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}},\"is_hidden\":false}"; + String expectedIncompleteInputStr = "{\"function_name\":\"LINEAR_REGRESSION\"," + + "\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\"," + + "\"description\":\"test description\",\"model_content_hash_value\":\"hash_value_test\"," + + "\"deploy_model\":true,\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"}," + + "\"credential\":{\"key\":\"test_key_value\"},\"actions\":[{\"action_type\":\"PREDICT\"," + + "\"method\":\"POST\",\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}},\"is_hidden\":false}"; input.setUrl(null); input.setModelConfig(null); input.setModelFormat(null); @@ -207,24 +208,41 @@ public void parse_WithModel() throws Exception { @Test public void parse_WithoutModel() throws Exception { - testParseFromJsonString( false, expectedInputStr, parsedInput -> { + testParseFromJsonString(false, expectedInputStr, parsedInput -> { assertFalse(parsedInput.isDeployModel()); assertEquals("modelName", parsedInput.getModelName()); assertEquals("version", parsedInput.getVersion()); }); } - private void testParseFromJsonString(String modelName, String version, Boolean deployModel, String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + private void testParseFromJsonString( + String modelName, + String version, + Boolean deployModel, + String expectedInputStr, + Consumer verify + ) throws Exception { + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLRegisterModelInput parsedInput = MLRegisterModelInput.parse(parser, modelName, version, deployModel); verify.accept(parsedInput); } - private void testParseFromJsonString(Boolean deployModel,String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + private void testParseFromJsonString(Boolean deployModel, String expectedInputStr, Consumer verify) + throws Exception { + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLRegisterModelInput parsedInput = MLRegisterModelInput.parse(parser, deployModel); verify.accept(parsedInput); @@ -238,7 +256,6 @@ public void readInputStream_Success() throws IOException { }); } - @Test public void readInputStream_SuccessWithNullFields() throws IOException { input.setModelFormat(null); @@ -252,14 +269,15 @@ public void readInputStream_SuccessWithNullFields() throws IOException { @Test public void readInputStream_WithConnectorId() throws IOException { String connectorId = "test_connector_id"; - input = MLRegisterModelInput.builder() - .functionName(FunctionName.REMOTE) - .modelName(modelName) - .description("test model input") - .version(version) - .modelGroupId(modelGroupId) - .connectorId(connectorId) - .build(); + input = MLRegisterModelInput + .builder() + .functionName(FunctionName.REMOTE) + .modelName(modelName) + .description("test model input") + .version(version) + .modelGroupId(modelGroupId) + .connectorId(connectorId) + .build(); readInputStream(input, parsedInput -> { assertNull(parsedInput.getModelConfig()); assertNull(parsedInput.getModelFormat()); @@ -271,14 +289,15 @@ public void readInputStream_WithConnectorId() throws IOException { @Test public void readInputStream_WithInternalConnector() throws IOException { HttpConnector connector = HttpConnectorTest.createHttpConnector(); - input = MLRegisterModelInput.builder() - .functionName(FunctionName.REMOTE) - .modelName(modelName) - .description("test model input") - .version(version) - .modelGroupId(modelGroupId) - .connector(connector) - .build(); + input = MLRegisterModelInput + .builder() + .functionName(FunctionName.REMOTE) + .modelName(modelName) + .description("test model input") + .version(version) + .modelGroupId(modelGroupId) + .connector(connector) + .build(); readInputStream(input, parsedInput -> { assertNull(parsedInput.getModelConfig()); assertNull(parsedInput.getModelFormat()); @@ -288,24 +307,27 @@ public void readInputStream_WithInternalConnector() throws IOException { @Test public void testMCorrInput() throws IOException { - String testString = "{\"function_name\":\"METRICS_CORRELATION\",\"name\":\"METRICS_CORRELATION\",\"version\":\"1.0.0b1\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"TORCH_SCRIPT\",\"model_config\":{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; + String testString = + "{\"function_name\":\"METRICS_CORRELATION\",\"name\":\"METRICS_CORRELATION\",\"version\":\"1.0.0b1\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"TORCH_SCRIPT\",\"model_config\":{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; - MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .build(); + MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); - MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .modelName(FunctionName.METRICS_CORRELATION.name()) - .version("1.0.0b1") - .modelGroupId(modelGroupId) - .url(url) - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .modelConfig(mcorrConfig) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); + MLRegisterModelInput mcorrInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.METRICS_CORRELATION) + .modelName(FunctionName.METRICS_CORRELATION.name()) + .version("1.0.0b1") + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(mcorrConfig) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); mcorrInput.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); @@ -314,22 +336,24 @@ public void testMCorrInput() throws IOException { @Test public void readInputStream_MCorr() throws IOException { - MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .build(); + MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); - MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .modelName(FunctionName.METRICS_CORRELATION.name()) - .version("1.0.0b1") - .modelGroupId(modelGroupId) - .url(url) - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .modelConfig(mcorrConfig) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); + MLRegisterModelInput mcorrInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.METRICS_CORRELATION) + .modelName(FunctionName.METRICS_CORRELATION.name()) + .version("1.0.0b1") + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(mcorrConfig) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); readInputStream(mcorrInput, parsedInput -> { assertEquals(parsedInput.getModelConfig().getModelType(), mcorrConfig.getModelType()); assertEquals(parsedInput.getModelConfig().getAllConfig(), mcorrConfig.getAllConfig()); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java index b983fb1827..bcbee60593 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java @@ -1,5 +1,7 @@ package org.opensearch.ml.common.transport.register; +import static org.junit.Assert.*; + import java.io.IOException; import java.io.UncheckedIOException; @@ -9,47 +11,44 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelConfig; - -import static org.junit.Assert.*; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; public class MLRegisterModelRequestTest { private MLRegisterModelInput mlRegisterModelInput; @Before - public void setUp(){ - - TextEmbeddingModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - - - mlRegisterModelInput = mlRegisterModelInput.builder() - .functionName(FunctionName.KMEANS) - .modelName("modelName") - .version("version") - .modelGroupId("modelGroupId") - .url("url") - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); + public void setUp() { + + TextEmbeddingModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + mlRegisterModelInput = mlRegisterModelInput + .builder() + .functionName(FunctionName.KMEANS) + .modelName("modelName") + .version("version") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); } @Test public void writeTo_Success() throws IOException { - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .registerModelInput(mlRegisterModelInput) - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLRegisterModelRequest(bytesStreamOutput.bytes().streamInput()); @@ -69,17 +68,14 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .registerModelInput(mlRegisterModelInput) - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); assertNull(request.validate()); } @Test public void validate_Exception_NullMLRegisterModelInput() { - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); } @@ -88,9 +84,7 @@ public void validate_Exception_NullMLRegisterModelInput() { // MLRegisterModelInput check its parameters when created, so exception is not thrown here public void validate_Exception_NullMLModelName() { mlRegisterModelInput.setModelName(null); - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .registerModelInput(mlRegisterModelInput) - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); assertNull(request.validate()); assertNull(request.getRegisterModelInput().getModelName()); @@ -98,17 +92,13 @@ public void validate_Exception_NullMLModelName() { @Test public void fromActionRequest_Success_WithMLRegisterModelRequest() { - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .registerModelInput(mlRegisterModelInput) - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); assertSame(MLRegisterModelRequest.fromActionRequest(request), request); } @Test public void fromActionRequest_Success_WithNonMLRegisterModelRequest() { - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .registerModelInput(mlRegisterModelInput) - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -123,7 +113,10 @@ public void writeTo(StreamOutput out) throws IOException { MLRegisterModelRequest result = MLRegisterModelRequest.fromActionRequest(actionRequest); assertNotSame(result, request); assertEquals(request.getRegisterModelInput().getModelName(), result.getRegisterModelInput().getModelName()); - assertEquals(request.getRegisterModelInput().getModelConfig().getModelType(), result.getRegisterModelInput().getModelConfig().getModelType()); + assertEquals( + request.getRegisterModelInput().getModelConfig().getModelType(), + result.getRegisterModelInput().getModelConfig().getModelType() + ); } @Test(expected = UncheckedIOException.class) diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java index f222f40200..0caaed0f4c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java @@ -1,5 +1,10 @@ package org.opensearch.ml.common.transport.register; +import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -7,11 +12,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.*; -import static org.junit.Assert.assertEquals; - public class MLRegisterModelResponseTest { private String taskId; @@ -49,8 +49,7 @@ public void testToXContent() throws IOException { assertNotNull(builder); String jsonStr = builder.toString(); // Verify the results - assertEquals("{\"task_id\":\"test_id\"," + - "\"status\":\"test\"}", jsonStr); + assertEquals("{\"task_id\":\"test_id\"," + "\"status\":\"test\"}", jsonStr); } @Test @@ -63,7 +62,6 @@ public void testToXContent_withModelId() throws IOException { assertNotNull(builder); String jsonStr = builder.toString(); // Verify the results - assertEquals("{\"task_id\":\"test_id\"," + - "\"status\":\"test\"," + "\"model_id\":\"model_id\"}", jsonStr); + assertEquals("{\"task_id\":\"test_id\"," + "\"status\":\"test\"," + "\"model_id\":\"model_id\"}", jsonStr); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpInputTest.java index 7a728a1e41..354017e30e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpInputTest.java @@ -1,8 +1,6 @@ package org.opensearch.ml.common.transport.sync; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; +import static org.junit.Assert.*; import java.io.IOException; import java.util.HashMap; @@ -10,18 +8,20 @@ import java.util.Map; import java.util.Set; -import static org.junit.Assert.*; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; public class MLSyncUpInputTest { - @Test public void testConstructorSerialization_SuccessWithNullFields() throws IOException { - MLSyncUpInput syncUpInputWithNullFields = MLSyncUpInput.builder() - .getDeployedModels(true) - .clearRoutingTable(true) - .syncRunningDeployModelTasks(true) - .build(); + MLSyncUpInput syncUpInputWithNullFields = MLSyncUpInput + .builder() + .getDeployedModels(true) + .clearRoutingTable(true) + .syncRunningDeployModelTasks(true) + .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); syncUpInputWithNullFields.writeTo(bytesStreamOutput); @@ -41,34 +41,47 @@ public void testConstructorSerialization_SuccessWithFullFields() throws IOExcept Map> modelRoutingTable = new HashMap<>(); Map> runningDeployModelTasks = new HashMap<>(); - MLSyncUpInput syncUpInput = MLSyncUpInput.builder() - .getDeployedModels(true) - .addedWorkerNodes(addedWorkerNodes) - .removedWorkerNodes(removedWorkerNodes) - .modelRoutingTable(modelRoutingTable) - .runningDeployModelTasks(runningDeployModelTasks) - .clearRoutingTable(true) - .syncRunningDeployModelTasks(true) - .build(); + MLSyncUpInput syncUpInput = MLSyncUpInput + .builder() + .getDeployedModels(true) + .addedWorkerNodes(addedWorkerNodes) + .removedWorkerNodes(removedWorkerNodes) + .modelRoutingTable(modelRoutingTable) + .runningDeployModelTasks(runningDeployModelTasks) + .clearRoutingTable(true) + .syncRunningDeployModelTasks(true) + .build(); Set modelRoutingTableSet = new HashSet<>(); Set runningDeployModelTaskSet = new HashSet<>(); modelRoutingTableSet.add("modelRoutingTable1"); runningDeployModelTaskSet.add("runningDeployModelTask1"); - addedWorkerNodes.put("addedWorkerNodesKey1", new String [] {"addedWorkerNode1"}); - removedWorkerNodes.put("removedWorkerNodesKey1", new String [] {"removedWorkerNode1"}); - modelRoutingTable.put("modelRoutingTableKey1",modelRoutingTableSet); - runningDeployModelTasks.put("runningDeployModelTaskKey1",runningDeployModelTaskSet); + addedWorkerNodes.put("addedWorkerNodesKey1", new String[] { "addedWorkerNode1" }); + removedWorkerNodes.put("removedWorkerNodesKey1", new String[] { "removedWorkerNode1" }); + modelRoutingTable.put("modelRoutingTableKey1", modelRoutingTableSet); + runningDeployModelTasks.put("runningDeployModelTaskKey1", runningDeployModelTaskSet); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); syncUpInput.writeTo(bytesStreamOutput); StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLSyncUpInput parsedInput = new MLSyncUpInput(streamInput); - assertArrayEquals(syncUpInput.getAddedWorkerNodes().get("addedWorkerNodesKey1"), parsedInput.getAddedWorkerNodes().get("addedWorkerNodesKey1")); - assertArrayEquals(syncUpInput.getRemovedWorkerNodes().get("removedWorkerNodesKey1"), parsedInput.getRemovedWorkerNodes().get("removedWorkerNodesKey1")); - assertEquals(syncUpInput.getModelRoutingTable().get("modelRoutingTableKey1"), parsedInput.getModelRoutingTable().get("modelRoutingTableKey1")); - assertEquals(syncUpInput.getRunningDeployModelTasks().get("runningDeployModelTaskKey1"), parsedInput.getRunningDeployModelTasks().get("runningDeployModelTaskKey1")); + assertArrayEquals( + syncUpInput.getAddedWorkerNodes().get("addedWorkerNodesKey1"), + parsedInput.getAddedWorkerNodes().get("addedWorkerNodesKey1") + ); + assertArrayEquals( + syncUpInput.getRemovedWorkerNodes().get("removedWorkerNodesKey1"), + parsedInput.getRemovedWorkerNodes().get("removedWorkerNodesKey1") + ); + assertEquals( + syncUpInput.getModelRoutingTable().get("modelRoutingTableKey1"), + parsedInput.getModelRoutingTable().get("modelRoutingTableKey1") + ); + assertEquals( + syncUpInput.getRunningDeployModelTasks().get("runningDeployModelTaskKey1"), + parsedInput.getRunningDeployModelTasks().get("runningDeployModelTaskKey1") + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java index 12d135a9b3..cea291bf23 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java @@ -1,26 +1,20 @@ package org.opensearch.ml.common.transport.sync; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.transport.TransportAddress; - -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -import static org.junit.Assert.*; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLSyncUpNodeRequestTest { @@ -44,23 +38,22 @@ public void setUp() throws Exception { Map> modelRoutingTable = new HashMap<>(); Map> runningDeployModelTasks = new HashMap<>(); - syncUpInput = MLSyncUpInput.builder() - .getDeployedModels(true) - .addedWorkerNodes(addedWorkerNodes) - .removedWorkerNodes(removedWorkerNodes) - .modelRoutingTable(modelRoutingTable) - .runningDeployModelTasks(runningDeployModelTasks) - .clearRoutingTable(true) - .syncRunningDeployModelTasks(true) - .build(); + syncUpInput = MLSyncUpInput + .builder() + .getDeployedModels(true) + .addedWorkerNodes(addedWorkerNodes) + .removedWorkerNodes(removedWorkerNodes) + .modelRoutingTable(modelRoutingTable) + .runningDeployModelTasks(runningDeployModelTasks) + .clearRoutingTable(true) + .syncRunningDeployModelTasks(true) + .build(); } @Test public void testConstructorSerialization1() throws IOException { - String [] nodeIds = {"id1", "id2", "id3"}; - MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest( - new MLSyncUpNodesRequest(nodeIds, syncUpInput) - ); + String[] nodeIds = { "id1", "id2", "id3" }; + MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest(new MLSyncUpNodesRequest(nodeIds, syncUpInput)); BytesStreamOutput output = new BytesStreamOutput(); syncUpNodeRequest.writeTo(output); @@ -74,10 +67,8 @@ public void testConstructorSerialization1() throws IOException { @Test public void testConstructorSerialization2() { - DiscoveryNode [] nodeIds = {localNode1, localNode2, localNode3}; - MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest( - new MLSyncUpNodesRequest(nodeIds, syncUpInput) - ); + DiscoveryNode[] nodeIds = { localNode1, localNode2, localNode3 }; + MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest(new MLSyncUpNodesRequest(nodeIds, syncUpInput)); assertEquals(3, syncUpNodeRequest.getSyncUpNodesRequest().concreteNodes().length); assertTrue(syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable()); @@ -87,9 +78,7 @@ public void testConstructorSerialization2() { @Test public void testConstructorSerialization3() { - MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest( - new MLSyncUpNodesRequest(localNode1, localNode2, localNode3) - ); + MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest(new MLSyncUpNodesRequest(localNode1, localNode2, localNode3)); syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().setClearRoutingTable(true); syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().setSyncRunningDeployModelTasks(true); syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().setClearRoutingTable(true); @@ -102,19 +91,26 @@ public void testConstructorSerialization3() { @Test public void testConstructorFromInputStream() throws IOException { - String [] nodeIds = {"id1", "id2", "id3"}; - MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest( - new MLSyncUpNodesRequest(nodeIds, syncUpInput) - ); + String[] nodeIds = { "id1", "id2", "id3" }; + MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest(new MLSyncUpNodesRequest(nodeIds, syncUpInput)); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); syncUpNodeRequest.writeTo(bytesStreamOutput); StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLSyncUpNodeRequest parsedNodeRequest = new MLSyncUpNodeRequest(streamInput); assertEquals(3, parsedNodeRequest.getSyncUpNodesRequest().nodesIds().length); - assertEquals(parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable()); - assertEquals(parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isSyncRunningDeployModelTasks()); - assertEquals(parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable()); + assertEquals( + parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), + syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable() + ); + assertEquals( + parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), + syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isSyncRunningDeployModelTasks() + ); + assertEquals( + parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), + syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable() + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java index 8599002354..ff0ea232ff 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java @@ -1,5 +1,12 @@ package org.opensearch.ml.common.transport.sync; +import static org.junit.Assert.*; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -9,38 +16,39 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.transport.TransportAddress; -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; - -import static org.junit.Assert.*; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLSyncUpNodeResponseTest { private DiscoveryNode localNode; private final String modelStatus = "modelStatus"; - private final String[] loadedModelIds = {"loadedModelIds"}; - private final String[] runningLoadModelTaskIds = {"runningLoadModelTaskIds"}; - private final String[] runningLoadModelIds = {"modelid1"}; + private final String[] loadedModelIds = { "loadedModelIds" }; + private final String[] runningLoadModelTaskIds = { "runningLoadModelTaskIds" }; + private final String[] runningLoadModelIds = { "modelid1" }; + + private final String[] expiredModelIds = { "modelExpired" }; - private final String[] expiredModelIds = {"modelExpired"}; @Before public void setUp() throws Exception { localNode = new DiscoveryNode( - "foo0", - "foo0", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); } @Test public void testSerializationDeserialization() throws IOException { - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse( + localNode, + modelStatus, + loadedModelIds, + runningLoadModelIds, + runningLoadModelTaskIds, + expiredModelIds + ); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLSyncUpNodeResponse newResponse = new MLSyncUpNodeResponse(output.bytes().streamInput()); @@ -53,7 +61,14 @@ public void testSerializationDeserialization() throws IOException { @Test public void testReadProfile() throws IOException { - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse( + localNode, + modelStatus, + loadedModelIds, + runningLoadModelIds, + runningLoadModelTaskIds, + expiredModelIds + ); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLSyncUpNodeResponse newResponse = MLSyncUpNodeResponse.readStats(output.bytes().streamInput()); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponseTest.java index 6603f17cbb..80d2608669 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponseTest.java @@ -1,5 +1,11 @@ package org.opensearch.ml.common.transport.sync; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.*; +import java.util.List; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -8,13 +14,6 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; -import java.io.IOException; -import java.util.*; - -import static org.junit.Assert.assertEquals; - -import java.util.List; - @RunWith(MockitoJUnitRunner.class) public class MLSyncUpNodesResponseTest { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponseTest.java index c183841431..18e74d0db3 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponseTest.java @@ -1,5 +1,10 @@ package org.opensearch.ml.common.transport.sync; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -9,11 +14,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - @RunWith(MockitoJUnitRunner.class) public class MLSyncUpResponseTest { private String status; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetRequestTest.java index 5d4e300904..0219c68e0e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetRequestTest.java @@ -1,5 +1,11 @@ package org.opensearch.ml.common.transport.task; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -7,12 +13,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; - public class MLTaskGetRequestTest { private String taskId; @@ -23,8 +23,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder() - .taskId(taskId).build(); + MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlTaskGetRequest.writeTo(bytesStreamOutput); MLTaskGetRequest parsedModel = new MLTaskGetRequest(bytesStreamOutput.bytes().streamInput()); @@ -41,8 +40,7 @@ public void validate_Exception_NullModelId() { @Test public void fromActionRequest_Success() { - MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder() - .taskId(taskId).build(); + MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java index 8eeff0916e..efc3e859ac 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java @@ -1,5 +1,12 @@ package org.opensearch.ml.common.transport.task; +import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -7,40 +14,34 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; - -import java.io.IOException; -import java.time.Instant; -import java.util.Arrays; - -import static org.junit.Assert.*; -import static org.junit.Assert.assertEquals; +import org.opensearch.ml.common.dataset.MLInputDataType; public class MLTaskGetResponseTest { MLTask mlTask; @Before public void setUp() { - mlTask = MLTask.builder() - .taskId("id") - .modelId("model id") - .taskType(MLTaskType.EXECUTION) - .functionName(FunctionName.LINEAR_REGRESSION) - .state(MLTaskState.CREATED) - .inputType(MLInputDataType.DATA_FRAME) - .progress(1.3f) - .outputIndex("some index") - .workerNodes(Arrays.asList("some node")) - .createTime(Instant.ofEpochMilli(123)) - .lastUpdateTime(Instant.ofEpochMilli(123)) - .error("error") - .user(new User()) - .async(true) - .build(); + mlTask = MLTask + .builder() + .taskId("id") + .modelId("model id") + .taskType(MLTaskType.EXECUTION) + .functionName(FunctionName.LINEAR_REGRESSION) + .state(MLTaskState.CREATED) + .inputType(MLInputDataType.DATA_FRAME) + .progress(1.3f) + .outputIndex("some index") + .workerNodes(Arrays.asList("some node")) + .createTime(Instant.ofEpochMilli(123)) + .lastUpdateTime(Instant.ofEpochMilli(123)) + .error("error") + .user(new User()) + .async(true) + .build(); } @Test @@ -71,19 +72,22 @@ public void toXContentTest() throws IOException { mlTaskGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"task_id\":\"id\"," + - "\"model_id\":\"model id\"," + - "\"task_type\":\"EXECUTION\"," + - "\"function_name\":\"LINEAR_REGRESSION\"," + - "\"state\":\"CREATED\"," + - "\"input_type\":\"DATA_FRAME\"," + - "\"progress\":1.3," + - "\"output_index\":\"some index\"," + - "\"worker_node\":[\"some node\"]," + - "\"create_time\":123," + - "\"last_update_time\":123," + - "\"error\":\"error\"," + - "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + - "\"is_async\":true}", jsonStr); + assertEquals( + "{\"task_id\":\"id\"," + + "\"model_id\":\"model id\"," + + "\"task_type\":\"EXECUTION\"," + + "\"function_name\":\"LINEAR_REGRESSION\"," + + "\"state\":\"CREATED\"," + + "\"input_type\":\"DATA_FRAME\"," + + "\"progress\":1.3," + + "\"output_index\":\"some index\"," + + "\"worker_node\":[\"some node\"]," + + "\"create_time\":123," + + "\"last_update_time\":123," + + "\"error\":\"error\"," + + "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"is_async\":true}", + jsonStr + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetRequestTests.java index 6e62d99507..e9ae797e72 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetRequestTests.java @@ -4,6 +4,14 @@ */ package org.opensearch.ml.common.transport.tools; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -12,35 +20,25 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.ToolMetadata; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; - public class MLToolGetRequestTests { private List toolMetadataList; @Before public void setUp() { toolMetadataList = new ArrayList<>(); - ToolMetadata wikipediaTool = ToolMetadata.builder() - .name("MathTool") - .description("Use this tool to search general knowledge on wikipedia.") - .type("MathTool") - .version("test") - .build(); + ToolMetadata wikipediaTool = ToolMetadata + .builder() + .name("MathTool") + .description("Use this tool to search general knowledge on wikipedia.") + .type("MathTool") + .version("test") + .build(); toolMetadataList.add(wikipediaTool); } @Test public void writeTo_success() throws IOException { - MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder() - .toolName("MathTool") - .toolMetadataList(toolMetadataList) - .build(); + MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder().toolName("MathTool").toolMetadataList(toolMetadataList).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlToolGetRequest.writeTo(bytesStreamOutput); @@ -51,10 +49,7 @@ public void writeTo_success() throws IOException { @Test public void fromActionRequest_success() { - MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder() - .toolName("MathTool") - .toolMetadataList(toolMetadataList) - .build(); + MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder().toolName("MathTool").toolMetadataList(toolMetadataList).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -95,4 +90,4 @@ public void validate_Exception_NullToolName() { assertEquals("Validation Failed: 1: Tool name can't be null;", exception.getMessage()); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetResponseTests.java index 6ec682dcc2..82287aba99 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolGetResponseTests.java @@ -4,25 +4,21 @@ */ package org.opensearch.ml.common.transport.tools; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; -import org.opensearch.core.action.ActionResponse; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.common.Strings; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.ToolMetadata; -import org.opensearch.ml.common.transport.model.MLModelGetResponse; - -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.*; public class MLToolGetResponseTests { ToolMetadata toolMetadata; @@ -31,12 +27,13 @@ public class MLToolGetResponseTests { @Before public void setUp() { - toolMetadata = ToolMetadata.builder() - .name("MathTool") - .description("Use this tool to calculate any math problem.") - .type("MathTool") - .version(null) - .build(); + toolMetadata = ToolMetadata + .builder() + .name("MathTool") + .description("Use this tool to calculate any math problem.") + .type("MathTool") + .version(null) + .build(); mlToolGetResponse = MLToolGetResponse.builder().toolMetadata(toolMetadata).build(); } @@ -57,7 +54,10 @@ public void toXContentTest() throws IOException { mlToolGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"undefined\"}", jsonStr); + assertEquals( + "{\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"undefined\"}", + jsonStr + ); } @Test @@ -89,4 +89,4 @@ public void writeTo(StreamOutput out) throws IOException { }; MLToolGetResponse.fromActionResponse(actionResponse); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListRequestTests.java index 8aedf99970..3b15f9ff1b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListRequestTests.java @@ -5,6 +5,15 @@ package org.opensearch.ml.common.transport.tools; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -12,13 +21,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.ToolMetadata; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; public class MLToolsListRequestTests { private List toolMetadataList; @@ -26,20 +28,20 @@ public class MLToolsListRequestTests { @Before public void setUp() { toolMetadataList = new ArrayList<>(); - ToolMetadata wikipediaTool = ToolMetadata.builder() - .name("WikipediaTool") - .description("Use this tool to search general knowledge on wikipedia.") - .type("WikipediaTool") - .version(null) - .build(); + ToolMetadata wikipediaTool = ToolMetadata + .builder() + .name("WikipediaTool") + .description("Use this tool to search general knowledge on wikipedia.") + .type("WikipediaTool") + .version(null) + .build(); toolMetadataList.add(wikipediaTool); } + @Test public void writeTo_success() throws IOException { - MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder() - .toolMetadataList(toolMetadataList) - .build(); + MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder().toolMetadataList(toolMetadataList).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlToolsListRequest.writeTo(bytesStreamOutput); MLToolsListRequest parsedToolMetadata = new MLToolsListRequest(bytesStreamOutput.bytes().streamInput()); @@ -73,6 +75,7 @@ public void fromActionRequest_IOException() { public ActionRequestValidationException validate() { return null; } + @Override public void writeTo(StreamOutput out) throws IOException { throw new IOException("test"); @@ -83,10 +86,9 @@ public void writeTo(StreamOutput out) throws IOException { @Test public void fromActionRequest_Success() { - MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder() - .toolMetadataList(toolMetadataList).build(); + MLToolsListRequest mlToolsListRequest = MLToolsListRequest.builder().toolMetadataList(toolMetadataList).build(); - ActionRequest actionRequest = new ActionRequest() { + ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { return null; @@ -108,4 +110,4 @@ public void testValidate() { MLToolsListRequest request = MLToolsListRequest.builder().build(); assertNull(request.validate()); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListResponseTests.java index c9ab879731..436f1b8d18 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/tools/MLToolsListResponseTests.java @@ -5,11 +5,16 @@ package org.opensearch.ml.common.transport.tools; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; + import org.junit.Before; import org.junit.Test; -// import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamOutput; @@ -17,12 +22,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.ToolMetadata; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.Assert.*; public class MLToolsListResponseTests { List toolMetadataList; @@ -32,18 +31,20 @@ public class MLToolsListResponseTests { @Before public void setUp() { toolMetadataList = new ArrayList<>(); - ToolMetadata searchWikipediaTool = ToolMetadata.builder() - .name("SearchWikipediaTool") - .description("Useful when you need to use this tool to search general knowledge on wikipedia.") - .type("SearchWikipediaTool") - .version(null) - .build(); - ToolMetadata toolMetadata = ToolMetadata.builder() - .name("MathTool") - .description("Use this tool to calculate any math problem.") - .type("MathTool") - .version("test") - .build(); + ToolMetadata searchWikipediaTool = ToolMetadata + .builder() + .name("SearchWikipediaTool") + .description("Useful when you need to use this tool to search general knowledge on wikipedia.") + .type("SearchWikipediaTool") + .version(null) + .build(); + ToolMetadata toolMetadata = ToolMetadata + .builder() + .name("MathTool") + .description("Use this tool to calculate any math problem.") + .type("MathTool") + .version("test") + .build(); toolMetadataList.add(searchWikipediaTool); toolMetadataList.add(toolMetadata); @@ -67,14 +68,20 @@ public void toXContentTest() throws IOException { mlToolsListResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("[{\"name\":\"SearchWikipediaTool\",\"description\":\"Useful when you need to use this tool to search general knowledge on wikipedia.\",\"type\":\"SearchWikipediaTool\",\"version\":\"undefined\"},{\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"test\"}]", jsonStr); + assertEquals( + "[{\"name\":\"SearchWikipediaTool\",\"description\":\"Useful when you need to use this tool to search general knowledge on wikipedia.\",\"type\":\"SearchWikipediaTool\",\"version\":\"undefined\"},{\"name\":\"MathTool\",\"description\":\"Use this tool to calculate any math problem.\",\"type\":\"MathTool\",\"version\":\"test\"}]", + jsonStr + ); } @Test public void fromActionResponseWithMLToolsListResponse_Success() { MLToolsListResponse mlToolsListResponseFromActionResponse = MLToolsListResponse.fromActionResponse(mlToolsListResponse); assertSame(mlToolsListResponse, mlToolsListResponseFromActionResponse); - assertEquals(mlToolsListResponse.getToolMetadataList().get(0).getName(), mlToolsListResponseFromActionResponse.getToolMetadataList().get(0).getName()); + assertEquals( + mlToolsListResponse.getToolMetadataList().get(0).getName(), + mlToolsListResponseFromActionResponse.getToolMetadataList().get(0).getName() + ); } @Test @@ -86,7 +93,10 @@ public void writeTo(StreamOutput out) throws IOException { } }; MLToolsListResponse mlToolsListResponseFromActionResponse = MLToolsListResponse.fromActionResponse(actionResponse); - assertEquals(mlToolsListResponse.getToolMetadataList().get(0).getName(), mlToolsListResponseFromActionResponse.getToolMetadataList().get(0).getName()); + assertEquals( + mlToolsListResponse.getToolMetadataList().get(0).getName(), + mlToolsListResponseFromActionResponse.getToolMetadataList().get(0).getName() + ); } @Test(expected = UncheckedIOException.class) @@ -99,4 +109,4 @@ public void writeTo(StreamOutput out) throws IOException { }; MLToolsListResponse.fromActionResponse(actionResponse); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java index 7c3e9eaa06..3de755892e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java @@ -5,29 +5,29 @@ package org.opensearch.ml.common.transport.training; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.HashMap; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataType; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.MLInput; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Collections; -import java.util.HashMap; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; +import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; public class MLTrainingTaskRequestTest { @@ -35,14 +35,17 @@ public class MLTrainingTaskRequestTest { @Before public void setUp() { - DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }})); - mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(KMeansParams.builder().centroids(1).build()) - .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) - .build(); + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + })); + mlInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(KMeansParams.builder().centroids(1).build()) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) + .build(); } @Test @@ -53,16 +56,13 @@ public void validate_Success() { @Test public void validate_SuccessWithBuilder() { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); assertNull(request.validate()); } @Test public void validate_Exception_NullMLInput() { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); } @@ -70,18 +70,14 @@ public void validate_Exception_NullMLInput() { @Test public void validate_Exception_NullInputDataInMLInput() { mlInput.setInputDataset(null); - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage()); } @Test public void writeTo() throws IOException { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLTrainingTaskRequest(bytesStreamOutput.bytes().streamInput()); @@ -92,17 +88,13 @@ public void writeTo() throws IOException { @Test public void fromActionRequest_WithMLTrainingTaskRequest() { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); assertSame(request, MLTrainingTaskRequest.fromActionRequest(request)); } @Test public void fromActionRequest_WithNonMLTrainingTaskRequest() { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -136,4 +128,4 @@ public void writeTo(StreamOutput out) throws IOException { }; MLTrainingTaskRequest.fromActionRequest(actionRequest); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskResponseTest.java index cca7a158cf..9b5a033d7a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskResponseTest.java @@ -5,54 +5,45 @@ package org.opensearch.ml.common.transport.training; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + import java.io.IOException; import java.io.UncheckedIOException; import org.junit.Test; -import org.opensearch.core.action.ActionResponse; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; - public class MLTrainingTaskResponseTest { @Test public void writeTo() throws IOException { - MLTrainingOutput output = MLTrainingOutput.builder().status("success") - .modelId("taskId").build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); response.writeTo(bytesStreamOutput); response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput()); - MLTrainingOutput modelTrainingOutput = (MLTrainingOutput)response.getOutput(); + MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) response.getOutput(); assertEquals("success", modelTrainingOutput.getStatus()); assertEquals("taskId", modelTrainingOutput.getModelId()); } @Test public void fromActionResponse_Success_WithMLTrainingTaskResponse() { - MLTrainingOutput output = MLTrainingOutput.builder().status("success") - .modelId("taskId").build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); assertSame(response, MLTaskResponse.fromActionResponse(response)); } @Test public void fromActionResponse_Success_WithNonMLTrainingTaskResponse() { - MLTrainingOutput output = MLTrainingOutput.builder().status("success") - .modelId("taskId").build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput out) throws IOException { @@ -62,8 +53,8 @@ public void writeTo(StreamOutput out) throws IOException { MLTaskResponse result = MLTaskResponse.fromActionResponse(actionResponse); assertNotSame(response, result); - MLTrainingOutput modelTrainingOutput = (MLTrainingOutput)response.getOutput(); - MLTrainingOutput resultModelTrainingOutput = (MLTrainingOutput)result.getOutput(); + MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) response.getOutput(); + MLTrainingOutput resultModelTrainingOutput = (MLTrainingOutput) result.getOutput(); assertEquals(modelTrainingOutput.getStatus(), resultModelTrainingOutput.getStatus()); assertEquals(modelTrainingOutput.getModelId(), resultModelTrainingOutput.getModelId()); } @@ -79,4 +70,4 @@ public void writeTo(StreamOutput out) throws IOException { MLTaskResponse.fromActionResponse(actionResponse); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInputTest.java index 44e820dc37..ce28d3fdf2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInputTest.java @@ -1,35 +1,32 @@ package org.opensearch.ml.common.transport.undeploy; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.*; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; - -import static org.junit.Assert.*; - public class MLUndeployModelInputTest { private MLUndeployModelInput input; - private final String [] modelIds = new String [] {"modelId1","modelId2","modelId3"}; - private final String [] nodeIds = new String [] {"nodeId1","nodeId2","nodeId3"}; - private final String expectedInputStr = "{\"model_ids\":[\"modelId1\",\"modelId2\",\"modelId3\"]," + - "\"node_ids\":[\"nodeId1\",\"nodeId2\",\"nodeId3\"]}"; + private final String[] modelIds = new String[] { "modelId1", "modelId2", "modelId3" }; + private final String[] nodeIds = new String[] { "nodeId1", "nodeId2", "nodeId3" }; + private final String expectedInputStr = "{\"model_ids\":[\"modelId1\",\"modelId2\",\"modelId3\"]," + + "\"node_ids\":[\"nodeId1\",\"nodeId2\",\"nodeId3\"]}"; @Before public void setUp() throws Exception { - input = MLUndeployModelInput.builder() - .modelIds(modelIds) - .nodeIds(nodeIds) - .build(); + input = MLUndeployModelInput.builder().modelIds(modelIds).nodeIds(nodeIds).build(); } @Test @@ -43,25 +40,35 @@ public void testToXContent() throws Exception { @Test public void testParse() throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLUndeployModelInput parsedInput = MLUndeployModelInput.parse(parser); - assertArrayEquals(new String [] {"modelId1","modelId2","modelId3"}, parsedInput.getModelIds()); - assertArrayEquals(new String [] {"nodeId1","nodeId2","nodeId3"}, parsedInput.getNodeIds()); + assertArrayEquals(new String[] { "modelId1", "modelId2", "modelId3" }, parsedInput.getModelIds()); + assertArrayEquals(new String[] { "nodeId1", "nodeId2", "nodeId3" }, parsedInput.getNodeIds()); } @Test public void testParseWithInvalidField() throws Exception { - String withInvalidFieldInputStr = "{\"void\":\"void\"," + - "\"model_ids\":[\"modelId1\",\"modelId2\",\"modelId3\"]," + - "\"node_ids\":[\"nodeId1\",\"nodeId2\",\"nodeId3\"]}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, withInvalidFieldInputStr); + String withInvalidFieldInputStr = "{\"void\":\"void\"," + + "\"model_ids\":[\"modelId1\",\"modelId2\",\"modelId3\"]," + + "\"node_ids\":[\"nodeId1\",\"nodeId2\",\"nodeId3\"]}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + withInvalidFieldInputStr + ); parser.nextToken(); MLUndeployModelInput parsedInput = MLUndeployModelInput.parse(parser); - assertArrayEquals(new String [] {"modelId1","modelId2","modelId3"}, parsedInput.getModelIds()); - assertArrayEquals(new String [] {"nodeId1","nodeId2","nodeId3"}, parsedInput.getNodeIds()); + assertArrayEquals(new String[] { "modelId1", "modelId2", "modelId3" }, parsedInput.getModelIds()); + assertArrayEquals(new String[] { "nodeId1", "nodeId2", "nodeId3" }, parsedInput.getNodeIds()); } @Test @@ -73,5 +80,4 @@ public void readInputStream() throws IOException { assertArrayEquals(input.getModelIds(), parsedInput.getModelIds()); assertArrayEquals(input.getNodeIds(), parsedInput.getNodeIds()); } - } - +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponseTest.java index 30dc98f0a9..45648be6d5 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponseTest.java @@ -1,5 +1,14 @@ package org.opensearch.ml.common.transport.undeploy; +import static org.junit.Assert.*; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -10,15 +19,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.transport.TransportAddress; -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.*; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLUndeployModelNodeResponseTest { @@ -30,15 +30,15 @@ public class MLUndeployModelNodeResponseTest { @Before public void setUp() throws Exception { localNode = new DiscoveryNode( - "foo0", - "foo0", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); modelWorkerNodeCounts = new HashMap<>(); - modelWorkerNodeCounts.put("modelId1", new String[]{"node"}); + modelWorkerNodeCounts.put("modelId1", new String[] { "node" }); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java index 7323f059f3..6e721ada02 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java @@ -1,22 +1,17 @@ package org.opensearch.ml.common.transport.undeploy; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.transport.TransportAddress; - -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; @RunWith(MockitoJUnitRunner.class) public class MLUndeployModelNodesRequestTest { @@ -30,16 +25,19 @@ public class MLUndeployModelNodesRequestTest { @Test public void testConstructorSerialization1() throws IOException { - String[] modelIds = {"modelId1", "modelId2", "modelId3"}; - String[] nodeIds = {"nodeId1", "nodeId2", "nodeId3"}; + String[] modelIds = { "modelId1", "modelId2", "modelId3" }; + String[] nodeIds = { "nodeId1", "nodeId2", "nodeId3" }; MLUndeployModelNodeRequest undeployModelNodeRequest = new MLUndeployModelNodeRequest( - new MLUndeployModelNodesRequest(nodeIds, modelIds) + new MLUndeployModelNodesRequest(nodeIds, modelIds) ); BytesStreamOutput output = new BytesStreamOutput(); undeployModelNodeRequest.writeTo(output); - assertArrayEquals(new String[] {"modelId1", "modelId2", "modelId3"}, undeployModelNodeRequest.getMlUndeployModelNodesRequest().getModelIds()); + assertArrayEquals( + new String[] { "modelId1", "modelId2", "modelId3" }, + undeployModelNodeRequest.getMlUndeployModelNodesRequest().getModelIds() + ); } @@ -47,7 +45,7 @@ public void testConstructorSerialization1() throws IOException { public void testConstructorSerialization2() throws IOException { MLUndeployModelNodeRequest undeployModelNodeRequest = new MLUndeployModelNodeRequest( - new MLUndeployModelNodesRequest(localNode1,localNode2) + new MLUndeployModelNodesRequest(localNode1, localNode2) ); assertEquals(2, undeployModelNodeRequest.getMlUndeployModelNodesRequest().concreteNodes().length); @@ -56,11 +54,11 @@ public void testConstructorSerialization2() throws IOException { @Test public void testConstructorFromInputStream() throws IOException { - String[] modelIds = {"modelId1", "modelId2", "modelId3"}; - String[] nodeIds = {"nodeId1", "nodeId2", "nodeId3"}; + String[] modelIds = { "modelId1", "modelId2", "modelId3" }; + String[] nodeIds = { "nodeId1", "nodeId2", "nodeId3" }; MLUndeployModelNodeRequest undeployModelNodeRequest = new MLUndeployModelNodeRequest( - new MLUndeployModelNodesRequest(nodeIds, modelIds) + new MLUndeployModelNodesRequest(nodeIds, modelIds) ); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); undeployModelNodeRequest.writeTo(bytesStreamOutput); @@ -68,7 +66,10 @@ public void testConstructorFromInputStream() throws IOException { StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLUndeployModelNodeRequest parsedNodeRequest = new MLUndeployModelNodeRequest(streamInput); - assertArrayEquals(undeployModelNodeRequest.getMlUndeployModelNodesRequest().getModelIds(), parsedNodeRequest.getMlUndeployModelNodesRequest().getModelIds()); + assertArrayEquals( + undeployModelNodeRequest.getMlUndeployModelNodesRequest().getModelIds(), + parsedNodeRequest.getMlUndeployModelNodesRequest().getModelIds() + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponseTest.java index e27b53dbb2..5f0be1a8e0 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponseTest.java @@ -1,5 +1,16 @@ package org.opensearch.ml.common.transport.undeploy; +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -10,22 +21,11 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.InetAddress; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLUndeployModelNodesResponseTest { @@ -39,20 +39,20 @@ public class MLUndeployModelNodesResponseTest { public void setUp() throws Exception { clusterName = new ClusterName("clusterName"); node1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); node2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); modelWorkerNodeCounts = new HashMap<>(); modelWorkerNodeCounts.put("modelId1", 1); @@ -76,13 +76,13 @@ public void testToXContent() throws IOException { Map modelToUndeployStatus1 = new HashMap<>(); modelToUndeployStatus1.put("modelId1", "response"); Map modelWorkerNodeCounts1 = new HashMap<>(); - modelWorkerNodeCounts1.put("modelId1", new String[]{"mockNode1"}); + modelWorkerNodeCounts1.put("modelId1", new String[] { "mockNode1" }); nodes.add(new MLUndeployModelNodeResponse(node1, modelToUndeployStatus1, modelWorkerNodeCounts1)); Map modelToUndeployStatus2 = new HashMap<>(); modelToUndeployStatus2.put("modelId2", "response"); Map modelWorkerNodeCounts2 = new HashMap<>(); - modelWorkerNodeCounts2.put("modelId2", new String[]{"mockNode2"}); + modelWorkerNodeCounts2.put("modelId2", new String[] { "mockNode2" }); nodes.add(new MLUndeployModelNodeResponse(node2, modelToUndeployStatus2, modelWorkerNodeCounts2)); List failures = new ArrayList<>(); @@ -90,9 +90,6 @@ public void testToXContent() throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assertEquals( - "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", - jsonStr - ); + assertEquals("{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", jsonStr); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java index 69f12099e9..f08572518c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java @@ -5,6 +5,18 @@ package org.opensearch.ml.common.transport.undeploy; +import static org.junit.Assert.*; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.InetAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -22,18 +34,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.net.InetAddress; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.*; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - public class MLUndeployModelsResponseTest { MLUndeployModelNodesResponse undeployModelNodesResponse; @@ -46,15 +46,15 @@ public void setUp() { Map modelToDeployStatus = new HashMap<>(); modelToDeployStatus.put("modelId1", "response"); DiscoveryNode localNode = new DiscoveryNode( - "test_node_name", - "test_node_id", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "test_node_name", + "test_node_id", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); Map modelWorkerNodeCounts = new HashMap<>(); - modelWorkerNodeCounts.put("modelId1", new String[]{"node"}); + modelWorkerNodeCounts.put("modelId1", new String[] { "node" }); MLUndeployModelNodeResponse nodeResponse = new MLUndeployModelNodeResponse(localNode, modelToDeployStatus, modelWorkerNodeCounts); List nodes = Arrays.asList(nodeResponse); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java index 28e61476cc..404c1ffc87 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodeResponseTest.java @@ -5,6 +5,16 @@ package org.opensearch.ml.common.transport.update_cache; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -16,17 +26,6 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeResponse; - -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; @RunWith(MockitoJUnitRunner.class) public class MLUpdateModelCacheNodeResponseTest { @@ -40,12 +39,12 @@ public class MLUpdateModelCacheNodeResponseTest { @Before public void setUp() throws Exception { localNode = new DiscoveryNode( - "foo0", - "foo0", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java index a698c139dd..9a00ca5427 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesRequestTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.transport.update_cache; +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; @@ -13,15 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeRequest; -import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesRequest; - -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; @RunWith(MockitoJUnitRunner.class) public class MLUpdateModelCacheNodesRequestTest { @@ -29,10 +27,10 @@ public class MLUpdateModelCacheNodesRequestTest { @Test public void testConstructorSerialization1() throws IOException { String modelId = "testModelId"; - String[] nodeIds = {"nodeId1", "nodeId2", "nodeId3"}; + String[] nodeIds = { "nodeId1", "nodeId2", "nodeId3" }; MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = new MLUpdateModelCacheNodeRequest( - new MLUpdateModelCacheNodesRequest(nodeIds, modelId) + new MLUpdateModelCacheNodesRequest(nodeIds, modelId) ); BytesStreamOutput output = new BytesStreamOutput(); @@ -44,24 +42,24 @@ public void testConstructorSerialization1() throws IOException { public void testConstructorSerialization2() { String modelId = "testModelId"; DiscoveryNode localNode1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); DiscoveryNode localNode2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); - DiscoveryNode[] nodes = {localNode1, localNode2}; + DiscoveryNode[] nodes = { localNode1, localNode2 }; MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = new MLUpdateModelCacheNodeRequest( - new MLUpdateModelCacheNodesRequest(nodes, modelId) + new MLUpdateModelCacheNodesRequest(nodes, modelId) ); assertEquals(2, updateModelCacheNodeRequest.getUpdateModelCacheNodesRequest().concreteNodes().length); } @@ -69,10 +67,10 @@ public void testConstructorSerialization2() { @Test public void testConstructorFromInputStream() throws IOException { String modelId = "testModelId"; - String[] nodeIds = {"nodeId1", "nodeId2", "nodeId3"}; + String[] nodeIds = { "nodeId1", "nodeId2", "nodeId3" }; MLUpdateModelCacheNodeRequest updateModelCacheNodeRequest = new MLUpdateModelCacheNodeRequest( - new MLUpdateModelCacheNodesRequest(nodeIds, modelId) + new MLUpdateModelCacheNodesRequest(nodeIds, modelId) ); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); updateModelCacheNodeRequest.writeTo(bytesStreamOutput); @@ -80,6 +78,9 @@ public void testConstructorFromInputStream() throws IOException { StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLUpdateModelCacheNodeRequest parsedNodeRequest = new MLUpdateModelCacheNodeRequest(streamInput); - assertEquals(updateModelCacheNodeRequest.getUpdateModelCacheNodesRequest().getModelId(), parsedNodeRequest.getUpdateModelCacheNodesRequest().getModelId()); + assertEquals( + updateModelCacheNodeRequest.getUpdateModelCacheNodesRequest().getModelId(), + parsedNodeRequest.getUpdateModelCacheNodesRequest().getModelId() + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java index e1fc242d43..95aab55d9e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/update_cache/MLUpdateModelCacheNodesResponseTest.java @@ -5,6 +5,16 @@ package org.opensearch.ml.common.transport.update_cache; +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -15,22 +25,10 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodeResponse; -import org.opensearch.ml.common.transport.update_cache.MLUpdateModelCacheNodesResponse; - -import java.io.IOException; -import java.net.InetAddress; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; @RunWith(MockitoJUnitRunner.class) public class MLUpdateModelCacheNodesResponseTest { @@ -44,20 +42,20 @@ public class MLUpdateModelCacheNodesResponseTest { public void setUp() throws Exception { clusterName = new ClusterName("clusterName"); node1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); node2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); } @@ -87,10 +85,7 @@ public void testToXContent() throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assertEquals( - "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", - jsonStr - ); + assertEquals("{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", jsonStr); } @Test @@ -102,6 +97,6 @@ public void testNullUpdateModelCacheStatusToXContent() throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assertEquals("{}",jsonStr); + assertEquals("{}", jsonStr); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index 667d4276c5..4bc3d79c50 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -1,103 +1,125 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.TestHelper; -import org.opensearch.ml.common.model.MLModelFormat; -import org.opensearch.ml.common.model.MLModelState; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; - -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - -public class MLRegisterModelMetaInputTest { - - - Function function = parser -> { - try { - return MLRegisterModelMetaInput.parse(parser); - } catch (Exception e) { - throw new RuntimeException("Failed to parse MLRegisterModelMetaInput", e); - } - }; - TextEmbeddingModelConfig config; - MLRegisterModelMetaInput mLRegisterModelMetaInput; - - @Before - public void setup() { - config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", - TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); - mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0", - "Model Description", null, null, MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, - 200L, "123", config, null, 2, - null, null, false, false, false, null); - } - - @Test - public void parse_MLRegisterModelMetaInput() throws IOException { - TestHelper.testParse(mLRegisterModelMetaInput, function); - } - - @Test - public void readInputStream_Success() throws IOException { - readInputStream(mLRegisterModelMetaInput); - } - - - private void readInputStream(MLRegisterModelMetaInput input) throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - input.writeTo(bytesStreamOutput); - StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); - MLRegisterModelMetaInput newInput = new MLRegisterModelMetaInput(streamInput); - assertEquals(input.getName(), newInput.getName()); - assertEquals(input.getFunctionName(), newInput.getFunctionName()); - assertEquals(input.getModelGroupId(), newInput.getModelGroupId()); - assertEquals(input.getVersion(), newInput.getVersion()); - assertEquals(input.getDescription(), newInput.getDescription()); - assertEquals(input.getModelFormat(), newInput.getModelFormat()); - assertEquals(input.getModelConfig().getAllConfig(), newInput.getModelConfig().getAllConfig()); - assertEquals(input.getModelConfig().getModelType(), newInput.getModelConfig().getModelType()); - assertEquals(input.getModelFormat(), newInput.getModelFormat()); - assertEquals(input.getModelState(), newInput.getModelState()); - assertEquals(input.getModelContentSizeInBytes(), newInput.getModelContentSizeInBytes()); - assertEquals(input.getModelContentHashValue(), newInput.getModelContentHashValue()); - assertEquals(input.getTotalChunks(), newInput.getTotalChunks()); - assertEquals(input.getBackendRoles(), newInput.getBackendRoles()); - assertEquals(input.getIsAddAllBackendRoles(), newInput.getIsAddAllBackendRoles()); - assertEquals(input.getAccessMode(), newInput.getAccessMode()); - assertEquals(input.getDoesVersionCreateModelGroup(), newInput.getDoesVersionCreateModelGroup()); - assertEquals(input.getIsHidden(), newInput.getIsHidden()); - } - - - @Test - public void testToXContent() throws IOException { - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); - String mlModelContent = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":" + - "\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\"," + - "\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\"," + - "\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\"," + - "\"model_config\":{\"model_type\":\"Model Type\",\"embedding_dimension\":123," + - "\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\"," + - "\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2," + - "\"add_all_backend_roles\":false,\"does_version_create_model_group\":false,\"is_hidden\":false}"; - assertEquals(expected, mlModelContent); - } -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.function.Function; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; + +public class MLRegisterModelMetaInputTest { + + Function function = parser -> { + try { + return MLRegisterModelMetaInput.parse(parser); + } catch (Exception e) { + throw new RuntimeException("Failed to parse MLRegisterModelMetaInput", e); + } + }; + TextEmbeddingModelConfig config; + MLRegisterModelMetaInput mLRegisterModelMetaInput; + + @Before + public void setup() { + config = new TextEmbeddingModelConfig( + "Model Type", + 123, + FrameworkType.SENTENCE_TRANSFORMERS, + "All Config", + TextEmbeddingModelConfig.PoolingMode.MEAN, + true, + 512 + ); + mLRegisterModelMetaInput = new MLRegisterModelMetaInput( + "Model Name", + FunctionName.BATCH_RCF, + "model_group_id", + "1.0", + "Model Description", + null, + null, + MLModelFormat.TORCH_SCRIPT, + MLModelState.DEPLOYING, + 200L, + "123", + config, + null, + 2, + null, + null, + false, + false, + false, + null + ); + } + + @Test + public void parse_MLRegisterModelMetaInput() throws IOException { + TestHelper.testParse(mLRegisterModelMetaInput, function); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(mLRegisterModelMetaInput); + } + + private void readInputStream(MLRegisterModelMetaInput input) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLRegisterModelMetaInput newInput = new MLRegisterModelMetaInput(streamInput); + assertEquals(input.getName(), newInput.getName()); + assertEquals(input.getFunctionName(), newInput.getFunctionName()); + assertEquals(input.getModelGroupId(), newInput.getModelGroupId()); + assertEquals(input.getVersion(), newInput.getVersion()); + assertEquals(input.getDescription(), newInput.getDescription()); + assertEquals(input.getModelFormat(), newInput.getModelFormat()); + assertEquals(input.getModelConfig().getAllConfig(), newInput.getModelConfig().getAllConfig()); + assertEquals(input.getModelConfig().getModelType(), newInput.getModelConfig().getModelType()); + assertEquals(input.getModelFormat(), newInput.getModelFormat()); + assertEquals(input.getModelState(), newInput.getModelState()); + assertEquals(input.getModelContentSizeInBytes(), newInput.getModelContentSizeInBytes()); + assertEquals(input.getModelContentHashValue(), newInput.getModelContentHashValue()); + assertEquals(input.getTotalChunks(), newInput.getTotalChunks()); + assertEquals(input.getBackendRoles(), newInput.getBackendRoles()); + assertEquals(input.getIsAddAllBackendRoles(), newInput.getIsAddAllBackendRoles()); + assertEquals(input.getAccessMode(), newInput.getAccessMode()); + assertEquals(input.getDoesVersionCreateModelGroup(), newInput.getDoesVersionCreateModelGroup()); + assertEquals(input.getIsHidden(), newInput.getIsHidden()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":" + + "\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\"," + + "\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\"," + + "\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\"," + + "\"model_config\":{\"model_type\":\"Model Type\",\"embedding_dimension\":123," + + "\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\"," + + "\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2," + + "\"add_all_backend_roles\":false,\"does_version_create_model_group\":false,\"is_hidden\":false}"; + assertEquals(expected, mlModelContent); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java index bbf64f5688..f765aced5b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -1,105 +1,130 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.model.MLModelFormat; -import org.opensearch.ml.common.model.MLModelState; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; - -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; - -public class MLRegisterModelMetaRequestTest { - - TextEmbeddingModelConfig config; - MLRegisterModelMetaInput mlRegisterModelMetaInput; - - @Before - public void setUp() { - config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", - TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); - mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0", - "Model Description", null, null, MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, null, 2, null, null, null, null, null, null); - } - - @Test - public void writeTo_Succeess() throws IOException { - MLRegisterModelMetaRequest request = new MLRegisterModelMetaRequest(mlRegisterModelMetaInput); - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - request.writeTo(bytesStreamOutput); - MLRegisterModelMetaRequest newRequest = new MLRegisterModelMetaRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(request.getMlRegisterModelMetaInput().getName(), newRequest.getMlRegisterModelMetaInput().getName()); - assertEquals(request.getMlRegisterModelMetaInput().getDescription(), - newRequest.getMlRegisterModelMetaInput().getDescription()); - assertEquals(request.getMlRegisterModelMetaInput().getFunctionName(), - newRequest.getMlRegisterModelMetaInput().getFunctionName()); - assertEquals(request.getMlRegisterModelMetaInput().getModelConfig().getAllConfig(), - newRequest.getMlRegisterModelMetaInput().getModelConfig().getAllConfig()); - assertEquals(request.getMlRegisterModelMetaInput().getModelGroupId(), - newRequest.getMlRegisterModelMetaInput().getModelGroupId()); - } - - @Test - public void validate_Exception_NullModelId() { - MLRegisterModelMetaRequest mlRegisterModelMetaRequest = MLRegisterModelMetaRequest.builder().build(); - ActionRequestValidationException exception = mlRegisterModelMetaRequest.validate(); - assertEquals("Validation Failed: 1: Model meta input can't be null;", exception.getMessage()); - } - - @Test - public void fromActionRequest_Success() { - MLRegisterModelMetaRequest request = new MLRegisterModelMetaRequest(mlRegisterModelMetaInput); - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - request.writeTo(out); - } - }; - MLRegisterModelMetaRequest newRequest = MLRegisterModelMetaRequest.fromActionRequest(actionRequest); - assertNotSame(request, newRequest); - assertEquals(request.getMlRegisterModelMetaInput().getName(), newRequest.getMlRegisterModelMetaInput().getName()); - assertEquals(request.getMlRegisterModelMetaInput().getDescription(), - newRequest.getMlRegisterModelMetaInput().getDescription()); - assertEquals(request.getMlRegisterModelMetaInput().getFunctionName(), - newRequest.getMlRegisterModelMetaInput().getFunctionName()); - assertEquals(request.getMlRegisterModelMetaInput().getModelConfig().getAllConfig(), - newRequest.getMlRegisterModelMetaInput().getModelConfig().getAllConfig()); - assertEquals(request.getMlRegisterModelMetaInput().getModelGroupId(), - newRequest.getMlRegisterModelMetaInput().getModelGroupId()); - } - - @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); - } - }; - MLRegisterModelMetaRequest.fromActionRequest(actionRequest); - } -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; + +public class MLRegisterModelMetaRequestTest { + + TextEmbeddingModelConfig config; + MLRegisterModelMetaInput mlRegisterModelMetaInput; + + @Before + public void setUp() { + config = new TextEmbeddingModelConfig( + "Model Type", + 123, + FrameworkType.SENTENCE_TRANSFORMERS, + "All Config", + TextEmbeddingModelConfig.PoolingMode.MEAN, + true, + 512 + ); + mlRegisterModelMetaInput = new MLRegisterModelMetaInput( + "Model Name", + FunctionName.BATCH_RCF, + "Model Group Id", + "1.0", + "Model Description", + null, + null, + MLModelFormat.TORCH_SCRIPT, + MLModelState.DEPLOYING, + 200L, + "123", + config, + null, + 2, + null, + null, + null, + null, + null, + null + ); + } + + @Test + public void writeTo_Succeess() throws IOException { + MLRegisterModelMetaRequest request = new MLRegisterModelMetaRequest(mlRegisterModelMetaInput); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLRegisterModelMetaRequest newRequest = new MLRegisterModelMetaRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(request.getMlRegisterModelMetaInput().getName(), newRequest.getMlRegisterModelMetaInput().getName()); + assertEquals(request.getMlRegisterModelMetaInput().getDescription(), newRequest.getMlRegisterModelMetaInput().getDescription()); + assertEquals(request.getMlRegisterModelMetaInput().getFunctionName(), newRequest.getMlRegisterModelMetaInput().getFunctionName()); + assertEquals( + request.getMlRegisterModelMetaInput().getModelConfig().getAllConfig(), + newRequest.getMlRegisterModelMetaInput().getModelConfig().getAllConfig() + ); + assertEquals(request.getMlRegisterModelMetaInput().getModelGroupId(), newRequest.getMlRegisterModelMetaInput().getModelGroupId()); + } + + @Test + public void validate_Exception_NullModelId() { + MLRegisterModelMetaRequest mlRegisterModelMetaRequest = MLRegisterModelMetaRequest.builder().build(); + ActionRequestValidationException exception = mlRegisterModelMetaRequest.validate(); + assertEquals("Validation Failed: 1: Model meta input can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + MLRegisterModelMetaRequest request = new MLRegisterModelMetaRequest(mlRegisterModelMetaInput); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLRegisterModelMetaRequest newRequest = MLRegisterModelMetaRequest.fromActionRequest(actionRequest); + assertNotSame(request, newRequest); + assertEquals(request.getMlRegisterModelMetaInput().getName(), newRequest.getMlRegisterModelMetaInput().getName()); + assertEquals(request.getMlRegisterModelMetaInput().getDescription(), newRequest.getMlRegisterModelMetaInput().getDescription()); + assertEquals(request.getMlRegisterModelMetaInput().getFunctionName(), newRequest.getMlRegisterModelMetaInput().getFunctionName()); + assertEquals( + request.getMlRegisterModelMetaInput().getModelConfig().getAllConfig(), + newRequest.getMlRegisterModelMetaInput().getModelConfig().getAllConfig() + ); + assertEquals(request.getMlRegisterModelMetaInput().getModelGroupId(), newRequest.getMlRegisterModelMetaInput().getModelGroupId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLRegisterModelMetaRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java index 92f66530e9..fb8fcd0e81 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java @@ -1,50 +1,49 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - -import java.io.IOException; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.TestHelper; - -public class MLRegisterModelMetaResponseTest { - - MLRegisterModelMetaResponse mlRegisterModelMetaResponse; - - @Before - public void setup() { - mlRegisterModelMetaResponse = new MLRegisterModelMetaResponse("Model Id", "Status"); - } - - - @Test - public void writeTo_Success() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - mlRegisterModelMetaResponse.writeTo(bytesStreamOutput); - MLRegisterModelMetaResponse newResponse = new MLRegisterModelMetaResponse(bytesStreamOutput.bytes().streamInput()); - assertEquals(mlRegisterModelMetaResponse.getModelId(), newResponse.getModelId()); - assertEquals(mlRegisterModelMetaResponse.getStatus(), newResponse.getStatus()); - } - - @Test - public void testToXContent() throws IOException { - MLRegisterModelMetaResponse response = new MLRegisterModelMetaResponse("Model Id", "Status"); - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - response.toXContent(builder, EMPTY_PARAMS); - assertNotNull(builder); - String jsonStr = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"model_id\":\"Model Id\",\"status\":\"Status\"}"; - assertEquals(expected, jsonStr); - } -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; + +public class MLRegisterModelMetaResponseTest { + + MLRegisterModelMetaResponse mlRegisterModelMetaResponse; + + @Before + public void setup() { + mlRegisterModelMetaResponse = new MLRegisterModelMetaResponse("Model Id", "Status"); + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlRegisterModelMetaResponse.writeTo(bytesStreamOutput); + MLRegisterModelMetaResponse newResponse = new MLRegisterModelMetaResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(mlRegisterModelMetaResponse.getModelId(), newResponse.getModelId()); + assertEquals(mlRegisterModelMetaResponse.getStatus(), newResponse.getStatus()); + } + + @Test + public void testToXContent() throws IOException { + MLRegisterModelMetaResponse response = new MLRegisterModelMetaResponse("Model Id", "Status"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = TestHelper.xContentBuilderToString(builder); + final String expected = "{\"model_id\":\"Model Id\",\"status\":\"Status\"}"; + assertEquals(expected, jsonStr); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInputTest.java index cafec77356..a04e5fcb51 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInputTest.java @@ -1,107 +1,116 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - -import java.io.IOException; -import java.util.Collections; -import java.util.function.Function; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.TestHelper; -import org.opensearch.search.SearchModule; - -public class MLUploadModelChunkInputTest { - - MLUploadModelChunkInput mlUploadModelChunkInput; - private Function function = parser -> { - try { - return MLUploadModelChunkInput.parse(parser, new byte[] { 12, 4, 5, 3 }); - } catch (Exception e) { - throw new RuntimeException("Failed to parse MLUploadModelChunkInput", e); - } - }; - - @Before - public void setup() { - mlUploadModelChunkInput = MLUploadModelChunkInput.builder().modelId("modelId").chunkNumber(1) - .content(new byte[] { 1, 3, 4 }).build(); - } - - @Test - public void parse_MLUploadModelChunkInput() throws IOException { - TestHelper.testParse(mlUploadModelChunkInput, function); - } - - @Test - public void readInputStream_Success() throws IOException { - readInputStream(mlUploadModelChunkInput); - } - - private void readInputStream(MLUploadModelChunkInput input) throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - input.writeTo(bytesStreamOutput); - StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); - MLUploadModelChunkInput newInput = new MLUploadModelChunkInput(streamInput); - assertEquals(input.getChunkNumber(), newInput.getChunkNumber()); - assertEquals(input.getModelId(), newInput.getModelId()); - } - - @Test - public void testMLUploadModelChunkInputConstructor() { - MLUploadModelChunkInput input = new MLUploadModelChunkInput("modelId", 1, new byte[] { 12, 3 }); - assertNotNull(input); - } - - @Test - public void testMLUploadModelChunkInputWriteToSuccess() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - mlUploadModelChunkInput.writeTo(bytesStreamOutput); - final var newLlUploadModelChunkInput = new MLUploadModelChunkInput(bytesStreamOutput.bytes().streamInput()); - assertEquals(mlUploadModelChunkInput.getModelId(), newLlUploadModelChunkInput.getModelId()); - assertEquals(mlUploadModelChunkInput.getChunkNumber(), newLlUploadModelChunkInput.getChunkNumber()); - } - - @Test - public void testToXContent() throws IOException { - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mlUploadModelChunkInput.toXContent(builder, EMPTY_PARAMS); - String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"model_id\":\"modelId\",\"chunk_number\":1,\"model_content\":\"AQME\"}", mlModelContent); - } - - @Test - public void testMLUploadModelChunkInputParser() throws IOException { - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder = mlUploadModelChunkInput.toXContent(builder, null); - String json = builder.toString(); - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry( - new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), null, json); - parser.nextToken(); - MLUploadModelChunkInput newMlUploadModelChunkInput = MLUploadModelChunkInput.parse(parser, new byte[] { 1, 3, 4 }); - assertEquals(mlUploadModelChunkInput, newMlUploadModelChunkInput); - } - - @Test - public void testMLUploadModelChunkInputParser_XContentParser() throws IOException { - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mlUploadModelChunkInput.toXContent(builder, EMPTY_PARAMS); - String mlModelContent = TestHelper.xContentBuilderToString(builder); - TestHelper.testParseFromString(mlUploadModelChunkInput, mlModelContent, function); - } -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.Collections; +import java.util.function.Function; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +public class MLUploadModelChunkInputTest { + + MLUploadModelChunkInput mlUploadModelChunkInput; + private Function function = parser -> { + try { + return MLUploadModelChunkInput.parse(parser, new byte[] { 12, 4, 5, 3 }); + } catch (Exception e) { + throw new RuntimeException("Failed to parse MLUploadModelChunkInput", e); + } + }; + + @Before + public void setup() { + mlUploadModelChunkInput = MLUploadModelChunkInput + .builder() + .modelId("modelId") + .chunkNumber(1) + .content(new byte[] { 1, 3, 4 }) + .build(); + } + + @Test + public void parse_MLUploadModelChunkInput() throws IOException { + TestHelper.testParse(mlUploadModelChunkInput, function); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(mlUploadModelChunkInput); + } + + private void readInputStream(MLUploadModelChunkInput input) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLUploadModelChunkInput newInput = new MLUploadModelChunkInput(streamInput); + assertEquals(input.getChunkNumber(), newInput.getChunkNumber()); + assertEquals(input.getModelId(), newInput.getModelId()); + } + + @Test + public void testMLUploadModelChunkInputConstructor() { + MLUploadModelChunkInput input = new MLUploadModelChunkInput("modelId", 1, new byte[] { 12, 3 }); + assertNotNull(input); + } + + @Test + public void testMLUploadModelChunkInputWriteToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlUploadModelChunkInput.writeTo(bytesStreamOutput); + final var newLlUploadModelChunkInput = new MLUploadModelChunkInput(bytesStreamOutput.bytes().streamInput()); + assertEquals(mlUploadModelChunkInput.getModelId(), newLlUploadModelChunkInput.getModelId()); + assertEquals(mlUploadModelChunkInput.getChunkNumber(), newLlUploadModelChunkInput.getChunkNumber()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlUploadModelChunkInput.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"model_id\":\"modelId\",\"chunk_number\":1,\"model_content\":\"AQME\"}", mlModelContent); + } + + @Test + public void testMLUploadModelChunkInputParser() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder = mlUploadModelChunkInput.toXContent(builder, null); + String json = builder.toString(); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + json + ); + parser.nextToken(); + MLUploadModelChunkInput newMlUploadModelChunkInput = MLUploadModelChunkInput.parse(parser, new byte[] { 1, 3, 4 }); + assertEquals(mlUploadModelChunkInput, newMlUploadModelChunkInput); + } + + @Test + public void testMLUploadModelChunkInputParser_XContentParser() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlUploadModelChunkInput.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + TestHelper.testParseFromString(mlUploadModelChunkInput, mlModelContent, function); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequestTest.java index 9571c5db53..4f20046077 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequestTest.java @@ -1,84 +1,81 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; - -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; - -public class MLUploadModelChunkRequestTest { - - MLUploadModelChunkInput mlUploadModelChunkInput; - - @Before - public void setUp() { - mlUploadModelChunkInput = new MLUploadModelChunkInput("modelId", 1, new byte[] { 12, 3 }); - } - - - @Test - public void writeTo_Succeess() throws IOException { - MLUploadModelChunkRequest request = new MLUploadModelChunkRequest(mlUploadModelChunkInput); - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - request.writeTo(bytesStreamOutput); - MLUploadModelChunkRequest newRequest = new MLUploadModelChunkRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(request.getUploadModelChunkInput(), newRequest.getUploadModelChunkInput()); - } - - @Test - public void validate_Exception_NullModelId() { - MLUploadModelChunkRequest mlUploadModelChunkRequest = MLUploadModelChunkRequest.builder().build(); - ActionRequestValidationException exception = mlUploadModelChunkRequest.validate(); - assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); - } - - - @Test - public void fromActionRequest_Success() { - MLUploadModelChunkRequest request = new MLUploadModelChunkRequest(mlUploadModelChunkInput); - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - request.writeTo(out); - } - }; - MLUploadModelChunkRequest result = MLUploadModelChunkRequest.fromActionRequest(actionRequest); - assertNotSame(request, result); - assertEquals(request.getUploadModelChunkInput(), result.getUploadModelChunkInput()); - } - - - @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); - } - }; - MLUploadModelChunkRequest.fromActionRequest(actionRequest); - } - -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLUploadModelChunkRequestTest { + + MLUploadModelChunkInput mlUploadModelChunkInput; + + @Before + public void setUp() { + mlUploadModelChunkInput = new MLUploadModelChunkInput("modelId", 1, new byte[] { 12, 3 }); + } + + @Test + public void writeTo_Succeess() throws IOException { + MLUploadModelChunkRequest request = new MLUploadModelChunkRequest(mlUploadModelChunkInput); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLUploadModelChunkRequest newRequest = new MLUploadModelChunkRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(request.getUploadModelChunkInput(), newRequest.getUploadModelChunkInput()); + } + + @Test + public void validate_Exception_NullModelId() { + MLUploadModelChunkRequest mlUploadModelChunkRequest = MLUploadModelChunkRequest.builder().build(); + ActionRequestValidationException exception = mlUploadModelChunkRequest.validate(); + assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + MLUploadModelChunkRequest request = new MLUploadModelChunkRequest(mlUploadModelChunkInput); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLUploadModelChunkRequest result = MLUploadModelChunkRequest.fromActionRequest(actionRequest); + assertNotSame(request, result); + assertEquals(request.getUploadModelChunkInput(), result.getUploadModelChunkInput()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLUploadModelChunkRequest.fromActionRequest(actionRequest); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java index 9bff6e68de..14aa51f7ae 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java @@ -1,47 +1,47 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - -import java.io.IOException; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.TestHelper; - -public class MLUploadModelChunkResponseTest { - - MLUploadModelChunkResponse mlUploadModelChunkResponse; - - @Before - public void setup() { - mlUploadModelChunkResponse = new MLUploadModelChunkResponse("Status"); - } - - @Test - public void writeTo_Success() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - mlUploadModelChunkResponse.writeTo(bytesStreamOutput); - MLUploadModelChunkResponse newResponse = new MLUploadModelChunkResponse(bytesStreamOutput.bytes().streamInput()); - assertEquals(mlUploadModelChunkResponse.getStatus(), newResponse.getStatus()); - } - - @Test - public void testToXContent() throws IOException { - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mlUploadModelChunkResponse.toXContent(builder, EMPTY_PARAMS); - assertNotNull(builder); - String jsonStr = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"status\":\"Status\"}"; - assertEquals(expected, jsonStr); - } -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; + +public class MLUploadModelChunkResponseTest { + + MLUploadModelChunkResponse mlUploadModelChunkResponse; + + @Before + public void setup() { + mlUploadModelChunkResponse = new MLUploadModelChunkResponse("Status"); + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlUploadModelChunkResponse.writeTo(bytesStreamOutput); + MLUploadModelChunkResponse newResponse = new MLUploadModelChunkResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(mlUploadModelChunkResponse.getStatus(), newResponse.getStatus()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlUploadModelChunkResponse.toXContent(builder, EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = TestHelper.xContentBuilderToString(builder); + final String expected = "{\"status\":\"Status\"}"; + assertEquals(expected, jsonStr); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/utils/IndexUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/IndexUtilsTest.java index 4c979cf0c4..ed50be13ca 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/IndexUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/IndexUtilsTest.java @@ -5,25 +5,27 @@ package org.opensearch.ml.common.utils; -import org.junit.Test; -import java.util.Map; import static org.junit.Assert.assertEquals; +import java.util.Map; + +import org.junit.Test; + public class IndexUtilsTest { - @Test - public void testIndexSettingsContainsExpectedValues() { - Map indexSettings = IndexUtils.INDEX_SETTINGS; - assertEquals("index.number_of_shards should be 1", indexSettings.get("index.number_of_shards"), "1"); - assertEquals("index.auto_expand_replicas should be 0-1", indexSettings.get("index.auto_expand_replicas"), "0-1"); - assertEquals("INDEX_SETTINGS should contain exactly 2 settings", 2, indexSettings.size()); - } - - @Test - public void testUpdatedIndexSettingsContainsExpectedValues() { - Map updatedIndexSettings = IndexUtils.UPDATED_INDEX_SETTINGS; - assertEquals("index.auto_expand_replicas should be 0-1", updatedIndexSettings.get("index.auto_expand_replicas"), "0-1"); - assertEquals("INDEX_SETTINGS should contain exactly 1 settings", 1, updatedIndexSettings.size()); - } + @Test + public void testIndexSettingsContainsExpectedValues() { + Map indexSettings = IndexUtils.INDEX_SETTINGS; + assertEquals("index.number_of_shards should be 1", indexSettings.get("index.number_of_shards"), "1"); + assertEquals("index.auto_expand_replicas should be 0-1", indexSettings.get("index.auto_expand_replicas"), "0-1"); + assertEquals("INDEX_SETTINGS should contain exactly 2 settings", 2, indexSettings.size()); + } + + @Test + public void testUpdatedIndexSettingsContainsExpectedValues() { + Map updatedIndexSettings = IndexUtils.UPDATED_INDEX_SETTINGS; + assertEquals("index.auto_expand_replicas should be 0-1", updatedIndexSettings.get("index.auto_expand_replicas"), "0-1"); + assertEquals("INDEX_SETTINGS should contain exactly 1 settings", 1, updatedIndexSettings.size()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java index 0227ecd520..96eb4b3fca 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/ModelInterfaceUtilsTest.java @@ -5,18 +5,6 @@ package org.opensearch.ml.common.utils; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.mockito.Spy; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.connector.HttpConnector; -import org.opensearch.ml.common.transport.register.MLRegisterModelInput; - -import java.util.HashMap; -import java.util.Map; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE; @@ -30,6 +18,18 @@ import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE; import static org.opensearch.ml.common.utils.ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Spy; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; + public class ModelInterfaceUtilsTest { @Spy MLRegisterModelInput registerModelInputWithInnerConnector; @@ -46,16 +46,16 @@ public class ModelInterfaceUtilsTest { @Before public void setUp() throws Exception { registerModelInputWithInnerConnector = MLRegisterModelInput - .builder() - .modelName("test-model-with-inner-connector") - .functionName(FunctionName.REMOTE) - .build(); + .builder() + .modelName("test-model-with-inner-connector") + .functionName(FunctionName.REMOTE) + .build(); registerModelInputWithStandaloneConnector = MLRegisterModelInput - .builder() - .modelName("test-model-with-stand-alone-connector") - .functionName(FunctionName.REMOTE) - .build(); + .builder() + .modelName("test-model-with-stand-alone-connector") + .functionName(FunctionName.REMOTE) + .build(); } @Test @@ -143,7 +143,10 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_CO connector = HttpConnector.builder().protocol("http").parameters(parameters).build(); updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector); - assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE); + assertEquals( + registerModelInputWithStandaloneConnector.getModelInterface(), + AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE + ); } @Test @@ -195,7 +198,9 @@ public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorNullParam } @Test - public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorInnerConnectorBEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE() { + public + void + testUpdateRegisterModelInputModelInterfaceFieldsByConnectorInnerConnectorBEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE() { Map parameters = new HashMap<>(); parameters.put("service_name", "bedrock"); parameters.put("model", "ai21.j2-mid-v1"); diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index b2aa45e068..cf112d6ca3 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -5,8 +5,7 @@ package org.opensearch.ml.common.utils; -import org.junit.Assert; -import org.junit.Test; +import static org.junit.Assert.assertEquals; import java.util.Arrays; import java.util.HashMap; @@ -15,8 +14,8 @@ import java.util.Map; import java.util.Set; -import static java.util.stream.Collectors.toList; -import static org.junit.Assert.assertEquals; +import org.junit.Assert; +import org.junit.Test; public class StringUtilsTest { @@ -70,12 +69,13 @@ public void fromJson_SimpleMap() { @Test public void fromJson_NestedMap() { - Map response = StringUtils.fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response"); + Map response = StringUtils + .fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response"); assertEquals(1, response.size()); Assert.assertTrue(response.get("key") instanceof Map); - Map nestedMap = (Map)response.get("key"); + Map nestedMap = (Map) response.get("key"); assertEquals("nested_value", nestedMap.get("nested_key")); - List list = (List)nestedMap.get("nested_array"); + List list = (List) nestedMap.get("nested_array"); assertEquals(2, list.size()); assertEquals(1.0, list.get(0)); assertEquals("a", list.get(1)); @@ -86,7 +86,7 @@ public void fromJson_SimpleList() { Map response = StringUtils.fromJson("[1, \"a\"]", "response"); assertEquals(1, response.size()); Assert.assertTrue(response.get("response") instanceof List); - List list = (List)response.get("response"); + List list = (List) response.get("response"); assertEquals(1.0, list.get(0)); assertEquals("a", list.get(1)); } @@ -96,7 +96,7 @@ public void fromJson_NestedList() { Map response = StringUtils.fromJson("[1, \"a\", [2, 3], {\"key\": \"value\"}]", "response"); assertEquals(1, response.size()); Assert.assertTrue(response.get("response") instanceof List); - List list = (List)response.get("response"); + List list = (List) response.get("response"); assertEquals(1.0, list.get(0)); assertEquals("a", list.get(1)); Assert.assertTrue(list.get(2) instanceof List); @@ -109,8 +109,8 @@ public void getParameterMap() { parameters.put("key1", "value1"); parameters.put("key2", 2); parameters.put("key3", 2.1); - parameters.put("key4", new int[]{10, 20}); - parameters.put("key5", new Object[]{1.01, "abc"}); + parameters.put("key4", new int[] { 10, 20 }); + parameters.put("key5", new Object[] { 1.01, "abc" }); Map parameterMap = StringUtils.getParameterMap(parameters); assertEquals(5, parameterMap.size()); assertEquals("value1", parameterMap.get("key1")); @@ -122,13 +122,13 @@ public void getParameterMap() { @Test public void getInterfaceMap() { - final Set allowedInterfaceFieldNameList = new HashSet<>(Arrays.asList("input","output")); + final Set allowedInterfaceFieldNameList = new HashSet<>(Arrays.asList("input", "output")); Map parameters = new HashMap<>(); parameters.put("input", "value1"); parameters.put("output", 2); parameters.put("key3", 2.1); - parameters.put("key4", new int[]{10, 20}); - parameters.put("key5", new Object[]{1.01, "abc"}); + parameters.put("key4", new int[] { 10, 20 }); + parameters.put("key5", new Object[] { 1.01, "abc" }); Map interfaceMap = StringUtils.filteredParameterMap(parameters, allowedInterfaceFieldNameList); Assert.assertEquals(2, interfaceMap.size()); Assert.assertEquals("value1", interfaceMap.get("input"));