-
Notifications
You must be signed in to change notification settings - Fork 3
/
inference.py
95 lines (77 loc) · 3.43 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import argparse
import multiprocessing
import warnings
import pandas as pd
import torch
from torch.utils.data import DataLoader
from sklearn import preprocessing
from importlib import import_module
from tqdm import tqdm
from dataset.dataset import CustomDataset, get_data
def inference(model, test_loader, device, mode):
model.to(device)
model.eval()
model_preds = []
logits = []
with torch.no_grad():
for img in tqdm(iter(test_loader)):
img = img.float().to(device)
if args.tta:
model_pred = model(img) / 2
model_pred += model(torch.flip(img, dims=(-1,))) / 2
else:
model_pred = model(img)
if mode in ['logit', 'both']:
logits.extend(model_pred.detach().cpu().numpy().tolist())
if mode in ['answer', 'both']:
model_preds += model_pred.argmax(1).detach().cpu().numpy().tolist()
return model_preds, logits
def parse_arg():
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument('--data_dir', type=str, default='data/')
parser.add_argument('--model', type=str, default='BaseModel')
parser.add_argument('--img_size', type=int, default=224)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--output_dir', type=str, default='./output/submission/')
parser.add_argument('--model_path', type=str, default='./output/model/exp/latest.pt')
parser.add_argument('--mode', type=str, default='answer')
parser.add_argument('--tta', action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_arg()
print(args)
if args.mode not in ['answer', 'logit', 'both']:
raise ValueError(f'Unknown mode ({args.mode})')
warnings.filterwarnings('ignore')
device = "cuda" if torch.cuda.is_available() else "cpu"
num_workers = multiprocessing.cpu_count() // 2
model_module = getattr(import_module("models.model"), args.model)
model = model_module(num_classes=50)
model.load_state_dict(torch.load(args.model_path))
model.eval()
test_df = pd.read_csv(os.path.join(args.data_dir, 'test.csv'))
test_df['img_path'] = test_df['img_path'].apply(
lambda x: os.path.join(args.data_dir, x[2:]))
test_img_paths = get_data(test_df, infer=True)
test_transform_module = getattr(import_module('dataset.augmentation'), 'TestAugmentation')
test_transform = test_transform_module(args.img_size)
test_dataset = CustomDataset(test_img_paths, None, test_transform)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=num_workers)
preds, logits = inference(model, test_loader, device, args.mode)
os.makedirs(args.output_dir, exist_ok=True)
submit = pd.read_csv(os.path.join(args.data_dir, 'sample_submission.csv'))
if args.mode in ['answer', 'both']:
df = pd.read_csv(os.path.join(args.data_dir, 'train.csv'))
le = preprocessing.LabelEncoder()
df['artist'] = le.fit_transform(df['artist'].values)
preds = le.inverse_transform(preds)
submit['artist'] = preds
path = os.path.join(args.output_dir, 'answer.csv')
submit.to_csv(path, index=False)
if args.mode in ['logit', 'both']:
submit["artist"] = logits
path = os.path.join(args.output_dir, 'logit.csv')
submit.to_csv(path, index=False)