From 6f47767666ea87fd3608732f9179714937076ccc Mon Sep 17 00:00:00 2001 From: Hendrych Date: Wed, 13 Nov 2024 16:26:39 +0100 Subject: [PATCH] Rather than deleting away vertices, filter in the argmin_max function. --- optimal_design_loop.jl | 234 +++++++++++++++++++++++++++++++++++++++++ src/dicg.jl | 11 +- src/utils.jl | 11 +- 3 files changed, 242 insertions(+), 14 deletions(-) create mode 100644 optimal_design_loop.jl diff --git a/optimal_design_loop.jl b/optimal_design_loop.jl new file mode 100644 index 000000000..e25183e72 --- /dev/null +++ b/optimal_design_loop.jl @@ -0,0 +1,234 @@ +# Showcases infinite loop for Lazy DICG on Optimal Design problem + +using FrankWolfe +using LinearAlgebra +using StableRNGs +using Distributions +using SparseArrays + +# Optimal Design functions +""" + build_data + +seed - for the Random Number Generator. +m - number of experiments. +""" +function build_data(seed, m) + # set up + rng = StableRNG(seed) + + n = Int(floor(m/10)) + B = rand(rng, m,n) + B = B'*B + @assert isposdef(B) + D = MvNormal(randn(rng, n),B) + + A = rand(rng, D, m)' + @assert rank(A) == n + + return A +end + +""" +Check if given point is in the domain of f, i.e. X = transpose(A) * diagm(x) * A +positive definite. +""" +function build_domain_oracle(A, n) + return function domain_oracle(x) + S = findall(x-> !iszero(x),x) + #@show rank(A[S,:]) == n + return rank(A[S,:]) == n #&& sum(x .< 0) == 0 + end +end + +""" +Find n linearly independent rows of A to build the starting point. +""" +function linearly_independent_rows(A) + S = [] + m, n = size(A) + for i in 1:m + S_i= vcat(S, i) + if rank(A[S_i,:])==length(S_i) + S=S_i + end + if length(S) == n # we only n linearly independent points + return S + end + end + return S +end + +""" +Build start point used in Boscia in case of A-opt and D-opt. +The functions are self concordant and so not every point in the feasible region +is in the domain of f and grad!. +""" +function build_start_point(A) + # Get n linearly independent rows of A + m, n = size(A) + S = linearly_independent_rows(A) + @assert length(S) == n + V = FrankWolfe.ScaledHotVector{Float64}[] + + for i in S + v = FrankWolfe.ScaledHotVector(1.0, i, m) + push!(V, v) + end + + x = sum(V .* 1/n) + x = convert(SparseArrays.SparseVector, x) + active_set= FrankWolfe.ActiveSet(fill(1/n, n), V, x) + + return x, active_set, S +end + +# A Optimal +""" +Build function for the A-criterion. +""" +function build_a_criterion(A; μ=0.0, build_safe=true) + m, n = size(A) + a=m + domain_oracle = build_domain_oracle(A, n) + + function f_a(x) + X = transpose(A)*diagm(x)*A + Matrix(μ *I, n, n) + X = Symmetric(X) + U = cholesky(X) + X_inv = U \ I + return LinearAlgebra.tr(X_inv)/a + end + + function grad_a!(storage, x) + X = Symmetric(X*X) + F = cholesky(X) + for i in 1:length(x) + storage[i] = LinearAlgebra.tr(- (F \ A[i,:]) * transpose(A[i,:]))/a + end + return storage #float.(storage) # in case of x .= BigFloat(x) + end + + function f_a_safe(x) + if !domain_oracle(x) + return Inf + end + X = transpose(A)*diagm(x)*A + Matrix(μ *I, n, n) + X = Symmetric(X) + X_inv = LinearAlgebra.inv(X) + return LinearAlgebra.tr(X_inv)/a + end + + function grad_a_safe!(storage, x) + if !domain_oracle(x) + return fill(Inf, length(x)) + end + #x = BigFloat.(x) # Setting can be useful for numerical tricky problems + X = transpose(A)*diagm(x)*A + Matrix(μ *I, n, n) + X = Symmetric(X*X) + F = cholesky(X) + for i in 1:length(x) + storage[i] = LinearAlgebra.tr(- (F \ A[i,:]) * transpose(A[i,:]))/a + end + return storage #float.(storage) # in case of x .= BigFloat(x) + end + + if build_safe + return f_a_safe, grad_a_safe! + end + + return f_a, grad_a! +end + +# D Optimal +""" +Build function for the D-criterion. +""" +function build_d_criterion(A; μ =0.0, build_safe=true) + m, n = size(A) + a=m + domain_oracle = build_domain_oracle(A, n) + + function f_d(x) + X = transpose(A)*diagm(x)*A + Matrix(μ *I, n, n) + X = Symmetric(X) + return -log(det(X))/a + end + + function grad_d!(storage, x) + X = transpose(A)*diagm(x)*A + Matrix(μ *I, n, n) + X= Symmetric(X) + F = cholesky(X) + for i in 1:length(x) + storage[i] = 1/a * LinearAlgebra.tr(-(F \ A[i,:] )*transpose(A[i,:])) + end + # https://stackoverflow.com/questions/46417005/exclude-elements-of-array-based-on-index-julia + return storage + end + + function f_d_safe(x) + if !domain_oracle(x) + return Inf + end + X = transpose(A)*diagm(x)*A + Matrix(μ *I, n, n) + X = Symmetric(X) + return -log(det(X))/a + end + + function grad_d_safe!(storage, x) + if !domain_oracle(x) + return fill(Inf, length(x)) + end + X = transpose(A)*diagm(x)*A + Matrix(μ *I, n, n) + X= Symmetric(X) + F = cholesky(X) + for i in 1:length(x) + storage[i] = 1/a * LinearAlgebra.tr(-(F \ A[i,:] )*transpose(A[i,:])) + end + # https://stackoverflow.com/questions/46417005/exclude-elements-of-array-based-on-index-julia + return storage + end + + if build_safe + return f_d_safe, grad_d_safe! + end + + return f_d, grad_d! +end + +""" +Returns FW args for D-Criterion +""" +function build_d_opt(; n=100, seed=1234) + A = build_data(seed, n) + f, grad! = build_d_criterion(A) + + lmo = FrankWolfe.ProbabilitySimplexOracle(1.0) + x0, _ = build_start_point(A) + + return f, grad!, lmo, x0 +end + +""" +Returns FW args for A-Criterion +""" +function build_a_opt(; n=100, seed=1234) + A = build_data(seed, n) + A = A + f, grad! = build_a_criterion(A) + + lmo = FrankWolfe.ProbabilitySimplexOracle(1.0) + x0, _ = build_start_point(A) + + return f, grad!, lmo, x0 +end + +println("D_Criterion") +f, grad!, lmo, x0 = build_d_opt(n=250, seed=95781326) +FrankWolfe.decomposition_invariant_conditional_gradient(f, grad!, lmo, copy(x0), verbose=true, max_iteration=2000, print_iter=100, lazy=true); + +println("A-Criterion") +f, grad!, lmo, x0 = build_a_opt(n=250, seed=95781326) +FrankWolfe.decomposition_invariant_conditional_gradient(f, grad!, lmo, copy(x0), verbose=true, max_iteration=10000, print_iter=1000, lazy=true); + +println("Fin.") diff --git a/src/dicg.jl b/src/dicg.jl index b2fb6c1f8..ee2d11792 100644 --- a/src/dicg.jl +++ b/src/dicg.jl @@ -198,7 +198,7 @@ function decomposition_invariant_conditional_gradient( linesearch_workspace, memory_mode, ) - + if lazy idx = findfirst(x -> x == v, pre_computed_set) if idx === nothing @@ -228,13 +228,6 @@ function decomposition_invariant_conditional_gradient( end end x = muladd_memory_mode(memory_mode, x, gamma, d) - if lazy - x_inds = setdiff(collect(1:length(x)), SparseArrays.nonzeroinds(x)) - if !iszero(a[x_inds]) - idy = findfirst(x -> x == a, pre_computed_set) - idy !== nothing ? deleteat!(pre_computed_set, idy) : nothing - end - end end # recompute everything once more for final verfication / do not record to trajectory though @@ -504,7 +497,7 @@ function lazy_dicg_step( memory_mode::MemoryEmphasis=InplaceEmphasis(), ) v_local, v_local_loc, val, a_local, a_local_loc, valM = - pre_computed_set_argminmax(pre_computed_set, gradient) + pre_computed_set_argminmax(pre_computed_set, gradient, x) step_type = ST_REGULAR away_index = nothing fw_index = nothing diff --git a/src/utils.jl b/src/utils.jl index f7f071a35..01b0a7a0a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -382,9 +382,10 @@ Base.length(storage::DeletedVertexStorage) = length(storage.storage) Computes the linear minimizer in the direction on the precomputed_set. Precomputed_set stores the vertices computed as extreme points v in each iteration. """ -function pre_computed_set_argminmax(pre_computed_set, direction) +function pre_computed_set_argminmax(pre_computed_set, direction, x) val = convert(eltype(direction), Inf) valM = convert(eltype(direction), -Inf) + x_zeros = setdiff(collect(1:length(x)), SparseArrays.nonzeroinds(x)) idx = -1 idxM = -1 for i in eachindex(pre_computed_set) @@ -393,16 +394,16 @@ function pre_computed_set_argminmax(pre_computed_set, direction) val = temp_val idx = i end - if valM < temp_val + if iszero(pre_computed_set[i][x_zeros]) && valM > temp_val valM = temp_val idxM = i end end - if idx == -1 || idxM == -1 - error("Infinite minimum $val or maximum $valM in the precomputed set. Does the gradient contain invalid (NaN / Inf) entries?") + if idx == -1 + error("Infinite minimum $val in the precomputed set. Does the gradient contain invalid (NaN / Inf) entries?") end v_local = pre_computed_set[idx] - a_local = pre_computed_set[idxM] + a_local = idxM != -1 ? pre_computed_set[idxM] : nothing return (v_local, idx, val, a_local, idxM, valM) end