-
Notifications
You must be signed in to change notification settings - Fork 8
/
utils.py
96 lines (80 loc) · 3.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import faiss
import logging
import numpy as np
from typing import Tuple
from torch.utils.data import Dataset
import os
from os.path import join
import sys
import traceback
import visualizations
# Compute R@1, R@5, R@10, R@20
RECALL_VALUES = [1, 5, 10, 20]
def compute_recalls(eval_ds: Dataset, queries_descriptors : np.ndarray, database_descriptors : np.ndarray,
output_folder : str = None, num_preds_to_save : int = 0,
save_only_wrong_preds : bool = True) -> Tuple[np.ndarray, str]:
"""Compute the recalls given the queries and database descriptors. The dataset is needed to know the ground truth
positives for each query."""
# Use a kNN to find predictions
faiss_index = faiss.IndexFlatL2(queries_descriptors.shape[1])
faiss_index.add(database_descriptors)
del database_descriptors
logging.debug("Calculating recalls")
_, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES))
#### For each query, check if the predictions are correct
positives_per_query = eval_ds.get_positives()
recalls = np.zeros(len(RECALL_VALUES))
for query_index, preds in enumerate(predictions):
for i, n in enumerate(RECALL_VALUES):
if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
recalls[i:] += 1
break
# Divide by queries_num and multiply by 100, so the recalls are in percentages
recalls = recalls / eval_ds.queries_num * 100
recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)])
# Save visualizations of predictions
if num_preds_to_save != 0:
# For each query save num_preds_to_save predictions
visualizations.save_preds(predictions[:, :num_preds_to_save], eval_ds, output_folder, save_only_wrong_preds)
return recalls, recalls_str
def setup_logging(save_dir, console="debug",
info_filename="info.log", debug_filename="debug.log"):
"""Set up logging files and console output.
Creates one file for INFO logs and one for DEBUG logs.
Args:
save_dir (str): creates the folder where to save the files.
debug (str):
if == "debug" prints on console debug messages and higher
if == "info" prints on console info messages and higher
if == None does not use console (useful when a logger has already been set)
info_filename (str): the name of the info file. if None, don't create info file
debug_filename (str): the name of the debug file. if None, don't create debug file
"""
if os.path.exists(save_dir):
raise FileExistsError(f"{save_dir} already exists!")
os.makedirs(save_dir, exist_ok=True)
# logging.Logger.manager.loggerDict.keys() to check which loggers are in use
base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S")
logger = logging.getLogger('')
logger.setLevel(logging.DEBUG)
if info_filename is not None:
info_file_handler = logging.FileHandler(join(save_dir, info_filename))
info_file_handler.setLevel(logging.INFO)
info_file_handler.setFormatter(base_formatter)
logger.addHandler(info_file_handler)
if debug_filename is not None:
debug_file_handler = logging.FileHandler(join(save_dir, debug_filename))
debug_file_handler.setLevel(logging.DEBUG)
debug_file_handler.setFormatter(base_formatter)
logger.addHandler(debug_file_handler)
if console is not None:
console_handler = logging.StreamHandler()
if console == "debug":
console_handler.setLevel(logging.DEBUG)
if console == "info":
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(base_formatter)
logger.addHandler(console_handler)
def exception_handler(type_, value, tb):
logger.info("\n" + "".join(traceback.format_exception(type, value, tb)))
sys.excepthook = exception_handler