Skip to content

Commit

Permalink
Add inference class (#3)
Browse files Browse the repository at this point in the history
* Add inference class
  • Loading branch information
thunn authored Nov 12, 2024
1 parent 251e907 commit 4c20e8b
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 17 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,43 @@
# BigVGANInference
An unofficial minimal package for using BigVGAN at inference time

[![PyPI version](https://img.shields.io/pypi/v/bigvganinference)](https://pypi.org/project/bigvganinference/)
![License](https://img.shields.io/pypi/l/bigvganinference)
![Python versions](https://img.shields.io/pypi/pyversions/bigvganinference)

## Installation

```python
pip install bigvganinference
```

or install from source:

```python
git clone https://github.com/thunn/BigVGANInference.git
cd BigVGANInference
poetry install
```

## Usage

Loading model is as simple as:
```python
from bigvganinference import BigVGANInference, BigVGANHFModel

# model is loaded, set to eval and weight norm is removed
model = BigVGANInference.from_pretrained(
BigVGANHFModel.V2_44KHZ_128BAND_512X, use_cuda_kernel=False
)


output_audio = model(input_mel)
```

See the [example](https://github.com/thunn/BigVGANInference/blob/main/example/inference.py) for full usage example.

## Acknowledgements
This is an unofficial implementation based on [original BigVGAN repository](https://github.com/NVIDIA/BigVGAN).

## License
This project is licensed under the MIT License. See the [LICENSE](https://github.com/thunn/BigVGANInference/blob/main/LICENSE) file for details.
4 changes: 4 additions & 0 deletions bigvganinference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import importlib.metadata

from bigvganinference.inference import BigVGANHFModel, BigVGANInference

__version__ = importlib.metadata.version("bigvganinference")

__all__ = ["BigVGANInference", "BigVGANHFModel"]
4 changes: 2 additions & 2 deletions bigvganinference/alias_free_activation/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# LICENSE is in incl_licenses directory.

from .act import Activation1d
from .filter import LowPassFilter1d, Sinc, kaiser_sinc_filter1d
from .filter import LowPassFilter1d, kaiser_sinc_filter1d, sinc
from .resample import DownSample1d, UpSample1d

__all__ = [
"Activation1d",
"Sinc",
"sinc",
"LowPassFilter1d",
"kaiser_sinc_filter1d",
"UpSample1d",
Expand Down
6 changes: 3 additions & 3 deletions bigvganinference/bigvgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(

# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
from bigvganinference.alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)

Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):

# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
from bigvganinference.alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)

Expand Down Expand Up @@ -437,7 +437,7 @@ def _from_pretrained(
local_files_only=local_files_only,
)

checkpoint_dict = torch.load(model_file, map_location=map_location)
checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True)

try:
model.load_state_dict(checkpoint_dict["generator"])
Expand Down
71 changes: 71 additions & 0 deletions bigvganinference/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from enum import Enum

import numpy as np
import torch

from bigvganinference.bigvgan import BigVGAN
from bigvganinference.env import AttrDict
from bigvganinference.meldataset import get_mel_spectrogram


class BigVGANHFModel(str, Enum):
"""
BigVGAN HF models.
"""

V2_44KHZ_128BAND_512X = "nvidia/bigvgan_v2_44khz_128band_512x"
V2_44KHZ_128BAND_256X = "nvidia/bigvgan_v2_44khz_128band_256x"
V2_24KHZ_100BAND_256X = "nvidia/bigvgan_v2_24khz_100band_256x"
V2_22KHZ_80BAND_256X = "nvidia/bigvgan_v2_22khz_80band_256x"
V2_22KHZ_80BAND_FMAX8K_256X = "nvidia/bigvgan_v2_22khz_80band_fmax8k_256x"
V2_24KHZ_100BAND = "nvidia/bigvgan_24khz_100band"
V2_22KHZ_80BAND = "nvidia/bigvgan_22khz_80band"
BASE_24KHZ_100BAND = "nvidia/bigvgan_base_24khz_100band"
BASE_22KHZ_80BAND = "nvidia/bigvgan_base_22khz_80band"

def __str__(self):
return self.value


class BigVGANInference(BigVGAN):
"""
BigVGAN inference.
"""

def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
super().__init__(h, use_cuda_kernel)

# set to eval and remove weight norm
self.eval()
self.remove_weight_norm()

def get_mel_spectrogram(self, wav: torch.Tensor | np.ndarray) -> torch.Tensor:
"""
Wrapper function to preprocess audio and convert to mel spectrogram.
Args:
wav (torch.Tensor | np.ndarray): Audio waveform.
Returns:
torch.Tensor: Mel spectrogram.
"""

# ensure wav is FloatTensor with shape [B(1), T_time]
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)

# If batch dimension is missing, add it
if wav.ndim == 1:
wav = wav.unsqueeze(0)

# ensure that audio is mono (batch size of 1)
if wav.shape[0] > 1:
wav = wav.mean(dim=0).unsqueeze(0)

mel = get_mel_spectrogram(wav, self.h)

# ensure mel is on the same device as the model
device = next(self.parameters()).device
mel = mel.to(device)

return mel
2 changes: 1 addition & 1 deletion bigvganinference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_padding(kernel_size, dilation=1):
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True)
print("Complete.")
return checkpoint_dict

Expand Down
20 changes: 20 additions & 0 deletions example/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import librosa
import torch

from bigvganinference import BigVGANHFModel, BigVGANInference

model = BigVGANInference.from_pretrained(BigVGANHFModel.V2_22KHZ_80BAND_FMAX8K_256X, use_cuda_kernel=False)


wav_path = "example/example.wav"
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]

# Note this implemntation is a wrapper around the get_mel_spectrogram function
# additional audio preprocessing is done to ensure the input is in the correct format
mel = model.get_mel_spectrogram(wav)

with torch.inference_mode():
wav_gen = model(mel)

wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype("int16") # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
27 changes: 27 additions & 0 deletions example/vanilla_bigvgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import librosa
import torch

from bigvganinference import bigvgan
from bigvganinference.meldataset import get_mel_spectrogram

# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
model = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)
device = "cpu"

model.remove_weight_norm()
model = model.eval().to(device)


wav_path = "example/example.wav"
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]

mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]

# generate waveform from mel
with torch.inference_mode():
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]

# you can convert the generated waveform to 16 bit linear PCM
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype("int16") # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
45 changes: 36 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tool.poetry]
name = "bigvganinference"
version = "0.1.0"
version = "0.0.1"
description = "An unofficial minimal package for using BigVGAN at inference time"
authors = ["Tom Hunn <thunn_on_github>"]
authors = []
license = "MIT"
readme = "README.md"
packages = [{ include = "bigvganinference" }]
Expand All @@ -13,6 +13,7 @@ torch = ">=2.3.1"
librosa = ">=0.8.1"
scipy = "^1.14.1"
huggingface-hub = "^0.26.0"
ninja = "^1.11.1.1"


[tool.poetry.group.dev.dependencies]
Expand Down

0 comments on commit 4c20e8b

Please sign in to comment.