-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DenseMetric and Component arrays (solve #344) #345
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @erathorn ! And nice catch on the constructor for Phasepoint
!
I'm wondering if we should just make compat with ComponentArrays.jl an extension instead of complicating the existing code, and then we can just overload whatever we need there. Thoughts?
function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) | ||
out = similar(r) # Make sure the output of this function is of the same type as r | ||
mul!(out, h.metric.M⁻¹, r) | ||
out | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit uncertain about this change as it "complicates" code to mainly just stay compatible with ComponentArrays.jl, and thus I'd be more in favour of just making it an extension instead, I think 😕 Then in the extension, we just overload whatever we need to be compatible.
Also, will this code break if, say, h.metric.M⁻¹
has eltype Float64
but r
has eltype Float32
, rather than just promoting, as is current behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
p1 = ComponentArray(m=one(Float32), s = one(Float32))
r = similar(p1)
M = diagm(randn(Float64, 2))
mul!(r, M, p1)
This works on my machine, and returns r
as a component array of eltype Float32
as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AHMC supports vectorised sampling, when passing arguments in a suitable type. In this case, r::AbstractVecOrMat
could be a single momentum realization or a vector of momentum realizations. Therefore, the new code needs to be able to handle the vectorized sampling mode for the tests to pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the silence. Thank you for the suggestion, it totally makes sense to me. However, I looked into this a bit more and am honestly slightly lost. The call to the rand
function, which fails in the tests only works in the test case. Calling this function in a plain Julia session fails for me (on the main branch). A brute force solution, which dispatches on r::AbtractVecOrMat{AbstractVecOrMat}
, does unfortunately not do the trick either.
I went for this "complication" because of the comment next to |
Co-authored-by: Tor Erlend Fjelde <[email protected]>
The tests fail at this line: The problem seems to be, that the implementations of AdvancedHMC.jl/src/utilities.jl Line 5 in eb9b2e0
However, I have not touched any of this at all. 🤔 |
This is indeed strange given that the CI on master is working just fine 😕 |
This PR attempts to solve #344
I went for the solution to preallocate the result in
∂H∂r
such that the type of the inputr
matches the type of the output.I added tests that not only check the correct numerical output but also check the type.
Additionally, I found a typo in the inner constructor of PhasePoint, where it originally was
length(ℓπ.gradient) == length(ℓπ.gradient)
.