Skip to content

Commit

Permalink
Merge remote-tracking branch 'fork/pyav' into update
Browse files Browse the repository at this point in the history
  • Loading branch information
Luis Nunez committed Oct 16, 2024
2 parents 64852b5 + ab192e7 commit 88ac65e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 10 deletions.
88 changes: 79 additions & 9 deletions faster_whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
system dependencies. FFmpeg does not need to be installed on the system.
However, the API is quite low-level so we need to manipulate audio frames directly.
"""

import gc
import io
import itertools

from typing import BinaryIO, Union

import av
import numpy as np
import torch
import torchaudio


def decode_audio(
Expand All @@ -17,22 +30,79 @@ def decode_audio(
split_stereo: Return separate left and right channels.
Returns:
A float32 Torch Tensor.
A float32 Numpy array.
If `split_stereo` is enabled, the function returns a 2-tuple with the
separated left and right channels.
"""
resampler = av.audio.resampler.AudioResampler(
format="s16",
layout="mono" if not split_stereo else "stereo",
rate=sampling_rate,
)

waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T
raw_buffer = io.BytesIO()
dtype = None

with av.open(input_file, mode="r", metadata_errors="ignore") as container:
frames = container.decode(audio=0)
frames = _ignore_invalid_frames(frames)
frames = _group_frames(frames, 500000)
frames = _resample_frames(frames, resampler)

for frame in frames:
array = frame.to_ndarray()
dtype = array.dtype
raw_buffer.write(array)

# It appears that some objects related to the resampler are not freed
# unless the garbage collector is manually run.
del resampler
gc.collect()

audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)

# Convert s16 back to f32.
audio = audio.astype(np.float32) / 32768.0

if audio_sf != sampling_rate:
waveform = torchaudio.functional.resample(
waveform, orig_freq=audio_sf, new_freq=sampling_rate
)
if split_stereo:
return waveform[0], waveform[1]
left_channel = audio[0::2]
right_channel = audio[1::2]
return torch.from_numpy(left_channel), torch.from_numpy(right_channel)

return torch.from_numpy(audio)


def _ignore_invalid_frames(frames):
iterator = iter(frames)

while True:
try:
yield next(iterator)
except StopIteration:
break
except av.error.InvalidDataError:
continue


def _group_frames(frames, num_samples=None):
fifo = av.audio.fifo.AudioFifo()

for frame in frames:
frame.pts = None # Ignore timestamp check.
fifo.write(frame)

if num_samples is not None and fifo.samples >= num_samples:
yield fifo.read()

if fifo.samples > 0:
yield fifo.read()


return waveform.mean(0)
def _resample_frames(frames, resampler):
# Add None to flush the resampler.
for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame)


def pad_or_trim(array, length: int, *, axis: int = -1):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ huggingface_hub>=0.13
tokenizers>=0.13,<1
onnxruntime>=1.14,<2
torch>=2.1.1
torchaudio>=2.1.2
av>=11
tqdm

0 comments on commit 88ac65e

Please sign in to comment.