forked from daixiangzi/PRCV2019
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Cutout.py
32 lines (28 loc) · 948 Bytes
/
Cutout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
import torch
class Cutout(object):
# n:patch num;lenght:one pacth w or h
def __init__(self,n,length):
self.n = n
self.length = length
#img must be tensor:(C,H,W)
def __call__(self,img):
h = img.size(1)
w = img.size(2)
mask = np.ones((h,w),np.float32)
for i in range(self.n):
#random product (x,y) as patch center
x = np.random.randint(w)
y = np.random.randint(h)
x1 = np.clip(x-self.length//2,0,w)
y1 = np.clip(y-self.length//2,0,h)
x2 = np.clip(x+self.length//2,0,w)
y2 = np.clip(y+self.length//2,0,h)
#padding
mask[y1:y2,x1:x2] = 0
#numpy convert tensor
mask = torch.from_numpy(mask)
#mask is (h*w),img is(c*h*w);mask is expanded (c*h*w)
mask = mask.expand_as(img)
img = img*mask
return img