Skip to content

Commit

Permalink
Catch index corruption a bit more easily.
Browse files Browse the repository at this point in the history
  • Loading branch information
psobot committed Oct 2, 2023
1 parent d72590a commit 53cddf7
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 3 deletions.
23 changes: 20 additions & 3 deletions cpp/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,13 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, data_t> {
totalFileSize = inputStream->getTotalLength();
}
readBinaryPOD(inputStream, offsetLevel0_);
if (totalFileSize > 0 && offsetLevel0_ > totalFileSize) {
throw std::domain_error("Index appears to contain corrupted data; level "
"0 offset parameter (" +
std::to_string(offsetLevel0_) +
") exceeded size of index file (" +
std::to_string(totalFileSize) + ").");
}

readBinaryPOD(inputStream, max_elements_);
readBinaryPOD(inputStream, cur_element_count);
Expand Down Expand Up @@ -763,7 +770,14 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, data_t> {
if (inputStream->getPosition() < 0 ||
inputStream->getPosition() >= totalFileSize) {
throw std::runtime_error(
"Index seems to be corrupted or unsupported");
"Index seems to be corrupted or unsupported. Seeked to " +
std::to_string(position +
(cur_element_count * size_data_per_element_) +
(sizeof(unsigned int) * i)) +
" bytes to read linked list, but resulting stream position was " +
std::to_string(inputStream->getPosition()) +
" (of total file size " + std::to_string(totalFileSize) +
" bytes).");
}

unsigned int linkListSize;
Expand All @@ -774,7 +788,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, data_t> {
}

if (inputStream->getPosition() != totalFileSize)
throw std::runtime_error("Index seems to be corrupted or unsupported");
throw std::runtime_error(
"Index seems to be corrupted or unsupported. After reading all "
"linked lists, extra data remained at the end of the index.");

inputStream->setPosition(position);
}
Expand Down Expand Up @@ -903,7 +919,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, data_t> {
tableint label_c;
auto search = label_lookup_.find(label);
if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
throw std::runtime_error("Label not found");
throw std::runtime_error("Label " + std::to_string(label) +
" not found in index.");
}
label_c = search->second;

Expand Down
72 changes: 72 additions & 0 deletions python/tests/test_load_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest

import os
from io import BytesIO
import numpy as np
from glob import glob

Expand Down Expand Up @@ -64,6 +65,8 @@ def test_load_v0_indices(load_from_stream: bool, index_filename: str):
num_dimensions = detect_num_dimensions_from_filename(index_filename)
if load_from_stream:
with open(index_filename, "rb") as f:
print(f.read(8))
f.seek(0)
index = Index.load(
f,
space=space,
Expand Down Expand Up @@ -107,3 +110,72 @@ def test_load_v1_indices(load_from_stream: bool, index_filename: str):
# Voyager stores only normalized vectors in Cosine mode:
expected_vector = expected_vector / np.sqrt(np.sum(expected_vector**2))
np.testing.assert_allclose(index[_id], expected_vector, atol=0.2)


@pytest.mark.parametrize(
"data,should_pass",
[
(
b"VOYA" # Header
b"\x01\x00\x00\x00" # File version
b"\x0A\x00\x00\x00" # Number of dimensions (10)
b"\x00" # Space type
b"\x20", # Storage data type
False,
),
(
b"VOYA" # Header
b"\x01\x00\x00\x00" # File version
b"\x0A\x00\x00\x00" # Number of dimensions (10)
b"\x00" # Space type
b"\x20" # Storage data type
b"\x00\x00\x00\x00\x00\x00\x00\x00" # offsetLevel0_
b"\x01\x00\x00\x00\x00\x00\x00\x00" # max_elements_
b"\x01\x00\x00\x00\x00\x00\x00\x00" # cur_element_count
b"\x34\x00\x00\x00\x00\x00\x00\x00" # size_data_per_element_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # label_offset_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # offsetData_
b"\x00\x00\x00\x00" # maxlevel_
b"\x00\x00\x00\x00" # enterpoint_node_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # maxM_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # maxM0_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # M_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # mult_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # ef_construction_
+ (b"\x00" * 52) # one vector
+ b"\x00\x00\x00\x00", # one linklist
True,
),
(
b"VOYA" # Header
b"\x01\x00\x00\x00" # File version
b"\x0A\x00\x00\x00" # Number of dimensions (10)
b"\x00" # Space type
b"\x20" # Storage data type
b"\x00\x00\x00\xFF\x00\x00\x00\x00" # offsetLevel0_
b"\x01\x00\x00\x00\x00\x00\x00\x00" # max_elements_
b"\x01\x00\x00\x00\x00\x00\x00\x00" # cur_element_count
b"\x34\x00\x00\x00\x00\x00\x00\x00" # size_data_per_element_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # label_offset_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # offsetData_
b"\x00\x00\x00\x00" # maxlevel_
b"\x00\x00\x00\x00" # enterpoint_node_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # maxM_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # maxM0_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # M_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # mult_
b"\x00\x00\x00\x00\x00\x00\x00\x00" # ef_construction_
+ (b"\x00" * 52) # one vector
+ (b"\x00\x00\x00\x00"), # one linklist
False,
),
],
)
def test_loading_random_data_cannot_crash(data: bytes, should_pass: bool):
if should_pass:
index = Index.load(BytesIO(data))
assert len(index) == 1
np.testing.assert_allclose(index[0], np.zeros(index.num_dimensions))
else:
with pytest.raises(Exception):
Index.load(BytesIO(data))

0 comments on commit 53cddf7

Please sign in to comment.