Skip to content

Commit

Permalink
Added support for topk best tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
ryback123 committed Oct 28, 2024
1 parent 3c0db89 commit 2848e13
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 3 deletions.
184 changes: 181 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
mod utils;
use js_sys::{Array as JSArray, Float32Array as JSFloat32Array};
use js_sys::{Array as JSArray, Float32Array as JSFloat32Array, Object as JSObject};
use mel_spec::mel::mel;
use ndarray::{concatenate, s, Array1, Array2, Array3, Axis, Zip};
use ndarray_stats::QuantileExt;
use rustfft::{num_complex::Complex32, FftPlanner};
use std::cmp::min;
use std::collections::HashMap;
use std::io::{BufReader, Cursor};
use utils::Value;
use wasm_bindgen::prelude::*;
use wavers::{IntoNdarray, ReadSeek, Wav};

Expand Down Expand Up @@ -211,6 +213,19 @@ fn run_mask(audio: Array3<f32>, seq_len: Array1<usize>) -> Array3<f32> {
masked_signal
}

fn softmax(array: &Array3<f32>) -> Array3<f32> {
// Calculate softmax along the last axis (logits)
let max_values = array
.map_axis(Axis(2), |x| {
*x.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap()
})
.insert_axis(Axis(2));
let exp_array = (array - &max_values).mapv(f32::exp);
let sum_exp = exp_array.sum_axis(Axis(2)).insert_axis(Axis(2)); // Keep the dimensions

exp_array / sum_exp // Broadcasting division
}

