Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize sample for sparse weights #943

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

AntonOresten
Copy link

@AntonOresten AntonOresten commented Nov 26, 2024

This PR adds a new sample method for sparse weights, as well as tests. It brings the time complexity from O(n) to O(n_nonzero).

This would be useful for e.g. top-p sampling, where one might have on the order of 100k tokens to sample from, but only a few are considered.

Benchmarks across different sizes and densities

Results

This shows the dense baseline, and the relative performance increase to invoking sample with the generic method for sparse weights.

Dense vs Sparse vs Generic sampling:
size    density         dense_time      generic_time    sparse_time     speedup_dense   speedup_generic
----------------------------------------------------------------------------------------------------
10      0.10            8.200 ns        16.717 ns       20.900 ns       0.4x            0.8x
10      0.25            12.312 ns       29.648 ns       25.726 ns       0.5x            1.2x
10      0.50            15.000 ns       31.426 ns       31.000 ns       0.5x            1.0x
10      1.00            17.918 ns       46.821 ns       33.835 ns       0.5x            1.4x

100     0.01            53.799 ns       133.853 ns      21.421 ns       2.5x            6.2x
100     0.10            44.803 ns       235.024 ns      34.303 ns       1.3x            6.9x
100     0.25            54.095 ns       380.565 ns      40.302 ns       1.3x            9.4x
100     0.50            52.775 ns       454.237 ns      50.655 ns       1.0x            9.0x
100     1.00            51.160 ns       553.982 ns      69.706 ns       0.7x            7.9x

1000    0.01            376.093 ns      2.613 μs        34.102 ns       11.0x           76.6x
1000    0.10            405.793 ns      6.025 μs        70.871 ns       5.7x            85.0x
1000    0.25            393.353 ns      7.775 μs        128.072 ns      3.1x            60.7x
1000    0.50            383.527 ns      8.743 μs        219.973 ns      1.7x            39.7x
1000    1.00            384.167 ns      8.444 μs        398.155 ns      1.0x            21.2x

10000   0.01            3.533 μs        44.500 μs       69.706 ns       50.7x           638.4x
10000   0.10            3.778 μs        88.300 μs       403.333 ns      9.4x            218.9x
10000   0.25            3.689 μs        121.850 μs      940.230 ns      3.9x            129.6x
10000   0.50            3.720 μs        152.863 μs      1.880 μs        2.0x            81.3x
10000   1.00            3.744 μs        131.150 μs      3.750 μs        1.0x            35.0x
Benchmark setup
function benchmark_sparse_sample(; sizes=[10, 100, 1000, 10_000], densities=[0.01, 0.1, 0.25, 0.5, 1.0])
    println("Dense vs Sparse vs Generic sampling:")
    println("size\tdensity\t\tdense_time\tgeneric_time\tsparse_time\tspeedup_dense\tspeedup_generic")
    println("-" ^ 100)
    
    for n in sizes
        for density in densities
            n * density < 1 && continue
            nnz = round(Int, n * density)

            indices = sort!(sample(1:n, nnz, replace=false))
            values = rand(nnz)
            values ./= sum(values)
            sparse_vector = sparsevec(indices, values, n)

            sparse_weights = Weights(sparse_vector)
            dense_weights = Weights(collect(sparse_vector))
            
            dense = @benchmark sample($dense_weights)
            sparse = @benchmark sample($sparse_weights)
            generic = @benchmark invoke(sample, Tuple{AbstractRNG, AbstractWeights}, 
                    $(Random.default_rng()), $sparse_weights)
            
            dense_time = median(dense).time
            generic_time = median(generic).time
            sparse_time = median(sparse).time
            speedup_dense = dense_time / sparse_time
            speedup_generic = generic_time / sparse_time
            
            @printf("%-8d%-16.2f%-16s%-16s%-16s%.1fx\t\t%.1fx\n",
                    n, density, BenchmarkTools.prettytime(dense_time),
                    BenchmarkTools.prettytime(generic_time),
                    BenchmarkTools.prettytime(sparse_time), speedup_dense, speedup_generic)
        end
        println()
    end
end

Note: For small vector lengths (~10) and low densities (~0.2) the performance difference becomes noisy and less meaningful. The generic method can sometimes be faster in these cases due to less overhead when it happens to find the target probability mass early in the vector. However, for these small cases the absolute timing differences are negligible (few nanoseconds) and sparse storage isn't really beneficial anyway.

Note: The implementation uses SparseArrays.nonzeroinds, which is not public.

Copy link
Member

@nalimilan nalimilan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@@ -608,6 +608,11 @@ sample(wv::AbstractWeights) = sample(default_rng(), wv)
sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)]
sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv)

function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,T,V}) where {T<:Real,V<:SparseVector{T}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be enough, right?

Suggested change
function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,T,V}) where {T<:Real,V<:SparseVector{T}}
function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,<:Real,<:SparseVector}

@@ -608,6 +608,11 @@ sample(wv::AbstractWeights) = sample(default_rng(), wv)
sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)]
sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv)

function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,T,V}) where {T<:Real,V<:SparseVector{T}}
i = sample(rng, Weights(nonzeros(wv.values), sum(wv)))
return SparseArrays.nonzeroinds(wv.values)[i]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return SparseArrays.nonzeroinds(wv.values)[i]
return rowvals(wv.values)[i]

@@ -41,6 +41,7 @@ for wv in (
weights([0.2, 0.8, 0.4, 0.6]),
weights([2, 8, 4, 6]),
weights(Float32[0.2, 0.8, 0.4, 0.6]),
weights(sparsevec([0, 8, 0, 6])),
Weights(Float32[0.2, 0.8, 0.4, 0.6], 2),
Weights([2, 8, 4, 6], 20.0),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make sense to also test line 137 below on sparse weights, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants