Skip to content

Commit

Permalink
Fix KNNQuery Precision
Browse files Browse the repository at this point in the history
Signed-off-by: luyuncheng <[email protected]>
  • Loading branch information
luyuncheng committed Mar 1, 2024
1 parent 045c805 commit 4436e89
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import java.math.BigDecimal;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.opensearch.client.json.JsonpDeserializable;
Expand All @@ -19,6 +20,8 @@
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;

import static java.math.RoundingMode.HALF_UP;

@JsonpDeserializable
public class KnnQuery extends QueryBase implements QueryVariant {
private final String field;
Expand Down Expand Up @@ -93,7 +96,9 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeKey("vector");
generator.writeStartArray();
for (float value : this.vector) {
generator.write(value);
BigDecimal b = new BigDecimal(value);
double T = b.setScale(6, HALF_UP).doubleValue();
generator.write(T);
}
generator.writeEnd();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,11 @@ public void toBuilder() {

assertEquals(toJson(copied), toJson(origin));
}

@Test
public void toBuilderPrecision() {
KnnQuery origin = new KnnQuery.Builder().field("field").vector(new float[] { 0.1f, 0.4f }).k(1).build();

assertEquals(toJson(origin), "{\"field\":{\"vector\":[0.1,0.4],\"k\":1}}");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
package org.opensearch.client.opensearch.model;

import org.junit.Test;
import org.opensearch.client.json.jackson.JacksonJsonpMapper;
import org.opensearch.client.json.jsonb.JsonbJsonpMapper;
import org.opensearch.client.opensearch.core.SearchRequest;

Expand Down Expand Up @@ -61,4 +62,22 @@ public void testParametersNotInJson() {
assertNull(request.q());

}

@Test
public void testKnnVectorPrecision() {

float[] vector = {0.4f, 0.3f};
SearchRequest request = new SearchRequest.Builder().q("knn")
.query(q -> q.knn(k -> k.field("values").vector(vector).k(1)))
.build();

JacksonJsonpMapper mapper = new JacksonJsonpMapper();
String str = toJson(request, mapper);
assertEquals("{\"query\":{\"knn\":{\"values\":{\"vector\":[0.4,0.3],\"k\":1}}}}", str);

request = fromJson(str, SearchRequest.class, mapper);

assertTrue(request.query().isKnn());
assertNull(request.q());
}
}

0 comments on commit 4436e89

Please sign in to comment.