-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
53 lines (44 loc) · 1.68 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
class SudokuSolver(nn.Module):
def __init__(self, constraint_mask, n=9, hidden1=100):
super(SudokuSolver, self).__init__()
self.constraint_mask = constraint_mask.view(1, n * n, 3, n * n, 1)
self.n = n
self.hidden1 = hidden1
# Feature vector is the 3 constraints
self.input_size = 3 * n
self.l1 = nn.Linear(self.input_size,
self.hidden1, bias=False)
self.a1 = nn.ReLU()
self.l2 = nn.Linear(self.hidden1,
n, bias=False)
self.softmax = nn.Softmax(dim=1)
# x is a (batch, n^2, n) tensor
def forward(self, x):
n = self.n
bts = x.shape[0]
c = self.constraint_mask
min_empty = (x.sum(dim=2) == 0).sum(dim=1).max()
x_pred = x.clone()
for a in range(min_empty):
# score empty numbers
constraints = (x.view(bts, 1, 1, n * n, n) * c).sum(dim=3)
# empty cells
empty_mask = (x.sum(dim=2) == 0)
f = constraints.reshape(bts, n * n, 3 * n)
y_ = self.l2(self.a1(self.l1(f[empty_mask])))
s_ = self.softmax(y_)
# Score the rows
x_pred[empty_mask] = s_
s = torch.zeros_like(x_pred)
s[empty_mask] = s_
# find most probable guess
score, score_pos = s.max(dim=2)
mmax = score.max(dim=1)[1]
# fill it in
nz = empty_mask.sum(dim=1).nonzero().view(-1)
mmax_ = mmax[nz]
ones = torch.ones(nz.shape[0])
x.index_put_((nz, mmax_, score_pos[nz, mmax_]), ones)
return x_pred, x