diff --git a/README.md b/README.md index 600d006..6b61134 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,36 @@ # BigVGANInference An unofficial minimal package for using BigVGAN at inference time + + +## Installation + +```python +pip install bigvganinference +``` + +or install from source: + +```python +git clone https://github.com/thunn/BigVGANInference.git +cd BigVGANInference +poetry install +``` + +## Usage + +Loading model is as simple as: +```python +from bigvganinference.inference import BigVGANInference, BigVGANHFModel + +model = BigVGANInference.from_pretrained( + BigVGANHFModel.V2_44KHZ_128BAND_512X, use_cuda_kernel=False +) +``` + +See the [example](example/inference.py) for full usage example. + +## Acknowledgements +This is an unofficial implementation based on [original BigVGAN repository](https://github.com/NVIDIA/BigVGAN). + +## License +This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. \ No newline at end of file diff --git a/bigvganinference/__init__.py b/bigvganinference/__init__.py index 1503569..507043b 100644 --- a/bigvganinference/__init__.py +++ b/bigvganinference/__init__.py @@ -1,3 +1,7 @@ import importlib.metadata +from bigvganinference.inference import BigVGANHFModel, BigVGANInference + __version__ = importlib.metadata.version("bigvganinference") + +__all__ = ["BigVGANInference", "BigVGANHFModel"] diff --git a/bigvganinference/alias_free_activation/torch/__init__.py b/bigvganinference/alias_free_activation/torch/__init__.py index a483d73..17595c9 100644 --- a/bigvganinference/alias_free_activation/torch/__init__.py +++ b/bigvganinference/alias_free_activation/torch/__init__.py @@ -2,12 +2,12 @@ # LICENSE is in incl_licenses directory. from .act import Activation1d -from .filter import LowPassFilter1d, Sinc, kaiser_sinc_filter1d +from .filter import LowPassFilter1d, kaiser_sinc_filter1d, sinc from .resample import DownSample1d, UpSample1d __all__ = [ "Activation1d", - "Sinc", + "sinc", "LowPassFilter1d", "kaiser_sinc_filter1d", "UpSample1d", diff --git a/bigvganinference/bigvgan.py b/bigvganinference/bigvgan.py index abb1e83..c6ad4a2 100644 --- a/bigvganinference/bigvgan.py +++ b/bigvganinference/bigvgan.py @@ -95,7 +95,7 @@ def __init__( # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( + from bigvganinference.alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) @@ -242,7 +242,7 @@ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( + from bigvganinference.alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) @@ -437,7 +437,7 @@ def _from_pretrained( local_files_only=local_files_only, ) - checkpoint_dict = torch.load(model_file, map_location=map_location) + checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True) try: model.load_state_dict(checkpoint_dict["generator"]) diff --git a/bigvganinference/inference.py b/bigvganinference/inference.py new file mode 100644 index 0000000..45269e1 --- /dev/null +++ b/bigvganinference/inference.py @@ -0,0 +1,71 @@ +import numpy as np +import torch +from bigvganinference.env import AttrDict +from enum import Enum +from bigvganinference.bigvgan import BigVGAN + +from typing import Dict, Optional, Union +from bigvganinference.meldataset import get_mel_spectrogram + + +class BigVGANHFModel(str, Enum): + """ + BigVGAN HF models. + """ + + V2_44KHZ_128BAND_512X = "nvidia/bigvgan_v2_44khz_128band_512x" + V2_44KHZ_128BAND_256X = "nvidia/bigvgan_v2_44khz_128band_256x" + V2_24KHZ_100BAND_256X = "nvidia/bigvgan_v2_24khz_100band_256x" + V2_22KHZ_80BAND_256X = "nvidia/bigvgan_v2_22khz_80band_256x" + V2_22KHZ_80BAND_FMAX8K_256X = "nvidia/bigvgan_v2_22khz_80band_fmax8k_256x" + V2_24KHZ_100BAND = "nvidia/bigvgan_24khz_100band" + V2_22KHZ_80BAND = "nvidia/bigvgan_22khz_80band" + BASE_24KHZ_100BAND = "nvidia/bigvgan_base_24khz_100band" + BASE_22KHZ_80BAND = "nvidia/bigvgan_base_22khz_80band" + + def __str__(self): + return self.value + + +class BigVGANInference(BigVGAN): + """ + BigVGAN inference. + """ + + def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): + super().__init__(h, use_cuda_kernel) + + # set to eval and remove weight norm + self.eval() + self.remove_weight_norm() + + def get_mel_spectrogram(self, wav: torch.Tensor | np.ndarray) -> torch.Tensor: + """ + Wrapper function to preprocess audio and convert to mel spectrogram. + + Args: + wav (torch.Tensor | np.ndarray): Audio waveform. + + Returns: + torch.Tensor: Mel spectrogram. + """ + + # ensure wav is FloatTensor with shape [B(1), T_time] + if isinstance(wav, np.ndarray): + wav = torch.from_numpy(wav) + + # If batch dimension is missing, add it + if wav.ndim == 1: + wav = wav.unsqueeze(0) + + # ensure that audio is mono (batch size of 1) + if wav.shape[0] > 1: + wav = wav.mean(dim=0).unsqueeze(0) + + mel = get_mel_spectrogram(wav, self.h) + + # ensure mel is on the same device as the model + device = next(self.parameters()).device + mel = mel.to(device) + + return mel diff --git a/bigvganinference/utils.py b/bigvganinference/utils.py index 1ef6ad4..713b1b4 100644 --- a/bigvganinference/utils.py +++ b/bigvganinference/utils.py @@ -30,7 +30,7 @@ def get_padding(kernel_size, dilation=1): def load_checkpoint(filepath, device): assert os.path.isfile(filepath) print(f"Loading '{filepath}'") - checkpoint_dict = torch.load(filepath, map_location=device) + checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True) print("Complete.") return checkpoint_dict diff --git a/example/inference.py b/example/inference.py new file mode 100644 index 0000000..152a407 --- /dev/null +++ b/example/inference.py @@ -0,0 +1,29 @@ +import librosa +import torch + +from bigvganinference import BigVGANInference, BigVGANHFModel + + +model = BigVGANInference.from_pretrained( + BigVGANHFModel.V2_22KHZ_80BAND_FMAX8K_256X, use_cuda_kernel=False +) + + +wav_path = "example/example.wav" +wav, sr = librosa.load( + wav_path, sr=model.h.sampling_rate, mono=True +) # wav is np.ndarray with shape [T_time] and values in [-1, 1] + +# Note this implemntation is a wrapper around the get_mel_spectrogram function +# additional audio preprocessing is done to ensure the input is in the correct format +mel = model.get_mel_spectrogram(wav) + +with torch.inference_mode(): + wav_gen = model(mel) + +wav_gen_float = wav_gen.squeeze( + 0 +).cpu() # wav_gen is FloatTensor with shape [1, T_time] +wav_gen_int16 = ( + (wav_gen_float * 32767.0).numpy().astype("int16") +) # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype diff --git a/example/vanilla_bigvgan.py b/example/vanilla_bigvgan.py new file mode 100644 index 0000000..8f9cbee --- /dev/null +++ b/example/vanilla_bigvgan.py @@ -0,0 +1,27 @@ +import librosa +import torch + +from bigvganinference import bigvgan +from bigvganinference.meldataset import get_mel_spectrogram + +# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference. +model = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False) +device = "cpu" + +model.remove_weight_norm() +model = model.eval().to(device) + + +wav_path = "example/example.wav" +wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1] +wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time] + +mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame] + +# generate waveform from mel +with torch.inference_mode(): + wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1] +wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time] + +# you can convert the generated waveform to 16 bit linear PCM +wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype("int16") # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype diff --git a/poetry.lock b/poetry.lock index da26b70..b9db9e1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1181,23 +1181,23 @@ test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.3 [[package]] name = "setuptools" -version = "75.2.0" +version = "75.4.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"}, - {file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"}, + {file = "setuptools-75.4.0-py3-none-any.whl", hash = "sha256:b3c5d862f98500b06ffdf7cc4499b48c46c317d8d56cb30b5c8bce4d88f5c216"}, + {file = "setuptools-75.4.0.tar.gz", hash = "sha256:1dc484f5cf56fd3fe7216d7b8df820802e7246cfb534a1db2aa64f14fcb9cdcb"}, ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.7.0)"] +core = ["importlib-metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"] [[package]] name = "soundfile" @@ -1432,4 +1432,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "6d18b2aabee8f3b0d3acd6b615da588a123864dcde0606e7c272f52f02144d72" +content-hash = "639e5c5ac7427616356ba85a714224e31b6302d4a9974b7b4feac73516f33685"