@@ -60,16 +60,17 @@ class _EmptyDataLoader:
6060 """Minimal dataloader for online mode that yields empty dicts.
6161
6262 Compatible with ``cycle_dataloader()`` and ``len()`` expectations.
63- Each "epoch" produces a single batch of ``batch_size`` empty dicts ,
64- so the training loop collects the correct number of trajectories
65- before proceeding to a train step .
63+ ``steps_per_epoch`` controls how many steps constitute one epoch ,
64+ derived from ``total_train_steps // total_train_epochs`` to ensure
65+ epoch-frequency-gated components (Saver, RecoverHandler) behave correctly .
6666 """
6767
68- def __init__ (self , batch_size : int = 1 ):
68+ def __init__ (self , batch_size : int = 1 , steps_per_epoch : int = 1 ):
6969 self .batch_size = batch_size
70+ self ._steps_per_epoch = steps_per_epoch
7071
7172 def __len__ (self ) -> int :
72- return 1 # 1 step per "epoch" for online mode
73+ return self . _steps_per_epoch
7374
7475 def __iter__ (self ):
7576 while True :
@@ -123,9 +124,26 @@ def __init__(
123124 self .train_dataset = train_dataset
124125 self .valid_dataset = valid_dataset
125126 if train_dataset is None :
126- # Online mode: use empty data generator
127+ # Online mode: require total_train_steps to compute steps_per_epoch.
128+ # Without this, __len__()=1 causes every step to be treated as an
129+ # epoch boundary, making Saver/RecoverHandler fire every step and
130+ # corrupting the LR schedule.
131+ if config .total_train_steps is None :
132+ raise ValueError (
133+ "total_train_steps must be set for online mode "
134+ "(train_dataset is None). Both total_train_epochs and "
135+ "total_train_steps are needed to compute steps_per_epoch."
136+ )
137+ steps_per_epoch = config .total_train_steps // config .total_train_epochs
138+ if steps_per_epoch < 1 :
139+ raise ValueError (
140+ f"total_train_steps ({ config .total_train_steps } ) must be >= "
141+ f"total_train_epochs ({ config .total_train_epochs } ) so that "
142+ f"steps_per_epoch >= 1."
143+ )
127144 self .train_dataloader = _EmptyDataLoader (
128- batch_size = config .train_dataset .batch_size
145+ batch_size = config .train_dataset .batch_size ,
146+ steps_per_epoch = steps_per_epoch ,
129147 )
130148 else :
131149 self .train_dataloader = self ._create_dataloader (
0 commit comments