-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpositional_embeddings.py
94 lines (69 loc) · 2.48 KB
/
positional_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
'''Different methods for positional embeddings. These are not essential for understanding DDPMs, but are relevant for the ablation study.'''
import torch
from torch import nn
from torch.nn import functional as F
class SinusoidalEmbedding(nn.Module):
def __init__(self, size: int, scale: float = 1.0):
super().__init__()
self.size = size
self.scale = scale
def forward(self, x: torch.Tensor):
x = x * self.scale
half_size = self.size // 2
emb = torch.log(torch.Tensor([10000.0])) / (half_size - 1)
emb = torch.exp(-emb * torch.arange(half_size))
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
return emb
def __len__(self):
return self.size
class LinearEmbedding(nn.Module):
def __init__(self, size: int, scale: float = 1.0):
super().__init__()
self.size = size
self.scale = scale
def forward(self, x: torch.Tensor):
x = x / self.size * self.scale
return x.unsqueeze(-1)
def __len__(self):
return 1
class LearnableEmbedding(nn.Module):
def __init__(self, size: int):
super().__init__()
self.size = size
self.linear = nn.Linear(1, size)
def forward(self, x: torch.Tensor):
return self.linear(x.unsqueeze(-1).float() / self.size)
def __len__(self):
return self.size
class IdentityEmbedding(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor):
return x.unsqueeze(-1)
def __len__(self):
return 1
class ZeroEmbedding(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor):
return x.unsqueeze(-1) * 0
def __len__(self):
return 1
class PositionalEmbedding(nn.Module):
def __init__(self, size: int, type: str, **kwargs):
super().__init__()
if type == "sinusoidal":
self.layer = SinusoidalEmbedding(size, **kwargs)
elif type == "linear":
self.layer = LinearEmbedding(size, **kwargs)
elif type == "learnable":
self.layer = LearnableEmbedding(size)
elif type == "zero":
self.layer = ZeroEmbedding()
elif type == "identity":
self.layer = IdentityEmbedding()
else:
raise ValueError(f"Unknown positional embedding type: {type}")
def forward(self, x: torch.Tensor):
return self.layer(x)