Skip to content
Merged
Show file tree
Hide file tree
Changes from 73 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
c5f4b5c
.
wensun Aug 26, 2025
60f016a
.
wensun Aug 26, 2025
50b3a70
.
wensun Aug 26, 2025
bed1179
.
wensun Aug 26, 2025
c69e228
.
wensun Aug 26, 2025
665b80c
.
wensun Aug 26, 2025
8ae2b01
.
wensun Aug 26, 2025
752995b
.
wensun Aug 26, 2025
83e8a8f
.
wensun Aug 26, 2025
4a3a1ac
.
wensun Aug 26, 2025
0c45113
.
wensun Aug 26, 2025
689fd19
.
wensun Aug 26, 2025
f27e5e8
.
wensun Aug 26, 2025
3cae5d1
.
wensun Aug 26, 2025
3ad15d3
.
wensun Aug 26, 2025
3faab3c
.
wensun Aug 26, 2025
70730bc
.
wensun Aug 26, 2025
d309886
.
wensun Aug 26, 2025
9090acd
.
wensun Aug 26, 2025
a1d5dfe
.
wensun Aug 26, 2025
808dcd1
.
wensun Aug 26, 2025
746310a
.
wensun Aug 26, 2025
6434316
.
wensun Aug 26, 2025
b578daf
.
wensun Aug 26, 2025
fa59b25
.
wensun Aug 26, 2025
1f6af37
.
wensun Aug 26, 2025
1a8adc8
.
wensun Aug 26, 2025
8cb130e
.
wensun Aug 26, 2025
b65384f
.
wensun Aug 26, 2025
68b9bfb
.
wensun Aug 26, 2025
ee71caf
start clean up
wensun Aug 27, 2025
1e9d123
recreating wrong example beta string
wensun Aug 27, 2025
2badf80
bug reproduced, revert back to the working version
wensun Aug 27, 2025
918f6b7
creating second bebugging example: kl/policy_kl
wensun Aug 27, 2025
b5f5da1
convert back to the correct version and check in the yaml for smd
wensun Aug 27, 2025
0abe985
.
wensun Aug 27, 2025
a93c8d5
addressed comments from bowen
wensun Aug 28, 2025
f4d4cc5
.
wensun Aug 28, 2025
4fa329a
add vllm logp and importance weight
wensun Sep 1, 2025
e414dcd
.
wensun Sep 1, 2025
603d748
.
wensun Sep 1, 2025
35a365e
.
wensun Sep 1, 2025
549428f
.
wensun Sep 1, 2025
be53de6
.
wensun Sep 1, 2025
e5f8164
.
wensun Sep 1, 2025
2e159c0
.
wensun Sep 1, 2025
49a451a
.
wensun Sep 1, 2025
3799abb
.
wensun Sep 1, 2025
83c4494
.
wensun Sep 1, 2025
5572515
.
wensun Sep 1, 2025
0536ad2
.
wensun Sep 1, 2025
a42f45a
.
wensun Sep 1, 2025
96bb36e
.
wensun Sep 1, 2025
5690a5d
.
wensun Sep 1, 2025
1d21901
.
wensun Sep 1, 2025
472519c
.
wensun Sep 1, 2025
e11b2a5
.
wensun Sep 1, 2025
a693b1d
delete ray debug, not useful
wensun Sep 1, 2025
fec31ba
remove some debug print
wensun Sep 1, 2025
d810e04
.
wensun Sep 1, 2025
8736cb0
first draft of decoupled ppo
wensun Sep 1, 2025
3b6d7f2
.
wensun Sep 2, 2025
e3212ef
.
wensun Sep 2, 2025
ef2b7f4
.
wensun Sep 2, 2025
f047c33
.
wensun Sep 2, 2025
869923e
.
wensun Sep 2, 2025
d6a9922
add importance weight option
wensun Sep 2, 2025
28c10bb
.
wensun Sep 2, 2025
c917afb
clean up logging
wensun Sep 2, 2025
ed9e1f0
more comments
wensun Sep 2, 2025
0fd0458
revert the yamls back but added importance weight
wensun Sep 2, 2025
1641314
include all math evals
wensun Sep 2, 2025
70b40a3
.
wensun Sep 2, 2025
693d1f0
addressed comments from bowen
wensun Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,14 @@ def _vllm_generate(
pad_token_id: int, # type: ignore
all_prompts: list,
batch_sizes: list,
) -> list:
) -> tuple[list, list]:
futs = []
sampling_params = {
'temperature': generation_kwargs.get('temperature', 1.0),
'top_p': generation_kwargs.get('top_p', 1.0),
'top_k': generation_kwargs.get('top_k', -1),
'max_tokens': max_gen_len,
'logprobs': 1, # to get the logprobs directly from vllm
}

# We have to remove all pad tokens here
Expand Down Expand Up @@ -138,25 +139,33 @@ def _vllm_generate(
start_time = time.time()
results = ray.get(futs)
all_responses = []
all_logprobs = []

# Get all of the ray futures
for i, result in enumerate(results):
# Each result is a list of responses this assumes one output per input
all_responses.extend([resp.outputs[0].token_ids for resp in result])
all_logprobs.extend([[list(datum.values())[0].logprob for datum in resp.outputs[0].logprobs] for resp in result])

log.info(
f'took: {time.time() - start_time} to gather futures',
)

# Distribute padded responses back to the correct device
split_responses = []
split_logprobs = []
start = 0
for size in batch_sizes:
split_responses.append(
all_responses[start:start + size],
)
split_logprobs.append(
all_logprobs[start:start + size],
)
start += size
return split_responses


return split_responses, split_logprobs


def _vllm_chat(
Expand Down Expand Up @@ -254,7 +263,7 @@ def vllm_generate(
generation_kwargs: dict,
tokenizer: Tokenizer,
vllm_generate_function: str,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run vllm chat on the prompts using messages.

Runs generate over a set of sequences in the batch. It also does extra computation
Expand Down Expand Up @@ -320,7 +329,7 @@ def vllm_generate(
batch_sizes,
)
else:
split_responses = _vllm_generate(
split_responses, split_logprobs = _vllm_generate(
vllm_engines,
max_gen_len,
generation_kwargs,
Expand All @@ -335,6 +344,7 @@ def vllm_generate(
all_prompts = None
all_messages = None
split_responses = None
split_logprobs = None

# Do another garbage collection and empty the cache
gc.collect()
Expand All @@ -345,12 +355,19 @@ def vllm_generate(

# Scatter the generated responses back to the correct rank
local_responses = [None]
local_logprobs = [None]
start_time = time.time()
torch.distributed.scatter_object_list(
local_responses,
split_responses,
src=0,
)
torch.distributed.scatter_object_list(
local_logprobs,
split_logprobs,
src=0,
)
local_logprobs = local_logprobs[0]
local_responses = local_responses[0]

log.info(f'took: {time.time() - start_time} to scatter prompts')
Expand Down Expand Up @@ -379,7 +396,21 @@ def vllm_generate(
# Construct full sequences from the prompt and padded responses
sequences = torch.cat([prompt_tokens, padded_responses], dim=-1)
num_tokens_generated = sequences.size(1) - prompt_tokens.size(1)

padded_logprobs = []
for logprobs in local_logprobs: # type: ignore
logprobs = list(logprobs)
if len(logprobs) < max_vllm_generated_len:
logprobs = logprobs + [0] * (max_vllm_generated_len - len(logprobs))
padded_logprobs.append(logprobs)

vllm_logprobs = torch.tensor(
padded_logprobs,
dtype=torch.float,
device=cur_device,
)

log.info(
f'It took {time.time() - start_gen_time} to generate {num_tokens_generated} tokens',
)
return sequences
return sequences, vllm_logprobs
4 changes: 4 additions & 0 deletions compose_rl/algorithms/online/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def __init__(
kl_estimator: str = 'k3',
kl_clip_range: float = 40.0,
temperature: float = 1.0,
importance_weighting: bool = True,
**kwargs: Any,
):
"""Initialize the ComposerHFCriticFreePolicyModel.
Expand All @@ -297,6 +298,7 @@ def __init__(
kl_clip_range (float): The KL clip range. Default: ``40.0``.
beta (float): pi_ref KL hyperparameter for APO. Default: ``1e-3``
temperature (float): Sampling temperature used for generations to properly scale logits.
importance_weighting (bool): Whether to apply importance weighting for off-policy rollouts. Default: ``True``.
"""
super().__init__(**kwargs)
self.policy_kl = []
Expand All @@ -312,6 +314,7 @@ def __init__(
self.kl_clip_range = kl_clip_range
self.entropy_loss_weight = entropy_loss_weight
self.temperature = temperature
self.importance_weighting = importance_weighting

def forward(self, batch: MutableMapping):
ret_val = composer_online_rl_forward(
Expand Down Expand Up @@ -340,6 +343,7 @@ def loss(self, outputs: MutableMapping, batch: MutableMapping):
kl_estimator=self.kl_estimator,
kl_clip_range=self.kl_clip_range,
entropy_loss_weight=self.entropy_loss_weight,
importance_weighting=self.importance_weighting,
)

self.policy_kl.append(return_dict['kl/policy_kl'])
Expand Down
109 changes: 80 additions & 29 deletions compose_rl/algorithms/online/model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@ class OnPolicyEnum(Enum):
PPO = 'ppo'
GRPO = 'grpo'
APO = 'apo' #add A-star PO
SMD = 'smd' # SMD


class ALGORITHM_TYPE(set, Enum):
CRITIC_FREE = {OnPolicyEnum.GRPO, OnPolicyEnum.APO}
CRITIC_FREE = {OnPolicyEnum.GRPO, OnPolicyEnum.APO, OnPolicyEnum.SMD}
ACTOR_CRITIC = {OnPolicyEnum.PPO}
CLIPPED_PG = {OnPolicyEnum.PPO, OnPolicyEnum.GRPO}
REGRESSION = {
OnPolicyEnum.APO,
OnPolicyEnum.SMD,
}


Expand Down Expand Up @@ -231,13 +233,20 @@ def policy_loss(
length_normalize_policy_loss: bool = True,
kl_estimator: Optional[str] = 'k3',
kl_clip_range: Optional[float] = 40.0,
importance_weighting: bool = True,
) -> MutableMapping:

if loss_type in ALGORITHM_TYPE.CLIPPED_PG:
assert advantages is not None
assert advantages.dim() == 2 #(bs, max_gen_len)
online_log_probs, old_log_probs = outputs['online_log_probs'], batch[
'old_log_probs']
old_entropies = batch['old_entropies']

vllm_logprobs = batch['vllm_logprobs']
token_log_ratio = torch.clamp(old_log_probs - vllm_logprobs, min = -100.0, max = 100.0) # [ ln (pi_prox_t / pi_behavior_t) ]
token_IS_ratio = torch.exp(token_log_ratio) # [ pi_prox_t / pi_behavior_t ], (bs, gen_len)

gen_logits = utils.get_batched_generated_values(
batched_values=outputs['logits'],
prompt_len=batch['prompt_len'],
Expand Down Expand Up @@ -322,6 +331,9 @@ def policy_loss(
batch['action_mask'],
)

if importance_weighting:
policy_loss = policy_loss * token_IS_ratio # [ pi_prox_t / pi_behavior_t * policy_loss_t ]_t

if length_normalize_policy_loss:
policy_loss = utils.sample_wise_masked_mean(
policy_loss,
Expand Down Expand Up @@ -384,6 +396,8 @@ def policy_loss(
seq_entropies,
'advantages/mean':
utils.sample_wise_masked_mean(advantages, batch['action_mask']),
'importance_ratio/mean': # always logging this in default regardless of importance weighting: want to check how far vllm logp is from log pi_old
utils.sample_wise_masked_mean(token_IS_ratio, batch['action_mask']),
}
# Add entropy percentiles to policy_dict
for i, p in enumerate(percentiles):
Expand All @@ -392,55 +406,89 @@ def policy_loss(
return policy_dict

elif loss_type in ALGORITHM_TYPE.REGRESSION:
#assume batch contains (1) V-star values (key 'vstar), (2) rewards (key 'rewards'), (3) ref_log_probs
# current it only supports SMD
# TODO: add APO support
prompt_advantages = batch['prompt_advantages'].detach()
assert prompt_advantages is not None
assert prompt_advantages.dim() == 1 # (bs,)

online_log_probs = outputs['online_log_probs']
ref_log_probs = batch['ift_log_probs']
log_probs_diff = online_log_probs - ref_log_probs
old_entropies = batch['old_entropies']
old_log_probs = batch['old_log_probs'] # note this is the log prob of the pi_prox -- the usual pi_old in ppo language.
vllm_logprobs = batch['vllm_logprobs'] # note this the log prob from vllm when generating the rollouts, i.e., log pi_behavior
assert old_log_probs.shape == vllm_logprobs.shape, f'old_log_probs and vllm_logprobs have different shapes {old_log_probs.shape=}, {vllm_logprobs.shape=}'

token_log_ratio = old_log_probs - vllm_logprobs # [ ln (pi_prox_t / pi_behavior_t) ]
online_to_old_diff = online_log_probs - old_log_probs # ln(π/π_old) for SMD

#compute KL to pi_ref to keep track the divergence to \pi_ref
policy_kl_dict = utils.approx_kl(
ref_policy_kl_dict = utils.approx_kl(
log_p=ref_log_probs,
log_q=online_log_probs, #log_q - log_p = log pi - log pi_ref
kl_clip_range=kl_clip_range,
)

old_policy_kl_dict = utils.approx_kl(
log_p=old_log_probs,
log_q=online_log_probs, #log_q - log_p = log pi - log pi_old
kl_clip_range=kl_clip_range,
)

with torch.no_grad():
policy_kl = utils.masked_mean(
policy_kl_dict[kl_estimator], # pyright: ignore
old_policy_kl_dict[kl_estimator], # pyright: ignore
batch['action_mask'],
) #plain average over all tokens (KL to pi_ref)

#compute the policy loss
ref_policy_kl = utils.masked_mean(
ref_policy_kl_dict[kl_estimator], # pyright: ignore
batch['action_mask'],
) #plain average over all tokens (KL to pi_ref)

#compute the policy loss for SMD;
masked_log_probs_diff = utils.masked_sum(
log_probs_diff,
online_to_old_diff, # Correct: ln(π/π_old)
batch['action_mask'],
dim=-1,
) #size: (batch_size,)
vstars = batch['vstar']
masked_log_ratio = utils.masked_sum(
token_log_ratio,
batch['action_mask'],
dim=-1,
) #size: (batch_size,) # \sum_t ln (pi_prox_t / pi_behavior_t)
masked_log_ratio = torch.clamp(masked_log_ratio, min = -100.0, max = 100.0) # clip to avoid overflow
masked_importance_ratio = torch.exp(masked_log_ratio) # pi_prox / pi_behavior
assert masked_importance_ratio.shape == masked_log_probs_diff.shape, f'masked_importance_ratio and masked_log_probs_diff have different shapes {masked_importance_ratio.shape=}, {masked_log_probs_diff.shape=}'

beta_float = float(beta) # convert it to float to avoid type error
seq_level_policy_loss = (beta_float * masked_log_probs_diff - prompt_advantages)**2 # (bs,)
if importance_weighting:
seq_level_policy_loss = masked_importance_ratio * seq_level_policy_loss # IS at the sequence level
policy_loss = seq_level_policy_loss.mean() # (1,)

rewards = utils.masked_sum(
batch['rewards'],
batch['action_mask'],
dim=-1,
)
assert vstars.size() == rewards.size() == masked_log_probs_diff.size(
) # should have the same shape which is (batch_size, )

policy_loss = ((beta * masked_log_probs_diff -
(rewards - vstars))**2).mean()

policy_dict = {
'loss/policy_loss': policy_loss,
'kl/policy_kl': policy_kl,
'kl/policy_kl': policy_kl, # Required by calling code in model.py
'kl/online_ift_kl': ref_policy_kl,
'gen/gen_length': batch['action_mask'].sum(dim=1).to(torch.float32),
'gen/entropy': old_entropies,
'rewards/mean': torch.mean(
rewards,
), #compute the average reward of the current batch
'vstars/mean': torch.mean(
vstars,
), #compute the average of the vstar of the current batch
), # compute the average reward of the current batch
'advantages/mean': torch.mean(
prompt_advantages, # SMD uses prompt_advantages, not advantages
), # compute the average of the prompt advantages for SMD
'importance_ratio/mean': torch.mean(
masked_importance_ratio,
), # always logging this in default regardless of importance weighting: want to check how far vllm logp is from log pi_old
}
return policy_dict

else:
raise ValueError(f'Policy loss not implemented for {loss_type}')

Expand All @@ -459,6 +507,7 @@ def online_rl_loss(
entropy_loss_weight: float | None = None,
kl_estimator: Optional[str] = 'k3',
kl_clip_range: Optional[float] = 40.0,
importance_weighting: bool = True,
) -> MutableMapping:
"""Compute the online RL loss.

Expand Down Expand Up @@ -489,7 +538,7 @@ def online_rl_loss(

return_dict = {}
advantages = None
if loss_type not in ALGORITHM_TYPE.REGRESSION:
if loss_type not in ALGORITHM_TYPE.REGRESSION: # basically grpo/ppo
advantages = batch['advantages']

# 1. Critic Loss
Expand Down Expand Up @@ -535,10 +584,12 @@ def online_rl_loss(
length_normalize_policy_loss=length_normalize_policy_loss,
kl_estimator=kl_estimator,
kl_clip_range=kl_clip_range,
importance_weighting=importance_weighting,
)

return_dict.update(**policy_dict)


for key, value in batch.items():
# This logic handles reward logging a little differently than other quantities.
# For rewards shaped as [batch, actions] we log (1) the per-sequence masked average
Expand All @@ -553,28 +604,28 @@ def online_rl_loss(
batch['action_mask'],
dim=1,
).mean(dim=0)

# Total reward over timesteps
return_dict['env/' + str(key) + '_total'] = utils.masked_sum(
value,
batch['action_mask'],
dim=1,
).mean(dim=0)
else:
# If this value is not [batch, actions] shaped, just do a
# vanilla mean.
return_dict['env/' + str(key)] = value.mean(dim=0)
if 'ift_kl' == key:
elif 'ift_kl' == key:
return_dict['kl/' + str(key)] = utils.masked_mean(
value,
batch['action_mask'],
)

# 3. Compute the total loss
# 3. Compute the total loss
return_dict['total'] = return_dict['loss/policy_loss']

if loss_type in ALGORITHM_TYPE.ACTOR_CRITIC:
# Add value loss to total loss
return_dict['total'] += value_loss_weight * return_dict[
'loss/value_loss'] # pyright: ignore


# If we want to directly minimize the KL Divergence, we can do so here
# and it will not include the KL in the reward.
if add_direct_kl_loss:
Expand All @@ -593,7 +644,7 @@ def online_rl_loss(
# breakpoint()
return_dict['loss/entropy'] = entropy_loss
return_dict['total'] += entropy_loss

if 'lbl' in outputs and outputs['lbl'] is not None:
return_dict['loss/lbl'] = outputs['lbl']
return_dict['total'] += outputs['lbl']
Expand All @@ -603,4 +654,4 @@ def online_rl_loss(
if key not in 'total':
return_dict[key] = value.detach().cpu()

return return_dict
return return_dict
Loading