Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segfault in AdvancedVI when involving Diagonal specializations #2042

Closed
Red-Portal opened this issue Nov 3, 2024 · 12 comments
Closed

Segfault in AdvancedVI when involving Diagonal specializations #2042

Red-Portal opened this issue Nov 3, 2024 · 12 comments

Comments

@Red-Portal
Copy link

Red-Portal commented Nov 3, 2024

Hitting segfaults in the following example. This seems unique to the case where D <: Diagonal.

using ADTypes
using Enzyme
using Distributions
using LinearAlgebra
using StableRNGs
using Random
using Optimisers
using Functors
using Test
using LogDensityProblems

struct MvLocationScale{
    S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
    scale_eps::E
end

Functors.@functor MvLocationScale (location, scale)

struct RestructureMeanField{S<:Diagonal,D,L,E}
    model::MvLocationScale{S,D,L,E}
end

function (re::RestructureMeanField)(flat::AbstractVector)
    n_dims = div(length(flat), 2)
    location = first(flat, n_dims)
    scale = Diagonal(last(flat, n_dims))
    return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
end

function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E}
    (; location, scale, dist) = q
    flat = vcat(location, diag(scale))
    return flat, RestructureMeanField(q)
end

# This specialization improves AD performance of the sampling path
function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
    (; location, scale, dist) = q
    n_dims = length(location)
    scale_diag = diag(scale)
    return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location
end


struct Problem end

function LogDensityProblems.logdensity(problem, params)
	sum(params)
end

function estimate_energy_with_samples(prob, samples)
    return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachcol(samples))
end

function estimate_repgradelbo_ad_forward(params′, aux)
    (; rng, problem, adtype, restructure, q_stop) = aux
    q = restructure(params′)::typeof(restructure.model)
    samples = rand(rng, q, 3)
    energy = estimate_energy_with_samples(problem, samples)
    elbo = energy
    return -elbo
end

function mwe(T)
	adtype = AutoEnzyme()
    d    = 10
    seed = (0x38bef07cf9cc549d)
    rng  = StableRNG(seed)

    q = MvLocationScale(
        zeros(T, d), Diagonal(ones(T, d)), Normal{T}(zero(T), one(T)), T(1e-5)
    )
    params, restructure, = Optimisers.destructure(q)

    aux = (
        rng=rng,
        adtype=adtype,
        problem=Problem(),
        restructure=restructure,
        q_stop=q,
    )
    
    ∇x = zero(params)
    fill!(∇x, zero(eltype(∇x)))
    _, y = Enzyme.autodiff(
        set_runtime_activity(Enzyme.ReverseWithPrimal, true),
        estimate_repgradelbo_ad_forward,
        Enzyme.Active,
        Enzyme.Duplicated(params, ∇x),
        Enzyme.Const(aux)
    )
    ∇x
end

mwe(Float64)
[16416] signal (11.1): Segmentation fault
in expression starting at REPL[22]:1
getindex at ./essentials.jl:13 [inlined]
getindex at ./subarray.jl:323 [inlined]
mapreduce_impl at ./reduce.jl:260
unknown function (ip: 0x2)
macro expansion at /home/krkim/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8197 [inlined]
enzyme_call at /home/krkim/.julia/packages/Enzyme/VSRgT/src/compiler.jl:7760 [inlined]
CombinedAdjointThunk at /home/krkim/.julia/packages/Enzyme/VSRgT/src/compiler.jl:7533 [inlined]
autodiff at /home/krkim/.julia/packages/Enzyme/VSRgT/src/Enzyme.jl:491 [inlined]
autodiff at /home/krkim/.julia/packages/Enzyme/VSRgT/src/Enzyme.jl:512 [inlined]
mwe at ./REPL[21]:22
unknown function (ip: 0x7f05dbd13282)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
do_call at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:126
eval_value at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:617
jl_interpret_toplevel_thunk at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/interpreter.c:775
jl_toplevel_eval_flex at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:934
jl_toplevel_eval_flex at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:877
ijl_toplevel_eval_in at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/toplevel.c:985
eval at ./boot.jl:385 [inlined]
eval_user_input at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
repl_backend_loop at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
#start_repl_backend#46 at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
start_repl_backend at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:228
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
#run_repl#59 at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
run_repl at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
jfptr_run_repl_91949.1 at /home/krkim/.julia/juliaup/julia-1.10.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
#1013 at ./client.jl:437
jfptr_YY.1013_82918.1 at /home/krkim/.julia/juliaup/julia-1.10.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_f__call_latest at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/builtins.c:812
#invokelatest#2 at ./essentials.jl:892 [inlined]
invokelatest at ./essentials.jl:889 [inlined]
run_main_repl at ./client.jl:421
exec_options at ./client.jl:338
_start at ./client.jl:557
jfptr__start_82944.1 at /home/krkim/.julia/juliaup/julia-1.10.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
true_main at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/jlapi.c:582
jl_repl_entrypoint at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/src/jlapi.c:731
main at /cache/build/builder-amdci5-5/julialang/julia-release-1-dot-10/cli/loader_exe.c:58
unknown function (ip: 0x7f06480d5e07)
__libc_start_main at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 38356665 (Pool: 38283792; Big: 72873); GC: 56
zsh: segmentation fault (core dumped)  julia

This is on 1.10 BTW:

Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 20 × 12th Gen Intel(R) Core(TM) i7-12700H
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 20 virtual cores)
@Red-Portal
Copy link
Author

related: TuringLang/AdvancedVI.jl#129

@wsmoses
Copy link
Member

wsmoses commented Nov 3, 2024

Hm, are you able to reduce this?

@Red-Portal
Copy link
Author

Let me give it a shot

@Red-Portal
Copy link
Author

Red-Portal commented Nov 3, 2024

@wsmoses Would this be more useful? The custom Restructure seems to be necessary to reproduce the problem.

using Enzyme
using Distributions
using LinearAlgebra
using StableRNGs
using Random
using Optimisers
using Functors

struct MvLocationScale{
    S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
    scale_eps::E
end

Functors.@functor MvLocationScale (location, scale)

struct RestructureMeanField{S<:Diagonal,D,L,E}
    model::MvLocationScale{S,D,L,E}
end

function (re::RestructureMeanField)(flat::AbstractVector)
    n_dims = div(length(flat), 2)
    location = first(flat, n_dims)
    scale = Diagonal(last(flat, n_dims))
    return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
end

function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E}
    (; location, scale, dist) = q
    flat = vcat(location, diag(scale))
    return flat, RestructureMeanField(q)
end

function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
    (; location, scale, dist) = q
    n_dims = length(location)
    scale_diag = diag(scale)
    return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location
end

function estimate_repgradelbo_ad_forward(params′, aux)
    (; restructure,) = aux
    q = restructure(params′)::typeof(restructure.model)
    samples = rand(Random.default_rng(), q, 3)
    mean(samples)
end

function mwe(T)
	d = 2
    q = MvLocationScale(
        zeros(T, d), Diagonal(ones(T, d)), Normal{T}(zero(T), one(T)), T(1e-5)
    )
    params, restructure, = Optimisers.destructure(q)

    aux = (
        restructure=restructure,
    )
    
    ∇x = zero(params)
    fill!(∇x, zero(eltype(∇x)))
    _, y = Enzyme.autodiff(
        set_runtime_activity(Enzyme.ReverseWithPrimal, true),
        estimate_repgradelbo_ad_forward,
        Enzyme.Active,
        Enzyme.Duplicated(params, ∇x),
        Enzyme.Const(aux)
    )
    ∇x
end

mwe(Float64)

@wsmoses
Copy link
Member

wsmoses commented Nov 3, 2024

That is but anything further you can do will help

@Red-Portal
Copy link
Author

Stripped a little more

using Enzyme
using Distributions
using LinearAlgebra
using Random

struct MvLocationScale{
    S, D <: ContinuousDistribution, L
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
end

struct RestructureMeanField{S<:Diagonal,D,L}
    model::MvLocationScale{S,D,L}
end

function (re::RestructureMeanField)(flat::AbstractVector)
    n_dims = div(length(flat), 2)
    location = first(flat, n_dims)
    scale = Diagonal(last(flat, n_dims))
    return MvLocationScale(location, scale, re.model.dist)
end

function destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
    (; location, scale, dist) = q
    flat = vcat(location, diag(scale))
    return flat, RestructureMeanField(q)
end

function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
    (; location, scale, dist) = q
    n_dims = length(location)
    scale_diag = diag(scale)
    return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location
end

function estimate_repgradelbo_ad_forward(params′, aux)
    (; restructure,) = aux
    q = restructure(params′)::typeof(restructure.model)
    samples = rand(Random.default_rng(), q, 3)
    mean(samples)
end

function mwe(T)
	d = 2
    q = MvLocationScale(zeros(T, d), Diagonal(ones(T, d)), Normal{T}(zero(T), one(T)))
    params, restructure, = destructure(q)
    aux = (restructure=restructure,)
    
