-
Notifications
You must be signed in to change notification settings - Fork 9
/
PAE.py
41 lines (35 loc) · 1.23 KB
/
PAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torch.nn import Linear as Lin, Sequential as Seq
import torch.nn.functional as F
from torch import nn
class PAE(torch.nn.Module):
def __init__(self, input_dim, dropout=0.2):
super(PAE, self).__init__()
hidden=128
self.parser =nn.Sequential(
nn.Linear(input_dim, hidden, bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm1d(hidden),
nn.Dropout(dropout),
nn.Linear(hidden, hidden, bias=True),
)
self.cos = nn.CosineSimilarity(dim=1, eps=1e-8)
self.input_dim = input_dim
self.model_init()
self.relu = nn.ReLU(inplace=True)
self.elu = nn.ReLU()
def forward(self, x):
x1 = x[:,0:self.input_dim]
x2 = x[:,self.input_dim:]
h1 = self.parser(x1)
h2 = self.parser(x2)
p = (self.cos(h1,h2) + 1)*0.5
return p
def model_init(self):
for m in self.modules():
if isinstance(m, Lin):
torch.nn.init.kaiming_normal_(m.weight)
m.weight.requires_grad = True
if m.bias is not None:
m.bias.data.zero_()
m.bias.requires_grad = True