-
Notifications
You must be signed in to change notification settings - Fork 2.3k
added 10 papers (+trainer cross-links) for #4407 #4441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
494657f
70a26e6
85054a4
2e01916
4e2f10b
0cf0649
be590d6
5f2e5e9
597a397
f1d98b0
9348298
a9d9467
66cc83d
e35c02d
98d3086
627db69
33f7f42
1efce43
e50ed53
0856286
e0e0e39
f204d80
a7872c3
d428d41
ab26bcc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,8 +5,28 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ## Group Relative Policy Optimization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Papers relating to the [`GRPOTrainer`] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **π Paper**: https://huggingface.co/papers/2402.03300 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Introduces **GRPO** and shows strong math-reasoning gains from math-centric pretraining plus group-relative PPO-style optimization. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **Used in TRL via:** [`GRPOTrainer`] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ```python | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Minimal GRPO setup (mirrors style used for other papers on the page). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from trl import GRPOConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| training_args = GRPOConfig( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loss_type="grpo", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| beta=0.0, # GRPO commonly trains without explicit KL in released configs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| epsilon=2e-4, # clip range (use paper/experiment settings if you mirror them) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| epsilon_high=4e-4, # upper clip (symmetrical if not specified) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
SSusantAchary marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| steps_per_generation=4, # sample multiple completions per prompt | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gradient_accumulation_steps=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_generations=8, # completions per prompt (adjust to your compute) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_prompt_length=1024, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_completion_length=1024, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### Group Sequence Policy Optimization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **π Paper**: https://huggingface.co/papers/2507.18071 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -232,10 +252,6 @@ trainer = PAPOTrainer( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ## Direct Policy Optimization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Papers relating to the [`DPOTrainer`] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
SSusantAchary marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### Direct Preference Optimization (DPO): Your Language Model is Secretly a Reward Model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **π Paper**: https://huggingface.co/papers/2305.18290 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -457,6 +473,74 @@ training_args = DPOConfig( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| These parameters only appear in the [published version](https://aclanthology.org/2025.tacl-1.22.pdf) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ### Statistical Rejection Sampling Improves Preference Optimization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **π Paper**: https://huggingface.co/papers/2309.06657 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Proposes **RSO**, selecting stronger preference pairs via statistical rejection sampling to boost offline preference optimization; complements DPO/SLiC. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
qgallouedec marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ```python | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Curate DPO pairs with rejection sampling BEFORE training | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from datasets import Dataset | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from trl import DPOConfig, DPOTrainer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def rso_accept(ex): # replace with your statistic (gap / z-score / judge score) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ex.get("rso_keep", True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dpo_pairs = dpo_pairs.filter(rso_accept) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = AutoModelForCausalLM.from_pretrained("..."); tok = AutoTokenizer.from_pretrained("...") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args = DPOConfig(loss_type="sigmoid", beta=0.1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trainer = DPOTrainer(model=model, args=args, tokenizer=tok, train_dataset=dpo_pairs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trainer.train() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from datasets import Dataset | |
| from trl import DPOConfig, DPOTrainer | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def rso_accept(ex): # replace with your statistic (gap / z-score / judge score) | |
| return ex.get("rso_keep", True) | |
| dpo_pairs = dpo_pairs.filter(rso_accept) | |
| model = AutoModelForCausalLM.from_pretrained("..."); tok = AutoTokenizer.from_pretrained("...") | |
| args = DPOConfig(loss_type="sigmoid", beta=0.1) | |
| trainer = DPOTrainer(model=model, args=args, tokenizer=tok, train_dataset=dpo_pairs) | |
| trainer.train() | |
| from datasets import load_dataset | |
| from trl import DPOConfig, DPOTrainer | |
| train_dataset = load_dataset(...) | |
| def rso_accept(example): # replace with your statistic (gap / z-score / judge score) | |
| return example.get("rso_keep", True) | |
| train_dataset = train_dataset.filter(rso_accept) | |
| training_args = DPOConfig(loss_type="sigmoid", beta=0.1) | |
| trainer = DPOTrainer( | |
| ..., | |
| args=training_args, | |
| train_dataset=train_dataset | |
| ) | |
| trainer.train() |
for consistency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minimal
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there no need to pass the value if it matches the default. In other words, remove any occurence of like loss_type="sigmoid" or beta=0.1
qgallouedec marked this conversation as resolved.
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # LoRA adapters with SFT (works the same for DPO/GRPO by passing peft_config to those trainers) | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import LoraConfig | |
| from trl import SFTTrainer, SFTConfig | |
| model_id = "meta-llama/Llama-3.1-8B-Instruct" # any causal LM on HF Hub | |
| tok = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto") | |
| peft_cfg = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| # common modules for LLaMA/Mistral/Qwen/Gemma; adjust per model if needed | |
| target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], | |
| ) | |
| args = SFTConfig( | |
| max_seq_length=2048, | |
| per_device_train_batch_size=4, | |
| gradient_accumulation_steps=8, | |
| learning_rate=2e-4, | |
| bf16=True, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=args, | |
| tokenizer=tok, | |
| peft_config=peft_cfg, # <- LoRA enabled | |
| train_dataset=..., | |
| ) | |
| trainer.train() | |
| from peft import LoraConfig | |
| from trl import SFTTrainer | |
| trainer = SFTTrainer( | |
| ..., | |
| peft_config=LoraConfig(), | |
| ) |
the more minimal, the clearer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the original paper they don't use beta=0.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed