梯度下降在大模型訓練中的作用與實現

梯度下降(Gradient Descent)是深度學習中最核心的優化算法之一。大模型(如GPT、BERT)在訓練時需要優化數十億甚至上千億的參數,而梯度下降及其變體(如SGD、Adam)正是實現這一優化的關鍵工具。它通過計算損失函數相對于參數的梯度,并沿梯度負方向迭代更新參數,從而最小化損失。

梯度下降解決的問題

在大模型訓練中,我們需要最小化一個高維、非凸的損失函數。梯度下降的目標就是找到損失函數的局部甚至全局最優點,以使模型在訓練數據和測試數據上表現良好。

主要解決的問題包括:

損失最小化:通過迭代不斷減少模型預測與真實值之間的誤差。

收斂效率:改進的優化算法(如Adam)可以加速收斂。

避免困在鞍點:高維空間中鞍點比局部極小值更常見,因此優化器需具備跳出鞍點的能力。

2. 原理與數學推導

2.1 基本公式

梯度下降的更新規則為:

公式如下:

θt+1=θt?η??θL(θt) \theta_{t+1} = \theta_t - \eta \cdot \nabla_\theta L(\theta_t) θt+1?=θt??η??θ?L(θt?)

其中:

  • θ\thetaθ 是模型參數;
  • L(θ)L(\theta)L(θ) 是損失函數;
  • η\etaη 是學習率(Learning Rate);
  • ?θL\nabla_\theta L?θ?L 是損失函數對參數的梯度。

2.2 損失函數的幾何意義

損失函數可以看作一個“地形”,梯度下降就是沿著最陡峭的下坡路一步步走到山谷底部(全局或局部最小值)。


3. 梯度下降的種類與應用

算法特點適用場景
Batch GD使用全量數據,穩定但計算量大小數據集
SGD每次用一個樣本,更新快但噪聲大深度學習初期
Mini-Batch GD折中方案,批量樣本大模型訓練首選

4. 在大模型訓練中的實踐

  • 優化器:Adam / AdamW 廣泛用于 LLM 訓練;
  • Loss:交叉熵(Cross Entropy)是語言建模的常見選擇;
  • 技巧:學習率調度(Warm-up)、梯度裁剪(Gradient Clipping)、正則化(Weight Decay)。

5. 可視化示例:梯度下降過程

以下示例演示了如何用 Python + Matplotlib 畫出梯度下降在二維損失曲面上的收斂軌跡

import numpy as np
import matplotlib.pyplot as plt# 損失函數: f(x) = x^2 + 2x + 1
def loss(x):return x**2 + 2*x + 1# 梯度: f'(x) = 2x + 2
def grad(x):return 2*x + 2# 參數初始化
x = 5.0
eta = 0.2  # 學習率
history = [x]# 迭代梯度下降
for _ in range(15):x -= eta * grad(x)history.append(x)# 繪圖
xs = np.linspace(-4, 6, 100)
ys = loss(xs)plt.figure(figsize=(8,4))
plt.plot(xs, ys, label="Loss Curve")
plt.scatter(history, [loss(h) for h in history], c="red", label="Steps", zorder=5)
plt.title("Gradient Descent Optimization Path")
plt.xlabel("Parameter x")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

運行后會顯示

  • 藍色曲線:損失函數 L(x)=x2+2x+1L(x)=x^2+2x+1L(x)=x2+2x+1
  • 紅點:梯度下降的更新軌跡,逐步逼近最小值。

6. 圖示(直觀理解)

損失 L(θ)
│       ?            ← 初始參數 θ0
│     ?
│   ?
│ ?
└──────────────────────────→ 參數 θ

7. 示例:PyTorch 訓練循環(簡化版)

import torch
import torch.nn as nn
import torch.optim as optim# 簡單線性模型 y = wx + b
model = nn.Linear(1, 1)
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.01)x = torch.randn(100, 1)
y = 3 * x + 1 + 0.1 * torch.randn(100, 1)for epoch in range(100):optimizer.zero_grad()y_pred = model(x)loss = criterion(y_pred, y)loss.backward()optimizer.step()if epoch % 10 == 0:print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

