Skip to content

Commit

Permalink
fix reshape judivector in ad
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyiyin97 committed Jan 6, 2023
1 parent 2252837 commit f045df7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/TimeModeling/Types/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,15 @@ vec(x::judiMultiSourceVector) = vcat(vec.(x.data)...)

time_sampling(ms::judiMultiSourceVector) = [1 for i=1:ms.nsrc]

reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N = reshape(vec(ms), dims)
function reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N
try
return reshape(vec(ms), dims)
catch e
@assert dims[1] == ms.nsrc ### during AD, size(ms::judiVector) = ms.nsrc
return ms
end
end

############################################################################################################################
# Linear algebra `*`
(msv::judiMultiSourceVector{mT})(x::AbstractVector{T}) where {mT, T<:Number} = x
Expand Down
1 change: 1 addition & 0 deletions src/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ broadcasted(::typeof(^), y::LazyPropagation, p::Real) = eval_prop(y).^(p)
*(F::judiPropagator, q::LazyPropagation) = F*eval_prop(q)

reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, F.q, F.val)
vec(F::LazyPropagation) = LazyPropagation(vec, F.F, F.q, F.val)
copyto!(x::AbstractArray, F::LazyPropagation) = copyto!(x, eval_prop(F))
dot(x::AbstractArray, F::LazyPropagation) = dot(x, eval_prop(F))
dot(F::LazyPropagation, x::AbstractArray) = dot(x, F)
Expand Down

0 comments on commit f045df7

Please sign in to comment.