-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Specialize sample
for sparse weights
#943
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@@ -608,6 +608,11 @@ sample(wv::AbstractWeights) = sample(default_rng(), wv) | |||
sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)] | |||
sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv) | |||
|
|||
function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,T,V}) where {T<:Real,V<:SparseVector{T}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be enough, right?
function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,T,V}) where {T<:Real,V<:SparseVector{T}} | |
function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,<:Real,<:SparseVector} |
@@ -608,6 +608,11 @@ sample(wv::AbstractWeights) = sample(default_rng(), wv) | |||
sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)] | |||
sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv) | |||
|
|||
function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,T,V}) where {T<:Real,V<:SparseVector{T}} | |||
i = sample(rng, Weights(nonzeros(wv.values), sum(wv))) | |||
return SparseArrays.nonzeroinds(wv.values)[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return SparseArrays.nonzeroinds(wv.values)[i] | |
return rowvals(wv.values)[i] |
@@ -41,6 +41,7 @@ for wv in ( | |||
weights([0.2, 0.8, 0.4, 0.6]), | |||
weights([2, 8, 4, 6]), | |||
weights(Float32[0.2, 0.8, 0.4, 0.6]), | |||
weights(sparsevec([0, 8, 0, 6])), | |||
Weights(Float32[0.2, 0.8, 0.4, 0.6], 2), | |||
Weights([2, 8, 4, 6], 20.0), | |||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would make sense to also test line 137 below on sparse weights, right?
This PR adds a new
sample
method for sparse weights, as well as tests. It brings the time complexity fromO(n)
toO(n_nonzero)
.This would be useful for e.g. top-p sampling, where one might have on the order of 100k tokens to sample from, but only a few are considered.
Benchmarks across different sizes and densities
Results
This shows the dense baseline, and the relative performance increase to invoking
sample
with the generic method for sparse weights.Benchmark setup
Note: For small vector lengths (~10) and low densities (~0.2) the performance difference becomes noisy and less meaningful. The generic method can sometimes be faster in these cases due to less overhead when it happens to find the target probability mass early in the vector. However, for these small cases the absolute timing differences are negligible (few nanoseconds) and sparse storage isn't really beneficial anyway.
Note: The implementation uses
SparseArrays.nonzeroinds
, which is not public.