-
Notifications
You must be signed in to change notification settings - Fork 1
/
RoshamboNet.py
60 lines (50 loc) · 1.77 KB
/
RoshamboNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from tensorflow.keras import layers
def RoshamboNetPruningPolicy(pruning_policy):
if pruning_policy["mode"] == "fixed":
return pruning_policy["target"]
else:
raise AttributeError
def RoshamboNet(
input_tensor,
classes=4,
include_top=True,
pooling="max",
num_3x3_blocks=3,
**kwargs
):
if pooling == "max":
Pooling = layers.MaxPooling2D
elif pooling == "avg":
Pooling = layers.AveragePooling2D
else:
raise ValueError
# Block 1
num_out_channels = 16
x = layers.Conv2D(num_out_channels, (5, 5),
activation='relu',
padding='valid',
name='layer1')(input_tensor)
x = Pooling((2, 2), strides=(2, 2), name='pool1')(x)
for block_idx in range(num_3x3_blocks):
num_out_channels = num_out_channels * 2
conv_name = "layer{}".format(block_idx + 2)
pool_name = "pool{}".format(block_idx + 2)
# Block 3x3
x = layers.Conv2D(num_out_channels, (3, 3),
activation='relu',
padding='valid',
name=conv_name)(x)
x = Pooling((2, 2), strides=(2, 2), name=pool_name)(x)
conv_name = "layer{}".format(num_3x3_blocks + 2)
pool_name = "pool{}".format(num_3x3_blocks + 2)
x = layers.Conv2D(num_out_channels, (1, 1),
activation='relu',
padding='valid',
name=conv_name)(x)
x = Pooling((2, 2), strides=(2, 2), name=pool_name)(x)
# Block FC
fc_name = "layer{}".format(num_3x3_blocks + 2 + 1)
if include_top:
x = layers.Flatten()(x)
x = layers.Dense(classes, name=fc_name, activation="softmax")(x)
return x