Skip to content

Commit

Permalink
Merge pull request #153 from JuliaDynamics/hw/caching
Browse files Browse the repository at this point in the history
improve caching
  • Loading branch information
hexaeder authored Oct 14, 2024
2 parents a65b3e3 + 30cb2f4 commit 3858f6d
Show file tree
Hide file tree
Showing 16 changed files with 314 additions and 206 deletions.
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ authors = ["Frank Hellmann <[email protected]>, Michael Lindner <michaelli
version = "0.9.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand All @@ -30,17 +30,22 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

[extensions]
ModelingToolkitExt = "ModelingToolkit"
CUDAExt = ["CUDA", "Adapt"]
MTKExt = "ModelingToolkit"

[compat]
Adapt = "4.0.4"
ArgCheck = "2.3.0"
Atomix = "0.1.0"
CUDA = "5.5.2"
DocStringExtensions = "0.9.3"
FastClosures = "0.3.2"
ForwardDiff = "0.10.36"
Graphs = "1"
InteractiveUtils = "1"
KernelAbstractions = "0.9.18"
Expand Down
1 change: 1 addition & 0 deletions benchmark/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Pkg.activate(BMPATH);
if VERSION < v"1.11.0-0"
Pkg.develop(; path=NDPATH);
end
Pkg.update();
Pkg.precompile();

using PkgBenchmark
Expand Down
118 changes: 118 additions & 0 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
module CUDAExt
using NetworkDynamics: Network, NetworkLayer, VertexBatch, EdgeBatch,
KAAggregator, AggregationMap, SparseAggregator,
LazyGBufProvider, EagerGBufProvider, LazyGBuf,
dispatchT, compf, iscudacompatible, executionstyle
using NetworkDynamics.PreallocationTools: DiffCache
using NetworkDynamics: KernelAbstractions as KA

using CUDA: CuArray
using Adapt: Adapt, adapt

# main entry for bringing Network to GPU
function Adapt.adapt_structure(to, n::Network)
if to isa KA.GPU
throw(ArgumentError("Looks like to passed an KernelAbstractions backend to adapt Network to GPU. \
this is not supported as the internal cache types cannot be infered without known the eltype. \
Please adapt using `CuArray{Float32}` or `CuArray{Float64}`!"))
end
if !(to isa Type{<:CuArray})
throw(ArgumentError("Can't handle Adaptor $to. \
Please adapt using `CuArray{Float32}` or `CuArray{Float64}`!"))
end
if eltype(to) (Float32, Float64)
throw(ArgumentError("Use adapt on Network with either `CuArray{Float32}` or `CuArray{Float64}` \
such that internal caches can be created with the correct type!"))
end
if !iscudacompatible(n)
throw(ArgumentError("The provided network has non-cuda compatible aggregator or exectuion strategies."))
end
vb = adapt(to, n.vertexbatches)
layer = adapt(to, n.layer)
mm = adapt(to, n.mass_matrix)
gbp = adapt(to, n.gbufprovider)
caches = (;state = _adapt_diffcache(to, n.caches.state),
aggregation = _adapt_diffcache(to, n.caches.aggregation))
exT = typeof(executionstyle(n))
gT = typeof(n.im.g)

Network{exT,gT,typeof(layer),typeof(vb),typeof(mm),eltype(caches),typeof(gbp)}(
vb, layer, n.im, caches, mm, gbp)
end

Adapt.@adapt_structure NetworkLayer


####
#### Adapt Aggregators
####
Adapt.@adapt_structure KAAggregator

# overload to retain int types for aggregation map
Adapt.@adapt_structure AggregationMap
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, am::AggregationMap)
map = adapt(CuArray, am.map)
symmap = adapt(CuArray, am.symmap)
AggregationMap(am.range, map, am.symrange, symmap)
end

Adapt.@adapt_structure SparseAggregator


####
#### Adapt GBufProviders
####
Adapt.@adapt_structure LazyGBufProvider
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, gbp::LazyGBufProvider)
adapt(CuArray, gbp) # preserve Vector{UnitRange}
end
Adapt.@adapt_structure LazyGBuf

# overload to retain int types for eager gbuf map
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, gbp::EagerGBufProvider)
_adapt_eager_gbufp(CuArray, to, gbp)
end
function Adapt.adapt_structure(to, gbp::EagerGBufProvider)
_adapt_eager_gbufp(to, to, gbp)
end
function _adapt_eager_gbufp(mapto, cacheto, gbp)
map = adapt(mapto, gbp.map)
cache = _adapt_diffcache(cacheto, gbp.diffcache)
EagerGBufProvider(map, cache)
end


####
#### Adapt VertexBatch/EdgeBatch
####
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, b::VertexBatch)
Adapt.adapt_structure(CuArray, b)
end
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, b::EdgeBatch)
Adapt.adapt_structure(CuArray, b)
end
function Adapt.adapt_structure(to, b::VertexBatch)
idxs = adapt(to, b.indices)
VertexBatch{dispatchT(b), typeof(compf(b)), typeof(idxs)}(
idxs, compf(b), b.statestride, b.pstride, b.aggbufstride)
end
function Adapt.adapt_structure(to, b::EdgeBatch)
idxs = adapt(to, b.indices)
EdgeBatch{dispatchT(b), typeof(compf(b)), typeof(idxs)}(
idxs, compf(b), b.statestride, b.pstride, b.gbufstride)
end


####
#### utils
####
# define similar to adapt_structure for DiffCache without type piracy
function _adapt_diffcache(to, c::DiffCache)
du = adapt(to, c.du)
dual_du = adapt(to, c.dual_du)
DiffCache(du, dual_du, c.any_du)
# N = length(c.dual_du) ÷ length(c.du) - 1
# DiffCache(du, N)
end

