diff --git a/.gitignore b/.gitignore index 3e54ba1a..a1c3e257 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.egg-info/ *.egg/ +*.eggs/ *.so *.o build/ diff --git a/annoy/__init__.pyi b/annoy/__init__.pyi index 08adf4b9..c30d21aa 100644 --- a/annoy/__init__.pyi +++ b/annoy/__init__.pyi @@ -10,6 +10,8 @@ class AnnoyIndex: def __init__(self, f: int, metric: Literal["angular", "euclidean", "manhattan", "hamming", "dot"]) -> None: ... def load(self, fn: str, prefault: bool = ...) -> Literal[True]: ... def save(self, fn: str, prefault: bool = ...) -> Literal[True]: ... + def serialize(self) -> bytes: ... + def deserialize(self, data: bytes, prefault: bool = ...) -> Literal[True]: ... @overload def get_nns_by_item(self, i: int, n: int, search_k: int = ..., include_distances: Literal[False] = ...) -> list[int]: ... @overload diff --git a/src/annoygomodule.h b/src/annoygomodule.h index 074cc635..a4cad929 100644 --- a/src/annoygomodule.h +++ b/src/annoygomodule.h @@ -36,6 +36,15 @@ class AnnoyIndex { bool load(const char* filename) { return ptr->load(filename, true); }; + vector serialize() { + return ptr->serialize(); + } + bool deserialize(vector* v, bool prefault) { + return ptr->deserialize(v, prefault); + } + bool deserialize(vector* v) { + return ptr->deserialize(v, true); + } float getDistance(int i, int j) { return ptr->get_distance(i, j); }; diff --git a/src/annoylib.h b/src/annoylib.h index 657977cb..a657a714 100644 --- a/src/annoylib.h +++ b/src/annoylib.h @@ -912,6 +912,8 @@ class AnnoyIndexInterface { virtual bool save(const char* filename, bool prefault=false, char** error=NULL) = 0; virtual void unload() = 0; virtual bool load(const char* filename, bool prefault=false, char** error=NULL) = 0; + virtual vector serialize(char** error=NULL) const = 0; + virtual bool deserialize(vector* bytes, bool prefault=false, char** error=NULL) = 0; virtual T get_distance(S i, S j) const = 0; virtual void get_nns_by_item(S item, size_t n, int search_k, vector* result, vector* distances) const = 0; virtual void get_nns_by_vector(const T* w, size_t n, int search_k, vector* result, vector* distances) const = 0; @@ -1221,6 +1223,79 @@ template serialize(char** error=NULL) const { + if (!_built) { + set_error_from_string(error, "Index cannot be serialized if it hasn't been built"); + return {}; + } + + vector bytes {}; + + S n_items = _n_items; + S n_nodes = _n_nodes; + size_t roots_size = _roots.size(); + S nodes_size = _nodes_size; + + bytes.insert(bytes.end(), (uint8_t*)&n_items, (uint8_t*)&n_items + sizeof(n_items)); + bytes.insert(bytes.end(), (uint8_t*)&n_nodes, (uint8_t*)&n_nodes + sizeof(n_nodes)); + bytes.insert(bytes.end(), (uint8_t*)&roots_size, (uint8_t*)&roots_size + sizeof(roots_size)); + bytes.insert(bytes.end(), (uint8_t*)&nodes_size, (uint8_t*)&nodes_size + sizeof(nodes_size)); + + uint8_t* roots_buffer = (uint8_t*)_roots.data(); + bytes.insert(bytes.end(), roots_buffer, roots_buffer + _roots.size() * sizeof(S)); + + uint8_t* nodes_buffer = (uint8_t*)_nodes; + bytes.insert(bytes.end(), nodes_buffer, nodes_buffer + _nodes_size * _s); + + return bytes; + } + + bool deserialize(vector* bytes, bool prefault=false, char** error=NULL) { + if (bytes->empty()) { + set_error_from_errno(error, "Size of bytes is zero"); + return false; + } + + int flags = MAP_SHARED; + if (prefault) { +#ifdef MAP_POPULATE + flags |= MAP_POPULATE; +#else + annoylib_showUpdate("prefault is set to true, but MAP_POPULATE is not defined on this platform"); +#endif + } + + uint8_t* bytes_buffer = (uint8_t*)bytes->data(); + + _n_items = *(S*)bytes_buffer; + bytes_buffer += sizeof(S); + + S n_nodes = *(S*)bytes_buffer; + bytes_buffer += sizeof(S); + + size_t roots_size = *(size_t*)bytes_buffer; + bytes_buffer += sizeof(size_t); + + S nodes_size = *(S*)bytes_buffer; + bytes_buffer += sizeof(S); + + _roots.clear(); + _roots.resize(roots_size); + _roots.assign((S*) bytes_buffer, (S*) bytes_buffer + roots_size); + bytes_buffer += roots_size * sizeof(S); + + _allocate_size((S) nodes_size); + + memcpy(_nodes, bytes_buffer, nodes_size * _s); + + _n_nodes = n_nodes; + _loaded = true; + _built = true; + + if (_verbose) annoylib_showUpdate("found %zu roots with degree %d\n", _roots.size(), _n_items); + return true; + } + T get_distance(S i, S j) const { return D::normalized_distance(D::distance(_get(i), _get(j), _f)); } diff --git a/src/annoyluamodule.cc b/src/annoyluamodule.cc index a005df11..c638f4d3 100644 --- a/src/annoyluamodule.cc +++ b/src/annoyluamodule.cc @@ -164,6 +164,27 @@ class LuaAnnoy { return 1; } + static int serialize(lua_State* L) { + Impl* self = getAnnoy(L, 1); + int nargs = lua_gettop(L); + vector bytes = self->serialize(); + + lua_pushlstring(L, (const char*) bytes.data(), bytes.size()); + + return 1; + } + + static int deserialize(lua_State* L) { + Impl* self = getAnnoy(L, 1); + int nargs = lua_gettop(L); + const char* bytes_buffer = lua_tostring(L, 2); + size_t bytes_buffer_size = lua_rawlen(L, 2); + vector bytes(bytes_buffer, bytes_buffer + bytes_buffer_size); + self->deserialize(&bytes); + + return 1; + } + static int unload(lua_State* L) { Impl* self = getAnnoy(L, 1); self->unload(); @@ -260,6 +281,8 @@ class LuaAnnoy { {"build", &ThisClass::build}, {"save", &ThisClass::save}, {"load", &ThisClass::load}, + {"serialize", &ThisClass::serialize}, + {"deserialize", &ThisClass::deserialize}, {"unload", &ThisClass::unload}, {"get_nns_by_item", &ThisClass::get_nns_by_item}, {"get_nns_by_vector", &ThisClass::get_nns_by_vector}, diff --git a/src/annoymodule.cc b/src/annoymodule.cc index 6bb0ae1b..7705ed7b 100644 --- a/src/annoymodule.cc +++ b/src/annoymodule.cc @@ -16,6 +16,7 @@ #include "kissrandom.h" #include "Python.h" #include "structmember.h" +#include "bytesobject.h" #include #if defined(_MSC_VER) && _MSC_VER == 1500 typedef signed __int32 int32_t; @@ -96,6 +97,8 @@ class HammingWrapper : public AnnoyIndexInterface { bool save(const char* filename, bool prefault, char** error) { return _index.save(filename, prefault, error); }; void unload() { _index.unload(); }; bool load(const char* filename, bool prefault, char** error) { return _index.load(filename, prefault, error); }; + vector serialize(char** error) const { return _index.serialize(error); }; + bool deserialize(vector* bytes, bool prefault, char** error) { return _index.deserialize(bytes, prefault, error); }; float get_distance(int32_t i, int32_t j) const { return _index.get_distance(i, j); }; void get_nns_by_item(int32_t item, size_t n, int search_k, vector* result, vector* distances) const { if (distances) { @@ -235,6 +238,51 @@ py_an_save(py_annoy *self, PyObject *args, PyObject *kwargs) { Py_RETURN_TRUE; } +static PyObject * +py_an_serialize(py_annoy *self, PyObject *args, PyObject *kwargs) { + if (!self->ptr) + return NULL; + + vector bytes = self->ptr->serialize(NULL); + + return PyBytes_FromStringAndSize((const char*)bytes.data(), bytes.size()); +} + +static PyObject * +py_an_deserialize(py_annoy *self, PyObject *args, PyObject *kwargs) { + PyObject* bytes_object; + char *error; + bool prefault = false; + if (!self->ptr) + return NULL; + + static char const * kwlist[] = {"bytes", "prefault", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "S|b", (char**)kwlist, &bytes_object, &prefault)) + return NULL; + + if (bytes_object == NULL) { + PyErr_SetString(PyExc_TypeError, "Expected bytes"); + return NULL; + } + + if (!PyBytes_Check(bytes_object)) { + PyErr_SetString(PyExc_TypeError, "Expected bytes"); + return NULL; + } + + Py_ssize_t length = PyBytes_Size(bytes_object); + uint8_t* raw_bytes = (uint8_t*)PyBytes_AsString(bytes_object); + vector v(raw_bytes, raw_bytes + length); + + if (!self->ptr->deserialize(&v, prefault, &error)) { + PyErr_SetString(PyExc_IOError, error); + free(error); + return NULL; + } + + Py_RETURN_TRUE; +} PyObject* get_nns_to_python(const vector& result, const vector& distances, int include_distances) { @@ -575,6 +623,8 @@ py_an_set_seed(py_annoy *self, PyObject *args) { static PyMethodDef AnnoyMethods[] = { {"load", (PyCFunction)py_an_load, METH_VARARGS | METH_KEYWORDS, "Loads (mmaps) an index from disk."}, {"save", (PyCFunction)py_an_save, METH_VARARGS | METH_KEYWORDS, "Saves the index to disk."}, + {"serialize", (PyCFunction)py_an_serialize, METH_VARARGS | METH_KEYWORDS, "Serializes the index to bytes."}, + {"deserialize", (PyCFunction)py_an_deserialize, METH_VARARGS | METH_KEYWORDS, "Deserializes the index from bytes."}, {"get_nns_by_item",(PyCFunction)py_an_get_nns_by_item, METH_VARARGS | METH_KEYWORDS, "Returns the `n` closest items to item `i`.\n\n:param search_k: the query will inspect up to `search_k` nodes.\n`search_k` gives you a run-time tradeoff between better accuracy and speed.\n`search_k` defaults to `n_trees * n` if not provided.\n\n:param include_distances: If `True`, this function will return a\n2 element tuple of lists. The first list contains the `n` closest items.\nThe second list contains the corresponding distances."}, {"get_nns_by_vector",(PyCFunction)py_an_get_nns_by_vector, METH_VARARGS | METH_KEYWORDS, "Returns the `n` closest items to vector `vector`.\n\n:param search_k: the query will inspect up to `search_k` nodes.\n`search_k` gives you a run-time tradeoff between better accuracy and speed.\n`search_k` defaults to `n_trees * n` if not provided.\n\n:param include_distances: If `True`, this function will return a\n2 element tuple of lists. The first list contains the `n` closest items.\nThe second list contains the corresponding distances."}, {"get_item_vector",(PyCFunction)py_an_get_item_vector, METH_VARARGS, "Returns the vector for item `i` that was previously added."}, diff --git a/test/annoy_test.go b/test/annoy_test.go index bd0e569d..44505c6c 100644 --- a/test/annoy_test.go +++ b/test/annoy_test.go @@ -104,6 +104,51 @@ func (suite *AnnoyTestSuite) TestFileHandling() { os.Remove("go_test3.ann") } +func (suite *AnnoyTestSuite) TestSerialization() { + index := annoyindex.NewAnnoyIndexAngular(3) + index.AddItem(0, []float32{0, 0, 1}) + index.AddItem(1, []float32{0, 1, 0}) + index.AddItem(2, []float32{1, 0, 0}) + index.Build(10) + + bytes := index.Serialize() + + index2 := annoyindex.NewAnnoyIndexAngular(3) + + success := index2.Deserialize(bytes) + + if !success { + assert.Fail(suite.T(), "Failed to deserialize") + } + + itemCountIsSame := index.GetNItems() == index2.GetNItems() + + if !itemCountIsSame { + assert.Fail(suite.T(), "Item count is not the same") + } + + var resultIndex []int + var resultIndex2 []int + + itemCount := index.GetNItems() + + index.GetNnsByItem(0, itemCount, -1, &resultIndex) + index2.GetNnsByItem(0, itemCount, -1, &resultIndex2) + + itemsAreSame := true + + for index := 0; index < itemCount; index++ { + if resultIndex[index] != resultIndex2[index] { + itemsAreSame = false + break + } + } + + if !itemsAreSame { + assert.Fail(suite.T(), "Items are not the same") + } +} + func (suite *AnnoyTestSuite) TestOnDiskBuild() { index := annoyindex.NewAnnoyIndexAngular(3) index.OnDiskBuild("go_test.ann"); diff --git a/test/annoy_test.lua b/test/annoy_test.lua index 5e8d2e02..ef922bbd 100644 --- a/test/annoy_test.lua +++ b/test/annoy_test.lua @@ -496,6 +496,30 @@ describe("index test", function() assert.same(u, y) end) + it("serialize_deserialize", function() + local f = 2 + local i = AnnoyIndex(f, 'euclidean') + i:add_item(0, {2, 2}) + i:add_item(1, {3, 2}) + i:add_item(2, {3, 3}) + i:add_item(3, {4, 4}) + i:add_item(4, {5, 5}) + i:build(10) + + local bytes = i:serialize() + + local j = AnnoyIndex(f, 'euclidean') + + j:deserialize(bytes) + + local item_count = 4 + + local first_items = i:get_nns_by_item(0, item_count) + local second_items = j:get_nns_by_item(0, item_count) + + assert.same(first_items, second_items) + end) + it("on_disk_build", function() local f = 2 local i = AnnoyIndex(f, 'euclidean') diff --git a/test/serialize_test.py b/test/serialize_test.py new file mode 100644 index 00000000..3742d524 --- /dev/null +++ b/test/serialize_test.py @@ -0,0 +1,63 @@ +import random + +from annoy import AnnoyIndex + +def test_serialize_index(): + f = 32 + + index = AnnoyIndex(f, 'angular') + + for iteration in range(1000): + vector = [random.gauss(0, 1) for _ in range(f)] + index.add_item(iteration, vector) + + index.build(10) + + _ = index.serialize() + + +def test_deserialize_index(): + f = 32 + + index = AnnoyIndex(f, 'angular') + + for iteration in range(1000): + vector = [random.gauss(0, 1) for _ in range(f)] + index.add_item(iteration, vector) + + index.build(10) + + data = index.serialize() + + index2 = AnnoyIndex(f, 'angular') + + index2.deserialize(data) + + index_item_count = index.get_n_items() + + assert index_item_count == index2.get_n_items() + assert index.get_n_trees() == index2.get_n_trees() + assert index.get_nns_by_item(0, index_item_count) == index2.get_nns_by_item(0, index_item_count) + +def test_serialize_after_load(): + f = 32 + + index1 = AnnoyIndex(f, 'angular') + + for iteration in range(1000): + vector = [random.gauss(0, 1) for _ in range(f)] + index1.add_item(iteration, vector) + + index1.build(10) + + save_path = "test/test.tree" + index1.save(save_path) + + index2 = AnnoyIndex(f, 'angular') + index2.load(save_path) + + assert index1.serialize() == index2.serialize() + assert index1.get_n_items() == index2.get_n_items() + assert index1.get_n_trees() == index2.get_n_trees() + assert index1.get_nns_by_item(0, index1.get_n_items()) == index2.get_nns_by_item(0, index1.get_n_items()) +