-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable scalar/broadcast operation for LazyPropagation #167
base: master
Are you sure you want to change the base?
Changes from 4 commits
4c1ef4d
2252837
f045df7
8594806
6c693a2
171170f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,7 +73,15 @@ vec(x::judiMultiSourceVector) = vcat(vec.(x.data)...) | |
|
||
time_sampling(ms::judiMultiSourceVector) = [1 for i=1:ms.nsrc] | ||
|
||
reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N = reshape(vec(ms), dims) | ||
function reshape(ms::judiMultiSourceVector, dims::Dims{N}) where N | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for failing exampleusing JUDI
using Flux
using ArgParse, Test, Printf, Aqua
using SegyIO, LinearAlgebra, Distributed, JOLI
using TimerOutputs: TimerOutputs, @timeit
Flux.Random.seed!(2022)
### Model
tti = false
viscoacoustic = false
nsrc = 1
dt = 1f0
include(joinpath(JUDIPATH, "../test/seismic_utils.jl"))
model, model0, dm = setup_model(tti, viscoacoustic, 4)
m, m0 = model.m.data, model0.m.data
q, srcGeometry, recGeometry, f0 = setup_geom(model; nsrc=nsrc, dt=dt)
# Common op
Pr = judiProjection(recGeometry)
Ps = judiProjection(srcGeometry)
opt = Options(sum_padding=true, f0=f0)
A_inv = judiModeling(model; options=opt)
A_inv0 = judiModeling(model0; options=opt)
# Operators
F = Pr*A_inv*adjoint(Ps)
J = judiJacobian(F,q)
dm = vec(m-m0)
gs_inv = gradient(q -> norm(J(q)*dm), q) ERROR: LoadError: DimensionMismatch: new dimensions (1,) must be consistent with array size 1501
Stacktrace:
[1] (::Base.var"#throw_dmrsa#289")(dims::Tuple{Int64}, len::Int64)
@ Base ./reshapedarray.jl:41
[2] reshape
@ ./reshapedarray.jl:45 [inlined]
[3] reshape
@ ~/.julia/dev/JUDI/src/TimeModeling/Types/abstract.jl:76 [inlined]
[4] reshape(parent::judiVector{Float32, Matrix{Float32}}, shp::Tuple{Base.OneTo{Int64}})
@ Base ./reshapedarray.jl:111
[5] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::JUDI.LazyPropagation)
@ JUDI ~/.julia/dev/JUDI/src/rrules.jl:142
[6] _project
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:184 [inlined]
[7] map(f::typeof(Zygote._project), t::Tuple{judiVector{Float32, Matrix{Float32}}}, s::Tuple{JUDI.LazyPropagation})
@ Base ./tuple.jl:246
[8] gradient(f::Function, args::judiVector{Float32, Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:98
[9] top-level scope
@ ~/.julia/dev/JUDI/test/MFE.jl:33
[10] include(fname::String)
@ Base.MainInclude ./client.jl:476
[11] top-level scope
@ REPL[1]:1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It wasn't failing before what changed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The example above fails on master branch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be more specific: julia> gs_inv = gradient(() -> norm(J(q)*dm), Flux.params(q))
Operator `born` ran in 0.75 s
Grads(...) this doesn't fail but julia> gs_inv = gradient(q -> norm(J(q)*dm), q)
Operator `born` ran in 0.72 s
Operator `born` ran in 0.73 s
ERROR: DimensionMismatch: new dimensions (1,) must be consistent with array size 1501
Stacktrace:
[1] (::Base.var"#throw_dmrsa#289")(dims::Tuple{Int64}, len::Int64)
@ Base ./reshapedarray.jl:41
[2] reshape
@ ./reshapedarray.jl:45 [inlined]
[3] reshape
@ ~/.julia/dev/JUDI/src/TimeModeling/Types/abstract.jl:76 [inlined]
[4] reshape(parent::judiVector{Float32, Matrix{Float32}}, shp::Tuple{Base.OneTo{Int64}})
@ Base ./reshapedarray.jl:111
[5] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::JUDI.LazyPropagation)
@ JUDI ~/.julia/dev/JUDI/src/rrules.jl:142
[6] _project
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:184 [inlined]
[7] map(f::typeof(Zygote._project), t::Tuple{judiVector{Float32, Matrix{Float32}}}, s::Tuple{JUDI.LazyPropagation})
@ Base ./tuple.jl:246
[8] gradient(f::Function, args::judiVector{Float32, Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:98
[9] top-level scope
@ REPL[25]:1
[10] top-level scope
@ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52 this fail There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hum ok, then split into |
||
try | ||
return reshape(vec(ms), dims) | ||
catch e | ||
@assert dims[1] == ms.nsrc ### during AD, size(ms::judiVector) = ms.nsrc | ||
return ms | ||
end | ||
end | ||
|
||
############################################################################################################################ | ||
# Linear algebra `*` | ||
(msv::judiMultiSourceVector{mT})(x::AbstractVector{T}) where {mT, T<:Number} = x | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,14 +18,19 @@ Parameters | |
* `F`: the JUDI propgator | ||
* `q`: The source to compute F*q | ||
""" | ||
struct LazyPropagation | ||
mutable struct LazyPropagation | ||
post::Function | ||
F::judiPropagator | ||
q | ||
val # store F * q | ||
end | ||
|
||
eval_prop(F::LazyPropagation) = F.post(F.F * F.q) | ||
function eval_prop(F::LazyPropagation) | ||
isnothing(F.val) && (F.val = F.F * F.q) | ||
return F.post(F.val) | ||
end | ||
Base.collect(F::LazyPropagation) = eval_prop(F) | ||
LazyPropagation(post, F::judiPropagator, q) = LazyPropagation(post, F, q, nothing) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ::Function |
||
LazyPropagation(F::judiPropagator, q) = LazyPropagation(identity, F, q) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. COnstructor |
||
|
||
# Only a few arithmetic operation are supported | ||
|
@@ -45,15 +50,35 @@ for op in [:+, :-, :*, :/] | |
end | ||
end | ||
|
||
for op in [:*, :/] | ||
@eval begin | ||
$(op)(F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, $(op)(F.q, y), isnothing(F.val) ? nothing : $(op)(F.val, y)) | ||
$(op)(y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, $(op)(y, F.q), isnothing(F.val) ? nothing : $(op)(y, F.val)) | ||
broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), F.q, y), isnothing(F.val) ? nothing : broadcasted($(op), F.val, y)) | ||
broadcasted(::typeof($op), y::T, F::LazyPropagation) where T <: Number = LazyPropagation(F.post, F.F, broadcasted($(op), y, F.q), isnothing(F.val) ? nothing : broadcasted($(op), y, F.val)) | ||
end | ||
end | ||
|
||
for op in [:+, :-] | ||
@eval begin | ||
$(op)(F::LazyPropagation, y::T) where T <: Number = $(op)(eval_prop(F), y) | ||
$(op)(y::T, F::LazyPropagation) where T <: Number = $(op)(y, eval_prop(F)) | ||
broadcasted(::typeof($op), F::LazyPropagation, y::T) where T <: Number = broadcasted($(op), eval_prop(F), y) | ||
broadcasted(::typeof($op), y::T, F::LazyPropagation) where T <: Number = broadcasted($(op), y, eval_prop(F)) | ||
end | ||
end | ||
|
||
broadcasted(::typeof(^), y::LazyPropagation, p::Real) = eval_prop(y).^(p) | ||
*(F::judiPropagator, q::LazyPropagation) = F*eval_prop(q) | ||
|
||
reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, Q.q) | ||
reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, F.q, F.val) | ||
vec(F::LazyPropagation) = LazyPropagation(vec, F.F, F.q, F.val) | ||
copyto!(x::AbstractArray, F::LazyPropagation) = copyto!(x, eval_prop(F)) | ||
dot(x::AbstractArray, F::LazyPropagation) = dot(x, eval_prop(F)) | ||
dot(F::LazyPropagation, x::AbstractArray) = dot(x, F) | ||
norm(F::LazyPropagation, p::Real=2) = norm(eval_prop(F), p) | ||
adjoint(F::JUDI.LazyPropagation) = F | ||
length(F::JUDI.LazyPropagation) = size(F.F, 1) | ||
|
||
############################ Two params rules ############################################ | ||
function rrule(F::judiPropagator{T, O}, m::AbstractArray{T}, q::AbstractArray{T}) where {T, O} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not quite true, it's just a container for a single time trace there is nothing about "everywhere in space" in it.