Skip to content

Commit

Permalink
Updated decoder function to support multisoftmax
Browse files Browse the repository at this point in the history
  • Loading branch information
ryback123 committed Oct 22, 2024
1 parent e0c7296 commit 3c0db89
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ ndarray = "0.15.6"
ndarray-stats = "0.5.1"
getrandom = { version = "0.2", features = ["js"] }
wavers = {version = "1.4.3", features = ["ndarray"]}
js-sys = "0.3.69"

# The `console_error_panic_hook` crate provides better debugging of panics by
# logging them with `console.error`. This is great for development, but requires
# all the `std::fmt` and `std::panicking` infrastructure, so isn't great for
# code size when deploying.
console_error_panic_hook = { version = "0.1.7", optional = true }
js-sys = "0.3.69"

[dev-dependencies]
wasm-bindgen-test = "0.3.34"
Expand Down
36 changes: 26 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
mod utils;
use js_sys::{Array as JSArray, Float32Array as JSFloat32Array};
use mel_spec::mel::mel;
use ndarray::{concatenate, s, Array1, Array2, Array3, ArrayD, Axis, Dimension, Zip};
use ndarray::{concatenate, s, Array1, Array2, Array3, Axis, Zip};
use ndarray_stats::QuantileExt;
use rustfft::{num_complex::Complex32, FftPlanner};
use std::clone;
use std::cmp::min;
use std::convert::TryInto;
use std::io::{BufReader, Cursor};
use wasm_bindgen::prelude::*;
use wavers::{IntoNdarray, ReadSeek, Wav};
Expand Down Expand Up @@ -280,7 +278,12 @@ fn array3_to_js_array(audio: Array3<f32>) -> JSArray {
audio_array
}

fn js_array_to_array3(arr: JSArray, shape: &[usize]) -> Array3<f32> {
fn js_array_to_array3(
arr: JSArray,
shape: &[usize],
vocab_start: usize,
vocab_end: usize,
) -> Array3<f32> {
let arr3 = Array3::from_shape_vec(
(shape[0], shape[1], shape[2]),
arr.to_vec()
Expand All @@ -290,7 +293,9 @@ fn js_array_to_array3(arr: JSArray, shape: &[usize]) -> Array3<f32> {
)
.unwrap();

arr3
let logits_arr = arr3.slice(s![.., .., vocab_start..vocab_end]).to_owned();
let blanks_arr = arr3.slice(s![.., .., (shape[2] - 1)..]).to_owned();
concatenate(Axis(2), &[logits_arr.view(), blanks_arr.view()]).unwrap()
}

#[wasm_bindgen]
Expand Down Expand Up @@ -323,21 +328,32 @@ pub fn run_preprocessor(audio_file: &[u8]) -> JSArray {
}

#[wasm_bindgen]
pub fn decode_logprobs(logprobs: JSArray, shape: &[usize], vocab_arr: JSArray) -> JSArray {
let arr = js_array_to_array3(logprobs, shape);
pub fn decode_logprobs(
logprobs: JSArray,
shape: &[usize],
vocab_arr: JSArray,
offset: usize,
actual_vocab_size: usize,
) -> JSArray {
let vocab_start = offset * actual_vocab_size;
let vocab_end = vocab_start + actual_vocab_size;

let arr = js_array_to_array3(logprobs, shape, vocab_start, vocab_end);

let argmax = get_argmax(&arr);
let indices_batch = merge_logprobs(&argmax);

let vocab: Vec<String> = vocab_arr
.to_vec()
let mut vocab: Vec<String> = vocab_arr.to_vec()[vocab_start..vocab_end]
.iter()
.map(|a| a.as_string().unwrap())
.collect();

vocab.push(String::from("b"));

let text: JSArray = JSArray::new();
for indices in indices_batch {
let t = get_text(&vocab, indices);
text.push(&JsValue::from_str(&t.as_str()));
text.push(&JsValue::from_str(&t.as_str().trim()));
}

text
Expand Down

0 comments on commit 3c0db89

Please sign in to comment.