-
Notifications
You must be signed in to change notification settings - Fork 3
/
model_A.py
28 lines (19 loc) · 823 Bytes
/
model_A.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from keras import applications
from keras.engine import Model
from keras.layers import Flatten, Dense, BatchNormalization, Activation, Dropout
class Model_A(object):
def __init__(self) -> None:
super().__init__()
def create_model(self, image_width, image_height, num_classes):
model = applications.ResNet50(weights="imagenet", include_top=False, pooling="avg", input_shape=(image_width, image_height, 3))
# freezing all layers
# for layer in model.layers:
# layer.trainable = False
x = model.output
# x = Flatten()(x) # not needed anymore?
x = Dropout(0.8)(x)
x = Dense(num_classes)(x)
x = BatchNormalization()(x)
x = Activation('softmax')(x)
predictions = x
return Model(model.input, predictions)