Skip to content

Commit

Permalink
Merge pull request #448 from probcomp/mrb/unfold_gradient_default_args
Browse files Browse the repository at this point in the history
Fix Unfold propagating gradients to default parameters.
  • Loading branch information
alex-lew authored Mar 4, 2022
2 parents 3cc93fd + 56115e3 commit dce003c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Manifest.toml
*.pdf
*.png
*.jld
Expand All @@ -7,4 +8,4 @@
*.log
docs/build/
docs/site/
.DS_Store
.DS_Store
7 changes: 5 additions & 2 deletions src/modeling_library/unfold/backprop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
@inline fold_sum(::Nothing, a::A) where A = a
@inline fold_sum(a::A, b::A) where A = a + b

@inline _sum(::Nothing) = nothing
@inline _sum(x::Vector) = sum(x)

function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selection, retval_grad) where {T,U}
kernel_has_grads = has_argument_grads(trace.gen_fn.kernel)
if kernel_has_grads[1]
Expand Down Expand Up @@ -44,7 +47,7 @@ function choice_gradients(trace::VectorTrace{UnfoldType,T,U}, selection::Selecti
end
end
end
((nothing, kernel_arg_grads[2], params_grad...), value_choices, gradient_choices)
((nothing, kernel_arg_grads[2], map(_sum, params_grad)...), value_choices, gradient_choices)
end

function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_grad) where {T,U}
Expand Down Expand Up @@ -82,5 +85,5 @@ function accumulate_param_gradients!(trace::VectorTrace{UnfoldType,T,U}, retval_
end
end
end
(nothing, kernel_arg_grads[2], params_grad...)
(nothing, kernel_arg_grads[2], map(_sum, params_grad)...)
end
7 changes: 5 additions & 2 deletions test/modeling_library/unfold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,13 @@ foo = Unfold(kernel)

zero_param_grad!(kernel, :std)
input_grads = accumulate_param_gradients!(trace, nothing)
expected_xs_grad = logpdf_grad(normal, x1, x_init * alpha + beta, std)[2] * x_init + logpdf_grad(normal, x2, x1 * alpha + beta, std)[2] * x1
expected_ys_grad = logpdf_grad(normal, x1, x_init * alpha + beta, std)[2] + logpdf_grad(normal, x2, x1 * alpha + beta, std)[2]

@test input_grads[1] == nothing # length
@test input_grads[2] == nothing # inital state
#@test isapprox(input_grads[3], expected_xs_grad) # alpha
#@test isapprox(input_grads[4], expected_ys_grad) # beta
@test isapprox(input_grads[3], expected_xs_grad) # alpha
@test isapprox(input_grads[4], expected_ys_grad) # beta
expected_std_grad = (logpdf_grad(normal, x1, x_init * alpha + beta, std)[3]
+ logpdf_grad(normal, x2, x1 * alpha + beta, std)[3])
@test isapprox(get_param_grad(kernel, :std), expected_std_grad)
Expand Down

0 comments on commit dce003c

Please sign in to comment.