Skip to content

Commit

Permalink
Merge pull request #29 from SciML/RoA_approximation
Browse files Browse the repository at this point in the history
Added support for smooth loss functions to `RoAAwareDecreaseCondition`
  • Loading branch information
nicholaskl97 authored Jun 14, 2024
2 parents ab3f142 + a963c33 commit 24d9409
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
45 changes: 40 additions & 5 deletions src/decrease_conditions_RoA_aware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ attraction estimate of ``\\{ x : V(x) ≤ ρ \\}``.
`check_decrease == true`.
- `rectifier::Function`: positive when the input is positive and (approximately) zero when
the input is negative.
- `sigmoid::Function`: approximately one when the input is positive and approximately zero
when the input is negative.
- `ρ`: the level of the sublevel set forming the estimate of the region of attraction.
- `out_of_RoA_penalty::Function`: a loss function to be applied penalizing points outside
the sublevel set forming the region of attraction estimate.
# Training conditions
If `check_decrease == true`, training will enforce
If `check_decrease == true`, training will attempt to enforce
``\\texttt{rate\\_metric}(V(x), V̇(x)) ≤ - \\texttt{strength}(x, x_0)``
Expand All @@ -37,6 +39,15 @@ The inequality will be approximated by the equation
Note that the approximate equation and inequality are identical when
``\\texttt{rectifier}(t) = \\max(0, t)``.
The `sigmoid` function allows for a smooth transition between the ``V(x) ≤ ρ`` case and the
``V(x) > ρ`` case, by combining the above equations into one:
``\\texttt{sigmoid}(ρ - V(x)) (\\text{in-RoA expression}) + \\texttt{sigmoid}(V(x) - ρ) (\\text{out-of-RoA expression}) = 0``.
Note that a hard transition, which only enforces the in-RoA equation when ``V(x) ≤ ρ`` and
the out-of-RoA equation when ``V(x) > ρ`` can be provided by a `sigmoid` which is exactly
one when the input is nonnegative and exactly zero when the input is negative.
If the dynamics truly have a fixed point at ``x_0`` and ``V̇(x)`` is truly the rate of
decrease of ``V(x)`` along the dynamics, then ``V̇(x_0)`` will be ``0`` and there is no need
to train for ``V̇(x_0) = 0``.
Expand Down Expand Up @@ -72,12 +83,15 @@ well as ``ρ``.
In any of these cases, the rectified linear unit `rectifier = (t) -> max(zero(t), t)`
exactly represents the inequality, but differentiable approximations of this function may be
employed.
See also: [`LyapunovDecreaseCondition`](@ref)
"""
struct RoAAwareDecreaseCondition <: AbstractLyapunovDecreaseCondition
check_decrease::Bool
rate_metric::Function
strength::Function
rectifier::Function
sigmoid::Function
ρ::Real
out_of_RoA_penalty::Function
end
Expand All @@ -89,10 +103,10 @@ end
function get_decrease_condition(cond::RoAAwareDecreaseCondition)
if cond.check_decrease
return function (V, dVdt, x, fixed_point)
(V(x) cond.ρ) * cond.rectifier(
cond.sigmoid(cond.ρ - V(x)) * cond.rectifier(
cond.rate_metric(V(x), dVdt(x)) + cond.strength(x, fixed_point)
) +
(V(x) > cond.ρ) * cond.out_of_RoA_penalty(V(x), dVdt(x), x, fixed_point,
cond.sigmoid(V(x) - cond.ρ) * cond.out_of_RoA_penalty(V(x), dVdt(x), x, fixed_point,
cond.ρ)
end
else
Expand All @@ -101,29 +115,50 @@ function get_decrease_condition(cond::RoAAwareDecreaseCondition)
end

"""
make_RoA_aware(cond; ρ, out_of_RoA_penalty)
make_RoA_aware(cond; ρ, out_of_RoA_penalty, sigmoid)
Add awareness of the region of attraction (RoA) estimation task to the supplied
[`LyapunovDecreaseCondition`](@ref).
When estimating the region of attraction using a Lyapunov function, the decrease condition
only needs to be met within a bounded sublevel set ``\\{ x : V(x) ≤ ρ \\}``.
The returned [`RoAAwareDecreaseCondition`](@ref) enforces the decrease condition represented
by `cond` only in that sublevel set.
# Arguments
- `cond::LyapunovDecreaseCondition`: specifies the loss to be applied when ``V(x) ≤ ρ``.
- `ρ`: the target level such that the RoA will be ``\\{ x : V(x) ≤ ρ \\}``.
- `out_of_RoA_penalty::Function`: specifies the loss to be applied when ``V(x) > ρ``.
- `sigmoid::Function`: approximately one when the input is positive and approximately zero
when the input is negative.
The loss applied to samples ``x`` such that ``V(x) > ρ`` is
``\\lvert \\texttt{out\\_of\\_RoA\\_penalty}(V(x), V̇(x), x, x_0, ρ) \\rvert^2``.
The `sigmoid` function allows for a smooth transition between the ``V(x) ≤ ρ`` case and the
``V(x) > ρ`` case, by combining the above equations into one:
``\\texttt{sigmoid}(ρ - V(x)) (\\text{in-RoA expression}) + \\texttt{sigmoid}(V(x) - ρ) (\\text{out-of-RoA expression}) = 0``.
Note that a hard transition, which only enforces the in-RoA equation when ``V(x) ≤ ρ`` and
the out-of-RoA equation when ``V(x) > ρ`` can be provided by a `sigmoid` which is exactly
one when the input is nonnegative and exactly zero when the input is negative.
As such, the default value is `sigmoid(t) = t ≥ zero(t)`.
See also: [`RoAAwareDecreaseCondition`](@ref)
"""
function make_RoA_aware(
cond::LyapunovDecreaseCondition;
ρ = 1.0,
out_of_RoA_penalty = (V, dVdt, state, fixed_point, _ρ) -> 0.0
out_of_RoA_penalty = (V, dVdt, state, fixed_point, _ρ) -> 0.0,
sigmoid = (x) -> x .≥ zero.(x)
)::RoAAwareDecreaseCondition
RoAAwareDecreaseCondition(
cond.check_decrease,
cond.rate_metric,
cond.strength,
cond.rectifier,
sigmoid,
ρ,
out_of_RoA_penalty
)
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ using SafeTestsets
@time @safetestset "Local Lyapunov function search" begin
include("local_lyapunov.jl")
end
@time @safetestset "Errors for partially-implemented extensions" begin
include("unimplemented.jl")
end
end
18 changes: 18 additions & 0 deletions test/unimplemented.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using NeuralLyapunov
using Test

struct UnimplementedMinimizationCondition <: NeuralLyapunov.AbstractLyapunovMinimizationCondition end

cond = UnimplementedMinimizationCondition()

@test_throws ErrorException NeuralLyapunov.check_nonnegativity(cond)
@test_throws ErrorException NeuralLyapunov.check_minimal_fixed_point(cond)
@test_throws ErrorException NeuralLyapunov.get_minimization_condition(cond)


struct UnimplementedDecreaseCondition <: NeuralLyapunov.AbstractLyapunovDecreaseCondition end

cond = UnimplementedDecreaseCondition()

@test_throws ErrorException NeuralLyapunov.check_decrease(cond)
@test_throws ErrorException NeuralLyapunov.get_decrease_condition(cond)

0 comments on commit 24d9409

Please sign in to comment.