-
Notifications
You must be signed in to change notification settings - Fork 0
/
uncertainty_rnn.py
161 lines (113 loc) · 5.75 KB
/
uncertainty_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
from torch import nn
class BoardGuesserNet(nn.Module):
"""
This is a network for the original emission matrix 19 channels.
For a description of the channels please see the comments under create_blank_emission_matrix in fen_string_convert.py
"""
def __init__(self):
super(BoardGuesserNet, self).__init__()
"""This could be a small backbone"""
self.conv1 = torch.nn.Conv2d(16, 32, 3)
torch.nn.init.xavier_uniform_(self.conv1.weight, gain = 1)
self.relu1 = torch.nn.LeakyReLU()
self.pool1 = torch.nn.MaxPool2d(2)
self.conv2 = torch.nn.Conv2d(32, 64, 2)
torch.nn.init.xavier_uniform_(self.conv2.weight, gain = 1)
self.relu2 = torch.nn.LeakyReLU()
self.pool2 = torch.nn.MaxPool2d(2)
self.flatten = torch.nn.Flatten()
# two hidden lstm states
self.lstm = torch.nn.LSTM(256, 256, 2, batch_first=True)
# recast board to truth
self.dense1 = torch.nn.Linear(256, 640)
torch.nn.init.xavier_uniform_(self.dense1.weight, gain = 1)
self.relu3 = torch.nn.LeakyReLU()
self.dense2 = torch.nn.Linear(640, 18 * 65)
# https://discuss.pytorch.org/t/do-i-need-to-use-softmax-before-nn-crossentropyloss/16739
# self.softmax = torch.nn.Softmax(axis = 1) # we use a sigmoid because it has a range of 0 to 1
def forward(self, state: torch.Tensor):
"""
:param state: <N, 19, 8, 8> tensor representing one game :return: <19, 36, 64> . For the encoding of the
output please see get_truncated_truth_board in fen_string_convert.py
:
"""
conv_1_out = self.conv1(state)
relu_1 = self.relu1(conv_1_out)
conv_2_out = self.conv2(relu_1)
relu_2 = self.relu2(conv_2_out)
pool_2_out = self.pool2(relu_2)
flatten_out = self.flatten(pool_2_out)
flatten_out_batch_size_1 = torch.unsqueeze(flatten_out, 0) # sequence length is now the number of games
h_0, c_0 = self.lstm(
flatten_out_batch_size_1) # discard c_0 but we will definitely need it when we deploy the model
remove_batch_dimension = h_0.squeeze(0)
dense1_out = self.dense1(remove_batch_dimension)
relu_3 = self.relu3(dense1_out)
dense2_out = self.dense2(relu_3)
return dense2_out.reshape((state.shape[0], 18, 65)) # apparently no softmax needed https://discuss.pytorch.org/t/do-i-need-to-use-softmax-before-nn-crossentropyloss/16739
class BoardGuesserNetOnline(nn.Module):
def __init__(self):
super(BoardGuesserNetOnline, self).__init__()
"""This could be a small backbone"""
self.conv1 = torch.nn.Conv2d(16, 32, 3)
torch.nn.init.xavier_uniform_(self.conv1.weight, gain=1)
self.relu1 = torch.nn.LeakyReLU()
self.pool1 = torch.nn.MaxPool2d(2)
self.conv2 = torch.nn.Conv2d(32, 64, 2)
torch.nn.init.xavier_uniform_(self.conv2.weight, gain=1)
self.relu2 = torch.nn.LeakyReLU()
self.pool2 = torch.nn.MaxPool2d(2)
self.flatten = torch.nn.Flatten()
# two hidden lstm states
self.lstm = torch.nn.LSTM(256, 256, 2, batch_first=True)
# recast board to truth
self.dense1 = torch.nn.Linear(256, 640)
torch.nn.init.xavier_uniform_(self.dense1.weight, gain=1)
self.relu3 = torch.nn.LeakyReLU()
self.dense2 = torch.nn.Linear(640, 18 * 65)
# https://discuss.pytorch.org/t/do-i-need-to-use-softmax-before-nn-crossentropyloss/16739
# self.softmax = torch.nn.Softmax(axis = 1) # we use a sigmoid because it has a range of 0 to 1
def forward(self, state, hidden=None):
"""
:param state:
:param hidden:
:return:
"""
"""
:param state: <N, 19, 8, 8> tensor representing one game :return: <19, 36, 64> . For the encoding of the
output please see get_truncated_truth_board in fen_string_convert.py
:
"""
conv_1_out = self.conv1(state)
relu_1 = self.relu1(conv_1_out)
conv_2_out = self.conv2(relu_1)
relu_2 = self.relu2(conv_2_out)
pool_2_out = self.pool2(relu_2)
flatten_out = self.flatten(pool_2_out)
flatten_out_batch_size_1 = torch.unsqueeze(flatten_out, 0) # sequence length is now the number of games
if hidden is None:
output, (h_0, c_0) = self.lstm(
flatten_out_batch_size_1) # discard c_0 but we will definitely need it when we deploy the model
else:
output, (h_0, c_0) = self.lstm(
flatten_out_batch_size_1, hidden) # discard c_0 but we will definitely need it when we deploy the model
remove_batch_dimension = output.squeeze(0)
dense1_out = self.dense1(remove_batch_dimension)
relu_3 = self.relu3(dense1_out)
dense2_out = self.dense2(relu_3)
return dense2_out.reshape((state.shape[0], 18, 65)), (h_0, c_0) # apparently no softmax needed https://discuss.pytorch.org/t/do-i-need-to-use-softmax-before-nn-crossentropyloss/16739
if __name__ == "__main__":
# Test initialize RNN
guessNet = BoardGuesserNet()
network_input = torch.zeros((50, 19, 8, 8))
network_guess = guessNet(network_input)
#assert network_guess.shape == (50, 18, 8, 8)
torch.save(guessNet.state_dict(), "test_model")
network_input_online = torch.zeros((50, 19, 8, 8))
# test online network. With the online network we have to manually pass in hidden states
guessNetOnline = BoardGuesserNetOnline()
guessNetOnline.load_state_dict(torch.load("test_model"))
truth_board, hidden_state = guessNetOnline(network_input_online)
truth_board2, hidden_state2 = guessNetOnline(network_input_online, hidden_state)
#assert truth_board2.shape == (50, 18, 8, 8)