Skip to content

Commit

Permalink
Merge pull request #197 from mobiusml/batched_whisper_PR
Browse files Browse the repository at this point in the history
Batched whisper integration
  • Loading branch information
Jiltseb authored Nov 4, 2024
2 parents 9d4952e + b631e14 commit a04c5be
Show file tree
Hide file tree
Showing 12 changed files with 682 additions and 286 deletions.
169 changes: 75 additions & 94 deletions aana/deployments/whisper_deployment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator # noqa: I001
from enum import Enum
from typing import Any, cast

import torch
from faster_whisper import WhisperModel
from faster_whisper import BatchedInferencePipeline, WhisperModel
from pydantic import BaseModel, ConfigDict, Field
from ray import serve
from typing_extensions import TypedDict
Expand All @@ -15,12 +15,21 @@
)
from aana.core.models.audio import Audio
from aana.core.models.base import pydantic_protected_fields
from aana.core.models.whisper import (
WhisperParams,
)
from aana.core.models.vad import VadSegment
from aana.core.models.whisper import BatchedWhisperParams, WhisperParams
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.exceptions.runtime import InferenceException

# Workaround for CUDNN issue with cTranslate2:
import os
import nvidia.cudnn.lib
from pathlib import Path

cudnn_path = str(Path(nvidia.cudnn.lib.__file__).parent)
os.environ["LD_LIBRARY_PATH"] = (
cudnn_path + "/:" + os.environ.get("LD_LIBRARY_PATH", "")
)


class WhisperComputeType(str, Enum):
"""The data type used by whisper models.
Expand Down Expand Up @@ -149,6 +158,9 @@ async def apply_config(self, config: dict[str, Any]):
self.model = WhisperModel(
self.model_size, device=self.device, compute_type=self.compute_type
)
self.batched_model = BatchedInferencePipeline(
model=self.model,
)

@exception_handler
async def transcribe(
Expand Down Expand Up @@ -282,92 +294,61 @@ async def transcribe_batch(
segments=segments, transcription_info=infos, transcription=transcriptions
)

# TODO: Update once batched whisper PR is merged
# async def transcribe_in_chunks(
# self,
# audio: Audio,
# segments: list[VadSegment],
# batch_size: int = 16,
# params: BatchedWhisperParams | None = None,
# ) -> AsyncGenerator[WhisperOutput, None]:
# """Transcribe a single audio by segmenting it into chunks (4x faster) in streaming mode.

# Args:
# audio (Audio): The audio to transcribe.
# segments (list[VadSegment]): List of segments to guide batching the audio data.
# batch_size (int): Maximum batch size for the batched inference.
# params (WhisperParams): The parameters for the whisper model.

# Yields:
# WhisperOutput: The transcription output as a dictionary:
# segments (list[AsrSegment]): The ASR segments.
# transcription_info (AsrTranscriptionInfo): The ASR transcription info.
# transcription (AsrTranscription): The ASR transcription.

# Raises:
# InferenceException: If the inference fails.
# """
# try:
# from faster_whisper import BatchedInferencePipeline
# except ImportError as e:
# raise ImportError(
# "Batched version of whisper is not available. "
# "Install faster-whisper from https://github.com/mobiusml/faster-whisper"
# ) from e

# if not params:
# params = BatchedWhisperParams()

# if params.language is not None:
# tokenizer = Tokenizer(
# self.model.hf_tokenizer,
# self.model.model.is_multilingual,
# task="transcribe",
# language=params.language,
# )
# else:
# # If no language is specified, language will be first detected for each audio.
# tokenizer = None

# self.batched_model = BatchedInferencePipeline(
# model=self.model,
# use_vad_model=False,
# options=None,
# tokenizer=tokenizer,
# language=params.language,
# )

# audio_array = audio.get_numpy()

# vad_input = [seg.to_whisper_dict() for seg in segments]
# if not vad_input:
# # For silent audios/no audio tracks, return empty output with language as silence
# yield WhisperOutput(
# segments=[],
# transcription_info=AsrTranscriptionInfo(
# language="silence", language_confidence=1.0
# ),
# transcription=AsrTranscription(text=""),
# )
# else:
# try:
# result = self.batched_model.transcribe(
# audio_array,
# vad_segments=vad_input,
# batch_size=batch_size,
# **params.model_dump(),
# )
# except Exception as e:
# raise InferenceException(self.model_name) from e

# for count, (segment, info) in enumerate(result):
# if count == 0:
# asr_transcription_info = AsrTranscriptionInfo.from_whisper(info)
# asr_segments = [AsrSegment.from_whisper(segment)]
# asr_transcription = AsrTranscription(text=segment.text)

