Skip to content

Commit

Permalink
Expose Tantivy's TermSetQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
aecio committed Apr 24, 2024
1 parent 41f72b2 commit 90619e4
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,32 @@ impl Query {
})
}

/// Construct a Tantivy's TermSetQuery
#[staticmethod]
#[pyo3(signature = (schema, field_name, field_values))]
pub(crate) fn term_set_query(
schema: &Schema,
field_name: &str,
field_values: Vec<&PyAny>,
) -> PyResult<Query> {
let mut terms: Vec<tv::Term> = Vec::new();
for field_value in field_values {
let term = make_term(&schema.inner, field_name, &field_value);
if let Ok(term) = term {
terms.push(term);
} else {
return Err(exceptions::PyTypeError::new_err(format!(
"Cannot create a term from the value: {}",
field_value
)));
}
}
let inner = tv::query::TermSetQuery::new(terms);
Ok(Query {
inner: Box::new(inner),
})
}

/// Construct a Tantivy's AllQuery
#[staticmethod]
pub(crate) fn all_query() -> PyResult<Query> {
Expand Down
4 changes: 4 additions & 0 deletions tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ class Query:
def term_query(schema: Schema, field_name: str, field_value: Any, index_option: str = "position") -> Query:
pass

@staticmethod
def term_set_query(schema: Schema, field_name: str, field_values: Sequence[Any]) -> Query:
pass

@staticmethod
def all_query() -> Query:
pass
Expand Down
29 changes: 29 additions & 0 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,35 @@ def test_term_query(self, ram_index):
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["The Old Man and the Sea"]

def test_term_set_query(self, ram_index):
index = ram_index

# Should match 1 document that contains both terms
terms = ["old", "man"]
query = Query.term_set_query(index.schema, "title", terms)
result = index.searcher().search(query, 10)
assert len(result.hits) == 1
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["The Old Man and the Sea"]

# Should not match any document since the term does not exist in the index
terms = ["a long term that does not exist in the index"]
query = Query.term_set_query(index.schema, "title", terms)
result = index.searcher().search(query, 10)
assert len(result.hits) == 0

# Should not match any document when the terms list is empty
terms = []
query = Query.term_set_query(index.schema, "title", terms)
result = index.searcher().search(query, 10)
assert len(result.hits) == 0

# Should fail to create the query due to the invalid list object in the terms list
with pytest.raises(TypeError, match = r"Cannot create a term from the value: \[\]"):
terms = ["old", [], "man"]
query = Query.term_set_query(index.schema, "title", terms)

def test_all_query(self, ram_index):
index = ram_index
query = Query.all_query()
Expand Down

0 comments on commit 90619e4

Please sign in to comment.