From 722da8210fcfe1088c0b35387b2e8f80051306ee Mon Sep 17 00:00:00 2001 From: jiltseb Date: Tue, 29 Oct 2024 13:01:48 +0000 Subject: [PATCH] batched whisper tests --- .../deployments/test_whisper_deployment.py | 118 ++++++++++-------- 1 file changed, 66 insertions(+), 52 deletions(-) diff --git a/aana/tests/deployments/test_whisper_deployment.py b/aana/tests/deployments/test_whisper_deployment.py index 824157e4..b3fe7871 100644 --- a/aana/tests/deployments/test_whisper_deployment.py +++ b/aana/tests/deployments/test_whisper_deployment.py @@ -1,12 +1,15 @@ # ruff: noqa: S101 +import json from collections import defaultdict from importlib import resources +from pathlib import Path import pytest from aana.core.models.audio import Audio from aana.core.models.base import pydantic_to_dict -from aana.core.models.whisper import WhisperParams +from aana.core.models.vad import VadSegment +from aana.core.models.whisper import BatchedWhisperParams, WhisperParams from aana.deployments.aana_deployment_handle import AanaDeploymentHandle from aana.deployments.whisper_deployment import ( WhisperComputeType, @@ -57,7 +60,7 @@ class TestWhisperDeployment: """Test Whisper deployment.""" @pytest.mark.asyncio - @pytest.mark.parametrize("audio_file", ["squirrel.wav", "physicsworks.wav"]) + @pytest.mark.parametrize("audio_file", ["physicsworks.wav", "squirrel.wav"]) async def test_transcribe(self, setup_deployment, audio_file): """Test transcribe methods.""" deployment_name, handle_name, _ = setup_deployment @@ -99,53 +102,64 @@ async def test_transcribe(self, setup_deployment, audio_file): verify_deployment_results(expected_output_path, grouped_dict) - # Test transcribe_batch method - - # Test transcribe_in_chunks method: Note that the expected asr output is different - # TODO: Update once batched whisper PR is merged - # expected_batched_output_path = resources.path( - # f"aana.tests.files.expected.whisper.{model_size}", - # f"{audio_file_name}_batched.json", - # ) - # assert ( - # expected_batched_output_path.exists() - # ), f"Expected output not found: {expected_batched_output_path}" - # with Path(expected_batched_output_path) as path, path.open() as f: - # expected_output_batched = json.load(f) - - # # Get expected vad segments - # vad_path = resources.path( - # "aana.tests.files.expected.vad", f"{audio_file_name}_vad.json" - # ) - # assert vad_path.exists(), f"vad expected predictions not found: {vad_path}" - - # with Path(vad_path) as path, path.open() as f: - # expected_output_vad = json.load(f) - - # final_input = [ - # VadSegment(time_interval=seg["time_interval"], segments=seg["segments"]) - # for seg in expected_output_vad["segments"] - # ] - - # batched_stream = handle.options(stream=True).transcribe_in_chunks.remote( - # audio=audio, - # segments=final_input, - # batch_size=16, - # params=BatchedWhisperParams(temperature=0.0), - # ) - - # # Combine individual segments and compare with the final dict - # transcript = "" - # grouped_dict = defaultdict(list) - # async for chunk in batched_stream: - # output = pydantic_to_dict(chunk) - # transcript += output["transcription"]["text"] - # grouped_dict["segments"].extend(output.get("segments", [])) - - # grouped_dict["transcription"] = {"text": transcript} - # grouped_dict["transcription_info"] = output.get("transcription_info") - - # compare_transcriptions( - # expected_output_batched, - # dict(grouped_dict), - # ) + # Test transcribe_batch method + + # Test transcribe_in_chunks method: Note that the expected asr output is different + expected_output_path = ( + resources.files("aana.tests.files.expected") + / "whisper" + / f"{deployment_name}_{audio_file}_batched.json" + ) + audio_file_name = audio_file.removesuffix(".wav") + + # Get expected vad segments + vad_path = resources.path( + "aana.tests.files.expected.vad", f"{audio_file_name}.json" + ) + assert vad_path.exists(), f"vad expected predictions not found: {vad_path}" + + with Path(vad_path) as path, path.open() as f: + expected_output_vad = json.load(f) + + vad_input = [ + VadSegment(time_interval=seg["time_interval"], segments=seg["segments"]) + for seg in expected_output_vad["segments"] + ] + + batched_stream = handle.transcribe_in_chunks( + audio=audio, + vad_segments=vad_input, + batch_size=16, + params=BatchedWhisperParams(temperature=0.0), + ) + + # Combine individual segments and compare with the final dict + grouped_dict = defaultdict(list) + transcript = "" + async for chunk in batched_stream: + output = pydantic_to_dict(chunk) + transcript += output["transcription"]["text"] + grouped_dict["segments"].extend(output.get("segments", [])) + + grouped_dict["transcription"] = {"text": transcript} + grouped_dict["transcription_info"] = output.get("transcription_info") + verify_deployment_results(expected_output_path, grouped_dict) + + # Test with whisper internal VAD + batched_stream = handle.transcribe_in_chunks( + audio=audio, + batch_size=16, + params=BatchedWhisperParams(temperature=0.0), + ) + + # Combine individual segments and compare with the final dict + grouped_dict = defaultdict(list) + transcript = "" + async for chunk in batched_stream: + output = pydantic_to_dict(chunk) + transcript += output["transcription"]["text"] + grouped_dict["segments"].extend(output.get("segments", [])) + + grouped_dict["transcription"] = {"text": transcript} + grouped_dict["transcription_info"] = output.get("transcription_info") + verify_deployment_results(expected_output_path, grouped_dict)