-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
180 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,36 @@ | ||
# BigVGANInference | ||
An unofficial minimal package for using BigVGAN at inference time | ||
|
||
|
||
## Installation | ||
|
||
```python | ||
pip install bigvganinference | ||
``` | ||
|
||
or install from source: | ||
|
||
```python | ||
git clone https://github.com/thunn/BigVGANInference.git | ||
cd BigVGANInference | ||
poetry install | ||
``` | ||
|
||
## Usage | ||
|
||
Loading model is as simple as: | ||
```python | ||
from bigvganinference.inference import BigVGANInference, BigVGANHFModel | ||
|
||
model = BigVGANInference.from_pretrained( | ||
BigVGANHFModel.V2_44KHZ_128BAND_512X, use_cuda_kernel=False | ||
) | ||
``` | ||
|
||
See the [example](example/inference.py) for full usage example. | ||
|
||
## Acknowledgements | ||
This is an unofficial implementation based on [original BigVGAN repository](https://github.com/NVIDIA/BigVGAN). | ||
|
||
## License | ||
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
import importlib.metadata | ||
|
||
from bigvganinference.inference import BigVGANHFModel, BigVGANInference | ||
|
||
__version__ = importlib.metadata.version("bigvganinference") | ||
|
||
__all__ = ["BigVGANInference", "BigVGANHFModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
import torch | ||
from bigvganinference.env import AttrDict | ||
from enum import Enum | ||
from bigvganinference.bigvgan import BigVGAN | ||
|
||
from typing import Dict, Optional, Union | ||
from bigvganinference.meldataset import get_mel_spectrogram | ||
|
||
|
||
class BigVGANHFModel(str, Enum): | ||
""" | ||
BigVGAN HF models. | ||
""" | ||
|
||
V2_44KHZ_128BAND_512X = "nvidia/bigvgan_v2_44khz_128band_512x" | ||
V2_44KHZ_128BAND_256X = "nvidia/bigvgan_v2_44khz_128band_256x" | ||
V2_24KHZ_100BAND_256X = "nvidia/bigvgan_v2_24khz_100band_256x" | ||
V2_22KHZ_80BAND_256X = "nvidia/bigvgan_v2_22khz_80band_256x" | ||
V2_22KHZ_80BAND_FMAX8K_256X = "nvidia/bigvgan_v2_22khz_80band_fmax8k_256x" | ||
V2_24KHZ_100BAND = "nvidia/bigvgan_24khz_100band" | ||
V2_22KHZ_80BAND = "nvidia/bigvgan_22khz_80band" | ||
BASE_24KHZ_100BAND = "nvidia/bigvgan_base_24khz_100band" | ||
BASE_22KHZ_80BAND = "nvidia/bigvgan_base_22khz_80band" | ||
|
||
def __str__(self): | ||
return self.value | ||
|
||
|
||
class BigVGANInference(BigVGAN): | ||
""" | ||
BigVGAN inference. | ||
""" | ||
|
||
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): | ||
super().__init__(h, use_cuda_kernel) | ||
|
||
# set to eval and remove weight norm | ||
self.eval() | ||
self.remove_weight_norm() | ||
|
||
def get_mel_spectrogram(self, wav: torch.Tensor | np.ndarray) -> torch.Tensor: | ||
""" | ||
Wrapper function to preprocess audio and convert to mel spectrogram. | ||
Args: | ||
wav (torch.Tensor | np.ndarray): Audio waveform. | ||
Returns: | ||
torch.Tensor: Mel spectrogram. | ||
""" | ||
|
||
# ensure wav is FloatTensor with shape [B(1), T_time] | ||
if isinstance(wav, np.ndarray): | ||
wav = torch.from_numpy(wav) | ||
|
||
# If batch dimension is missing, add it | ||
if wav.ndim == 1: | ||
wav = wav.unsqueeze(0) | ||
|
||
# ensure that audio is mono (batch size of 1) | ||
if wav.shape[0] > 1: | ||
wav = wav.mean(dim=0).unsqueeze(0) | ||
|
||
mel = get_mel_spectrogram(wav, self.h) | ||
|
||
# ensure mel is on the same device as the model | ||
device = next(self.parameters()).device | ||
mel = mel.to(device) | ||
|
||
return mel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import librosa | ||
import torch | ||
|
||
from bigvganinference import BigVGANInference, BigVGANHFModel | ||
|
||
|
||
model = BigVGANInference.from_pretrained( | ||
BigVGANHFModel.V2_22KHZ_80BAND_FMAX8K_256X, use_cuda_kernel=False | ||
) | ||
|
||
|
||
wav_path = "example/example.wav" | ||
wav, sr = librosa.load( | ||
wav_path, sr=model.h.sampling_rate, mono=True | ||
) # wav is np.ndarray with shape [T_time] and values in [-1, 1] | ||
|
||
# Note this implemntation is a wrapper around the get_mel_spectrogram function | ||
# additional audio preprocessing is done to ensure the input is in the correct format | ||
mel = model.get_mel_spectrogram(wav) | ||
|
||
with torch.inference_mode(): | ||
wav_gen = model(mel) | ||
|
||
wav_gen_float = wav_gen.squeeze( | ||
0 | ||
).cpu() # wav_gen is FloatTensor with shape [1, T_time] | ||
wav_gen_int16 = ( | ||
(wav_gen_float * 32767.0).numpy().astype("int16") | ||
) # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import librosa | ||
import torch | ||
|
||
from bigvganinference import bigvgan | ||
from bigvganinference.meldataset import get_mel_spectrogram | ||
|
||
# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference. | ||
model = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False) | ||
device = "cpu" | ||
|
||
model.remove_weight_norm() | ||
model = model.eval().to(device) | ||
|
||
|
||
wav_path = "example/example.wav" | ||
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1] | ||
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time] | ||
|
||
mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame] | ||
|
||
# generate waveform from mel | ||
with torch.inference_mode(): | ||
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1] | ||
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time] | ||
|
||
# you can convert the generated waveform to 16 bit linear PCM | ||
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype("int16") # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.