Skip to content

Testing new progress bars #83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions src/logging.jl
Original file line number Diff line number Diff line change
@@ -43,3 +43,134 @@ function progresslogger()
return TerminalLoggers.TerminalLogger()
end
end

##

using ProgressLogging: _asprogress, ProgressString, ProgressLevel, uuid4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the use of these internals indicates that these extensions should be discussed and potentially be added in ProgressLogging but not in AbstractMCMC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No doubt, I will test these extensions and may open a PR in ProgressLogging later, this is just a quick hack

using Base.Meta: isexpr

const _id_var_children = gensym(:progress_id_children)
const _name_var_children = gensym(:progress_name_children)

macro progressid_child(N)
id_err = "`@progressid_child` must be used inside `@withprogress_children`"

quote
$Base.@isdefined($_id_var_children) ? $_id_var_children[$N] : $error($id_err)
end |> esc
end

macro logprogress_child(args...)
_logprogress_child(args...)
end

function _logprogress_child(N, name, progress = nothing, args...)
name_expr = :($Base.@isdefined($_name_var_children) ? $_name_var_children[$N] : "")
if progress == nothing
# Handle: @logprogress progress
kwargs = (:(progress = $name), args...)
progress = name
name = name_expr
elseif isexpr(progress, :(=)) && progress.args[1] isa Symbol
# Handle: @logprogress progress key1=val1 ...
kwargs = (:(progress = $name), progress, args...)
progress = name
name = name_expr
else
# Otherwise, it's: @logprogress name progress key1=val1 ...
kwargs = (:(progress = $progress), args...)
end

id_err = "`@logprogress_child` must be used inside `@withprogress_children`"
id_expr = :($Base.@isdefined($_id_var_children) ? $_id_var_children[$N] : $error($id_err))

@gensym id_tmp
# Emitting progress log record as old/open API (i.e., using
# `progress` key) and _also_ as new API based on `Progress` type.
msgexpr = :($ProgressString($_asprogress(
$name,
$id_tmp,
$(ProgressLogging._id_var);
progress = $progress,
)))
quote
$id_tmp = $id_expr
$Logging.@logmsg($ProgressLevel, $msgexpr, $(kwargs...), _id = $id_tmp)
end |> esc
end

macro withprogress_children(N, exprs...)
_withprogress_children(N, exprs...)
end

function _withprogress_children(N, exprs...)
length(exprs) == 0 &&
throw(ArgumentError("`@withprogress_children` requires at least one number and one expression."))

m = ProgressLogging.@__MODULE__

kwargs = Dict{Symbol,Any}(:names => :(["" for i in 1:$N]))
unsupported = []
for kw in exprs[1:end-1]
if isexpr(kw, :(=)) && length(kw.args) == 2 && haskey(kwargs, kw.args[1])
kwargs[kw.args[1]] = kw.args[2]
else
push!(unsupported, kw)
end
end

# Error on invalid input expressions:
if !isempty(unsupported)
msg = sprint() do io
println(io, "Unsupported optional arguments:")
for kw in unsupported
println(io, kw)
end
print(io, "`@withprogress_children` supports only following keyword arguments: ")
join(io, keys(kwargs), ", ")
end
throw(ArgumentError(msg))
end

ex = exprs[end]
id_err = "`@withprogress_children` must be used inside `@withprogress`"

i_var = gensym()
quote
let
$Base.@isdefined($(ProgressLogging._id_var)) || $error($id_err)
$_id_var_children = [ $uuid4() for i in 1:$N ]
$_name_var_children = $(kwargs[:names])

for $i_var in 1:$N
@logprogress_child $i_var nothing
end

try
$ex
finally
for $i_var in 1:$N
@logprogress_child $i_var nothing
end
end
end
end |> esc
end

macro ifwithprogresslogger_children(progress, exprs...)
return quote
if $progress
if $hasprogresslevel($Logging.current_logger())
@withprogress_children $(exprs...)
else
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do
@withprogress_children $(exprs...)
end
end
else
$(exprs[end])
end
end |> esc
end


