Skip to content

Commit

Permalink
Merge pull request #78 from benjamin-james/master
Browse files Browse the repository at this point in the history
Added dot product from Annoy c++ source
  • Loading branch information
eddelbuettel authored Aug 4, 2024
2 parents b33c20b + ae3fd70 commit 5c10b71
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 0 deletions.
8 changes: 8 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
2024-08-03 Benjamin James <[email protected]>

* R/annoy.R: Add AnnoyDotProduct to namespace
* inst/tinytest/testDotProduct.R: Unit tests for 'AnnoyDotProduct'

* src/init.c: Added new dot product distance measure (via template)
* src/annoy.cpp: Added template and module for AnnoyDotProduct

2024-05-20 Dirk Eddelbuettel <[email protected]>

* README.md: Use tinyverse.netlify.app for dependency badge
Expand Down
1 change: 1 addition & 0 deletions R/annoy.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,4 @@ loadModule("AnnoyAngular", TRUE)
loadModule("AnnoyEuclidean", TRUE)
loadModule("AnnoyManhattan", TRUE)
loadModule("AnnoyHamming", TRUE)
loadModule("AnnoyDotProduct", TRUE)
76 changes: 76 additions & 0 deletions inst/tinytest/testDotProduct.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

suppressMessages(library(RcppAnnoy))

f <- 3
a <- new(AnnoyDotProduct, f)
a$addItem(0, c(0,0,1))
a$addItem(1, c(0,1,0))
a$addItem(2, c(1,0,0))
a$build(10)
checkEqual(a$getNNsByVector(c(3,2,1), 3), c(2,1,0), msg="getNNsByVector check 1")
checkEqual(a$getNNsByVector(c(1,2,3), 3), c(0,1,2), msg="getNNsByVector check 1")
checkEqual(a$getNNsByVector(c(2,0,1), 3), c(2,0,1), msg="getNNsByVector check 1")


f <- 3
a <- new(AnnoyDotProduct, f)
a$addItem(0, c(2,1,0))
a$addItem(1, c(1,2,0))
a$addItem(2, c(0,0,1))
a$build(10)
checkEqual(a$getNNsByItem(0, 3), c(0,1,2), msg="getNNsByItem check1")
checkEqual(a$getNNsByItem(1, 3), c(1,0,2), msg="getNNsByItem check2")


f <- 2
a <- new(AnnoyDotProduct, f)
a$addItem(0, c(0, 1))
a$addItem(1, c(1, 1))
checkEqual(a$getDistance(0, 1), 0 * 1 + 1 * 1, msg="distance 1", tolerance=1e-6)


f <- 2
a <- new(AnnoyDotProduct, f)
a$addItem(0, c(1000, 0))
a$addItem(1, c(10, 0))
checkEqual(a$getDistance(0, 1), 1000 * 10 + 0 * 0, msg="distance 2", tolerance=1e-6)


f <- 2
a <- new(AnnoyDotProduct, f)
a$addItem(0, c(97, 0))
a$addItem(1, c(42, 42))
d <- 97 * 42 + 0 * 42
checkEqual(a$getDistance(0, 1), d, msg="distance 3", tolerance=1.0e-6)


f <- 2
a <- new(AnnoyDotProduct, f)
a$addItem(0, c(1, 0))
a$addItem(1, c(0, 0))
checkEqual(a$getDistance(0, 1), 0, msg="distance 4", tolerance=1.0e-6)


## Generate pairs of random points where the pair is super close
f <- 10
a <- new(AnnoyDotProduct, f)
set.seed(123)
for (j in seq(0, 10000, by=2)) {
p <- rnorm(f)
f1 <- runif(1) + 1
f2 <- runif(1) + 1
x <- f1 * p + rnorm(f, 0, 1.0e-2)
y <- f2 * p + rnorm(f, 0, 1.0e-2)
a$addItem(j, x / norm(x, "2"))
a$addItem(j+1, y/norm(y, "2"))
}
a$build(10)
res <- TRUE
for (j in seq(0, 10000, by=2)) {
#expect_equal(a$getNNsByItem(j, 2), c(j, j+1), msg="getNNsByItem check1")
#expect_equal(a$getNNsByItem(j+1, 2), c(j+1, j), msg="getNNsByItem check1")
res <- res &&
all.equal(a$getNNsByItem(j, 2), c(j, j+1)) &&
all.equal(a$getNNsByItem(j+1, 2), c(j+1, j))
}
checkTrue(res)
31 changes: 31 additions & 0 deletions src/annoy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,42 @@ class Annoy

}

