-
Notifications
You must be signed in to change notification settings - Fork 0
/
metric.jl
226 lines (190 loc) · 8.25 KB
/
metric.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
using JSON, SimilaritySearch, JLD2, LinearAlgebra, SimSearchManifoldLearning
"""
load_sentence_embeddings(filename; key, normalize)
Loads a sentence embeddings in h5 format, e.g., generated by `encode-file.py`.
"""
function load_sentence_embeddings(filename; key="emb", normalize=true)
jldopen(filename) do f
X = f[key]
if normalize
for c in eachcol(X)
normalize!(c)
end
end
X
end
end
"""
knngraph(dist::SemiMetric, input::String; output::String, k::Int=16, minrecall::Float64=0.95, normalize=true, key::String="emb")
Computes an approximation of the ``k`` nearest neighbor graph
- `dist`: Distance function
- `input`: input embedding file (h5 format)
- `output`: filename to save the knn graph (using two matrices `knns` and `dists` of identifiers and distances)
- `k`: the number of neighbors
- `minrecall`: controls the quality of the approximation (between 0 and 1)
- `normalize`: true if the embeddings must be adjusted to have unitary norm
- `key`: the name of the dataset inside of the input file
"""
function knngraph(dist::SemiMetric, input::String; output::String, k::Int=16, minrecall::Float64=0.95, normalize=true, key::String="emb")
X = load_sentence_embeddings(input; key, normalize)
G = create_index(dist, StrideMatrixDatabase(X); k, minrecall)
knns, dists = allknn(G, k)
jldsave(output; knns, dists)
knns, dists
end
"""
create_index(dist::SemiMetric, db::AbstractDatabase; k::Int=16, minrecall::Float64=0.95, verbose=true)
Creates an index for the given database
- `dist`: Distance function
- `db`: input database
- `k`: the number of neighbors (only for optimization purposes)
- `minrecall`: controls the quality of the approximation (between 0 and 1)
- `verbose`: set `verbose=false` to reduce the output of the index's building
"""
function create_index(dist::SemiMetric, db::AbstractDatabase; k::Int=16, minrecall::Float64=0.95, verbose::Bool=true)
G = SearchGraph(; dist, db, verbose)
minrecall = MinRecall(minrecall)
callbacks = SearchGraphCallbacks(minrecall; ksearch=k, verbose)
index!(G; callbacks)
optimize!(G, minrecall)
G
end
"""
distsample(dist::SemiMetric, input::String; output::String, prob=0.1, normalize=true, key::String="emb")
Computes a sample of size of the pairwise distance matrix. Loads and saves the result in h5 files.
- `dist`: Distance function
- `input`: input embedding file (h5 format)
- `output`: filename to save the sample
- `prob`: sampling probability (on the upper triangle pairwise distance matrix)
- `normalize`: true if the embeddings must be adjusted to have unitary norm
- `key`: the name of the dataset inside of the input file
"""
function distsample(dist::SemiMetric, input::String; output::String, prob::Float64=0.1, key::String="emb", normalize::Bool=true)
X = load_sentence_embeddings(input; key, normalize) |> StrideMatrixDatabase
dists = distsample(dist, X; prob)
jldsave(output; dists)
dists
end
"""
distsample(dist::SemiMetric, X::AbstractDatabase; prob=0.1)
Computes a sample of size of the pairwise distance matrix
- `dist`: Distance function
- `X`: input database
- `output`: filename to save the sample
- `prob`: sampling probability (on the upper triangle pairwise distance matrix)
"""
function distsample(dist::SemiMetric, X::AbstractDatabase; prob::Float64=0.1)
n = length(X)
S = Float32[]
sizehint!(S, ceil(Int, prob * n))
for i in 1:n
for j in i+1:n-1
if rand() <= prob
push!(S, evaluate(dist, X[i], X[j]))
end
end
end
sort!(S)
S
end
"""
neardup_analysis(dist::SemiMetric, input::String; output::String, epsilon::Float64=0.1, key::String="emb", minrecall::Float64=0.9, normalize::Bool=true)
Remove near duplicates in the embedding readed from `input`
- `dist`: Distance function
- `epsilon`: items are appended incrementally; an item is accepted if there is not an item at distance `epsilon` and rejected if already exists an item at that distance
- `input`: input embedding file (h5 format)
- `output`: filename to save the `idx`, `map`, `nn`, and `dists` (h5 format)
- `minrecall`: controls the quality of the approximation (between 0 and 1)
- `normalize`: true if the embeddings must be adjusted to have unitary norm
- `key`: the name of the dataset inside of the input file
- `verbose`: verbose output
"""
function neardup_analysis(dist::SemiMetric, input::String; output::String, epsilon::Float64=0.1, key::String="emb", minrecall::Float64=0.9, normalize::Bool=true, verbose=true)
X = load_sentence_embeddings(input; key, normalize) |> StrideMatrixDatabase
p = neardup_analysis(dist, X; epsilon, minrecall, verbose)
jldsave(output; p.idx, p.map, p.nn, p.dist)
p
end
"""
neardup_analysis(dist::SemiMetric, X::AbstractDatabase; epsilon::Float64=0.1, minrecall::Float64=0.9, verbose=true)
Finds near duplicates in the `X` database using an incremental algorithma. Returns a the named duple
- `idx`: an index with the set of non near duplicates under the given parameters
- `map`: maps each entry of `idx` to its original index in `X`
- `nn`: nearest neighbors identifiers of `nn` in `idx`
- `dist`: corresponding distance values of `nn`
# Arguments
- `dist`: Distance function
- `X`: input database
- `epsilon`: items are appended incrementally; an item is accepted if there is not an item at distance `epsilon` and rejected if already exists an item at that distance
- `minrecall`: controls the quality of the approximation (between 0 and 1)
- `verbose`: verbose output
"""
function neardup_analysis(dist::SemiMetric, X::AbstractDatabase; epsilon::Float64=0.1, minrecall::Float64=0.9, verbose=true)
db = VectorDatabase(Vector{Float32}[])
neardup(SearchGraph(; db, dist, verbose), X, epsilon)
end
"""
filter_neardup(dist::SemiMetric, X::AbstractDatabase; epsilon::Float64=0.1, minrecall::Float64=0.9, verbose=true)
Remove near duplicates in the embedding readed from `input`
- `dist`: Distance function
- `X`: input database
- `epsilon`: items are appended incrementally; an item is accepted if there is not an item at distance `epsilon` and rejected if already exists an item at that distance
- `minrecall`: controls the quality of the approximation (between 0 and 1)
- `verbose`: verbose output
"""
function filter_neardup(dist::SemiMetric, X::AbstractDatabase; epsilon::Float64=0.1, minrecall::Float64=0.9, verbose=true)
p = neardup_analysis(dist, X; epsilon, minrecall, verbose)
MatrixDatabase(database(p.idx))
end
"""
umap_embeddings(
index::AbstractSearchIndex,
k = 15,
n_epochs = 100,
neg_sample_rate = 3,
tol = 1e-4,
layout = SpectralLayout(),
)
Computes a 2D and 3D umap embeddings for the given database (i.e., in the form of `index`)
- `index`: the indexed metric database
- `k`: the number of neighbors for the embedding
- `n_epochs`: number of epochs to refine the embedding
- `neg_sample_rate`: negative sampling rate (3 or 5 give good results)
- `tol`: stop tolerance for the iterative optimization (early stopping)
- `layout`: initialization method for the embedding
"""
function umap_embeddings(
index::AbstractSearchIndex,
k = 15,
n_epochs = 100,
neg_sample_rate = 3,
tol = 1e-4,
layout = SpectralLayout(),
)
# increase both `n_epochs` and `neg_sample_rate` to improve projection
#layout = SpectralLayout() ## the results are much better with Spectral layout
@time U2 = fit(UMAP, index; k, neg_sample_rate, layout, n_epochs, tol)
@time U3 = fit(U2, 3; neg_sample_rate, n_epochs, tol) # reuses U2
#jldsave(umapfile, e2=U2.embedding, e3=U3.embedding)
e2 = predict(U2)
e3 = predict(U3)
(; e2, e3)
end
function normcolors(V)
min_, max_ = extrema(V)
V .= (V .- min_) ./ (max_ - min_)
V .= clamp.(V, 0, 1)
end
#=
D = DataFrame(JSON.parse.(eachline("datasets/comp2023/IberLEF2023_HOMO-MEX_Es_train.json")))
D = DataFrame(JSON.parse.(eachline("datasets/comp2023/IberLEF2023_HOMO-MEX_Es_train.json")))
for r in eachrow(D[enns[:, 1], :])
println(r)
end
for r in eachrow(D[knns[:, 1], :])
println(r.klass => r.text)
end
for r in eachrow(D[knns[:, 2], :])
println(r.klass => r.text)
end
=#