-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdemo1.py
49 lines (46 loc) · 2.01 KB
/
demo1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from src.models.sr_module import *
from pytorch_lightning import Trainer
from argparse import ArgumentParser
from torchvision.io import read_image, ImageReadMode
from torchvision.utils import save_image
from torchvision import transforms
from torch.nn.functional import avg_pool2d
import torch
import glob
import os
import math
from pathlib import Path
parser = ArgumentParser()
parser.add_argument('-s','--scales', type=float, nargs='+', required=True)
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--model_name", type=str, default='default_model')
parser.add_argument("--file_ext", type=str, default='.png')
args = parser.parse_args()
def resize_fn(img, size):
return transforms.Resize(size=size,
interpolation=transforms.InterpolationMode.BICUBIC,
antialias=True)(img)
'''
demo1: downscale each HR image in ./demo/ at larger and larger scales to obtain corresponding LR images, then super-resolve them back to the HR resolution
'''
@torch.no_grad()
def demo1(args):
#load image
hr_path = "./demo/"
names_hr = sorted(glob.glob(os.path.join(str(hr_path), '*'+args.file_ext)))
if args.model_name == 'bicubic':
model = SRLitModule(arch='bicubic')
else:
model = SRLitModule.load_from_checkpoint(args.ckpt_path)
for f_hr in names_hr:
filename, _ = os.path.splitext(os.path.basename(f_hr))
hr = read_image(f_hr, ImageReadMode.RGB).unsqueeze(0)/255.
for s in args.scales:
lr_size = round(hr.shape[-2]/s), round(hr.shape[-1]/s)
lr = resize_fn(hr, lr_size)
Path(os.path.dirname(f_hr) + "/{}".format(args.model_name)).mkdir(parents=True, exist_ok=True)
save_image(lr, os.path.dirname(f_hr) + "/{}/{}_lrx{}.png".format(args.model_name, filename, s))
sr = model(lr, hr.shape[-2:])
save_image(sr, os.path.dirname(f_hr) + "/{}/{}_{}x{}.png".format(args.model_name, args.model_name, filename, s))
if __name__=='__main__':
demo1(args)