-
Notifications
You must be signed in to change notification settings - Fork 4
/
MLproject
104 lines (100 loc) · 4.2 KB
/
MLproject
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
name: notes_generator
conda_env: conda.yaml
entry_points:
fetch:
parameters:
data_dir: { type: string, default: data }
command: >
$SHELL scripts/fetch.sh {data_dir}
preprocess:
parameters:
data_dir: { type: path, default: data }
command: >
$SHELL scripts/preprocess.sh {data_dir}
train:
parameters:
app_name: { type: string, default: STEPMANIA }
model_dir: { type: path, default: data/onsets_model }
score_dir: { type: path, default: data/train_data/score_onsets_1 }
mel_dir: { type: path, default: data/train_data/mel_log }
resume: { type: float, default: 0 }
epochs: { type: float, default: 35 }
batch: { type: float, default: 32 }
lr_start: { type: float, default: 9e-4 }
lr_end: { type: float, default: 9e-4 }
lr_scheduler: { type: string, default: CosineAnnealingLR}
seq_length: { type: float, default: 20480 }
aug_count: { type: float, default: 0 }
num_layers: { type: float, default: 2 }
onset_weight: { type: float, default: 64 }
dropout: { type: float, default: 0.3 }
fuzzy_width: { type: float, default: 5 }
fuzzy_scale: { type: float, default: 0.2 }
with_beats: { type: float, default: 1 }
difficulties: { type: string, default: "" }
send_model: {type: float, default: 0}
n_saved_model: {type: float, default: 20}
log_artifacts: { type: string, default: "" }
augmentation_setting: { type: string, default: loader_aug_config.yaml }
warmup_steps: {type: float, default: 400}
weight_decay: {type: float, default: 0}
is_parallel: {type: float, default: 0}
eta_min: {type: float, default: 1e-6}
conv_stack_type: { type: string, default: "v7" }
pretrained_model_path: {type: string, default: ""}
rnn_dropout: {type: float, default: 0.1}
command: >
PYTHONPATH=. python scripts/onsets_train.py --app_name={app_name}
--model_dir={model_dir} --score_dir={score_dir} --mel_dir={mel_dir}
--resume={resume} --epochs={epochs} --batch={batch}
--lr_start={lr_start} --lr_end={lr_end} --seq_length={seq_length}
--aug_count={aug_count} --num_layers={num_layers}
--onset_weight={onset_weight} --dropout={dropout}
--fuzzy_width={fuzzy_width} --fuzzy_scale={fuzzy_scale}
--with_beats={with_beats} --difficulties={difficulties}
--send_model={send_model} --n_saved_model={n_saved_model}
--log_artifacts={log_artifacts}
--augmentation_setting={augmentation_setting}
--lr_scheduler={lr_scheduler}
--warmup_steps={warmup_steps}
--weight_decay={weight_decay}
--is_parallel={is_parallel}
--eta_min={eta_min}
--conv_stack_type={conv_stack_type}
--pretrained_model_path={pretrained_model_path}
--rnn_dropout={rnn_dropout}
test:
parameters:
model_dir: { type: path, default: data/onsets_model }
app_name: { type: string, default: STEPMANIA }
score_dir: { type: path, default: data/train_data/score_onsets_1 }
mel_dir: { type: path, default: data/train_data/mel_log }
seq_length: { type: float, default: 20480 }
batch: { type: float, default: 1 }
num_layers: { type: float, default: 2 }
onset_weight: { type: float, default: 64 }
with_beats: { type: float, default: 1 }
conv_stack_type: { type: string, default: "v7" }
csv_save_dir: { type: path, default: data/model_test_result }
command: >
PYTHONPATH=. python scripts/model_test.py {model_dir}
--app_name={app_name} --score_dir={score_dir}
--mel_dir={mel_dir} --seq_length={seq_length}
--batch={batch}
--num_layers={num_layers}
--onset_weight={onset_weight}
--with_beats={with_beats}
--conv_stack_type={conv_stack_type}
--csv_save_dir={csv_save_dir}
generate:
parameters:
model_path: { type: path, default: pretrained_model/model.pth }
audio_path: { type: path }
midi_save_path: { type: path, default: data/midi_out }
bpm_info: { type: string }
command: >
PYTHONPATH=. python scripts/prediction_stepmania.py
--onset_model_path={model_path}
--audio_path={audio_path}
--midi_save_path={midi_save_path}
--bpm_info={bpm_info}