Skip to content

Commit

Permalink
add query
Browse files Browse the repository at this point in the history
  • Loading branch information
uzushino committed Dec 8, 2023
1 parent 77dd212 commit 2e5bf5e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 15 deletions.
7 changes: 5 additions & 2 deletions src/c/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ IndexVm* init_index() {
return vm;
}

void add_item(IndexVm* vm, float* vec_, size_t len, size_t _id) {
void add_item(IndexVm* vm, float* vec_, size_t len, size_t is_some, size_t _id) {
std::vector<float> v(vec_, vec_ + len);
vm->index_->AddItem(v, _id);
std::optional<size_t> id = is_some ? std::optional<size_t>(_id) : std::nullopt;

vm->index_->AddItem(v, id);

return ;
}

Expand Down
2 changes: 1 addition & 1 deletion src/c/binding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extern "C" {

IndexVm* init_index();

void add_item(IndexVm* vm, float* item, size_t len, size_t _id);
void add_item(IndexVm* vm, float* item, size_t len, size_t is_some, size_t _id);

void dispose(IndexVm* vm);

Expand Down
62 changes: 51 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@ mod ffi {
#[link(name = "binding", kind = "static")]
extern "C" {
pub fn init_index() -> *mut Index;
pub fn add_item(index: *mut Index, item: *const c_float, len: usize, size: c_uint);
pub fn add_item(
index: *mut Index,
item: *const c_float,
len: usize,
is_some: c_int,
id: c_uint,
);
pub fn dispose(index: *mut Index);

#[allow(clippy::all, dead_code)]
pub fn query(
index: *mut Index,
input: *const c_float,
Expand All @@ -24,28 +29,47 @@ mod ffi {
}
}

pub struct Voyager(usize, *mut Index);
pub struct Voyager(*mut Index);

impl Voyager {
pub fn new(n: usize) -> Self {
pub fn new() -> Self {
let index = unsafe { ffi::init_index() };
Voyager(n, index)
Voyager(index)
}

pub fn add_item(&self, w: &[f32], id: Option<u32>) {
let len = w.len();
let is_some: c_int = id.is_some() as c_int;

unsafe {
ffi::add_item(self.0, w.as_ptr(), len, is_some, id.unwrap_or(0));
}
}

pub fn add_item(&self, w: &[f32]) {
pub fn query(&self, w: &[f32], k: i32) -> (Vec<usize>, Vec<f32>) {
let len = w.len();
let size = self.0;

let mut result = Vec::with_capacity(k as usize);
let result_ptr = result.as_mut_ptr();

let mut distance = Vec::with_capacity(k as usize);
let distance_ptr = distance.as_mut_ptr();

unsafe {
ffi::add_item(self.1, w.as_ptr(), len, size as c_uint);
ffi::query(self.0, w.as_ptr(), len, result_ptr, distance_ptr, k, -1);
}

let a = unsafe { std::slice::from_raw_parts_mut(result_ptr, k as usize) };
let b = unsafe { std::slice::from_raw_parts_mut(distance_ptr, k as usize) };

(a.to_vec(), b.to_vec())
}
}

impl Drop for Voyager {
fn drop(&mut self) {
unsafe {
ffi::dispose(self.1);
ffi::dispose(self.0);
}
}
}
Expand All @@ -54,6 +78,22 @@ impl Drop for Voyager {
mod test {
use super::*;

#[test]
fn test_voyager() {
let v = Voyager::new();

let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0];

v.add_item(v1, Some(1));
v.add_item(v2, Some(2));

let (result, distance) = v.query(v1, 2);

assert!(result == vec![1, 2]);
assert!(distance == vec![0.0, 125.0]);
}

#[test]
fn test_runtime() {
unsafe {
Expand All @@ -67,8 +107,8 @@ mod test {
let mut distance = Vec::with_capacity(2);
let distance_ptr = distance.as_mut_ptr();

ffi::add_item(index, v1.as_ptr(), v1.len(), 0);
ffi::add_item(index, v2.as_ptr(), v2.len(), 1);
ffi::add_item(index, v1.as_ptr(), v1.len(), 1, 0);
ffi::add_item(index, v2.as_ptr(), v2.len(), 1, 1);

ffi::query(
index,
Expand Down

0 comments on commit 2e5bf5e

Please sign in to comment.