Skip to content

Commit 8c4b379

Browse files
authored
Merge pull request #4 from slimgroup/f32
prevent f32 overflow
2 parents 2be0d1f + ded0928 commit 8c4b379

File tree

9 files changed

+30
-30
lines changed

9 files changed

+30
-30
lines changed

.github/workflows/ci-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
fail-fast: false
2020

2121
matrix:
22-
version: ['1.5', '1.6', '1.7', 'nightly']
22+
version: ['1.6', '1.7']
2323
os: [ubuntu-latest]
2424

2525
include:

src/PARSDMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ for i=1:maxit #main loop
247247
end #end Q-update timer
248248

249249
if i==maxit
250-
println("PARSDMM reached maxit")
250+
constr_log("PARSDMM reached maxit")
251251
(TD_OP,AtA,log_PARSDMM) = output_check_PARSDMM(x,TD_OP,AtA,log_PARSDMM,i,counter)
252252
end
253253

src/PARSDMM_initialize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ function PARSDMM_initialize(
9999
end
100100
end
101101
if maximum(feasibility_initial)<options.feas_tol #accept input as feasible and return
102-
println("input to PARSDMM is feasible, returning")
102+
constr_log("input to PARSDMM is feasible, returning")
103103
stop = true
104104
end
105105

106106
# if one of the sets is non-convex, use different lambda and rho update frequency, don't update gamma and set a different fixed gamma
107107
for ii=1:pp
108108
if set_Prop.ncvx[ii] == true
109-
println("non-convex set(s) involved, using special settings")
109+
constr_log("non-convex set(s) involved, using special settings")
110110
rho_update_frequency = 3;
111111
adjust_gamma = false
112112
gamma_ini = TF(0.75)

src/SetIntersectionProjection.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ using TimerOutputs
2121

2222
export log_type_PARSDMM, set_properties, PARSDMM_options, set_definitions
2323

24+
const _verbose = false
25+
constr_log(msg...) = _verbose ? nothing : println(msg...)
26+
2427
#main scripts
2528
include("PARSDMM.jl")
2629
include("PARSDMM_multi_level.jl")

src/cg.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ function cg(A::Function,b::Vector{TF}; tol::Real=1e-2,maxIter::Integer=100,M::Fu
6262

6363

6464
if out==2
65-
println("=== cg ===")
66-
println(@sprintf("%4s\t%7s","iter","relres"))
65+
constr_log("=== cg ===")
66+
constr_log(@sprintf("%4s\t%7s","iter","relres"))
6767
end
6868

6969
resvec = zeros(TF,maxIter)
@@ -99,7 +99,7 @@ function cg(A::Function,b::Vector{TF}; tol::Real=1e-2,maxIter::Integer=100,M::Fu
9999
#resvec[iter] = BLAS.nrm2(n, r, 1) / nr0#
100100
resvec[iter] = norm(r)/nr0
101101
if out==2
102-
println(iter,resvec[iter])
102+
constr_log(iter,resvec[iter])
103103
end
104104
if resvec[iter] <= tol
105105
flag = 0; break
@@ -116,12 +116,12 @@ function cg(A::Function,b::Vector{TF}; tol::Real=1e-2,maxIter::Integer=100,M::Fu
116116

117117
if out>=0
118118
if flag==-1
119-
println("cg iterated maxIter (=%d) times but reached only residual norm %1.2e instead of tol=%1.2e.",
119+
constr_log("cg iterated maxIter (=%d) times but reached only residual norm %1.2e instead of tol=%1.2e.",
120120
maxIter,resvec[lastIter],tol)
121121
elseif flag==-2
122-
println("Matrix A in cg has to be positive definite.")
122+
constr_log("Matrix A in cg has to be positive definite.")
123123
elseif flag==0 && out>=1
124-
println("cg achieved desired tolerance at iteration %d. Residual norm is %1.2e.",lastIter,resvec[lastIter])
124+
constr_log("cg achieved desired tolerance at iteration %d. Residual norm is %1.2e.",lastIter,resvec[lastIter])
125125
end
126126
end
127127
return x,flag,resvec[lastIter],lastIter,resvec[1:lastIter]
@@ -194,12 +194,12 @@ end
194194
#
195195
# if out>=0
196196
# if flag==-1
197-
# println("cg iterated maxIter (=%d) times but reached only residual norm %1.2e instead of tol=%1.2e.",
197+
# constr_log("cg iterated maxIter (=%d) times but reached only residual norm %1.2e instead of tol=%1.2e.",
198198
# maxIter,resvec[lastIter],tol)
199199
# elseif flag==-2
200-
# println("Matrix A in cg has to be positive definite.")
200+
# constr_log("Matrix A in cg has to be positive definite.")
201201
# elseif flag==0 && out>=1
202-
# println("cg achieved desired tolerance at iteration %d. Residual norm is %1.2e.",lastIter,resvec[lastIter])
202+
# constr_log("cg achieved desired tolerance at iteration %d. Residual norm is %1.2e.",lastIter,resvec[lastIter])
203203
# end
204204
# end
205205

src/default_PARSDMM_options.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ export default_PARSDMM_options
33
"""
44
Returns a set of default options for the PARSDMM solver
55
"""
6-
function default_PARSDMM_options(options,TF)
6+
function default_PARSDMM_options(options,TF; verbose=false)
77

88
if TF == Float64
99
TI = Int64
@@ -28,5 +28,7 @@ function default_PARSDMM_options(options,TF)
2828
options.parallel = false #comput proximal mappings, multiplier updates, rho and gamma updates in parallel
2929
options.zero_ini_guess = true #zero initial guess for primal, auxilliary, and multipliers
3030
Minkowski = false #the intersection of sets includes a Minkowski set
31+
32+
_verbose = verbose
3133
return options
3234
end

src/projectors/project_l1_Duchi!.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function project_l1_Duchi!(v::Union{Vector{TF},Vector{Complex{TF}}}, b::TF) wher
2626
u = similar(v)
2727
sv = Vector{TF}(undef, lv)
2828

29-
#use RadixSort for Float32 (short keywords)
29+
# use RadixSort for Float32 (short keywords)
3030
copyto!(u, v)
3131
u .= abs.(u)
3232
u = convert(Vector{TF},u)
@@ -35,19 +35,14 @@ function project_l1_Duchi!(v::Union{Vector{TF},Vector{Complex{TF}}}, b::TF) wher
3535
else
3636
u = sort!(u, rev=true, alg=QuickSort)
3737
end
38-
39-
40-
# if TF==Float32
41-
# u = sort!(abs.(u), rev=true, alg=RadixSort)
42-
# else
43-
# u = sort!(abs.(u), rev=true, alg=QuickSort)
44-
# end
45-
4638
cumsum!(sv, u)
4739

4840
# Thresholding level
49-
temp = TF(1.0):TF(1.0):TF(lv)
50-
rho = max(1, min(lv, findlast(u .> ((sv.-b) ./ temp ) ) ))::Int
41+
rho = 0
42+
while u[rho+1] > ((sv[rho+1] - b)/(rho+1)) && (rho+1) < lv
43+
rho += 1
44+
end
45+
rho = max(1, rho)
5146
theta = max.(TF(0) , (sv[rho] .- b) ./ rho)::TF
5247

5348
# Projection as soft thresholding

src/setup_multi_level_PARSDMM.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ for i=2:n_levels
8787
constraint_level = constraint2coarse(constraint_level,comp_grid_levels[i],coarsening_factor)
8888

8989
#set up constraints on new level
90-
println(TF)
90+
constr_log(TF)
9191
(P_sub_l,TD_OP_l,set_Prop_l) = setup_constraints(constraint_level,comp_grid_levels[i],TF)
9292
(TD_OP_l,AtA_l,dummy1,dummy2) = PARSDMM_precompute_distribute(TD_OP_l,set_Prop_l,comp_grid_levels[i],options)
9393

src/stop_PARSDMM.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,19 @@ function stop_PARSDMM(
2121

2222
#stop if objective value does not change and x is sufficiently feasible for all sets
2323
if i>6 && maximum(log_PARSDMM.set_feasibility[counter-1,:])<feas_tol && maximum(abs.( (log_PARSDMM.obj[i-5:i]-log_PARSDMM.obj[i-1-5:i-1])./log_PARSDMM.obj[i-1-5:i-1] )) < obj_tol
24-
println("stationary objective and reached feasibility, exiting PARSDMM (iteration ",i,")")
24+
constr_log("stationary objective and reached feasibility, exiting PARSDMM (iteration ",i,")")
2525
stop=true;
2626
end
2727

2828
#stop if x doesn't change significantly anyjore
2929
if i>5 && maximum(log_PARSDMM.evol_x[i-5:i])<evol_rel_tol
30-
println("relative evolution to small, exiting PARSDMM (iteration ",i,")")
30+
constr_log("relative evolution to small, exiting PARSDMM (iteration ",i,")")
3131
stop=true;
3232
end
3333

3434
# fix rho to ensure regular ADMM convergence if primal residual does not decrease over a 20 iteration window
3535
if i>20 && adjust_rho==true && log_PARSDMM.r_pri_total[i]>maximum(log_PARSDMM.r_pri_total[(i-1):-1:max((i-50),1)])
36-
println("no primal residual reduction, fixing PARSDMM rho & gamma (iteration ",i,")")
36+
constr_log("no primal residual reduction, fixing PARSDMM rho & gamma (iteration ",i,")")
3737
adjust_rho = false;
3838
adjust_feasibility_rho = false;
3939
adjust_gamma = false;
@@ -47,7 +47,7 @@ function stop_PARSDMM(
4747

4848
#if rho is fixed and still no decrease in primal residual is observed over a window, we give up
4949
if adjust_rho==false && i>(ind_ref+25) && log_PARSDMM.r_pri_total[i]>maximum(log_PARSDMM.r_pri_total[(i-1):-1:max(ind_ref,max((i-50),1))])
50-
println("no primal residual reduction, exiting PARSDMM (iteration ",i,")")
50+
constr_log("no primal residual reduction, exiting PARSDMM (iteration ",i,")")
5151
stop = true;
5252
end
5353
return stop,adjust_rho,adjust_gamma,adjust_feasibility_rho,ind_ref

0 commit comments

Comments
 (0)