Skip to content

Commit

Permalink
Add inference class
Browse files Browse the repository at this point in the history
  • Loading branch information
thunn committed Nov 12, 2024
1 parent 251e907 commit 2c5080e
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 15 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,36 @@
# BigVGANInference
An unofficial minimal package for using BigVGAN at inference time


## Installation

```python
pip install bigvganinference
```

or install from source:

```python
git clone https://github.com/thunn/BigVGANInference.git
cd BigVGANInference
poetry install
```

## Usage

Loading model is as simple as:
```python
from bigvganinference.inference import BigVGANInference, BigVGANHFModel

model = BigVGANInference.from_pretrained(
BigVGANHFModel.V2_44KHZ_128BAND_512X, use_cuda_kernel=False
)
```

See the [example](example/inference.py) for full usage example.

## Acknowledgements
This is an unofficial implementation based on [original BigVGAN repository](https://github.com/NVIDIA/BigVGAN).

## License
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
4 changes: 4 additions & 0 deletions bigvganinference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import importlib.metadata

from bigvganinference.inference import BigVGANHFModel, BigVGANInference

__version__ = importlib.metadata.version("bigvganinference")

__all__ = ["BigVGANInference", "BigVGANHFModel"]
4 changes: 2 additions & 2 deletions bigvganinference/alias_free_activation/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# LICENSE is in incl_licenses directory.

from .act import Activation1d
from .filter import LowPassFilter1d, Sinc, kaiser_sinc_filter1d
from .filter import LowPassFilter1d, kaiser_sinc_filter1d, sinc
from .resample import DownSample1d, UpSample1d

__all__ = [
"Activation1d",
"Sinc",
"sinc",
"LowPassFilter1d",
"kaiser_sinc_filter1d",
"UpSample1d",
Expand Down
6 changes: 3 additions & 3 deletions bigvganinference/bigvgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(

# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
from bigvganinference.alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)

Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):

# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
from bigvganinference.alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)

Expand Down Expand Up @@ -437,7 +437,7 @@ def _from_pretrained(
local_files_only=local_files_only,
)

checkpoint_dict = torch.load(model_file, map_location=map_location)
checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True)

try:
model.load_state_dict(checkpoint_dict["generator"])
Expand Down
71 changes: 71 additions & 0 deletions bigvganinference/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import torch
from bigvganinference.env import AttrDict
from enum import Enum
from bigvganinference.bigvgan import BigVGAN

from typing import Dict, Optional, Union
from bigvganinference.meldataset import get_mel_spectrogram


class BigVGANHFModel(str, Enum):
"""
BigVGAN HF models.
"""

V2_44KHZ_128BAND_512X = "nvidia/bigvgan_v2_44khz_128band_512x"
V2_44KHZ_128BAND_256X = "nvidia/bigvgan_v2_44khz_128band_256x"
V2_24KHZ_100BAND_256X = "nvidia/bigvgan_v2_24khz_100band_256x"
V2_22KHZ_80BAND_256X = "nvidia/bigvgan_v2_22khz_80band_256x"
V2_22KHZ_80BAND_FMAX8K_256X = "nvidia/bigvgan_v2_22khz_80band_fmax8k_256x"
V2_24KHZ_100BAND = "nvidia/bigvgan_24khz_100band"
V2_22KHZ_80BAND = "nvidia/bigvgan_22khz_80band"
BASE_24KHZ_100BAND = "nvidia/bigvgan_base_24khz_100band"
BASE_22KHZ_80BAND = "nvidia/bigvgan_base_22khz_80band"

def __str__(self):
return self.value


class BigVGANInference(BigVGAN):
"""
BigVGAN inference.
"""

def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
super().__init__(h, use_cuda_kernel)

# set to eval and remove weight norm
self.eval()
self.remove_weight_norm()

def get_mel_spectrogram(self, wav: torch.Tensor | np.ndarray) -> torch.Tensor:
"""
Wrapper function to preprocess audio and convert to mel spectrogram.
Args:
wav (torch.Tensor | np.ndarray): Audio waveform.
Returns:
torch.Tensor: Mel spectrogram.
"""

# ensure wav is FloatTensor with shape [B(1), T_time]
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)

# If batch dimension is missing, add it
if wav.ndim == 1:
wav = wav.unsqueeze(0)

# ensure that audio is mono (batch size of 1)
if wav.shape[0] > 1:
wav = wav.mean(dim=0).unsqueeze(0)

mel = get_mel_spectrogram(wav, self.h)

# ensure mel is on the same device as the model
device = next(self.parameters()).device
mel = mel.to(device)

return mel
2 changes: 1 addition & 1 deletion bigvganinference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_padding(kernel_size, dilation=1):
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True)
print("Complete.")
return checkpoint_dict

Expand Down
29 changes: 29 additions & 0 deletions example/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import librosa
import torch

from bigvganinference import BigVGANInference, BigVGANHFModel


model = BigVGANInference.from_pretrained(
BigVGANHFModel.V2_22KHZ_80BAND_FMAX8K_256X, use_cuda_kernel=False
)


wav_path = "example/example.wav"
wav, sr = librosa.load(
wav_path, sr=model.h.sampling_rate, mono=True
) # wav is np.ndarray with shape [T_time] and values in [-1, 1]

# Note this implemntation is a wrapper around the get_mel_spectrogram function
# additional audio preprocessing is done to ensure the input is in the correct format
mel = model.get_mel_spectrogram(wav)

with torch.inference_mode():
wav_gen = model(mel)

wav_gen_float = wav_gen.squeeze(
0
).cpu() # wav_gen is FloatTensor with shape [1, T_time]
wav_gen_int16 = (
(wav_gen_float * 32767.0).numpy().astype("int16")
) # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
27 changes: 27 additions & 0 deletions example/vanilla_bigvgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import librosa
import torch

from bigvganinference import bigvgan
from bigvganinference.meldataset import get_mel_spectrogram

# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
model = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)
device = "cpu"

model.remove_weight_norm()
model = model.eval().to(device)


wav_path = "example/example.wav"
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]

mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]

# generate waveform from mel
with torch.inference_mode():
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]

# you can convert the generated waveform to 16 bit linear PCM
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype("int16") # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
18 changes: 9 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 2c5080e

Please sign in to comment.