Skip to content

Commit

Permalink
add const generic
Browse files Browse the repository at this point in the history
  • Loading branch information
uzushino committed Feb 19, 2024
1 parent e5d51ff commit a0ef15f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 29 deletions.
13 changes: 8 additions & 5 deletions examples/mnist.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use mnist::{Mnist, MnistBuilder};
use rand::Rng;
use rulinalg::matrix::{BaseMatrix, Matrix};
use std::convert::TryInto;

fn load_mnist(
size: u32,
Expand Down Expand Up @@ -34,7 +35,7 @@ fn main() {
.test_set_length(tst_size)
.finalize();

let ann = voyager_rs::Voyager::new(28 * 28);
let ann = voyager_rs::Voyager::new();
let mut rng = rand::thread_rng();

for i in 0..trn_size {
Expand All @@ -46,8 +47,9 @@ fn main() {
.into_iter()
.map(|v| v as f32)
.collect::<Vec<_>>();

ann.add_item(&img_to_vec, None);

let v: [f32; 28*28] = img_to_vec.try_into().unwrap();
ann.add_item(v, None);

if i % 1_000 == 0 {
println!("Add item {}/{}.", i, trn_size);
Expand All @@ -72,7 +74,8 @@ fn main() {
.map(|v| v as f32)
.collect::<Vec<_>>();

let (result, _distance) = ann.query(&img_to_vec, 1, None);
let v: [f32; 28*28] = img_to_vec.try_into().unwrap();
let (result, _distance) = ann.query(v, 1, None);
let actual = result
.into_iter()
.map(|v| trn_lbl[v as usize])
Expand All @@ -86,4 +89,4 @@ fn main() {
println!("{}\n{}", trn, tst);
}
}
}
}
53 changes: 29 additions & 24 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::ffi::CString;
use std::path::Path;

use libc::{c_float, c_int, c_uint, c_void};

pub enum Index {}
Expand Down Expand Up @@ -43,24 +42,30 @@ mod ffi {
}
}

pub struct Voyager(*mut Index);
pub struct Voyager<const N : usize> {
ix: *mut Index
}

impl<const N: usize> Voyager<N> {
pub fn new() -> Self {
let n = Voyager::<N>::dimension();
let index = unsafe { ffi::init_index(n as u32) };

impl Voyager {
pub fn new(n: u32) -> Self {
let index = unsafe { ffi::init_index(n) };
Voyager(index)
Voyager { ix: index }
}

pub fn add_item(&self, w: &[f32], id: Option<u32>) {
pub const fn dimension() -> usize { N }

pub fn add_item(&self, w: [f32; N], id: Option<u32>) {
let len = w.len();
let is_some: c_int = id.is_some() as c_int;

unsafe {
ffi::add_item(self.0, w.as_ptr(), len, is_some, id.unwrap_or(0));
ffi::add_item(self.ix, w.as_ptr(), len, is_some, id.unwrap_or(0));
}
}

pub fn query(&self, w: &[f32], k: i32, ef: Option<i32>) -> (Vec<usize>, Vec<f32>) {
pub fn query(&self, w: [f32; N], k: i32, ef: Option<i32>) -> (Vec<usize>, Vec<f32>) {
let len = w.len();

let mut result = Vec::with_capacity(k as usize);
Expand All @@ -71,7 +76,7 @@ impl Voyager {

unsafe {
ffi::query(
self.0,
self.ix,
w.as_ptr(),
len,
result_ptr,
Expand All @@ -87,26 +92,26 @@ impl Voyager {
(a.to_vec(), b.to_vec())
}

pub fn get_distance(&self, w1: &[f32], w2: &[f32]) -> f32 {
pub fn get_distance(&self, w1: [f32; N], w2: [f32; N]) -> f32 {
let len = w1.len();

unsafe { ffi::get_distance(self.0, w1.as_ptr(), w2.as_ptr(), len) }
unsafe { ffi::get_distance(self.ix, w1.as_ptr(), w2.as_ptr(), len) }
}

pub fn save<P: AsRef<Path>>(&self, path: P) {
unsafe {
if let Some(f) = path.as_ref().as_os_str().to_str() {
let path_str_c = CString::new(f).unwrap();
ffi::save_index(self.0, path_str_c.as_ptr() as *const c_void);
ffi::save_index(self.ix, path_str_c.as_ptr() as *const c_void);
}
}
}
}

impl Drop for Voyager {
impl<const N: usize> Drop for Voyager<N> {
fn drop(&mut self) {
unsafe {
ffi::dispose(self.0);
ffi::dispose(self.ix);
}
}
}
Expand All @@ -117,10 +122,10 @@ mod test {

#[test]
fn test_voyager() {
let v = Voyager::new(5);
let v = Voyager::new();

let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0];
let v1 = [1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = [6.0, 7.0, 8.0, 9.0, 10.0];

v.add_item(v1, Some(1));
v.add_item(v2, Some(2));
Expand All @@ -133,10 +138,10 @@ mod test {

#[test]
fn test_distance() {
let v = Voyager::new(5);
let v = Voyager::new();

let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0];
let v1 = [1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = [6.0, 7.0, 8.0, 9.0, 10.0];

let distance = v.get_distance(v1, v2);

Expand All @@ -145,10 +150,10 @@ mod test {

#[test]
fn test_save() {
let v = Voyager::new(5);
let v = Voyager::new();

let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0];
let v1 = [1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = [6.0, 7.0, 8.0, 9.0, 10.0];

v.add_item(v1, Some(1));
v.add_item(v2, Some(2));
Expand Down

0 comments on commit a0ef15f

Please sign in to comment.