生成式任務輸入就是標簽
transformers在進入compute_metrics前會有一個判斷,源碼如下:
# 版本 transformers==4.41.2
# 在trainer.py 的 3842 行
# Metrics!
if (self.compute_metrics is not Noneand all_preds is not Noneand all_labels is not Noneand not self.args.batch_eval_metrics
):if args.include_inputs_for_metrics:metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs))else:metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
elif metrics is None:metrics = {}
生成式任務如果沒有標簽字段,即labels
那么這里的all_labels is not None
就會是false,從而無法進入compute_metrics方法。
此時可以在TrainingArguments
中加入一個變量label_names
把輸入文本作為標簽,如下:
training_args = TrainingArguments(
...
label_names=['input_ids'], # 這里假設我的文本輸入叫 ‘input_ids’
...
)
這樣就可以進入compute_metrics函數了。
此外,若需要將輸入的變量傳入compute_metrics,可以在TrainingArguments
中設置include_inputs_for_metrics=True