Skip to content

Commit

Permalink
Merge pull request #394 from SciML/dw/get_u
Browse files Browse the repository at this point in the history
Do not fall back to inplace interpolation method
  • Loading branch information
ChrisRackauckas authored Feb 7, 2025
2 parents 8ead701 + e3cfc01 commit b983363
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 44 deletions.
42 changes: 17 additions & 25 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,29 @@ include("online.jl")
include("show.jl")

(interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t)

function (interp::AbstractInterpolation)(t::AbstractVector)
u = get_u(interp.u, t)
interp(u, t)
end

function get_u(u::AbstractVector, t)
return similar(t, promote_type(eltype(u), eltype(t)))
end

function get_u(u::AbstractVector{<:AbstractVector}, t)
type = promote_type(eltype(eltype(u)), eltype(t))
return [zeros(type, length(first(u))) for _ in eachindex(t)]
end

function get_u(u::AbstractMatrix, t)
type = promote_type(eltype(u), eltype(t))
return zeros(type, (size(u, 1), length(t)))
if interp.u isa AbstractVector
# Return a vector of interpolated values, on for each element in `t`
return map(interp, t)
elseif interp.u isa AbstractArray
# Stack interpolated values if `u` was stored in matrix/... form
return stack(interp, t)
end
end

function (interp::AbstractInterpolation)(u::AbstractMatrix, t::AbstractVector)
@inbounds for i in eachindex(t)
u[:, i] = interp(t[i])
function (interp::AbstractInterpolation)(out::AbstractVector, t::AbstractVector)
if length(out) != length(t)
throw(DimensionMismatch("number of evaluation points and length of the result vector must be equal"))
end
u
map!(interp, out, t)
return out
end
function (interp::AbstractInterpolation)(u::AbstractVector, t::AbstractVector)
@inbounds for i in eachindex(u, t)
u[i] = interp(t[i])
function (interp::AbstractInterpolation)(out::AbstractArray, t::AbstractVector)
if size(out, ndims(out)) != length(t)
throw(DimensionMismatch("number of evaluation points and last dimension of the result array must be equal"))
end
u
map!(interp, eachslice(out; dims = ndims(out)), t)
return out
end

const EXTRAPOLATION_ERROR = "Cannot extrapolate as `extrapolate` keyword passed was `false`"
Expand Down
24 changes: 18 additions & 6 deletions test/extrapolation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,16 @@ end
# Left extrapolation
A = ConstantInterpolation(u_un, t_un; extrapolation_left = extrapolation_type)
t_eval = 0.0u"s"
@test A(t_eval) == 1.0u"m"
@test @inferred(A(t_eval)) == 1.0u"m"
@test @inferred(A([t_eval])) == [1.0u"m"]
@test A([t_eval]) isa Vector{typeof(1.0u"m")}

# Right extrapolation
A = ConstantInterpolation(u_un, t_un; extrapolation_right = extrapolation_type)
t_eval = 3.0u"s"
@test A(t_eval) == 2.0u"m"
@test @inferred(A(t_eval)) == 2.0u"m"
@test @inferred(A([t_eval])) == [2.0u"m"]
@test A([t_eval]) isa Vector{typeof(2.0u"m")}
end
end

Expand All @@ -68,22 +72,30 @@ end
# Left constant extrapolation
A = LinearInterpolation(u_un, t_un; extrapolation_left = ExtrapolationType.Constant)
t_eval = 0.0u"s"
@test A(t_eval) == 1.0u"m"
@test @inferred(A(t_eval)) == 1.0u"m"
@test @inferred(A([t_eval])) == [1.0u"m"]
@test A([t_eval]) isa Vector{typeof(1.0u"m")}

# Right constant extrapolation
A = LinearInterpolation(u_un, t_un; extrapolation_right = ExtrapolationType.Constant)
t_eval = 3.0u"s"
@test A(t_eval) == 2.0u"m"
@test @inferred(A(t_eval)) == 2.0u"m"
@test @inferred(A([t_eval])) == [2.0u"m"]
@test A([t_eval]) isa Vector{typeof(2.0u"m")}

# Left linear extrapolation
A = LinearInterpolation(u_un, t_un; extrapolation_left = ExtrapolationType.Linear)
t_eval = 0.0u"s"
@test A(t_eval) == 0.0u"m"
@test @inferred(A(t_eval)) == 0.0u"m"
@test @inferred(A([t_eval])) == [0.0u"m"]
@test A([t_eval]) isa Vector{typeof(0.0u"m")}

# Right constant extrapolation
A = LinearInterpolation(u_un, t_un; extrapolation_right = ExtrapolationType.Linear)
t_eval = 3.0u"s"
@test A(t_eval) == 3.0u"m"
@test @inferred(A(t_eval)) == 3.0u"m"
@test @inferred(A([t_eval])) == [3.0u"m"]
@test A([t_eval]) isa Vector{typeof(3.0u"m")}
end

@testset "Linear Interpolation" begin
Expand Down
46 changes: 33 additions & 13 deletions test/interpolation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,13 @@ end
itp = ConstantInterpolation([2], [0.0]; extrapolation = ExtrapolationType.Constant)
@test itp(1.0) === 2
@test itp(-1.0) === 2

# Test output type of vector evaluation (issue #388)
u = [2, 3]
t = [0.0, 1.0]
itp = ConstantInterpolation(u, t)
@test @inferred(itp(t)) == itp.(t)
@test typeof(itp(t)) === typeof(itp.(t)) === Vector{Int}
end

@testset "QuadraticSpline Interpolation" begin
Expand Down Expand Up @@ -855,33 +862,46 @@ end

@testset "Type of vector returned" begin
# Issue https://github.com/SciML/DataInterpolations.jl/issues/253
t1 = Float32[0.1, 0.2, 0.3, 0.4, 0.5]
t2 = Float64[0.1, 0.2, 0.3, 0.4, 0.5]
interps_and_types = [
(LinearInterpolation(t1, t1), Float32),
(LinearInterpolation(t1, t2), Float32),
(LinearInterpolation(t2, t1), Float64),
(LinearInterpolation(t2, t2), Float64)
]
for i in eachindex(interps_and_types)
@test eltype(interps_and_types[i][1](t1)) == interps_and_types[i][2]
ut1 = Float32[0.1, 0.2, 0.3, 0.4, 0.5]
ut2 = Float64[0.1, 0.2, 0.3, 0.4, 0.5]
for u in (ut1, ut2), t in (ut1, ut2)
interp = LinearInterpolation(ut1, ut2)
for xs in (u, t)
ys = @inferred(interp(xs))
@test ys isa Vector{typeof(interp(first(xs)))}
@test all(y == interp(x) for (x, y) in zip(xs, ys))
end
end
end

@testset "Plugging vector timepoints" begin
# Issue https://github.com/SciML/DataInterpolations.jl/issues/267
t = Float64[1.0, 2.0, 3.0, 4.0, 5.0]
x = Float64[1.3, 2.2, 4.1]
@testset "utype - Vectors" begin
interp = LinearInterpolation(rand(5), t)
@test interp(t) isa Vector{Float64}
y = interp(x)
@test y isa Vector{Float64}
@test length(y) == 3
end
@testset "utype - Vector of Vectors" begin
interp = LinearInterpolation([rand(2) for _ in 1:5], t)
@test interp(t) isa Vector{Vector{Float64}}
y = interp(x)
@test y isa Vector{Vector{Float64}}
@test length(y) == 3
@test all(length(yi) == 2 for yi in y)
end
@testset "utype - Matrix" begin
interp = LinearInterpolation(rand(2, 5), t)
@test interp(t) isa Matrix{Float64}
y = interp(x)
@test y isa Matrix{Float64}
@test size(y) == (2, 3)
end
@testset "utype - Array" begin
interp = LinearInterpolation(rand(2, 3, 4, 5), t)
y = interp(x)
@test y isa Array{Float64, 4}
@test size(y) == (2, 3, 4, 3)
end
end

Expand Down

0 comments on commit b983363

Please sign in to comment.