Skip to content

Commit e2df7ec

Browse files
Merge pull request #112 from SciML/bigprealloc
Allow bigfloats through dualcaches
2 parents 98e4b02 + 2cf2bd2 commit e2df7ec

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

src/PreallocationTools.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -129,27 +129,39 @@ const dualcache = DiffCache
129129
Returns the `Dual` or normal cache array stored in `dc` based on the type of `u`.
130130
"""
131131
function get_tmp(dc::DiffCache, u::T) where {T <: ForwardDiff.Dual}
132-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
133-
if nelem > length(dc.dual_du)
134-
enlargediffcache!(dc, nelem)
132+
if isbitstype(T)
133+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
134+
if nelem > length(dc.dual_du)
135+
enlargediffcache!(dc, nelem)
136+
end
137+
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
138+
else
139+
_restructure(dc.du, zeros(T, size(dc.du)))
135140
end
136-
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
137141
end
138142

139143
function get_tmp(dc::DiffCache, ::Type{T}) where {T <: ForwardDiff.Dual}
140-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
141-
if nelem > length(dc.dual_du)
142-
enlargediffcache!(dc, nelem)
144+
if isbitstype(T)
145+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
146+
if nelem > length(dc.dual_du)
147+
enlargediffcache!(dc, nelem)
148+
end
149+
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
150+
else
151+
_restructure(dc.du, zeros(T, size(dc.du)))
143152
end
144-
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
145153
end
146154

147155
function get_tmp(dc::DiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
148-
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
149-
if nelem > length(dc.dual_du)
150-
enlargediffcache!(dc, nelem)
156+
if isbitstype(T)
157+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
158+
if nelem > length(dc.dual_du)
159+
enlargediffcache!(dc, nelem)
160+
end
161+
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
162+
else
163+
_restructure(dc.du, zeros(T, size(dc.du)))
151164
end
152-
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
153165
end
154166

155167
function get_tmp(dc::DiffCache, u::Union{Number, AbstractArray})

0 commit comments

Comments
 (0)