Skip to content

Commit

Permalink
think all the samples are eligible for training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 3, 2025
1 parent c829fd4 commit 6d5cbb1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 28 deletions.
39 changes: 12 additions & 27 deletions palm_rlhf_pytorch/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def forward(
'mask',
'action_prob',
'action_log_prob',
'reward',
'reward_mean',
'reward_variance'
'group_relative_normalized_reward',
])

class ExperienceDataset(Dataset):
Expand Down Expand Up @@ -406,14 +404,9 @@ def learn(
old_action_probs,
old_log_probs,
rewards,
rewards_mean,
rewards_variance
) in dl:
action_masks = ~prompt_masks & masks

values = torch.tensor(0.)
old_values = torch.tensor(0.)

action_logits = self.actor(
sequences,
mask = action_masks
Expand Down Expand Up @@ -444,13 +437,9 @@ def learn(
# calculate clipped surrogate objective, classic PPO loss

ratios = (action_log_probs - old_log_probs).exp()
advantages = (rewards - rewards_mean) / rewards_variance.clamp(min = 1e-5).sqrt()

if advantages.ndim == 1:
advantages = rearrange(advantages, 'b -> b 1')

surr1 = ratios * advantages
surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages
surr1 = ratios * rewards
surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * rewards
policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropies

# combine losses
Expand Down Expand Up @@ -551,24 +540,20 @@ def train(

# use the first reward for training, the rest of them to derive statistics for normalization, iiuc

reward, rewards = rewards[0], rewards[1:]

rewards_mean, rewards_variance = rewards.mean(), rewards.var(unbiased = False)
normalized_rewards = (rewards - rewards.mean()) / rewards.var(unbiased = False).clamp(min = 1e-5).sqrt()

# store memory for learning

detach_to_cpu_ = lambda t: t.detach().cpu()

memories.append(Memory(*map(detach_to_cpu_, (
first(sequence),
first(prompt_mask),
first(mask),
first(action_prob),
first(action_log_prob),
reward,
rewards_mean,
rewards_variance
))))
memories.extend([Memory(*memories) for memories in zip(*map(detach_to_cpu_, (
sequence,
prompt_mask,
mask,
action_prob,
action_log_prob,
normalized_rewards,
)))])

# learn from the stored memories

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.1',
version = '0.4.3',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6d5cbb1

Please sign in to comment.