diff --git a/llm/alignment/ppo/README.md b/llm/alignment/ppo/README.md index 0840bbfc4868..1ddda800ce66 100644 --- a/llm/alignment/ppo/README.md +++ b/llm/alignment/ppo/README.md @@ -80,7 +80,7 @@ wget https://paddlenlp.bj.bcebos.com/datasets/examples/ppo-kk.tgz && tar zxf ppo - `top_p`: 生成解码超参数 - `temperature`: 生成解码超参数 - `repetition_penalty`: 生成解码超参数 -- `num_return_sequences`: 生成解码超参数 +- `rollout_n`: 生成解码超参数 - `min_learning_rate`: Actor 模型的最小学习率 - `critic_learning_rate`: Critic 模型的最小学习率 - `recompute`: Actor 模型是否使用重计算策略,开启后可节省训练显存 diff --git a/paddlenlp/datasets/rlhf_datasets/protocol.py b/paddlenlp/datasets/rlhf_datasets/protocol.py index ef47106c2c69..0107c8a9cdeb 100644 --- a/paddlenlp/datasets/rlhf_datasets/protocol.py +++ b/paddlenlp/datasets/rlhf_datasets/protocol.py @@ -16,15 +16,18 @@ Implement base data transfer protocol between any two functions, modules. We can subclass Protocol to define more detailed batch info with specific keys """ + import copy from dataclasses import dataclass, field -from typing import Dict, List, Union +from typing import Dict, List, Sequence, Union import numpy as np import paddle import pandas as pd from paddle.io import DataLoader +from paddle.utils import map_structure +original_concat = paddle.concat __all__ = [ "DataProto", "union_tensor_dict", @@ -49,7 +52,18 @@ def __setitem__(self, key: str, tensor: paddle.Tensor): self._tensors[key] = tensor def __getitem__(self, key): - return self._tensors[key] + if isinstance(key, str): + return self._tensors[key] + elif isinstance(key, slice): + strides = [1] if key.step is None else [key.step] + tensor_dict_slice = { + k: paddle.strided_slice(v, axes=[0], starts=[key.start], ends=[key.stop], strides=strides) + for k, v in self._tensors.items() + } + batch_size = tensor_dict_slice[list(tensor_dict_slice.keys())[0]].shape[: self.num_batch_dims] + return TensorDict(tensor_dict_slice, batch_size=batch_size, num_batch_dims=self.num_batch_dims) + else: + raise KeyError(f"Unsupported key type: {type(key)}") def keys(self): return self._tensors.keys() @@ -62,6 +76,58 @@ def to(self, device: str): self._tensors[key] = self._tensors[key].to(device) return self + @classmethod + def concat(cls, tensordict_list, axis=0): + if not tensordict_list: + raise ValueError("tensordict_list must not be empty") + + # 获取第一个 TensorDict 的键和对应的张量形状,用于验证后续 TensorDict 的一致性 + first_tensordict = tensordict_list[0] + first_keys = first_tensordict.keys() + first_shapes = {key: tensor.shape for key, tensor in first_tensordict.items()} + + # 验证所有 TensorDict 是否具有相同的键和对应的张量形状(除了拼接维度) + for tensordict in tensordict_list: + if tensordict.keys() != first_keys: + raise ValueError("All TensorDict objects must have the same keys") + for key in first_keys: + if ( + tensordict[key].shape[:axis] + tensordict[key].shape[axis + 1 :] + != first_shapes[key][:axis] + first_shapes[key][axis + 1 :] + ): + raise ValueError(f"Shapes of tensor '{key}' do not match except on concatenation axis {axis}") + + # 拼接每个键对应的张量 + concatenated_tensors = { + key: paddle.concat([tensordict[key] for tensordict in tensordict_list], axis=axis) for key in first_keys + } + + # 创建一个新的 TensorDict 对象并返回 + batch_size = concatenated_tensors[list(concatenated_tensors.keys())[0]].shape[ + : tensordict_list[0].num_batch_dims + ] + return cls(concatenated_tensors, batch_size=batch_size, num_batch_dims=tensordict_list[0].num_batch_dims) + + +def tensordict_concat( + x: Union[Sequence[paddle.Tensor], Sequence[TensorDict]], + axis: int | paddle.Tensor = 0, + name: str | None = None, +): + def is_tensor_sequence(): + if isinstance(x[0], paddle.Tensor): + return True + else: + return False + + if not is_tensor_sequence() and paddle.in_dynamic_mode(): + return TensorDict.concat(x) + else: + return original_concat(x, axis, name) + + +paddle.concat = tensordict_concat + def union_two_dict(dict1: Dict, dict2: Dict): """Union two dict. Will throw an error if there is an item not the same object with the same key. @@ -239,9 +305,39 @@ def __len__(self): return 0 def __getitem__(self, item): - tensor_data = self.batch[item] - non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} - return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) + """ + Enhanced indexing for DataProto objects. + + Args: + item: Can be one of: + - int: A single index + - slice: A slice object (start:stop:step) + - list: A list of indices + - numpy.ndarray: An array of indices + - torch.Tensor: A tensor of indices + + Returns: + DataProto: For all indexing types except single integers + DataProtoItem: Only for single integer indices + """ + # Case 1: Slice object - use the slice method + if isinstance(item, slice): + return self.slice(item.start, item.stop, item.step) + + # Case 2: List, numpy array, or torch tensor - use sel_idxs + elif isinstance(item, (list, np.ndarray, paddle.Tensor)): + return self.select_idxs(item) + + # Case 3: Single integer - return DataProtoItem for backward compatibility + elif isinstance(item, (int, np.integer)): + tensor_data = self.batch[item] + non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} + return_type = DataProto if isinstance(item, slice) else DataProtoItem + return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) + + # Case 4: Unsupported type + else: + raise TypeError(f"Indexing with {type(item)} is not supported") def print_size(self, prefix=""): size_of_tensordict = 0 @@ -385,6 +481,87 @@ def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=Non return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) + def select_idxs(self, idxs): + """ + Select specific indices from the DataProto. + + Args: + idxs (torch.Tensor or numpy.ndarray or list): Indices to select + + Returns: + DataProto: A new DataProto containing only the selected indices + """ + if isinstance(idxs, list): + idxs = paddle.tensor(idxs, dtype=paddle.int32) + + if isinstance(idxs, np.ndarray): + idxs_np = idxs + idxs_paddle = paddle.from_numpy(idxs) + else: # torch.Tensor + idxs_paddle = idxs + idxs_np = idxs.numpy() + + if self.batch is not None: + # Use TensorDict's built-in indexing capabilities + selected_batch = TensorDict( + source={key: tensor[idxs_paddle] for key, tensor in self.batch.items()}, + batch_size=(idxs_paddle.shape[0],), + ) + else: + selected_batch = None + + selected_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + selected_non_tensor[key] = val[idxs_np] + + return DataProto(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info) + + def slice(self, start=None, end=None, step=None): + """ + Slice the DataProto and return a new DataProto object. + This is an improved version of direct slicing which returns a DataProtoItem. + + Args: + start (int, optional): Start index. Defaults to None (start from beginning). + end (int, optional): End index (exclusive). Defaults to None (go to end). + step (int, optional): Step size. Defaults to None (step=1). + + Returns: + DataProto: A new DataProto containing the sliced data + + Examples: + # Using the slice method directly + sliced_data = data_proto.slice(10, 20) + + # Using enhanced indexing (returns DataProto) + sliced_data = data_proto[10:20] + sliced_data = data_proto[::2] # Every other element + + # Using list indexing (returns DataProto) + indices = [1, 5, 10] + selected_data = data_proto[indices] + + # Single index still returns DataProtoItem + single_item = data_proto[5] + """ + # Create a slice object + slice_obj = slice(start, end, step) + + # Handle the batch data + if self.batch is not None: + # Use TensorDict's built-in slicing capabilities + sliced_batch = self.batch[slice_obj] + else: + sliced_batch = None + + # Handle the non-tensor batch data + sliced_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + sliced_non_tensor[key] = val[slice_obj] + + # Return a new DataProto object + return DataProto(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info) + def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto": """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` @@ -447,6 +624,7 @@ def validate_input(keys): def union(self, other: "DataProto") -> "DataProto": """Union with another DataProto. Union batch and meta_info separately. Throw an error if + - there are conflict keys in batch and they are not equal - the batch size of two data batch is not the same - there are conflict keys in meta_info and they are not the same. diff --git a/paddlenlp/rl/algos/advantage.py b/paddlenlp/rl/algos/advantage.py index 48d68b5d1cc5..b0bc1891aba1 100644 --- a/paddlenlp/rl/algos/advantage.py +++ b/paddlenlp/rl/algos/advantage.py @@ -20,6 +20,61 @@ from ..utils.comm_utils import masked_whiten +def compute_gae_advantage_return( + token_level_rewards: paddle.Tensor, + values: paddle.Tensor, + sequence_mask: paddle.Tensor, + start: int, + gamma: paddle.Tensor, + lam: paddle.Tensor, + use_tgt_len_return: bool = True, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute advantages and returns using Generalized Advantage Estimation (GAE).""" + # Modified from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py + lastgaelam = 0.0 + advantages_reversed = [] + gen_len = token_level_rewards.shape[-1] + + values = values * sequence_mask + token_level_rewards = token_level_rewards * sequence_mask + if use_tgt_len_return and start > 0: + # consistent with Beaver + # values length is src+tgt-1, start is src-1, return length is tgt + pass + elif use_tgt_len_return: + # values length is tgt, start is 0, return length is tgt + assert start == 0 + else: + # values length is src+tgt-1, start is src-1, return length is src+tgt-1 + pass + for t in reversed(range(start, gen_len)): # pylint: disable=invalid-name + next_values = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = token_level_rewards[:, t] + gamma * next_values - values[:, t] + lastgaelam = delta + gamma * lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = paddle.stack(advantages_reversed[::-1], axis=1) + + returns = advantages + values[:, start:].contiguous() + + if not use_tgt_len_return: + advantages = paddle.concat( + [ + paddle.zeros([advantages.shape[0], start], dtype=advantages.dtype), + advantages, + ], + axis=-1, + ) + returns = paddle.concat( + [ + paddle.zeros([returns.shape[0], start], dtype=returns.dtype), + returns, + ], + axis=-1, + ) + + return advantages.detach(), returns + + @paddle.no_grad() def compute_grpo_advantages( rewards: paddle.Tensor, @@ -101,3 +156,46 @@ def compute_reinforce_plus_plus_advantages_and_returns( advantages = masked_whiten(returns, eos_mask) advantages = advantages * eos_mask return advantages, returns + + +def add_kl_divergence_regularization( + prompt: paddle.Tensor, # size = (B, S) # pylint: disable=unused-argument + log_probs: paddle.Tensor, # size = (B, L) + ref_log_probs: paddle.Tensor, # size = (B, L) + reward_score: paddle.Tensor, # size = (B,) + sequence_mask: paddle.Tensor, # size = (B, L) + kl_coeff: float, + clip_range_score: float, +) -> paddle.Tensor: + """ + Calculate the KL divergence regularization gain and add it to the reward. + + Args: + prompt (paddle.Tensor, shape=(B, S)): The prompt of the input sequence, not used. + log_probs (paddle.Tensor, shape=(B, L)): The log probability distribution of the current predictions. + ref_log_probs (paddle.Tensor, shape=(B, L)): The log probability distribution of the baseline predictions. + reward_score (paddle.Tensor, shape=(B,)): The base reward score based on the prompt and output sequence. + sequence_mask (paddle.Tensor, shape=(B, L)): The mask of the sequence, used to determine the length of the sequence. + + Returns: + paddle.Tensor, shape=(B, L): A vector containing the KL divergence regularization gain. + """ + + kl_divergence_estimate = -kl_coeff * (log_probs - ref_log_probs) # size = (B, L) + rewards = kl_divergence_estimate # size = (B, L) + reward_clip = paddle.clip( # size = (B,) + reward_score, + min=-clip_range_score, + max=clip_range_score, + ) + # TODO(guosheng): use scatter_add/put_along_axis + index = paddle.cumsum(sequence_mask.cast(paddle.int64), axis=-1).argmax(-1, keepdim=True) + + rewards = paddle.put_along_axis( + rewards, + index, + reward_clip.unsqueeze(axis=-1), + axis=-1, + reduce="add", + ) + return rewards, kl_divergence_estimate diff --git a/paddlenlp/rl/trainer/actor_trainer.py b/paddlenlp/rl/trainer/actor_trainer.py index e37795ce9dca..b8e9a5fb1c72 100644 --- a/paddlenlp/rl/trainer/actor_trainer.py +++ b/paddlenlp/rl/trainer/actor_trainer.py @@ -16,11 +16,14 @@ import numpy as np import paddle +import paddle.distributed as dist from paddle import nn from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy from paddle.io import Dataset +from paddle.utils import map_structure from ...data import DataCollator +from ...datasets.rlhf_datasets.protocol import DataProto, TensorDict from ...generation import GenerationConfig from ...trainer.trainer import ( EvalPrediction, @@ -33,9 +36,11 @@ RLHFPPOMixedLoss, create_startend_row_indices, gather_log_probabilities, + make_position_ids_from_input_ids, ) +from ..utils.infer_utils import infer_guard from .rl_trainer import RLTrainer -from .trainer_utils import guard_set_args +from .trainer_utils import guard_set_args, process_row class ActorReferenceTrainer(RLTrainer): @@ -72,7 +77,7 @@ def __init__( self.generation_config = GenerationConfig( max_new_tokens=self.args.max_dec_len, - num_return_sequences=self.args.num_return_sequences, + rollout_n=self.args.rollout_n, temperature=self.args.temperature, top_p=self.args.top_p, top_k=0, # to disable top_k sampling, default is 50 @@ -101,28 +106,80 @@ def loss_identifier(self, inputs: Dict) -> str: """ return "actor_loss" + def truncate_batch_data(self, batch, truncate_max_len): + if len(batch) > truncate_max_len: + batch = self.tokenizer.truncate_sequences( + batch, + num_tokens_to_remove=len(batch) - truncate_max_len, + truncation_strategy="longest_first", + )[0] + return batch + + def pad_batch_data(self, batches, padding_strategy="longest", padding_max_len=None, pad_to_multiple_of=None): + input_ids = self.tokenizer.pad( + {"input_ids": batches}, + padding=padding_strategy, + padding_side="right", + max_length=padding_max_len, + return_attention_mask=False, + pad_to_multiple_of=pad_to_multiple_of, + )["input_ids"] + + position_ids = make_position_ids_from_input_ids(input_ids) + return input_ids, position_ids + + @paddle.no_grad() + def generate_sequences(self, prompts: DataProto, do_eval=False) -> List[Dict[str, Any]]: + cleanup_batches, indices, label_ids_batches = [], [], [] + total_batch_size = prompts.batch["input_ids"].shape[0] + per_device_rollout_batch_size = self.args.per_device_rollout_batch_size + with infer_guard(self): + generated_batches = [] + for i in range(0, total_batch_size, per_device_rollout_batch_size): + micro_batch = prompts[i : i + per_device_rollout_batch_size] + + # generate for multi batches and then disable FuseMT model + generated_batch = self._generate_sequences(micro_batch) + generated_batches.append(generated_batch) + + batch = DataProto.concat(generated_batches) + cur_batch = self.truncate_batch_data( + cleanup_batches, truncate_max_len=self._model_config.max_position_embeddings + ) + if self._model_config.sequence_parallel: + pad_to_multiple_of = self.args.tensor_parallel_degree + else: + pad_to_multiple_of = None + input_ids, position_ids = self.pad_batch_data(cur_batch, pad_to_multiple_of=pad_to_multiple_of) + prompt = batch["input_ids"] + + batch = { + "prompt": prompt, + "input_ids": input_ids, + "position_ids": position_ids, + "index": indices, + **({"label_ids": label_ids_batches} if self.args.use_rm_server else {}), + } + + self.timers and dist.get_world_size() > 1 and dist.barrier() + return cleanup_batches, indices, label_ids_batches + @paddle.no_grad() - def generate_sequences(self, prompt_only_batch: Dict, do_eval=False) -> List[Dict[str, Any]]: + def _generate_sequences(self, micro_batch: DataProto, do_eval=False) -> DataProto: """Rollout a batch of experiences.""" - input_ids = prompt_only_batch["input_ids"] - # attention_mask = prompt_only_batch["attention_mask"] + input_ids = micro_batch.batch["input_ids"] + batch_size = input_ids.shape[0] if do_eval: - train_num_return_sequences = self.args.num_return_sequences - self.args.num_return_sequences = 1 - - # position_ids = ( - # prompt_only_batch["position_ids"] - # if "position_ids" in prompt_only_batch - # else make_position_ids(attention_mask) - # ) - - if self.args.num_return_sequences > 1: - input_ids = input_ids.repeat_interleave(self.args.num_return_sequences, axis=0) - # raw_dtype = attention_mask.dtype - # attention_mask = ( - # attention_mask.cast("int32").repeat_interleave(self.args.num_return_sequences, axis=0).cast(raw_dtype) - # ) - # position_ids = position_ids.repeat_interleave(self.args.num_return_sequences, axis=0) + train_rollout_n = self.args.rollout_n + self.args.rollout_n = 1 + + if self.args.rollout_n > 1: + input_ids = input_ids.repeat_interleave(self.args.rollout_n, axis=0) + if self.args.use_rm_server: + label_ids = micro_batch.batch["label_ids"] + label_ids = label_ids.repeat_interleave(self.args.rollout_n, axis=0) + else: + label_ids = None with guard_set_args(self.model.config, {"use_fused_head_and_loss_fn": False}): sequences = self.get_model(False).generate( @@ -134,41 +191,44 @@ def generate_sequences(self, prompt_only_batch: Dict, do_eval=False) -> List[Dic do_eval=do_eval, )[0] - if self.args.use_rm_server: - label_ids = prompt_only_batch["label_ids"] - if self.args.num_return_sequences > 1: - label_ids = label_ids.repeat_interleave(self.args.num_return_sequences, axis=0) + indices = [] + for _ in range(batch_size): + indices.extend([str(uuid.uuid4())] * self.args.rollout_n) + indices = np.array(indices, dtype=object) + + # sequences, label_ids = self._post_process_generate_outputs(sequences, label_ids) - sequences = sequences.reshape( - [input_ids.shape[0] // self.args.num_return_sequences, self.args.num_return_sequences, -1] - ) if do_eval: - self.args.num_return_sequences = train_num_return_sequences - sequences = sequences.transpose([1, 0, 2]) - # prompt, sequence, attention_mask - return [ + self.args.rollout_n = train_rollout_n + + # prompt : [batch_size*rollout_n, seq_len] + # input_ids : [batch_size*rollout_n, seq_len] + # indexes : [batch_size*rollout_n] + batch = TensorDict( { "prompt": input_ids, - "input_ids": seq, - **({"label_ids": label_ids[idx * len(seq) : (idx + 1) * len(seq)]} if self.args.use_rm_server else {}), - "index": np.array([str(uuid.uuid4())] * len(seq), dtype=object), - # "attention_mask": make_attention_mask( - # seq, - # pad_id=self.tokenizer.pad_token_id, - # eos_id=None, - # unk_id=self.tokenizer.unk_token_id, - # causal_mask=True, - # ).cast(self._model_config.dtype), - # "sequence_mask": make_attention_mask( - # seq, - # pad_id=self.tokenizer.pad_token_id, - # eos_id=None, - # unk_id=self.tokenizer.unk_token_id, - # causal_mask=False, - # ).cast(self._model_config.dtype), - } - for idx, seq in enumerate(sequences) - ] + "input_ids": sequences, + **({"label_ids": label_ids} if self.args.use_rm_server else {}), + }, + batch_size=[batch_size * self.args.rollout_n], + ) + non_tensor_batch = { + "index": indices, + } + return DataProto(batch, non_tensor_batch) + + def _post_process_generate_outputs(self, sequences: paddle.Tensor, label_ids: paddle.Tensor) -> List[List[int]]: + output_sequences, output_label_ids = [], [] + + for sequence, label_id in zip(sequences, label_ids): + output_sequences.append( + process_row(sequence, remove_value=self.tokenizer.pad_token_id, remove_side="right") + ) + output_label_ids.append( + process_row(label_id, remove_value=self.tokenizer.pad_token_id, remove_side="left") + ) + + return output_sequences, output_label_ids @paddle.no_grad() def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor = None, **kwargs) -> paddle.Tensor: @@ -197,21 +257,8 @@ def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor """ log_probs_list = [] batch_size, sequence_length = input_ids.shape - if self.args.rollout_logprob_batch_size is None: - rollout_logprob_batch_size = batch_size - else: - if str(self.args.rollout_logprob_batch_size).lower() == "auto": - # Auto compute - if sequence_length > 4096 - 128: - rollout_logprob_batch_size = 2 - elif sequence_length > 2048 - 128: - rollout_logprob_batch_size = 4 - else: - rollout_logprob_batch_size = batch_size - else: - rollout_logprob_batch_size = int(self.args.rollout_logprob_batch_size) - - num_batches = (batch_size + rollout_logprob_batch_size - 1) // rollout_logprob_batch_size + per_device_logprob_batch_size = self.args.per_device_logprob_batch_size + num_batches = (batch_size + per_device_logprob_batch_size - 1) // per_device_logprob_batch_size # Pipe model outputs a logits tensor with LMHead, while non-pipe model # outputs a tuple with logits tensor as the only one element. @@ -220,8 +267,8 @@ def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor for i in range(num_batches): # Calculate the start and end indices for the current batch - start_index = i * rollout_logprob_batch_size - end_index = min(start_index + rollout_logprob_batch_size, batch_size) + start_index = i * per_device_logprob_batch_size + end_index = min(start_index + per_device_logprob_batch_size, batch_size) # Extract the current batch current_input_ids = input_ids[start_index:end_index] diff --git a/paddlenlp/rl/trainer/ppo_trainer.py b/paddlenlp/rl/trainer/ppo_trainer.py index 3a2f09ab9012..7c9fae78a83d 100644 --- a/paddlenlp/rl/trainer/ppo_trainer.py +++ b/paddlenlp/rl/trainer/ppo_trainer.py @@ -28,11 +28,11 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import PipelineLayer from paddle.io import DataLoader, Dataset, DistributedBatchSampler -from paddle.utils import map_structure from rich.console import Console from rich.table import Table from ...data import DataCollator +from ...datasets.rlhf_datasets.protocol import DataProto from ...trainer.trainer import ( EvalLoopOutput, EvalPrediction, @@ -55,6 +55,8 @@ from ...transformers.model_utils import _add_variant from ...utils.env import PADDLE_WEIGHTS_NAME from ..algos.advantage import ( + add_kl_divergence_regularization, + compute_gae_advantage_return, compute_grpo_advantages, compute_reinforce_plus_plus_advantages_and_returns, ) @@ -84,7 +86,6 @@ batch_retokenize, guard_set_args, is_same_tokenizer, - process_row, ) @@ -1010,11 +1011,7 @@ def init_train_num( if not self._is_iterable_dataset(self.train_dataset): len_dataloader = len(train_dataloader) num_train_sub_steps = ( - len_dataloader - * self.args.update_iters - * self.args.per_device_prompt_batch_size - * self.args.num_return_sequences - // self.args.per_device_train_batch_size + len_dataloader * self.args.update_iters * self.args.global_batch_size * self.args.rollout_n ) num_update_steps_per_epoch = num_train_sub_steps // args.gradient_accumulation_steps num_examples = len(self.train_dataset) @@ -1078,59 +1075,6 @@ def get_step_loss(self, loss_prefix: str = "") -> Dict: rl_loss.update(value_loss) return rl_loss - def remove_pad_tokens_after_generate(self, generated_batches): - cleanup_batches, indices, label_ids_batches = [], [], [] - - for batch in generated_batches: - cleanup_batches.extend( - [ - process_row( - row, - remove_value=self.tokenizer.pad_token_id, - remove_side="right", - eos_token_id=self.tokenizer.eos_token_id, - ) - for row in batch["input_ids"] - ] - ) - if self.args.use_rm_server: - label_ids_batches.extend( - [ - process_row( - row, - remove_value=self.tokenizer.pad_token_id, - remove_side="left", - eos_token_id=self.tokenizer.eos_token_id, - ) - for row in batch["label_ids"] - ] - ) - indices.append(batch["index"]) - - return cleanup_batches, indices, label_ids_batches - - def truncate_batch_data(self, batch, truncate_max_len): - if len(batch) > truncate_max_len: - batch = self.tokenizer.truncate_sequences( - batch, - num_tokens_to_remove=len(batch) - truncate_max_len, - truncation_strategy="longest_first", - )[0] - return batch - - def pad_batch_data(self, batches, padding_strategy="longest", padding_max_len=None, pad_to_multiple_of=None): - input_ids = self.tokenizer.pad( - {"input_ids": batches}, - padding=padding_strategy, - padding_side="right", - max_length=padding_max_len, - return_attention_mask=False, - pad_to_multiple_of=pad_to_multiple_of, - )["input_ids"] - - position_ids = make_position_ids_from_input_ids(input_ids) - return input_ids, position_ids - def distribute_gather_and_pad_data(self, micro_batches): # group index for grpo index = [micro_batch["index"] for micro_batch in micro_batches] @@ -1149,6 +1093,7 @@ def distribute_gather_and_pad_data(self, micro_batches): dp_group = hcg.get_data_parallel_group() except AttributeError: pass + new_batch = { "index": gather_and_pad(index, dp_group, sd_group, pad=False), "rewards": gather_and_pad(rewards, dp_group, sd_group, pad=False), @@ -1277,7 +1222,7 @@ def train( with ( guard_set_args( args, - {"per_device_train_batch_size": self.args.per_device_prompt_batch_size}, + {"per_device_train_batch_size": self.args.global_batch_size // self.args.dataset_world_size}, ), guard_set_args( self, @@ -1358,6 +1303,7 @@ def train( step = -1 for prompt_only_batch in self.prompt_only_dataloader: + batch: DataProto = DataProto.from_single_dict(prompt_only_batch) self.control = self.callback_handler.on_step_begin(args, self.state, self.control) # step 1-1: rollout data with actor model (eval) and reward model self.set_eval() @@ -1639,8 +1585,7 @@ def train( with reload_and_offload_scope(self, self.actor_model, self.actor_trainer.optimizer): with TimerScope(self.timers, ActorStages.RL_STEP): # timer_info = {} # prepare for each micro_step - - for micro_step, rl_batch in enumerate(train_batch): + for micro_step, rl_batch in enumerate(train_batch * self.args.update_iters): step = 0 if step == -1 else step with TimerScopeManualLabel( self.timers, @@ -1809,47 +1754,6 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, with guard_set_args(self.control, {"should_log": False}): super()._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval) - def add_kl_divergence_regularization( - self, - prompt: paddle.Tensor, # size = (B, S) # pylint: disable=unused-argument - log_probs: paddle.Tensor, # size = (B, L) - ref_log_probs: paddle.Tensor, # size = (B, L) - reward_score: paddle.Tensor, # size = (B,) - sequence_mask: paddle.Tensor, # size = (B, L) - ) -> paddle.Tensor: - """ - Calculate the KL divergence regularization gain and add it to the reward. - - Args: - prompt (paddle.Tensor, shape=(B, S)): The prompt of the input sequence, not used. - log_probs (paddle.Tensor, shape=(B, L)): The log probability distribution of the current predictions. - ref_log_probs (paddle.Tensor, shape=(B, L)): The log probability distribution of the baseline predictions. - reward_score (paddle.Tensor, shape=(B,)): The base reward score based on the prompt and output sequence. - sequence_mask (paddle.Tensor, shape=(B, L)): The mask of the sequence, used to determine the length of the sequence. - - Returns: - paddle.Tensor, shape=(B, L): A vector containing the KL divergence regularization gain. - """ - - kl_divergence_estimate = -self.kl_coeff * (log_probs - ref_log_probs) # size = (B, L) - rewards = kl_divergence_estimate # size = (B, L) - reward_clip = paddle.clip( # size = (B,) - reward_score, - min=-self.clip_range_score, - max=self.clip_range_score, - ) - # TODO(guosheng): use scatter_add/put_along_axis - index = paddle.cumsum(sequence_mask.cast(paddle.int64), axis=-1).argmax(-1, keepdim=True) - - rewards = paddle.put_along_axis( - rewards, - index, - reward_clip.unsqueeze(axis=-1), - axis=-1, - reduce="add", - ) - return rewards, kl_divergence_estimate - def get_advantages_and_returns( self, values: paddle.Tensor, @@ -1966,29 +1870,35 @@ def compute_advantage(self, rl_batches, use_tgt_len_value): elif self.args.rl_algorithm == "ppo": start = rl_batch["prompt"].shape[-1] - 1 eos_mask = (rl_batch["input_ids"] != self.tokenizer.pad_token_id)[:, 1:].to(old_log_probs.dtype) - rewards_with_kl, kl_rewards = self.add_kl_divergence_regularization( + rewards_with_kl, kl_rewards = add_kl_divergence_regularization( None, # prompt, old_log_probs, ref_log_probs, rewards, eos_mask[:, start:], + self.kl_coeff, + self.clip_range_score, ) # length: tgt if use_tgt_len_value src + tgt -1 - reward_advantages, reward_returns = self.get_advantages_and_returns( - old_reward_values, + reward_advantages, reward_returns = compute_gae_advantage_return( rewards_with_kl, + old_reward_values, eos_mask[:, start:], start=0 if use_tgt_len_value else start, + gamma=self.gamma, + lam=self.gae_lambda, use_tgt_len_return=use_tgt_len_value, ) # length: tgt if use_tgt_len_value src + tgt -1 elif self.args.rl_algorithm == "reinforce_plus_plus": start = 0 eos_mask = rl_batch["eos_mask"] - rewards_with_kl, kl_rewards = self.add_kl_divergence_regularization( + rewards_with_kl, kl_rewards = add_kl_divergence_regularization( None, # prompt, old_log_probs, ref_log_probs, rewards, eos_mask[:, start:], + self.kl_coeff, + self.clip_range_score, ) # length: tgt if use_tgt_len_value src + tgt -1 reward_advantages, reward_returns = compute_reinforce_plus_plus_advantages_and_returns( rewards_with_kl, diff --git a/paddlenlp/rl/trainer/reward_trainer.py b/paddlenlp/rl/trainer/reward_trainer.py index 7704d0c9ad16..4875d39123f3 100644 --- a/paddlenlp/rl/trainer/reward_trainer.py +++ b/paddlenlp/rl/trainer/reward_trainer.py @@ -85,45 +85,53 @@ def compute_reward( label_ids: paddle.Tensor = None, **kwargs, ) -> Dict[str, paddle.Tensor]: - if not self.args.use_rm_server: - if self.tokenizer is not input_ids_tokenizer: - # right padding - reward_tokenize_output = batch_retokenize( - input_ids, - src_tokenizer=input_ids_tokenizer, - dest_tokenizer=self.tokenizer, + pre_device_reward_batch_size = self.args.per_device_eval_batch_size + reward_scores = [] + + for i in range(0, input_ids.shape[0], pre_device_reward_batch_size): + cur_input_ids = input_ids[i : i + pre_device_reward_batch_size] + cur_position_ids = position_ids[i : i + pre_device_reward_batch_size] + cur_label_ids = label_ids[i : i + pre_device_reward_batch_size] + + if not self.args.use_rm_server: + if self.tokenizer is not input_ids_tokenizer: + # right padding + reward_tokenize_output = batch_retokenize( + cur_input_ids, + src_tokenizer=input_ids_tokenizer, + dest_tokenizer=self.tokenizer, + ) + reward_input_ids = reward_tokenize_output["input_ids"] + reward_position_ids = reward_tokenize_output["position_ids"] + else: + reward_input_ids = cur_input_ids + reward_position_ids = cur_position_ids + + attn_mask_startend_row_indices = create_startend_row_indices( + reward_input_ids, self.tokenizer.pad_token_id ) - reward_input_ids = reward_tokenize_output["input_ids"] - reward_position_ids = reward_tokenize_output["position_ids"] + reward_score = self.model( + reward_input_ids, + attention_mask=None, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + position_ids=reward_position_ids, + )[1] else: - reward_input_ids = input_ids - reward_position_ids = position_ids - - attn_mask_startend_row_indices = create_startend_row_indices(reward_input_ids, self.tokenizer.pad_token_id) - reward_score = self.model( - reward_input_ids, - attention_mask=None, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - position_ids=reward_position_ids, - )[1] - else: - prompt_len = kwargs["prompt"].shape[-1] - if label_ids is None: - raise ValueError("Rule-based reward needs labels.") - src = input_ids_tokenizer.batch_decode(input_ids[:, :prompt_len], skip_special_tokens=False) - tgt = input_ids_tokenizer.batch_decode(label_ids, skip_special_tokens=False) - response = input_ids_tokenizer.batch_decode(input_ids[:, prompt_len:], skip_special_tokens=False) - reward_score = self.request_reward_server( - [i.replace(self.tokenizer.pad_token, "") for i in src], - [i.replace(self.tokenizer.pad_token, "") for i in tgt], - [i.replace(self.tokenizer.pad_token, "") for i in response], - ) - - reward_score = reward_score.squeeze(axis=-1) + prompt_len = kwargs["prompt"].shape[-1] + if cur_label_ids is None: + raise ValueError("Rule-based reward needs labels.") + src = input_ids_tokenizer.batch_decode(cur_input_ids[:, :prompt_len], skip_special_tokens=False) + tgt = input_ids_tokenizer.batch_decode(cur_label_ids, skip_special_tokens=False) + response = input_ids_tokenizer.batch_decode(cur_input_ids[:, prompt_len:], skip_special_tokens=False) + reward_score = self.request_reward_server( + [i.replace(self.tokenizer.pad_token, "") for i in src], + [i.replace(self.tokenizer.pad_token, "") for i in tgt], + [i.replace(self.tokenizer.pad_token, "") for i in response], + ) - return reward_score - # if self.args.rl_algorithm in ["grpo", "reinforce_plus_plus"]: - # return {"rewards": reward_score} + reward_score = reward_score.squeeze(axis=-1) + reward_scores.append(reward_score) + return paddle.concat(reward_scores, axis=0) def request_reward_server(self, src, tgt, response): data = {"src": src, "tgt": tgt, "response": response} diff --git a/paddlenlp/rl/utils/config_utils.py b/paddlenlp/rl/utils/config_utils.py index 08fb229b0a8b..3ae2d57da1f3 100644 --- a/paddlenlp/rl/utils/config_utils.py +++ b/paddlenlp/rl/utils/config_utils.py @@ -15,7 +15,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Optional import paddle @@ -25,9 +24,37 @@ @dataclass class TrainingArguments(TrainingArguments): - rollout_logprob_batch_size: str = field( - default=None, - metadata={"help": "The log prob batch size."}, + # rollout_logprob_batch_size: str = field( + # default=None, + # metadata={"help": "The log prob batch size."}, + # ) + global_batch_size: int = field( + default=8, + metadata={"help": "Global batch size for input prompt."}, + ) + mini_batch_size: int = field( + default=-1, + metadata={"help": "Mini-batch size (global) for the training dataloader."}, + ) + per_device_train_batch_size: int = field( + default=1, + metadata={"help": "Batch size (per device) for the training dataloader."}, + ) + per_device_rollout_batch_size: int = field( + default=-1, + metadata={"help": "Batch size (per device) for the training dataloader."}, + ) + per_device_logprob_batch_size: int = field( + default=-1, + metadata={"help": "Batch size (per device) for the training dataloader."}, + ) + per_device_reward_batch_size: int = field( + default=-1, + metadata={"help": "Batch size (per device) for the training dataloader."}, + ) + per_device_value_batch_size: int = field( + default=-1, + metadata={"help": "Batch size (per device) for the training dataloader."}, ) use_fused_rms_norm: bool = field( default=False, @@ -99,10 +126,6 @@ class TrainingArguments(TrainingArguments): "clip_range_value, value_estimate + clip_range_value] during training." }, ) - ptx_coeff: float = field( - default=0.0, - metadata={"help": "The coefficient for the ptx loss."}, - ) update_iters: int = field( default=1, metadata={"help": "The number of repeated updates on a generated batch."}, @@ -146,7 +169,7 @@ class TrainingArguments(TrainingArguments): "with probabilities that add up to`top_p` or higher are kept for generation." }, ) - num_return_sequences: int = field( + rollout_n: int = field( default=1, metadata={"help": "The number of independently computed returned sequences for each element in the batch."}, ) @@ -228,28 +251,6 @@ class TrainingArguments(TrainingArguments): "will use the min_learning_rate." }, ) - unified_checkpoint: bool = field( - default=True, - metadata={ - "help": "Enable fused linear grad add strategy, which will reduce elementwise " - "add for grad accumulation in the backward of nn.Linear ." - }, - ) - unified_checkpoint_config: Optional[str] = field( - default="", - metadata={ - "help": ( - "Configs to unify hybrid parallel checkpoint.\n" - "Following options are supports:\n" - "- skip_save_model_weight: do not save model weights when the masters weight exist\n" - "- master_weight_compatible: 1. if the master weights exist, only load when needed\n" - " 2. if master weights does not exist, convert model weights" - " to master weights when needed\n" - "- async_save: enable asynchronous saving checkpoints to disk\n" - "- enable_all_options: enable all optimization configurations\n" - ) - }, - ) autotuner_benchmark: bool = field( default=False, metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."}, @@ -276,10 +277,6 @@ class TrainingArguments(TrainingArguments): default=True, metadata={"help": "use tensor_parallel_output."}, ) - per_device_rollout_batch_size: int = field( - default=-1, - metadata={"help": "Batch size per GPU core/CPU for rollout."}, - ) # save_generation_output: bool = field( # default=False, # metadata={"help": "Whether to save generated text to file when eval"}, @@ -331,6 +328,30 @@ def __post_init__(self): Raises: None. """ + # obtain the parallrl degree from the training arguments + # for auto config the accumulation steps + self._post_init_parallel_degree() + + if self.mini_batch_size < 0: + self.mini_batch_size = self.global_batch_size + + if self.per_device_rollout_batch_size < 0: + self.per_device_train_batch_size = self.per_device_train_batch_size + if self.per_device_logprob_batch_size < 0: + self.per_device_logprob_batch_size = self.per_device_train_batch_size + if self.per_device_reward_batch_size < 0: + self.per_device_reward_batch_size = self.per_device_train_batch_size + if self.per_device_value_batch_size < 0: + self.per_device_value_batch_size = self.per_device_train_batch_size + + self.gradient_accumulation_steps = ( + self.mini_batch_size + * self.rollout_n + * self.update_iters + // self.per_device_train_batch_size + // self.dataset_world_size + ) + super().__post_init__() if self.autotuner_benchmark: self.num_train_epochs = 1 @@ -354,8 +375,6 @@ def __post_init__(self): paddle.set_device(self.device) - if self.per_device_rollout_batch_size < 0: - self.per_device_rollout_batch_size = self.per_device_train_batch_size assert self.rl_algorithm in [ "ppo", "grpo", @@ -365,14 +384,14 @@ def __post_init__(self): self.normalize_reward = False self.normalize_advantage = False - if self.per_device_eval_batch_size > self.per_device_rollout_batch_size * self.num_return_sequences: + if self.per_device_eval_batch_size > self.per_device_rollout_batch_size * self.rollout_n: logger.warning( f"per_device_eval_batch_size: {self.per_device_eval_batch_size} is larger than " - f"per_device_rollout_batch_size: {self.per_device_rollout_batch_size} * num_return_sequences: " - f"{self.num_return_sequences}, which may cause infer error. " - f"We will set it to per_device_rollout_batch_size * num_return_sequences!" + f"per_device_rollout_batch_size: {self.per_device_rollout_batch_size} * rollout_n: " + f"{self.rollout_n}, which may cause infer error. " + f"We will set it to per_device_rollout_batch_size * rollout_n!" ) - self.per_device_eval_batch_size = self.per_device_rollout_batch_size * self.num_return_sequences + self.per_device_eval_batch_size = self.per_device_rollout_batch_size * self.rollout_n self.offload_level = self.offload_level.split() diff --git a/paddlenlp/rl/utils/infer_utils.py b/paddlenlp/rl/utils/infer_utils.py index 64a97e4773a4..39b2dc29050c 100644 --- a/paddlenlp/rl/utils/infer_utils.py +++ b/paddlenlp/rl/utils/infer_utils.py @@ -185,7 +185,7 @@ def create_predictor(trainer: Trainer): min_length=trainer.args.min_dec_len, max_length=trainer.args.max_dec_len, total_max_length=trainer.args.max_src_len + trainer.args.max_dec_len, - batch_size=trainer.args.per_device_rollout_batch_size * trainer.args.num_return_sequences, + batch_size=trainer.args.per_device_rollout_batch_size * trainer.args.rollout_n, top_p=trainer.args.top_p, temperature=trainer.args.temperature, repetition_penalty=trainer.args.repetition_penalty, diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 1201e50a2044..c957b9bded48 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1059,6 +1059,7 @@ class TrainingArguments: ) def __post_init__(self): + world_size = paddle.distributed.get_world_size() if in_auto_parallel_align_mode(): self.max_grad_norm = 0.0 os.environ["FLAGS_max_inplace_grad_add"] = "65536" @@ -1155,118 +1156,7 @@ def __post_init__(self): if self.optim == OptimizerNames.ADAMW_MINI and self.tensor_parallel_degree > 1: raise ValueError("AdamW Mini currently doesn't support tensor parallelism.") - self.use_hybrid_parallel = False - - if isinstance(self.sharding, bool): - self.sharding = "stage1" if self.sharding else "" - if isinstance(self.sharding, str): - self.sharding = [ShardingOption(s) for s in self.sharding.split()] - if self.sharding == [ShardingOption.OFFLOAD]: - raise ValueError( - "`--sharding offload` can't work on its own. It needs to be added to `--sharding stage2` or " - '`--sharding stage3`. For example, `--sharding "stage2 offload"`.' - ) - elif len(self.sharding) > (ShardingOption.OFFLOAD in self.sharding) + 1: - raise ValueError("`--sharding` recived too many arguments.") - - if self.sharding_degree > 0: - warnings.warn("`sharding_degree` is deprecated, please use `sharding_parallel_degree`") - self.sharding_parallel_degree = max(self.sharding_degree, self.sharding_parallel_degree) - self.data_parallel_degree = 1 - - delattr(self, "sharding_degree") - - if len(self.sharding) == 0 and self.sharding_parallel_degree > 0: - warnings.warn("`--sharding_parallel_degree` is useful only when `--sharding` is specified.") - - world_size = paddle.distributed.get_world_size() - - if world_size > 1: - tensor_parallel_degree = max(self.tensor_parallel_degree, 1) - sep_parallel_degree = max(self.sep_parallel_degree, 1) - context_parallel_degree = max(self.context_parallel_degree, 1) - pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) - expert_parallel_degree = max(self.expert_parallel_degree, 1) - expert_tensor_parallel_degree = max(self.expert_tensor_parallel_degree, 1) - - # TODO(@gexiao): support expert_tensor_parallel_degree > 1 in the future - assert ( - expert_tensor_parallel_degree == 1 - ), f"Currently only support expert_tensor_parallel_degree=1, but got expert_tensor_parallel_degree of {expert_tensor_parallel_degree}" - - assert ( - world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0 - ), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}." - - assert not ( - sep_parallel_degree > 1 and context_parallel_degree > 1 - ), f"sep parallel and context parallel cannot be used together, sep_parallel_degree:{sep_parallel_degree}, context_parallel_degree:{context_parallel_degree}." - - if self.sharding_parallel_degree == -1: - if len(self.sharding) > 0: - self.sharding_parallel_degree = world_size // ( - tensor_parallel_degree - * sep_parallel_degree - * context_parallel_degree - * pipeline_parallel_degree - ) - - sharding_parallel_degree = max(self.sharding_parallel_degree, 1) - if sharding_parallel_degree == 1 and len(self.sharding) > 0: - logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!") - self.sharding = [] - - if sharding_parallel_degree > 1: - assert ( - sharding_parallel_degree % expert_parallel_degree == 0 - ), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}." - - self.data_parallel_degree = world_size // ( - sharding_parallel_degree - * tensor_parallel_degree - * sep_parallel_degree - * context_parallel_degree - * pipeline_parallel_degree - ) - - assert not ( - self.data_parallel_degree > 1 and expert_parallel_degree > 1 - ), f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. Currently data_parallel_degree is {self.data_parallel_degree}." - - if ( - sharding_parallel_degree > 1 - or tensor_parallel_degree > 1 - or pipeline_parallel_degree > 1 - or self.sep_parallel_degree > 1 - or self.context_parallel_degree > 1 - or expert_parallel_degree > 1 - or expert_tensor_parallel_degree > 1 - ): - self.use_hybrid_parallel = True - self.sharding_parallel_degree = sharding_parallel_degree - self.tensor_parallel_degree = tensor_parallel_degree - self.pipeline_parallel_degree = pipeline_parallel_degree - self.sep_parallel_degree = sep_parallel_degree - self.context_parallel_degree = context_parallel_degree - self.expert_parallel_degree = expert_parallel_degree - self.expert_tensor_parallel_degree = expert_tensor_parallel_degree - - if not self.use_hybrid_parallel: - self.sharding = [] - self.sharding_parallel_degree = -1 - self.tensor_parallel_degree = -1 - self.pipeline_parallel_degree = -1 - self.sep_parallel_degree = -1 - self.context_parallel_degree = -1 - self.expert_parallel_degree = -1 - self.expert_tensor_parallel_degree = -1 - - if self.hybrid_parallel_topo_order is None: - self.hybrid_parallel_topo_order = "pp_first" - assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"] - - if self.use_hybrid_parallel and self.enable_auto_parallel: - self.use_hybrid_parallel = False + self._post_init_parallel_degree() if self.to_static: assert world_size == 1 or self.enable_auto_parallel, ( @@ -2014,6 +1904,120 @@ def is_segment_parallel_supported(): self.zcc_workers_num == 1 ), "EMA function in zero cost checkpoint mode does not support zcc_workers_num > 1 for now." + def _post_init_parallel_degree(self): + self.use_hybrid_parallel = False + + if isinstance(self.sharding, bool): + self.sharding = "stage1" if self.sharding else "" + if isinstance(self.sharding, str): + self.sharding = [ShardingOption(s) for s in self.sharding.split()] + if self.sharding == [ShardingOption.OFFLOAD]: + raise ValueError( + "`--sharding offload` can't work on its own. It needs to be added to `--sharding stage2` or " + '`--sharding stage3`. For example, `--sharding "stage2 offload"`.' + ) + elif len(self.sharding) > (ShardingOption.OFFLOAD in self.sharding) + 1: + raise ValueError("`--sharding` recived too many arguments.") + + if self.sharding_degree > 0: + warnings.warn("`sharding_degree` is deprecated, please use `sharding_parallel_degree`") + self.sharding_parallel_degree = max(self.sharding_degree, self.sharding_parallel_degree) + self.data_parallel_degree = 1 + + delattr(self, "sharding_degree") + + if len(self.sharding) == 0 and self.sharding_parallel_degree > 0: + warnings.warn("`--sharding_parallel_degree` is useful only when `--sharding` is specified.") + + world_size = paddle.distributed.get_world_size() + + if world_size > 1: + tensor_parallel_degree = max(self.tensor_parallel_degree, 1) + sep_parallel_degree = max(self.sep_parallel_degree, 1) + context_parallel_degree = max(self.context_parallel_degree, 1) + pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) + expert_parallel_degree = max(self.expert_parallel_degree, 1) + expert_tensor_parallel_degree = max(self.expert_tensor_parallel_degree, 1) + + # TODO(@gexiao): support expert_tensor_parallel_degree > 1 in the future + assert ( + expert_tensor_parallel_degree == 1 + ), f"Currently only support expert_tensor_parallel_degree=1, but got expert_tensor_parallel_degree of {expert_tensor_parallel_degree}" + + assert ( + world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0 + ), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}." + + assert not ( + sep_parallel_degree > 1 and context_parallel_degree > 1 + ), f"sep parallel and context parallel cannot be used together, sep_parallel_degree:{sep_parallel_degree}, context_parallel_degree:{context_parallel_degree}." + + if self.sharding_parallel_degree == -1: + if len(self.sharding) > 0: + self.sharding_parallel_degree = world_size // ( + tensor_parallel_degree + * sep_parallel_degree + * context_parallel_degree + * pipeline_parallel_degree + ) + + sharding_parallel_degree = max(self.sharding_parallel_degree, 1) + if sharding_parallel_degree == 1 and len(self.sharding) > 0: + logger.warning("sharding_parallel_degree=1 means no sharding, please set sharding to empty!") + self.sharding = [] + + if sharding_parallel_degree > 1: + assert ( + sharding_parallel_degree % expert_parallel_degree == 0 + ), f"sharding_parallel_degree should be divided by expert_parallel_degree, current sharding_parallel_degree: {sharding_parallel_degree}, expert_parallel_degree: {expert_parallel_degree}." + + self.data_parallel_degree = world_size // ( + sharding_parallel_degree + * tensor_parallel_degree + * sep_parallel_degree + * context_parallel_degree + * pipeline_parallel_degree + ) + + assert not ( + self.data_parallel_degree > 1 and expert_parallel_degree > 1 + ), f"Currently only support use expert_data_parallel strategy together with sharding_parallel strategy, but not with data_parallel strategy. Currently data_parallel_degree is {self.data_parallel_degree}." + + if ( + sharding_parallel_degree > 1 + or tensor_parallel_degree > 1 + or pipeline_parallel_degree > 1 + or self.sep_parallel_degree > 1 + or self.context_parallel_degree > 1 + or expert_parallel_degree > 1 + or expert_tensor_parallel_degree > 1 + ): + self.use_hybrid_parallel = True + self.sharding_parallel_degree = sharding_parallel_degree + self.tensor_parallel_degree = tensor_parallel_degree + self.pipeline_parallel_degree = pipeline_parallel_degree + self.sep_parallel_degree = sep_parallel_degree + self.context_parallel_degree = context_parallel_degree + self.expert_parallel_degree = expert_parallel_degree + self.expert_tensor_parallel_degree = expert_tensor_parallel_degree + + if not self.use_hybrid_parallel: + self.sharding = [] + self.sharding_parallel_degree = -1 + self.tensor_parallel_degree = -1 + self.pipeline_parallel_degree = -1 + self.sep_parallel_degree = -1 + self.context_parallel_degree = -1 + self.expert_parallel_degree = -1 + self.expert_tensor_parallel_degree = -1 + + if self.hybrid_parallel_topo_order is None: + self.hybrid_parallel_topo_order = "pp_first" + assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"] + + if self.use_hybrid_parallel and self.enable_auto_parallel: + self.use_hybrid_parallel = False + def add_moe_comm_group(self): hcg = fleet.get_hybrid_communicate_group() topo = hcg._topo