深入解析 Loss 減少方式:mean和sum的區別及其在大語言模型中的應用 (中英雙語)

深入解析 Loss 減少方式:meansum 的區別及其在大語言模型中的應用

在訓練大語言模型(Large Language Models, LLM)時,損失函數(Loss Function)的處理方式對模型的性能和優化過程有顯著影響。本文以 reduce_loss 參數為例,詳細探討 meansum 兩種方式的定義、適用場景及其對對話模型性能的潛在提升原因,并通過代碼實例加深理解。


1. 什么是 reduce_loss

reduce_loss 決定了在每個 batch 中,如何對 token-level 的損失進行歸一化或累加處理。常見的選項是:

  • mean: 取每個 token 損失的平均值。
  • sum: 將每個 token 損失直接累加。

參數定義示例(在代碼中通過 dataclass 定義):參考來源:https://github.com/allenai/open-instruct

from dataclasses import dataclass, field@dataclass
class TrainingArguments:reduce_loss: str = field(default="mean",metadata={"help": ("How to reduce loss over tokens. Options are 'mean' or 'sum'.""Using 'sum' can improve chat model performance.")},)

2. meansum 的定義

2.1 mean 模式
  • 定義:將 batch 中所有 token 的損失值取平均。
  • 公式
    Loss mean = ∑ i = 1 N Loss i N \text{Loss}_{\text{mean}} = \frac{\sum_{i=1}^{N} \text{Loss}_i}{N} Lossmean?=Ni=1N?Lossi??
    其中 ( N N N) 是當前 batch 中的 token 總數。
  • 特性:每個 token 的損失對最終的 loss 貢獻相等,損失值與 batch 中的 token 數無關。
2.2 sum 模式
  • 定義:將 batch 中所有 token 的損失值直接累加。
  • 公式
    Loss sum = ∑ i = 1 N Loss i \text{Loss}_{\text{sum}} = \sum_{i=1}^{N} \text{Loss}_i Losssum?=i=1N?Lossi?
  • 特性:長序列(更多 token)的損失對總 loss 的貢獻更大,損失值直接與 token 數成正比。

3. meansum 的區別

模式特點優點缺點
mean損失對 token 數歸一化,獨立于 batch size。穩定性強,適用于 token 數差異大的批次。長序列與短序列對損失的貢獻相同,可能弱化長序列的重要性。
sum損失值與 token 總數成正比,長序列貢獻更大。在注重長序列表現的任務中效果更好(如對話生成)。損失值隨 batch size 變化波動,需要動態調整學習率。

4. 適用場景分析

4.1 mean
  • 適用任務:大多數語言建模任務,如 GPT 或 BERT 的預訓練。
  • 適用場景:當訓練數據中序列長度差異較大時,mean 可以避免因長序列的損失值過大而導致梯度更新不均衡。
4.2 sum
  • 適用任務:對長序列表現要求較高的任務,如對話生成(Chat Models)和長文本生成。
  • 適用場景:長序列的損失占比更高,從而使優化過程更加關注全局上下文的建模。

5. 為什么 sum 能提升對話模型性能?

對話模型(Chat Models)的訓練中,長序列往往包含豐富的上下文信息,而短序列則可能無法體現模型的上下文理解能力。在 sum 模式下:

  1. 長序列的重要性增加:長序列的損失對總損失的貢獻更大,這促使模型更關注上下文的建模。
  2. 對全局一致性更敏感sum 模式下,模型的優化方向更傾向于全序列的一致性,特別適合需要長距離依賴的任務。

示例
假設一個 batch 包含以下兩個樣本:

  • 樣本 A: 長度為 10,損失總和為 5。
  • 樣本 B: 長度為 50,損失總和為 25。

計算損失貢獻:

  • mean 模式
    Loss mean = 5 + 25 10 + 50 = 0.5 \text{Loss}_{\text{mean}} = \frac{5 + 25}{10 + 50} = 0.5 Lossmean?=10+505+25?=0.5
    樣本 A 和 B 的貢獻權重相同。
  • sum 模式
    Loss sum = 5 + 25 = 30 \text{Loss}_{\text{sum}} = 5 + 25 = 30 Losssum?=5+25=30
    樣本 B 的貢獻權重顯著增加,優化更關注長序列。

