Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llm/alignment/ppo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ wget https://paddlenlp.bj.bcebos.com/datasets/examples/ppo-kk.tgz && tar zxf ppo
- `top_p`: 生成解码超参数
- `temperature`: 生成解码超参数
- `repetition_penalty`: 生成解码超参数
- `num_return_sequences`: 生成解码超参数
- `rollout_n`: 生成解码超参数
- `min_learning_rate`: Actor 模型的最小学习率
- `critic_learning_rate`: Critic 模型的最小学习率
- `recompute`: Actor 模型是否使用重计算策略,开启后可节省训练显存
Expand Down
188 changes: 183 additions & 5 deletions paddlenlp/datasets/rlhf_datasets/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
Implement base data transfer protocol between any two functions, modules.
We can subclass Protocol to define more detailed batch info with specific keys
"""

import copy
from dataclasses import dataclass, field
from typing import Dict, List, Union
from typing import Dict, List, Sequence, Union

import numpy as np
import paddle
import pandas as pd
from paddle.io import DataLoader
from paddle.utils import map_structure

original_concat = paddle.concat
__all__ = [
"DataProto",
"union_tensor_dict",
Expand All @@ -49,7 +52,18 @@ def __setitem__(self, key: str, tensor: paddle.Tensor):
self._tensors[key] = tensor

def __getitem__(self, key):
return self._tensors[key]
if isinstance(key, str):
return self._tensors[key]
elif isinstance(key, slice):
strides = [1] if key.step is None else [key.step]
tensor_dict_slice = {
k: paddle.strided_slice(v, axes=[0], starts=[key.start], ends=[key.stop], strides=strides)
for k, v in self._tensors.items()
}
batch_size = tensor_dict_slice[list(tensor_dict_slice.keys())[0]].shape[: self.num_batch_dims]
return TensorDict(tensor_dict_slice, batch_size=batch_size, num_batch_dims=self.num_batch_dims)
else:
raise KeyError(f"Unsupported key type: {type(key)}")

def keys(self):
return self._tensors.keys()
Expand All @@ -62,6 +76,58 @@ def to(self, device: str):
self._tensors[key] = self._tensors[key].to(device)
return self

@classmethod
def concat(cls, tensordict_list, axis=0):
if not tensordict_list:
raise ValueError("tensordict_list must not be empty")

# 获取第一个 TensorDict 的键和对应的张量形状,用于验证后续 TensorDict 的一致性
first_tensordict = tensordict_list[0]
first_keys = first_tensordict.keys()
first_shapes = {key: tensor.shape for key, tensor in first_tensordict.items()}

# 验证所有 TensorDict 是否具有相同的键和对应的张量形状(除了拼接维度)
for tensordict in tensordict_list:
if tensordict.keys() != first_keys:
raise ValueError("All TensorDict objects must have the same keys")
for key in first_keys:
if (
tensordict[key].shape[:axis] + tensordict[key].shape[axis + 1 :]
!= first_shapes[key][:axis] + first_shapes[key][axis + 1 :]
):
raise ValueError(f"Shapes of tensor '{key}' do not match except on concatenation axis {axis}")

# 拼接每个键对应的张量
concatenated_tensors = {
key: paddle.concat([tensordict[key] for tensordict in tensordict_list], axis=axis) for key in first_keys
}

# 创建一个新的 TensorDict 对象并返回
batch_size = concatenated_tensors[list(concatenated_tensors.keys())[0]].shape[
: tensordict_list[0].num_batch_dims
]
return cls(concatenated_tensors, batch_size=batch_size, num_batch_dims=tensordict_list[0].num_batch_dims)


def tensordict_concat(
x: Union[Sequence[paddle.Tensor], Sequence[TensorDict]],
axis: int | paddle.Tensor = 0,
name: str | None = None,
):
def is_tensor_sequence():
if isinstance(x[0], paddle.Tensor):
return True
else:
return False

if not is_tensor_sequence() and paddle.in_dynamic_mode():
return TensorDict.concat(x)
else:
return original_concat(x, axis, name)


paddle.concat = tensordict_concat


def union_two_dict(dict1: Dict, dict2: Dict):
"""Union two dict. Will throw an error if there is an item not the same object with the same key.
Expand Down Expand Up @@ -239,9 +305,39 @@ def __len__(self):
return 0

