-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add inference class
- Loading branch information
Showing
10 changed files
with
208 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,43 @@ | ||
# BigVGANInference | ||
An unofficial minimal package for using BigVGAN at inference time | ||
|
||
[![PyPI version](https://img.shields.io/pypi/v/bigvganinference)](https://pypi.org/project/bigvganinference/) | ||
![License](https://img.shields.io/pypi/l/bigvganinference) | ||
![Python versions](https://img.shields.io/pypi/pyversions/bigvganinference) | ||
|
||
## Installation | ||
|
||
```python | ||
pip install bigvganinference | ||
``` | ||
|
||
or install from source: | ||
|
||
```python | ||
git clone https://github.com/thunn/BigVGANInference.git | ||
cd BigVGANInference | ||
poetry install | ||
``` | ||
|
||
## Usage | ||
|
||
Loading model is as simple as: | ||
```python | ||
from bigvganinference import BigVGANInference, BigVGANHFModel | ||
|
||
# model is loaded, set to eval and weight norm is removed | ||
model = BigVGANInference.from_pretrained( | ||
BigVGANHFModel.V2_44KHZ_128BAND_512X, use_cuda_kernel=False | ||
) | ||
|
||
|
||
output_audio = model(input_mel) | ||
``` | ||
|
||
See the [example](https://github.com/thunn/BigVGANInference/blob/main/example/inference.py) for full usage example. | ||
|
||
## Acknowledgements | ||
This is an unofficial implementation based on [original BigVGAN repository](https://github.com/NVIDIA/BigVGAN). | ||
|
||
## License | ||
This project is licensed under the MIT License. See the [LICENSE](https://github.com/thunn/BigVGANInference/blob/main/LICENSE) file for details. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
import importlib.metadata | ||
|
||
from bigvganinference.inference import BigVGANHFModel, BigVGANInference | ||
|
||
__version__ = importlib.metadata.version("bigvganinference") | ||
|
||
__all__ = ["BigVGANInference", "BigVGANHFModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from enum import Enum | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from bigvganinference.bigvgan import BigVGAN | ||
from bigvganinference.env import AttrDict | ||
from bigvganinference.meldataset import get_mel_spectrogram | ||
|
||
|
||
class BigVGANHFModel(str, Enum): | ||
""" | ||
BigVGAN HF models. | ||
""" | ||
|
||
V2_44KHZ_128BAND_512X = "nvidia/bigvgan_v2_44khz_128band_512x" | ||
V2_44KHZ_128BAND_256X = "nvidia/bigvgan_v2_44khz_128band_256x" | ||
V2_24KHZ_100BAND_256X = "nvidia/bigvgan_v2_24khz_100band_256x" | ||
V2_22KHZ_80BAND_256X = "nvidia/bigvgan_v2_22khz_80band_256x" | ||
V2_22KHZ_80BAND_FMAX8K_256X = "nvidia/bigvgan_v2_22khz_80band_fmax8k_256x" | ||
V2_24KHZ_100BAND = "nvidia/bigvgan_24khz_100band" | ||
V2_22KHZ_80BAND = "nvidia/bigvgan_22khz_80band" | ||
BASE_24KHZ_100BAND = "nvidia/bigvgan_base_24khz_100band" | ||
BASE_22KHZ_80BAND = "nvidia/bigvgan_base_22khz_80band" | ||
|
||
def __str__(self): | ||
return self.value | ||
|
||
|
||
class BigVGANInference(BigVGAN): | ||
""" | ||
BigVGAN inference. | ||
""" | ||
|
||
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): | ||
super().__init__(h, use_cuda_kernel) | ||
|
||
# set to eval and remove weight norm | ||
self.eval() | ||
self.remove_weight_norm() | ||
|
||
def get_mel_spectrogram(self, wav: torch.Tensor | np.ndarray) -> torch.Tensor: | ||
""" | ||
Wrapper function to preprocess audio and convert to mel spectrogram. | ||
Args: | ||
wav (torch.Tensor | np.ndarray): Audio waveform. | ||
Returns: | ||
torch.Tensor: Mel spectrogram. | ||
""" | ||
|
||
# ensure wav is FloatTensor with shape [B(1), T_time] | ||
if isinstance(wav, np.ndarray): | ||
wav = torch.from_numpy(wav) | ||
|
||
# If batch dimension is missing, add it | ||
if wav.ndim == 1: | ||
wav = wav.unsqueeze(0) | ||
|
||
# ensure that audio is mono (batch size of 1) | ||
if wav.shape[0] > 1: | ||
wav = wav.mean(dim=0).unsqueeze(0) | ||
|
||
mel = get_mel_spectrogram(wav, self.h) | ||
|
||
# ensure mel is on the same device as the model | ||
device = next(self.parameters()).device | ||
mel = mel.to(device) | ||
|
||
return mel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import librosa | ||
import torch | ||
|
||
from bigvganinference import BigVGANHFModel, BigVGANInference | ||
|
||
model = BigVGANInference.from_pretrained(BigVGANHFModel.V2_22KHZ_80BAND_FMAX8K_256X, use_cuda_kernel=False) | ||
|
||
|
||
wav_path = "example/example.wav" | ||
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1] | ||
|
||
# Note this implemntation is a wrapper around the get_mel_spectrogram function | ||
# additional audio preprocessing is done to ensure the input is in the correct format | ||
mel = model.get_mel_spectrogram(wav) | ||
|
||
with torch.inference_mode(): | ||
wav_gen = model(mel) | ||
|
||
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time] | ||
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype("int16") # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import librosa | ||
import torch | ||
|
||
from bigvganinference import bigvgan | ||
from bigvganinference.meldataset import get_mel_spectrogram | ||
|
||
# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference. | ||
model = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False) | ||
device = "cpu" | ||
|
||
model.remove_weight_norm() | ||
model = model.eval().to(device) | ||
|
||
|
||
wav_path = "example/example.wav" | ||
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1] | ||
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time] | ||
|
||
mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame] | ||
|
||
# generate waveform from mel | ||
with torch.inference_mode(): | ||
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1] | ||
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time] | ||
|
||
# you can convert the generated waveform to 16 bit linear PCM | ||
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype("int16") # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters