Skip to content

Commit

Permalink
introduce bucket datacell parameter
Browse files Browse the repository at this point in the history
Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 committed Jan 22, 2025
1 parent 2b9ddac commit a35fcde
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 2 deletions.
50 changes: 50 additions & 0 deletions src/data_cell/bucket_datacell_parameter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "bucket_datacell_parameter.h"

#include <fmt/format-inl.h>

#include "inner_string_params.h"

namespace vsag {
BucketDataCellParameter::BucketDataCellParameter() = default;

void
BucketDataCellParameter::FromJson(const JsonType& json) {
CHECK_ARGUMENT(json.contains(IO_PARAMS_KEY),
fmt::format("bucket interface parameters must contains {}", IO_PARAMS_KEY));
this->io_parameter_ = IOParameter::GetIOParameterByJson(json[IO_PARAMS_KEY]);

CHECK_ARGUMENT(
json.contains(QUANTIZATION_PARAMS_KEY),
fmt::format("bucket interface parameters must contains {}", QUANTIZATION_PARAMS_KEY));
this->quantizer_parameter_ =
QuantizerParameter::GetQuantizerParameterByJson(json[QUANTIZATION_PARAMS_KEY]);

if (json.contains(BUCKETS_COUNT_KEY)) {
this->buckets_count_ = json[BUCKETS_COUNT_KEY];
}
}

JsonType
BucketDataCellParameter::ToJson() {
JsonType json;
json[IO_PARAMS_KEY] = this->io_parameter_->ToJson();
json[QUANTIZATION_PARAMS_KEY] = this->quantizer_parameter_->ToJson();
json[BUCKETS_COUNT_KEY] = this->buckets_count_;
return json;
}

Check warning on line 49 in src/data_cell/bucket_datacell_parameter.cpp

View check run for this annotation

Codecov / codecov/patch

src/data_cell/bucket_datacell_parameter.cpp#L49

Added line #L49 was not covered by tests
} // namespace vsag
44 changes: 44 additions & 0 deletions src/data_cell/bucket_datacell_parameter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "io/io_parameter.h"
#include "parameter.h"
#include "quantization/quantizer_parameter.h"

namespace vsag {

class BucketDataCellParameter : public Parameter {
public:
explicit BucketDataCellParameter();

void
FromJson(const JsonType& json) override;

JsonType
ToJson() override;

public:
QuantizerParamPtr quantizer_parameter_{nullptr};

IOParamPtr io_parameter_{nullptr};

int64_t buckets_count_{1};
};

using BucketDataCellParamPtr = std::shared_ptr<BucketDataCellParameter>;

} // namespace vsag
106 changes: 106 additions & 0 deletions src/data_cell/bucket_datacell_parameter_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "bucket_datacell_parameter.h"

#include <catch2/catch_test_macros.hpp>

#include "parameter_test.h"

using namespace vsag;

TEST_CASE("BucketDataCellParameter ToJson Test", "[ut][BucketDataCellParameter]") {
std::string param_str = R"(
{
"io_params": {
"type": "memory_io"
},
"quantization_params": {
"type": "sq8"
},
"buckets_count": 10
})";
auto param = std::make_shared<BucketDataCellParameter>();
auto json = JsonType::parse(param_str);
param->FromJson(json);
REQUIRE(param->buckets_count_ == 10);
ParameterTest::TestToJson(param);
}

TEST_CASE("BucketDataCellParameter Parse Exception", "[ut][BucketDataCellParameter]") {
auto check_param = [](const std::string& str) -> BucketDataCellParamPtr {
auto param = std::make_shared<BucketDataCellParameter>();
auto json = JsonType::parse(str);
param->FromJson(json);
return param;
};

SECTION("miss io param") {
std::string param_str = R"(
{
"quantization_params": {
"type": "sq8",
},
"buckets_count": 10
})";
REQUIRE_THROWS(check_param(param_str));
}

SECTION("miss quantization param") {
std::string param_str = R"(
{
"io_params": {
"type": "memory_io"
},
"buckets_count": 10
})";
REQUIRE_THROWS(check_param(param_str));
}

SECTION("wrong io param type") {
std::string param_str = R"(
{
"io_params": {
"type": "wrong_io"
},
"buckets_count": 10
})";
REQUIRE_THROWS(check_param(param_str));
}

SECTION("wrong quantization param type") {
std::string param_str = R"(
{
"quantization_params": {
"type": "wrong_quantization",
},
"buckets_count": 10
})";
REQUIRE_THROWS(check_param(param_str));
}

SECTION("valid on missing buckets_count") {
std::string param_str = R"(
{
"io_params": {
"type": "memory_io"
},
"quantization_params": {
"type": "sq8"
}
})";
auto param = check_param(param_str);
}
}
3 changes: 1 addition & 2 deletions src/data_cell/flatten_datacell_parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
#include "inner_string_params.h"

namespace vsag {
FlattenDataCellParameter::FlattenDataCellParameter() {
}
FlattenDataCellParameter::FlattenDataCellParameter() = default;

void
FlattenDataCellParameter::FromJson(const JsonType& json) {
Expand Down
3 changes: 3 additions & 0 deletions src/inner_string_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ const char* const BUILD_PARAMS_KEY = "build_params";
const char* const BUILD_THREAD_COUNT = "build_thread_count";
const char* const BUILD_EF_CONSTRUCTION = "ef_construction";

const char* const BUCKETS_COUNT_KEY = "buckets_count";

const std::unordered_map<std::string, std::string> DEFAULT_MAP = {
{"INDEX_TYPE_HGRAPH", INDEX_TYPE_HGRAPH},
{"HGRAPH_USE_REORDER_KEY", HGRAPH_USE_REORDER_KEY},
Expand All @@ -76,6 +78,7 @@ const std::unordered_map<std::string, std::string> DEFAULT_MAP = {
{"BUILD_PARAMS_KEY", BUILD_PARAMS_KEY},
{"BUILD_THREAD_COUNT", BUILD_THREAD_COUNT},
{"BUILD_EF_CONSTRUCTION", BUILD_EF_CONSTRUCTION},
{"BUCKETS_COUNT_KEY", BUCKETS_COUNT_KEY},
};

} // namespace vsag

0 comments on commit a35fcde

Please sign in to comment.