Skip to content

do-notation API for logging everything in one step #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/Scalars.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,28 @@ with_logger(logger) do
@info "scalar/complex" y = z
end
end


################control step increments with context################
with_logger(logger) do
for epoch in 1:10
for i=1:100
# increments global_step by default
with_TBLogger_hold_step() do
# all of these are logged at the same global_step
# and the logger global_step is only then increased
@info "train1/scalar" val=i
@info "train2/scalar" val2=i/2
@info "train3/scalar" val3=100-i
end
end
# step increment at end can be disabled for easy train/test sync
with_TBLogger_hold_step(;step_at_end=false) do
# all of these are logged at the same global_step
# and the logger global_step is only then increased
@info "test1/scalar" epoch=epoch
@info "test2/scalar" epoch2=epoch^2
@info "test3/scalar" epoch3=epoch^3
end
end
end
50 changes: 50 additions & 0 deletions src/TBLogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,53 @@ Base.show(io::IO, mime::MIME"text/plain", tbl::TBLogger) = begin
"""
Base.print(io, str)
end

"""
`with_TBLogger_hold_step(f, [step]; step_at_end::Bool=true)`
Context function to ease control of logging steps and synchronization.
Amount of step increment can be controlled via `set_step_increment!``.

Example:
```julia
with_logger(lg) do
for epoch in 1:10
for i=1:100
# increments global_step by default
with_TBLogger_hold_step() do
# all of these are logged at the same global_step
# and the logger global_step is only then increased
@info "train1/scalar" i=i
@info "train2/scalar" i2=i^2
@info "train3/scalar" i3=i^3
end
end
# step increment at end can be disabled for easy train/test sync
with_TBLogger_hold_step(;step_at_end=false) do
# all of these are logged at the same global_step
# and the logger global_step is only then increased
@info "test1/scalar" i=i
@info "test2/scalar" i2=i^2
@info "test3/scalar" i3=i^3
end
end
end
```

"""
function with_TBLogger_hold_step(f, step::Int; step_at_end::Bool=true)
logger = CoreLogging.current_logger()
@assert logger isa TBLogger "with_TBLogger_hold_step: current logger is not a TBLogger, cannot establish current step automatically"
curr_step = logger.global_step
curr_increment = logger.step_increment
set_step!(logger, step)
set_step_increment!(logger, 0)
f()
set_step!(logger, curr_step)
set_step_increment!(logger, curr_increment)
step_at_end && increment_step!(logger, curr_increment)
end
function with_TBLogger_hold_step(f; step_at_end::Bool=true)
logger = CoreLogging.current_logger()
isa(logger, TBLogger) || error("with_TBLogger_hold_step: current logger is not a TBLogger, cannot establish current step automatically")
with_TBLogger_hold_step(f, logger.global_step; step_at_end=step_at_end)
end
2 changes: 1 addition & 1 deletion src/TensorBoardLogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info,
handle_message, shouldlog, min_enabled_level, catch_exceptions, with_logger,
NullLogger

export TBLogger, reset!, set_step!, increment_step!, set_step_increment!
export TBLogger, reset!, set_step!, increment_step!, set_step_increment!, with_TBLogger_hold_step
export log_histogram, log_value, log_vector, log_text, log_image, log_images,
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar
export map_summaries, TBReader
Expand Down