From 4218a77b277b12d58054a8e602606542c1dfae57 Mon Sep 17 00:00:00 2001 From: kaandocal <26488673+kaandocal@users.noreply.github.com> Date: Wed, 1 Sep 2021 15:59:42 +0100 Subject: [PATCH 1/2] Testing new progress bars --- src/logging.jl | 131 +++++++++++++++++++++++++++++++++++++++++++++++++ src/sample.jl | 117 +++++++++++++++++++++++++++---------------- 2 files changed, 207 insertions(+), 41 deletions(-) diff --git a/src/logging.jl b/src/logging.jl index a550c532..9c69e4c8 100644 --- a/src/logging.jl +++ b/src/logging.jl @@ -43,3 +43,134 @@ function progresslogger() return TerminalLoggers.TerminalLogger() end end + +## + +using ProgressLogging: _asprogress, ProgressString, ProgressLevel, uuid4 +using Base.Meta: isexpr + +const _id_var_children = gensym(:progress_id_children) +const _name_var_children = gensym(:progress_name_children) + +macro progressid_child(N) + id_err = "`@progressid_child` must be used inside `@withprogress_children`" + + quote + $Base.@isdefined($_id_var_children) ? $_id_var_children[$N] : $error($id_err) + end |> esc +end + +macro logprogress_child(args...) + _logprogress_child(args...) +end + +function _logprogress_child(N, name, progress = nothing, args...) + name_expr = :($Base.@isdefined($_name_var_children) ? $_name_var_children[$N] : "") + if progress == nothing + # Handle: @logprogress progress + kwargs = (:(progress = $name), args...) + progress = name + name = name_expr + elseif isexpr(progress, :(=)) && progress.args[1] isa Symbol + # Handle: @logprogress progress key1=val1 ... + kwargs = (:(progress = $name), progress, args...) + progress = name + name = name_expr + else + # Otherwise, it's: @logprogress name progress key1=val1 ... + kwargs = (:(progress = $progress), args...) + end + + id_err = "`@logprogress_child` must be used inside `@withprogress_children`" + id_expr = :($Base.@isdefined($_id_var_children) ? $_id_var_children[$N] : $error($id_err)) + + @gensym id_tmp + # Emitting progress log record as old/open API (i.e., using + # `progress` key) and _also_ as new API based on `Progress` type. + msgexpr = :($ProgressString($_asprogress( + $name, + $id_tmp, + $(ProgressLogging._id_var); + progress = $progress, + ))) + quote + $id_tmp = $id_expr + $Logging.@logmsg($ProgressLevel, $msgexpr, $(kwargs...), _id = $id_tmp) + end |> esc +end + +macro withprogress_children(N, exprs...) + _withprogress_children(N, exprs...) +end + +function _withprogress_children(N, exprs...) + length(exprs) == 0 && + throw(ArgumentError("`@withprogress_children` requires at least one number and one expression.")) + + m = ProgressLogging.@__MODULE__ + + kwargs = Dict{Symbol,Any}(:names => :(["" for i in 1:$N])) + unsupported = [] + for kw in exprs[1:end-1] + if isexpr(kw, :(=)) && length(kw.args) == 2 && haskey(kwargs, kw.args[1]) + kwargs[kw.args[1]] = kw.args[2] + else + push!(unsupported, kw) + end + end + + # Error on invalid input expressions: + if !isempty(unsupported) + msg = sprint() do io + println(io, "Unsupported optional arguments:") + for kw in unsupported + println(io, kw) + end + print(io, "`@withprogress_children` supports only following keyword arguments: ") + join(io, keys(kwargs), ", ") + end + throw(ArgumentError(msg)) + end + + ex = exprs[end] + id_err = "`@withprogress_children` must be used inside `@withprogress`" + + i_var = gensym() + quote + let + $Base.@isdefined($(ProgressLogging._id_var)) || $error($id_err) + $_id_var_children = [ $uuid4() for i in 1:$N ] + $_name_var_children = $(kwargs[:names]) + + for $i_var in 1:$N + @logprogress_child $i_var nothing + end + + try + $ex + finally + for $i_var in 1:$N + @logprogress_child $i_var nothing + end + end + end + end |> esc +end + +macro ifwithprogresslogger_children(progress, exprs...) + return quote + if $progress + if $hasprogresslevel($Logging.current_logger()) + @withprogress_children $(exprs...) + else + $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do + @withprogress_children $(exprs...) + end + end + else + $(exprs[end]) + end + end |> esc +end + + diff --git a/src/sample.jl b/src/sample.jl index df76caf0..9af182cd 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -275,6 +275,9 @@ function mcmcsample( ) end + +using ProgressLogging: Progress + function mcmcsample( rng::Random.AbstractRNG, model::AbstractModel, @@ -283,7 +286,8 @@ function mcmcsample( N::Integer, nchains::Integer; progress = PROGRESS[], - progressname = "Sampling ($(min(nchains, Threads.nthreads())) threads)", + progressname = "Sampling", + callback=nothing, kwargs... ) # Check if actually multiple threads are used. @@ -312,51 +316,81 @@ function mcmcsample( chains = Vector{Any}(undef, nchains) @ifwithprogresslogger progress name=progressname begin - # Create a channel for progress logging. - if progress - channel = Channel{Bool}(length(interval)) - end - - Distributed.@sync begin - if progress - # Update the progress bar. - Distributed.@async begin - # Determine threshold values for progress logging - # (one update per 0.5% of progress) - threshold = nchains ÷ 200 - nextprogresschains = threshold - - progresschains = 0 - while take!(channel) - progresschains += 1 - if progresschains >= nextprogresschains - ProgressLogging.@logprogress progresschains/nchains - nextprogresschains = progresschains + threshold + threshold = nchains ÷ 200 + itotals = zeros(Int, nchains) + next_updates = [ threshold for i in 1:nchains ] + + threshold_chains = nchains ÷ 200 + nextprogress_chains = threshold_chains + + progress_chains = 0 + + chain_names = ["$progressname (Chain $i of $nchains)" for i in 1:nchains] + @ifwithprogresslogger_children progress nchains names=chain_names begin + # Create a channel for progress logging. + channel = Channel{Int}(length(interval) * 10) + + Distributed.@sync begin + if progress + # Update the progress bar. + Distributed.@async begin + # Determine threshold values for progress logging + # (one update per 0.5% of progress) + threshold = nchains ÷ 200 + nextprogresschains = threshold + + progresschains = 0 + while (i = take!(channel)) != 0 + if i > 0 + itotals[i] += 1 + if itotals[1] >= next_updates[i] + @logprogress_child i itotals[i] / N + next_updates[i] = itotals[1] + threshold + end + else + i = -i + @logprogress_child i "done" + progresschains += 1 + + if progresschains >= nextprogresschains + ProgressLogging.@logprogress progresschains/nchains + nextprogresschains = progresschains + threshold_chains + end + end end end end - end - - Distributed.@async begin - try - Threads.@threads for i in 1:nchains - # Obtain the ID of the current thread. - id = Threads.threadid() - - # Seed the thread-specific random number generator with the pre-made seed. - subrng = rngs[id] - Random.seed!(subrng, seeds[i]) - # Sample a chain and save it to the vector. - chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N; - progress = false, kwargs...) - - # Update the progress bar. - progress && put!(channel, true) + Distributed.@async begin + try + Threads.@threads for i in 1:nchains + # Obtain the ID of the current thread. + id = Threads.threadid() + + # Seed the thread-specific random number generator with the pre-made seed. + subrng = rngs[id] + Random.seed!(subrng, seeds[i]) + + if progress + if callback isa Nothing + callback_i = (args...; kwargs...) -> put!(channel, i) + else + callback_i = (args...; kwargs...) -> begin put!(channel, i); callback(args...; kwargs...) end + end + else + callback_i = callback + end + + # Sample a chain and save it to the vector. + chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N; + progress = false, callback=callback_i, kwargs...) + + progress && put!(channel, -i) + end + finally + # Stop updating the progress bar. + progress && put!(channel, 0) end - finally - # Stop updating the progress bar. - progress && put!(channel, false) end end end @@ -366,6 +400,7 @@ function mcmcsample( return chainsstack(tighten_eltype(chains)) end + function mcmcsample( rng::Random.AbstractRNG, model::AbstractModel, From 349799146c2cfe85d24a1efbfbc6b162f8850ca4 Mon Sep 17 00:00:00 2001 From: kaandocal <26488673+kaandocal@users.noreply.github.com> Date: Wed, 1 Sep 2021 16:23:06 +0100 Subject: [PATCH 2/2] Bugfixes --- src/sample.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 9af182cd..7c995bfd 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -286,7 +286,7 @@ function mcmcsample( N::Integer, nchains::Integer; progress = PROGRESS[], - progressname = "Sampling", + progressname = "Sampling ($(min(nchains, Threads.nthreads())) threads)", callback=nothing, kwargs... ) @@ -328,7 +328,9 @@ function mcmcsample( chain_names = ["$progressname (Chain $i of $nchains)" for i in 1:nchains] @ifwithprogresslogger_children progress nchains names=chain_names begin # Create a channel for progress logging. - channel = Channel{Int}(length(interval) * 10) + if progress + channel = Channel{Int}(length(interval) * 10) + end Distributed.@sync begin if progress