diff --git a/.github/actions/audiocraft_build/action.yml b/.github/actions/audiocraft_build/action.yml index b412cd02..74ab1e38 100644 --- a/.github/actions/audiocraft_build/action.yml +++ b/.github/actions/audiocraft_build/action.yml @@ -5,7 +5,7 @@ runs: steps: - uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.9 - uses: actions/cache@v2 id: cache with: @@ -21,9 +21,9 @@ runs: python3 -m venv env . env/bin/activate python -m pip install --upgrade pip - pip install torch torchvision torchaudio + pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pip install xformers - pip install -e '.[dev]' + pip install -e '.[dev,wm]' - name: System Dependencies shell: bash run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index 7828bcb2..e1931719 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [1.4.0a1] - 2024-06-03 + +Adding new metric PesqMetric ([Perceptual Evaluation of Speech Quality](https://doi.org/10.5281/zenodo.6549559)) + +Adding multiple audio augmentation functions: generating pink noises, up-/downsampling, low-/highpass filtering, banpass filtering, smoothing, duck masking, boosting. All are wrapped in the `audiocraft.utils.audio_effects.AudioEffects` and can be called with the API `audiocraft.utils.audio_effects.select_audio_effects`. + +Add training code for AudioSeal (https://arxiv.org/abs/2401.17264) along with the [hf checkpoints]( https://huggingface.co/facebook/audioseal). + ## [1.3.0] - 2024-05-02 Adding the MAGNeT model (https://arxiv.org/abs/2401.04577) along with hf checkpoints and a gradio demo app. diff --git a/Makefile b/Makefile index 3a491006..27ed2149 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,9 @@ INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_m transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \ checkpoint.save_last=false # Using compression model from 616d7b3c +INTEG_WATERMARK = AUDIOCRAFT_DORA_DIR="/tmp/wm_$(USER)" dora run device=cpu dataset.num_workers=0 optim.epochs=1 \ + dataset.train.num_samples=10 dataset.valid.num_samples=10 dataset.evaluate.num_samples=10 dataset.generate.num_samples=10 \ + logging.level=DEBUG solver=watermark/robustness checkpoint.save_last=false dset=audio/example default: linter tests @@ -29,6 +32,7 @@ tests_integ: $(INTEG_MBD) $(INTEG_MUSICGEN) $(INTEG_AUDIOGEN) + $(INTEG_WATERMARK) api_docs: diff --git a/README.md b/README.md index 795cf948..4560ea24 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ python -m pip install setuptools wheel python -m pip install -U audiocraft # stable release python -m pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge python -m pip install -e . # or if you cloned the repo locally (mandatory if you want to train). +python -m pip install -e '.[wm]' # if you want to train a watermarking model ``` We also recommend having `ffmpeg` installed, either through your system or Anaconda: @@ -37,6 +38,7 @@ At the moment, AudioCraft contains the training code and inference code for: * [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec. * [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion. * [MAGNeT](./docs/MAGNET.md): A state-of-the-art non-autoregressive model for text-to-music and text-to-sound. +* [AudioSeal](./docs/WATERMARKING.md): A state-of-the-art audio watermarking. ## Training code diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py index 062ab7e5..f00d6e5e 100644 --- a/audiocraft/__init__.py +++ b/audiocraft/__init__.py @@ -23,4 +23,4 @@ # flake8: noqa from . import data, modules, models -__version__ = '1.3.0' +__version__ = '1.4.0a1' diff --git a/audiocraft/data/audio.py b/audiocraft/data/audio.py index a35dfd9c..8496cb61 100644 --- a/audiocraft/data/audio.py +++ b/audiocraft/data/audio.py @@ -114,7 +114,7 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., - duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]: + duration: float = -1.0, pad: bool = False) -> tp.Tuple[torch.Tensor, int]: """Read audio by picking the most appropriate backend tool based on the audio format. Args: @@ -229,3 +229,123 @@ def audio_write(stem_name: tp.Union[str, Path], path.unlink() raise return path + + +def get_spec(y, sr=16000, n_fft=4096, hop_length=128, dur=8) -> np.ndarray: + """Get the mel-spectrogram from the raw audio. + + Args: + y (numpy array): raw input + sr (int): Sampling rate + n_fft (int): Number of samples per FFT. Default is 2048. + hop_length (int): Number of samples between successive frames. Default is 512. + dur (float): Maxium duration to get the spectrograms + Returns: + spectro histogram as a numpy array + """ + import librosa + import librosa.display + + spectrogram = librosa.feature.melspectrogram( + y=y, sr=sr, n_fft=n_fft, hop_length=hop_length + ) + spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max) + return spectrogram_db + + +def save_spectrograms( + ys: tp.List[np.ndarray], + sr: int, + path: str, + names: tp.List[str], + n_fft: int = 4096, + hop_length: int = 128, + dur: float = 8.0, +): + """Plot a spectrogram for an audio file. + + Args: + ys: List of audio spectrograms + sr (int): Sampling rate of the audio file. Default is 22050 Hz. + path (str): Path to the plot file. + names: name of each spectrogram plot + n_fft (int): Number of samples per FFT. Default is 2048. + hop_length (int): Number of samples between successive frames. Default is 512. + dur (float): Maxium duration to plot the spectrograms + + Returns: + None (plots the spectrogram using matplotlib) + """ + import matplotlib as mpl # type: ignore + import matplotlib.pyplot as plt # type: ignore + import librosa.display + + if not names: + names = ["Ground Truth", "Audio Watermarked", "Watermark"] + ys = [wav[: int(dur * sr)] for wav in ys] # crop + assert len(names) == len( + ys + ), f"There are {len(ys)} wavs but {len(names)} names ({names})" + + # Set matplotlib stuff + BIGGER_SIZE = 10 + SMALLER_SIZE = 8 + linewidth = 234.8775 # linewidth in pt + + plt.rc("font", size=BIGGER_SIZE, family="serif") # controls default text sizes + plt.rcParams["font.family"] = "DeJavu Serif" + plt.rcParams["font.serif"] = ["Times New Roman"] + + plt.rc("axes", titlesize=BIGGER_SIZE) # fontsize of the axes title + plt.rc("axes", labelsize=BIGGER_SIZE) # fontsize of the x and y labels + plt.rc("xtick", labelsize=BIGGER_SIZE) # fontsize of the tick labels + plt.rc("ytick", labelsize=SMALLER_SIZE) # fontsize of the tick labels + plt.rc("legend", fontsize=BIGGER_SIZE) # legend fontsize + plt.rc("figure", titlesize=BIGGER_SIZE) + height = 1.6 * linewidth / 72.0 + fig, ax = plt.subplots( + nrows=len(ys), + ncols=1, + sharex=True, + figsize=(linewidth / 72.0, height), + ) + fig.tight_layout() + + # Plot the spectrogram + + for i, ysi in enumerate(ys): + spectrogram_db = get_spec(ysi, sr=sr, n_fft=n_fft, hop_length=hop_length) + if i == 0: + cax = fig.add_axes( + [ + ax[0].get_position().x1 + 0.01, # type: ignore + ax[-1].get_position().y0, + 0.02, + ax[0].get_position().y1 - ax[-1].get_position().y0, + ] + ) + fig.colorbar( + mpl.cm.ScalarMappable( + norm=mpl.colors.Normalize( + np.min(spectrogram_db), np.max(spectrogram_db) + ), + cmap="magma", + ), + ax=ax, + orientation="vertical", + format="%+2.0f dB", + cax=cax, + ) + librosa.display.specshow( + spectrogram_db, + sr=sr, + hop_length=hop_length, + x_axis="time", + y_axis="mel", + ax=ax[i], + ) + ax[i].set(title=names[i]) + ax[i].yaxis.set_label_text(None) + ax[i].label_outer() + fig.savefig(path, bbox_inches="tight") + plt.close() diff --git a/audiocraft/data/audio_utils.py b/audiocraft/data/audio_utils.py index 9d3129b8..cf71b990 100644 --- a/audiocraft/data/audio_utils.py +++ b/audiocraft/data/audio_utils.py @@ -5,6 +5,9 @@ # LICENSE file in the root directory of this source tree. """Various utilities for audio convertion (pcm format, sample rate and channels), and volume normalization.""" +import io +import logging +import re import sys import typing as tp @@ -12,6 +15,8 @@ import torch import torchaudio +logger = logging.getLogger(__name__) + def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor: """Convert audio to the given number of channels. @@ -84,7 +89,9 @@ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None: - """Utility function to clip the audio with logging if specified.""" + """ + Utility function to clip the audio with logging if specified. + """ max_scale = wav.abs().max() if log_clipping and max_scale > 1: clamp_prob = (wav.abs() > 1).float().mean().item() @@ -146,7 +153,12 @@ def normalize_audio(wav: torch.Tensor, normalize: bool = True, def f32_pcm(wav: torch.Tensor) -> torch.Tensor: - """Convert audio to float 32 bits PCM format. + """ + Convert audio to float 32 bits PCM format. + Args: + wav (torch.tensor): Input wav tensor + Returns: + same wav in float32 PCM format """ if wav.dtype.is_floating_point: return wav @@ -164,6 +176,10 @@ def i16_pcm(wav: torch.Tensor) -> torch.Tensor: due to the asymmetry of the int16 range. One either have possible clipping, DC offset, or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom, it is possible that `i16_pcm(f32_pcm)) != Identity`. + Args: + wav (torch.tensor): Input wav tensor + Returns: + same wav in float16 PCM format """ if wav.dtype.is_floating_point: assert wav.abs().max() <= 1 @@ -174,3 +190,185 @@ def i16_pcm(wav: torch.Tensor) -> torch.Tensor: else: assert wav.dtype == torch.int16 return wav + + +def compress(wav: torch.Tensor, sr: int, + target_format: tp.Literal["mp3", "ogg", "flac"] = "mp3", + bitrate: str = "128k") -> tp.Tuple[torch.Tensor, int]: + """Convert audio wave form to a specified lossy format: mp3, ogg, flac + + Args: + wav (torch.Tensor): Input wav tensor. + sr (int): Sampling rate. + target_format (str): Compression format (e.g., 'mp3'). + bitrate (str): Bitrate for compression. + + Returns: + Tuple of compressed WAV tensor and sampling rate. + """ + + # Extract the bit rate from string (e.g., '128k') + match = re.search(r"\d+(\.\d+)?", str(bitrate)) + parsed_bitrate = float(match.group()) if match else None + assert parsed_bitrate, f"Invalid bitrate specified (got {parsed_bitrate})" + try: + # Create a virtual file instead of saving to disk + buffer = io.BytesIO() + + torchaudio.save( + buffer, wav, sr, format=target_format, bits_per_sample=parsed_bitrate, + ) + # Move to the beginning of the file + buffer.seek(0) + compressed_wav, sr = torchaudio.load(buffer) + return compressed_wav, sr + + except RuntimeError: + logger.warning( + f"compression failed skipping compression: {format} {parsed_bitrate}" + ) + return wav, sr + + +def get_mp3(wav_tensor: torch.Tensor, sr: int, bitrate: str = "128k") -> torch.Tensor: + """Convert a batch of audio files to MP3 format, maintaining the original shape. + + This function takes a batch of audio files represented as a PyTorch tensor, converts + them to MP3 format using the specified bitrate, and returns the batch in the same + shape as the input. + + Args: + wav_tensor (torch.Tensor): Batch of audio files represented as a tensor. + Shape should be (batch_size, channels, length). + sr (int): Sampling rate of the audio. + bitrate (str): Bitrate for MP3 conversion, default is '128k'. + + Returns: + torch.Tensor: Batch of audio files converted to MP3 format, with the same + shape as the input tensor. + """ + device = wav_tensor.device + batch_size, channels, original_length = wav_tensor.shape + + # Flatten tensor for conversion and move to CPU + wav_tensor_flat = wav_tensor.view(1, -1).cpu() + + # Convert to MP3 format with specified bitrate + wav_tensor_flat, _ = compress(wav_tensor_flat, sr, bitrate=bitrate) + + # Reshape back to original batch format and trim or pad if necessary + wav_tensor = wav_tensor_flat.view(batch_size, channels, -1) + compressed_length = wav_tensor.shape[-1] + if compressed_length > original_length: + wav_tensor = wav_tensor[:, :, :original_length] # Trim excess frames + elif compressed_length < original_length: + padding = torch.zeros( + batch_size, channels, original_length - compressed_length, device=device + ) + wav_tensor = torch.cat((wav_tensor, padding), dim=-1) # Pad with zeros + + # Move tensor back to the original device + return wav_tensor.to(device) + + +def get_aac( + wav_tensor: torch.Tensor, + sr: int, + bitrate: str = "128k", + lowpass_freq: tp.Optional[int] = None, +) -> torch.Tensor: + """Converts a batch of audio tensors to AAC format and then back to tensors. + + This function first saves the input tensor batch as WAV files, then uses FFmpeg to convert + these WAV files to AAC format. Finally, it loads the AAC files back into tensors. + + Args: + wav_tensor (torch.Tensor): A batch of audio files represented as a tensor. + Shape should be (batch_size, channels, length). + sr (int): Sampling rate of the audio. + bitrate (str): Bitrate for AAC conversion, default is '128k'. + lowpass_freq (Optional[int]): Frequency for a low-pass filter. If None, no filter is applied. + + Returns: + torch.Tensor: Batch of audio files converted to AAC and back, with the same + shape as the input tensor. + """ + import tempfile + import subprocess + + device = wav_tensor.device + batch_size, channels, original_length = wav_tensor.shape + + # Parse the bitrate value from the string + match = re.search(r"\d+(\.\d+)?", bitrate) + parsed_bitrate = ( + match.group() if match else "128" + ) # Default to 128 if parsing fails + + # Flatten tensor for conversion and move to CPU + wav_tensor_flat = wav_tensor.view(1, -1).cpu() + + with tempfile.NamedTemporaryFile( + suffix=".wav" + ) as f_in, tempfile.NamedTemporaryFile(suffix=".aac") as f_out: + input_path, output_path = f_in.name, f_out.name + + # Save the tensor as a WAV file + torchaudio.save(input_path, wav_tensor_flat, sr, backend="ffmpeg") + + # Prepare FFmpeg command for AAC conversion + command = [ + "ffmpeg", + "-y", + "-i", + input_path, + "-ar", + str(sr), + "-b:a", + f"{parsed_bitrate}k", + "-c:a", + "aac", + ] + if lowpass_freq is not None: + command += ["-cutoff", str(lowpass_freq)] + command.append(output_path) + + try: + # Run FFmpeg and suppress output + subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + # Load the AAC audio back into a tensor + aac_tensor, _ = torchaudio.load(output_path, backend="ffmpeg") + except Exception as exc: + raise RuntimeError( + "Failed to run command " ".join(command)} " + "(Often this means ffmpeg is not installed or the encoder is not supported, " + "make sure you installed an older version ffmpeg<5)" + ) from exc + + original_length_flat = batch_size * channels * original_length + compressed_length_flat = aac_tensor.shape[-1] + + # Trim excess frames + if compressed_length_flat > original_length_flat: + aac_tensor = aac_tensor[:, :original_length_flat] + + # Pad the shortedn frames + elif compressed_length_flat < original_length_flat: + padding = torch.zeros( + 1, original_length_flat - compressed_length_flat, device=device + ) + aac_tensor = torch.cat((aac_tensor, padding), dim=-1) + + # Reshape and adjust length to match original tensor + wav_tensor = aac_tensor.view(batch_size, channels, -1) + compressed_length = wav_tensor.shape[-1] + + assert compressed_length == original_length, ( + "AAC-compressed audio does not have the same frames as original one. " + "One reason can be ffmpeg is not installed and used as proper backed " + "for torchaudio, or the AAC encoder is not correct. Run " + "`torchaudio.utils.ffmpeg_utils.get_audio_encoders()` and make sure we see entry for" + "AAC in the output." + ) + return wav_tensor.to(device) diff --git a/audiocraft/grids/watermarking/__init__.py b/audiocraft/grids/watermarking/__init__.py new file mode 100644 index 00000000..d930fecc --- /dev/null +++ b/audiocraft/grids/watermarking/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""watermarking grids.""" diff --git a/audiocraft/grids/watermarking/_explorers.py b/audiocraft/grids/watermarking/_explorers.py new file mode 100644 index 00000000..7dd0b784 --- /dev/null +++ b/audiocraft/grids/watermarking/_explorers.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import treetable as tt + +from .._base_explorers import BaseExplorer + + +class WatermarkingMbExplorer(BaseExplorer): + eval_metrics = ["acc", "bit_acc", "visqol", "fnr", "fpr", "sisnr"] + + def stages(self): + return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"] + + def get_grid_meta(self): + """Returns the list of Meta information to display for each XP/job.""" + return [ + tt.leaf("index", align=">"), + tt.leaf("name", wrap=140), + tt.leaf("state"), + tt.leaf("sig", align=">"), + ] + + def get_grid_metrics(self): + """Return the metrics that should be displayed in the tracking table.""" + return [ + tt.group( + "train", + [ + tt.leaf("epoch"), + tt.leaf("sisnr", ".3%"), + tt.leaf("wm_detection_identity", ".3%"), + tt.leaf("wm_mb_identity", ".3%"), + ], + align=">", + ), + tt.group( + "valid", + [ + tt.leaf("sisnr", ".3%"), + tt.leaf("wm_detection_identity", ".3%"), + tt.leaf("wm_mb_identity", ".3%"), + # tt.leaf("loss_0", ".3%"), + ], + align=">", + ), + tt.group( + "evaluate", + [ + tt.leaf("aug_identity_acc", ".4f"), + tt.leaf("aug_identity_fnr", ".4f"), + tt.leaf("aug_identity_fpr", ".4f"), + tt.leaf("aug_identity_bit_acc", ".4f"), + tt.leaf("pesq", ".4f"), + tt.leaf("all_aug_acc", ".4f"), + tt.leaf("localization_acc_padding", ".4f"), + ], + align=">", + ), + ] + + +class WatermarkingExplorer(BaseExplorer): + eval_metrics = ["acc", "visqol", "fnr", "fpr", "sisnr"] + + def stages(self): + return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"] + + def get_grid_meta(self): + """Returns the list of Meta information to display for each XP/job.""" + return [ + tt.leaf("index", align=">"), + tt.leaf("name", wrap=140), + tt.leaf("state"), + tt.leaf("sig", align=">"), + ] + + def get_grid_metrics(self): + """Return the metrics that should be displayed in the tracking table.""" + return [ + tt.group( + "train", + [ + tt.leaf("epoch"), + tt.leaf("sisnr", ".3f"), + tt.leaf("wm_detection_identity"), + ], + align=">", + ), + tt.group( + "valid", + [ + tt.leaf("sisnr", ".3f"), + tt.leaf("wm_detection_identity"), + # tt.leaf("loss_0", ".3%"), + ], + align=">", + ), + tt.group( + "evaluate", + [ + tt.leaf("aug_identity_acc", ".4f"), + tt.leaf("aug_identity_fnr", ".4f"), + tt.leaf("aug_identity_fpr", ".4f"), + tt.leaf("pesq", ".4f"), + tt.leaf("all_aug_acc", ".4f"), + tt.leaf("localization_acc_padding", ".4f"), + + ], + align=">", + ), + ] diff --git a/audiocraft/grids/watermarking/audioseal.py b/audiocraft/grids/watermarking/audioseal.py new file mode 100644 index 00000000..84fd86ed --- /dev/null +++ b/audiocraft/grids/watermarking/audioseal.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +""" +dora grid watermarking.audioseal --clear +""" +from audiocraft.environment import AudioCraftEnvironment +from ._explorers import WatermarkingExplorer + + +@WatermarkingExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_( + gpus=8, + partition=partitions, + constraint="volta32gb", + ) + launcher.bind_( + { + "solver": "watermark/robustness", + "dset": "audio/example", + } + ) + launcher.bind_(label="audioseal") + + with launcher.job_array(): + launcher() diff --git a/audiocraft/grids/watermarking/kbits.py b/audiocraft/grids/watermarking/kbits.py new file mode 100644 index 00000000..b86bf890 --- /dev/null +++ b/audiocraft/grids/watermarking/kbits.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +dora grid watermarking.kbits --clear +""" +import os +from audiocraft.environment import AudioCraftEnvironment +from ._explorers import WatermarkingMbExplorer + + +@WatermarkingMbExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_( + gpus=8, + partition=partitions, + constraint="volta32gb", + ) + launcher.bind_( + { + "solver": "watermark/robustness", + "dset": os.getenv("AUDIOCRAFT_DSET", "audio/example"), + "dataset.batch_size": 16, + # optim + "optim.epochs": 300, + "schedule": { + "lr_scheduler": "cosine", + "cosine": { + "warmup": 4000, + "lr_min_ratio": 0.0, + "cycle_length": 1.0, + }, + }, + # crop and padding + "crop": { + "prob": 0.4, + "shuffle_prob": 0.2, + "pad_prob": 0.2, + "size": 0.5, + "max_n_windows": 5, + }, + # augmentations + "select_aug_mode": 'use_eval', + "aug_weights.updownresample": 0.1, + "aug_weights.speed": 0.1, + "aug_weights.echo": 0.1, + "aug_weights.pink_noise": 0.1, + "aug_weights.lowpass_filter": 0.1, + "aug_weights.highpass_filter": 0.1, + "aug_weights.bandpass_filter": 0.1, + "aug_weights.smooth": 0.1, + "aug_weights.boost_audio": 0.1, + "aug_weights.duck_audio": 0.1, + "aug_weights.mp3_compression": 0.1, + "aug_weights.encodec": 0.1, + "aug_weights.identity": 1.0, + # multi-bit + "audioseal.nbits": 16, + "detector.output_dim": 32, + "wm_mb.loss_type": "bce", + "wm_mb.temperature": 0.1, + # losses + "losses": { # encodec loss + tf = 10 + "adv": 4.0, + "feat": 4.0, + "l1": 0.1, + "mel": 0.0, + "msspec": 2.0, + "sisnr": 0.0, + "tf_loudnessratio": 10.0, + }, + "losses.wm_detection": 1.0, + "losses.wm_mb": 1.0, + } + ) + launcher.bind_(label="kbits16") + + lrs = [5e-5] + seeds = [1, 2, 3, 4] + + with launcher.job_array(): + for lr in lrs: + for seed in seeds: + launcher({ + "optim.lr": lr, + "seed": seed, + }) diff --git a/audiocraft/losses/__init__.py b/audiocraft/losses/__init__.py index d55107b2..272d6bdb 100644 --- a/audiocraft/losses/__init__.py +++ b/audiocraft/losses/__init__.py @@ -19,3 +19,10 @@ MelSpectrogramL1Loss, MultiScaleMelSpectrogramLoss, ) + +from .wmloss import ( + WMDetectionLoss, + WMMbLoss +) + +from .loudnessloss import TFLoudnessRatio diff --git a/audiocraft/losses/loudnessloss.py b/audiocraft/losses/loudnessloss.py new file mode 100644 index 00000000..c1803878 --- /dev/null +++ b/audiocraft/losses/loudnessloss.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import typing as tp + +import julius +import torch +import torchaudio +from torch import nn +from torch.nn import functional as F +from torchaudio.functional.filtering import highpass_biquad, treble_biquad + + +def basic_loudness(waveform: torch.Tensor, sample_rate: int) -> torch.Tensor: + """This is a simpler loudness function that is more stable. + Args: + waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)` + sample_rate (int): sampling rate of the waveform + Returns: + loudness loss as a scalar + """ + + if waveform.size(-2) > 5: + raise ValueError("Only up to 5 channels are supported.") + eps = torch.finfo(torch.float32).eps + gate_duration = 0.4 + overlap = 0.75 + gate_samples = int(round(gate_duration * sample_rate)) + step = int(round(gate_samples * (1 - overlap))) + + # Apply K-weighting + waveform = treble_biquad(waveform, sample_rate, 4.0, 1500.0, 1 / math.sqrt(2)) + waveform = highpass_biquad(waveform, sample_rate, 38.0, 0.5) + + # Compute the energy for each block + energy = torch.square(waveform).unfold(-1, gate_samples, step) + energy = torch.mean(energy, dim=-1) + + # Compute channel-weighted summation + g = torch.tensor([1.0, 1.0, 1.0, 1.41, 1.41], dtype=waveform.dtype, device=waveform.device) + g = g[: energy.size(-2)] + + energy_weighted = torch.sum(g.unsqueeze(-1) * energy, dim=-2) + # loudness with epsilon for stability. Not as much precision in the very low loudness sections + loudness = -0.691 + 10 * torch.log10(energy_weighted + eps) + return loudness + + +def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: + """Given input of size [*OT, T], output Tensor of size [*OT, F, K] + with K the kernel size, by extracting frames with the given stride. + This will pad the input so that `F = ceil(T / K)`. + see https://github.com/pytorch/pytorch/issues/60466 + """ + *shape, length = a.shape + n_frames = math.ceil(length / stride) + tgt_length = (n_frames - 1) * stride + kernel_size + a = F.pad(a, (0, tgt_length - length)) + strides = list(a.stride()) + assert strides[-1] == 1, "data should be contiguous" + strides = strides[:-1] + [stride, 1] + return a.as_strided([*shape, n_frames, kernel_size], strides) + + +class FLoudnessRatio(nn.Module): + """FSNR loss. + + Input should be [B, C, T], output is scalar. + + Args: + sample_rate (int): Sample rate. + segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on + entire audio only. + overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. + epsilon (float): Epsilon value for numerical stability. + n_bands (int): number of mel scale bands that we include + """ + def __init__( + self, + sample_rate: int = 16000, + segment: tp.Optional[float] = 20, + overlap: float = 0.5, + epsilon: float = torch.finfo(torch.float32).eps, + n_bands: int = 0, + ): + super().__init__() + self.sample_rate = sample_rate + self.segment = segment + self.overlap = overlap + self.epsilon = epsilon + if n_bands == 0: + self.filter = None + else: + self.filter = julius.SplitBands(sample_rate=sample_rate, n_bands=n_bands) + self.loudness = torchaudio.transforms.Loudness(sample_rate) + + def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: + B, C, T = ref_sig.shape + assert ref_sig.shape == out_sig.shape + assert self.filter is not None + bands_ref = self.filter(ref_sig) + bands_out = self.filter(out_sig) + l_noise = self.loudness(bands_ref - bands_out) + l_ref = self.loudness(bands_ref) + l_ratio = (l_noise - l_ref).view(-1, B) + loss = torch.nn.functional.softmax(l_ratio, dim=0) * l_ratio + return loss.sum() + + +class TLoudnessRatio(nn.Module): + """TSNR loss. + + Input should be [B, C, T], output is scalar. + + Args: + sample_rate (int): Sample rate. + segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on + entire audio only. + overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. + """ + def __init__( + self, + sample_rate: int = 16000, + segment: float = 0.5, + overlap: float = 0.5, + ): + super().__init__() + self.sample_rate = sample_rate + self.segment = segment + self.overlap = overlap + self.loudness = torchaudio.transforms.Loudness(sample_rate) + + def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: + B, C, T = ref_sig.shape + assert ref_sig.shape == out_sig.shape + assert C == 1 + + frame = int(self.segment * self.sample_rate) + stride = int(frame * (1 - self.overlap)) + gt = _unfold(ref_sig, frame, stride).view(-1, 1, frame) + est = _unfold(out_sig, frame, stride).view(-1, 1, frame) + l_noise = self.loudness(gt - est) # watermark + l_ref = self.loudness(gt) # ground truth + l_ratio = (l_noise - l_ref).view(-1, B) + loss = torch.nn.functional.softmax(l_ratio, dim=0) * l_ratio + return loss.sum() + + +class TFLoudnessRatio(nn.Module): + """TF-loudness ratio loss. + + Input should be [B, C, T], output is scalar. + + Args: + sample_rate (int): Sample rate. + segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on + entire audio only. + overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. + n_bands (int): number of bands to separate + temperature (float): temperature of the softmax step + """ + def __init__( + self, + sample_rate: int = 16000, + segment: float = 0.5, + overlap: float = 0.5, + n_bands: int = 0, + clip_min: float = -100, + temperature: float = 1.0, + ): + super().__init__() + self.sample_rate = sample_rate + self.segment = segment + self.overlap = overlap + self.clip_min = clip_min + self.temperature = temperature + if n_bands == 0: + self.filter = None + else: + self.n_bands = n_bands + self.filter = julius.SplitBands(sample_rate=sample_rate, n_bands=n_bands) + + def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: + B, C, T = ref_sig.shape + + assert ref_sig.shape == out_sig.shape + assert C == 1 + assert self.filter is not None + + bands_ref = self.filter(ref_sig).view(B * self.n_bands, 1, -1) + bands_out = self.filter(out_sig).view(B * self.n_bands, 1, -1) + frame = int(self.segment * self.sample_rate) + stride = int(frame * (1 - self.overlap)) + gt = _unfold(bands_ref, frame, stride).squeeze(1).contiguous().view(-1, 1, frame) + est = _unfold(bands_out, frame, stride).squeeze(1).contiguous().view(-1, 1, frame) + l_noise = basic_loudness(est - gt, sample_rate=self.sample_rate) # watermark + l_ref = basic_loudness(gt, sample_rate=self.sample_rate) # ground truth + l_ratio = (l_noise - l_ref).view(-1, B) + loss = torch.nn.functional.softmax(l_ratio / self.temperature, dim=0) * l_ratio + return loss.mean() diff --git a/audiocraft/losses/wmloss.py b/audiocraft/losses/wmloss.py new file mode 100644 index 00000000..588938fd --- /dev/null +++ b/audiocraft/losses/wmloss.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Literal + +import torch +import torch.nn as nn + + +class WMDetectionLoss(nn.Module): + """Compute the detection loss""" + def __init__(self, p_weight: float = 1.0, n_weight: float = 1.0) -> None: + super().__init__() + self.criterion = nn.NLLLoss() + self.p_weight = p_weight + self.n_weight = n_weight + + def forward(self, positive, negative, mask, message=None): + + positive = positive[:, :2, :] # b 2+nbits t -> b 2 t + negative = negative[:, :2, :] # b 2+nbits t -> b 2 t + + # dimensionality of positive [bsz, classes=2, time_steps] + # correct classes for pos = [bsz, time_steps] where all values = 1 for positive + classes_shape = positive[ + :, 0, : + ] # same as positive or negative but dropping dim=1 + pos_correct_classes = torch.ones_like(classes_shape, dtype=int) + neg_correct_classes = torch.zeros_like(classes_shape, dtype=int) + + # taking log because network outputs softmax + # NLLLoss expects a logsoftmax input + positive = torch.log(positive) + negative = torch.log(negative) + + if not torch.all(mask == 1): + # pos_correct_classes [bsz, timesteps] mask [bsz, 1, timesptes] + # mask is applied to the watermark, this basically flips the tgt class from 1 (positive) + # to 0 (negative) in the correct places + pos_correct_classes = pos_correct_classes * mask[:, 0, :].to(int) + loss_p = self.p_weight * self.criterion(positive, pos_correct_classes) + # no need for negative class loss here since some of the watermark + # is masked to negative + return loss_p + + else: + loss_p = self.p_weight * self.criterion(positive, pos_correct_classes) + loss_n = self.n_weight * self.criterion(negative, neg_correct_classes) + return loss_p + loss_n + + +class WMMbLoss(nn.Module): + def __init__(self, temperature: float, loss_type: Literal["bce", "mse"]) -> None: + """ + Compute the masked sample-level detection loss + (https://arxiv.org/pdf/2401.17264) + + Args: + temperature: temperature for loss computation + loss_type: bce or mse between outputs and original message + """ + super().__init__() + self.bce_with_logits = ( + nn.BCEWithLogitsLoss() + ) # same as Softmax + NLLLoss, but when only 1 output unit + self.mse = nn.MSELoss() + self.loss_type = loss_type + self.temperature = temperature + + def forward(self, positive, negative, mask, message): + """ + Compute decoding loss + Args: + positive: outputs on watermarked samples [bsz, 2+nbits, time_steps] + negative: outputs on not watermarked samples [bsz, 2+nbits, time_steps] + mask: watermark mask [bsz, 1, time_steps] + message: original message [bsz, nbits] or None + """ + # # no use of negative at the moment + # negative = negative[:, 2:, :] # b 2+nbits t -> b nbits t + # negative = torch.masked_select(negative, mask) + if message.size(0) == 0: + return torch.tensor(0.0) + positive = positive[:, 2:, :] # b 2+nbits t -> b nbits t + assert ( + positive.shape[-2] == message.shape[1] + ), "in decoding loss: \ + enc and dec don't share nbits, are you using multi-bit?" + + # cut last dim of positive to keep only where mask is 1 + new_shape = [*positive.shape[:-1], -1] # b nbits -1 + positive = torch.masked_select(positive, mask == 1).reshape(new_shape) + + message = message.unsqueeze(-1).repeat(1, 1, positive.shape[2]) # b k -> b k t + if self.loss_type == "bce": + # in this case similar to temperature in softmax + loss = self.bce_with_logits(positive / self.temperature, message.float()) + elif self.loss_type == "mse": + loss = self.mse(positive / self.temperature, message.float()) + + return loss diff --git a/audiocraft/metrics/miou.py b/audiocraft/metrics/miou.py new file mode 100644 index 00000000..c705fe65 --- /dev/null +++ b/audiocraft/metrics/miou.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def calculate_miou(y_pred: torch.Tensor, y_true: torch.Tensor) -> float: + """ + Calculate the mean Intersection over Union (mIoU) between two binary tensors using PyTorch. + + Args: + y_pred (torch.Tensor): Predicted binary tensor of shape [bsz, frames]. + y_true (torch.Tensor): Ground truth binary tensor of shape [bsz, frames]. + + Returns: + float: The mean Intersection over Union (mIoU) score. + + Reference: + The Intersection over Union (IoU) metric is commonly used in computer vision. + For more information, refer to the following paper: + "SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation" + by Vijay Badrinarayanan, Alex Kendall, Roberto Cipolla + """ + # Ensure y_pred and y_true have the same shape + if y_pred.shape != y_true.shape: + raise ValueError("Input tensors must have the same shape") + + # converting predictions to binary vector + y_pred = y_pred > 0.5 + # Compute the intersection and union + intersection = torch.logical_and(y_pred, y_true) + union = torch.logical_or(y_pred, y_true) + + # Compute IoU for each sample in the batch + iou_per_sample = torch.sum(intersection, dim=1) / torch.sum(union, dim=1) + # Calculate mIoU by taking the mean across the batch + miou = torch.mean(iou_per_sample).item() + + return miou diff --git a/audiocraft/metrics/pesq.py b/audiocraft/metrics/pesq.py new file mode 100644 index 00000000..744ca759 --- /dev/null +++ b/audiocraft/metrics/pesq.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import julius +import pesq + +import torch +import torchmetrics + + +class PesqMetric(torchmetrics.Metric): + """Metric for Perceptual Evaluation of Speech Quality. + (https://doi.org/10.5281/zenodo.6549559) + + """ + + sum_pesq: torch.Tensor + total: torch.Tensor + + def __init__(self, sample_rate: int): + super().__init__() + self.sr = sample_rate + + self.add_state("sum_pesq", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, targets: torch.Tensor): + if self.sr != 16000: + preds = julius.resample_frac(preds, self.sr, 16000) + targets = julius.resample_frac(targets, self.sr, 16000) + for ii in range(preds.size(0)): + try: + self.sum_pesq += pesq.pesq( + 16000, targets[ii, 0].detach().cpu().numpy(), preds[ii, 0].detach().cpu().numpy() + ) + self.total += 1 + except ( + pesq.NoUtterancesError + ): # this error can append when the sample don't contain speech + pass + + def compute(self) -> torch.Tensor: + return ( + self.sum_pesq / self.total + if (self.total != 0).item() + else torch.tensor(0.0) + ) diff --git a/audiocraft/models/__init__.py b/audiocraft/models/__init__.py index a6b49825..3f485fed 100644 --- a/audiocraft/models/__init__.py +++ b/audiocraft/models/__init__.py @@ -18,3 +18,4 @@ from .musicgen import MusicGen from .magnet import MAGNeT from .unet import DiffusionUnet +from .watermark import WMModel diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py index 66aa85c6..0bd0db60 100644 --- a/audiocraft/models/builders.py +++ b/audiocraft/models/builders.py @@ -11,52 +11,49 @@ import typing as tp -import audiocraft import omegaconf import torch -from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel +import audiocraft + +from .. import quantization as qt +from ..modules.codebooks_patterns import (CoarseFirstPattern, + CodebooksPatternProvider, + DelayedPatternProvider, + MusicLMPattern, + ParallelPatternProvider, + UnrolledPatternProvider) +from ..modules.conditioners import (BaseConditioner, ChromaStemConditioner, + CLAPEmbeddingConditioner, ConditionFuser, + ConditioningProvider, LUTConditioner, + T5Conditioner) +from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor +from ..utils.utils import dict_from_config +from .encodec import (CompressionModel, EncodecModel, + InterleaveStereoCompressionModel) from .lm import LMModel from .lm_magnet import MagnetLMModel -from ..modules.codebooks_patterns import ( - CodebooksPatternProvider, - DelayedPatternProvider, - MusicLMPattern, - ParallelPatternProvider, - UnrolledPatternProvider, - CoarseFirstPattern, -) -from ..modules.conditioners import ( - BaseConditioner, - ChromaStemConditioner, - CLAPEmbeddingConditioner, - ConditionFuser, - ConditioningProvider, - LUTConditioner, - T5Conditioner, -) from .unet import DiffusionUnet -from .. import quantization as qt -from ..utils.utils import dict_from_config -from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor +from .watermark import WMModel -def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer: - klass = { - 'no_quant': qt.DummyQuantizer, - 'rvq': qt.ResidualVectorQuantizer - }[quantizer] +def get_quantizer( + quantizer: str, cfg: omegaconf.DictConfig, dimension: int +) -> qt.BaseQuantizer: + klass = {"no_quant": qt.DummyQuantizer, "rvq": qt.ResidualVectorQuantizer}[ + quantizer + ] kwargs = dict_from_config(getattr(cfg, quantizer)) - if quantizer != 'no_quant': - kwargs['dimension'] = dimension + if quantizer != "no_quant": + kwargs["dimension"] = dimension return klass(**kwargs) def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): - if encoder_name == 'seanet': - kwargs = dict_from_config(getattr(cfg, 'seanet')) - encoder_override_kwargs = kwargs.pop('encoder') - decoder_override_kwargs = kwargs.pop('decoder') + if encoder_name == "seanet": + kwargs = dict_from_config(getattr(cfg, "seanet")) + encoder_override_kwargs = kwargs.pop("encoder") + decoder_override_kwargs = kwargs.pop("decoder") encoder_kwargs = {**kwargs, **encoder_override_kwargs} decoder_kwargs = {**kwargs, **decoder_override_kwargs} encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs) @@ -68,45 +65,55 @@ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: """Instantiate a compression model.""" - if cfg.compression_model == 'encodec': - kwargs = dict_from_config(getattr(cfg, 'encodec')) - encoder_name = kwargs.pop('autoencoder') - quantizer_name = kwargs.pop('quantizer') + if cfg.compression_model == "encodec": + kwargs = dict_from_config(getattr(cfg, "encodec")) + encoder_name = kwargs.pop("autoencoder") + quantizer_name = kwargs.pop("quantizer") encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) - frame_rate = kwargs['sample_rate'] // encoder.hop_length - renormalize = kwargs.pop('renormalize', False) + frame_rate = kwargs["sample_rate"] // encoder.hop_length + renormalize = kwargs.pop("renormalize", False) # deprecated params - kwargs.pop('renorm', None) - return EncodecModel(encoder, decoder, quantizer, - frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) + kwargs.pop("renorm", None) + return EncodecModel( + encoder, + decoder, + quantizer, + frame_rate=frame_rate, + renormalize=renormalize, + **kwargs, + ).to(cfg.device) else: raise KeyError(f"Unexpected compression model {cfg.compression_model}") def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: """Instantiate a transformer LM.""" - if cfg.lm_model in ['transformer_lm', 'transformer_lm_magnet']: - kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) - n_q = kwargs['n_q'] - q_modeling = kwargs.pop('q_modeling', None) - codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') - attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) - cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) - cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] + if cfg.lm_model in ["transformer_lm", "transformer_lm_magnet"]: + kwargs = dict_from_config(getattr(cfg, "transformer_lm")) + n_q = kwargs["n_q"] + q_modeling = kwargs.pop("q_modeling", None) + codebooks_pattern_cfg = getattr(cfg, "codebooks_pattern") + attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout")) + cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance")) + cfg_prob, cfg_coef = ( + cls_free_guidance["training_dropout"], + cls_free_guidance["inference_coef"], + ) fuser = get_condition_fuser(cfg) condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) - if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically - kwargs['cross_attention'] = True + if len(fuser.fuse2cond["cross"]) > 0: # enforce cross-att programmatically + kwargs["cross_attention"] = True if codebooks_pattern_cfg.modeling is None: - assert q_modeling is not None, \ - "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" + assert ( + q_modeling is not None + ), "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" codebooks_pattern_cfg = omegaconf.OmegaConf.create( - {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} + {"modeling": q_modeling, "delay": {"delays": list(range(n_q))}} ) pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) - lm_class = MagnetLMModel if cfg.lm_model == 'transformer_lm_magnet' else LMModel + lm_class = MagnetLMModel if cfg.lm_model == "transformer_lm_magnet" else LMModel return lm_class( pattern_provider=pattern_provider, condition_provider=condition_provider, @@ -116,67 +123,72 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: attribute_dropout=attribute_dropout, dtype=getattr(torch, cfg.dtype), device=cfg.device, - **kwargs + **kwargs, ).to(cfg.device) else: raise KeyError(f"Unexpected LM model {cfg.lm_model}") -def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: +def get_conditioner_provider( + output_dim: int, cfg: omegaconf.DictConfig +) -> ConditioningProvider: """Instantiate a conditioning model.""" device = cfg.device duration = cfg.dataset.segment_duration - cfg = getattr(cfg, 'conditioners') + cfg = getattr(cfg, "conditioners") dict_cfg = {} if cfg is None else dict_from_config(cfg) conditioners: tp.Dict[str, BaseConditioner] = {} - condition_provider_args = dict_cfg.pop('args', {}) - condition_provider_args.pop('merge_text_conditions_p', None) - condition_provider_args.pop('drop_desc_p', None) + condition_provider_args = dict_cfg.pop("args", {}) + condition_provider_args.pop("merge_text_conditions_p", None) + condition_provider_args.pop("drop_desc_p", None) for cond, cond_cfg in dict_cfg.items(): - model_type = cond_cfg['model'] + model_type = cond_cfg["model"] model_args = cond_cfg[model_type] - if model_type == 't5': - conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args) - elif model_type == 'lut': - conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args) - elif model_type == 'chroma_stem': + if model_type == "t5": + conditioners[str(cond)] = T5Conditioner( + output_dim=output_dim, device=device, **model_args + ) + elif model_type == "lut": + conditioners[str(cond)] = LUTConditioner( + output_dim=output_dim, **model_args + ) + elif model_type == "chroma_stem": conditioners[str(cond)] = ChromaStemConditioner( - output_dim=output_dim, - duration=duration, - device=device, - **model_args + output_dim=output_dim, duration=duration, device=device, **model_args ) - elif model_type == 'clap': + elif model_type == "clap": conditioners[str(cond)] = CLAPEmbeddingConditioner( - output_dim=output_dim, - device=device, - **model_args + output_dim=output_dim, device=device, **model_args ) else: raise ValueError(f"Unrecognized conditioning model: {model_type}") - conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args) + conditioner = ConditioningProvider( + conditioners, device=device, **condition_provider_args + ) return conditioner def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: """Instantiate a condition fuser object.""" - fuser_cfg = getattr(cfg, 'fuser') - fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate'] + fuser_cfg = getattr(cfg, "fuser") + fuser_methods = ["sum", "cross", "prepend", "input_interpolate"] fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) return fuser -def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: +def get_codebooks_pattern_provider( + n_q: int, cfg: omegaconf.DictConfig +) -> CodebooksPatternProvider: """Instantiate a codebooks pattern provider object.""" pattern_providers = { - 'parallel': ParallelPatternProvider, - 'delay': DelayedPatternProvider, - 'unroll': UnrolledPatternProvider, - 'coarse_first': CoarseFirstPattern, - 'musiclm': MusicLMPattern, + "parallel": ParallelPatternProvider, + "delay": DelayedPatternProvider, + "unroll": UnrolledPatternProvider, + "coarse_first": CoarseFirstPattern, + "musiclm": MusicLMPattern, } name = cfg.modeling kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} @@ -184,20 +196,23 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb return klass(n_q, **kwargs) -def get_debug_compression_model(device='cpu', sample_rate: int = 32000): +def get_debug_compression_model(device="cpu", sample_rate: int = 32000): """Instantiate a debug compression model to be used for unit tests.""" - assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model" + assert sample_rate in [ + 16000, + 32000, + ], "unsupported sample rate for debug compression model" model_ratios = { 16000: [10, 8, 8], # 25 Hz at 16kHz - 32000: [10, 8, 16] # 25 Hz at 32kHz + 32000: [10, 8, 16], # 25 Hz at 32kHz } ratios: tp.List[int] = model_ratios[sample_rate] frame_rate = 25 seanet_kwargs: dict = { - 'n_filters': 4, - 'n_residual_layers': 1, - 'dimension': 32, - 'ratios': ratios, + "n_filters": 4, + "n_residual_layers": 1, + "dimension": 32, + "ratios": ratios, } encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs) decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs) @@ -205,8 +220,13 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000): init_x = torch.randn(8, 32, 128) quantizer(init_x, 1) # initialize kmeans etc. compression_model = EncodecModel( - encoder, decoder, quantizer, - frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device) + encoder, + decoder, + quantizer, + frame_rate=frame_rate, + sample_rate=sample_rate, + channels=1, + ).to(device) return compression_model.eval() @@ -214,48 +234,106 @@ def get_diffusion_model(cfg: omegaconf.DictConfig): # TODO Find a way to infer the channels from dset channels = cfg.channels num_steps = cfg.schedule.num_steps - return DiffusionUnet( - chin=channels, num_steps=num_steps, **cfg.diffusion_unet) + return DiffusionUnet(chin=channels, num_steps=num_steps, **cfg.diffusion_unet) def get_processor(cfg, sample_rate: int = 24000): sample_processor = SampleProcessor() if cfg.use: kw = dict(cfg) - kw.pop('use') - kw.pop('name') + kw.pop("use") + kw.pop("name") if cfg.name == "multi_band_processor": sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw) return sample_processor -def get_debug_lm_model(device='cpu'): +def get_debug_lm_model(device="cpu"): """Instantiate a debug LM to be used for unit tests.""" pattern = DelayedPatternProvider(n_q=4) dim = 16 providers = { - 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"), + "description": LUTConditioner( + n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace" + ), } condition_provider = ConditioningProvider(providers) fuser = ConditionFuser( - {'cross': ['description'], 'prepend': [], - 'sum': [], 'input_interpolate': []}) + {"cross": ["description"], "prepend": [], "sum": [], "input_interpolate": []} + ) lm = LMModel( - pattern, condition_provider, fuser, - n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2, - cross_attention=True, causal=True) + pattern, + condition_provider, + fuser, + n_q=4, + card=400, + dim=dim, + num_heads=4, + custom=True, + num_layers=2, + cross_attention=True, + causal=True, + ) return lm.to(device).eval() def get_wrapped_compression_model( - compression_model: CompressionModel, - cfg: omegaconf.DictConfig) -> CompressionModel: - if hasattr(cfg, 'interleave_stereo_codebooks'): + compression_model: CompressionModel, cfg: omegaconf.DictConfig +) -> CompressionModel: + if hasattr(cfg, "interleave_stereo_codebooks"): if cfg.interleave_stereo_codebooks.use: kwargs = dict_from_config(cfg.interleave_stereo_codebooks) - kwargs.pop('use') - compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs) - if hasattr(cfg, 'compression_model_n_q'): + kwargs.pop("use") + compression_model = InterleaveStereoCompressionModel( + compression_model, **kwargs + ) + if hasattr(cfg, "compression_model_n_q"): if cfg.compression_model_n_q is not None: compression_model.set_num_codebooks(cfg.compression_model_n_q) return compression_model + + +def get_watermark_model(cfg: omegaconf.DictConfig) -> WMModel: + """Build a WMModel based by audioseal. This requires audioseal to be installed""" + import audioseal + + from .watermark import AudioSeal + + # Builder encoder and decoder directly using audiocraft API to avoid cyclic import + assert hasattr( + cfg, "seanet" + ), "Missing required `seanet` parameters in AudioSeal config" + encoder, decoder = get_encodec_autoencoder("seanet", cfg) + + # Build message processor + kwargs = ( + dict_from_config(getattr(cfg, "audioseal")) if hasattr(cfg, "audioseal") else {} + ) + nbits = kwargs.get("nbits", 0) + hidden_size = getattr(cfg.seanet, "dimension", 128) + msg_processor = audioseal.MsgProcessor(nbits, hidden_size=hidden_size) + + # Build detector using audioseal API + def _get_audioseal_detector(): + # We don't need encoder and decoder params from seanet, remove them + seanet_cfg = dict_from_config(cfg.seanet) + seanet_cfg.pop("encoder") + seanet_cfg.pop("decoder") + detector_cfg = dict_from_config(cfg.detector) + + typed_seanet_cfg = audioseal.builder.SEANetConfig(**seanet_cfg) + typed_detector_cfg = audioseal.builder.DetectorConfig(**detector_cfg) + _cfg = audioseal.builder.AudioSealDetectorConfig( + nbits=nbits, seanet=typed_seanet_cfg, detector=typed_detector_cfg + ) + return audioseal.builder.create_detector(_cfg) + + detector = _get_audioseal_detector() + generator = audioseal.AudioSealWM( + encoder=encoder, decoder=decoder, msg_processor=msg_processor + ) + model = AudioSeal(generator=generator, detector=detector, nbits=nbits) + + device = torch.device(getattr(cfg, "device", "cpu")) + dtype = getattr(torch, getattr(cfg, "dtype", "float32")) + return model.to(device=device, dtype=dtype) diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index a6ec475e..3c7dd069 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -28,6 +28,7 @@ import torch import audiocraft + from . import builders from .encodec import CompressionModel @@ -60,10 +61,13 @@ def _get_state_dict( else: assert filename is not None, "filename needs to be defined if using HF checkpoints" - file = hf_hub_download( - repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, - library_name="audiocraft", library_version=audiocraft.__version__) + repo_id=file_or_url_or_id, + filename=filename, + cache_dir=cache_dir, + library_name="audiocraft", + library_version=audiocraft.__version__, + ) return torch.load(file, map_location=device) @@ -71,14 +75,18 @@ def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_di return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) -def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): +def load_compression_model( + file_or_url_or_id: tp.Union[Path, str], + device="cpu", + cache_dir: tp.Optional[str] = None, +): pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) if 'pretrained' in pkg: return CompressionModel.get_pretrained(pkg['pretrained'], device=device) cfg = OmegaConf.create(pkg['xp.cfg']) cfg.device = str(device) model = builders.get_compression_model(cfg) - model.load_state_dict(pkg['best_state']) + model.load_state_dict(pkg["best_state"]) model.eval() return model @@ -136,6 +144,7 @@ def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_mod # MAGNeT models v1 support only xformers backend. from audiocraft.modules.transformer import set_efficient_attention_backend + if cfg.transformer_lm.memory_efficient: set_efficient_attention_backend("xformers") @@ -175,3 +184,68 @@ def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], processors.append(processor) cfgs.append(cfg) return models, processors, cfgs + + +def load_audioseal_models( + file_or_url_or_id: tp.Union[Path, str], + device="cpu", + filename: tp.Optional[str] = None, + cache_dir: tp.Optional[str] = None, +): + + detector_ckpt = _get_state_dict( + file_or_url_or_id, + filename=f"detector_{filename}.pth", + device=device, + cache_dir=cache_dir, + ) + assert ( + "model" in detector_ckpt + ), f"No model state dict found in {file_or_url_or_id}/detector_{filename}.pth" + detector_state = detector_ckpt["model"] + + generator_ckpt = _get_state_dict( + file_or_url_or_id, + filename=f"generator_{filename}.pth", + device=device, + cache_dir=cache_dir, + ) + assert ( + "model" in generator_ckpt + ), f"No model state dict found in {file_or_url_or_id}/generator_{filename}.pth" + generator_state = generator_ckpt["model"] + + def load_model_config(): + if Path(file_or_url_or_id).joinpath(f"{filename}.yaml").is_file(): + return OmegaConf.load(Path(file_or_url_or_id).joinpath(f"{filename}.yaml")) + elif file_or_url_or_id.startswith("https://"): + import requests # type: ignore + + resp = requests.get(f"{file_or_url_or_id}/{filename}.yaml") + return OmegaConf.create(resp.text) + else: + file = hf_hub_download( + repo_id=file_or_url_or_id, + filename=f"{filename}.yaml", + cache_dir=cache_dir, + library_name="audiocraft", + library_version=audiocraft.__version__, + ) + return OmegaConf.load(file) + + try: + cfg = load_model_config() + except Exception as exc: # noqa + cfg_fp = ( + Path(__file__) + .parents[2] + .joinpath("config", "model", "watermark", "default.yaml") + ) + cfg = OmegaConf.load(cfg_fp) + + OmegaConf.resolve(cfg) + model = builders.get_watermark_model(cfg) + + model.generator.load_state_dict(generator_state) + model.detector.load_state_dict(detector_state) + return model.to(device) diff --git a/audiocraft/models/watermark.py b/audiocraft/models/watermark.py new file mode 100644 index 00000000..7a762eec --- /dev/null +++ b/audiocraft/models/watermark.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import typing as tp +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from audiocraft.models.loaders import load_audioseal_models + + +class WMModel(ABC, nn.Module): + """ + A wrapper interface to different watermarking models for + training or evaluation purporses + """ + + @abstractmethod + def get_watermark( + self, + x: torch.Tensor, + message: tp.Optional[torch.Tensor] = None, + sample_rate: int = 16_000, + ) -> torch.Tensor: + """Get the watermark from an audio tensor and a message. + If the input message is None, a random message of + n bits {0,1} will be generated + """ + + @abstractmethod + def detect_watermark(self, x: torch.Tensor) -> torch.Tensor: + """Detect the watermarks from the audio signal + + Args: + x: Audio signal, size batch x frames + + Returns: + tensor of size (B, 2+n, frames) where: + Detection results of shape (B, 2, frames) + Message decoding results of shape (B, n, frames) + """ + + +class AudioSeal(WMModel): + """Wrap Audioseal (https://github.com/facebookresearch/audioseal) for the + training and evaluation. The generator and detector are jointly trained + """ + + def __init__( + self, + generator: nn.Module, + detector: nn.Module, + nbits: int = 0, + ): + super().__init__() + self.generator = generator # type: ignore + self.detector = detector # type: ignore + + # Allow to re-train an n-bit model with new 0-bit message + self.nbits = nbits if nbits else self.generator.msg_processor.nbits + + def get_watermark( + self, + x: torch.Tensor, + message: tp.Optional[torch.Tensor] = None, + sample_rate: int = 16_000, + ) -> torch.Tensor: + return self.generator.get_watermark(x, message=message, sample_rate=sample_rate) + + def detect_watermark(self, x: torch.Tensor) -> torch.Tensor: + """ + Detect the watermarks from the audio signal. The first two units of the output + are used for detection, the rest is used to decode the message. If the audio is + not watermarked, the message will be random. + + Args: + x: Audio signal, size batch x frames + Returns + torch.Tensor: Detection + decoding results of shape (B, 2+nbits, T). + """ + + # Getting the direct decoded message from the detector + result = self.detector.detector(x) # b x 2+nbits + # hardcode softmax on 2 first units used for detection + result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1) + return result + + def forward( # generator + self, + x: torch.Tensor, + message: tp.Optional[torch.Tensor] = None, + sample_rate: int = 16_000, + alpha: float = 1.0, + ) -> torch.Tensor: + """Apply the watermarking to the audio signal x with a tune-down ratio (default 1.0)""" + wm = self.get_watermark(x, message) + return x + alpha * wm + + @staticmethod + def get_pretrained(name="base", device=None) -> WMModel: + if device is None: + if torch.cuda.device_count(): + device = "cuda" + else: + device = "cpu" + return load_audioseal_models("facebook/audioseal", filename=name, device=device) diff --git a/audiocraft/modules/watermark.py b/audiocraft/modules/watermark.py new file mode 100644 index 00000000..f3a2e7e6 --- /dev/null +++ b/audiocraft/modules/watermark.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp +import random + +import torch + + +def pad( + x_wm: torch.Tensor, central: bool = False +) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Pad a watermarked signal at the begining and the end + + Args: + x_wm (torch.Tensor) : watermarked audio + central (bool): Whether to mask the middle of the wave (around 34%) or the two tails + (beginning and ending frames) + + Returns: + padded (torch.Tensor): padded signal + true_predictions(torch.Tensor): A binary mask where 1 represents + watermarked and 0 represents non-watermarked.""" + # keep at leat 34% of watermarked signal + max_start = int(0.33 * x_wm.size(-1)) + min_end = int(0.66 * x_wm.size(-1)) + starts = torch.randint(0, max_start, size=(x_wm.size(0),)) + ends = torch.randint(min_end, x_wm.size(-1), size=(x_wm.size(0),)) + mask = torch.zeros_like(x_wm) + for i in range(x_wm.size(0)): + mask[i, :, starts[i]: ends[i]] = 1 + if central: + mask = 1 - mask + padded = x_wm * mask + true_predictions = torch.cat([1 - mask, mask], dim=1) + return padded, true_predictions + + +def mix( + x: torch.Tensor, x_wm: torch.Tensor, window_size: float = 0.5, shuffle: bool = False +) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """ + Mixes a window of the non-watermarked audio signal 'x' into the watermarked audio signal 'x_wm'. + + This function takes two tensors of shape [batch, channels, frames], copies a window of 'x' with the specified + 'window_size' into 'x_wm', and returns a new tensor that is a mix between the watermarked (1 - mix_percent %) + and non-watermarked audio (mix_percent %). + + Args: + x (torch.Tensor): The non-watermarked audio signal tensor. + x_wm (torch.Tensor): The watermarked audio signal tensor. + window_size (float, optional): The percentage of 'x' to copy into 'x_wm' (between 0 and 1). + shuffle (bool): whether or no keep the mix from the same batch element + + Returns: + tuple: A tuple containing two tensors: + - mixed_tensor (torch.Tensor): The resulting mixed audio signal tensor. + - mask (torch.Tensor): A binary mask where 1 represents watermarked and 0 represents non-watermarked. + + Raises: + AssertionError: If 'window_size' is not between 0 and 1. + """ + assert 0 < window_size <= 1, "window_size should be between 0 and 1" + + # Calculate the maximum starting point for the window + max_start_point = x.shape[-1] - int(window_size * x.shape[-1]) + + # Generate a random starting point within the adjusted valid range + start_point = random.randint(0, max_start_point) + + # Calculate the window size in frames + total_frames = x.shape[-1] + window_frames = int(window_size * total_frames) + + # Create a mask tensor to identify watermarked and non-watermarked portions + # it outputs two classes to match the detector output shape of [bsz, 2, frames] + # Copy the random window from 'x' to 'x_wm' + mixed = x_wm.detach().clone() + + true_predictions = torch.cat( + [torch.zeros_like(mixed), torch.ones_like(mixed)], dim=1 + ) + # non-watermark class correct labels. + true_predictions[:, 0, start_point: start_point + window_frames] = 1.0 + # watermarked class correct labels + true_predictions[:, 1, start_point: start_point + window_frames] = 0.0 + + if shuffle: + # Take the middle part from a random element of the batch + shuffle_idx = torch.randint(0, x.size(0), (x.size(0),)) + mixed[:, :, start_point: start_point + window_frames] = x[shuffle_idx][ + :, :, start_point: start_point + window_frames + ] + else: + mixed[:, :, start_point: start_point + window_frames] = x[ + :, :, start_point: start_point + window_frames + ] + + return mixed, true_predictions diff --git a/audiocraft/solvers/builders.py b/audiocraft/solvers/builders.py index 7c53b3ac..bf18f2d6 100644 --- a/audiocraft/solvers/builders.py +++ b/audiocraft/solvers/builders.py @@ -19,6 +19,7 @@ import torch from torch import nn from torch.optim import Optimizer + # LRScheduler was renamed in some torch versions try: from torch.optim.lr_scheduler import LRScheduler # type: ignore @@ -46,6 +47,7 @@ def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: from .musicgen import MusicGenSolver from .diffusion import DiffusionSolver from .magnet import MagnetSolver, AudioMagnetSolver + from .watermark import WatermarkSolver klass = { 'compression': CompressionSolver, 'musicgen': MusicGenSolver, @@ -55,6 +57,7 @@ def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: 'lm': MusicGenSolver, # backward compatibility 'diffusion': DiffusionSolver, 'sound_lm': AudioGenSolver, # backward compatibility + 'watermarking': WatermarkSolver, }[cfg.solver] return klass(cfg) # type: ignore @@ -189,6 +192,9 @@ def get_loss(loss_name: str, cfg: omegaconf.DictConfig): 'mrstft': losses.MRSTFTLoss, 'msspec': losses.MultiScaleMelSpectrogramLoss, 'sisnr': losses.SISNR, + 'wm_detection': losses.WMDetectionLoss, + 'wm_mb': losses.WMMbLoss, + 'tf_loudnessratio': losses.TFLoudnessRatio }[loss_name] kwargs = dict(getattr(cfg, loss_name)) return klass(**kwargs) diff --git a/audiocraft/solvers/watermark.py b/audiocraft/solvers/watermark.py new file mode 100644 index 00000000..0ae90c7f --- /dev/null +++ b/audiocraft/solvers/watermark.py @@ -0,0 +1,716 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import typing as tp +from functools import partial +import os +from pathlib import Path + +import flashy +from omegaconf import DictConfig +import multiprocessing +import numpy as np +import torch +import torch.nn as nn + +from . import base, builders +from ..models.builders import get_watermark_model +from ..modules.watermark import pad, mix + +from ..metrics.miou import calculate_miou +from ..metrics.pesq import PesqMetric + +from ..utils import checkpoint +from ..utils.audio_effects import ( + compress_with_encodec, + get_audio_effects, + select_audio_effects, +) +from ..utils.samples.manager import SampleManager +from ..data.audio import save_spectrograms +from ..utils.utils import get_pool_executor + +from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio +from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility + + +if tp.TYPE_CHECKING: + from ..models.watermark import WMModel + + +def get_encodec_audio_effect(encodec_cfg: DictConfig, sr: int) -> tp.Dict: + """ + Construct encodec-based compression data agumentation. This method is + is put here instead of in `audiocraft.utils.audio_effects` because + it depends on the package `audiocraft.solvers`, which is one layer + higher than `audiocraft.utils`, so we avoid the circle dependency + from any solvers using `audiocraft.utils.audio_effects` to do the + augmentation + """ + from ..solvers.compression import CompressionSolver + + codec_model = CompressionSolver.model_from_checkpoint(encodec_cfg.ckpt) + codec_model.train() + return { + f"encodec_nq={n_q}": partial( + compress_with_encodec, + model=codec_model, + n_q=n_q, + sample_rate=sr, + ) + for n_q in encodec_cfg.n_qs + } + + +def random_message(nbits: int, batch_size: int) -> torch.Tensor: + """Return random message as 0/1 tensor.""" + if nbits == 0: + return torch.tensor([]) + return torch.randint(0, 2, (batch_size, nbits)) + + +class WatermarkSolver(base.StandardSolver): + """Solver for different watermarking models""" + + def __init__(self, cfg: DictConfig): + super().__init__(cfg) + self.rng: torch.Generator # set at each epoch + self.model: WMModel + if hasattr(cfg, "fsdp"): + assert not getattr( + cfg.fsdp, "use", False + ), "FSDP not supported by WatermarkSolver." + self._init_losses() + self._init_augmentations() + self.balancer = builders.get_balancer(self.loss_weights, self.cfg.balancer) + self.path_specs = os.path.join(self.folder, "spectrograms") + os.makedirs(self.path_specs, exist_ok=True) + + def _init_losses(self): + assert hasattr(self.cfg, "losses") and isinstance( + self.cfg.losses, (DictConfig, tp.Mapping) + ), "WatermarkSolver must declare training losses in the config" + + self.adv_losses = builders.get_adversarial_losses(self.cfg) # noqa + self.register_stateful("adv_losses") + + self.aux_losses = nn.ModuleDict() # noqa + self.info_losses = nn.ModuleDict() # noqa + self.wm_losses = nn.ModuleDict() # noqa + loss_weights = {} + for loss_name, weight in self.cfg.losses.items(): + + # explicitly skip this loss calculation by setting a -1 as weight + # if weight == 0 it will be calculated but kept as info + if weight == -1: + continue + + if loss_name in ["adv", "feat"]: + for adv_name, _ in self.adv_losses.items(): + loss_weights[f"{loss_name}_{adv_name}"] = weight + elif weight > 0: + if loss_name[:3] == "wm_": + self.wm_losses[loss_name] = builders.get_loss( + loss_name, self.cfg + ).to(self.device) + loss_weights[loss_name] = weight + else: + self.aux_losses[loss_name] = builders.get_loss( + loss_name, self.cfg + ).to(self.device) + loss_weights[loss_name] = weight + else: + self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg).to( + self.device + ) + + self.loss_weights = loss_weights # noqa + + def _init_augmentations(self): + if not hasattr(self.cfg, "aug_weights") or not hasattr( + self.cfg, "audio_effects" + ): + return + + aug_weights = {} + cfg_audio_effects = dict(self.cfg.audio_effects) + + # Handle `encodec` augmentation separately as this requires loading a + # CompressionSolver checkpoint + encodec_cfg = cfg_audio_effects.pop("encodec", None) + if encodec_cfg: + encodec_effects = get_encodec_audio_effect( + encodec_cfg, self.cfg.sample_rate + ) + for aug_name in encodec_effects.keys(): + aug_weights[aug_name] = getattr(self.cfg.aug_weights, "encodec", -1) + else: + encodec_effects = {} + + other_effects = get_audio_effects(self.cfg) # noqa + for name in other_effects.keys(): + aug_weights[name] = self.cfg.aug_weights.get(name, -1) + + self.aug_weights = aug_weights # noqa + self.augmentations = {**encodec_effects, **other_effects} # noqa + + @property + def best_metric_name(self) -> tp.Optional[str]: + # best model is the last for the watermark model for now + return None + + def build_model(self): + """Instantiate model and optimizer.""" + # Model and optimizer + self.model = get_watermark_model(self.cfg) + # Need two optimizers ? + self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) + self.register_stateful("model", "optimizer") + self.register_best_state("model") + self.register_ema("model") + + def build_dataloaders(self): + """Instantiate audio dataloaders for each stage.""" + self.dataloaders = builders.get_audio_datasets(self.cfg) + + def show(self): + """Show the Watermark model and employed adversarial loss.""" + self.log_model_summary(self.model) + self.logger.info("Sould print losses here:") + + def crop( + self, signal: torch.Tensor, watermark: torch.Tensor + ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Applies a transformation to modify the watermarked signal to train localization. + It can be one of the following: + - zero padding: add zeros at the begining and the end of the signal + - crop: crop the watermark apply a watermark only on some parts of the signal + - shuffle: replace some part of the audio with other non watermarked parts + from the batch + In every cases the function returns a mask that contains indicates the parts that are or + not watermarked + + Args: + watermark (torch.Tensor): The watermark to apply on the signal. + signal (torch.Tensor): clean signal + Returns: + watermark (torch.Tensor): modified watermark + signal (torch.Tensor): modified signal + mask (torch.Tensor): mask indicating which portion is still watermarked + """ + assert ( + self.cfg.crop.prob + self.cfg.crop.shuffle_prob + self.cfg.crop.pad_prob + <= 1 + ), f"The sum of the probabilities {self.cfg.crop.prob=} {self.cfg.crop.shuffle_prob=} \ + {self.cfg.crop.pad_prob=} should be less than 1" + mask = torch.ones_like(watermark) + p = torch.rand(1) + if p < self.cfg.crop.pad_prob: # Pad with some probability + start = int(torch.rand(1) * 0.33 * watermark.size(-1)) + finish = int((0.66 + torch.rand(1) * 0.33) * watermark.size(-1)) + mask[:, :, :start] = 0 + mask[:, :, finish:] = 0 + if torch.rand(1) > 0.5: + mask = 1 - mask + signal *= mask # pad signal + + elif ( + p < self.cfg.crop.prob + self.cfg.crop.pad_prob + self.cfg.crop.shuffle_prob + ): + # Define a mask, then crop or shuffle + mask_size = round(watermark.shape[-1] * self.cfg.crop.size) + n_windows = int( + torch.randint(1, self.cfg.crop.max_n_windows + 1, (1,)).item() + ) + window_size = int(mask_size / n_windows) + for _ in range(n_windows): # Create multiple windows in the mask + mask_start = torch.randint(0, watermark.shape[-1] - window_size, (1,)) + mask[:, :, mask_start: mask_start + window_size] = ( + 0 # Apply window to mask + ) + # inverse the mask half the time + if torch.rand(1) > 0.5: + mask = 1 - mask + + if p < self.cfg.crop.pad_prob + self.cfg.crop.shuffle_prob: # shuffle + # shuffle + signal_cloned = signal.clone().detach() # detach to be sure + shuffle_idx = torch.randint(0, signal.size(0), (signal.size(0),)) + signal = signal * mask + signal_cloned[shuffle_idx] * ( + 1 - mask + ) # shuffle signal where not wm + + watermark *= mask # Apply mask to the watermark + return signal, watermark, mask + + def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): + """Perform one training or valid step on a given batch.""" + x = batch.to(self.device) + y = x.clone() + nbits = getattr(self.model, "nbits") + message = random_message(nbits, y.shape[0]).to(self.device) + watermark = self.model.get_watermark(x, message=message) + y, watermark, mask = self.crop(y, watermark) + + y_wm = y + watermark + + if ( + self.cfg.losses.adv != 0 or self.cfg.losses.feat != 0 + ) and self.is_training: # train quality adv + d_losses: dict = {} + if ( + len(self.adv_losses) > 0 + and torch.rand(1, generator=self.rng).item() + <= 1 / self.cfg.adversarial.every + ): + for adv_name, adversary in self.adv_losses.items(): + disc_loss = adversary.train_adv(y_wm, y) + d_losses[f"d_{adv_name}"] = disc_loss + metrics["d_loss"] = torch.sum(torch.stack(list(d_losses.values()))) + metrics.update(d_losses) + + balanced_losses: dict = {} + other_losses: dict = {} + + # adversarial losses + if self.cfg.losses.adv != 0 or self.cfg.losses.feat != 0: + for adv_name, adversary in self.adv_losses.items(): + adv_loss, feat_loss = adversary(y_wm, y) + balanced_losses[f"adv_{adv_name}"] = adv_loss + balanced_losses[f"feat_{adv_name}"] = feat_loss + + # auxiliary losses on quality/similarity + for loss_name, criterion in self.aux_losses.items(): + loss = criterion(y_wm, y) + balanced_losses[loss_name] = loss + + # apply augmentations + mode = "all" if self.cfg.select_aug_mode == "all" else "weighted" + selected_augs = select_audio_effects( + self.augmentations, + self.aug_weights, + mode=mode, + max_length=self.cfg.n_max_aug, + ) + N_augs = len(selected_augs) + for ( + augmentation_name, + augmentation_method, + ) in selected_augs.items(): + # concatenate to use the augmentation function only once + y_y_wm = torch.cat([y, y_wm], dim=0) + aug_cat, mask_aug = augmentation_method(y_y_wm, mask=mask) + aug_y = aug_cat[: y.size(0)] + aug_y_wm = aug_cat[y.size(0):] + positive = self.model.detect_watermark(aug_y_wm) + negative = self.model.detect_watermark(aug_y) + for loss_name, criterion in self.wm_losses.items(): + loss = criterion(positive, negative, mask_aug, message) + other_losses[f"{loss_name}_{augmentation_name}"] = loss + + # weighted losses + metrics.update(balanced_losses) + metrics.update(other_losses) + if self.is_training: # something is weird about the loss balancer not + other_loss = torch.tensor(0.0, device=self.device) + for name, o_loss in other_losses.items(): + if "wm_detection" in name: + # here we include the detection losses for augmentation + other_loss += (self.loss_weights["wm_detection"] / N_augs) * o_loss + elif "wm_mb" in name: + other_loss += (self.loss_weights["wm_mb"] / N_augs) * o_loss + else: + other_loss += self.loss_weights[name] * o_loss + if other_loss.requires_grad: + other_loss.backward(retain_graph=True) + ratio1 = sum( + p.grad.data.norm(p=2).pow(2) + for p in self.model.parameters() + if p.grad is not None + ) + assert isinstance(ratio1, torch.Tensor) + metrics["ratio1"] = ratio1.sqrt() + + # balancer losses backward, returns effective training loss + # with effective weights at the current batch. + metrics["g_loss"] = self.balancer.backward(balanced_losses, y_wm) + # add metrics corresponding to weight ratios + metrics.update(self.balancer.metrics) + ratio2 = sum( + p.grad.data.norm(p=2).pow(2) + for p in self.model.parameters() + if p.grad is not None + ) + assert isinstance(ratio2, torch.Tensor) + metrics["ratio2"] = ratio2.sqrt() + + # optim + flashy.distrib.sync_model(self.model) + if self.cfg.optim.max_norm: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.optim.max_norm + ) + + self.optimizer.step() + self.optimizer.zero_grad() + + # informative losses only + info_losses: dict = {} + with torch.no_grad(): + for loss_name, criterion in self.info_losses.items(): + loss = criterion(y_wm, y) + info_losses[loss_name] = loss + # pesq + metrics["pesq"] = tensor_pesq(y_wm, y, sr=self.cfg.sample_rate) + # max allocated memory + metrics["max_mem"] = torch.cuda.max_memory_allocated() / 1e9 + + metrics.update(info_losses) + if self.cfg.losses.adv != 0 or self.cfg.losses.feat != 0: + # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups + adv_losses = [ + loss + for loss_name, loss in metrics.items() + if loss_name.startswith("adv") + ] + if len(adv_losses) > 0: + metrics["adv"] = torch.sum(torch.stack(adv_losses)) + feat_losses = [ + loss + for loss_name, loss in metrics.items() + if loss_name.startswith("feat") + ] + if len(feat_losses) > 0: + metrics["feat"] = torch.sum(torch.stack(feat_losses)) + + return metrics + + def run_epoch(self): + # reset random seed at the beginning of the epoch + self.rng = torch.Generator() + self.rng.manual_seed(1234 + self.epoch) + # run epoch + super().run_epoch() + + def evaluate(self) -> dict: + """Evaluate stage. Runs audio reconstruction evaluation.""" + self.model.eval() + evaluate_stage_name = str(self.current_stage) + + loader = self.dataloaders["evaluate"] + updates = len(loader) + lp = self.log_progress( + f"{evaluate_stage_name} inference", + loader, + total=updates, + updates=self.log_updates, + ) + average = flashy.averager() + + pendings = [] + ctx = multiprocessing.get_context("spawn") + with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: + for batch in lp: + x = batch.to(self.device) + with torch.no_grad(): + message = random_message(self.model.nbits, x.shape[0]) + watermark = self.model.get_watermark(x, message) + x_wm = x + watermark + y_pred = x_wm.cpu() + y = batch.cpu() # should already be on CPU but just in case + pendings.append( + pool.submit( + evaluate_audio_watermark, + y_pred, + y, + self.cfg, + ) + ) + # evaluate augmentations + # evaluation is run on all the augmentations + for ( + augmentation_name, + augmentation_method, + ) in self.augmentations.items(): + # if ( + # "mp3" in augmentation_name + # and idx >= 8 + # and self.cfg.evaluate.every <= 2 + # ): + # # When evaluating often do not compute mp3 on the full eval dset to make things faster + # continue + with torch.no_grad(): + aug_positive = self.model.detect_watermark( + augmentation_method(x_wm) + ) + aug_negative = self.model.detect_watermark( + augmentation_method(x) + ) + + pendings.append( + pool.submit( + evaluate_augmentations, + aug_positive.cpu(), + aug_negative.cpu(), + augmentation_name, + message.cpu(), + ) + ) + # end eval of augmentations + + # evaluate localization cropping + for window_size in np.linspace(0.1, 0.9, 9): + + mixed, true_predictions = mix(x, x_wm, window_size=window_size) + model_predictions = self.model.detect_watermark(mixed) + pendings.append( + pool.submit( + evaluate_localizations, + model_predictions.cpu(), + true_predictions.cpu(), + f"crop_{window_size:0.1f}", + ) + ) + mixed, true_predictions = mix( + x, x_wm, window_size=window_size, shuffle=True + ) + model_predictions = self.model.detect_watermark(mixed) + pendings.append( + pool.submit( + evaluate_localizations, + model_predictions.cpu(), + true_predictions.cpu(), + f"shuffle_{window_size:0.1f}", + ) + ) + # evaluate localization padding + mixed, true_predictions = pad(x_wm) + model_predictions = self.model.detect_watermark(mixed) + pendings.append( + pool.submit( + evaluate_localizations, + model_predictions.cpu(), + true_predictions.cpu(), + "padding", + ) + ) + mixed, true_predictions = pad(x_wm, central=True) + model_predictions = self.model.detect_watermark(mixed) + pendings.append( + pool.submit( + evaluate_localizations, + model_predictions.cpu(), + true_predictions.cpu(), + "central_padding", + ) + ) + # end of evaluate localization + + metrics_lp = self.log_progress( + f"{evaluate_stage_name} metrics", pendings, updates=self.log_updates + ) + for pending in metrics_lp: + metrics = pending.result() + metrics = average(metrics) + + metrics = flashy.distrib.average_metrics(metrics, len(loader)) + if self.cfg.select_aug_mode == "use_eval_acc": + # Adjust augmentation weights based on evaluation loss. + # Higher accuracy results in lower probability of selecting this augmentation. + for name in self.augmentations.keys(): + if ( + self.aug_weights[name] != -1 + ): # keep weight to -1 for unwanted augmentations + # set to 0.05 to ensure that an augmentation is never completely removed during a full epoch. + self.aug_weights[name] = max(1 - metrics[f"aug_{name}_acc"], 0.05) + return metrics + + def generate(self): + """Generate stage.""" + self.model.eval() + sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) + generate_stage_name = str(self.current_stage) + + loader = self.dataloaders["generate"] + updates = len(loader) + lp = self.log_progress( + generate_stage_name, loader, total=updates, updates=self.log_updates + ) + path_dir = os.path.join(self.path_specs, f"epoch={self.epoch}") + os.makedirs(path_dir, exist_ok=True) + first_batch = True + for batch in lp: + reference, _ = batch + reference = reference.to(self.device) + with torch.no_grad(): + message = random_message(self.model.nbits, reference.shape[0]) + watermark = self.model.get_watermark(reference, message) + x_wm = reference + watermark + + reference = reference.cpu() + sample_manager.add_samples( + x_wm.cpu(), self.epoch, ground_truth_wavs=reference + ) + if first_batch and flashy.distrib.is_rank_zero(): + for i in range(reference.size(0)): + ys = [ + reference.cpu()[i].squeeze(0).numpy(), + x_wm.cpu()[i].squeeze(0).numpy(), + watermark.cpu()[i].squeeze(0).numpy(), + ] + path = os.path.join(path_dir, f"spec_{i}.pdf") + save_spectrograms( + ys, + names=["Ground Truth", "Audio Watermarked", "Watermark"], + sr=self.cfg.sample_rate, + path=path, + ) + first_batch = False + flashy.distrib.barrier() + + def load_from_pretrained(self, name: str) -> dict: + raise ValueError("No pretrained model") + + @staticmethod + def model_from_checkpoint( + checkpoint_path: tp.Union[Path, str], + device: tp.Union[torch.device, str] = "cpu", + ) -> "WMModel": + """Instantiate a WatermarkModel from a given checkpoint path or dora sig. + + Args: + checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. + device (torch.device or str): Device on which the model is loaded. + """ + checkpoint_path = str(checkpoint_path) + logger = logging.getLogger(__name__) + logger.info(f"Loading WatermarkModel from checkpoint: {checkpoint_path}") + _checkpoint_path = checkpoint.resolve_checkpoint_path( + checkpoint_path, use_fsdp=False + ) + assert ( + _checkpoint_path is not None + ), f"Could not resolve WatermarkModel checkpoint path: {checkpoint_path}" + state = checkpoint.load_checkpoint(_checkpoint_path) + assert ( + state is not None and "xp.cfg" in state + ), f"Could not load WatermarkModel from ckpt: {checkpoint_path}" + cfg = state["xp.cfg"] + cfg.device = device + watermarking_model = get_watermark_model(cfg).to(device) + + assert "best_state" in state and state["best_state"] != {} + assert ( + "exported" not in state + ), "When loading an exported checkpoint, use the //pretrained/ prefix." + watermarking_model.load_state_dict(state["best_state"]["model"]) + watermarking_model.eval() + logger.info("Watermarking model loaded!") + return watermarking_model + + +def evaluate_localizations(predictions, true_predictions, name): + metrics = {} + # predictions are output of the detector shape [bsz, 2, frames] + # true_predictions is output of the mix method shape [bsz, 2, frames] + metrics[f"localization_acc_{name}"] = ( + ((predictions[:, 1, :] > 0.5) == true_predictions[:, 1, :]) + .float() + .mean() + .item() + ) + metrics[f"localization_miou_{name}"] = calculate_miou( + predictions[:, 1, :], true_predictions[:, 1, :] + ) + return metrics + + +def evaluate_augmentations( + positive: torch.Tensor, + negative: torch.Tensor, + augmentation_name: str, + message: torch.Tensor, +) -> dict: + """calculating evaluation metrics but take name of the augmentation + method that has been done before getting positive and negative results""" + metrics = {} + metrics[f"aug_{augmentation_name}_acc"] = compute_accuracy(positive, negative) + metrics[f"aug_{augmentation_name}_fpr"] = compute_FPR(negative) + metrics[f"aug_{augmentation_name}_fnr"] = compute_FNR(positive) + if message.shape[0] != 0: + metrics[f"aug_{augmentation_name}_bit_acc"] = compute_bit_acc(positive, message) + + # add one metric which is average overall score of all augmentations + metrics["all_aug_acc"] = compute_accuracy(positive, negative) + + return metrics + + +def evaluate_audio_watermark( + y_pred: torch.Tensor, + y: torch.Tensor, + cfg: DictConfig, +) -> dict: + """Audio reconstruction evaluation method that can be conveniently pickled.""" + metrics = {} + if cfg.evaluate.metrics.visqol: + visqol = builders.get_visqol(cfg.metrics.visqol) + metrics["visqol"] = visqol(y_pred, y, cfg.sample_rate) + sisnr = ScaleInvariantSignalNoiseRatio().to(y.device) + stoi = ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate) + metrics["sisnr"] = sisnr(y_pred, y) + metrics["stoi"] = stoi(y_pred, y) + metrics["pesq"] = tensor_pesq(y_pred, y, sr=cfg.sample_rate) + return metrics + + +def tensor_pesq(y_pred: torch.Tensor, y: torch.Tensor, sr: int): + # pesq returns error if no speech is detected, so we catch it + return PesqMetric(sr)(y_pred, y).item() + + +def compute_accuracy(positive, negative): + N = (positive[:, 1, :].mean(dim=1) > 0.5).sum() + ( + negative[:, 0, :].mean(dim=1) > 0.5 + ).sum() + acc = N / (2 * positive.size(0)) + return acc + + +def compute_FPR(negative): + N = (negative[:, 1, :].mean(dim=1) > 0.5).sum() + fpr = N / (negative.size(0)) + return fpr + + +def compute_FNR(positive): + N = (positive[:, 0, :].mean(dim=1) > 0.5).sum() + fpr = N / (positive.size(0)) + return fpr + + +def _bit_acc(decoded, original): + bit_acc = (decoded == original).float().mean() + return bit_acc + + +def compute_bit_acc(positive, original, mask=None): + """Compute bit accuracy. + Args: + positive: detector outputs [bsz, 2+nbits, time_steps] + original: original message (0 or 1) [bsz, nbits] + mask: mask of the watermark [bsz, 1, time_steps] + """ + decoded = positive[:, 2:, :] # b 2+nbits t -> b nbits t + if mask is not None: + # cut last dim of positive to keep only where mask is 1 + new_shape = [*decoded.shape[:-1], -1] # b nbits t -> b nbits -1 + decoded = torch.masked_select(decoded, mask == 1).reshape(new_shape) + # average decision over time, then threshold + decoded = decoded.mean(dim=-1) > 0 # b nbits + return _bit_acc(decoded, original) diff --git a/audiocraft/utils/audio_effects.py b/audiocraft/utils/audio_effects.py new file mode 100644 index 00000000..70fe4dbe --- /dev/null +++ b/audiocraft/utils/audio_effects.py @@ -0,0 +1,457 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import inspect +import random +import typing as tp +from functools import partial + +import julius +import omegaconf +import torch +from julius import fft_conv1d, resample_frac + +from ..data.audio_utils import get_aac, get_mp3 + +if tp.TYPE_CHECKING: + from ..models.encodec import CompressionModel + + +def select_audio_effects( + audio_effects: tp.Dict, + weights: tp.Optional[tp.Dict] = None, + mode: str = "all", + max_length: tp.Optional[int] = None, +): + """Samples a subset of audio effects methods from the `AudioEffects` class. + + This function allows you to select a subset of audio effects + based on the chosen selection mode and optional weights. + + Args: + audio_effects (dict): A dictionary of available audio augmentations, usually + obtained from the output of the 'get_audio_effects' function. + weights (dict): A dictionary mapping augmentation names to their corresponding + probabilities of being selected. This argument is used when 'mode' is set + to "weighted." If 'weights' is None, all augmentations have equal + probability of being selected. + mode (str): The selection mode, which can be one of the following: + - "all": Select all available augmentations. + - "weighted": Select augmentations based on their probabilities in the + 'weights' dictionary. + max_length (int): The maximum number of augmentations to select. If 'max_length' + is None, no limit is applied. + + Returns: + dict: A subset of the 'audio_effects' dictionary containing the selected audio + augmentations. + + Note: + - In "all" mode, all available augmentations are selected. + - In "weighted" mode, augmentations are selected with a probability + proportional to their weights specified in the 'weights' dictionary. + - If 'max_length' is set, the function limits the number of selected + augmentations. + - If no augmentations are selected or 'audio_effects' is empty, the function + defaults to including an "identity" augmentation. + - The "identity" augmentation means that no audio effect is applied. + """ + if mode == "all": # original code + out = audio_effects + elif mode == "weighted": + # Probability proportionnal to weights + assert weights is not None + out = { + name: value + for name, value in audio_effects.items() + if random.random() < weights.get(name, 1.0) + } + else: + raise ValueError(f"Unknown mode {mode}") + if max_length is not None: + # Help having a deterministic limit of the gpu memory usage + random_keys = random.sample(list(out.keys()), max_length) + out = {key: out[key] for key in random_keys} + if len(out) == 0: # Check not to return empty dict + out = {"identity": AudioEffects.identity} + return out + + +def get_audio_effects(cfg: omegaconf.DictConfig): + """Automatically pull the list all effects available in this class based on the parameters from the cfg + + Returns: + dict: A dict of names and pointers to all methods in this class. + """ + assert hasattr(cfg, "audio_effects") + cfg_audio_effects = dict(cfg["audio_effects"]) + return { + name: partial(value, **cfg_audio_effects.get(name, {})) + for name, value in inspect.getmembers(AudioEffects) + if inspect.isfunction(value) + } + + +def audio_effect_return( + tensor: torch.Tensor, mask: tp.Optional[torch.Tensor] +) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Return the mask if it was in the input otherwise only the output tensor""" + if mask is None: + return tensor + else: + return tensor, mask + + +def generate_pink_noise(length: int) -> torch.Tensor: + """Generate pink noise using Voss-McCartney algorithm with PyTorch.""" + num_rows = 16 + array = torch.randn(num_rows, length // num_rows + 1) + reshaped_array = torch.cumsum(array, dim=1) + reshaped_array = reshaped_array.reshape(-1) + reshaped_array = reshaped_array[:length] + # Normalize + pink_noise = reshaped_array / torch.max(torch.abs(reshaped_array)) + return pink_noise + + +def compress_with_encodec( + tensor: torch.Tensor, + n_q: int, + model: "CompressionModel", + sample_rate: int, + mask: tp.Optional[torch.Tensor] = None, +) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Special augmentation function that compresses and decompresses wav tensor + using a compression model with the n_q codebooks + """ + + model.to(tensor.device) + model.set_num_codebooks(n_q) + codes, scale = model.encode( + julius.resample_frac(tensor, old_sr=sample_rate, new_sr=model.sample_rate) + ) + compressed = model.decode(codes=codes, scale=scale) + return audio_effect_return( + tensor=julius.resample_frac( + compressed, old_sr=model.sample_rate, new_sr=sample_rate + ), + mask=mask, + ) + + +def apply_compression_skip_grad(tensor: torch.Tensor, compression_fn, **kwargs): + """Applies a specified compression function to the audio tensor. + Whire carrying over the grads to the output tensor with skip through estimator + this is a straight through estimator to make mp3/aac compression differentiable + see more: Yin et al. 2019 https://arxiv.org/pdf/1903.05662.pdf + + Args: + tensor (torch.Tensor): The input audio tensor. + compression_fn (function): The compression function to apply. + **kwargs: Additional keyword arguments for the compression function. + + Returns: + torch.Tensor: The output tensor after applying compression and straight through estimator. + """ + compressed = compression_fn(tensor.detach(), **kwargs) + + # Trim compressed output if needed + compressed = compressed[:, :, : tensor.size(-1)] + + # Straight through estimator for differentiable compression + out = tensor + (compressed - tensor).detach() + + # Check that gradients are not broken + if out.requires_grad: + assert ( + out.grad_fn + ), "The computation graph might be broken due to compression augmentation." + + return out + + +class AudioEffects: + @staticmethod + def speed( + tensor: torch.Tensor, + speed_range: tuple = (0.5, 1.5), + sample_rate: int = 16000, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Function to change the speed of a batch of audio data. + The output will have a different length ! + + Args: + audio_batch (torch.Tensor): The batch of audio data in torch tensor format. + speed (float): The speed to change the audio to. + + Returns: + torch.Tensor: The batch of audio data with the speed changed. + """ + speed = torch.FloatTensor(1).uniform_(*speed_range) + new_sr = int(sample_rate * 1 / speed) + resampled_tensor = julius.resample.resample_frac(tensor, sample_rate, new_sr) + if mask is None: + return resampled_tensor + else: + return resampled_tensor, torch.nn.functional.interpolate( + mask, size=resampled_tensor.size(-1), mode="nearest-exact" + ) + + @staticmethod + def updownresample( + tensor: torch.Tensor, + sample_rate: int = 16000, + intermediate_freq: int = 32000, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + + orig_shape = tensor.shape + # upsample + tensor = resample_frac(tensor, sample_rate, intermediate_freq) + # downsample + tensor = resample_frac(tensor, intermediate_freq, sample_rate) + + assert tensor.shape == orig_shape + return audio_effect_return(tensor=tensor, mask=mask) + + @staticmethod + def echo( + tensor: torch.Tensor, + volume_range: tuple = (0.1, 0.5), + duration_range: tuple = (0.1, 0.5), + sample_rate: int = 16000, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Attenuating the audio volume by a factor of 0.4, delaying it by 100ms, + and then overlaying it with the original. + + Args: + tensor: 3D Tensor representing the audio signal [bsz, channels, frames] + volumne range: volume range of the echo signal + duration range: duration range of the echo signal + sample_rate: Sample rate of the audio signal. + Returns: + Audio signal with reverb. + """ + + # Create a simple impulse response + # Duration of the impulse response in seconds + duration = torch.FloatTensor(1).uniform_(*duration_range) + volume = torch.FloatTensor(1).uniform_(*volume_range) + + n_samples = int(sample_rate * duration) + impulse_response = torch.zeros(n_samples).type(tensor.type()).to(tensor.device) + + # Define a few reflections with decreasing amplitude + impulse_response[0] = 1.0 # Direct sound + + impulse_response[ + int(sample_rate * duration) - 1 + ] = volume # First reflection after 100ms + + # Add batch and channel dimensions to the impulse response + impulse_response = impulse_response.unsqueeze(0).unsqueeze(0) + + # Convolve the audio signal with the impulse response + reverbed_signal = fft_conv1d(tensor, impulse_response) + + # Normalize to the original amplitude range for stability + reverbed_signal = ( + reverbed_signal + / torch.max(torch.abs(reverbed_signal)) + * torch.max(torch.abs(tensor)) + ) + + # Ensure tensor size is not changed + tmp = torch.zeros_like(tensor) + tmp[..., : reverbed_signal.shape[-1]] = reverbed_signal + reverbed_signal = tmp + + return audio_effect_return(tensor=reverbed_signal, mask=mask) + + @staticmethod + def random_noise( + waveform: torch.Tensor, + noise_std: float = 0.001, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Add Gaussian noise to the waveform.""" + noise = torch.randn_like(waveform) * noise_std + noisy_waveform = waveform + noise + return audio_effect_return(tensor=noisy_waveform, mask=mask) + + @staticmethod + def pink_noise( + waveform: torch.Tensor, + noise_std: float = 0.01, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Add pink background noise to the waveform.""" + noise = generate_pink_noise(waveform.shape[-1]) * noise_std + noise = noise.to(waveform.device) + # Assuming waveform is of shape (bsz, channels, length) + noisy_waveform = waveform + noise.unsqueeze(0).unsqueeze(0).to(waveform.device) + return audio_effect_return(tensor=noisy_waveform, mask=mask) + + @staticmethod + def lowpass_filter( + waveform: torch.Tensor, + cutoff_freq: float = 5000, + sample_rate: int = 16000, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Filter the lowpass frequency from the waveform""" + return audio_effect_return( + tensor=julius.lowpass_filter(waveform, cutoff=cutoff_freq / sample_rate), + mask=mask, + ) + + @staticmethod + def highpass_filter( + waveform: torch.Tensor, + cutoff_freq: float = 500, + sample_rate: int = 16000, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Filter the highpass frequency from the waveform""" + return audio_effect_return( + tensor=julius.highpass_filter(waveform, cutoff=cutoff_freq / sample_rate), + mask=mask, + ) + + @staticmethod + def bandpass_filter( + waveform: torch.Tensor, + cutoff_freq_low: float = 300, + cutoff_freq_high: float = 8000, + sample_rate: int = 16000, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Apply a bandpass filter to the waveform by cascading + a high-pass filter followed by a low-pass filter. + + Args: + waveform (torch.Tensor): Input audio waveform. + low_cutoff (float): Lower cutoff frequency. + high_cutoff (float): Higher cutoff frequency. + sample_rate (int): The sample rate of the waveform. + + Returns: + torch.Tensor: Filtered audio waveform. + """ + + return audio_effect_return( + tensor=julius.bandpass_filter( + waveform, + cutoff_low=cutoff_freq_low / sample_rate, + cutoff_high=cutoff_freq_high / sample_rate, + ), + mask=mask, + ) + + @staticmethod + def smooth( + tensor: torch.Tensor, + window_size_range: tuple = (2, 10), + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Smooths the input tensor (audio signal) using a moving average filter with the + given window size. + + Args: + tensor (torch.Tensor): Input audio tensor. Assumes tensor shape is (batch_size, + channels, time). + window_size (int): Size of the moving average window. + mask: Masks for the input wave + + Returns: + torch.Tensor: Smoothed audio tensor. + """ + + window_size = int(torch.FloatTensor(1).uniform_(*window_size_range)) + # Create a uniform smoothing kernel + kernel = torch.ones(1, 1, window_size).type(tensor.type()) / window_size + kernel = kernel.to(tensor.device) + + smoothed = fft_conv1d(tensor, kernel) + # Ensure tensor size is not changed + tmp = torch.zeros_like(tensor) + tmp[..., : smoothed.shape[-1]] = smoothed + smoothed = tmp + + return audio_effect_return(tensor=smoothed, mask=mask) + + @staticmethod + def boost_audio( + tensor: torch.Tensor, + amount: float = 20, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Filter the lowpass frequency from the waveform""" + return audio_effect_return(tensor=tensor * (1 + amount / 100), mask=mask) + + @staticmethod + def duck_audio( + tensor: torch.Tensor, + amount: float = 20, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Mask input wav with some ducked signnals""" + return audio_effect_return(tensor=tensor * (1 - amount / 100), mask=mask) + + @staticmethod + def identity( + tensor: torch.Tensor, mask: tp.Optional[torch.Tensor] = None + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + return audio_effect_return(tensor=tensor, mask=mask) + + @staticmethod + def mp3_compression( + tensor: torch.Tensor, + sample_rate: int = 16000, + bitrate: str = "128k", + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """ + Compress audio using MP3 algorithm + Args: + tensor (torch.Tensor): The input audio tensor. + sample_rate (int): The sample rate of the audio. + bitrate (str): The bitrate for MP3 compression. + + Returns: + torch.Tensor: The output tensor after applying MP3 compression. + """ + out = apply_compression_skip_grad( + tensor, get_mp3, sr=sample_rate, bitrate=bitrate + ) + return audio_effect_return(tensor=out, mask=mask) + + @staticmethod + def aac_compression( + tensor: torch.Tensor, + sample_rate: int = 16000, + bitrate: str = "128k", + lowpass_freq: tp.Optional[int] = None, + mask: tp.Optional[torch.Tensor] = None, + ) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + """Applies AAC compression to an audio tensor. + + Args: + tensor (torch.Tensor): The input audio tensor. + sample_rate (int): The sample rate of the audio. + bitrate (str): The bitrate for AAC compression. + lowpass_freq (Optional[int]): The frequency for a low-pass filter. + + Returns: + torch.Tensor: The output tensor after applying AAC compression. + """ + out = apply_compression_skip_grad( + tensor, get_aac, sr=sample_rate, bitrate=bitrate, lowpass_freq=lowpass_freq + ) + return audio_effect_return(tensor=out, mask=mask) diff --git a/config/augmentations/default.yaml b/config/augmentations/default.yaml new file mode 100644 index 00000000..120887b0 --- /dev/null +++ b/config/augmentations/default.yaml @@ -0,0 +1,65 @@ +# @package __global__ + +audio_effects: + speed: + sample_rate: ${sample_rate} + speed_range: [0.8, 1.2] + updownresample: + sample_rate: ${sample_rate} + intermediate_freq: 32000 + echo: + sample_rate: ${sample_rate} + volume_range: [0.1, 0.5] + duration_range: [0.1, 0.5] + random_noise: + noise_std: 0.001 + pink_noise: + noise_std: 0.01 + lowpass_filter: + sample_rate: ${sample_rate} + cutoff_freq: 5000 + highpass_filter: + cutoff_freq: 500 + sample_rate: ${sample_rate} + bandpass_filter: + cutoff_freq_low: 300 + cutoff_freq_high: 8000 + sample_rate: ${sample_rate} + smooth: + window_size_range: [2, 10] + boost_audio: + amount: 20 + duck_audio: + amount: 20 + mp3_compression: + sample_rate: ${sample_rate} + bitrate: 128k # should be a string e.g. "8k", "32k".. cf ffmpeg to see available bitrates + aac_compression: + sample_rate: ${sample_rate} + bitrate: 128k # should be a string e.g. "8k", "32k".. cf ffmpeg to see available bitrates + lowpass_freq: null # don't apply low pass freq to ffmpeg aac compression + encodec: + ckpt: "//pretrained/facebook/encodec_24khz" + n_qs: [4, 8, 16] + +select_aug_mode: + "use_eval" # other are 'all' and 'use_eval_acc', used to sample augmentations, `fixed` uses the prob from aug_weights, `all` uses all agmentations every step + # `use_eval_acc` changes the weights based on the accuracies at evaluation time + +aug_weights: + speed: 0.1 + updownresample: 0.1 + echo: 0.1 + pink_noise: 0.1 + lowpass_filter: 0.1 + highpass_filter: 0.1 + bandpass_filter: 0.1 + smooth: 0.1 + boost_audio: 0.1 + duck_audio: 0.1 + mp3_compression: 0.1 # eval only never use in training even if eval_acc low + aac_compression: 0.1 # eval only never use in training even if eval_acc low + encodec: 0.1 + identity: 1 # no augmentation + +n_max_aug: null \ No newline at end of file diff --git a/config/model/watermark/default.yaml b/config/model/watermark/default.yaml new file mode 100644 index 00000000..6e17abb5 --- /dev/null +++ b/config/model/watermark/default.yaml @@ -0,0 +1,41 @@ +# @package __global__ + +audioseal: + autoencoder: seanet + sample_rate: 16000 + channels: 1 + nbits: 16 + +seanet: + dimension: 128 + channels: 1 + causal: false + n_filters: 32 + n_residual_layers: 1 + ratios: [8, 5, 4, 2] + activation: ELU + activation_params: { "alpha": 1. } + norm: weight_norm + norm_params: {} + kernel_size: 7 + residual_kernel_size: 3 + last_kernel_size: 7 + dilation_base: 2 + pad_mode: constant + true_skip: true + compress: 2 + lstm: 2 + disable_norm_outer_blocks: 0 + # Specific encoder or decoder params. + # You can also override any param for the encoder or decoder only + # by using Hydra `+param=` syntax, i.e.` + # `+seanet.decoder.n_filters=64`. + decoder: + trim_right_ratio: 1.0 + final_activation: null + final_activation_params: null + encoder: {} + +detector: { + "output_dim": 32, # output channels of detector upsampling +} \ No newline at end of file diff --git a/config/solver/watermark/debug.yaml b/config/solver/watermark/debug.yaml new file mode 100644 index 00000000..64c002d8 --- /dev/null +++ b/config/solver/watermark/debug.yaml @@ -0,0 +1,207 @@ +# @package __global__ + +defaults: + - /solver/default + - /augmentations/default + - /model: watermark/default + - override /dset: audio/example + - _self_ + +solver: watermarking # standard name to load the solver using builders +sample_rate: 48000 +channels: 1 + +# all the defaults form compression +losses: + adv: 4. + feat: 4. + l1: 0.1 + mel: 0.0 + msspec: 2.0 + sisnr: 0.0 + wm_detection: 1.0 # loss for first 2 bits cannot be 0 + wm_mb: 1.0 # loss for the rest of the bits (wm message) + tf_loudnessratio: 10.0 + +balancer: + balance_grads: true + ema_decay: 0.999 + per_batch_item: true + total_norm: 1. + +crop: + prob: 0.4 + shuffle_prob: 0.2 + pad_prob: 0.2 # shuffle_prob + pad_prob + prob <= 1 + size: 0.5 + max_n_windows: 5 + +adversarial: + every: 1 + adversaries: [msstftd] + adv_loss: hinge + feat_loss: l1 + +tf_loudnessratio: + sample_rate: ${sample_rate} + segment: 0.5 + overlap: 0.5 + n_bands: 16 + temperature: 1.0 + +# watermarking: audioseal + +# losses hyperparameters +l1: {} +l2: {} + +wm_detection: + p_weight: 1 + n_weight: 1 + +wm_mb: + loss_type: bce # loss between decoded and original + temperature: 0.1 # decoded is divided by temperature before loss computation + +spec_range: + n_fft: 2048 + min_frequency: 300.0 + max_frequency: 15000.0 + sample_rate: ${sample_rate} +spec_entropy_range: + n_fft: 2048 + min_frequency: 300.0 + max_frequency: 15000.0 + sample_rate: ${sample_rate} +mrstft: + factor_sc: .5 + factor_mag: .5 + normalized: false +mel: + sample_rate: ${sample_rate} + n_fft: 1024 + hop_length: 256 + win_length: 1024 + n_mels: 64 + f_min: 64 + f_max: null + normalized: false + floor_level: 1e-5 +sisnr: + sample_rate: ${sample_rate} + segment: 5. +msspec: + sample_rate: ${sample_rate} + range_start: 6 + range_end: 11 + n_mels: 64 + f_min: 64 + f_max: null + normalized: true + alphas: false + floor_level: 1e-5 + +# metrics +metrics: + visqol: + mode: audio + bin: null # path to visqol install + model: tcdaudio14_aacvopus_coresv_svrnsim_n.68_g.01_c1.model # visqol v3 + +# adversaries hyperparameters +msstftd: + in_channels: 1 + out_channels: 1 + filters: 32 + norm: weight_norm + n_ffts: [1024, 2048, 512, 256, 128] + hop_lengths: [256, 512, 128, 64, 32] + win_lengths: [1024, 2048, 512, 256, 128] + activation: LeakyReLU + activation_params: { negative_slope: 0.3 } +msd: + in_channels: 1 + out_channels: 1 + scale_norms: [spectral_norm, weight_norm, weight_norm] + kernel_sizes: [5, 3] + filters: 16 + max_filters: 1024 + downsample_scales: [4, 4, 4, 4] + inner_kernel_sizes: null + groups: [4, 4, 4, 4] + strides: null + paddings: null + activation: LeakyReLU + activation_params: { negative_slope: 0.3 } +mpd: + in_channels: 1 + out_channels: 1 + periods: [2, 3, 5, 7, 11] + n_layers: 5 + kernel_size: 5 + stride: 3 + filters: 8 + filter_scales: 4 + max_filters: 1024 + activation: LeakyReLU + activation_params: { negative_slope: 0.3 } + norm: weight_norm + +# data hyperparameters +dataset: + batch_size: 16 + num_workers: 10 + segment_duration: 1 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + + generate: + batch_size: 16 + num_samples: 50 + segment_duration: 30 + +# solver hyperparameters +evaluate: + every: 10 + num_workers: 5 + metrics: + visqol: false + sisnr: true +generate: + every: 10 + num_workers: 5 + audio: + sample_rate: ${sample_rate} + +# checkpointing schedule +checkpoint: + save_last: true + save_every: 25 + keep_last: 10 + keep_every_states: null + + + +# optimization hyperparameters +optim: + epochs: 2 + updates_per_epoch: 10 + lr: 5e-5 + max_norm: 3.0 + optimizer: adam + adam: + betas: [0.5, 0.9] + weight_decay: 0. + ema: + use: true # whether to use EMA or not + updates: 1 # update at every step + device: ${device} # device for EMA, can be put on GPU if more frequent updates + decay: 0.99 # EMA decay value, if null, no EMA is used + + +schedule: + lr_scheduler: "cosine" + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 diff --git a/config/solver/watermark/default.yaml b/config/solver/watermark/default.yaml new file mode 100644 index 00000000..5726e414 --- /dev/null +++ b/config/solver/watermark/default.yaml @@ -0,0 +1,212 @@ +# @package __global__ + +defaults: + - /solver/default + - /augmentations/default + - override /dset: audio/example + - _self_ + +solver: watermarking # standard name to load the solver using builders +sample_rate: ??? +channels: ??? + +# all the defaults form compression +losses: + adv: 4. + feat: 4. + l1: 0.1 + mel: 0.0 + msspec: 2.0 + sisnr: 0.0 + wm_detection: 1.0 # loss for first 2 bits cannot be 0 + wm_mb: 1.0 # loss for the rest of the bits (wm message) + tf_loudnessratio: 10.0 + +balancer: + balance_grads: true + ema_decay: 0.999 + per_batch_item: true + total_norm: 1. + +crop: + prob: 0.4 + shuffle_prob: 0.2 + pad_prob: 0.2 # shuffle_prob + pad_prob + prob <= 1 + size: 0.5 + max_n_windows: 5 + +adversarial: + every: 1 + adversaries: [msstftd] + adv_loss: hinge + feat_loss: l1 + +tf_loudnessratio: + sample_rate: ${sample_rate} + segment: 0.5 + overlap: 0.5 + n_bands: 16 + temperature: 1.0 + +# watermarking: audioseal + +# losses hyperparameters +l1: {} +l2: {} + +wm_detection: + p_weight: 1 + n_weight: 1 + +wm_mb: + loss_type: bce # loss between decoded and original + temperature: 0.1 # decoded is divided by temperature before loss computation + +spec_range: + n_fft: 2048 + min_frequency: 300.0 + max_frequency: 15000.0 + sample_rate: ${sample_rate} +spec_entropy_range: + n_fft: 2048 + min_frequency: 300.0 + max_frequency: 15000.0 + sample_rate: ${sample_rate} +mrstft: + factor_sc: .5 + factor_mag: .5 + normalized: false +mel: + sample_rate: ${sample_rate} + n_fft: 1024 + hop_length: 256 + win_length: 1024 + n_mels: 64 + f_min: 64 + f_max: null + normalized: false + floor_level: 1e-5 +sisnr: + sample_rate: ${sample_rate} + segment: 5. +msspec: + sample_rate: ${sample_rate} + range_start: 6 + range_end: 11 + n_mels: 64 + f_min: 64 + f_max: null + normalized: true + alphas: false + floor_level: 1e-5 + +# metrics +metrics: + visqol: + mode: audio + bin: null # path to visqol install + model: tcdaudio14_aacvopus_coresv_svrnsim_n.68_g.01_c1.model # visqol v3 + +# adversaries hyperparameters +msstftd: + in_channels: 1 + out_channels: 1 + filters: 32 + norm: weight_norm + n_ffts: [1024, 2048, 512, 256, 128] + hop_lengths: [256, 512, 128, 64, 32] + win_lengths: [1024, 2048, 512, 256, 128] + activation: LeakyReLU + activation_params: { negative_slope: 0.3 } +msd: + in_channels: 1 + out_channels: 1 + scale_norms: [spectral_norm, weight_norm, weight_norm] + kernel_sizes: [5, 3] + filters: 16 + max_filters: 1024 + downsample_scales: [4, 4, 4, 4] + inner_kernel_sizes: null + groups: [4, 4, 4, 4] + strides: null + paddings: null + activation: LeakyReLU + activation_params: { negative_slope: 0.3 } +mpd: + in_channels: 1 + out_channels: 1 + periods: [2, 3, 5, 7, 11] + n_layers: 5 + kernel_size: 5 + stride: 3 + filters: 8 + filter_scales: 4 + max_filters: 1024 + activation: LeakyReLU + activation_params: { negative_slope: 0.3 } + norm: weight_norm + +# data hyperparameters +dataset: + batch_size: 16 + num_workers: 10 + segment_duration: 1 + train: + num_samples: 500000 + valid: + num_samples: 10000 + evaluate: + batch_size: 16 + num_samples: 10000 + segment_duration: 10 + + generate: + batch_size: 16 + num_samples: 50 + segment_duration: 30 + +# solver hyperparameters +evaluate: + every: 10 + num_workers: 5 + metrics: + visqol: false + sisnr: true +generate: + every: 10 + num_workers: 5 + audio: + sample_rate: ${sample_rate} + +# checkpointing schedule +checkpoint: + save_last: true + save_every: 25 + keep_last: 10 + keep_every_states: null + + + +# optimization hyperparameters +optim: + epochs: 300 + updates_per_epoch: 2000 + lr: 5e-5 + max_norm: 3.0 + optimizer: adam + adam: + betas: [0.5, 0.9] + weight_decay: 0. + ema: + use: true # whether to use EMA or not + updates: 1 # update at every step + device: ${device} # device for EMA, can be put on GPU if more frequent updates + decay: 0.99 # EMA decay value, if null, no EMA is used + + +schedule: + lr_scheduler: "cosine" + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 diff --git a/config/solver/watermark/robustness.yaml b/config/solver/watermark/robustness.yaml new file mode 100644 index 00000000..5cf6bb49 --- /dev/null +++ b/config/solver/watermark/robustness.yaml @@ -0,0 +1,15 @@ +# @package __global__ +defaults: + - watermark/default + - /augmentations/default + - /model: watermark/default + - _self_ + +sample_rate: 16000 +channels: 1 + +balancer: + balance_grads: true + ema_decay: 0.999 + per_batch_item: true + total_norm: 1. diff --git a/docs/WATERMARKING.md b/docs/WATERMARKING.md new file mode 100644 index 00000000..425e204f --- /dev/null +++ b/docs/WATERMARKING.md @@ -0,0 +1,40 @@ +# AudioSeal: Proactive Localized Watermarking + +AudioCraft provides the training code and models for AudioSeal, a method for speech localized watermarking [Proactive Detection of Voice Cloning with Localized Watermarking][arxiv], with state-of-the-art robustness and detector speed. It jointly trains a generator that embeds a watermark in the audio, and a detector that detects the watermarked fragments in longer audios, even in the presence of editing. + +## Installation and setup + +Make sure to install audiocraft version `1.4.0a1` or later, and with the `[wm]` extra (see [README](../README.md)). +Alternatively, you can just install audioseal yourself. To install AudioSeal, follow [Installation](https://github.com/facebookresearch/audioseal) guidelines in the AudioSeal repo. + +_NOTE_: Since we use AAC augmentation in our training loop, you need to install ffmpeg, or it will not work (See Section "Installation" in [README](../README.md)). + +Make sure you follow [steps for basic training setup](TRAINING.md) before starting. + +## API +Check the [Github repository](https://github.com/facebookresearch/audioseal) for more details. + +## Training + +The [WatermarkSolver](../audiocraft/solvers/watermark.py) implements the AudioSeal's training pipeline. It joins the generator and detector that wrap +`audioseal.AudioSealWM` and `audioseal.AudioSealDetector` respectively. For the training recipe, see [config/solver/watermark/robustness.yaml](../config/solver/watermark/robustness.yaml). + +For illustration, we use the three example audios in `datasets`, with datasourc definition in [dset/audio/example.yaml](../config/dset/audio/example.yaml) (Please read [DATASET](./DATASETS.md) to understand AudioCraft's dataset structure.) + +To run the Watermarking training pipeline locally: + +```bash +dora run solver=watermark/robustness dset=audio/example +``` + +you can override model / experiment parameters here directly like: + +```bash +dora run solver=watermark/robustness dset=audio/example sample_rate=24000 +``` + +If you want to run in debug mode: + +```bash +python3 -m pdb -c c -m dora run solver=watermark/robustness dset=audio/example +``` diff --git a/mypy.ini b/mypy.ini index 6ab60f2f..b0b6c505 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,4 +1,4 @@ [mypy] -[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub,transformers,dac.*] +[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,flashy.*,torchmetrics.*,hydra,pesq,demucs.*,huggingface_hub,transformers,dac.*] ignore_missing_imports = True diff --git a/requirements.txt b/requirements.txt index a30655e1..effd8c7e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,9 +17,12 @@ transformers>=4.31.0 # need Encodec there. xformers<0.0.23 demucs librosa +soundfile gradio torchmetrics encodec protobuf torchvision==0.16.0 torchtext==0.16.0 +pesq +pystoi diff --git a/setup.py b/setup.py index 9d844ea9..83c40d6c 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ install_requires=REQUIRED, extras_require={ 'dev': ['coverage', 'flake8', 'mypy', 'pdoc3', 'pytest'], + 'wm': ['audioseal'], }, packages=[p for p in find_packages() if p.startswith('audiocraft')], package_data={'audiocraft': ['py.typed']}, diff --git a/tests/data/test_audio_utils.py b/tests/data/test_audio_utils.py index 0480671b..8f24e9b2 100644 --- a/tests/data/test_audio_utils.py +++ b/tests/data/test_audio_utils.py @@ -12,6 +12,8 @@ _clip_wav, convert_audio_channels, convert_audio, + f32_pcm, + i16_pcm, normalize_audio ) from ..common_utils import get_batch_white_noise @@ -78,6 +80,14 @@ def test_convert_audio_resample(self): out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr) assert torch.allclose(out, out_j) + def test_convert_pcm(self): + b, c, dur = 2, 1, 4. + sr = 3 + i16_audio = torch.randint(-2**15, 2**15, (b, c, int(sr * dur)), dtype=torch.int16) + f32_audio = f32_pcm(i16_audio) + another_i16_audio = i16_pcm(f32_audio) + assert torch.allclose(i16_audio, another_i16_audio) + class TestNormalizeAudio: diff --git a/tests/losses/test_losses.py b/tests/losses/test_losses.py index b6681e12..1b9120b7 100644 --- a/tests/losses/test_losses.py +++ b/tests/losses/test_losses.py @@ -15,6 +15,9 @@ SISNR, STFTLoss, ) +from audiocraft.losses.loudnessloss import TFLoudnessRatio +from audiocraft.losses.wmloss import WMMbLoss +from tests.common_utils.wav_utils import get_white_noise def test_mel_l1_loss(): @@ -76,3 +79,25 @@ def test_stft_loss(): loss = mrstft(t1, t2) assert isinstance(loss, torch.Tensor) + + +def test_wm_loss(): + N, nbits, T = 2, 16, random.randrange(1000, 100_000) + positive = torch.randn(N, 2 + nbits, T) + t2 = torch.randn(N, 1, T) + message = torch.randn(N, nbits) + + wmloss = WMMbLoss(0.3, "mse") + loss = wmloss(positive, None, t2, message) + + assert isinstance(loss, torch.Tensor) + + +def test_loudness_loss(): + sr = 16_000 + duration = 1.0 + wav = get_white_noise(1, int(sr * duration)).unsqueeze(0) + tflrloss = TFLoudnessRatio(sample_rate=sr, n_bands=1) + + loss = tflrloss(wav, wav) + assert isinstance(loss, torch.Tensor) diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 00000000..0952fcc3 --- /dev/null +++ b/tests/metrics/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/metrics/test_pesq.py b/tests/metrics/test_pesq.py new file mode 100644 index 00000000..fb1a0eac --- /dev/null +++ b/tests/metrics/test_pesq.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import julius +import pesq +import torch +from audiocraft.metrics.pesq import PesqMetric +from ..common_utils import TempDirMixin, get_batch_white_noise + + +def tensor_pesq(y_pred: torch.Tensor, y: torch.Tensor, sr: int): + # pesq returns error if no speech is detected, so we catch it + if sr != 16000: + y_pred = julius.resample_frac(y_pred, sr, 16000) + y = julius.resample_frac(y, sr, 16000) + P, n = 0, 0 + for ii in range(y_pred.size(0)): + try: # torchmetrics crashes when there is one error in the batch so doing it manually.. + P += pesq.pesq(16000, y[ii, 0].cpu().numpy(), y_pred[ii, 0].cpu().numpy()) + n += 1 + except pesq.NoUtterancesError: # this error can append when the sample don't contain speech + pass + p = P / n if n != 0 else 0.0 + return p + + +class TestPesq(TempDirMixin): + + def test(self): + sample_rate = 16_000 + duration = 20 + channel = 1 + bs = 10 + wavs = get_batch_white_noise(bs, channel, int(sample_rate * duration)) + + pesq_metric = PesqMetric(sample_rate=sample_rate) + pesq1 = pesq_metric(wavs, wavs) + print(f"Pesq between 2 identical white noises: {pesq1}") + assert pesq1 > 1 + + pesq2 = tensor_pesq(wavs, wavs, 16000) + assert torch.allclose(pesq1, torch.tensor(pesq2)) diff --git a/tests/models/test_watermark.py b/tests/models/test_watermark.py new file mode 100644 index 00000000..ff1422a8 --- /dev/null +++ b/tests/models/test_watermark.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from audiocraft.models.watermark import AudioSeal +from tests.common_utils.wav_utils import get_white_noise + + +class TestWatermarkModel: + + def test_base(self): + sr = 16_000 + duration = 1.0 + wav = get_white_noise(1, int(sr * duration)).unsqueeze(0) + wm = AudioSeal.get_pretrained(name="base") + + secret_message = torch.randint(0, 2, (1, 16), dtype=torch.int32) + watermarked_wav = wm(wav, message=secret_message, sample_rate=sr, alpha=0.8) + result = wm.detect_watermark(watermarked_wav) + + detected = ( + torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1] + ) + detect_prob = detected.cpu().item() # type: ignore + + assert detect_prob >= 0.0 diff --git a/tests/utils/test_audio_effects.py b/tests/utils/test_audio_effects.py new file mode 100644 index 00000000..e4e1b44d --- /dev/null +++ b/tests/utils/test_audio_effects.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from omegaconf import OmegaConf + +from audiocraft.utils.audio_effects import AudioEffects, get_audio_effects, select_audio_effects + +from ..common_utils import get_batch_white_noise + + +class TestAudioEffect: + SR = 16_000 + + @pytest.fixture(autouse=True) + def audio_effects(self): + cfg = { + "audio_effects": { + "speed": { + "sample_rate": self.SR, + "speed_range": [0.8, 1.2] + }, + "updownresample": { + "sample_rate": self.SR, + "intermediate_freq": 32_000, + }, + "echo": { + "sample_rate": self.SR, + "volume_range": [0.1, 0.5], + }, + "random_noise": { + "noise_std": 0.001, + }, + "pink_noise": { + "noise_std": 0.01, + }, + "lowpass_filter": { + "sample_rate": self.SR, + "cutoff_freq": 5_000, + }, + "highpass_filter": { + "sample_rate": self.SR, + "cutoff_freq": 500, + }, + "bandpass_filter": { + "sample_rate": self.SR, + "cutoff_freq_low": 300, + "cutoff_freq_high": 8_000, + }, + "smooth": { + "window_size_range": [2, 10], + }, + "boost_audio": { + "amount": 20, + }, + "duck_audio": { + "amount": 20, + }, + "mp3_compression": { + "sample_rate": self.SR, + "bitrate": "128k", + }, + "aac_compression": { + "sample_rate": self.SR, + "bitrate": "128k", + "lowpass_freq": None, + } + } + } + weights = { + "speed": 2.0, + "updownresample": 0.4, + "echo": 1.0, + "random_noise": 3.0, + "pink_noise": 0.5, + "lowpass_filter": 4.0, + "highpass_filter": 5.0, + "bandpass_filter": 6.0, + "smooth": 1.0, + } + return get_audio_effects(OmegaConf.structured(cfg)), weights + + def test_select_empty_effects(self): + effects = select_audio_effects({}) + assert "identity" in effects and effects["identity"] == AudioEffects.identity + + def test_select_wrong_strategy(self): + with pytest.raises(ValueError): + _ = select_audio_effects( + audio_effects={}, + mode="some invalid mode" + ) + + def test_selection(self, audio_effects): + effect_cfg, weights = audio_effects + effects = select_audio_effects( + audio_effects=effect_cfg, + weights=weights, + mode="weighted" + ) + b, c, t = 2, 4, 32000 + audio = get_batch_white_noise(b, c, t) + for effect_name, effect_func in effects.items(): + modified_audio = effect_func(audio) + # It is quite hard to unit test the content of the modified_audio though + if effect_name == "speed": # Speeding up audio should return in more frames + assert modified_audio.size()[-1] > audio.size()[-1] + else: + assert modified_audio.size() == audio.size(), f"Wrong dimension in {effect_name}"