typedef Annoy::Annoy<int32_t, float, Annoy::DotProduct,Kiss64Random, RcppAnnoyIndexThreadPolicy> AnnoyDotProduct;
typedef Annoy::Annoy<int32_t, float, Annoy::Angular, Kiss64Random, RcppAnnoyIndexThreadPolicy> AnnoyAngular;
typedef Annoy::Annoy<int32_t, float, Annoy::Euclidean, Kiss64Random, RcppAnnoyIndexThreadPolicy> AnnoyEuclidean;
typedef Annoy::Annoy<int32_t, float, Annoy::Manhattan, Kiss64Random, RcppAnnoyIndexThreadPolicy> AnnoyManhattan;
typedef Annoy::Annoy<int32_t, uint64_t, Annoy::Hamming, Kiss64Random, RcppAnnoyIndexThreadPolicy> AnnoyHamming;

RCPP_EXPOSED_CLASS_NODECL(AnnoyDotProduct)
RCPP_MODULE(AnnoyDotProduct) {
Rcpp::class_<AnnoyDotProduct>("AnnoyDotProduct")

.constructor<int32_t>("constructor with integer count")

.method("addItem", &AnnoyDotProduct::addItem, "add item")
.method("build", &AnnoyDotProduct::callBuild, "build an index")
.method("unbuild", &AnnoyDotProduct::callUnbuild, "unbuild an index")
.method("save", &AnnoyDotProduct::callSave, "save index to file")
.method("load", &AnnoyDotProduct::callLoad, "load index from file")
.method("unload", &AnnoyDotProduct::callUnload, "unload index")
.method("getDistance", &AnnoyDotProduct::getDistance, "get distance between i and j")
.method("getNNsByItem", &AnnoyDotProduct::getNNsByItem,
"retrieve Nearest Neigbours given item")
.method("getNNsByItemList", &AnnoyDotProduct::getNNsByItemList,
"retrieve Nearest Neigbours given item")
.method("getNNsByVector", &AnnoyDotProduct::getNNsByVector,
"retrieve Nearest Neigbours given vector")
.method("getNNsByVectorList", &AnnoyDotProduct::getNNsByVectorList,
"retrieve Nearest Neigbours given vector")
.method("getItemsVector", &AnnoyDotProduct::getItemsVector, "retrieve item vector")
.method("getNItems", &AnnoyDotProduct::getNItems, "get number of items")
.method("getNTrees", &AnnoyDotProduct::getNTrees, "get number of trees")
.method("setVerbose", &AnnoyDotProduct::verbose, "set verbose")
.method("setSeed", &AnnoyDotProduct::setSeed, "set seed")
.method("onDiskBuild", &AnnoyDotProduct::onDiskBuild, "build in given file")
;
}

RCPP_EXPOSED_CLASS_NODECL(AnnoyAngular)
RCPP_MODULE(AnnoyAngular) {
Rcpp::class_<AnnoyAngular>("AnnoyAngular")
Expand Down
2 changes: 2 additions & 0 deletions src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/

/* .Call calls */
extern SEXP _rcpp_module_boot_AnnoyDotProduct(void);
extern SEXP _rcpp_module_boot_AnnoyAngular(void);
extern SEXP _rcpp_module_boot_AnnoyEuclidean(void);
extern SEXP _rcpp_module_boot_AnnoyManhattan(void);
Expand All @@ -16,6 +17,7 @@ extern SEXP _RcppAnnoy_getArchictectureStatus(void);
extern SEXP _RcppAnnoy_annoy_version(void);

static const R_CallMethodDef CallEntries[] = {
{"_rcpp_module_boot_AnnoyDotProduct",(DL_FUNC) &_rcpp_module_boot_AnnoyDotProduct,0},
{"_rcpp_module_boot_AnnoyAngular", (DL_FUNC) &_rcpp_module_boot_AnnoyAngular, 0},
{"_rcpp_module_boot_AnnoyEuclidean", (DL_FUNC) &_rcpp_module_boot_AnnoyEuclidean, 0},
{"_rcpp_module_boot_AnnoyManhattan", (DL_FUNC) &_rcpp_module_boot_AnnoyManhattan, 0},
Expand Down

0 comments on commit 5c10b71

Please sign in to comment.