Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode search results at field level #3309

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ urllib3<2
uvloop
vulture>=2.3.0
wheel>=0.30.0
numpy>=1.24.0
5 changes: 3 additions & 2 deletions docs/examples/search_vector_similarity_examples.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions redis/commands/search/_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def to_string(s):
def to_string(s, encoding: str = "utf-8"):
if isinstance(s, str):
return s
elif isinstance(s, bytes):
return s.decode("utf-8", "ignore")
return s.decode(encoding, "ignore")
else:
return s # Not a string we care about
1 change: 1 addition & 0 deletions redis/commands/search/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _parse_search(self, res, **kwargs):
duration=kwargs["duration"],
has_payload=kwargs["query"]._with_payloads,
with_scores=kwargs["query"]._with_scores,
field_encodings=kwargs["query"]._return_fields_decode_as,
)

def _parse_aggregate(self, res, **kwargs):
Expand Down
23 changes: 19 additions & 4 deletions redis/commands/search/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, query_string: str) -> None:
self._in_order: bool = False
self._sortby: Optional[SortbyField] = None
self._return_fields: List = []
self._return_fields_decode_as: dict = {}
self._summarize_fields: List = []
self._highlight_fields: List = []
self._language: Optional[str] = None
Expand All @@ -53,13 +54,27 @@ def limit_ids(self, *ids) -> "Query":

def return_fields(self, *fields) -> "Query":
"""Add fields to return fields."""
self._return_fields += fields
for field in fields:
self.return_field(field)
return self

def return_field(self, field: str, as_field: Optional[str] = None) -> "Query":
"""Add field to return fields (Optional: add 'AS' name
to the field)."""
def return_field(
self,
field: str,
as_field: Optional[str] = None,
decode_field: Optional[bool] = True,
encoding: Optional[str] = "utf8",
) -> "Query":
"""
Add a field to the list of fields to return.

- **field**: The field to include in query results
- **as_field**: The alias for the field
- **decode_field**: Whether to decode the field from bytes to string
- **encoding**: The encoding to use when decoding the field
"""
self._return_fields.append(field)
self._return_fields_decode_as[field] = encoding if decode_field else None
if as_field is not None:
self._return_fields += ("AS", as_field)
return self
Expand Down
44 changes: 29 additions & 15 deletions redis/commands/search/result.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from ._util import to_string
from .document import Document

