From 1b0d29bc3c30fb84da6ad00750a904be0c002a1a Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Mon, 13 May 2024 12:37:08 -0700 Subject: [PATCH] Fix: Issue with search dialect 3 and JSON (resolves #140) --- redisvl/index/index.py | 40 +++++++++- tests/integration/test_dialects.py | 118 +++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 tests/integration/test_dialects.py diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 2162f4de..88124224 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -32,6 +32,38 @@ logger = get_logger(__name__) +def _handle_dialect_3(result: Dict[str, Any]) -> Dict[str, Any]: + """ + Handle dialect 3 responses by converting JSON-encoded list values to strings. + + Each JSON-encoded string in the result that is a list will be converted: + - If the list has one item, it is unpacked. + - If the list has multiple items, they are joined into a single comma-separated string. + + Args: + result (Dict[str, Any]): The dictionary containing the results to process. + + Returns: + Dict[str, Any]: The processed dictionary with updated values. + """ + for field, value in result.items(): + if isinstance(value, str): + try: + parsed_value = json.loads(value) + except json.JSONDecodeError: + continue # Skip processing if value is not valid JSON + + if isinstance(parsed_value, list): + # Use a single value if the list contains only one item, else join all items. + result[field] = ( + parsed_value[0] + if len(parsed_value) == 1 + else ", ".join(map(str, parsed_value)) + ) + + return result + + def process_results( results: "Result", query: BaseQuery, storage_type: StorageType ) -> List[Dict[str, Any]]: @@ -81,7 +113,13 @@ def _process(doc: "Document") -> Dict[str, Any]: return doc_dict - return [_process(doc) for doc in results.docs] + processed_results = [_process(doc) for doc in results.docs] + + # Handle dialect 3 responses + if query._dialect == 3: + processed_results = [_handle_dialect_3(result) for result in processed_results] + + return processed_results def check_modules_present(): diff --git a/tests/integration/test_dialects.py b/tests/integration/test_dialects.py new file mode 100644 index 00000000..55184163 --- /dev/null +++ b/tests/integration/test_dialects.py @@ -0,0 +1,118 @@ +import json + +import pytest +from redis import Redis +from redis.commands.search.query import Query + +from redisvl.index import SearchIndex +from redisvl.query import FilterQuery, VectorQuery +from redisvl.query.filter import Tag +from redisvl.schema.schema import IndexSchema + + +@pytest.fixture +def sample_data(): + return [ + { + "name": "Noise-cancelling Bluetooth headphones", + "description": "Wireless Bluetooth headphones with noise-cancelling technology", + "connection": {"wireless": True, "type": "Bluetooth"}, + "price": 99.98, + "stock": 25, + "colors": ["black", "silver"], + "embedding": [0.87, -0.15, 0.55, 0.03], + "embeddings": [[0.56, -0.34, 0.69, 0.02], [0.94, -0.23, 0.45, 0.19]], + }, + { + "name": "Wireless earbuds", + "description": "Wireless Bluetooth in-ear headphones", + "connection": {"wireless": True, "type": "Bluetooth"}, + "price": 64.99, + "stock": 17, + "colors": ["red", "black", "white"], + "embedding": [-0.7, -0.51, 0.88, 0.14], + "embeddings": [[0.54, -0.14, 0.79, 0.92], [0.94, -0.93, 0.45, 0.16]], + }, + ] + + +@pytest.fixture +def schema_dict(): + return { + "index": {"name": "products", "prefix": "product", "storage_type": "json"}, + "fields": [ + {"name": "name", "type": "text"}, + {"name": "description", "type": "text"}, + {"name": "connection_type", "path": "$.connection.type", "type": "tag"}, + {"name": "price", "type": "numeric"}, + {"name": "stock", "type": "numeric"}, + {"name": "color", "path": "$.colors.*", "type": "tag"}, + { + "name": "embedding", + "type": "vector", + "attrs": {"dims": 4, "algorithm": "flat", "distance_metric": "cosine"}, + }, + { + "name": "embeddings", + "path": "$.embeddings[*]", + "type": "vector", + "attrs": {"dims": 4, "algorithm": "hnsw", "distance_metric": "l2"}, + }, + ], + } + + +@pytest.fixture +def index(sample_data, redis_url, schema_dict): + index_schema = IndexSchema.from_dict(schema_dict) + redis_client = Redis.from_url(redis_url) + index = SearchIndex(index_schema, redis_client) + index.create(overwrite=True, drop=True) + index.load(sample_data) + yield index + index.delete(drop=True) + + +def test_dialect_3_json(index, sample_data): + # Create a VectorQuery with dialect 3 + vector_query = VectorQuery( + vector=[0.23, 0.12, -0.03, 0.98], + vector_field_name="embedding", + return_fields=["name", "description", "price"], + dialect=3, + ) + + # Execute the query + results = index.query(vector_query) + + # Print the results + print("VectorQuery Results:") + print(results) + + # Assert the expected format of the results + assert len(results) > 0 + for result in results: + assert not isinstance(result["name"], list) + assert not isinstance(result["description"], list) + assert not isinstance(result["price"], (list, str)) + + # Create a FilterQuery with dialect 3 + filter_query = FilterQuery( + filter_expression=Tag("color") == "black", + return_fields=["name", "description", "price"], + dialect=3, + ) + + # Execute the query + results = index.query(filter_query) + + # Print the results + print("FilterQuery Results:") + print(results) + + # Assert the expected format of the results + assert len(results) > 0 + for result in results: + assert not isinstance(result["name"], list) + assert not isinstance(result["description"], list) + assert not isinstance(result["price"], (list, str))