Skip to content

Commit cac9f1d

Browse files
authored
Fix Replay Buffer docs. (#4574)
1 parent 547d924 commit cac9f1d

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

docs/source/grpo_with_replay_buffer.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ This experimental trainer, trains a model with GRPO but replaces groups (and cor
55
## Usage
66

77
```python
8-
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferTrainer
8+
import torch
9+
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferConfig, GRPOWithReplayBufferTrainer
910
from datasets import load_dataset
1011

1112
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
@@ -18,15 +19,16 @@ def custom_reward_func(completions, **kwargs):
1819
return torch.rand(len(completions)).tolist()
1920

2021
training_args = GRPOWithReplayBufferConfig(
21-
output_dir=self.tmp_dir,
22+
output_dir="./tmp",
2223
learning_rate=1e-4,
2324
per_device_train_batch_size=4,
2425
num_generations=4,
2526
max_completion_length=8,
2627
replay_buffer_size=8,
2728
report_to="none",
2829
)
29-
trainer = GRPOTrainer(
30+
31+
trainer = GRPOWithReplayBufferTrainer(
3032
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
3133
reward_funcs=[custom_reward_func],
3234
args=training_args,

0 commit comments

Comments
 (0)