-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathencode_vision.py
54 lines (44 loc) · 1.67 KB
/
encode_vision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os.path
import torch
import collections
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
from define_network import Compression_encoder,AutoEncoder
from sample_set import Sample_set
if __name__ == '__main__':
path_ = os.path.abspath('.')
n = 100
fname = path_ + '/conv_autoencoder.pth'
ae = AutoEncoder()
ae.load_state_dict(torch.load(fname))
fname = path_ + '/compression_encoder.pth'
ce = Compression_encoder()
ce.load_state_dict(torch.load(fname))
testset = Sample_set(path_+'/test')
testloader = torch.utils.data.DataLoader(testset,batch_size=1,shuffle=False)
for i,data in enumerate(testloader,0):
input,target = data
input = Variable(input)
actual = target[0][0].numpy() # actual is 2-dim
output = ae(input.float())
output = output.data[0][0].numpy() # output is 2-dim
min_ = min(actual)
max_ = max(actual)
#print actual,output
code = ce(input.float())
code = code.data[0].numpy()
X = range(0,len(actual))
plt.figure(figsize=(12,8),dpi=80)
plt.plot(X,actual,color='black',linewidth=1,label='original')
plt.plot(X,output,color='red',linewidth=1,label='encoding_recover')
plt.legend(loc='upper right', frameon=False, fontsize=20)
plt.text(5,min_+(max_-min_)*0.7,code[0],fontsize=20)
plt.text(5,min_+(max_-min_)*0.6,code[1],fontsize=20)
plt.text(5,min_+(max_-min_)*0.5,code[2],fontsize=20)
plt.text(5,min_+(max_-min_)*0.4,code[3],fontsize=20)
plt.text(5,min_+(max_-min_)*0.3,code[4],fontsize=20)
plt.savefig(path_+'/vision/vision_%d.png'%i)
#plt.show()
if i == n :
break