-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathvalidation.py
79 lines (61 loc) · 1.93 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
from shapely.geometry.point import Point
from skimage.draw import circle_perimeter_aa
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from network import Net
def draw_circle(img, row, col, rad):
rr, cc, val = circle_perimeter_aa(row, col, rad)
valid = (
(rr >= 0) &
(rr < img.shape[0]) &
(cc >= 0) &
(cc < img.shape[1])
)
img[rr[valid], cc[valid]] = val[valid]
def noisy_circle(size, radius, noise):
img = np.zeros((size, size), dtype=np.float)
# Circle
row = np.random.randint(size)
col = np.random.randint(size)
rad = np.random.randint(10, max(10, radius))
draw_circle(img, row, col, rad)
# Noise
img += noise * np.random.rand(*img.shape)
return (row, col, rad), img
def find_circle(img):
model = Net()
checkpoint = torch.load('model.pth.tar')
model.load_state_dict(checkpoint)
model.eval()
with torch.no_grad():
image = np.expand_dims(np.asarray(img), axis=0)
image = torch.from_numpy(np.array(image, dtype=np.float32))
normalize = transforms.Normalize(mean=[0.5], std=[0.5])
image = normalize (image)
image = image.unsqueeze(0)
output = model(image)
return [round(i) for i in (200*output).tolist()[0]]
def iou(params0, params1):
row0, col0, rad0 = params0
row1, col1, rad1 = params1
shape0 = Point(row0, col0).buffer(rad0)
shape1 = Point(row1, col1).buffer(rad1)
return (
shape0.intersection(shape1).area /
shape0.union(shape1).area
)
def main():
results = []
for _ in range(1000):
params, img = noisy_circle(200, 75, 1)
detected = find_circle(img)
results.append(iou(params, detected))
results = np.array(results)
print((results > 0.7).mean())
if __name__ == '__main__':
main()