Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented AnnoyIndex serialization to bytes objects in-memory #661

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
*.egg-info/
*.egg/
*.eggs/
*.so
*.o
build/
Expand Down
2 changes: 2 additions & 0 deletions annoy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class AnnoyIndex:
def __init__(self, f: int, metric: Literal["angular", "euclidean", "manhattan", "hamming", "dot"]) -> None: ...
def load(self, fn: str, prefault: bool = ...) -> Literal[True]: ...
def save(self, fn: str, prefault: bool = ...) -> Literal[True]: ...
def serialize(self) -> bytes: ...
def deserialize(self, data: bytes, prefault: bool = ...) -> Literal[True]: ...
@overload
def get_nns_by_item(self, i: int, n: int, search_k: int = ..., include_distances: Literal[False] = ...) -> list[int]: ...
@overload
Expand Down
9 changes: 9 additions & 0 deletions src/annoygomodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ class AnnoyIndex {
bool load(const char* filename) {
return ptr->load(filename, true);
};
vector<uint8_t> serialize() {
return ptr->serialize();
}
bool deserialize(vector<uint8_t>* v, bool prefault) {
return ptr->deserialize(v, prefault);
}
bool deserialize(vector<uint8_t>* v) {
return ptr->deserialize(v, true);
}
float getDistance(int i, int j) {
return ptr->get_distance(i, j);
};
Expand Down
75 changes: 75 additions & 0 deletions src/annoylib.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,8 @@ class AnnoyIndexInterface {
virtual bool save(const char* filename, bool prefault=false, char** error=NULL) = 0;
virtual void unload() = 0;
virtual bool load(const char* filename, bool prefault=false, char** error=NULL) = 0;
virtual vector<uint8_t> serialize(char** error=NULL) const = 0;
virtual bool deserialize(vector<uint8_t>* bytes, bool prefault=false, char** error=NULL) = 0;
virtual T get_distance(S i, S j) const = 0;
virtual void get_nns_by_item(S item, size_t n, int search_k, vector<S>* result, vector<T>* distances) const = 0;
virtual void get_nns_by_vector(const T* w, size_t n, int search_k, vector<S>* result, vector<T>* distances) const = 0;
Expand Down Expand Up @@ -1221,6 +1223,79 @@ template<typename S, typename T, typename Distance, typename Random, class Threa
return true;
}

vector<uint8_t> serialize(char** error=NULL) const {
if (!_built) {
set_error_from_string(error, "Index cannot be serialized if it hasn't been built");
return {};
}

vector<uint8_t> bytes {};

S n_items = _n_items;
S n_nodes = _n_nodes;
size_t roots_size = _roots.size();
S nodes_size = _nodes_size;

bytes.insert(bytes.end(), (uint8_t*)&n_items, (uint8_t*)&n_items + sizeof(n_items));
bytes.insert(bytes.end(), (uint8_t*)&n_nodes, (uint8_t*)&n_nodes + sizeof(n_nodes));
bytes.insert(bytes.end(), (uint8_t*)&roots_size, (uint8_t*)&roots_size + sizeof(roots_size));
bytes.insert(bytes.end(), (uint8_t*)&nodes_size, (uint8_t*)&nodes_size + sizeof(nodes_size));

uint8_t* roots_buffer = (uint8_t*)_roots.data();
bytes.insert(bytes.end(), roots_buffer, roots_buffer + _roots.size() * sizeof(S));

uint8_t* nodes_buffer = (uint8_t*)_nodes;
bytes.insert(bytes.end(), nodes_buffer, nodes_buffer + _nodes_size * _s);

return bytes;
}

bool deserialize(vector<uint8_t>* bytes, bool prefault=false, char** error=NULL) {
if (bytes->empty()) {
set_error_from_errno(error, "Size of bytes is zero");
return false;
}

int flags = MAP_SHARED;
if (prefault) {
#ifdef MAP_POPULATE
flags |= MAP_POPULATE;
#else
annoylib_showUpdate("prefault is set to true, but MAP_POPULATE is not defined on this platform");
#endif
}

uint8_t* bytes_buffer = (uint8_t*)bytes->data();

_n_items = *(S*)bytes_buffer;
bytes_buffer += sizeof(S);

S n_nodes = *(S*)bytes_buffer;
bytes_buffer += sizeof(S);

size_t roots_size = *(size_t*)bytes_buffer;
bytes_buffer += sizeof(size_t);

S nodes_size = *(S*)bytes_buffer;
bytes_buffer += sizeof(S);

_roots.clear();
_roots.resize(roots_size);
_roots.assign((S*) bytes_buffer, (S*) bytes_buffer + roots_size);
bytes_buffer += roots_size * sizeof(S);

_allocate_size((S) nodes_size);

memcpy(_nodes, bytes_buffer, nodes_size * _s);

_n_nodes = n_nodes;
_loaded = true;
_built = true;

if (_verbose) annoylib_showUpdate("found %zu roots with degree %d\n", _roots.size(), _n_items);
return true;
}

T get_distance(S i, S j) const {
return D::normalized_distance(D::distance(_get(i), _get(j), _f));
}
Expand Down
23 changes: 23 additions & 0 deletions src/annoyluamodule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,27 @@ class LuaAnnoy {
return 1;
}

static int serialize(lua_State* L) {
Impl* self = getAnnoy(L, 1);
int nargs = lua_gettop(L);
vector<uint8_t> bytes = self->serialize();

lua_pushlstring(L, (const char*) bytes.data(), bytes.size());

return 1;
}

static int deserialize(lua_State* L) {
Impl* self = getAnnoy(L, 1);
int nargs = lua_gettop(L);
const char* bytes_buffer = lua_tostring(L, 2);
size_t bytes_buffer_size = lua_rawlen(L, 2);
vector<uint8_t> bytes(bytes_buffer, bytes_buffer + bytes_buffer_size);
self->deserialize(&bytes);

return 1;
}

static int unload(lua_State* L) {
Impl* self = getAnnoy(L, 1);
self->unload();
Expand Down Expand Up @@ -260,6 +281,8 @@ class LuaAnnoy {
{"build", &ThisClass::build},
{"save", &ThisClass::save},
{"load", &ThisClass::load},
{"serialize", &ThisClass::serialize},
{"deserialize", &ThisClass::deserialize},
{"unload", &ThisClass::unload},
{"get_nns_by_item", &ThisClass::get_nns_by_item},
{"get_nns_by_vector", &ThisClass::get_nns_by_vector},
Expand Down
50 changes: 50 additions & 0 deletions src/annoymodule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "kissrandom.h"
#include "Python.h"
#include "structmember.h"
#include "bytesobject.h"
#include <exception>
#if defined(_MSC_VER) && _MSC_VER == 1500
typedef signed __int32 int32_t;
Expand Down Expand Up @@ -96,6 +97,8 @@ class HammingWrapper : public AnnoyIndexInterface<int32_t, float> {
bool save(const char* filename, bool prefault, char** error) { return _index.save(filename, prefault, error); };
void unload() { _index.unload(); };
bool load(const char* filename, bool prefault, char** error) { return _index.load(filename, prefault, error); };
vector<uint8_t> serialize(char** error) const { return _index.serialize(error); };
bool deserialize(vector<uint8_t>* bytes, bool prefault, char** error) { return _index.deserialize(bytes, prefault, error); };
float get_distance(int32_t i, int32_t j) const { return _index.get_distance(i, j); };
void get_nns_by_item(int32_t item, size_t n, int search_k, vector<int32_t>* result, vector<float>* distances) const {
if (distances) {
Expand Down Expand Up @@ -235,6 +238,51 @@ py_an_save(py_annoy *self, PyObject *args, PyObject *kwargs) {
Py_RETURN_TRUE;
}

static PyObject *
py_an_serialize(py_annoy *self, PyObject *args, PyObject *kwargs) {
if (!self->ptr)
return NULL;

vector<uint8_t> bytes = self->ptr->serialize(NULL);

return PyBytes_FromStringAndSize((const char*)bytes.data(), bytes.size());
}

static PyObject *
py_an_deserialize(py_annoy *self, PyObject *args, PyObject *kwargs) {
PyObject* bytes_object;
char *error;
bool prefault = false;
if (!self->ptr)
return NULL;

static char const * kwlist[] = {"bytes", "prefault", NULL};

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "S|b", (char**)kwlist, &bytes_object, &prefault))
return NULL;

if (bytes_object == NULL) {
PyErr_SetString(PyExc_TypeError, "Expected bytes");
return NULL;
}

if (!PyBytes_Check(bytes_object)) {
PyErr_SetString(PyExc_TypeError, "Expected bytes");
return NULL;
}

Py_ssize_t length = PyBytes_Size(bytes_object);
uint8_t* raw_bytes = (uint8_t*)PyBytes_AsString(bytes_object);
vector<uint8_t> v(raw_bytes, raw_bytes + length);

if (!self->ptr->deserialize(&v, prefault, &error)) {
PyErr_SetString(PyExc_IOError, error);
free(error);
return NULL;
}

Py_RETURN_TRUE;
}

PyObject*
get_nns_to_python(const vector<int32_t>& result, const vector<float>& distances, int include_distances) {
Expand Down Expand Up @@ -575,6 +623,8 @@ py_an_set_seed(py_annoy *self, PyObject *args) {
static PyMethodDef AnnoyMethods[] = {
{"load", (PyCFunction)py_an_load, METH_VARARGS | METH_KEYWORDS, "Loads (mmaps) an index from disk."},
{"save", (PyCFunction)py_an_save, METH_VARARGS | METH_KEYWORDS, "Saves the index to disk."},
{"serialize", (PyCFunction)py_an_serialize, METH_VARARGS | METH_KEYWORDS, "Serializes the index to bytes."},
{"deserialize", (PyCFunction)py_an_deserialize, METH_VARARGS | METH_KEYWORDS, "Deserializes the index from bytes."},
{"get_nns_by_item",(PyCFunction)py_an_get_nns_by_item, METH_VARARGS | METH_KEYWORDS, "Returns the `n` closest items to item `i`.\n\n:param search_k: the query will inspect up to `search_k` nodes.\n`search_k` gives you a run-time tradeoff between better accuracy and speed.\n`search_k` defaults to `n_trees * n` if not provided.\n\n:param include_distances: If `True`, this function will return a\n2 element tuple of lists. The first list contains the `n` closest items.\nThe second list contains the corresponding distances."},
{"get_nns_by_vector",(PyCFunction)py_an_get_nns_by_vector, METH_VARARGS | METH_KEYWORDS, "Returns the `n` closest items to vector `vector`.\n\n:param search_k: the query will inspect up to `search_k` nodes.\n`search_k` gives you a run-time tradeoff between better accuracy and speed.\n`search_k` defaults to `n_trees * n` if not provided.\n\n:param include_distances: If `True`, this function will return a\n2 element tuple of lists. The first list contains the `n` closest items.\nThe second list contains the corresponding distances."},
{"get_item_vector",(PyCFunction)py_an_get_item_vector, METH_VARARGS, "Returns the vector for item `i` that was previously added."},
Expand Down
45 changes: 45 additions & 0 deletions test/annoy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,51 @@ func (suite *AnnoyTestSuite) TestFileHandling() {
os.Remove("go_test3.ann")
}

func (suite *AnnoyTestSuite) TestSerialization() {
index := annoyindex.NewAnnoyIndexAngular(3)
index.AddItem(0, []float32{0, 0, 1})
index.AddItem(1, []float32{0, 1, 0})
index.AddItem(2, []float32{1, 0, 0})
index.Build(10)

bytes := index.Serialize()

index2 := annoyindex.NewAnnoyIndexAngular(3)

success := index2.Deserialize(bytes)

if !success {
assert.Fail(suite.T(), "Failed to deserialize")
}

itemCountIsSame := index.GetNItems() == index2.GetNItems()

if !itemCountIsSame {
assert.Fail(suite.T(), "Item count is not the same")
}

var resultIndex []int
var resultIndex2 []int

itemCount := index.GetNItems()

index.GetNnsByItem(0, itemCount, -1, &resultIndex)
index2.GetNnsByItem(0, itemCount, -1, &resultIndex2)

itemsAreSame := true

for index := 0; index < itemCount; index++ {
if resultIndex[index] != resultIndex2[index] {
itemsAreSame = false
break
}
}

if !itemsAreSame {
assert.Fail(suite.T(), "Items are not the same")
}
}

func (suite *AnnoyTestSuite) TestOnDiskBuild() {
index := annoyindex.NewAnnoyIndexAngular(3)
index.OnDiskBuild("go_test.ann");
Expand Down
24 changes: 24 additions & 0 deletions test/annoy_test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,30 @@ describe("index test", function()
assert.same(u, y)
end)

it("serialize_deserialize", function()
local f = 2
local i = AnnoyIndex(f, 'euclidean')
i:add_item(0, {2, 2})
i:add_item(1, {3, 2})
i:add_item(2, {3, 3})
i:add_item(3, {4, 4})
i:add_item(4, {5, 5})
i:build(10)

local bytes = i:serialize()

local j = AnnoyIndex(f, 'euclidean')

j:deserialize(bytes)

local item_count = 4

local first_items = i:get_nns_by_item(0, item_count)
local second_items = j:get_nns_by_item(0, item_count)

assert.same(first_items, second_items)
end)

it("on_disk_build", function()
local f = 2
local i = AnnoyIndex(f, 'euclidean')
Expand Down
63 changes: 63 additions & 0 deletions test/serialize_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import random

from annoy import AnnoyIndex

def test_serialize_index():
f = 32

index = AnnoyIndex(f, 'angular')

for iteration in range(1000):
vector = [random.gauss(0, 1) for _ in range(f)]
index.add_item(iteration, vector)

index.build(10)

_ = index.serialize()


def test_deserialize_index():
f = 32

index = AnnoyIndex(f, 'angular')

for iteration in range(1000):
vector = [random.gauss(0, 1) for _ in range(f)]
index.add_item(iteration, vector)

index.build(10)

data = index.serialize()

index2 = AnnoyIndex(f, 'angular')

index2.deserialize(data)

index_item_count = index.get_n_items()

assert index_item_count == index2.get_n_items()
assert index.get_n_trees() == index2.get_n_trees()
assert index.get_nns_by_item(0, index_item_count) == index2.get_nns_by_item(0, index_item_count)

def test_serialize_after_load():
f = 32

index1 = AnnoyIndex(f, 'angular')

for iteration in range(1000):
vector = [random.gauss(0, 1) for _ in range(f)]
index1.add_item(iteration, vector)

index1.build(10)

save_path = "test/test.tree"
index1.save(save_path)

index2 = AnnoyIndex(f, 'angular')
index2.load(save_path)

assert index1.serialize() == index2.serialize()
assert index1.get_n_items() == index2.get_n_items()
assert index1.get_n_trees() == index2.get_n_trees()
assert index1.get_nns_by_item(0, index1.get_n_items()) == index2.get_nns_by_item(0, index1.get_n_items())