Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dimensions parameter support for bedrock titan embedding v2 model #3136

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@

import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;

import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.connector.functions.preprocess.PreProcessFunction;

public class MLPreProcessFunction {

private static final Map<String, Function<MLInput, RemoteInferenceInputDataSet>> PRE_PROCESS_FUNCTIONS = new HashMap<>();
private static final Map<String, PreProcessFunction> PRE_PROCESS_FUNCTIONS = new HashMap<>();
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
Expand Down Expand Up @@ -50,7 +48,7 @@ public static boolean contains(String functionName) {
return PRE_PROCESS_FUNCTIONS.containsKey(functionName);
}

public static Function<MLInput, RemoteInferenceInputDataSet> get(String postProcessFunction) {
public static PreProcessFunction get(String postProcessFunction) {
return PRE_PROCESS_FUNCTIONS.get(postProcessFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

import java.util.Map;
import java.util.Optional;

import org.apache.commons.lang3.math.NumberUtils;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -24,10 +26,22 @@ public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
}

// Keep this method for robust
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, Object> processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0)));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}

@Override
public RemoteInferenceInputDataSet process(Map<String, String> connectorParams, MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
// Amazon Titan Text Embeddings V2 model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html
// Default dimension is 1024
int dimensions = Optional.ofNullable(connectorParams.get("dimensions")).map(x -> NumberUtils.toInt(x, 1024)).orElse(1024);
Map<String, Object> processedResult = Map
.of("parameters", Map.of("inputText", inputData.getDocs().get(0), "dimensions", dimensions));
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
Expand All @@ -29,14 +28,40 @@
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true.
*/
@Log4j2
public abstract class ConnectorPreProcessFunction implements Function<MLInput, RemoteInferenceInputDataSet> {
public abstract class ConnectorPreProcessFunction implements PreProcessFunction {

/**
* This is a flag that can be used to determine if the pre-process function should return the input directly for RemoteInferenceInputDataSet.
* If this is true and the input is already of type RemoteInferenceInputDataSet, it will be returned directly, otherwise it will be processed.
*/
protected boolean returnDirectlyForRemoteInferenceInput;

/**
* Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet.
*
* @param connectorParams the connector parameters: including parameters defined in the connector and the parameters from request.
* refer to RemoteConnectorExecutor.preparePayloadAndInvoke for details.
* @param mlInput the MLInput object to be processed
* @return RemoteInferenceInputDataSet resulting from the pre-processing function
* @throws IllegalArgumentException if the input MLInput object is null
*/
@Override
public RemoteInferenceInputDataSet apply(Map<String, String> connectorParams, MLInput mlInput) {
if (mlInput == null) {
throw new IllegalArgumentException("Preprocess function input can't be null");
}
if (returnDirectlyForRemoteInferenceInput && mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
return (RemoteInferenceInputDataSet) mlInput.getInputDataset();
} else {
validate(mlInput);
if (connectorParams != null) {
return process(connectorParams, mlInput);
} else {
return process(mlInput);
}
}
}

/**
* Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet.
*
Expand All @@ -57,10 +82,6 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) {
}
}

public abstract void validate(MLInput mlInput);

public abstract RemoteInferenceInputDataSet process(MLInput mlInput);

/**
* Validates the input of a pre-process function for text documents.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import java.util.Map;

import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

/**
* The PreProcessFunction interface defines methods for preprocessing {@link MLInput} data
* before it is used for inference. It includes methods to apply preprocessing with or without
* additional parameters and to validate the input data.
*/
public interface PreProcessFunction {

RemoteInferenceInputDataSet apply(Map<String, String> connectorParams, MLInput mlInput);

RemoteInferenceInputDataSet apply(MLInput mlInput);

/**
* The default behavior of this method is to invoke process method with only the MLInput parameter, when the process
* needs more parameters from the connector parameters, the concrete implementation should override this method.
* @param connectorParams
* @param mlInput
* @return
*/
default RemoteInferenceInputDataSet process(Map<String, String> connectorParams, MLInput mlInput) {
return process(mlInput);
}

RemoteInferenceInputDataSet process(MLInput mlInput);

void validate(MLInput mlInput);
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ public void setUp() {
function = new BedrockEmbeddingPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("key1", "value1", "key2", "value2", "dimensions", "1024"))
.build();

textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
Expand Down Expand Up @@ -73,4 +76,12 @@ public void process_RemoteInferenceInput() {
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
assertEquals(remoteInferenceInputDataSet, dataSet);
}

@Test
public void process_TextDocsInput_withConnectorParams() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(Map.of("dimensions", "1024"), mlInput);
assertEquals(2, dataSet.getParameters().size());
assertEquals("1024", dataSet.getParameters().get("dimensions"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
Expand All @@ -34,6 +33,7 @@
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.PreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
Expand Down Expand Up @@ -106,8 +106,8 @@ private static RemoteInferenceInputDataSet processMLInput(
} else {
preProcessFunction = fillProcessFunctionParameter(parameters, preProcessFunction);
if (MLPreProcessFunction.contains(preProcessFunction)) {
Function<MLInput, RemoteInferenceInputDataSet> function = MLPreProcessFunction.get(preProcessFunction);
return function.apply(mlInput);
PreProcessFunction function = MLPreProcessFunction.get(preProcessFunction);
return function.apply(parameters, mlInput);
} else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
if (parameters.containsKey(PROCESS_REMOTE_INFERENCE_INPUT)
&& Boolean.parseBoolean(parameters.get(PROCESS_REMOTE_INFERENCE_INPUT))) {
Expand Down
Loading