From 96a73447a5c31a29a7ea78df884b50257293e4e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 5 Sep 2025 22:36:49 +0000 Subject: [PATCH 1/4] Fix passing model kwargs --- examples/scripts/dpo_online.py | 21 +++++++------ examples/scripts/dpo_vlm.py | 20 ++++++++----- examples/scripts/gkd.py | 46 ++++++++++++++++++----------- examples/scripts/mpo_vlm.py | 21 +++++++------ examples/scripts/nash_md.py | 21 +++++++------ examples/scripts/ppo/ppo.py | 20 ++++++++----- examples/scripts/ppo/ppo_tldr.py | 20 ++++++++----- examples/scripts/prm.py | 32 ++++++++++++-------- examples/scripts/reward_modeling.py | 21 ++++++++----- examples/scripts/sft_gpt_oss.py | 23 ++++++++------- examples/scripts/sft_video_llm.py | 21 ++++++------- examples/scripts/sft_vlm.py | 20 ++++++++----- examples/scripts/sft_vlm_gemma3.py | 21 ++++++++----- examples/scripts/xpo.py | 21 +++++++------ trl/scripts/dpo.py | 19 +++++++----- trl/scripts/sft.py | 21 ++++++++----- trl/trainer/gkd_trainer.py | 1 - trl/trainer/online_dpo_trainer.py | 1 - 18 files changed, 217 insertions(+), 153 deletions(-) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 83c8265a5e8..a6df6c507ac 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -83,16 +83,19 @@ script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - dtype=dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index 249dfd20123..4052320b04c 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -87,16 +87,20 @@ ################ # Model & Tokenizer ################ - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - dtype=dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) model = AutoModelForImageTextToText.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 61a2019db81..8bbfc48cba0 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -54,6 +54,7 @@ import os +import torch from datasets import load_dataset from transformers import AutoTokenizer, GenerationConfig @@ -81,27 +82,36 @@ ################ # Model & Tokenizer ################ + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - attn_implementation=model_args.attn_implementation, - dtype=model_args.dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config training_args.model_init_kwargs = model_kwargs - teacher_model_kwargs = dict( - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - attn_implementation=model_args.attn_implementation, - dtype=model_args.dtype, - use_cache=True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + teacher_model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + model_kwargs["use_cache"] = True + training_args.teacher_model_init_kwargs = teacher_model_kwargs tokenizer = AutoTokenizer.from_pretrained( diff --git a/examples/scripts/mpo_vlm.py b/examples/scripts/mpo_vlm.py index bf4db4214d8..674e2ff7f0a 100644 --- a/examples/scripts/mpo_vlm.py +++ b/examples/scripts/mpo_vlm.py @@ -71,17 +71,20 @@ ################ # Model & Processor ################ - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config - model_kwargs = dict( - trust_remote_code=model_args.trust_remote_code, - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - dtype=dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) model = AutoModelForImageTextToText.from_pretrained( model_args.model_name_or_path, **model_kwargs, diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index 1c37bdadeae..da3fdbe2fb3 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -87,16 +87,19 @@ script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - dtype=dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 9ed9ae2917e..390a108025c 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -90,15 +90,19 @@ ################ # Model & Tokenizer ################ - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - dtype=dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 13301ccfe7b..128023ad0be 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -97,15 +97,19 @@ ################ # Model & Tokenizer ################ - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - dtype=dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code diff --git a/examples/scripts/prm.py b/examples/scripts/prm.py index 2df83da5efc..8d362e86fe5 100644 --- a/examples/scripts/prm.py +++ b/examples/scripts/prm.py @@ -76,30 +76,36 @@ if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, PRMConfig, ModelConfig)) - script_args, training_args, model_config = parser.parse_args_into_dataclasses() + script_args, training_args, model_args = parser.parse_args_into_dataclasses() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) ################ # Model & Tokenizer ################ - dtype = model_config.dtype if model_config.dtype in ["auto", None] else getattr(torch, model_config.dtype) - quantization_config = get_quantization_config(model_config) - model_kwargs = dict( - revision=model_config.model_revision, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - use_cache=False if training_args.gradient_checkpointing else True, - ) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) model = AutoModelForTokenClassification.from_pretrained( - model_config.model_name_or_path, num_labels=2, trust_remote_code=model_config.trust_remote_code, **model_kwargs + model_args.model_name_or_path, num_labels=2, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) # Align padding tokens between tokenizer and model model.config.pad_token_id = tokenizer.pad_token_id - if model_config.use_peft and model_config.lora_task_type != "TOKEN_CLS": + if model_args.use_peft and model_args.lora_task_type != "TOKEN_CLS": logger.warning( "You are using a `task_type` that is different than `TOKEN_CLS` for PEFT. This will lead to silent bugs" " Make sure to pass --lora_task_type TOKEN_CLS when using this script with PEFT.", @@ -121,7 +127,7 @@ args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split], - peft_config=get_peft_config(model_config), + peft_config=get_peft_config(model_args), ) trainer.train() diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 7b9e2dd4a55..4a2583b66ec 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -85,15 +85,20 @@ ################ # Model & Tokenizer ################ - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - use_cache=False if training_args.gradient_checkpointing else True, - dtype=dtype, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) diff --git a/examples/scripts/sft_gpt_oss.py b/examples/scripts/sft_gpt_oss.py index 61085211803..2eb6e484250 100644 --- a/examples/scripts/sft_gpt_oss.py +++ b/examples/scripts/sft_gpt_oss.py @@ -31,7 +31,7 @@ examples/scripts/sft_gpt_oss.py \ --dtype bfloat16 \ --model_name_or_path openai/gpt-oss-20b \ - --packing true packing_strategy wrapped \ + --packing \ --run_name 20b-full-eager \ --attn_implementation kernels-community/vllm-flash-attn3 \ --dataset_num_proc 12 \ @@ -51,6 +51,7 @@ import os +import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config @@ -63,15 +64,17 @@ def main(script_args, training_args, model_args): # Load model & tokenizer - quantization_config = Mxfp4Config(dequantize=True) - model_kwargs = dict( - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - attn_implementation=model_args.attn_implementation, - dtype=model_args.dtype, - use_cache=False if training_args.gradient_checkpointing else True, - quantization_config=quantization_config, - ) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + model_kwargs["device_map"] = quantization_config = Mxfp4Config(dequantize=True) + model_kwargs["quantization_config"] = quantization_config model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index f0b2c174dea..f4df858a210 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -187,9 +187,6 @@ class CustomScriptArguments(ScriptArguments): # Load dataset dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train") - # Setup model - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) - # Quantization configuration for 4-bit training bnb_config = BitsAndBytesConfig( load_in_4bit=True, @@ -199,13 +196,17 @@ class CustomScriptArguments(ScriptArguments): ) # Model initialization - model_kwargs = dict( - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - dtype=dtype, - device_map=get_kbit_device_map(), - quantization_config=bnb_config, - ) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = bnb_config model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs) diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index 3cef50cf829..50d61748db1 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -84,15 +84,19 @@ ################ # Model, Tokenizer & Processor ################ - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - dtype=dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config model = AutoModelForImageTextToText.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs diff --git a/examples/scripts/sft_vlm_gemma3.py b/examples/scripts/sft_vlm_gemma3.py index 1434ffef586..e0612a9182a 100644 --- a/examples/scripts/sft_vlm_gemma3.py +++ b/examples/scripts/sft_vlm_gemma3.py @@ -149,15 +149,20 @@ def main(): ################ # Model, Tokenizer & Processor ################ - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - dtype=dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + model = AutoModelForImageTextToText.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index b7da476a95f..f844e9cd7dc 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -72,16 +72,19 @@ script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - dtype=dtype, - use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 1a2a1dd85e1..0b4f6cea33e 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -93,15 +93,18 @@ def main(script_args, training_args, model_args, dataset_args): ################ # Model & Tokenizer ################### - dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - dtype=dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index f108d21389d..352a7bfd700 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -65,6 +65,7 @@ import argparse import os +import torch from accelerate import logging from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer @@ -95,15 +96,19 @@ def main(script_args, training_args, model_args, dataset_args): ################ # Model init kwargs & Tokenizer ################ + model_kwargs = {} + if model_args.revision is not None: + model_kwargs["revision"] = model_args.revision + if model_args.trust_remote_code is not None: + model_kwargs["trust_remote_code"] = model_args.trust_remote_code + if model_args.attn_implementation is not None: + model_kwargs["attn_implementation"] = model_args.attn_implementation + if model_args.dtype is not None: + model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - attn_implementation=model_args.attn_implementation, - dtype=model_args.dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config # Create model config = AutoConfig.from_pretrained(model_args.model_name_or_path) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 9d17f32e417..1429dc981ef 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -155,7 +155,6 @@ def __init__( temperature=args.temperature, do_sample=True, top_k=0, - use_cache=False if args.gradient_checkpointing else True, pad_token_id=self.processing_class.pad_token_id, ) # Set custom EOS tokens if they are specified by the model's generation diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index f70624130c9..f48d65507dd 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -303,7 +303,6 @@ def __init__( top_k=50, top_p=1.0, do_sample=True, - use_cache=False if args.gradient_checkpointing else True, ) # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the From 64532fd374cf8702999db8040937a04224aee184 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 9 Sep 2025 19:42:38 +0000 Subject: [PATCH 2/4] focus on quantization config --- examples/scripts/dpo_vlm.py | 17 +++++++------ examples/scripts/gkd.py | 37 +++++++++++++---------------- examples/scripts/mpo_vlm.py | 18 +++++++------- examples/scripts/nash_md.py | 17 +++++++------ examples/scripts/online_dpo.py | 17 +++++++------ examples/scripts/ppo/ppo.py | 18 +++++++------- examples/scripts/ppo/ppo_tldr.py | 16 ++++++------- examples/scripts/prm.py | 15 +++++------- examples/scripts/reward_modeling.py | 16 ++++++------- examples/scripts/sft_gpt_oss.py | 21 +++++++--------- examples/scripts/sft_video_llm.py | 21 ++++++++-------- examples/scripts/sft_vlm.py | 16 ++++++------- examples/scripts/sft_vlm_gemma3.py | 16 ++++++------- examples/scripts/xpo.py | 17 +++++++------ trl/scripts/dpo.py | 14 +++++------ trl/scripts/sft.py | 17 ++++++------- trl/trainer/gkd_trainer.py | 1 + trl/trainer/online_dpo_trainer.py | 1 + 18 files changed, 134 insertions(+), 161 deletions(-) diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py index 4052320b04c..64f902d7926 100644 --- a/examples/scripts/dpo_vlm.py +++ b/examples/scripts/dpo_vlm.py @@ -87,17 +87,16 @@ ################ # Model & Tokenizer ################ - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index 8bbfc48cba0..b186662b62f 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -54,7 +54,6 @@ import os -import torch from datasets import load_dataset from transformers import AutoTokenizer, GenerationConfig @@ -82,35 +81,31 @@ ################ # Model & Tokenizer ################ - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + dtype=model_args.dtype, + use_cache=False if training_args.gradient_checkpointing else True, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config + training_args.model_init_kwargs = model_kwargs - teacher_model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) - quantization_config = get_quantization_config(model_args) + teacher_model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + dtype=model_args.dtype, + use_cache=True, + ) if quantization_config is not None: model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config - model_kwargs["use_cache"] = True training_args.teacher_model_init_kwargs = teacher_model_kwargs diff --git a/examples/scripts/mpo_vlm.py b/examples/scripts/mpo_vlm.py index 674e2ff7f0a..4d168e82185 100644 --- a/examples/scripts/mpo_vlm.py +++ b/examples/scripts/mpo_vlm.py @@ -71,17 +71,17 @@ ################ # Model & Processor ################ - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + + model_kwargs = dict( + trust_remote_code=model_args.trust_remote_code, + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index a7cd89734c7..6d7a91f8e1b 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -87,17 +87,16 @@ script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + use_cache=False if training_args.gradient_checkpointing else True, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index 96a18aa79cf..896f9b30950 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -83,17 +83,16 @@ script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + use_cache=False if training_args.gradient_checkpointing else True, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 390a108025c..759dbd475be 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -90,17 +90,15 @@ ################ # Model & Tokenizer ################ - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) quantization_config = get_quantization_config(model_args) - if quantization_config is not None: + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 128023ad0be..eb0962de530 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -97,17 +97,15 @@ ################ # Model & Tokenizer ################ - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/prm.py b/examples/scripts/prm.py index 8d362e86fe5..dec154cbf93 100644 --- a/examples/scripts/prm.py +++ b/examples/scripts/prm.py @@ -82,17 +82,14 @@ ################ # Model & Tokenizer ################ - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + use_cache=False if training_args.gradient_checkpointing else True, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 4a2583b66ec..0cedb350bd6 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -85,17 +85,15 @@ ################ # Model & Tokenizer ################ - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + use_cache=False if training_args.gradient_checkpointing else True, + dtype=dtype, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/sft_gpt_oss.py b/examples/scripts/sft_gpt_oss.py index 2eb6e484250..367177c7028 100644 --- a/examples/scripts/sft_gpt_oss.py +++ b/examples/scripts/sft_gpt_oss.py @@ -51,7 +51,6 @@ import os -import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config @@ -64,17 +63,15 @@ def main(script_args, training_args, model_args): # Load model & tokenizer - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) - model_kwargs["device_map"] = quantization_config = Mxfp4Config(dequantize=True) - model_kwargs["quantization_config"] = quantization_config + quantization_config = Mxfp4Config(dequantize=True) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + dtype=model_args.dtype, + use_cache=False if training_args.gradient_checkpointing else True, + quantization_config=quantization_config, + ) model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py index f4df858a210..f0b2c174dea 100644 --- a/examples/scripts/sft_video_llm.py +++ b/examples/scripts/sft_video_llm.py @@ -187,6 +187,9 @@ class CustomScriptArguments(ScriptArguments): # Load dataset dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train") + # Setup model + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + # Quantization configuration for 4-bit training bnb_config = BitsAndBytesConfig( load_in_4bit=True, @@ -196,17 +199,13 @@ class CustomScriptArguments(ScriptArguments): ) # Model initialization - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) - model_kwargs["device_map"] = get_kbit_device_map() - model_kwargs["quantization_config"] = bnb_config + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + dtype=dtype, + device_map=get_kbit_device_map(), + quantization_config=bnb_config, + ) model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs) diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index c77b7cf0f9d..38fe1a482f1 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -84,17 +84,15 @@ ################ # Model, Tokenizer & Processor ################ - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/sft_vlm_gemma3.py b/examples/scripts/sft_vlm_gemma3.py index e0612a9182a..12b80eefb3f 100644 --- a/examples/scripts/sft_vlm_gemma3.py +++ b/examples/scripts/sft_vlm_gemma3.py @@ -149,17 +149,15 @@ def main(): ################ # Model, Tokenizer & Processor ################ - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 2933f981588..d06ee2740a1 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -72,17 +72,16 @@ script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + use_cache=False if training_args.gradient_checkpointing else True, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py index 0b4f6cea33e..03024a1a0da 100644 --- a/trl/scripts/dpo.py +++ b/trl/scripts/dpo.py @@ -93,15 +93,15 @@ def main(script_args, training_args, model_args, dataset_args): ################ # Model & Tokenizer ################### - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index b5c26ce68cc..e0cced1ff1e 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -65,7 +65,6 @@ import argparse import os -import torch from accelerate import logging from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer @@ -95,17 +94,15 @@ def main(script_args, training_args, model_args, dataset_args): ################ # Model init kwargs & Tokenizer ################ - model_kwargs = {} - if model_args.revision is not None: - model_kwargs["revision"] = model_args.revision - if model_args.trust_remote_code is not None: - model_kwargs["trust_remote_code"] = model_args.trust_remote_code - if model_args.attn_implementation is not None: - model_kwargs["attn_implementation"] = model_args.attn_implementation - if model_args.dtype is not None: - model_kwargs["dtype"] = "auto" if model_args.dtype == "auto" else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + dtype=model_args.dtype, + ) quantization_config = get_quantization_config(model_args) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index f91cb300377..0db55700df9 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -194,6 +194,7 @@ def __init__( temperature=args.temperature, do_sample=True, top_k=0, + use_cache=False if args.gradient_checkpointing else True, pad_token_id=self.processing_class.pad_token_id, ) # Set custom EOS tokens if they are specified by the model's generation diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index d6e9cddc0af..f44523f6d99 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -563,6 +563,7 @@ def __init__( "top_k": self.top_k, "top_p": self.top_p, "repetition_penalty": self.repetition_penalty, + "use_cache": True if not self.args.gradient_checkpointing else False, } # Add min_p if supported if self.min_p is not None: From de4d583231a7b3c161aa383b7909c7ced27debf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 9 Sep 2025 19:50:31 +0000 Subject: [PATCH 3/4] style --- examples/scripts/ppo/ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 759dbd475be..80e9b71e648 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -97,7 +97,7 @@ dtype=dtype, ) quantization_config = get_quantization_config(model_args) - if quantization_config is not None: + if quantization_config is not None: # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config From cd70e8b04805e5f6a1e024304f0127e5702fa51a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 9 Sep 2025 19:53:21 +0000 Subject: [PATCH 4/4] add some missing --- examples/scripts/gkd.py | 1 + examples/scripts/grpo_vlm.py | 8 +++++--- examples/scripts/gspo.py | 8 +++++--- examples/scripts/gspo_vlm.py | 8 +++++--- examples/scripts/online_dpo_vlm.py | 8 +++++--- 5 files changed, 21 insertions(+), 12 deletions(-) diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index b186662b62f..88d951b267b 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -104,6 +104,7 @@ use_cache=True, ) if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. model_kwargs["device_map"] = get_kbit_device_map() model_kwargs["quantization_config"] = quantization_config diff --git a/examples/scripts/grpo_vlm.py b/examples/scripts/grpo_vlm.py index 725f3acde42..ad62e728f56 100644 --- a/examples/scripts/grpo_vlm.py +++ b/examples/scripts/grpo_vlm.py @@ -97,14 +97,16 @@ # Model & Processor ################ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) - quantization_config = get_quantization_config(model_args) training_args.model_init_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, dtype=dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + training_args.model_init_kwargs["device_map"] = get_kbit_device_map() + training_args.model_init_kwargs["quantization_config"] = quantization_config ################ # Dataset diff --git a/examples/scripts/gspo.py b/examples/scripts/gspo.py index d75186c4c00..4a67642255c 100644 --- a/examples/scripts/gspo.py +++ b/examples/scripts/gspo.py @@ -83,14 +83,16 @@ # Model & Processor ################ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) - quantization_config = get_quantization_config(model_args) training_args.model_init_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, dtype=dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + training_args.model_init_kwargs["device_map"] = get_kbit_device_map() + training_args.model_init_kwargs["quantization_config"] = quantization_config ################ # Dataset diff --git a/examples/scripts/gspo_vlm.py b/examples/scripts/gspo_vlm.py index b57aa195c0a..d5532013976 100644 --- a/examples/scripts/gspo_vlm.py +++ b/examples/scripts/gspo_vlm.py @@ -84,14 +84,16 @@ # Model & Processor ################ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) - quantization_config = get_quantization_config(model_args) training_args.model_init_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, dtype=dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + training_args.model_init_kwargs["device_map"] = get_kbit_device_map() + training_args.model_init_kwargs["quantization_config"] = quantization_config ################ # Dataset diff --git a/examples/scripts/online_dpo_vlm.py b/examples/scripts/online_dpo_vlm.py index de89e030f59..a1c3010cd10 100644 --- a/examples/scripts/online_dpo_vlm.py +++ b/examples/scripts/online_dpo_vlm.py @@ -115,15 +115,17 @@ training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) - quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, dtype=dtype, use_cache=False if training_args.gradient_checkpointing else True, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config # Load the VLM model using correct architecture (from GRPO pattern) config = AutoConfig.from_pretrained(model_args.model_name_or_path)