@@ -122,26 +122,102 @@ def loop(self) -> None:
122122                        # receive data from producers 
123123                        for  r  in  range (self .num_producers ):
124124                            print (f"[T{ dist .get_rank ()}  ] Recv data episode { episode }   step { step }   from { r }  " )
125-                             self .buffer .extend (
126-                                 unbind_batch (
127-                                     ray_broadcast_tensor_dict (
128-                                         None , src = 0 , device = self .device , group_name = f"sync_data_{ r }  " 
129-                                     )
130-                                 )
125+                             raw_batch  =  ray_broadcast_tensor_dict (
126+                                 None , src = 0 , device = self .device , group_name = f"sync_data_{ r }  " 
131127                            )
132-                         while  len (self .buffer ) >=  self .dp_size  *  self .minibatch_size :
133-                             batches  =  self .buffer [
134-                                 self .dp_rank  *  self .minibatch_size  : (self .dp_rank  +  1 ) *  self .minibatch_size 
128+                             # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), 
129+                             # we need to calculate the metrics before filtering here for logging 
130+                             # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...] 
131+                             raw_batch_with_reward  =  self .calculate_reward (
132+                                 {k : v .view (- 1 , v .size (- 1 )) if  k  !=  "temperature"  else  v  for  k , v  in  raw_batch .items ()}
133+                             )
134+                             raw_batch_with_reward  =  {
135+                                 k : v .view (- 1 , self .num_generations , v .size (- 1 )) if  k  !=  "temperature"  else  v 
136+                                 for  k , v  in  raw_batch_with_reward .items ()
137+                             }
138+                             # [batch_size, num_generations] -> [batch_size] 
139+                             reward  =  raw_batch_with_reward ["reward" ][:, :, 0 ]
140+                             format_acc  =  raw_batch_with_reward ["format_acc" ][:, :, 0 ]
141+                             ans_acc  =  raw_batch_with_reward ["ans_acc" ][:, :, 0 ]
142+                             response_len  =  (
143+                                 raw_batch_with_reward ["response_idx" ][:, :, 1 ]
144+                                 -  raw_batch_with_reward ["response_idx" ][:, :, 0 ]
145+                                 +  1 
146+                             ).type (torch .float32 )
147+                             effective_group_mask  =  None 
148+                             if  self .filter_range  is  not   None  and  self .grpo_config .get ("dynamic_batching" , True ):
149+                                 # filter the group based on the reward and accuracy 
150+                                 group_ans_acc_mean  =  ans_acc .mean (dim = 1 )
151+                                 effective_group_mask  =  torch .logical_and (
152+                                     group_ans_acc_mean  >  self .filter_range [0 ], group_ans_acc_mean  <  self .filter_range [1 ]
153+                                 )
154+                             raw_batch_with_reward  =  unbind_batch (raw_batch_with_reward )  # List[Dict[str, torch.Tensor]] 
155+                             for  group_idx , group_with_reward  in  enumerate (raw_batch_with_reward ):
156+                                 self .buffer .append (
157+                                     [
158+                                         (
159+                                             group_with_reward 
160+                                             if  effective_group_mask  is  None  or  effective_group_mask [group_idx ]
161+                                             else  None 
162+                                         ),
163+                                         reward [group_idx ],
164+                                         format_acc [group_idx ],
165+                                         ans_acc [group_idx ],
166+                                         response_len [group_idx ],
167+                                     ]
168+                                 )
169+                             if  effective_group_mask  is  not   None :
170+                                 print (
171+                                     f"[T{ dist .get_rank ()}  ] Filter recv data: { len (raw_batch_with_reward )}   -> { torch .sum (effective_group_mask ).cpu ().item ()}   effective groups" 
172+                                 )
173+                         # mapping the effective group to the raw group for indexing 
174+                         effective_group_to_raw_group_mapping  =  {}
175+                         for  buffer_idx  in  range (len (self .buffer )):
176+                             if  self .buffer [buffer_idx ][0 ] is  not   None :
177+                                 effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] =  (
178+                                     buffer_idx 
179+                                 )
180+                         print (
181+                             f"[T{ dist .get_rank ()}  ] Collect Effective Prompt: { len (effective_group_to_raw_group_mapping )}  /{ self .dp_size  *  self .minibatch_size }  " 
182+                         )
183+ 
184+                         while  len (effective_group_to_raw_group_mapping ) >=  self .dp_size  *  self .minibatch_size :
185+                             # on each dp_rank, we use minibatch_size effective samples to form a batch 
186+                             batches  =  [
187+                                 self .buffer [effective_group_to_raw_group_mapping [i ]]
188+                                 for  i  in  range (
189+                                     self .dp_rank  *  self .minibatch_size , (self .dp_rank  +  1 ) *  self .minibatch_size 
190+                                 )
135191                            ]
136-                             batch  =  bind_batch (batches )
192+                             # every dp_rank will receive a complete mini-batch, no need to sync within step() later 
193+                             # each mini-batch use the first self.dp_size * minibatch_size effective samples 
194+                             raw_mini_batches  =  self .buffer [
195+                                 : effective_group_to_raw_group_mapping [self .dp_size  *  self .minibatch_size  -  1 ] +  1 
196+                             ]  # include the last effective sample 
197+                             raw_mini_batches_metric_dict  =  {
198+                                 "raw_train_mini_batch_reward" : [t [1 ] for  t  in  raw_mini_batches ],
199+                                 "raw_train_mini_batch_format_acc" : [t [2 ] for  t  in  raw_mini_batches ],
200+                                 "raw_train_mini_batch_ans_acc" : [t [3 ] for  t  in  raw_mini_batches ],
201+                                 "raw_train_mini_batch_response_len" : [t [4 ] for  t  in  raw_mini_batches ],
202+                             }
203+                             batch  =  bind_batch ([t [0 ] for  t  in  batches ])
137204                            batch  =  post_recv (batch )
138-                             loss , excessive_prompts_idx  =  self .step (i , pbar , ** batch )
139- 
140-                             if  excessive_prompts_idx  is  not   None :
141-                                 excessive_prompts  =  [self .buffer [idx ] for  idx  in  excessive_prompts_idx ]
142-                                 self .buffer  =  excessive_prompts  +  self .buffer [self .dp_size  *  self .minibatch_size  :]
143-                             else :
144-                                 self .buffer  =  self .buffer [self .dp_size  *  self .minibatch_size  :]
205+                             loss  =  self .step (i , pbar , ** batch , ** raw_mini_batches_metric_dict )
206+                             self .buffer  =  self .buffer [
207+                                 effective_group_to_raw_group_mapping [self .dp_size  *  self .minibatch_size  -  1 ] +  1  :
208+                             ]
209+                             # recalculate the effective group to raw group mapping 
210+                             effective_group_to_raw_group_mapping_size_before  =  len (effective_group_to_raw_group_mapping )
211+                             effective_group_to_raw_group_mapping  =  {}
212+                             for  buffer_idx  in  range (len (self .buffer )):
213+                                 if  self .buffer [buffer_idx ][0 ] is  not   None :
214+                                     effective_group_to_raw_group_mapping [len (effective_group_to_raw_group_mapping )] =  (
215+                                         buffer_idx 
216+                                     )
217+                             assert  (
218+                                 len (effective_group_to_raw_group_mapping )
219+                                 ==  effective_group_to_raw_group_mapping_size_before  -  self .dp_size  *  self .minibatch_size 
220+                             )
145221                            if  loss  is  not   None :
146222                                pbar .set_postfix ({"loss" : loss })
147223                            i  +=  1 
0 commit comments