Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Merge pull request #5 from LuxDL/ap/luxdeviceutils
Browse files Browse the repository at this point in the history
Use centralized device management repo
  • Loading branch information
avik-pal authored Jun 26, 2023
2 parents c66999b + 9d2e251 commit 1d2bf0d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 99 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
version:
- "1"
- "1.6"
- "~1.9.0-0"
steps:
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
/docs/Manifest.toml
/test/coverage/Manifest.toml
LocalPreferences.toml
.vscode
14 changes: 5 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
name = "LuxTestUtils"
uuid = "ac9de150-d08f-4546-94fb-7472b5760531"
authors = ["Avik Pal <[email protected]>"]
version = "0.1.9"
version = "0.1.10"

[deps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -20,23 +19,20 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
AMDGPU = "0.4"
Adapt = "3"
CUDA = "4"
ComponentArrays = "0.13"
FiniteDifferences = "0.12"
ForwardDiff = "0.10"
Functors = "0.4"
JET = "0.4, 0.5, 0.6, 0.7, 0.8"
LuxCore = "0.1"
LuxDeviceUtils = "0.1"
Optimisers = "0.2"
Preferences = "1"
ReverseDiff = "1"
Tracker = "0.2"
Zygote = "0.6"
cuDNN = "1"
julia = "1.6"

[extras]
Expand Down
102 changes: 13 additions & 89 deletions src/LuxTestUtils.jl
Original file line number Diff line number Diff line change
@@ -1,92 +1,11 @@
module LuxTestUtils

using ComponentArrays, Optimisers, Preferences, Test
using ComponentArrays, Optimisers, Preferences, LuxCore, LuxDeviceUtils, Test
using ForwardDiff, ReverseDiff, Tracker, Zygote, FiniteDifferences
# TODO: Yota, Enzyme

const JET_TARGET_MODULES = @load_preference("target_modules", nothing)

### Device Functionalities: REMOVE once moved out of Lux into a separate package
using Adapt, AMDGPU, CUDA, cuDNN, Functors, Random, SparseArrays
import Adapt: adapt_storage

const use_cuda = Ref{Union{Nothing, Bool}}(nothing)
const use_amdgpu = Ref{Union{Nothing, Bool}}(nothing)

abstract type LuxTestUtilsDeviceAdaptor end

struct LuxTestUtilsCPUAdaptor <: LuxTestUtilsDeviceAdaptor end
struct LuxTestUtilsCUDAAdaptor <: LuxTestUtilsDeviceAdaptor end
struct LuxTestUtilsAMDGPUAdaptor <: LuxTestUtilsDeviceAdaptor end

adapt_storage(::LuxTestUtilsCUDAAdaptor, x) = cu(x)
adapt_storage(::LuxTestUtilsCUDAAdaptor, rng::AbstractRNG) = rng

adapt_storage(::LuxTestUtilsAMDGPUAdaptor, x) = roc(x)
adapt_storage(::LuxTestUtilsAMDGPUAdaptor, rng::AbstractRNG) = rng

function adapt_storage(::LuxTestUtilsCPUAdaptor,
x::Union{AbstractRange, SparseArrays.AbstractSparseArray})
return x
end
adapt_storage(::LuxTestUtilsCPUAdaptor, x::AbstractArray) = adapt(Array, x)
adapt_storage(::LuxTestUtilsCPUAdaptor, rng::AbstractRNG) = rng
function adapt_storage(::LuxTestUtilsCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix)
return adapt(Array, x)
end

_isbitsarray(::AbstractArray{<:Number}) = true
_isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T)
_isbitsarray(x) = false

_isleaf(::AbstractRNG) = true
_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)

cpu(x) = fmap(x -> adapt(LuxTestUtilsCPUAdaptor(), x), x)

function cuda_gpu(x)
check_use_cuda()
return use_cuda[] ? fmap(x -> adapt(LuxTestUtilsCUDAAdaptor(), x), x; exclude=_isleaf) :
x
end

function amdgpu_gpu(x)
check_use_amdgpu()
return use_amdgpu[] ?
fmap(x -> adapt(LuxTestUtilsAMDGPUAdaptor(), x), x; exclude=_isleaf) : x
end

function check_use_cuda()
if use_cuda[] === nothing
use_cuda[] = CUDA.functional()
if use_cuda[] && !cuDNN.has_cudnn()
@warn """CUDA.jl found cuda, but did not find libcudnn. Some functionality
will not be available."""
end
if !(use_cuda[])
@info """The GPU function is being called but the GPU is not accessible.
Defaulting back to the CPU. (No action is required if you want
to run on the CPU).""" maxlog=1
end
end
end

function check_use_amdgpu()
if use_amdgpu[] === nothing
use_amdgpu[] = AMDGPU.functional()
if use_amdgpu[] && !AMDGPU.functional(:MIOpen)
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \
available." maxlog=1
end
if !(use_amdgpu[])
@info """The GPU function is being called but the GPU is not accessible.
Defaulting back to the CPU. (No action is required if you want
to run on the CPU).""" maxlog=1
end
end
end
### REMOVE once moved out of Lux into a separate package

# JET Testing
try
using JET
Expand Down Expand Up @@ -182,14 +101,20 @@ end
struct GradientComputationSkipped end

@generated function check_approx(x::X, y::Y; kwargs...) where {X, Y}
device = cpu_device()
(X == GradientComputationSkipped || Y == GradientComputationSkipped) && return :(true)
hasmethod(isapprox, (X, Y)) && return :(isapprox(cpu(x), cpu(y); kwargs...))
hasmethod(isapprox, (X, Y)) && return :(isapprox($(device)(x), $(device)(y); kwargs...))
return quote
@warn "No `isapprox` method found for types $(X) and $(Y). Using `==` instead."
return cpu(x) == cpu(y)
return $(device)(x) == $(device)(y)
end
end

function check_approx(x::LuxCore.AbstractExplicitLayer,
y::LuxCore.AbstractExplicitLayer;
kwargs...)
return x == y
end
check_approx(x::Tuple, y::Tuple; kwargs...) = all(check_approx.(x, y; kwargs...))

function check_approx(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...)
Expand Down Expand Up @@ -474,14 +399,13 @@ __test_broken(test_type, orig_expr, source) = Test.Broken(test_type, orig_expr)

__correct_arguments(x::AbstractArray) = x
function __correct_arguments(x::NamedTuple)
xc = cpu(x)
cpu_dev = cpu_device()
gpu_dev = gpu_device()
xc = cpu_dev(x)
ca = ComponentArray(xc)
# Hacky check to see if there are any non-CPU arrays in the NamedTuple
typeof(xc) == typeof(x) && return ca

ca_cuda = cuda_gpu(ca)
typeof(ca_cuda) == typeof(x) && return ca_cuda
return amdgpu_gpu(ca)
return gpu_dev(ca)
end
__correct_arguments(x) = x

Expand Down

2 comments on commit 1d2bf0d

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/86316

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.10 -m "<description of version>" 1d2bf0df6b663de8694351bcf87b5e069735cb72
git push origin v0.1.10

Please sign in to comment.