Skip to content

Commit

Permalink
Address revie comments
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 15, 2024
1 parent 7a95087 commit d297e0f
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@ public class ExplanationResponseProcessor implements SearchResponseProcessor {
private final String tag;
private final boolean ignoreFailure;

/**
* Add explanation details to search response if it is present in request context
*/
@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) {
return processResponse(request, response, null);
}

/**
* Combines explanation from processor with search hits level explanations and adds it to search response
*/
@Override
public SearchResponse processResponse(
final SearchRequest request,
Expand All @@ -56,35 +62,43 @@ public SearchResponse processResponse(
|| requestContext.getAttribute(EXPLANATION_RESPONSE_KEY) instanceof ExplanationPayload == false) {
return response;
}
// Extract explanation payload from context
ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLANATION_RESPONSE_KEY);
Map<ExplanationPayload.PayloadType, Object> explainPayload = explanationPayload.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
// for score normalization, processor level explanations will be sorted in scope of each shard,
// and we are merging both into a single sorted list
SearchHits searchHits = response.getHits();
SearchHit[] searchHitsArray = searchHits.getHits();
// create a map of searchShard and list of indexes of search hit objects in search hits array
// the list will keep original order of sorting as per final search results
Map<SearchShard, List<Integer>> searchHitsByShard = new HashMap<>();
// we keep index for each shard, where index is a position in searchHitsByShard list
Map<SearchShard, Integer> explainsByShardCount = new HashMap<>();
// Build initial shard mappings
for (int i = 0; i < searchHitsArray.length; i++) {
SearchHit searchHit = searchHitsArray[i];
SearchShardTarget searchShardTarget = searchHit.getShard();
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i);
explainsByShardCount.putIfAbsent(searchShard, -1);
}
// Process normalization details if available in correct format
if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map<?, ?>) {
@SuppressWarnings("unchecked")
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplainDetails = (Map<
SearchShard,
List<CombinedExplanationDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);

// Process each search hit to add processor level explanations
for (SearchHit searchHit : searchHitsArray) {
SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard());
int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1;
CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
// Extract various explanation components
Explanation queryLevelExplanation = searchHit.getExplanation();
ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations();
ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations();
// Create normalized explanations for each detail
Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length];
for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) {
normalizedExplanation[i] = Explanation.match(
Expand All @@ -96,6 +110,7 @@ public SearchResponse processResponse(
queryLevelExplanation.getDetails()[i]
);
}
// Create and set final explanation combining all components
Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
// combination level explanation is always a single detail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -106,6 +107,10 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)
updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds);
}

/**
* Collects explanations from normalization and combination techniques and save thme into pipeline context. Later that
* information will be read by the response processor to add it to search response
*/
private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<CompoundTopDocs> queryTopDocs) {
if (!request.isExplain()) {
return;
Expand All @@ -122,15 +127,19 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
request.getCombinationTechnique(),
sortForQuery
);
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplanations = combinationExplain.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().stream().map(explainDetail -> {
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplanations = new HashMap<>();
for (Map.Entry<SearchShard, List<ExplanationDetails>> entry : combinationExplain.entrySet()) {
List<CombinedExplanationDetails> combinedDetailsList = new ArrayList<>();
for (ExplanationDetails explainDetail : entry.getValue()) {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.getDocId(), entry.getKey());
return CombinedExplanationDetails.builder()
CombinedExplanationDetails combinedDetail = CombinedExplanationDetails.builder()
.normalizationExplanations(normalizationExplain.get(docIdAtSearchShard))
.combinationExplanations(explainDetail)
.build();
}).collect(Collectors.toList())));
combinedDetailsList.add(combinedDetail);
}
combinedExplanations.put(entry.getKey(), combinedDetailsList);
}

ExplanationPayload explanationPayload = ExplanationPayload.builder()
.explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplanations))
Expand All @@ -139,7 +148,6 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext();
pipelineProcessingContext.setAttribute(EXPLANATION_RESPONSE_KEY, explanationPayload);
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,19 +359,19 @@ private List<ExplanationDetails> explainByShard(
// sort combined scores as per sorting criteria - either score desc or field sorting
Collection<Integer> sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId);

List<ExplanationDetails> listOfExplanations = sortedDocsIds.stream()
.map(
docId -> new ExplanationDetails(
docId,
List.of(
Pair.of(
combinedNormalizedScoresByDocId.get(docId),
String.format(Locale.ROOT, "%s combination of:", ((ExplainableTechnique) scoreCombinationTechnique).describe())
)
)
)
)
.toList();
List<ExplanationDetails> listOfExplanations = new ArrayList<>();
String combinationDescription = String.format(
Locale.ROOT,
"%s combination of:",
((ExplainableTechnique) scoreCombinationTechnique).describe()
);
for (int docId : sortedDocsIds) {
ExplanationDetails explanation = new ExplanationDetails(
docId,
List.of(Pair.of(combinedNormalizedScoresByDocId.get(docId), combinationDescription))
);
listOfExplanations.add(explanation);
}
return listOfExplanations;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

