11from contextlib import nullcontext
2- from typing import Any , Dict , Optional
2+ from typing import Any , Optional
33
44import ray
55import torch
66import wandb
77from coati .distributed .consumer import BaseConsumer
88from coati .distributed .loss import PolicyLoss
9- from coati .distributed .reward .reward_fn import boxed_math_reward_fn , math_reward_fn
10- from coati .distributed .reward .verifiable_reward import VerifiableReward
11- from coati .distributed .utils import calc_action_log_probs
9+ from coati .distributed .utils import memory_efficient_logprob
1210from coati .trainer .utils import all_reduce_mean , all_reduce_sum
1311from transformers import AutoModelForCausalLM , AutoTokenizer
1412
@@ -40,6 +38,8 @@ def __init__(
4038 project_name : str = None ,
4139 run_name : str = None ,
4240 wandb_group_name : str = None ,
41+ enable_profiling : bool = False ,
42+ n_behind : int = 0 ,
4343 ):
4444 print (f"Using GRPO config: { grpo_config } " )
4545 if (
@@ -65,6 +65,8 @@ def __init__(
6565 minibatch_size ,
6666 save_interval = save_interval ,
6767 save_dir = save_dir ,
68+ enable_profiling = enable_profiling ,
69+ n_behind = n_behind ,
6870 )
6971 path = model_config .pop ("path" )
7072 self .policy_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
@@ -119,20 +121,7 @@ def __init__(
119121 "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
120122 )
121123 # Initialize verifiable reward.
122- response_format_tags = grpo_config .get ("response_format_tags" , None )
123- reward_model_kwargs = {
124- k : v
125- for k , v in grpo_config .items ()
126- if k in ["soft_over_length_punishment" , "max_new_tokens" , "cache_length" ]
127- }
128- self .reward_model = VerifiableReward (
129- reward_fns = [
130- math_reward_fn if grpo_config .get ("reward_fn_type" ) == "think_answer_tags" else boxed_math_reward_fn
131- ],
132- tokenizer = self .tokenizer ,
133- tags = response_format_tags ,
134- ** reward_model_kwargs ,
135- )
124+ grpo_config .get ("response_format_tags" , None )
136125 self .global_step = 0
137126
138127 self .lr_scheduler = CosineAnnealingWarmupLR (
@@ -295,12 +284,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
295284 )
296285
297286 if self .booster .plugin .stage_manager .is_last_stage ():
298- reference_model_logits = reference_model_outputs ["outputs" ]["logits" ]
299- reference_action_log_probs = calc_action_log_probs (
300- reference_model_logits / self .generate_config ["temperature" ],
287+ reference_action_log_probs = memory_efficient_logprob (
288+ reference_model_outputs ["outputs" ]["logits" ],
301289 input_ids_forward_micro_batch ,
302290 num_action ,
303- self .plugin .shard_config ,
291+ shard_config = self .plugin .shard_config ,
304292 )
305293 else :
306294 # Dummy reference logprobs for data iterator.
@@ -323,11 +311,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
323311
324312 def _criterion (outputs , inputs ):
325313 action_logits = outputs .logits
326- action_log_probs = calc_action_log_probs (
327- action_logits / self . generate_config [ "temperature" ] ,
314+ action_log_probs = memory_efficient_logprob (
315+ action_logits ,
328316 inputs ["input_ids" ],
329317 num_action ,
330- self .plugin .shard_config ,
318+ shard_config = self .plugin .shard_config ,
331319 )
332320 if "reference_action_log_probs" in inputs :
333321 per_token_kl = (
@@ -370,16 +358,15 @@ def _criterion(outputs, inputs):
370358 mean_kl .append (kl )
371359 mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
372360 else :
373-
374361 policy_model_logits = self .policy_model (
375362 input_ids = input_ids_forward_micro_batch ,
376363 attention_mask = attention_mask_forward_micro_batch ,
377364 ).logits
378- action_log_probs = calc_action_log_probs (
365+ action_log_probs = memory_efficient_logprob (
379366 policy_model_logits / self .generate_config ["temperature" ],
380367 input_ids_forward_micro_batch ,
381368 num_action ,
382- self .plugin .shard_config ,
369+ shard_config = self .plugin .shard_config ,
383370 )
384371
385372 if self .policy_loss_fn .beta > 0 :
@@ -388,11 +375,11 @@ def _criterion(outputs, inputs):
388375 input_ids = input_ids_forward_micro_batch ,
389376 attention_mask = attention_mask_forward_micro_batch ,
390377 ).logits
391- reference_action_log_probs = calc_action_log_probs (
378+ reference_action_log_probs = memory_efficient_logprob (
392379 reference_model_logits / self .generate_config ["temperature" ],
393380 input_ids_forward_micro_batch ,
394381 num_action ,
395- self .plugin .shard_config ,
382+ shard_config = self .plugin .shard_config ,
396383 )
397384 per_token_kl = (
398385 torch .exp (reference_action_log_probs - action_log_probs )
@@ -498,40 +485,6 @@ def _criterion(outputs, inputs):
498485 else :
499486 return None
500487
501- def calculate_reward (self , rollout : Dict [str , Any ]) -> Dict [str , Any ]:
502- """
503- Calculate the group reward for the given rollout group.
504-
505- Args:
506- rollout_group (Dict[str, Any]):
507- a group of samples generated by the model from the same prompt
508- contain the following keys:
509- "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
510- "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
511- "action_mask": torch.Tensor, [num_of_generation, response_length]
512- "action_log_probs": torch.Tensor, [num_of_generation, response_length]
513- "response_idx": int, torch.Tensor, [num_of_generation, 2]
514- "gt_answer": torch.Tensor, [num_of_generation, 128]
515- "temperature": torch.Tensor, [] (scalar)
516-
517- Returns:
518- Dict[str, Any]: The new group data with calculated reward.
519- """
520- reward_model_output = self .reward_model (
521- rollout ["input_ids" ],
522- gt_answer = rollout ["gt_answer" ],
523- response_idx = rollout ["response_idx" ],
524- )
525- # [num_of_generation]
526- reward = torch .tensor ([value [0 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
527- format_acc = torch .tensor ([value [1 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
528- ans_acc = torch .tensor ([value [2 ] for value in reward_model_output ]).to (rollout ["input_ids" ].device )
529-
530- rollout ["reward" ] = reward .view ((- 1 , 1 ))
531- rollout ["format_acc" ] = format_acc .view ((- 1 , 1 ))
532- rollout ["ans_acc" ] = ans_acc .view ((- 1 , 1 ))
533- return rollout
534-
535488 def state_dict (self ):
536489 self .policy_model ._force_wait_all_gather ()
537490 model = self .policy_model .unwrap ()
0 commit comments