    ∇x = zero(params)
    fill!(∇x, zero(eltype(∇x)))
    _, y = Enzyme.autodiff(
        set_runtime_activity(Enzyme.ReverseWithPrimal, true),
        estimate_repgradelbo_ad_forward,
        Enzyme.Active,
        Enzyme.Duplicated(params, ∇x),
        Enzyme.Const(aux)
    )
    ∇x
end

mwe(Float64)

@wsmoses
Copy link
Member

wsmoses commented Nov 3, 2024

Any chance you can get rid of the distributions dependency (ccalls are fine)

@Red-Portal
Copy link
Author

Red-Portal commented Nov 3, 2024

The problem seemed to go away if I removed the dist field, so I couldn't simply axe it out. At this point, it is a bit of shooting in the dark for me to try to replicate the problem without dist.

edit: Actually, let me try something

@Red-Portal
Copy link
Author

Red-Portal commented Nov 3, 2024

Okay, the following works, so it's not obvious to me how to remove Distributions while keeping the bug:

using Enzyme
using LinearAlgebra
using Random
using Statistics

struct FakeNormal end

struct MvLocationScale{S, D, L}
    location ::L
    scale    ::S
    dist     ::D
end

struct RestructureMeanField{S<:Diagonal,D,L}
    model::MvLocationScale{S,D,L}
end

function (re::RestructureMeanField)(flat::AbstractVector)
    n_dims = div(length(flat), 2)
    location = first(flat, n_dims)
    scale = Diagonal(last(flat, n_dims))
    return MvLocationScale(location, scale, re.model.dist)
end

function destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
    (; location, scale, dist) = q
    flat = vcat(location, diag(scale))
    return flat, RestructureMeanField(q)
end

function rand(
    rng::AbstractRNG, ::FakeNormal, n_dims::Int, num_samples::Int
)
	return randn(rng, n_dims, num_samples)
end

function rand(
    rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
    (; location, scale, dist) = q
    n_dims = length(location)
    scale_diag = diag(scale)
    return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location
end

function estimate_repgradelbo_ad_forward(params′, aux)
    (; restructure,) = aux
    q = restructure(params′)::typeof(restructure.model)
    samples = rand(Random.default_rng(), q, 3)
    mean(samples)
end

function mwe(T)
    d = 2
    q = MvLocationScale(zeros(T, d), Diagonal(ones(T, d)), FakeNormal())
    params, restructure, = destructure(q)
    aux = (restructure=restructure,)
    
    ∇x = zero(params)
    fill!(∇x, zero(eltype(∇x)))
    _, y = Enzyme.autodiff(
        set_runtime_activity(Enzyme.ReverseWithPrimal, true),
        estimate_repgradelbo_ad_forward,
        Enzyme.Active,
        Enzyme.Duplicated(params, ∇x),
        Enzyme.Const(aux)
    )
    ∇x
end

mwe(Float64)

@wsmoses
Copy link
Member

wsmoses commented Nov 3, 2024

using Enzyme
using LinearAlgebra
using Random
using Statistics

Enzyme.API.printall!(true)
Enzyme.Compiler.DumpPostOpt[] = true

struct MyNormal
    sigma::Float64
    off::Float64
end

struct MvLocationScale{
    S, D, L
}
    location ::L
    scale    ::S
    dist     ::D
end

@noinline function law(dist, flat::AbstractVector)
    ccall(:jl_, Cvoid, (Any,), flat)
    n_dims = div(length(flat), 2)
    data = first(flat, n_dims)
    scale = Diagonal(data)
    return MvLocationScale(nothing, scale, dist)
end

function destructure(q::MvLocationScale)
    return diag(q.scale)
end


myxval(d::MyNormal, z::Real) = muladd(d.sigma, z, d.off)

function myrand!(rng::AbstractRNG, d::MyNormal, A::AbstractArray{<:Real})
    # randn!(rng, A)
    map!(Base.Fix1(myxval, d), A, A)
    return A
end

function man(q::MvLocationScale)
    dist = MyNormal(1.0, 0.0)
    
    out = ones(2,3) # Array{Float64}(undef, (2,3))
    @inbounds myrand!(Random.default_rng(), dist, out)

    return q.scale[1] * out
end

function estimate_repgradelbo_ad_forward(params, dist)
    q = law(dist, params)
    samples = man(q)
    mean(samples)
end

function mwe(T)
	d = 2
    dist = MyNormal(1.0, 0.0)
    q = MvLocationScale(zeros(T, d), Diagonal(ones(T, d)), dist)
    params = destructure(q)
    
    ∇x = zero(params)
    fill!(∇x, zero(eltype(∇x)))
    
    estimate_repgradelbo_ad_forward(params, dist)

    _, y = Enzyme.autodiff(
        set_runtime_activity(Enzyme.ReverseWithPrimal),
        estimate_repgradelbo_ad_forward,
        Enzyme.Active,
        Enzyme.Duplicated(params, ∇x),
        Enzyme.Const(dist)
    )
    ∇x
end

mwe(Float64)

@wsmoses
Copy link
Member

wsmoses commented Nov 3, 2024

fixed by #2047

@wsmoses wsmoses closed this as completed Nov 3, 2024
@Red-Portal
Copy link
Author

@wsmoses Any plans for the next release?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants