Skip to content

Commit b50c226

Browse files
Metal Extension with fix for tanh_fast (#666)
* metal ext * tests * add test * don't test on nightly
1 parent f880a87 commit b50c226

File tree

7 files changed

+122
-7
lines changed

7 files changed

+122
-7
lines changed

.buildkite/pipeline.yml

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ steps:
1919
env:
2020
JULIA_NUM_THREADS: 4
2121
NNLIB_TEST_CUDA: "true"
22-
NNLIB_TEST_CPU: "true" # Could be useful to uncover multithreading related issues
23-
# Buildkite workers have more threads.
22+
NNLIB_TEST_CPU: "false"
2423
if: build.message !~ /\[skip tests\]/
2524
timeout_in_minutes: 180
2625
matrix:
@@ -34,6 +33,7 @@ steps:
3433
julia: "nightly"
3534
soft_fail: true
3635

36+
3737
- label: ":julia: Julia {{matrix.julia}} - AMD GPU"
3838
command:
3939
- echo 'AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"' >> test/Project.toml
@@ -65,10 +65,43 @@ steps:
6565
# - "1.10"
6666
- "1"
6767
# - "nightly"
68-
adjustments:
69-
- with:
70-
julia: "nightly"
71-
soft_fail: true
68+
# adjustments:
69+
# - with:
70+
# julia: "nightly"
71+
# soft_fail: true
72+
73+
74+
- label: ":julia: Julia {{matrix.julia}} - Metal GPU"
75+
command:
76+
- echo 'Metal = "dde4c033-4e86-420c-a63e-0dd931031962"' >> test/Project.toml
77+
plugins:
78+
- JuliaCI/julia#v1:
79+
version: "{{matrix.julia}}"
80+
- JuliaCI/julia-test#v1:
81+
test_args: "--quickfail"
82+
- JuliaCI/julia-coverage#v1:
83+
codecov: true
84+
dirs:
85+
- src
86+
- ext
87+
agents:
88+
os: "macos"
89+
arch: "aarch64"
90+
timeout_in_minutes: 180
91+
env:
92+
NNLIB_TEST_METAL: "true"
93+
NNLIB_TEST_CPU: "false"
94+
JULIA_NUM_THREADS: 4
95+
matrix:
96+
setup:
97+
julia:
98+
# - "1.10"
99+
- "1"
100+
# - "nightly "
101+
# adjustments:
102+
# - with:
103+
# julia: "nightly"
104+
# soft_fail: true
72105

73106

74107
- label: "Benchmarks"

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1919
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
2020
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
2121
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
22+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
2223
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2324
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2425

@@ -29,6 +30,7 @@ NNlibCUDAExt = "CUDA"
2930
NNlibEnzymeCoreExt = "EnzymeCore"
3031
NNlibFFTWExt = "FFTW"
3132
NNlibForwardDiffExt = "ForwardDiff"
33+
NNlibMetalExt = "Metal"
3234
NNlibSpecialFunctionsExt = "SpecialFunctions"
3335

3436
[compat]
@@ -43,9 +45,10 @@ ForwardDiff = "1"
4345
GPUArraysCore = "0.2"
4446
KernelAbstractions = "0.9.2"
4547
LinearAlgebra = "1"
48+
Metal = "1.6"
4649
Random = "1"
4750
ScopedValues = "1.3.0"
4851
SpecialFunctions = "2"
4952
Statistics = "1"
5053
cuDNN = "1"
51-
julia = "1.10"
54+
julia = "1.10"

ext/NNlibMetalExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module NNlibMetalExt
2+
3+
4+
using Metal: method_table, @device_override
5+
using NNlib: NNlib
6+
7+
@device_override NNlib.tanh_fast(x) = Base.FastMath.tanh_fast(x)
8+
9+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
1414
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1616
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
17+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1718
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1819
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

test/ext_metal/activations.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@testset "activation broadcast" begin
2+
broken_f = (:hardσ, :leakyrelu)
3+
for name in NNlib.ACTIVATIONS
4+
# println("Testing forward diff for activation: ", name)
5+
f = @eval $name
6+
@test gputest(DEVICE, x -> f.(x), rand(5)) broken=name broken_f
7+
end
8+
end
9+
10+
@testset "forward diff" begin
11+
broken_f = (:hardσ, :leakyrelu)
12+
for name in NNlib.ACTIVATIONS
13+
# println("Testing forward diff for activation: ", name)
14+
f = @eval $name
15+
@test gputest(DEVICE, x -> f.(x), Dual.(rand(Float32, 5), 1)) broken=name broken_f
16+
end
17+
end

test/ext_metal/runtests.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using NNlib
2+
using Test
3+
using Metal
4+
using Zygote: gradient
5+
using MLDataDevices: gpu_device
6+
using ForwardDiff: Dual
7+
8+
Metal.allowscalar(false)
9+
10+
#TODO move this to test/ test_utils.jl and use it with all backends
11+
function gputest(device, f, xs...; checkgrad=true, atol=1e-6, kws...)
12+
cpu_in = xs
13+
gpu_in = device(xs)
14+
15+
cpu_out = f(cpu_in...; kws...)
16+
gpu_out = f(gpu_in...; kws...)
17+
@test collect(cpu_out) collect(gpu_out)
18+
19+
if checkgrad
20+
cpu_grad = gradient((x...) -> sum(f(x...; kws...)), cpu_in...)
21+
gpu_grad = gradient((x...) -> sum(f(x...; kws...)), gpu_in...)
22+
for (cpu_g, gpu_g) in zip(cpu_grad, gpu_grad)
23+
if cpu_g === nothing
24+
@test gpu_g === nothing
25+
else
26+
@test collect(cpu_g) collect(gpu_g) atol=atol
27+
end
28+
end
29+
end
30+
return true
31+
end
32+
33+
DEVICE = gpu_device(force=true)
34+
35+
include("activations.jl")

test/runtests.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursiv
2424

2525
# ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests
2626
# ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests
27+
# ENV["NNLIB_TEST_METAL"] = "true" # uncomment to run Metal tests
2728
# ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests
2829

2930
const rng = StableRNG(123)
@@ -184,4 +185,20 @@ end
184185
else
185186
@info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them."
186187
end
188+
189+
if get(ENV, "NNLIB_TEST_METAL", "false") == "true"
190+
Pkg.add("Metal")
191+
192+
using Metal
193+
if Metal.functional()
194+
@testset "Metal" begin
195+
# nnlib_testsuite(MetalBackend)
196+
include("ext_metal/runtests.jl")
197+
end
198+
else
199+
@info "Insufficient version or Metal not found; Skipping Metal tests"
200+
end
201+
else
202+
@info "Skipping Metal tests, set NNLIB_TEST_METAL=true to run them"
203+
end
187204
end

0 commit comments

Comments
 (0)