forked from yeyupiaoling/PaddlePaddle-DeepSpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
104 lines (94 loc) · 5.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import functools
import io
import os
from datetime import datetime
from model_utils.model import DeepSpeech2Model
from data_utils.data import DataGenerator
from utils.utility import add_arguments, print_arguments, get_data_len
import paddle
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "是否使用GPU训练")
add_arg('batch_size', int, 16, "训练每一批数据的大小")
add_arg('num_epoch', int, 50, "训练的轮数")
add_arg('num_conv_layers', int, 2, "卷积层数量")
add_arg('num_rnn_layers', int, 3, "循环神经网络的数量")
add_arg('rnn_layer_size', int, 1024, "循环神经网络的大小")
add_arg('learning_rate', float, 5e-4, "初始学习率")
add_arg('min_duration', float, 0.5, "最短的用于训练的音频长度")
add_arg('max_duration', float, 20.0, "最长的用于训练的音频长度")
add_arg('test_off', bool, False, "是否关闭测试")
add_arg('resume_model', str, None, "恢复训练,当为None则不使用预训练模型")
add_arg('pretrained_model', str, None, "使用预训练模型的路径,当为None是不使用预训练模型")
add_arg('train_manifest', str, './dataset/manifest.train', "训练的数据列表")
add_arg('test_manifest', str, './dataset/manifest.test', "测试的数据列表")
add_arg('mean_std_path', str, './dataset/mean_std.npz', "数据集的均值和标准值的npy文件路径")
add_arg('vocab_path', str, './dataset/zh_vocab.txt', "数据集的词汇表文件路径")
add_arg('output_model_dir', str, './models/param', "保存训练模型的文件夹")
add_arg('augment_conf_path', str, './conf/augmentation.json', "数据增强的配置文件,为json格式")
add_arg('shuffle_method', str, 'batch_shuffle_clipped', "打乱数据的方法", choices=['instance_shuffle', 'batch_shuffle', 'batch_shuffle_clipped'])
args = parser.parse_args()
# 训练模型
def train():
# 是否使用GPU
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
# 获取训练数据生成器
augmentation_config = io.open(args.augment_conf_path, mode='r', encoding='utf8').read() if args.augment_conf_path is not None else '{}'
train_generator = DataGenerator(vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
augmentation_config=augmentation_config,
max_duration=args.max_duration,
min_duration=args.min_duration,
place=place)
if args.resume_model:
try:
pre_epoch = os.path.basename(args.resume_model).split('.')[0]
train_generator.epoch = int(pre_epoch)
except:
pass
# 获取测试数据生成器
test_generator = DataGenerator(vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
keep_transcription_text=True,
place=place,
is_training=False)
# 获取训练数据
train_batch_reader = train_generator.batch_reader_creator(manifest_path=args.train_manifest,
batch_size=args.batch_size,
shuffle_method=args.shuffle_method)
# 获取测试数据
test_batch_reader = test_generator.batch_reader_creator(manifest_path=args.test_manifest,
batch_size=args.batch_size,
shuffle_method=None)
# 获取DeepSpeech2模型
ds2_model = DeepSpeech2Model(vocab_size=train_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
place=place,
pretrained_model=args.pretrained_model,
resume_model=args.resume_model,
output_model_dir=args.output_model_dir,
vocab_list=train_generator.vocab_list)
# 获取训练数据数量
train_num_samples = get_data_len(args.train_manifest, args.max_duration, args.min_duration)
print("[%s] 训练数据数量:%d\n" % (datetime.now(), train_num_samples))
# 获取训测试据数量
test_num_samples = get_data_len(args.test_manifest, args.max_duration, args.min_duration)
print("[%s] 测试数据数量:%d\n" % (datetime.now(), test_num_samples))
# 开始训练
ds2_model.train(train_batch_reader=train_batch_reader,
dev_batch_reader=test_batch_reader,
learning_rate=args.learning_rate,
gradient_clipping=400,
batch_size=args.batch_size,
train_num_samples=train_num_samples,
test_num_samples=test_num_samples,
num_epoch=args.num_epoch,
test_off=args.test_off)
def main():
print_arguments(args)
train()
if __name__ == '__main__':
main()