Skip to content

Commit b19ffd6

Browse files
Merge pull request #116 from ErikQQY/qqy/sct
Add extension for SparseConnectivityTracer
2 parents 9b1b2d7 + d494f72 commit b19ffd6

6 files changed

+84
-30
lines changed

.github/workflows/Downgrade.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
group:
2222
- Core
2323
version:
24-
- '1'
24+
- '1.10'
2525
os:
2626
- ubuntu-latest
2727
- macos-latest

Project.toml

+24-18
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PreallocationTools"
22
uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "0.4.24"
4+
version = "0.4.25"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -10,31 +10,36 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1010

1111
[weakdeps]
1212
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
13+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
1314

1415
[extensions]
1516
PreallocationToolsReverseDiffExt = "ReverseDiff"
17+
PreallocationToolsSparseConnectivityTracerExt = "SparseConnectivityTracer"
1618

1719
[compat]
18-
Adapt = "3.4, 4"
19-
Aqua = "0.8"
20-
ArrayInterface = "7.7"
21-
ForwardDiff = "0.10.19"
22-
LabelledArrays = "1.15"
23-
LinearAlgebra = "1"
24-
Optimization = "3.19"
25-
OptimizationOptimJL = "0.1.5"
26-
OrdinaryDiffEq = "6.65"
27-
Pkg = "1"
28-
Random = "1"
29-
RecursiveArrayTools = "3.2"
30-
ReverseDiff = "1"
20+
Adapt = "4.1.1"
21+
ADTypes = "1.13"
22+
Aqua = "0.8.11"
23+
ArrayInterface = "7.18.0"
24+
ForwardDiff = "0.10.38"
25+
LabelledArrays = "1.16.0"
26+
LinearAlgebra = "1.10"
27+
Optimization = "4.1.1"
28+
OptimizationOptimJL = "0.4.1"
29+
OrdinaryDiffEq = "6.91.0"
30+
Pkg = "1.10"
31+
Random = "1.10.8"
32+
RecursiveArrayTools = "3.29.0"
33+
ReverseDiff = "1.15.3"
3134
SafeTestsets = "0.1"
32-
SparseArrays = "1"
33-
Symbolics = "5.12"
34-
Test = "1"
35+
SparseArrays = "1.10"
36+
SparseConnectivityTracer = "0.6.12"
37+
Symbolics = "6.29.0"
38+
Test = "1.10"
3539
julia = "1.10"
3640

3741
[extras]
42+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3843
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3944
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
4045
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -47,8 +52,9 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4752
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
4853
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4954
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
55+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
5056
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
5157
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5258

