From a0ef15f89a4cf91141a47a59e8ef773f63e04718 Mon Sep 17 00:00:00 2001 From: uzushino Date: Mon, 19 Feb 2024 22:37:20 +0900 Subject: [PATCH] add const generic --- examples/mnist.rs | 13 +++++++----- src/lib.rs | 53 ++++++++++++++++++++++++++--------------------- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/examples/mnist.rs b/examples/mnist.rs index 33dd98b..01838c3 100644 --- a/examples/mnist.rs +++ b/examples/mnist.rs @@ -1,6 +1,7 @@ use mnist::{Mnist, MnistBuilder}; use rand::Rng; use rulinalg::matrix::{BaseMatrix, Matrix}; +use std::convert::TryInto; fn load_mnist( size: u32, @@ -34,7 +35,7 @@ fn main() { .test_set_length(tst_size) .finalize(); - let ann = voyager_rs::Voyager::new(28 * 28); + let ann = voyager_rs::Voyager::new(); let mut rng = rand::thread_rng(); for i in 0..trn_size { @@ -46,8 +47,9 @@ fn main() { .into_iter() .map(|v| v as f32) .collect::>(); - - ann.add_item(&img_to_vec, None); + + let v: [f32; 28*28] = img_to_vec.try_into().unwrap(); + ann.add_item(v, None); if i % 1_000 == 0 { println!("Add item {}/{}.", i, trn_size); @@ -72,7 +74,8 @@ fn main() { .map(|v| v as f32) .collect::>(); - let (result, _distance) = ann.query(&img_to_vec, 1, None); + let v: [f32; 28*28] = img_to_vec.try_into().unwrap(); + let (result, _distance) = ann.query(v, 1, None); let actual = result .into_iter() .map(|v| trn_lbl[v as usize]) @@ -86,4 +89,4 @@ fn main() { println!("{}\n{}", trn, tst); } } -} +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 2583908..220fe5e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,5 @@ use std::ffi::CString; use std::path::Path; - use libc::{c_float, c_int, c_uint, c_void}; pub enum Index {} @@ -43,24 +42,30 @@ mod ffi { } } -pub struct Voyager(*mut Index); +pub struct Voyager { + ix: *mut Index +} + +impl Voyager { + pub fn new() -> Self { + let n = Voyager::::dimension(); + let index = unsafe { ffi::init_index(n as u32) }; -impl Voyager { - pub fn new(n: u32) -> Self { - let index = unsafe { ffi::init_index(n) }; - Voyager(index) + Voyager { ix: index } } - pub fn add_item(&self, w: &[f32], id: Option) { + pub const fn dimension() -> usize { N } + + pub fn add_item(&self, w: [f32; N], id: Option) { let len = w.len(); let is_some: c_int = id.is_some() as c_int; unsafe { - ffi::add_item(self.0, w.as_ptr(), len, is_some, id.unwrap_or(0)); + ffi::add_item(self.ix, w.as_ptr(), len, is_some, id.unwrap_or(0)); } } - pub fn query(&self, w: &[f32], k: i32, ef: Option) -> (Vec, Vec) { + pub fn query(&self, w: [f32; N], k: i32, ef: Option) -> (Vec, Vec) { let len = w.len(); let mut result = Vec::with_capacity(k as usize); @@ -71,7 +76,7 @@ impl Voyager { unsafe { ffi::query( - self.0, + self.ix, w.as_ptr(), len, result_ptr, @@ -87,26 +92,26 @@ impl Voyager { (a.to_vec(), b.to_vec()) } - pub fn get_distance(&self, w1: &[f32], w2: &[f32]) -> f32 { + pub fn get_distance(&self, w1: [f32; N], w2: [f32; N]) -> f32 { let len = w1.len(); - unsafe { ffi::get_distance(self.0, w1.as_ptr(), w2.as_ptr(), len) } + unsafe { ffi::get_distance(self.ix, w1.as_ptr(), w2.as_ptr(), len) } } pub fn save>(&self, path: P) { unsafe { if let Some(f) = path.as_ref().as_os_str().to_str() { let path_str_c = CString::new(f).unwrap(); - ffi::save_index(self.0, path_str_c.as_ptr() as *const c_void); + ffi::save_index(self.ix, path_str_c.as_ptr() as *const c_void); } } } } -impl Drop for Voyager { +impl Drop for Voyager { fn drop(&mut self) { unsafe { - ffi::dispose(self.0); + ffi::dispose(self.ix); } } } @@ -117,10 +122,10 @@ mod test { #[test] fn test_voyager() { - let v = Voyager::new(5); + let v = Voyager::new(); - let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0]; - let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0]; + let v1 = [1.0, 2.0, 3.0, 4.0, 5.0]; + let v2 = [6.0, 7.0, 8.0, 9.0, 10.0]; v.add_item(v1, Some(1)); v.add_item(v2, Some(2)); @@ -133,10 +138,10 @@ mod test { #[test] fn test_distance() { - let v = Voyager::new(5); + let v = Voyager::new(); - let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0]; - let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0]; + let v1 = [1.0, 2.0, 3.0, 4.0, 5.0]; + let v2 = [6.0, 7.0, 8.0, 9.0, 10.0]; let distance = v.get_distance(v1, v2); @@ -145,10 +150,10 @@ mod test { #[test] fn test_save() { - let v = Voyager::new(5); + let v = Voyager::new(); - let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0]; - let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0]; + let v1 = [1.0, 2.0, 3.0, 4.0, 5.0]; + let v2 = [6.0, 7.0, 8.0, 9.0, 10.0]; v.add_item(v1, Some(1)); v.add_item(v2, Some(2));