-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsample.py
58 lines (41 loc) · 1.8 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from argparse import ArgumentParser
from pathlib import Path
from torchvision.utils import save_image
import torch
from pae.model import NADE
def get_arg_parser():
# 1. setting
parser = ArgumentParser(description='pytorch-auto-encoder')
parser.add_argument('--data-dir', type=str, default=os.path.join('data', 'mnist'), help='root path of dataset')
parser.add_argument('--log-dir', type=str, default='sample_log', help='root log dir')
parser.add_argument('--cuda', type=str, default='0,', help="cuda devices")
# 2. model
parser.add_argument('-m', '--model-name', type=str, default='NADE', help='the name of model')
parser.add_argument('-cp', '--checkpoint', type=str, default='', help='saved state dict path')
return parser
def setup(args):
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
args.device = 'cuda:0'
log_root_path = os.path.join(args.log_dir, args.model_name)
Path(log_root_path).mkdir(exist_ok=True, parents=True)
args.run_id = f"v{len(os.listdir(log_root_path))}"
args.log_dir = os.path.join(log_root_path, args.run_id)
Path(args.log_dir).mkdir(exist_ok=True, parents=True)
def run(args):
setup(args)
if args.model_name == 'NADE':
f = NADE().to(args.device)
else:
AssertionError(f"{args.model_name} is not supported yet!")
state_dict = torch.load(args.checkpoint, map_location='cpu')
if isinstance(state_dict, (list, tuple)):
state_dict = state_dict[0]
f.load_state_dict(state_dict)
sampled_img = f.sample(16, args.device).reshape(16, 1, 28, 28)
save_image(sampled_img, os.path.join(args.log_dir, f'{args.model_name}_sampled_img.jpg'))
if __name__ == '__main__':
parser = get_arg_parser()
args = parser.parse_args()
run(args)