-
Notifications
You must be signed in to change notification settings - Fork 18
/
wav_kan.py
232 lines (202 loc) · 9 KB
/
wav_kan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# code modified from https://github.com/zavareh1/Wav-KAN
class WavKANLinear(nn.Module):
def __init__(self, in_features, out_features, wavelet_type='mexican_hat'):
super(WavKANLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.wavelet_type = wavelet_type
# Parameters for wavelet transformation
self.scale = nn.Parameter(torch.ones(out_features, in_features))
self.translation = nn.Parameter(torch.zeros(out_features, in_features))
# Linear weights for combining outputs
#self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features)) #not used; you may like to use it for wieghting base activation and adding it like Spl-KAN paper
self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features))
nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
# Base activation function #not used for this experiment
self.base_activation = nn.SiLU()
# Batch normalization
self.bn = nn.BatchNorm1d(out_features)
def wavelet_transform(self, x):
if x.dim() == 2:
x_expanded = x.unsqueeze(1)
else:
x_expanded = x
translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1)
scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1)
x_scaled = (x_expanded - translation_expanded) / scale_expanded
# Implementation of different wavelet types
if self.wavelet_type == 'mexican_hat':
term1 = ((x_scaled ** 2)-1)
term2 = torch.exp(-0.5 * x_scaled ** 2)
wavelet = (2 / (math.sqrt(3) * math.pi**0.25)) * term1 * term2
wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
wavelet_output = wavelet_weighted.sum(dim=2)
elif self.wavelet_type == 'morlet':
omega0 = 5.0 # Central frequency
real = torch.cos(omega0 * x_scaled)
envelope = torch.exp(-0.5 * x_scaled ** 2)
wavelet = envelope * real
wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
wavelet_output = wavelet_weighted.sum(dim=2)
elif self.wavelet_type == 'dog':
# Implementing Derivative of Gaussian Wavelet
dog = -x_scaled * torch.exp(-0.5 * x_scaled ** 2)
wavelet = dog
wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
wavelet_output = wavelet_weighted.sum(dim=2)
elif self.wavelet_type == 'meyer':
# Implement Meyer Wavelet here
# Constants for the Meyer wavelet transition boundaries
v = torch.abs(x_scaled)
pi = math.pi
def meyer_aux(v):
return torch.where(v <= 1/2,torch.ones_like(v),torch.where(v >= 1,torch.zeros_like(v),torch.cos(pi / 2 * nu(2 * v - 1))))
def nu(t):
return t**4 * (35 - 84*t + 70*t**2 - 20*t**3)
# Meyer wavelet calculation using the auxiliary function
wavelet = torch.sin(pi * v) * meyer_aux(v)
wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
wavelet_output = wavelet_weighted.sum(dim=2)
elif self.wavelet_type == 'shannon':
# Windowing the sinc function to limit its support
pi = math.pi
sinc = torch.sinc(x_scaled / pi) # sinc(x) = sin(pi*x) / (pi*x)
# Applying a Hamming window to limit the infinite support of the sinc function
window = torch.hamming_window(x_scaled.size(-1), periodic=False, dtype=x_scaled.dtype, device=x_scaled.device)
# Shannon wavelet is the product of the sinc function and the window
wavelet = sinc * window
wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
wavelet_output = wavelet_weighted.sum(dim=2)
elif self.wavelet_type == 'bump':
# Bump wavelet is only defined in the interval (-1, 1)
# We apply a condition to restrict the computation to this interval
inside_interval = (x_scaled > -1.0) & (x_scaled < 1.0)
wavelet = torch.exp(-1.0 / (1 - x_scaled**2)) * inside_interval.float()
wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
wavelet_output = wavelet_weighted.sum(dim=2)
else:
raise ValueError("Unsupported wavelet type")
return wavelet_output
def forward(self, x):
wavelet_output = self.wavelet_transform(x)
#You may like test the cases like Spl-KAN
#wav_output = F.linear(wavelet_output, self.weight)
#base_output = F.linear(self.base_activation(x), self.weight1)
base_output = F.linear(x, self.weight1)
combined_output = wavelet_output #+ base_output
# Apply batch normalization
return self.bn(combined_output)
class Mexican_Hat_KAN(nn.Module):
def __init__(self, layers_hidden,
grid_size=5, # placeholder
spline_order=0., # placeholder
wavelet_type='mexican_hat'):
super(Mexican_Hat_KAN, self).__init__()
super().__init__()
self.layers = nn.ModuleList([
WavKANLinear(
in_features=in_dim,
out_features=out_dim,
wavelet_type=wavelet_type
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class Morlet_KAN(nn.Module):
def __init__(self, layers_hidden,
grid_size=5, # placeholder
spline_order=0., # placeholder
wavelet_type='morlet'):
super(Morlet_KAN, self).__init__()
super().__init__()
self.layers = nn.ModuleList([
WavKANLinear(
in_features=in_dim,
out_features=out_dim,
wavelet_type=wavelet_type
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class Dog_KAN(nn.Module):
def __init__(self, layers_hidden,
grid_size=5, # placeholder
spline_order=0., # placeholder
wavelet_type='dog'):
super(Dog_KAN, self).__init__()
super().__init__()
self.layers = nn.ModuleList([
WavKANLinear(
in_features=in_dim,
out_features=out_dim,
wavelet_type=wavelet_type
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class Meyer_KAN(nn.Module):
def __init__(self, layers_hidden,
grid_size=5, # placeholder
spline_order=0., # placeholder
wavelet_type='meyer'):
super(Meyer_KAN, self).__init__()
super().__init__()
self.layers = nn.ModuleList([
WavKANLinear(
in_features=in_dim,
out_features=out_dim,
wavelet_type=wavelet_type
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class Shannon_KAN(nn.Module):
def __init__(self, layers_hidden,
grid_size=5, # placeholder
spline_order=0., # placeholder
wavelet_type='shannon'):
super(Shannon_KAN, self).__init__()
super().__init__()
self.layers = nn.ModuleList([
WavKANLinear(
in_features=in_dim,
out_features=out_dim,
wavelet_type=wavelet_type
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class Bump_KAN(nn.Module):
def __init__(self, layers_hidden,
grid_size=5, # placeholder
spline_order=0., # placeholder
wavelet_type='bump'):
super(Bump_KAN, self).__init__()
super().__init__()
self.layers = nn.ModuleList([
WavKANLinear(
in_features=in_dim,
out_features=out_dim,
wavelet_type=wavelet_type
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x