-
Notifications
You must be signed in to change notification settings - Fork 0
/
transforms.py
135 lines (90 loc) · 3.28 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
1D augmentations
"""
import numpy as np
import torch
import random
# for BYOL
class MultiViewDataInjector():
def __init__(self, transform_list):
if not isinstance(transform_list, list):
transform_list = [transform_list]
self.transform_list = transform_list
def __call__(self, sample):
output = [transform(sample).unsqueeze(0) for transform in self.transform_list]
output_cat = torch.cat(output, dim=0)
return output_cat
class TwoCropsTransform:
"""Take two random crops of one image as the query and key."""
def __init__(self,):
self.base_transform = BaseTransform()
def __call__(self, x):
q = self.base_transform(x)
k = self.base_transform(x)
return q, k
class RandomResizedCrop(torch.nn.Module):
def __init__(self):
super().__init__()
# self.n_samples = n_samples
def forward(self, signal):
max_samples = signal.shape[-1]
n_samples = max_samples // random.randint(2, 10)
start_idx = random.randint(0, max_samples - n_samples)
signal = signal[..., start_idx : start_idx + n_samples]
return signal
class Flip(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, signal):
flipped = torch.flip(signal, [-1])
return flipped
class Noise(torch.nn.Module):
def __init__(self, min_snr=0.0001, max_snr=0.01):
"""
:param min_snr: Minimum signal-to-noise ratio
:param max_snr: Maximum signal-to-noise ratio
"""
super().__init__()
self.min_snr = min_snr
self.max_snr = max_snr
def forward(self, signal):
std = torch.std(signal)
random_factor = torch.rand(1) * (self.max_snr - self.min_snr) + self.min_snr
random_factor = random_factor.to(std.device)
# Scale the random_factor by std to get the desired range
noise_std = random_factor * std
noise = torch.randn(signal.shape).to(noise_std.device) * noise_std # np.random.normal(0.0, noise_std, size=signal.shape).astype(np.float32)
return signal + noise
class BaseTransform(torch.nn.Module):
def __init__(self):
super().__init__()
self.random_noise = RandomApply([Noise()], 0.7)
self.random_flip = RandomApply([Flip()], 0.4)
self.compose = torch.nn.Sequential(
self.random_noise,
# RandomResizedCrop()
)
def forward(self, x):
return self.compose(x)
class RandomApply(torch.nn.Module):
def __init__(self, transforms, p=0.5):
super().__init__()
self.transforms = transforms
self.p = p
def forward(self, img):
if self.p < torch.rand(1):
return img
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + "("
format_string += "\n p={}".format(self.p)
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
def get_transforms_list():
return [RandomApply([Noise()], 0.7), \
RandomApply([Flip()], 0.4) ]