diff --git a/src/buffer_wrapper.h b/src/buffer_wrapper.h new file mode 100644 index 00000000..af10fce9 --- /dev/null +++ b/src/buffer_wrapper.h @@ -0,0 +1,36 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "vsag/allocator.h" + +namespace vsag { +class BufferWrapper { +public: + BufferWrapper(uint64_t size, Allocator* allocator) : allocator(allocator) { + data = static_cast(allocator->Allocate(size)); + } + + ~BufferWrapper() { + allocator->Deallocate(data); + } + +public: + uint8_t* data; + + Allocator* allocator; +}; +} // namespace vsag diff --git a/src/data_cell/flatten_datacell.h b/src/data_cell/flatten_datacell.h index c8e2437e..8164b614 100644 --- a/src/data_cell/flatten_datacell.h +++ b/src/data_cell/flatten_datacell.h @@ -14,13 +14,16 @@ // limitations under the License. #pragma once + #include #include #include +#include "buffer_wrapper.h" #include "flatten_interface.h" #include "io/basic_io.h" #include "quantization/quantizer.h" + namespace vsag { /* * thread unsafe @@ -160,22 +163,6 @@ FlattenDataCell::InsertVector(const float* vector, InnerIdTyp allocator_->Deallocate(codes); } -struct BufferWrapper { -public: - BufferWrapper(uint64_t size, Allocator* allocator) : allocator(allocator) { - data = static_cast(allocator->Allocate(size)); - } - - ~BufferWrapper() { - allocator->Deallocate(data); - } - -public: - uint8_t* data; - - Allocator* allocator; -}; - template void FlattenDataCell::BatchInsertVector(const float* vectors, diff --git a/src/io/basic_io.h b/src/io/basic_io.h index 0f051892..83d4bfaa 100644 --- a/src/io/basic_io.h +++ b/src/io/basic_io.h @@ -19,6 +19,7 @@ #include +#include "buffer_wrapper.h" #include "io_parameter.h" #include "stream_reader.h" #include "stream_writer.h" @@ -43,7 +44,7 @@ namespace vsag { template class BasicIO { public: - BasicIO() = default; + explicit BasicIO(Allocator* allocator) : allocator_(allocator){}; virtual ~BasicIO() = default; @@ -100,21 +101,28 @@ class BasicIO { inline void Serialize(StreamWriter& writer) { - if constexpr (has_SerializeImpl::value) { - cast().SerializeImpl(writer); - } else { - throw std::runtime_error( - fmt::format("class {} have no func named SerializeImpl", typeid(IOTmpl).name())); + StreamWriter::WriteObj(writer, this->size_); + BufferWrapper buffer(BUFFER_SIZE, this->allocator_); + uint64_t offset = 0; + while (offset < this->size_) { + auto cur_size = std::min(BUFFER_SIZE, this->size_ - offset); + this->Read(cur_size, offset, buffer.data); + writer.Write(reinterpret_cast(buffer.data), cur_size); + offset += cur_size; } } inline void Deserialize(StreamReader& reader) { - if constexpr (has_DeserializeImpl::value) { - cast().DeserializeImpl(reader); - } else { - throw std::runtime_error( - fmt::format("class {} have no func named DeserializeImpl", typeid(IOTmpl).name())); + uint64_t size; + StreamReader::ReadObj(reader, size); + BufferWrapper buffer(BUFFER_SIZE, this->allocator_); + uint64_t offset = 0; + while (offset < size) { + auto cur_size = std::min(BUFFER_SIZE, size - offset); + reader.Read(reinterpret_cast(buffer.data), cur_size); + this->Write(buffer.data, cur_size, offset); + offset += cur_size; } } @@ -128,6 +136,12 @@ class BasicIO { } } +public: + uint64_t size_{0}; + +protected: + Allocator* const allocator_{nullptr}; + private: inline IOTmpl& cast() { @@ -139,14 +153,14 @@ class BasicIO { return static_cast(*this); } + constexpr static uint64_t BUFFER_SIZE = 1024 * 1024 * 2; + private: GENERATE_HAS_MEMBER_FUNC(WriteImpl, void (U::*)(const uint8_t*, uint64_t, uint64_t)) GENERATE_HAS_MEMBER_FUNC(ReadImpl, bool (U::*)(uint64_t, uint64_t, uint8_t*)) GENERATE_HAS_MEMBER_FUNC(DirectReadImpl, const uint8_t* (U::*)(uint64_t, uint64_t, bool&)) GENERATE_HAS_MEMBER_FUNC(MultiReadImpl, bool (U::*)(uint8_t*, uint64_t*, uint64_t*, uint64_t)) GENERATE_HAS_MEMBER_FUNC(PrefetchImpl, void (U::*)(uint64_t, uint64_t)) - GENERATE_HAS_MEMBER_FUNC(SerializeImpl, void (U::*)(StreamWriter&)) - GENERATE_HAS_MEMBER_FUNC(DeserializeImpl, void (U::*)(StreamReader&)) GENERATE_HAS_MEMBER_FUNC(ReleaseImpl, void (U::*)(const uint8_t*)) }; } // namespace vsag diff --git a/src/io/basic_io_test.cpp b/src/io/basic_io_test.cpp index 848d1f43..8bb57f40 100644 --- a/src/io/basic_io_test.cpp +++ b/src/io/basic_io_test.cpp @@ -20,28 +20,22 @@ #include #include "fixtures.h" +#include "safe_allocator.h" -class WrongIO : public vsag::BasicIO {}; +class WrongIO : public vsag::BasicIO { +public: + WrongIO(vsag::Allocator* allocator) : vsag::BasicIO(allocator){}; +}; TEST_CASE("wrong io", "[ut][basic io]") { - auto io = std::make_shared(); + auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator(); + auto io = std::make_shared(allocator.get()); std::vector data(100); bool release; - fixtures::TempDir dirname("TestWrongIO"); - std::string filename = dirname.GenerateRandomFile(); - std::ofstream outfile(filename.c_str(), std::ios::binary); - IOStreamWriter writer(outfile); - REQUIRE_THROWS(io->Serialize(writer)); - outfile.close(); - std::ifstream infile(filename.c_str(), std::ios::binary); - IOStreamReader reader(infile); REQUIRE_THROWS(io->Read(1, 0, data.data())); REQUIRE_THROWS(io->Write(data.data(), 1, 0)); REQUIRE_THROWS(io->Read(1, 0, release)); - REQUIRE_THROWS(io->Deserialize(reader)); REQUIRE_THROWS(io->Prefetch(1, 0)); REQUIRE_THROWS(io->MultiRead(data.data(), nullptr, nullptr, 1)); - - infile.close(); } diff --git a/src/io/memory_block_io.h b/src/io/memory_block_io.h index c46fd9ca..8752980e 100644 --- a/src/io/memory_block_io.h +++ b/src/io/memory_block_io.h @@ -35,8 +35,8 @@ namespace vsag { class MemoryBlockIO : public BasicIO { public: explicit MemoryBlockIO(Allocator* allocator, uint64_t block_size) - : block_size_(MemoryBlockIOParameter::NearestPowerOfTwo(block_size)), - allocator_(allocator), + : BasicIO(allocator), + block_size_(MemoryBlockIOParameter::NearestPowerOfTwo(block_size)), blocks_(0, allocator) { this->update_by_block_size(); } @@ -49,7 +49,7 @@ class MemoryBlockIO : public BasicIO { ~MemoryBlockIO() override { for (auto* block : blocks_) { - allocator_->Deallocate(block); + this->allocator_->Deallocate(block); } } @@ -65,7 +65,7 @@ class MemoryBlockIO : public BasicIO { inline void ReleaseImpl(const uint8_t* data) const { auto ptr = const_cast(data); - allocator_->Deallocate(ptr); + this->allocator_->Deallocate(ptr); }; inline bool @@ -74,16 +74,10 @@ class MemoryBlockIO : public BasicIO { inline void PrefetchImpl(uint64_t offset, uint64_t cache_line = 64); - inline void - SerializeImpl(StreamWriter& writer); - - inline void - DeserializeImpl(StreamReader& reader); - private: [[nodiscard]] inline bool check_valid_offset(uint64_t size) const { - return size <= (blocks_.size() << block_bit_); + return size <= this->size_; } inline void @@ -112,8 +106,6 @@ class MemoryBlockIO : public BasicIO { Vector blocks_; - Allocator* const allocator_{nullptr}; - static constexpr uint64_t DEFAULT_BLOCK_SIZE = 128 * 1024 * 1024; // 128MB static constexpr uint64_t DEFAULT_BLOCK_BIT = 27; @@ -139,6 +131,9 @@ MemoryBlockIO::WriteImpl(const uint8_t* data, uint64_t size, uint64_t offset) { ++start_no; start_off = 0; } + if (size + offset > this->size_) { + this->size_ = size + offset; + } } bool @@ -170,7 +165,7 @@ MemoryBlockIO::DirectReadImpl(uint64_t size, uint64_t offset, bool& need_release return this->get_data_ptr(offset); } else { need_release = true; - auto* ptr = reinterpret_cast(allocator_->Allocate(size)); + auto* ptr = reinterpret_cast(this->allocator_->Allocate(size)); this->ReadImpl(size, offset, ptr); return ptr; } @@ -196,40 +191,16 @@ MemoryBlockIO::PrefetchImpl(uint64_t offset, uint64_t cache_line) { void MemoryBlockIO::check_and_realloc(uint64_t size) { - if (check_valid_offset(size)) { + if (size <= (blocks_.size() << block_bit_)) { return; } const uint64_t new_block_count = (size + this->block_size_ - 1) >> block_bit_; auto cur_block_size = this->blocks_.size(); this->blocks_.reserve(new_block_count); while (cur_block_size < new_block_count) { - this->blocks_.emplace_back((uint8_t*)(allocator_->Allocate(block_size_))); + this->blocks_.emplace_back((uint8_t*)(this->allocator_->Allocate(block_size_))); ++cur_block_size; } } -void -MemoryBlockIO::SerializeImpl(StreamWriter& writer) { - StreamWriter::WriteObj(writer, this->block_size_); - uint64_t block_count = this->blocks_.size(); - StreamWriter::WriteObj(writer, block_count); - for (uint64_t i = 0; i < block_count; ++i) { - writer.Write(reinterpret_cast(this->blocks_[i]), block_size_); - } -} - -void -MemoryBlockIO::DeserializeImpl(StreamReader& reader) { - for (auto* block : blocks_) { - allocator_->Deallocate(block); - } - StreamReader::ReadObj(reader, this->block_size_); - uint64_t block_count; - StreamReader::ReadObj(reader, block_count); - this->blocks_.resize(block_count); - for (uint64_t i = 0; i < block_count; ++i) { - blocks_[i] = static_cast(allocator_->Allocate(this->block_size_)); - reader.Read(reinterpret_cast(blocks_[i]), block_size_); - } -} } // namespace vsag diff --git a/src/io/memory_block_io_test.cpp b/src/io/memory_block_io_test.cpp index 2b962f7f..c946bc4e 100644 --- a/src/io/memory_block_io_test.cpp +++ b/src/io/memory_block_io_test.cpp @@ -19,7 +19,6 @@ #include #include "basic_io_test.h" -#include "default_allocator.h" #include "safe_allocator.h" using namespace vsag; diff --git a/src/io/memory_io.h b/src/io/memory_io.h index 24ef1313..4ae9c998 100644 --- a/src/io/memory_io.h +++ b/src/io/memory_io.h @@ -28,9 +28,8 @@ namespace vsag { class MemoryIO : public BasicIO { public: - explicit MemoryIO(Allocator* allocator) : allocator_(allocator) { - start_ = reinterpret_cast(allocator_->Allocate(MIN_SIZE)); - current_size_ = MIN_SIZE; + explicit MemoryIO(Allocator* allocator) : BasicIO(allocator) { + start_ = static_cast(allocator->Allocate(1)); } explicit MemoryIO(const MemoryIOParamPtr& param, const IndexCommonParam& common_param) @@ -42,7 +41,7 @@ class MemoryIO : public BasicIO { } ~MemoryIO() override { - allocator_->Deallocate(start_); + this->allocator_->Deallocate(start_); } inline void @@ -63,16 +62,10 @@ class MemoryIO : public BasicIO { inline void PrefetchImpl(uint64_t offset, uint64_t cache_line = 64); - inline void - SerializeImpl(StreamWriter& writer); - - inline void - DeserializeImpl(StreamReader& reader); - private: [[nodiscard]] inline bool check_valid_offset(uint64_t size) const { - return size <= current_size_; + return size <= this->size_; } void @@ -80,15 +73,12 @@ class MemoryIO : public BasicIO { if (check_valid_offset(size)) { return; } - start_ = reinterpret_cast(allocator_->Reallocate(start_, size)); - current_size_ = size; + start_ = reinterpret_cast(this->allocator_->Reallocate(start_, size)); + this->size_ = size; } private: - Allocator* const allocator_{nullptr}; uint8_t* start_{nullptr}; - uint64_t current_size_{0}; - static const uint64_t MIN_SIZE = 1024; }; void @@ -127,18 +117,4 @@ void MemoryIO::PrefetchImpl(uint64_t offset, uint64_t cache_line) { PrefetchLines(this->start_ + offset, cache_line); } -void -MemoryIO::SerializeImpl(StreamWriter& writer) { - StreamWriter::WriteObj(writer, this->current_size_); - writer.Write(reinterpret_cast(this->start_), current_size_); -} - -void -MemoryIO::DeserializeImpl(StreamReader& reader) { - allocator_->Deallocate(this->start_); - StreamReader::ReadObj(reader, this->current_size_); - this->start_ = static_cast(allocator_->Allocate(this->current_size_)); - reader.Read(reinterpret_cast(this->start_), current_size_); -} - } // namespace vsag diff --git a/src/io/memory_io_test.cpp b/src/io/memory_io_test.cpp index cb9b7ed2..a06f23df 100644 --- a/src/io/memory_io_test.cpp +++ b/src/io/memory_io_test.cpp @@ -19,7 +19,6 @@ #include #include "basic_io_test.h" -#include "default_allocator.h" #include "safe_allocator.h" using namespace vsag;