Skip to content

Commit

Permalink
unified all io serialize method
Browse files Browse the repository at this point in the history
Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 committed Jan 22, 2025
1 parent eeb0295 commit 05b8506
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 114 deletions.
36 changes: 36 additions & 0 deletions src/buffer_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "vsag/allocator.h"

namespace vsag {
class BufferWrapper {
public:
BufferWrapper(uint64_t size, Allocator* allocator) : allocator(allocator) {
data = static_cast<uint8_t*>(allocator->Allocate(size));
}

~BufferWrapper() {
allocator->Deallocate(data);
}

public:
uint8_t* data;

Allocator* allocator;
};
} // namespace vsag
19 changes: 3 additions & 16 deletions src/data_cell/flatten_datacell.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
// limitations under the License.

#pragma once

#include <algorithm>
#include <limits>
#include <memory>

#include "buffer_wrapper.h"
#include "flatten_interface.h"
#include "io/basic_io.h"
#include "quantization/quantizer.h"

namespace vsag {
/*
* thread unsafe
Expand Down Expand Up @@ -160,22 +163,6 @@ FlattenDataCell<QuantTmpl, IOTmpl>::InsertVector(const float* vector, InnerIdTyp
allocator_->Deallocate(codes);
}

struct BufferWrapper {
public:
BufferWrapper(uint64_t size, Allocator* allocator) : allocator(allocator) {
data = static_cast<uint8_t*>(allocator->Allocate(size));
}

~BufferWrapper() {
allocator->Deallocate(data);
}

public:
uint8_t* data;

Allocator* allocator;
};

template <typename QuantTmpl, typename IOTmpl>
void
FlattenDataCell<QuantTmpl, IOTmpl>::BatchInsertVector(const float* vectors,
Expand Down
40 changes: 27 additions & 13 deletions src/io/basic_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <cstdint>

#include "buffer_wrapper.h"
#include "io_parameter.h"
#include "stream_reader.h"
#include "stream_writer.h"
Expand All @@ -43,7 +44,7 @@ namespace vsag {
template <typename IOTmpl>
class BasicIO {
public:
BasicIO<IOTmpl>() = default;
explicit BasicIO<IOTmpl>(Allocator* allocator) : allocator_(allocator){};

virtual ~BasicIO() = default;

Expand Down Expand Up @@ -100,21 +101,28 @@ class BasicIO {

inline void
Serialize(StreamWriter& writer) {
if constexpr (has_SerializeImpl<IOTmpl>::value) {
cast().SerializeImpl(writer);
} else {
throw std::runtime_error(
fmt::format("class {} have no func named SerializeImpl", typeid(IOTmpl).name()));
StreamWriter::WriteObj(writer, this->size_);
BufferWrapper buffer(BUFFER_SIZE, this->allocator_);
uint64_t offset = 0;
while (offset < this->size_) {
auto cur_size = std::min(BUFFER_SIZE, this->size_ - offset);
this->Read(cur_size, offset, buffer.data);
writer.Write(reinterpret_cast<const char*>(buffer.data), cur_size);
offset += cur_size;
}
}

inline void
Deserialize(StreamReader& reader) {
if constexpr (has_DeserializeImpl<IOTmpl>::value) {
cast().DeserializeImpl(reader);
} else {
throw std::runtime_error(
fmt::format("class {} have no func named DeserializeImpl", typeid(IOTmpl).name()));
uint64_t size;
StreamReader::ReadObj(reader, size);
BufferWrapper buffer(BUFFER_SIZE, this->allocator_);
uint64_t offset = 0;
while (offset < size) {
auto cur_size = std::min(BUFFER_SIZE, size - offset);
reader.Read(reinterpret_cast<char*>(buffer.data), cur_size);
this->Write(buffer.data, cur_size, offset);
offset += cur_size;
}
}

Expand All @@ -128,6 +136,12 @@ class BasicIO {
}
}

public:
uint64_t size_{0};

protected:
Allocator* const allocator_{nullptr};

private:
inline IOTmpl&
cast() {
Expand All @@ -139,14 +153,14 @@ class BasicIO {
return static_cast<const IOTmpl&>(*this);
}

constexpr static uint64_t BUFFER_SIZE = 1024 * 1024 * 2;

private:
GENERATE_HAS_MEMBER_FUNC(WriteImpl, void (U::*)(const uint8_t*, uint64_t, uint64_t))
GENERATE_HAS_MEMBER_FUNC(ReadImpl, bool (U::*)(uint64_t, uint64_t, uint8_t*))
GENERATE_HAS_MEMBER_FUNC(DirectReadImpl, const uint8_t* (U::*)(uint64_t, uint64_t, bool&))
GENERATE_HAS_MEMBER_FUNC(MultiReadImpl, bool (U::*)(uint8_t*, uint64_t*, uint64_t*, uint64_t))
GENERATE_HAS_MEMBER_FUNC(PrefetchImpl, void (U::*)(uint64_t, uint64_t))
GENERATE_HAS_MEMBER_FUNC(SerializeImpl, void (U::*)(StreamWriter&))
GENERATE_HAS_MEMBER_FUNC(DeserializeImpl, void (U::*)(StreamReader&))
GENERATE_HAS_MEMBER_FUNC(ReleaseImpl, void (U::*)(const uint8_t*))
};
} // namespace vsag
20 changes: 7 additions & 13 deletions src/io/basic_io_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,22 @@
#include <memory>

#include "fixtures.h"
#include "safe_allocator.h"

class WrongIO : public vsag::BasicIO<WrongIO> {};
class WrongIO : public vsag::BasicIO<WrongIO> {
public:
WrongIO(vsag::Allocator* allocator) : vsag::BasicIO<WrongIO>(allocator){};
};

TEST_CASE("wrong io", "[ut][basic io]") {
auto io = std::make_shared<WrongIO>();
auto allocator = vsag::SafeAllocator::FactoryDefaultAllocator();
auto io = std::make_shared<WrongIO>(allocator.get());
std::vector<uint8_t> data(100);
bool release;
fixtures::TempDir dirname("TestWrongIO");
std::string filename = dirname.GenerateRandomFile();
std::ofstream outfile(filename.c_str(), std::ios::binary);
IOStreamWriter writer(outfile);
REQUIRE_THROWS(io->Serialize(writer));
outfile.close();

std::ifstream infile(filename.c_str(), std::ios::binary);
IOStreamReader reader(infile);
REQUIRE_THROWS(io->Read(1, 0, data.data()));
REQUIRE_THROWS(io->Write(data.data(), 1, 0));
REQUIRE_THROWS(io->Read(1, 0, release));
REQUIRE_THROWS(io->Deserialize(reader));
REQUIRE_THROWS(io->Prefetch(1, 0));
REQUIRE_THROWS(io->MultiRead(data.data(), nullptr, nullptr, 1));

infile.close();
}
51 changes: 11 additions & 40 deletions src/io/memory_block_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ namespace vsag {
class MemoryBlockIO : public BasicIO<MemoryBlockIO> {
public:
explicit MemoryBlockIO(Allocator* allocator, uint64_t block_size)
: block_size_(MemoryBlockIOParameter::NearestPowerOfTwo(block_size)),
allocator_(allocator),
: BasicIO<MemoryBlockIO>(allocator),
block_size_(MemoryBlockIOParameter::NearestPowerOfTwo(block_size)),
blocks_(0, allocator) {
this->update_by_block_size();
}
Expand All @@ -49,7 +49,7 @@ class MemoryBlockIO : public BasicIO<MemoryBlockIO> {

~MemoryBlockIO() override {
for (auto* block : blocks_) {
allocator_->Deallocate(block);
this->allocator_->Deallocate(block);
}
}

Expand All @@ -65,7 +65,7 @@ class MemoryBlockIO : public BasicIO<MemoryBlockIO> {
inline void
ReleaseImpl(const uint8_t* data) const {
auto ptr = const_cast<uint8_t*>(data);
allocator_->Deallocate(ptr);
this->allocator_->Deallocate(ptr);
};

inline bool
Expand All @@ -74,16 +74,10 @@ class MemoryBlockIO : public BasicIO<MemoryBlockIO> {
inline void
PrefetchImpl(uint64_t offset, uint64_t cache_line = 64);

inline void
SerializeImpl(StreamWriter& writer);

inline void
DeserializeImpl(StreamReader& reader);

private:
[[nodiscard]] inline bool
check_valid_offset(uint64_t size) const {
return size <= (blocks_.size() << block_bit_);
return size <= this->size_;
}

inline void
Expand Down Expand Up @@ -112,8 +106,6 @@ class MemoryBlockIO : public BasicIO<MemoryBlockIO> {

Vector<uint8_t*> blocks_;

Allocator* const allocator_{nullptr};

static constexpr uint64_t DEFAULT_BLOCK_SIZE = 128 * 1024 * 1024; // 128MB

static constexpr uint64_t DEFAULT_BLOCK_BIT = 27;
Expand All @@ -139,6 +131,9 @@ MemoryBlockIO::WriteImpl(const uint8_t* data, uint64_t size, uint64_t offset) {
++start_no;
start_off = 0;
}
if (size + offset > this->size_) {
this->size_ = size + offset;
}
}

bool
Expand Down Expand Up @@ -170,7 +165,7 @@ MemoryBlockIO::DirectReadImpl(uint64_t size, uint64_t offset, bool& need_release
return this->get_data_ptr(offset);
} else {
need_release = true;
auto* ptr = reinterpret_cast<uint8_t*>(allocator_->Allocate(size));
auto* ptr = reinterpret_cast<uint8_t*>(this->allocator_->Allocate(size));
this->ReadImpl(size, offset, ptr);
return ptr;
}
Expand All @@ -196,40 +191,16 @@ MemoryBlockIO::PrefetchImpl(uint64_t offset, uint64_t cache_line) {

void
MemoryBlockIO::check_and_realloc(uint64_t size) {
if (check_valid_offset(size)) {
if (size <= (blocks_.size() << block_bit_)) {
return;
}
const uint64_t new_block_count = (size + this->block_size_ - 1) >> block_bit_;
auto cur_block_size = this->blocks_.size();
this->blocks_.reserve(new_block_count);
while (cur_block_size < new_block_count) {
this->blocks_.emplace_back((uint8_t*)(allocator_->Allocate(block_size_)));
this->blocks_.emplace_back((uint8_t*)(this->allocator_->Allocate(block_size_)));
++cur_block_size;
}
}
void
MemoryBlockIO::SerializeImpl(StreamWriter& writer) {
StreamWriter::WriteObj(writer, this->block_size_);
uint64_t block_count = this->blocks_.size();
StreamWriter::WriteObj(writer, block_count);
for (uint64_t i = 0; i < block_count; ++i) {
writer.Write(reinterpret_cast<char*>(this->blocks_[i]), block_size_);
}
}

void
MemoryBlockIO::DeserializeImpl(StreamReader& reader) {
for (auto* block : blocks_) {
allocator_->Deallocate(block);
}
StreamReader::ReadObj(reader, this->block_size_);
uint64_t block_count;
StreamReader::ReadObj(reader, block_count);
this->blocks_.resize(block_count);
for (uint64_t i = 0; i < block_count; ++i) {
blocks_[i] = static_cast<unsigned char*>(allocator_->Allocate(this->block_size_));
reader.Read(reinterpret_cast<char*>(blocks_[i]), block_size_);
}
}

} // namespace vsag
1 change: 0 additions & 1 deletion src/io/memory_block_io_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <memory>

#include "basic_io_test.h"
#include "default_allocator.h"
#include "safe_allocator.h"

using namespace vsag;
Expand Down
Loading

0 comments on commit 05b8506

Please sign in to comment.