Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support phrase match query #38869

Merged
merged 18 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/core/src/common/EasyAssert.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ enum ErrorCode {
OutOfRange = 2039,
GcpNativeError = 2040,
TextIndexNotFound = 2041,
InvalidParameter = 2042,

KnowhereError = 2099
};
Expand Down
34 changes: 30 additions & 4 deletions internal/core/src/exec/expression/UnaryExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,12 @@
template <typename T>
VectorPtr
PhyUnaryRangeFilterExpr::ExecRangeVisitorImpl(OffsetVector* input) {
if (expr_->op_type_ == proto::plan::OpType::TextMatch) {
if (expr_->op_type_ == proto::plan::OpType::TextMatch ||
expr_->op_type_ == proto::plan::OpType::PhraseMatch) {
if (has_offset_input_) {
PanicInfo(
OpTypeInvalid,
fmt::format("text match does not support iterative filter"));
fmt::format("match query does not support iterative filter"));
}
return ExecTextMatch();
}
Expand Down Expand Up @@ -1089,8 +1090,33 @@
PhyUnaryRangeFilterExpr::ExecTextMatch() {
using Index = index::TextMatchIndex;
auto query = GetValueFromProto<std::string>(expr_->val_);
auto func = [](Index* index, const std::string& query) -> TargetBitmap {
return index->MatchQuery(query);
int64_t slop = 0;
if (expr_->op_type_ == proto::plan::PhraseMatch) {
// It should be larger than 0 in normal cases. Check it incase of receiving old version proto.
if (expr_->extra_values_.size() > 0) {
slop = GetValueFromProto<int64_t>(expr_->extra_values_[0]);
}
if (slop < 0 || slop > std::numeric_limits<uint32_t>::max()) {
throw SegcoreError(

Check warning on line 1100 in internal/core/src/exec/expression/UnaryExpr.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/exec/expression/UnaryExpr.cpp#L1100

Added line #L1100 was not covered by tests
ErrorCode::InvalidParameter,
fmt::format(

Check warning on line 1102 in internal/core/src/exec/expression/UnaryExpr.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/exec/expression/UnaryExpr.cpp#L1102

Added line #L1102 was not covered by tests
"Slop {} is invalid in phrase match query. Should be "
"within [0, UINT32_MAX].",
slop));

Check warning on line 1105 in internal/core/src/exec/expression/UnaryExpr.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/exec/expression/UnaryExpr.cpp#L1105

Added line #L1105 was not covered by tests
}
}
auto op_type = expr_->op_type_;
auto func = [op_type, slop](Index* index,
const std::string& query) -> TargetBitmap {
if (op_type == proto::plan::OpType::TextMatch) {
return index->MatchQuery(query);
} else if (op_type == proto::plan::OpType::PhraseMatch) {
return index->PhraseMatchQuery(query, slop);
} else {
PanicInfo(OpTypeInvalid,

Check warning on line 1116 in internal/core/src/exec/expression/UnaryExpr.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/exec/expression/UnaryExpr.cpp#L1116

Added line #L1116 was not covered by tests
"unsupported operator type for match query: {}",
op_type);
}
};
auto res = ProcessTextMatchIndex(func, query);
return res;
Expand Down
26 changes: 21 additions & 5 deletions internal/core/src/expr/ITypeExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,18 +349,33 @@

class UnaryRangeFilterExpr : public ITypeFilterExpr {
public:
explicit UnaryRangeFilterExpr(const ColumnInfo& column,
proto::plan::OpType op_type,
const proto::plan::GenericValue& val)
: ITypeFilterExpr(), column_(column), op_type_(op_type), val_(val) {
explicit UnaryRangeFilterExpr(
const ColumnInfo& column,
proto::plan::OpType op_type,
const proto::plan::GenericValue& val,
const std::vector<proto::plan::GenericValue>& extra_values)
: ITypeFilterExpr(),
column_(column),
op_type_(op_type),
val_(val),
extra_values_(extra_values) {
}

std::string
ToString() const override {
std::stringstream ss;
ss << "UnaryRangeFilterExpr: {columnInfo:" << column_.ToString()
<< " op_type:" << milvus::proto::plan::OpType_Name(op_type_)
<< " val:" << val_.DebugString() << "}";
<< " val:" << val_.DebugString() << " extra_values: [";

Check warning on line 369 in internal/core/src/expr/ITypeExpr.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/expr/ITypeExpr.h#L369

Added line #L369 was not covered by tests

for (size_t i = 0; i < extra_values_.size(); i++) {
ss << extra_values_[i].DebugString();
if (i != extra_values_.size() - 1) {
ss << ", ";

Check warning on line 374 in internal/core/src/expr/ITypeExpr.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/expr/ITypeExpr.h#L371-L374

Added lines #L371 - L374 were not covered by tests
}
}

ss << "]}";

Check warning on line 378 in internal/core/src/expr/ITypeExpr.h

View check run for this annotation

Codecov / codecov/patch

internal/core/src/expr/ITypeExpr.h#L378

Added line #L378 was not covered by tests
return ss.str();
}

Expand Down Expand Up @@ -393,6 +408,7 @@
const ColumnInfo column_;
const proto::plan::OpType op_type_;
const proto::plan::GenericValue val_;
const std::vector<proto::plan::GenericValue> extra_values_;
};

class AlwaysTrueExpr : public ITypeFilterExpr {
Expand Down
21 changes: 21 additions & 0 deletions internal/core/src/index/TextMatchIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,25 @@
apply_hits(bitset, hits, true);
return bitset;
}

TargetBitmap
TextMatchIndex::PhraseMatchQuery(const std::string& query, uint32_t slop) {
if (shouldTriggerCommit()) {
Commit();
Reload();

Check warning on line 299 in internal/core/src/index/TextMatchIndex.cpp

View check run for this annotation

Codecov / codecov/patch

internal/core/src/index/TextMatchIndex.cpp#L298-L299

Added lines #L298 - L299 were not covered by tests
}

// The count opeartion of tantivy may be get older cnt if the index is committed with new tantivy segment.
// So we cannot use the count operation to get the total count for bitmap.
// Just use the maximum offset of hits to get the total count for bitmap here.
auto hits = wrapper_->phrase_match_query(query, slop);
auto cnt = should_allocate_bitset_size(hits);
TargetBitmap bitset(cnt);
if (bitset.empty()) {
return bitset;
}
apply_hits(bitset, hits, true);
return bitset;
}

} // namespace milvus::index
3 changes: 3 additions & 0 deletions internal/core/src/index/TextMatchIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class TextMatchIndex : public InvertedIndexTantivy<std::string> {
TargetBitmap
MatchQuery(const std::string& query);

TargetBitmap
PhraseMatchQuery(const std::string& query, uint32_t slop);

private:
bool
shouldTriggerCommit();
Expand Down
9 changes: 8 additions & 1 deletion internal/core/src/query/PlanProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,15 @@ ProtoParser::ParseUnaryRangeExprs(const proto::plan::UnaryRangeExpr& expr_pb) {
auto field_id = FieldId(column_info.field_id());
auto data_type = schema[field_id].get_data_type();
Assert(data_type == static_cast<DataType>(column_info.data_type()));
std::vector<::milvus::proto::plan::GenericValue> extra_values;
for (auto val : expr_pb.extra_values()) {
extra_values.emplace_back(val);
}
return std::make_shared<milvus::expr::UnaryRangeFilterExpr>(
expr::ColumnInfo(column_info), expr_pb.op(), expr_pb.value());
expr::ColumnInfo(column_info),
expr_pb.op(),
expr_pb.value(),
extra_values);
}

expr::TypedExprPtr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ RustResult tantivy_regex_query(void *ptr, const char *pattern);

RustResult tantivy_match_query(void *ptr, const char *query);

RustResult tantivy_phrase_match_query(void *ptr, const char *query, uint32_t slop);

RustResult tantivy_register_tokenizer(void *ptr,
const char *tokenizer_name,
const char *analyzer_params);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use tantivy::{
query::BooleanQuery,
query::{BooleanQuery, PhraseQuery},
tokenizer::{TextAnalyzer, TokenStream},
Term,
};

use crate::error::Result;
use crate::error::{Result, TantivyBindingError};
use crate::{index_reader::IndexReaderWrapper, tokenizer::standard_analyzer};

impl IndexReaderWrapper {
// split the query string into multiple tokens using index's default tokenizer,
// and then execute the disconjunction of term query.
pub(crate) fn match_query(&self, q: &str) -> Result<Vec<u32>> {
// clone the tokenizer to make `match_query` thread-safe.
let mut tokenizer = self
.index
.tokenizer_for_field(self.field)
Expand All @@ -27,6 +26,31 @@ impl IndexReaderWrapper {
self.search(&query)
}

// split the query string into multiple tokens using index's default tokenizer,
// and then execute the disconjunction of term query.
pub(crate) fn phrase_match_query(&self, q: &str, slop: u32) -> Result<Vec<u32>> {
// clone the tokenizer to make `match_query` thread-safe.
let mut tokenizer = self
.index
.tokenizer_for_field(self.field)
.unwrap_or(standard_analyzer(vec![]))
.clone();
let mut token_stream = tokenizer.token_stream(q);
let mut terms: Vec<Term> = Vec::new();
while token_stream.advance() {
let token = token_stream.token();
terms.push(Term::from_field_text(self.field, &token.text));
}
if terms.len() <= 1 {
// tantivy will panic when terms.len() <= 1, so we forward to text match instead.
let query = BooleanQuery::new_multiterms_query(terms);
return self.search(&query);
}
let terms = terms.into_iter().enumerate().collect();
let phrase_query = PhraseQuery::new_with_offset_and_slop(terms, slop);
self.search(&phrase_query)
}

pub(crate) fn register_tokenizer(&self, tokenizer_name: String, tokenizer: TextAnalyzer) {
self.index.tokenizers().register(&tokenizer_name, tokenizer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ pub extern "C" fn tantivy_match_query(ptr: *mut c_void, query: *const c_char) ->
}
}

#[no_mangle]
pub extern "C" fn tantivy_phrase_match_query(
ptr: *mut c_void,
query: *const c_char,
slop: u32,
) -> RustResult {
let real = ptr as *mut IndexReaderWrapper;
unsafe {
let query = cstr_to_str!(query);
(*real).phrase_match_query(query, slop).into()
}
}

#[no_mangle]
pub extern "C" fn tantivy_register_tokenizer(
ptr: *mut c_void,
Expand Down
13 changes: 13 additions & 0 deletions internal/core/thirdparty/tantivy/tantivy-wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,19 @@ struct TantivyIndexWrapper {
return RustArrayWrapper(std::move(res.result_->value.rust_array._0));
}

RustArrayWrapper
phrase_match_query(const std::string& query, uint32_t slop) {
auto array = tantivy_phrase_match_query(reader_, query.c_str(), slop);
auto res = RustResultWrapper(array);
AssertInfo(res.result_->success,
"TantivyIndexWrapper.phrase_match_query: {}",
res.result_->error);
AssertInfo(
res.result_->value.tag == Value::Tag::RustArray,
"TantivyIndexWrapper.phrase_match_query: invalid result type");
return RustArrayWrapper(std::move(res.result_->value.rust_array._0));
}

public:
inline IndexWriter
get_writer() {
Expand Down
3 changes: 2 additions & 1 deletion internal/core/unittest/test_array_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2461,7 +2461,8 @@ TEST(Expr, TestArrayStringMatch) {
milvus::expr::ColumnInfo(
string_array_fid, DataType::ARRAY, testcase.nested_path),
testcase.op_type,
value);
value,
std::vector<proto::plan::GenericValue>{});
BitsetType final;
auto plan =
std::make_shared<plan::FilterBitsNode>(DEFAULT_PLANNODE_ID, expr);
Expand Down
Loading
Loading