Skip to content

Compute metrics for generation tasks in SFTTrainer #862

@wei-ann-Github

Description

@wei-ann-Github

Hi, I want to include a custom generation based compute_metrics e.g., BLEU, to the SFTTrainer. However, I have difficulties because:

  1. The input, eval_preds, into compute_metrics contains a .predictions attribute but its value are logits rather than generations. There don't seem to be an attribute for generations.
  2. In Seq2seqTrainingArguments, there is a predict_with_generate argument for the purpose of generation based metrics. But this does not work on the SFTTrainer. So I am not able to compute the metrics that I want.

Can anyone advise how I can get the generations into the compute_metrics function, please? The trl version I am using is trl==0.5.0

Here are my TrainingArgs and SFTTrainer.

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    eval_steps=eval_steps,
    evaluation_strategy=evaluation_strategy,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=lr,
    max_steps=max_steps,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_steps=save_steps,
    save_strategy=save_strategy,
    warmup_steps=num_warmup_steps,
    weight_decay=weight_decay,
    save_total_limit=save_total_limit,
    load_best_model_at_end=True,
    metric_for_best_model=metric_for_best_model,
    greater_is_better=metric_greater_is_better,
    generation_config=generation_config,
    predict_with_generate=True,
    generation_max_length=10,
)

trainer = SFTTrainer(
    model,
    args=training_args,
    train_dataset=dataset_train,
    eval_dataset=dataset_eval,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    compute_metrics=compute_metrics,
)

Attributes of eval_preds are:

>> dir(eval_preds)
['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'inputs',
 'label_ids',
 'predictions']

Thank you in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions