Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions compose_rl/algorithms/offline/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def after_load(self, state: State, logger: Logger) -> None:
strict_model_weights=callback.strict_model_weights,
ignore_keys=callback.ignore_keys,
event=callback.event,
)
for callback in state.callbacks
if isinstance(callback, LoadCheckpoint)
) for callback in state.callbacks if isinstance(callback, LoadCheckpoint)
]

# For HF checkpoint, load_path is unset and should be handled in llmfoundry code.
Expand Down
6 changes: 2 additions & 4 deletions compose_rl/algorithms/offline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def eval_forward(
) -> None:
raise ValueError('Eval forward is not implemented for ComposerDPOLM.')

def loss(self, outputs: CausalLMOutputWithPast,
batch: Mapping) -> dict[str, torch.Tensor]:
def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> dict[str, torch.Tensor]:
return pairwise_offline_loss(
outputs,
batch,
Expand Down Expand Up @@ -119,8 +118,7 @@ def eval_forward(
) -> None:
raise ValueError('Eval forward is not implemented for ComposerHFDPOLM.')

def loss(self, outputs: CausalLMOutputWithPast,
batch: Mapping) -> dict[str, torch.Tensor]:
def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> dict[str, torch.Tensor]:
return pairwise_offline_loss(
outputs,
batch,
Expand Down
16 changes: 5 additions & 11 deletions compose_rl/algorithms/offline/model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,7 @@ def pairwise_offline_loss(

losses = torch.zeros_like(logits)
if loss_type == PairwiseOfflineEnum.DPO:
losses = (
-F.logsigmoid(beta * logits) * (1 - label_smoothing) -
F.logsigmoid(-beta * logits) * label_smoothing
)
losses = (-F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing)
elif loss_type == PairwiseOfflineEnum.RCDPO:
# Adding reward-difference based label_smoothing = 1 - reward_bt_prob
chosen_reward = outputs['chosen_reward']
Expand All @@ -239,9 +236,8 @@ def pairwise_offline_loss(
logsigmoid_not_a = F.logsigmoid(-beta * logits)
logsigmoid_not_b = F.logsigmoid(-eta * reward_diff)

losses = torch.exp(logsigmoid_b) * (
logsigmoid_b - logsigmoid_a
) + torch.exp(logsigmoid_not_b) * (logsigmoid_not_b - logsigmoid_not_a)
losses = torch.exp(logsigmoid_b) * (logsigmoid_b - logsigmoid_a
) + torch.exp(logsigmoid_not_b) * (logsigmoid_not_b - logsigmoid_not_a)
elif loss_type == PairwiseOfflineEnum.REBEL:
# Reproducing the REBEL loss from paper: https://arxiv.org/pdf/2404.16767 page 4
# Code: https://github.com/ZhaolinGao/REBEL/blob/e0a6a190108a45c70b4920b58a4ccac8a09ab22b/src/tldr/rebel.py#L761-L777
Expand All @@ -259,8 +255,7 @@ def pairwise_offline_loss(
losses = (logits - 1 / (2 * beta))**2
elif loss_type == PairwiseOfflineEnum.KTO:
chosen_KL = (policy_chosen_logp - ref_chosen_logp).mean().clamp(min=0)
rejected_KL = (policy_rejected_logp -
ref_rejected_logp).mean().clamp(min=0)
rejected_KL = (policy_rejected_logp - ref_rejected_logp).mean().clamp(min=0)

chosen_logratios = policy_chosen_logp - ref_chosen_logp
rejected_logratios = policy_rejected_logp - ref_rejected_logp
Expand All @@ -283,8 +278,7 @@ def pairwise_offline_loss(
losses = losses.mean()

chosen_rewards = beta * (policy_chosen_logp - ref_chosen_logp).detach()
rejected_rewards = beta * (policy_rejected_logp -
ref_rejected_logp).detach()
rejected_rewards = beta * (policy_rejected_logp - ref_rejected_logp).detach()

# Logging KL margins for comparing different methods
chosen_KL = (policy_chosen_logp - ref_chosen_logp).detach()
Expand Down
6 changes: 2 additions & 4 deletions compose_rl/algorithms/online/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
ComposerHFPolicyLM,
ComposerMPTPolicyLM,
)
from compose_rl.algorithms.online.model_methods import \
CausalLMOutputWithPastAndValues
from compose_rl.algorithms.online.model_methods import CausalLMOutputWithPastAndValues
from compose_rl.algorithms.online.policy_configuration import (
HFPolicyConfig,
MPTPolicyConfig,
)
from compose_rl.algorithms.online.single_controller_callback import \
SingleControllerOnPolicyCallback
from compose_rl.algorithms.online.single_controller_callback import SingleControllerOnPolicyCallback
from compose_rl.registry import kl_controllers

kl_controllers.register('adaptive', func=AdaptiveKLController)
Expand Down
60 changes: 21 additions & 39 deletions compose_rl/algorithms/online/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ def env_reward(
# we can pad this with pad tokens, since we switch the padding between left and right
# padding based on the sequence length + max_sequence_length.
if prompt_tokens.size(1) + max_gen_len > sequences.size(1):
len_to_pad = max_gen_len - (
sequences.size(1) - prompt_tokens.size(1)
)
len_to_pad = max_gen_len - (sequences.size(1) - prompt_tokens.size(1))

extra_padding = torch.ones(
(batch_size, len_to_pad),
Expand Down Expand Up @@ -206,7 +204,8 @@ def env_reward(
untokenized_prompt_and_responses = []
for i in range(batch_size):
prompt = tokenizer.decode( # type: ignore
right_padded_obs[i, :prompt_len[i]])
right_padded_obs[i, :prompt_len[i]],
)
generated_text = tokenizer.decode( # type: ignore
get_decoded_sequence(actions[i], generated_len[i],
max_gen_len))
Expand Down Expand Up @@ -283,8 +282,7 @@ def env_reward(
value_action_mask = torch.cat([
action_mask,
torch.zeros((batch_size, 1), device=cur_device),
],
dim=-1)
], dim=-1)
device_train_microbatch_values *= value_action_mask
partial_env_output['values'] = device_train_microbatch_values
# Future implementations may change the way reward_seq_len is defined
Expand Down Expand Up @@ -350,8 +348,7 @@ def __init__(
kl_estimator = train_config['model'].get('kl_estimator', 'k1')
if kl_estimator not in ['k1', 'k2', 'k3', 'k3_offpolicy']:
raise ValueError(
f'Invalid kl estimator: {kl_estimator}. ' +
'Valid options are: k1, k2, k3, k3_offpolicy.',
f'Invalid kl estimator: {kl_estimator}. ' + 'Valid options are: k1, k2, k3, k3_offpolicy.',
)
self.kl_estimator = kl_estimator

Expand All @@ -366,8 +363,7 @@ def __init__(
kl_clip_range = train_config['model'].get('kl_clip_range', 40.0)
if kl_clip_range <= 0:
raise ValueError(
f'Invalid kl clip range: {kl_clip_range}. ' +
'Must be greater than 0.',
f'Invalid kl clip range: {kl_clip_range}. ' + 'Must be greater than 0.',
)
# check for precision and clip range
precision = train_config['precision']
Expand Down Expand Up @@ -421,8 +417,7 @@ def __init__(

self.num_unique_prompts_per_iter: int = var_config.get(
'num_unique_prompts_per_iter',
self.num_batches_per_update * self.global_train_batch_size //
self.generations_per_prompt,
self.num_batches_per_update * self.global_train_batch_size // self.generations_per_prompt,
)

log.info(
Expand Down Expand Up @@ -492,8 +487,7 @@ def __init__(
'vllm_generate_function needs to be either `generate` or `chat`',
)

self.vllm_model_name = train_config['model'][
'pretrained_model_name_or_path']
self.vllm_model_name = train_config['model']['pretrained_model_name_or_path']

# set vllm tensor parallel size
total_num_nodes = os.getenv('TOTAL_NUM_NODES', None)
Expand Down Expand Up @@ -662,9 +656,7 @@ def _get_next_iter_prompts(self):
"""Gets the next iteration's batch of prompts."""
# Sample fewer batches for the Online RL interation depending on the number of generations per prompt
n_unique_batches = self.num_unique_prompts_per_iter // self.global_train_batch_size
batches = [
self._get_single_batch_prompts() for _ in range(n_unique_batches)
]
batches = [self._get_single_batch_prompts() for _ in range(n_unique_batches)]

ret_batch = {}
for key in batches[0].keys():
Expand Down Expand Up @@ -695,8 +687,7 @@ def _get_next_iter_prompts(self):
padding_key = self.pad_token_idx
if (batch[key][:, -1] == padding_key).any():
raise ValueError(
'The last token in the prompt should not be the pad token. Please double '
+
'The last token in the prompt should not be the pad token. Please double ' +
'check the dataloader and prompt and dataloader.',
)
elif key == 'prompt_attention_mask':
Expand Down Expand Up @@ -874,10 +865,7 @@ def _extract_minibatch(
"""
start_idx = idx * minibatch_size
end_idx = (idx + 1) * minibatch_size
curr_gen_batch = {
batch_key: tensor[start_idx:end_idx]
for batch_key, tensor in batch.items()
}
curr_gen_batch = {batch_key: tensor[start_idx:end_idx] for batch_key, tensor in batch.items()}
return curr_gen_batch

def _resolve_outputs(
Expand Down Expand Up @@ -989,8 +977,7 @@ def _resolve_outputs(
env_outs['advantages'] = advantages
else:
raise ValueError(
f'Invalid loss type: {self.actor_critic.loss_type}. ' +
'Valid options are: ppo, grpo.',
f'Invalid loss type: {self.actor_critic.loss_type}. ' + 'Valid options are: ppo, grpo.',
)

batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var(
Expand Down Expand Up @@ -1043,11 +1030,10 @@ def _log_generations_to_logger(self, state: State):
'verified_answer',
]
save_data = [[prompt_id, reward, prompt, generation, verified_answer]
for (prompt_id, reward,
verified_answer), (prompt, generation) in zip(
prompt_ids_rewards_and_answers,
prompts_and_gens,
)]
for (prompt_id, reward, verified_answer), (prompt, generation) in zip(
prompt_ids_rewards_and_answers,
prompts_and_gens,
)]
# Sort the save_data by reward in descending order
save_data = sorted(save_data, key=lambda x: x[1], reverse=True)

Expand All @@ -1067,8 +1053,7 @@ def _log_generations_to_logger(self, state: State):

artifact.add(text_table, 'predictions')
wandb.log_artifact(artifact)
wandb.log({'generations': text_table},
step=state.timestamp.batch.value)
wandb.log({'generations': text_table}, step=state.timestamp.batch.value)

if self.mlflow_logger is not None:
self.mlflow_logger.log_table(
Expand All @@ -1088,8 +1073,7 @@ def _update_ift_kl(self):

self.kl_ctl.update(
ift_kl_update,
self.num_batches_per_update * self.device_train_batch_size *
dist.get_world_size(),
self.num_batches_per_update * self.device_train_batch_size * dist.get_world_size(),
)

self.kl_ift = []
Expand All @@ -1102,8 +1086,7 @@ def _create_vllm_engines(self):
self.model_update_group = None
self.vllm_engines = []

if os.getenv('NODE_RANK',
None) == '0' and os.getenv('LOCAL_RANK', None) == '0':
if os.getenv('NODE_RANK', None) == '0' and os.getenv('LOCAL_RANK', None) == '0':
log.info('Creating vLLM engines.')

os.environ['NCCL_CUMEM_ENABLE'] = '0'
Expand All @@ -1124,7 +1107,7 @@ def _create_vllm_engines(self):
)
log.info('After creating vLLM engines.')

master_address = ray._private.services.get_node_ip_address( # type: ignore
master_address = ray._private.services.get_node_ip_address( # type: ignore
)
with socket.socket() as sock:
sock.bind(('', 0))
Expand Down Expand Up @@ -1173,8 +1156,7 @@ def state_dict(self):
return {
'KL_ctl_state_dict': self.kl_ctl.state_dict(),
'iter_num': self.iter_num,
'train_prompt_loader':
self.train_prompt_loader.state_dict(), # pyright: ignore
'train_prompt_loader': self.train_prompt_loader.state_dict(), # pyright: ignore
}

def load_state_dict(self, state_dict: dict[str, Any]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,9 @@ def _vllm_generate(
}

# We have to remove all pad tokens here
all_prompts = [[
token
for token in prompt.detach().cpu().tolist()
if token != pad_token_id
all_prompts = [
[token for token in prompt.detach().cpu().tolist() if token != pad_token_id] for prompt in all_prompts
]
for prompt in all_prompts]

# Generate with vllm
# Calculate the base batch size
Expand Down Expand Up @@ -177,12 +174,9 @@ def _vllm_chat(
}

# We have to remove all pad tokens here. Keep here to check
all_prompts = [[
token
for token in prompt.detach().cpu().tolist()
if token != pad_token_id
all_prompts = [
[token for token in prompt.detach().cpu().tolist() if token != pad_token_id] for prompt in all_prompts
]
for prompt in all_prompts]

# Generate with vllm
# Calculate the base batch size
Expand Down Expand Up @@ -227,10 +221,7 @@ def _vllm_chat(
)

# Remove pad tokens from vllm prompts
all_vllm_prompts = [[token
for token in prompt
if token != pad_token_id]
for prompt in all_vllm_prompts]
all_vllm_prompts = [[token for token in prompt if token != pad_token_id] for prompt in all_vllm_prompts]

# Checking vllm prompts with all prompts
assert all_prompts == all_vllm_prompts
Expand Down
Loading
Loading