diff --git a/CHANGELOG.md b/CHANGELOG.md index d9a014e5..99cf6f31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [1.3.0a] - TBD + +Adding the MAGNeT model (https://arxiv.org/abs/2401.04577) along with hf checkpoints and a gradio demo app. + +Typo fixes. + ## [1.2.0] - 2024-01-11 Adding stereo models. diff --git a/README.md b/README.md index bb2e41f2..f89a9829 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ At the moment, AudioCraft contains the training code and inference code for: * [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model. * [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec. * [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion. +* [MAGNeT](./docs/MAGNET.md): A state-of-the-art non-autoregressive model for text-to-music and text-to-sound. ## Training code diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py index 018a37a7..85446fb7 100644 --- a/audiocraft/__init__.py +++ b/audiocraft/__init__.py @@ -23,4 +23,4 @@ # flake8: noqa from . import data, modules, models -__version__ = '1.2.0' +__version__ = '1.3.0a' diff --git a/audiocraft/grids/magnet/__init__.py b/audiocraft/grids/magnet/__init__.py new file mode 100644 index 00000000..fb497091 --- /dev/null +++ b/audiocraft/grids/magnet/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""MAGNeT grids.""" diff --git a/audiocraft/grids/magnet/audio_magnet_16khz.py b/audiocraft/grids/magnet/audio_magnet_16khz.py new file mode 100644 index 00000000..d8ed75db --- /dev/null +++ b/audiocraft/grids/magnet/audio_magnet_16khz.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ..musicgen._explorers import LMExplorer +from ...environment import AudioCraftEnvironment + + +@LMExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=32, partition=partitions) + launcher.bind_(solver='magnet/audio_magnet_16khz') + # replace this by the desired environmental sound dataset + launcher.bind_(dset='internal/sounds_16khz') + + fsdp = {'autocast': False, 'fsdp.use': True} + medium = {'model/lm/model_scale': 'medium'} + + # Small model (300M) + launcher.slurm_(gpus=32).bind_(label='32gpus') + with launcher.job_array(): + sub = launcher.bind() + sub() + + # Medium model (1.5B) + launcher.slurm_(gpus=64).bind_(label='64gpus') + with launcher.job_array(): + sub = launcher.bind() + sub(medium, fsdp) diff --git a/audiocraft/grids/magnet/audio_magnet_pretrained_16khz_eval.py b/audiocraft/grids/magnet/audio_magnet_pretrained_16khz_eval.py new file mode 100644 index 00000000..71282fef --- /dev/null +++ b/audiocraft/grids/magnet/audio_magnet_pretrained_16khz_eval.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Evaluation with objective metrics for the pretrained audio-MAGNeT models. +This grid takes signature from the training grid and runs evaluation-only stage. + +When running the grid for the first time, please use: +REGEN=1 dora grid magnet.audio_magnet_pretrained_16khz_eval +and re-use the REGEN=1 option when the grid is changed to force regenerating it. + +Note that you need the proper metrics external libraries setup to use all +the objective metrics activated in this grid. Refer to the README for more information. +""" + +import os + +from ..musicgen._explorers import GenerationEvalExplorer +from ...environment import AudioCraftEnvironment +from ... import train + + +def eval(launcher, batch_size: int = 32): + opts = { + 'dset': 'audio/audiocaps_16khz', + 'solver/audiogen/evaluation': 'objective_eval', + 'execute_only': 'evaluate', + '+dataset.evaluate.batch_size': batch_size, + '+metrics.fad.tf.batch_size': 32, + } + # binary for FAD computation: replace this path with your own path + metrics_opts = { + 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research' + } + + sub = launcher.bind(opts) + sub.bind_(metrics_opts) + + # base objective metrics + sub() + + +@GenerationEvalExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=4, partition=partitions) + + if 'REGEN' not in os.environ: + folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1] + with launcher.job_array(): + for sig in folder.iterdir(): + if not sig.is_symlink(): + continue + xp = train.main.get_xp_from_sig(sig.name) + launcher(xp.argv) + return + + with launcher.job_array(): + audio_magnet = launcher.bind(solver="magnet/audio_magnet_16khz") + + fsdp = {'autocast': False, 'fsdp.use': True} + + # Small audio-MAGNeT model (300M) + audio_magnet_small = audio_magnet.bind({'continue_from': '//pretrained/facebook/audio-magnet-small'}) + eval(audio_magnet_small, batch_size=128) + + # Medium audio-MAGNeT model (1.5B) + audio_magnet_medium = audio_magnet.bind({'continue_from': '//pretrained/facebook/audio-magnet-medium'}) + audio_magnet_medium.bind_({'model/lm/model_scale': 'medium'}) + audio_magnet_medium.bind_(fsdp) + eval(audio_magnet_medium, batch_size=128) diff --git a/audiocraft/grids/magnet/magnet_32khz.py b/audiocraft/grids/magnet/magnet_32khz.py new file mode 100644 index 00000000..c3575b30 --- /dev/null +++ b/audiocraft/grids/magnet/magnet_32khz.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from ..musicgen._explorers import LMExplorer +from ...environment import AudioCraftEnvironment + + +@LMExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=32, partition=partitions) + launcher.bind_(solver='magnet/magnet_base_32khz') + # replace this by the desired music dataset + launcher.bind_(dset='internal/music_400k_32khz') + + fsdp = {'autocast': False, 'fsdp.use': True} + medium = {'model/lm/model_scale': 'medium'} + adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} + segdur_10secs = {'dataset.segment_duration': 10, + 'dataset.batch_size': 576, + 'generate.lm.decoding_steps': [20, 10, 10, 10]} + + # Small models (300M) + launcher.slurm_(gpus=32).bind_(label='32gpus') + with launcher.job_array(): + # 30 seconds + sub = launcher.bind() + sub() + + # 10 seconds + sub = launcher.bind() + sub(segdur_10secs) + + # Medium models (1.5B) + launcher.bind_(fsdp) + launcher.slurm_(gpus=64).bind_(label='64gpus') + with launcher.job_array(): + # 30 seconds + sub = launcher.bind() + sub(medium, adam) + + # 10 seconds + sub = launcher.bind() + sub(segdur_10secs) diff --git a/audiocraft/grids/magnet/magnet_pretrained_32khz_eval.py b/audiocraft/grids/magnet/magnet_pretrained_32khz_eval.py new file mode 100644 index 00000000..2aaabc9b --- /dev/null +++ b/audiocraft/grids/magnet/magnet_pretrained_32khz_eval.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Evaluation with objective metrics for the pretrained MAGNeT models. +This grid takes signature from the training grid and runs evaluation-only stage. + +When running the grid for the first time, please use: +REGEN=1 dora grid magnet.magnet_pretrained_32khz_eval +and re-use the REGEN=1 option when the grid is changed to force regenerating it. + +Note that you need the proper metrics external libraries setup to use all +the objective metrics activated in this grid. Refer to the README for more information. +""" + +import os + +from ..musicgen._explorers import GenerationEvalExplorer +from ...environment import AudioCraftEnvironment +from ... import train + + +def eval(launcher, batch_size: int = 32): + opts = { + 'dset': 'audio/musiccaps_32khz', + 'solver/musicgen/evaluation': 'objective_eval', + 'execute_only': 'evaluate', + '+dataset.evaluate.batch_size': batch_size, + '+metrics.fad.tf.batch_size': 16, + } + # binary for FAD computation: replace this path with your own path + metrics_opts = { + 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research' + } + + sub = launcher.bind(opts) + sub.bind_(metrics_opts) + + # base objective metrics + sub() + + +@GenerationEvalExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=4, partition=partitions) + + if 'REGEN' not in os.environ: + folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1] + with launcher.job_array(): + for sig in folder.iterdir(): + if not sig.is_symlink(): + continue + xp = train.main.get_xp_from_sig(sig.name) + launcher(xp.argv) + return + + with launcher.job_array(): + magnet = launcher.bind(solver="magnet/magnet_32khz") + + fsdp = {'autocast': False, 'fsdp.use': True} + + segdur_10secs = {'dataset.segment_duration': 10, + 'generate.lm.decoding_steps': [20, 10, 10, 10]} + + # 10-second magnet models + magnet_small_10secs = magnet.bind({'continue_from': '//pretrained/facebook/magnet-small-10secs'}) + magnet_small_10secs.bind_(segdur_10secs) + eval(magnet_small_10secs, batch_size=128) + + magnet_medium_10secs = magnet.bind({'continue_from': '//pretrained/facebook/magnet-medium-10secs'}) + magnet_medium_10secs.bind_(segdur_10secs) + magnet_medium_10secs.bind_({'model/lm/model_scale': 'medium'}) + magnet_medium_10secs.bind_(fsdp) + eval(magnet_medium_10secs, batch_size=128) + + # 30-second magnet models + magnet_small_30secs = magnet.bind({'continue_from': '//pretrained/facebook/magnet-small-30secs'}) + eval(magnet_small_30secs, batch_size=128) + + magnet_medium_30secs = magnet.bind({'continue_from': '//pretrained/facebook/magnet-medium-30secs'}) + magnet_medium_30secs.bind_({'model/lm/model_scale': 'medium'}) + magnet_medium_30secs.bind_(fsdp) + eval(magnet_medium_30secs, batch_size=128) diff --git a/audiocraft/models/__init__.py b/audiocraft/models/__init__.py index be6bfe4b..a6b49825 100644 --- a/audiocraft/models/__init__.py +++ b/audiocraft/models/__init__.py @@ -13,6 +13,8 @@ HFEncodecModel, HFEncodecCompressionModel) from .audiogen import AudioGen from .lm import LMModel +from .lm_magnet import MagnetLMModel from .multibanddiffusion import MultiBandDiffusion from .musicgen import MusicGen +from .magnet import MAGNeT from .unet import DiffusionUnet diff --git a/audiocraft/models/audiogen.py b/audiocraft/models/audiogen.py index b4df536e..5f0e7f36 100644 --- a/audiocraft/models/audiogen.py +++ b/audiocraft/models/audiogen.py @@ -14,15 +14,13 @@ import torch from .encodec import CompressionModel +from .genmodel import BaseGenModel from .lm import LMModel from .builders import get_debug_compression_model, get_debug_lm_model from .loaders import load_compression_model, load_lm_model -from ..data.audio_utils import convert_audio -from ..modules.conditioners import ConditioningAttributes -from ..utils.autocast import TorchAutocast -class AudioGen: +class AudioGen(BaseGenModel): """AudioGen main model with convenient generation API. Args: @@ -35,44 +33,8 @@ class AudioGen: """ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: tp.Optional[float] = None): - self.name = name - self.compression_model = compression_model - self.lm = lm - # Just to be safe, let's put everything in eval mode. - self.compression_model.eval() - self.lm.eval() - - if max_duration is None: - if hasattr(lm, 'cfg'): - max_duration = lm.cfg.dataset.segment_duration # type: ignore - else: - raise ValueError("You must provide max_duration when building directly AudioGen") - assert max_duration is not None - self.max_duration: float = max_duration - self.device = next(iter(lm.parameters())).device - self.generation_params: dict = {} - self.set_generation_params(duration=5) # 5 seconds by default - self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None - if self.device.type == 'cpu': - self.autocast = TorchAutocast(enabled=False) - else: - self.autocast = TorchAutocast( - enabled=True, device_type=self.device.type, dtype=torch.float16) - - @property - def frame_rate(self) -> float: - """Roughly the number of AR steps per seconds.""" - return self.compression_model.frame_rate - - @property - def sample_rate(self) -> int: - """Sample rate of the generated audio.""" - return self.compression_model.sample_rate - - @property - def audio_channels(self) -> int: - """Audio channels of the generated audio.""" - return self.compression_model.channels + super().__init__(name, compression_model, lm, max_duration) + self.set_generation_params(duration=5) # default duration @staticmethod def get_pretrained(name: str = 'facebook/audiogen-medium', device=None): @@ -129,139 +91,3 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, 'cfg_coef': cfg_coef, 'two_step_cfg': two_step_cfg, } - - def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): - """Override the default progress callback.""" - self._progress_callback = progress_callback - - def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor: - """Generate samples conditioned on text. - - Args: - descriptions (list of str): A list of strings used as text conditioning. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) - assert prompt_tokens is None - return self._generate_tokens(attributes, prompt_tokens, progress) - - def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, - descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, - progress: bool = False) -> torch.Tensor: - """Generate samples conditioned on audio prompts. - - Args: - prompt (torch.Tensor): A batch of waveforms used for continuation. - Prompt should be [B, C, T], or [C, T] if only one sample is generated. - prompt_sample_rate (int): Sampling rate of the given audio waveforms. - descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - if prompt.dim() == 2: - prompt = prompt[None] - if prompt.dim() != 3: - raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") - prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) - if descriptions is None: - descriptions = [None] * len(prompt) - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) - assert prompt_tokens is not None - return self._generate_tokens(attributes, prompt_tokens, progress) - - @torch.no_grad() - def _prepare_tokens_and_attributes( - self, - descriptions: tp.Sequence[tp.Optional[str]], - prompt: tp.Optional[torch.Tensor], - ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: - """Prepare model inputs. - - Args: - descriptions (list of str): A list of strings used as text conditioning. - prompt (torch.Tensor): A batch of waveforms used for continuation. - """ - attributes = [ - ConditioningAttributes(text={'description': description}) - for description in descriptions] - - if prompt is not None: - if descriptions is not None: - assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" - prompt = prompt.to(self.device) - prompt_tokens, scale = self.compression_model.encode(prompt) - assert scale is None - else: - prompt_tokens = None - return attributes, prompt_tokens - - def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], - prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: - """Generate discrete audio tokens given audio prompt and/or conditions. - - Args: - attributes (list of ConditioningAttributes): Conditions used for generation (here text). - prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - Returns: - torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. - """ - total_gen_len = int(self.duration * self.frame_rate) - max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) - current_gen_offset: int = 0 - - def _progress_callback(generated_tokens: int, tokens_to_generate: int): - generated_tokens += current_gen_offset - if self._progress_callback is not None: - # Note that total_gen_len might be quite wrong depending on the - # codebook pattern used, but with delay it is almost accurate. - self._progress_callback(generated_tokens, total_gen_len) - else: - print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r') - - if prompt_tokens is not None: - assert max_prompt_len >= prompt_tokens.shape[-1], \ - "Prompt is longer than audio to generate" - - callback = None - if progress: - callback = _progress_callback - - if self.duration <= self.max_duration: - # generate by sampling from LM, simple case. - with self.autocast: - gen_tokens = self.lm.generate( - prompt_tokens, attributes, - callback=callback, max_gen_len=total_gen_len, **self.generation_params) - - else: - all_tokens = [] - if prompt_tokens is None: - prompt_length = 0 - else: - all_tokens.append(prompt_tokens) - prompt_length = prompt_tokens.shape[-1] - - stride_tokens = int(self.frame_rate * self.extend_stride) - while current_gen_offset + prompt_length < total_gen_len: - time_offset = current_gen_offset / self.frame_rate - chunk_duration = min(self.duration - time_offset, self.max_duration) - max_gen_len = int(chunk_duration * self.frame_rate) - with self.autocast: - gen_tokens = self.lm.generate( - prompt_tokens, attributes, - callback=callback, max_gen_len=max_gen_len, **self.generation_params) - if prompt_tokens is None: - all_tokens.append(gen_tokens) - else: - all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) - prompt_tokens = gen_tokens[:, :, stride_tokens:] - prompt_length = prompt_tokens.shape[-1] - current_gen_offset += stride_tokens - - gen_tokens = torch.cat(all_tokens, dim=-1) - - # generate audio - assert gen_tokens.dim() == 3 - with torch.no_grad(): - gen_audio = self.compression_model.decode(gen_tokens, None) - return gen_audio diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py index b7144874..66aa85c6 100644 --- a/audiocraft/models/builders.py +++ b/audiocraft/models/builders.py @@ -17,6 +17,7 @@ from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel from .lm import LMModel +from .lm_magnet import MagnetLMModel from ..modules.codebooks_patterns import ( CodebooksPatternProvider, DelayedPatternProvider, @@ -85,7 +86,7 @@ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: """Instantiate a transformer LM.""" - if cfg.lm_model == 'transformer_lm': + if cfg.lm_model in ['transformer_lm', 'transformer_lm_magnet']: kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) n_q = kwargs['n_q'] q_modeling = kwargs.pop('q_modeling', None) @@ -103,8 +104,10 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: codebooks_pattern_cfg = omegaconf.OmegaConf.create( {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} ) + pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) - return LMModel( + lm_class = MagnetLMModel if cfg.lm_model == 'transformer_lm_magnet' else LMModel + return lm_class( pattern_provider=pattern_provider, condition_provider=condition_provider, fuser=fuser, diff --git a/audiocraft/models/encodec.py b/audiocraft/models/encodec.py index cb0484ee..627fdddd 100644 --- a/audiocraft/models/encodec.py +++ b/audiocraft/models/encodec.py @@ -99,7 +99,7 @@ def get_pretrained( - dac_24khz (same) - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz) - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz) - - your own model on HugginFace. Export instructions to come... + - your own model on Hugging Face. Export instructions to come... """ from . import builders, loaders diff --git a/audiocraft/models/genmodel.py b/audiocraft/models/genmodel.py new file mode 100644 index 00000000..96397450 --- /dev/null +++ b/audiocraft/models/genmodel.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Base implementation for audio generative models. This base implementation +combines all the required components to run inference with pretrained audio +generative models. It can be easily inherited by downstream model classes to +provide easy access to the generation API. +""" + +from abc import ABC, abstractmethod +import typing as tp + +import omegaconf +import torch + +from .encodec import CompressionModel +from .lm import LMModel +from .builders import get_wrapped_compression_model +from ..data.audio_utils import convert_audio +from ..modules.conditioners import ConditioningAttributes +from ..utils.autocast import TorchAutocast + + +class BaseGenModel(ABC): + """Base generative model with convenient generation API. + + Args: + name (str): name of the model. + compression_model (CompressionModel): Compression model + used to map audio to invertible discrete representations. + lm (LMModel): Language model over discrete representations. + max_duration (float, optional): maximum duration the model can produce, + otherwise, inferred from the training params. + """ + def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, + max_duration: tp.Optional[float] = None): + self.name = name + self.compression_model = compression_model + self.lm = lm + self.cfg: tp.Optional[omegaconf.DictConfig] = None + # Just to be safe, let's put everything in eval mode. + self.compression_model.eval() + self.lm.eval() + + if hasattr(lm, 'cfg'): + cfg = lm.cfg + assert isinstance(cfg, omegaconf.DictConfig) + self.cfg = cfg + + if self.cfg is not None: + self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg) + + if max_duration is None: + if self.cfg is not None: + max_duration = lm.cfg.dataset.segment_duration # type: ignore + else: + raise ValueError("You must provide max_duration when building directly your GenModel") + assert max_duration is not None + + self.max_duration: float = max_duration + self.duration = self.max_duration + + # self.extend_stride is the length of audio extension when generating samples longer + # than self.max_duration. NOTE: the derived class must set self.extend_stride to a + # positive float value when generating with self.duration > self.max_duration. + self.extend_stride: tp.Optional[float] = None + self.device = next(iter(lm.parameters())).device + self.generation_params: dict = {} + self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None + if self.device.type == 'cpu': + self.autocast = TorchAutocast(enabled=False) + else: + self.autocast = TorchAutocast( + enabled=True, device_type=self.device.type, dtype=torch.float16) + + @property + def frame_rate(self) -> float: + """Roughly the number of AR steps per seconds.""" + return self.compression_model.frame_rate + + @property + def sample_rate(self) -> int: + """Sample rate of the generated audio.""" + return self.compression_model.sample_rate + + @property + def audio_channels(self) -> int: + """Audio channels of the generated audio.""" + return self.compression_model.channels + + def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): + """Override the default progress callback.""" + self._progress_callback = progress_callback + + @abstractmethod + def set_generation_params(self, *args, **kwargs): + """Set the generation parameters.""" + raise NotImplementedError("No base implementation for setting generation params.") + + @staticmethod + @abstractmethod + def get_pretrained(name: str, device=None): + raise NotImplementedError("No base implementation for getting pretrained model") + + @torch.no_grad() + def _prepare_tokens_and_attributes( + self, + descriptions: tp.Sequence[tp.Optional[str]], + prompt: tp.Optional[torch.Tensor], + ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: + """Prepare model inputs. + + Args: + descriptions (list of str): A list of strings used as text conditioning. + prompt (torch.Tensor): A batch of waveforms used for continuation. + """ + attributes = [ + ConditioningAttributes(text={'description': description}) + for description in descriptions] + + if prompt is not None: + if descriptions is not None: + assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" + prompt = prompt.to(self.device) + prompt_tokens, scale = self.compression_model.encode(prompt) + assert scale is None + else: + prompt_tokens = None + return attributes, prompt_tokens + + def generate_unconditional(self, num_samples: int, progress: bool = False, + return_tokens: bool = False) -> tp.Union[torch.Tensor, + tp.Tuple[torch.Tensor, torch.Tensor]]: + """Generate samples in an unconditional manner. + + Args: + num_samples (int): Number of samples to be generated. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + descriptions: tp.List[tp.Optional[str]] = [None] * num_samples + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) + + def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: + """Generate samples conditioned on text. + + Args: + descriptions (list of str): A list of strings used as text conditioning. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) + assert prompt_tokens is None + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) + + def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, + descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, + progress: bool = False, return_tokens: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: + """Generate samples conditioned on audio prompts and an optional text description. + + Args: + prompt (torch.Tensor): A batch of waveforms used for continuation. + Prompt should be [B, C, T], or [C, T] if only one sample is generated. + prompt_sample_rate (int): Sampling rate of the given audio waveforms. + descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + if prompt.dim() == 2: + prompt = prompt[None] + if prompt.dim() != 3: + raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") + prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) + if descriptions is None: + descriptions = [None] * len(prompt) + attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) + assert prompt_tokens is not None + tokens = self._generate_tokens(attributes, prompt_tokens, progress) + if return_tokens: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) + + def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], + prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: + """Generate discrete audio tokens given audio prompt and/or conditions. + + Args: + attributes (list of ConditioningAttributes): Conditions used for generation (here text). + prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + Returns: + torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. + """ + total_gen_len = int(self.duration * self.frame_rate) + max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) + current_gen_offset: int = 0 + + def _progress_callback(generated_tokens: int, tokens_to_generate: int): + generated_tokens += current_gen_offset + if self._progress_callback is not None: + # Note that total_gen_len might be quite wrong depending on the + # codebook pattern used, but with delay it is almost accurate. + self._progress_callback(generated_tokens, tokens_to_generate) + else: + print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r') + + if prompt_tokens is not None: + assert max_prompt_len >= prompt_tokens.shape[-1], \ + "Prompt is longer than audio to generate" + + callback = None + if progress: + callback = _progress_callback + + if self.duration <= self.max_duration: + # generate by sampling from LM, simple case. + with self.autocast: + gen_tokens = self.lm.generate( + prompt_tokens, attributes, + callback=callback, max_gen_len=total_gen_len, **self.generation_params) + + else: + assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration" + assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration." + all_tokens = [] + if prompt_tokens is None: + prompt_length = 0 + else: + all_tokens.append(prompt_tokens) + prompt_length = prompt_tokens.shape[-1] + + stride_tokens = int(self.frame_rate * self.extend_stride) + while current_gen_offset + prompt_length < total_gen_len: + time_offset = current_gen_offset / self.frame_rate + chunk_duration = min(self.duration - time_offset, self.max_duration) + max_gen_len = int(chunk_duration * self.frame_rate) + with self.autocast: + gen_tokens = self.lm.generate( + prompt_tokens, attributes, + callback=callback, max_gen_len=max_gen_len, **self.generation_params) + if prompt_tokens is None: + all_tokens.append(gen_tokens) + else: + all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) + prompt_tokens = gen_tokens[:, :, stride_tokens:] + prompt_length = prompt_tokens.shape[-1] + current_gen_offset += stride_tokens + + gen_tokens = torch.cat(all_tokens, dim=-1) + return gen_tokens + + def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor: + """Generate Audio from tokens.""" + assert gen_tokens.dim() == 3 + with torch.no_grad(): + gen_audio = self.compression_model.decode(gen_tokens, None) + return gen_audio diff --git a/audiocraft/models/lm.py b/audiocraft/models/lm.py index c4ea2e5e..3bd68109 100644 --- a/audiocraft/models/lm.py +++ b/audiocraft/models/lm.py @@ -219,7 +219,8 @@ def num_codebooks(self) -> int: def forward(self, sequence: torch.Tensor, conditions: tp.List[ConditioningAttributes], - condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor: + condition_tensors: tp.Optional[ConditionTensors] = None, + stage: int = -1) -> torch.Tensor: """Apply language model on sequence and conditions. Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and S the sequence steps, return the logits with shape [B, card, K, S]. @@ -231,6 +232,9 @@ def forward(self, sequence: torch.Tensor, you should pre-compute those and pass them as `condition_tensors`. condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning tensors, see `conditions`. + stage (int): The codebook level that is being predicted. Relevant for MAGNeT + in which prediction is done in a codebook-by-codebook manner. + Takes values in range(n_q), and ignored by default. Returns: torch.Tensor: Logits. """ @@ -250,7 +254,8 @@ def forward(self, sequence: torch.Tensor, input_, cross_attention_input = self.fuser(input_, condition_tensors) - out = self.transformer(input_, cross_attention_src=cross_attention_input) + out = self.transformer(input_, cross_attention_src=cross_attention_input, + src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None)) if self.out_norm: out = self.out_norm(out) logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card] @@ -264,7 +269,9 @@ def forward(self, sequence: torch.Tensor, def compute_predictions( self, codes: torch.Tensor, conditions: tp.List[ConditioningAttributes], - condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput: + condition_tensors: tp.Optional[ConditionTensors] = None, + stage: int = -1, + keep_only_valid_steps: bool = True) -> LMOutput: """Given an input tensor of codes [B, K, T] and list of conditions, runs the model forward using the specified codes interleaving pattern. @@ -276,6 +283,11 @@ def compute_predictions( you should pre-compute those and pass them as `condition_tensors`. condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning tensors, see `conditions`. + stage (int): The codebook level that is being predicted. Relevant for MAGNeT + in which prediction is done in a codebook-by-codebook manner. + Takes values in range(n_q), and ignored by default. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. Returns: LMOutput: Language model outputs logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, @@ -290,17 +302,18 @@ def compute_predictions( # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens pattern = self.pattern_provider.get_pattern(T) sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence( - codes, self.special_token_id, keep_only_valid_steps=True + codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps, ) + # apply model on pattern sequence model = self if self._fsdp is None else self._fsdp - logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card] + logits = model(sequence_codes, conditions, condition_tensors, stage=stage) # [B, K, S, card] # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card] # and provide the corresponding mask over invalid positions of tokens logits = logits.permute(0, 3, 1, 2) # [B, card, K, S] # note: we use nans as special token to make it obvious if we feed unexpected logits logits, logits_indexes, logits_mask = pattern.revert_pattern_logits( - logits, float('nan'), keep_only_valid_steps=True + logits, float('nan'), keep_only_valid_steps=keep_only_valid_steps ) logits = logits.permute(0, 2, 3, 1) # [B, K, T, card] logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T] @@ -393,9 +406,10 @@ def generate(self, two_step_cfg: tp.Optional[bool] = None, remove_prompts: bool = False, check: bool = False, - callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor: + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + **kwargs) -> torch.Tensor: """Generate tokens sampling from the model given a prompt or unconditionally. Generation can - be perform in a greedy fashion or using sampling with top K and top P strategies. + be performed in a greedy fashion or using sampling with top K and top P strategies. Args: prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T]. diff --git a/audiocraft/models/lm_magnet.py b/audiocraft/models/lm_magnet.py new file mode 100644 index 00000000..4c2ab9ee --- /dev/null +++ b/audiocraft/models/lm_magnet.py @@ -0,0 +1,490 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import typing as tp +import torch +import numpy as np + +from ..utils import utils +from ..modules.conditioners import ( + ClassifierFreeGuidanceDropout, + ConditioningAttributes, + ConditionType, +) +from .lm import LMModel + +logger = logging.getLogger(__name__) +ConditionTensors = tp.Dict[str, ConditionType] +CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]] + + +class MagnetLMModel(LMModel): + """Transformer-based, non-autoregressive model, operates on multiple streams of audio tokens (MAGNeT). + Args: + subcodes_context (int): The number of timesteps attended in the self-attention blocks of codebooks > 0. + When set to -1, attention is unrestricted and all timesteps are attended. Defaults to 5. + compression_model_framerate (int): frame rate of the audio tokenizer. + segment_duration (int): Sample length in seconds. + span_len (int): Determines the length of masking spans. This is the minimal length of consecutive masked tokens, + for both training and inference. Defaults to 3. + **kwargs: Additional parameters for the LMModel. + """ + def __init__(self, subcodes_context: int = 5, compression_model_framerate: int = 50, + segment_duration: int = 10, span_len: int = 3, **kwargs): + super().__init__(**kwargs) + self.causal = kwargs['causal'] + self.subcodes_context = subcodes_context + self.span_len = span_len + self._build_attn_masks(compression_model_framerate, segment_duration, + device=kwargs['device'], dtype=kwargs['dtype']) + + def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """Creates a restricted attention mask (local attention map) where the context + is determined by self.subcodes_context. + Args: + seq_len (int): token sequence length. + device (torch.device): device of the output tensor. + dtype (torch.dtype): data type of the output tensor. + Returns: + torch.Tensor: The restricted attention mask. + """ + # Return a context restricted non-causal att mask + queries_pos = torch.arange(seq_len, device=device).view(-1, 1) + keys_pos = torch.arange(seq_len, device=device).view(1, -1) + + delta = queries_pos - keys_pos + valid = torch.abs(delta) <= self.subcodes_context + return torch.where( + valid, + torch.zeros([], device=device, dtype=dtype), + torch.full([], float('-inf'), device=device, dtype=dtype)) + + def _stage_attn_mask(self, stage: int, seq_len: int, + device: torch.device, dtype: torch.dtype) -> tp.Optional[torch.Tensor]: + """Creates a restricted attention mask given the stage (codebook index). + Args: + stage (int): The codebook index. Takes values in [0, n_q]. + seq_len (int): Token sequence length. + device (torch.device): device of the output tensor. + dtype (torch.dtype): data type of the output tensor. + Returns: + torch.Tensor: Either a restricted attention mask or None if stage attention is unrestricted. + """ + sa_mask = None + + if stage > 0 and self.subcodes_context > -1: + # parallel - non-causal - with restricted subcodes context + sa_mask = self.restricted_context_attn_mask(seq_len, device=device, dtype=dtype) + + if sa_mask is not None: + # align8 to enable memory efficient attention + MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR = 8 + seq_len_aligned = \ + int(np.ceil(seq_len / MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR)) * MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR + + sa_mask_aligned = torch.zeros((seq_len_aligned, seq_len_aligned), device=device, dtype=dtype) + sa_mask_aligned[:seq_len, :seq_len] = sa_mask + sa_mask = sa_mask_aligned + + return sa_mask + + def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int, + device: torch.device, dtype: torch.dtype): + """Construct attention mask per stage. For each of the RVQ codebook levels in the [0, n_q] range, + either a local attention map or None would be stored as an entry in the self.attn_mask_per_stage list. + Args: + compression_model_framerate (int): The frame rate of the tokenizer. + segment_duration (int): Sample length in seconds. + device (torch.device): device of the output tensor. + dtype (torch.dtype): data type of the output tensor. + """ + seq_len = compression_model_framerate * segment_duration + self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, device, dtype) for stage in range(self.n_q)] + + @torch.no_grad() + def generate(self, + prompt: tp.Optional[torch.Tensor] = None, + conditions: tp.List[ConditioningAttributes] = [], + num_samples: tp.Optional[int] = None, + max_gen_len: int = 256, + use_sampling: bool = True, + temp: float = 1.0, + top_k: int = 250, + top_p: float = 0.0, + cfg_coef: tp.Optional[float] = None, + two_step_cfg: tp.Optional[bool] = None, + remove_prompts: bool = False, + check: bool = False, + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + **kwargs) -> torch.Tensor: + + assert cfg_coef is None, "Unsupported in MAGNeT. Use max_cfg_coef,min_cfg_coef instead." + assert two_step_cfg is None, "MAGNeT currently doesn't support two step classifier-free-guidance." + assert remove_prompts is False, "MAGNeT currently doesn't support the remove_prompts arg." + assert check is False, "MAGNeT currently doesn't support the check arg." + # Call the MAGNeT-specific generation method + return self._generate_magnet(prompt=prompt, + conditions=conditions, + num_samples=num_samples, + max_gen_len=max_gen_len, + use_sampling=use_sampling, + temp=temp, + top_k=top_k, + top_p=top_p, + callback=callback, **kwargs) + + @torch.no_grad() + def _generate_magnet(self, + prompt: tp.Optional[torch.Tensor] = None, + conditions: tp.List[ConditioningAttributes] = [], + num_samples: tp.Optional[int] = None, + max_gen_len: int = 256, + use_sampling: bool = True, + temp: float = 3.0, + top_k: int = 0, + top_p: float = 0.9, + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + max_cfg_coef: float = 10.0, + min_cfg_coef: float = 1.0, + decoding_steps: tp.List[int] = [20, 10, 10, 10], + anneal_temp: bool = True, + span_scoring='max', + span_arrangement='nonoverlap') -> torch.Tensor: + """Generate audio tokens given textual conditions, and optionally given audio prompts, + by running MAGNeT's iterative decoding algorithm for each of the n_q RVQ levels. + Args: + prompt (torch.Tensor): Prompt tokens of shape [B, K, T]. + conditions (list of ConditioningAttributes): List of conditions. + num_samples (int): Number of samples to generate when no prompt and no conditions are given. + max_gen_len (int): Maximum generation length. + use_sampling (bool): Whether to use a sampling strategy or not. + temp (float): Initial sampling temperature. + top_k (int): k for "top-k" sampling. + top_p (float): p for "top-p" sampling. + callback (Callback): Callback function to report generation progress. + max_clsfg_coef (float): Initial coefficient used for classifier free guidance. + min_clsfg_coef (float): Final coefficient used for classifier free guidance. + decoding_steps (list of n_q ints): The number of iterative decoding steps, + for each of the n_q RVQ codebooks. + anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage. + span_scoring (str): Use the maximum probability of each span ('max') + or the product of probabilities ('prod'). + span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1'). + in the masking scheme. + Returns: + torch.Tensor: Generated tokens. + """ + assert not self.training, "generation shouldn't be used in training mode." + first_param = next(iter(self.parameters())) + device = first_param.device + + # Checking all input shapes are consistent. + possible_num_samples = [] + if num_samples is not None: + possible_num_samples.append(num_samples) + elif prompt is not None: + possible_num_samples.append(prompt.shape[0]) + elif conditions: + possible_num_samples.append(len(conditions)) + else: + possible_num_samples.append(1) + assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" + num_samples = possible_num_samples[0] + + # below we create set of conditions: one conditional and one unconditional + # to do that we merge the regular condition together with the null condition + # we then do 1 forward pass instead of 2. + cfg_conditions: tp.Optional[ConditionTensors] + if conditions: + null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions) + conditions = conditions + null_conditions + tokenized = self.condition_provider.tokenize(conditions) + cfg_conditions = self.condition_provider(tokenized) + else: + cfg_conditions = {} + + if prompt is None: + assert num_samples > 0 + prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device) + + B, K, prompt_length = prompt.shape + start_offset = prompt_length + assert start_offset < max_gen_len + + mask_id = self.special_token_id + + # we generate codes with a fixed sequence length + shape = (B, K, max_gen_len) + + gen_codes = torch.full(shape, mask_id, dtype=torch.long, device=device) + # filling the gen_codes with the prompt if needed + gen_codes[..., :start_offset] = prompt + # create the gen_sequence with proper interleaving from the pattern: [B, K, S] + gen_sequence = gen_codes + + curr_step = 0 + for stage, n_steps in zip(range(self.n_q), decoding_steps): + gen_sequence, curr_step = self._generate_stage(gen_sequence, + cfg_conditions, + stage=stage, + device=device, + prompt_length=prompt_length, + prompt=prompt, + temp=temp, + max_cfg_coef=max_cfg_coef, + min_cfg_coef=min_cfg_coef, + top_k=top_k, + top_p=top_p, + timesteps=n_steps, + anneal_temp=anneal_temp, + span_scoring=span_scoring, + use_sampling=use_sampling, + span_arrangement=span_arrangement, + curr_step=curr_step, + total_steps=sum(decoding_steps), + callback=callback) + + return gen_sequence + + @torch.no_grad() + def _generate_stage(self, + gen_sequence: torch.Tensor, + condition_tensors: tp.Optional[ConditionTensors], + stage: int, + device: torch.device, + prompt_length: int = 0, + prompt: tp.Optional[torch.Tensor] = None, + use_sampling: bool = True, + temp: float = 3.0, + max_cfg_coef: float = 10.0, + min_cfg_coef: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + timesteps: int = 10, + anneal_temp: bool = True, + span_scoring: str = 'max', + span_arrangement: str = 'nonoverlap', + curr_step: int = 0, + total_steps: int = 0, + callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> tp.Tuple[torch.Tensor, int]: + """Generate audio tokens of a single RVQ level (stage), given the previously generated stages, + and the textual conditions. + Args: + gen_sequence (torch.Tensor): Previously generated tokens. + condition_tensors (tp.Optional[ConditionTensors]): pre-computed conditioning tensors. + stage (int): RVQ level to generate. + device (torch.device): device of the output tensor. + prompt_length (int): Temporal length of the audio prompt. + prompt (torch.Tensor): Prompt tokens of shape [B, K, T]. + use_sampling (bool): Whether to use a sampling strategy or not. + temp (float): Initial sampling temperature. + max_clsfg_coef (float): Initial coefficient used for classifier free guidance. + min_clsfg_coef (float): Final coefficient used for classifier free guidance. + top_k (int): k for "top-k" sampling. + top_p (float): p for "top-p" sampling. + timesteps (int): Number of iterative decoding steps. + anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage. + span_scoring (str): Use the maximum probability of each span ('max') + or the product of probabilities ('prod'). + span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1'). + in the masking scheme. + curr_step (int): Global iterative decoding step counter. + total_steps (int): Total decoding steps. + callback (Callback): Callback function to report generation progress. + Returns: + tuple(torch.Tensor, int): Generated tokens and the current decoding step counter. + """ + B, K, T = gen_sequence.shape + shape = (B, 1, T) # generating a single codebook per stage + + mask_id = self.special_token_id + stage_gen_seq = torch.full(shape, mask_id, dtype=torch.long, device=device) + + assert span_arrangement == 'nonoverlap' or span_arrangement == 'stride1' + chunk_masking = self.span_len > 1 and span_arrangement == 'nonoverlap' + + DONT_REMASK_ME_SCORE = -1e4 + + model = self if self._fsdp is None else self._fsdp + + if chunk_masking: + # span-wise scores + n_chunks = T // self.span_len + if T % self.span_len != 0: + # trim sequence ending to achieve a multiple of span_len + T = self.span_len * n_chunks + gen_sequence = gen_sequence[..., :T] + stage_gen_seq = stage_gen_seq[..., :T] + + chunked_shape = (B, 1, n_chunks) + n_prompt_chunks = prompt_length // self.span_len + scores = torch.zeros(chunked_shape, dtype=torch.float32, device=device) + scores[..., :n_prompt_chunks] = DONT_REMASK_ME_SCORE + num_chunks_to_gen = n_chunks - n_prompt_chunks + else: + # token-wise scores + scores = torch.zeros(shape, dtype=torch.float32, device=device) + scores[..., :prompt_length] = DONT_REMASK_ME_SCORE + gen_T = T - prompt_length + + # run MAGNeT iterative decoding for "timesteps" iterations + for timestep, steps_left in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))): + + mask_p = torch.cos(timestep * math.pi * 0.5) + + if chunk_masking: + num_masked = max(int((mask_p * num_chunks_to_gen).item()), 1) + else: + num_masked = max(int((mask_p * gen_T).item()), 1) + + # masking + run_lps_masking = (span_arrangement == 'stride1') and self.span_len > 1 + if run_lps_masking: + # masking of the k least probable overlapping (stride 1) spans + mask = torch.concat(( + [self._least_probable_span_masking(scores[[i], :, :], num_masked).to(device) + for i in range(B)]), dim=0) + stage_gen_seq[mask] = mask_id + else: + # masking of the k least probable non-overlapping spans + masked = scores.topk(num_masked, dim=-1).indices + if chunk_masking: + chunks_mask = torch.full(chunked_shape, False, dtype=torch.bool, device=device) + chunks_mask = chunks_mask.scatter(2, masked, True) + mask = torch.repeat_interleave(chunks_mask, self.span_len, dim=-1) + stage_gen_seq[mask] = mask_id + else: + stage_gen_seq = stage_gen_seq.scatter(2, masked, mask_id) + + if prompt is not None: + stage_gen_seq[..., :prompt_length] = prompt[:, stage, :].unsqueeze(1) + + gen_sequence[:, [stage], :] = stage_gen_seq + if condition_tensors: + # duplicate input for classifier free guidance + sequence = torch.cat([gen_sequence, gen_sequence], dim=0) + + all_logits = model(sequence, [], condition_tensors, stage=stage) + + if condition_tensors: + # classifier free guidance with annealing + cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card] + clsfg_coef = float(mask_p) * max_cfg_coef + (1 - float(mask_p)) * min_cfg_coef + logits = uncond_logits + (cond_logits - uncond_logits) * clsfg_coef + else: + logits = all_logits + + # temperature annealing - linear + t = temp * (steps_left / timesteps) if anneal_temp else temp + + # sampling + logits = logits[:, stage, :, :].unsqueeze(1) + probs = torch.softmax(logits / max(t, 1e-2), dim=-1) + if use_sampling: + if top_p > 0.0: + sampled_tokens = utils.sample_top_p(probs, p=top_p) + elif top_k > 0: + sampled_tokens = utils.sample_top_k(probs, k=top_k) + else: + sampled_tokens = utils.multinomial(probs, num_samples=1) + else: + sampled_tokens = torch.argmax(logits, dim=-1, keepdim=True) + + # place mask_id token in each of the masked positions + mask = stage_gen_seq == mask_id + stage_gen_seq = torch.where(mask, sampled_tokens[..., 0], stage_gen_seq) + gen_sequence[:, [stage], :] = stage_gen_seq + + # get probs of sampled tokens + sampled_probs = torch.gather(probs, 3, sampled_tokens)[..., 0] + + # span scoring + if chunk_masking: + if span_scoring == 'max': + # max in linear space + scores = 1 - torch.max(sampled_probs.reshape((B, 1, n_chunks, -1)), dim=-1)[0] + elif span_scoring == 'prod': + # prod in log space + scores = torch.sum(-torch.log(sampled_probs).reshape((B, 1, n_chunks, -1)), dim=-1) + else: + raise NotImplementedError + else: + # prod in log space for lps masking (stride1) + scores = -torch.log(sampled_probs) + + # Fix unmasked tokens by placing inf probs (-inf scores) + if chunk_masking: + scores = scores.masked_fill(~chunks_mask, DONT_REMASK_ME_SCORE) + else: + scores = scores.masked_fill(~mask, DONT_REMASK_ME_SCORE) + + if callback is not None: + curr_step += 1 + callback(curr_step, total_steps) + + return gen_sequence, curr_step + + def _construct_spans_mask(self, span_starts: torch.Tensor, T: int, device: torch.device) -> torch.Tensor: + """Build a [1x1xT] boolean mask consists of overlapping spans of True values, where + span_starts defines the initial index of each span, and the span length is + defined by self.span_len. + Args: + span_starts (torch.Tensor): Boolean mask determines the temporal location of each span start. + T (int): Sequence length. + device (torch.device): device of the output tensor. + Returns: + torch.Tensor: Spans mask of shape [1x1xT] + """ + mask = torch.full((1, 1, T), False, device=device) + mask[:, :, span_starts] = True + shifted_mask = mask.clone() + for _ in range(self.span_len - 1): + shifted_mask = torch.concat((torch.full((1, 1, 1), False, device=device), shifted_mask[:, :, :-1]), dim=-1) + mask = torch.logical_or(mask, shifted_mask) + return mask + + def _least_probable_span_masking(self, scores: torch.Tensor, num_masked_trg: int) -> torch.Tensor: + """Construct a [1x1xT] boolean mask, consists of the u least probable spans, + where the token probability is determined by -scores, and the total + number of masked tokens is as closest as possible to num_masked_trg. + Find u using binary search. + Args: + scores (torch.Tensor): Per token score [-log(prob)] + num_masked_trg: int: The desired amount of tokens to be masked. + Returns: + torch.Tensor: Spans mask of shape [1x1xT] + """ + T = scores.shape[-1] + device = scores.device + scores_unfolded = scores.unfold(2, self.span_len, 1) + # Span score is the product of probs (sum in log space) + span_scores = scores_unfolded.sum(dim=-1) + spans_by_scores = torch.argsort(span_scores[0, 0], descending=True) + + num_masked_trg = max(num_masked_trg, self.span_len) + + # Binary search for u - the number least probable overlapping masked spans s.t. + # the total masking rate is the closest to num_masked_trg / T. + min_u = num_masked_trg // self.span_len + max_u = num_masked_trg - self.span_len + 1 + mid = round(0.5 * (min_u + max_u)) + + if mid == min_u or mid == max_u: + return self._construct_spans_mask(spans_by_scores[:mid], T, device) + + while mid > min_u and mid < max_u: + mask = self._construct_spans_mask(spans_by_scores[:mid], T, device) + n_masked = mask.sum() + if n_masked > num_masked_trg: + max_u = mid + mid = round(0.5 * (min_u + max_u)) + else: + min_u = mid + mid = round(0.5 * (min_u + max_u)) + + return mask diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index f02ba115..a6ec475e 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -118,6 +118,34 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di return model +def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int, + device='cpu', cache_dir: tp.Optional[str] = None): + pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) + cfg = OmegaConf.create(pkg['xp.cfg']) + cfg.device = str(device) + if cfg.device == 'cpu': + cfg.dtype = 'float32' + else: + cfg.dtype = 'float16' + _delete_param(cfg, 'conditioners.args.merge_text_conditions_p') + _delete_param(cfg, 'conditioners.args.drop_desc_p') + + cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate + cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration + cfg.transformer_lm.span_len = cfg.masking.span_len + + # MAGNeT models v1 support only xformers backend. + from audiocraft.modules.transformer import set_efficient_attention_backend + if cfg.transformer_lm.memory_efficient: + set_efficient_attention_backend("xformers") + + model = builders.get_lm_model(cfg) + model.load_state_dict(pkg['best_state']) + model.eval() + model.cfg = cfg + return model + + def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], filename: tp.Optional[str] = None, cache_dir: tp.Optional[str] = None): diff --git a/audiocraft/models/magnet.py b/audiocraft/models/magnet.py new file mode 100644 index 00000000..453269ad --- /dev/null +++ b/audiocraft/models/magnet.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Main model for using MAGNeT. This will combine all the required components +and provide easy access to the generation API. +""" +import typing as tp +import torch + +from .genmodel import BaseGenModel +from .loaders import load_compression_model, load_lm_model_magnet + + +class MAGNeT(BaseGenModel): + """MAGNeT main model with convenient generation API. + Args: + See MusicGen class. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + # MAGNeT operates over a fixed sequence length defined in it's config. + self.duration = self.lm.cfg.dataset.segment_duration + self.set_generation_params() + + @staticmethod + def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None): + """Return pretrained model, we provide six models: + - facebook/magnet-small-10secs (300M), text to music, 10-second audio samples. + # see: https://huggingface.co/facebook/magnet-small-10secs + - facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples. + # see: https://huggingface.co/facebook/magnet-medium-10secs + - facebook/magnet-small-30secs (300M), text to music, 30-second audio samples. + # see: https://huggingface.co/facebook/magnet-small-30secs + - facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples. + # see: https://huggingface.co/facebook/magnet-medium-30secs + - facebook/audio-magnet-small (300M), text to sound-effect (10-second samples). + # see: https://huggingface.co/facebook/audio-magnet-small + - facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples). + # see: https://huggingface.co/facebook/audio-magnet-medium + """ + if device is None: + if torch.cuda.device_count(): + device = 'cuda' + else: + device = 'cpu' + + compression_model = load_compression_model(name, device=device) + lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device) + + if 'self_wav' in lm.condition_provider.conditioners: + lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True + + kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm} + return MAGNeT(**kwargs) + + def set_generation_params(self, use_sampling: bool = True, top_k: int = 0, + top_p: float = 0.9, temperature: float = 3.0, + max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0, + decoding_steps: tp.List[int] = [20, 10, 10, 10], + span_arrangement: str = 'nonoverlap'): + """Set the generation parameters for MAGNeT. + + Args: + use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. + top_k (int, optional): top_k used for sampling. Defaults to 0. + top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9. + temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0. + max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0. + min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0. + decoding_steps (list of n_q ints, optional): The number of iterative decoding steps, + for each of the n_q RVQ codebooks. + span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap') + or overlapping spans ('stride1') in the masking scheme. + """ + self.generation_params = { + 'use_sampling': use_sampling, + 'temp': temperature, + 'top_k': top_k, + 'top_p': top_p, + 'max_cfg_coef': max_cfg_coef, + 'min_cfg_coef': min_cfg_coef, + 'decoding_steps': [int(s) for s in decoding_steps], + 'span_arrangement': span_arrangement + } diff --git a/audiocraft/models/musicgen.py b/audiocraft/models/musicgen.py index 88ee13b6..463f61c6 100644 --- a/audiocraft/models/musicgen.py +++ b/audiocraft/models/musicgen.py @@ -12,16 +12,15 @@ import typing as tp import warnings -import omegaconf import torch from .encodec import CompressionModel +from .genmodel import BaseGenModel from .lm import LMModel -from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model +from .builders import get_debug_compression_model, get_debug_lm_model from .loaders import load_compression_model, load_lm_model from ..data.audio_utils import convert_audio from ..modules.conditioners import ConditioningAttributes, WavCondition -from ..utils.autocast import TorchAutocast MelodyList = tp.List[tp.Optional[torch.Tensor]] @@ -37,7 +36,7 @@ } -class MusicGen: +class MusicGen(BaseGenModel): """MusicGen main model with convenient generation API. Args: @@ -50,54 +49,8 @@ class MusicGen: """ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: tp.Optional[float] = None): - self.name = name - self.compression_model = compression_model - self.lm = lm - self.cfg: tp.Optional[omegaconf.DictConfig] = None - # Just to be safe, let's put everything in eval mode. - self.compression_model.eval() - self.lm.eval() - - if hasattr(lm, 'cfg'): - cfg = lm.cfg - assert isinstance(cfg, omegaconf.DictConfig) - self.cfg = cfg - - if self.cfg is not None: - self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg) - - if max_duration is None: - if self.cfg is not None: - max_duration = lm.cfg.dataset.segment_duration # type: ignore - else: - raise ValueError("You must provide max_duration when building directly MusicGen") - assert max_duration is not None - self.max_duration: float = max_duration - self.device = next(iter(lm.parameters())).device - - self.generation_params: dict = {} - self.set_generation_params(duration=15) # 15 seconds by default - self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None - if self.device.type == 'cpu': - self.autocast = TorchAutocast(enabled=False) - else: - self.autocast = TorchAutocast( - enabled=True, device_type=self.device.type, dtype=torch.float16) - - @property - def frame_rate(self) -> float: - """Roughly the number of AR steps per seconds.""" - return self.compression_model.frame_rate - - @property - def sample_rate(self) -> int: - """Sample rate of the generated audio.""" - return self.compression_model.sample_rate - - @property - def audio_channels(self) -> int: - """Audio channels of the generated audio.""" - return self.compression_model.channels + super().__init__(name, compression_model, lm, max_duration) + self.set_generation_params(duration=15) # default duration @staticmethod def get_pretrained(name: str = 'facebook/musicgen-melody', device=None): @@ -169,41 +122,6 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, 'two_step_cfg': two_step_cfg, } - def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): - """Override the default progress callback.""" - self._progress_callback = progress_callback - - def generate_unconditional(self, num_samples: int, progress: bool = False, - return_tokens: bool = False) -> tp.Union[torch.Tensor, - tp.Tuple[torch.Tensor, torch.Tensor]]: - """Generate samples in an unconditional manner. - - Args: - num_samples (int): Number of samples to be generated. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - descriptions: tp.List[tp.Optional[str]] = [None] * num_samples - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) - tokens = self._generate_tokens(attributes, prompt_tokens, progress) - if return_tokens: - return self.generate_audio(tokens), tokens - return self.generate_audio(tokens) - - def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ - -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: - """Generate samples conditioned on text. - - Args: - descriptions (list of str): A list of strings used as text conditioning. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) - assert prompt_tokens is None - tokens = self._generate_tokens(attributes, prompt_tokens, progress) - if return_tokens: - return self.generate_audio(tokens), tokens - return self.generate_audio(tokens) - def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType, melody_sample_rate: int, progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, @@ -242,33 +160,6 @@ def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyTy return self.generate_audio(tokens), tokens return self.generate_audio(tokens) - def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, - descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, - progress: bool = False, return_tokens: bool = False) \ - -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: - """Generate samples conditioned on audio prompts. - - Args: - prompt (torch.Tensor): A batch of waveforms used for continuation. - Prompt should be [B, C, T], or [C, T] if only one sample is generated. - prompt_sample_rate (int): Sampling rate of the given audio waveforms. - descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - if prompt.dim() == 2: - prompt = prompt[None] - if prompt.dim() != 3: - raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") - prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) - if descriptions is None: - descriptions = [None] * len(prompt) - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) - assert prompt_tokens is not None - tokens = self._generate_tokens(attributes, prompt_tokens, progress) - if return_tokens: - return self.generate_audio(tokens), tokens - return self.generate_audio(tokens) - @torch.no_grad() def _prepare_tokens_and_attributes( self, @@ -347,9 +238,9 @@ def _progress_callback(generated_tokens: int, tokens_to_generate: int): if self._progress_callback is not None: # Note that total_gen_len might be quite wrong depending on the # codebook pattern used, but with delay it is almost accurate. - self._progress_callback(generated_tokens, total_gen_len) + self._progress_callback(generated_tokens, tokens_to_generate) else: - print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r') + print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r') if prompt_tokens is not None: assert max_prompt_len >= prompt_tokens.shape[-1], \ @@ -377,6 +268,8 @@ def _progress_callback(generated_tokens: int, tokens_to_generate: int): all_tokens.append(prompt_tokens) prompt_length = prompt_tokens.shape[-1] + assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration" + assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration." stride_tokens = int(self.frame_rate * self.extend_stride) while current_gen_offset + prompt_length < total_gen_len: @@ -413,10 +306,3 @@ def _progress_callback(generated_tokens: int, tokens_to_generate: int): gen_tokens = torch.cat(all_tokens, dim=-1) return gen_tokens - - def generate_audio(self, gen_tokens: torch.Tensor): - """Generate Audio from tokens""" - assert gen_tokens.dim() == 3 - with torch.no_grad(): - gen_audio = self.compression_model.decode(gen_tokens, None) - return gen_audio diff --git a/audiocraft/modules/codebooks_patterns.py b/audiocraft/modules/codebooks_patterns.py index 61362588..386df582 100644 --- a/audiocraft/modules/codebooks_patterns.py +++ b/audiocraft/modules/codebooks_patterns.py @@ -30,7 +30,7 @@ class Pattern: The pattern provides convenient methods to build and revert interleaved sequences from it: ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T] - to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size, + to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size, K being the number of codebooks, T the number of original timesteps and S the number of sequence steps for the output sequence. The unfilled positions are replaced with a special token and the built sequence is returned along with a mask indicating valid tokens. @@ -49,7 +49,6 @@ class Pattern: def __post_init__(self): assert len(self.layout) > 0 - assert self.layout[0] == [] self._validate_layout() self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes) @@ -93,6 +92,9 @@ def valid_layout(self): valid_step = len(self.layout) - self.max_delay return self.layout[:valid_step] + def starts_with_special_token(self): + return self.layout[0] == [] + def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None): """Get codebook coordinates in the layout that corresponds to the specified timestep t and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step @@ -202,7 +204,7 @@ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" # ensure we take the appropriate indexes to keep the model output from the first special token as well - if is_model_output: + if is_model_output and self.starts_with_special_token(): ref_layout = ref_layout[1:] # single item indexing being super slow with pytorch vs. numpy, so we use numpy here @@ -335,7 +337,8 @@ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, assert sorted(self.delays) == self.delays def get_pattern(self, timesteps: int) -> Pattern: - out: PatternLayout = [[]] + omit_special_token = self.empty_initial < 0 + out: PatternLayout = [] if omit_special_token else [[]] max_delay = max(self.delays) if self.empty_initial: out += [[] for _ in range(self.empty_initial)] @@ -360,9 +363,10 @@ class ParallelPatternProvider(DelayedPatternProvider): Args: n_q (int): Number of codebooks. + empty_initial (int): Prepend with N empty list of coordinates. """ - def __init__(self, n_q: int): - super().__init__(n_q, [0] * n_q) + def __init__(self, n_q: int, empty_initial: int = 0): + super().__init__(n_q, [0] * n_q, empty_initial=empty_initial) class UnrolledPatternProvider(CodebooksPatternProvider): diff --git a/audiocraft/modules/transformer.py b/audiocraft/modules/transformer.py index e8100a4c..818e98c0 100644 --- a/audiocraft/modules/transformer.py +++ b/audiocraft/modules/transformer.py @@ -315,7 +315,6 @@ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask=None, need_weights=False, attn_mask=None, average_attn_weights=True, is_causal=False): - assert attn_mask is None assert not is_causal, ("New param added in torch 2.0.1 not supported, " "use the causal args in the constructor.") @@ -329,7 +328,10 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, assert self.causal or self.cross_attention, \ "Streaming only available for causal or cross attention" + custom_attn_mask = attn_mask is not None + if self.causal: + assert attn_mask is None # At the moment we specialize only for the self-attention case. assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value" assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value" @@ -398,6 +400,11 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, if self.attention_as_float32: q, k, v = [x.float() for x in [q, k, v]] if self.memory_efficient: + if custom_attn_mask: + # When using a custom attn mask: move to query's device + remove align8 padding + seq_len = query.shape[1] + attn_mask = attn_mask.to(q.dtype) + attn_mask = attn_mask[:seq_len, :seq_len] p = self.dropout if self.training else 0 if _efficient_attention_backend == 'torch': x = torch.nn.functional.scaled_dot_product_attention( diff --git a/audiocraft/solvers/builders.py b/audiocraft/solvers/builders.py index 304d8f08..7c53b3ac 100644 --- a/audiocraft/solvers/builders.py +++ b/audiocraft/solvers/builders.py @@ -45,10 +45,13 @@ def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: from .compression import CompressionSolver from .musicgen import MusicGenSolver from .diffusion import DiffusionSolver + from .magnet import MagnetSolver, AudioMagnetSolver klass = { 'compression': CompressionSolver, 'musicgen': MusicGenSolver, 'audiogen': AudioGenSolver, + 'magnet': MagnetSolver, + 'audio_magnet': AudioMagnetSolver, 'lm': MusicGenSolver, # backward compatibility 'diffusion': DiffusionSolver, 'sound_lm': AudioGenSolver, # backward compatibility @@ -108,7 +111,7 @@ def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: o elif cfg.optimizer == 'dadam': optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam) else: - raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}") + raise ValueError(f"Unsupported Optimizer: {cfg.optimizer}") return optimizer diff --git a/audiocraft/solvers/magnet.py b/audiocraft/solvers/magnet.py new file mode 100644 index 00000000..5c401202 --- /dev/null +++ b/audiocraft/solvers/magnet.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from omegaconf import DictConfig +from . import builders, musicgen +from einops import rearrange +from torch.nn import functional as F +from ..modules.conditioners import SegmentWithAttributes + +import torch +import numpy as np +import random +import typing as tp +import math +import flashy + + +class MagnetSolver(musicgen.MusicGenSolver): + """Solver for MAGNeT - Masked Audio Generation using + a single Non-autoregressive Transformer https://arxiv.org/abs/2401.04577. + """ + def __init__(self, cfg: DictConfig): + super().__init__(cfg) + + # initialize generation parameters by config + self.generation_params = { + 'use_sampling': self.cfg.generate.lm.use_sampling, + 'temp': self.cfg.generate.lm.temp, + 'top_k': self.cfg.generate.lm.top_k, + 'top_p': self.cfg.generate.lm.top_p, + 'max_cfg_coef': self.cfg.generate.lm.max_cfg_coef, + 'min_cfg_coef': self.cfg.generate.lm.min_cfg_coef, + 'decoding_steps': list(self.cfg.generate.lm.decoding_steps), + 'anneal_temp': self.cfg.generate.lm.anneal_temp, + 'span_scoring': self.cfg.generate.lm.span_scoring, + 'span_arrangement': self.cfg.generate.lm.span_arrangement + } + + sequence_len = int(cfg.dataset.segment_duration * self.compression_model.frame_rate) + self.mean_maskrate_to_u = torch.tensor(self._calc_mean_maskrate_to_u_LUT(sequence_len), device=self.device) + self.ce_per_codebook = [torch.log(torch.tensor(self.compression_model.cardinality, device=self.device)) + for _ in range(cfg.transformer_lm.n_q)] + + def build_model(self) -> None: + self.cfg.transformer_lm.segment_duration = self.cfg.dataset.segment_duration + self.cfg.transformer_lm.span_len = self.cfg.masking.span_len + assert self.cfg.efficient_attention_backend == "xformers", "MAGNeT v1 models support only xformers backend." + super().build_model() + + def _calc_mean_maskrate_to_u_LUT(self, T: int): + """ Create a Look Up Table (LUT) transforming a discrete masking percentage m in 0,1,...,100 to u, + the number of overlapping spans of length L to place s.t. the masking rate is approximately m/float(100). + It first creates the inverse transformation, of the masking rate as function of u, + using the expression choose(T - L, u) / choose(T, u), where L is the atomic span length used + during masking. See https://arxiv.org/abs/2401.04577, + appendix C, for the mean mask rate derivation. + + We leverage the fact that: + choose(T - L, u) / choose(T, u) = Prod_{j = 0}^{u - 1}((T - L - j)/(T - j)) + in the provided implementation, in order to avoid overflow. + Args: + T (float): Sequence length. + Returns: + (List) A LUT transforming m in 0,1,...,100 to u, + s.t. the masking rate of the span-L mask is approximately m/float(100). + """ + + L = self.cfg.masking.span_len + + u2mean = [0.0] # mean mask rate is 0.0 for u = 0 + v = (T - L) / float(T) + for u in range(1, T): + u2mean.append(1 - v) + v *= (T - L - u) / (T - u) # Overflow-safe implementation of choose(T - L, u) / choose(T, u). + + mean2u = [] + for maskperc in range(101): + maskrate = maskperc / float(100) + u = int(np.searchsorted(u2mean, maskrate)) + mean2u.append(u) + + return mean2u + + def _non_spans_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: torch.device) -> torch.Tensor: + """ Construct a boolean mask of shape [B, T, 1], with masking rates defined by mask_probs. + The masked tokens are singletons, placed uniformly at random. + Args: + mask_probs (torch.Tensor): The desired masking rate per sample, of shape [B,] + B (int): Batch size. + T (int): Sequence length. + device (torch.device): device of the output tensor + Returns: + (torch.Tensor): A mask of shape [B, T] + """ + num_token_masked = (T * mask_probs).round().clamp(min=1) + batch_randperm = torch.rand((B, T), device=device).argsort(dim=-1) + return batch_randperm < rearrange(num_token_masked, 'b -> b 1') + + def _spans_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: torch.device) -> torch.Tensor: + """ Construct a spans mask with masking rates defined by mask_probs, + where the atomic span length ( > 1 ) is defined by cfg.masking.span_len. + Args: + mask_probs (torch.Tensor): The desired masking rate per sample, of shape [B,] + B (int): Batch size. + T (int): Sequence length. + device (torch.device): device of the output tensor + Returns: + (torch.Tensor): A spans mask of shape [B, T] + """ + rounded_probs = torch.round(100 * mask_probs).long() + k = self.mean_maskrate_to_u[rounded_probs].clamp(min=1) # k is the number of span starts + + # sample random span starts + batch_randperm = torch.rand((B, T), device=device).argsort(dim=-1) + mask = batch_randperm < rearrange(k, 'b -> b 1') + B, T = mask.shape + shifted_mask = mask.clone() + for _ in range(self.cfg.masking.span_len - 1): + shifted_mask = torch.concat((torch.full((B, 1), False, device=device), shifted_mask[:, :-1]), dim=1) + mask = torch.logical_or(mask, shifted_mask) + + return mask + + def _get_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: torch.device) -> torch.Tensor: + """ Construct a boolean mask with masking rates defined by mask_probs, and atomic + span length defined by cfg.masking.span_len. + Args: + mask_probs (torch.Tensor): The desired masking rate per sample, of shape [B,] + B (int): Batch size. + T (int): Sequence length. + device (torch.device): device of the output tensor + Returns: + (torch.Tensor): A boolean tensor of shape [B, T] + """ + if self.cfg.masking.span_len <= 1: + return self._non_spans_mask(mask_probs, B, T, device) + + return self._spans_mask(mask_probs, B, T, device) + + def _compute_cross_entropy_magnet(self, logits: torch.Tensor, + targets: torch.Tensor, mask: torch.Tensor, stage: torch.Tensor) -> torch.Tensor: + """ Compute cross entropy between multi-codebook targets and model's logits. + The cross entropy is computed only on a specific codebook, defined by the stage argument. + Valid timesteps for each codebook are pulled from the mask, where invalid + timesteps are set to 0. + + Args: + logits (torch.Tensor): Model's logits of shape [B, K, T, card]. + targets (torch.Tensor): Target codes, of shape [B, K, T]. + mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. + stage (torch.Tensor): The codebook (idx) that is being optimized, as a scalar tensor. + Returns: + ce (torch.Tensor): Cross entropy of the codebook that is being optimized. + """ + assert logits.shape[:-1] == targets.shape + assert mask.shape == targets.shape + ce = torch.zeros([], device=targets.device) + logits_k = logits[:, stage, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] + targets_k = targets[:, stage, ...].contiguous().view(-1) # [B x T] + mask_k = mask[:, stage, ...].contiguous().view(-1) # [B x T] + + IGNORE_IDX = -1 + targets_k[~mask_k] = IGNORE_IDX + q_ce = F.cross_entropy(logits_k, targets_k, ignore_index=IGNORE_IDX) + + ce += q_ce + return ce + + def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict: + """Perform one training or valid step on a given batch.""" + check_synchronization_points = idx == 1 and self.device == 'cuda' + + condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes( + batch, check_synchronization_points) + + self.deadlock_detect.update('tokens_and_conditions') + + if check_synchronization_points: + torch.cuda.set_sync_debug_mode('warn') + + B, K, T = audio_tokens.shape + device = self.device + + # Choose the stage (codebook idx) for update, uniformly at random. + stage_ = random.randint(0, K - 1) + stage = torch.full((1, ), stage_, device=device) + + # masking + rand_time = torch.zeros((B,), device=device).float().uniform_(0, 1) + rand_mask_probs = torch.cos(rand_time * math.pi * 0.5) + + # stage mask + stage_mask = self._get_mask(rand_mask_probs, B, T, device) # [B, T] + stage_mask = stage_mask.unsqueeze(1) # [B, 1, T] + + # Keep all preceding codebooks. + mask = torch.full((B, K, T), False, device=device) + mask[:, stage, :] = stage_mask + + # Mask all codebooks larger than stage_ + mask_id = self.model.special_token_id + mask[:, (stage_+1):, :] = torch.full((B, K - stage_ - 1, T), True, device=device) + input_tokens = torch.where(mask, mask_id, audio_tokens) + + # Take loss only on the chosen stage, and only on the masked tokens. + loss_mask = torch.full((B, K, T), False, device=device) + loss_mask[:, stage, :] = stage_mask + + with self.autocast: + model_output = self.model.compute_predictions(input_tokens, [], condition_tensors, stage=stage_) + logits = model_output.logits + loss_mask &= padding_mask + ce = self._compute_cross_entropy_magnet(logits, audio_tokens, loss_mask, stage) + loss = ce + self.deadlock_detect.update('loss') + + if check_synchronization_points: + torch.cuda.set_sync_debug_mode('default') + + if self.is_training: + metrics['lr'] = self.optimizer.param_groups[0]['lr'] + if self.scaler is not None: + loss = self.scaler.scale(loss) + self.deadlock_detect.update('scale') + if self.cfg.fsdp.use: + loss.backward() + flashy.distrib.average_tensors(self.model.buffers()) + elif self.cfg.optim.eager_sync: + with flashy.distrib.eager_sync_model(self.model): + loss.backward() + else: + # this should always be slower but can be useful + # for weird use cases like multiple backwards. + loss.backward() + flashy.distrib.sync_model(self.model) + self.deadlock_detect.update('backward') + + if self.scaler is not None: + self.scaler.unscale_(self.optimizer) + if self.cfg.optim.max_norm: + if self.cfg.fsdp.use: + metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore + else: + metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.optim.max_norm + ) + if self.scaler is None: + self.optimizer.step() + else: + self.scaler.step(self.optimizer) + self.scaler.update() + if self.lr_scheduler: + self.lr_scheduler.step() + self.optimizer.zero_grad() + self.deadlock_detect.update('optim') + if self.scaler is not None: + scale = self.scaler.get_scale() + metrics['grad_scale'] = scale + if not loss.isfinite().all(): + raise RuntimeError("Model probably diverged.") + + metrics['ce'] = ce + metrics['ppl'] = torch.exp(ce) + + return metrics + + +class AudioMagnetSolver(MagnetSolver): + """Solver for audio-MAGNeT. A MAGNeT model for sound generation. + + More information can be found in the MAGNeT model card. + """ + DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND diff --git a/config/solver/magnet/audio_magnet_16khz.yaml b/config/solver/magnet/audio_magnet_16khz.yaml new file mode 100644 index 00000000..79326db3 --- /dev/null +++ b/config/solver/magnet/audio_magnet_16khz.yaml @@ -0,0 +1,104 @@ +# @package __global__ + +# This is the training loop solver +# for the base audio-MAGNeT model (text-to-sound) +# on monophonic audio sampled at 16 kHz +# using a similar EnCodec+LM setup to MAGNeT +defaults: + - audiogen/default + - /model: lm/audiogen_lm + - override /dset: audio/default + - _self_ + +lm_model: transformer_lm_magnet +solver: audio_magnet + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 16khz +# with a total stride of 320 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //reference/bd44a852/checkpoint.th + +channels: 1 +sample_rate: 16000 + +deadlock: + use: true # deadlock detection + +dataset: + batch_size: 128 # matching AudioGen paper setup (256 * mix_p=0.5 = 128) + num_workers: 10 + segment_duration: 10 + min_segment_ratio: 1.0 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + external_metadata_source: null + # sample mixing augmentation at train time + train: + batch_size: 256 # matching AudioGen paper setup + aug_p: 0.5 # perform audio mixing 50% of the time + mix_p: 0.5 # proportion of batch items mixed together + # important: note that this will reduce the + # actual batch size used at train time + # which will be equal to mix_p * batch_size + mix_snr_low: -5 + mix_snr_high: 5 + mix_min_overlap: 0.5 + +optim: + epochs: 100 + optimizer: adamw + lr: 5e-4 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: inverse_sqrt + inverse_sqrt: + warmup: 3000 + warmup_init_lr: 0.0 + +codebooks_pattern: + modeling: parallel + parallel: + empty_initial: -1 + +transformer_lm: + card: 2048 + causal: false + subcodes_context: 5 + compression_model_framerate: 50 # NOTE: Must match the actual frame rate of the used compression model + segment_duration: 0 + span_len: -1 + +masking: + span_len: 3 + +generate: + lm: + max_prompt_len: null + max_gen_len: null + remove_prompts: false + use_sampling: true + temp: 3.5 + top_k: 0 + top_p: 0.8 + max_cfg_coef: 20.0 + min_cfg_coef: 1.0 + decoding_steps: [20, 10, 10, 10] + anneal_temp: true + span_scoring: 'max' + span_arrangement: 'nonoverlap' + prompted_samples: false + samples: + prompted: false + unprompted: true + diff --git a/config/solver/magnet/magnet_32khz.yaml b/config/solver/magnet/magnet_32khz.yaml new file mode 100644 index 00000000..8d53b566 --- /dev/null +++ b/config/solver/magnet/magnet_32khz.yaml @@ -0,0 +1,90 @@ +# @package __global__ + +# This is the training loop solver +# for the base MusicGen model (text-to-music) +# on monophonic audio sampled at 32 kHz +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /dset: audio/default + - _self_ + +lm_model: transformer_lm_magnet +solver: magnet + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 32khz +# with a total stride of 640 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //pretrained/facebook/encodec_32khz + +efficient_attention_backend: xformers # restricted attention implementation supports only xformers at the moment + +channels: 1 +sample_rate: 32000 + +deadlock: + use: true # deadlock detection + +dataset: + batch_size: 192 # 32 GPUs + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + +optim: + epochs: 500 + optimizer: dadam + lr: 1 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 + +codebooks_pattern: + modeling: parallel + parallel: + empty_initial: -1 + +transformer_lm: + card: 2048 + causal: false + subcodes_context: 5 + compression_model_framerate: 50 # NOTE: Must match the actual frame rate of the used compression model + segment_duration: 0 + span_len: -1 + +masking: + span_len: 3 + +generate: + lm: + max_prompt_len: null + max_gen_len: null + remove_prompts: false + use_sampling: true + temp: 3.0 + top_k: 0 + top_p: 0.9 + max_cfg_coef: 10.0 + min_cfg_coef: 1.0 + decoding_steps: [60, 10, 10, 10] + anneal_temp: true + span_scoring: 'max' + span_arrangement: 'nonoverlap' + prompted_samples: false + samples: + prompted: false + unprompted: true diff --git a/demos/magnet_app.py b/demos/magnet_app.py new file mode 100644 index 00000000..a5713c56 --- /dev/null +++ b/demos/magnet_app.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under thmage license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from concurrent.futures import ProcessPoolExecutor +import logging +import os +from pathlib import Path +import subprocess as sp +import sys +from tempfile import NamedTemporaryFile +import time +import typing as tp +import warnings + +import gradio as gr + +from audiocraft.data.audio import audio_write +from audiocraft.models import MAGNeT + + +MODEL = None # Last used model +SPACE_ID = os.environ.get('SPACE_ID', '') +MAX_BATCH_SIZE = 12 +N_REPEATS = 2 +INTERRUPTING = False +MBD = None +# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform +_old_call = sp.call + +PROD_STRIDE_1 = "prod-stride1 (new!)" + + +def _call_nostderr(*args, **kwargs): + # Avoid ffmpeg vomiting on the logs. + kwargs['stderr'] = sp.DEVNULL + kwargs['stdout'] = sp.DEVNULL + _old_call(*args, **kwargs) + + +sp.call = _call_nostderr +# Preallocating the pool of processes. +pool = ProcessPoolExecutor(4) +pool.__enter__() + + +def interrupt(): + global INTERRUPTING + INTERRUPTING = True + + +class FileCleaner: + def __init__(self, file_lifetime: float = 3600): + self.file_lifetime = file_lifetime + self.files = [] + + def add(self, path: tp.Union[str, Path]): + self._cleanup() + self.files.append((time.time(), Path(path))) + + def _cleanup(self): + now = time.time() + for time_added, path in list(self.files): + if now - time_added > self.file_lifetime: + if path.exists(): + path.unlink() + self.files.pop(0) + else: + break + + +file_cleaner = FileCleaner() + + +def make_waveform(*args, **kwargs): + # Further remove some warnings. + be = time.time() + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + out = gr.make_waveform(*args, **kwargs) + print("Make a video took", time.time() - be) + return out + + +def load_model(version='facebook/magnet-small-10secs'): + global MODEL + print("Loading model", version) + if MODEL is None or MODEL.name != version: + MODEL = None # in case loading would crash + MODEL = MAGNeT.get_pretrained(version) + + +def _do_predictions(texts, progress=False, gradio_progress=None, **gen_kwargs): + MODEL.set_generation_params(**gen_kwargs) + print("new batch", len(texts), texts) + be = time.time() + + try: + outputs = MODEL.generate(texts, progress=progress, return_tokens=False) + except RuntimeError as e: + raise gr.Error("Error while generating " + e.args[0]) + outputs = outputs.detach().cpu().float() + pending_videos = [] + out_wavs = [] + for i, output in enumerate(outputs): + with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: + audio_write( + file.name, output, MODEL.sample_rate, strategy="loudness", + loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) + if i == 0: + pending_videos.append(pool.submit(make_waveform, file.name)) + out_wavs.append(file.name) + file_cleaner.add(file.name) + out_videos = [pending_video.result() for pending_video in pending_videos] + for video in out_videos: + file_cleaner.add(video) + print("batch finished", len(texts), time.time() - be) + print("Tempfiles currently stored: ", len(file_cleaner.files)) + return out_videos, out_wavs + + +def predict_batched(texts, melodies): + max_text_length = 512 + texts = [text[:max_text_length] for text in texts] + load_model('facebook/magnet-small-10secs') + res = _do_predictions(texts, melodies) + return res + + +def predict_full(model, model_path, text, temperature, topp, + max_cfg_coef, min_cfg_coef, + decoding_steps1, decoding_steps2, decoding_steps3, decoding_steps4, + span_score, + progress=gr.Progress()): + global INTERRUPTING + INTERRUPTING = False + progress(0, desc="Loading model...") + model_path = model_path.strip() + if model_path: + if not Path(model_path).exists(): + raise gr.Error(f"Model path {model_path} doesn't exist.") + if not Path(model_path).is_dir(): + raise gr.Error(f"Model path {model_path} must be a folder containing " + "state_dict.bin and compression_state_dict_.bin.") + model = model_path + if temperature < 0: + raise gr.Error("Temperature must be >= 0.") + + load_model(model) + + max_generated = 0 + + def _progress(generated, to_generate): + nonlocal max_generated + max_generated = max(generated, max_generated) + progress((min(max_generated, to_generate), to_generate)) + if INTERRUPTING: + raise gr.Error("Interrupted.") + MODEL.set_custom_progress_callback(_progress) + + videos, wavs = _do_predictions( + [text] * N_REPEATS, progress=True, + temperature=temperature, top_p=topp, + max_cfg_coef=max_cfg_coef, min_cfg_coef=min_cfg_coef, + decoding_steps=[decoding_steps1, decoding_steps2, decoding_steps3, decoding_steps4], + span_arrangement='stride1' if (span_score == PROD_STRIDE_1) else 'nonoverlap', + gradio_progress=progress) + + outputs_ = [videos[0]] + [wav for wav in wavs] + return tuple(outputs_) + +def ui_full(launch_kwargs): + with gr.Blocks() as interface: + gr.Markdown( + """ + # MAGNeT + This is your private demo for [MAGNeT](https://github.com/facebookresearch/audiocraft), + A fast text-to-music model, consists of a single, non-autoregressive transformer. + presented at: ["Masked Audio Generation using a Single Non-Autoregressive Transformer"] (https://huggingface.co/papers/2401.04577) + """ + ) + with gr.Row(): + with gr.Column(): + with gr.Row(): + text = gr.Text(label="Input Text", value="80s electronic track with melodic synthesizers, catchy beat and groovy bass", interactive=True) + with gr.Row(): + submit = gr.Button("Submit") + # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. + _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) + with gr.Row(): + model = gr.Radio(['facebook/magnet-small-10secs', 'facebook/magnet-medium-10secs', + 'facebook/magnet-small-30secs', 'facebook/magnet-medium-30secs', + 'facebook/audio-magnet-small', 'facebook/audio-magnet-medium'], + label="Model", value='facebook/magnet-small-10secs', interactive=True) + model_path = gr.Text(label="Model Path (custom models)") + with gr.Row(): + span_score = gr.Radio(["max-nonoverlap", PROD_STRIDE_1], + label="Span Scoring", value=PROD_STRIDE_1, interactive=True) + with gr.Row(): + decoding_steps1 = gr.Number(label="Decoding Steps (stage 1)", value=20, interactive=True) + decoding_steps2 = gr.Number(label="Decoding Steps (stage 2)", value=10, interactive=True) + decoding_steps3 = gr.Number(label="Decoding Steps (stage 3)", value=10, interactive=True) + decoding_steps4 = gr.Number(label="Decoding Steps (stage 4)", value=10, interactive=True) + with gr.Row(): + temperature = gr.Number(label="Temperature", value=3.0, step=0.25, minimum=0, interactive=True) + topp = gr.Number(label="Top-p", value=0.9, step=0.1, minimum=0, maximum=1, interactive=True) + max_cfg_coef = gr.Number(label="Max CFG coefficient", value=10.0, minimum=0, interactive=True) + min_cfg_coef = gr.Number(label="Min CFG coefficient", value=1.0, minimum=0, interactive=True) + with gr.Column(): + output = gr.Video(label="Generated Audio - variation 1") + audio_outputs = [gr.Audio(label=f"Generated Audio - variation {i+1}", type='filepath') for i in range(N_REPEATS)] + submit.click(fn=predict_full, + inputs=[model, model_path, text, + temperature, topp, + max_cfg_coef, min_cfg_coef, + decoding_steps1, decoding_steps2, decoding_steps3, decoding_steps4, + span_score], + outputs=[output] + [o for o in audio_outputs]) + gr.Examples( + fn=predict_full, + examples=[ + [ + "80s electronic track with melodic synthesizers, catchy beat and groovy bass", + 'facebook/magnet-small-10secs', + 20, 3.0, 0.9, 10.0, + ], + [ + "80s electronic track with melodic synthesizers, catchy beat and groovy bass. 170 bpm", + 'facebook/magnet-small-10secs', + 20, 3.0, 0.9, 10.0, + ], + [ + "Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves", + 'facebook/magnet-medium-10secs', + 20, 3.0, 0.9, 10.0, + ], + [ "Funky groove with electric piano playing blue chords rhythmically", + 'facebook/magnet-medium-10secs', + 20, 3.0, 0.9, 10.0, + ], + [ + "Rock with saturated guitars, a heavy bass line and crazy drum break and fills.", + 'facebook/magnet-small-30secs', + 60, 3.0, 0.9, 10.0, + ], + [ "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle", + 'facebook/magnet-medium-30secs', + 60, 3.0, 0.9, 10.0, + ], + [ "Seagulls squawking as ocean waves crash while wind blows heavily into a microphone.", + 'facebook/audio-magnet-small', + 20, 3.5, 0.8, 20.0, + ], + [ "A toilet flushing as music is playing and a man is singing in the distance.", + 'facebook/audio-magnet-medium', + 20, 3.5, 0.8, 20.0, + ], + ], + + inputs=[text, model, decoding_steps1, temperature, topp, max_cfg_coef], + outputs=[output] + ) + + gr.Markdown( + """ + ### More details + + #### Music Generation + "magnet" models will generate a short music extract based on the textual description you provided. + These models can generate either 10 seconds or 30 seconds of music. + These models were trained with descriptions from a stock music catalog. Descriptions that will work best + should include some level of details on the instruments present, along with some intended use case + (e.g. adding "perfect for a commercial" can somehow help). + + We present 4 model variants: + 1. facebook/magnet-small-10secs - a 300M non-autoregressive transformer capable of generating 10-second music conditioned + on text. + 2. facebook/magnet-medium-10secs - 1.5B parameters, 10 seconds audio. + 3. facebook/magnet-small-30secs - 300M parameters, 30 seconds audio. + 4. facebook/magnet-medium-30secs - 1.5B parameters, 30 seconds audio. + + #### Sound-Effect Generation + "audio-magnet" models will generate a 10-second sound effect based on the description you provide. + + These models were trained on the following data sources: a subset of AudioSet (Gemmeke et al., 2017), + [BBC sound effects](https://sound-effects.bbcrewind.co.uk/), AudioCaps (Kim et al., 2019), + Clotho v2 (Drossos et al., 2020), VGG-Sound (Chen et al., 2020), FSD50K (Fonseca et al., 2021), + [Free To Use Sounds](https://www.freetousesounds.com/all-in-one-bundle/), [Sonniss Game Effects](https://sonniss.com/gameaudiogdc), + [WeSoundEffects](https://wesoundeffects.com/we-sound-effects-bundle-2020/), + [Paramount Motion - Odeon Cinematic Sound Effects](https://www.paramountmotion.com/odeon-sound-effects). + + We present 2 model variants: + 1. facebook/audio-magnet-small - 10 second sound effect generation, 300M parameters. + 2. facebook/audio-magnet-medium - 10 second sound effect generation, 1.5B parameters. + + See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MAGNET.md) + for more details. + """ + ) + + interface.queue().launch(**launch_kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--listen', + type=str, + default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1', + help='IP to listen on for connections to Gradio', + ) + parser.add_argument( + '--username', type=str, default='', help='Username for authentication' + ) + parser.add_argument( + '--password', type=str, default='', help='Password for authentication' + ) + parser.add_argument( + '--server_port', + type=int, + default=0, + help='Port to run the server listener on', + ) + parser.add_argument( + '--inbrowser', action='store_true', help='Open in browser' + ) + parser.add_argument( + '--share', action='store_true', help='Share the gradio UI' + ) + + args = parser.parse_args() + + launch_kwargs = {} + launch_kwargs['server_name'] = args.listen + + if args.username and args.password: + launch_kwargs['auth'] = (args.username, args.password) + if args.server_port: + launch_kwargs['server_port'] = args.server_port + if args.inbrowser: + launch_kwargs['inbrowser'] = args.inbrowser + if args.share: + launch_kwargs['share'] = args.share + + logging.basicConfig(level=logging.INFO, stream=sys.stderr) + + # Show the interface + ui_full(launch_kwargs) diff --git a/demos/magnet_demo.ipynb b/demos/magnet_demo.ipynb new file mode 100644 index 00000000..0138468a --- /dev/null +++ b/demos/magnet_demo.ipynb @@ -0,0 +1,214 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MAGNeT\n", + "Welcome to MAGNeT's demo jupyter notebook. \n", + "Here you will find a self-contained example of how to use MAGNeT for music/sound-effect generation.\n", + "\n", + "First, we start by initializing MAGNeT for music generation, you can choose a model from the following selection:\n", + "1. facebook/magnet-small-10secs - a 300M non-autoregressive transformer capable of generating 10-second music conditioned on text.\n", + "2. facebook/magnet-medium-10secs - 1.5B parameters, 10 seconds music samples.\n", + "3. facebook/magnet-small-30secs - 300M parameters, 30 seconds music samples.\n", + "4. facebook/magnet-medium-30secs - 1.5B parameters, 30 seconds music samples.\n", + "\n", + "We will use the `facebook/magnet-small-10secs` variant for the purpose of this demonstration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from audiocraft.models import MAGNeT\n", + "\n", + "model = MAGNeT.get_pretrained('facebook/magnet-small-10secs')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let us configure the generation parameters. Specifically, you can control the following:\n", + "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n", + "* `top_k` (int, optional): top_k used for sampling. Defaults to 0.\n", + "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.\n", + "* `temperature` (float, optional): Initial softmax temperature parameter. Defaults to 3.0.\n", + "* `max_clsfg_coef` (float, optional): Initial coefficient used for classifier free guidance. Defaults to 10.0.\n", + "* `min_clsfg_coef` (float, optional): Final coefficient used for classifier free guidance. Defaults to 1.0.\n", + "* `decoding_steps` (list of n_q ints, optional): The number of iterative decoding steps, for each of the n_q RVQ codebooks.\n", + "* `span_arrangement` (str, optional): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1') \n", + " in the masking scheme. \n", + "\n", + "When left unchanged, MAGNeT will revert to its default parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.set_generation_params(\n", + " use_sampling=True,\n", + " top_k=0,\n", + " top_p=0.9,\n", + " temperature=3.0,\n", + " max_cfg_coef=10.0,\n", + " min_cfg_coef=1.0,\n", + " decoding_steps=[int(20 * model.lm.cfg.dataset.segment_duration // 10), 10, 10, 10],\n", + " span_arrangement='nonoverlap'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we can go ahead and start generating music given textual prompts." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Text-conditional Generation - Music" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from audiocraft.utils.notebook import display_audio\n", + "\n", + "###### Text-to-music prompts - examples ######\n", + "text = \"80s electronic track with melodic synthesizers, catchy beat and groovy bass\"\n", + "# text = \"80s electronic track with melodic synthesizers, catchy beat and groovy bass. 170 bpm\"\n", + "# text = \"Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves\"\n", + "# text = \"Funky groove with electric piano playing blue chords rhythmically\"\n", + "# text = \"Rock with saturated guitars, a heavy bass line and crazy drum break and fills.\"\n", + "# text = \"A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle\"\n", + " \n", + "N_VARIATIONS = 3\n", + "descriptions = [text for _ in range(N_VARIATIONS)]\n", + "\n", + "print(f\"text prompt: {text}\\n\")\n", + "output = model.generate(descriptions=descriptions, progress=True, return_tokens=True)\n", + "display_audio(output[0], sample_rate=model.compression_model.sample_rate)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Text-conditional Generation - Sound Effects" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Besides music, MAGNeT models can generate sound effects given textual prompts. \n", + "First, let's load an Audio-MAGNeT model, out of the following collection: \n", + "1. facebook/audio-magnet-small - a 300M non-autoregressive transformer capable of generating 10 second sound effects conditioned on text.\n", + "2. facebook/audio-magnet-medium - 10 second sound effect generation, 1.5B parameters.\n", + "\n", + "We will use the `facebook/audio-magnet-small` variant for the purpose of this demonstration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from audiocraft.models import MAGNeT\n", + "\n", + "model = MAGNeT.get_pretrained('facebook/audio-magnet-small')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The recommended parameters for sound generation are a bit different than the defaults in MAGNeT, let's initialize it: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.set_generation_params(\n", + " use_sampling=True,\n", + " top_k=0,\n", + " top_p=0.8,\n", + " temperature=3.5,\n", + " max_cfg_coef=20.0,\n", + " min_cfg_coef=1.0,\n", + " decoding_steps=[int(20 * model.lm.cfg.dataset.segment_duration // 10), 10, 10, 10],\n", + " span_arrangement='nonoverlap'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we can go ahead and start generating sounds given textual prompts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from audiocraft.utils.notebook import display_audio\n", + " \n", + "###### Text-to-audio prompts - examples ######\n", + "text = \"Seagulls squawking as ocean waves crash while wind blows heavily into a microphone.\"\n", + "# text = \"A toilet flushing as music is playing and a man is singing in the distance.\"\n", + "\n", + "N_VARIATIONS = 3\n", + "descriptions = [text for _ in range(N_VARIATIONS)]\n", + "\n", + "print(f\"text prompt: {text}\\n\")\n", + "output = model.generate(descriptions=descriptions, progress=True, return_tokens=True)\n", + "display_audio(output[0], sample_rate=model.compression_model.sample_rate)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "vscode": { + "interpreter": { + "hash": "b02c911f9b3627d505ea4a19966a915ef21f28afb50dbf6b2115072d27c69103" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/MAGNET.md b/docs/MAGNET.md new file mode 100644 index 00000000..5d115d75 --- /dev/null +++ b/docs/MAGNET.md @@ -0,0 +1,237 @@ +# MAGNeT: Masked Audio Generation using a Single Non-Autoregressive Transformer + +AudioCraft provides the code and models for MAGNeT, [Masked Audio Generation using a Single Non-Autoregressive Transformer][arxiv]. + +MAGNeT is a text-to-music and text-to-sound model capable of generating high-quality audio samples conditioned on text descriptions. +It is a masked generative non-autoregressive Transformer trained over a 32kHz EnCodec tokenizer with 4 codebooks sampled at 50 Hz. +Unlike prior work on masked generative audio Transformers, such as [SoundStorm](https://arxiv.org/abs/2305.09636) and [VampNet](https://arxiv.org/abs/2307.04686), +MAGNeT doesn't require semantic token conditioning, model cascading or audio prompting, and employs a full text-to-audio using a single non-autoregressive Transformer. + +Check out our [sample page][magnet_samples] or test the available demo! + +We use 16K hours of licensed music to train MAGNeT. Specifically, we rely on an internal dataset +of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data. + + +## Model Card + +See [the model card](../model_cards/MAGNET_MODEL_CARD.md). + + +## Installation + +Please follow the AudioCraft installation instructions from the [README](../README.md). + +AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters). + +## Usage + +We currently offer two ways to interact with MAGNeT: +1. You can use the gradio demo locally by running [`python -m demos.magnet_app --share`](../demos/magnet_app.py). +2. You can play with MAGNeT by running the jupyter notebook at [`demos/magnet_demo.ipynb`](../demos/magnet_demo.ipynb) locally (if you have a GPU). + +## API + +We provide a simple API and 6 pre-trained models. The pre trained models are: +- `facebook/magnet-small-10secs`: 300M model, text to music, generates 10-second samples - [🤗 Hub](https://huggingface.co/facebook/magnet-small-10secs) +- `facebook/magnet-medium-10secs`: 1.5B model, text to music, generates 10-second samples - [🤗 Hub](https://huggingface.co/facebook/magnet-medium-10secs) +- `facebook/magnet-small-30secs`: 300M model, text to music, generates 30-second samples - [🤗 Hub](https://huggingface.co/facebook/magnet-small-30secs) +- `facebook/magnet-medium-30secs`: 1.5B model, text to music, generates 30-second samples - [🤗 Hub](https://huggingface.co/facebook/magnet-medium-30secs) +- `facebook/audio-magnet-small`: 300M model, text to sound-effect - [🤗 Hub](https://huggingface.co/facebook/audio-magnet-small) +- `facebook/audio-magnet-small`: 300M model, text to sound-effect - [🤗 Hub](https://huggingface.co/facebook/audio-magnet-medium) + +In order to use MAGNeT locally **you must have a GPU**. We recommend 16GB of memory, especially for +the medium size models. + +See after a quick example for using the API. + +```python +import torchaudio +from audiocraft.models import MAGNeT +from audiocraft.data.audio import audio_write + +model = MAGNeT.get_pretrained('facebook/magnet-small-10secs') +descriptions = ['disco beat', 'energetic EDM', 'funky groove'] +wav = model.generate(descriptions) # generates 3 samples. + +for idx, one_wav in enumerate(wav): + # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. + audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) +``` + +## 🤗 Transformers Usage + +Coming soon... + +## Training + +The [MagnetSolver](../audiocraft/solvers/magnet.py) implements MAGNeT's training pipeline. +It defines a masked generation task over multiple streams of discrete tokens +extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md) +for more details on how to train such model). + +Note that **we do NOT provide any of the datasets** used for training MAGNeT. +We provide a dummy dataset containing just a few examples for illustrative purposes. + +Please read first the [TRAINING documentation](./TRAINING.md), in particular the Environment Setup section. + + +### Example configurations and grids + +We provide configurations to reproduce the released models and our research. +MAGNeT solvers configuration are available in [config/solver/magnet](../config/solver/magnet), +in particular: +* MAGNeT model for text-to-music: +[`solver=magnet/magnet_32khz`](../config/solver/magnet/magnet_32khz.yaml) +* MAGNeT model for text-to-sound: +[`solver=magnet/audio_magnet_16khz`](../config/solver/magnet/audio_magnet_16khz.yaml) + +We provide 3 different scales, e.g. `model/lm/model_scale=small` (300M), or `medium` (1.5B), and `large` (3.3B). + +Please find some example grids to train MAGNeT at +[audiocraft/grids/magnet](../audiocraft/grids/magnet/). + +```shell +# text-to-music +dora grid magnet.magnet_32khz --dry_run --init + +# text-to-sound +dora grid magnet.audio_magnet_16khz --dry_run --init + +# Remove the `--dry_run --init` flags to actually schedule the jobs once everything is setup. +``` + +### dataset and metadata +Learn more in the [datasets section](./DATASETS.md). + +#### Music Models +MAGNeT's underlying dataset is an AudioDataset augmented with music-specific metadata. +The MAGNeT dataset implementation expects the metadata to be available as `.json` files +at the same location as the audio files. + +#### Sound Models +Audio-MAGNeT's underlying dataset is an AudioDataset augmented with description metadata. +The Audio-MAGNeT dataset implementation expects the metadata to be available as `.json` files +at the same location as the audio files or through specified external folder. + +### Audio tokenizers + +See [MusicGen](./MUSICGEN.md) + +### Fine tuning existing models + +You can initialize your model to one of the pretrained models by using the `continue_from` argument, in particular + +```bash +# Using pretrained MAGNeT model. +dora run solver=magnet/magnet_32khz model/lm/model_scale=medium continue_from=//pretrained/facebook/magnet-medium-10secs conditioner=text2music + +# Using another model you already trained with a Dora signature SIG. +dora run solver=magnet/magnet_32khz model/lm/model_scale=medium continue_from=//sig/SIG conditioner=text2music + +# Or providing manually a path +dora run solver=magnet/magnet_32khz model/lm/model_scale=medium continue_from=/checkpoints/my_other_xp/checkpoint.th +``` + +**Warning:** You are responsible for selecting the other parameters accordingly, in a way that make it compatible + with the model you are fine tuning. Configuration is NOT automatically inherited from the model you continue from. In particular make sure to select the proper `conditioner` and `model/lm/model_scale`. + +**Warning:** We currently do not support fine tuning a model with slightly different layers. If you decide + to change some parts, like the conditioning or some other parts of the model, you are responsible for manually crafting a checkpoint file from which we can safely run `load_state_dict`. + If you decide to do so, make sure your checkpoint is saved with `torch.save` and contains a dict + `{'best_state': {'model': model_state_dict_here}}`. Directly give the path to `continue_from` without a `//pretrained/` prefix. + +### Evaluation stage +For the 6 pretrained MAGNeT models, objective metrics could be reproduced using the following grids: + +```shell +# text-to-music +REGEN=1 dora grid magnet.magnet_pretrained_32khz_eval --dry_run --init + +# text-to-sound +REGEN=1 dora grid magnet.audio_magnet_pretrained_16khz_eval --dry_run --init + +# Remove the `--dry_run --init` flags to actually schedule the jobs once everything is setup. +``` + +See [MusicGen](./MUSICGEN.md) for more details. + +### Generation stage + +See [MusicGen](./MUSICGEN.md) + +### Playing with the model + +Once you have launched some experiments, you can easily get access +to the Solver with the latest trained model using the following snippet. + +```python +from audiocraft.solvers.magnet import MagnetSolver + +solver = MagnetSolver.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8) +solver.model +solver.dataloaders +``` + +### Importing / Exporting models + +We do not support currently loading a model from the Hugging Face implementation or exporting to it. +If you want to export your model in a way that is compatible with `audiocraft.models.MAGNeT` +API, you can run: + +```python +from audiocraft.utils import export +from audiocraft import train +xp = train.main.get_xp_from_sig('SIG_OF_LM') +export.export_lm(xp.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/state_dict.bin') +# You also need to bundle the EnCodec model you used !! +## Case 1) you trained your own +xp_encodec = train.main.get_xp_from_sig('SIG_OF_ENCODEC') +export.export_encodec(xp_encodec.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/compression_state_dict.bin') +## Case 2) you used a pretrained model. Give the name you used without the //pretrained/ prefix. +## This will actually not dump the actual model, simply a pointer to the right model to download. +export.export_pretrained_compression_model('facebook/encodec_32khz', '/checkpoints/my_audio_lm/compression_state_dict.bin') +``` + +Now you can load your custom model with: +```python +import audiocraft.models +magnet = audiocraft.models.MAGNeT.get_pretrained('/checkpoints/my_audio_lm/') +``` + + +### Learn more + +Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). + +## FAQ + +#### What are top-k, top-p, temperature and classifier-free guidance? + +Check out [@FurkanGozukara tutorial](https://github.com/FurkanGozukara/Stable-Diffusion/blob/main/Tutorials/AI-Music-Generation-Audiocraft-Tutorial.md#more-info-about-top-k-top-p-temperature-and-classifier-free-guidance-from-chatgpt). + +#### Should I use FSDP or autocast ? + +The two are mutually exclusive (because FSDP does autocast on its own). +You can use autocast up to 1.5B (medium), if you have enough RAM on your GPU. +FSDP makes everything more complex but will free up some memory for the actual +activations by sharding the optimizer state. + +## Citation +``` +@misc{ziv2024masked, + title={Masked Audio Generation using a Single Non-Autoregressive Transformer}, + author={Alon Ziv and Itai Gat and Gael Le Lan and Tal Remez and Felix Kreuk and Alexandre Défossez and Jade Copet and Gabriel Synnaeve and Yossi Adi}, + year={2024}, + eprint={2401.04577}, + archivePrefix={arXiv}, + primaryClass={cs.SD} +} +``` + +## License + +See license information in the [model card](../model_cards/MAGNET_MODEL_CARD.md). + +[arxiv]: https://arxiv.org/abs/2401.04577 +[magnet_samples]: https://pages.cs.huji.ac.il/adiyoss-lab/MAGNeT/ diff --git a/model_cards/MAGNET_MODEL_CARD.md b/model_cards/MAGNET_MODEL_CARD.md new file mode 100644 index 00000000..b77e2037 --- /dev/null +++ b/model_cards/MAGNET_MODEL_CARD.md @@ -0,0 +1,109 @@ +# MAGNeT Model Card + +## Model details + +**Organization developing the model:** The FAIR team of Meta AI. + +**Model date:** MAGNeT was trained between November 2023 and January 2024. + +**Model version:** This is the version 1 of the model. + +**Model type:** MAGNeT consists of an EnCodec model for audio tokenization, and a non-autoregressive model based on the transformer architecture for music modeling. The model comes in different sizes: 300M and 1.5B; and two variants: a model trained for text-to-music generation, and a model trained for text-to-sound generation. + +**Paper or resources for more information:** More information can be found in the paper [Masked Audio Generation using a Single Non-Autoregressive Transformer][arxiv]. + +**Citation details:** See [our paper][arxiv] + +**License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0. + +**Where to send questions or comments about the model:** Questions and comments about MAGNeT can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue. + +## Intended use +**Primary intended use:** The primary use of MAGNeT is research on AI-based music generation, including: + +- Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science +- Generation of music guided by text to understand current abilities of generative AI models by machine learning amateurs + +**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models. + +**Out-of-scope use cases:** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. + +## Metrics + +**Models performance measures:** We used the following objective measure to evaluate the model on a standard music benchmark: + +- Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish) +- Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST) +- CLAP Score between audio embedding and text embedding extracted from a pre-trained CLAP model + +Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes: + +- Overall quality of the music samples; +- Text relevance to the provided text input; + +More details on performance measures and human studies can be found in the paper. + +**Decision thresholds:** Not applicable. + +## Evaluation datasets + +The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/datasets/googleai/musiccaps) and on an in-domain held-out evaluation set, with no artist overlap with the training set. + +## Training datasets + +The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing. + +## Evaluation results + +Below are the objective metrics obtained on MusicCaps with the released model. Note that for the publicly released models, we used the state-of-the-art music source separation method, namely the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs), in order to keep only instrumental tracks. This explains the difference in objective metrics with the models used in the paper. + +| Model | Frechet Audio Distance | KLD | Text Consistency | +|---|---|---|---| +| **facebook/magnet-small-10secs** | 4.22 | 1.11 | 0.28 | +| facebook/magnet-medium-10secs | 4.61 | 1.14 | 0.28 | +| facebook/magnet-small-30secs | 4.35 | 1.17 | 0.28 | +| facebook/magnet-medium-30secs | 4.63 | 1.20 | 0.28 | + +More information can be found in the paper [Masked Audio Generation using a Single Non-Autoregressive Transformer][arxiv], in the Results section. + +## Limitations and biases + +**Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 16K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model. + +**Mitigations:** Tracks that include vocals have been removed from the data source using corresponding tags, and using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs). + +**Limitations:** + +- The model is not able to generate realistic vocals. +- The model has been trained with English descriptions and will not perform as well in other languages. +- The model does not perform equally well for all music styles and cultures. +- The model sometimes generates end of songs, collapsing to silence. +- It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results. + +**Biases:** The source of data is potentially lacking diversity and all music cultures are not equally represented in the dataset. The model may not perform equally well on the wide variety of music genres that exists. The generated samples from the model will reflect the biases from the training data. Further work on this model should include methods for balanced and just representations of cultures, for example, by scaling the training data to be both diverse and inclusive. + +**Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data. + +**Use cases:** Users must be aware of the biases, limitations and risks of the model. MAGNeT is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks. + +[arxiv]: https://arxiv.org/abs/2401.04577 + +## Audio-MAGNeT - Sound-effect generation models + +### Training datasets + +The audio-MAGNeT models were trained on the following data sources: a subset of AudioSet (Gemmeke et al., 2017), [BBC sound effects](https://sound-effects.bbcrewind.co.uk/), AudioCaps (Kim et al., 2019), Clotho v2 (Drossos et al., 2020), VGG-Sound (Chen et al., 2020), FSD50K (Fonseca et al., 2021), [Free To Use Sounds](https://www.freetousesounds.com/all-in-one-bundle/), [Sonniss Game Effects](https://sonniss.com/gameaudiogdc), [WeSoundEffects](https://wesoundeffects.com/we-sound-effects-bundle-2020/), [Paramount Motion - Odeon Cinematic Sound Effects](https://www.paramountmotion.com/odeon-sound-effects). + + +### Evaluation datasets + +The audio-magnet models (sound effect generation) were evaluated on the [AudioCaps benchmark](https://audiocaps.github.io/). + +### Evaluation results + +Below are the objective metrics obtained with the released audio-magnet models on AudioCaps (consisting of 10-second long samples). + +| Model | Frechet Audio Distance | KLD | +|---|---|---| +| facebook/audio-magnet-small | 3.21 | 1.42 | +| facebook/audio-magnet-medium | 2.32 | 1.64 |