diff --git a/src/rrules.jl b/src/rrules.jl index a94c838ef..96729e79f 100644 --- a/src/rrules.jl +++ b/src/rrules.jl @@ -18,14 +18,19 @@ Parameters * `F`: the JUDI propgator * `q`: The source to compute F*q """ -struct LazyPropagation +mutable struct LazyPropagation post::Function F::judiPropagator q + val # store F * q end -eval_prop(F::LazyPropagation) = F.post(F.F * F.q) +function eval_prop(F::LazyPropagation) + isnothing(F.val) && (F.val = F.F * F.q) + return F.post(F.val) +end Base.collect(F::LazyPropagation) = eval_prop(F) +LazyPropagation(post, F::judiPropagator, q) = LazyPropagation(post, F, q, nothing) LazyPropagation(F::judiPropagator, q) = LazyPropagation(identity, F, q) # Only a few arithmetic operation are supported @@ -47,10 +52,10 @@ end for op in [:*, :/] @eval begin - $(op)(F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, $(op)(F.q, y)) - $(op)(y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, $(op)(y, F.q)) - broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), F.q, y)) - broadcasted(::typeof($op), y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), y, F.q)) + $(op)(F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, $(op)(F.q, y), isnothing(F.val) ? nothing : $(op)(F.val, y)) + $(op)(y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, $(op)(y, F.q), isnothing(F.val) ? nothing : $(op)(y, F.val)) + broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), F.q, y), isnothing(F.val) ? nothing : broadcasted($(op), F.val, y)) + broadcasted(::typeof($op), y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), y, F.q), isnothing(F.val) ? nothing : broadcasted($(op), y, F.val)) end end @@ -66,7 +71,7 @@ end broadcasted(::typeof(^), y::LazyPropagation, p::Real) = eval_prop(y).^(p) *(F::judiPropagator, q::LazyPropagation) = F*eval_prop(q) -reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, Q.q) +reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, F.q, F.val) copyto!(x::AbstractArray, F::LazyPropagation) = copyto!(x, eval_prop(F)) dot(x::AbstractArray, F::LazyPropagation) = dot(x, eval_prop(F)) dot(F::LazyPropagation, x::AbstractArray) = dot(x, F) diff --git a/test/test_rrules.jl b/test/test_rrules.jl index 5b86ee664..57f879923 100644 --- a/test/test_rrules.jl +++ b/test/test_rrules.jl @@ -32,9 +32,8 @@ perturb(x::judiVector) = judiVector(x.geometry, [randx(x.data[i]) for i=1:x.nsrc reverse(x::judiVector) = judiVector(x.geometry, [x.data[i][end:-1:1, :] for i=1:x.nsrc]) misfit_objective_2p(d_obs, q0, m0, F) = .5f0*norm(F(m0, q0) - d_obs)^2 -misfit_objective_1p(d_obs, q0, m0, F) = .5f0*norm(F(m0)*q0 - d_obs)^2 -misfit_objective_1p_(d_obs, q0, m0, F) = .5f0*norm(F(1f0*m0)*q0 - d_obs)^2 - +misfit_objective_1p(d_obs, q0, m0, F) = .5f0*norm(F(1f0*m0)*q0 - d_obs)^2 + function loss(misfit, d_obs, q0, m0, F) local ϕ # Reshape as ML size if returns array @@ -84,7 +83,6 @@ ftol = sqrt(eps(1f0)) gs_inv = gradient(x -> misfit_objective_2p(d_obs, q0, x, F), m0) if ~ra gs_inv1 = gradient(x -> misfit_objective_1p(d_obs, q0, x, F), model0.m) - gs_inv1_ = gradient(x -> misfit_objective_1p_(d_obs, q0, x, F), model0.m) @test gs_inv[1][:] ≈ gs_inv1[1][:] rtol=ftol end # Gradient with m PhysicalParameter