Skip to content

Commit

Permalink
add example
Browse files Browse the repository at this point in the history
  • Loading branch information
uzushino committed Jan 19, 2024
1 parent f1960c0 commit 63619a3
Show file tree
Hide file tree
Showing 11 changed files with 265 additions and 18 deletions.
149 changes: 149 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,9 @@ path = "src/lib.rs"
libc = "0.2.150"

[build-dependencies]
cc = "1.0"
cc = "1.0"

[dev-dependencies]
mnist = "0.5"
rulinalg = "0.3"
rand = "*"
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,15 @@ assert!(result == vec![1, 2]);
assert!(distance == vec![0.0, 125.0]);
```

### mnist

```
$ cargo run --example mnist
```

## Feature

- [x] add_item
- [x] get_distance
- [x] query
- [x] query
- [x] save_index
Binary file added data/t10k-images-idx3-ubyte
Binary file not shown.
Binary file added data/t10k-labels-idx1-ubyte
Binary file not shown.
Binary file added data/train-images-idx3-ubyte
Binary file not shown.
Binary file added data/train-labels-idx1-ubyte
Binary file not shown.
89 changes: 89 additions & 0 deletions examples/mnist.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use mnist::{Mnist, MnistBuilder};
use rand::Rng;
use rulinalg::matrix::{BaseMatrix, Matrix};

fn load_mnist(
size: u32,
rows: u32,
cols: u32,
img: &Vec<u8>,
lbl: &Vec<u8>,
index: usize,
) -> (u8, Matrix<u8>) {
let img = Matrix::new((size * rows) as usize, cols as usize, img.clone());
let s = index * 28;
let e = s + 28;
let row_indexes = (s..e).collect::<Vec<_>>();
let img = img.select_rows(&row_indexes);

(lbl[index], img)
}

fn main() {
let (trn_size, tst_size, rows, cols) = (5_000, 5_000, 28, 28);

let Mnist {
trn_img,
trn_lbl,
tst_img,
tst_lbl,
..
} = MnistBuilder::new()
.label_format_digit()
.training_set_length(trn_size)
.test_set_length(tst_size)
.finalize();

let ann = voyager_rs::Voyager::new(28 * 28);
let mut rng = rand::thread_rng();

for i in 0..trn_size {
let (_, img) = load_mnist(trn_size, rows, cols, &trn_img, &trn_lbl, i as usize);

let img_to_vec = img
.data()
.clone()
.into_iter()
.map(|v| v as f32)
.collect::<Vec<_>>();

ann.add_item(&img_to_vec, None);

if i % 1_000 == 0 {
println!("Add item {}/{}.", i, trn_size);
}
}

for i in 0..10 {
let ti: u32 = rng.gen();
let (lbl, img) = load_mnist(
trn_size,
rows,
cols,
&tst_img,
&tst_lbl,
(ti % tst_size) as usize,
);

let img_to_vec = img
.data()
.clone()
.into_iter()
.map(|v| v as f32)
.collect::<Vec<_>>();

let (result, _distance) = ann.query(&img_to_vec, 1, None);
let actual = result
.into_iter()
.map(|v| trn_lbl[v as usize])
.collect::<Vec<_>>();

println!("TEST{}: expected: {}, actual: {:?}", i, lbl, actual);
if actual[0] != lbl {
let (_, trn) = load_mnist(10_000, 28, 28, &trn_img, &trn_lbl, lbl as usize);
let (_, tst) = load_mnist(10_000, 28, 28, &tst_img, &tst_lbl, actual[0] as usize);

println!("{}\n{}", trn, tst);
}
}
}
4 changes: 2 additions & 2 deletions src/c/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#include <TypedIndex.h>
#include "binding.hpp"

IndexVm* init_index() {
IndexVm* init_index(size_t num_dimensions) {
IndexVm* vm = new IndexVm {};
vm->index_ = std::make_shared<voyager::Index>(SpaceType::Euclidean, 5);
vm->index_ = std::make_shared<voyager::Index>(SpaceType::Euclidean, num_dimensions);
return vm;
}

Expand Down
2 changes: 1 addition & 1 deletion src/c/binding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ extern "C" {
std::shared_ptr<voyager::Index> index_;
} IndexVm;

IndexVm* init_index();
IndexVm* init_index(size_t num_dimensions);

void add_item(IndexVm* vm, float* item, size_t len, size_t is_some, size_t _id);

Expand Down
23 changes: 10 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ mod ffi {

#[link(name = "binding", kind = "static")]
extern "C" {
pub fn init_index() -> *mut Index;
pub fn init_index(num_dimensions: c_uint) -> *mut Index;

pub fn add_item(
index: *mut Index,
item: *const c_float,
len: usize,
is_some: c_int,
id: c_uint,
);

pub fn dispose(index: *mut Index);

pub fn query(
Expand Down Expand Up @@ -47,8 +49,8 @@ mod ffi {
pub struct Voyager(*mut Index);

impl Voyager {
pub fn new() -> Self {
let index = unsafe { ffi::init_index() };
pub fn new(num_dimensions: u32) -> Self {
let index = unsafe { ffi::init_index(num_dimensions) };
Voyager(index)
}

Expand Down Expand Up @@ -104,12 +106,6 @@ impl Voyager {
}
}

impl Default for Voyager {
fn default() -> Self {
Self::new()
}
}

impl Drop for Voyager {
fn drop(&mut self) {
unsafe {
Expand All @@ -124,7 +120,7 @@ mod test {

#[test]
fn test_voyager() {
let v = Voyager::new();
let v = Voyager::new(5);

let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0];
Expand All @@ -140,7 +136,7 @@ mod test {

#[test]
fn test_distance() {
let v = Voyager::new();
let v = Voyager::new(5);

let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0];
Expand All @@ -152,7 +148,7 @@ mod test {

#[test]
fn test_save() {
let v = Voyager::new();
let v = Voyager::new(5);

let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0];
Expand All @@ -167,7 +163,8 @@ mod test {
#[test]
fn test_runtime() {
unsafe {
let index = ffi::init_index();
let index = ffi::init_index(5);

let v1 = &[1.0, 2.0, 3.0, 4.0, 5.0];
let v2 = &[6.0, 7.0, 8.0, 9.0, 10.0];

Expand Down

0 comments on commit 63619a3

Please sign in to comment.