@@ -298,3 +298,53 @@ Base.show(io::IO, mime::MIME"text/plain", tbl::TBLogger) = begin
298
298
"""
299
299
Base. print (io, str)
300
300
end
301
+
302
+ """
303
+ `with_TBLogger_hold_step(f, [step]; step_at_end::Bool=true)`
304
+ Context function to ease control of logging steps and synchronization.
305
+ Amount of step increment can be controlled via `set_step_increment!``.
306
+
307
+ Example:
308
+ ```julia
309
+ with_logger(lg) do
310
+ for epoch in 1:10
311
+ for i=1:100
312
+ # increments global_step by default
313
+ with_TBLogger_hold_step() do
314
+ # all of these are logged at the same global_step
315
+ # and the logger global_step is only then increased
316
+ @info "train1/scalar" i=i
317
+ @info "train2/scalar" i2=i^2
318
+ @info "train3/scalar" i3=i^3
319
+ end
320
+ end
321
+ # step increment at end can be disabled for easy train/test sync
322
+ with_TBLogger_hold_step(;step_at_end=false) do
323
+ # all of these are logged at the same global_step
324
+ # and the logger global_step is only then increased
325
+ @info "test1/scalar" i=i
326
+ @info "test2/scalar" i2=i^2
327
+ @info "test3/scalar" i3=i^3
328
+ end
329
+ end
330
+ end
331
+ ```
332
+
333
+ """
334
+ function with_TBLogger_hold_step (f, step:: Int ; step_at_end:: Bool = true )
335
+ logger = CoreLogging. current_logger ()
336
+ @assert logger isa TBLogger " with_TBLogger_hold_step: current logger is not a TBLogger, cannot establish current step automatically"
337
+ curr_step = logger. global_step
338
+ curr_increment = logger. step_increment
339
+ set_step! (logger, step)
340
+ set_step_increment! (logger, 0 )
341
+ f ()
342
+ set_step! (logger, curr_step)
343
+ set_step_increment! (logger, curr_increment)
344
+ step_at_end && increment_step! (logger, curr_increment)
345
+ end
346
+ function with_TBLogger_hold_step (f; step_at_end:: Bool = true )
347
+ logger = CoreLogging. current_logger ()
348
+ isa (logger, TBLogger) || error (" with_TBLogger_hold_step: current logger is not a TBLogger, cannot establish current step automatically" )
349
+ with_TBLogger_hold_step (f, logger. global_step; step_at_end= step_at_end)
350
+ end
0 commit comments