From 2d4c496d79d282956edc94d93e375a14e555a1d4 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 14 Oct 2024 12:05:01 +0800 Subject: [PATCH] Add dimensions parameter support for bedrock titan embedding v2 model Signed-off-by: zane-neo --- .../connector/MLPreProcessFunction.java | 8 ++-- .../BedrockEmbeddingPreProcessFunction.java | 14 +++++++ .../ConnectorPreProcessFunction.java | 33 ++++++++++++--- .../preprocess/PreProcessFunction.java | 40 +++++++++++++++++++ ...edrockEmbeddingPreProcessFunctionTest.java | 13 +++++- .../algorithms/remote/ConnectorUtils.java | 6 +-- 6 files changed, 99 insertions(+), 15 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/PreProcessFunction.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index 3a5a3427a8..5acd1dd9b6 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -7,19 +7,17 @@ import java.util.HashMap; import java.util.Map; -import java.util.function.Function; import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.connector.functions.preprocess.PreProcessFunction; public class MLPreProcessFunction { - private static final Map> PRE_PROCESS_FUNCTIONS = new HashMap<>(); + private static final Map PRE_PROCESS_FUNCTIONS = new HashMap<>(); public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding"; @@ -50,7 +48,7 @@ public static boolean contains(String functionName) { return PRE_PROCESS_FUNCTIONS.containsKey(functionName); } - public static Function get(String postProcessFunction) { + public static PreProcessFunction get(String postProcessFunction) { return PRE_PROCESS_FUNCTIONS.get(postProcessFunction); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java index b6a95be042..cbc140fcc1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java @@ -8,7 +8,9 @@ import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; import java.util.Map; +import java.util.Optional; +import org.apache.commons.lang3.math.NumberUtils; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -24,10 +26,22 @@ public void validate(MLInput mlInput) { validateTextDocsInput(mlInput); } + // Keep this method for robust @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); Map processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0))); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); } + + @Override + public RemoteInferenceInputDataSet process(Map connectorParams, MLInput mlInput) { + TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); + // Amazon Titan Text Embeddings V2 model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html + // Default dimension is 1024 + int dimensions = Optional.ofNullable(connectorParams.get("dimensions")).map(x -> NumberUtils.toInt(x, 1024)).orElse(1024); + Map processedResult = Map + .of("parameters", Map.of("inputText", inputData.getDocs().get(0), "dimensions", dimensions)); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index 387ac27467..e8305ab5c3 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -10,7 +10,6 @@ import java.util.Collections; import java.util.Locale; import java.util.Map; -import java.util.function.Function; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -29,7 +28,7 @@ * If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true. */ @Log4j2 -public abstract class ConnectorPreProcessFunction implements Function { +public abstract class ConnectorPreProcessFunction implements PreProcessFunction { /** * This is a flag that can be used to determine if the pre-process function should return the input directly for RemoteInferenceInputDataSet. @@ -37,6 +36,32 @@ public abstract class ConnectorPreProcessFunction implements Function connectorParams, MLInput mlInput) { + if (mlInput == null) { + throw new IllegalArgumentException("Preprocess function input can't be null"); + } + if (returnDirectlyForRemoteInferenceInput && mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { + return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); + } else { + validate(mlInput); + if (connectorParams != null) { + return process(connectorParams, mlInput); + } else { + return process(mlInput); + } + } + } + /** * Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet. * @@ -57,10 +82,6 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) { } } - public abstract void validate(MLInput mlInput); - - public abstract RemoteInferenceInputDataSet process(MLInput mlInput); - /** * Validates the input of a pre-process function for text documents. * diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/PreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/PreProcessFunction.java new file mode 100644 index 0000000000..dd24d3c6a3 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/PreProcessFunction.java @@ -0,0 +1,40 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.ml.common.connector.functions.preprocess; + +import java.util.Map; + +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; + +/** + * The PreProcessFunction interface defines methods for preprocessing {@link MLInput} data + * before it is used for inference. It includes methods to apply preprocessing with or without + * additional parameters and to validate the input data. + */ +public interface PreProcessFunction { + + RemoteInferenceInputDataSet apply(Map connectorParams, MLInput mlInput); + + RemoteInferenceInputDataSet apply(MLInput mlInput); + + /** + * The default behavior of this method is to invoke process method with only the MLInput parameter, when the process + * needs more parameters from the connector parameters, the concrete implementation should override this method. + * @param connectorParams + * @param mlInput + * @return + */ + default RemoteInferenceInputDataSet process(Map connectorParams, MLInput mlInput) { + return process(mlInput); + } + + RemoteInferenceInputDataSet process(MLInput mlInput); + + void validate(MLInput mlInput); +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java index 851d7eaab7..eb6e023c34 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java @@ -39,7 +39,10 @@ public void setUp() { function = new BedrockEmbeddingPreProcessFunction(); textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); - remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build(); + remoteInferenceInputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("key1", "value1", "key2", "value2", "dimensions", "1024")) + .build(); textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); @@ -73,4 +76,12 @@ public void process_RemoteInferenceInput() { RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); assertEquals(remoteInferenceInputDataSet, dataSet); } + + @Test + public void process_TextDocsInput_withConnectorParams() { + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); + RemoteInferenceInputDataSet dataSet = function.apply(Map.of("dimensions", "1024"), mlInput); + assertEquals(2, dataSet.getParameters().size()); + assertEquals("1024", dataSet.getParameters().get("dimensions")); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index f2c93ef5fd..89af9ed6a2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -25,7 +25,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.function.Function; import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; @@ -34,6 +33,7 @@ import org.opensearch.ml.common.connector.MLPostProcessFunction; import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction; +import org.opensearch.ml.common.connector.functions.preprocess.PreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; @@ -106,8 +106,8 @@ private static RemoteInferenceInputDataSet processMLInput( } else { preProcessFunction = fillProcessFunctionParameter(parameters, preProcessFunction); if (MLPreProcessFunction.contains(preProcessFunction)) { - Function function = MLPreProcessFunction.get(preProcessFunction); - return function.apply(mlInput); + PreProcessFunction function = MLPreProcessFunction.get(preProcessFunction); + return function.apply(parameters, mlInput); } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT) && Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) {