forked from Visual-Computing/sisap24-deglib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_task.py
executable file
·234 lines (199 loc) · 8.4 KB
/
run_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#!/usr/bin/env python3
import multiprocessing
import argparse
import gc
import time
from typing import Optional, Callable, Any, List
from pathlib import Path
import h5py
import numpy as np
from compress import CompressionNet
import deglib
SMART_ENTRY = False
BUILD_HPARAMS = {
'quantize': True,
'edges_per_vertex': 32,
'lid': deglib.builder.LID.Low,
'extend_k': 64,
'extend_eps': 0.1,
}
EPS_SETTINGS = [
0.0, 0.001, 0.002, 0.005, 0.01,
0.02, 0.03, 0.04, 0.05, 0.06,
0.07, 0.08, 0.09, 0.1, 0.11,
0.12, 0.13, 0.14, 0.15, 0.16,
0.17, 0.18, 0.19, 0.2, 0.21,
0.22, 0.25, 0.27, 0.3, 0.35,
]
def parse_args():
parser = argparse.ArgumentParser(description='Run task 1 with --compression=512 and task 3 with --compression=64')
parser.add_argument(
'--dbsize',
type=str,
choices=["100K", "300K", "10M", "30M", "100M"],
help='The database size to use'
)
parser.add_argument(
'-k',
type=int,
default=30,
help='Number of results per query'
)
parser.add_argument(
'--show-progress',
default=True,
action='store_true',
help='show progress during graph building'
)
parser.add_argument(
'--compression', '-c',
type=int,
default=0,
help='use compression net to reduce data dimensionality'
)
parser.add_argument(
'--query-file',
type=Path,
default=Path("data") / "public-queries-2024-laion2B-en-clip768v2-n=10k.h5",
help='The query file'
)
return parser.parse_args()
def main():
# parse DB_SIZE argument
args = parse_args()
dbsize: str = args.dbsize
k: int = args.k
# load data (database, queries) using h5-file in batches (convert to f32)
data_file = Path("data") / "laion2B-en-clip768v2-n={}.h5".format(dbsize)
query_file = args.query_file
print('Use the following files:\ndata: "{}"\nqueries: "{}"\n'.format(data_file, query_file))
# load compression network
comp_net = None
if args.compression:
print('Load compression network {}D'.format(args.compression))
comp_net = CompressionNet(target_dim=args.compression)
# build graph
callback = 'progress' if args.show_progress else None
with h5py.File(data_file, 'r') as data_f:
assert 'emb' in data_f.keys()
data = data_f['emb']
build_start_time = time.perf_counter()
graph = build_deglib_from_data(data, comp_net, **BUILD_HPARAMS, callback=callback)
build_duration = time.perf_counter() - build_start_time
# benchmark graph
print('\nStart benchmarking the graph:')
# load cluster centers
entry_indices = [0]
if SMART_ENTRY:
cluster_centers = np.load('cluster_centers.npy', allow_pickle=True)
if comp_net is not None:
cluster_centers = comp_net.compress(cluster_centers, quantize=BUILD_HPARAMS['quantize'], batch_size=cluster_centers.shape[0])
entry_indices, _ = graph.search(
cluster_centers, eps=0.2, k=1, threads=min(multiprocessing.cpu_count(), cluster_centers.shape[0]),
thread_batch_size=1
)
entry_indices = list(entry_indices.flatten())
print('Seed vertex indices for the evaluation: {}'.format(entry_indices))
# evaluate on a test set
queries = load_queries(query_file)
benchmark_graph(graph, queries, comp_net, k, dbsize, build_duration, entry_indices)
def build_deglib_from_data(
data: h5py.Dataset, comp_net: Optional[CompressionNet], quantize: bool, edges_per_vertex: int,
lid: deglib.builder.LID, extend_k: Optional[int] = None, extend_eps: float = 0.2,
callback: Callable[[Any], None] | str | None = None
):
print('\n\nBuilding graph with hyperparameters: {}'.format(BUILD_HPARAMS))
num_samples = data.shape[0]
if comp_net is not None:
dim = comp_net.target_dim
metric = deglib.Metric.L2_Uint8 if quantize else deglib.Metric.L2
else:
dim = data.shape[1]
metric = deglib.Metric.InnerProduct
graph = deglib.graph.SizeBoundedGraph.create_empty(num_samples, dim, edges_per_vertex, metric)
builder = deglib.builder.EvenRegularGraphBuilder(
graph, rng=None, lid=lid, extend_k=extend_k, extend_eps=extend_eps, improve_k=0
)
print(f"Start adding {num_samples} data points to builder", flush=True)
chunk_size = 50_000
start_time = time.perf_counter()
labels = np.arange(num_samples, dtype=np.uint32)
for counter, min_index in enumerate(range(0, num_samples, chunk_size)):
if counter != 0 and counter % 10 == 0:
print('Added {} data points after {:5.1f}s'.format(min_index, time.perf_counter() - start_time), flush=True)
max_index = min(min_index + chunk_size, num_samples)
chunk = data[min_index:max_index]
if comp_net is not None:
chunk = comp_net.compress(chunk, quantize=quantize, batch_size=chunk.shape[0])
else:
chunk = chunk.astype(np.float32)
builder.add_entry(
labels[min_index:max_index],
chunk
)
print('Added {} data points after {:5.1f}s\n'.format(num_samples, time.perf_counter() - start_time), flush=True)
print('Start building graph:', flush=True)
builder.build(callback=callback)
# remove builder to free memory
del builder
gc.collect()
print('Removing all none MRNG conform edges ... ', flush=True)
graph.remove_non_mrng_edges()
graph = deglib.graph.ReadOnlyGraph.from_graph(graph)
gc.collect()
return graph
def load_queries(query_file: Path):
assert query_file.is_file(), 'Could not find query file: {}'.format(query_file)
with h5py.File(query_file, 'r') as infile:
assert 'emb' in infile.keys(), 'Could not find "emb" key in query file "{}"'.format(query_file)
queries = infile['emb'][()]
return queries
def benchmark_graph(
graph: deglib.graph.SearchGraph, queries: np.ndarray, comp_net: Optional[CompressionNet], k: int, dbsize: str,
build_time: float, entry_indices: List[int] | None
):
print('queries:', queries.shape, queries.dtype)
print(f'{"eps":<8} {"query time":<13} {"comp time":<13} {"graph time":<13}')
for eps in EPS_SETTINGS:
start_time_benchmark = time.perf_counter()
if comp_net is not None:
compressed_queries = comp_net.compress(
queries, quantize=BUILD_HPARAMS['quantize'], batch_size=queries.shape[0]
)
else:
compressed_queries = queries
start_time_search = time.perf_counter()
prediction, distances = graph.search(
compressed_queries, k=k, eps=eps, threads=multiprocessing.cpu_count(), thread_batch_size=32,
entry_vertex_indices=entry_indices
)
end_time = time.perf_counter()
query_time = end_time - start_time_benchmark
print(f'{eps:<8} {query_time:<13.4f} {start_time_search - start_time_benchmark:<13.4f} '
f'{end_time - start_time_search:<13.4f}')
# store the information
# https://github.com/sisap-challenges/sisap23-laion-challenge-evaluation/tree/0a6f90debe73365abee210d3950efc07223c846d
algo = "deglib"
data = "normal"
if comp_net is not None:
data = "uint{}".format(comp_net.target_dim)
index_identifier = "epv={}_extendK={}_extendEps={}".format(BUILD_HPARAMS['edges_per_vertex'], BUILD_HPARAMS['extend_k'], BUILD_HPARAMS['extend_eps'])
identifier = f"index=({index_identifier}),query=(eps={eps})"
destination = Path('results') / dbsize / data / 'deglib_{}.h5'.format(identifier)
prediction += 1 # offset prediction to start at index 1
store_results(
destination, algo, data, distances, prediction, build_time, query_time, identifier, dbsize
)
def store_results(dst: Path, algo, kind, distances, result_indices, buildtime, querytime, params, size):
dst.parent.mkdir(exist_ok=True, parents=True)
with h5py.File(dst, 'w') as f:
f.attrs['algo'] = algo
f.attrs['data'] = kind
f.attrs['buildtime'] = buildtime
f.attrs['querytime'] = querytime
f.attrs['size'] = size
f.attrs['params'] = params
f.create_dataset('knns', result_indices.shape, dtype=result_indices.dtype)[:] = result_indices
f.create_dataset('dists', distances.shape, dtype=distances.dtype)[:] = distances
if __name__ == '__main__':
main()