Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions examples/scripts/dpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,18 @@
# Model & Tokenizer
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
Expand Down
16 changes: 11 additions & 5 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,19 @@
################
# Model & Tokenizer
################
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

training_args.model_init_kwargs = model_kwargs

teacher_model_kwargs = dict(
Expand All @@ -99,9 +102,12 @@
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
use_cache=True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

training_args.teacher_model_init_kwargs = teacher_model_kwargs

tokenizer = AutoTokenizer.from_pretrained(
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/grpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,16 @@
# Model & Processor
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
training_args.model_init_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

################
# Dataset
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/gspo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@
# Model & Processor
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
training_args.model_init_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

################
# Dataset
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/gspo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,16 @@
# Model & Processor
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
training_args.model_init_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

################
# Dataset
Expand Down
9 changes: 6 additions & 3 deletions examples/scripts/mpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,19 @@
# Model & Processor
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
trust_remote_code=model_args.trust_remote_code,
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/nash_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,17 @@
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/online_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,17 @@
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/online_dpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

# Load the VLM model using correct architecture (from GRPO pattern)
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,16 @@
# Model & Tokenizer
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,16 @@
# Model & Tokenizer
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
Expand Down
23 changes: 13 additions & 10 deletions examples/scripts/prm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,33 @@

if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, PRMConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)

################
# Model & Tokenizer
################
dtype = model_config.dtype if model_config.dtype in ["auto", None] else getattr(torch, model_config.dtype)
quantization_config = get_quantization_config(model_config)
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
model_kwargs = dict(
revision=model_config.model_revision,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
revision=model_args.model_revision,
use_cache=False if training_args.gradient_checkpointing else True,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
model = AutoModelForTokenClassification.from_pretrained(
model_config.model_name_or_path, num_labels=2, trust_remote_code=model_config.trust_remote_code, **model_kwargs
model_args.model_name_or_path, num_labels=2, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
# Align padding tokens between tokenizer and model
model.config.pad_token_id = tokenizer.pad_token_id

if model_config.use_peft and model_config.lora_task_type != "TOKEN_CLS":
if model_args.use_peft and model_args.lora_task_type != "TOKEN_CLS":
logger.warning(
"You are using a `task_type` that is different than `TOKEN_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type TOKEN_CLS when using this script with PEFT.",
Expand All @@ -121,7 +124,7 @@
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
peft_config=get_peft_config(model_config),
peft_config=get_peft_config(model_args),
)
trainer.train()

Expand Down
9 changes: 6 additions & 3 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,17 @@
# Model & Tokenizer
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
use_cache=False if training_args.gradient_checkpointing else True,
dtype=dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/sft_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
examples/scripts/sft_gpt_oss.py \
--dtype bfloat16 \
--model_name_or_path openai/gpt-oss-20b \
--packing true packing_strategy wrapped \
--packing \
--run_name 20b-full-eager \
--attn_implementation kernels-community/vllm-flash-attn3 \
--dataset_num_proc 12 \
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/sft_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,16 @@
# Model, Tokenizer & Processor
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
Expand Down
9 changes: 6 additions & 3 deletions examples/scripts/sft_vlm_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,17 @@ def main():
# Model, Tokenizer & Processor
################
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
Expand Down
8 changes: 5 additions & 3 deletions examples/scripts/xpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,17 @@
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
Expand Down
Loading
Loading