import org.apache.commons.lang3.tuple.Pair;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* Utility class for explain functionality
Expand All @@ -27,15 +28,17 @@ public static Map<DocIdAtSearchShard, ExplanationDetails> getDocIdAtQueryForNorm
final Map<DocIdAtSearchShard, List<Float>> normalizedScores,
final ExplainableTechnique technique
) {
Map<DocIdAtSearchShard, ExplanationDetails> explain = normalizedScores.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> {
List<Float> normScores = normalizedScores.get(entry.getKey());
List<Pair<Float, String>> explanations = normScores.stream()
.map(score -> Pair.of(score, String.format(Locale.ROOT, "%s normalization of:", technique.describe())))
.collect(Collectors.toList());
return new ExplanationDetails(explanations);
}));
Map<DocIdAtSearchShard, ExplanationDetails> explain = new HashMap<>();
for (Map.Entry<DocIdAtSearchShard, List<Float>> entry : normalizedScores.entrySet()) {
List<Float> normScores = normalizedScores.get(entry.getKey());
List<Pair<Float, String>> explanations = new ArrayList<>();
for (float score : normScores) {
String description = String.format(Locale.ROOT, "%s normalization of:", technique.describe());
explanations.add(Pair.of(score, description));
}
explain.put(entry.getKey(), new ExplanationDetails(explanations));
}

return explain;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.IntStream;
Expand All @@ -37,6 +38,7 @@ public class HybridQueryExplainIT extends BaseNeuralSearchIT {
private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-hybrid-vector-doc-field-index";
private static final String TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME = "test-hybrid-multi-doc-nested-fields-index";
private static final String TEST_MULTI_DOC_INDEX_NAME = "test-hybrid-multi-doc-index";
private static final String TEST_LARGE_DOCS_INDEX_NAME = "test-hybrid-large-docs-index";

private static final String TEST_QUERY_TEXT3 = "hello";
private static final String TEST_QUERY_TEXT4 = "place";
Expand Down Expand Up @@ -459,6 +461,64 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe
}
}

@SneakyThrows
public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() {
try {
initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME);
// create search pipeline with both normalization processor and explain response processor
createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true);

TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(termQueryBuilder);

Map<String, Object> searchResponseAsMap = search(
TEST_LARGE_DOCS_INDEX_NAME,
hybridQueryBuilder,
null,
1000,
Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString())
);

List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);
assertNotNull(hitsNestedList);
assertFalse(hitsNestedList.isEmpty());

// Verify total hits
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertTrue((int) total.get("value") > 0);
assertEquals(RELATION_EQUAL_TO, total.get("relation"));

// Sanity checks for each hit's explanation
for (Map<String, Object> hit : hitsNestedList) {
// Verify score is positive
double score = (double) hit.get("_score");
assertTrue("Score should be positive", score > 0.0);

// Basic explanation structure checks
Map<String, Object> explanation = (Map<String, Object>) hit.get("_explanation");
assertNotNull(explanation);
assertEquals("arithmetic_mean combination of:", explanation.get("description"));
Map<String, Object> hitDetailsForHit = getListOfValues(explanation, "details").get(0);
assertTrue((double) hitDetailsForHit.get("value") > 0.0f);
assertEquals("min_max normalization of:", hitDetailsForHit.get("description"));
Map<String, Object> subQueryDetailsForHit = getListOfValues(hitDetailsForHit, "details").get(0);
assertTrue((double) subQueryDetailsForHit.get("value") > 0.0f);
assertFalse(subQueryDetailsForHit.get("description").toString().isEmpty());
assertEquals(1, getListOfValues(subQueryDetailsForHit, "details").size());
}
// Verify scores are properly ordered
List<Double> scores = new ArrayList<>();
for (Map<String, Object> hit : hitsNestedList) {
scores.add((Double) hit.get("_score"));
}
assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1)));
} finally {
wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE);
}
}

@SneakyThrows
private void initializeIndexIfNotExist(String indexName) {
if (TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)) {
Expand Down Expand Up @@ -521,6 +581,43 @@ private void initializeIndexIfNotExist(String indexName) {
);
addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME);
}

if (TEST_LARGE_DOCS_INDEX_NAME.equals(indexName) && !indexExists(TEST_LARGE_DOCS_INDEX_NAME)) {
prepareKnnIndex(
TEST_LARGE_DOCS_INDEX_NAME,
List.of(
new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE),
new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_2, TEST_DIMENSION, TEST_SPACE_TYPE)
)
);

// Index 1000 documents
for (int i = 0; i < 1000; i++) {
String docText;
if (i % 5 == 0) {
docText = TEST_DOC_TEXT1; // "Hello world"
} else if (i % 7 == 0) {
docText = TEST_DOC_TEXT2; // "Hi to this place"
} else if (i % 11 == 0) {
docText = TEST_DOC_TEXT3; // "We would like to welcome everyone"
} else {
docText = String.format(Locale.ROOT, "Document %d with random content", i);
}

addKnnDoc(
TEST_LARGE_DOCS_INDEX_NAME,
String.valueOf(i),
List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2),
List.of(
Floats.asList(createRandomVector(TEST_DIMENSION)).toArray(),
Floats.asList(createRandomVector(TEST_DIMENSION)).toArray()
),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(docText)
);
}
assertEquals(1000, getDocCount(TEST_LARGE_DOCS_INDEX_NAME));
}
}

private void addDocsToIndex(final String testMultiDocIndexName) {
Expand Down

0 comments on commit d297e0f

Please sign in to comment.