From 0e65e07ac4b22188b2fbd2900a052c016a1905f1 Mon Sep 17 00:00:00 2001 From: Chris Petersen Date: Sat, 23 Mar 2024 14:16:22 -0700 Subject: [PATCH 1/8] Embeddings have the correct dimensionality/rank now. Still working on them --- README.md | 7 +++++++ ext/candle/src/model/config.rs | 31 +++++++++++++++++++++++++++---- ext/candle/src/model/rb_tensor.rs | 1 - 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 5b97a03..e0b495e 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,12 @@ x = x.reshape([3, 2]) # Tensor[[3, 2], f32] ``` +```ruby +require 'candle' +model = Candle::Model.new +model.embedding("Hi there!") +``` + ## Development FORK IT! @@ -29,6 +35,7 @@ bundle bundle exec rake compile ``` + Implemented with [Magnus](https://github.com/matsadler/magnus), with reference to [Polars Ruby](https://github.com/ankane/polars-ruby) Policies diff --git a/ext/candle/src/model/config.rs b/ext/candle/src/model/config.rs index a1a3fe7..8e4fa25 100644 --- a/ext/candle/src/model/config.rs +++ b/ext/candle/src/model/config.rs @@ -113,20 +113,43 @@ impl ModelConfig { .map_err(wrap_std_err)? .get_ids() .to_vec(); + println!("TOKENS {:#?}", tokens); let token_ids = Tensor::new(&tokens[..], &self.0.device) .map_err(wrap_candle_err)? .unsqueeze(0) .map_err(wrap_candle_err)?; - println!("Loaded and encoded {:?}", start.elapsed()); + + // let token_ids = Tensor::stack(&token_ids, 0)? + // .map_err(wrap_candle_err)?; + println!("TOKEN IDS {:#?}", token_ids); let start: std::time::Instant = std::time::Instant::now(); let result = model.forward(&token_ids).map_err(wrap_candle_err)?; - // println!("{result}"); println!("Took {:?}", start.elapsed()); - Ok(result) + + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = result.dims3() + .map_err(wrap_candle_err)?; + let sum = result.sum(1) + .map_err(wrap_candle_err)?; + let embeddings = (sum / (n_tokens as f64)) + .map_err(wrap_candle_err)?; + println!("EMBEDDINGS {:#?}", embeddings); + let embeddings = Self::normalize_l2(&embeddings).map_err(wrap_candle_err)?; + // let embeddings = if args.normalize_embeddings { + // + // } else { + // embeddings + // }; + + Ok(embeddings) + } + + fn normalize_l2(v: &Tensor) -> Result { + v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) } pub fn __repr__(&self) -> String { - format!("Candle::Model(path={})", self.0.model_path.as_deref().unwrap_or("None")) + format!("# String { diff --git a/ext/candle/src/model/rb_tensor.rs b/ext/candle/src/model/rb_tensor.rs index 8456cd2..0121ea2 100644 --- a/ext/candle/src/model/rb_tensor.rs +++ b/ext/candle/src/model/rb_tensor.rs @@ -40,7 +40,6 @@ impl RbTensor { )) } - // FXIME: Do not use `to_f64` here. pub fn values(&self) -> RbResult> { let values = self .0 From bad5ef5f16c53b1266a2d6f526e0366aa23bcb08 Mon Sep 17 00:00:00 2001 From: Chris Petersen Date: Sat, 23 Mar 2024 14:43:52 -0700 Subject: [PATCH 2/8] Figured out how to get the tokenizer to pad the inputs correctly --- ext/candle/src/model/config.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ext/candle/src/model/config.rs b/ext/candle/src/model/config.rs index 8e4fa25..1141a39 100644 --- a/ext/candle/src/model/config.rs +++ b/ext/candle/src/model/config.rs @@ -92,10 +92,14 @@ impl ModelConfig { )) .get("tokenizer.json") .map_err(wrap_hf_err)?; - let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path) - // .with_padding(None) - // .with_truncation(None) + let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path) .map_err(wrap_std_err)?; + let pp = tokenizers::PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + Ok(tokenizer) } From b0fc1e80bc5f136719c853b6f477038cd81a8976 Mon Sep 17 00:00:00 2001 From: Chris Petersen Date: Sat, 23 Mar 2024 14:46:42 -0700 Subject: [PATCH 3/8] =?UTF-8?q?Removed=20the=20normalization=20and=20now?= =?UTF-8?q?=20it=20matches=20the=20python=20method.=20=F0=9F=8E=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ext/candle/src/model/config.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/candle/src/model/config.rs b/ext/candle/src/model/config.rs index 1141a39..77dc37b 100644 --- a/ext/candle/src/model/config.rs +++ b/ext/candle/src/model/config.rs @@ -138,7 +138,7 @@ impl ModelConfig { let embeddings = (sum / (n_tokens as f64)) .map_err(wrap_candle_err)?; println!("EMBEDDINGS {:#?}", embeddings); - let embeddings = Self::normalize_l2(&embeddings).map_err(wrap_candle_err)?; + // let embeddings = Self::normalize_l2(&embeddings).map_err(wrap_candle_err)?; // let embeddings = if args.normalize_embeddings { // // } else { From b237a3af423d66137c47f424cb4aa1afa1a2d7be Mon Sep 17 00:00:00 2001 From: Chris Petersen Date: Sat, 23 Mar 2024 14:57:45 -0700 Subject: [PATCH 4/8] Refactored ModelConfig into RbModel and ignored some dead_code warnings --- ext/candle/src/lib.rs | 10 +++--- ext/candle/src/model/mod.rs | 4 +-- .../src/model/{config.rs => rb_model.rs} | 31 ++++++------------- ext/candle/src/model/utils.rs | 2 ++ 4 files changed, 18 insertions(+), 29 deletions(-) rename ext/candle/src/model/{config.rs => rb_model.rs} (84%) diff --git a/ext/candle/src/lib.rs b/ext/candle/src/lib.rs index 33c5a5b..08c172f 100644 --- a/ext/candle/src/lib.rs +++ b/ext/candle/src/lib.rs @@ -1,6 +1,6 @@ use magnus::{function, method, prelude::*, Ruby}; -use crate::model::{candle_utils, ModelConfig, RbDType, RbDevice, RbQTensor, RbResult, RbTensor}; +use crate::model::{candle_utils, RbModel, RbDType, RbDevice, RbQTensor, RbResult, RbTensor}; pub mod model; @@ -93,10 +93,10 @@ fn init(ruby: &Ruby) -> RbResult<()> { rb_qtensor.define_method("dequantize", method!(RbQTensor::dequantize, 0))?; let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?; - rb_model.define_singleton_method("new", function!(ModelConfig::new, 0))?; - rb_model.define_method("embedding", method!(ModelConfig::embedding, 1))?; - rb_model.define_method("to_s", method!(ModelConfig::__str__, 0))?; - rb_model.define_method("inspect", method!(ModelConfig::__repr__, 0))?; + rb_model.define_singleton_method("new", function!(RbModel::new, 0))?; + rb_model.define_method("embedding", method!(RbModel::embedding, 1))?; + rb_model.define_method("to_s", method!(RbModel::__str__, 0))?; + rb_model.define_method("inspect", method!(RbModel::__repr__, 0))?; Ok(()) } diff --git a/ext/candle/src/model/mod.rs b/ext/candle/src/model/mod.rs index 40ef61d..f0ec326 100644 --- a/ext/candle/src/model/mod.rs +++ b/ext/candle/src/model/mod.rs @@ -1,5 +1,5 @@ -mod config; -pub use config::*; +mod rb_model; +pub use rb_model::*; mod errors; pub use errors::*; diff --git a/ext/candle/src/model/config.rs b/ext/candle/src/model/rb_model.rs similarity index 84% rename from ext/candle/src/model/config.rs rename to ext/candle/src/model/rb_model.rs index 77dc37b..b6ff1c0 100644 --- a/ext/candle/src/model/config.rs +++ b/ext/candle/src/model/rb_model.rs @@ -16,9 +16,9 @@ use crate::model::RbResult; use tokenizers::Tokenizer; #[magnus::wrap(class = "Candle::Model", free_immediately, size)] -pub struct ModelConfig(pub ModelConfigInner); +pub struct RbModel(pub RbModelInner); -pub struct ModelConfigInner { +pub struct RbModelInner { device: Device, tokenizer_path: Option, model_path: Option, @@ -26,14 +26,14 @@ pub struct ModelConfigInner { tokenizer: Option, } -impl ModelConfig { +impl RbModel { pub fn new() -> RbResult { Self::new2(Some("jinaai/jina-embeddings-v2-base-en".to_string()), Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu)) } pub fn new2(model_path: Option, tokenizer_path: Option, device: Option) -> RbResult { let device = device.unwrap_or(Device::Cpu); - Ok(ModelConfig(ModelConfigInner { + Ok(RbModel(RbModelInner { device: device.clone(), model_path: model_path.clone(), tokenizer_path: tokenizer_path.clone(), @@ -109,26 +109,18 @@ impl ModelConfig { model: &BertModel, tokenizer: &Tokenizer, ) -> Result { - let start: std::time::Instant = std::time::Instant::now(); - // let tokenizer_impl = tokenizer - // .map_err(wrap_std_err)?; let tokens = tokenizer .encode(prompt, true) .map_err(wrap_std_err)? .get_ids() .to_vec(); - println!("TOKENS {:#?}", tokens); let token_ids = Tensor::new(&tokens[..], &self.0.device) .map_err(wrap_candle_err)? .unsqueeze(0) .map_err(wrap_candle_err)?; - // let token_ids = Tensor::stack(&token_ids, 0)? - // .map_err(wrap_candle_err)?; - println!("TOKEN IDS {:#?}", token_ids); - let start: std::time::Instant = std::time::Instant::now(); + // let start: std::time::Instant = std::time::Instant::now(); let result = model.forward(&token_ids).map_err(wrap_candle_err)?; - println!("Took {:?}", start.elapsed()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = result.dims3() @@ -137,23 +129,18 @@ impl ModelConfig { .map_err(wrap_candle_err)?; let embeddings = (sum / (n_tokens as f64)) .map_err(wrap_candle_err)?; - println!("EMBEDDINGS {:#?}", embeddings); // let embeddings = Self::normalize_l2(&embeddings).map_err(wrap_candle_err)?; - // let embeddings = if args.normalize_embeddings { - // - // } else { - // embeddings - // }; Ok(embeddings) } + #[allow(dead_code)] fn normalize_l2(v: &Tensor) -> Result { v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) } pub fn __repr__(&self) -> String { - format!("# String { @@ -170,14 +157,14 @@ impl ModelConfig { // #[test] // fn test_build_model_and_tokenizer() { -// let config = super::ModelConfig::build(); +// let config = super::RbModel::build(); // let (_model, tokenizer) = config.build_model_and_tokenizer().unwrap(); // assert_eq!(tokenizer.get_vocab_size(true), 30522); // } // #[test] // fn test_embedding() { -// let config = super::ModelConfig::build(); +// let config = super::RbModel::build(); // // let (_model, tokenizer) = config.build_model_and_tokenizer().unwrap(); // // assert_eq!(config.embedding("Scientist.com is a marketplace for pharmaceutical services.")?, None); // } diff --git a/ext/candle/src/model/utils.rs b/ext/candle/src/model/utils.rs index dc065c6..81ea0f8 100644 --- a/ext/candle/src/model/utils.rs +++ b/ext/candle/src/model/utils.rs @@ -71,6 +71,7 @@ pub fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> { /// Applies the Softmax function to a given tensor.# /// &RETURNS&: Tensor +#[allow(dead_code)] fn softmax(tensor: RbTensor, dim: i64) -> RbResult { let dim = actual_dim(&tensor, dim).map_err(wrap_candle_err)?; let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_candle_err)?; @@ -79,6 +80,7 @@ fn softmax(tensor: RbTensor, dim: i64) -> RbResult { /// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. /// &RETURNS&: Tensor +#[allow(dead_code)] fn silu(tensor: RbTensor) -> RbResult { let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_candle_err)?; Ok(RbTensor(s)) From a7d47340f76d89c270fbc0bed448cfbbdb2ff73a Mon Sep 17 00:00:00 2001 From: Chris Petersen Date: Sat, 23 Mar 2024 16:01:43 -0700 Subject: [PATCH 5/8] Added elem_count and adding a thin ruby wrapper around tensor --- README.md | 2 +- ext/candle/src/lib.rs | 1 + ext/candle/src/model/rb_tensor.rs | 6 ++++++ lib/candle.rb | 1 + lib/candle/tensor.rb | 4 ++++ 5 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 lib/candle/tensor.rb diff --git a/README.md b/README.md index e0b495e..3cf25c5 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ x = x.reshape([3, 2]) ```ruby require 'candle' model = Candle::Model.new -model.embedding("Hi there!") +embedding = model.embedding("Hi there!") ``` ## Development diff --git a/ext/candle/src/lib.rs b/ext/candle/src/lib.rs index 08c172f..49e171b 100644 --- a/ext/candle/src/lib.rs +++ b/ext/candle/src/lib.rs @@ -22,6 +22,7 @@ fn init(ruby: &Ruby) -> RbResult<()> { rb_tensor.define_method("dtype", method!(RbTensor::dtype, 0))?; rb_tensor.define_method("device", method!(RbTensor::device, 0))?; rb_tensor.define_method("rank", method!(RbTensor::rank, 0))?; + rb_tensor.define_method("elem_count", method!(RbTensor::elem_count, 0))?; rb_tensor.define_method("sin", method!(RbTensor::sin, 0))?; rb_tensor.define_method("cos", method!(RbTensor::cos, 0))?; rb_tensor.define_method("log", method!(RbTensor::log, 0))?; diff --git a/ext/candle/src/model/rb_tensor.rs b/ext/candle/src/model/rb_tensor.rs index 0121ea2..b0c4370 100644 --- a/ext/candle/src/model/rb_tensor.rs +++ b/ext/candle/src/model/rb_tensor.rs @@ -82,6 +82,12 @@ impl RbTensor { self.0.rank() } + /// The number of elements stored in this tensor. + /// &RETURNS&: int + pub fn elem_count(&self) -> usize { + self.0.elem_count() + } + pub fn __repr__(&self) -> String { format!("{}", self.0) } diff --git a/lib/candle.rb b/lib/candle.rb index 886012a..0d82e32 100644 --- a/lib/candle.rb +++ b/lib/candle.rb @@ -1 +1,2 @@ require_relative 'candle/candle' +require_relative 'candle/tensor' diff --git a/lib/candle/tensor.rb b/lib/candle/tensor.rb new file mode 100644 index 0000000..c3889af --- /dev/null +++ b/lib/candle/tensor.rb @@ -0,0 +1,4 @@ +module Candle + class Tensor + end +end \ No newline at end of file From 48117aa363a5026e3bbd427184ec60c0c665cd5b Mon Sep 17 00:00:00 2001 From: Chris Petersen Date: Sat, 23 Mar 2024 16:02:42 -0700 Subject: [PATCH 6/8] Ignore the RustRover project file --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 83e47d2..2c7295b 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,9 @@ # Ignore Byebug command history file. .byebug_history +# Ignore the RustRover project file +.idea + ## Specific to RubyMotion: .dat* .repl_history @@ -64,4 +67,4 @@ target *.o *.lock -lib.py.rs \ No newline at end of file +lib.py.rs From a2d527ac25efcf10de39868b7e01148d8f68d3c2 Mon Sep 17 00:00:00 2001 From: Chris Petersen Date: Sun, 24 Mar 2024 14:43:16 -0700 Subject: [PATCH 7/8] Made tensor implement Enumerable --- lib/candle/tensor.rb | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/lib/candle/tensor.rb b/lib/candle/tensor.rb index c3889af..4b412db 100644 --- a/lib/candle/tensor.rb +++ b/lib/candle/tensor.rb @@ -1,4 +1,17 @@ module Candle class Tensor + include Enumerable + + def each + if self.rank == 1 + self.values.each do |value| + yield value + end + else + shape.first.times do |i| + yield self[i] + end + end + end end -end \ No newline at end of file +end From a65963ed2cda6cf983ffb5580b3451df4ce08ba0 Mon Sep 17 00:00:00 2001 From: Chris Petersen Date: Sun, 24 Mar 2024 15:17:20 -0700 Subject: [PATCH 8/8] Updating the README --- README.md | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/README.md b/README.md index 3cf25c5..d530fa0 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,44 @@ model = Candle::Model.new embedding = model.embedding("Hi there!") ``` +## A note on memory usage +The `Candle::Model` defaults to the `jinaai/jina-embeddings-v2-base-en` model with the `sentence-transformers/all-MiniLM-L6-v2` tokenizer (both from [HuggingFace](https://huggingface.co)). With this configuration the model takes a little more than 3GB of memory running on my Mac. The memory stays with the instantiated `Candle::Model` class, if you instantiate more than one, you'll use more memory. Likewise, if you let it go out of scope and call the garbage collector, you'll free the memory. For example: + +```ruby +> require 'candle' +# Ruby memory = 25.9 MB +> model = Candle::Model.new +# Ruby memory = 3.50 GB +> model2 = Candle::Model.new +# Ruby memory = 7.04 GB +> model2 = nil +> GC.start +# Ruby memory = 3.56 GB +> model = nil +> GC.start +# Ruby memory = 55.2 MB +``` + +## A note on returned embeddings + +The code should match the same embeddings when generated from the python `transformers` library. For instance, locally I was able to generate the same embedding for the text "Hi there!" using the python code: + +```python +from transformers import AutoModel +model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True) +sentence = ['Hi there!'] +embedding = model.encode(sentence) +print(embedding) +``` + +And the following ruby: + +```ruby +require 'candle' +model = Candle::Model.new +embedding = model.embedding("Hi there!") +``` + ## Development FORK IT!