Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DenseMetric and Component arrays (solve #344) #345

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

erathorn
Copy link

@erathorn erathorn commented Aug 3, 2023

This PR attempts to solve #344

I went for the solution to preallocate the result in ∂H∂r such that the type of the input r matches the type of the output.
I added tests that not only check the correct numerical output but also check the type.

Additionally, I found a typo in the inner constructor of PhasePoint, where it originally was length(ℓπ.gradient) == length(ℓπ.gradient).

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @erathorn ! And nice catch on the constructor for Phasepoint!

I'm wondering if we should just make compat with ComponentArrays.jl an extension instead of complicating the existing code, and then we can just overload whatever we need there. Thoughts?

Comment on lines +45 to +49
function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat)
out = similar(r) # Make sure the output of this function is of the same type as r
mul!(out, h.metric.M⁻¹, r)
out
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit uncertain about this change as it "complicates" code to mainly just stay compatible with ComponentArrays.jl, and thus I'd be more in favour of just making it an extension instead, I think 😕 Then in the extension, we just overload whatever we need to be compatible.

Also, will this code break if, say, h.metric.M⁻¹ has eltype Float64 but r has eltype Float32, rather than just promoting, as is current behavior?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p1 = ComponentArray(m=one(Float32), s = one(Float32))
r = similar(p1)
M = diagm(randn(Float64, 2))
mul!(r, M, p1)

This works on my machine, and returns r as a component array of eltype Float32 as expected.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AHMC supports vectorised sampling, when passing arguments in a suitable type. In this case, r::AbstractVecOrMat could be a single momentum realization or a vector of momentum realizations. Therefore, the new code needs to be able to handle the vectorized sampling mode for the tests to pass.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the silence. Thank you for the suggestion, it totally makes sense to me. However, I looked into this a bit more and am honestly slightly lost. The call to the rand function, which fails in the tests only works in the test case. Calling this function in a plain Julia session fails for me (on the main branch). A brute force solution, which dispatches on r::AbtractVecOrMat{AbstractVecOrMat}, does unfortunately not do the trick either.

test/hamiltonian.jl Outdated Show resolved Hide resolved
src/hamiltonian.jl Show resolved Hide resolved
@erathorn
Copy link
Author

erathorn commented Aug 3, 2023

I'm wondering if we should just make compat with ComponentArrays.jl an extension instead of complicating the existing code, and then we can just overload whatever we need there. Thoughts?

I went for this "complication" because of the comment next to safe_rsimilar and the phasepoint function taking different types. Which I understood as "workarounds" without explicit dependence.

Co-authored-by: Tor Erlend Fjelde <[email protected]>
@erathorn
Copy link
Author

erathorn commented Aug 3, 2023

The tests fail at this line:
https://github.com/TuringLang/AdvancedHMC.jl/blob/eb9b2e0d60ef3dd85768d6e6a9f19de15b8f7130/test/metric.jl#L13C1-L13C34

The problem seems to be, that the implementations of rand do not have the correct signature.

https://github.com/TuringLang/AdvancedHMC.jl/blob/eb9b2e0d60ef3dd85768d6e6a9f19de15b8f7130/src/metric.jl#L128C1-L136C38

Base.rand(rng::AbstractVector{<:AbstractRNG}) = rand.(rng)

https://github.com/TuringLang/AdvancedHMC.jl/blob/eb9b2e0d60ef3dd85768d6e6a9f19de15b8f7130/src/utilities.jl#L9C1-L12C4

However, I have not touched any of this at all. 🤔

@torfjelde
Copy link
Member

The tests fail at this line:

This is indeed strange given that the CI on master is working just fine 😕

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants