-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscaling_model.py
87 lines (73 loc) · 2.9 KB
/
scaling_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
ACTIVATIONS = {"relu": F.relu, "tanh": torch.tanh, "linear": lambda x: x}
class FinalResNet(nn.Module):
def __init__(
self,
dim,
num_layers,
delta_type, # 'none', 'shared', 'multi',
initial_sd,
activation, # 'relu', 'tanh'
):
"""Generic fully configurable version of 1D ResNet.
Update rule:
x[l+1] = x[l] + delta[l] * sigma(A[l] * x[l] + b[l])
"""
super(FinalResNet, self).__init__()
self.num_layers = num_layers
self.dim = dim
self.delta_type = delta_type # 'none', 'shared', 'multi'
self.sigma = ACTIVATIONS[activation] # 'relu', 'tanh'
self.initial_sd = initial_sd
self.A = [nn.Parameter(torch.Tensor(dim, dim)) for _ in range(num_layers)]
self.b = [nn.Parameter(torch.Tensor(dim)) for _ in range(num_layers)]
if delta_type == "shared":
self.delta = nn.Parameter(torch.Tensor(1))
elif delta_type == "multi":
self.delta = [nn.Parameter(torch.Tensor(1)) for _ in range(num_layers)]
else:
self.delta = 1.0
def forward(self, x):
for l in range(self.num_layers):
x = x + self.update_rule(x, l)
return x
def update_rule(self, x, l):
h = self.sigma(torch.mm(x, self.A[l]) + self.b[l])
if self.delta_type == "shared":
return torch.abs(self.delta) * h
elif self.delta_type == "multi":
return self.delta[l] * h
else:
return h
def init_values(self, method="xavier"):
L = self.num_layers
dim = self.dim
if self.delta_type == "shared":
self.delta.data = torch.FloatTensor([1 / L])
elif self.delta_type == "multi":
for l in range(L):
self.delta[l].data.normal_(0, 1.0 / L)
if method == "xavier":
for l in range(L):
self.A[l].data.normal_(0, self.initial_sd / dim)
self.b[l].data.normal_(0, self.initial_sd / np.sqrt(dim))
elif method == "xavier-depth":
for l in range(L):
self.A[l].data.normal_(0, self.initial_sd / (L * dim))
self.b[l].data.normal_(0, self.initial_sd / (L * np.sqrt(dim)))
else:
raise ValueError(
f"Unknown method {method}. Allowed: 'xavier' and 'xavier-depth'."
)
def init_from_data(self, A, b, delta):
"""A: [(d, d)], b: [(d,)], delta: [()]."""
if self.delta_type == "shared":
self.delta.data = torch.FloatTensor(delta)
for l in range(self.num_layers):
self.A[l].data = torch.FloatTensor(A[l])
self.b[l].data = torch.FloatTensor(b[l])
if self.delta_type == "multi":
self.delta[l].data = torch.FloatTensor(delta[l])