# yield WhisperOutput(
# segments=asr_segments,
# transcription_info=asr_transcription_info,
# transcription=asr_transcription,
# )
@exception_handler
async def transcribe_in_chunks(
self,
audio: Audio,
vad_segments: list[VadSegment] | None = None,
batch_size: int = 16,
params: BatchedWhisperParams | None = None,
) -> AsyncGenerator[WhisperOutput, None]:
"""Transcribe a single audio by segmenting it into chunks (4x faster) in streaming mode.
Args:
audio (Audio): The audio to transcribe.
vad_segments (list[VadSegment]| None): List of vad segments to guide batching of the audio data.
batch_size (int): Maximum batch size for the batched inference.
params (BatchedWhisperParams | None): The parameters for the batched inference with whisper model.
Yields:
WhisperOutput: The transcription output as a dictionary:
segments (list[AsrSegment]): The ASR segments.
transcription_info (AsrTranscriptionInfo): The ASR transcription info.
transcription (AsrTranscription): The ASR transcription.
Raises:
InferenceException: If the inference fails.
"""
if not params:
params = BatchedWhisperParams()
audio_array = audio.get_numpy()
if vad_segments:
vad_input = [seg.to_whisper_dict() for seg in vad_segments]
if not audio_array.any():
# For silent audios/no audio tracks, return empty output with language as silence
yield WhisperOutput(
segments=[],
transcription_info=AsrTranscriptionInfo(
language="silence", language_confidence=1.0
),
transcription=AsrTranscription(text=""),
)
else:
try:
segments, info = self.batched_model.transcribe(
audio_array,
vad_segments=vad_input if vad_segments else None,
batch_size=batch_size,
**params.model_dump(),
)
except Exception as e:
raise InferenceException(self.model_name) from e
asr_transcription_info = AsrTranscriptionInfo.from_whisper(info)
for segment in segments:
asr_segments = [AsrSegment.from_whisper(segment)]
asr_transcription = AsrTranscription(text=segment.text)
yield WhisperOutput(
segments=asr_segments,
transcription_info=asr_transcription_info,
transcription=asr_transcription,
)
1 change: 0 additions & 1 deletion aana/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
DeploymentException,
EmptyMigrationsException,
FailedDeployment,
InferenceException,
InsufficientResources,
)
from aana.storage.op import run_alembic_migrations
Expand Down
108 changes: 57 additions & 51 deletions aana/tests/deployments/test_whisper_deployment.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# ruff: noqa: S101
import json
from collections import defaultdict
from importlib import resources
from pathlib import Path

import pytest

from aana.core.models.audio import Audio
from aana.core.models.base import pydantic_to_dict
from aana.core.models.whisper import WhisperParams
from aana.core.models.vad import VadSegment
from aana.core.models.whisper import BatchedWhisperParams, WhisperParams
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.deployments.whisper_deployment import (
WhisperComputeType,
Expand Down Expand Up @@ -99,53 +102,56 @@ async def test_transcribe(self, setup_deployment, audio_file):

verify_deployment_results(expected_output_path, grouped_dict)

# Test transcribe_batch method

# Test transcribe_in_chunks method: Note that the expected asr output is different
# TODO: Update once batched whisper PR is merged
# expected_batched_output_path = resources.path(
# f"aana.tests.files.expected.whisper.{model_size}",
# f"{audio_file_name}_batched.json",
# )
# assert (
# expected_batched_output_path.exists()
# ), f"Expected output not found: {expected_batched_output_path}"
# with Path(expected_batched_output_path) as path, path.open() as f:
# expected_output_batched = json.load(f)

# # Get expected vad segments
# vad_path = resources.path(
# "aana.tests.files.expected.vad", f"{audio_file_name}_vad.json"
# )
# assert vad_path.exists(), f"vad expected predictions not found: {vad_path}"

# with Path(vad_path) as path, path.open() as f:
# expected_output_vad = json.load(f)

# final_input = [
# VadSegment(time_interval=seg["time_interval"], segments=seg["segments"])
# for seg in expected_output_vad["segments"]
# ]

# batched_stream = handle.options(stream=True).transcribe_in_chunks.remote(
# audio=audio,
# segments=final_input,
# batch_size=16,
# params=BatchedWhisperParams(temperature=0.0),
# )

# # Combine individual segments and compare with the final dict
# transcript = ""
# grouped_dict = defaultdict(list)
# async for chunk in batched_stream:
# output = pydantic_to_dict(chunk)
# transcript += output["transcription"]["text"]
# grouped_dict["segments"].extend(output.get("segments", []))

# grouped_dict["transcription"] = {"text": transcript}
# grouped_dict["transcription_info"] = output.get("transcription_info")

# compare_transcriptions(
# expected_output_batched,
# dict(grouped_dict),
# )
# Test transcribe_in_chunks method: Note that the expected asr output is different
expected_output_path = (
resources.files("aana.tests.files.expected")
/ "whisper"
/ f"{deployment_name}_{audio_file}_batched.json"
)
audio_file_name = audio_file.removesuffix(".wav")
# Get expected vad segments
vad_path = (
resources.files("aana.tests.files.expected.vad") / f"{audio_file_name}.json"
)

assert vad_path.exists(), f"vad expected predictions not found: {vad_path}"
with Path(vad_path) as path, path.open() as f:
expected_output_vad = json.load(f)
vad_input = [
VadSegment(time_interval=seg["time_interval"], segments=seg["segments"])
for seg in expected_output_vad["segments"]
]

batched_stream = handle.transcribe_in_chunks(
audio=audio,
vad_segments=vad_input,
batch_size=16,
params=BatchedWhisperParams(temperature=0.0),
)
# Combine individual segments and compare with the final dict
grouped_dict = defaultdict(list)
transcript = ""
async for chunk in batched_stream:
output = pydantic_to_dict(chunk)
transcript += output["transcription"]["text"]
grouped_dict["segments"].extend(output.get("segments", []))
grouped_dict["transcription"] = {"text": transcript}
grouped_dict["transcription_info"] = output.get("transcription_info")
verify_deployment_results(expected_output_path, grouped_dict)

# Test with whisper internal VAD
batched_stream = handle.transcribe_in_chunks(
audio=audio,
batch_size=16,
params=BatchedWhisperParams(temperature=0.0),
)
# Combine individual segments and compare with the final dict
grouped_dict = defaultdict(list)
transcript = ""
async for chunk in batched_stream:
output = pydantic_to_dict(chunk)
transcript += output["transcription"]["text"]
grouped_dict["segments"].extend(output.get("segments", []))
grouped_dict["transcription"] = {"text": transcript}
grouped_dict["transcription_info"] = output.get("transcription_info")
verify_deployment_results(expected_output_path, grouped_dict)
Loading

0 comments on commit a04c5be

Please sign in to comment.