5359
[targets]
54-
test = ["Aqua", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics"]
60+
test = ["Aqua", "ADTypes", "Random", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SparseArrays", "Symbolics", "SparseConnectivityTracer"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
module PreallocationToolsSparseConnectivityTracerExt
2+
3+
using PreallocationTools
4+
isdefined(Base, :get_extension) ? (import SparseConnectivityTracer) :
5+
(import ..SparseConnectivityTracer)
6+
7+
function PreallocationTools.get_tmp(
8+
dc::DiffCache, u::T) where {T <: SparseConnectivityTracer.Dual}
9+
if isbitstype(T)
10+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
11+
if nelem > length(dc.dual_du)
12+
PreallocationTools.enlargediffcache!(dc, nelem)
13+
end
14+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
15+
else
16+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
17+
end
18+
end
19+
20+
function PreallocationTools.get_tmp(
21+
dc::DiffCache, ::Type{T}) where {T <: SparseConnectivityTracer.Dual}
22+
if isbitstype(T)
23+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
24+
if nelem > length(dc.dual_du)
25+
PreallocationTools.enlargediffcache!(dc, nelem)
26+
end
27+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
28+
else
29+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
30+
end
31+
end
32+
33+
function PreallocationTools.get_tmp(
34+
dc::DiffCache, u::AbstractArray{T}) where {T <: SparseConnectivityTracer.Dual}
35+
if isbitstype(T)
36+
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
37+
if nelem > length(dc.dual_du)
38+
PreallocationTools.enlargediffcache!(dc, nelem)
39+
end
40+
PreallocationTools._restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
41+
else
42+
PreallocationTools._restructure(dc.du, zeros(T, size(dc.du)))
43+
end
44+
end
45+
46+
end

test/core_nesteddual.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,13 @@ newtonsol = solve(optprob, Newton())
9090
cache = DiffCache(zeros(ps, ps), [4, 4, 2])
9191
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, ones(ps, ps), (0.0, 1.0),
9292
(coeffs, cache))
93-
realsol = solve(prob, TRBDF2(chunk_size = 2), saveat = 0.0:0.1:10.0, reltol = 1e-8)
93+
realsol = solve(prob, TRBDF2(autodiff = AutoForwardDiff(chunksize = 2)),
94+
saveat = 0.0:0.1:10.0, reltol = 1e-8)
9495

9596
function objfun(x, prob, realsol, cache)
9697
prob = remake(prob, u0 = eltype(x).(prob.u0), p = (x, cache))
97-
sol = solve(prob, TRBDF2(chunk_size = 2), saveat = 0.0:0.1:10.0, reltol = 1e-8)
98+
sol = solve(prob, TRBDF2(autodiff = AutoForwardDiff(chunksize = 2)),
99+
saveat = 0.0:0.1:10.0, reltol = 1e-8)
98100

99101
ofv = 0.0
100102
if any((s.retcode != ReturnCode.Success for s in sol))

test/core_odes.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra,
22
OrdinaryDiffEq, Test, PreallocationTools, LabelledArrays,
3-
RecursiveArrayTools
3+
RecursiveArrayTools, ADTypes
44

55
#Base array
66
function foo(du, u, (A, tmp), t)
@@ -10,17 +10,17 @@ function foo(du, u, (A, tmp), t)
1010
nothing
1111
end
1212
#with defined chunk_size
13-
chunk_size = 5
13+
chunk_size = 9
1414
u0 = ones(5, 5)
1515
A = ones(5, 5)
1616
cache = DiffCache(zeros(5, 5), chunk_size)
1717
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, cache))
18-
sol = solve(prob, Rodas5P(chunk_size = chunk_size))
18+
sol = solve(prob, Rodas5P(autodiff = AutoForwardDiff(chunksize = chunk_size)))
1919
@test sol.retcode == ReturnCode.Success
2020

2121
cache = FixedSizeDiffCache(zeros(5, 5), chunk_size)
2222
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, cache))
23-
sol = solve(prob, Rodas5P(chunk_size = chunk_size))
23+
sol = solve(prob, Rodas5P(autodiff = AutoForwardDiff(chunksize = chunk_size)))
2424
@test sol.retcode == ReturnCode.Success
2525

2626
#with auto-detected chunk_size
@@ -60,7 +60,7 @@ end
6060
chunk_size = 4
6161
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0),
6262
(A, DiffCache(c, chunk_size)))
63-
sol = solve(prob, Rodas5P(chunk_size = chunk_size))
63+
sol = solve(prob, Rodas5P(autodiff = AutoForwardDiff(chunksize = chunk_size)))
6464
@test sol.retcode == ReturnCode.Success
6565
#with auto-detected chunk_size
6666
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0, 1.0), (A, DiffCache(c)))

test/gpu_all.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using LinearAlgebra,
2-
OrdinaryDiffEq, Test, PreallocationTools, CUDA, ForwardDiff
2+
OrdinaryDiffEq, Test, PreallocationTools, CUDA, ForwardDiff, ADTypes
33

44
# upstream
55
OrdinaryDiffEq.DiffEqBase.anyeltypedual(x::FixedSizeDiffCache, counter = 0) = Any
@@ -56,17 +56,17 @@ function foo(du, u, (A, tmp), t)
5656
nothing
5757
end
5858
#with specified chunk_size
59-
chunk_size = 10
59+
chunk_size = 9
6060
u0 = cu(rand(10, 10)) #example kept small for test purposes.
6161
A = cu(-randn(10, 10))
6262
cache = DiffCache(cu(zeros(10, 10)), chunk_size)
6363
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0f0, 1.0f0), (A, cache))
64-
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
64+
sol = solve(prob, TRBDF2(autodiff = AutoForwardDiff(chunksize = chunk_size)))
6565
@test sol.retcode == ReturnCode.Success
6666

6767
cache = FixedSizeDiffCache(cu(zeros(10, 10)), chunk_size)
6868
prob = ODEProblem{true, SciMLBase.FullSpecialize}(foo, u0, (0.0f0, 1.0f0), (A, cache))
69-
sol = solve(prob, TRBDF2(chunk_size = chunk_size))
69+
sol = solve(prob, TRBDF2(autodiff = AutoForwardDiff(chunksize = chunk_size)))
7070
@test sol.retcode == ReturnCode.Success
7171

7272
#with auto-detected chunk_size

0 commit comments

Comments
 (0)