Skip to content

Commit

Permalink
add a value to store F*q
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyiyin97 committed Jan 6, 2023
1 parent 5754554 commit 889c37e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
19 changes: 12 additions & 7 deletions src/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@ Parameters
* `F`: the JUDI propgator
* `q`: The source to compute F*q
"""
struct LazyPropagation
mutable struct LazyPropagation
post::Function
F::judiPropagator
q
val # store F * q
end

eval_prop(F::LazyPropagation) = F.post(F.F * F.q)
function eval_prop(F::LazyPropagation)
isnothing(F.val) && (F.val = F.F * F.q)
return F.post(F.val)
end
Base.collect(F::LazyPropagation) = eval_prop(F)
LazyPropagation(post, F::judiPropagator, q) = LazyPropagation(post, F, q, nothing)
LazyPropagation(F::judiPropagator, q) = LazyPropagation(identity, F, q)

# Only a few arithmetic operation are supported
Expand All @@ -47,10 +52,10 @@ end

for op in [:*, :/]
@eval begin
$(op)(F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, $(op)(F.q, y))
$(op)(y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, $(op)(y, F.q))
broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), F.q, y))
broadcasted(::typeof($op), y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), y, F.q))
$(op)(F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, $(op)(F.q, y), isnothing(F.val) ? nothing : $(op)(F.val, y))
$(op)(y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, $(op)(y, F.q), isnothing(F.val) ? nothing : $(op)(y, F.val))
broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), F.q, y), isnothing(F.val) ? nothing : broadcasted($(op), F.val, y))
broadcasted(::typeof($op), y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), y, F.q), isnothing(F.val) ? nothing : broadcasted($(op), y, F.val))
end
end

Expand All @@ -66,7 +71,7 @@ end
broadcasted(::typeof(^), y::LazyPropagation, p::Real) = eval_prop(y).^(p)
*(F::judiPropagator, q::LazyPropagation) = F*eval_prop(q)

reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, Q.q)
reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, F.q, F.val)
copyto!(x::AbstractArray, F::LazyPropagation) = copyto!(x, eval_prop(F))
dot(x::AbstractArray, F::LazyPropagation) = dot(x, eval_prop(F))
dot(F::LazyPropagation, x::AbstractArray) = dot(x, F)
Expand Down
6 changes: 2 additions & 4 deletions test/test_rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ perturb(x::judiVector) = judiVector(x.geometry, [randx(x.data[i]) for i=1:x.nsrc
reverse(x::judiVector) = judiVector(x.geometry, [x.data[i][end:-1:1, :] for i=1:x.nsrc])

misfit_objective_2p(d_obs, q0, m0, F) = .5f0*norm(F(m0, q0) - d_obs)^2
misfit_objective_1p(d_obs, q0, m0, F) = .5f0*norm(F(m0)*q0 - d_obs)^2
misfit_objective_1p_(d_obs, q0, m0, F) = .5f0*norm(F(1f0*m0)*q0 - d_obs)^2

misfit_objective_1p(d_obs, q0, m0, F) = .5f0*norm(F(1f0*m0)*q0 - d_obs)^2

function loss(misfit, d_obs, q0, m0, F)
local ϕ
# Reshape as ML size if returns array
Expand Down Expand Up @@ -84,7 +83,6 @@ ftol = sqrt(eps(1f0))
gs_inv = gradient(x -> misfit_objective_2p(d_obs, q0, x, F), m0)
if ~ra
gs_inv1 = gradient(x -> misfit_objective_1p(d_obs, q0, x, F), model0.m)
gs_inv1_ = gradient(x -> misfit_objective_1p_(d_obs, q0, x, F), model0.m)
@test gs_inv[1][:] gs_inv1[1][:] rtol=ftol
end
# Gradient with m PhysicalParameter
Expand Down

0 comments on commit 889c37e

Please sign in to comment.