-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunet_visualization.py
74 lines (59 loc) · 2.47 KB
/
unet_visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Visualize the masks generated by Unet.
import os
import glob
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from cxr_mask_dataset import CXRMaskDataset
from res_unet_model import ResnetUNet, UNet2
from PIL import Image
OUTPUT_PATH = os.path.join('data', 'predicted', 'masks')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using {DEVICE}')
os.makedirs(OUTPUT_PATH, exist_ok=True)
test_cxr_folder = os.path.join('data', 'processed', 'test', 'imgs')
test_cxrs = glob.glob(os.path.join(test_cxr_folder, '*')) # A list of paths to all test cxr imgs
cxr_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float)
])
# Load the model
# resunet = ResnetUNet(3)
# resunet.load_state_dict(torch.load("unet20.pt"))
resunet = UNet2()
resunet.load_state_dict(torch.load('unet_val.pt'))
resunet.to(DEVICE)
# print('model is on cuda: {resunet.is_cuda}')
resunet.eval()
scores = []
with torch.no_grad():
# metrics = defaultdict(float)
# epoch_samples = 0
# for images, y_batch in loader:
count = 0
for test_cxr in test_cxrs:
fname = os.path.basename(test_cxr)
count += 1
print(f'processing {fname}, {count}/{len(test_cxrs)}')
cxr_img = np.array(Image.open(test_cxr))
cxr_img = cxr_transforms(cxr_img)
cxr_img = torch.unsqueeze(cxr_img, 0)
cxr_img = cxr_img.to(DEVICE)
# print(f'input cxr on cuda: {cxr_img.is_cuda}')
mask_pred = resunet(cxr_img)
mask_pred = torch.sigmoid(mask_pred)
mask_pred = mask_pred.squeeze() # (3, 512, 512)
# print(f'predicted mask shape before argmax: {mask_pred.shape}\n{mask_pred}')
# Save the predicted mask
mask_pred = torch.argmax(mask_pred, dim=0).float() # modified
# print(f'\nmask_pred shape: {mask_pred.shape}\n{mask_pred}')
# print(f'predicted mask min: {torch.min(mask_pred)}, max: {torch.max(mask_pred)}')
save_image(mask_pred, os.path.join(OUTPUT_PATH, fname), normalize=True)
# y_pred = torch.squeeze(y_pred)
# y_batch = torch.squeeze(y_batch)
# loss = calc_loss(y_pred, y_batch, metrics)
# break
# print(f"Micro-averaged Mean Squared Error {sum(scores)/len(scores)}")
print(f'Done, predicted masks saved in {OUTPUT_PATH}')