Skip to content

Commit

Permalink
batched whisper tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiltseb committed Oct 29, 2024
1 parent ed92a8f commit 722da82
Showing 1 changed file with 66 additions and 52 deletions.
118 changes: 66 additions & 52 deletions aana/tests/deployments/test_whisper_deployment.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# ruff: noqa: S101
import json
from collections import defaultdict
from importlib import resources
from pathlib import Path

import pytest

from aana.core.models.audio import Audio
from aana.core.models.base import pydantic_to_dict
from aana.core.models.whisper import WhisperParams
from aana.core.models.vad import VadSegment
from aana.core.models.whisper import BatchedWhisperParams, WhisperParams
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.deployments.whisper_deployment import (
WhisperComputeType,
Expand Down Expand Up @@ -57,7 +60,7 @@ class TestWhisperDeployment:
"""Test Whisper deployment."""

@pytest.mark.asyncio
@pytest.mark.parametrize("audio_file", ["squirrel.wav", "physicsworks.wav"])
@pytest.mark.parametrize("audio_file", ["physicsworks.wav", "squirrel.wav"])
async def test_transcribe(self, setup_deployment, audio_file):
"""Test transcribe methods."""
deployment_name, handle_name, _ = setup_deployment
Expand Down Expand Up @@ -99,53 +102,64 @@ async def test_transcribe(self, setup_deployment, audio_file):

verify_deployment_results(expected_output_path, grouped_dict)

# Test transcribe_batch method

# Test transcribe_in_chunks method: Note that the expected asr output is different
# TODO: Update once batched whisper PR is merged
# expected_batched_output_path = resources.path(
# f"aana.tests.files.expected.whisper.{model_size}",
# f"{audio_file_name}_batched.json",
# )
# assert (
# expected_batched_output_path.exists()
# ), f"Expected output not found: {expected_batched_output_path}"
# with Path(expected_batched_output_path) as path, path.open() as f:
# expected_output_batched = json.load(f)

# # Get expected vad segments
# vad_path = resources.path(
# "aana.tests.files.expected.vad", f"{audio_file_name}_vad.json"
# )
# assert vad_path.exists(), f"vad expected predictions not found: {vad_path}"

# with Path(vad_path) as path, path.open() as f:
# expected_output_vad = json.load(f)

# final_input = [
# VadSegment(time_interval=seg["time_interval"], segments=seg["segments"])
# for seg in expected_output_vad["segments"]
# ]

# batched_stream = handle.options(stream=True).transcribe_in_chunks.remote(
# audio=audio,
# segments=final_input,
# batch_size=16,
# params=BatchedWhisperParams(temperature=0.0),
# )

# # Combine individual segments and compare with the final dict
# transcript = ""
# grouped_dict = defaultdict(list)
# async for chunk in batched_stream:
# output = pydantic_to_dict(chunk)
# transcript += output["transcription"]["text"]
# grouped_dict["segments"].extend(output.get("segments", []))

# grouped_dict["transcription"] = {"text": transcript}
# grouped_dict["transcription_info"] = output.get("transcription_info")

# compare_transcriptions(
# expected_output_batched,
# dict(grouped_dict),
# )
# Test transcribe_batch method

# Test transcribe_in_chunks method: Note that the expected asr output is different
expected_output_path = (
resources.files("aana.tests.files.expected")
/ "whisper"
/ f"{deployment_name}_{audio_file}_batched.json"
)
audio_file_name = audio_file.removesuffix(".wav")

# Get expected vad segments
vad_path = resources.path(
"aana.tests.files.expected.vad", f"{audio_file_name}.json"
)
assert vad_path.exists(), f"vad expected predictions not found: {vad_path}"

with Path(vad_path) as path, path.open() as f:
expected_output_vad = json.load(f)

vad_input = [
VadSegment(time_interval=seg["time_interval"], segments=seg["segments"])
for seg in expected_output_vad["segments"]
]

batched_stream = handle.transcribe_in_chunks(
audio=audio,
vad_segments=vad_input,
batch_size=16,
params=BatchedWhisperParams(temperature=0.0),
)

# Combine individual segments and compare with the final dict
grouped_dict = defaultdict(list)
transcript = ""
async for chunk in batched_stream:
output = pydantic_to_dict(chunk)
transcript += output["transcription"]["text"]
grouped_dict["segments"].extend(output.get("segments", []))

grouped_dict["transcription"] = {"text": transcript}
grouped_dict["transcription_info"] = output.get("transcription_info")
verify_deployment_results(expected_output_path, grouped_dict)

# Test with whisper internal VAD
batched_stream = handle.transcribe_in_chunks(
audio=audio,
batch_size=16,
params=BatchedWhisperParams(temperature=0.0),
)

# Combine individual segments and compare with the final dict
grouped_dict = defaultdict(list)
transcript = ""
async for chunk in batched_stream:
output = pydantic_to_dict(chunk)
transcript += output["transcription"]["text"]
grouped_dict["segments"].extend(output.get("segments", []))

grouped_dict["transcription"] = {"text": transcript}
grouped_dict["transcription_info"] = output.get("transcription_info")
verify_deployment_results(expected_output_path, grouped_dict)

0 comments on commit 722da82

Please sign in to comment.