Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Host Snappy compression #17824

Draft
wants to merge 2 commits into
base: branch-25.02
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 185 additions & 29 deletions cpp/src/io/comp/comp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

#include "gpuinflate.hpp"
#include "io/utilities/getenv_or.hpp"
#include "io/utilities/hostdevice_vector.hpp"
#include "nvcomp_adapter.hpp"

#include <cudf/detail/nvtx/ranges.hpp>
Expand Down Expand Up @@ -92,33 +91,189 @@ std::vector<std::uint8_t> compress_gzip(host_span<uint8_t const> src)
return dst;
}

/**
* @brief SNAPPY device compressor
*/
std::vector<std::uint8_t> compress_snappy(host_span<uint8_t const> src,
rmm::cuda_stream_view stream)
template <typename T>
[[nodiscard]] T load(uint8_t const* ptr)
{
T value;
std::memcpy(&value, ptr, sizeof(T));
return value;
}

class hash_table {
std::vector<uint16_t> tbl;
static constexpr int hash_table_bits = 15;

public:
hash_table() : tbl(1 << hash_table_bits, 0) {}

void clear() { std::fill(tbl.begin(), tbl.end(), 0); }

[[nodiscard]] uint16_t* entry(uint32_t bytes)
{
constexpr uint32_t multiplier = 0x1e35a7bd;
auto const hash = (bytes * multiplier) >> (31 - hash_table_bits);
return tbl.data() + hash / sizeof(uint16_t);
}
};

uint8_t* emit_literal(uint8_t* out_begin, uint8_t const* literal_begin, uint8_t const* literal_end)
{
auto const literal_size = literal_end - literal_begin;
if (literal_size == 0) { return out_begin; }
auto const n = literal_size - 1;

auto out_it = out_begin;
if (n < 60) {
// Fits into a single tag byte
*out_it++ = n << 2;
} else {
auto const log2_n = 31 - __builtin_clz(n);
auto const count = (log2_n >> 3) + 1;
*out_it++ = (59 + count) << 2;
std::memcpy(out_it, &n, count);
out_it += count;
}
std::memcpy(out_it, literal_begin, literal_size);
return out_it + literal_size;
}

uint8_t* emit_copy(uint8_t* out_begin, size_t offset, size_t len)
{
while (len > 0) {
auto const copy_len = std::min(len, 64ul);
auto const out_val = 2 + ((copy_len - 1) << 2) + (offset << 8);
std::memcpy(out_begin, &out_val, 3);

out_begin += 3;
len -= copy_len;
}
return out_begin;
}

size_t compress_block(host_span<uint8_t const> input, hash_table& table, host_span<uint8_t> output)
{
auto const [in_remain, out_remain] = [&]() -> std::pair<uint8_t const*, uint8_t*> {
auto in_it = input.begin();
auto out_it = output.begin();

// The algorithm reads 8 bytes at a time, so we need to ensure there are at least 8 bytes
auto const input_max = input.end() - sizeof(uint64_t);
while (in_it < input_max) {
auto const next_emit = in_it++;
auto data = load<uint64_t>(in_it);
uint32_t stride = 1;
uint8_t const* candidate = nullptr;

auto word_match_found = [&]() {
if (input_max - in_it < 16) { return false; }
for (size_t word_idx = 0; word_idx < 4; ++word_idx) {
for (size_t byte_idx = 0; byte_idx < sizeof(uint32_t); ++byte_idx) {
auto const offset = sizeof(uint32_t) * word_idx + byte_idx;
auto* const entry = table.entry(static_cast<uint32_t>(data));
candidate = input.begin() + *entry;
*entry = in_it - input.data() + offset;

if (load<uint32_t>(candidate) == static_cast<uint32_t>(data)) {
*(out_it++) = offset * sizeof(uint32_t);
std::memcpy(out_it, next_emit, offset + 1);
in_it += offset;
out_it += offset + 1;
stride = 1;
return true;
}
data >>= 8;
}
// Fetch the next eight bytes
data = load<uint64_t>(in_it + sizeof(uint32_t) * (word_idx + 1));
}
in_it += 16;
return false;
}();

if (not word_match_found) {
// keep looking for a match with increasing stride
for (;;) {
auto* const entry = table.entry(static_cast<uint32_t>(data));
candidate = input.begin() + *entry;
*entry = in_it - input.begin();
if (static_cast<uint32_t>(data) == load<uint32_t>(candidate)) {
stride = 1;
break;
}

auto const next_input = in_it + stride;
if (next_input > input_max) {
// Reached the end of the input without finding a match
return {next_emit, out_it};
}

data = load<uint32_t>(next_input);
in_it = next_input;
stride += 1;
}

// Emit data prior to the match as literal
out_it = emit_literal(out_it, next_emit, in_it);
}

// Emit match(es)
do {
auto const match_len = std::mismatch(in_it, input.end(), candidate).first - in_it;
out_it = emit_copy(out_it, in_it - candidate, match_len);

in_it += match_len;
if (in_it >= input_max) {
// Reached the end of the input, no more matches to look for
return {in_it, out_it};
}
data = load<uint64_t>(in_it);
*table.entry(load<uint32_t>(in_it - 1)) = in_it - input.begin() - 1;
auto* const entry = table.entry(data);
candidate = input.begin() + *entry;
*entry = in_it - input.begin();

} while (static_cast<uint32_t>(data) == load<uint32_t>(candidate));
}

return {in_it, out_it};
}();

// Emit the remaining data as a literal
return emit_literal(out_remain, in_remain, input.end()) - output.begin();
}