fn get_argmax(arr: &Array3<f32>) -> Array2<usize> {
let dims = arr.dim();

Expand All @@ -225,16 +240,37 @@ fn get_argmax(arr: &Array3<f32>) -> Array2<usize> {
argmax
}

fn get_argmax_topk(arr: &mut Array3<f32>, topk: usize) -> Array3<(usize, f32)> {
let dims = arr.dim();

let mut argmax: Array3<(usize, f32)> =
Array3::from_shape_fn((dims.0, dims.1, topk), |_| (0, 0.0));

for batch in 0..dims.0 {
for k in 0..topk {
for timestep in 0..dims.1 {
let argmax_idx: usize = arr.slice(s![batch, timestep, ..]).argmax().unwrap();
let argmax_prob: f32 = arr[[batch, timestep, argmax_idx]];

argmax[[batch, timestep, k]] = (argmax_idx, argmax_prob);
arr[[batch, timestep, argmax_idx]] = 0.0;
}
}
}

argmax
}

fn merge_logprobs(arr: &Array2<usize>) -> Vec<Vec<usize>> {
let mut result = Vec::new();

// Iterate over the slices along the specified axis
for slice in arr.axis_iter(Axis(0)) {
for batch in arr.axis_iter(Axis(0)) {
let mut unique = Vec::new();
let mut previous = None;

// Iterate over the elements of the current slice
for &value in slice.iter() {
for &value in batch.iter() {
if Some(value) != previous {
unique.push(value);
previous = Some(value);
Expand All @@ -247,6 +283,29 @@ fn merge_logprobs(arr: &Array2<usize>) -> Vec<Vec<usize>> {
result
}

fn merge_logprobs_topk(arr: &Array3<(usize, f32)>) -> Vec<Vec<Vec<(usize, f32)>>> {
let mut result = Vec::new();

// Iterate over the slices along the specified axis
for batch in arr.axis_iter(Axis(0)) {
let mut unique = Vec::new();
let mut previous = None;

// Iterate over the elements of the current slice
for timestep in batch.axis_iter(Axis(0)) {
let (idx, _): (usize, f32) = timestep[[0]];
if Some(idx) != previous {
unique.push(timestep.to_owned().into_raw_vec());
previous = Some(idx);
}
}

result.push(unique);
}

result
}

fn get_text(vocab: &Vec<String>, indices: Vec<usize>) -> String {
let mut text = String::from("");
for i in indices {
Expand All @@ -262,6 +321,66 @@ fn get_text(vocab: &Vec<String>, indices: Vec<usize>) -> String {
text.replace("▁", " ")
}

fn get_text_topk(vocab: &Vec<String>, indices: Vec<Vec<(usize, f32)>>) -> HashMap<String, Value> {
let mut n_best_tokens = Vec::new();
let mut sent = String::new();

let mut cur_word = String::new();
let mut cur_tokens = Vec::new();
for tokens in indices {
let (highest_prob_token_id, _) = tokens[0];
let highest_prob_token = &vocab[highest_prob_token_id];

if highest_prob_token_id == vocab.len() - 1 {
continue;
}

if highest_prob_token.contains("▁") {
sent += cur_word.replace("▁", " ").as_str();

if cur_word != "" {
let word = HashMap::from([
("word".to_string(), Value::Text(cur_word.replace("▁", ""))),
("tokens".to_string(), Value::MapList(cur_tokens)),
]);

n_best_tokens.push(word);
}
cur_word = String::new();
cur_tokens = Vec::new();
}

// println!("{:?}", n_best_tokens);
cur_word += highest_prob_token;
if highest_prob_token == "▁" {
continue;
}

let mut tokens_topk = HashMap::new();
for (token, prob) in tokens {
tokens_topk.insert(vocab[token].to_string(), Value::Float(prob));
}

cur_tokens.push(tokens_topk);
}

// for the remaining tokens
sent += cur_word.replace("▁", " ").as_str();
if cur_word != "" {
let word = HashMap::from([
("word".to_string(), Value::Text(cur_word.replace("▁", ""))),
("tokens".to_string(), Value::MapList(cur_tokens)),
]);

n_best_tokens.push(word);
}

HashMap::from([
("text".to_string(), Value::Text(sent)),
("nBestTokens".to_string(), Value::MapList(n_best_tokens)),
])
}

fn array3_to_js_array(audio: Array3<f32>) -> JSArray {
let audio_array: JSArray = JSArray::new();
let (n, m, _) = audio.dim();
Expand Down Expand Up @@ -298,6 +417,31 @@ fn js_array_to_array3(
concatenate(Axis(2), &[logits_arr.view(), blanks_arr.view()]).unwrap()
}

fn hashmap_to_jsmap(topk_map: &HashMap<String, Value>) -> JsValue {
// Create a JS object
let js_object = JSObject::new();

// Insert values into the JS object
for (key, value) in topk_map {
let js_value = match value {
Value::Text(s) => JsValue::from(s.to_owned()),
Value::Float(f) => JsValue::from(f.to_owned()),
// Value::List(v) => JsValue::from_serde(&v).unwrap(),
Value::MapList(m) => JsValue::from(
m.iter()
.map(|x| hashmap_to_jsmap(x))
.collect::<Vec<JsValue>>(),
),
};
// let js_value = serde_wasm_bindgen::to_value(&value).unwrap();

// Use Reflect to set properties on the JS object
js_sys::Reflect::set(&js_object, &JsValue::from(key), &js_value).unwrap();
}

JsValue::from(js_object)
}

#[wasm_bindgen]
pub fn run_preprocessor(audio_file: &[u8]) -> JSArray {
#[cfg(debug_assertions)]
Expand Down Expand Up @@ -358,3 +502,37 @@ pub fn decode_logprobs(

text
}

#[wasm_bindgen]
pub fn decode_logprobs_topk(
logits: JSArray,
shape: &[usize],
vocab_arr: JSArray,
offset: usize,
actual_vocab_size: usize,
topk: usize,
) -> JSArray {
let vocab_start = offset * actual_vocab_size;
let vocab_end = vocab_start + actual_vocab_size;

let logits_arr3 = js_array_to_array3(logits, shape, vocab_start, vocab_end);
let mut logprobs = softmax(&logits_arr3);

let argmax = get_argmax_topk(&mut logprobs, topk);
let indices_batch = merge_logprobs_topk(&argmax);

let mut vocab: Vec<String> = vocab_arr.to_vec()[vocab_start..vocab_end]
.iter()
.map(|a| a.as_string().unwrap())
.collect();

vocab.push(String::from(""));

let text: JSArray = JSArray::new();
for indices in indices_batch {
let t = get_text_topk(&vocab, indices);
text.push(&hashmap_to_jsmap(&t));
}

text
}
8 changes: 8 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

pub fn set_panic_hook() {
// When the `console_error_panic_hook` feature is enabled, we can call the
// `set_panic_hook` function at least once during initialization, and then
Expand All @@ -9,3 +11,9 @@ pub fn set_panic_hook() {
#[cfg(debug_assertions)]
console_error_panic_hook::set_once();
}

pub enum Value {
Text(String),
MapList(Vec<HashMap<String, Value>>),
Float(f32),
}

0 comments on commit 2848e13

Please sign in to comment.