-
Notifications
You must be signed in to change notification settings - Fork 4
/
TODO
38 lines (29 loc) · 1.14 KB
/
TODO
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
## TODO
* [x] Resume training from existing checkpoint:
* [x] save/load CTA
* [x] save ema model
* [ ] DDP:
* [x] Synchronize CTA across processes
* [ ] Bug: DDP performances are worse than DP on the first epochs
* [x] Logging to an online platform: W&B
* [ ] Replace PIL augmentations with Albumentations
```python
class BlurLimitSampler:
def __init__(self, blur, weights):
self.blur = blur # [3, 5, 7]
self.weights = weights # [0.1, 0.5, 0.4]
def get_params(self):
return {"ksize": int(random.choice(self.blur, p=self.weights))}
class Blur(ImageOnlyTransform):
def __init__(self, blur_limit, always_apply=False, p=0.5):
super(Blur, self).__init__(always_apply, p)
self.blur_limit = blur_limit
def apply(self, image, ksize=3, **params):
return F.blur(image, ksize)
def get_params(self):
if isinstance(self.blur_limit, BlurLimitSampler):
return self.blur_limit.get_params()
return {"ksize": int(random.choice(np.arange(self.blur_limit[0], self.blur_limit[1] + 1, 2)))}
def get_transform_init_args_names(self):
return ("blur_limit",)
```