end
4 changes: 2 additions & 2 deletions ext/ModelingToolkitExt.jl → ext/MTKExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module ModelingToolkitExt
module MTKExt

using ModelingToolkit: Symbolic, iscall, operation, arguments, build_function
using ModelingToolkit: ModelingToolkit, Equation, ODESystem, Differential
Expand All @@ -11,7 +11,7 @@ using LinearAlgebra: Diagonal, I
using NetworkDynamics: Fiducial
import NetworkDynamics: ODEVertex, StaticEdge

include("ModelingToolkitUtils.jl")
include("MTKUtils.jl")

function ODEVertex(sys::ODESystem, inputs, outputs; verbose=false, name=getname(sys))
warn_events(sys)
Expand Down
File renamed without changes.
8 changes: 3 additions & 5 deletions src/NetworkDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Graphs: Graphs, AbstractGraph, SimpleEdge, edges, vertices, ne, nv,
using TimerOutputs: @timeit_debug, reset_timer!

using ArgCheck: @argcheck
using PreallocationTools: PreallocationTools, LazyBufferCache, DiffCache
using PreallocationTools: PreallocationTools, DiffCache, get_tmp
using SciMLBase: SciMLBase
using Base.Threads: @threads
using NNlib: NNlib
Expand All @@ -18,15 +18,14 @@ using DocStringExtensions: FIELDS, TYPEDEF
using StyledStrings: StyledStrings, @styled_str
using RecursiveArrayTools: DiffEqArray
using FastClosures: @closure
using ForwardDiff: ForwardDiff

@static if VERSION v"1.11.0-0"
using Base: AnnotatedIOBuffer, AnnotatedString
else
using StyledStrings: AnnotatedIOBuffer, AnnotatedString
end

using Adapt: Adapt, adapt

using Base: @propagate_inbounds
using InteractiveUtils: subtypes

Expand Down Expand Up @@ -54,11 +53,10 @@ include("network_structure.jl")
export NaiveAggregator, KAAggregator, SequentialAggregator,
PolyesterAggregator, ThreadedAggregator
include("aggregators.jl")
include("gbufs.jl")
include("construction.jl")
include("coreloop.jl")

include("adapt.jl")

# XXX: have both, s[:] and uflat(s) ?
export VIndex, EIndex, VPIndex, EPIndex, NWState, NWParameter, uflat, pflat
export vidxs, eidxs, vpidxs, epidxs
Expand Down
25 changes: 0 additions & 25 deletions src/adapt.jl

This file was deleted.

3 changes: 3 additions & 0 deletions src/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ function aggregate!(a::KAAggregator, aggbuf, data)
# kernel(a.f, aggbuf, view(data, am.range), am.map)
kernel = agg_kernel!(_backend)
kernel(a.f, aggbuf, view(data, am.range), am.map; ndrange=length(am.map))
# TODO: synchronize after both aggregation sweeps?
KernelAbstractions.synchronize(_backend)

if !isempty(am.symrange)
Expand Down Expand Up @@ -268,6 +269,8 @@ function SparseAggregator(f)
SparseAggregator
end
function SparseAggregator(im, batches)
# sparse multiply is faster with Matrix{Float} . Vector{Float} than int!
# (both on GPU and CPU)
I, J, V = Float64[], Float64[], Float64[]
unrolled_foreach(batches) do batch
for eidx in batch.indices
Expand Down
36 changes: 16 additions & 20 deletions src/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,27 @@ function Network(g::AbstractGraph,
_aggregator = aggregator(im, edgebatches)
end

nl = NetworkLayer(im, edgebatches, _aggregator)
nl = NetworkLayer(im.g, edgebatches, _aggregator, im.edepth, im.vdepth)

@assert isdense(im)
mass_matrix = construct_mass_matrix(im)
nw = Network{typeof(execution),typeof(g),typeof(nl),typeof(vertexbatches),typeof(mass_matrix)}(
N = ForwardDiff.pickchunksize(max(im.lastidx_dynamic, im.lastidx_p))
caches = (;state = DiffCache(zeros(im.lastidx_static), N),
aggregation = DiffCache(zeros(im.lastidx_aggr), N))

gbufprovider = if usebuffer(execution)
EagerGBufProvider(im, edgebatches)
else
LazyGBufProvider(im, edgebatches)
end

nw = Network{typeof(execution),typeof(g),typeof(nl), typeof(vertexbatches),
typeof(mass_matrix),eltype(caches),typeof(gbufprovider)}(
vertexbatches,
nl, im,
LazyBufferCache(),
mass_matrix
caches,
mass_matrix,
gbufprovider
)

end
Expand Down Expand Up @@ -236,22 +248,6 @@ function EdgeBatch(im::IndexManager, idxs::Vector{Int}; verbose)
end
end

function NetworkLayer(im::IndexManager, batches, agg)
map = zeros(Int, ne(im.g) * im.vdepth, 2)
for batch in batches
for i in 1:length(batch)
eidx = batch.indices[i]
e = im.edgevec[eidx]
dst_range = im.v_data[e.src][1:im.vdepth]
src_range = im.v_data[e.dst][1:im.vdepth]
range = gbuf_range(batch, i)
map[range, 1] .= dst_range
map[range, 2] .= src_range
end
end
NetworkLayer(im.g, batches, agg, im.edepth, im.vdepth, map)
end

batch_by_idxs(v, idxs::Vector{Vector{Int}}) = [v for batch in idxs]
function batch_by_idxs(v::AbstractVector, batches::Vector{Vector{Int}})
@assert length(v) == sum(length.(batches))
Expand Down
Loading

0 comments on commit 3858f6d

Please sign in to comment.