Skip to content

Commit

Permalink
Merge pull request #335 from ashutosh-b-b/bb/linear_interp_array
Browse files Browse the repository at this point in the history
LinearInterpolation: Fix scalar indexing and generalize for ndims(u) > 2
  • Loading branch information
ChrisRackauckas authored Sep 24, 2024
2 parents 5fc47d7 + 97529bb commit beba54c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 9 deletions.
5 changes: 3 additions & 2 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, igues
val
end

function _interpolate(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess)
function _interpolate(A::LinearInterpolation{<:AbstractArray}, t::Number, iguess)
idx = get_idx(A, t, iguess)
Δt = t - A.t[idx]
slope = get_parameters(A, idx)
return A.u[:, idx] + slope * Δt
ax = axes(A.u)[1:(end - 1)]
return A.u[ax..., idx] + slope * Δt
end

# Quadratic Interpolation
Expand Down
15 changes: 15 additions & 0 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ function munge_data(U::StridedMatrix, t::AbstractVector)
return U, t
end

function munge_data(U::AbstractArray{T, N}, t) where {T, N}
TU = Base.nonmissingtype(eltype(U))
Tt = Base.nonmissingtype(eltype(t))
@assert length(t) == size(U, ndims(U))
ax = axes(U)[1:(end - 1)]
non_missing_indices = collect(
i for i in 1:length(t)
if !any(ismissing, U[ax..., i]) && !ismissing(t[i])
)
U = cat([TU.(U[ax..., i]) for i in non_missing_indices]...; dims = ndims(U))
t = Tt.([t[i] for i in non_missing_indices])

return U, t
end

seems_linear(assume_linear_t::Bool, _) = assume_linear_t
seems_linear(assume_linear_t::Number, t) = looks_linear(t; threshold = assume_linear_t)

Expand Down
8 changes: 5 additions & 3 deletions src/parameter_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ function safe_diff(b, a::T) where {T}
b == a ? zero(T) : b - a
end

function linear_interpolation_parameters(u::AbstractArray{T}, t, idx) where {T}
Δu = if u isa AbstractMatrix
[safe_diff(u[j, idx + 1], u[j, idx]) for j in 1:size(u)[1]]
function linear_interpolation_parameters(u::AbstractArray{T, N}, t, idx) where {T, N}
Δu = if N > 1
ax = axes(u)
safe_diff.(
u[ax[1:(end - 1)]..., (idx + 1):(idx + 1)], u[ax[1:(end - 1)]..., idx:idx])
else
safe_diff(u[idx + 1], u[idx])
end
Expand Down
8 changes: 4 additions & 4 deletions test/interpolation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ end
A = LinearInterpolation(u, t; extrapolate = true)

for (_t, _u) in zip(t, eachcol(u))
@test A(_t) == _u
@test A(_t) == reshape(_u, :, 1)
end
@test A(0) == [0.0, 0.0]
@test A(5.5) == [11.0, 16.5]
@test A(11) == [22, 33]
@test A(0) == [0.0; 0.0;;]
@test A(5.5) == [11.0; 16.5;;]
@test A(11) == [22; 33;;]

x = 1:10
y = 2:4
Expand Down

0 comments on commit beba54c

Please sign in to comment.