-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
105 lines (85 loc) · 3.07 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from lib.config import cfg, args
import numpy as np
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
def run_dataset():
from lib.datasets import make_data_loader
import tqdm
cfg.train.num_workers = 0
data_loader = make_data_loader(cfg, is_train=False)
for batch in tqdm.tqdm(data_loader):
pass
def run_network():
from lib.networks import make_network
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_network
import tqdm
import torch
import time
network = make_network(cfg).cuda()
load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
total_time = 0
for batch in tqdm.tqdm(data_loader):
for k in batch:
if k != 'meta':
batch[k] = batch[k].cuda()
with torch.no_grad():
torch.cuda.synchronize()
start = time.time()
network(batch['inp'])
torch.cuda.synchronize()
total_time += time.time() - start
print(total_time / len(data_loader))
def run_evaluate():
from lib.datasets import make_data_loader
from lib.evaluators import make_evaluator
import tqdm
import torch
from lib.networks import make_network
from lib.utils.net_utils import load_network
network = make_network(cfg).cuda()
load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
evaluator = make_evaluator(cfg)
for batch in tqdm.tqdm(data_loader):
inp = batch['inp'].cuda()
# print(batch.keys())
# print(inp.shape)
# print(batch['meta'])
with torch.no_grad():
output = network(inp)
# print(output)
evaluator.evaluate(output, batch)
evaluator.summarize()
def run_visualize():
from lib.networks import make_network
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_network
import tqdm
import torch
from lib.visualizers import make_visualizer
network = make_network(cfg).cuda()
load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
visualizer = make_visualizer(cfg)
for batch in tqdm.tqdm(data_loader):
for k in batch:
if k != 'meta':
batch[k] = batch[k].cuda()
with torch.no_grad():
output = network(batch['inp'], batch)
visualizer.visualize(output, batch)
def run_sbd():
from tools import convert_sbd
convert_sbd.convert_sbd()
def run_demo():
from tools import demo
# demo.demo() # # 大部分用这个??
# demo.test_demo() #暂时改成这个便于测试
demo.test_demo_0618()#增加去重复检测和跟踪(为解决会遇情况)
if __name__ == '__main__':
globals()['run_'+args.type]() #从全局环境中找到名为 'run_'+args.type 的函数/类等 callable 的对象,并执行它。