Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a faster implementation of AliasTables #1848

Merged
merged 7 commits into from
Apr 20, 2024

Conversation

LilithHafner
Copy link
Contributor

@LilithHafner LilithHafner commented Apr 8, 2024

AliasTables.jl provides a high performance, high precision alias table implementation. This PR switches from the use of StatsBase implementation details to the public API of the AliasTables package. After making that switch, I also re-tuned the sampling threshold for Multinomial so that it selects the right algorithm more appropriately.

Benchmarks using Chairmarks.jl ("a => b" means I got the result "a" when running on master and the result "b" when running on PR branch):

Multinomial sampling (some multinomial changes backed out of this PR, see #1848 (comment), #1834. The multinomial bnechamrks here are as of 1c3530c)

julia> using Chairmarks

julia> @b Distributions.Multinomial(100, normalize(rand(100))) rand(_, 100)
420.086 μs (6 allocs: 81.672 KiB) => 17.250 μs (4 allocs: 81.172 KiB)

julia> @b Distributions.Multinomial(100000, normalize(rand(100))) rand(_, 100)
577.754 μs (6 allocs: 81.672 KiB) => 549.170 μs (4 allocs: 81.172 KiB)

julia> @b Distributions.Multinomial(100, normalize(rand(1000))) rand(_, 100)
1.515 ms (6 allocs: 813.047 KiB) => 37.750 μs (4 allocs: 805.359 KiB)

julia> @b Distributions.Multinomial(1000, normalize(rand(1000))) rand(_, 1000)
43.298 ms (6 allocs: 7.660 MiB) => 2.326 ms (4 allocs: 7.653 MiB)

AliasTable sampling (same benchmarks as #1831)

julia> @b Distributions.AliasTable(normalize(rand(10))) rand
6.916 ns => 2.633 ns

julia> @b Distributions.AliasTable(normalize(rand(100))) rand
9.717 ns => 2.713 ns

julia> @b Distributions.AliasTable(normalize(rand(100_000))) rand
11.408 ns => 2.874 ns

AliasTable construction

julia> @b normalize(rand(100_000)) Distributions.AliasTable # Regression
1.218 ms (8 allocs: 3.052 MiB) => 1.949 ms (4 allocs: 2.763 MiB) 

julia> @b normalize(rand(100)) Distributions.AliasTable
936.609 ns (4 allocs: 3.500 KiB) => 901.031 ns (2 allocs: 3.000 KiB)

julia> @b normalize(rand(10)) Distributions.AliasTable # Regression
117.140 ns (4 allocs: 576 bytes) => 165.119 ns (2 allocs: 480 bytes)

However, these constructor time comparisons are not apples to apples. The old ones give the wrong answer without normalization (see #832) while the new ones do not require normalization (fixes #832).

The new version is slow largely because of unreasonably strict precision guarantees in the normalization code. I can add an option (which could be on by default for Distributions.jl) for faster and less precise construction. However, I haven not added this lower precision option because I doubt construction time is typically a limiting factor in AliasTable sampling speed (one would have to generate the weights at a rate of less than 20ns/weight for that to be an issue)

The new version is also faster (and more precise) for integer inputs (not including the old runtimes because the old algorithm produced incorrect answers)

julia> @b rand(1:1000, 10) Distributions.AliasTable
91.000 ns (2 allocs: 480 bytes)

julia> @b rand(1:1000, 100) Distributions.AliasTable
596.056 ns (2 allocs: 3.000 KiB)

julia> @b rand(1:1000, 100_000) Distributions.AliasTable
1.027 ms (4 allocs: 2.763 MiB)

The sampling assembly is now branch-less and shorter

Before

julia> @code_native debuginfo=:none rand(Random.default_rng(), at)
        .text
        .file   "rand"
        .globl  julia_rand_13388                // -- Begin function julia_rand_13388
        .p2align        2
        .type   julia_rand_13388,@function
julia_rand_13388:                       // @julia_rand_13388
; Function Signature: rand(Random.TaskLocalRNG, Distributions.AliasTable)
// %bb.0:                               // %top
        //DEBUG_VALUE: rand:s <- [DW_OP_deref] [$x0+0]
        //DEBUG_VALUE: rand:s <- [DW_OP_deref] [$x0+0]
        sub     sp, sp, #96
        stp     x29, x30, [sp, #48]             // 16-byte Folded Spill
        str     x21, [sp, #64]                  // 8-byte Folded Spill
        stp     x20, x19, [sp, #80]             // 16-byte Folded Spill
        add     x29, sp, #48
        str     xzr, [sp, #32]
        str     xzr, [sp, #24]
        str     xzr, [sp, #16]
        //APP
        mrs     x8, TPIDR_EL0
        //NO_APP
        ldr     x20, [x8, #16]
        mov     w8, #4
        str     x8, [sp, #16]
        ldr     x8, [x20]
        str     x8, [sp, #24]
        add     x8, sp, #16
        str     x8, [x20]
        ldr     x21, [x0, #8]
        ldr     x8, [x21, #16]
        cmp     x8, #0
        b.le    .LBB0_4
// %bb.1:                               // %L17
        mov     x19, x0
        //DEBUG_VALUE: rand:s <- [DW_OP_deref] [$x19+0]
        mov     w9, #1
        stp     x9, x8, [sp]
        mov     x0, sp
        bl      j_rand_13402
        ldp     x8, x9, [x20, #-56]
        ldp     x10, x11, [x20, #-40]
        add     x12, x11, x8
        ror     x12, x12, #41
        add     x12, x12, x8
        eor     x10, x10, x8
        eor     x11, x11, x9
        eor     x13, x10, x9
        eor     x8, x11, x8
        eor     x9, x10, x9, lsl #17
        ror     x10, x11, #19
        stp     x8, x13, [x20, #-56]
        stp     x9, x10, [x20, #-40]
        lsr     x8, x12, #11
        ucvtf   d0, x8
        mov     x8, #4368491638549381120
        fmov    d1, x8
        fmul    d0, d0, d1
        ldr     x8, [x19]
        ldr     x9, [x8]
        sub     x8, x0, #1
        ldr     d1, [x9, x8, lsl #3]
        fcmp    d0, d1
        b.mi    .LBB0_3
// %bb.2:                               // %L75
        ldr     x9, [x21]
        ldr     x0, [x9, x8, lsl #3]
.LBB0_3:                                // %L95
        ldr     x8, [sp, #24]
        str     x8, [x20]
        ldp     x20, x19, [sp, #80]             // 16-byte Folded Reload
        ldr     x21, [sp, #64]                  // 8-byte Folded Reload
        ldp     x29, x30, [sp, #48]             // 16-byte Folded Reload
        add     sp, sp, #96
        ret
.LBB0_4:                                // %L14
        adrp    x0, ".Ljl_global#13395.jit"
        add     x0, x0, :lo12:".Ljl_global#13395.jit"
        bl      j_ArgumentError_13394
        mov     x19, x0
        str     x0, [sp, #32]
        ldr     x0, [x20, #16]
        mov     x20, #36592
        movk    x20, #43617, lsl #16
        movk    x20, #65534, lsl #32
        mov     w1, #752
        mov     w2, #16
        mov     x3, #36592
        movk    x3, #43617, lsl #16
        movk    x3, #65534, lsl #32
        bl      ijl_gc_pool_alloc_instrumented
        stp     x20, x19, [x0, #-8]
        bl      ijl_throw
.Lfunc_end0:
        .size   julia_rand_13388, .Lfunc_end0-julia_rand_13388
                                        // -- End function
.set ".L+Core.ArgumentError#13397.jit", 281469245296368
        .size   ".L+Core.ArgumentError#13397.jit", 8
.set ".Ljl_global#13395.jit", 281469267606560
        .size   ".Ljl_global#13395.jit", 8
        .section        ".note.GNU-stack","",@progbits

After

julia> @code_native debuginfo=:none rand(Random.default_rng(), at)
        .text
        .file   "rand"
        .globl  julia_rand_8708                 // -- Begin function julia_rand_8708
        .p2align        2
        .type   julia_rand_8708,@function
julia_rand_8708:                        // @julia_rand_8708
; Function Signature: rand(Random.TaskLocalRNG, Distributions.AliasTable)
// %bb.0:                               // %top
        //DEBUG_VALUE: rand:s <- [DW_OP_deref] [$x0+0]
        //DEBUG_VALUE: rand:s <- [DW_OP_deref] [$x0+0]
        stp     x29, x30, [sp, #-16]!           // 16-byte Folded Spill
        mov     x29, sp
        //APP
        mrs     x8, TPIDR_EL0
        //NO_APP
        ldr     x8, [x8, #16]
        ldp     x9, x10, [x8, #-56]
        ldp     x11, x12, [x8, #-40]
        add     x13, x12, x9
        ror     x13, x13, #41
        eor     x11, x11, x9
        eor     x12, x12, x10
        eor     x14, x11, x10
        eor     x10, x11, x10, lsl #17
        ror     x11, x12, #19
        eor     x12, x12, x9
        stp     x12, x14, [x8, #-56]
        stp     x10, x11, [x8, #-40]
        ldp     x10, x8, [x0]
        ldp     x11, x8, [x8]
        clz     x12, x11
        add     x12, x12, #1
        add     x9, x13, x9
        and     x10, x10, x9
        lsr     x9, x9, x12
        cmp     x11, #1
        mov     w11, #1
        csinc   x9, x11, x9, ls
        add     x8, x8, x9, lsl #4
        ldp     x11, x8, [x8, #-16]
        cmp     x10, x11
        csel    x8, x8, xzr, lo
        add     x0, x8, x9
        ldp     x29, x30, [sp], #16             // 16-byte Folded Reload
        ret
.Lfunc_end0:
        .size   julia_rand_8708, .Lfunc_end0-julia_rand_8708
                                        // -- End function
        .section        ".note.GNU-stack","",@progbits

For how this speedup is possible, see https://aliastables.lilithhafner.com/dev/#Implementation-details.

@codecov-commenter
Copy link

codecov-commenter commented Apr 8, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 85.96%. Comparing base (f33af97) to head (c9137ef).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1848      +/-   ##
==========================================
+ Coverage   85.94%   85.96%   +0.02%     
==========================================
  Files         144      144              
  Lines        8656     8647       -9     
==========================================
- Hits         7439     7433       -6     
+ Misses       1217     1214       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@devmotion
Copy link
Member

Given the performance improvements, shouldn't this be used in or even contributed to StatsBase?

src/samplers/aliastable.jl Outdated Show resolved Hide resolved
src/samplers/multinomial.jl Outdated Show resolved Hide resolved
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, that's a great performance improvement!

test/univariate/discrete/categorical.jl Outdated Show resolved Hide resolved
@devmotion devmotion merged commit b670fee into JuliaStats:master Apr 20, 2024
11 of 14 checks passed
@LilithHafner LilithHafner deleted the lh/alias-tables branch April 20, 2024 12:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AliasTable method is not working for unbalanced values
4 participants