-
Notifications
You must be signed in to change notification settings - Fork 10
/
pt_layers.py
158 lines (137 loc) · 5.56 KB
/
pt_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
from torch import nn
from torch.nn import functional as F
import lc0_az_policy_map
class SqueezeExcitation(nn.Module):
# Tested as equivalent to the TF layer
def __init__(self, channels, se_ratio):
super().__init__()
self.se_ratio = se_ratio
self.pooler = nn.AdaptiveAvgPool2d(1)
self.squeeze = nn.Sequential(
nn.Linear(channels, int(channels // se_ratio), bias=False), nn.ReLU()
)
self.expand = nn.Linear(int(channels // se_ratio), channels * 2, bias=False)
self.channels = channels
nn.init.xavier_normal_(self.squeeze[0].weight)
nn.init.xavier_normal_(self.expand.weight)
def forward(self, x):
pooled = self.pooler(x).view(-1, self.channels)
squeezed = self.squeeze(pooled)
expanded = self.expand(squeezed).view(-1, self.channels * 2, 1, 1)
gammas, betas = torch.split(expanded, self.channels, dim=1)
gammas = torch.sigmoid(gammas)
return gammas * x + betas
class ConvBlock(nn.Module):
def __init__(self, input_channels, output_channels, filter_size):
super().__init__()
self.conv_layer = nn.Conv2d(
input_channels, output_channels, filter_size, bias=False, padding="same"
)
self.conv_layer.weight.clamp_weights = True
self.batchnorm = nn.BatchNorm2d(output_channels, affine=True)
nn.init.xavier_normal_(self.conv_layer.weight)
def forward(self, inputs):
out = self.conv_layer(inputs)
out = self.batchnorm(out.float())
return F.relu(out)
class ResidualBlock(nn.Module):
def __init__(self, channels, se_ratio):
super().__init__()
self.conv1 = nn.Conv2d(
channels,
channels,
3,
bias=False,
padding="same",
)
self.conv1.weight.clamp_weights = True
self.batch_norm = nn.BatchNorm2d(channels, affine=True)
self.conv2 = nn.Conv2d(
channels,
channels,
3,
bias=False,
padding="same",
)
self.conv2.weight.clamp_weights = True
nn.init.xavier_normal_(self.conv1.weight)
nn.init.xavier_normal_(self.conv2.weight)
self.squeeze_excite = SqueezeExcitation(channels, se_ratio)
def forward(self, inputs):
out1 = self.conv1(inputs)
out1 = F.relu(self.batch_norm(out1.float()))
out2 = self.conv2(out1)
out2 = self.squeeze_excite(out2)
return F.relu(inputs + out2)
class ConvolutionalPolicyHead(nn.Module):
def __init__(self, num_filters):
super().__init__()
self.conv_block = ConvBlock(
filter_size=3, input_channels=num_filters, output_channels=num_filters
)
# No l2_reg on the final convolution, because it's not going to be followed by a batchnorm
self.conv = nn.Conv2d(num_filters, 80, 3, bias=True, padding="same")
nn.init.xavier_normal_(self.conv.weight)
self.fc1 = nn.parameter.Parameter(
torch.tensor(
lc0_az_policy_map.make_map(), requires_grad=False, dtype=torch.float32
),
requires_grad=False,
)
def forward(self, inputs):
flow = self.conv_block(inputs)
flow = self.conv(flow)
h_conv_pol_flat = flow.reshape(-1, 80 * 8 * 8)
return h_conv_pol_flat @ self.fc1.type(h_conv_pol_flat.dtype)
class DensePolicyHead(nn.Module):
def __init__(self, input_dim, hidden_dim=128):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
# No l2_reg on the final layer, because it's not going to be followed by a batchnorm
self.fc_final = nn.Linear(hidden_dim, 1858)
nn.init.xavier_normal_(self.fc1.weight)
nn.init.xavier_normal_(self.fc_final.weight)
def forward(self, inputs):
# Flatten input before proceeding
inputs = inputs.reshape(inputs.shape[0], -1)
out = F.relu(self.fc1(inputs))
return self.fc_final(out)
class ConvolutionalValueOrMovesLeftHead(nn.Module):
def __init__(self, input_dim, output_dim, num_filters, hidden_dim, relu):
super().__init__()
self.num_filters = num_filters
self.conv_block = ConvBlock(
input_channels=input_dim, filter_size=1, output_channels=num_filters
)
# No l2_reg on the final layers, because they're not going to be followed by a batchnorm
self.fc2 = nn.Linear(self.num_filters * 8 * 8, hidden_dim, bias=True)
self.fc_out = nn.Linear(hidden_dim, output_dim, bias=True)
self.relu = relu
nn.init.xavier_normal_(self.fc_out.weight)
def forward(self, inputs):
flow = self.conv_block(inputs)
flow = flow.reshape(-1, self.num_filters * 8 * 8)
flow = self.fc2(flow)
flow = F.relu(flow)
flow = self.fc_out(flow)
if self.relu:
flow = F.relu(flow)
return flow
class DenseValueOrMovesLeftHead(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim, relu):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_out = nn.Linear(hidden_dim, output_dim)
self.relu = relu
nn.init.xavier_normal_(self.fc1.weight)
nn.init.xavier_normal_(self.fc_out.weight)
def forward(self, inputs):
if inputs.dim() > 2:
# Flatten input before proceeding
inputs = inputs.reshape(inputs.shape[0], -1)
flow = F.relu(self.fc1(inputs))
flow = self.fc_out(flow)
if self.relu:
flow = F.relu(flow)
return flow