Skip to content

Commit

Permalink
add get_distance
Browse files Browse the repository at this point in the history
  • Loading branch information
uzushino committed Dec 8, 2023
1 parent 160af0f commit 853fe37
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/c/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,11 @@ void query(
std::copy(distancesV.begin(), distancesV.end(), distances);

return ;
}

float get_distance(IndexVm* vm, float* a, float* b, size_t len) {
std::vector<float> v1(a, a + len);
std::vector<float> v2(b, b + len);

return vm->index_->GetDistance(v1, v2);
}
2 changes: 2 additions & 0 deletions src/c/binding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ extern "C" {
void query(IndexVm* vm, float* item, size_t len, size_t *result, float *distances,
int k = 1, long queryEf = -1);

float get_distance(IndexVm* vm, float* a, float* b, size_t len);

#ifdef __cplusplus
}
#endif
4 changes: 4 additions & 0 deletions src/c/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ namespace voyager {
return index_->query(queryVectors, k, queryEf);
}

float GetDistance(std::vector<float> v1, std::vector<float> v2) {
return index_->getDistance(v1, v2);
}

private:
std::shared_ptr<::Index> index_;

Expand Down
28 changes: 25 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ mod ffi {
k: c_int,
query_ef: c_int,
);

pub fn get_distance(index: *mut Index, item1: *const c_float, item2: *const c_float, len: usize) -> c_float;
}
}

Expand All @@ -46,7 +48,7 @@ impl Voyager {
}
}

pub fn query(&self, w: &[f32], k: i32) -> (Vec<usize>, Vec<f32>) {
pub fn query(&self, w: &[f32], k: i32, ef: Option<i32>) -> (Vec<usize>, Vec<f32>) {
let len = w.len();

let mut result = Vec::with_capacity(k as usize);
Expand All @@ -56,14 +58,22 @@ impl Voyager {
let distance_ptr = distance.as_mut_ptr();

unsafe {
ffi::query(self.0, w.as_ptr(), len, result_ptr, distance_ptr, k, -1);
ffi::query(self.0, w.as_ptr(), len, result_ptr, distance_ptr, k, ef.unwrap_or(-1));
}

let a = unsafe { std::slice::from_raw_parts_mut(result_ptr, k as usize) };
let b = unsafe { std::slice::from_raw_parts_mut(distance_ptr, k as usize) };

(a.to_vec(), b.to_vec())
}

pub fn get_distance(&self, w1: &[f32], w2: &[f32]) -> f32 {
let len = w1.len();

unsafe {
ffi::get_distance(self.0, w1.as_ptr(), w2.as_ptr(), len)
}
}
}

impl Default for Voyager {
Expand Down Expand Up @@ -94,12 +104,24 @@ mod test {
v.add_item(v1, Some(1));
v.add_item(v2, Some(2));

let (result, distance) = v.query(v1, 2);
let (result, distance) = v.query(v1, 2, None);

assert!(result == vec![1, 2]);
assert!(distance == vec![0.0, 125.0]);
}

#[test]
fn test_distance() {
let v = Voyager::new();

let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0];

let distance = v.get_distance(v1, v2);

assert!(distance == 125.0);
}

#[test]
fn test_runtime() {
unsafe {
Expand Down

0 comments on commit 853fe37

Please sign in to comment.