Skip to content

Commit

Permalink
Rather than deleting away vertices, filter in the argmin_max function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hendrych committed Nov 13, 2024
1 parent a86fa88 commit 6f47767
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 14 deletions.
234 changes: 234 additions & 0 deletions optimal_design_loop.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# Showcases infinite loop for Lazy DICG on Optimal Design problem

using FrankWolfe
using LinearAlgebra
using StableRNGs
using Distributions
using SparseArrays

# Optimal Design functions
"""
build_data
seed - for the Random Number Generator.
m - number of experiments.
"""
function build_data(seed, m)
# set up
rng = StableRNG(seed)

n = Int(floor(m/10))
B = rand(rng, m,n)
B = B'*B
@assert isposdef(B)
D = MvNormal(randn(rng, n),B)

A = rand(rng, D, m)'
@assert rank(A) == n

return A
end

"""
Check if given point is in the domain of f, i.e. X = transpose(A) * diagm(x) * A
positive definite.
"""
function build_domain_oracle(A, n)
return function domain_oracle(x)
S = findall(x-> !iszero(x),x)
#@show rank(A[S,:]) == n
return rank(A[S,:]) == n #&& sum(x .< 0) == 0
end
end

"""
Find n linearly independent rows of A to build the starting point.
"""
function linearly_independent_rows(A)
S = []
m, n = size(A)
for i in 1:m
S_i= vcat(S, i)
if rank(A[S_i,:])==length(S_i)
S=S_i
end
if length(S) == n # we only n linearly independent points
return S
end
end
return S
end

"""
Build start point used in Boscia in case of A-opt and D-opt.
The functions are self concordant and so not every point in the feasible region
is in the domain of f and grad!.
"""
function build_start_point(A)
# Get n linearly independent rows of A
m, n = size(A)
S = linearly_independent_rows(A)
@assert length(S) == n
V = FrankWolfe.ScaledHotVector{Float64}[]

for i in S
v = FrankWolfe.ScaledHotVector(1.0, i, m)
push!(V, v)
end

x = sum(V .* 1/n)
x = convert(SparseArrays.SparseVector, x)
active_set= FrankWolfe.ActiveSet(fill(1/n, n), V, x)

return x, active_set, S
end

# A Optimal
"""
Build function for the A-criterion.
"""
function build_a_criterion(A; μ=0.0, build_safe=true)
m, n = size(A)
a=m
domain_oracle = build_domain_oracle(A, n)

function f_a(x)
X = transpose(A)*diagm(x)*A + Matrix*I, n, n)
X = Symmetric(X)
U = cholesky(X)
X_inv = U \ I
return LinearAlgebra.tr(X_inv)/a
end

function grad_a!(storage, x)
X = Symmetric(X*X)
F = cholesky(X)
for i in 1:length(x)
storage[i] = LinearAlgebra.tr(- (F \ A[i,:]) * transpose(A[i,:]))/a
end
return storage #float.(storage) # in case of x .= BigFloat(x)
end

function f_a_safe(x)
if !domain_oracle(x)
return Inf
end
X = transpose(A)*diagm(x)*A + Matrix*I, n, n)
X = Symmetric(X)
X_inv = LinearAlgebra.inv(X)
return LinearAlgebra.tr(X_inv)/a
end

function grad_a_safe!(storage, x)
if !domain_oracle(x)
return fill(Inf, length(x))
end
#x = BigFloat.(x) # Setting can be useful for numerical tricky problems
X = transpose(A)*diagm(x)*A + Matrix*I, n, n)
X = Symmetric(X*X)
F = cholesky(X)
for i in 1:length(x)
storage[i] = LinearAlgebra.tr(- (F \ A[i,:]) * transpose(A[i,:]))/a
end
return storage #float.(storage) # in case of x .= BigFloat(x)
end

if build_safe
return f_a_safe, grad_a_safe!
end

return f_a, grad_a!
end

# D Optimal
"""
Build function for the D-criterion.
"""
function build_d_criterion(A; μ =0.0, build_safe=true)
m, n = size(A)
a=m
domain_oracle = build_domain_oracle(A, n)

function f_d(x)
X = transpose(A)*diagm(x)*A + Matrix*I, n, n)
X = Symmetric(X)
return -log(det(X))/a
end

function grad_d!(storage, x)
X = transpose(A)*diagm(x)*A + Matrix*I, n, n)
X= Symmetric(X)
F = cholesky(X)
for i in 1:length(x)
storage[i] = 1/a * LinearAlgebra.tr(-(F \ A[i,:] )*transpose(A[i,:]))
end
# https://stackoverflow.com/questions/46417005/exclude-elements-of-array-based-on-index-julia
return storage
end

function f_d_safe(x)
if !domain_oracle(x)
return Inf
end
X = transpose(A)*diagm(x)*A + Matrix*I, n, n)
X = Symmetric(X)
return -log(det(X))/a
end

function grad_d_safe!(storage, x)
if !domain_oracle(x)
return fill(Inf, length(x))
end
X = transpose(A)*diagm(x)*A + Matrix*I, n, n)
X= Symmetric(X)
F = cholesky(X)
for i in 1:length(x)
storage[i] = 1/a * LinearAlgebra.tr(-(F \ A[i,:] )*transpose(A[i,:]))
end
# https://stackoverflow.com/questions/46417005/exclude-elements-of-array-based-on-index-julia
return storage
end

if build_safe
return f_d_safe, grad_d_safe!
end

return f_d, grad_d!
end

"""
Returns FW args for D-Criterion
"""
function build_d_opt(; n=100, seed=1234)
A = build_data(seed, n)
f, grad! = build_d_criterion(A)

lmo = FrankWolfe.ProbabilitySimplexOracle(1.0)
x0, _ = build_start_point(A)

return f, grad!, lmo, x0
end

"""
Returns FW args for A-Criterion
"""
function build_a_opt(; n=100, seed=1234)
A = build_data(seed, n)
A = A
f, grad! = build_a_criterion(A)

lmo = FrankWolfe.ProbabilitySimplexOracle(1.0)
x0, _ = build_start_point(A)

return f, grad!, lmo, x0
end

println("D_Criterion")
f, grad!, lmo, x0 = build_d_opt(n=250, seed=95781326)
FrankWolfe.decomposition_invariant_conditional_gradient(f, grad!, lmo, copy(x0), verbose=true, max_iteration=2000, print_iter=100, lazy=true);

println("A-Criterion")
f, grad!, lmo, x0 = build_a_opt(n=250, seed=95781326)
FrankWolfe.decomposition_invariant_conditional_gradient(f, grad!, lmo, copy(x0), verbose=true, max_iteration=10000, print_iter=1000, lazy=true);

println("Fin.")
11 changes: 2 additions & 9 deletions src/dicg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ function decomposition_invariant_conditional_gradient(
linesearch_workspace,
memory_mode,
)

if lazy
idx = findfirst(x -> x == v, pre_computed_set)
if idx === nothing
Expand Down Expand Up @@ -228,13 +228,6 @@ function decomposition_invariant_conditional_gradient(
end
end
x = muladd_memory_mode(memory_mode, x, gamma, d)
if lazy
x_inds = setdiff(collect(1:length(x)), SparseArrays.nonzeroinds(x))
if !iszero(a[x_inds])
idy = findfirst(x -> x == a, pre_computed_set)
idy !== nothing ? deleteat!(pre_computed_set, idy) : nothing
end
end
end

# recompute everything once more for final verfication / do not record to trajectory though
Expand Down Expand Up @@ -504,7 +497,7 @@ function lazy_dicg_step(
memory_mode::MemoryEmphasis=InplaceEmphasis(),
)
v_local, v_local_loc, val, a_local, a_local_loc, valM =
pre_computed_set_argminmax(pre_computed_set, gradient)
pre_computed_set_argminmax(pre_computed_set, gradient, x)
step_type = ST_REGULAR
away_index = nothing
fw_index = nothing
Expand Down
11 changes: 6 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,10 @@ Base.length(storage::DeletedVertexStorage) = length(storage.storage)
Computes the linear minimizer in the direction on the precomputed_set.
Precomputed_set stores the vertices computed as extreme points v in each iteration.
"""
function pre_computed_set_argminmax(pre_computed_set, direction)
function pre_computed_set_argminmax(pre_computed_set, direction, x)
val = convert(eltype(direction), Inf)
valM = convert(eltype(direction), -Inf)
x_zeros = setdiff(collect(1:length(x)), SparseArrays.nonzeroinds(x))
idx = -1
idxM = -1
for i in eachindex(pre_computed_set)
Expand All @@ -393,16 +394,16 @@ function pre_computed_set_argminmax(pre_computed_set, direction)
val = temp_val
idx = i
end
if valM < temp_val
if iszero(pre_computed_set[i][x_zeros]) && valM > temp_val
valM = temp_val
idxM = i
end
end
if idx == -1 || idxM == -1
error("Infinite minimum $val or maximum $valM in the precomputed set. Does the gradient contain invalid (NaN / Inf) entries?")
if idx == -1
error("Infinite minimum $val in the precomputed set. Does the gradient contain invalid (NaN / Inf) entries?")
end
v_local = pre_computed_set[idx]
a_local = pre_computed_set[idxM]
a_local = idxM != -1 ? pre_computed_set[idxM] : nothing
return (v_local, idx, val, a_local, idxM, valM)
end

Expand Down

0 comments on commit 6f47767

Please sign in to comment.