Skip to content

Commit

Permalink
Convert record to lombok value, add unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 13, 2024
1 parent 9830ab3 commit 7a95087
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,17 @@ public SearchResponse processResponse(
for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) {
normalizedExplanation[i] = Explanation.match(
// normalized score
normalizationExplanation.scoreDetails().get(i).getKey(),
normalizationExplanation.getScoreDetails().get(i).getKey(),
// description of normalized score
normalizationExplanation.scoreDetails().get(i).getValue(),
normalizationExplanation.getScoreDetails().get(i).getValue(),
// shard level details
queryLevelExplanation.getDetails()[i]
);
}
Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
// combination level explanation is always a single detail
combinationExplanation.scoreDetails().get(0).getValue(),
combinationExplanation.getScoreDetails().get(0).getValue(),
normalizedExplanation
);
searchHit.explanation(finalExplanation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplanations = combinationExplain.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().stream().map(explainDetail -> {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), entry.getKey());
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.getDocId(), entry.getKey());
return CombinedExplanationDetails.builder()
.normalizationExplanations(normalizationExplain.get(docIdAtSearchShard))
.combinationExplanations(explainDetail)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Value;
import org.opensearch.search.SearchShardTarget;

/**
* DTO class to store index, shardId and nodeId for a search shard.
*/
public record SearchShard(String index, int shardId, String nodeId) {
@Value
@AllArgsConstructor
public class SearchShard {
String index;
int shardId;
String nodeId;

/**
* Create SearchShard from SearchShardTarget
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
*/
package org.opensearch.neuralsearch.processor.explain;

import lombok.Value;
import org.opensearch.neuralsearch.processor.SearchShard;

/**
* DTO class to store docId and search shard for a query.
* Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards.
* @param docId
* @param searchShard
*/
public record DocIdAtSearchShard(int docId, SearchShard searchShard) {
@Value
public class DocIdAtSearchShard {
int docId;
SearchShard searchShard;
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
*/
package org.opensearch.neuralsearch.processor.explain;

import lombok.AllArgsConstructor;
import lombok.Value;
import org.apache.commons.lang3.tuple.Pair;

import java.util.List;

/**
* DTO class to store value and description for explain details.
* Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards.
* @param docId iterator based id of the document
* @param scoreDetails list of score details for the document, each Pair object contains score and description of the score
*/
public record ExplanationDetails(int docId, List<Pair<Float, String>> scoreDetails) {
@Value
@AllArgsConstructor
public class ExplanationDetails {
int docId;
List<Pair<Float, String>> scoreDetails;

public ExplanationDetails(List<Pair<Float, String>> scoreDetails) {
this(-1, scoreDetails);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -45,6 +46,9 @@ public static Map<DocIdAtSearchShard, ExplanationDetails> getDocIdAtQueryForNorm
* @return a string describing the combination technique and its parameters
*/
public static String describeCombinationTechnique(final String techniqueName, final List<Float> weights) {
if (Objects.isNull(techniqueName)) {
throw new IllegalArgumentException("combination technique name cannot be null");
}
return Optional.ofNullable(weights)
.filter(w -> !w.isEmpty())
.map(w -> String.format(Locale.ROOT, "%s, weights %s", techniqueName, weights))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.explain;

import org.apache.commons.lang3.tuple.Pair;
import org.junit.Before;

import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ExplanationUtilsTests extends OpenSearchQueryTestCase {

private DocIdAtSearchShard docId1;
private DocIdAtSearchShard docId2;
private Map<DocIdAtSearchShard, List<Float>> normalizedScores;
private final MinMaxScoreNormalizationTechnique MIN_MAX_TECHNIQUE = new MinMaxScoreNormalizationTechnique();

@Before
public void setUp() throws Exception {
super.setUp();
SearchShard searchShard = new SearchShard("test_index", 0, "abcdefg");
docId1 = new DocIdAtSearchShard(1, searchShard);
docId2 = new DocIdAtSearchShard(2, searchShard);
normalizedScores = new HashMap<>();
}

public void testGetDocIdAtQueryForNormalization() {
// Setup
normalizedScores.put(docId1, Arrays.asList(1.0f, 0.5f));
normalizedScores.put(docId2, Arrays.asList(0.8f));
// Act
Map<DocIdAtSearchShard, ExplanationDetails> result = ExplanationUtils.getDocIdAtQueryForNormalization(
normalizedScores,
MIN_MAX_TECHNIQUE
);
// Assert
assertNotNull(result);
assertEquals(2, result.size());

// Assert first document
ExplanationDetails details1 = result.get(docId1);
assertNotNull(details1);
List<Pair<Float, String>> explanations1 = details1.getScoreDetails();
assertEquals(2, explanations1.size());
assertEquals(1.0f, explanations1.get(0).getLeft(), 0.001);
assertEquals(0.5f, explanations1.get(1).getLeft(), 0.001);
assertEquals("min_max normalization of:", explanations1.get(0).getRight());
assertEquals("min_max normalization of:", explanations1.get(1).getRight());

// Assert second document
ExplanationDetails details2 = result.get(docId2);
assertNotNull(details2);
List<Pair<Float, String>> explanations2 = details2.getScoreDetails();
assertEquals(1, explanations2.size());
assertEquals(0.8f, explanations2.get(0).getLeft(), 0.001);
assertEquals("min_max normalization of:", explanations2.get(0).getRight());
}

public void testGetDocIdAtQueryForNormalizationWithEmptyScores() {
// Setup
// Using empty normalizedScores from setUp
// Act
Map<DocIdAtSearchShard, ExplanationDetails> result = ExplanationUtils.getDocIdAtQueryForNormalization(
normalizedScores,
MIN_MAX_TECHNIQUE
);
// Assert
assertNotNull(result);
assertTrue(result.isEmpty());
}

public void testDescribeCombinationTechniqueWithWeights() {
// Setup
String techniqueName = "test_technique";
List<Float> weights = Arrays.asList(0.3f, 0.7f);
// Act
String result = ExplanationUtils.describeCombinationTechnique(techniqueName, weights);
// Assert
assertEquals("test_technique, weights [0.3, 0.7]", result);
}

public void testDescribeCombinationTechniqueWithoutWeights() {
// Setup
String techniqueName = "test_technique";
// Act
String result = ExplanationUtils.describeCombinationTechnique(techniqueName, null);
// Assert
assertEquals("test_technique", result);
}

public void testDescribeCombinationTechniqueWithEmptyWeights() {
// Setup
String techniqueName = "test_technique";
List<Float> weights = Arrays.asList();
// Act
String result = ExplanationUtils.describeCombinationTechnique(techniqueName, weights);
// Assert
assertEquals("test_technique", result);
}

public void testDescribeCombinationTechniqueWithNullTechnique() {
// Setup
List<Float> weights = Arrays.asList(1.0f);
// Act & Assert
expectThrows(IllegalArgumentException.class, () -> ExplanationUtils.describeCombinationTechnique(null, weights));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.factory;

import lombok.SneakyThrows;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.test.OpenSearchTestCase;

import java.util.HashMap;
import java.util.Map;

import static org.mockito.Mockito.mock;

public class ExplanationResponseProcessorFactoryTests extends OpenSearchTestCase {

@SneakyThrows
public void testDefaults_whenNoParams_thenSuccessful() {
// Setup
ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory();
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;
Map<String, Object> config = new HashMap<>();
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
// Act
SearchResponseProcessor responseProcessor = explanationResponseProcessorFactory.create(
processorFactories,
tag,
description,
ignoreFailure,
config,
pipelineContext
);
// Assert
assertProcessor(responseProcessor, tag, description, ignoreFailure);
}

@SneakyThrows
public void testInvalidInput_whenParamsPassedToFactory_thenSuccessful() {
// Setup
ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory();
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;
// create map of random parameters
Map<String, Object> config = new HashMap<>();
for (int i = 0; i < randomInt(1_000); i++) {
config.put(randomAlphaOfLength(10) + i, randomAlphaOfLength(100));
}
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
// Act
SearchResponseProcessor responseProcessor = explanationResponseProcessorFactory.create(
processorFactories,
tag,
description,
ignoreFailure,
config,
pipelineContext
);
// Assert
assertProcessor(responseProcessor, tag, description, ignoreFailure);
}

@SneakyThrows
public void testNewInstanceCreation_whenCreateMultipleTimes_thenNewInstanceReturned() {
// Setup
ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory();
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories = new HashMap<>();
String tag = "tag";
String description = "description";
boolean ignoreFailure = false;
Map<String, Object> config = new HashMap<>();
Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class);
// Act
SearchResponseProcessor responseProcessorOne = explanationResponseProcessorFactory.create(
processorFactories,
tag,
description,
ignoreFailure,
config,
pipelineContext
);

SearchResponseProcessor responseProcessorTwo = explanationResponseProcessorFactory.create(
processorFactories,
tag,
description,
ignoreFailure,
config,
pipelineContext
);

// Assert
assertNotEquals(responseProcessorOne, responseProcessorTwo);
}

private static void assertProcessor(SearchResponseProcessor responseProcessor, String tag, String description, boolean ignoreFailure) {
assertNotNull(responseProcessor);
assertTrue(responseProcessor instanceof ExplanationResponseProcessor);
ExplanationResponseProcessor explanationResponseProcessor = (ExplanationResponseProcessor) responseProcessor;
assertEquals("explanation_response_processor", explanationResponseProcessor.getType());
assertEquals(tag, explanationResponseProcessor.getTag());
assertEquals(description, explanationResponseProcessor.getDescription());
assertEquals(ignoreFailure, explanationResponseProcessor.isIgnoreFailure());
}
}

0 comments on commit 7a95087

Please sign in to comment.