這段代碼模擬了一個使用 AdamW + MSE Loss 的小型訓練過程。

7. Jupyter Notebook詳細版本

可視化與軌跡演示的demo示意

pip install numpy matplotlib torch pillow
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei']  # Mac/Windows 中文字體
matplotlib.rcParams['axes.unicode_minus'] = Falseimport numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import torch.nn as nn
import torch.optim as optim#############################
# 1. 一維梯度下降動畫
#############################def loss_1d(x):return x**2 + 2*x + 1def grad_1d(x):return 2*x + 2x_init = 5.0
eta = 0.2
steps = [x_init]
x = x_init
for _ in range(15):x -= eta * grad_1d(x)steps.append(x)xs = np.linspace(-4, 6, 200)
ys = loss_1d(xs)
plt.figure(figsize=(8,4))
plt.plot(xs, ys, label="Loss Curve")
plt.scatter(steps, [loss_1d(s) for s in steps], c="red", label="Steps", zorder=5)
plt.title("1D 梯度下降路徑")
plt.xlabel("參數 x")
plt.ylabel("損失 Loss")
plt.legend()
plt.grid(True)
plt.show()fig, ax = plt.subplots()
ax.plot(xs, ys, label="Loss Curve")
point, = ax.plot([], [], 'ro')
ax.legend()
ax.set_title("1D 梯度下降動畫")
ax.set_xlabel("參數 x")
ax.set_ylabel("損失 Loss")def init():point.set_data([], [])return point,def update(frame):x_val = steps[frame]y_val = loss_1d(x_val)point.set_data([x_val], [y_val])return point,ani = animation.FuncAnimation(fig, update, frames=len(steps), init_func=init, blit=True)
plt.close(fig)
ani.save("gradient_descent_1d.gif", writer="pillow", fps=2)#############################
# 2. 三維損失曲面 + 路徑
#############################def loss_2d(w):x, y = wreturn x**2 + y**2 + x*y + 2*x + 3*y + 5def grad_2d(w):x, y = wreturn np.array([2*x + y + 2, 2*y + x + 3])eta = 0.1
w = np.array([4.0, 4.0])
path = [w.copy()]
for _ in range(30):w -= eta * grad_2d(w)path.append(w.copy())X = np.linspace(-5, 5, 50)
Y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(X, Y)
Z = loss_2d([X, Y])fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.7)
path = np.array(path)
ax.plot(path[:,0], path[:,1], [loss_2d(p) for p in path], 'r-o')
ax.set_title("3D 損失曲面與梯度下降路徑")
plt.show()#############################
# 3. 優化器對比:SGD vs Adam
#############################torch.manual_seed(0)
X = torch.randn(200,1)
y = 3*X + 1 + 0.1*torch.randn(200,1)def build_model():return nn.Linear(1,1)def train(optimizer_type, lr=0.01):model = build_model()criterion = nn.MSELoss()optimizer = optimizer_type(model.parameters(), lr=lr)losses = []for epoch in range(50):optimizer.zero_grad()y_pred = model(X)loss = criterion(y_pred, y)loss.backward()optimizer.step()losses.append(loss.item())return lossesloss_sgd = train(optim.SGD, lr=0.05)
loss_adam = train(optim.Adam, lr=0.01)plt.figure(figsize=(8,4))
plt.plot(loss_sgd, label="SGD")
plt.plot(loss_adam, label="Adam")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("優化器收斂速度對比:SGD vs Adam")
plt.legend()
plt.grid(True)
plt.show()

在這里插入圖片描述

在這里插入圖片描述

在這里插入圖片描述

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/917334.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/917334.shtml
英文地址,請注明出處:http://en.pswp.cn/news/917334.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【JVS更新日志】開源框架、APS排產、企業計劃、物聯網、邏輯引擎7.30更新說明!