6. 實戰代碼

以下是一個完整的訓練腳本,展示如何在 Hugging Face 的 transformers 框架中使用 reduce_loss 參數。

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch# 模型和數據集
model_name = "meta-llama/Llama-3.1-8B"
dataset_name = "allenai/tulu-3-sft-mixture"model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)dataset = load_dataset(dataset_name)
tokenized_dataset = dataset.map(lambda x: tokenizer(x['text'], truncation=True, padding="max_length"), batched=True)
train_loader = DataLoader(tokenized_dataset["train"], batch_size=2, shuffle=True)# 訓練設置
reduce_loss = "sum"  # 改為 "mean" 可對比效果
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 訓練循環
for epoch in range(2):for batch in train_loader:inputs = torch.tensor(batch["input_ids"]).to(device)labels = inputs.clone()outputs = model(inputs, labels=labels)if reduce_loss == "sum":loss = outputs.loss.sum()else:  # 默認 "mean"loss = outputs.loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()print(f"Epoch: {epoch}, Loss: {loss.item()}")

7. 注意事項與優化建議

  1. 動態調整學習率

    • 使用 sum 時,由于損失值放大,建議適配學習率,如降低到 mean 模式的 ( 1 / N 1/N 1/N )。
    • 配合學習率調度器(如 linear)優化訓練。
  2. 對長短序列的平衡

    • 若長序列權重過大導致模型性能退化,可結合 curriculum learning 或混合訓練策略(如對長短序列按比例采樣)。
  3. 性能評估

    • 在驗證集上,關注長序列和短序列的生成性能對比。

8. 總結

reduce_loss 的選擇對模型性能有直接影響:

  • mean 更通用,適合大多數語言建模任務。
  • sum 在對話生成等長序列敏感任務中表現更優。

希望本文能為 LLM 研究人員提供思路和參考,在具體任務中靈活選擇合適的損失歸一化方式,從而提升模型性能。

Understanding the Difference Between mean and sum Loss Reduction in LLM Training

When training large language models (LLMs), the way token-level loss is reduced across a batch can significantly impact optimization and model performance. This article delves into the reduce_loss parameter, exploring the differences between mean and sum reduction modes, their definitions, use cases, and why sum might improve the performance of chat-oriented models. Practical code examples are also provided for clarity.


1. What is reduce_loss?

The reduce_loss parameter determines how the token-level loss values in a batch are aggregated. The two most common options are:

  • mean: Averages the loss over all tokens in a batch.
  • sum: Sums the loss of all tokens in a batch.

Example definition (from the codebase using Python dataclass):

from dataclasses import dataclass, field@dataclass
class TrainingArguments:reduce_loss: str = field(default="mean",metadata={"help": ("How to reduce loss over tokens. Options are 'mean' or 'sum'.""Using 'sum' can improve chat model performance.")},)

2. Definitions of mean and sum

2.1 mean
  • Definition: Averages the loss across all tokens in a batch.
  • Formula:
    Loss mean = ∑ i = 1 N Loss i N \text{Loss}_{\text{mean}} = \frac{\sum_{i=1}^{N} \text{Loss}_i}{N} Lossmean?=Ni=1N?Lossi??
    where ( N N N ) is the total number of tokens in the batch.
  • Characteristics: The contribution of each token to the final loss is normalized, making the loss independent of the batch’s token count.
2.2 sum
  • Definition: Sums up the loss across all tokens in a batch.
  • Formula:
    Loss sum = ∑ i = 1 N Loss i \text{Loss}_{\text{sum}} = \sum_{i=1}^{N} \text{Loss}_i Losssum?=i=1N?Lossi?
  • Characteristics: The total loss is proportional to the number of tokens, giving longer sequences more weight in the optimization process.

3. Key Differences Between mean and sum

Reduction ModeCharacteristicsAdvantagesDisadvantages
meanNormalizes the loss by token count.Stable and robust for datasets with variable-length sequences.Long sequences are underweighted relative to short ones.
sumLoss scales with the number of tokens.Places greater emphasis on longer sequences, improving performance in tasks requiring context modeling.Loss values vary with batch size, necessitating dynamic learning rate adjustment.

4. Use Cases for mean and sum

4.1 mean
  • Best Suited For: Pretraining or general language modeling tasks like GPT or BERT.
  • Scenario: When the dataset contains sequences of widely varying lengths, mean ensures that longer sequences do not disproportionately influence gradient updates.
4.2 sum
  • Best Suited For: Tasks requiring high performance on long sequences, such as dialogue generation or document-level text generation.
  • Scenario: Encourages the model to prioritize sequences with richer contexts, as their loss contributes more to the overall optimization.

5. Why Does sum Improve Chat Model Performance?

In chat-oriented models, sequences are typically longer and require the model to understand and generate coherent responses over extended contexts. Using sum mode:

  1. Enhances Long Sequence Weighting: Longer sequences contribute more to the total loss, emphasizing the importance of context modeling.
  2. Encourages Global Consistency: By assigning more weight to longer contexts, the model better captures dependencies across the entire sequence.
  3. Balances Token Importance: Since chat models are often evaluated on dialogue-level coherence, sum ensures that tokens from the context and the response are proportionally weighted.

Example:
Consider a batch with two samples:

  • Sample A: Sequence length = 10, loss = 5.
  • Sample B: Sequence length = 50, loss = 25.

Loss calculations:

  • mean mode:
    Loss mean = 5 + 25 10 + 50 = 0.5 \text{Loss}_{\text{mean}} = \frac{5 + 25}{10 + 50} = 0.5 Lossmean?=10+505+25?=0.5
    Both samples contribute equally to the loss.
  • sum mode:
    Loss sum = 5 + 25 = 30 \text{Loss}_{\text{sum}} = 5 + 25 = 30 Losssum?=5+25=30
    Sample B contributes much more to the total loss, focusing the optimization on longer contexts.

6. Practical Implementation

Here’s a practical training script that demonstrates the use of reduce_loss in both modes.

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch# Model and dataset
model_name = "meta-llama/Llama-3.1-8B"
dataset_name = "allenai/tulu-3-sft-mixture"model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)dataset = load_dataset(dataset_name)
tokenized_dataset = dataset.map(lambda x: tokenizer(x['text'], truncation=True, padding="max_length"), batched=True)
train_loader = DataLoader(tokenized_dataset["train"], batch_size=2, shuffle=True)# Training setup
reduce_loss = "sum"  # Change to "mean" to compare effects
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# Training loop
for epoch in range(2):for batch in train_loader:inputs = torch.tensor(batch["input_ids"]).to(device)labels = inputs.clone()outputs = model(inputs, labels=labels)if reduce_loss == "sum":loss = outputs.loss.sum()else:  # Default: "mean"loss = outputs.loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()print(f"Epoch: {epoch}, Loss: {loss.item()}")

7. Practical Considerations

  1. Learning Rate Adjustment:

    • When using sum, the loss magnitude increases with batch size, so you may need to adjust the learning rate (e.g., scale it down by ( 1 / N 1/N 1/N )).
  2. Balancing Long and Short Sequences:

    • Overweighting long sequences can sometimes harm generalization. Using curriculum learning or sampling strategies (e.g., proportional sampling) can help mitigate this.
  3. Validation:

    • Evaluate model performance on both short and long sequences to confirm improvements in the intended metrics.

8. Conclusion

The choice between mean and sum loss reduction modes depends on the specific task and dataset:

  • Use mean for general-purpose language modeling tasks where sequence lengths vary significantly.
  • Use sum for tasks that prioritize long-sequence performance, such as chat models or long-text generation.

Understanding and experimenting with these settings can lead to better-optimized models, particularly in the nuanced field of LLM fine-tuning.

后記

2024年12月3日16點04分于上海,在GPT4o大模型輔助下完成。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/62183.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/62183.shtml
英文地址,請注明出處:http://en.pswp.cn/web/62183.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

基于 AutoFlow 快速搭建基于 TiDB 向量搜索的本地知識庫問答機器人

導讀 本文將詳細介紹如何通過 PingCAP 開源項目 AutoFlow 實現快速搭建基于 TiDB 的本地知識庫問答機器人。如果提前準備好 Docker、TiDB 環境,整個搭建過程估計在 10 分鐘左右即可完成,無須開發任何代碼。 文中使用一篇 TiDB 文檔作為本地數據源作為示…

生信技能63 - 構建gnomAD變異位點的SQLite查詢數據庫

將數據量巨大的gnomAD數據庫,通過SQLite數據庫尋找gnomAD中存在的各種變異注釋信息(如等位基因計數,深度,次要等位基因頻率等),查詢300.000個變量的查詢需要大約40秒,通過染色體編號+位置+REF+ALT即可進行快速查詢。 1. gnomAD變異注釋VCF文件字段 gnomAD VCF各版本包…

【前端】將vue的方法掛載到window上供全局使用,也方便跟原生js做交互

【前端】將vue的方法掛載到window上供全局使用&#xff0c;也方便跟原生js做交互 <template><div><el-button click"start">調用方法</el-button></div> </template> <script> // import { JScallbackProc } from ./JScal…

基于XML的AOP開發

AOP 為 Aspect Oriented Programming 的縮寫&#xff0c;意思為面向切面編程。 AOP相關術語&#xff1a; 目標對象(Target)&#xff1a; 你要去代理的對象&#xff0c;可以理解為之前很單純的那個對象。 代理對象(Proxy)&#xff1a; 你把你那個單純的對象給我&#xff0c…

記錄blender學習過程中遇到的問題

物體發射的方向不對 被發射物體&#xff08;例如一棵樹&#xff09;n鍵看旋轉歸0 切換正視圖 將被發射物體的局部坐標的Z軸 指向 全局方向的X軸時 并且把粒子系統設置的物體旋轉勾選上 方向就對了 做倒角發現有問題 檢查縮放應用、面朝向、有沒有重合點&#xff08;融合點&am…

Ubuntu系統中Redis的安裝步驟及服務配置

目錄 內容概括 系統環境 安裝方式 1、apt包管理器安裝 &#xff08;1&#xff09;安裝redis服務 &#xff08;2&#xff09;安裝客戶端&#xff08;進入命令行操作使用&#xff0c;包含redis-cli&#xff09; &#xff08;3&#xff09;安裝檢驗 &#xff08;4&#xf…

半導體設備中的微型導軌應如何選擇合適的潤滑油?

微型導軌的潤滑對于保證其高精度和高穩定性至關重要&#xff0c;尤其是在半導體設備中&#xff0c;微型導軌的潤滑油選擇需要考慮多個因素&#xff0c;以確保設備的最佳性能和壽命。以下是一些關鍵點&#xff1a; 1、黏度&#xff1a;潤滑油的黏度是影響其流動性和潤滑效果的重…

RocketMq詳解:六、RocketMq的負載均衡機制

上一章&#xff1a;《SpringBootAop實現RocketMq的冪等》 文章目錄 1.背景1.1 什么是負載均衡1.2 負載均衡的意義 2.RocketMQ消息消費2.1 消息的流轉過程2.2 Consumer消費消息的流程 3.RocketMq的負載均衡策略3.1 Broker負載均衡3.2 Producer發送消息負載均衡3.3 消費端的負載均…

yocto的xxx.bb文件在什么時候會拷貝文件到build目錄

在 Yocto 中&#xff0c;.bb 文件用于描述如何構建和安裝一個軟件包&#xff0c;而文件在構建過程中的拷貝操作通常會在某些特定的步驟中進行。具體來說&#xff0c;文件會在以下幾個階段被拷貝到 build 目錄&#xff08;或者更準確地說&#xff0c;拷貝到目標目錄 ${D}&#x…

主打極致性價比,AMD RX 8600/8800顯卡定了

*以下內容僅為網絡爆料及傳聞&#xff0c;一切以官方消息為準。 這誰能想到&#xff0c;率先掏出下一代桌面獨立顯卡的不是老大哥 NVIDIA&#xff0c;也不是 AMD&#xff0c;反而是三家中存在感最弱的 Intel&#xff01; 就在 12 月 3 日&#xff0c;Intel 正式發布了自家第二…

數組哪些方法會觸發Vue監聽,哪些不會觸發監聽

發現寶藏 前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。【寶藏入口】。 在 Vue 中&#xff0c;數組的變化是通過 響應式 系統來監聽的。Vue 使用 getter 和 setter 來追蹤數組的變化&#xff0c;并在數…

npm, yarn, pnpm之間的區別

前言 在現代化的開發中&#xff0c;一個人可能同時開發多個項目&#xff0c;安裝的項目越來越多&#xff0c;所隨之安裝的依賴包也越來越臃腫&#xff0c;而且有時候所安裝的速度也很慢&#xff0c;甚至會安裝失敗。 因此我們就需要去了解一下&#xff0c;我們的包管理器&#…

工業檢測基礎-工業相機選型及應用場景

以下是一些常見的工業檢測相機種類、檢測原理、應用場景及選型依據&#xff1a; 2D相機 檢測原理&#xff1a;基于二維圖像捕獲&#xff0c;通過分析圖像的明暗、紋理、顏色等信息來檢測物體的特征和缺陷.應用場景&#xff1a;廣泛應用于平面工件的外觀檢測&#xff0c;如檢測…

C語言連接數據庫

文章目錄 一、初始化數據庫二、創建數據庫連接三、執行增刪改查語句1、增刪改2、查 四、執行增刪改查語句 接下來我簡單的介紹一下怎么用C語言連接數據庫。 初始化數據庫創建數據庫連接執行增刪改查語句關閉數據庫連接 一、初始化數據庫 // 數據庫初始化 MYSQL mysql; MYSQL* r…

優化LabVIEW數據運算效率的方法

在LabVIEW中進行大量數據運算時&#xff0c;提升計算效率并減少時間占用是開發過程中常遇到的挑戰。為此&#xff0c;可以從多個角度著手優化&#xff0c;包括合理選擇數據結構與算法、并行處理、多線程技術、硬件加速、內存管理和界面優化等。通過采用這些策略&#xff0c;可以…

開源模型應用落地-安全合規篇-用戶輸入價值觀判斷(四)

一、前言 在深度合規功能中,對用戶輸入內容的價值觀判斷具有重要意義。這一功能不僅僅是對信息合法性和合規性的簡單審核,更是對信息背后隱含的倫理道德和社會責任的深刻洞察。通過對價值觀的判斷,系統能夠識別可能引發不當影響或沖突的內容,從而為用戶提供更安全、更和諧的…

計算機的錯誤計算(一百七十六)

摘要 利用某一大語言模型計算 的值&#xff0c;輸出為 0 . 例1. 在某一大語言模型下&#xff0c;計算 的值。其中sin中值取弧度。結果保留16位有效數字。 直接貼圖吧&#xff1a; 點評&#xff1a; &#xff08;1&#xff09;以上為一個大模型給的答案。從其回答可知&…

數據結構與算法——1204—遞歸分治法

1、斐波那契數列優化 使用滾動變量&#xff0c;保存當前計算結果和前兩項值 (1)RAB (2)更新計算對象&#xff0c;AB&#xff0c;BR #include<iostream> using namespace std;int fun(int n) {if (n 0)return 0;if (n 1 || n 2)return 1;int num11;int num21;int su…

openstack內部rpc消息通信源碼分析

我們知道openstack內部消息隊列基于AMQP協議&#xff0c;默認使用的rabbitmq 消息隊列。談到rabbitmq&#xff0c;大家或許并不陌生&#xff0c;但或許會對oslo message有些陌生。openstack內部并不是直接使用rabbitmq&#xff0c;而是使用了oslo.message 。oslo.message 后端的…

Python 3 和 MongoDB 的集成使用

Python 3 和 MongoDB 的集成使用 MongoDB 是一個流行的 NoSQL 數據庫&#xff0c;以其靈活的數據模型和強大的查詢功能而聞名。Python 3 作為一種廣泛使用的編程語言&#xff0c;與 MongoDB 的集成變得日益重要。本文將介紹如何在 Python 3 環境中集成和使用 MongoDB&#xff…