-
-
Notifications
You must be signed in to change notification settings - Fork 66
Expand file tree
/
Copy pathcuda_tests.jl
More file actions
74 lines (62 loc) · 2.53 KB
/
cuda_tests.jl
File metadata and controls
74 lines (62 loc) · 2.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@testitem "CUDA Tests" tags=[:cuda] begin
using CUDA, NonlinearSolve, LinearSolve, StableRNGs
if CUDA.functional()
CUDA.allowscalar(false)
A = cu(rand(StableRNG(0), 4, 4))
u0 = cu(rand(StableRNG(0), 4))
b = cu(rand(StableRNG(0), 4))
linear_f(du, u, p) = (du .= A * u .+ b)
prob = NonlinearProblem(linear_f, u0)
SOLVERS = (
NewtonRaphson(),
LevenbergMarquardt(; linsolve = QRFactorization()),
LevenbergMarquardt(; linsolve = KrylovJL_GMRES()),
PseudoTransient(),
Klement(),
Broyden(; linesearch = LiFukushimaLineSearch()),
LimitedMemoryBroyden(; threshold = 2, linesearch = LiFukushimaLineSearch()),
DFSane(),
TrustRegion(; linsolve = QRFactorization()),
TrustRegion(; linsolve = KrylovJL_GMRES(), concrete_jac = true), # Needed if Zygote not loaded
nothing
)
@testset "[IIP] GPU Solvers" begin
@testset "$(nameof(typeof(alg)))" for alg in SOLVERS
@test_nowarn sol = solve(prob, alg; abstol = 1.0f-5, reltol = 1.0f-5)
end
end
linear_f(u, p) = A * u .+ b
prob = NonlinearProblem{false}(linear_f, u0)
@testset "[OOP] GPU Solvers" begin
@testset "$(nameof(typeof(alg)))" for alg in SOLVERS
@test_nowarn sol = solve(prob, alg; abstol = 1.0f-5, reltol = 1.0f-5)
end
end
end
end
@testitem "Termination Conditions: Allocations" tags=[:cuda] begin
using CUDA, NonlinearSolveBase, Test, LinearAlgebra
CUDA.allowscalar(false)
du = cu(rand(4))
u = cu(rand(4))
uprev = cu(rand(4))
TERMINATION_CONDITIONS = [
RelTerminationMode, AbsTerminationMode
]
NORM_TERMINATION_CONDITIONS = [
AbsNormTerminationMode, RelNormTerminationMode, RelNormSafeTerminationMode,
AbsNormSafeTerminationMode, RelNormSafeBestTerminationMode, AbsNormSafeBestTerminationMode
]
@testset begin
@testset "Mode: $(tcond)" for tcond in TERMINATION_CONDITIONS
@test_nowarn NonlinearSolveBase.check_convergence(
tcond(), du, u, uprev, 1e-3, 1e-3)
end
@testset "Mode: $(tcond)" for tcond in NORM_TERMINATION_CONDITIONS
for nfn in (Base.Fix1(maximum, abs), Base.Fix2(norm, 2), Base.Fix2(norm, Inf))
@test_nowarn NonlinearSolveBase.check_convergence(
tcond(nfn), du, u, uprev, 1e-3, 1e-3)
end
end
end
end