項目介紹 JVS是企業級數字化服務構建的基礎腳手架,主要解決企業信息化項目交付難、實施效率低、開發成本高的問題,采用微服務配置化的方式,提供了低代碼數據分析物聯網的核心能力產品,并構建了協同辦公、企業常用的管理工具等&…

Eclipse中導入新項目,右鍵項目沒有Run on Server,Tomcat的add and remove找不到項目

原因分析沒有勾選Dynamic Web Module、Java、JavaScriptDynamic Web Module版本問題解決方法Eclipse中右鍵項目選擇Properties左側點擊project facets勾選Dynamic Web Module、Java、JavaScript,注意Dynamic Web Module版本問題,要和tomcat版本對應。- Dynamic Web …

IntelliJ IDEA 2025系列通用軟件安裝教程(Windows版)

前言 JetBrains系列開發工具(如IntelliJ IDEA、PyCharm、WebStorm等)是程序員們非常喜愛的集成開發環境。2025年最新版本帶來了更多強大的功能和改進。本教程將詳細介紹如何在Windows系統上安裝JetBrains 2025系列軟件。 最近挖到一個寶藏級人工智能學習…

烏鶇科技前端二面

1. 你能給我介紹一下你參與的重要項目,并重點介紹一下做的內容?通俗解釋: 挑一個你覺得最拿得出手、技術含量最高的項目,說說這個項目是干什么的(比如一個電商網站、一個后臺管理系統),你在里面具體負責了…

《c++面向對象入門與實戰》筆記

前年的書,翻出來整理一下7章.指針指針 sizeof為4*指針 sizeof為 所指類型的sizeof注意free后置空,避免野指針11章.類

easyExcel生成多個sheet的動態表頭的實現

在使用 EasyExcel 實現“多個 Sheet 且每個 Sheet 表頭是動態的”需求時&#xff0c;思路如下&#xff1a;? 實現思路概述 EasyExcel 的 ExcelWriter 支持多個 Sheet 寫入。每個 Sheet&#xff1a; 使用 WriteSheet 創建&#xff1b;可以綁定一個動態生成的表頭 List<List&…

SQL 連接類型示例:內連接與外連接

SQL 連接類型示例&#xff1a;內連接與外連接 示例數據表 假設我們有兩個表&#xff1a; employees 表:emp_idemp_namedept_id1張三1012李四1023王五1034趙六NULLdepartments 表:dept_iddept_name101銷售部102技術部104財務部1. 內連接 (INNER JOIN) 內連接只返回兩個表中匹配的…

Ubuntu安裝gpu驅動,cuda

系統初始化 1、安裝基礎軟件 apt-get update apt-get -y install openssh-server openssh-client apt-utils freeipmi ipmitool sshpass ethtool zip unzip nano less git netplan.io iputils-ping mtr ipvsadm smartmontools python3-pip socat conntrack libvirt-clients li…

ctfshow_源碼壓縮包泄露

根據題目信息直接dirsearch解壓下來一個.txt文件&#xff0c;一個index.phpflag{flag_here}不對那么就去看index.php也沒有東西&#xff0c;于是查看wp發現是訪問/fl000g.txt這才是對的還有很多源碼泄露需要去了解? git源碼泄露? svn源碼泄露? DS_Store 文件泄露? 網站備份…

Python 程序設計講義(54):Python 的函數——函數概述

Python 程序設計講義&#xff08;54&#xff09;&#xff1a;Python 的函數——函數概述 目錄Python 程序設計講義&#xff08;54&#xff09;&#xff1a;Python 的函數——函數概述一、函數的類型1、內置函數2、自定義函數二、調用函數Python 提供了函數機制&#xff0c;把實…

學習Python中Selenium模塊的基本用法(3:下載瀏覽器驅動續)

前一篇文章主要介紹下載針對火狐瀏覽器的WebDriver&#xff0c;寫那篇文章時才找到能夠下最新版本Chrome的WebDriver地址&#xff08;參考文獻6&#xff09;&#xff0c;本文繼續學習并驗證針對Chrome瀏覽器的WebDriver下載和使用方法。Chrome的WebDriver版本與操作系統相關&am…

