forked from soprof/face-identification-tpe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path1_train_cnn.py
82 lines (65 loc) · 2.02 KB
/
1_train_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import json
import os.path
import numpy as np
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import ImageDataGenerator
from sklearn.preprocessing import OneHotEncoder
from cnn import build_cnn
WEIGHTS_DIR = 'data/weights/'
NB_EPOCH = 100
BATCH_SIZE = 32
AUGMENTATION = True
if not os.path.exists(WEIGHTS_DIR):
os.mkdir(WEIGHTS_DIR)
oh = OneHotEncoder()
train_x, train_y = np.load('data/train_x.npy'), np.load('data/train_y.npy')
test_x, test_y = np.load('data/test_x.npy'), np.load('data/test_y.npy')
n_subjects = len(set(train_y))
n_train = train_x.shape[0]
n_test = test_x.shape[0]
oh.fit(train_y.reshape(-1, 1))
train_y = oh.transform(train_y.reshape(-1, 1)).todense()
test_y = oh.transform(test_y.reshape(-1, 1)).todense()
print('n_train: {}'.format(n_train))
print('n_test: {}'.format(n_test))
print('n_subjects: {}'.format(n_subjects))
with open('data/meta.json', 'w') as f:
json.dump({'n_subjects': n_subjects}, f)
mc1 = ModelCheckpoint(
WEIGHTS_DIR + 'weights.best.h5',
monitor='val_accuracy',
verbose=0,
save_best_only=True,
mode='max'
)
model = build_cnn(227, n_subjects)
model.summary()
weights_to_load = WEIGHTS_DIR + 'weights.best.h5'
if os.path.exists(weights_to_load):
model.load_weights(weights_to_load)
try:
if AUGMENTATION:
data_gen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
zoom_range=0.1,
horizontal_flip=True
)
model.fit_generator(
data_gen.flow(train_x, train_y, batch_size=BATCH_SIZE),
steps_per_epoch=train_x.shape[0],
epochs=NB_EPOCH,
validation_data=(test_x, test_y),
callbacks=[mc1]
)
else:
model.fit(
train_x, train_y,
batch_size=BATCH_SIZE,
epochs=NB_EPOCH,
validation_data=(test_x, test_y),
callbacks=[mc1]
)
finally:
model.save_weights(WEIGHTS_DIR + 'weights.finally.h5')