Skip to content

Commit 1445449

Browse files
feat(HomotopyContinuation): enable more performant jacobians with Enzyme (#528)
* feat(HomotopyContinuation): enable more performant jacobians with Enzyme * build: bump SciMLBase compat
1 parent 7bb635d commit 1445449

File tree

7 files changed

+322
-146
lines changed

7 files changed

+322
-146
lines changed

Diff for: lib/NonlinearSolveHomotopyContinuation/Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,22 @@ CommonSolve = "0.2.4"
2323
ConcreteStructs = "0.2.3"
2424
DifferentiationInterface = "0.6.27"
2525
DocStringExtensions = "0.9.3"
26+
Enzyme = "0.13"
2627
HomotopyContinuation = "2.12.0"
2728
LinearAlgebra = "1.10"
2829
NonlinearSolve = "4"
2930
NonlinearSolveBase = "1.3.3"
30-
SciMLBase = "2.71"
31+
SciMLBase = "2.72.2"
3132
SymbolicIndexingInterface = "0.3.36"
3233
TaylorDiff = "0.3.1"
3334
Test = "1.10"
3435
julia = "1.10"
3536

3637
[extras]
3738
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
39+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3840
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
3941
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4042

4143
[targets]
42-
test = ["Aqua", "Test", "NonlinearSolve"]
44+
test = ["Aqua", "Test", "NonlinearSolve", "Enzyme"]

Diff for: lib/NonlinearSolveHomotopyContinuation/src/NonlinearSolveHomotopyContinuation.jl

+5
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ end
6060

6161
HomotopyContinuationJL(; kwargs...) = HomotopyContinuationJL{false}(; kwargs...)
6262

63+
function HomotopyContinuationJL(alg::HomotopyContinuationJL{R}; kwargs...) where {R}
64+
HomotopyContinuationJL{R}(; autodiff = alg.autodiff, alg.kwargs..., kwargs...)
65+
end
66+
6367
include("interface_types.jl")
68+
include("jacobian_handling.jl")
6469
include("solve.jl")
6570

6671
end

Diff for: lib/NonlinearSolveHomotopyContinuation/src/interface_types.jl

+9-103
Original file line numberDiff line numberDiff line change
@@ -4,47 +4,6 @@ struct Inplace <: HomotopySystemVariant end
44
struct OutOfPlace <: HomotopySystemVariant end
55
struct Scalar <: HomotopySystemVariant end
66

7-
"""
8-
$(TYPEDEF)
9-
10-
A simple struct that wraps a polynomial function which takes complex input and returns
11-
complex output in a form that supports automatic differentiation. If the wrapped
12-
function if ``f: \\mathbb{C}^n \\rightarrow \\mathbb{C}^n`` then it is assumed
13-
the input arrays are real-valued and have length ``2n``. They are `reinterpret`ed
14-
into complex arrays and passed into the function. This struct has an in-place signature
15-
regardless of the signature of ``f``.
16-
"""
17-
@concrete struct ComplexJacobianWrapper{variant <: HomotopySystemVariant}
18-
f
19-
end
20-
21-
function (cjw::ComplexJacobianWrapper{Inplace})(
22-
u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
23-
x = reinterpret(Complex{T}, x)
24-
u = reinterpret(Complex{T}, u)
25-
cjw.f(u, x, p)
26-
u = parent(u)
27-
return u
28-
end
29-
30-
function (cjw::ComplexJacobianWrapper{OutOfPlace})(
31-
u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
32-
x = reinterpret(Complex{T}, x)
33-
u_tmp = cjw.f(x, p)
34-
u_tmp = reinterpret(T, u_tmp)
35-
copyto!(u, u_tmp)
36-
return u
37-
end
38-
39-
function (cjw::ComplexJacobianWrapper{Scalar})(
40-
u::AbstractVector{T}, x::AbstractVector{T}, p) where {T}
41-
x = reinterpret(Complex{T}, x)
42-
u_tmp = cjw.f(x[1], p)
43-
u[1] = real(u_tmp)
44-
u[2] = imag(u_tmp)
45-
return u
46-
end
47-
487
"""
498
$(TYPEDEF)
509
@@ -62,34 +21,24 @@ $(FIELDS)
6221
"""
6322
f
6423
"""
65-
The jacobian function, if provided to the `NonlinearProblem` being solved. Otherwise,
66-
a `ComplexJacobianWrapper` wrapping `f` used for automatic differentiation.
24+
A function which calculates both the polynomial and the jacobian. Must be a function
25+
of the form `f(u, U, x, p)` where `x` is the current unknowns and `p` is the parameter
26+
object, writing the value of the polynomial to `u` and the jacobian to `U`. Must be able
27+
to handle complex `x`.
6728
"""
6829
jac
6930
"""
7031
The parameter object.
7132
"""
7233
p
7334
"""
74-
The ADType for automatic differentiation.
75-
"""
76-
autodiff
77-
"""
78-
The result from `DifferentiationInterface.prepare_jacobian`.
79-
"""
80-
prep
81-
"""
8235
HomotopyContinuation.jl's symbolic variables for the system.
8336
"""
8437
vars
8538
"""
8639
The `TaylorDiff.TaylorScalar` objects used to compute the taylor series of `f`.
8740
"""
8841
taylorvars
89-
"""
90-
Preallocated intermediate buffers used for calculating the jacobian.
91-
"""
92-
jacobian_buffers
9342
end
9443

9544
Base.size(sys::HomotopySystemWrapper) = (length(sys.vars), length(sys.vars))
@@ -112,54 +61,11 @@ function HC.ModelKit.evaluate!(u, sys::HomotopySystemWrapper{Scalar}, x, p = not
11261
end
11362

11463
function HC.ModelKit.evaluate_and_jacobian!(
115-
u, U, sys::HomotopySystemWrapper{Inplace}, x, p = nothing)
116-
p = sys.p
117-
sys.f(u, x, p)
118-
sys.jac(U, x, p)
119-
return u, U
120-
end
121-
122-
function HC.ModelKit.evaluate_and_jacobian!(
123-
u, U, sys::HomotopySystemWrapper{OutOfPlace}, x, p = nothing)
124-
p = sys.p
125-
u_tmp = sys.f(x, p)
126-
copyto!(u, u_tmp)
127-
j_tmp = sys.jac(x, p)
128-
copyto!(U, j_tmp)
64+
u, U, sys::HomotopySystemWrapper, x, p = nothing)
65+
sys.jac(u, U, x, sys.p)
12966
return u, U
13067
end
13168

132-
function HC.ModelKit.evaluate_and_jacobian!(
133-
u, U, sys::HomotopySystemWrapper{Scalar}, x, p = nothing)
134-
p = sys.p
135-
u[1] = sys.f(x[1], p)
136-
U[1] = sys.jac(x[1], p)
137-
return u, U
138-
end
139-
140-
for V in (Inplace, OutOfPlace, Scalar)
141-
@eval function HC.ModelKit.evaluate_and_jacobian!(
142-
u, U, sys::HomotopySystemWrapper{$V, F, J}, x,
143-
p = nothing) where {F, J <: ComplexJacobianWrapper}
144-
p = sys.p
145-
U_tmp = sys.jacobian_buffers
146-
x = reinterpret(Float64, x)
147-
u = reinterpret(Float64, u)
148-
DI.value_and_jacobian!(sys.jac, u, U_tmp, sys.prep, sys.autodiff, x, DI.Constant(p))
149-
U = reinterpret(Float64, U)
150-
@inbounds for j in axes(U, 2)
151-
jj = 2j - 1
152-
for i in axes(U, 1)
153-
U[i, j] = U_tmp[i, jj]
154-
end
155-
end
156-
u = parent(u)
157-
U = parent(U)
158-
159-
return u, U
160-
end
161-
end
162-
16369
function update_taylorvars_from_taylorvector!(
16470
vars, x::HC.ModelKit.TaylorVector)
16571
for i in eachindex(x)
@@ -185,14 +91,14 @@ end
18591

18692
function check_taylor_equality(vars, x::HC.ModelKit.TaylorVector)
18793
for i in eachindex(x)
188-
TaylorDiff.flatten(vars[2i-1]) == map(real, x[i]) || return false
94+
TaylorDiff.flatten(vars[2i - 1]) == map(real, x[i]) || return false
18995
TaylorDiff.flatten(vars[2i]) == map(imag, x[i]) || return false
19096
end
19197
return true
19298
end
19399
function check_taylor_equality(vars, x::AbstractVector)
194100
for i in eachindex(x)
195-
TaylorDiff.value(vars[2i-1]) != real(x[i]) && return false
101+
TaylorDiff.value(vars[2i - 1]) != real(x[i]) && return false
196102
TaylorDiff.value(vars[2i]) != imag(x[i]) && return false
197103
end
198104
return true
@@ -212,7 +118,7 @@ function update_maybe_taylorvector_from_taylorvars!(
212118
for i in eachindex(vars)
213119
rval = TaylorDiff.flatten(real(buffer[i]))
214120
ival = TaylorDiff.flatten(imag(buffer[i]))
215-
u[i] = ntuple(i -> rval[i] + im * ival[i], Val(length(rval)))
121+
u[i] = ntuple(i -> rval[i] + im * ival[i], Val(length(rval)))
216122
end
217123
end
218124

0 commit comments

Comments
 (0)