AIDL當Parcelable序列化的數據類通信時報“Class not found when unmarshalling“找不到該類時的解決方案

1. 報錯棧 &#xff1a;cusText這個類找不到 2 16:01:29.796 1044 5718 E Parcel : Class not found when unmarshalling: com.cus.sdk.cusText 08-02 16:01:29.796 1044 5718 E Parcel : java.lang.ClassNotFoundException: com.cus.sdk.cusText 08-02 16:01:29.796 1…

Django模型查詢與性能調優:告別N+1問題

文章目錄一、查詢基礎QuerySet 詳解一對多關聯查詢多對多關聯查詢二、N1查詢問題問題分析檢測方法解決方案三、高級查詢優化values()values_list()values()和values_list()對比Q() 對象復雜查詢查看生成的 SQL四、項目實戰場景實戰一、查詢基礎 QuerySet 詳解 Django 中通過模…

PyTorch 中 Tensor 統計學函數及相關概念

文章目錄PyTorch 中 Tensor 統計學函數及相關概念一、引言二、基礎統計學函數&#xff08;一&#xff09;torch.mean()——均值計算&#xff08;二&#xff09;torch.sum()——總和計算&#xff08;三&#xff09;torch.prod()——元素積計算&#xff08;四&#xff09;torch.m…

淺拷貝與深拷貝的區別

淺拷貝和深拷貝是兩種不同的對象復制方式&#xff0c;主要區別在于它們如何處理對象內部的引用類型字段。淺拷貝 (Shallow Copy)特點&#xff1a;只復制對象本身&#xff08;基本類型字段&#xff09;和對象中的引用&#xff08;地址&#xff09;不復制引用指向的實際對象原始對…

腳本統計MongoDB集合表數據量

腳本&#xff1a; #!/bin/bashipxxx.xx.xx.xx portxxxx dbxxxdb #user #passwmongo -host ${ip}:${port} <<EOF 2>/dev/null|grep -vE version|not match|session|compressors||Warning|delivers|upcoming|installation|https|switched|bye >collec use ${db}; sho…

圖漾AGV行業常用相機使用文檔

文章目錄1.圖漾相機設置IP1.1 前期準備2.FM851-E2相機2.1 FM851-E2適用場景2.2 FM851-E2 IO線和數據線定義2.2.1 IO接口定義2.2.2 數據接口線2.2.3 相機正面安裝方向2.2.4 相機IO指示燈2.3 FM851-E2/FM855-E2-7相機RGB顏色異常【解決措施1】&#xff1a;【解決措施2】&#xff…

電力系統分析學習筆記(二)- 標幺值計算與變壓器建模

電力系統分析學習筆記&#xff08;二&#xff09;- 標幺值計算與變壓器建模 1. 電力系統參數計算的基本原理 1.1 基本級的概念與選擇 基本級定義&#xff1a; 在多電壓等級的電力系統中&#xff0c;需要將所有參數歸算到同一個電壓等級這個統一的電壓等級稱為基本級 基本級選擇…

防火墻相關技術內容

防火墻的狀態檢測和會話技術一、防火墻的檢測機制早期包過濾防火墻采用逐包檢測機制&#xff0c;對每個報文獨立檢測其源地址、目的地址、端口等信息&#xff0c;根據預設規則決定轉發或丟棄。安全隱患&#xff1a;僅基于單包信息判斷&#xff0c;無法識別連接狀態。例如&#…

在 Mac 上用 Vagrant 安裝 K8s

文章目錄&#x1f4cb; 1. 環境準備1.1 系統要求1.2 軟件清單&#x1f680; 2. 安裝步驟2.1 安裝Parallels Desktop2.2 配置網絡代理&#xff08;可選&#xff09;2.3 安裝Homebrew2,4 準備項目目錄2.5 安裝Vagrant及插件2.6 配置Python環境2.6.1 安裝Python管理工具2.6.2 配置…