梯度下降(Gradient Descent)是深度學習中最核心的優化算法之一。大模型(如GPT、BERT)在訓練時需要優化數十億甚至上千億的參數,而梯度下降及其變體(如SGD、Adam)正是實現這一優化的關鍵工具。它通過計算損失函數相對于參數的梯度,并沿梯度負方向迭代更新參數,從而最小化損失。
梯度下降解決的問題
在大模型訓練中,我們需要最小化一個高維、非凸的損失函數。梯度下降的目標就是找到損失函數的局部甚至全局最優點,以使模型在訓練數據和測試數據上表現良好。
主要解決的問題包括:
損失最小化:通過迭代不斷減少模型預測與真實值之間的誤差。
收斂效率:改進的優化算法(如Adam)可以加速收斂。
避免困在鞍點:高維空間中鞍點比局部極小值更常見,因此優化器需具備跳出鞍點的能力。
2. 原理與數學推導
2.1 基本公式
梯度下降的更新規則為:
公式如下:
θt+1=θt?η??θL(θt) \theta_{t+1} = \theta_t - \eta \cdot \nabla_\theta L(\theta_t) θt+1?=θt??η??θ?L(θt?)
其中:
- θ\thetaθ 是模型參數;
- L(θ)L(\theta)L(θ) 是損失函數;
- η\etaη 是學習率(Learning Rate);
- ?θL\nabla_\theta L?θ?L 是損失函數對參數的梯度。
2.2 損失函數的幾何意義
損失函數可以看作一個“地形”,梯度下降就是沿著最陡峭的下坡路一步步走到山谷底部(全局或局部最小值)。
3. 梯度下降的種類與應用
算法 | 特點 | 適用場景 |
---|---|---|
Batch GD | 使用全量數據,穩定但計算量大 | 小數據集 |
SGD | 每次用一個樣本,更新快但噪聲大 | 深度學習初期 |
Mini-Batch GD | 折中方案,批量樣本 | 大模型訓練首選 |
4. 在大模型訓練中的實踐
- 優化器:Adam / AdamW 廣泛用于 LLM 訓練;
- Loss:交叉熵(Cross Entropy)是語言建模的常見選擇;
- 技巧:學習率調度(Warm-up)、梯度裁剪(Gradient Clipping)、正則化(Weight Decay)。
5. 可視化示例:梯度下降過程
以下示例演示了如何用 Python + Matplotlib 畫出梯度下降在二維損失曲面上的收斂軌跡。
import numpy as np
import matplotlib.pyplot as plt# 損失函數: f(x) = x^2 + 2x + 1
def loss(x):return x**2 + 2*x + 1# 梯度: f'(x) = 2x + 2
def grad(x):return 2*x + 2# 參數初始化
x = 5.0
eta = 0.2 # 學習率
history = [x]# 迭代梯度下降
for _ in range(15):x -= eta * grad(x)history.append(x)# 繪圖
xs = np.linspace(-4, 6, 100)
ys = loss(xs)plt.figure(figsize=(8,4))
plt.plot(xs, ys, label="Loss Curve")
plt.scatter(history, [loss(h) for h in history], c="red", label="Steps", zorder=5)
plt.title("Gradient Descent Optimization Path")
plt.xlabel("Parameter x")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()
運行后會顯示:
- 藍色曲線:損失函數 L(x)=x2+2x+1L(x)=x^2+2x+1L(x)=x2+2x+1
- 紅點:梯度下降的更新軌跡,逐步逼近最小值。
6. 圖示(直觀理解)
損失 L(θ)
│ ? ← 初始參數 θ0
│ ?
│ ?
│ ?
└──────────────────────────→ 參數 θ
7. 示例:PyTorch 訓練循環(簡化版)
import torch
import torch.nn as nn
import torch.optim as optim# 簡單線性模型 y = wx + b
model = nn.Linear(1, 1)
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.01)x = torch.randn(100, 1)
y = 3 * x + 1 + 0.1 * torch.randn(100, 1)for epoch in range(100):optimizer.zero_grad()y_pred = model(x)loss = criterion(y_pred, y)loss.backward()optimizer.step()if epoch % 10 == 0:print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
這段代碼模擬了一個使用 AdamW + MSE Loss 的小型訓練過程。
7. Jupyter Notebook詳細版本
可視化與軌跡演示的demo示意
pip install numpy matplotlib torch pillow
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei'] # Mac/Windows 中文字體
matplotlib.rcParams['axes.unicode_minus'] = Falseimport numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import torch.nn as nn
import torch.optim as optim#############################
# 1. 一維梯度下降動畫
#############################def loss_1d(x):return x**2 + 2*x + 1def grad_1d(x):return 2*x + 2x_init = 5.0
eta = 0.2
steps = [x_init]
x = x_init
for _ in range(15):x -= eta * grad_1d(x)steps.append(x)xs = np.linspace(-4, 6, 200)
ys = loss_1d(xs)
plt.figure(figsize=(8,4))
plt.plot(xs, ys, label="Loss Curve")
plt.scatter(steps, [loss_1d(s) for s in steps], c="red", label="Steps", zorder=5)
plt.title("1D 梯度下降路徑")
plt.xlabel("參數 x")
plt.ylabel("損失 Loss")
plt.legend()
plt.grid(True)
plt.show()fig, ax = plt.subplots()
ax.plot(xs, ys, label="Loss Curve")
point, = ax.plot([], [], 'ro')
ax.legend()
ax.set_title("1D 梯度下降動畫")
ax.set_xlabel("參數 x")
ax.set_ylabel("損失 Loss")def init():point.set_data([], [])return point,def update(frame):x_val = steps[frame]y_val = loss_1d(x_val)point.set_data([x_val], [y_val])return point,ani = animation.FuncAnimation(fig, update, frames=len(steps), init_func=init, blit=True)
plt.close(fig)
ani.save("gradient_descent_1d.gif", writer="pillow", fps=2)#############################
# 2. 三維損失曲面 + 路徑
#############################def loss_2d(w):x, y = wreturn x**2 + y**2 + x*y + 2*x + 3*y + 5def grad_2d(w):x, y = wreturn np.array([2*x + y + 2, 2*y + x + 3])eta = 0.1
w = np.array([4.0, 4.0])
path = [w.copy()]
for _ in range(30):w -= eta * grad_2d(w)path.append(w.copy())X = np.linspace(-5, 5, 50)
Y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(X, Y)
Z = loss_2d([X, Y])fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.7)
path = np.array(path)
ax.plot(path[:,0], path[:,1], [loss_2d(p) for p in path], 'r-o')
ax.set_title("3D 損失曲面與梯度下降路徑")
plt.show()#############################
# 3. 優化器對比:SGD vs Adam
#############################torch.manual_seed(0)
X = torch.randn(200,1)
y = 3*X + 1 + 0.1*torch.randn(200,1)def build_model():return nn.Linear(1,1)def train(optimizer_type, lr=0.01):model = build_model()criterion = nn.MSELoss()optimizer = optimizer_type(model.parameters(), lr=lr)losses = []for epoch in range(50):optimizer.zero_grad()y_pred = model(X)loss = criterion(y_pred, y)loss.backward()optimizer.step()losses.append(loss.item())return lossesloss_sgd = train(optim.SGD, lr=0.05)
loss_adam = train(optim.Adam, lr=0.01)plt.figure(figsize=(8,4))
plt.plot(loss_sgd, label="SGD")
plt.plot(loss_adam, label="Adam")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("優化器收斂速度對比:SGD vs Adam")
plt.legend()
plt.grid(True)
plt.show()