From 85a735cc4a6eaafa1d7b841a8d9750af2a048ac2 Mon Sep 17 00:00:00 2001 From: br3no Date: Fri, 26 Apr 2024 19:50:32 +0200 Subject: [PATCH] added missing result filter to inference Signed-off-by: br3no --- .../engine/algorithms/TextEmbeddingModel.java | 2 +- .../TextEmbeddingDenseModelTest.java | 38 +++++++++++++++++-- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java index 33a69697d1..63c11ca79d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java @@ -83,7 +83,7 @@ private TextDocsInputDataSet addPrefixesToData(AsymmetricTextEmbeddingParameters : modelConfig.getQueryPrefix(); if (prefix != null) { List prefixedDocs = inputDataSet.getDocs().stream().map(s -> prefix + s).collect(Collectors.toList()); - return TextDocsInputDataSet.builder().docs(prefixedDocs).build(); + return TextDocsInputDataSet.builder().docs(prefixedDocs).resultFilter(inputDataSet.getResultFilter()).build(); } return inputDataSet; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java index 6c72a97fcb..179d11d2a1 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingDenseModelTest.java @@ -262,7 +262,13 @@ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_As .builder() .algorithm(FunctionName.TEXT_EMBEDDING) .inputDataset( - TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build() + TextDocsInputDataSet + .builder() + .docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")) + .resultFilter( + ModelResultFilter.builder().targetResponse(List.of(SENTENCE_EMBEDDING)).returnBytes(true).returnNumber(true).build() + ) + .build() ) .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY)) .build(); @@ -270,7 +276,13 @@ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_As .builder() .algorithm(FunctionName.TEXT_EMBEDDING) .inputDataset( - TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build() + TextDocsInputDataSet + .builder() + .docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")) + .resultFilter( + ModelResultFilter.builder().targetResponse(List.of(SENTENCE_EMBEDDING)).returnBytes(true).returnNumber(true).build() + ) + .build() ) .parameters(new AsymmetricTextEmbeddingParameters(EmbeddingContentType.PASSAGE)) .build(); @@ -285,20 +297,38 @@ public void initModel_predict_TorchScript_SentenceTransformer_SmallModel_With_As .builder() .algorithm(FunctionName.TEXT_EMBEDDING) .inputDataset( - TextDocsInputDataSet.builder().docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")).build() + TextDocsInputDataSet + .builder() + .docs(Arrays.asList("what is the meaning of life?", "who won this year's us open")) + .resultFilter( + ModelResultFilter.builder().targetResponse(List.of(SENTENCE_EMBEDDING)).returnBytes(true).returnNumber(true).build() + ) + .build() ) .build(); MLInput symmetricMlInputPassages = MLInput .builder() .algorithm(FunctionName.TEXT_EMBEDDING) .inputDataset( - TextDocsInputDataSet.builder().docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")).build() + TextDocsInputDataSet + .builder() + .docs(Arrays.asList("The meaning of life is 42", "I won this year's us open")) + .resultFilter( + ModelResultFilter.builder().targetResponse(List.of(SENTENCE_EMBEDDING)).returnBytes(true).returnNumber(true).build() + ) + .build() ) .build(); ModelTensorOutput symmetricQueryEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputQueries); ModelTensorOutput symmetricPassageEmbeddings = (ModelTensorOutput) textEmbeddingDenseModel.predict(symmetricMlInputPassages); + assertTrue( + "asymmetric and symmetric embeddings should have the same number of tensors", + asymmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().size() == 1 + && symmetricQueryEmbeddings.getMlModelOutputs().get(0).getMlModelTensors().size() == 1 + ); + assertTrue( "asymmetric and symmetric query embeddings should be different", areTensorsDifferent(