111 changes: 74 additions & 37 deletions src/sample.jl
Original file line number Diff line number Diff line change
@@ -275,6 +275,9 @@ function mcmcsample(
)
end


using ProgressLogging: Progress

function mcmcsample(
rng::Random.AbstractRNG,
model::AbstractModel,
@@ -284,6 +287,7 @@ function mcmcsample(
nchains::Integer;
progress = PROGRESS[],
progressname = "Sampling ($(min(nchains, Threads.nthreads())) threads)",
callback=nothing,
kwargs...
)
# Check if actually multiple threads are used.
@@ -312,51 +316,83 @@ function mcmcsample(
chains = Vector{Any}(undef, nchains)

@ifwithprogresslogger progress name=progressname begin
# Create a channel for progress logging.
if progress
channel = Channel{Bool}(length(interval))
end
threshold = nchains ÷ 200
itotals = zeros(Int, nchains)
next_updates = [ threshold for i in 1:nchains ]

Distributed.@sync begin
threshold_chains = nchains ÷ 200
nextprogress_chains = threshold_chains

progress_chains = 0

chain_names = ["$progressname (Chain $i of $nchains)" for i in 1:nchains]
@ifwithprogresslogger_children progress nchains names=chain_names begin
# Create a channel for progress logging.
if progress
# Update the progress bar.
Distributed.@async begin
# Determine threshold values for progress logging
# (one update per 0.5% of progress)
threshold = nchains ÷ 200
nextprogresschains = threshold
channel = Channel{Int}(length(interval) * 10)
end

progresschains = 0
while take!(channel)
progresschains += 1
if progresschains >= nextprogresschains
ProgressLogging.@logprogress progresschains/nchains
nextprogresschains = progresschains + threshold
Distributed.@sync begin
if progress
# Update the progress bar.
Distributed.@async begin
# Determine threshold values for progress logging
# (one update per 0.5% of progress)
threshold = nchains ÷ 200
nextprogresschains = threshold

progresschains = 0
while (i = take!(channel)) != 0
if i > 0
itotals[i] += 1
if itotals[1] >= next_updates[i]
@logprogress_child i itotals[i] / N
next_updates[i] = itotals[1] + threshold
end
else
i = -i
@logprogress_child i "done"
progresschains += 1

if progresschains >= nextprogresschains
ProgressLogging.@logprogress progresschains/nchains
nextprogresschains = progresschains + threshold_chains
end
end
end
end
end
end

Distributed.@async begin
try
Threads.@threads for i in 1:nchains
# Obtain the ID of the current thread.
id = Threads.threadid()

# Seed the thread-specific random number generator with the pre-made seed.
subrng = rngs[id]
Random.seed!(subrng, seeds[i])

# Sample a chain and save it to the vector.
chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N;
progress = false, kwargs...)

# Update the progress bar.
progress && put!(channel, true)
Distributed.@async begin
try
Threads.@threads for i in 1:nchains
# Obtain the ID of the current thread.
id = Threads.threadid()

# Seed the thread-specific random number generator with the pre-made seed.
subrng = rngs[id]
Random.seed!(subrng, seeds[i])

if progress
if callback isa Nothing
callback_i = (args...; kwargs...) -> put!(channel, i)
else
callback_i = (args...; kwargs...) -> begin put!(channel, i); callback(args...; kwargs...) end
end
else
callback_i = callback
end

# Sample a chain and save it to the vector.
chains[i] = StatsBase.sample(subrng, models[id], samplers[id], N;
progress = false, callback=callback_i, kwargs...)

progress && put!(channel, -i)
end
finally
# Stop updating the progress bar.
progress && put!(channel, 0)
end
finally
# Stop updating the progress bar.
progress && put!(channel, false)
end
end
end
@@ -366,6 +402,7 @@ function mcmcsample(
return chainsstack(tighten_eltype(chains))
end


function mcmcsample(
rng::Random.AbstractRNG,
model::AbstractModel,