From f83dae4a91de96be9f9c3c6e93f80674919122ef Mon Sep 17 00:00:00 2001 From: Morgan Gallant Date: Wed, 14 Aug 2024 14:45:22 -0700 Subject: [PATCH] add support for GET schema Signed-off-by: Morgan Gallant --- tests/test_schema.py | 21 ++++++++++++++ turbopuffer/namespace.py | 61 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 tests/test_schema.py diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..07ff443 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,21 @@ +import turbopuffer as tpuf +import tests + +def test_schema(): + ns = tpuf.Namespace(tests.test_prefix + "schema") + + # Upsert some data + ns.upsert([ + {'id': 2, 'vector': [2, 2]}, + {'id': 7, 'vector': [0.7, 0.7], 'attributes': {'hello': 'world', 'test': 'rows'}}, + ], distance_metric='euclidean_squared') + + # Get the schema for the namespace + schema = ns.schema() + for attr in ['hello', 'test']: + assert schema.get(attr).type == "OptString" + assert schema.get(attr).filterable + assert schema.get(attr).bm25 is None + + # todo patch schema + diff --git a/turbopuffer/namespace.py b/turbopuffer/namespace.py index 4bc4096..4bb7f6b 100644 --- a/turbopuffer/namespace.py +++ b/turbopuffer/namespace.py @@ -8,6 +8,58 @@ from typing import Dict, List, Optional, Iterable, Union, overload import turbopuffer as tpuf +class BM25Params: + """ + Used for configuring BM25 full-text indexing for a given attribute. + """ + + language: str + stemming: bool + remove_stopwords: bool + case_sensitive: bool + + def __init__(self, language: str, stemming: bool, remove_stopwords: bool, case_sensitive: bool): + self.language = language + self.stemming = stemming + self.remove_stopwords = remove_stopwords + self.case_sensitive = case_sensitive + +class AttributeSchema: + """ + The schema for a particular attribute within a namespace. + """ + + type: str # one of: '?string', '?uint', '[]string', '[]uint' + filterable: bool + bm25: Optional[BM25Params] = None + + def __init__(self, type: str, filterable: bool, bm25: Optional[BM25Params] = None): + self.type = type + self.filterable = filterable + self.bm25 = bm25 + +# Type alias for a namespace schema +NamespaceSchema = Dict[str, AttributeSchema] + +def parse_namespace_schema(data: dict) -> NamespaceSchema: + namespace_schema = {} + for key, value in data.items(): + bm25_params = value.get('bm25') + bm25_instance = None + if bm25_params: + bm25_instance = BM25Params( + language=bm25_params['language'], + stemming=bm25_params['stemming'], + remove_stopwords=bm25_params['remove_stopwords'], + case_sensitive=bm25_params['case_sensitive'] + ) + attribute_schema = AttributeSchema( + type=value['type'], + filterable=value['filterable'], + bm25=bm25_instance + ) + namespace_schema[key] = attribute_schema + return namespace_schema class Namespace: """ @@ -63,6 +115,13 @@ def refresh_metadata(self): } else: raise APIError(response.status_code, 'Unexpected status code', response.get('content')) + + def schema(self) -> NamespaceSchema: + """ + Returns the current schema for the namespace. + """ + response = self.backend.make_api_request('vectors', self.name, 'schema', method='GET') + return parse_namespace_schema(response["content"]) def exists(self) -> bool: """ @@ -354,7 +413,7 @@ def recall(self, num=20, top_k=10) -> float: content = response.get('content', dict()) assert 'recall' in content, f'Invalid recall() response: {response}' return float(content.get('recall')) - + class NamespaceIterator: """