-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmultitest.py
78 lines (69 loc) · 2.75 KB
/
multitest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import os
import numpy as np
import argparse
from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.utils as vutils
from network.Transformer import Transformer
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', default = 'test_img')
parser.add_argument('--load_size', default = None, type=int)
parser.add_argument('--model_path', default = './pretrained_model')
parser.add_argument('--output_dir', default = 'test_output')
parser.add_argument('--gpu', type=int, default = 0)
opt = parser.parse_args()
valid_ext = ['.jpg', '.png']
if not os.path.exists(opt.output_dir): os.mkdir(opt.output_dir)
MODELS = ['Hayao', 'Hosoda', 'Paprika', 'Shinkai']
with torch.no_grad():
for style in MODELS:
# load pretrained model
model = Transformer()
model.load_state_dict(torch.load(os.path.join(opt.model_path, style + '_net_G_float.pth')))
model.eval()
if opt.gpu > -1:
print('GPU mode')
model.cuda()
else:
print('CPU mode')
model.float()
for files in tqdm(os.listdir(opt.input_dir)):
ext = os.path.splitext(files)[1]
if ext not in valid_ext:
continue
# load image
input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
# resize image, keep aspect ratio
h = input_image.size[0]
w = input_image.size[1]
ratio = h *1.0 / w
if ratio > 1:
h = opt.load_size if opt.load_size is not None else h
w = int(h*1.0/ratio)
else:
w = opt.load_size if opt.load_size is not None else w
h = int(w * ratio)
input_image = input_image.resize((h, w), Image.BICUBIC)
input_image = np.asarray(input_image)
# RGB -> BGR
input_image = input_image[:, :, [2, 1, 0]]
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
# preprocess, (-1, 1)
input_image = -1 + 2 * input_image
if opt.gpu > -1:
input_image = Variable(input_image, volatile=True).cuda()
else:
input_image = Variable(input_image, volatile=True).float()
# forward
output_image = model(input_image)
output_image = output_image[0]
# BGR -> RGB
output_image = output_image[[2, 1, 0], :, :]
# deprocess, (0, 1)
output_image = output_image.data.cpu().float() * 0.5 + 0.5
# save
vutils.save_image(output_image, os.path.join(opt.output_dir, files[:-4] + '_' + style + '.jpg'))
print('Done!')