Skip to content

Commit

Permalink
bug fix: EnableDictTraining can be false (microsoft#319)
Browse files Browse the repository at this point in the history
  • Loading branch information
suiguoxin authored Jul 25, 2022
1 parent 5ab4ced commit 8b25316
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions AnnService/inc/Core/SPANN/ExtraFullGraphSearcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,33 +652,35 @@ namespace SPTAG
return false;
}
}

// train dict
// create compressor
if (p_opt.m_enableDataCompression && i == 0)
{
m_pCompressor = std::make_unique<Compressor>(p_opt.m_zstdCompressLevel, p_opt.m_dictBufferCapacity);
LOG(Helper::LogLevel::LL_Info, "Training dictionary...\n");
std::string samplesBuffer("");
std::vector<size_t> samplesSizes;
for (int j = 0; j < curPostingListSizes.size(); j++) {
if (curPostingListSizes[j] == 0) {
continue;
}
ValueType* headVector = nullptr;
if (p_opt.m_enableDeltaEncoding)
{
headVector = (ValueType*)p_headIndex->GetSample(j);
}
std::string postingListFullData = GetPostingListFullData(
j, curPostingListSizes[j], selections, fullVectors, p_opt.m_enableDeltaEncoding, p_opt.m_enablePostingListRearrange, headVector);
// train dict
if (p_opt.m_enableDictTraining) {
LOG(Helper::LogLevel::LL_Info, "Training dictionary...\n");
std::string samplesBuffer("");
std::vector<size_t> samplesSizes;
for (int j = 0; j < curPostingListSizes.size(); j++) {
if (curPostingListSizes[j] == 0) {
continue;
}
ValueType* headVector = nullptr;
if (p_opt.m_enableDeltaEncoding)
{
headVector = (ValueType*)p_headIndex->GetSample(j);
}
std::string postingListFullData = GetPostingListFullData(
j, curPostingListSizes[j], selections, fullVectors, p_opt.m_enableDeltaEncoding, p_opt.m_enablePostingListRearrange, headVector);

samplesBuffer += postingListFullData;
samplesSizes.push_back(postingListFullData.size());
if (samplesBuffer.size() > p_opt.m_minDictTraingBufferSize) break;
samplesBuffer += postingListFullData;
samplesSizes.push_back(postingListFullData.size());
if (samplesBuffer.size() > p_opt.m_minDictTraingBufferSize) break;
}
LOG(Helper::LogLevel::LL_Info, "Using the first %zu postingLists to train dictionary... \n", samplesSizes.size());
std::size_t dictSize = m_pCompressor->TrainDict(samplesBuffer, &samplesSizes[0], samplesSizes.size());
LOG(Helper::LogLevel::LL_Info, "Dictionary trained, dictionary size: %zu \n", dictSize);
}
LOG(Helper::LogLevel::LL_Info, "Using the first %zu postingLists to train dictionary... \n", samplesSizes.size());
std::size_t dictSize = m_pCompressor->TrainDict(samplesBuffer, &samplesSizes[0], samplesSizes.size());
LOG(Helper::LogLevel::LL_Info, "Dictionary trained, dictionary size: %zu \n", dictSize);
}

if (p_opt.m_enableDataCompression) {
Expand All @@ -703,7 +705,7 @@ namespace SPTAG
if (sizeToCompress != postingListFullData.size()) {
LOG(Helper::LogLevel::LL_Error, "Size to compress NOT MATCH! PostingListFullData size: %zu sizeToCompress: %zu \n", postingListFullData.size(), sizeToCompress);
}
curPostingListBytes[j] = m_pCompressor->GetCompressedSize(postingListFullData, true);
curPostingListBytes[j] = m_pCompressor->GetCompressedSize(postingListFullData, p_opt.m_enableDictTraining);
if (postingListId % 10000 == 0 || curPostingListBytes[j] > static_cast<uint64_t>(p_opt.m_postingPageLimit) * PageSize) {
LOG(Helper::LogLevel::LL_Info, "Posting list %d/%d, compressed size: %d, compression ratio: %.4f\n", postingListId, postingListSize.size(), curPostingListBytes[j], curPostingListBytes[j] / float(sizeToCompress));
}
Expand Down

0 comments on commit 8b25316

Please sign in to comment.