From bc791d49964947fc800cd68d62111d9383e8eb73 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 14 Aug 2023 17:02:23 +0800 Subject: [PATCH] Fix C api for Go and MFC to support streaming paraformer (#268) --- .github/workflows/go.yaml | 21 ++- CMakeLists.txt | 2 +- .../main.go | 8 +- .../streaming-decode-files/main.go | 8 +- .../streaming-decode-files/run-paraformer.sh | 21 +++ .../{run.sh => run-transducer.sh} | 0 .../NonStreamingSpeechRecognitionDlg.cpp | 84 ++++++++- .../NonStreamingSpeechRecognitionDlg.h | 1 + .../StreamingSpeechRecognitionDlg.cpp | 167 +++++++++++++----- .../StreamingSpeechRecognitionDlg.h | 2 + scripts/go/sherpa_onnx.go | 47 +++-- sherpa-onnx/c-api/c-api.cc | 6 + sherpa-onnx/c-api/c-api.h | 6 + 13 files changed, 307 insertions(+), 66 deletions(-) create mode 100755 go-api-examples/streaming-decode-files/run-paraformer.sh rename go-api-examples/streaming-decode-files/{run.sh => run-transducer.sh} (100%) diff --git a/.github/workflows/go.yaml b/.github/workflows/go.yaml index b19a801ed..c4bad36d4 100644 --- a/.github/workflows/go.yaml +++ b/.github/workflows/go.yaml @@ -178,9 +178,14 @@ jobs: echo "Test transducer" git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 - ./run.sh + ./run-transducer.sh rm -rf sherpa-onnx-streaming-zipformer-en-2023-06-26 + echo "Test paraformer" + git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en + ./run-paraformer.sh + rm -rf sherpa-onnx-streaming-paraformer-bilingual-zh-en + - name: Test streaming decoding files (Win64) if: matrix.os == 'windows-latest' && matrix.arch == 'x64' shell: bash @@ -202,9 +207,14 @@ jobs: echo "Test transducer" git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 - ./run.sh + ./run-transducer.sh rm -rf sherpa-onnx-streaming-zipformer-en-2023-06-26 + echo "Test paraformer" + git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en + ./run-paraformer.sh + rm -rf sherpa-onnx-streaming-paraformer-bilingual-zh-en + - name: Test streaming decoding files (Win32) if: matrix.os == 'windows-latest' && matrix.arch == 'x86' shell: bash @@ -235,5 +245,10 @@ jobs: echo "Test transducer" git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 - ./run.sh + ./run-transducer.sh rm -rf sherpa-onnx-streaming-zipformer-en-2023-06-26 + + echo "Test paraformer" + git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en + ./run-paraformer.sh + rm -rf sherpa-onnx-streaming-paraformer-bilingual-zh-en diff --git a/CMakeLists.txt b/CMakeLists.txt index c6086fa31..cd2d85213 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.7.4") +set(SHERPA_ONNX_VERSION "1.7.5") # Disable warning about # diff --git a/go-api-examples/real-time-speech-recognition-from-microphone/main.go b/go-api-examples/real-time-speech-recognition-from-microphone/main.go index 229b25b3d..094132835 100644 --- a/go-api-examples/real-time-speech-recognition-from-microphone/main.go +++ b/go-api-examples/real-time-speech-recognition-from-microphone/main.go @@ -33,9 +33,11 @@ func main() { config := sherpa.OnlineRecognizerConfig{} config.FeatConfig = sherpa.FeatureConfig{SampleRate: 16000, FeatureDim: 80} - flag.StringVar(&config.ModelConfig.Encoder, "encoder", "", "Path to the encoder model") - flag.StringVar(&config.ModelConfig.Decoder, "decoder", "", "Path to the decoder model") - flag.StringVar(&config.ModelConfig.Joiner, "joiner", "", "Path to the joiner model") + flag.StringVar(&config.ModelConfig.Transducer.Encoder, "encoder", "", "Path to the transducer encoder model") + flag.StringVar(&config.ModelConfig.Transducer.Decoder, "decoder", "", "Path to the transducer decoder model") + flag.StringVar(&config.ModelConfig.Transducer.Joiner, "joiner", "", "Path to the transducer joiner model") + flag.StringVar(&config.ModelConfig.Paraformer.Encoder, "paraformer-encoder", "", "Path to the paraformer encoder model") + flag.StringVar(&config.ModelConfig.Paraformer.Decoder, "paraformer-decoder", "", "Path to the paraformer decoder model") flag.StringVar(&config.ModelConfig.Tokens, "tokens", "", "Path to the tokens file") flag.IntVar(&config.ModelConfig.NumThreads, "num-threads", 1, "Number of threads for computing") flag.IntVar(&config.ModelConfig.Debug, "debug", 0, "Whether to show debug message") diff --git a/go-api-examples/streaming-decode-files/main.go b/go-api-examples/streaming-decode-files/main.go index 657767efe..fc2922236 100644 --- a/go-api-examples/streaming-decode-files/main.go +++ b/go-api-examples/streaming-decode-files/main.go @@ -17,9 +17,11 @@ func main() { config := sherpa.OnlineRecognizerConfig{} config.FeatConfig = sherpa.FeatureConfig{SampleRate: 16000, FeatureDim: 80} - flag.StringVar(&config.ModelConfig.Encoder, "encoder", "", "Path to the encoder model") - flag.StringVar(&config.ModelConfig.Decoder, "decoder", "", "Path to the decoder model") - flag.StringVar(&config.ModelConfig.Joiner, "joiner", "", "Path to the joiner model") + flag.StringVar(&config.ModelConfig.Transducer.Encoder, "encoder", "", "Path to the transducer encoder model") + flag.StringVar(&config.ModelConfig.Transducer.Decoder, "decoder", "", "Path to the transducer decoder model") + flag.StringVar(&config.ModelConfig.Transducer.Joiner, "joiner", "", "Path to the transducer joiner model") + flag.StringVar(&config.ModelConfig.Paraformer.Encoder, "paraformer-encoder", "", "Path to the paraformer encoder model") + flag.StringVar(&config.ModelConfig.Paraformer.Decoder, "paraformer-decoder", "", "Path to the paraformer decoder model") flag.StringVar(&config.ModelConfig.Tokens, "tokens", "", "Path to the tokens file") flag.IntVar(&config.ModelConfig.NumThreads, "num-threads", 1, "Number of threads for computing") flag.IntVar(&config.ModelConfig.Debug, "debug", 0, "Whether to show debug message") diff --git a/go-api-examples/streaming-decode-files/run-paraformer.sh b/go-api-examples/streaming-decode-files/run-paraformer.sh new file mode 100755 index 000000000..f2b7fbf2a --- /dev/null +++ b/go-api-examples/streaming-decode-files/run-paraformer.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english +# to download the model files + +if [ ! -d ./sherpa-onnx-streaming-paraformer-bilingual-zh-en ]; then + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en + cd sherpa-onnx-streaming-paraformer-bilingual-zh-en + git lfs pull --include "*.onnx" + cd .. +fi + +./streaming-decode-files \ + --paraformer-encoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \ + --paraformer-decoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx \ + --tokens ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ + --decoding-method greedy_search \ + --model-type paraformer \ + --debug 0 \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav diff --git a/go-api-examples/streaming-decode-files/run.sh b/go-api-examples/streaming-decode-files/run-transducer.sh similarity index 100% rename from go-api-examples/streaming-decode-files/run.sh rename to go-api-examples/streaming-decode-files/run-transducer.sh diff --git a/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp b/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp index aefd5b57d..0d1454ba6 100644 --- a/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp +++ b/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.cpp @@ -306,12 +306,10 @@ void CNonStreamingSpeechRecognitionDlg::ShowInitRecognizerHelpMessage() { "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html " "\r\n"; msg += "to download a non-streaming model, i.e., an offline model.\r\n"; + msg += "You need to rename them after downloading\r\n\r\n"; + msg += "It supports transducer, paraformer, and whisper models.\r\n\r\n"; msg += - "You need to rename them to encoder.onnx, decoder.onnx, and " - "joiner.onnx correspoondingly.\r\n\r\n"; - msg += "It supports both transducer models and paraformer models.\r\n\r\n"; - msg += - "We give two examples below to show you how to download models\r\n\r\n"; + "We give three examples below to show you how to download models\r\n\r\n"; msg += "(1) Transducer\r\n\r\n"; msg += "We use " @@ -346,13 +344,82 @@ void CNonStreamingSpeechRecognitionDlg::ShowInitRecognizerHelpMessage() { "https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28/" "resolve/main/tokens.txt\r\n\r\n"; msg += "\r\n Now rename them\r\n"; - msg += "mv model.onnx paraformer.onnx\r\n"; + msg += "mv model.onnx paraformer.onnx\r\n\r\n"; + msg += "(3) Whisper\r\n\r\n"; + msg += + "wget " + "https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en/resolve/" + "main/tiny.en-encoder.onnx\r\n"; + msg += + "wget " + "https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en/resolve/" + "main/tiny.en-decoder.onnx\r\n"; + msg += + "wget " + "https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en/resolve/" + "main/tiny.en-tokens.txt\r\n"; + msg += "\r\n Now rename them\r\n"; + msg += "mv tiny.en-encoder.onnx whisper-encoder.onnx\r\n"; + msg += "mv tiny.en-decoder.onnx whisper-decoder.onnx\r\n"; msg += "\r\n"; msg += "That's it!\r\n"; AppendLineToMultilineEditCtrl(msg); } +void CNonStreamingSpeechRecognitionDlg::InitWhisper() { + std::string whisper_encoder = "./whisper-encoder.onnx"; + std::string whisper_decoder = "./whisper-decoder.onnx"; + + std::string tokens = "./tokens.txt"; + + bool is_ok = true; + + if (Exists("./whisper-encoder.int8.onnx")) { + whisper_encoder = "./whisper-encoder.int8.onnx"; + } else if (!Exists(whisper_encoder)) { + std::string msg = whisper_encoder + " does not exist!"; + AppendLineToMultilineEditCtrl(msg); + is_ok = false; + } + + if (Exists("./whisper-decoder.int8.onnx")) { + whisper_decoder = "./whisper-decoder.int8.onnx"; + } else if (!Exists(whisper_decoder)) { + std::string msg = whisper_decoder + " does not exist!"; + AppendLineToMultilineEditCtrl(msg); + is_ok = false; + } + + if (!Exists(tokens)) { + std::string msg = tokens + " does not exist!"; + AppendLineToMultilineEditCtrl(msg); + is_ok = false; + } + + if (!is_ok) { + ShowInitRecognizerHelpMessage(); + return; + } + + memset(&config_, 0, sizeof(config_)); + + config_.feat_config.sample_rate = 16000; + config_.feat_config.feature_dim = 80; + + config_.model_config.whisper.encoder = whisper_encoder.c_str(); + config_.model_config.whisper.decoder = whisper_decoder.c_str(); + config_.model_config.tokens = tokens.c_str(); + config_.model_config.num_threads = 1; + config_.model_config.debug = 1; + config_.model_config.model_type = "whisper"; + + config_.decoding_method = "greedy_search"; + config_.max_active_paths = 4; + + recognizer_ = CreateOfflineRecognizer(&config_); +} + void CNonStreamingSpeechRecognitionDlg::InitParaformer() { std::string paraformer = "./paraformer.onnx"; std::string tokens = "./tokens.txt"; @@ -401,6 +468,11 @@ void CNonStreamingSpeechRecognitionDlg::InitRecognizer() { return; } + if (Exists("./whisper-encoder.onnx") || Exists("./whisper-encoder.int8.onnx")) { + InitWhisper(); + return; + } + // assume it is transducer std::string encoder = "./encoder.onnx"; diff --git a/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.h b/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.h index e364bc58d..77a8992e9 100644 --- a/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.h +++ b/mfc-examples/NonStreamingSpeechRecognition/NonStreamingSpeechRecognitionDlg.h @@ -69,5 +69,6 @@ class CNonStreamingSpeechRecognitionDlg : public CDialogEx { void InitRecognizer(); void InitParaformer(); + void InitWhisper(); void ShowInitRecognizerHelpMessage(); }; diff --git a/mfc-examples/StreamingSpeechRecognition/StreamingSpeechRecognitionDlg.cpp b/mfc-examples/StreamingSpeechRecognition/StreamingSpeechRecognitionDlg.cpp index 1748b985d..7be8dbe39 100644 --- a/mfc-examples/StreamingSpeechRecognition/StreamingSpeechRecognitionDlg.cpp +++ b/mfc-examples/StreamingSpeechRecognition/StreamingSpeechRecognitionDlg.cpp @@ -234,50 +234,18 @@ bool CStreamingSpeechRecognitionDlg::Exists(const std::string &filename) { return is.good(); } -void CStreamingSpeechRecognitionDlg::InitRecognizer() { - std::string encoder = "./encoder.onnx"; - std::string decoder = "./decoder.onnx"; - std::string joiner = "./joiner.onnx"; - std::string tokens = "./tokens.txt"; - - bool is_ok = true; - if (!Exists(encoder)) { - std::string msg = encoder + " does not exist!"; - AppendLineToMultilineEditCtrl(msg); - is_ok = false; - } - - if (!Exists(decoder)) { - std::string msg = decoder + " does not exist!"; - AppendLineToMultilineEditCtrl(msg); - is_ok = false; - } - - if (!Exists(joiner)) { - std::string msg = joiner + " does not exist!"; - AppendLineToMultilineEditCtrl(msg); - is_ok = false; - } - - if (!Exists(tokens)) { - std::string msg = tokens + " does not exist!"; - AppendLineToMultilineEditCtrl(msg); - is_ok = false; - } - - if (!is_ok) { +void CStreamingSpeechRecognitionDlg::ShowInitRecognizerHelpMessage() { my_btn_.EnableWindow(FALSE); std::string msg = "\r\nPlease go to\r\n" "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html " "\r\n"; msg += "to download a streaming model, i.e., an online model.\r\n"; + msg += "You need to rename them after downloading\r\n\r\n"; + msg += "It supports both transducer and paraformer models.\r\n\r\n"; msg += - "You need to rename them to encoder.onnx, decoder.onnx, and " - "joiner.onnx correspoondingly.\r\n\r\n"; - msg += - "We use the following model as an example to show you how to do " - "that.\r\n"; + "We give two examples below to show you how to download models\r\n\r\n"; + msg += "(1) Transducer\r\n\r\n"; msg += "https://huggingface.co/pkufool/" "icefall-asr-zipformer-streaming-wenetspeech-20230615"; @@ -308,13 +276,132 @@ void CStreamingSpeechRecognitionDlg::InitRecognizer() { msg += "mv decoder-epoch-12-avg-4-chunk-16-left-128.onnx decoder.onnx\r\n"; msg += "mv joiner-epoch-12-avg-4-chunk-16-left-128.onnx joiner.onnx\r\n"; msg += "\r\n"; + msg += "(2) Paraformer\r\n\r\n"; + msg += + "wget " + "https://huggingface.co/csukuangfj/" + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/resolve/main/" + "encoder.int8.onnx\r\n"; + msg += + "wget " + "https://huggingface.co/csukuangfj/" + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/resolve/main/" + "decoder.int8.onnx\r\n"; + msg += + "wget " + "https://huggingface.co/csukuangfj/" + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/resolve/main/" + "tokens.txt\r\n"; + msg += "\r\nNow rename them.\r\n"; + msg += "mv encoder.int8.onnx paraformer-encoder.onnx\r\n"; + msg += "mv decoder.int8.onnx paraformer-decoder.onnx\r\n\r\n"; msg += "That's it!\r\n"; AppendLineToMultilineEditCtrl(msg); +} + +void CStreamingSpeechRecognitionDlg::InitParaformer() { + std::string paraformer_encoder = "./paraformer-encoder.onnx"; + std::string paraformer_decoder = "./paraformer-decoder.onnx"; + + std::string tokens = "./tokens.txt"; + + bool is_ok = true; + + if (Exists("./paraformer-encoder.int8.onnx")) { + paraformer_encoder = "./paraformer-encoder.int8.onnx"; + } else if (!Exists(paraformer_encoder)) { + std::string msg = paraformer_encoder + " does not exist!"; + AppendLineToMultilineEditCtrl(msg); + is_ok = false; + } + + if (Exists("./paraformer-decoder.int8.onnx")) { + paraformer_decoder = "./paraformer-decoder.int8.onnx"; + } else if (!Exists(paraformer_decoder)) { + std::string msg = paraformer_decoder + " does not exist!"; + AppendLineToMultilineEditCtrl(msg); + is_ok = false; + } + + if (!Exists(tokens)) { + std::string msg = tokens + " does not exist!"; + AppendLineToMultilineEditCtrl(msg); + is_ok = false; + } + + if (!is_ok) { + ShowInitRecognizerHelpMessage(); + return; + } + + SherpaOnnxOnlineRecognizerConfig config; + memset(&config, 0, sizeof(config)); + config.model_config.debug = 0; + config.model_config.num_threads = 1; + config.model_config.provider = "cpu"; + + config.decoding_method = "greedy_search"; + config.max_active_paths = 4; + + config.feat_config.sample_rate = 16000; + config.feat_config.feature_dim = 80; + + config.enable_endpoint = 1; + config.rule1_min_trailing_silence = 1.2f; + config.rule2_min_trailing_silence = 0.8f; + config.rule3_min_utterance_length = 300.0f; + + config.model_config.tokens = tokens.c_str(); + config.model_config.paraformer.encoder = paraformer_encoder.c_str(); + config.model_config.paraformer.decoder = paraformer_decoder.c_str(); + + recognizer_ = CreateOnlineRecognizer(&config); +} + +void CStreamingSpeechRecognitionDlg::InitRecognizer() { + if (Exists("./paraformer-encoder.onnx") || Exists("./paraformer-encoder.int8.onnx")) { + InitParaformer(); + return; + } + + std::string encoder = "./encoder.onnx"; + std::string decoder = "./decoder.onnx"; + std::string joiner = "./joiner.onnx"; + std::string tokens = "./tokens.txt"; + + bool is_ok = true; + if (!Exists(encoder)) { + std::string msg = encoder + " does not exist!"; + AppendLineToMultilineEditCtrl(msg); + is_ok = false; + } + + if (!Exists(decoder)) { + std::string msg = decoder + " does not exist!"; + AppendLineToMultilineEditCtrl(msg); + is_ok = false; + } + + if (!Exists(joiner)) { + std::string msg = joiner + " does not exist!"; + AppendLineToMultilineEditCtrl(msg); + is_ok = false; + } + + if (!Exists(tokens)) { + std::string msg = tokens + " does not exist!"; + AppendLineToMultilineEditCtrl(msg); + is_ok = false; + } + + if (!is_ok) { + ShowInitRecognizerHelpMessage(); return; } SherpaOnnxOnlineRecognizerConfig config; + memset(&config, 0, sizeof(config)); config.model_config.debug = 0; config.model_config.num_threads = 1; config.model_config.provider = "cpu"; @@ -331,9 +418,9 @@ void CStreamingSpeechRecognitionDlg::InitRecognizer() { config.rule3_min_utterance_length = 300.0f; config.model_config.tokens = tokens.c_str(); - config.model_config.encoder = encoder.c_str(); - config.model_config.decoder = decoder.c_str(); - config.model_config.joiner = joiner.c_str(); + config.model_config.transducer.encoder = encoder.c_str(); + config.model_config.transducer.decoder = decoder.c_str(); + config.model_config.transducer.joiner = joiner.c_str(); recognizer_ = CreateOnlineRecognizer(&config); } diff --git a/mfc-examples/StreamingSpeechRecognition/StreamingSpeechRecognitionDlg.h b/mfc-examples/StreamingSpeechRecognition/StreamingSpeechRecognitionDlg.h index 122b61983..36b3d9cad 100644 --- a/mfc-examples/StreamingSpeechRecognition/StreamingSpeechRecognitionDlg.h +++ b/mfc-examples/StreamingSpeechRecognition/StreamingSpeechRecognitionDlg.h @@ -67,6 +67,8 @@ class CStreamingSpeechRecognitionDlg : public CDialogEx { bool Exists(const std::string &filename); void InitRecognizer(); + void InitParaformer(); + void ShowInitRecognizerHelpMessage(); }; class RecognizerThread : public CWinThread { diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index 9320af5d0..a5ec4b529 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -45,9 +45,30 @@ import "unsafe" // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html // to download pre-trained models type OnlineTransducerModelConfig struct { - Encoder string // Path to the encoder model, e.g., encoder.onnx or encoder.int8.onnx - Decoder string // Path to the decoder model. - Joiner string // Path to the joiner model. + Encoder string // Path to the encoder model, e.g., encoder.onnx or encoder.int8.onnx + Decoder string // Path to the decoder model. + Joiner string // Path to the joiner model. +} + +// Configuration for online/streaming paraformer models +// +// Please refer to +// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html +// to download pre-trained models +type OnlineParaformerModelConfig struct { + Encoder string // Path to the encoder model, e.g., encoder.onnx or encoder.int8.onnx + Decoder string // Path to the decoder model. +} + +// Configuration for online/streaming models +// +// Please refer to +// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html +// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html +// to download pre-trained models +type OnlineModelConfig struct { + Transducer OnlineTransducerModelConfig + Paraformer OnlineParaformerModelConfig Tokens string // Path to tokens.txt NumThreads int // Number of threads to use for neural network computation Provider string // Optional. Valid values are: cpu, cuda, coreml @@ -68,7 +89,7 @@ type FeatureConfig struct { // Configuration for the online/streaming recognizer. type OnlineRecognizerConfig struct { FeatConfig FeatureConfig - ModelConfig OnlineTransducerModelConfig + ModelConfig OnlineModelConfig // Valid decoding methods: greedy_search, modified_beam_search DecodingMethod string @@ -116,14 +137,20 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { c.feat_config.sample_rate = C.int(config.FeatConfig.SampleRate) c.feat_config.feature_dim = C.int(config.FeatConfig.FeatureDim) - c.model_config.encoder = C.CString(config.ModelConfig.Encoder) - defer C.free(unsafe.Pointer(c.model_config.encoder)) + c.model_config.transducer.encoder = C.CString(config.ModelConfig.Transducer.Encoder) + defer C.free(unsafe.Pointer(c.model_config.transducer.encoder)) + + c.model_config.transducer.decoder = C.CString(config.ModelConfig.Transducer.Decoder) + defer C.free(unsafe.Pointer(c.model_config.transducer.decoder)) + + c.model_config.transducer.joiner = C.CString(config.ModelConfig.Transducer.Joiner) + defer C.free(unsafe.Pointer(c.model_config.transducer.joiner)) - c.model_config.decoder = C.CString(config.ModelConfig.Decoder) - defer C.free(unsafe.Pointer(c.model_config.decoder)) + c.model_config.paraformer.encoder = C.CString(config.ModelConfig.Paraformer.Encoder) + defer C.free(unsafe.Pointer(c.model_config.paraformer.encoder)) - c.model_config.joiner = C.CString(config.ModelConfig.Joiner) - defer C.free(unsafe.Pointer(c.model_config.joiner)) + c.model_config.paraformer.decoder = C.CString(config.ModelConfig.Paraformer.Decoder) + defer C.free(unsafe.Pointer(c.model_config.paraformer.decoder)) c.model_config.tokens = C.CString(config.ModelConfig.Tokens) defer C.free(unsafe.Pointer(c.model_config.tokens)) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 0a3bc13f1..9d1f39196 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -265,6 +265,12 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( recognizer_config.model_config.nemo_ctc.model = SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, ""); + recognizer_config.model_config.whisper.encoder = + SHERPA_ONNX_OR(config->model_config.whisper.encoder, ""); + + recognizer_config.model_config.whisper.decoder = + SHERPA_ONNX_OR(config->model_config.whisper.decoder, ""); + recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); recognizer_config.model_config.num_threads = diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 621b6a804..5bbd9fe2c 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -300,6 +300,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineNemoEncDecCtcModelConfig { const char *model; } SherpaOnnxOfflineNemoEncDecCtcModelConfig; +SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig { + const char *encoder; + const char *decoder; +} SherpaOnnxOfflineWhisperModelConfig; + SHERPA_ONNX_API typedef struct SherpaOnnxOfflineLMConfig { const char *model; float scale; @@ -309,6 +314,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig { SherpaOnnxOfflineTransducerModelConfig transducer; SherpaOnnxOfflineParaformerModelConfig paraformer; SherpaOnnxOfflineNemoEncDecCtcModelConfig nemo_ctc; + SherpaOnnxOfflineWhisperModelConfig whisper; const char *tokens; int32_t num_threads;