void append_varint(std::vector<uint8_t>& output, size_t v)
{
while (v > 127) {
output.push_back((v & 0x7F) | 0x80);
v >>= 7;
}
output.push_back(v);
}

[[nodiscard]] std::vector<std::uint8_t> compress_snappy(host_span<uint8_t const> src)
{
auto const d_src =
cudf::detail::make_device_uvector_async(src, stream, cudf::get_current_device_resource_ref());
cudf::detail::hostdevice_vector<device_span<uint8_t const>> inputs(1, stream);
inputs[0] = d_src;
inputs.host_to_device_async(stream);

auto dst_size = compress_max_output_chunk_size(nvcomp::compression_type::SNAPPY, src.size());
rmm::device_uvector<uint8_t> d_dst(dst_size, stream);
cudf::detail::hostdevice_vector<device_span<uint8_t>> outputs(1, stream);
outputs[0] = d_dst;
outputs.host_to_device_async(stream);

cudf::detail::hostdevice_vector<compression_result> hd_status(1, stream);
hd_status[0] = {};
hd_status.host_to_device_async(stream);

nvcomp::batched_compress(nvcomp::compression_type::SNAPPY, inputs, outputs, hd_status, stream);

hd_status.device_to_host_sync(stream);
CUDF_EXPECTS(hd_status[0].status == compression_status::SUCCESS, "snappy compression failed");
return cudf::detail::make_std_vector_sync<uint8_t>(d_dst, stream);
std::vector<uint8_t> dst;
append_varint(dst, src.size());
dst.reserve(dst.size() + max_compressed_size(compression_type::SNAPPY, src.size()));
hash_table table; // reuse hash table across blocks

constexpr size_t block_size = 1 << 16;
for (size_t src_offset = 0; src_offset < src.size(); src_offset += block_size) {
// Compress data in blocks of limited size
auto const block = src.subspan(src_offset, std::min(src.size() - src_offset, block_size));

auto const previous_size = dst.size();
dst.resize(previous_size + max_compressed_size(compression_type::SNAPPY, block.size()));
auto const block_dst =
host_span<uint8_t>{dst.data() + previous_size, dst.size() - previous_size};

table.clear();
auto const comp_block_size = compress_block(block, table, block_dst);
dst.resize(previous_size + comp_block_size);
}

return dst;
}

void device_compress(compression_type compression,
Expand Down Expand Up @@ -183,6 +338,7 @@ void host_compress(compression_type compression,
{
switch (compression) {
case compression_type::GZIP:
case compression_type::SNAPPY:
case compression_type::NONE: return true;
default: return false;
}
Expand Down Expand Up @@ -249,12 +405,12 @@ std::optional<size_t> compress_max_allowed_chunk_size(compression_type compressi

std::vector<std::uint8_t> compress(compression_type compression,
host_span<uint8_t const> src,
rmm::cuda_stream_view stream)
rmm::cuda_stream_view)
{
CUDF_FUNC_RANGE();
switch (compression) {
case compression_type::GZIP: return compress_gzip(src);
case compression_type::SNAPPY: return compress_snappy(src, stream);
case compression_type::SNAPPY: return compress_snappy(src);
default: CUDF_FAIL("Unsupported compression type");
}
}
Expand Down
Loading