forked from seujung/WaveNet-gluon
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
28 lines (23 loc) · 983 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import mxnet as mx
import argparse
from trainer import Train
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--mu', type=int, default=128)
parser.add_argument('--n_residue', type=int, default=24)
parser.add_argument('--n_skip', type=int, default=128)
parser.add_argument('--dilation_depth', type=int, default=10)
parser.add_argument('--n_repeat', type=int, default=2)
parser.add_argument('--seq_size', type=int, default=20000)
parser.add_argument('--use_gpu', type=bool, default=True)
parser.add_argument('--generation', type=bool, default=True)
parser.add_argument('--input', type=str, default='parametric-2.wav')
config = parser.parse_args()
trainer = Train(config)
trainer.train()
if (config.generation):
trainer.generation()
if __name__ =="__main__":
main()