Skip to content

Commit

Permalink
Enabled the IVF algorithm to work with Filters of K-NN Query. (#1013) (
Browse files Browse the repository at this point in the history
…#1015)

Signed-off-by: Navneet Verma <[email protected]>
(cherry picked from commit 85c5a3a)

Co-authored-by: Navneet Verma <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and navneet1v authored Aug 1, 2023
1 parent 591fff6 commit d6e166f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.9...2.x)
### Features
### Enhancements
* Enabled the IVF algorithm to work with Filters of K-NN Query. [#1013](https://github.com/opensearch-project/k-NN/pull/1013)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
20 changes: 19 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.search.FilteredDocIdSetIterator;
Expand Down Expand Up @@ -49,6 +50,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -290,7 +292,7 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderCont
float[] queryVector = this.knnQuery.getQueryVector();
try {
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.getName());
final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));
final SpaceType spaceType = getSpaceType(fieldInfo);
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(this.knnQuery.getK(), true);
ScoreDoc topDoc = queue.top();
Expand Down Expand Up @@ -351,4 +353,20 @@ public static float normalizeScore(float score) {
if (score >= 0) return 1 / (1 + score);
return -score + 1;
}

private SpaceType getSpaceType(final FieldInfo fieldInfo) {
final String spaceTypeString = fieldInfo.getAttribute(SPACE_TYPE);
if (StringUtils.isNotEmpty(spaceTypeString)) {
return SpaceType.getSpace(spaceTypeString);
}

final String modelId = fieldInfo.getAttribute(MODEL_ID);
if (StringUtils.isNotEmpty(modelId)) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
return modelMetadata.getSpaceType();
}
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Unable to find the Space Type from Field Info attribute for field %s", fieldInfo.getName())
);
}
}
35 changes: 31 additions & 4 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public void testDocDeletion() throws IOException {
deleteKnnDoc(INDEX_NAME, "1");
}

public void testEndToEnd_fromModel() throws IOException, InterruptedException {
public void testKNNQuery_withModelDifferentCombination_thenSuccess() throws IOException, InterruptedException {
String modelId = "test-model";
int dimension = 128;

Expand Down Expand Up @@ -270,10 +270,9 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException {
// Index some documents
int numDocs = 100;
for (int i = 0; i < numDocs; i++) {
Float[] indexVector = new Float[dimension];
float[] indexVector = new float[dimension];
Arrays.fill(indexVector, (float) i);

addKnnDoc(indexName, Integer.toString(i), fieldName, indexVector);
addKnnDocWithAttributes(indexName, Integer.toString(i), fieldName, indexVector, ImmutableMap.of("rating", String.valueOf(i)));
}

// Run search and ensure that the values returned are expected
Expand All @@ -287,6 +286,34 @@ public void testEndToEnd_fromModel() throws IOException, InterruptedException {
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId()));
}

// doing exact search with filters
Response exactSearchFilteredResponse = searchKNNIndex(
indexName,
new KNNQueryBuilder(fieldName, queryVector, k, QueryBuilders.rangeQuery("rating").gte("90").lte("99")),
k
);
List<KNNResult> exactSearchFilteredResults = parseSearchResponse(
EntityUtils.toString(exactSearchFilteredResponse.getEntity()),
fieldName
);
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(exactSearchFilteredResults.get(i).getDocId()));
}

// doing exact search with filters
Response aNNSearchFilteredResponse = searchKNNIndex(
indexName,
new KNNQueryBuilder(fieldName, queryVector, k, QueryBuilders.rangeQuery("rating").gte("80").lte("99")),
k
);
List<KNNResult> aNNSearchFilteredResults = parseSearchResponse(
EntityUtils.toString(aNNSearchFilteredResponse.getEntity()),
fieldName
);
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(aNNSearchFilteredResults.get(i).getDocId()));
}
}

@SneakyThrows
Expand Down
22 changes: 22 additions & 0 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -1322,4 +1322,26 @@ protected void addKnnDocWithAttributes(String docId, float[] vector, Map<String,
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}

protected void addKnnDocWithAttributes(
String indexName,
String docId,
String vectorFieldName,
float[] vector,
Map<String, String> fieldValues
) throws IOException {
Request request = new Request("POST", "/" + indexName + "/_doc/" + docId + "?refresh=true");

final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(vectorFieldName, vector);
for (String fieldName : fieldValues.keySet()) {
builder.field(fieldName, fieldValues.get(fieldName));
}
builder.endObject();
request.setJsonEntity(Strings.toString(builder));
client().performRequest(request);

request = new Request("POST", "/" + indexName + "/_refresh");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
}
}

0 comments on commit d6e166f

Please sign in to comment.