forked from elitehacker0802/Train_TTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_training_pipeline.py
139 lines (121 loc) · 4.34 KB
/
run_training_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import os
import random
import sys
import torch
from TrainingInterfaces.TrainingPipelines.Avocodo_combined import run as hifi_codo
from TrainingInterfaces.TrainingPipelines.BigVGAN_combined import run as bigvgan
from TrainingInterfaces.TrainingPipelines.FastSpeech2Embedding_IntegrationTest import (
run as fs_integration_test,
)
from TrainingInterfaces.TrainingPipelines.GST_FastSpeech2 import run as embedding
from TrainingInterfaces.TrainingPipelines.StochasticToucanTTS_Nancy import (
run as nancystoch,
)
from TrainingInterfaces.TrainingPipelines.ToucanTTS_IntegrationTest import (
run as tt_integration_test,
)
from TrainingInterfaces.TrainingPipelines.ToucanTTS_MetaCheckpoint import run as meta
from TrainingInterfaces.TrainingPipelines.ToucanTTS_Nancy import run as nancy
from TrainingInterfaces.TrainingPipelines.finetuning_example import (
run as fine_tuning_example,
)
from TrainingInterfaces.TrainingPipelines.custom import run as custom
from TrainingInterfaces.TrainingPipelines.pretrain_aligner import run as aligner
pipeline_dict = {
# the finetuning example
"custom": custom,
"fine_ex": fine_tuning_example,
# integration tests
"fs_it": fs_integration_test,
"tt_it": tt_integration_test,
# regular ToucanTTS pipelines
"nancy": nancy,
"nancystoch": nancystoch,
"meta": meta,
# training vocoders (not recommended, best to use provided checkpoint)
"avocodo": hifi_codo,
"bigvgan": bigvgan,
# training the GST embedding jointly with FastSpeech 2 on expressive data (not recommended, best to use provided checkpoint)
"embedding": embedding,
# training the aligner from scratch (not recommended, best to use provided checkpoint)
"aligner": aligner,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Training with the IMS Toucan Speech Synthesis Toolkit"
)
parser.add_argument(
"pipeline", choices=list(pipeline_dict.keys()), help="Select pipeline to train."
)
parser.add_argument(
"--gpu_id",
type=str,
help="Which GPU to run on. If not specified runs on CPU, but other than for integration tests that doesn't make much sense.",
default="cpu",
)
parser.add_argument(
"--resume_checkpoint",
type=str,
help="Path to checkpoint to resume from.",
default=None,
)
parser.add_argument(
"--resume",
action="store_true",
help="Automatically load the highest checkpoint and continue from there.",
default=False,
)
parser.add_argument(
"--finetune",
action="store_true",
help="Whether to fine-tune from the specified checkpoint.",
default=False,
)
parser.add_argument(
"--model_save_dir",
type=str,
help="Directory where the checkpoints should be saved to.",
default=None,
)
parser.add_argument(
"--wandb",
action="store_true",
help="Whether to use weights and biases to track training runs. Requires you to run wandb login and place your auth key before.",
default=False,
)
parser.add_argument(
"--wandb_resume_id",
type=str,
help="ID of a stopped wandb run to continue tracking",
default=None,
)
args = parser.parse_args()
if args.finetune and args.resume_checkpoint is None and not args.resume:
print("Need to provide path to checkpoint to fine-tune from!")
sys.exit()
if args.gpu_id == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = ""
device = torch.device("cpu")
print(
f"No GPU specified, using CPU. Training will likely not work without GPU."
)
else:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu_id}"
device = torch.device("cuda")
print(
f"Making GPU {os.environ['CUDA_VISIBLE_DEVICES']} the only visible device."
)
torch.manual_seed(131714)
random.seed(131714)
torch.random.manual_seed(131714)
pipeline_dict[args.pipeline](
gpu_id=args.gpu_id,
resume_checkpoint=args.resume_checkpoint,
resume=args.resume,
finetune=args.finetune,
model_dir=args.model_save_dir,
use_wandb=args.wandb,
wandb_resume_id=args.wandb_resume_id,
)