def __getitem__(self, item):
tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
"""
Enhanced indexing for DataProto objects.

Args:
item: Can be one of:
- int: A single index
- slice: A slice object (start:stop:step)
- list: A list of indices
- numpy.ndarray: An array of indices
- torch.Tensor: A tensor of indices

Returns:
DataProto: For all indexing types except single integers
DataProtoItem: Only for single integer indices
"""
# Case 1: Slice object - use the slice method
if isinstance(item, slice):
return self.slice(item.start, item.stop, item.step)

# Case 2: List, numpy array, or torch tensor - use sel_idxs
elif isinstance(item, (list, np.ndarray, paddle.Tensor)):
return self.select_idxs(item)

# Case 3: Single integer - return DataProtoItem for backward compatibility
elif isinstance(item, (int, np.integer)):
tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)

# Case 4: Unsupported type
else:
raise TypeError(f"Indexing with {type(item)} is not supported")

def print_size(self, prefix=""):
size_of_tensordict = 0
Expand Down Expand Up @@ -385,6 +481,87 @@ def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=Non

return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)

def select_idxs(self, idxs):
"""
Select specific indices from the DataProto.

Args:
idxs (torch.Tensor or numpy.ndarray or list): Indices to select

Returns:
DataProto: A new DataProto containing only the selected indices
"""
if isinstance(idxs, list):
idxs = paddle.tensor(idxs, dtype=paddle.int32)

if isinstance(idxs, np.ndarray):
idxs_np = idxs
idxs_paddle = paddle.from_numpy(idxs)
else: # torch.Tensor
idxs_paddle = idxs
idxs_np = idxs.numpy()

if self.batch is not None:
# Use TensorDict's built-in indexing capabilities
selected_batch = TensorDict(
source={key: tensor[idxs_paddle] for key, tensor in self.batch.items()},
batch_size=(idxs_paddle.shape[0],),
)
else:
selected_batch = None

selected_non_tensor = {}
for key, val in self.non_tensor_batch.items():
selected_non_tensor[key] = val[idxs_np]

return DataProto(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info)

def slice(self, start=None, end=None, step=None):
"""
Slice the DataProto and return a new DataProto object.
This is an improved version of direct slicing which returns a DataProtoItem.

Args:
start (int, optional): Start index. Defaults to None (start from beginning).
end (int, optional): End index (exclusive). Defaults to None (go to end).
step (int, optional): Step size. Defaults to None (step=1).

Returns:
DataProto: A new DataProto containing the sliced data

Examples:
# Using the slice method directly
sliced_data = data_proto.slice(10, 20)

# Using enhanced indexing (returns DataProto)
sliced_data = data_proto[10:20]
sliced_data = data_proto[::2] # Every other element

# Using list indexing (returns DataProto)
indices = [1, 5, 10]
selected_data = data_proto[indices]

# Single index still returns DataProtoItem
single_item = data_proto[5]
"""
# Create a slice object
slice_obj = slice(start, end, step)

# Handle the batch data
if self.batch is not None:
# Use TensorDict's built-in slicing capabilities
sliced_batch = self.batch[slice_obj]
else:
sliced_batch = None

# Handle the non-tensor batch data
sliced_non_tensor = {}
for key, val in self.non_tensor_batch.items():
sliced_non_tensor[key] = val[slice_obj]

# Return a new DataProto object
return DataProto(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info)

