From 4c20e8bcaeac00c59aaa75d1990b321d7abb5443 Mon Sep 17 00:00:00 2001 From: thunn Date: Tue, 12 Nov 2024 19:16:36 +1000 Subject: [PATCH] Add inference class (#3) * Add inference class --- README.md | 41 +++++++++++ bigvganinference/__init__.py | 4 ++ .../alias_free_activation/torch/__init__.py | 4 +- bigvganinference/bigvgan.py | 6 +- bigvganinference/inference.py | 71 +++++++++++++++++++ bigvganinference/utils.py | 2 +- example/inference.py | 20 ++++++ example/vanilla_bigvgan.py | 27 +++++++ poetry.lock | 45 +++++++++--- pyproject.toml | 5 +- 10 files changed, 208 insertions(+), 17 deletions(-) create mode 100644 bigvganinference/inference.py create mode 100644 example/inference.py create mode 100644 example/vanilla_bigvgan.py diff --git a/README.md b/README.md index 600d006..32d63ac 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,43 @@ # BigVGANInference An unofficial minimal package for using BigVGAN at inference time + +[![PyPI version](https://img.shields.io/pypi/v/bigvganinference)](https://pypi.org/project/bigvganinference/) +![License](https://img.shields.io/pypi/l/bigvganinference) +![Python versions](https://img.shields.io/pypi/pyversions/bigvganinference) + +## Installation + +```python +pip install bigvganinference +``` + +or install from source: + +```python +git clone https://github.com/thunn/BigVGANInference.git +cd BigVGANInference +poetry install +``` + +## Usage + +Loading model is as simple as: +```python +from bigvganinference import BigVGANInference, BigVGANHFModel + +# model is loaded, set to eval and weight norm is removed +model = BigVGANInference.from_pretrained( + BigVGANHFModel.V2_44KHZ_128BAND_512X, use_cuda_kernel=False +) + + +output_audio = model(input_mel) +``` + +See the [example](https://github.com/thunn/BigVGANInference/blob/main/example/inference.py) for full usage example. + +## Acknowledgements +This is an unofficial implementation based on [original BigVGAN repository](https://github.com/NVIDIA/BigVGAN). + +## License +This project is licensed under the MIT License. See the [LICENSE](https://github.com/thunn/BigVGANInference/blob/main/LICENSE) file for details. \ No newline at end of file diff --git a/bigvganinference/__init__.py b/bigvganinference/__init__.py index 1503569..507043b 100644 --- a/bigvganinference/__init__.py +++ b/bigvganinference/__init__.py @@ -1,3 +1,7 @@ import importlib.metadata +from bigvganinference.inference import BigVGANHFModel, BigVGANInference + __version__ = importlib.metadata.version("bigvganinference") + +__all__ = ["BigVGANInference", "BigVGANHFModel"] diff --git a/bigvganinference/alias_free_activation/torch/__init__.py b/bigvganinference/alias_free_activation/torch/__init__.py index a483d73..17595c9 100644 --- a/bigvganinference/alias_free_activation/torch/__init__.py +++ b/bigvganinference/alias_free_activation/torch/__init__.py @@ -2,12 +2,12 @@ # LICENSE is in incl_licenses directory. from .act import Activation1d -from .filter import LowPassFilter1d, Sinc, kaiser_sinc_filter1d +from .filter import LowPassFilter1d, kaiser_sinc_filter1d, sinc from .resample import DownSample1d, UpSample1d __all__ = [ "Activation1d", - "Sinc", + "sinc", "LowPassFilter1d", "kaiser_sinc_filter1d", "UpSample1d", diff --git a/bigvganinference/bigvgan.py b/bigvganinference/bigvgan.py index abb1e83..c6ad4a2 100644 --- a/bigvganinference/bigvgan.py +++ b/bigvganinference/bigvgan.py @@ -95,7 +95,7 @@ def __init__( # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( + from bigvganinference.alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) @@ -242,7 +242,7 @@ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( + from bigvganinference.alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) @@ -437,7 +437,7 @@ def _from_pretrained( local_files_only=local_files_only, ) - checkpoint_dict = torch.load(model_file, map_location=map_location) + checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True) try: model.load_state_dict(checkpoint_dict["generator"]) diff --git a/bigvganinference/inference.py b/bigvganinference/inference.py new file mode 100644 index 0000000..a6e2d21 --- /dev/null +++ b/bigvganinference/inference.py @@ -0,0 +1,71 @@ +from enum import Enum + +import numpy as np +import torch + +from bigvganinference.bigvgan import BigVGAN +from bigvganinference.env import AttrDict +from bigvganinference.meldataset import get_mel_spectrogram + + +class BigVGANHFModel(str, Enum): + """ + BigVGAN HF models. + """ + + V2_44KHZ_128BAND_512X = "nvidia/bigvgan_v2_44khz_128band_512x" + V2_44KHZ_128BAND_256X = "nvidia/bigvgan_v2_44khz_128band_256x" + V2_24KHZ_100BAND_256X = "nvidia/bigvgan_v2_24khz_100band_256x" + V2_22KHZ_80BAND_256X = "nvidia/bigvgan_v2_22khz_80band_256x" + V2_22KHZ_80BAND_FMAX8K_256X = "nvidia/bigvgan_v2_22khz_80band_fmax8k_256x" + V2_24KHZ_100BAND = "nvidia/bigvgan_24khz_100band" + V2_22KHZ_80BAND = "nvidia/bigvgan_22khz_80band" + BASE_24KHZ_100BAND = "nvidia/bigvgan_base_24khz_100band" + BASE_22KHZ_80BAND = "nvidia/bigvgan_base_22khz_80band" + + def __str__(self): + return self.value + + +class BigVGANInference(BigVGAN): + """ + BigVGAN inference. + """ + + def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): + super().__init__(h, use_cuda_kernel) + + # set to eval and remove weight norm + self.eval() + self.remove_weight_norm() + + def get_mel_spectrogram(self, wav: torch.Tensor | np.ndarray) -> torch.Tensor: + """ + Wrapper function to preprocess audio and convert to mel spectrogram. + + Args: + wav (torch.Tensor | np.ndarray): Audio waveform. + + Returns: + torch.Tensor: Mel spectrogram. + """ + + # ensure wav is FloatTensor with shape [B(1), T_time] + if isinstance(wav, np.ndarray): + wav = torch.from_numpy(wav) + + # If batch dimension is missing, add it + if wav.ndim == 1: + wav = wav.unsqueeze(0) + + # ensure that audio is mono (batch size of 1) + if wav.shape[0] > 1: + wav = wav.mean(dim=0).unsqueeze(0) + + mel = get_mel_spectrogram(wav, self.h) + + # ensure mel is on the same device as the model + device = next(self.parameters()).device + mel = mel.to(device) + + return mel diff --git a/bigvganinference/utils.py b/bigvganinference/utils.py index 1ef6ad4..713b1b4 100644 --- a/bigvganinference/utils.py +++ b/bigvganinference/utils.py @@ -30,7 +30,7 @@ def get_padding(kernel_size, dilation=1): def load_checkpoint(filepath, device): assert os.path.isfile(filepath) print(f"Loading '{filepath}'") - checkpoint_dict = torch.load(filepath, map_location=device) + checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True) print("Complete.") return checkpoint_dict diff --git a/example/inference.py b/example/inference.py new file mode 100644 index 0000000..868e90a --- /dev/null +++ b/example/inference.py @@ -0,0 +1,20 @@ +import librosa +import torch + +from bigvganinference import BigVGANHFModel, BigVGANInference + +model = BigVGANInference.from_pretrained(BigVGANHFModel.V2_22KHZ_80BAND_FMAX8K_256X, use_cuda_kernel=False) + + +wav_path = "example/example.wav" +wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1] + +# Note this implemntation is a wrapper around the get_mel_spectrogram function +# additional audio preprocessing is done to ensure the input is in the correct format +mel = model.get_mel_spectrogram(wav) + +with torch.inference_mode(): + wav_gen = model(mel) + +wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time] +wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype("int16") # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype diff --git a/example/vanilla_bigvgan.py b/example/vanilla_bigvgan.py new file mode 100644 index 0000000..8f9cbee --- /dev/null +++ b/example/vanilla_bigvgan.py @@ -0,0 +1,27 @@ +import librosa +import torch + +from bigvganinference import bigvgan +from bigvganinference.meldataset import get_mel_spectrogram + +# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference. +model = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False) +device = "cpu" + +model.remove_weight_norm() +model = model.eval().to(device) + + +wav_path = "example/example.wav" +wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1] +wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time] + +mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame] + +# generate waveform from mel +with torch.inference_mode(): + wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1] +wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time] + +# you can convert the generated waveform to 16 bit linear PCM +wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype("int16") # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype diff --git a/poetry.lock b/poetry.lock index da26b70..aa40216 100644 --- a/poetry.lock +++ b/poetry.lock @@ -665,6 +665,33 @@ doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9. extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] +[[package]] +name = "ninja" +version = "1.11.1.1" +description = "Ninja is a small build system with a focus on speed" +optional = false +python-versions = "*" +files = [ + {file = "ninja-1.11.1.1-py2.py3-none-macosx_10_9_universal2.macosx_10_9_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:376889c76d87b95b5719fdd61dd7db193aa7fd4432e5d52d2e44e4c497bdbbee"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux1_i686.manylinux_2_5_i686.whl", hash = "sha256:ecf80cf5afd09f14dcceff28cb3f11dc90fb97c999c89307aea435889cb66877"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:84502ec98f02a037a169c4b0d5d86075eaf6afc55e1879003d6cab51ced2ea4b"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:73b93c14046447c7c5cc892433d4fae65d6364bec6685411cb97a8bcf815f93a"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:18302d96a5467ea98b68e1cae1ae4b4fb2b2a56a82b955193c637557c7273dbd"}, + {file = "ninja-1.11.1.1-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:aad34a70ef15b12519946c5633344bc775a7656d789d9ed5fdb0d456383716ef"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:d491fc8d89cdcb416107c349ad1e3a735d4c4af5e1cb8f5f727baca6350fdaea"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:7563ce1d9fe6ed5af0b8dd9ab4a214bf4ff1f2f6fd6dc29f480981f0f8b8b249"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:9df724344202b83018abb45cb1efc22efd337a1496514e7e6b3b59655be85205"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:3e0f9be5bb20d74d58c66cc1c414c3e6aeb45c35b0d0e41e8d739c2c0d57784f"}, + {file = "ninja-1.11.1.1-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:76482ba746a2618eecf89d5253c0d1e4f1da1270d41e9f54dfbd91831b0f6885"}, + {file = "ninja-1.11.1.1-py2.py3-none-win32.whl", hash = "sha256:fa2ba9d74acfdfbfbcf06fad1b8282de8a7a8c481d9dee45c859a8c93fcc1082"}, + {file = "ninja-1.11.1.1-py2.py3-none-win_amd64.whl", hash = "sha256:95da904130bfa02ea74ff9c0116b4ad266174fafb1c707aa50212bc7859aebf1"}, + {file = "ninja-1.11.1.1-py2.py3-none-win_arm64.whl", hash = "sha256:185e0641bde601e53841525c4196278e9aaf4463758da6dd1e752c0a0f54136a"}, + {file = "ninja-1.11.1.1.tar.gz", hash = "sha256:9d793b08dd857e38d0b6ffe9e6b7145d7c485a42dcfea04905ca0cdb6017cc3c"}, +] + +[package.extras] +test = ["codecov (>=2.0.5)", "coverage (>=4.2)", "flake8 (>=3.0.4)", "pytest (>=4.5.0)", "pytest-cov (>=2.7.1)", "pytest-runner (>=5.1)", "pytest-virtualenv (>=1.7.0)", "virtualenv (>=15.0.3)"] + [[package]] name = "nodeenv" version = "1.9.1" @@ -1181,23 +1208,23 @@ test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.3 [[package]] name = "setuptools" -version = "75.2.0" +version = "75.4.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "setuptools-75.2.0-py3-none-any.whl", hash = "sha256:a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8"}, - {file = "setuptools-75.2.0.tar.gz", hash = "sha256:753bb6ebf1f465a1912e19ed1d41f403a79173a9acf66a42e7e6aec45c3c16ec"}, + {file = "setuptools-75.4.0-py3-none-any.whl", hash = "sha256:b3c5d862f98500b06ffdf7cc4499b48c46c317d8d56cb30b5c8bce4d88f5c216"}, + {file = "setuptools-75.4.0.tar.gz", hash = "sha256:1dc484f5cf56fd3fe7216d7b8df820802e7246cfb534a1db2aa64f14fcb9cdcb"}, ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.7.0)"] +core = ["importlib-metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.11.*)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"] [[package]] name = "soundfile" @@ -1432,4 +1459,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "6d18b2aabee8f3b0d3acd6b615da588a123864dcde0606e7c272f52f02144d72" +content-hash = "100e05e1311afc56ca7d1b20d1f9078c5b529e1756f4de2e2edf3f7ba4a62ee2" diff --git a/pyproject.toml b/pyproject.toml index 0271ccb..73ebc15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,8 @@ [tool.poetry] name = "bigvganinference" -version = "0.1.0" +version = "0.0.1" description = "An unofficial minimal package for using BigVGAN at inference time" -authors = ["Tom Hunn "] +authors = [] license = "MIT" readme = "README.md" packages = [{ include = "bigvganinference" }] @@ -13,6 +13,7 @@ torch = ">=2.3.1" librosa = ">=0.8.1" scipy = "^1.14.1" huggingface-hub = "^0.26.0" +ninja = "^1.11.1.1" [tool.poetry.group.dev.dependencies]