From 0890c4b8ef577121425dfbb65d62d86012bdfa05 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sun, 19 Jan 2025 20:33:55 +0100 Subject: [PATCH] fix(xtts): overwrite model_args in GPTTrainerConfig This makes it possible to use --continue_path with XTTS --- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 0a8af2f950..a8611b594f 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -23,16 +23,6 @@ logger = logging.getLogger(__name__) -@dataclass -class GPTTrainerConfig(XttsConfig): - lr: float = 5e-06 - training_seed: int = 1 - optimizer_wd_only_on_weights: bool = False - weighted_loss_attrs: dict = field(default_factory=lambda: {}) - weighted_loss_multipliers: dict = field(default_factory=lambda: {}) - test_sentences: list[dict] = field(default_factory=lambda: []) - - @dataclass class GPTArgs(XttsArgs): min_conditioning_length: int = 66150 @@ -51,6 +41,17 @@ class GPTArgs(XttsArgs): vocoder: str = "" # overide vocoder key on the config to avoid json write issues +@dataclass +class GPTTrainerConfig(XttsConfig): + lr: float = 5e-06 + training_seed: int = 1 + optimizer_wd_only_on_weights: bool = False + weighted_loss_attrs: dict = field(default_factory=lambda: {}) + weighted_loss_multipliers: dict = field(default_factory=lambda: {}) + test_sentences: list[dict] = field(default_factory=lambda: []) + model_args: GPTArgs = field(default_factory=GPTArgs) + + def callback_clearml_load_save(operation_type, model_info): # return None means skip the file upload/log, returning model_info will continue with the log/upload # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size