-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvisualize.py
65 lines (51 loc) · 1.87 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- coding: utf-8 -*-
import argparse
import torch
from torch.autograd import Variable
from torchvision import datasets
from torchvision import transforms
import numpy as np
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import offsetbox
import seaborn as sns
from train import Siamese
def main(args):
sns.set(style="whitegrid", font_scale=1.5)
test_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST("./data", train=False, download=False, transform=transforms.Compose([
transforms.ToTensor()
])), batch_size=100, shuffle=True)
model = torch.load(args.ck_path)
model.eval()
inputs, embs, targets = [], [], []
for x, t in tqdm(test_loader, total=len(test_loader)):
x = Variable(x.cuda())
o1 = model(x)
inputs.append(x.cpu().data.numpy())
embs.append(o1.cpu().data.numpy())
targets.append(t.numpy())
inputs = np.array(inputs).reshape(-1, 28, 28)
embs = np.array(embs).reshape((-1, 2))
targets = np.array(targets).reshape((-1,))
n_plots = args.n_plots
plt.figure(figsize=(8, 8))
ax = plt.subplot(111)
ax.set_title("FashionMNIST 2D embeddigs")
for x, e, t in zip(inputs[:n_plots], embs[:n_plots], targets[:n_plots]):
imagebox = offsetbox.AnnotationBbox(
offsetbox.OffsetImage(x, zoom=0.5, cmap=plt.cm.gray_r),
xy=e, frameon=False)
ax.add_artist(imagebox)
ax.set_xlim(embs[:, 0].min(), embs[:, 0].max())
ax.set_ylim(embs[:, 1].min(), embs[:, 1].max())
plt.tight_layout()
plt.savefig("vis.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ck_path", type=str, default="./checkpoint/20.tar")
parser.add_argument("--n_plots", type=int, default=500)
args = parser.parse_args()
main(args)