diff --git a/packages/stt-realtime-webgpu/src/libs/worker.js b/packages/stt-realtime-webgpu/src/libs/worker.ts similarity index 75% rename from packages/stt-realtime-webgpu/src/libs/worker.js rename to packages/stt-realtime-webgpu/src/libs/worker.ts index 2fadc07..8842380 100644 --- a/packages/stt-realtime-webgpu/src/libs/worker.js +++ b/packages/stt-realtime-webgpu/src/libs/worker.ts @@ -1,4 +1,19 @@ -import { AutoProcessor, AutoTokenizer, full, TextStreamer, WhisperForConditionalGeneration } from '@huggingface/transformers' +import type { + ModelOutput, + PreTrainedModel, + PreTrainedTokenizer, + Processor, + ProgressCallback, + Tensor, +} from '@huggingface/transformers' + +import { + AutoProcessor, + AutoTokenizer, + full, + TextStreamer, + WhisperForConditionalGeneration, +} from '@huggingface/transformers' const MAX_NEW_TOKENS = 64 @@ -6,12 +21,12 @@ const MAX_NEW_TOKENS = 64 * This class uses the Singleton pattern to ensure that only one instance of the model is loaded. */ class AutomaticSpeechRecognitionPipeline { - static model_id = null - static tokenizer = null - static processor = null - static model = null + static model_id: string | null = null + static tokenizer: Promise + static processor: Promise + static model: Promise - static async getInstance(progress_callback = null) { + static async getInstance(progress_callback?: ProgressCallback) { this.model_id = 'onnx-community/whisper-base' this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, { @@ -36,7 +51,7 @@ class AutomaticSpeechRecognitionPipeline { } let processing = false -async function generate({ audio, language }) { +async function generate({ audio, language }: { audio: ArrayBuffer, language: string }) { if (processing) return processing = true @@ -49,13 +64,14 @@ async function generate({ audio, language }) { let startTime let numTokens = 0 - const callback_function = (output) => { + const callback_function = (output: ModelOutput | Tensor) => { startTime ??= performance.now() let tps if (numTokens++ > 0) { tps = numTokens / (performance.now() - startTime) * 1000 } + globalThis.postMessage({ status: 'update', output, @@ -66,7 +82,9 @@ async function generate({ audio, language }) { const streamer = new TextStreamer(tokenizer, { skip_prompt: true, - skip_special_tokens: true, + decode_kwargs: { + skip_special_tokens: true, + }, callback_function, }) @@ -79,7 +97,7 @@ async function generate({ audio, language }) { streamer, }) - const outputText = tokenizer.batch_decode(outputs, { skip_special_tokens: true }) + const outputText = tokenizer.batch_decode(outputs as Tensor, { skip_special_tokens: true }) // Send the output back to the main thread globalThis.postMessage({ @@ -96,7 +114,7 @@ async function load() { }) // Load the pipeline and save it for future use. - // eslint-disable-next-line no-unused-vars, unused-imports/no-unused-vars + // eslint-disable-next-line unused-imports/no-unused-vars const [tokenizer, processor, model] = await AutomaticSpeechRecognitionPipeline.getInstance((x) => { // We also add a progress callback to the pipeline so that we can // track model loading.