-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_bitnet_chan.py
110 lines (94 loc) · 4.63 KB
/
test_bitnet_chan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import glob
import time
import numpy as np
from scipy import misc
import tensorflow as tf
import data_test as data
import models.bitnet_chan as net
# If CPU
cpu = False
# GPU selection
if cpu:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
else:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Model name
model_name = 'bitnet_chan'
# Session
sess = tf.InteractiveSession()
# Placeholders
low = tf.placeholder(tf.float32, [1])
high = tf.placeholder(tf.float32, [1])
full_dqimage_dir_ph = tf.placeholder(tf.string)
image_ph = tf.placeholder(tf.float32, [1, None, None, 4])
label_ph = tf.placeholder(tf.float32, [1, None, None, 3])
infer_ph = tf.placeholder(tf.float32, [1, None, None, 3])
# Infer, Metrics
res = (2**high-1)
infer1 = net.net(tf.concat([image_ph[:, :, :, 0:1], image_ph[:, :, :, 3:4]], axis=3), tf.AUTO_REUSE)
infer2 = net.net(tf.concat([image_ph[:, :, :, 1:2], image_ph[:, :, :, 3:4]], axis=3), tf.AUTO_REUSE)
infer3 = net.net(tf.concat([image_ph[:, :, :, 2:3], image_ph[:, :, :, 3:4]], axis=3), tf.AUTO_REUSE)
psnr = tf.image.psnr(tf.to_int32(tf.clip_by_value(infer_ph, 0., 1.) * res), tf.to_int32(label_ph * res), max_val=res)
ssim = tf.image.ssim(tf.to_int32(tf.clip_by_value(infer_ph, 0., 1.) * res), tf.to_int32(label_ph * res), max_val=res)
# Restore Model
saver = tf.train.Saver()
saver.restore(sess, './checkpoint/' + model_name + '/model_100.ckpt')
print('Model restored')
# Run
result_dir = './test_result/' + model_name + '/'
data_dirs = glob.glob('./dataset/*')
for data_dir in data_dirs:
# Different settings for each database
database = data_dir.split('/')[2]
os.makedirs(result_dir + database) if not os.path.exists(result_dir + database) else None
if database == 'mit':
continue
elif database == 'espl' or database == 'kodak':
l_b_h_b = [[3, 8], [4, 8]]
else:
l_b_h_b = [[3, 16], [4, 16], [5, 16], [6, 16]]
# Call image reader
reader = data.DataReader()
full_dqimage_dirs, img_num = reader.read_file(data_dir + '/*')
image, label = reader.read_data(full_dqimage_dir_ph, database, low)
# Logging
if cpu:
log = open('./log/' + model_name + '/test_' + database + '_cpu.txt', 'a')
else:
log = open('./log/' + model_name + '/test_' + database + '.txt', 'a')
# Test
for l_b, h_b in l_b_h_b:
# Warm up
image_, label_ = sess.run([image, label], feed_dict={full_dqimage_dir_ph:full_dqimage_dirs[0], low: [l_b]})
infer1_ = sess.run(infer1, feed_dict={image_ph: image_})
infer2_ = sess.run(infer2, feed_dict={image_ph: image_})
infer3_ = sess.run(infer3, feed_dict={image_ph: image_})
t_p_, t_s_, t_t_ = 0, 0, 0
for i in range(1, img_num+1):
# Read image and label
image_, label_ = sess.run([image, label], feed_dict={full_dqimage_dir_ph: full_dqimage_dirs[i-1], low: [l_b]})
# Inference, measure time
start = time.time()
infer1_ = sess.run(infer1, feed_dict={image_ph: image_})
infer2_ = sess.run(infer2, feed_dict={image_ph: image_})
infer3_ = sess.run(infer3, feed_dict={image_ph: image_})
t_ = time.time() - start
# Measure PSNR, SSIM
infer_ = np.concatenate((infer1_, infer2_, infer3_), axis=3)
p_, s_, = sess.run([psnr, ssim], feed_dict={high: [h_b], infer_ph: infer_, label_ph: label_})
t_p_, t_s_, t_t_ = t_p_ + p_, t_s_ + s_, t_t_ + t_
# Save results
input_image = np.uint8(np.squeeze(image_) * 255.)
infer_image = np.uint8(np.minimum(np.maximum(np.squeeze(infer_), 0.0), 1.) * 255.)
misc.imsave(result_dir + database + '/%d_%d_%d_1_input.png' % (l_b, h_b, i), input_image)
misc.imsave(result_dir + database + '/%d_%d_%d_2_infer.png' % (l_b, h_b, i), infer_image)
# Logging
print('Data:% s, Low: %d, High: %d, %d/%d, PSNR: %f, SSIM: %f, Time per Img: %f' % (database, l_b, h_b, i, img_num, p_, s_, t_))
log.write('Data: %s, Low: %d, High: %d, %d/%d, PSNR: %f, SSIM: %f, Time per Img: %f\n' % (database, l_b, h_b, i, img_num, p_, s_, t_))
print('Data: %s, Low: %d, High: %d, Avg PSNR: %f, SSIM: %f, Time per Img: %f' % (database, l_b, h_b, t_p_/img_num, t_s_/img_num, t_t_/img_num))
log.write('Data: %s, Low: %d, High: %d, Avg PSNR: %f, SSIM: %f, Time per Img: %f\n\n' % (database, l_b, h_b, t_p_/img_num, t_s_/img_num, t_t_/img_num))
log.flush()
log.close()