Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for asymmetric embedding models #710

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Add support for asymmetric embedding models ([#710](https://github.com/opensearch-project/neural-search/pull/710))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -48,17 +49,19 @@ public void doExecute(
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(resultMaps -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps));
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}

@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentencesWithMapResult(
this.modelId,
inferenceList,
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).build(),
ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,28 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;

/**
* This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use,
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results.
* This processor is used for user input data text embedding processing, model_id can be used to
* indicate which model user use, and field_map can be used to indicate which fields needs text
* embedding and the corresponding keys for the text embedding results.
*/
@Log4j2
public final class TextEmbeddingProcessor extends InferenceProcessor {

public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";

private static final AsymmetricTextEmbeddingParameters PASSAGE_PARAMETERS = AsymmetricTextEmbeddingParameters.builder()
.embeddingContentType(EmbeddingContentType.PASSAGE)
.build();

public TextEmbeddingProcessor(
String tag,
String description,
Expand All @@ -47,14 +55,20 @@ public void doExecute(
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
mlCommonsClientAccessor.inferenceSentences(
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}

@Override
public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(handler::accept, onException));
mlCommonsClientAccessor.inferenceSentences(
InferenceRequest.builder().modelId(this.modelId).inputTexts(inferenceList).mlAlgoParams(PASSAGE_PARAMETERS).build(),
ActionListener.wrap(handler::accept, onException)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.common.annotations.VisibleForTesting;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

/**
Expand Down Expand Up @@ -113,10 +114,13 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
if (inferenceMap.isEmpty()) {
handler.accept(ingestDocument, null);
} else {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceMap, ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
mlCommonsClientAccessor.inferenceSentencesMap(
InferenceRequest.builder().modelId(this.modelId).inputObjects(inferenceMap).build(),
ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); })
);
}
} catch (Exception e) {
handler.accept(null, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
Expand Down Expand Up @@ -73,9 +74,11 @@ public void rescoreSearchResponse(
List<?> ctxList = (List<?>) ctxObj;
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList());
mlCommonsClientAccessor.inferenceSimilarity(
modelId,
(String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD),
contexts,
InferenceRequest.builder()
.modelId(modelId)
.queryText((String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD))
.inputTexts(contexts)
.build(),
listener
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import org.opensearch.knn.index.query.parser.RescoreParser;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.common.MinClusterVersionUtil;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

import com.google.common.annotations.VisibleForTesting;
Expand All @@ -55,11 +57,12 @@
import lombok.Setter;
import lombok.experimental.Accessors;
import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;

/**
* NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a wrapper around a
* k-NN vector query. It uses a ML language model to produce a dense vector from a query string that is then used as
* the query vector for the k-NN search.
* NeuralQueryBuilder is responsible for producing "neural" query types. A "neural" query type is a
* wrapper around a k-NN vector query. It uses a ML language model to produce a dense vector from a
* query string that is then used as the query vector for the k-NN search.
*/

@Log4j2
Expand All @@ -84,6 +87,9 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
static final ParseField K_FIELD = new ParseField("k");

private static final int DEFAULT_K = 10;
private static final AsymmetricTextEmbeddingParameters QUERY_PARAMETERS = AsymmetricTextEmbeddingParameters.builder()
.embeddingContentType(EmbeddingContentType.QUERY)
.build();

private static MLCommonsClientAccessor ML_CLIENT;

Expand Down Expand Up @@ -333,10 +339,13 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
inferenceInput.put(INPUT_IMAGE, queryImage());
}
queryRewriteContext.registerAsyncAction(
((client, actionListener) -> ML_CLIENT.inferenceSentences(modelId(), inferenceInput, ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)))
((client, actionListener) -> ML_CLIENT.inferenceSentencesMap(
InferenceRequest.builder().modelId(modelId()).inputObjects(inferenceInput).mlAlgoParams(QUERY_PARAMETERS).build(),
ActionListener.wrap(floatList -> {
vectorSetOnce.set(vectorAsListToArray(floatList));
actionListener.onResponse(null);
}, actionListener::onFailure)
))
);
return new NeuralQueryBuilder(
fieldName(),
Expand All @@ -361,8 +370,12 @@ protected Query doToQuery(QueryShardContext queryShardContext) {

@Override
protected boolean doEquals(NeuralQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(fieldName, obj.fieldName);
equalsBuilder.append(queryText, obj.queryText);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor.InferenceRequest;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

Expand Down Expand Up @@ -341,8 +342,7 @@ private BiConsumer<Client, ActionListener<?>> getModelInferenceAsync(SetOnce<Map
// it splits the tokens using a threshold defined by a ratio of the maximum score of tokens, updating the token set
// accordingly.
return ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult(
modelId(),
List.of(queryText),
InferenceRequest.builder().modelId(modelId()).inputTexts(List.of(queryText)).build(),
ActionListener.wrap(mapResultList -> {
Map<String, Float> queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0);
if (Objects.nonNull(twoPhaseSharedQueryToken)) {
Expand Down
Loading
Loading