def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto":
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`

Expand Down Expand Up @@ -447,6 +624,7 @@ def validate_input(keys):
def union(self, other: "DataProto") -> "DataProto":
"""Union with another DataProto. Union batch and meta_info separately.
Throw an error if

- there are conflict keys in batch and they are not equal
- the batch size of two data batch is not the same
- there are conflict keys in meta_info and they are not the same.
Expand Down
98 changes: 98 additions & 0 deletions paddlenlp/rl/algos/advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,61 @@
from ..utils.comm_utils import masked_whiten


def compute_gae_advantage_return(
token_level_rewards: paddle.Tensor,
values: paddle.Tensor,
sequence_mask: paddle.Tensor,
start: int,
gamma: paddle.Tensor,
lam: paddle.Tensor,
use_tgt_len_return: bool = True,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute advantages and returns using Generalized Advantage Estimation (GAE)."""
# Modified from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py
lastgaelam = 0.0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]

values = values * sequence_mask
token_level_rewards = token_level_rewards * sequence_mask
if use_tgt_len_return and start > 0:
# consistent with Beaver
# values length is src+tgt-1, start is src-1, return length is tgt
pass
elif use_tgt_len_return:
# values length is tgt, start is 0, return length is tgt
assert start == 0
else:
# values length is src+tgt-1, start is src-1, return length is src+tgt-1
pass
for t in reversed(range(start, gen_len)): # pylint: disable=invalid-name
next_values = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * next_values - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = paddle.stack(advantages_reversed[::-1], axis=1)

returns = advantages + values[:, start:].contiguous()

if not use_tgt_len_return:
advantages = paddle.concat(
[
paddle.zeros([advantages.shape[0], start], dtype=advantages.dtype),
advantages,
],
axis=-1,
)
returns = paddle.concat(
[
paddle.zeros([returns.shape[0], start], dtype=returns.dtype),
returns,
],
axis=-1,
)

return advantages.detach(), returns


@paddle.no_grad()
def compute_grpo_advantages(
rewards: paddle.Tensor,
Expand Down Expand Up @@ -101,3 +156,46 @@ def compute_reinforce_plus_plus_advantages_and_returns(
advantages = masked_whiten(returns, eos_mask)
advantages = advantages * eos_mask
return advantages, returns


def add_kl_divergence_regularization(
prompt: paddle.Tensor, # size = (B, S) # pylint: disable=unused-argument
log_probs: paddle.Tensor, # size = (B, L)
ref_log_probs: paddle.Tensor, # size = (B, L)
reward_score: paddle.Tensor, # size = (B,)
sequence_mask: paddle.Tensor, # size = (B, L)
kl_coeff: float,
clip_range_score: float,
) -> paddle.Tensor:
"""
Calculate the KL divergence regularization gain and add it to the reward.

Args:
prompt (paddle.Tensor, shape=(B, S)): The prompt of the input sequence, not used.
log_probs (paddle.Tensor, shape=(B, L)): The log probability distribution of the current predictions.
ref_log_probs (paddle.Tensor, shape=(B, L)): The log probability distribution of the baseline predictions.
reward_score (paddle.Tensor, shape=(B,)): The base reward score based on the prompt and output sequence.
sequence_mask (paddle.Tensor, shape=(B, L)): The mask of the sequence, used to determine the length of the sequence.

Returns:
paddle.Tensor, shape=(B, L): A vector containing the KL divergence regularization gain.
"""

kl_divergence_estimate = -kl_coeff * (log_probs - ref_log_probs) # size = (B, L)
rewards = kl_divergence_estimate # size = (B, L)
reward_clip = paddle.clip( # size = (B,)
reward_score,
min=-clip_range_score,
max=clip_range_score,
)
# TODO(guosheng): use scatter_add/put_along_axis
index = paddle.cumsum(sequence_mask.cast(paddle.int64), axis=-1).argmax(-1, keepdim=True)

rewards = paddle.put_along_axis(
rewards,
index,
reward_clip.unsqueeze(axis=-1),
axis=-1,
reduce="add",
)
return rewards, kl_divergence_estimate
Loading
Loading