Expand All @@ -9,11 +11,19 @@ class Result:
"""

def __init__(
self, res, hascontent, duration=0, has_payload=False, with_scores=False
self,
res,
hascontent,
duration=0,
has_payload=False,
with_scores=False,
field_encodings: Optional[dict] = None,
):
"""
- **snippets**: An optional dictionary of the form
{field: snippet_size} for snippet formatting
- duration: the execution time of the query
- has_payload: whether the query has payloads
- with_scores: whether the query has scores
- field_encodings: a dictionary of field encodings if any is provided
"""

self.total = res[0]
Expand All @@ -39,18 +49,22 @@ def __init__(

fields = {}
if hascontent and res[i + fields_offset] is not None:
fields = (
dict(
dict(
zip(
map(to_string, res[i + fields_offset][::2]),
map(to_string, res[i + fields_offset][1::2]),
)
)
)
if hascontent
else {}
)
keys = map(to_string, res[i + fields_offset][::2])
values = res[i + fields_offset][1::2]

for key, value in zip(keys, values):
if field_encodings is None or key not in field_encodings:
fields[key] = to_string(value)
gerzse marked this conversation as resolved.
Show resolved Hide resolved
continue

encoding = field_encodings[key]

# If the encoding is None, we don't need to decode the value
if encoding is None:
fields[key] = value
else:
fields[key] = to_string(value, encoding=encoding)

try:
del fields["id"]
except KeyError:
Expand Down
75 changes: 73 additions & 2 deletions tests/test_asyncio/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@
import time
from io import TextIOWrapper

import numpy as np
import pytest
import pytest_asyncio
import redis.asyncio as redis
import redis.commands.search
import redis.commands.search.aggregation as aggregations
import redis.commands.search.reducers as reducers
from redis.commands.search import AsyncSearch
from redis.commands.search.field import GeoField, NumericField, TagField, TextField
from redis.commands.search.indexDefinition import IndexDefinition
from redis.commands.search.field import (
GeoField,
NumericField,
TagField,
TextField,
VectorField,
)
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import GeoFilter, NumericFilter, Query
from redis.commands.search.result import Result
from redis.commands.search.suggestion import Suggestion
from tests.conftest import (
assert_resp_response,
is_resp2_connection,
skip_if_redis_enterprise,
skip_if_resp_version,
skip_ifmodversion_lt,
)

Expand All @@ -37,6 +45,11 @@ async def decoded_r(create_redis, stack_url):
return await create_redis(decode_responses=True, url=stack_url)


@pytest_asyncio.fixture()
async def binary_client(create_redis, stack_url):
return await create_redis(decode_responses=False, url=stack_url)


async def waitForIndex(env, idx, timeout=None):
delay = 0.1
while True:
Expand Down Expand Up @@ -1560,3 +1573,61 @@ async def test_query_timeout(decoded_r: redis.Redis):
q2 = Query("foo").timeout("not_a_number")
with pytest.raises(redis.ResponseError):
await decoded_r.ft().search(q2)


@pytest.mark.redismod
@skip_if_resp_version(3)
gerzse marked this conversation as resolved.
Show resolved Hide resolved
async def test_binary_and_text_fields(binary_client):
assert (
binary_client.get_connection_kwargs()["decode_responses"] is False
), "This feature is only available when decode_responses is False"

fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)

index_name = "mixed_index"
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
await binary_client.hset(f"{index_name}:1", mapping=mixed_data)

schema = (
TagField("first_name"),
VectorField(
"embeddings_bio",
algorithm="HNSW",
attributes={
"TYPE": "FLOAT32",
"DIM": 4,
"DISTANCE_METRIC": "COSINE",
},
),
)

await binary_client.ft(index_name).create_index(
fields=schema,
definition=IndexDefinition(
prefix=[f"{index_name}:"], index_type=IndexType.HASH
),
)

bytes_person_1 = await binary_client.hget(f"{index_name}:1", "vector_emb")
decoded_vec_from_hash = np.frombuffer(bytes_person_1, dtype=np.float32)
assert np.array_equal(decoded_vec_from_hash, fake_vec), "The vectors are not equal"

query = (
Query("*")
.return_field("vector_emb", decode_field=False)
.return_field("first_name", decode_field=True)
)
result = await binary_client.ft(index_name).search(query=query, query_params={})
docs = result.docs

decoded_vec_from_search_results = np.frombuffer(
docs[0]["vector_emb"], dtype=np.float32
)

assert np.array_equal(
decoded_vec_from_search_results, fake_vec
), "The vectors are not equal"

assert (
docs[0]["first_name"] == mixed_data["first_name"]
), "The first is not decoded correctly"
65 changes: 65 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from io import TextIOWrapper

import numpy as np
import pytest
import redis
import redis.commands.search
Expand All @@ -29,6 +30,7 @@
assert_resp_response,
is_resp2_connection,
skip_if_redis_enterprise,
skip_if_resp_version,
skip_ifmodversion_lt,
)

Expand Down Expand Up @@ -113,6 +115,13 @@ def client(request, stack_url):
return r


@pytest.fixture
def binary_client(request, stack_url):
r = _get_client(redis.Redis, request, decode_responses=False, from_url=stack_url)
r.flushdb()
return r


@pytest.mark.redismod
def test_client(client):
num_docs = 500
Expand Down Expand Up @@ -1705,6 +1714,62 @@ def test_search_return_fields(client):
assert "telmatosaurus" == total["results"][0]["extra_attributes"]["txt"]


@pytest.mark.redismod
@skip_if_resp_version(3)
gerzse marked this conversation as resolved.
Show resolved Hide resolved
def test_binary_and_text_fields(binary_client):
assert (
binary_client.get_connection_kwargs()["decode_responses"] is False
), "This feature is only available when decode_responses is False"

fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)

index_name = "mixed_index"
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
binary_client.hset(f"{index_name}:1", mapping=mixed_data)

schema = (
TagField("first_name"),
VectorField(
"embeddings_bio",
algorithm="HNSW",
attributes={
"TYPE": "FLOAT32",
"DIM": 4,
"DISTANCE_METRIC": "COSINE",
},
),
)

binary_client.ft(index_name).create_index(
fields=schema,
definition=IndexDefinition(
prefix=[f"{index_name}:"], index_type=IndexType.HASH
),
)

bytes_person_1 = binary_client.hget(f"{index_name}:1", "vector_emb")
decoded_vec_from_hash = np.frombuffer(bytes_person_1, dtype=np.float32)
assert np.array_equal(decoded_vec_from_hash, fake_vec), "The vectors are not equal"

query = (
Query("*")
.return_field("vector_emb", decode_field=False)
.return_field("first_name", decode_field=True)
)
docs = binary_client.ft(index_name).search(query=query, query_params={}).docs
decoded_vec_from_search_results = np.frombuffer(
docs[0]["vector_emb"], dtype=np.float32
)

assert np.array_equal(
decoded_vec_from_search_results, fake_vec
), "The vectors are not equal"

assert (
docs[0]["first_name"] == mixed_data["first_name"]
), "The first is not decoded correctly"

gerzse marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.redismod
def test_synupdate(client):
definition = IndexDefinition(index_type=IndexType.HASH)
Expand Down
Loading