diff --git a/Cargo.toml b/Cargo.toml index 10dcd57..b80b15d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,13 +20,13 @@ ndarray = "0.15.6" ndarray-stats = "0.5.1" getrandom = { version = "0.2", features = ["js"] } wavers = {version = "1.4.3", features = ["ndarray"]} +js-sys = "0.3.69" # The `console_error_panic_hook` crate provides better debugging of panics by # logging them with `console.error`. This is great for development, but requires # all the `std::fmt` and `std::panicking` infrastructure, so isn't great for # code size when deploying. console_error_panic_hook = { version = "0.1.7", optional = true } -js-sys = "0.3.69" [dev-dependencies] wasm-bindgen-test = "0.3.34" diff --git a/src/lib.rs b/src/lib.rs index 4863671..947a562 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,10 @@ mod utils; use js_sys::{Array as JSArray, Float32Array as JSFloat32Array}; use mel_spec::mel::mel; -use ndarray::{concatenate, s, Array1, Array2, Array3, ArrayD, Axis, Dimension, Zip}; +use ndarray::{concatenate, s, Array1, Array2, Array3, Axis, Zip}; use ndarray_stats::QuantileExt; use rustfft::{num_complex::Complex32, FftPlanner}; -use std::clone; use std::cmp::min; -use std::convert::TryInto; use std::io::{BufReader, Cursor}; use wasm_bindgen::prelude::*; use wavers::{IntoNdarray, ReadSeek, Wav}; @@ -280,7 +278,12 @@ fn array3_to_js_array(audio: Array3) -> JSArray { audio_array } -fn js_array_to_array3(arr: JSArray, shape: &[usize]) -> Array3 { +fn js_array_to_array3( + arr: JSArray, + shape: &[usize], + vocab_start: usize, + vocab_end: usize, +) -> Array3 { let arr3 = Array3::from_shape_vec( (shape[0], shape[1], shape[2]), arr.to_vec() @@ -290,7 +293,9 @@ fn js_array_to_array3(arr: JSArray, shape: &[usize]) -> Array3 { ) .unwrap(); - arr3 + let logits_arr = arr3.slice(s![.., .., vocab_start..vocab_end]).to_owned(); + let blanks_arr = arr3.slice(s![.., .., (shape[2] - 1)..]).to_owned(); + concatenate(Axis(2), &[logits_arr.view(), blanks_arr.view()]).unwrap() } #[wasm_bindgen] @@ -323,21 +328,32 @@ pub fn run_preprocessor(audio_file: &[u8]) -> JSArray { } #[wasm_bindgen] -pub fn decode_logprobs(logprobs: JSArray, shape: &[usize], vocab_arr: JSArray) -> JSArray { - let arr = js_array_to_array3(logprobs, shape); +pub fn decode_logprobs( + logprobs: JSArray, + shape: &[usize], + vocab_arr: JSArray, + offset: usize, + actual_vocab_size: usize, +) -> JSArray { + let vocab_start = offset * actual_vocab_size; + let vocab_end = vocab_start + actual_vocab_size; + + let arr = js_array_to_array3(logprobs, shape, vocab_start, vocab_end); + let argmax = get_argmax(&arr); let indices_batch = merge_logprobs(&argmax); - let vocab: Vec = vocab_arr - .to_vec() + let mut vocab: Vec = vocab_arr.to_vec()[vocab_start..vocab_end] .iter() .map(|a| a.as_string().unwrap()) .collect(); + vocab.push(String::from("b")); + let text: JSArray = JSArray::new(); for indices in indices_batch { let t = get_text(&vocab, indices); - text.push(&JsValue::from_str(&t.as_str())); + text.push(&JsValue::from_str(&t.as_str().trim())); } text