Skip to content

Commit

Permalink
chore: to worker.ts
Browse files Browse the repository at this point in the history
  • Loading branch information
nekomeowww committed Dec 10, 2024
1 parent bff652c commit ed2cd55
Showing 1 changed file with 29 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import { AutoProcessor, AutoTokenizer, full, TextStreamer, WhisperForConditionalGeneration } from '@huggingface/transformers'
import type {
ModelOutput,
PreTrainedModel,
PreTrainedTokenizer,
Processor,
ProgressCallback,
Tensor,
} from '@huggingface/transformers'

import {
AutoProcessor,
AutoTokenizer,
full,
TextStreamer,
WhisperForConditionalGeneration,
} from '@huggingface/transformers'

const MAX_NEW_TOKENS = 64

/**
* This class uses the Singleton pattern to ensure that only one instance of the model is loaded.
*/
class AutomaticSpeechRecognitionPipeline {
static model_id = null
static tokenizer = null
static processor = null
static model = null
static model_id: string | null = null
static tokenizer: Promise<PreTrainedTokenizer>
static processor: Promise<Processor>
static model: Promise<PreTrainedModel>

static async getInstance(progress_callback = null) {
static async getInstance(progress_callback?: ProgressCallback) {
this.model_id = 'onnx-community/whisper-base'

this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
Expand All @@ -36,7 +51,7 @@ class AutomaticSpeechRecognitionPipeline {
}

let processing = false
async function generate({ audio, language }) {
async function generate({ audio, language }: { audio: ArrayBuffer, language: string }) {
if (processing)
return
processing = true
Expand All @@ -49,13 +64,14 @@ async function generate({ audio, language }) {

let startTime
let numTokens = 0
const callback_function = (output) => {
const callback_function = (output: ModelOutput | Tensor) => {
startTime ??= performance.now()

let tps
if (numTokens++ > 0) {
tps = numTokens / (performance.now() - startTime) * 1000
}

globalThis.postMessage({
status: 'update',
output,
Expand All @@ -66,7 +82,9 @@ async function generate({ audio, language }) {

const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
decode_kwargs: {
skip_special_tokens: true,
},
callback_function,
})

Expand All @@ -79,7 +97,7 @@ async function generate({ audio, language }) {
streamer,
})

const outputText = tokenizer.batch_decode(outputs, { skip_special_tokens: true })
const outputText = tokenizer.batch_decode(outputs as Tensor, { skip_special_tokens: true })

// Send the output back to the main thread
globalThis.postMessage({
Expand All @@ -96,7 +114,7 @@ async function load() {
})

// Load the pipeline and save it for future use.
// eslint-disable-next-line no-unused-vars, unused-imports/no-unused-vars
// eslint-disable-next-line unused-imports/no-unused-vars
const [tokenizer, processor, model] = await AutomaticSpeechRecognitionPipeline.getInstance((x) => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
Expand Down

0 comments on commit ed2cd55

Please sign in to comment.