diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 8652e82b..e1330e36 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -100,6 +100,16 @@ class TranscriptionInfo(NamedTuple): vad_options: VadOptions +def find_numeral_symbol_tokens(tokenizer: Tokenizer): + numeral_symbol_tokens = [] + for i in range(tokenizer.eot): + token = tokenizer.decode([i]).removeprefix(" ") + has_numeral_symbol = any(c in "0123456789%$£¥" for c in token) + if has_numeral_symbol: + numeral_symbol_tokens.append(i) + return numeral_symbol_tokens + + # The code below is originally from HF pipeline and is used in whisper-x # (https://github.com/m-bain/whisperX) and adapted for faster_whisper @@ -314,6 +324,7 @@ def transcribe( prefix: Optional[str] = None, suppress_blank: bool = True, suppress_tokens: Optional[List[int]] = [-1], + suppress_numerals: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", max_new_tokens: Optional[int] = None, @@ -485,7 +496,9 @@ def transcribe( initial_prompt=initial_prompt, prefix=prefix, suppress_blank=suppress_blank, - suppress_tokens=get_suppressed_tokens(self.tokenizer, suppress_tokens), + suppress_tokens=get_suppressed_tokens( + self.tokenizer, suppress_tokens, suppress_numerals + ), prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, max_new_tokens=max_new_tokens, @@ -2110,6 +2123,7 @@ def get_compression_ratio(text: str) -> float: def get_suppressed_tokens( tokenizer: Tokenizer, suppress_tokens: Tuple[int], + suppress_numerals: bool = False, ) -> Optional[List[int]]: if -1 in suppress_tokens: suppress_tokens = [t for t in suppress_tokens if t >= 0] @@ -2129,6 +2143,12 @@ def get_suppressed_tokens( ] ) + # This is not present in the original faster_whisper implementation + # Follows the same logic as whisperx + if suppress_numerals: + numeral_symbol_tokens = find_numeral_symbol_tokens(tokenizer) + suppress_tokens.extend(numeral_symbol_tokens) + return tuple(sorted(set(suppress_tokens)))