-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #153 from JuliaDynamics/hw/caching
improve caching
- Loading branch information
Showing
16 changed files
with
314 additions
and
206 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,11 +4,11 @@ authors = ["Frank Hellmann <[email protected]>, Michael Lindner <michaelli | |
version = "0.9.0" | ||
|
||
[deps] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" | ||
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" | ||
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | ||
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" | ||
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" | ||
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" | ||
|
@@ -30,17 +30,22 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" | |
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" | ||
|
||
[weakdeps] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" | ||
|
||
[extensions] | ||
ModelingToolkitExt = "ModelingToolkit" | ||
CUDAExt = ["CUDA", "Adapt"] | ||
MTKExt = "ModelingToolkit" | ||
|
||
[compat] | ||
Adapt = "4.0.4" | ||
ArgCheck = "2.3.0" | ||
Atomix = "0.1.0" | ||
CUDA = "5.5.2" | ||
DocStringExtensions = "0.9.3" | ||
FastClosures = "0.3.2" | ||
ForwardDiff = "0.10.36" | ||
Graphs = "1" | ||
InteractiveUtils = "1" | ||
KernelAbstractions = "0.9.18" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
module CUDAExt | ||
using NetworkDynamics: Network, NetworkLayer, VertexBatch, EdgeBatch, | ||
KAAggregator, AggregationMap, SparseAggregator, | ||
LazyGBufProvider, EagerGBufProvider, LazyGBuf, | ||
dispatchT, compf, iscudacompatible, executionstyle | ||
using NetworkDynamics.PreallocationTools: DiffCache | ||
using NetworkDynamics: KernelAbstractions as KA | ||
|
||
using CUDA: CuArray | ||
using Adapt: Adapt, adapt | ||
|
||
# main entry for bringing Network to GPU | ||
function Adapt.adapt_structure(to, n::Network) | ||
if to isa KA.GPU | ||
throw(ArgumentError("Looks like to passed an KernelAbstractions backend to adapt Network to GPU. \ | ||
this is not supported as the internal cache types cannot be infered without known the eltype. \ | ||
Please adapt using `CuArray{Float32}` or `CuArray{Float64}`!")) | ||
end | ||
if !(to isa Type{<:CuArray}) | ||
throw(ArgumentError("Can't handle Adaptor $to. \ | ||
Please adapt using `CuArray{Float32}` or `CuArray{Float64}`!")) | ||
end | ||
if eltype(to) ∉ (Float32, Float64) | ||
throw(ArgumentError("Use adapt on Network with either `CuArray{Float32}` or `CuArray{Float64}` \ | ||
such that internal caches can be created with the correct type!")) | ||
end | ||
if !iscudacompatible(n) | ||
throw(ArgumentError("The provided network has non-cuda compatible aggregator or exectuion strategies.")) | ||
end | ||
vb = adapt(to, n.vertexbatches) | ||
layer = adapt(to, n.layer) | ||
mm = adapt(to, n.mass_matrix) | ||
gbp = adapt(to, n.gbufprovider) | ||
caches = (;state = _adapt_diffcache(to, n.caches.state), | ||
aggregation = _adapt_diffcache(to, n.caches.aggregation)) | ||
exT = typeof(executionstyle(n)) | ||
gT = typeof(n.im.g) | ||
|
||
Network{exT,gT,typeof(layer),typeof(vb),typeof(mm),eltype(caches),typeof(gbp)}( | ||
vb, layer, n.im, caches, mm, gbp) | ||
end | ||
|
||
Adapt.@adapt_structure NetworkLayer | ||
|
||
|
||
#### | ||
#### Adapt Aggregators | ||
#### | ||
Adapt.@adapt_structure KAAggregator | ||
|
||
# overload to retain int types for aggregation map | ||
Adapt.@adapt_structure AggregationMap | ||
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, am::AggregationMap) | ||
map = adapt(CuArray, am.map) | ||
symmap = adapt(CuArray, am.symmap) | ||
AggregationMap(am.range, map, am.symrange, symmap) | ||
end | ||
|
||
Adapt.@adapt_structure SparseAggregator | ||
|
||
|
||
#### | ||
#### Adapt GBufProviders | ||
#### | ||
Adapt.@adapt_structure LazyGBufProvider | ||
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, gbp::LazyGBufProvider) | ||
adapt(CuArray, gbp) # preserve Vector{UnitRange} | ||
end | ||
Adapt.@adapt_structure LazyGBuf | ||
|
||
# overload to retain int types for eager gbuf map | ||
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, gbp::EagerGBufProvider) | ||
_adapt_eager_gbufp(CuArray, to, gbp) | ||
end | ||
function Adapt.adapt_structure(to, gbp::EagerGBufProvider) | ||
_adapt_eager_gbufp(to, to, gbp) | ||
end | ||
function _adapt_eager_gbufp(mapto, cacheto, gbp) | ||
map = adapt(mapto, gbp.map) | ||
cache = _adapt_diffcache(cacheto, gbp.diffcache) | ||
EagerGBufProvider(map, cache) | ||
end | ||
|
||
|
||
#### | ||
#### Adapt VertexBatch/EdgeBatch | ||
#### | ||
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, b::VertexBatch) | ||
Adapt.adapt_structure(CuArray, b) | ||
end | ||
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, b::EdgeBatch) | ||
Adapt.adapt_structure(CuArray, b) | ||
end | ||
function Adapt.adapt_structure(to, b::VertexBatch) | ||
idxs = adapt(to, b.indices) | ||
VertexBatch{dispatchT(b), typeof(compf(b)), typeof(idxs)}( | ||
idxs, compf(b), b.statestride, b.pstride, b.aggbufstride) | ||
end | ||
function Adapt.adapt_structure(to, b::EdgeBatch) | ||
idxs = adapt(to, b.indices) | ||
EdgeBatch{dispatchT(b), typeof(compf(b)), typeof(idxs)}( | ||
idxs, compf(b), b.statestride, b.pstride, b.gbufstride) | ||
end | ||
|
||
|
||
#### | ||
#### utils | ||
#### | ||
# define similar to adapt_structure for DiffCache without type piracy | ||
function _adapt_diffcache(to, c::DiffCache) | ||
du = adapt(to, c.du) | ||
dual_du = adapt(to, c.dual_du) | ||
DiffCache(du, dual_du, c.any_du) | ||
# N = length(c.dual_du) ÷ length(c.du) - 1 | ||
# DiffCache(du, N) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.