diff --git a/compose_rl/algorithms/offline/callback.py b/compose_rl/algorithms/offline/callback.py index 266c9951..3733447a 100644 --- a/compose_rl/algorithms/offline/callback.py +++ b/compose_rl/algorithms/offline/callback.py @@ -58,9 +58,7 @@ def after_load(self, state: State, logger: Logger) -> None: strict_model_weights=callback.strict_model_weights, ignore_keys=callback.ignore_keys, event=callback.event, - ) - for callback in state.callbacks - if isinstance(callback, LoadCheckpoint) + ) for callback in state.callbacks if isinstance(callback, LoadCheckpoint) ] # For HF checkpoint, load_path is unset and should be handled in llmfoundry code. diff --git a/compose_rl/algorithms/offline/model.py b/compose_rl/algorithms/offline/model.py index 806842c1..15d66b3e 100644 --- a/compose_rl/algorithms/offline/model.py +++ b/compose_rl/algorithms/offline/model.py @@ -67,8 +67,7 @@ def eval_forward( ) -> None: raise ValueError('Eval forward is not implemented for ComposerDPOLM.') - def loss(self, outputs: CausalLMOutputWithPast, - batch: Mapping) -> dict[str, torch.Tensor]: + def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> dict[str, torch.Tensor]: return pairwise_offline_loss( outputs, batch, @@ -119,8 +118,7 @@ def eval_forward( ) -> None: raise ValueError('Eval forward is not implemented for ComposerHFDPOLM.') - def loss(self, outputs: CausalLMOutputWithPast, - batch: Mapping) -> dict[str, torch.Tensor]: + def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> dict[str, torch.Tensor]: return pairwise_offline_loss( outputs, batch, diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 4da46019..6e54f73b 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -210,10 +210,7 @@ def pairwise_offline_loss( losses = torch.zeros_like(logits) if loss_type == PairwiseOfflineEnum.DPO: - losses = ( - -F.logsigmoid(beta * logits) * (1 - label_smoothing) - - F.logsigmoid(-beta * logits) * label_smoothing - ) + losses = (-F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing) elif loss_type == PairwiseOfflineEnum.RCDPO: # Adding reward-difference based label_smoothing = 1 - reward_bt_prob chosen_reward = outputs['chosen_reward'] @@ -239,9 +236,8 @@ def pairwise_offline_loss( logsigmoid_not_a = F.logsigmoid(-beta * logits) logsigmoid_not_b = F.logsigmoid(-eta * reward_diff) - losses = torch.exp(logsigmoid_b) * ( - logsigmoid_b - logsigmoid_a - ) + torch.exp(logsigmoid_not_b) * (logsigmoid_not_b - logsigmoid_not_a) + losses = torch.exp(logsigmoid_b) * (logsigmoid_b - logsigmoid_a + ) + torch.exp(logsigmoid_not_b) * (logsigmoid_not_b - logsigmoid_not_a) elif loss_type == PairwiseOfflineEnum.REBEL: # Reproducing the REBEL loss from paper: https://arxiv.org/pdf/2404.16767 page 4 # Code: https://github.com/ZhaolinGao/REBEL/blob/e0a6a190108a45c70b4920b58a4ccac8a09ab22b/src/tldr/rebel.py#L761-L777 @@ -259,8 +255,7 @@ def pairwise_offline_loss( losses = (logits - 1 / (2 * beta))**2 elif loss_type == PairwiseOfflineEnum.KTO: chosen_KL = (policy_chosen_logp - ref_chosen_logp).mean().clamp(min=0) - rejected_KL = (policy_rejected_logp - - ref_rejected_logp).mean().clamp(min=0) + rejected_KL = (policy_rejected_logp - ref_rejected_logp).mean().clamp(min=0) chosen_logratios = policy_chosen_logp - ref_chosen_logp rejected_logratios = policy_rejected_logp - ref_rejected_logp @@ -283,8 +278,7 @@ def pairwise_offline_loss( losses = losses.mean() chosen_rewards = beta * (policy_chosen_logp - ref_chosen_logp).detach() - rejected_rewards = beta * (policy_rejected_logp - - ref_rejected_logp).detach() + rejected_rewards = beta * (policy_rejected_logp - ref_rejected_logp).detach() # Logging KL margins for comparing different methods chosen_KL = (policy_chosen_logp - ref_chosen_logp).detach() diff --git a/compose_rl/algorithms/online/__init__.py b/compose_rl/algorithms/online/__init__.py index 3857e92f..b826c01e 100644 --- a/compose_rl/algorithms/online/__init__.py +++ b/compose_rl/algorithms/online/__init__.py @@ -13,14 +13,12 @@ ComposerHFPolicyLM, ComposerMPTPolicyLM, ) -from compose_rl.algorithms.online.model_methods import \ - CausalLMOutputWithPastAndValues +from compose_rl.algorithms.online.model_methods import CausalLMOutputWithPastAndValues from compose_rl.algorithms.online.policy_configuration import ( HFPolicyConfig, MPTPolicyConfig, ) -from compose_rl.algorithms.online.single_controller_callback import \ - SingleControllerOnPolicyCallback +from compose_rl.algorithms.online.single_controller_callback import SingleControllerOnPolicyCallback from compose_rl.registry import kl_controllers kl_controllers.register('adaptive', func=AdaptiveKLController) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 41d96d59..c6fdf59b 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -154,9 +154,7 @@ def env_reward( # we can pad this with pad tokens, since we switch the padding between left and right # padding based on the sequence length + max_sequence_length. if prompt_tokens.size(1) + max_gen_len > sequences.size(1): - len_to_pad = max_gen_len - ( - sequences.size(1) - prompt_tokens.size(1) - ) + len_to_pad = max_gen_len - (sequences.size(1) - prompt_tokens.size(1)) extra_padding = torch.ones( (batch_size, len_to_pad), @@ -206,7 +204,8 @@ def env_reward( untokenized_prompt_and_responses = [] for i in range(batch_size): prompt = tokenizer.decode( # type: ignore - right_padded_obs[i, :prompt_len[i]]) + right_padded_obs[i, :prompt_len[i]], + ) generated_text = tokenizer.decode( # type: ignore get_decoded_sequence(actions[i], generated_len[i], max_gen_len)) @@ -283,8 +282,7 @@ def env_reward( value_action_mask = torch.cat([ action_mask, torch.zeros((batch_size, 1), device=cur_device), - ], - dim=-1) + ], dim=-1) device_train_microbatch_values *= value_action_mask partial_env_output['values'] = device_train_microbatch_values # Future implementations may change the way reward_seq_len is defined @@ -350,8 +348,7 @@ def __init__( kl_estimator = train_config['model'].get('kl_estimator', 'k1') if kl_estimator not in ['k1', 'k2', 'k3', 'k3_offpolicy']: raise ValueError( - f'Invalid kl estimator: {kl_estimator}. ' + - 'Valid options are: k1, k2, k3, k3_offpolicy.', + f'Invalid kl estimator: {kl_estimator}. ' + 'Valid options are: k1, k2, k3, k3_offpolicy.', ) self.kl_estimator = kl_estimator @@ -366,8 +363,7 @@ def __init__( kl_clip_range = train_config['model'].get('kl_clip_range', 40.0) if kl_clip_range <= 0: raise ValueError( - f'Invalid kl clip range: {kl_clip_range}. ' + - 'Must be greater than 0.', + f'Invalid kl clip range: {kl_clip_range}. ' + 'Must be greater than 0.', ) # check for precision and clip range precision = train_config['precision'] @@ -421,8 +417,7 @@ def __init__( self.num_unique_prompts_per_iter: int = var_config.get( 'num_unique_prompts_per_iter', - self.num_batches_per_update * self.global_train_batch_size // - self.generations_per_prompt, + self.num_batches_per_update * self.global_train_batch_size // self.generations_per_prompt, ) log.info( @@ -492,8 +487,7 @@ def __init__( 'vllm_generate_function needs to be either `generate` or `chat`', ) - self.vllm_model_name = train_config['model'][ - 'pretrained_model_name_or_path'] + self.vllm_model_name = train_config['model']['pretrained_model_name_or_path'] # set vllm tensor parallel size total_num_nodes = os.getenv('TOTAL_NUM_NODES', None) @@ -662,9 +656,7 @@ def _get_next_iter_prompts(self): """Gets the next iteration's batch of prompts.""" # Sample fewer batches for the Online RL interation depending on the number of generations per prompt n_unique_batches = self.num_unique_prompts_per_iter // self.global_train_batch_size - batches = [ - self._get_single_batch_prompts() for _ in range(n_unique_batches) - ] + batches = [self._get_single_batch_prompts() for _ in range(n_unique_batches)] ret_batch = {} for key in batches[0].keys(): @@ -695,8 +687,7 @@ def _get_next_iter_prompts(self): padding_key = self.pad_token_idx if (batch[key][:, -1] == padding_key).any(): raise ValueError( - 'The last token in the prompt should not be the pad token. Please double ' - + + 'The last token in the prompt should not be the pad token. Please double ' + 'check the dataloader and prompt and dataloader.', ) elif key == 'prompt_attention_mask': @@ -874,10 +865,7 @@ def _extract_minibatch( """ start_idx = idx * minibatch_size end_idx = (idx + 1) * minibatch_size - curr_gen_batch = { - batch_key: tensor[start_idx:end_idx] - for batch_key, tensor in batch.items() - } + curr_gen_batch = {batch_key: tensor[start_idx:end_idx] for batch_key, tensor in batch.items()} return curr_gen_batch def _resolve_outputs( @@ -989,8 +977,7 @@ def _resolve_outputs( env_outs['advantages'] = advantages else: raise ValueError( - f'Invalid loss type: {self.actor_critic.loss_type}. ' + - 'Valid options are: ppo, grpo.', + f'Invalid loss type: {self.actor_critic.loss_type}. ' + 'Valid options are: ppo, grpo.', ) batch_adv_mean, batch_adv_var = dist_compute_masked_mean_and_var( @@ -1043,11 +1030,10 @@ def _log_generations_to_logger(self, state: State): 'verified_answer', ] save_data = [[prompt_id, reward, prompt, generation, verified_answer] - for (prompt_id, reward, - verified_answer), (prompt, generation) in zip( - prompt_ids_rewards_and_answers, - prompts_and_gens, - )] + for (prompt_id, reward, verified_answer), (prompt, generation) in zip( + prompt_ids_rewards_and_answers, + prompts_and_gens, + )] # Sort the save_data by reward in descending order save_data = sorted(save_data, key=lambda x: x[1], reverse=True) @@ -1067,8 +1053,7 @@ def _log_generations_to_logger(self, state: State): artifact.add(text_table, 'predictions') wandb.log_artifact(artifact) - wandb.log({'generations': text_table}, - step=state.timestamp.batch.value) + wandb.log({'generations': text_table}, step=state.timestamp.batch.value) if self.mlflow_logger is not None: self.mlflow_logger.log_table( @@ -1088,8 +1073,7 @@ def _update_ift_kl(self): self.kl_ctl.update( ift_kl_update, - self.num_batches_per_update * self.device_train_batch_size * - dist.get_world_size(), + self.num_batches_per_update * self.device_train_batch_size * dist.get_world_size(), ) self.kl_ift = [] @@ -1102,8 +1086,7 @@ def _create_vllm_engines(self): self.model_update_group = None self.vllm_engines = [] - if os.getenv('NODE_RANK', - None) == '0' and os.getenv('LOCAL_RANK', None) == '0': + if os.getenv('NODE_RANK', None) == '0' and os.getenv('LOCAL_RANK', None) == '0': log.info('Creating vLLM engines.') os.environ['NCCL_CUMEM_ENABLE'] = '0' @@ -1124,7 +1107,7 @@ def _create_vllm_engines(self): ) log.info('After creating vLLM engines.') - master_address = ray._private.services.get_node_ip_address( # type: ignore + master_address = ray._private.services.get_node_ip_address( # type: ignore ) with socket.socket() as sock: sock.bind(('', 0)) @@ -1173,8 +1156,7 @@ def state_dict(self): return { 'KL_ctl_state_dict': self.kl_ctl.state_dict(), 'iter_num': self.iter_num, - 'train_prompt_loader': - self.train_prompt_loader.state_dict(), # pyright: ignore + 'train_prompt_loader': self.train_prompt_loader.state_dict(), # pyright: ignore } def load_state_dict(self, state_dict: dict[str, Any]): diff --git a/compose_rl/algorithms/online/generation_utils/generation_utils.py b/compose_rl/algorithms/online/generation_utils/generation_utils.py index beea9176..8d1f9fd9 100644 --- a/compose_rl/algorithms/online/generation_utils/generation_utils.py +++ b/compose_rl/algorithms/online/generation_utils/generation_utils.py @@ -102,12 +102,9 @@ def _vllm_generate( } # We have to remove all pad tokens here - all_prompts = [[ - token - for token in prompt.detach().cpu().tolist() - if token != pad_token_id + all_prompts = [ + [token for token in prompt.detach().cpu().tolist() if token != pad_token_id] for prompt in all_prompts ] - for prompt in all_prompts] # Generate with vllm # Calculate the base batch size @@ -177,12 +174,9 @@ def _vllm_chat( } # We have to remove all pad tokens here. Keep here to check - all_prompts = [[ - token - for token in prompt.detach().cpu().tolist() - if token != pad_token_id + all_prompts = [ + [token for token in prompt.detach().cpu().tolist() if token != pad_token_id] for prompt in all_prompts ] - for prompt in all_prompts] # Generate with vllm # Calculate the base batch size @@ -227,10 +221,7 @@ def _vllm_chat( ) # Remove pad tokens from vllm prompts - all_vllm_prompts = [[token - for token in prompt - if token != pad_token_id] - for prompt in all_vllm_prompts] + all_vllm_prompts = [[token for token in prompt if token != pad_token_id] for prompt in all_vllm_prompts] # Checking vllm prompts with all prompts assert all_prompts == all_vllm_prompts diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 4da62390..b13abd82 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -32,8 +32,7 @@ from ray.exceptions import GetTimeoutError from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from torch.distributed.distributed_c10d import \ - _new_process_group_helper # type: ignore +from torch.distributed.distributed_c10d import _new_process_group_helper # type: ignore from torch.distributed.distributed_c10d import _world # type: ignore from torch.distributed.distributed_c10d import ( Backend, @@ -64,9 +63,7 @@ def init_process_group( store: Optional[Store] = None, group_name: Optional[str] = None, ) -> torch.distributed.ProcessGroup: - assert (store is None) or ( - init_method is None - ), 'Cannot specify both init_method and store.' + assert (store is None) or (init_method is None), 'Cannot specify both init_method and store.' if store is not None: assert world_size > 0, 'world_size must be positive if using store' @@ -124,8 +121,7 @@ def init_process_group( backend: str, ): """Init torch process group for model weights update.""" - assert torch.distributed.is_initialized( - ), 'default torch process group must be initialized' + assert torch.distributed.is_initialized(), 'default torch process group must be initialized' assert group_name != '', 'group name must not be empty' rank = torch.distributed.get_rank() + rank_offset @@ -138,8 +134,8 @@ def init_process_group( ) log.info(f'init process group for: {torch.distributed.get_rank()}') log.info( - f'init_process_group: master_address={master_address}, master_port={master_port}, ' - + f'rank={rank}, world_size={world_size}, group_name={group_name}', + f'init_process_group: master_address={master_address}, master_port={master_port}, ' + + f'rank={rank}, world_size={world_size}, group_name={group_name}', ) def update_weight( @@ -415,9 +411,7 @@ def broadcast_to_vllm( refss = [] cache_reset_refss = [] if enable_prefix_caching and dist.get_global_rank() == 0: - cache_reset_refss = [ - engine.reset_prefix_cache.remote() for engine in vllm_engines - ] + cache_reset_refss = [engine.reset_prefix_cache.remote() for engine in vllm_engines] # These apply to llama modules, it might change for other modules valid_non_leaf_module_names = [ @@ -434,18 +428,12 @@ def broadcast_to_vllm( # Adding a dummy forwards call. # We need this otherwise FSDP throws an error during a standard forward pass. dummy_batch = { - 'obs': - torch.tensor([[0]], dtype=torch.long, device=device), - 'right_padded_attn_mask': - torch.tensor([[1]], dtype=torch.bool, device=device), - 'actions': - torch.tensor([[0]], dtype=torch.long, device=device), - 'prompt_len': - torch.tensor([1], device=device), - 'max_gen_len': - torch.tensor([1], device=device), - 'action_mask': - torch.tensor([[0]], dtype=torch.long, device=device), + 'obs': torch.tensor([[0]], dtype=torch.long, device=device), + 'right_padded_attn_mask': torch.tensor([[1]], dtype=torch.bool, device=device), + 'actions': torch.tensor([[0]], dtype=torch.long, device=device), + 'prompt_len': torch.tensor([1], device=device), + 'max_gen_len': torch.tensor([1], device=device), + 'action_mask': torch.tensor([[0]], dtype=torch.long, device=device), } model(dummy_batch) start_time = time.time() @@ -518,8 +506,8 @@ def broadcast_to_vllm( # Check if the number of parameters updated is equal to the number of parameters # This can only be done on global rank 0, since it is the one that is updating the parameters. assert num_params == count, ( - f'Number of parameters updated {count} does not match the number of parameters {num_params}' - + f'This means that some parameters were not updated.' + f'Number of parameters updated {count} does not match the number of parameters {num_params}' + + f'This means that some parameters were not updated.' ) log.info(f'for loop took: {time.time() - start_time}') @@ -550,6 +538,4 @@ def ray_noset_visible_devices(env_vars: os._Environ = os.environ): 'RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS', 'RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR', ] - return any( - env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST - ) + return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) diff --git a/compose_rl/algorithms/online/kl_controller.py b/compose_rl/algorithms/online/kl_controller.py index 50054995..3941fee4 100644 --- a/compose_rl/algorithms/online/kl_controller.py +++ b/compose_rl/algorithms/online/kl_controller.py @@ -102,9 +102,7 @@ def __init__( device: str = 'cpu', ): super().__init__(device=device) - self._value: torch.Tensor = torch.tensor([init_kl_coef], - requires_grad=True, - device=self.device) + self._value: torch.Tensor = torch.tensor([init_kl_coef], requires_grad=True, device=self.device) self._target = target self._horizon = horizon self._optim = torch.optim.Adam([self._value], lr=kl_lr) @@ -161,9 +159,7 @@ def __init__( device: str = 'cpu', ): super().__init__(device=device) - self._value: torch.Tensor = torch.tensor([init_kl_coef], - requires_grad=True, - device=self.device) + self._value: torch.Tensor = torch.tensor([init_kl_coef], requires_grad=True, device=self.device) self._target = target self._horizon = horizon self._optim = torch.optim.Adam([self._value], lr=kl_lr) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 04a8970f..8edb0f3b 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -93,9 +93,7 @@ def prepare_critic_values_for_training( values *= action_mask if zero_pad: - zero_pad_tensor = torch.zeros((bs, 1), - device=values.device, - dtype=values.dtype) + zero_pad_tensor = torch.zeros((bs, 1), device=values.device, dtype=values.dtype) values = torch.cat([values, zero_pad_tensor], dim=-1) return values @@ -190,8 +188,7 @@ def critic_loss( batch['action_mask'], ) - val_error = utils.sample_wise_masked_mean((v_preds - returns)**2, - batch['action_mask']) + val_error = utils.sample_wise_masked_mean((v_preds - returns)**2, batch['action_mask']) critic_dict = { 'loss/value_loss': @@ -235,8 +232,7 @@ def policy_loss( if loss_type in ALGORITHM_TYPE.CLIPPED_PG: assert advantages is not None - online_log_probs, old_log_probs = outputs['online_log_probs'], batch[ - 'old_log_probs'] + online_log_probs, old_log_probs = outputs['online_log_probs'], batch['old_log_probs'] old_entropies = batch['old_entropies'] gen_logits = utils.get_batched_generated_values( batched_values=outputs['logits'], @@ -247,8 +243,8 @@ def policy_loss( logits=gen_logits, ) assert token_entropies.shape == batch['action_mask'].shape, ( - 'Token entropies shape {token_entropies_shape} does not match action mask shape {action_mask_shape}.' - .format( + 'Token entropies shape {token_entropies_shape} does not match action mask shape {action_mask_shape}.'. + format( token_entropies_shape=token_entropies.shape, action_mask_shape=batch['action_mask'].shape, ), @@ -262,13 +258,11 @@ def policy_loss( flattened_entropies = masked_token_entropies.flatten() # Calculate entropies at different percentiles - percentiles = torch.tensor([0, 20, 40, 60, 80, 100], - device=token_entropies.device) + percentiles = torch.tensor([0, 20, 40, 60, 80, 100], device=token_entropies.device) num_entropies = flattened_entropies.numel() if num_entropies > 0: # Calculate indices for percentiles (excluding 0 and 100) - indices = ((percentiles / 100.0) * - (num_entropies - 1)).ceil().long() + indices = ((percentiles / 100.0) * (num_entropies - 1)).ceil().long() # Get sorted values sorted_entropies = flattened_entropies.sort().values @@ -335,59 +329,45 @@ def policy_loss( policy_loss = utils.masked_sum(policy_loss, batch['action_mask']) policy_token_kl_logging_dict = { - f'token_kl/policy_token_kl_{k}_estimate': - utils.sample_wise_masked_mean( - v, - batch['action_mask'], - ) for k, v in policy_kl_dict.items() + f'token_kl/policy_token_kl_{k}_estimate': utils.sample_wise_masked_mean( + v, + batch['action_mask'], + ) for k, v in policy_kl_dict.items() } policy_seq_kl_logging_dict = { - f'seq_kl/policy_seq_kl_{k}_estimate': - utils.masked_sum( - v, - batch['action_mask'], - ) for k, v in policy_kl_dict.items() + f'seq_kl/policy_seq_kl_{k}_estimate': utils.masked_sum( + v, + batch['action_mask'], + ) for k, v in policy_kl_dict.items() } online_ift_token_kl_logging_dict = { - f'token_kl/online_ift_token_kl_{k}_estimate': - utils.sample_wise_masked_mean( - v, - batch['action_mask'], - ) for k, v in online_ift_kl_dict.items() + f'token_kl/online_ift_token_kl_{k}_estimate': utils.sample_wise_masked_mean( + v, + batch['action_mask'], + ) for k, v in online_ift_kl_dict.items() } online_ift_seq_kl_logging_dict = { - f'seq_kl/online_ift_seq_kl_{k}_estimate': - utils.masked_sum( - v, - batch['action_mask'], - ) for k, v in online_ift_kl_dict.items() + f'seq_kl/online_ift_seq_kl_{k}_estimate': utils.masked_sum( + v, + batch['action_mask'], + ) for k, v in online_ift_kl_dict.items() } policy_dict = { - 'loss/policy_loss': - policy_loss, - 'kl/policy_kl': - policy_kl, - 'kl/online_ift_kl': - online_ift_kl, - 'kl/ift_kl_scalar': - batch['ift_kl_scalar'], + 'loss/policy_loss': policy_loss, + 'kl/policy_kl': policy_kl, + 'kl/online_ift_kl': online_ift_kl, + 'kl/ift_kl_scalar': batch['ift_kl_scalar'], **policy_token_kl_logging_dict, **policy_seq_kl_logging_dict, **online_ift_token_kl_logging_dict, **online_ift_seq_kl_logging_dict, - 'policy_loss/clip_frac': - policy_clip_frac, - 'policy_loss/ratio': - utils.sample_wise_masked_mean(ratio, batch['action_mask']), - 'gen/gen_length': - batch['action_mask'].sum(dim=1).to(torch.float32), - 'gen/prev_seq_entropy': - old_entropies, - 'gen/cur_seq_entropy': - seq_entropies, - 'advantages/mean': - utils.sample_wise_masked_mean(advantages, batch['action_mask']), + 'policy_loss/clip_frac': policy_clip_frac, + 'policy_loss/ratio': utils.sample_wise_masked_mean(ratio, batch['action_mask']), + 'gen/gen_length': batch['action_mask'].sum(dim=1).to(torch.float32), + 'gen/prev_seq_entropy': old_entropies, + 'gen/cur_seq_entropy': seq_entropies, + 'advantages/mean': utils.sample_wise_masked_mean(advantages, batch['action_mask']), } # Add entropy percentiles to policy_dict for i, p in enumerate(percentiles): @@ -429,8 +409,7 @@ def policy_loss( assert vstars.size() == rewards.size() == masked_log_probs_diff.size( ) # should have the same shape which is (batch_size, ) - policy_loss = ((beta * masked_log_probs_diff - - (rewards - vstars))**2).mean() + policy_loss = ((beta * masked_log_probs_diff - (rewards - vstars))**2).mean() policy_dict = { 'loss/policy_loss': policy_loss, 'kl/policy_kl': policy_kl, @@ -577,23 +556,17 @@ def online_rl_loss( return_dict['total'] = return_dict['loss/policy_loss'] if loss_type in ALGORITHM_TYPE.ACTOR_CRITIC: # Add value loss to total loss - return_dict['total'] += value_loss_weight * return_dict[ - 'loss/value_loss'] # pyright: ignore + return_dict['total'] += value_loss_weight * return_dict['loss/value_loss'] # pyright: ignore # If we want to directly minimize the KL Divergence, we can do so here # and it will not include the KL in the reward. if add_direct_kl_loss: - return_dict['total'] += batch['ift_kl_scalar'][0] * return_dict[ - 'kl/online_ift_kl'] - return_dict['loss/online_ift_kl'] = ( - batch['ift_kl_scalar'][0] * return_dict['kl/online_ift_kl'] - ) + return_dict['total'] += batch['ift_kl_scalar'][0] * return_dict['kl/online_ift_kl'] + return_dict['loss/online_ift_kl'] = (batch['ift_kl_scalar'][0] * return_dict['kl/online_ift_kl']) # Entropy Loss. Meant to promote diversity. if entropy_loss_weight is not None: # We want to maximize entropy so we deduct it from the loss. - entropy_loss = -1.0 * ( - entropy_loss_weight * return_dict['gen/cur_seq_entropy'] - ).mean() + entropy_loss = -1.0 * (entropy_loss_weight * return_dict['gen/cur_seq_entropy']).mean() # breakpoint() return_dict['loss/entropy'] = entropy_loss return_dict['total'] += entropy_loss diff --git a/compose_rl/algorithms/online/modeling_hf.py b/compose_rl/algorithms/online/modeling_hf.py index dfd58e9b..79091332 100644 --- a/compose_rl/algorithms/online/modeling_hf.py +++ b/compose_rl/algorithms/online/modeling_hf.py @@ -33,9 +33,7 @@ class ComposerHFPolicy(BaseHuggingFaceModel): See base class for argument documentation. """ - model_cls: Union[ - type[_BaseAutoModelClass], - type[PreTrainedModel]] = AutoModelForCausalLMAsPolicy # type: ignore + model_cls: Union[type[_BaseAutoModelClass], type[PreTrainedModel]] = AutoModelForCausalLMAsPolicy # type: ignore default_train_metrics: tuple = () default_eval_metrics: tuple = () @@ -153,13 +151,10 @@ def prepare_inner_model( if hasattr(model, 'peft_type') and model.peft_type is not None: peft_type = model.peft_type.lower() # type: ignore active_adapters = [ - adapter.lower() - for adapter in model.active_adapters # type: ignore + adapter.lower() for adapter in model.active_adapters # type: ignore ] for name, module in model.named_modules(): - if peft_type in name.lower() and any( - adapter in name.lower() for adapter in active_adapters - ): + if peft_type in name.lower() and any(adapter in name.lower() for adapter in active_adapters): has_parameters = next(module.parameters(), None) is not None has_buffers = next(module.buffers(), None) is not None if has_parameters or has_buffers: diff --git a/compose_rl/algorithms/online/reward_manager.py b/compose_rl/algorithms/online/reward_manager.py index 94187d41..e9fd9b1a 100644 --- a/compose_rl/algorithms/online/reward_manager.py +++ b/compose_rl/algorithms/online/reward_manager.py @@ -89,8 +89,7 @@ def __init__( self.functional_rewards: list[str] = [] self.local_reward_models: list[str] = [] - ref_model_config: dict[str, - Any] = self.ref_config.get('model_config', None) + ref_model_config: dict[str, Any] = self.ref_config.get('model_config', None) self.reference_model = self.initialize_composer_model( model_config=ref_model_config, @@ -165,8 +164,7 @@ def __init__( self.pool = None if self.inference_rewards or self.functional_rewards: self.pool = Pool( - processes=len(self.inference_rewards) + - len(self.functional_rewards), + processes=len(self.inference_rewards) + len(self.functional_rewards), context=get_context('spawn'), ) @@ -258,12 +256,9 @@ def call_reward_model( # We need to do this to handle getting rewards at multiple points in a # single input sequence with a deployed RM. if isinstance(reward_model, InferenceRewardModel): - rm_seq_lens = [ - [idx + prompt_len - for idx in gather_indices] - for gather_indices, prompt_len in - zip(batch['end_idxs_gather'], batch['reward_prompt_lens']) - ] + rm_seq_lens = [[idx + prompt_len + for idx in gather_indices] + for gather_indices, prompt_len in zip(batch['end_idxs_gather'], batch['reward_prompt_lens'])] else: rm_seq_lens = batch['reward_seq_lens'] @@ -459,8 +454,7 @@ def _create_batch( } elif isinstance(reward_model, RewardModel): granularity = self.granularities[reward_name] - curr_inputs = processed_inputs['end_reward_inputs_dict'][granularity - ] + curr_inputs = processed_inputs['end_reward_inputs_dict'][granularity] tok_formatted_reward_inputs = torch.tensor( curr_inputs.input_ids, ).type(base_batch['input_ids'].dtype) @@ -469,26 +463,16 @@ def _create_batch( ).type(base_batch['attention_mask'].dtype) return { - 'tok_formatted_reward_inputs': - tok_formatted_reward_inputs, - 'tok_formatted_reward_attn_masks': - tok_formatted_reward_attn_masks, - 'reward_seq_lens': - processed_inputs['reward_seq_lens_dict'][granularity], - 'reward_prompt_lens': - processed_inputs['reward_prompt_lens_dict'][granularity], - 'reward_generated_lens': - processed_inputs['reward_generated_lens_dict'][granularity], - 'end_idxs_gather': - processed_inputs['end_idxs_gather_dict'][granularity], - 'end_idxs_scatter': - processed_inputs['end_idxs_scatter_dict'][granularity], - 'prompt_lens': - base_batch['prompt_len'], - 'generated_lens': - base_batch['generated_lens'], - 'seq_lens': - base_batch['seq_lens'], + 'tok_formatted_reward_inputs': tok_formatted_reward_inputs, + 'tok_formatted_reward_attn_masks': tok_formatted_reward_attn_masks, + 'reward_seq_lens': processed_inputs['reward_seq_lens_dict'][granularity], + 'reward_prompt_lens': processed_inputs['reward_prompt_lens_dict'][granularity], + 'reward_generated_lens': processed_inputs['reward_generated_lens_dict'][granularity], + 'end_idxs_gather': processed_inputs['end_idxs_gather_dict'][granularity], + 'end_idxs_scatter': processed_inputs['end_idxs_scatter_dict'][granularity], + 'prompt_lens': base_batch['prompt_len'], + 'generated_lens': base_batch['generated_lens'], + 'seq_lens': base_batch['seq_lens'], } else: raise TypeError( @@ -594,9 +578,7 @@ def resolve_outputs( bad_end_generation_name = name bad_generation_row_mask = torch.any(resolved_reward != 0, dim=1) - bad_end_generation_mask = ( - ~bad_generation_row_mask - ).unsqueeze(1).expand_as(resolved_reward) + bad_end_generation_mask = (~bad_generation_row_mask).unsqueeze(1).expand_as(resolved_reward) bad_end_generation_mask = bad_end_generation_mask.to( device=device, ) diff --git a/compose_rl/algorithms/reward_modeling/__init__.py b/compose_rl/algorithms/reward_modeling/__init__.py index ea6779b5..c7b48d14 100644 --- a/compose_rl/algorithms/reward_modeling/__init__.py +++ b/compose_rl/algorithms/reward_modeling/__init__.py @@ -21,8 +21,7 @@ AutoModelForCausalLMWithRM, RewardModelConfig, ) -from compose_rl.algorithms.reward_modeling.inference_model import \ - InferenceRewardModel +from compose_rl.algorithms.reward_modeling.inference_model import InferenceRewardModel from compose_rl.algorithms.reward_modeling.model import ( ComposerHFCausalClassifierRewardModel, ComposerHFClassifierRewardModel, diff --git a/compose_rl/algorithms/reward_modeling/functional.py b/compose_rl/algorithms/reward_modeling/functional.py index 7681753f..1ba4c39b 100644 --- a/compose_rl/algorithms/reward_modeling/functional.py +++ b/compose_rl/algorithms/reward_modeling/functional.py @@ -68,11 +68,7 @@ def __call__( curr_rewards = [] for gen_text in all_generated_texts: gen_tokens = gen_text.split() - number_tokens = [ - float(token) - for token in gen_tokens - if IncreasingNumbersReward.is_number(token) - ] + number_tokens = [float(token) for token in gen_tokens if IncreasingNumbersReward.is_number(token)] if len(number_tokens) > 0: sorted_count = 1 previous_token = number_tokens[0] @@ -110,8 +106,7 @@ def __init__(self, reward: float, len_threshold: int, tokenizer: Tokenizer): self.len_threshold = len_threshold log.info( - f'Adding a reward of {self.reward} if a model generates ' + - f'tokens under the length {self.len_threshold}', + f'Adding a reward of {self.reward} if a model generates ' + f'tokens under the length {self.len_threshold}', ) def __call__( @@ -165,9 +160,8 @@ def __init__( self.eos_penalty = eos_penalty # Extra special tokens for any other formats with pseudo EOS alternatives like ChatML - self.extra_special_tokens = [ - str(tok) for tok in extra_special_tokens - ] if extra_special_tokens is not None else [] + self.extra_special_tokens = [str(tok) for tok in extra_special_tokens + ] if extra_special_tokens is not None else [] self.extra_special_token_ids = [] if self.extra_special_tokens != []: self.extra_special_token_ids.extend([ diff --git a/compose_rl/algorithms/reward_modeling/inference_model.py b/compose_rl/algorithms/reward_modeling/inference_model.py index f3ed5f6e..71be1023 100644 --- a/compose_rl/algorithms/reward_modeling/inference_model.py +++ b/compose_rl/algorithms/reward_modeling/inference_model.py @@ -114,8 +114,7 @@ def __call__( deployment_inputs = [] batch_indices = [] reward_indices = [] - for bidx, (seq_input_ids, - seq_reward_indices) in enumerate(zip(input_ids, seq_lens)): + for bidx, (seq_input_ids, seq_reward_indices) in enumerate(zip(input_ids, seq_lens)): for seq_reward_index in seq_reward_indices: deployment_inputs.append({ 'input_ids': seq_input_ids[:seq_reward_index + 1], @@ -146,10 +145,7 @@ def call_predict_with_backoff( ) response = response.json() # Currently, all outputs will contain a single reward, coming from the last token. - rewards = [ - choice['metadata']['rewards'][-1] - for choice in response['choices'] - ] + rewards = [choice['metadata']['rewards'][-1] for choice in response['choices']] return rewards try: @@ -161,8 +157,7 @@ def call_predict_with_backoff( # Retry limit has been reached. Raise the error :( error_msg = ( 'REWARD MODEL DEPLOYMENT BACKOFF LIMIT EXCEEDED. ' + - 'Printing deployment inputs then raising last error...' + - f'\nDeployment inputs:\n{deployment_inputs}' + 'Printing deployment inputs then raising last error...' + f'\nDeployment inputs:\n{deployment_inputs}' ) raise RuntimeError(error_msg) from e diff --git a/compose_rl/algorithms/reward_modeling/model.py b/compose_rl/algorithms/reward_modeling/model.py index cb6dc449..843c102b 100644 --- a/compose_rl/algorithms/reward_modeling/model.py +++ b/compose_rl/algorithms/reward_modeling/model.py @@ -16,8 +16,7 @@ RewardModel, Tokenizer, ) -from compose_rl.algorithms.reward_modeling.hf_utils import \ - SequenceClassifierOutput +from compose_rl.algorithms.reward_modeling.hf_utils import SequenceClassifierOutput from compose_rl.algorithms.reward_modeling.model_methods import ( ClassifierRewardEnum, PairwiseRewardEnum, @@ -27,10 +26,8 @@ pairwise_forward, pairwise_loss, ) -from compose_rl.algorithms.reward_modeling.modeling_hf import \ - ComposerHFSequenceClassification -from compose_rl.algorithms.reward_modeling.modeling_mpt import \ - MPTForSequenceClassification +from compose_rl.algorithms.reward_modeling.modeling_hf import ComposerHFSequenceClassification +from compose_rl.algorithms.reward_modeling.modeling_mpt import MPTForSequenceClassification log = logging.getLogger(__name__) @@ -98,8 +95,7 @@ def eval_forward( ) -> dict[str, torch.Tensor]: return outputs if outputs is not None else self.forward(batch) - def loss(self, outputs: SequenceClassifierOutput, - batch: Mapping) -> dict[str, torch.Tensor]: + def loss(self, outputs: SequenceClassifierOutput, batch: Mapping) -> dict[str, torch.Tensor]: return pairwise_loss( outputs, batch, @@ -164,8 +160,7 @@ def eval_forward( ) -> dict[str, torch.Tensor]: return outputs if outputs is not None else self.forward(batch) - def loss(self, outputs: SequenceClassifierOutput, - batch: Mapping) -> dict[str, torch.Tensor]: + def loss(self, outputs: SequenceClassifierOutput, batch: Mapping) -> dict[str, torch.Tensor]: return classifier_loss( outputs, batch, @@ -190,9 +185,7 @@ def __init__( self.return_lm_logits = return_lm_logits self.return_last = return_last - kwargs[ - 'loss_fn' - ] = 'torch_crossentropy' # NOTE: passing in dummy value to overwrite + kwargs['loss_fn'] = 'torch_crossentropy' # NOTE: passing in dummy value to overwrite super().__init__( tokenizer=tokenizer, use_train_metrics=use_train_metrics, @@ -231,8 +224,7 @@ def eval_forward( ) -> dict[str, torch.Tensor]: return outputs if outputs is not None else self.forward(batch) - def loss(self, outputs: SequenceClassifierOutput, - batch: Mapping) -> dict[str, torch.Tensor]: + def loss(self, outputs: SequenceClassifierOutput, batch: Mapping) -> dict[str, torch.Tensor]: return pairwise_loss( outputs, batch, @@ -354,8 +346,7 @@ def eval_forward( batch, ) # type: ignore - def loss(self, outputs: SequenceClassifierOutput, - batch: Mapping) -> dict[str, torch.Tensor]: + def loss(self, outputs: SequenceClassifierOutput, batch: Mapping) -> dict[str, torch.Tensor]: return classifier_loss( outputs, batch, diff --git a/compose_rl/algorithms/reward_modeling/model_methods.py b/compose_rl/algorithms/reward_modeling/model_methods.py index 45df58f5..6c50dd2f 100644 --- a/compose_rl/algorithms/reward_modeling/model_methods.py +++ b/compose_rl/algorithms/reward_modeling/model_methods.py @@ -15,8 +15,7 @@ PreTrainedTokenizerFast, ) -from compose_rl.algorithms.reward_modeling.hf_utils import \ - SequenceClassifierOutput +from compose_rl.algorithms.reward_modeling.hf_utils import SequenceClassifierOutput from compose_rl.utils import ( clear_mb_load_balancing_loss, extract_packed_chosen_rejected, @@ -284,13 +283,10 @@ def pairwise_loss( losses = losses.mean() loss_dict = { - 'chosen_rewards': - chosen_scores.detach(), - 'rejected_rewards': - rejected_scores.detach(), + 'chosen_rewards': chosen_scores.detach(), + 'rejected_rewards': rejected_scores.detach(), 'margin': (chosen_scores - rejected_scores).detach(), - 'accuracy': (chosen_scores - > rejected_scores).detach().to(torch.float32), + 'accuracy': (chosen_scores > rejected_scores).detach().to(torch.float32), } loss_dict.update(partial_loss_dict) diff --git a/compose_rl/algorithms/reward_modeling/modeling_hf.py b/compose_rl/algorithms/reward_modeling/modeling_hf.py index 22ee944b..39a6cf48 100644 --- a/compose_rl/algorithms/reward_modeling/modeling_hf.py +++ b/compose_rl/algorithms/reward_modeling/modeling_hf.py @@ -68,9 +68,7 @@ class ComposerHFSequenceClassification(BaseHuggingFaceModel): use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``. tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ - model_cls: Union[ - type[_BaseAutoModelClass], - type[PreTrainedModel]] = AutoModelForCausalLMWithRM # type: ignore + model_cls: Union[type[_BaseAutoModelClass], type[PreTrainedModel]] = AutoModelForCausalLMWithRM # type: ignore default_train_metrics: tuple = () default_eval_metrics: tuple = () @@ -219,13 +217,10 @@ def prepare_inner_model( if hasattr(model, 'peft_type') and model.peft_type is not None: peft_type = model.peft_type.lower() # type: ignore active_adapters = [ - adapter.lower() - for adapter in model.active_adapters # type: ignore + adapter.lower() for adapter in model.active_adapters # type: ignore ] for name, module in model.named_modules(): - if peft_type in name.lower() and any( - adapter in name.lower() for adapter in active_adapters - ): + if peft_type in name.lower() and any(adapter in name.lower() for adapter in active_adapters): has_parameters = next(module.parameters(), None) is not None has_buffers = next(module.buffers(), None) is not None if has_parameters or has_buffers: diff --git a/compose_rl/data/buffer.py b/compose_rl/data/buffer.py index 685ae2ad..c9b9f9ec 100644 --- a/compose_rl/data/buffer.py +++ b/compose_rl/data/buffer.py @@ -131,4 +131,5 @@ def set_state_dict(self, state_dict: dict[str, Any], epoch: int): state_dict['epoch'] = epoch log.info(f'Saving state dict to: {state_dict}') self.dataset.set_state_dict( # pyright: ignore[reportGeneralTypeIssues] - state_dict) + state_dict, + ) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 617492f7..6b04c997 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -145,8 +145,7 @@ def get_num_tokens_in_batch_online( return int(relevant_tokens) else: log.warning( - 'No action_mask/prompt_len/sequences in batch. ' + - 'Using default value of 0 for num_tokens_in_batch.', + 'No action_mask/prompt_len/sequences in batch. ' + 'Using default value of 0 for num_tokens_in_batch.', ) return 0 diff --git a/compose_rl/data/messages_data.py b/compose_rl/data/messages_data.py index 16b836ff..18516090 100644 --- a/compose_rl/data/messages_data.py +++ b/compose_rl/data/messages_data.py @@ -62,8 +62,7 @@ def messages_dataset_collate_fn( raise ValueError(f'Invalid key: {key}') collated_batch['prompt_attention_mask'] = torch.logical_not( - torch.eq(collated_batch['prompt'], - tokenizer.pad_token_id), # type: ignore + torch.eq(collated_batch['prompt'], tokenizer.pad_token_id), # type: ignore ) return collated_batch diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index c08d2757..a7295d13 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -103,8 +103,7 @@ def pairwise_preference_dataset_collate_fn( cat_batch = torch.cat( [ cat_batch, - torch.ones(int(pad_len.item()), dtype=cat_batch.dtype) * - tokenizer.pad_token_id, # type: ignore + torch.ones(int(pad_len.item()), dtype=cat_batch.dtype) * tokenizer.pad_token_id, # type: ignore ], dim=-1, # type: ignore ) @@ -113,8 +112,7 @@ def pairwise_preference_dataset_collate_fn( torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore ) - cur_sequence_id = torch.tensor(([0] * chosen_len) + - ([1] * rejected_len) + + cur_sequence_id = torch.tensor(([0] * chosen_len) + ([1] * rejected_len) + ([-1] * max(0, int(pad_len.item()))),) sequence_id.append(cur_sequence_id) @@ -177,11 +175,7 @@ def finegrained_preference_dataset_collate_fn( cur_values = [item[key] for item in data] if key == 'prompt_mask': max_len = max([len(val) for val in cur_values]) - mask = torch.stack([ - torch.cat([torch.Tensor(val), - torch.ones(max_len - len(val))]) - for val in cur_values - ]) + mask = torch.stack([torch.cat([torch.Tensor(val), torch.ones(max_len - len(val))]) for val in cur_values]) mask = ~mask.to(torch.bool) batch[key] = mask.to(torch.int8) continue @@ -221,8 +215,7 @@ def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): ) log.info(f'Truncating: {truncated}') decoded_arr = torch.from_numpy( - np.frombuffer(sample[key], - dtype=np.int64)[:self.max_seq_len].copy(), + np.frombuffer(sample[key], dtype=np.int64)[:self.max_seq_len].copy(), ) return decoded_arr @@ -303,8 +296,7 @@ def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): ) log.info(f'Truncated sample: {truncated}') decoded_arr = torch.from_numpy( - np.frombuffer(sample[key], - dtype=np.int64)[:self.max_seq_len].copy(), + np.frombuffer(sample[key], dtype=np.int64)[:self.max_seq_len].copy(), ) else: decoded_arr = torch.from_numpy( diff --git a/compose_rl/data/prompt_data.py b/compose_rl/data/prompt_data.py index 82135bd0..d1fa6ff1 100644 --- a/compose_rl/data/prompt_data.py +++ b/compose_rl/data/prompt_data.py @@ -59,8 +59,7 @@ def prompt_dataset_collate_fn( collated_batch[key] = ref_collate_fn(cur_values)['input_ids'] collated_batch['prompt_attention_mask'] = torch.logical_not( - torch.eq(collated_batch['prompt'], - tokenizer.pad_token_id), # type: ignore + torch.eq(collated_batch['prompt'], tokenizer.pad_token_id), # type: ignore ) return collated_batch @@ -81,8 +80,7 @@ def __init__( def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): decoded_arr = torch.from_numpy( - np.frombuffer(sample[key], - dtype=np.int64)[:self.max_seq_len].copy(), + np.frombuffer(sample[key], dtype=np.int64)[:self.max_seq_len].copy(), ) return decoded_arr diff --git a/compose_rl/metrics/__init__.py b/compose_rl/metrics/__init__.py index c5ad2368..759dd563 100644 --- a/compose_rl/metrics/__init__.py +++ b/compose_rl/metrics/__init__.py @@ -1,8 +1,7 @@ # Copyright 2024 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 -from compose_rl.metrics.reward_model_metrics import \ - PairwiseRewardClassificationAccuracy +from compose_rl.metrics.reward_model_metrics import PairwiseRewardClassificationAccuracy __all__ = [ 'PairwiseRewardClassificationAccuracy', diff --git a/compose_rl/metrics/reward_model_metrics.py b/compose_rl/metrics/reward_model_metrics.py index 88343779..06b8706b 100644 --- a/compose_rl/metrics/reward_model_metrics.py +++ b/compose_rl/metrics/reward_model_metrics.py @@ -33,8 +33,7 @@ def update(self, batch: dict, output_logits: torch.Tensor): bs, _ = batch['chosen_scores'].shape self.total += bs - self.correct += (batch['chosen_scores'] - > batch['rejected_scores']).sum().detach().cpu() + self.correct += (batch['chosen_scores'] > batch['rejected_scores']).sum().detach().cpu() def compute(self): assert isinstance(self.correct, Tensor) diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index b51b7fc2..94dfb386 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -176,8 +176,7 @@ def get_token_entropies( """ # Calculate entropy using the logsumexp trick pd = F.softmax(logits, dim=-1) - token_entropies = torch.logsumexp(logits, - dim=-1) - torch.sum(pd * logits, dim=-1) + token_entropies = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) return token_entropies @@ -236,10 +235,7 @@ def remove_left_padding( max_gen_len (int): the maximum generation length. """ batch_size, _ = sequences.shape - unpadded_obs = [ - sequences[i, -(seq_length[i] + max_gen_len):] - for i in range(batch_size) - ] + unpadded_obs = [sequences[i, -(seq_length[i] + max_gen_len):] for i in range(batch_size)] return unpadded_obs @@ -258,8 +254,7 @@ def add_right_padding( right_padded_obs = [ torch.cat([ seq, - torch.ones(max_len - len(seq), device=seq.device, dtype=seq.dtype) * - pad_token, + torch.ones(max_len - len(seq), device=seq.device, dtype=seq.dtype) * pad_token, ]) for seq in unpadded_sequences ] return torch.stack(right_padded_obs, dim=0) @@ -285,8 +280,7 @@ def get_batched_generated_values( assert not curr_max_gen_len.is_floating_point() generations.append( - batched_values[i, prompt_len[i] - 1:prompt_len[i] + - curr_max_gen_len - 1], + batched_values[i, prompt_len[i] - 1:prompt_len[i] + curr_max_gen_len - 1], ) return torch.stack(generations, dim=0) @@ -466,7 +460,8 @@ def get_training_dataloader_state_dict( ): num_samples = per_iter_global_train_batch_size * cur_dataloader_iter state_dict: dict = dataset.state_dict( # pyright: ignore[reportGeneralTypeIssues] - num_samples, True) + num_samples, True, + ) return state_dict else: warnings.warn( @@ -514,8 +509,7 @@ def mask_eos( seen_eos_batches.add(batch_idx) # We need to refix all of the padding since we now always generate max_gen_len tokens - req_pad_start_idx = prompt_len[eos_idx[0] - ] + generated_len[eos_idx[0]] + req_pad_start_idx = prompt_len[eos_idx[0]] + generated_len[eos_idx[0]] right_padded_obs[eos_idx[0], req_pad_start_idx:] = pad_token right_padded_attn_mask[eos_idx[0], req_pad_start_idx:] = False @@ -535,9 +529,7 @@ def get_decoded_sequence( def split_text_to_sentences(long_text: str, parser: spacy.Language): doc = parser(long_text) - return [0] + [ - sent.end_char for sent in doc.sents if len(str(sent).strip()) > 0 - ] + return [0] + [sent.end_char for sent in doc.sents if len(str(sent).strip()) > 0] def split_text_to_subsentences( @@ -610,8 +602,7 @@ def tokenize_with_indices(text: str): sentence_start_char_idxs[:-1], ): - sentence = long_text[ - sentence_start_char_idx:sentence_start_char_idxs[sentence_idx + 1]] + sentence = long_text[sentence_start_char_idx:sentence_start_char_idxs[sentence_idx + 1]] tokens_with_indices = tokenize_with_indices(sentence) @@ -692,10 +683,7 @@ def process_fine_granularities( end_char_idxs = [0, len(generated)] else: raise NotImplementedError(f'{granularity=} is not supported.') - generated_sequences = [ - generated[end_char_idxs[i]:end_char_idxs[i + 1]] - for i in range(len(end_char_idxs) - 1) - ] + generated_sequences = [generated[end_char_idxs[i]:end_char_idxs[i + 1]] for i in range(len(end_char_idxs) - 1)] # Initialize an empty list to store the end token indices of each sentence unaligned_end_indices = [] @@ -719,9 +707,7 @@ def process_fine_granularities( tokenized_reward_input = [ t for t in tokenized_reward_input if t is not None # type: ignore ] - concatenated_subseq_tokens = [ - t for t in concatenated_subseq_tokens if t is not None - ] + concatenated_subseq_tokens = [t for t in concatenated_subseq_tokens if t is not None] # Truncate here to prevent scatter gather indices from going over tokenized_reward_input = tokenized_reward_input[:max_seq_len] @@ -756,8 +742,7 @@ def process_fine_granularities( ) # The original tokenized obses RL training sees, without the decode step - original_generated_token_ids = original_obs[prompt_len:prompt_len + - generated_len] + original_generated_token_ids = original_obs[prompt_len:prompt_len + generated_len] original_generated_tokens = tokenizer.convert_ids_to_tokens( original_generated_token_ids, # type: ignore ) @@ -790,21 +775,13 @@ def process_fine_granularities( ) failed_align_end_idxs.append(i) # Get rid of indices in the final gather where the sequence alignment fails - end_indices_aligned_gather = [ - u for i, u in enumerate(end_indices_aligned_gather) - if i not in failed_align_end_idxs - ] + end_indices_aligned_gather = [u for i, u in enumerate(end_indices_aligned_gather) if i not in failed_align_end_idxs] assert len(end_indices_aligned_gather) == len(end_indices_aligned_scatter) # last token cutoffs and additions - end_indices_aligned_gather = [ - min(item, reward_generated_len - 1) - for item in end_indices_aligned_gather - ] + end_indices_aligned_gather = [min(item, reward_generated_len - 1) for item in end_indices_aligned_gather] - end_indices_aligned_scatter = [ - min(item, generated_len - 1) for item in end_indices_aligned_scatter - ] + end_indices_aligned_scatter = [min(item, generated_len - 1) for item in end_indices_aligned_scatter] # Special edge case for document level rewards if granularity == 'document': @@ -962,13 +939,10 @@ def scatter_gather_rewards( for i in range(batch_size): # Prompt length and generated length together should gives us sequence length assert (prompt_lens[i] + generated_lens[i]).item() == seq_lens[i].item() - assert (reward_prompt_lens[i] + - reward_generated_lens[i]).item() == reward_seq_lens[i].item() + assert (reward_prompt_lens[i] + reward_generated_lens[i]).item() == reward_seq_lens[i].item() # The number of indices you gather rews from outputs is same as scatter assert end_idxs_scatter[i].shape[-1] == end_idxs_gather[i].shape[-1] - batch_curr_rewards = curr_rewards[ - i, reward_prompt_lens[i]:reward_prompt_lens[i] + - reward_generated_lens[i]] + batch_curr_rewards = curr_rewards[i, reward_prompt_lens[i]:reward_prompt_lens[i] + reward_generated_lens[i]] gathered_rewards = batch_curr_rewards.gather( dim=0, index=end_idxs_gather[i], @@ -1100,8 +1074,7 @@ def extract_packed_chosen_rejected( ) chosen_values.append(padded_chosen) - unpadded_rejected = input_tensor[i, chosen_len[i]:chosen_len[i] + - rejected_len[i]] + unpadded_rejected = input_tensor[i, chosen_len[i]:chosen_len[i] + rejected_len[i]] padded_rejected = make_padded_tensor( unpadded_rejected, max_seq_len, @@ -1134,8 +1107,7 @@ def make_padded_tensor( dtype=input_tensor.dtype, ) * pad_token_id elif len(input_tensor.shape) == 2: - pad_tensor = torch.ones((pad_len, input_tensor.size(1)), - device=input_tensor.device, + pad_tensor = torch.ones((pad_len, input_tensor.size(1)), device=input_tensor.device, dtype=input_tensor.dtype) * pad_token_id else: raise NotImplementedError( diff --git a/pyproject.toml b/pyproject.toml index 278acc71..1b0e8fce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,7 @@ ppo_load_planner = "compose_rl.utils.load_planner:ActorCriticModelLoadPlanner" # iSort [tool.isort] multi_line_output = 0 -line_length = 80 +line_length = 120 skip = [ "env", "wandb", "runs", "build", "node_modules" ] include_trailing_comma = true split_on_trailing_comma = true @@ -269,7 +269,7 @@ blank_line_before_nested_class_or_def = true coalesce_brackets = true # The column limit. -column_limit = 80 +column_limit = 120 # The style for continuation alignment. Possible values are: # diff --git a/scripts/data/messages_dataset_to_mds.py b/scripts/data/messages_dataset_to_mds.py index af516cde..49b0b204 100644 --- a/scripts/data/messages_dataset_to_mds.py +++ b/scripts/data/messages_dataset_to_mds.py @@ -98,8 +98,7 @@ def get_preprocess_fn(self, dataset_path: str): elif 'math' in dataset_path.lower(): log.info('Using MATH preprocessing function') return prepare_math_messages - elif 'stem' in dataset_path.lower() or 'science' in dataset_path.lower( - ): + elif 'stem' in dataset_path.lower() or 'science' in dataset_path.lower(): log.info('Using STEM preprocessing function') return prepare_messages elif 'allenai/ultrafeedback_binarized_cleaned' in dataset_path.lower(): diff --git a/scripts/data/messages_preprocessing_utils.py b/scripts/data/messages_preprocessing_utils.py index b63bb95f..a5a15295 100644 --- a/scripts/data/messages_preprocessing_utils.py +++ b/scripts/data/messages_preprocessing_utils.py @@ -55,10 +55,8 @@ def prepare_ultrafeedback_summarization_messages( ) -> tuple[list[dict[str, str]], dict]: prompt = sample['prompt'] messages = [{ - 'role': - 'user', - 'content': - f'Can you summarize the following content in 50 words or less: {prompt}', + 'role': 'user', + 'content': f'Can you summarize the following content in 50 words or less: {prompt}', }] return messages, {} diff --git a/scripts/data/unified_tokenize_dataset.py b/scripts/data/unified_tokenize_dataset.py index 35f711a6..b78dace5 100644 --- a/scripts/data/unified_tokenize_dataset.py +++ b/scripts/data/unified_tokenize_dataset.py @@ -44,8 +44,7 @@ def __init__( split: str, tokenizer: PreTrainedTokenizerBase, max_length: int, - dataset_type: Literal['preference', 'single_prompt', - 'verifiable_answers'], + dataset_type: Literal['preference', 'single_prompt', 'verifiable_answers'], subset: str | None = None, token: str | None = None, ): @@ -115,10 +114,8 @@ def _process_single_prompt_sample(self, sample: Any): """ prompt = sample['prompt'] messages = [{ - 'role': - 'user', - 'content': - f'Can you summarize the following content in 50 words or less: {prompt}', + 'role': 'user', + 'content': f'Can you summarize the following content in 50 words or less: {prompt}', }] encoded_prompt = self.tokenizer.apply_chat_template( messages, diff --git a/scripts/launch_composer_ray.py b/scripts/launch_composer_ray.py index b093c220..4e738813 100644 --- a/scripts/launch_composer_ray.py +++ b/scripts/launch_composer_ray.py @@ -295,8 +295,7 @@ def reassign_train_and_inference_ranks( # This is just a worker to coordinate from global rank 0 on training # to signal to inference nodes training is done sync_actor = None - if os.getenv('NODE_RANK', - None) == '0' and os.getenv('LOCAL_RANK', None) == '0': + if os.getenv('NODE_RANK', None) == '0' and os.getenv('LOCAL_RANK', None) == '0': train_world_size = os.getenv('TRAIN_WORLD_SIZE', None) train_num_nodes = os.getenv('TRAIN_NUM_NODES', None) master_port = os.getenv('TRAIN_MASTER_PORT', None) @@ -322,8 +321,7 @@ def reassign_train_and_inference_ranks( if train_num_nodes is not None: train_from_yaml(yaml_path, args_list) log.info('After calling `train_from_yaml`') - if os.getenv('NODE_RANK', - None) == '0' and os.getenv('LOCAL_RANK', None) == '0': + if os.getenv('NODE_RANK', None) == '0' and os.getenv('LOCAL_RANK', None) == '0': status = ray.get(sync_actor.mark_done.remote()) # type: ignore else: diff --git a/tests/common/actor.py b/tests/common/actor.py index a2eab75f..545dcf09 100644 --- a/tests/common/actor.py +++ b/tests/common/actor.py @@ -44,8 +44,7 @@ def __init__( os.environ['RANK'] = str(rank) # Set LOCAL_RANK based on Ray GPU allocation - os.environ['LOCAL_RANK'] = '0' if is_cuda_visible_devices_set( - ) else str(ray.get_gpu_ids()[0]) + os.environ['LOCAL_RANK'] = '0' if is_cuda_visible_devices_set() else str(ray.get_gpu_ids()[0]) # If this is rank 0 and no master_addr/master_port provided, allocate them if rank == 0 and (master_addr is None or master_port is None): diff --git a/tests/common/datasets.py b/tests/common/datasets.py index 795d264c..62502942 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -109,12 +109,8 @@ def __getitem__(self, index: int): }] # bit of a hack, but it works mock_prompt = torch.ones((len(messages[0]['content']),)).int() return { - 'messages': - messages, - 'prompt': - mock_prompt, - 'prompt_len': - torch.Tensor([len(messages[0]['content'])]).to(torch.int64), - 'verified_answer': - 'Paris', + 'messages': messages, + 'prompt': mock_prompt, + 'prompt_len': torch.Tensor([len(messages[0]['content'])]).to(torch.int64), + 'verified_answer': 'Paris', } diff --git a/tests/common/markers.py b/tests/common/markers.py index 1b76abd4..0ec47a9a 100644 --- a/tests/common/markers.py +++ b/tests/common/markers.py @@ -27,22 +27,19 @@ def device(*args: str, precision: bool = False): if precision: devices = { - 'cpu': - pytest.param('cpu', Precision.FP32, id='cpu-fp32'), - 'gpu': - pytest.param( - 'gpu', - Precision.FP32, - id='gpu-fp32', - marks=pytest.mark.gpu, - ), - 'gpu-amp': - pytest.param( - 'gpu', - Precision.AMP_FP16, - id='gpu-amp', - marks=pytest.mark.gpu, - ), + 'cpu': pytest.param('cpu', Precision.FP32, id='cpu-fp32'), + 'gpu': pytest.param( + 'gpu', + Precision.FP32, + id='gpu-fp32', + marks=pytest.mark.gpu, + ), + 'gpu-amp': pytest.param( + 'gpu', + Precision.AMP_FP16, + id='gpu-amp', + marks=pytest.mark.gpu, + ), } name = 'device,precision' else: diff --git a/tests/functional_rewards/test_bad_generation.py b/tests/functional_rewards/test_bad_generation.py index 4d9160bb..a06e1201 100644 --- a/tests/functional_rewards/test_bad_generation.py +++ b/tests/functional_rewards/test_bad_generation.py @@ -34,18 +34,14 @@ def reward() -> BadGenerationEndReward: [ ( { - 'zero_rewards': - torch.zeros((3, 5)), - 'seq_lens': - torch.tensor([5, 5, 5]), - 'input_ids': - torch.tensor([ - [1, 2, 3, 4, 5], - [1, 2, 3, 4, 0], - [1, 2, 3, 4, 50277], - ]), - 'generated_lens': - torch.tensor([5, 5, 5]), + 'zero_rewards': torch.zeros((3, 5)), + 'seq_lens': torch.tensor([5, 5, 5]), + 'input_ids': torch.tensor([ + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 0], + [1, 2, 3, 4, 50277], + ]), + 'generated_lens': torch.tensor([5, 5, 5]), }, [(0, 4, -1.0), (1, 4, 0.0), (2, 4, 0.0)], ), diff --git a/tests/test_offline.py b/tests/test_offline.py index 2838700b..92b3e380 100644 --- a/tests/test_offline.py +++ b/tests/test_offline.py @@ -86,8 +86,7 @@ def test_load_checkpoint_with_offline_callback( # Remove the _load added by the mock comparison_dict.pop('_load') # The callback passed into the dummy trainer should match the original callback - assert mock_trainer.call_args.kwargs['callbacks'][ - 0].__dict__ == comparison_dict + assert mock_trainer.call_args.kwargs['callbacks'][0].__dict__ == comparison_dict load_checkpoint_callback._load.assert_called_once() @@ -307,8 +306,7 @@ def test_checkpoint_reloading( trainer1.fit(duration='4ba') margins = in_memory_logger.data['loss/train/margin'] # The first margin should be 0.0 - assert margins[0][ - 1] == 0.0, 'The margin should be 0.0 in the first trainer fit' + assert margins[0][1] == 0.0, 'The margin should be 0.0 in the first trainer fit' # Restart the training from the intermediate checkpoint in_memory_logger = InMemoryLogger() @@ -328,6 +326,4 @@ def test_checkpoint_reloading( trainer2.fit() margins = in_memory_logger.data['loss/train/margin'] # After resuming the training, the first margin should not be 0.0 - assert margins[0][ - 1 - ] != 0.0, 'The margin should not be 0.0 in the second trainer fit after resuming' + assert margins[0][1] != 0.0, 'The margin should not be 0.0 in the second trainer fit after resuming' diff --git a/tests/test_registry.py b/tests/test_registry.py index 5de63f82..c9923954 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -12,8 +12,7 @@ def test_expected_registries_exist(): existing_registries = { - name for name in dir(registry) - if isinstance(getattr(registry, name), registry_utils.TypedRegistry) + name for name in dir(registry) if isinstance(getattr(registry, name), registry_utils.TypedRegistry) } expected_registry_names = { 'rewards', diff --git a/tests/test_reward_modeling.py b/tests/test_reward_modeling.py index 054a4af9..ce3b4010 100644 --- a/tests/test_reward_modeling.py +++ b/tests/test_reward_modeling.py @@ -24,8 +24,7 @@ from transformers import AutoTokenizer from transformers.models.llama.modeling_llama import LlamaAttention -from compose_rl.algorithms.reward_modeling.hf_utils import \ - AutoModelForCausalLMWithRM +from compose_rl.algorithms.reward_modeling.hf_utils import AutoModelForCausalLMWithRM from compose_rl.data import ( finegrained_preference_dataset_collate_fn, pairwise_preference_dataset_collate_fn, @@ -146,18 +145,14 @@ def gen_random_batch( size=(batch_size, test_cfg.max_seq_len * 2), dtype=torch.int64, ).to(device) - batch['chosen_len'] = ( - torch.ones( - size=(batch_size,), - dtype=torch.int64, - ) * test_cfg.max_seq_len - ).to(device) - batch['rejected_len'] = ( - torch.ones( - size=(batch_size,), - dtype=torch.int64, - ) * test_cfg.max_seq_len - ).to(device) + batch['chosen_len'] = (torch.ones( + size=(batch_size,), + dtype=torch.int64, + ) * test_cfg.max_seq_len).to(device) + batch['rejected_len'] = (torch.ones( + size=(batch_size,), + dtype=torch.int64, + ) * test_cfg.max_seq_len).to(device) return batch @@ -189,8 +184,7 @@ def test_forward_backward_hf_automodel(): pytest.param( 'tests/yamls/testing_hf_classifier.yaml', marks=pytest.mark.skip( - reason= - 'TODO: reenable. temporarily skipping to turn GPU CI back on.', + reason='TODO: reenable. temporarily skipping to turn GPU CI back on.', ), ), ], @@ -398,10 +392,8 @@ def test_flashattention2(world_size: int): out_flash = {k: v.to('cpu') for k, v in out_flash.items()} assert torch.all( - out['chosen_scores'].bfloat16() - != out['rejected_scores'].bfloat16(), + out['chosen_scores'].bfloat16() != out['rejected_scores'].bfloat16(), ) assert torch.all( - out_flash['chosen_scores'].bfloat16() - != out_flash['rejected_scores'].bfloat16(), + out_flash['chosen_scores'].bfloat16() != out_flash['rejected_scores'].bfloat16(), ) diff --git a/tests/test_single_controller.py b/tests/test_single_controller.py index 1df09c74..348cf5fd 100644 --- a/tests/test_single_controller.py +++ b/tests/test_single_controller.py @@ -136,9 +136,7 @@ def test_distributed_ray_actors( assert results == [num_train_actors] * num_train_actors vllm_tensor_parallel_size = world_size - num_train_actors - num_vllm_engines = ( - world_size - num_train_actors - ) // vllm_tensor_parallel_size + num_vllm_engines = (world_size - num_train_actors) // vllm_tensor_parallel_size logger.info(f'num_vllm_engines: {num_vllm_engines}') vllm_engines = create_vllm_engines( num_engines=num_vllm_engines, diff --git a/tests/test_single_controller_ppo.py b/tests/test_single_controller_ppo.py index 401683cf..50533254 100644 --- a/tests/test_single_controller_ppo.py +++ b/tests/test_single_controller_ppo.py @@ -120,22 +120,14 @@ def build_train_config(self, pretrain_model_name: str): 'kl_estimator': 'k1', 'kl_clip_range': 40.0, }, - 'fsdp_config': - self.fsdp_config, - 'seed': - 17, - 'precision': - self.precision, - 'variables': - variables, - 'max_seq_len': - self.max_seq_len, - 'global_train_batch_size': - self.device_train_batch_size * self.world_size, - 'device_train_batch_size': - self.device_train_batch_size, - 'device_train_microbatch_size': - self.device_train_batch_size, + 'fsdp_config': self.fsdp_config, + 'seed': 17, + 'precision': self.precision, + 'variables': variables, + 'max_seq_len': self.max_seq_len, + 'global_train_batch_size': self.device_train_batch_size * self.world_size, + 'device_train_batch_size': self.device_train_batch_size, + 'device_train_microbatch_size': self.device_train_batch_size, } def build_dataloader(self): @@ -233,8 +225,7 @@ def build_ref_model(self): parallelism_config={'fsdp': self.fsdp_config}, save_folder=tmp_ref_path, save_weights_only=True, - device_train_microbatch_size=self. - device_train_microbatch_size, # type: ignore + device_train_microbatch_size=self.device_train_microbatch_size, # type: ignore ) temp_trainer.fit() @@ -421,43 +412,30 @@ class TrainActorGroup(SPMDActorGroup): def build_models(self, pretrain_model_name: str): """Build reference models and PPO trainers for all actors.""" build_train_config_tasks = [ - actor.build_train_config.remote(pretrain_model_name) - for actor in self._train_actors + actor.build_train_config.remote(pretrain_model_name) for actor in self._train_actors ] ray.get(build_train_config_tasks) - init_task = [ - actor.init_composer_dist.remote() for actor in self._train_actors - ] + init_task = [actor.init_composer_dist.remote() for actor in self._train_actors] ray.get(init_task) # Build reference models - build_ref_model_tasks = [ - actor.build_ref_model.remote() for actor in self._train_actors - ] + build_ref_model_tasks = [actor.build_ref_model.remote() for actor in self._train_actors] ray.get(build_ref_model_tasks) print('build ref model done') # Build PPO trainers - build_ppo_trainer_tasks = [ - actor.build_ppo_trainer.remote() for actor in self._train_actors - ] + build_ppo_trainer_tasks = [actor.build_ppo_trainer.remote() for actor in self._train_actors] ray.get(build_ppo_trainer_tasks) print('build ppo trainer done') def update_inference_model(self, vllm_engines: list[Any]): - refs = [ - actor.update_inference_model.remote(vllm_engines) - for actor in self._train_actors - ] + refs = [actor.update_inference_model.remote(vllm_engines) for actor in self._train_actors] ray.get(refs) print('update inference model done') def query_inference_engines(self, vllm_engines: list[Any]): - refs = [ - actor.query_inference_engines.remote(vllm_engines) - for actor in self._train_actors - ] + refs = [actor.query_inference_engines.remote(vllm_engines) for actor in self._train_actors] ray.get(refs) print('query inference engines done') @@ -550,9 +528,7 @@ def _run_single_controller_ppo( # Create vLLM engines (or inference actors) vllm_tensor_parallel_size = world_size - num_train_actors - num_vllm_engines = ( - world_size - num_train_actors - ) // vllm_tensor_parallel_size + num_vllm_engines = (world_size - num_train_actors) // vllm_tensor_parallel_size # TODO: Encapsulate this into a inference server manager class vllm_engines = create_vllm_engines( num_engines=num_vllm_engines, diff --git a/tests/test_utils.py b/tests/test_utils.py index c913ccbc..9b64db5e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -24,10 +24,8 @@ def test_mask_eos_basic_functionality(): # right_padded_obs structure: [prompt tokens, action tokens, padding] right_padded_obs = torch.tensor([ - [101, 102, 103, 104, 105, 1, 2, 3, 50, 5, 6, 7, 8, 9, - 10], # 5 prompt tokens + 10 action tokens - [201, 202, 203, 204, 205, 11, 12, 13, 14, 15, 50, 17, 18, 19, - 20], # 5 prompt tokens + 10 action tokens + [101, 102, 103, 104, 105, 1, 2, 3, 50, 5, 6, 7, 8, 9, 10], # 5 prompt tokens + 10 action tokens + [201, 202, 203, 204, 205, 11, 12, 13, 14, 15, 50, 17, 18, 19, 20], # 5 prompt tokens + 10 action tokens ]) right_padded_attn_mask = torch.ones_like(right_padded_obs, dtype=torch.bool) @@ -115,10 +113,8 @@ def test_mask_eos_no_eos(): def test_mask_eos_multiple_eos_tokens(): # Test with multiple possible EOS tokens actions = torch.tensor([ - [1, 2, 3, 50, 5, 6, 7, 8, 9, - 10], # First sequence has EOS (50) at index 3 - [11, 12, 13, 14, 15, 51, 17, 18, 19, - 20], # Second sequence has EOS (51) at index 5 + [1, 2, 3, 50, 5, 6, 7, 8, 9, 10], # First sequence has EOS (50) at index 3 + [11, 12, 13, 14, 15, 51, 17, 18, 19, 20], # Second sequence has EOS (51) at index 5 ]) # right_padded_obs includes prompt + actions @@ -264,10 +260,8 @@ def test_mask_eos_varying_prompt_lengths(): # right_padded_obs with different prompt lengths right_padded_obs = torch.tensor([ - [101, 102, 103, 1, 2, 3, 50, 5, 6, 7, 8, 9, 10, 999, - 999], # 3 prompt tokens + 10 action tokens + padding - [201, 202, 203, 204, 205, 206, 207, 11, 12, 13, 14, 15, 50, 17, - 18], # 7 prompt tokens + 8 action tokens + [101, 102, 103, 1, 2, 3, 50, 5, 6, 7, 8, 9, 10, 999, 999], # 3 prompt tokens + 10 action tokens + padding + [201, 202, 203, 204, 205, 206, 207, 11, 12, 13, 14, 15, 50, 17, 18], # 7 prompt tokens + 8 action tokens ]) right_padded_attn_mask = torch.ones_like(right_padded_obs, dtype=torch.bool) @@ -449,8 +443,7 @@ def test_get_token_entropies_batch_variation(): # Batch 0 should have high entropy (close to log(vocab_size)) assert torch.allclose( token_entropies[0, :], - torch.log(torch.tensor(vocab_size, dtype=torch.float)) * - torch.ones(seq_len), + torch.log(torch.tensor(vocab_size, dtype=torch.float)) * torch.ones(seq_len), atol=1e-5, ) @@ -600,8 +593,7 @@ def test_get_sequence_entropies_single_item_batch(): # Calculate expected entropy expected_entropy = -( 0.5 * torch.log(torch.tensor(0.5)) + (vocab_size - 1) * - (0.5 / - (vocab_size - 1)) * torch.log(torch.tensor(0.5 / (vocab_size - 1))) + (0.5 / (vocab_size - 1)) * torch.log(torch.tensor(0.5 / (vocab_size - 1))) ) token_entropies = get_token_entropies(logits) @@ -623,8 +615,7 @@ def mock_fn( ): # For test purposes, just return a tensor with the right shape batch_size = batched_values.size(0) - gen_len = max_gen_len if isinstance(max_gen_len, - int) else max_gen_len.item() + gen_len = max_gen_len if isinstance(max_gen_len, int) else max_gen_len.item() gen_len = int(gen_len) vocab_size = batched_values.size(2) return torch.randn((batch_size, gen_len, vocab_size)) @@ -681,9 +672,7 @@ def test_get_entropies_integration(): # Fill with different distributions # For prompt tokens (these shouldn't matter except for the last prompt token) - logits[:, :prompt_seq_len, :] = torch.randn( - (batch_size, prompt_seq_len, vocab_size), - ) + logits[:, :prompt_seq_len, :] = torch.randn((batch_size, prompt_seq_len, vocab_size),) # IMPORTANT: get_batched_generated_values will extract tokens from prompt_len-1 to prompt_len+max_gen_len-1 # This includes the last token of the prompt and excludes the last token of the generation