Skip to content

Commit

Permalink
fix the issue caused by passing in invalid vectors
Browse files Browse the repository at this point in the history
Signed-off-by: jinjiabao.jjb <[email protected]>
  • Loading branch information
jinjiabao.jjb committed Jan 22, 2025
1 parent b28a8d0 commit 83fbed5
Show file tree
Hide file tree
Showing 10 changed files with 247 additions and 35 deletions.
27 changes: 18 additions & 9 deletions extern/diskann/DiskANN/src/pq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1906,16 +1906,25 @@ void generate_disk_quantized_data(const T* train_data, size_t train_size, size_t
size_t sample_size = std::min(train_size, (size_t)(train_size * p_val));
sample_size = std::max(sample_size, std::min(train_size, (size_t)MIN_SAMPLE_NUM));
sample_size = std::min(sample_size, (size_t)MAX_SAMPLE_NUM);
auto sample_data = train_data;
std::shared_ptr<T[]> new_train_data;
if (compare_metric == diskann::Metric::COSINE) {
new_train_data.reset(new T[train_dim * sample_size]);
memcpy(new_train_data.get(), train_data, train_dim * sample_size * sizeof(T));
for (int i = 0; i < sample_size; ++i)
{
normalize(new_train_data.get() + i * train_dim, train_dim);
std::shared_ptr<T[]> new_train_data = std::shared_ptr<T[]>(new T[train_dim * sample_size]);
size_t valid_size = 0;
for (int i = 0; i < sample_size; ++i)
{
auto norm = get_norm(train_data + i * train_dim, train_dim);
if (std::abs(norm) < 1e-6 || std::isnan(norm)) {
continue;
}
sample_data = new_train_data.get();
memcpy(new_train_data.get() + valid_size * train_dim, train_data + i * train_dim, train_dim * sizeof(T));
if (compare_metric == diskann::Metric::COSINE) {
normalize(new_train_data.get() + valid_size * train_dim, train_dim);
}
valid_size ++;
}
auto sample_data = new_train_data.get();
sample_size = valid_size;
if (sample_size < 2) {
throw std::runtime_error("fail to train pq: sample_size " + std::to_string(sample_size) +
" is too small, while train_size is " + std::to_string(train_size));
}

// diskann::cout << "Training data with " << sample_size << " samples loaded." << std::endl;
Expand Down
103 changes: 84 additions & 19 deletions tests/fixtures/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,37 @@ GenerateRandomDataset(uint64_t dim,
return base;
}

static TestDataset::DatasetPtr
GenerateNanRandomDataset(uint64_t dim, uint64_t count, std::string metric_str = "l2") {
auto base = vsag::Dataset::Make();
bool need_normalize = (metric_str != "cosine");

std::vector<float> vecs =
fixtures::generate_vectors(count, dim, need_normalize, fixtures::RandomValue(0, 564));
std::random_device rd;
std::mt19937 g(rd());
std::uniform_real_distribution real;
for (int i = 0; i < count; ++i) {
float r = real(g);
if (r < 0.01) {
vecs[i * dim] = std::numeric_limits<float>::quiet_NaN();
} else if (r < 0.02) {
for (int j = 0; j < dim; ++j) {
vecs[i * dim + j] = 0.0f;
}
}
}

std::vector<int64_t> ids(count);
std::iota(ids.begin(), ids.end(), 10086);
base->Dim(dim)
->Ids(CopyVector(ids))
->Float32Vectors(CopyVector(vecs))
->NumElements(count)
->Owner(true);
return base;
}

static std::pair<float*, int64_t*>
CalDistanceFloatMetrix(const vsag::DatasetPtr query,
const vsag::DatasetPtr base,
Expand Down Expand Up @@ -238,35 +269,69 @@ CalGroundTruthWithPath(const std::pair<float*, int64_t*>& result,
return gt;
}

TestDataset::TestDataset(uint64_t dim, uint64_t count, std::string metric_str, bool with_path)
: dim_(dim), count_(count) {
this->base_ = GenerateRandomDataset(dim, count, metric_str);
TestDatasetPtr
TestDataset::CreateTestDataset(uint64_t dim,
uint64_t count,
std::string metric_str,
bool with_path) {
TestDatasetPtr dataset = std::shared_ptr<TestDataset>(new TestDataset);
dataset->dim_ = dim;
dataset->count_ = count;
dataset->base_ = GenerateRandomDataset(dim, count, metric_str);
constexpr uint64_t query_count = 100;
this->query_ = GenerateRandomDataset(dim, query_count, metric_str, true);
this->filter_query_ = query_;
this->range_query_ = query_;
dataset->query_ = GenerateRandomDataset(dim, query_count, metric_str, true);
dataset->filter_query_ = dataset->query_;
dataset->range_query_ = dataset->query_;
{
auto result = CalDistanceFloatMetrix(query_, base_, metric_str);
this->top_k = 10;
auto result = CalDistanceFloatMetrix(dataset->query_, dataset->base_, metric_str);
dataset->top_k = 10;

this->filter_function_ = [](int64_t id) -> bool { return id % 7 != 5; };
dataset->filter_function_ = [](int64_t id) -> bool { return id % 7 != 5; };
if (with_path) {
this->ground_truth_ = CalGroundTruthWithPath(result, top_k, base_, query_);
this->filter_ground_truth_ =
CalGroundTruthWithPath(result, top_k, base_, query_, this->filter_function_);
dataset->ground_truth_ =
CalGroundTruthWithPath(result, dataset->top_k, dataset->base_, dataset->query_);
dataset->filter_ground_truth_ = CalGroundTruthWithPath(
result, dataset->top_k, dataset->base_, dataset->query_, dataset->filter_function_);
} else {
this->ground_truth_ = CalTopKGroundTruth(result, top_k, count, query_count);
this->filter_ground_truth_ =
CalFilterGroundTruth(result, top_k, this->filter_function_, count, query_count);
dataset->ground_truth_ = CalTopKGroundTruth(result, dataset->top_k, count, query_count);
dataset->filter_ground_truth_ = CalFilterGroundTruth(
result, dataset->top_k, dataset->filter_function_, count, query_count);
}
this->range_ground_truth_ = this->ground_truth_;
this->range_radius_.resize(query_count);
dataset->range_ground_truth_ = dataset->ground_truth_;
dataset->range_radius_.resize(query_count);
for (uint64_t i = 0; i < query_count; ++i) {
this->range_radius_[i] =
0.5f * (result.first[i * count + top_k] + result.first[i * count + top_k - 1]);
dataset->range_radius_[i] = 0.5f * (result.first[i * count + dataset->top_k] +
result.first[i * count + dataset->top_k - 1]);
}
delete[] result.first;
delete[] result.second;
}
return dataset;
}

TestDatasetPtr
TestDataset::CreateNanDataset(const std::string& metric_str) {
TestDatasetPtr dataset = std::shared_ptr<TestDataset>(new TestDataset);
dataset->dim_ = 256;
dataset->count_ = 1000;
constexpr uint64_t query_count = 100;
dataset->base_ = GenerateNanRandomDataset(dataset->dim_, dataset->count_, metric_str);
dataset->query_ = GenerateNanRandomDataset(dataset->dim_, query_count, metric_str);
{
auto result = CalDistanceFloatMetrix(dataset->query_, dataset->base_, metric_str);
dataset->top_k = 10;
dataset->ground_truth_ =
CalTopKGroundTruth(result, dataset->top_k, dataset->count_, query_count);
dataset->range_ground_truth_ = dataset->ground_truth_;
dataset->range_radius_.resize(query_count);
for (uint64_t i = 0; i < query_count; ++i) {
dataset->range_radius_[i] =
dataset->ground_truth_->GetDistances()[i * dataset->top_k + dataset->top_k - 1];
}
delete[] result.first;
delete[] result.second;
}
return dataset;
}

} // namespace fixtures
19 changes: 13 additions & 6 deletions tests/fixtures/test_dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ class TestDataset {
public:
using DatasetPtr = vsag::DatasetPtr;

TestDataset(uint64_t dim,
uint64_t count,
std::string metric_str = "l2",
bool with_path = false);
static std::shared_ptr<TestDataset>
CreateTestDataset(uint64_t dim,
uint64_t count,
std::string metric_str = "l2",
bool with_path = false);

static std::shared_ptr<TestDataset>
CreateNanDataset(const std::string& metric_str);

DatasetPtr base_{nullptr};

Expand All @@ -44,8 +48,11 @@ class TestDataset {
DatasetPtr filter_ground_truth_{nullptr};
std::function<bool(int64_t)> filter_function_{nullptr};

const uint64_t dim_;
const uint64_t count_;
uint64_t dim_{0};
uint64_t count_{0};

private:
TestDataset() = default;
};

using TestDatasetPtr = std::shared_ptr<TestDataset>;
Expand Down
11 changes: 10 additions & 1 deletion tests/fixtures/test_dataset_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ TestDatasetPool::GetDatasetAndCreate(uint64_t dim,
auto key = key_gen(dim, count, metric_str, with_path);
if (this->pool_.find(key) == this->pool_.end()) {
this->dim_counts_.emplace_back(dim, count);
this->pool_[key] = std::make_shared<TestDataset>(dim, count, metric_str, with_path);
this->pool_[key] = TestDataset::CreateTestDataset(dim, count, metric_str, with_path);
}
return this->pool_.at(key);
}
Expand All @@ -25,4 +25,13 @@ TestDatasetPool::key_gen(int64_t dim,
return std::to_string(dim) + "_" + std::to_string(count) + "_" + metric_str + "_" +
std::to_string(with_path);
}

TestDatasetPtr
TestDatasetPool::GetNanDataset(const std::string& metric_str) {
auto key = NAN_DATASET + metric_str;
if (this->pool_.find(key) == this->pool_.end()) {
this->pool_[key] = TestDataset::CreateNanDataset(metric_str);
}
return this->pool_.at(key);
}
} // namespace fixtures
6 changes: 6 additions & 0 deletions tests/fixtures/test_dataset_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include "test_dataset.h"

namespace fixtures {

static const std::string NAN_DATASET = "nan_dataset";

class TestDatasetPool {
public:
TestDatasetPtr
Expand All @@ -29,6 +32,9 @@ class TestDatasetPool {
const std::string& metric_str = "l2",
bool with_path = false);

TestDatasetPtr
GetNanDataset(const std::string& metric_str);

private:
static std::string
key_gen(int64_t dim,
Expand Down
19 changes: 19 additions & 0 deletions tests/test_diskann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,25 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::DiskANNTestIndex,
vsag::Options::Instance().set_block_size_limit(origin_size);
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::DiskANNTestIndex,
"DiskANN Search with Dirty Vector",
"[ft][diskann]") {
// bug issue #360
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
auto dataset = pool.GetNanDataset(metric_type);
auto dim = dataset->dim_;
const std::string name = "diskann";

vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateDiskANNBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);
TestBuildIndex(index, dataset, true);
TestSearchWithDirtyVector(index, dataset, search_param, true);
vsag::Options::Instance().set_block_size_limit(origin_size);
}

/* FIXME: segmentation fault on some platform
TEST_CASE("DiskAnn OPQ", "[ft][diskann]") {
int dim = 128; // Dimension of the elements
Expand Down
20 changes: 20 additions & 0 deletions tests/test_hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,26 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Add", "[ft][hgra
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex,
"HGraph Search with Dirty Vector",
"[ft][hgraph]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
auto dataset = pool.GetNanDataset(metric_type);
auto dim = dataset->dim_;
const std::string name = "hgraph";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& [base_quantization_str, recall] : test_cases) {
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str);
auto index = TestFactory(name, param, true);
TestBuildIndex(index, dataset, true);
TestSearchWithDirtyVector(index, dataset, search_param, true);
}
vsag::Options::Instance().set_block_size_limit(origin_size);
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Concurrent Add", "[ft][hgraph]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down
19 changes: 19 additions & 0 deletions tests/test_hnsw_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,25 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex,
vsag::Options::Instance().set_block_size_limit(origin_size);
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex,
"HNSW Search with Dirty Vector",
"[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
auto dataset = pool.GetNanDataset(metric_type);
auto dim = dataset->dim_;
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);

vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);
TestBuildIndex(index, dataset, true);
TestSearchWithDirtyVector(index, dataset, search_param, true);
vsag::Options::Instance().set_block_size_limit(origin_size);
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Build", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down
52 changes: 52 additions & 0 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,58 @@ TestIndex::TestSerializeFile(const IndexPtr& index_from,
}
}
}
void
TestIndex::TestSearchWithDirtyVector(const TestIndex::IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success) {
auto queries = dataset->query_;
auto query_count = queries->GetNumElements();
auto dim = queries->GetDim();
auto gts = dataset->ground_truth_;
auto gt_topK = dataset->top_k;
float cur_recall = 0.0f;
auto topk = gt_topK;
int valid_query_count = static_cast<int>(query_count * 0.9);
for (auto i = 0; i < valid_query_count; ++i) {
auto query = vsag::Dataset::Make();
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Owner(false);
auto res = index->KnnSearch(query, topk, search_param);
REQUIRE(res.has_value() == expected_success);
if (!expected_success) {
return;
}
REQUIRE(res.value()->GetDim() == topk);
}

cur_recall = 0.0f;
const auto& radius = dataset->range_radius_;
for (auto i = 0; i < valid_query_count; ++i) {
auto query = vsag::Dataset::Make();
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Owner(false);
if (std::isnan(radius[i])) {
continue;
}
auto res = index->RangeSearch(query, radius[i], search_param);
REQUIRE(res.has_value() == expected_success);
}

for (auto i = (int)(query_count * 0.9); i < query_count; ++i) {
auto query = vsag::Dataset::Make();
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Owner(false);
auto res = index->KnnSearch(query, topk, search_param);
REQUIRE(res.has_value() == expected_success);
}
}

void
TestIndex::TestSerializeBinarySet(const IndexPtr& index_from,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ class TestIndex {
float expected_recall = 0.99,
bool expected_success = true);

static void
TestSearchWithDirtyVector(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success = true);

static void
TestRangeSearch(const IndexPtr& index,
const TestDatasetPtr& dataset,
Expand Down

0 comments on commit 83fbed5

Please sign in to comment.