-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Segfault in AdvancedVI when involving Diagonal
specializations
#2042
Comments
related: TuringLang/AdvancedVI.jl#129 |
Hm, are you able to reduce this? |
Let me give it a shot |
@wsmoses Would this be more useful? The custom using Enzyme
using Distributions
using LinearAlgebra
using StableRNGs
using Random
using Optimisers
using Functors
struct MvLocationScale{
S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
location ::L
scale ::S
dist ::D
scale_eps::E
end
Functors.@functor MvLocationScale (location, scale)
struct RestructureMeanField{S<:Diagonal,D,L,E}
model::MvLocationScale{S,D,L,E}
end
function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
end
function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E}
(; location, scale, dist) = q
flat = vcat(location, diag(scale))
return flat, RestructureMeanField(q)
end
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
(; location, scale, dist) = q
n_dims = length(location)
scale_diag = diag(scale)
return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location
end
function estimate_repgradelbo_ad_forward(params′, aux)
(; restructure,) = aux
q = restructure(params′)::typeof(restructure.model)
samples = rand(Random.default_rng(), q, 3)
mean(samples)
end
function mwe(T)
d = 2
q = MvLocationScale(
zeros(T, d), Diagonal(ones(T, d)), Normal{T}(zero(T), one(T)), T(1e-5)
)
params, restructure, = Optimisers.destructure(q)
aux = (
restructure=restructure,
)
∇x = zero(params)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
set_runtime_activity(Enzyme.ReverseWithPrimal, true),
estimate_repgradelbo_ad_forward,
Enzyme.Active,
Enzyme.Duplicated(params, ∇x),
Enzyme.Const(aux)
)
∇x
end
mwe(Float64) |
That is but anything further you can do will help |
Stripped a little more using Enzyme
using Distributions
using LinearAlgebra
using Random
struct MvLocationScale{
S, D <: ContinuousDistribution, L
} <: ContinuousMultivariateDistribution
location ::L
scale ::S
dist ::D
end
struct RestructureMeanField{S<:Diagonal,D,L}
model::MvLocationScale{S,D,L}
end
function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
return MvLocationScale(location, scale, re.model.dist)
end
function destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
(; location, scale, dist) = q
flat = vcat(location, diag(scale))
return flat, RestructureMeanField(q)
end
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
(; location, scale, dist) = q
n_dims = length(location)
scale_diag = diag(scale)
return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location
end
function estimate_repgradelbo_ad_forward(params′, aux)
(; restructure,) = aux
q = restructure(params′)::typeof(restructure.model)
samples = rand(Random.default_rng(), q, 3)
mean(samples)
end
function mwe(T)
d = 2
q = MvLocationScale(zeros(T, d), Diagonal(ones(T, d)), Normal{T}(zero(T), one(T)))
params, restructure, = destructure(q)
aux = (restructure=restructure,)
∇x = zero(params)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
set_runtime_activity(Enzyme.ReverseWithPrimal, true),
estimate_repgradelbo_ad_forward,
Enzyme.Active,
Enzyme.Duplicated(params, ∇x),
Enzyme.Const(aux)
)
∇x
end
mwe(Float64) |
Any chance you can get rid of the distributions dependency (ccalls are fine) |
The problem seemed to go away if I removed the edit: Actually, let me try something |
Okay, the following works, so it's not obvious to me how to remove using Enzyme
using LinearAlgebra
using Random
using Statistics
struct FakeNormal end
struct MvLocationScale{S, D, L}
location ::L
scale ::S
dist ::D
end
struct RestructureMeanField{S<:Diagonal,D,L}
model::MvLocationScale{S,D,L}
end
function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
return MvLocationScale(location, scale, re.model.dist)
end
function destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
(; location, scale, dist) = q
flat = vcat(location, diag(scale))
return flat, RestructureMeanField(q)
end
function rand(
rng::AbstractRNG, ::FakeNormal, n_dims::Int, num_samples::Int
)
return randn(rng, n_dims, num_samples)
end
function rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
(; location, scale, dist) = q
n_dims = length(location)
scale_diag = diag(scale)
return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location
end
function estimate_repgradelbo_ad_forward(params′, aux)
(; restructure,) = aux
q = restructure(params′)::typeof(restructure.model)
samples = rand(Random.default_rng(), q, 3)
mean(samples)
end
function mwe(T)
d = 2
q = MvLocationScale(zeros(T, d), Diagonal(ones(T, d)), FakeNormal())
params, restructure, = destructure(q)
aux = (restructure=restructure,)
∇x = zero(params)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
set_runtime_activity(Enzyme.ReverseWithPrimal, true),
estimate_repgradelbo_ad_forward,
Enzyme.Active,
Enzyme.Duplicated(params, ∇x),
Enzyme.Const(aux)
)
∇x
end
mwe(Float64) |
using Enzyme
using LinearAlgebra
using Random
using Statistics
Enzyme.API.printall!(true)
Enzyme.Compiler.DumpPostOpt[] = true
struct MyNormal
sigma::Float64
off::Float64
end
struct MvLocationScale{
S, D, L
}
location ::L
scale ::S
dist ::D
end
@noinline function law(dist, flat::AbstractVector)
ccall(:jl_, Cvoid, (Any,), flat)
n_dims = div(length(flat), 2)
data = first(flat, n_dims)
scale = Diagonal(data)
return MvLocationScale(nothing, scale, dist)
end
function destructure(q::MvLocationScale)
return diag(q.scale)
end
myxval(d::MyNormal, z::Real) = muladd(d.sigma, z, d.off)
function myrand!(rng::AbstractRNG, d::MyNormal, A::AbstractArray{<:Real})
# randn!(rng, A)
map!(Base.Fix1(myxval, d), A, A)
return A
end
function man(q::MvLocationScale)
dist = MyNormal(1.0, 0.0)
out = ones(2,3) # Array{Float64}(undef, (2,3))
@inbounds myrand!(Random.default_rng(), dist, out)
return q.scale[1] * out
end
function estimate_repgradelbo_ad_forward(params, dist)
q = law(dist, params)
samples = man(q)
mean(samples)
end
function mwe(T)
d = 2
dist = MyNormal(1.0, 0.0)
q = MvLocationScale(zeros(T, d), Diagonal(ones(T, d)), dist)
params = destructure(q)
∇x = zero(params)
fill!(∇x, zero(eltype(∇x)))
estimate_repgradelbo_ad_forward(params, dist)
_, y = Enzyme.autodiff(
set_runtime_activity(Enzyme.ReverseWithPrimal),
estimate_repgradelbo_ad_forward,
Enzyme.Active,
Enzyme.Duplicated(params, ∇x),
Enzyme.Const(dist)
)
∇x
end
mwe(Float64) |
fixed by #2047 |
@wsmoses Any plans for the next release? |
Hitting segfaults in the following example. This seems unique to the case where
D <: Diagonal
.This is on 1.10 BTW:
The text was updated successfully, but these errors were encountered: