Skip to content

Commit

Permalink
add support for GET schema
Browse files Browse the repository at this point in the history
Signed-off-by: Morgan Gallant <[email protected]>
  • Loading branch information
morgangallant committed Aug 14, 2024
1 parent a63d8fe commit f83dae4
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
21 changes: 21 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import turbopuffer as tpuf
import tests

def test_schema():
ns = tpuf.Namespace(tests.test_prefix + "schema")

# Upsert some data
ns.upsert([
{'id': 2, 'vector': [2, 2]},
{'id': 7, 'vector': [0.7, 0.7], 'attributes': {'hello': 'world', 'test': 'rows'}},
], distance_metric='euclidean_squared')

# Get the schema for the namespace
schema = ns.schema()
for attr in ['hello', 'test']:
assert schema.get(attr).type == "OptString"
assert schema.get(attr).filterable
assert schema.get(attr).bm25 is None

# todo patch schema

61 changes: 60 additions & 1 deletion turbopuffer/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,58 @@
from typing import Dict, List, Optional, Iterable, Union, overload
import turbopuffer as tpuf

class BM25Params:
"""
Used for configuring BM25 full-text indexing for a given attribute.
"""

language: str
stemming: bool
remove_stopwords: bool
case_sensitive: bool

def __init__(self, language: str, stemming: bool, remove_stopwords: bool, case_sensitive: bool):
self.language = language
self.stemming = stemming
self.remove_stopwords = remove_stopwords
self.case_sensitive = case_sensitive

class AttributeSchema:
"""
The schema for a particular attribute within a namespace.
"""

type: str # one of: '?string', '?uint', '[]string', '[]uint'
filterable: bool
bm25: Optional[BM25Params] = None

def __init__(self, type: str, filterable: bool, bm25: Optional[BM25Params] = None):
self.type = type
self.filterable = filterable
self.bm25 = bm25

# Type alias for a namespace schema
NamespaceSchema = Dict[str, AttributeSchema]

def parse_namespace_schema(data: dict) -> NamespaceSchema:
namespace_schema = {}
for key, value in data.items():
bm25_params = value.get('bm25')
bm25_instance = None
if bm25_params:
bm25_instance = BM25Params(
language=bm25_params['language'],
stemming=bm25_params['stemming'],
remove_stopwords=bm25_params['remove_stopwords'],
case_sensitive=bm25_params['case_sensitive']
)
attribute_schema = AttributeSchema(
type=value['type'],
filterable=value['filterable'],
bm25=bm25_instance
)
namespace_schema[key] = attribute_schema
return namespace_schema

class Namespace:
"""
Expand Down Expand Up @@ -63,6 +115,13 @@ def refresh_metadata(self):
}
else:
raise APIError(response.status_code, 'Unexpected status code', response.get('content'))

def schema(self) -> NamespaceSchema:
"""
Returns the current schema for the namespace.
"""
response = self.backend.make_api_request('vectors', self.name, 'schema', method='GET')
return parse_namespace_schema(response["content"])

def exists(self) -> bool:
"""
Expand Down Expand Up @@ -354,7 +413,7 @@ def recall(self, num=20, top_k=10) -> float:
content = response.get('content', dict())
assert 'recall' in content, f'Invalid recall() response: {response}'
return float(content.get('recall'))


class NamespaceIterator:
"""
Expand Down

0 comments on commit f83dae4

Please sign in to comment.