generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Open
Description
Reproduction
training_args = GRPOConfig(
output_dir=adapter_root,
overwrite_output_dir=True,
warmup_ratio=0.1,
num_train_epochs=set_config['PRETRAIN_EPOCHS'],
per_device_train_batch_size=set_config['PRETRAIN_BATCH_SIZE'],
gradient_accumulation_steps=set_config['GRADIENT_BACK'],
lr_scheduler_type="cosine",
save_steps=500,
save_total_limit=2,
logging_dir=adapter_root,
logging_steps=100,
learning_rate=set_config['PRETRAIN_LR'],
weight_decay=set_config['PRETRAIN_WD'],
bf16=True,
num_generations=set_config['PRETRAIN_GEN_NUM'], #
max_completion_length=set_config['MAX_GEN_LEN'],
use_vllm=True,
# vllm_mode="server",
vllm_mode="server",
generation_kwargs={
"max_tokens": 2048,
"temperature": 0.7,
"top_p": 0.9,
}
)
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
peft_config=lora_config,
reward_funcs=reward_funcs,
# vllm_engine_kwargs={
# "max_model_len": 24000,
# "gpu_memory_utilization": 0.7,
# "tensor_parallel_size": 2,
# },
)
trainer.train()outputs:
Traceback (most recent call last):
File "/hpc2hdd/home/sguo349/czhaobo/TTLHealth/src/main.py", line 299, in <module>
run_single_purellm_config(pretrain, config, retrain=retrain, exp_num=exp_num)
File "/hpc2hdd/home/sguo349/czhaobo/TTLHealth/src/main.py", line 166, in run_single_purellm_config
log_file, our_model, _, generation_params = create_pretrain_model(train_dataset, set_config, adapter_root, output_root, ttl_root, templates,retrain, benchmark_name=data_name + '-' + task_name + '-' + exp_num)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/hpc2hdd/home/sguo349/czhaobo/TTLHealth/src/train.py", line 38, in create_pretrain_model
train_llm(train_dataset, set_config, adapter_root, output_root)
File "/hpc2hdd/home/sguo349/czhaobo/TTLHealth/src/train.py", line 338, in train_llm
trainer.train()
File "/hpc2hdd/home/sguo349/miniconda3/envs/ttl-test/lib/python3.11/site-packages/transformers/trainer.py", line 2325, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/hpc2hdd/home/sguo349/miniconda3/envs/ttl-test/lib/python3.11/site-packages/transformers/trainer.py", line 2674, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/hpc2hdd/home/sguo349/miniconda3/envs/ttl-test/lib/python3.11/site-packages/trl/trainer/grpo_trainer.py", line 1124, in training_step
output = super().training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/hpc2hdd/home/sguo349/miniconda3/envs/ttl-test/lib/python3.11/site-packages/transformers/trainer.py", line 4014, in training_step
inputs = self._prepare_inputs(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/hpc2hdd/home/sguo349/miniconda3/envs/ttl-test/lib/python3.11/site-packages/trl/extras/profiling.py", line 98, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/hpc2hdd/home/sguo349/miniconda3/envs/ttl-test/lib/python3.11/site-packages/trl/trainer/grpo_trainer.py", line 1153, in _prepare_inputs
generation_batch = self._generate_and_score_completions(generation_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/hpc2hdd/home/sguo349/miniconda3/envs/ttl-test/lib/python3.11/site-packages/trl/trainer/grpo_trainer.py", line 1775, in _generate_and_score_completions
) = self._generate(prompts)
^^^^^^^^^^^^^^^^^^^^^^^
File "/hpc2hdd/home/sguo349/miniconda3/envs/ttl-test/lib/python3.11/site-packages/trl/trainer/grpo_trainer.py", line 1657, in _generate
prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/hpc2hdd/home/sguo349/miniconda3/envs/ttl-test/lib/python3.11/site-packages/trl/trainer/grpo_trainer.py", line 1289, in _generate_single_turn
output = self.vllm_client.chat(
^^^^^^^^^^^^^^^^^^^^^^
File "/hpc2hdd/home/sguo349/miniconda3/envs/ttl-test/lib/python3.11/site-packages/trl/extras/vllm_client.py", line 336, in chat
raise NotImplementedError("Tool calling is not yet implemented in VLLMClient.chat().")
NotImplementedError: Tool calling is not yet implemented in VLLMClient.chat().
System Info
trl-version: 0.26.0
vllm-version 0.10.2
LLM model: Qwen2.5-7B-Instruct
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete
Metadata
Metadata
Assignees
Labels
No labels