diff --git a/demo/store_demo.cpp b/demo/store_demo.cpp index e6a42aac..d9e70321 100644 --- a/demo/store_demo.cpp +++ b/demo/store_demo.cpp @@ -46,5 +46,12 @@ int main(int argc, const char* argv[]) { for (const auto& text : similar_texts) { std::cout << text << std::endl; } + store->save("./tmp.mnn"); + store.reset(TextVectorStore::load("./tmp.mnn")); + store->set_embedding(embedding); + similar_texts = store->search_similar_texts(text, 2); + for (const auto& text : similar_texts) { + std::cout << text << std::endl; + } return 0; } diff --git a/src/llm.cpp b/src/llm.cpp index 79637832..26c2e9ff 100644 --- a/src/llm.cpp +++ b/src/llm.cpp @@ -791,17 +791,28 @@ VARP Bge::gen_position_ids(int seq_len) { // Embedding end // TextVectorStore strat - TextVectorStore* TextVectorStore::load(const std::string& path) { auto vars = Variable::load(path.c_str()); - return nullptr; - // TODO + if (vars.size() < 2) { + return nullptr; + } + TextVectorStore* store = new TextVectorStore; + store->vectors_ = vars[0]; + for (int i = 1; i < vars.size(); i++) { + const char* txt = vars[i]->readMap(); + store->texts_.push_back(txt); + } + return store; } void TextVectorStore::save(const std::string& path) { std::vector vars; + vars.push_back(vectors_); + for (auto text : texts_) { + auto text_var = _Const(text.data(), {text.size()}, NHWC, halide_type_of()); + vars.push_back(text_var); + } Variable::save(vars, path.c_str()); - // TODO } void TextVectorStore::add_text(const std::string& text) { @@ -812,6 +823,7 @@ void TextVectorStore::add_text(const std::string& text) { } else { vectors_ = _Concat({vectors_, vector}, 0); } + vectors_.fix(VARP::CONSTANT); } void TextVectorStore::add_texts(const std::vector& texts) { @@ -824,8 +836,7 @@ std::vector TextVectorStore::search_similar_texts(const std::string auto vector = text2vector(text); auto dist = _Sqrt(_ReduceSum(_Square(vectors_ - vector), {-1})); auto indices = _Sort(dist, 0, true); - auto ptr = dist->readMap(); - auto iptr = indices->readMap(); + // auto ptr = dist->readMap(); auto idx_ptr = indices->readMap(); std::vector res; for (int i = 0; i < topk; i++) { @@ -848,8 +859,8 @@ void TextVectorStore::bench() { auto vec = _RandomUnifom(shape1, halide_type_of()); auto start = std::chrono::high_resolution_clock::now(); auto dist = _Sqrt(_ReduceSum(_Square(vectors_ - vec), {-1})); - auto ptr = dist->readMap(); auto indices = _Sort(dist, 0, true); + auto ptr = dist->readMap(); auto iptr = indices->readMap(); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start);