-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_multi_mean.py
92 lines (78 loc) · 3.58 KB
/
inference_multi_mean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import timm
import time
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
model_num = 15
num_epoch = 60
model_name = ['gaussian_labelsmooth','gaussian','gaussian5','labelsmooth','gaussian5_labelsmooth']
ten_crop = True
seed_num = [0,1,2,0,1,2,37,47,57,0,1,2,37,47,57]
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the data transforms
transform_ten = transforms.Compose([
transforms.Resize(256),
transforms.GaussianBlur(3,1),
transforms.TenCrop(224),
transforms.Lambda(lambda crops: torch.stack([transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(transforms.ToTensor()(crop)) for crop in crops]))
])
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.GaussianBlur(3,1),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
# Load the CIFAR-10 test dataset
if ten_crop :
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_ten)
else:
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)
# Define the list of models for ensemble
models = []
correct = 0
total = 0
start = time.time()
for i in range(model_num):
# Define the ResNet-18 model with pre-trained weights
model = timm.create_model('resnet18', num_classes=10)
model.load_state_dict(torch.load(f"./weights/%s_%d_%d.pth" % (model_name[i//3], num_epoch, seed_num[i]))) # Load the trained weights
model.eval() # Set the model to evaluation mode
model = model.to(device) # Move the model to the GPU
models.append(model)
if ten_crop :
with torch.no_grad():
for data in testloader:
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device) # Move the input data to the GPU
bs, ncrops, c, h, w = inputs.size()
outputs = torch.zeros(bs, 10).to(device) # Initialize the output tensor with zeros
for model in models:
model_output = model(inputs.view(-1, c, h, w)) # Reshape the input to (bs*10, c, h, w)
model_output = model_output.view(bs, ncrops, -1).mean(1) # Average the predictions of the 10 crops
outputs += model_output
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
else :
with torch.no_grad():
for data in testloader:
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device) # Move the input data to the GPU
bs, c, h, w = inputs.size()
outputs = torch.zeros(bs, 10).to(device) # Initialize the output tensor with zeros
for model in models:
model_output = model(inputs) # Reshape the input to (bs*10, c, h, w)
outputs += model_output
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(model_name)
print('Accuracy : %2f %%' % (100 * correct / total))
print('Inference time : ', time.strftime("%H:%M:%S",time.gmtime(time.time()-start)))