Skip to content

Commit

Permalink
fix(xtts): overwrite model_args in GPTTrainerConfig
Browse files Browse the repository at this point in the history
This makes it possible to use --continue_path with XTTS
  • Loading branch information
eginhard committed Jan 19, 2025
1 parent 420a02f commit 0890c4b
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,6 @@
logger = logging.getLogger(__name__)


@dataclass
class GPTTrainerConfig(XttsConfig):
lr: float = 5e-06
training_seed: int = 1
optimizer_wd_only_on_weights: bool = False
weighted_loss_attrs: dict = field(default_factory=lambda: {})
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
test_sentences: list[dict] = field(default_factory=lambda: [])


@dataclass
class GPTArgs(XttsArgs):
min_conditioning_length: int = 66150
Expand All @@ -51,6 +41,17 @@ class GPTArgs(XttsArgs):
vocoder: str = "" # overide vocoder key on the config to avoid json write issues


@dataclass
class GPTTrainerConfig(XttsConfig):
lr: float = 5e-06
training_seed: int = 1
optimizer_wd_only_on_weights: bool = False
weighted_loss_attrs: dict = field(default_factory=lambda: {})
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
test_sentences: list[dict] = field(default_factory=lambda: [])
model_args: GPTArgs = field(default_factory=GPTArgs)


def callback_clearml_load_save(operation_type, model_info):
# return None means skip the file upload/log, returning model_info will continue with the log/upload
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
Expand Down

0 comments on commit 0890c4b

Please sign in to comment.