From c5f4b5ce1e0a1e4e49a2e0c26aaf5c41dba7fdde Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 20:20:03 -0400 Subject: [PATCH 01/74] . --- test_single_controller_ppo.py | 5 +++++ yamls/single-controller-grpo-workflow.yaml | 1 + 2 files changed, 6 insertions(+) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index dee5593c..c6cbbd83 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -156,6 +156,7 @@ def __init__( self.ref_model_config = None self.global_train_batch_size = None self.max_gen_len = None + self.loss_type = None # KL Penalty and Controller self.kl_ift = [] @@ -176,6 +177,10 @@ def build_train_config(self, config: Any): self.model_config = om.to_container(self.config.model, resolve=True) self.model_config['tokenizer'] = self.tokenizer + self.loss_type = self.model_config.get('loss_type', OnPolicyEnum.GRPO) + print("--------------------------------") + print(f'loss_type: {self.loss_type}') + print("--------------------------------") # Reference Model Initializing self.ref_model_config = om.to_container(self.config.variables.reference_model.model_config, resolve=True) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index ac59150a..29b66f22 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -37,6 +37,7 @@ parameters: pretrained: true init_device: mixed kl_estimator: k3 + beta: 1e-3 kl_clip_range: 40 use_auth_token: true compute_kl_loss: false From 60f016ad8250198fc648b32e542730b81bd71463 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 20:22:00 -0400 Subject: [PATCH 02/74] . --- test_single_controller_ppo.py | 2 ++ yamls/single-controller-grpo-workflow.yaml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index c6cbbd83..22f644d8 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -88,6 +88,8 @@ RewardOutput, ) +from compose_rl.algorithms.online.model_methods import OnPolicyEnum + @contextmanager def time_it(name: str): diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 29b66f22..b9238551 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -14,7 +14,7 @@ integrations: path: /workspace/compose-rl git_repo: databricks/compose-rl ssh_clone: true - git_branch: single-controller-hackathon + git_branch: single-controller-hackathon-smd #single-controller-hackathon - integration_type: git_repo path: /workspace/research-universe git_repo: databricks-mosaic/research-universe From 50b3a7094410f069057e2d79c7b8a2d675772104 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 20:31:54 -0400 Subject: [PATCH 03/74] . --- test_single_controller_ppo.py | 53 ++++++++++++++-------- yamls/single-controller-grpo-workflow.yaml | 4 +- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 22f644d8..0f3b02d8 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -815,25 +815,42 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any # Calculate GRPO advantage grpo_advantage = (flat_rewards - mean_rewards) # Only normalize the advantage if flag is set - if self.model_config['normalize_advantage']: # type: ignore - grpo_advantage /= (std_rewards + 1e-4) - - # Create advantages of the same shape as original rewards - advantages = torch.zeros_like(rewards) - # Copy the flat grpo_advantage according to action_mask - expanded_advantages = grpo_advantage.unsqueeze(1).expand_as( - batch['action_mask'], - ) - advantages = torch.where( - batch['action_mask'].bool(), - expanded_advantages, - advantages, - ) + if self.loss_type == OnPolicyEnum.GRPO: + if self.model_config['normalize_advantage']: # type: ignore + grpo_advantage /= (std_rewards + 1e-4) + + # Create advantages of the same shape as original rewards + advantages = torch.zeros_like(rewards) + # Copy the flat grpo_advantage according to action_mask + expanded_advantages = grpo_advantage.unsqueeze(1).expand_as( + batch['action_mask'], + ) + advantages = torch.where( + batch['action_mask'].bool(), + expanded_advantages, + advantages, + ) - batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var( - advantages, - batch['action_mask'], - ) + batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var( + advantages, + batch['action_mask'], + ) + print("-----------------------------------------------") + print(grpo_advantage.shape) + print(batch_adv_mean) + print(batch_adv_var) + print("-----------------------------------------------") + + elif self.loss_type == OnPolicyEnum.SMD: + advantages = grpo_advantage + batch_adv_mean = torch.mean(advantages) + batch_adv_var = torch.std(advantages)**2 + print("-----------------------------------------------") + print(grpo_advantage.shape) + print(batch_adv_mean) + print(batch_adv_var) + print("-----------------------------------------------") + advantage_output = { 'advantages': advantages, diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index b9238551..8d4708cd 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -59,9 +59,9 @@ parameters: ppo: {} orl_eval: evals: - - name: gsm8k + #- name: gsm8k - name: math_500 - - name: math_hard + #- name: math_hard eval_overrides: generation_params: max_tokens: 8192 From bed11799909b7852194c97ba3406771ec720aa5e Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 20:34:06 -0400 Subject: [PATCH 04/74] . --- test_single_controller_ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 0f3b02d8..c5d732fb 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -835,7 +835,7 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any advantages, batch['action_mask'], ) - print("-----------------------------------------------") + print("-----------------------GRPO------------------------") print(grpo_advantage.shape) print(batch_adv_mean) print(batch_adv_var) @@ -845,7 +845,7 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any advantages = grpo_advantage batch_adv_mean = torch.mean(advantages) batch_adv_var = torch.std(advantages)**2 - print("-----------------------------------------------") + print("------------------------SMD-----------------------") print(grpo_advantage.shape) print(batch_adv_mean) print(batch_adv_var) From c69e228816448b3a00f22e9e3b9c1363415ac428 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 21:03:53 -0400 Subject: [PATCH 05/74] . --- test_single_controller_ppo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index c5d732fb..31879dc8 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -850,6 +850,8 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any print(batch_adv_mean) print(batch_adv_var) print("-----------------------------------------------") + else: + raise ValueError(f"Unsupported loss_type: {self.loss_type}") advantage_output = { From 665b80c792405e100888b49d753f50c8ae317194 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 21:26:14 -0400 Subject: [PATCH 06/74] . --- test_single_controller_ppo.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 31879dc8..14853387 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -815,12 +815,15 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any # Calculate GRPO advantage grpo_advantage = (flat_rewards - mean_rewards) # Only normalize the advantage if flag is set + + advantages = torch.zeros_like(rewards) + batch_adv_mean = torch.tensor(0.0) + batch_adv_var = torch.tensor(0.0) + if self.loss_type == OnPolicyEnum.GRPO: if self.model_config['normalize_advantage']: # type: ignore grpo_advantage /= (std_rewards + 1e-4) - # Create advantages of the same shape as original rewards - advantages = torch.zeros_like(rewards) # Copy the flat grpo_advantage according to action_mask expanded_advantages = grpo_advantage.unsqueeze(1).expand_as( batch['action_mask'], @@ -841,15 +844,15 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any print(batch_adv_var) print("-----------------------------------------------") - elif self.loss_type == OnPolicyEnum.SMD: - advantages = grpo_advantage - batch_adv_mean = torch.mean(advantages) - batch_adv_var = torch.std(advantages)**2 - print("------------------------SMD-----------------------") - print(grpo_advantage.shape) - print(batch_adv_mean) - print(batch_adv_var) - print("-----------------------------------------------") + #elif self.loss_type == OnPolicyEnum.SMD: + #advantages = grpo_advantage + #batch_adv_mean = torch.mean(advantages) + #batch_adv_var = torch.std(advantages)**2 + #print("------------------------SMD-----------------------") + #print(grpo_advantage.shape) + #print(batch_adv_mean) + #print(batch_adv_var) + #print("-----------------------------------------------") else: raise ValueError(f"Unsupported loss_type: {self.loss_type}") From 8ae2b017853c7068960af40ef660eb4b4e76e9d4 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 21:37:41 -0400 Subject: [PATCH 07/74] . --- test_single_controller_ppo.py | 47 +++++++++++------------------------ 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 14853387..54acd79d 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -815,47 +815,30 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any # Calculate GRPO advantage grpo_advantage = (flat_rewards - mean_rewards) # Only normalize the advantage if flag is set + if self.model_config['normalize_advantage'] and self.loss_type == OnPolicyEnum.GRPO: # type: ignore + grpo_advantage /= (std_rewards + 1e-4) - advantages = torch.zeros_like(rewards) - batch_adv_mean = torch.tensor(0.0) - batch_adv_var = torch.tensor(0.0) - + # Create advantages of the same shape as original rewards + advantages = torch.zeros_like(rewards) #(bs, max_gen_len) + # Copy the flat grpo_advantage according to action_mask + expanded_advantages = grpo_advantage.unsqueeze(1).expand_as( + batch['action_mask'], + ) #(bs, max_gen_len) if self.loss_type == OnPolicyEnum.GRPO: - if self.model_config['normalize_advantage']: # type: ignore - grpo_advantage /= (std_rewards + 1e-4) - - # Copy the flat grpo_advantage according to action_mask - expanded_advantages = grpo_advantage.unsqueeze(1).expand_as( - batch['action_mask'], - ) advantages = torch.where( batch['action_mask'].bool(), expanded_advantages, advantages, ) - - batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var( - advantages, - batch['action_mask'], - ) - print("-----------------------GRPO------------------------") - print(grpo_advantage.shape) - print(batch_adv_mean) - print(batch_adv_var) - print("-----------------------------------------------") - - #elif self.loss_type == OnPolicyEnum.SMD: - #advantages = grpo_advantage - #batch_adv_mean = torch.mean(advantages) - #batch_adv_var = torch.std(advantages)**2 - #print("------------------------SMD-----------------------") - #print(grpo_advantage.shape) - #print(batch_adv_mean) - #print(batch_adv_var) - #print("-----------------------------------------------") + elif self.loss_type == OnPolicyEnum.SMD: + advantages = expanded_advantages else: raise ValueError(f"Unsupported loss_type: {self.loss_type}") - + + batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var( + advantages, + batch['action_mask'], + ) advantage_output = { 'advantages': advantages, From 752995b76d1d34d51b59820e133fd0f69520304f Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 21:49:56 -0400 Subject: [PATCH 08/74] . --- test_single_controller_ppo.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 54acd79d..3f704846 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -824,16 +824,22 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any expanded_advantages = grpo_advantage.unsqueeze(1).expand_as( batch['action_mask'], ) #(bs, max_gen_len) - if self.loss_type == OnPolicyEnum.GRPO: - advantages = torch.where( - batch['action_mask'].bool(), - expanded_advantages, - advantages, - ) - elif self.loss_type == OnPolicyEnum.SMD: - advantages = expanded_advantages - else: - raise ValueError(f"Unsupported loss_type: {self.loss_type}") + + # Branch-free approach: Use mathematical operations instead of conditionals + # For GRPO: apply action_mask, For SMD: use raw advantages + is_grpo = (self.loss_type == OnPolicyEnum.GRPO).float() + is_smd = (self.loss_type == OnPolicyEnum.SMD).float() + + # GRPO path: mask with action_mask, SMD path: use raw expanded_advantages + grpo_advantages = torch.where( + batch['action_mask'].bool(), + expanded_advantages, + advantages, + ) + smd_advantages = expanded_advantages + + # Combine both approaches without branching + advantages = is_grpo * grpo_advantages + is_smd * smd_advantages batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var( advantages, From 83e8a8f468ce17beaedb8cb72511c099fbfa57a6 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 21:58:13 -0400 Subject: [PATCH 09/74] . --- test_single_controller_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 3f704846..05a461a6 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -815,7 +815,7 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any # Calculate GRPO advantage grpo_advantage = (flat_rewards - mean_rewards) # Only normalize the advantage if flag is set - if self.model_config['normalize_advantage'] and self.loss_type == OnPolicyEnum.GRPO: # type: ignore + if self.model_config['normalize_advantage']: # type: ignore grpo_advantage /= (std_rewards + 1e-4) # Create advantages of the same shape as original rewards From 4a3a1ac845701e9d535756946c509256633a7dd6 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 22:08:14 -0400 Subject: [PATCH 10/74] . --- test_single_controller_ppo.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 05a461a6..22f644d8 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -819,27 +819,16 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any grpo_advantage /= (std_rewards + 1e-4) # Create advantages of the same shape as original rewards - advantages = torch.zeros_like(rewards) #(bs, max_gen_len) + advantages = torch.zeros_like(rewards) # Copy the flat grpo_advantage according to action_mask expanded_advantages = grpo_advantage.unsqueeze(1).expand_as( batch['action_mask'], - ) #(bs, max_gen_len) - - # Branch-free approach: Use mathematical operations instead of conditionals - # For GRPO: apply action_mask, For SMD: use raw advantages - is_grpo = (self.loss_type == OnPolicyEnum.GRPO).float() - is_smd = (self.loss_type == OnPolicyEnum.SMD).float() - - # GRPO path: mask with action_mask, SMD path: use raw expanded_advantages - grpo_advantages = torch.where( + ) + advantages = torch.where( batch['action_mask'].bool(), expanded_advantages, advantages, ) - smd_advantages = expanded_advantages - - # Combine both approaches without branching - advantages = is_grpo * grpo_advantages + is_smd * smd_advantages batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var( advantages, From 0c451134c96d26c9d072a97d80a57fb0e8459c62 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 22:25:21 -0400 Subject: [PATCH 11/74] . --- test_single_controller_ppo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 22f644d8..f3ed0ae6 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -837,6 +837,7 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any advantage_output = { 'advantages': advantages, + 'flat_advantages': grpo_advantage, 'adv_masked_mean': torch.ones(bs) * batch_adv_mean.cpu(), 'adv_masked_var': torch.ones(bs) * batch_adv_var.cpu(), 'reward_std': torch.ones(bs) * rewards.std().to('cpu'), From 689fd19294f44dc95e964450b9366abb1a8e89c8 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 22:47:00 -0400 Subject: [PATCH 12/74] . --- compose_rl/algorithms/online/model_methods.py | 10 ++++++++-- test_single_controller_ppo.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index acd00f4e..d42da572 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -15,14 +15,16 @@ class OnPolicyEnum(Enum): PPO = 'ppo' GRPO = 'grpo' APO = 'apo' #add A-star PO + SMD = 'smd' # SMD class ALGORITHM_TYPE(set, Enum): - CRITIC_FREE = {OnPolicyEnum.GRPO, OnPolicyEnum.APO} + CRITIC_FREE = {OnPolicyEnum.GRPO, OnPolicyEnum.APO, OnPolicyEnum.SMD} ACTOR_CRITIC = {OnPolicyEnum.PPO} CLIPPED_PG = {OnPolicyEnum.PPO, OnPolicyEnum.GRPO} REGRESSION = { OnPolicyEnum.APO, + OnPolicyEnum.SMD, } @@ -489,8 +491,12 @@ def online_rl_loss( return_dict = {} advantages = None - if loss_type not in ALGORITHM_TYPE.REGRESSION: + if loss_type not in ALGORITHM_TYPE.REGRESSION: #GRPO and PPO: advantages = batch['advantages'] + assert advantages.dim() == 2 #(bs, max_gen_len) + elif loss_type == OnPolicyEnum.SMD: + advantages = batch['prompt_advantages'] + assert advantages.dim() == 1 #(bs,) # 1. Critic Loss if loss_type in ALGORITHM_TYPE.ACTOR_CRITIC: diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index f3ed0ae6..012b40b8 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -837,7 +837,7 @@ def compute_advantages(self, batch: dict[str, Any], reward_output: dict[str, Any advantage_output = { 'advantages': advantages, - 'flat_advantages': grpo_advantage, + 'prompt_advantages': grpo_advantage, 'adv_masked_mean': torch.ones(bs) * batch_adv_mean.cpu(), 'adv_masked_var': torch.ones(bs) * batch_adv_var.cpu(), 'reward_std': torch.ones(bs) * rewards.std().to('cpu'), From f27e5e8e12c8ba0d84ab096a39d4398b72e76930 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 23:08:49 -0400 Subject: [PATCH 13/74] . --- compose_rl/algorithms/online/model_methods.py | 48 +++++++++++++------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index d42da572..a96bc5fb 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -237,6 +237,7 @@ def policy_loss( if loss_type in ALGORITHM_TYPE.CLIPPED_PG: assert advantages is not None + assert advantages.dim() == 2 #(bs, max_gen_len) online_log_probs, old_log_probs = outputs['online_log_probs'], batch[ 'old_log_probs'] old_entropies = batch['old_entropies'] @@ -394,51 +395,68 @@ def policy_loss( return policy_dict elif loss_type in ALGORITHM_TYPE.REGRESSION: - #assume batch contains (1) V-star values (key 'vstar), (2) rewards (key 'rewards'), (3) ref_log_probs + # current it only supports SMD + # TODO: add APO support + assert advantages is not None + assert advantages.dim() == 1 # (bs,) + online_log_probs = outputs['online_log_probs'] ref_log_probs = batch['ift_log_probs'] - log_probs_diff = online_log_probs - ref_log_probs old_entropies = batch['old_entropies'] + old_log_probs = batch['old_log_probs'] + old_log_probs_diff = old_log_probs - ref_log_probs + #compute KL to pi_ref to keep track the divergence to \pi_ref policy_kl_dict = utils.approx_kl( log_p=ref_log_probs, log_q=online_log_probs, #log_q - log_p = log pi - log pi_ref kl_clip_range=kl_clip_range, ) + old_policy_kl_dict = utils.approx_kl( + log_p=old_log_probs, + log_q=online_log_probs, #log_q - log_p = log pi - log pi_ref + kl_clip_range=kl_clip_range, + ) with torch.no_grad(): policy_kl = utils.masked_mean( policy_kl_dict[kl_estimator], # pyright: ignore batch['action_mask'], ) #plain average over all tokens (KL to pi_ref) + old_policy_kl = utils.masked_mean( + old_policy_kl_dict[kl_estimator], # pyright: ignore + batch['action_mask'], + ) #plain average over all tokens (KL to pi_ref) #compute the policy loss - masked_log_probs_diff = utils.masked_sum( - log_probs_diff, - batch['action_mask'], - dim=-1, - ) #size: (batch_size,) - vstars = batch['vstar'] + if loss_type == OnPolicyEnum.SMD: + masked_log_probs_diff = utils.masked_sum( + old_log_probs_diff, + batch['action_mask'], + dim=-1, + ) #size: (batch_size,) + else: + raise ValueError(f'RegressionPolicy loss not implemented for {loss_type}') + + policy_loss = ((beta * masked_log_probs_diff -advantages)**2).mean() + rewards = utils.masked_sum( batch['rewards'], batch['action_mask'], dim=-1, ) - assert vstars.size() == rewards.size() == masked_log_probs_diff.size( - ) # should have the same shape which is (batch_size, ) - policy_loss = ((beta * masked_log_probs_diff - - (rewards - vstars))**2).mean() policy_dict = { 'loss/policy_loss': policy_loss, - 'kl/policy_kl': policy_kl, + 'kl/ref_policy_kl': policy_kl, + 'kl/old_policy_kl': old_policy_kl, 'gen/gen_length': batch['action_mask'].sum(dim=1).to(torch.float32), 'gen/entropy': old_entropies, 'rewards/mean': torch.mean( rewards, ), #compute the average reward of the current batch - 'vstars/mean': torch.mean( - vstars, + 'advantages/mean': torch.mean( + advantages, ), #compute the average of the vstar of the current batch } return policy_dict From 3cae5d1b13f8f12260507919be99f96b4e6d2d71 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 23:22:57 -0400 Subject: [PATCH 14/74] . --- compose_rl/algorithms/online/model_methods.py | 4 ++-- yamls/single-controller-grpo-workflow.yaml | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index a96bc5fb..dea8e3e9 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -405,7 +405,7 @@ def policy_loss( old_entropies = batch['old_entropies'] old_log_probs = batch['old_log_probs'] - old_log_probs_diff = old_log_probs - ref_log_probs + online_to_old_diff = online_log_probs - old_log_probs # ln(π/π_old) for SMD #compute KL to pi_ref to keep track the divergence to \pi_ref policy_kl_dict = utils.approx_kl( @@ -431,7 +431,7 @@ def policy_loss( #compute the policy loss if loss_type == OnPolicyEnum.SMD: masked_log_probs_diff = utils.masked_sum( - old_log_probs_diff, + online_to_old_diff, # Correct: ln(π/π_old) batch['action_mask'], dim=-1, ) #size: (batch_size,) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 8d4708cd..3b4886b2 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -32,7 +32,7 @@ parameters: seed: 17 model: name: hf_critic_free_lm - loss_type: grpo + loss_type: smd #grpo target_kl: 0.1 pretrained: true init_device: mixed @@ -184,7 +184,8 @@ parameters: download_timeout: 1800 drop_last: true num_workers: 1 - eval_interval: 2iter + eval_interval: 10iter + eval_first: false save_interval: 100iter log_to_console: true save_overwrite: true From 3ad15d3c397e51f5123ce06fbe0931e415032455 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 23:36:56 -0400 Subject: [PATCH 15/74] . --- compose_rl/algorithms/online/model_methods.py | 4 ++++ yamls/single-controller-grpo-workflow.yaml | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index dea8e3e9..eb286afb 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -400,6 +400,10 @@ def policy_loss( assert advantages is not None assert advantages.dim() == 1 # (bs,) + print("########################") + print(f'loss_type: {loss_type}') + print("########################") + online_log_probs = outputs['online_log_probs'] ref_log_probs = batch['ift_log_probs'] old_entropies = batch['old_entropies'] diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 3b4886b2..a582eae3 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -1,4 +1,4 @@ -name: single-controller-hackathon +name: single-controller-hackathon_smd image: mosaicml/dle:nightly-latest scheduling: @@ -44,7 +44,7 @@ parameters: policy_clip_ratio: 0.2 attn_implementation: flash_attention_2 allow_embedding_resizing: true - normalize_advantage: true + normalize_advantage: False use_flash_attention_2: true length_normalize_policy_loss: true pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B From 3faab3c04a4b1b22d6db20c96e825e83f9ca5575 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 23:46:37 -0400 Subject: [PATCH 16/74] . --- compose_rl/algorithms/online/model_methods.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index eb286afb..7325b4d5 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -432,15 +432,12 @@ def policy_loss( batch['action_mask'], ) #plain average over all tokens (KL to pi_ref) - #compute the policy loss - if loss_type == OnPolicyEnum.SMD: - masked_log_probs_diff = utils.masked_sum( - online_to_old_diff, # Correct: ln(π/π_old) - batch['action_mask'], - dim=-1, - ) #size: (batch_size,) - else: - raise ValueError(f'RegressionPolicy loss not implemented for {loss_type}') + #compute the policy loss for SMD; + masked_log_probs_diff = utils.masked_sum( + online_to_old_diff, # Correct: ln(π/π_old) + batch['action_mask'], + dim=-1, + ) #size: (batch_size,) policy_loss = ((beta * masked_log_probs_diff -advantages)**2).mean() From 70730bcfb80931f6cd46f04375be2cb5e32bd4b1 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 25 Aug 2025 23:59:51 -0400 Subject: [PATCH 17/74] . --- compose_rl/algorithms/online/model_methods.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 7325b4d5..f5851625 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -224,6 +224,7 @@ def critic_loss( def policy_loss( advantages: torch.Tensor | None, + prompt_advantages: torch.Tensor | None, outputs: MutableMapping, batch: MutableMapping, loss_type: OnPolicyEnum, @@ -397,8 +398,8 @@ def policy_loss( elif loss_type in ALGORITHM_TYPE.REGRESSION: # current it only supports SMD # TODO: add APO support - assert advantages is not None - assert advantages.dim() == 1 # (bs,) + assert prompt_advantages is not None + assert prompt_advantages.dim() == 1 # (bs,) print("########################") print(f'loss_type: {loss_type}') @@ -439,7 +440,7 @@ def policy_loss( dim=-1, ) #size: (batch_size,) - policy_loss = ((beta * masked_log_probs_diff -advantages)**2).mean() + policy_loss = ((beta * masked_log_probs_diff -prompt_advantages)**2).mean() rewards = utils.masked_sum( batch['rewards'], @@ -509,13 +510,18 @@ def online_rl_loss( # tensors in `outputs` are recomputed at the start of each step in the epoch. return_dict = {} - advantages = None - if loss_type not in ALGORITHM_TYPE.REGRESSION: #GRPO and PPO: - advantages = batch['advantages'] - assert advantages.dim() == 2 #(bs, max_gen_len) - elif loss_type == OnPolicyEnum.SMD: - advantages = batch['prompt_advantages'] - assert advantages.dim() == 1 #(bs,) + #advantages = None + advantages = batch['advantages'] + assert advantages.dim() == 2 #(bs, max_gen_len) + prompt_advantages = batch['prompt_advantages'] + assert prompt_advantages.dim() == 1 #(bs,) + + #if loss_type not in ALGORITHM_TYPE.REGRESSION: #GRPO and PPO: + # advantages = batch['advantages'] + # assert advantages.dim() == 2 #(bs, max_gen_len) + #elif loss_type == OnPolicyEnum.SMD: + # advantages = batch['prompt_advantages'] + # assert advantages.dim() == 1 #(bs,) # 1. Critic Loss if loss_type in ALGORITHM_TYPE.ACTOR_CRITIC: @@ -551,6 +557,7 @@ def online_rl_loss( # 2. Policy Loss policy_dict = policy_loss( advantages=advantages, + prompt_advantages=prompt_advantages, outputs=outputs, batch=batch, loss_type=loss_type, From d3098869b4a9d41ef65727868bc6966e6846ad41 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 00:26:16 -0400 Subject: [PATCH 18/74] . --- compose_rl/algorithms/online/model_methods.py | 2 +- test_single_controller_ppo.py | 27 ++++++------------- yamls/single-controller-grpo-workflow.yaml | 2 +- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index f5851625..8d22a800 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -458,7 +458,7 @@ def policy_loss( rewards, ), #compute the average reward of the current batch 'advantages/mean': torch.mean( - advantages, + prompt_advantages, ), #compute the average of the vstar of the current batch } return policy_dict diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 012b40b8..4f00ecd9 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -12,27 +12,24 @@ import argparse import asyncio -import copy from contextlib import contextmanager import logging import os import pickle import time import datetime -from itertools import chain from functools import partial -from typing import Any, Optional, Union, MutableMapping +from typing import Any, Optional from multiprocessing import get_context from multiprocessing.context import TimeoutError as MultiprocessingTimeoutError from multiprocessing.pool import AsyncResult, Pool from composer.loggers import MLFlowLogger import ray -import spacy import torch import torch.distributed as dist from composer import Trainer -from composer.core import get_precision_context, Precision +from composer.core import get_precision_context from composer.core.data_spec import _default_split_batch from composer.trainer.trainer import _get_initial_device_train_microbatch_size from compose_rl.data.buffer import MinibatchRolloutBuffer @@ -62,21 +59,14 @@ from compose_rl.algorithms.online.callback_utils import preprocess_batches from compose_rl.registry_builders import build_reward from compose_rl.registry import rewards as rewards_registry -from compose_rl.interfaces.base_kl_controller import BaseKLController from compose_rl.algorithms.reward_modeling import ( - BadGenerationEndReward, - BaseReward, - InferenceRewardModel, Reward, - RewardModel, ) from compose_rl.utils import ( approx_kl, - batch_process_fine_granularities, dist_compute_masked_mean_and_var, get_log_probs, get_entropies, - scatter_gather_rewards, switch_left_to_right_padding, mask_eos, masked_sum, @@ -84,7 +74,6 @@ get_decoded_sequence, ) from compose_rl.algorithms.online.reward_manager import ( - ReferenceOutput, RewardOutput, ) @@ -178,8 +167,8 @@ def build_train_config(self, config: Any): self.pretrain_model_name = self.config.model.pretrained_model_name_or_path self.model_config = om.to_container(self.config.model, resolve=True) - self.model_config['tokenizer'] = self.tokenizer - self.loss_type = self.model_config.get('loss_type', OnPolicyEnum.GRPO) + self.model_config['tokenizer'] = self.tokenizer # type: ignore + self.loss_type = self.model_config.get('loss_type', OnPolicyEnum.GRPO) # type: ignore print("--------------------------------") print(f'loss_type: {self.loss_type}') print("--------------------------------") @@ -409,7 +398,7 @@ def create_online_minibatches(self, current_rank_rollouts: dict[str, Any]): # Construct batch bs = partial_batch['prompt_id'].shape[0] batch = { - 'max_gen_len': torch.ones(bs).to(torch.int32) * self.max_gen_len, + 'max_gen_len': torch.ones(bs).to(torch.int32) * self.max_gen_len, # type: ignore 'ift_kl_scalar': torch.ones(bs) * self.kl_controller.value, **partial_batch, **reference_output, @@ -520,7 +509,7 @@ def get_log_probs_and_entropy(self, current_rank_rollouts: dict[str, Any], devic batch_size, device=device, dtype=prompt_dtype, - ) * self.max_gen_len + ) * self.max_gen_len # type: ignore # If all the processes early exit generate, then we need to manually pad everything # we can pad this with pad tokens, since we switch the padding between left and right @@ -721,7 +710,7 @@ def update_rewards(self, raw_rewards_dict: dict[str, Any], ref_output: dict[str, if self.kl_penalty_in_reward: rewards: torch.Tensor = -self.kl_controller.value * ref_kl.detach() else: - rewards: torch.Tensor = torch.zeros_like(ref_kl) + rewards = torch.zeros_like(ref_kl) env_rewards = torch.zeros_like(rewards) rews_dict_out: dict[str, torch.Tensor] = {} @@ -1401,7 +1390,7 @@ def __init__( self.tokenizer = ray.get(self.streaming_dataset_actor.get_tokenizer.remote()) self.tokenizer_pad_token_id = ray.get(self.streaming_dataset_actor.get_tokenizer_pad_token_id.remote()) - if self.tokenizer_pad_token_id is None: + if self.tokenizer_pad_token_id is None: # type: ignore raise ValueError( 'Tokenizer does not have a pad token id. Please use a different tokenizer or add a pad token id.', ) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index a582eae3..1e903a6c 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -32,7 +32,7 @@ parameters: seed: 17 model: name: hf_critic_free_lm - loss_type: smd #grpo + loss_type: smd #grpo #grpo target_kl: 0.1 pretrained: true init_device: mixed From 9090acdfc89442c7261cee62b12d74a4a1852012 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 00:47:54 -0400 Subject: [PATCH 19/74] . --- compose_rl/algorithms/online/model_methods.py | 23 +++++----------- test_single_controller_ppo.py | 27 +++++++++++++------ 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 8d22a800..adaf3ea9 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -224,7 +224,6 @@ def critic_loss( def policy_loss( advantages: torch.Tensor | None, - prompt_advantages: torch.Tensor | None, outputs: MutableMapping, batch: MutableMapping, loss_type: OnPolicyEnum, @@ -398,11 +397,13 @@ def policy_loss( elif loss_type in ALGORITHM_TYPE.REGRESSION: # current it only supports SMD # TODO: add APO support + prompt_advantages = batch['prompt_advantages'] assert prompt_advantages is not None assert prompt_advantages.dim() == 1 # (bs,) print("########################") print(f'loss_type: {loss_type}') + print(f'prompt_advantages shape: {prompt_advantages.shape}') print("########################") online_log_probs = outputs['online_log_probs'] @@ -458,8 +459,8 @@ def policy_loss( rewards, ), #compute the average reward of the current batch 'advantages/mean': torch.mean( - prompt_advantages, - ), #compute the average of the vstar of the current batch + prompt_advantages, # SMD uses prompt_advantages, not advantages + ), #compute the average of the prompt advantages for SMD } return policy_dict @@ -510,18 +511,9 @@ def online_rl_loss( # tensors in `outputs` are recomputed at the start of each step in the epoch. return_dict = {} - #advantages = None - advantages = batch['advantages'] - assert advantages.dim() == 2 #(bs, max_gen_len) - prompt_advantages = batch['prompt_advantages'] - assert prompt_advantages.dim() == 1 #(bs,) - - #if loss_type not in ALGORITHM_TYPE.REGRESSION: #GRPO and PPO: - # advantages = batch['advantages'] - # assert advantages.dim() == 2 #(bs, max_gen_len) - #elif loss_type == OnPolicyEnum.SMD: - # advantages = batch['prompt_advantages'] - # assert advantages.dim() == 1 #(bs,) + advantages = None + if loss_type not in ALGORITHM_TYPE.REGRESSION: + advantages = batch['advantages'] # 1. Critic Loss if loss_type in ALGORITHM_TYPE.ACTOR_CRITIC: @@ -557,7 +549,6 @@ def online_rl_loss( # 2. Policy Loss policy_dict = policy_loss( advantages=advantages, - prompt_advantages=prompt_advantages, outputs=outputs, batch=batch, loss_type=loss_type, diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 4f00ecd9..012b40b8 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -12,24 +12,27 @@ import argparse import asyncio +import copy from contextlib import contextmanager import logging import os import pickle import time import datetime +from itertools import chain from functools import partial -from typing import Any, Optional +from typing import Any, Optional, Union, MutableMapping from multiprocessing import get_context from multiprocessing.context import TimeoutError as MultiprocessingTimeoutError from multiprocessing.pool import AsyncResult, Pool from composer.loggers import MLFlowLogger import ray +import spacy import torch import torch.distributed as dist from composer import Trainer -from composer.core import get_precision_context +from composer.core import get_precision_context, Precision from composer.core.data_spec import _default_split_batch from composer.trainer.trainer import _get_initial_device_train_microbatch_size from compose_rl.data.buffer import MinibatchRolloutBuffer @@ -59,14 +62,21 @@ from compose_rl.algorithms.online.callback_utils import preprocess_batches from compose_rl.registry_builders import build_reward from compose_rl.registry import rewards as rewards_registry +from compose_rl.interfaces.base_kl_controller import BaseKLController from compose_rl.algorithms.reward_modeling import ( + BadGenerationEndReward, + BaseReward, + InferenceRewardModel, Reward, + RewardModel, ) from compose_rl.utils import ( approx_kl, + batch_process_fine_granularities, dist_compute_masked_mean_and_var, get_log_probs, get_entropies, + scatter_gather_rewards, switch_left_to_right_padding, mask_eos, masked_sum, @@ -74,6 +84,7 @@ get_decoded_sequence, ) from compose_rl.algorithms.online.reward_manager import ( + ReferenceOutput, RewardOutput, ) @@ -167,8 +178,8 @@ def build_train_config(self, config: Any): self.pretrain_model_name = self.config.model.pretrained_model_name_or_path self.model_config = om.to_container(self.config.model, resolve=True) - self.model_config['tokenizer'] = self.tokenizer # type: ignore - self.loss_type = self.model_config.get('loss_type', OnPolicyEnum.GRPO) # type: ignore + self.model_config['tokenizer'] = self.tokenizer + self.loss_type = self.model_config.get('loss_type', OnPolicyEnum.GRPO) print("--------------------------------") print(f'loss_type: {self.loss_type}') print("--------------------------------") @@ -398,7 +409,7 @@ def create_online_minibatches(self, current_rank_rollouts: dict[str, Any]): # Construct batch bs = partial_batch['prompt_id'].shape[0] batch = { - 'max_gen_len': torch.ones(bs).to(torch.int32) * self.max_gen_len, # type: ignore + 'max_gen_len': torch.ones(bs).to(torch.int32) * self.max_gen_len, 'ift_kl_scalar': torch.ones(bs) * self.kl_controller.value, **partial_batch, **reference_output, @@ -509,7 +520,7 @@ def get_log_probs_and_entropy(self, current_rank_rollouts: dict[str, Any], devic batch_size, device=device, dtype=prompt_dtype, - ) * self.max_gen_len # type: ignore + ) * self.max_gen_len # If all the processes early exit generate, then we need to manually pad everything # we can pad this with pad tokens, since we switch the padding between left and right @@ -710,7 +721,7 @@ def update_rewards(self, raw_rewards_dict: dict[str, Any], ref_output: dict[str, if self.kl_penalty_in_reward: rewards: torch.Tensor = -self.kl_controller.value * ref_kl.detach() else: - rewards = torch.zeros_like(ref_kl) + rewards: torch.Tensor = torch.zeros_like(ref_kl) env_rewards = torch.zeros_like(rewards) rews_dict_out: dict[str, torch.Tensor] = {} @@ -1390,7 +1401,7 @@ def __init__( self.tokenizer = ray.get(self.streaming_dataset_actor.get_tokenizer.remote()) self.tokenizer_pad_token_id = ray.get(self.streaming_dataset_actor.get_tokenizer_pad_token_id.remote()) - if self.tokenizer_pad_token_id is None: # type: ignore + if self.tokenizer_pad_token_id is None: raise ValueError( 'Tokenizer does not have a pad token id. Please use a different tokenizer or add a pad token id.', ) From a1d5dfe213560b312df0c5f88627b6a04e5d726e Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 01:08:58 -0400 Subject: [PATCH 20/74] . --- compose_rl/algorithms/online/model_methods.py | 23 +++++++++++++++++-- yamls/single-controller-grpo-workflow.yaml | 2 +- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index adaf3ea9..113f447c 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -235,6 +235,8 @@ def policy_loss( kl_clip_range: Optional[float] = 40.0, ) -> MutableMapping: + print(f"DEBUG: policy_loss called with loss_type: {loss_type}") + if loss_type in ALGORITHM_TYPE.CLIPPED_PG: assert advantages is not None assert advantages.dim() == 2 #(bs, max_gen_len) @@ -406,24 +408,32 @@ def policy_loss( print(f'prompt_advantages shape: {prompt_advantages.shape}') print("########################") + print("DEBUG: Getting log probs...") online_log_probs = outputs['online_log_probs'] ref_log_probs = batch['ift_log_probs'] old_entropies = batch['old_entropies'] - old_log_probs = batch['old_log_probs'] + print(f"DEBUG: online_log_probs shape: {online_log_probs.shape}") + print(f"DEBUG: old_log_probs shape: {old_log_probs.shape}") + + print("DEBUG: Computing log prob diff...") online_to_old_diff = online_log_probs - old_log_probs # ln(π/π_old) for SMD + print(f"DEBUG: online_to_old_diff shape: {online_to_old_diff.shape}") + print("DEBUG: Computing KL estimates...") #compute KL to pi_ref to keep track the divergence to \pi_ref policy_kl_dict = utils.approx_kl( log_p=ref_log_probs, log_q=online_log_probs, #log_q - log_p = log pi - log pi_ref kl_clip_range=kl_clip_range, ) + print("DEBUG: First KL computed") old_policy_kl_dict = utils.approx_kl( log_p=old_log_probs, log_q=online_log_probs, #log_q - log_p = log pi - log pi_ref kl_clip_range=kl_clip_range, ) + print("DEBUG: Second KL computed") with torch.no_grad(): policy_kl = utils.masked_mean( policy_kl_dict[kl_estimator], # pyright: ignore @@ -433,22 +443,30 @@ def policy_loss( old_policy_kl_dict[kl_estimator], # pyright: ignore batch['action_mask'], ) #plain average over all tokens (KL to pi_ref) + print("DEBUG: KL means computed") + print("DEBUG: Computing policy loss...") #compute the policy loss for SMD; masked_log_probs_diff = utils.masked_sum( online_to_old_diff, # Correct: ln(π/π_old) batch['action_mask'], dim=-1, ) #size: (batch_size,) + print(f"DEBUG: masked_log_probs_diff shape: {masked_log_probs_diff.shape}") + print(f"DEBUG: beta: {beta}") - policy_loss = ((beta * masked_log_probs_diff -prompt_advantages)**2).mean() + policy_loss = ((beta * masked_log_probs_diff - prompt_advantages)**2).mean() + print(f"DEBUG: policy_loss computed: {policy_loss}") + print("DEBUG: Computing rewards...") rewards = utils.masked_sum( batch['rewards'], batch['action_mask'], dim=-1, ) + print(f"DEBUG: rewards shape: {rewards.shape}") + print("DEBUG: Creating return dictionary...") policy_dict = { 'loss/policy_loss': policy_loss, 'kl/ref_policy_kl': policy_kl, @@ -462,6 +480,7 @@ def policy_loss( prompt_advantages, # SMD uses prompt_advantages, not advantages ), #compute the average of the prompt advantages for SMD } + print("DEBUG: Policy dict created successfully") return policy_dict else: diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 1e903a6c..9756ad84 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -32,7 +32,7 @@ parameters: seed: 17 model: name: hf_critic_free_lm - loss_type: smd #grpo #grpo + loss_type: smd #smd #grpo #grpo target_kl: 0.1 pretrained: true init_device: mixed From 808dcd12c9fcf0574e8603f105d6298371624907 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 01:19:48 -0400 Subject: [PATCH 21/74] . --- compose_rl/algorithms/online/model_methods.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 113f447c..f9021552 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -399,7 +399,7 @@ def policy_loss( elif loss_type in ALGORITHM_TYPE.REGRESSION: # current it only supports SMD # TODO: add APO support - prompt_advantages = batch['prompt_advantages'] + prompt_advantages = batch['prompt_advantages'].detach() assert prompt_advantages is not None assert prompt_advantages.dim() == 1 # (bs,) @@ -454,6 +454,9 @@ def policy_loss( ) #size: (batch_size,) print(f"DEBUG: masked_log_probs_diff shape: {masked_log_probs_diff.shape}") print(f"DEBUG: beta: {beta}") + print(f"DEBUG: About to compute policy loss with shapes:") + print(f" - masked_log_probs_diff: {masked_log_probs_diff.shape}") + print(f" - prompt_advantages: {prompt_advantages.shape}") policy_loss = ((beta * masked_log_probs_diff - prompt_advantages)**2).mean() print(f"DEBUG: policy_loss computed: {policy_loss}") From 746310a2f4ea158bfb59f0917b6e4cd994011650 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 08:25:26 -0400 Subject: [PATCH 22/74] . --- compose_rl/algorithms/online/model_methods.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index f9021552..3d6aaf08 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -458,7 +458,23 @@ def policy_loss( print(f" - masked_log_probs_diff: {masked_log_probs_diff.shape}") print(f" - prompt_advantages: {prompt_advantages.shape}") - policy_loss = ((beta * masked_log_probs_diff - prompt_advantages)**2).mean() + #policy_loss = ((beta * masked_log_probs_diff - prompt_advantages)**2).mean() + print(f"DEBUG: Tensor details:") + print(f" - masked_log_probs_diff device: {masked_log_probs_diff.device}, dtype: {masked_log_probs_diff.dtype}") + print(f" - prompt_advantages device: {prompt_advantages.device}, dtype: {prompt_advantages.dtype}") + print(f" - masked_log_probs_diff value: {masked_log_probs_diff}") + print(f" - prompt_advantages value: {prompt_advantages}") + print(f"DEBUG: About to compute: beta * masked_log_probs_diff - prompt_advantages") + + try: + temp_result = beta * masked_log_probs_diff - prompt_advantages + print(f"DEBUG: Subtraction successful, result: {temp_result}") + policy_loss = (temp_result**2).mean() + except Exception as e: + print(f"DEBUG: Error during computation: {e}") + raise + + print(f"DEBUG: policy_loss computed: {policy_loss}") print("DEBUG: Computing rewards...") From 643431684512631157d890b74dcdd2b8ba284efa Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 08:38:37 -0400 Subject: [PATCH 23/74] . --- compose_rl/algorithms/online/model_methods.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 3d6aaf08..cf8949c9 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -446,12 +446,16 @@ def policy_loss( print("DEBUG: KL means computed") print("DEBUG: Computing policy loss...") + print(f"DEBUG: Before masked_sum - online_to_old_diff shape: {online_to_old_diff.shape}") + print(f"DEBUG: Before masked_sum - action_mask shape: {batch['action_mask'].shape}") + #compute the policy loss for SMD; masked_log_probs_diff = utils.masked_sum( online_to_old_diff, # Correct: ln(π/π_old) batch['action_mask'], dim=-1, ) #size: (batch_size,) + print(f"DEBUG: After masked_sum - masked_log_probs_diff created successfully") print(f"DEBUG: masked_log_probs_diff shape: {masked_log_probs_diff.shape}") print(f"DEBUG: beta: {beta}") print(f"DEBUG: About to compute policy loss with shapes:") @@ -467,11 +471,26 @@ def policy_loss( print(f"DEBUG: About to compute: beta * masked_log_probs_diff - prompt_advantages") try: - temp_result = beta * masked_log_probs_diff - prompt_advantages - print(f"DEBUG: Subtraction successful, result: {temp_result}") - policy_loss = (temp_result**2).mean() + print("DEBUG: Step 1 - Computing beta * masked_log_probs_diff...") + step1 = beta * masked_log_probs_diff + print(f"DEBUG: Step 1 result: {step1}") + + print("DEBUG: Step 2 - Subtracting prompt_advantages...") + step2 = step1 - prompt_advantages + print(f"DEBUG: Step 2 result: {step2}") + + print("DEBUG: Step 3 - Squaring...") + step3 = step2 ** 2 + print(f"DEBUG: Step 3 result: {step3}") + + print("DEBUG: Step 4 - Taking mean...") + policy_loss = step3.mean() + print(f"DEBUG: Final policy_loss: {policy_loss}") + except Exception as e: print(f"DEBUG: Error during computation: {e}") + import traceback + print(f"DEBUG: Full traceback: {traceback.format_exc()}") raise From b578daf78ea4b76e78dd536c9c94110dab208a0d Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 09:48:41 -0400 Subject: [PATCH 24/74] . --- compose_rl/algorithms/online/model_methods.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index cf8949c9..d542fd80 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -472,7 +472,17 @@ def policy_loss( try: print("DEBUG: Step 1 - Computing beta * masked_log_probs_diff...") - step1 = beta * masked_log_probs_diff + print(f"DEBUG: beta type: {type(beta)}") + print(f"DEBUG: beta value: {beta}") + if hasattr(beta, 'shape'): + print(f"DEBUG: beta shape: {beta.shape}") + if hasattr(beta, 'device'): + print(f"DEBUG: beta device: {beta.device}") + + # Convert beta to a simple float to avoid tensor indexing issues + beta_float = float(beta) + print(f"DEBUG: beta_float: {beta_float}") + step1 = beta_float * masked_log_probs_diff print(f"DEBUG: Step 1 result: {step1}") print("DEBUG: Step 2 - Subtracting prompt_advantages...") From fa59b25a8418a71cc8b53bb8646fb19b057c966b Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 10:05:23 -0400 Subject: [PATCH 25/74] . --- compose_rl/algorithms/online/model_methods.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index d542fd80..5f13a16a 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -626,9 +626,13 @@ def online_rl_loss( kl_estimator=kl_estimator, kl_clip_range=kl_clip_range, ) + print("DEBUG: Policy loss function completed successfully") + print("DEBUG: About to update return_dict with policy_dict") return_dict.update(**policy_dict) + print("DEBUG: return_dict updated successfully") + print("DEBUG: Starting batch items processing...") for key, value in batch.items(): # This logic handles reward logging a little differently than other quantities. # For rewards shaped as [batch, actions] we log (1) the per-sequence masked average From 1f6af3730a7d3156983819651bb4b66f3c865cd1 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 10:19:19 -0400 Subject: [PATCH 26/74] . --- compose_rl/algorithms/online/model_methods.py | 72 +++++++++++++------ yamls/single-controller-grpo-workflow.yaml | 2 +- 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 5f13a16a..1ced47b7 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -633,35 +633,61 @@ def online_rl_loss( print("DEBUG: return_dict updated successfully") print("DEBUG: Starting batch items processing...") - for key, value in batch.items(): - # This logic handles reward logging a little differently than other quantities. - # For rewards shaped as [batch, actions] we log (1) the per-sequence masked average - # and (2) the per-sequence masked sum over actions, both size [batch]. - # We then average over [batch], so the interpretation is (1) the average per-token - # reward, and (2) the average total reward. - if 'reward' in key: - if value.shape == batch['action_mask'].shape: - # Average reward per timestep - return_dict['env/' + str(key) + '_mean'] = utils.masked_mean( - value, - batch['action_mask'], - dim=1, - ).mean(dim=0) - # Total reward over timesteps - return_dict['env/' + str(key) + '_total'] = utils.masked_sum( + try: + for key, value in batch.items(): + print(f"DEBUG: Processing batch key: {key}") + print(f"DEBUG: Value type: {type(value)}, shape: {getattr(value, 'shape', 'N/A')}") + + # This logic handles reward logging a little differently than other quantities. + # For rewards shaped as [batch, actions] we log (1) the per-sequence masked average + # and (2) the per-sequence masked sum over actions, both size [batch]. + # We then average over [batch], so the interpretation is (1) the average per-token + # reward, and (2) the average total reward. + if 'reward' in key: + print(f"DEBUG: Processing reward key: {key}") + print(f"DEBUG: action_mask shape: {batch['action_mask'].shape}") + print(f"DEBUG: value shape: {value.shape}") + + if value.shape == batch['action_mask'].shape: + print(f"DEBUG: Shapes match, computing masked operations...") + # Average reward per timestep + return_dict['env/' + str(key) + '_mean'] = utils.masked_mean( + value, + batch['action_mask'], + dim=1, + ).mean(dim=0) + print(f"DEBUG: Masked mean computed for {key}") + + # Total reward over timesteps + return_dict['env/' + str(key) + '_total'] = utils.masked_sum( + value, + batch['action_mask'], + dim=1, + ).mean(dim=0) + print(f"DEBUG: Masked sum computed for {key}") + else: + print(f"DEBUG: Shapes don't match, skipping {key}") + elif 'ift_kl' == key: + print(f"DEBUG: Processing ift_kl key: {key}") + return_dict['kl/' + str(key)] = utils.masked_mean( value, batch['action_mask'], - dim=1, - ).mean(dim=0) + ) + print(f"DEBUG: ift_kl processed successfully") else: + print(f"DEBUG: Processing non-reward key: {key}") # If this value is not [batch, actions] shaped, just do a # vanilla mean. return_dict['env/' + str(key)] = value.mean(dim=0) - if 'ift_kl' == key: - return_dict['kl/' + str(key)] = utils.masked_mean( - value, - batch['action_mask'], - ) + print(f"DEBUG: Non-reward key {key} processed successfully") + + print("DEBUG: Batch items processing completed successfully") + + except Exception as e: + print(f"DEBUG: Error in batch processing: {e}") + import traceback + print(f"DEBUG: Traceback: {traceback.format_exc()}") + raise # 3. Compute the total loss return_dict['total'] = return_dict['loss/policy_loss'] diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 9756ad84..6088f8fe 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -194,7 +194,7 @@ parameters: device_eval_batch_size: 1 eval_subset_num_batches: -1 global_train_batch_size: 64 - device_train_microbatch_size: 1 + device_train_microbatch_size: 4 # Increased for better SMD batch statistics vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 From 1a8adc81ed857df1b8e6802ac286c554bd280262 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 10:31:39 -0400 Subject: [PATCH 27/74] . --- compose_rl/algorithms/online/model_methods.py | 7 ++----- yamls/single-controller-grpo-workflow.yaml | 4 ++-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 1ced47b7..ccaf61d7 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -675,11 +675,8 @@ def online_rl_loss( ) print(f"DEBUG: ift_kl processed successfully") else: - print(f"DEBUG: Processing non-reward key: {key}") - # If this value is not [batch, actions] shaped, just do a - # vanilla mean. - return_dict['env/' + str(key)] = value.mean(dim=0) - print(f"DEBUG: Non-reward key {key} processed successfully") + print(f"DEBUG: Skipping non-essential key: {key}") + # Skip all other keys - we only need rewards and ift_kl print("DEBUG: Batch items processing completed successfully") diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 6088f8fe..948ecead 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -37,7 +37,7 @@ parameters: pretrained: true init_device: mixed kl_estimator: k3 - beta: 1e-3 + beta: 0.001 kl_clip_range: 40 use_auth_token: true compute_kl_loss: false @@ -194,7 +194,7 @@ parameters: device_eval_batch_size: 1 eval_subset_num_batches: -1 global_train_batch_size: 64 - device_train_microbatch_size: 4 # Increased for better SMD batch statistics + device_train_microbatch_size: 1 # Increased for better SMD batch statistics vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 From 8cb130e18ef33d72a0a575cceb3f576896254c66 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 10:42:01 -0400 Subject: [PATCH 28/74] . --- compose_rl/algorithms/online/model_methods.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index ccaf61d7..958751a5 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -687,19 +687,39 @@ def online_rl_loss( raise # 3. Compute the total loss - return_dict['total'] = return_dict['loss/policy_loss'] + print("DEBUG: About to compute total loss") + print(f"DEBUG: return_dict keys: {list(return_dict.keys())}") + print(f"DEBUG: Checking for 'loss/policy_loss' key...") + + try: + return_dict['total'] = return_dict['loss/policy_loss'] + print("DEBUG: Total loss assigned successfully") + except KeyError as e: + print(f"DEBUG: KeyError accessing policy_loss: {e}") + print(f"DEBUG: Available keys: {list(return_dict.keys())}") + raise + + print("DEBUG: Checking ACTOR_CRITIC condition...") if loss_type in ALGORITHM_TYPE.ACTOR_CRITIC: + print("DEBUG: Adding value loss to total (ACTOR_CRITIC)") # Add value loss to total loss return_dict['total'] += value_loss_weight * return_dict[ 'loss/value_loss'] # pyright: ignore + else: + print("DEBUG: Skipping value loss (not ACTOR_CRITIC)") + + print("DEBUG: Checking add_direct_kl_loss condition...") # If we want to directly minimize the KL Divergence, we can do so here # and it will not include the KL in the reward. if add_direct_kl_loss: + print("DEBUG: Adding direct KL loss") return_dict['total'] += batch['ift_kl_scalar'][0] * return_dict[ 'kl/online_ift_kl'] return_dict['loss/online_ift_kl'] = ( batch['ift_kl_scalar'][0] * return_dict['kl/online_ift_kl'] ) + else: + print("DEBUG: Skipping direct KL loss") # Entropy Loss. Meant to promote diversity. if entropy_loss_weight is not None: From b65384fcfe41505b481d64980bee3e80796ad08b Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 10:55:39 -0400 Subject: [PATCH 29/74] . --- compose_rl/algorithms/online/model_methods.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 958751a5..717e7297 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -721,8 +721,11 @@ def online_rl_loss( else: print("DEBUG: Skipping direct KL loss") + print("DEBUG: Checking entropy loss...") # Entropy Loss. Meant to promote diversity. if entropy_loss_weight is not None: + print(f"DEBUG: Processing entropy loss with weight: {entropy_loss_weight}") + print(f"DEBUG: Looking for 'gen/cur_seq_entropy' in return_dict keys: {list(return_dict.keys())}") # We want to maximize entropy so we deduct it from the loss. entropy_loss = -1.0 * ( entropy_loss_weight * return_dict['gen/cur_seq_entropy'] @@ -730,14 +733,32 @@ def online_rl_loss( # breakpoint() return_dict['loss/entropy'] = entropy_loss return_dict['total'] += entropy_loss + print("DEBUG: Entropy loss processed successfully") + else: + print("DEBUG: Skipping entropy loss (weight is None)") + print("DEBUG: Checking label loss...") if 'lbl' in outputs and outputs['lbl'] is not None: + print("DEBUG: Processing label loss") return_dict['loss/lbl'] = outputs['lbl'] return_dict['total'] += outputs['lbl'] + print("DEBUG: Label loss processed successfully") + else: + print("DEBUG: Skipping label loss") - # Detaching all return_dict values - for key, value in return_dict.items(): - if key not in 'total': - return_dict[key] = value.detach().cpu() + print("DEBUG: Starting detachment of return_dict values...") + try: + # Detaching all return_dict values + for key, value in return_dict.items(): + print(f"DEBUG: Detaching key: {key}") + if key not in 'total': + return_dict[key] = value.detach().cpu() + print("DEBUG: All values detached successfully") + except Exception as e: + print(f"DEBUG: Error during detachment: {e}") + import traceback + print(f"DEBUG: Traceback: {traceback.format_exc()}") + raise + print("DEBUG: About to return return_dict") return return_dict From 68b9bfb8280c15b674ef05de6ba0065af297b9fa Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 13:02:45 -0400 Subject: [PATCH 30/74] . --- compose_rl/algorithms/online/model_methods.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 717e7297..1236bbda 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -517,6 +517,7 @@ def policy_loss( print("DEBUG: Creating return dictionary...") policy_dict = { 'loss/policy_loss': policy_loss, + 'kl/policy_kl': policy_kl, # Required by calling code in model.py 'kl/ref_policy_kl': policy_kl, 'kl/old_policy_kl': old_policy_kl, 'gen/gen_length': batch['action_mask'].sum(dim=1).to(torch.float32), @@ -761,4 +762,16 @@ def online_rl_loss( raise print("DEBUG: About to return return_dict") - return return_dict + print(f"DEBUG: return_dict type: {type(return_dict)}") + print(f"DEBUG: return_dict keys: {list(return_dict.keys())}") + print(f"DEBUG: return_dict size: {len(return_dict)}") + + try: + result = return_dict + print("DEBUG: Return assignment successful") + return result + except Exception as e: + print(f"DEBUG: Error during return: {e}") + import traceback + print(f"DEBUG: Traceback: {traceback.format_exc()}") + raise From ee71caf27f7c83c49eb3426210079c8ae0ebfaea Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 26 Aug 2025 20:57:06 -0400 Subject: [PATCH 31/74] start clean up --- compose_rl/algorithms/online/model_methods.py | 240 ++++-------------- yamls/single-controller-grpo-workflow.yaml | 10 +- 2 files changed, 51 insertions(+), 199 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 1236bbda..96238135 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -235,8 +235,6 @@ def policy_loss( kl_clip_range: Optional[float] = 40.0, ) -> MutableMapping: - print(f"DEBUG: policy_loss called with loss_type: {loss_type}") - if loss_type in ALGORITHM_TYPE.CLIPPED_PG: assert advantages is not None assert advantages.dim() == 2 #(bs, max_gen_len) @@ -403,37 +401,25 @@ def policy_loss( assert prompt_advantages is not None assert prompt_advantages.dim() == 1 # (bs,) - print("########################") - print(f'loss_type: {loss_type}') - print(f'prompt_advantages shape: {prompt_advantages.shape}') - print("########################") - - print("DEBUG: Getting log probs...") online_log_probs = outputs['online_log_probs'] ref_log_probs = batch['ift_log_probs'] old_entropies = batch['old_entropies'] old_log_probs = batch['old_log_probs'] - print(f"DEBUG: online_log_probs shape: {online_log_probs.shape}") - print(f"DEBUG: old_log_probs shape: {old_log_probs.shape}") - - print("DEBUG: Computing log prob diff...") online_to_old_diff = online_log_probs - old_log_probs # ln(π/π_old) for SMD - print(f"DEBUG: online_to_old_diff shape: {online_to_old_diff.shape}") - - print("DEBUG: Computing KL estimates...") + #compute KL to pi_ref to keep track the divergence to \pi_ref policy_kl_dict = utils.approx_kl( log_p=ref_log_probs, log_q=online_log_probs, #log_q - log_p = log pi - log pi_ref kl_clip_range=kl_clip_range, ) - print("DEBUG: First KL computed") + old_policy_kl_dict = utils.approx_kl( log_p=old_log_probs, log_q=online_log_probs, #log_q - log_p = log pi - log pi_ref kl_clip_range=kl_clip_range, ) - print("DEBUG: Second KL computed") + with torch.no_grad(): policy_kl = utils.masked_mean( policy_kl_dict[kl_estimator], # pyright: ignore @@ -443,11 +429,6 @@ def policy_loss( old_policy_kl_dict[kl_estimator], # pyright: ignore batch['action_mask'], ) #plain average over all tokens (KL to pi_ref) - print("DEBUG: KL means computed") - - print("DEBUG: Computing policy loss...") - print(f"DEBUG: Before masked_sum - online_to_old_diff shape: {online_to_old_diff.shape}") - print(f"DEBUG: Before masked_sum - action_mask shape: {batch['action_mask'].shape}") #compute the policy loss for SMD; masked_log_probs_diff = utils.masked_sum( @@ -455,66 +436,16 @@ def policy_loss( batch['action_mask'], dim=-1, ) #size: (batch_size,) - print(f"DEBUG: After masked_sum - masked_log_probs_diff created successfully") - print(f"DEBUG: masked_log_probs_diff shape: {masked_log_probs_diff.shape}") - print(f"DEBUG: beta: {beta}") - print(f"DEBUG: About to compute policy loss with shapes:") - print(f" - masked_log_probs_diff: {masked_log_probs_diff.shape}") - print(f" - prompt_advantages: {prompt_advantages.shape}") - - #policy_loss = ((beta * masked_log_probs_diff - prompt_advantages)**2).mean() - print(f"DEBUG: Tensor details:") - print(f" - masked_log_probs_diff device: {masked_log_probs_diff.device}, dtype: {masked_log_probs_diff.dtype}") - print(f" - prompt_advantages device: {prompt_advantages.device}, dtype: {prompt_advantages.dtype}") - print(f" - masked_log_probs_diff value: {masked_log_probs_diff}") - print(f" - prompt_advantages value: {prompt_advantages}") - print(f"DEBUG: About to compute: beta * masked_log_probs_diff - prompt_advantages") - - try: - print("DEBUG: Step 1 - Computing beta * masked_log_probs_diff...") - print(f"DEBUG: beta type: {type(beta)}") - print(f"DEBUG: beta value: {beta}") - if hasattr(beta, 'shape'): - print(f"DEBUG: beta shape: {beta.shape}") - if hasattr(beta, 'device'): - print(f"DEBUG: beta device: {beta.device}") - - # Convert beta to a simple float to avoid tensor indexing issues - beta_float = float(beta) - print(f"DEBUG: beta_float: {beta_float}") - step1 = beta_float * masked_log_probs_diff - print(f"DEBUG: Step 1 result: {step1}") - - print("DEBUG: Step 2 - Subtracting prompt_advantages...") - step2 = step1 - prompt_advantages - print(f"DEBUG: Step 2 result: {step2}") - - print("DEBUG: Step 3 - Squaring...") - step3 = step2 ** 2 - print(f"DEBUG: Step 3 result: {step3}") - - print("DEBUG: Step 4 - Taking mean...") - policy_loss = step3.mean() - print(f"DEBUG: Final policy_loss: {policy_loss}") - - except Exception as e: - print(f"DEBUG: Error during computation: {e}") - import traceback - print(f"DEBUG: Full traceback: {traceback.format_exc()}") - raise - - - print(f"DEBUG: policy_loss computed: {policy_loss}") + # Convert beta to a simple float + beta_float = float(beta) + policy_loss = ((beta_float * masked_log_probs_diff - prompt_advantages)**2).mean() - print("DEBUG: Computing rewards...") rewards = utils.masked_sum( batch['rewards'], batch['action_mask'], dim=-1, ) - print(f"DEBUG: rewards shape: {rewards.shape}") - - print("DEBUG: Creating return dictionary...") + policy_dict = { 'loss/policy_loss': policy_loss, 'kl/policy_kl': policy_kl, # Required by calling code in model.py @@ -529,9 +460,7 @@ def policy_loss( prompt_advantages, # SMD uses prompt_advantages, not advantages ), #compute the average of the prompt advantages for SMD } - print("DEBUG: Policy dict created successfully") return policy_dict - else: raise ValueError(f'Policy loss not implemented for {loss_type}') @@ -580,7 +509,7 @@ def online_rl_loss( return_dict = {} advantages = None - if loss_type not in ALGORITHM_TYPE.REGRESSION: + if loss_type not in ALGORITHM_TYPE.REGRESSION: # basically grpo/ppo advantages = batch['advantages'] # 1. Critic Loss @@ -627,106 +556,58 @@ def online_rl_loss( kl_estimator=kl_estimator, kl_clip_range=kl_clip_range, ) - print("DEBUG: Policy loss function completed successfully") - print("DEBUG: About to update return_dict with policy_dict") return_dict.update(**policy_dict) - print("DEBUG: return_dict updated successfully") - - print("DEBUG: Starting batch items processing...") - try: - for key, value in batch.items(): - print(f"DEBUG: Processing batch key: {key}") - print(f"DEBUG: Value type: {type(value)}, shape: {getattr(value, 'shape', 'N/A')}") - - # This logic handles reward logging a little differently than other quantities. - # For rewards shaped as [batch, actions] we log (1) the per-sequence masked average - # and (2) the per-sequence masked sum over actions, both size [batch]. - # We then average over [batch], so the interpretation is (1) the average per-token - # reward, and (2) the average total reward. - if 'reward' in key: - print(f"DEBUG: Processing reward key: {key}") - print(f"DEBUG: action_mask shape: {batch['action_mask'].shape}") - print(f"DEBUG: value shape: {value.shape}") - - if value.shape == batch['action_mask'].shape: - print(f"DEBUG: Shapes match, computing masked operations...") - # Average reward per timestep - return_dict['env/' + str(key) + '_mean'] = utils.masked_mean( - value, - batch['action_mask'], - dim=1, - ).mean(dim=0) - print(f"DEBUG: Masked mean computed for {key}") + + + for key, value in batch.items(): + # This logic handles reward logging a little differently than other quantities. + # For rewards shaped as [batch, actions] we log (1) the per-sequence masked average + # and (2) the per-sequence masked sum over actions, both size [batch]. + # We then average over [batch], so the interpretation is (1) the average per-token + # reward, and (2) the average total reward. + if 'reward' in key: + if value.shape == batch['action_mask'].shape: + print(f"DEBUG: Shapes match, computing masked operations...") + # Average reward per timestep + return_dict['env/' + str(key) + '_mean'] = utils.masked_mean( + value, + batch['action_mask'], + dim=1, + ).mean(dim=0) - # Total reward over timesteps - return_dict['env/' + str(key) + '_total'] = utils.masked_sum( - value, - batch['action_mask'], - dim=1, - ).mean(dim=0) - print(f"DEBUG: Masked sum computed for {key}") - else: - print(f"DEBUG: Shapes don't match, skipping {key}") - elif 'ift_kl' == key: - print(f"DEBUG: Processing ift_kl key: {key}") - return_dict['kl/' + str(key)] = utils.masked_mean( + # Total reward over timesteps + return_dict['env/' + str(key) + '_total'] = utils.masked_sum( value, batch['action_mask'], - ) - print(f"DEBUG: ift_kl processed successfully") - else: - print(f"DEBUG: Skipping non-essential key: {key}") - # Skip all other keys - we only need rewards and ift_kl - - print("DEBUG: Batch items processing completed successfully") - - except Exception as e: - print(f"DEBUG: Error in batch processing: {e}") - import traceback - print(f"DEBUG: Traceback: {traceback.format_exc()}") - raise - - # 3. Compute the total loss - print("DEBUG: About to compute total loss") - print(f"DEBUG: return_dict keys: {list(return_dict.keys())}") - print(f"DEBUG: Checking for 'loss/policy_loss' key...") - - try: - return_dict['total'] = return_dict['loss/policy_loss'] - print("DEBUG: Total loss assigned successfully") - except KeyError as e: - print(f"DEBUG: KeyError accessing policy_loss: {e}") - print(f"DEBUG: Available keys: {list(return_dict.keys())}") - raise + dim=1, + ).mean(dim=0) + elif 'ift_kl' == key: + return_dict['kl/' + str(key)] = utils.masked_mean( + value, + batch['action_mask'], + ) + + # 3. Compute the total loss + return_dict['total'] = return_dict['loss/policy_loss'] - print("DEBUG: Checking ACTOR_CRITIC condition...") if loss_type in ALGORITHM_TYPE.ACTOR_CRITIC: - print("DEBUG: Adding value loss to total (ACTOR_CRITIC)") # Add value loss to total loss return_dict['total'] += value_loss_weight * return_dict[ 'loss/value_loss'] # pyright: ignore - else: - print("DEBUG: Skipping value loss (not ACTOR_CRITIC)") + - print("DEBUG: Checking add_direct_kl_loss condition...") # If we want to directly minimize the KL Divergence, we can do so here # and it will not include the KL in the reward. if add_direct_kl_loss: - print("DEBUG: Adding direct KL loss") return_dict['total'] += batch['ift_kl_scalar'][0] * return_dict[ 'kl/online_ift_kl'] return_dict['loss/online_ift_kl'] = ( batch['ift_kl_scalar'][0] * return_dict['kl/online_ift_kl'] ) - else: - print("DEBUG: Skipping direct KL loss") - print("DEBUG: Checking entropy loss...") # Entropy Loss. Meant to promote diversity. if entropy_loss_weight is not None: - print(f"DEBUG: Processing entropy loss with weight: {entropy_loss_weight}") - print(f"DEBUG: Looking for 'gen/cur_seq_entropy' in return_dict keys: {list(return_dict.keys())}") # We want to maximize entropy so we deduct it from the loss. entropy_loss = -1.0 * ( entropy_loss_weight * return_dict['gen/cur_seq_entropy'] @@ -734,44 +615,15 @@ def online_rl_loss( # breakpoint() return_dict['loss/entropy'] = entropy_loss return_dict['total'] += entropy_loss - print("DEBUG: Entropy loss processed successfully") - else: - print("DEBUG: Skipping entropy loss (weight is None)") - - print("DEBUG: Checking label loss...") + if 'lbl' in outputs and outputs['lbl'] is not None: - print("DEBUG: Processing label loss") return_dict['loss/lbl'] = outputs['lbl'] return_dict['total'] += outputs['lbl'] - print("DEBUG: Label loss processed successfully") - else: - print("DEBUG: Skipping label loss") - - print("DEBUG: Starting detachment of return_dict values...") - try: - # Detaching all return_dict values - for key, value in return_dict.items(): - print(f"DEBUG: Detaching key: {key}") - if key not in 'total': - return_dict[key] = value.detach().cpu() - print("DEBUG: All values detached successfully") - except Exception as e: - print(f"DEBUG: Error during detachment: {e}") - import traceback - print(f"DEBUG: Traceback: {traceback.format_exc()}") - raise - - print("DEBUG: About to return return_dict") - print(f"DEBUG: return_dict type: {type(return_dict)}") - print(f"DEBUG: return_dict keys: {list(return_dict.keys())}") - print(f"DEBUG: return_dict size: {len(return_dict)}") - try: - result = return_dict - print("DEBUG: Return assignment successful") - return result - except Exception as e: - print(f"DEBUG: Error during return: {e}") - import traceback - print(f"DEBUG: Traceback: {traceback.format_exc()}") - raise + # Detaching all return_dict values + for key, value in return_dict.items(): + if key not in 'total': + return_dict[key] = value.detach().cpu() + + #result = return_dict + return return_dict diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 948ecead..b3f2e04d 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -32,7 +32,7 @@ parameters: seed: 17 model: name: hf_critic_free_lm - loss_type: smd #smd #grpo #grpo + loss_type: smd #grpo target_kl: 0.1 pretrained: true init_device: mixed @@ -44,14 +44,14 @@ parameters: policy_clip_ratio: 0.2 attn_implementation: flash_attention_2 allow_embedding_resizing: true - normalize_advantage: False + normalize_advantage: False #true for grpo use_flash_attention_2: true length_normalize_policy_loss: true pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B loggers: mlflow: tags: - run: test_single_controller_ppo_deepseek_l8b_open_r1_48k + run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_smd group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo @@ -61,7 +61,7 @@ parameters: evals: #- name: gsm8k - name: math_500 - #- name: math_hard + - name: math_hard eval_overrides: generation_params: max_tokens: 8192 @@ -169,7 +169,7 @@ parameters: max_seq_len: 10240 save_folder: /tmp/checkpoints dist_timeout: 1800 - max_duration: 10iter + max_duration: 200iter progress_bar: false train_loader: name: prompt From 1e9d123db175dbf9955f5836009a15d65a6f7b49 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 27 Aug 2025 14:14:59 -0400 Subject: [PATCH 32/74] recreating wrong example beta string --- compose_rl/algorithms/online/model_methods.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 96238135..885d7f4c 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -437,7 +437,8 @@ def policy_loss( dim=-1, ) #size: (batch_size,) # Convert beta to a simple float - beta_float = float(beta) + # TODO: fix that later!! this is for debugging. + beta_float = beta #float(beta) policy_loss = ((beta_float * masked_log_probs_diff - prompt_advantages)**2).mean() rewards = utils.masked_sum( From 2badf807448c45f81ccbd889f0ae3c9a890b4381 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 27 Aug 2025 14:29:08 -0400 Subject: [PATCH 33/74] bug reproduced, revert back to the working version --- compose_rl/algorithms/online/model_methods.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 885d7f4c..96238135 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -437,8 +437,7 @@ def policy_loss( dim=-1, ) #size: (batch_size,) # Convert beta to a simple float - # TODO: fix that later!! this is for debugging. - beta_float = beta #float(beta) + beta_float = float(beta) policy_loss = ((beta_float * masked_log_probs_diff - prompt_advantages)**2).mean() rewards = utils.masked_sum( From 918f6b713db35c0a2281d56b10f630cc8b019c5a Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 27 Aug 2025 14:55:51 -0400 Subject: [PATCH 34/74] creating second bebugging example: kl/policy_kl --- compose_rl/algorithms/online/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 96238135..a55762ad 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -448,7 +448,7 @@ def policy_loss( policy_dict = { 'loss/policy_loss': policy_loss, - 'kl/policy_kl': policy_kl, # Required by calling code in model.py + #'kl/policy_kl': policy_kl, # Required by calling code in model.py 'kl/ref_policy_kl': policy_kl, 'kl/old_policy_kl': old_policy_kl, 'gen/gen_length': batch['action_mask'].sum(dim=1).to(torch.float32), From b5f5da1a45e484e8f9b47e68cc9b4f9cebbe6d6b Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 27 Aug 2025 16:12:53 -0400 Subject: [PATCH 35/74] convert back to the correct version and check in the yaml for smd --- compose_rl/algorithms/online/model_methods.py | 3 +- yamls/single-controller-smd-workflow.yaml | 190 ++++++++++++++++++ 2 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 yamls/single-controller-smd-workflow.yaml diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index a55762ad..680346e8 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -448,8 +448,7 @@ def policy_loss( policy_dict = { 'loss/policy_loss': policy_loss, - #'kl/policy_kl': policy_kl, # Required by calling code in model.py - 'kl/ref_policy_kl': policy_kl, + 'kl/policy_kl': policy_kl, # Required by calling code in model.py 'kl/old_policy_kl': old_policy_kl, 'gen/gen_length': batch['action_mask'].sum(dim=1).to(torch.float32), 'gen/entropy': old_entropies, diff --git a/yamls/single-controller-smd-workflow.yaml b/yamls/single-controller-smd-workflow.yaml new file mode 100644 index 00000000..cb2b7196 --- /dev/null +++ b/yamls/single-controller-smd-workflow.yaml @@ -0,0 +1,190 @@ +name: single-controller-hackathon_smd + +image: mosaicml/dle:nightly-latest +scheduling: + priority: medium + resumable: false + preemptible: false +compute: + gpus: 8 + cluster: r5z2p1 + instance: oci.bm.gpu.h200.8.oke +integrations: +- integration_type: git_repo + path: /workspace/compose-rl + git_repo: databricks/compose-rl + ssh_clone: true + git_branch: single-controller-hackathon-smd #single-controller-hackathon +- integration_type: git_repo + path: /workspace/research-universe + git_repo: databricks-mosaic/research-universe + ssh_clone: true + git_branch: update-orl-eval +command: |- + python -m uv pip install --system /workspace/research-universe/minieval + python -m uv pip install --system /workspace/research-universe/orl_eval + python -m uv pip install --system /workspace/compose-rl[gpu] --no-deps + + cd /workspace/compose-rl + composer test_single_controller_ppo.py --file_path /mnt/config/parameters.yaml + +parameters: + seed: 17 + model: + name: hf_critic_free_lm + loss_type: smd #grpo + target_kl: 100000 # is it used in SDM? + pretrained: true + init_device: mixed + kl_estimator: k3 + beta: 1e-3 #0.01 + kl_clip_range: 40 + use_auth_token: true + compute_kl_loss: false + policy_clip_ratio: 0.2 + attn_implementation: flash_attention_2 + allow_embedding_resizing: true + normalize_advantage: false + use_flash_attention_2: true + length_normalize_policy_loss: true + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + loggers: + mlflow: + tags: + run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_smd_beta_0.01_target_kl_100000 + group: grpo + tracking_uri: databricks + experiment_name: test_single_controller_ppo + callbacks: + ppo: {} + orl_eval: + evals: + #- name: gsm8k + - name: math_500 + - name: math_hard + eval_overrides: + generation_params: + max_tokens: 8192 + lr_monitor: {} + scheduled_gc: + batch_interval: 1000 + speed_monitor: + window_size: 1 + memory_monitor: {} + hf_checkpointer: + overwrite: true + precision: bfloat16 + save_folder: /tmp/hf_checkpoints/ + save_interval: 1dur + runtime_estimator: {} + optimizer: + lr: 1.0e-06 + name: decoupled_adamw + betas: + - 0.9 + - 0.95 + weight_decay: 1.0e-8 + precision: amp_bf16 + scheduler: + name: constant_with_warmup + alpha: 1 + t_warmup: 10iter + tokenizer: + name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + kwargs: + padding: longest + pad_token: <|finetune_right_pad_id|> + truncation: true + padding_side: left + model_max_length: 10240 + trust_remote_code: true + variables: + gamma: 1 + buffer: + name: MinibatchRolloutBuffer + rewards: + math_verifier: + reward: 1 + reward_type: math_verifier + lambda_gae: 1 + global_seed: 17 + max_gen_len: 8192 + eos_token_ids: + - 128001 + - 128008 + - 128009 + kl_controller: + kl_ctl_type: fixed + init_kl_coef: 0 + tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + num_train_nodes: 1 + reference_model: + precision: amp_bf16 + pretrained: true + model_config: + name: hf_causal_lm + pretrained: true + use_auth_token: true + use_flash_attention_2: true + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + generation_kwargs: + top_p: 1 + do_sample: true + use_cache: true + temperature: 1 + epoch_per_iteration: 1 + generations_per_prompt: 8 + num_batches_per_update: 8 + device_generate_batch_size: 1 + algorithms: + gradient_clipping: + clipping_type: norm + clipping_threshold: 1.0 #0.001 + autoresume: true + log_config: true + fsdp_config: + sync_module_states: true + verbose: false + cpu_offload: false + mixed_precision: PURE + state_dict_type: sharded + use_orig_params: true + forward_prefetch: true + backward_prefetch: BACKWARD_PRE + sharding_strategy: FULL_SHARD + activation_cpu_offload: false + activation_checkpointing: true + activation_checkpointing_reentrant: false + max_seq_len: 10240 + save_folder: /tmp/checkpoints + dist_timeout: 1800 + max_duration: 200iter + progress_bar: false + train_loader: + name: prompt + dataset: + local: /tmp/dataset/prompt_{timestamp}/ + split: train + remote: dbfs:/Volumes/datasets/ashutoshbaheti/orl_data/open_r1_filtered/dpsk_8b_open_r1_48k/ + shuffle: true + max_gen_len: 8192 + max_seq_len: 10240 + shuffle_seed: 17 + download_timeout: 1800 + drop_last: true + num_workers: 1 + eval_interval: 10iter + eval_first: false + save_interval: 100iter + log_to_console: true + save_overwrite: true + python_log_level: debug + console_log_interval: 1ba + device_eval_batch_size: 1 + eval_subset_num_batches: -1 + global_train_batch_size: 64 + device_train_microbatch_size: 1 # Increased for better SMD batch statistics + vllm_tensor_parallel_size: 1 + vllm_enable_prefix_caching: false + save_num_checkpoints_to_keep: 1 + max_async_step: 0 From 0abe985cb87d1efa2e553a18fc0a8182c5bc94bb Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 27 Aug 2025 16:29:56 -0400 Subject: [PATCH 36/74] . --- yamls/single-controller-grpo-workflow.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index b3f2e04d..8b927b19 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -1,4 +1,4 @@ -name: single-controller-hackathon_smd +name: single-controller-hackathon_grpo image: mosaicml/dle:nightly-latest scheduling: @@ -32,26 +32,26 @@ parameters: seed: 17 model: name: hf_critic_free_lm - loss_type: smd #grpo + loss_type: grpo target_kl: 0.1 pretrained: true init_device: mixed kl_estimator: k3 - beta: 0.001 + beta: 0.01 kl_clip_range: 40 use_auth_token: true compute_kl_loss: false policy_clip_ratio: 0.2 attn_implementation: flash_attention_2 allow_embedding_resizing: true - normalize_advantage: False #true for grpo + normalize_advantage: true use_flash_attention_2: true length_normalize_policy_loss: true pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B loggers: mlflow: tags: - run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_smd + run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_grpo group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo @@ -83,7 +83,7 @@ parameters: betas: - 0.9 - 0.95 - weight_decay: 0 + weight_decay: 1.0e-8 precision: amp_bf16 scheduler: name: constant_with_warmup From a93c8d5cdd4010428193997a18acde0217acbc80 Mon Sep 17 00:00:00 2001 From: wensun Date: Wed, 27 Aug 2025 21:47:38 -0400 Subject: [PATCH 37/74] addressed comments from bowen --- compose_rl/algorithms/online/model_methods.py | 8 +++----- yamls/single-controller-grpo-workflow.yaml | 18 ++++++++---------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 680346e8..ee98272d 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -416,7 +416,7 @@ def policy_loss( old_policy_kl_dict = utils.approx_kl( log_p=old_log_probs, - log_q=online_log_probs, #log_q - log_p = log pi - log pi_ref + log_q=online_log_probs, #log_q - log_p = log pi - log pi_old kl_clip_range=kl_clip_range, ) @@ -567,7 +567,6 @@ def online_rl_loss( # reward, and (2) the average total reward. if 'reward' in key: if value.shape == batch['action_mask'].shape: - print(f"DEBUG: Shapes match, computing masked operations...") # Average reward per timestep return_dict['env/' + str(key) + '_mean'] = utils.masked_mean( value, @@ -618,11 +617,10 @@ def online_rl_loss( if 'lbl' in outputs and outputs['lbl'] is not None: return_dict['loss/lbl'] = outputs['lbl'] return_dict['total'] += outputs['lbl'] - + # Detaching all return_dict values for key, value in return_dict.items(): if key not in 'total': return_dict[key] = value.detach().cpu() - #result = return_dict - return return_dict + return return_dict \ No newline at end of file diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 8b927b19..ac59150a 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -1,4 +1,4 @@ -name: single-controller-hackathon_grpo +name: single-controller-hackathon image: mosaicml/dle:nightly-latest scheduling: @@ -14,7 +14,7 @@ integrations: path: /workspace/compose-rl git_repo: databricks/compose-rl ssh_clone: true - git_branch: single-controller-hackathon-smd #single-controller-hackathon + git_branch: single-controller-hackathon - integration_type: git_repo path: /workspace/research-universe git_repo: databricks-mosaic/research-universe @@ -37,7 +37,6 @@ parameters: pretrained: true init_device: mixed kl_estimator: k3 - beta: 0.01 kl_clip_range: 40 use_auth_token: true compute_kl_loss: false @@ -51,7 +50,7 @@ parameters: loggers: mlflow: tags: - run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_grpo + run: test_single_controller_ppo_deepseek_l8b_open_r1_48k group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo @@ -59,7 +58,7 @@ parameters: ppo: {} orl_eval: evals: - #- name: gsm8k + - name: gsm8k - name: math_500 - name: math_hard eval_overrides: @@ -83,7 +82,7 @@ parameters: betas: - 0.9 - 0.95 - weight_decay: 1.0e-8 + weight_decay: 0 precision: amp_bf16 scheduler: name: constant_with_warmup @@ -169,7 +168,7 @@ parameters: max_seq_len: 10240 save_folder: /tmp/checkpoints dist_timeout: 1800 - max_duration: 200iter + max_duration: 10iter progress_bar: false train_loader: name: prompt @@ -184,8 +183,7 @@ parameters: download_timeout: 1800 drop_last: true num_workers: 1 - eval_interval: 10iter - eval_first: false + eval_interval: 2iter save_interval: 100iter log_to_console: true save_overwrite: true @@ -194,7 +192,7 @@ parameters: device_eval_batch_size: 1 eval_subset_num_batches: -1 global_train_batch_size: 64 - device_train_microbatch_size: 1 # Increased for better SMD batch statistics + device_train_microbatch_size: 1 vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 From f4d4cc5951b5dd719fab1a4f24d9901f5bb775fc Mon Sep 17 00:00:00 2001 From: wensun Date: Thu, 28 Aug 2025 11:25:34 -0400 Subject: [PATCH 38/74] . --- compose_rl/algorithms/online/model_methods.py | 2 +- yamls/single-controller-grpo-workflow.yaml | 33 +++++++------------ yamls/single-controller-smd-workflow.yaml | 26 +++++++-------- 3 files changed, 25 insertions(+), 36 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index ee98272d..e6de90fd 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -565,7 +565,7 @@ def online_rl_loss( # and (2) the per-sequence masked sum over actions, both size [batch]. # We then average over [batch], so the interpretation is (1) the average per-token # reward, and (2) the average total reward. - if 'reward' in key: + if 'reward' in key: if value.shape == batch['action_mask'].shape: # Average reward per timestep return_dict['env/' + str(key) + '_mean'] = utils.masked_mean( diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index ac59150a..436abf97 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -1,4 +1,4 @@ -name: single-controller-hackathon +name: single-controller-hackathon_grpo image: mosaicml/dle:nightly-latest scheduling: @@ -46,11 +46,11 @@ parameters: normalize_advantage: true use_flash_attention_2: true length_normalize_policy_loss: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B loggers: mlflow: tags: - run: test_single_controller_ppo_deepseek_l8b_open_r1_48k + run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_grpo_max_async_step_2 group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo @@ -89,7 +89,7 @@ parameters: alpha: 1 t_warmup: 10iter tokenizer: - name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B kwargs: padding: longest pad_token: <|finetune_right_pad_id|> @@ -103,19 +103,8 @@ parameters: name: MinibatchRolloutBuffer rewards: math_verifier: - reward: 4 - reward_type: math_verifier - bad_generation_end: - reward: -1 - eos_penalty: true - reward_type: bad_generation_end - math_format_verifier: reward: 1 - reward_type: math_format_verifier - penalize_extra_short_responses: - reward: -1 - reward_type: short_response_reward - len_threshold: 10 + reward_type: math_verifier lambda_gae: 1 global_seed: 17 max_gen_len: 8192 @@ -126,7 +115,7 @@ parameters: kl_controller: kl_ctl_type: fixed init_kl_coef: 0 - tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B num_train_nodes: 1 reference_model: precision: amp_bf16 @@ -136,7 +125,7 @@ parameters: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B generation_kwargs: top_p: 1 do_sample: true @@ -168,7 +157,7 @@ parameters: max_seq_len: 10240 save_folder: /tmp/checkpoints dist_timeout: 1800 - max_duration: 10iter + max_duration: 100iter progress_bar: false train_loader: name: prompt @@ -183,7 +172,7 @@ parameters: download_timeout: 1800 drop_last: true num_workers: 1 - eval_interval: 2iter + eval_interval: 5iter save_interval: 100iter log_to_console: true save_overwrite: true @@ -191,9 +180,9 @@ parameters: console_log_interval: 1ba device_eval_batch_size: 1 eval_subset_num_batches: -1 - global_train_batch_size: 64 + global_train_batch_size: 64 # global_train_batch_size * num_batches_per_update / generations_per_prompt = number of unique prompts device_train_microbatch_size: 1 vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 - max_async_step: 0 + max_async_step: 2 diff --git a/yamls/single-controller-smd-workflow.yaml b/yamls/single-controller-smd-workflow.yaml index cb2b7196..7363a518 100644 --- a/yamls/single-controller-smd-workflow.yaml +++ b/yamls/single-controller-smd-workflow.yaml @@ -33,11 +33,11 @@ parameters: model: name: hf_critic_free_lm loss_type: smd #grpo - target_kl: 100000 # is it used in SDM? + target_kl: 0.1 # is it used in SDM? pretrained: true init_device: mixed kl_estimator: k3 - beta: 1e-3 #0.01 + beta: 0.001 #0.01 kl_clip_range: 40 use_auth_token: true compute_kl_loss: false @@ -47,11 +47,11 @@ parameters: normalize_advantage: false use_flash_attention_2: true length_normalize_policy_loss: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B loggers: mlflow: tags: - run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_smd_beta_0.01_target_kl_100000 + run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_smd_beta_0.001_max_async_step_2 group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo @@ -90,7 +90,7 @@ parameters: alpha: 1 t_warmup: 10iter tokenizer: - name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B kwargs: padding: longest pad_token: <|finetune_right_pad_id|> @@ -116,7 +116,7 @@ parameters: kl_controller: kl_ctl_type: fixed init_kl_coef: 0 - tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B num_train_nodes: 1 reference_model: precision: amp_bf16 @@ -126,7 +126,7 @@ parameters: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B generation_kwargs: top_p: 1 do_sample: true @@ -139,7 +139,7 @@ parameters: algorithms: gradient_clipping: clipping_type: norm - clipping_threshold: 1.0 #0.001 + clipping_threshold: 0.001 autoresume: true log_config: true fsdp_config: @@ -158,7 +158,7 @@ parameters: max_seq_len: 10240 save_folder: /tmp/checkpoints dist_timeout: 1800 - max_duration: 200iter + max_duration: 100iter progress_bar: false train_loader: name: prompt @@ -173,7 +173,7 @@ parameters: download_timeout: 1800 drop_last: true num_workers: 1 - eval_interval: 10iter + eval_interval: 5iter eval_first: false save_interval: 100iter log_to_console: true @@ -182,9 +182,9 @@ parameters: console_log_interval: 1ba device_eval_batch_size: 1 eval_subset_num_batches: -1 - global_train_batch_size: 64 - device_train_microbatch_size: 1 # Increased for better SMD batch statistics + global_train_batch_size: 64 # global_train_batch_size * num_batches_per_update / generations_per_prompt = number of unique prompts + device_train_microbatch_size: 1 vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 - max_async_step: 0 + max_async_step: 2 From 4fa329a5e008d032108aace0153a228dbf85f84c Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 31 Aug 2025 22:24:41 -0400 Subject: [PATCH 39/74] add vllm logp and importance weight --- .../generation_utils/generation_utils.py | 39 ++++++++++++++++--- compose_rl/algorithms/online/model_methods.py | 9 ++++- test_single_controller_ppo.py | 39 ++++++++++++++++++- yamls/single-controller-grpo-workflow.yaml | 12 +++--- yamls/single-controller-smd-workflow.yaml | 12 +++--- 5 files changed, 90 insertions(+), 21 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/generation_utils.py b/compose_rl/algorithms/online/generation_utils/generation_utils.py index beea9176..fd0ab83a 100644 --- a/compose_rl/algorithms/online/generation_utils/generation_utils.py +++ b/compose_rl/algorithms/online/generation_utils/generation_utils.py @@ -92,13 +92,14 @@ def _vllm_generate( pad_token_id: int, # type: ignore all_prompts: list, batch_sizes: list, -) -> list: +) -> tuple[list, list]: futs = [] sampling_params = { 'temperature': generation_kwargs.get('temperature', 1.0), 'top_p': generation_kwargs.get('top_p', 1.0), 'top_k': generation_kwargs.get('top_k', -1), 'max_tokens': max_gen_len, + 'logprobs': 1, # to get the logprobs directly from vllm } # We have to remove all pad tokens here @@ -138,11 +139,13 @@ def _vllm_generate( start_time = time.time() results = ray.get(futs) all_responses = [] + all_logprobs = [] # Get all of the ray futures for i, result in enumerate(results): # Each result is a list of responses this assumes one output per input all_responses.extend([resp.outputs[0].token_ids for resp in result]) + all_logprobs.extend([[list(datum.values())[0] for datum in resp.outputs[0].logprobs] for resp in result]) log.info( f'took: {time.time() - start_time} to gather futures', @@ -150,13 +153,17 @@ def _vllm_generate( # Distribute padded responses back to the correct device split_responses = [] + split_logprobs = [] start = 0 for size in batch_sizes: split_responses.append( all_responses[start:start + size], ) + split_logprobs.append( + all_logprobs[start:start + size], + ) start += size - return split_responses + return split_responses, split_logprobs def _vllm_chat( @@ -254,7 +261,7 @@ def vllm_generate( generation_kwargs: dict, tokenizer: Tokenizer, vllm_generate_function: str, -) -> torch.Tensor: +) -> tuple[torch.Tensor, torch.Tensor]: """Run vllm chat on the prompts using messages. Runs generate over a set of sequences in the batch. It also does extra computation @@ -320,7 +327,7 @@ def vllm_generate( batch_sizes, ) else: - split_responses = _vllm_generate( + split_responses, split_logprobs = _vllm_generate( vllm_engines, max_gen_len, generation_kwargs, @@ -335,6 +342,7 @@ def vllm_generate( all_prompts = None all_messages = None split_responses = None + split_logprobs = None # Do another garbage collection and empty the cache gc.collect() @@ -345,12 +353,19 @@ def vllm_generate( # Scatter the generated responses back to the correct rank local_responses = [None] + local_logprobs = [None] start_time = time.time() torch.distributed.scatter_object_list( local_responses, split_responses, src=0, ) + torch.distributed.scatter_object_list( + local_logprobs, + split_logprobs, + src=0, + ) + local_logprobs = local_logprobs[0] local_responses = local_responses[0] log.info(f'took: {time.time() - start_time} to scatter prompts') @@ -379,7 +394,21 @@ def vllm_generate( # Construct full sequences from the prompt and padded responses sequences = torch.cat([prompt_tokens, padded_responses], dim=-1) num_tokens_generated = sequences.size(1) - prompt_tokens.size(1) + + padded_logprobs = [] + for logprobs in local_logprobs: # type: ignore + logprobs = list(logprobs) + if len(logprobs) < max_vllm_generated_len: + logprobs = logprobs + [0] * (max_vllm_generated_len - len(logprobs)) + padded_logprobs.append(logprobs) + + vllm_logprobs = torch.tensor( + padded_logprobs, + dtype=torch.float, + device=cur_device, + ) + log.info( f'It took {time.time() - start_gen_time} to generate {num_tokens_generated} tokens', ) - return sequences + return sequences, vllm_logprobs diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index e6de90fd..3a7aa5a1 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -404,7 +404,12 @@ def policy_loss( online_log_probs = outputs['online_log_probs'] ref_log_probs = batch['ift_log_probs'] old_entropies = batch['old_entropies'] - old_log_probs = batch['old_log_probs'] + old_log_probs = batch['old_log_probs'] # note this is the log prob of the pi_prox -- the usual pi_old in ppo language. + vllm_logprobs = batch['vllm_logprobs'] # note this the log prob from vllm when generating the rollouts, i.e., log pi_behavior + + importance_ratio = torch.exp(old_log_probs - vllm_logprobs) # pi_prox / pi_behavior + importance_ratio = torch.clamp(importance_ratio, min = 0.0, max = 10) + online_to_old_diff = online_log_probs - old_log_probs # ln(π/π_old) for SMD #compute KL to pi_ref to keep track the divergence to \pi_ref @@ -438,7 +443,7 @@ def policy_loss( ) #size: (batch_size,) # Convert beta to a simple float beta_float = float(beta) - policy_loss = ((beta_float * masked_log_probs_diff - prompt_advantages)**2).mean() + policy_loss = (importance_ratio*((beta_float * masked_log_probs_diff - prompt_advantages)**2)).mean() rewards = utils.masked_sum( batch['rewards'], diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 012b40b8..3508baf1 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -514,8 +514,9 @@ def get_log_probs_and_entropy(self, current_rank_rollouts: dict[str, Any], devic prompt_dtype = prompt_tokens.dtype assert 'sequences' in current_rank_rollouts, f'sequences is not in batch {current_rank_rollouts.keys()=}' - + assert 'vllm_logprobs' in current_rank_rollouts, f'vllm_logprobs is not in batch {current_rank_rollouts.keys()=}' sequences = current_rank_rollouts['sequences'] + vllm_logprobs = current_rank_rollouts['vllm_logprobs'] generated_len = torch.ones( batch_size, device=device, @@ -540,6 +541,16 @@ def get_log_probs_and_entropy(self, current_rank_rollouts: dict[str, Any], devic dim=-1, # type: ignore ) + extra_zero_padding = torch.zeros( + (batch_size, len_to_pad), + device=device, + dtype=torch.float, + ) + vllm_logprobs = torch.cat( + [vllm_logprobs, extra_zero_padding], # type: ignore + dim=-1, # type: ignore + ) + # Sanity checking we're adding max_gen_len to prompt_tokens if prompt_tokens.size(1) + self.max_gen_len != sequences.size(1): raise ValueError( @@ -548,6 +559,7 @@ def get_log_probs_and_entropy(self, current_rank_rollouts: dict[str, Any], devic # Actions are what tokens the current policy would generate. actions = sequences[:, -self.max_gen_len:] # type: ignore + vllm_logprobs_gen = vllm_logprobs[:, -self.max_gen_len:] # type: ignore right_padded_obs = switch_left_to_right_padding( sequences, @@ -624,6 +636,9 @@ def get_log_probs_and_entropy(self, current_rank_rollouts: dict[str, Any], devic device_train_microbatch_log_probs = torch.cat(log_probs) device_train_microbatch_entropies = torch.cat(entropies) + assert vllm_logprobs_gen.shape == device_train_microbatch_log_probs.shape, f'vllm_logprobs_gen and device_train_microbatch_log_probs have different shapes {vllm_logprobs_gen.shape=}, {device_train_microbatch_log_probs.shape=}' + + partial_env_output = { 'prompt_id': prompt_id, 'old_log_probs': device_train_microbatch_log_probs, @@ -634,6 +649,7 @@ def get_log_probs_and_entropy(self, current_rank_rollouts: dict[str, Any], devic 'action_mask': action_mask, 'generated_len': generated_len, 'prompt_len': prompt_len, + 'vllm_logprobs': vllm_logprobs_gen, } if len(values) > 0: device_train_microbatch_values = torch.cat(values) @@ -1443,7 +1459,7 @@ def get_next_iter_rollouts(self): # TODO: Since this functionality is (somewhat) shared across the OnPolicyCallback and the RolloutAgent, # we should move this to the separate util file. with get_precision_context(self.precision), torch.no_grad(): - sequences = _vllm_generate( + sequences, vllm_logprobs = _vllm_generate( vllm_engines=self.inference_server.engines, max_gen_len=self.max_gen_len, generation_kwargs=self.generation_kwargs, @@ -1453,6 +1469,8 @@ def get_next_iter_rollouts(self): ) sequences = sequences[0] + vllm_logprobs = vllm_logprobs[0] + max_vllm_generated_len = max([len(response) for response in sequences]) padded_responses = [] for sequence in sequences: @@ -1470,6 +1488,23 @@ def get_next_iter_rollouts(self): processed_sequences = torch.cat([all_prompts, padded_responses], dim=-1) iter_data['sequences'] = processed_sequences + padded_logprobs = [] + for logprobs in vllm_logprobs: + logprobs = list(logprobs) + if len(logprobs) < max_vllm_generated_len: + logprobs = logprobs + [0] * (max_vllm_generated_len - len(logprobs)) + padded_logprobs.append(logprobs) + padded_logprobs = torch.tensor( + padded_logprobs, + dtype=torch.float, + device=torch.device('cpu'), + ) + temp_zeros = torch.zeros_like(all_prompts, dtype=torch.float, device=torch.device('cpu')) + processed_logprobs = torch.cat([temp_zeros, padded_logprobs], dim=-1) + iter_data['vllm_logprobs'] = processed_logprobs + assert processed_logprobs.shape == processed_sequences.shape, f'vllm_logprobs and sequences have different shapes {processed_logprobs.shape=}, {processed_sequences.shape=}' + + # Calculate the rewards here # Initialize the required variables from the reward actor tokenizer = self.tokenizer diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 436abf97..aa6c779d 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -46,11 +46,11 @@ parameters: normalize_advantage: true use_flash_attention_2: true length_normalize_policy_loss: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B loggers: mlflow: tags: - run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_grpo_max_async_step_2 + run: test_single_controller_grpo_ds_llama_7b_open_r1_48k_grpo_max_async_step_4 group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo @@ -89,7 +89,7 @@ parameters: alpha: 1 t_warmup: 10iter tokenizer: - name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B kwargs: padding: longest pad_token: <|finetune_right_pad_id|> @@ -115,7 +115,7 @@ parameters: kl_controller: kl_ctl_type: fixed init_kl_coef: 0 - tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B num_train_nodes: 1 reference_model: precision: amp_bf16 @@ -125,7 +125,7 @@ parameters: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B generation_kwargs: top_p: 1 do_sample: true @@ -185,4 +185,4 @@ parameters: vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 - max_async_step: 2 + max_async_step: 4 diff --git a/yamls/single-controller-smd-workflow.yaml b/yamls/single-controller-smd-workflow.yaml index 7363a518..cd6cf35a 100644 --- a/yamls/single-controller-smd-workflow.yaml +++ b/yamls/single-controller-smd-workflow.yaml @@ -37,7 +37,7 @@ parameters: pretrained: true init_device: mixed kl_estimator: k3 - beta: 0.001 #0.01 + beta: 0.1 #0.01 kl_clip_range: 40 use_auth_token: true compute_kl_loss: false @@ -47,11 +47,11 @@ parameters: normalize_advantage: false use_flash_attention_2: true length_normalize_policy_loss: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B loggers: mlflow: tags: - run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_smd_beta_0.001_max_async_step_2 + run: test_single_controller_smd_beta_0.1_max_async_step_2_ds_llama_7b group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo @@ -90,7 +90,7 @@ parameters: alpha: 1 t_warmup: 10iter tokenizer: - name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B kwargs: padding: longest pad_token: <|finetune_right_pad_id|> @@ -116,7 +116,7 @@ parameters: kl_controller: kl_ctl_type: fixed init_kl_coef: 0 - tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B num_train_nodes: 1 reference_model: precision: amp_bf16 @@ -126,7 +126,7 @@ parameters: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B generation_kwargs: top_p: 1 do_sample: true From e414dcde2e32209fa0fb97911518c05f778abec9 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 31 Aug 2025 22:25:37 -0400 Subject: [PATCH 40/74] . --- compose_rl/algorithms/online/model_methods.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 3a7aa5a1..892c6b0a 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -442,6 +442,7 @@ def policy_loss( dim=-1, ) #size: (batch_size,) # Convert beta to a simple float + assert importance_ratio.shape == masked_log_probs_diff.shape, f'importance_ratio and masked_log_probs_diff have different shapes {importance_ratio.shape=}, {masked_log_probs_diff.shape=}' beta_float = float(beta) policy_loss = (importance_ratio*((beta_float * masked_log_probs_diff - prompt_advantages)**2)).mean() From 603d748a7d4a1970213aa8bda6955903941c6c11 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 31 Aug 2025 22:38:57 -0400 Subject: [PATCH 41/74] . --- compose_rl/algorithms/online/model_methods.py | 3 +++ yamls/single-controller-smd-workflow.yaml | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 892c6b0a..7534273e 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -464,6 +464,9 @@ def policy_loss( 'advantages/mean': torch.mean( prompt_advantages, # SMD uses prompt_advantages, not advantages ), #compute the average of the prompt advantages for SMD + 'importance_ratio/mean': torch.mean( + importance_ratio, + ), } return policy_dict else: diff --git a/yamls/single-controller-smd-workflow.yaml b/yamls/single-controller-smd-workflow.yaml index cd6cf35a..a56d2838 100644 --- a/yamls/single-controller-smd-workflow.yaml +++ b/yamls/single-controller-smd-workflow.yaml @@ -51,7 +51,7 @@ parameters: loggers: mlflow: tags: - run: test_single_controller_smd_beta_0.1_max_async_step_2_ds_llama_7b + run: test_single_controller_smd_beta_0.1_max_async_step_2_ds_llama_7b_IS group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo From 35a365eb849c0e3afe8882bd2a8480c724339916 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 31 Aug 2025 23:14:10 -0400 Subject: [PATCH 42/74] . --- compose_rl/algorithms/online/model_methods.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 7534273e..db29359c 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -452,6 +452,10 @@ def policy_loss( dim=-1, ) + print('===============================') + print(f'importance_ratio: {importance_ratio.shape=}, {importance_ratio=}') + print('===============================') + policy_dict = { 'loss/policy_loss': policy_loss, 'kl/policy_kl': policy_kl, # Required by calling code in model.py From 549428f91fd4ae95bd412083ab0926f047b85f55 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 31 Aug 2025 23:15:10 -0400 Subject: [PATCH 43/74] . --- yamls/single-controller-smd-workflow.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamls/single-controller-smd-workflow.yaml b/yamls/single-controller-smd-workflow.yaml index a56d2838..5e3642f8 100644 --- a/yamls/single-controller-smd-workflow.yaml +++ b/yamls/single-controller-smd-workflow.yaml @@ -14,7 +14,7 @@ integrations: path: /workspace/compose-rl git_repo: databricks/compose-rl ssh_clone: true - git_branch: single-controller-hackathon-smd #single-controller-hackathon + git_branch: single-controller-hackathon-smd-vllm #single-controller-hackathon - integration_type: git_repo path: /workspace/research-universe git_repo: databricks-mosaic/research-universe From be53de69e5a797dbe0c58911925fb8d2e250cc4f Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 31 Aug 2025 23:26:25 -0400 Subject: [PATCH 44/74] . --- compose_rl/algorithms/online/model_methods.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index db29359c..7d991b3f 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -407,6 +407,14 @@ def policy_loss( old_log_probs = batch['old_log_probs'] # note this is the log prob of the pi_prox -- the usual pi_old in ppo language. vllm_logprobs = batch['vllm_logprobs'] # note this the log prob from vllm when generating the rollouts, i.e., log pi_behavior + print('===============================') + print(f'old_log_probs: {old_log_probs.shape=}, {old_log_probs=}') + print(f'vllm_logprobs: {vllm_logprobs.shape=}, {vllm_logprobs=}') + print('===============================') + + assert old_log_probs.shape == vllm_logprobs.shape, f'old_log_probs and vllm_logprobs have different shapes {old_log_probs.shape=}, {vllm_logprobs.shape=}' + + importance_ratio = torch.exp(old_log_probs - vllm_logprobs) # pi_prox / pi_behavior importance_ratio = torch.clamp(importance_ratio, min = 0.0, max = 10) @@ -441,10 +449,16 @@ def policy_loss( batch['action_mask'], dim=-1, ) #size: (batch_size,) + masked_importance_ratio = utils.masked_sum( + importance_ratio, + batch['action_mask'], + dim=-1, + ) #size: (batch_size,) + # Convert beta to a simple float - assert importance_ratio.shape == masked_log_probs_diff.shape, f'importance_ratio and masked_log_probs_diff have different shapes {importance_ratio.shape=}, {masked_log_probs_diff.shape=}' + assert masked_importance_ratio.shape == masked_log_probs_diff.shape, f'masked_importance_ratio and masked_log_probs_diff have different shapes {importance_ratio.shape=}, {masked_log_probs_diff.shape=}' beta_float = float(beta) - policy_loss = (importance_ratio*((beta_float * masked_log_probs_diff - prompt_advantages)**2)).mean() + policy_loss = (masked_importance_ratio*((beta_float * masked_log_probs_diff - prompt_advantages)**2)).mean() rewards = utils.masked_sum( batch['rewards'], @@ -453,7 +467,7 @@ def policy_loss( ) print('===============================') - print(f'importance_ratio: {importance_ratio.shape=}, {importance_ratio=}') + print(f'masked_importance_ratio: {masked_importance_ratio.shape=}, {masked_importance_ratio=}') print('===============================') policy_dict = { @@ -469,7 +483,7 @@ def policy_loss( prompt_advantages, # SMD uses prompt_advantages, not advantages ), #compute the average of the prompt advantages for SMD 'importance_ratio/mean': torch.mean( - importance_ratio, + masked_importance_ratio, ), } return policy_dict From e5f81641a4f9c5e911928202f8aa878b03fcbaed Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 31 Aug 2025 23:33:26 -0400 Subject: [PATCH 45/74] . --- test_single_controller_ppo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 3508baf1..e4c76f41 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -28,6 +28,7 @@ from composer.loggers import MLFlowLogger import ray +ray.init(_internal_config={"logging_level": "DEBUG"}) import spacy import torch import torch.distributed as dist From 2e159c08f56d08af782672f811fe348cb9b377bc Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 31 Aug 2025 23:35:12 -0400 Subject: [PATCH 46/74] . --- test_single_controller_ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index e4c76f41..22d3fde7 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -28,7 +28,7 @@ from composer.loggers import MLFlowLogger import ray -ray.init(_internal_config={"logging_level": "DEBUG"}) +ray.init(logging_level="DEBUG") import spacy import torch import torch.distributed as dist From 49a451aad12ba05d52a10053a236326cde73507d Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 31 Aug 2025 23:38:51 -0400 Subject: [PATCH 47/74] . --- compose_rl/utils/ray_utils.py | 2 +- scripts/launch_composer_ray.py | 2 +- test_single_controller_ppo.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/compose_rl/utils/ray_utils.py b/compose_rl/utils/ray_utils.py index adf51d29..d8410ce6 100644 --- a/compose_rl/utils/ray_utils.py +++ b/compose_rl/utils/ray_utils.py @@ -45,7 +45,7 @@ def init_ray_with_torch_distributed(timeout_seconds: int = 30): # Start Ray Server on master node subprocess.run(['ray', 'start', '--head'], check=True) # connect to the ray cluster - ray.init('auto') + ray.init(logging_level="DEBUG") # get existing ray ip and port ctx = ray.get_runtime_context() address = ctx.gcs_address diff --git a/scripts/launch_composer_ray.py b/scripts/launch_composer_ray.py index b093c220..913a1cda 100644 --- a/scripts/launch_composer_ray.py +++ b/scripts/launch_composer_ray.py @@ -161,7 +161,7 @@ def start_ray_nodes(): # Send the local node IP to other ranks broadcast_string(ip, src_rank=0) - ray.init() + ray.init(logging_level="DEBUG") # Wait for all ray clusters to start dist.barrier() diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 22d3fde7..3508baf1 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -28,7 +28,6 @@ from composer.loggers import MLFlowLogger import ray -ray.init(logging_level="DEBUG") import spacy import torch import torch.distributed as dist From 3799abb00761fa7687b75cb11653dd51f6876dc7 Mon Sep 17 00:00:00 2001 From: wensun Date: Sun, 31 Aug 2025 23:47:50 -0400 Subject: [PATCH 48/74] . --- .../online/generation_utils/generation_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/compose_rl/algorithms/online/generation_utils/generation_utils.py b/compose_rl/algorithms/online/generation_utils/generation_utils.py index fd0ab83a..1d73109e 100644 --- a/compose_rl/algorithms/online/generation_utils/generation_utils.py +++ b/compose_rl/algorithms/online/generation_utils/generation_utils.py @@ -151,6 +151,10 @@ def _vllm_generate( f'took: {time.time() - start_time} to gather futures', ) + print('===============================') + print(f'all_logprobs: {all_logprobs=}') + print('===============================') + # Distribute padded responses back to the correct device split_responses = [] split_logprobs = [] @@ -163,6 +167,11 @@ def _vllm_generate( all_logprobs[start:start + size], ) start += size + + print('===============================') + print(f'split_logprobs: {split_logprobs=}') + print('===============================') + return split_responses, split_logprobs From 83c44944ad5fc97904b3e4eb5a8dbc4de915b14a Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 00:30:50 -0400 Subject: [PATCH 49/74] . --- .../online/generation_utils/generation_utils.py | 7 ------- test_single_controller_ppo.py | 9 +++++++++ yamls/single-controller-smd-workflow.yaml | 8 ++++---- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/generation_utils.py b/compose_rl/algorithms/online/generation_utils/generation_utils.py index 1d73109e..af863ef5 100644 --- a/compose_rl/algorithms/online/generation_utils/generation_utils.py +++ b/compose_rl/algorithms/online/generation_utils/generation_utils.py @@ -151,10 +151,6 @@ def _vllm_generate( f'took: {time.time() - start_time} to gather futures', ) - print('===============================') - print(f'all_logprobs: {all_logprobs=}') - print('===============================') - # Distribute padded responses back to the correct device split_responses = [] split_logprobs = [] @@ -168,9 +164,6 @@ def _vllm_generate( ) start += size - print('===============================') - print(f'split_logprobs: {split_logprobs=}') - print('===============================') return split_responses, split_logprobs diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 3508baf1..557b2439 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -1471,6 +1471,11 @@ def get_next_iter_rollouts(self): sequences = sequences[0] vllm_logprobs = vllm_logprobs[0] + print('===============================') + print(f'len(sequences): {len(sequences)=}') + print(f'len(vllm_logprobs): {len(vllm_logprobs)=}') + print('===============================') + max_vllm_generated_len = max([len(response) for response in sequences]) padded_responses = [] for sequence in sequences: @@ -1502,6 +1507,10 @@ def get_next_iter_rollouts(self): temp_zeros = torch.zeros_like(all_prompts, dtype=torch.float, device=torch.device('cpu')) processed_logprobs = torch.cat([temp_zeros, padded_logprobs], dim=-1) iter_data['vllm_logprobs'] = processed_logprobs + print('===============================') + print(f"processed_logprobs.shape: {processed_logprobs.shape=}") + print(f"processed_sequences.shape: {processed_sequences.shape=}") + print('===============================') assert processed_logprobs.shape == processed_sequences.shape, f'vllm_logprobs and sequences have different shapes {processed_logprobs.shape=}, {processed_sequences.shape=}' diff --git a/yamls/single-controller-smd-workflow.yaml b/yamls/single-controller-smd-workflow.yaml index 5e3642f8..1be3214d 100644 --- a/yamls/single-controller-smd-workflow.yaml +++ b/yamls/single-controller-smd-workflow.yaml @@ -59,9 +59,9 @@ parameters: ppo: {} orl_eval: evals: - #- name: gsm8k - - name: math_500 - - name: math_hard + - name: gsm8k + #- name: math_500 + #- name: math_hard eval_overrides: generation_params: max_tokens: 8192 @@ -182,7 +182,7 @@ parameters: console_log_interval: 1ba device_eval_batch_size: 1 eval_subset_num_batches: -1 - global_train_batch_size: 64 # global_train_batch_size * num_batches_per_update / generations_per_prompt = number of unique prompts + global_train_batch_size: 16 # global_train_batch_size * num_batches_per_update / generations_per_prompt = number of unique prompts device_train_microbatch_size: 1 vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false From 5572515bb01ee9bc18959c5099469571793849c6 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 00:42:37 -0400 Subject: [PATCH 50/74] . --- test_single_controller_ppo.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 557b2439..f2a6736a 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -1471,11 +1471,6 @@ def get_next_iter_rollouts(self): sequences = sequences[0] vllm_logprobs = vllm_logprobs[0] - print('===============================') - print(f'len(sequences): {len(sequences)=}') - print(f'len(vllm_logprobs): {len(vllm_logprobs)=}') - print('===============================') - max_vllm_generated_len = max([len(response) for response in sequences]) padded_responses = [] for sequence in sequences: @@ -1493,12 +1488,21 @@ def get_next_iter_rollouts(self): processed_sequences = torch.cat([all_prompts, padded_responses], dim=-1) iter_data['sequences'] = processed_sequences + print('===============================') + print(f"processed_sequences.shape: {processed_sequences.shape=}") + print('===============================') + padded_logprobs = [] for logprobs in vllm_logprobs: logprobs = list(logprobs) if len(logprobs) < max_vllm_generated_len: logprobs = logprobs + [0] * (max_vllm_generated_len - len(logprobs)) padded_logprobs.append(logprobs) + + print('===============================') + print(f"len(padded_logprobs): {len(padded_logprobs)=}") + print('===============================') + padded_logprobs = torch.tensor( padded_logprobs, dtype=torch.float, @@ -1509,7 +1513,6 @@ def get_next_iter_rollouts(self): iter_data['vllm_logprobs'] = processed_logprobs print('===============================') print(f"processed_logprobs.shape: {processed_logprobs.shape=}") - print(f"processed_sequences.shape: {processed_sequences.shape=}") print('===============================') assert processed_logprobs.shape == processed_sequences.shape, f'vllm_logprobs and sequences have different shapes {processed_logprobs.shape=}, {processed_sequences.shape=}' @@ -1532,6 +1535,7 @@ def get_next_iter_rollouts(self): assert 'sequences' in iter_data, f'sequences is not in iter_data {iter_data.keys()=}' sequences = iter_data['sequences'] + vllm_logprobs = iter_data['vllm_logprobs'] generated_len = torch.ones( batch_size, device=cur_device, @@ -1556,6 +1560,16 @@ def get_next_iter_rollouts(self): dim=-1, # type: ignore ) + extra_zero_padding = torch.zeros( + (batch_size, len_to_pad), + device=cur_device, + dtype=torch.float, + ) + vllm_logprobs = torch.cat( + [vllm_logprobs, extra_zero_padding], # type: ignore + dim=-1, # type: ignore + ) + # Sanity checking we're adding max_gen_len to prompt_tokens if prompt_tokens.size(1) + max_gen_len != sequences.size(1): raise ValueError( From 0536ad23e1b6d13cb8c5367d9b2643ef003647a5 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 00:49:04 -0400 Subject: [PATCH 51/74] . --- test_single_controller_ppo.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index f2a6736a..5949766c 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -1508,6 +1508,10 @@ def get_next_iter_rollouts(self): dtype=torch.float, device=torch.device('cpu'), ) + print('===============================') + print(f"padded_logprobs.shape: {padded_logprobs.shape=}") + print('===============================') + temp_zeros = torch.zeros_like(all_prompts, dtype=torch.float, device=torch.device('cpu')) processed_logprobs = torch.cat([temp_zeros, padded_logprobs], dim=-1) iter_data['vllm_logprobs'] = processed_logprobs From a42f45a3cd89ee07289f65e9584ff5b9b6e15a37 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 00:55:31 -0400 Subject: [PATCH 52/74] . --- test_single_controller_ppo.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 5949766c..e98a8d8a 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -1497,21 +1497,29 @@ def get_next_iter_rollouts(self): logprobs = list(logprobs) if len(logprobs) < max_vllm_generated_len: logprobs = logprobs + [0] * (max_vllm_generated_len - len(logprobs)) + else: + raise ValueError(f"logprobs.shape: {logprobs.shape=} is larger than max_vllm_generated_len: {max_vllm_generated_len=}") padded_logprobs.append(logprobs) print('===============================') print(f"len(padded_logprobs): {len(padded_logprobs)=}") print('===============================') - padded_logprobs = torch.tensor( - padded_logprobs, + try: + padded_logprobs = torch.tensor( + padded_logprobs, dtype=torch.float, - device=torch.device('cpu'), - ) + device=torch.device('cpu'), + ) + except Exception as e: + print(f"Error: {e}") + print(f"padded_logprobs: {padded_logprobs=}") + raise e + print('===============================') print(f"padded_logprobs.shape: {padded_logprobs.shape=}") print('===============================') - + temp_zeros = torch.zeros_like(all_prompts, dtype=torch.float, device=torch.device('cpu')) processed_logprobs = torch.cat([temp_zeros, padded_logprobs], dim=-1) iter_data['vllm_logprobs'] = processed_logprobs From 96bb36e5171c43c5ac0d6291df823d06e7cac37e Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 01:03:05 -0400 Subject: [PATCH 53/74] . --- test_single_controller_ppo.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index e98a8d8a..17a81f7f 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -1472,6 +1472,13 @@ def get_next_iter_rollouts(self): vllm_logprobs = vllm_logprobs[0] max_vllm_generated_len = max([len(response) for response in sequences]) + max_vllm_logprobs_len = max([len(response) for response in vllm_logprobs]) + + print('===============================') + print(f"max_vllm_generated_len: {max_vllm_generated_len=}") + print(f"max_vllm_logprobs_len: {max_vllm_logprobs_len=}") + print('===============================') + padded_responses = [] for sequence in sequences: sequence = list(sequence) @@ -1498,6 +1505,7 @@ def get_next_iter_rollouts(self): if len(logprobs) < max_vllm_generated_len: logprobs = logprobs + [0] * (max_vllm_generated_len - len(logprobs)) else: + print(f"logprobs.shape: {logprobs.shape=} is larger than max_vllm_generated_len: {max_vllm_generated_len=}") raise ValueError(f"logprobs.shape: {logprobs.shape=} is larger than max_vllm_generated_len: {max_vllm_generated_len=}") padded_logprobs.append(logprobs) @@ -1508,14 +1516,14 @@ def get_next_iter_rollouts(self): try: padded_logprobs = torch.tensor( padded_logprobs, - dtype=torch.float, + dtype=torch.float, device=torch.device('cpu'), ) except Exception as e: print(f"Error: {e}") print(f"padded_logprobs: {padded_logprobs=}") raise e - + print('===============================') print(f"padded_logprobs.shape: {padded_logprobs.shape=}") print('===============================') From 5690a5d7e9963caf0e1b1f2bb3c00f40fd68189c Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 08:43:05 -0400 Subject: [PATCH 54/74] . --- test_single_controller_ppo.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 17a81f7f..c6d06512 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -1473,11 +1473,6 @@ def get_next_iter_rollouts(self): max_vllm_generated_len = max([len(response) for response in sequences]) max_vllm_logprobs_len = max([len(response) for response in vllm_logprobs]) - - print('===============================') - print(f"max_vllm_generated_len: {max_vllm_generated_len=}") - print(f"max_vllm_logprobs_len: {max_vllm_logprobs_len=}") - print('===============================') padded_responses = [] for sequence in sequences: @@ -1504,9 +1499,6 @@ def get_next_iter_rollouts(self): logprobs = list(logprobs) if len(logprobs) < max_vllm_generated_len: logprobs = logprobs + [0] * (max_vllm_generated_len - len(logprobs)) - else: - print(f"logprobs.shape: {logprobs.shape=} is larger than max_vllm_generated_len: {max_vllm_generated_len=}") - raise ValueError(f"logprobs.shape: {logprobs.shape=} is larger than max_vllm_generated_len: {max_vllm_generated_len=}") padded_logprobs.append(logprobs) print('===============================') From 1d219019331db5d57bf714619c536de8bb19da78 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 08:49:11 -0400 Subject: [PATCH 55/74] . --- test_single_controller_ppo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index c6d06512..4dd20b12 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -1513,7 +1513,6 @@ def get_next_iter_rollouts(self): ) except Exception as e: print(f"Error: {e}") - print(f"padded_logprobs: {padded_logprobs=}") raise e print('===============================') From 472519c4ad08f5e533fc93879c3298ce19ed4427 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 08:58:23 -0400 Subject: [PATCH 56/74] . --- .../algorithms/online/generation_utils/generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/generation_utils.py b/compose_rl/algorithms/online/generation_utils/generation_utils.py index af863ef5..537290d2 100644 --- a/compose_rl/algorithms/online/generation_utils/generation_utils.py +++ b/compose_rl/algorithms/online/generation_utils/generation_utils.py @@ -145,7 +145,7 @@ def _vllm_generate( for i, result in enumerate(results): # Each result is a list of responses this assumes one output per input all_responses.extend([resp.outputs[0].token_ids for resp in result]) - all_logprobs.extend([[list(datum.values())[0] for datum in resp.outputs[0].logprobs] for resp in result]) + all_logprobs.extend([[list(datum.values())[0].logprob for datum in resp.outputs[0].logprobs] for resp in result]) log.info( f'took: {time.time() - start_time} to gather futures', From e11b2a5ebb790a915315f1c6e12a99d4cc81834c Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 09:15:55 -0400 Subject: [PATCH 57/74] . --- compose_rl/algorithms/online/model_methods.py | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 7d991b3f..23a6d361 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -405,18 +405,10 @@ def policy_loss( ref_log_probs = batch['ift_log_probs'] old_entropies = batch['old_entropies'] old_log_probs = batch['old_log_probs'] # note this is the log prob of the pi_prox -- the usual pi_old in ppo language. - vllm_logprobs = batch['vllm_logprobs'] # note this the log prob from vllm when generating the rollouts, i.e., log pi_behavior - - print('===============================') - print(f'old_log_probs: {old_log_probs.shape=}, {old_log_probs=}') - print(f'vllm_logprobs: {vllm_logprobs.shape=}, {vllm_logprobs=}') - print('===============================') - + vllm_logprobs = batch['vllm_logprobs'] # note this the log prob from vllm when generating the rollouts, i.e., log pi_behavior assert old_log_probs.shape == vllm_logprobs.shape, f'old_log_probs and vllm_logprobs have different shapes {old_log_probs.shape=}, {vllm_logprobs.shape=}' - - importance_ratio = torch.exp(old_log_probs - vllm_logprobs) # pi_prox / pi_behavior - importance_ratio = torch.clamp(importance_ratio, min = 0.0, max = 10) + token_log_ratio = old_log_probs - vllm_logprobs # [ ln (pi_prox_t / pi_behavior_t) ] online_to_old_diff = online_log_probs - old_log_probs # ln(π/π_old) for SMD @@ -449,11 +441,13 @@ def policy_loss( batch['action_mask'], dim=-1, ) #size: (batch_size,) - masked_importance_ratio = utils.masked_sum( - importance_ratio, + masked_log_ratio = utils.masked_sum( + token_log_ratio, batch['action_mask'], dim=-1, - ) #size: (batch_size,) + ) #size: (batch_size,) # \sum_t ln (pi_prox_t / pi_behavior_t) + masked_log_ratio = torch.clamp(masked_log_ratio, min = -100.0, max = 10.0) # clip to avoid overflow + masked_importance_ratio = torch.exp(masked_log_ratio) # pi_prox / pi_behavior # Convert beta to a simple float assert masked_importance_ratio.shape == masked_log_probs_diff.shape, f'masked_importance_ratio and masked_log_probs_diff have different shapes {importance_ratio.shape=}, {masked_log_probs_diff.shape=}' @@ -466,10 +460,6 @@ def policy_loss( dim=-1, ) - print('===============================') - print(f'masked_importance_ratio: {masked_importance_ratio.shape=}, {masked_importance_ratio=}') - print('===============================') - policy_dict = { 'loss/policy_loss': policy_loss, 'kl/policy_kl': policy_kl, # Required by calling code in model.py From a693b1d96874e13d1fb208fd9dda0d4d7d3bd271 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 09:29:00 -0400 Subject: [PATCH 58/74] delete ray debug, not useful --- compose_rl/utils/ray_utils.py | 2 +- scripts/launch_composer_ray.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/compose_rl/utils/ray_utils.py b/compose_rl/utils/ray_utils.py index d8410ce6..dae3229d 100644 --- a/compose_rl/utils/ray_utils.py +++ b/compose_rl/utils/ray_utils.py @@ -45,7 +45,7 @@ def init_ray_with_torch_distributed(timeout_seconds: int = 30): # Start Ray Server on master node subprocess.run(['ray', 'start', '--head'], check=True) # connect to the ray cluster - ray.init(logging_level="DEBUG") + ray.init() # get existing ray ip and port ctx = ray.get_runtime_context() address = ctx.gcs_address diff --git a/scripts/launch_composer_ray.py b/scripts/launch_composer_ray.py index 913a1cda..b093c220 100644 --- a/scripts/launch_composer_ray.py +++ b/scripts/launch_composer_ray.py @@ -161,7 +161,7 @@ def start_ray_nodes(): # Send the local node IP to other ranks broadcast_string(ip, src_rank=0) - ray.init(logging_level="DEBUG") + ray.init() # Wait for all ray clusters to start dist.barrier() From fec31ba92cdc820981dfac4ad4f5eabd1fc5608e Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 09:30:22 -0400 Subject: [PATCH 59/74] remove some debug print --- test_single_controller_ppo.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index 4dd20b12..d859a460 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -1472,7 +1472,6 @@ def get_next_iter_rollouts(self): vllm_logprobs = vllm_logprobs[0] max_vllm_generated_len = max([len(response) for response in sequences]) - max_vllm_logprobs_len = max([len(response) for response in vllm_logprobs]) padded_responses = [] for sequence in sequences: @@ -1490,20 +1489,12 @@ def get_next_iter_rollouts(self): processed_sequences = torch.cat([all_prompts, padded_responses], dim=-1) iter_data['sequences'] = processed_sequences - print('===============================') - print(f"processed_sequences.shape: {processed_sequences.shape=}") - print('===============================') - padded_logprobs = [] for logprobs in vllm_logprobs: logprobs = list(logprobs) if len(logprobs) < max_vllm_generated_len: logprobs = logprobs + [0] * (max_vllm_generated_len - len(logprobs)) padded_logprobs.append(logprobs) - - print('===============================') - print(f"len(padded_logprobs): {len(padded_logprobs)=}") - print('===============================') try: padded_logprobs = torch.tensor( @@ -1515,16 +1506,9 @@ def get_next_iter_rollouts(self): print(f"Error: {e}") raise e - print('===============================') - print(f"padded_logprobs.shape: {padded_logprobs.shape=}") - print('===============================') - temp_zeros = torch.zeros_like(all_prompts, dtype=torch.float, device=torch.device('cpu')) processed_logprobs = torch.cat([temp_zeros, padded_logprobs], dim=-1) iter_data['vllm_logprobs'] = processed_logprobs - print('===============================') - print(f"processed_logprobs.shape: {processed_logprobs.shape=}") - print('===============================') assert processed_logprobs.shape == processed_sequences.shape, f'vllm_logprobs and sequences have different shapes {processed_logprobs.shape=}, {processed_sequences.shape=}' From d810e0461b047254cfd6296ea4cf51d7c80c05b2 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 09:43:23 -0400 Subject: [PATCH 60/74] . --- compose_rl/algorithms/online/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 23a6d361..8836f8ba 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -450,7 +450,7 @@ def policy_loss( masked_importance_ratio = torch.exp(masked_log_ratio) # pi_prox / pi_behavior # Convert beta to a simple float - assert masked_importance_ratio.shape == masked_log_probs_diff.shape, f'masked_importance_ratio and masked_log_probs_diff have different shapes {importance_ratio.shape=}, {masked_log_probs_diff.shape=}' + assert masked_importance_ratio.shape == masked_log_probs_diff.shape, f'masked_importance_ratio and masked_log_probs_diff have different shapes {masked_importance_ratio.shape=}, {masked_log_probs_diff.shape=}' beta_float = float(beta) policy_loss = (masked_importance_ratio*((beta_float * masked_log_probs_diff - prompt_advantages)**2)).mean() From 8736cb0cdb458607f797dd4bf78068f33b5f6947 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 11:22:12 -0400 Subject: [PATCH 61/74] first draft of decoupled ppo --- compose_rl/algorithms/online/model_methods.py | 7 +++++++ yamls/single-controller-smd-workflow.yaml | 13 +++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 8836f8ba..62c10a38 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -241,6 +241,11 @@ def policy_loss( online_log_probs, old_log_probs = outputs['online_log_probs'], batch[ 'old_log_probs'] old_entropies = batch['old_entropies'] + + vllm_logprobs = batch['vllm_logprobs'] + token_log_ratio = torch.clamp(old_log_probs - vllm_logprobs, min = -100.0, max = 100.0) # [ ln (pi_prox_t / pi_behavior_t) ] + token_IS_ratio = torch.exp(token_log_ratio) # [ pi_prox_t / pi_behavior_t ], (bs, gen_len) + gen_logits = utils.get_batched_generated_values( batched_values=outputs['logits'], prompt_len=batch['prompt_len'], @@ -325,6 +330,8 @@ def policy_loss( batch['action_mask'], ) + policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t + if length_normalize_policy_loss: policy_loss = utils.sample_wise_masked_mean( policy_loss, diff --git a/yamls/single-controller-smd-workflow.yaml b/yamls/single-controller-smd-workflow.yaml index 1be3214d..c6d5d81b 100644 --- a/yamls/single-controller-smd-workflow.yaml +++ b/yamls/single-controller-smd-workflow.yaml @@ -51,7 +51,7 @@ parameters: loggers: mlflow: tags: - run: test_single_controller_smd_beta_0.1_max_async_step_2_ds_llama_7b_IS + run: test_single_controller_smd_beta_0.1_max_async_step_4_ds_llama_7b_IS group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo @@ -59,9 +59,10 @@ parameters: ppo: {} orl_eval: evals: - - name: gsm8k - #- name: math_500 - #- name: math_hard + #- name: gsm8k + - name: math_500 + - name: math_hard + - name: math_competition eval_overrides: generation_params: max_tokens: 8192 @@ -182,9 +183,9 @@ parameters: console_log_interval: 1ba device_eval_batch_size: 1 eval_subset_num_batches: -1 - global_train_batch_size: 16 # global_train_batch_size * num_batches_per_update / generations_per_prompt = number of unique prompts + global_train_batch_size: 64 # global_train_batch_size * num_batches_per_update / generations_per_prompt = number of unique prompts device_train_microbatch_size: 1 vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 - max_async_step: 2 + max_async_step: 4 From 3b6d7f27e359b94ceff40a3730a9ecae25bedc9f Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 20:31:18 -0400 Subject: [PATCH 62/74] . --- compose_rl/algorithms/online/model_methods.py | 2 ++ yamls/single-controller-grpo-workflow.yaml | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 62c10a38..1eedc8d3 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -394,6 +394,8 @@ def policy_loss( seq_entropies, 'advantages/mean': utils.sample_wise_masked_mean(advantages, batch['action_mask']), + 'importance_ratio/mean': + utils.sample_wise_masked_mean(token_IS_ratio, batch['action_mask']), } # Add entropy percentiles to policy_dict for i, p in enumerate(percentiles): diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index aa6c779d..e57105ec 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -50,7 +50,7 @@ parameters: loggers: mlflow: tags: - run: test_single_controller_grpo_ds_llama_7b_open_r1_48k_grpo_max_async_step_4 + run: test_single_controller_grpo_ds_llama_7b_open_r1_48k_grpo_max_async_step_4_ls group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo @@ -58,9 +58,10 @@ parameters: ppo: {} orl_eval: evals: - - name: gsm8k + #- name: gsm8k - name: math_500 - name: math_hard + - name: math_competition eval_overrides: generation_params: max_tokens: 8192 From e3212ef92cee0b4a35b7fd2902a1a985f21f7cd0 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 20:37:42 -0400 Subject: [PATCH 63/74] . --- compose_rl/algorithms/online/model_methods.py | 2 +- yamls/single-controller-grpo-workflow.yaml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 1eedc8d3..d5edf566 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -330,7 +330,7 @@ def policy_loss( batch['action_mask'], ) - policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t + #policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t if length_normalize_policy_loss: policy_loss = utils.sample_wise_masked_mean( diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index e57105ec..5ba8b99a 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -50,10 +50,10 @@ parameters: loggers: mlflow: tags: - run: test_single_controller_grpo_ds_llama_7b_open_r1_48k_grpo_max_async_step_4_ls + run: test_single_controller_grpo_ds_llama_7b_open_r1_48k_grpo_max_async_step_0_nols group: grpo tracking_uri: databricks - experiment_name: test_single_controller_ppo + experiment_name: test_single_controller_grpo_IS_vs_no_IS callbacks: ppo: {} orl_eval: @@ -186,4 +186,4 @@ parameters: vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 - max_async_step: 4 + max_async_step: 0 From ef2b7f4dbcb1d58d06a84f141ed1e8bfd169df64 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 21:06:10 -0400 Subject: [PATCH 64/74] . --- compose_rl/algorithms/online/model_methods.py | 2 +- yamls/single-controller-grpo-workflow.yaml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index d5edf566..1eedc8d3 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -330,7 +330,7 @@ def policy_loss( batch['action_mask'], ) - #policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t + policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t if length_normalize_policy_loss: policy_loss = utils.sample_wise_masked_mean( diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 5ba8b99a..589621bc 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -14,7 +14,7 @@ integrations: path: /workspace/compose-rl git_repo: databricks/compose-rl ssh_clone: true - git_branch: single-controller-hackathon + git_branch: single-controller-hackathon-smd-vllm - integration_type: git_repo path: /workspace/research-universe git_repo: databricks-mosaic/research-universe @@ -50,7 +50,7 @@ parameters: loggers: mlflow: tags: - run: test_single_controller_grpo_ds_llama_7b_open_r1_48k_grpo_max_async_step_0_nols + run: test_single_controller_grpo_ds_llama_7b_open_r1_48k_grpo_max_async_step_2_ls group: grpo tracking_uri: databricks experiment_name: test_single_controller_grpo_IS_vs_no_IS @@ -186,4 +186,4 @@ parameters: vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 - max_async_step: 0 + max_async_step: 2 From f047c331b3abdf49a95ba0051dce100679a02564 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 21:32:04 -0400 Subject: [PATCH 65/74] . --- compose_rl/algorithms/online/model_methods.py | 2 +- yamls/single-controller-grpo-workflow.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 1eedc8d3..d5edf566 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -330,7 +330,7 @@ def policy_loss( batch['action_mask'], ) - policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t + #policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t if length_normalize_policy_loss: policy_loss = utils.sample_wise_masked_mean( diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 589621bc..eb4a7716 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -50,7 +50,7 @@ parameters: loggers: mlflow: tags: - run: test_single_controller_grpo_ds_llama_7b_open_r1_48k_grpo_max_async_step_2_ls + run: test_single_controller_grpo_ds_llama_7b_open_r1_48k_grpo_max_async_step_0_non_ls group: grpo tracking_uri: databricks experiment_name: test_single_controller_grpo_IS_vs_no_IS @@ -186,4 +186,4 @@ parameters: vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 - max_async_step: 2 + max_async_step: 0 From 869923ec6cca3b4ffd5de09989e88661a9a8b462 Mon Sep 17 00:00:00 2001 From: wensun Date: Mon, 1 Sep 2025 21:32:27 -0400 Subject: [PATCH 66/74] . --- compose_rl/algorithms/online/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index d5edf566..1eedc8d3 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -330,7 +330,7 @@ def policy_loss( batch['action_mask'], ) - #policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t + policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t if length_normalize_policy_loss: policy_loss = utils.sample_wise_masked_mean( From d6a9922fe079f9e5f7400856e351bc15e8f64492 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 08:07:53 -0400 Subject: [PATCH 67/74] add importance weight option --- compose_rl/algorithms/online/model.py | 4 ++++ compose_rl/algorithms/online/model_methods.py | 21 +++++++++++++++++-- yamls/single-controller-grpo-workflow.yaml | 1 + yamls/single-controller-smd-workflow.yaml | 1 + 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index 842fc36d..3c544930 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -280,6 +280,7 @@ def __init__( kl_estimator: str = 'k3', kl_clip_range: float = 40.0, temperature: float = 1.0, + importance_weighting: bool = True, **kwargs: Any, ): """Initialize the ComposerHFCriticFreePolicyModel. @@ -297,6 +298,7 @@ def __init__( kl_clip_range (float): The KL clip range. Default: ``40.0``. beta (float): pi_ref KL hyperparameter for APO. Default: ``1e-3`` temperature (float): Sampling temperature used for generations to properly scale logits. + importance_weighting (bool): Whether to apply importance weighting for off-policy rollouts. Default: ``True``. """ super().__init__(**kwargs) self.policy_kl = [] @@ -312,6 +314,7 @@ def __init__( self.kl_clip_range = kl_clip_range self.entropy_loss_weight = entropy_loss_weight self.temperature = temperature + self.importance_weighting = importance_weighting def forward(self, batch: MutableMapping): ret_val = composer_online_rl_forward( @@ -340,6 +343,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping): kl_estimator=self.kl_estimator, kl_clip_range=self.kl_clip_range, entropy_loss_weight=self.entropy_loss_weight, + importance_weighting=self.importance_weighting, ) self.policy_kl.append(return_dict['kl/policy_kl']) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 1eedc8d3..7cb34822 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -233,6 +233,7 @@ def policy_loss( length_normalize_policy_loss: bool = True, kl_estimator: Optional[str] = 'k3', kl_clip_range: Optional[float] = 40.0, + importance_weighting: bool = True, ) -> MutableMapping: if loss_type in ALGORITHM_TYPE.CLIPPED_PG: @@ -330,7 +331,11 @@ def policy_loss( batch['action_mask'], ) - policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t + if importance_weighting: + print("*"*100) + print('Using importance weighting') + print("*"*100) + policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t if length_normalize_policy_loss: policy_loss = utils.sample_wise_masked_mean( @@ -461,7 +466,17 @@ def policy_loss( # Convert beta to a simple float assert masked_importance_ratio.shape == masked_log_probs_diff.shape, f'masked_importance_ratio and masked_log_probs_diff have different shapes {masked_importance_ratio.shape=}, {masked_log_probs_diff.shape=}' beta_float = float(beta) - policy_loss = (masked_importance_ratio*((beta_float * masked_log_probs_diff - prompt_advantages)**2)).mean() + + if importance_weighting: + print("*"*100) + print('Using importance weighting') + print("*"*100) + policy_loss = (masked_importance_ratio*((beta_float * masked_log_probs_diff - prompt_advantages)**2)).mean() + else: + print("*"*100) + print('Not using importance weighting') + print("*"*100) + policy_loss = ((beta_float * masked_log_probs_diff - prompt_advantages)**2).mean() rewards = utils.masked_sum( batch['rewards'], @@ -504,6 +519,7 @@ def online_rl_loss( entropy_loss_weight: float | None = None, kl_estimator: Optional[str] = 'k3', kl_clip_range: Optional[float] = 40.0, + importance_weighting: bool = True, ) -> MutableMapping: """Compute the online RL loss. @@ -580,6 +596,7 @@ def online_rl_loss( length_normalize_policy_loss=length_normalize_policy_loss, kl_estimator=kl_estimator, kl_clip_range=kl_clip_range, + importance_weighting=importance_weighting, ) return_dict.update(**policy_dict) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index eb4a7716..d577e3b2 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -46,6 +46,7 @@ parameters: normalize_advantage: true use_flash_attention_2: true length_normalize_policy_loss: true + importance_weighting: true # Enable importance sampling for off-policy rollouts pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B loggers: mlflow: diff --git a/yamls/single-controller-smd-workflow.yaml b/yamls/single-controller-smd-workflow.yaml index c6d5d81b..8641dcde 100644 --- a/yamls/single-controller-smd-workflow.yaml +++ b/yamls/single-controller-smd-workflow.yaml @@ -47,6 +47,7 @@ parameters: normalize_advantage: false use_flash_attention_2: true length_normalize_policy_loss: true + importance_weighting: true # Enable importance sampling for off-policy rollouts pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B loggers: mlflow: From 28c10bb967bc847d3a5e902865c7aa0f3b74b967 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 08:17:26 -0400 Subject: [PATCH 68/74] . --- compose_rl/algorithms/online/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index 3c544930..ab4cd98e 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -316,6 +316,10 @@ def __init__( self.temperature = temperature self.importance_weighting = importance_weighting + print("*"*100) + print('Importance weighting at initialization: ', self.importance_weighting) + print("*"*100) + def forward(self, batch: MutableMapping): ret_val = composer_online_rl_forward( batch, From c917afb90ba9e4d901cc62f9644ce3edfd0f500b Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 08:33:49 -0400 Subject: [PATCH 69/74] clean up logging --- compose_rl/algorithms/online/model.py | 4 ---- compose_rl/algorithms/online/model_methods.py | 18 ++++-------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/compose_rl/algorithms/online/model.py b/compose_rl/algorithms/online/model.py index ab4cd98e..3c544930 100644 --- a/compose_rl/algorithms/online/model.py +++ b/compose_rl/algorithms/online/model.py @@ -316,10 +316,6 @@ def __init__( self.temperature = temperature self.importance_weighting = importance_weighting - print("*"*100) - print('Importance weighting at initialization: ', self.importance_weighting) - print("*"*100) - def forward(self, batch: MutableMapping): ret_val = composer_online_rl_forward( batch, diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 7cb34822..dd202aa5 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -332,9 +332,6 @@ def policy_loss( ) if importance_weighting: - print("*"*100) - print('Using importance weighting') - print("*"*100) policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t if length_normalize_policy_loss: @@ -423,7 +420,6 @@ def policy_loss( assert old_log_probs.shape == vllm_logprobs.shape, f'old_log_probs and vllm_logprobs have different shapes {old_log_probs.shape=}, {vllm_logprobs.shape=}' token_log_ratio = old_log_probs - vllm_logprobs # [ ln (pi_prox_t / pi_behavior_t) ] - online_to_old_diff = online_log_probs - old_log_probs # ln(π/π_old) for SMD #compute KL to pi_ref to keep track the divergence to \pi_ref @@ -460,23 +456,17 @@ def policy_loss( batch['action_mask'], dim=-1, ) #size: (batch_size,) # \sum_t ln (pi_prox_t / pi_behavior_t) - masked_log_ratio = torch.clamp(masked_log_ratio, min = -100.0, max = 10.0) # clip to avoid overflow + masked_log_ratio = torch.clamp(masked_log_ratio, min = -100.0, max = 100.0) # clip to avoid overflow masked_importance_ratio = torch.exp(masked_log_ratio) # pi_prox / pi_behavior # Convert beta to a simple float assert masked_importance_ratio.shape == masked_log_probs_diff.shape, f'masked_importance_ratio and masked_log_probs_diff have different shapes {masked_importance_ratio.shape=}, {masked_log_probs_diff.shape=}' beta_float = float(beta) + seq_level_policy_loss = (beta_float * masked_log_probs_diff - prompt_advantages)**2 if importance_weighting: - print("*"*100) - print('Using importance weighting') - print("*"*100) - policy_loss = (masked_importance_ratio*((beta_float * masked_log_probs_diff - prompt_advantages)**2)).mean() - else: - print("*"*100) - print('Not using importance weighting') - print("*"*100) - policy_loss = ((beta_float * masked_log_probs_diff - prompt_advantages)**2).mean() + seq_level_policy_loss = masked_importance_ratio * seq_level_policy_loss # IS at the sequence level + policy_loss = seq_level_policy_loss.mean() rewards = utils.masked_sum( batch['rewards'], From ed9e1f0e54ea2b18e8080579916072ecb19307ee Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 08:37:56 -0400 Subject: [PATCH 70/74] more comments --- compose_rl/algorithms/online/model_methods.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index dd202aa5..b1d1a71c 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -396,7 +396,7 @@ def policy_loss( seq_entropies, 'advantages/mean': utils.sample_wise_masked_mean(advantages, batch['action_mask']), - 'importance_ratio/mean': + 'importance_ratio/mean': # always logging this in default regardless of importance weighting: want to check how far vllm logp is from log pi_old utils.sample_wise_masked_mean(token_IS_ratio, batch['action_mask']), } # Add entropy percentiles to policy_dict @@ -458,15 +458,13 @@ def policy_loss( ) #size: (batch_size,) # \sum_t ln (pi_prox_t / pi_behavior_t) masked_log_ratio = torch.clamp(masked_log_ratio, min = -100.0, max = 100.0) # clip to avoid overflow masked_importance_ratio = torch.exp(masked_log_ratio) # pi_prox / pi_behavior - - # Convert beta to a simple float assert masked_importance_ratio.shape == masked_log_probs_diff.shape, f'masked_importance_ratio and masked_log_probs_diff have different shapes {masked_importance_ratio.shape=}, {masked_log_probs_diff.shape=}' - beta_float = float(beta) - seq_level_policy_loss = (beta_float * masked_log_probs_diff - prompt_advantages)**2 + beta_float = float(beta) # convert it to float to avoid type error + seq_level_policy_loss = (beta_float * masked_log_probs_diff - prompt_advantages)**2 # (bs,) if importance_weighting: seq_level_policy_loss = masked_importance_ratio * seq_level_policy_loss # IS at the sequence level - policy_loss = seq_level_policy_loss.mean() + policy_loss = seq_level_policy_loss.mean() # (1,) rewards = utils.masked_sum( batch['rewards'], @@ -482,13 +480,13 @@ def policy_loss( 'gen/entropy': old_entropies, 'rewards/mean': torch.mean( rewards, - ), #compute the average reward of the current batch + ), # compute the average reward of the current batch 'advantages/mean': torch.mean( prompt_advantages, # SMD uses prompt_advantages, not advantages - ), #compute the average of the prompt advantages for SMD + ), # compute the average of the prompt advantages for SMD 'importance_ratio/mean': torch.mean( masked_importance_ratio, - ), + ), # always logging this in default regardless of importance weighting: want to check how far vllm logp is from log pi_old } return policy_dict else: From 0fd0458111fd5fc89277e8d1400f7196f387622a Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 08:40:58 -0400 Subject: [PATCH 71/74] revert the yamls back but added importance weight --- yamls/single-controller-grpo-workflow.yaml | 40 ++++++++++++++-------- yamls/single-controller-smd-workflow.yaml | 31 ++++++++--------- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index d577e3b2..29dc3bab 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -1,4 +1,4 @@ -name: single-controller-hackathon_grpo +name: single-controller-hackathon image: mosaicml/dle:nightly-latest scheduling: @@ -14,7 +14,7 @@ integrations: path: /workspace/compose-rl git_repo: databricks/compose-rl ssh_clone: true - git_branch: single-controller-hackathon-smd-vllm + git_branch: single-controller-hackathon - integration_type: git_repo path: /workspace/research-universe git_repo: databricks-mosaic/research-universe @@ -46,23 +46,22 @@ parameters: normalize_advantage: true use_flash_attention_2: true length_normalize_policy_loss: true - importance_weighting: true # Enable importance sampling for off-policy rollouts - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + importance_weighting: true + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B loggers: mlflow: tags: - run: test_single_controller_grpo_ds_llama_7b_open_r1_48k_grpo_max_async_step_0_non_ls + run: test_single_controller_ppo_deepseek_l8b_open_r1_48k group: grpo tracking_uri: databricks - experiment_name: test_single_controller_grpo_IS_vs_no_IS + experiment_name: test_single_controller_ppo callbacks: ppo: {} orl_eval: evals: - #- name: gsm8k + - name: gsm8k - name: math_500 - name: math_hard - - name: math_competition eval_overrides: generation_params: max_tokens: 8192 @@ -91,7 +90,7 @@ parameters: alpha: 1 t_warmup: 10iter tokenizer: - name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B kwargs: padding: longest pad_token: <|finetune_right_pad_id|> @@ -105,8 +104,19 @@ parameters: name: MinibatchRolloutBuffer rewards: math_verifier: - reward: 1 + reward: 4 reward_type: math_verifier + bad_generation_end: + reward: -1 + eos_penalty: true + reward_type: bad_generation_end + math_format_verifier: + reward: 1 + reward_type: math_format_verifier + penalize_extra_short_responses: + reward: -1 + reward_type: short_response_reward + len_threshold: 10 lambda_gae: 1 global_seed: 17 max_gen_len: 8192 @@ -117,7 +127,7 @@ parameters: kl_controller: kl_ctl_type: fixed init_kl_coef: 0 - tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B num_train_nodes: 1 reference_model: precision: amp_bf16 @@ -127,7 +137,7 @@ parameters: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B generation_kwargs: top_p: 1 do_sample: true @@ -159,7 +169,7 @@ parameters: max_seq_len: 10240 save_folder: /tmp/checkpoints dist_timeout: 1800 - max_duration: 100iter + max_duration: 10iter progress_bar: false train_loader: name: prompt @@ -174,7 +184,7 @@ parameters: download_timeout: 1800 drop_last: true num_workers: 1 - eval_interval: 5iter + eval_interval: 2iter save_interval: 100iter log_to_console: true save_overwrite: true @@ -182,7 +192,7 @@ parameters: console_log_interval: 1ba device_eval_batch_size: 1 eval_subset_num_batches: -1 - global_train_batch_size: 64 # global_train_batch_size * num_batches_per_update / generations_per_prompt = number of unique prompts + global_train_batch_size: 64 device_train_microbatch_size: 1 vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false diff --git a/yamls/single-controller-smd-workflow.yaml b/yamls/single-controller-smd-workflow.yaml index 8641dcde..07c9629e 100644 --- a/yamls/single-controller-smd-workflow.yaml +++ b/yamls/single-controller-smd-workflow.yaml @@ -14,7 +14,7 @@ integrations: path: /workspace/compose-rl git_repo: databricks/compose-rl ssh_clone: true - git_branch: single-controller-hackathon-smd-vllm #single-controller-hackathon + git_branch: single-controller-hackathon-smd #single-controller-hackathon - integration_type: git_repo path: /workspace/research-universe git_repo: databricks-mosaic/research-universe @@ -33,11 +33,11 @@ parameters: model: name: hf_critic_free_lm loss_type: smd #grpo - target_kl: 0.1 # is it used in SDM? + target_kl: 100000 # is it used in SDM? pretrained: true init_device: mixed kl_estimator: k3 - beta: 0.1 #0.01 + beta: 1e-3 #0.01 kl_clip_range: 40 use_auth_token: true compute_kl_loss: false @@ -47,12 +47,12 @@ parameters: normalize_advantage: false use_flash_attention_2: true length_normalize_policy_loss: true - importance_weighting: true # Enable importance sampling for off-policy rollouts - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + importance_weighting: true + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B loggers: mlflow: tags: - run: test_single_controller_smd_beta_0.1_max_async_step_4_ds_llama_7b_IS + run: test_single_controller_ppo_deepseek_l8b_open_r1_48k_smd_beta_0.01_target_kl_100000 group: grpo tracking_uri: databricks experiment_name: test_single_controller_ppo @@ -63,7 +63,6 @@ parameters: #- name: gsm8k - name: math_500 - name: math_hard - - name: math_competition eval_overrides: generation_params: max_tokens: 8192 @@ -92,7 +91,7 @@ parameters: alpha: 1 t_warmup: 10iter tokenizer: - name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B kwargs: padding: longest pad_token: <|finetune_right_pad_id|> @@ -118,7 +117,7 @@ parameters: kl_controller: kl_ctl_type: fixed init_kl_coef: 0 - tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + tokenizer_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B num_train_nodes: 1 reference_model: precision: amp_bf16 @@ -128,7 +127,7 @@ parameters: pretrained: true use_auth_token: true use_flash_attention_2: true - pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B #deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + pretrained_model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Llama-8B generation_kwargs: top_p: 1 do_sample: true @@ -141,7 +140,7 @@ parameters: algorithms: gradient_clipping: clipping_type: norm - clipping_threshold: 0.001 + clipping_threshold: 1.0 #0.001 autoresume: true log_config: true fsdp_config: @@ -160,7 +159,7 @@ parameters: max_seq_len: 10240 save_folder: /tmp/checkpoints dist_timeout: 1800 - max_duration: 100iter + max_duration: 200iter progress_bar: false train_loader: name: prompt @@ -175,7 +174,7 @@ parameters: download_timeout: 1800 drop_last: true num_workers: 1 - eval_interval: 5iter + eval_interval: 10iter eval_first: false save_interval: 100iter log_to_console: true @@ -184,9 +183,9 @@ parameters: console_log_interval: 1ba device_eval_batch_size: 1 eval_subset_num_batches: -1 - global_train_batch_size: 64 # global_train_batch_size * num_batches_per_update / generations_per_prompt = number of unique prompts - device_train_microbatch_size: 1 + global_train_batch_size: 64 + device_train_microbatch_size: 1 # Increased for better SMD batch statistics vllm_tensor_parallel_size: 1 vllm_enable_prefix_caching: false save_num_checkpoints_to_keep: 1 - max_async_step: 4 + max_async_step: 0 From 1641314de6cdb4d81c99a5b4e6e1a9a09660b965 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 08:41:56 -0400 Subject: [PATCH 72/74] include all math evals --- yamls/single-controller-grpo-workflow.yaml | 1 + yamls/single-controller-smd-workflow.yaml | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 29dc3bab..82b74b69 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -62,6 +62,7 @@ parameters: - name: gsm8k - name: math_500 - name: math_hard + - name: math_competition eval_overrides: generation_params: max_tokens: 8192 diff --git a/yamls/single-controller-smd-workflow.yaml b/yamls/single-controller-smd-workflow.yaml index 07c9629e..6701cb86 100644 --- a/yamls/single-controller-smd-workflow.yaml +++ b/yamls/single-controller-smd-workflow.yaml @@ -60,9 +60,10 @@ parameters: ppo: {} orl_eval: evals: - #- name: gsm8k + - name: gsm8k - name: math_500 - name: math_hard + - name: math_competition eval_overrides: generation_params: max_tokens: 8192 From 70b40a3d76d5e8746f93d05523790a90e55e095b Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 13:42:45 -0400 Subject: [PATCH 73/74] . --- compose_rl/algorithms/online/model_methods.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index b1d1a71c..c7603e90 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -423,7 +423,7 @@ def policy_loss( online_to_old_diff = online_log_probs - old_log_probs # ln(π/π_old) for SMD #compute KL to pi_ref to keep track the divergence to \pi_ref - policy_kl_dict = utils.approx_kl( + ref_policy_kl_dict = utils.approx_kl( log_p=ref_log_probs, log_q=online_log_probs, #log_q - log_p = log pi - log pi_ref kl_clip_range=kl_clip_range, @@ -437,11 +437,11 @@ def policy_loss( with torch.no_grad(): policy_kl = utils.masked_mean( - policy_kl_dict[kl_estimator], # pyright: ignore + old_policy_kl_dict[kl_estimator], # pyright: ignore batch['action_mask'], ) #plain average over all tokens (KL to pi_ref) - old_policy_kl = utils.masked_mean( - old_policy_kl_dict[kl_estimator], # pyright: ignore + ref_policy_kl = utils.masked_mean( + ref_policy_kl_dict[kl_estimator], # pyright: ignore batch['action_mask'], ) #plain average over all tokens (KL to pi_ref) @@ -475,7 +475,7 @@ def policy_loss( policy_dict = { 'loss/policy_loss': policy_loss, 'kl/policy_kl': policy_kl, # Required by calling code in model.py - 'kl/old_policy_kl': old_policy_kl, + 'kl/online_ift_kl': ref_policy_kl, 'gen/gen_length': batch['action_mask'].sum(dim=1).to(torch.float32), 'gen/entropy': old_entropies, 'rewards/mean': torch.mean( From 693d1f0e7dbca6856935daf1e90d66a4046ffbf6 Mon Sep 17 00:00:00 2001 From: wensun Date: Tue, 2 Sep 2025 22:59:35 -0400 Subject: [PATCH 74/74] addressed comments from bowen --- compose_rl/algorithms/online/model_methods.py | 13 ++++++------- compose_rl/utils/ray_utils.py | 2 +- test_single_controller_ppo.py | 17 ++++++++--------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index c7603e90..40aae4fc 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -244,8 +244,7 @@ def policy_loss( old_entropies = batch['old_entropies'] vllm_logprobs = batch['vllm_logprobs'] - token_log_ratio = torch.clamp(old_log_probs - vllm_logprobs, min = -100.0, max = 100.0) # [ ln (pi_prox_t / pi_behavior_t) ] - token_IS_ratio = torch.exp(token_log_ratio) # [ pi_prox_t / pi_behavior_t ], (bs, gen_len) + token_IS_ratio = torch.exp(old_log_probs)/torch.exp(vllm_logprobs) gen_logits = utils.get_batched_generated_values( batched_values=outputs['logits'], @@ -332,7 +331,7 @@ def policy_loss( ) if importance_weighting: - policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t + policy_loss = policy_loss * token_IS_ratio # [ pi_old_t / pi_behavior_t * policy_loss_t ]_t if length_normalize_policy_loss: policy_loss = utils.sample_wise_masked_mean( @@ -415,11 +414,11 @@ def policy_loss( online_log_probs = outputs['online_log_probs'] ref_log_probs = batch['ift_log_probs'] old_entropies = batch['old_entropies'] - old_log_probs = batch['old_log_probs'] # note this is the log prob of the pi_prox -- the usual pi_old in ppo language. + old_log_probs = batch['old_log_probs'] # note this is the log prob of the pi_old -- the usual pi_old in ppo language. vllm_logprobs = batch['vllm_logprobs'] # note this the log prob from vllm when generating the rollouts, i.e., log pi_behavior assert old_log_probs.shape == vllm_logprobs.shape, f'old_log_probs and vllm_logprobs have different shapes {old_log_probs.shape=}, {vllm_logprobs.shape=}' - token_log_ratio = old_log_probs - vllm_logprobs # [ ln (pi_prox_t / pi_behavior_t) ] + token_log_ratio = old_log_probs - vllm_logprobs # [ ln (pi_old_t / pi_behavior_t) ] online_to_old_diff = online_log_probs - old_log_probs # ln(π/π_old) for SMD #compute KL to pi_ref to keep track the divergence to \pi_ref @@ -455,9 +454,9 @@ def policy_loss( token_log_ratio, batch['action_mask'], dim=-1, - ) #size: (batch_size,) # \sum_t ln (pi_prox_t / pi_behavior_t) + ) #size: (batch_size,) # \sum_t ln (pi_old_t / pi_behavior_t) masked_log_ratio = torch.clamp(masked_log_ratio, min = -100.0, max = 100.0) # clip to avoid overflow - masked_importance_ratio = torch.exp(masked_log_ratio) # pi_prox / pi_behavior + masked_importance_ratio = torch.exp(masked_log_ratio) # pi_old / pi_behavior assert masked_importance_ratio.shape == masked_log_probs_diff.shape, f'masked_importance_ratio and masked_log_probs_diff have different shapes {masked_importance_ratio.shape=}, {masked_log_probs_diff.shape=}' beta_float = float(beta) # convert it to float to avoid type error diff --git a/compose_rl/utils/ray_utils.py b/compose_rl/utils/ray_utils.py index dae3229d..adf51d29 100644 --- a/compose_rl/utils/ray_utils.py +++ b/compose_rl/utils/ray_utils.py @@ -45,7 +45,7 @@ def init_ray_with_torch_distributed(timeout_seconds: int = 30): # Start Ray Server on master node subprocess.run(['ray', 'start', '--head'], check=True) # connect to the ray cluster - ray.init() + ray.init('auto') # get existing ray ip and port ctx = ray.get_runtime_context() address = ctx.gcs_address diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index d859a460..2246678c 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -1473,6 +1473,7 @@ def get_next_iter_rollouts(self): max_vllm_generated_len = max([len(response) for response in sequences]) + # TODO: clean this up since this padded_response and padded_log_probs share similarity with the generation_utils.py padded_responses = [] for sequence in sequences: sequence = list(sequence) @@ -1496,15 +1497,12 @@ def get_next_iter_rollouts(self): logprobs = logprobs + [0] * (max_vllm_generated_len - len(logprobs)) padded_logprobs.append(logprobs) - try: - padded_logprobs = torch.tensor( - padded_logprobs, - dtype=torch.float, - device=torch.device('cpu'), - ) - except Exception as e: - print(f"Error: {e}") - raise e + padded_logprobs = torch.tensor( + padded_logprobs, + dtype=torch.float, + device=torch.device('cpu'), + ) + temp_zeros = torch.zeros_like(all_prompts, dtype=torch.float, device=torch.device('cpu')) processed_logprobs = torch.cat([temp_zeros, padded_logprobs], dim=-1) @@ -1540,6 +1538,7 @@ def get_next_iter_rollouts(self): # If all the processes early exit generate, then we need to manually pad everything # we can pad this with pad tokens, since we switch the padding between left and right # padding based on the sequence length + max_sequence_length. + #TODO: check if this padding is needed? i assume so. if prompt_tokens.size(1) + max_gen_len > sequences.size(1): len_to_pad = max_gen_len - ( sequences.size(1) - prompt_tokens.size(1)