Skip to content

Commit

Permalink
assertion for description_condition
Browse files Browse the repository at this point in the history
  • Loading branch information
simonrouard committed Nov 5, 2024
1 parent 3b59cc5 commit 7b5b009
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
15 changes: 7 additions & 8 deletions audiocraft/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import logging
import math
import typing as tp
from copy import deepcopy

import torch
from torch import nn
Expand All @@ -24,7 +23,7 @@
ConditioningProvider,
ConditioningAttributes,
ConditionType,
_drop_text_condition
_drop_description_condition
)
from ..modules.codebooks_patterns import CodebooksPatternProvider
from ..modules.activations import get_activation_fn
Expand Down Expand Up @@ -347,8 +346,8 @@ def _sample_next_token(self,
top_p (float): P for "top-p" sampling.
cfg_coef (float, optional): classifier free guidance coefficient
cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
push the text condition more than the style condition in the case where both text and style
conditions are being used.
Returns:
Expand All @@ -362,7 +361,7 @@ def _sample_next_token(self,
assert isinstance(cfg_conditions, dict)
condition_tensors = cfg_conditions
if condition_tensors:
# Preparing for CFG, predicting conditional text and style, conditional style
# Preparing for CFG, predicting conditional text and style, conditional style
# and unconditional
sequence = torch.cat([sequence, sequence, sequence], dim=0)
all_logits = model(
Expand Down Expand Up @@ -447,8 +446,8 @@ def generate(self,
top_p (float): P for "top-p" sampling.
cfg_coef (float, optional): Classifier-free guidance coefficient.
cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
push the text condition more than the style condition in the case where both text and style
conditions are being used.
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
Expand Down Expand Up @@ -488,7 +487,7 @@ def generate(self,
cfg_conditions = {}
if cfg_coef_beta is not None:
if conditions:
wav_conditions = _drop_text_condition(conditions)
wav_conditions = _drop_description_condition(conditions)
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
conditions = conditions + wav_conditions + null_conditions
tokenized = self.condition_provider.tokenize(conditions)
Expand Down
4 changes: 2 additions & 2 deletions audiocraft/models/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):

def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
top_p: float = 0.0, temperature: float = 1.0,
duration: float = 30.0, cfg_coef: float = 3.0,
duration: float = 30.0, cfg_coef: float = 3.0,
cfg_coef_beta: tp.Optional[float] = None,
two_step_cfg: bool = False, extend_stride: float = 18,):
"""Set the generation parameters for MusicGen.
Expand All @@ -109,7 +109,7 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
cfg_coef_beta (float, optional): beta coefficient in double classifier free guidance.
Should be only used for MusicGen melody if we want to push the text condition more than
the melody conditioning. See paragraph 4.3 in https://arxiv.org/pdf/2407.12563 to understand
the audio conditioning. See paragraph 4.3 in https://arxiv.org/pdf/2407.12563 to understand
double CFG.
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
instead of batching together the two. This has some impact on how things
Expand Down
8 changes: 7 additions & 1 deletion audiocraft/modules/conditioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,23 @@ def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
seek_time=[0] * embed.wav.shape[0],
)

def _drop_text_condition(conditions: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:

def _drop_description_condition(conditions: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
"""Drop the text condition but keep the wav conditon on a list of ConditioningAttributes.
This is useful to calculate l_style in the double classifier free guidance formula.
See paragraph 4.3 in https://arxiv.org/pdf/2407.12563
Args:
conditions (tp.List[ConditioningAttributes]): List of conditions.
"""
# We assert that description and self_wav are in the conditions
for condition in conditions:
assert 'description' in condition.text.keys()
assert 'self_wav' in condition.wav.keys()
return AttributeDropout(p={'text': {'description': 1.0},
'wav': {'self_wav': 0.0}})(conditions)


class Tokenizer:
"""Base tokenizer implementation
(in case we want to introduce more advances tokenizers in the future).
Expand Down

0 comments on commit 7b5b009

Please sign in to comment.