量化(Quantization)這詞兒聽著玄,經常和量化交易Quantitative Trading (量化交易)混淆。
其實機器學習(深度學習)領域的量化Quantization是和節約內存、提高運算效率相關的概念(因大模型的普及,這個量化問題尤為迫切)。
揭秘機器學習“量化”:不止省錢,更讓AI高效跑起來!
“量化(Quantization)”這個詞,在機器學習領域,常常讓人聯想到復雜的數學或是與金融交易相關的“量化交易”,從而感到困惑。但實際上,它與我們日常生活中的數字轉換概念更為接近,而在AI世界里,它扮演的角色是節約內存、提高運算效率的“幕后英雄”(現在已經顯露到“幕布之前”),尤其在大模型時代,其重要性日益凸顯。
那么,機器學習中的“量化”究竟是啥?咱為啥用它?
什么是機器學習中的“量化”(Quantization)?
簡單講,機器學習中的“量化”就是將模型中原本采用高精度浮點數(如32位浮點數,即FP32)表示的權重(weights)和激活值(activations),轉換成低精度表示(如8位整數,即INT8)的過程。
你可以把它想象成“數字的壓縮”。在計算機中,浮點數就像是擁有無限小數位的精確數字,而整數則像只有整數部分的數字。從高精度浮點數到低精度整數的轉換,必然會損失一些信息,但與此同時,它也帶來了顯著的優勢:
- 內存占用大幅減少: 8位整數比32位浮點數少占用4倍的內存空間。這意味著更大的模型可以被部署到內存有限的設備上(如手機、IoT設備),或者在相同內存下可以運行更大的模型。
- 計算速度顯著提升: 整數運算通常比浮點數運算更快、功耗更低。這使得模型在推理(Inference)階段能以更高的效率運行,減少延遲。
為何需要“量化”?
隨著深度學習模型變得越來越大,越來越復雜,它們對計算資源的需求也呈爆炸式增長。一個動輒幾十億甚至上百億參數的大模型,如果全部使用FP32存儲和計算,將對硬件資源提出極高的要求。
- 部署到邊緣設備: 手機、自動駕駛汽車、智能音箱等邊緣設備通常算力有限,內存緊張。量化是讓大模型“瘦身”后成功“登陸”這些設備的必經之路。
- 降低運行成本: 在云端部署大模型時,更低的內存占用和更快的計算速度意味著更低的服務器成本和能耗。
- 提升用戶體驗: 實時響應的AI應用,如語音助手、圖像識別等,對推理速度有極高要求。量化可以有效縮短響應時間。
量化策略:后訓練量化 vs. 量化感知訓練(QAT)
量化并非只有一種方式。根據量化發生的時間點,主要可以分為兩大類:
-
后訓練量化(Post-Training Quantization, PTQ): 顧名思義,PTQ 是在模型訓練完成之后,對已經訓練好的FP32模型進行量化。它操作簡單,不需要重新訓練,是實現量化的最快途徑。然而,由于量化過程中會損失精度,PTQ 可能會導致模型性能(如準確率)的下降。對于對精度要求不那么苛刻的應用,PTQ 是一個不錯的選擇。
-
量化感知訓練*(這是本文重點推介的:Quantization Aware Training, QAT): 這正是我們今天著重講解的明星策略!QAT 的核心思想是——在模型訓練過程中,就“感知”到未來的量化操作。
在QAT中,量化誤差被集成到模型的訓練循環中。這意味著,模型在訓練時就“知道”它最終會被量化成低精度,并會努力學習如何在這種低精度下保持最優性能。
具體來說,QAT通常通過在模型中插入“偽量化”(Fake Quantization)節點來實現。這些節點在訓練過程中模擬量化和反量化操作,使得模型在FP32環境下進行前向傳播和反向傳播時,能夠學習到量化對模型參數和激活值的影響。當訓練完成后,這些偽量化節點會被真正的量化操作所取代,從而得到一個高性能的量化模型。
為什么QAT是量化策略的“王牌”?
相較于PTQ……QAT 的優勢顯而易見:
- 精度損失最小: 這是QAT最大的亮點。通過在訓練過程中模擬量化,模型能夠自我調整以適應量化帶來的精度損失,從而在量化后依然保持接近FP32模型的性能。
- 適用于更苛刻的場景: 對于那些對模型精度要求極高,不能容忍明顯性能下降的應用(如自動駕駛、醫療影像分析),QAT幾乎是唯一的選擇。
- 更好的泛化能力: 在訓練階段就考慮量化,使得模型在量化后對各種輸入數據具有更好的魯棒性。
PyTorch中的QAT實踐
在PyTorch中實現QAT,通常需要以下幾個關鍵步驟:
- 準備量化配置: 定義量化類型(如INT8)、量化方法(如對稱量化、非對稱量化)以及需要量化的模塊。
- 模型轉換: 使用PyTorch提供的
torch.quantization
模塊,將普通的FP32模型轉換為QAT模型。這個過程會在模型中插入偽量化模塊。 - 重新訓練/微調: 在新的數據集上對轉換后的模型進行短時間的微調(Fine-tuning),或者在原有訓練基礎上繼續訓練。這個階段,模型會學習如何適應偽量化帶來的精度損失。
- 模型融合(可選但推薦): 將一些連續的層(如Conv-BN-ReLU)融合為一個操作,可以進一步提高量化后的推理效率。
- 模型量化和保存: 訓練完成后,將微調好的QAT模型轉換為真正的量化模型,并保存。
總結
量化(Quantization)是深度學習模型優化不可或缺的一環,它通過降低模型精度來換取內存和計算效率的大幅提升。而量化感知訓練(QAT)作為一種高級量化策略,通過在訓練階段就考慮量化對模型的影響,極大地減小了量化帶來的精度損失,使得在各種設備上部署高性能AI模型成為可能。
隨著大模型和邊緣AI的普及,掌握量化尤其是QAT的原理和實踐,將成為每一位AI工程師和研究人員的必備技能。讓我們一起,讓AI跑得更快、更高效!
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
import numpy as np
import os# ===== 1. XOR數據集 =====
X = torch.tensor([[0., 0.],[0., 1.],[1., 0.],[1., 1.]
], dtype=torch.float32)
y = torch.tensor([[0.],[1.],[1.],[0.]
], dtype=torch.float32)# ===== 2. 神經網絡模型 (標準FP32) =====
class XORNet(nn.Module):def __init__(self):super(XORNet, self).__init__()self.fc1 = nn.Linear(2, 3)self.relu = nn.ReLU()self.fc2 = nn.Linear(3, 1)# QAT階段不用Sigmoid,直接用BCEWithLogitsLoss# self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)# return self.sigmoid(x)return x# ===== 3. 初始化模型/優化器 =====
model = XORNet()
# He初始化,適合ReLU
for m in model.modules():if isinstance(m, nn.Linear):nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.05)# ===== 4. 訓練(標準模型) =====
print("--- 開始標準模型訓練 ---")
epochs = 900#1000#500 #1500
for epoch in range(epochs):outputs = model(X)loss = criterion(outputs, y)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 300 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')with torch.no_grad():probs = torch.sigmoid(model(X))predictions = (probs > 0.5).float()accuracy = (predictions == y).sum().item() / y.numel()print(f"\n標準模型訓練后精度: {accuracy*100:.2f}%")print(f"標準模型預測結果:\n{predictions}")# ===== 5. 構建QAT模型 =====
class XORNetQAT(nn.Module):def __init__(self):super(XORNetQAT, self).__init__()# 量化Stubself.quant = torch.quantization.QuantStub()self.fc1 = nn.Linear(2, 3)self.relu = nn.ReLU()self.fc2 = nn.Linear(3, 1)self.dequant = torch.quantization.DeQuantStub()def forward(self, x):x = self.quant(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.dequant(x)return xdef fuse_model(self):torch.quantization.fuse_modules(self, [['fc1', 'relu']], inplace=True)# ===== 6. QAT前權重遷移、模型融合 =====
model_qat = XORNetQAT()
# 遷移參數
model_qat.load_state_dict(model.state_dict())
# 融合(此步必須!)
model_qat.fuse_model()# 配置QAT
model_qat.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # CPU
torch.quantization.prepare_qat(model_qat, inplace=True)optimizer_qat = optim.Adam(model_qat.parameters(), lr=0.01)# ===== 7. QAT訓練 =====
qat_epochs = 700
print("\n--- QAT訓練 ---")
for epoch in range(qat_epochs):model_qat.train()outputs_qat = model_qat(X)loss_qat = criterion(outputs_qat, y)optimizer_qat.zero_grad()loss_qat.backward()optimizer_qat.step()if (epoch + 1) % 150 == 0:print(f'QAT Epoch [{epoch+1}/{qat_epochs}], Loss: {loss_qat.item():.4f}')# ===== 8. 轉換到量化模型/評估精度 =====
print("\n--- 轉換為量化模型 ---")
model_qat.eval()
model_quantized = torch.quantization.convert(model_qat.eval(), inplace=False)with torch.no_grad():probs_quantized = torch.sigmoid(model_quantized(X))predictions_quantized = (probs_quantized > 0.5).float()accuracy_quantized = (predictions_quantized == y).sum().item() / y.numel()print(f"量化模型精度: {accuracy_quantized*100:.2f}%")print(f"量化模型預測:\n{predictions_quantized}")# ===== 9. 模型大小對比 =====
torch.save(model.state_dict(), 'xor_fp32.pth')
torch.save(model_quantized.state_dict(), 'xor_int8.pth')
fp32_size = os.path.getsize('xor_fp32.pth') / (1024 * 1024)
int8_size = os.path.getsize('xor_int8.pth') / (1024 * 1024)
print(f"\nFP32模型大小: {fp32_size:.6f} MB")
print(f"INT8模型大小: {int8_size:.6f} MB")
print(f"模型縮減比例: {fp32_size/int8_size:.2f} 倍")
運行結果:
====================== RESTART: F:/qatXorPytorch250501.py?
--- 開始標準模型訓練 ---
Epoch [300/900], Loss: 0.0033
Epoch [600/900], Loss: 0.0011
Epoch [900/900], Loss: 0.0005
標準模型訓練后精度: 100.00%
標準模型預測結果:
tensor([[0.],
? ? ? ? [1.],
? ? ? ? [1.],
? ? ? ? [0.]])
--- QAT訓練 ---
QAT Epoch [150/700], Loss: 0.0005
QAT Epoch [300/700], Loss: 0.0004
QAT Epoch [450/700], Loss: 0.0004
QAT Epoch [600/700], Loss: 0.0003
--- 轉換為量化模型 ---
量化模型精度: 100.00%
量化模型預測:
tensor([[0.],
? ? ? ? [1.],
? ? ? ? [1.],
? ? ? ? [0.]])
FP32模型大小: 0.001976 MB
INT8模型大小: 0.004759 MB
模型縮減比例: 0.42 倍
最后:
其實,這次量化
量化后的模型是量化前的 4.2倍(咦?不是說好了壓縮嗎?咋變大了?)
魔鬼藏在細節:
咱們看看 量化 之前的 (基線的)模型 參數+權重等等:
===== Model Architecture =====
XORNet(
? (fc1): Linear(in_features=2, out_features=3, bias=True)
? (relu): ReLU()
? (fc2): Linear(in_features=3, out_features=1, bias=True)
)
===== Layer Parameters =====
[fc1.weight] shape: (3, 2)
[[ 1.723932,? 1.551827],
?[ 2.106917,? 1.681809],
?[-0.299378, -0.444912]]
[fc1.bias] shape: (3,)
[-1.725313, -2.509506,? 0.? ? ? ]
[fc2.weight] shape: (1, 3)
[[-2.492318, -3.94821 ,? 0.911841]]
[fc2.bias] shape: (1,)
[0.692789]
===== Extra Info (Hyperparameters) =====
Optimizer: Adam
Learning Rate: 0.05
Epochs: 900
Loss: BCEWithLogitsLoss
Activation: ReLU
量化之后的模型參數等:
超參數部分:
?
===== Model Architecture =====
XORNetQAT(
? (quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
? (fc1): QuantizedLinearReLU(in_features=2, out_features=3, scale=0.0678500160574913, zero_point=0, qscheme=torch.per_channel_affine)
? (relu): Identity()
? (fc2): QuantizedLinear(in_features=3, out_features=1, scale=0.1376650333404541, zero_point=65, qscheme=torch.per_channel_affine)
? (dequant): DeQuantize()
)
===== Layer Parameters =====
===== Extra Info (Hyperparameters) =====
Optimizer: Adam
Learning Rate: 0.01
QAT Epochs: 700
Loss: BCEWithLogitsLoss
QConfig: QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=True){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x00000164FDB16160>}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x00000164FDB16160>})
看到嗎?
量化后, 參數 變多了哈哈!
那量化的意義到底在哪里呢??
在下面的 參數(非超參)的 權重的部分:
===== 量化 (Quantization) 后參數 =====
[fc1] (QuantizedLinearReLU)
? [weight] shape: torch.Size([3, 2]), dtype: torch.qint8
? weight (quantized):
[[ 3.97? 3.98]
?[ 3.95? 3.96]
?[-0.56 -1.05]]
? weight (raw int):
[[127 127]
?[127 127]
?[-18 -33]]
? scale: 0.06785
? zero_point: 0
? [bias] shape: torch.Size([3]), dtype: torch.float
? bias:
[-3.049073e-04 -3.960315e+00? 0.000000e+00]
[fc2] (QuantizedLinear)
? [weight] shape: torch.Size([1, 3]), dtype: torch.qint8
? weight (quantized):
[[ 3.77 -7.79? 0.41]]
? weight (raw int):
[[ 27 -57? ?3]]
? scale: 0.13766
? zero_point: 65
? [bias] shape: torch.Size([1]), dtype: torch.float
? bias:
[-7.147094]
再看看量化前:
===== Layer Parameters =====
[fc1.weight] shape: (3, 2)
[[ 1.723932,? 1.551827],
?[ 2.106917,? 1.681809],
?[-0.299378, -0.444912]]
[fc1.bias] shape: (3,)
[-1.725313, -2.509506,? 0.? ? ? ]
[fc2.weight] shape: (1, 3)
[[-2.492318, -3.94821 ,? 0.911841]]
[fc2.bias] shape: (1,)
[0.692789]
最后看看 量化 后:
===== 量化 (Quantization) 后參數 =====
[fc1] (QuantizedLinearReLU)
? [weight] shape: torch.Size([3, 2]), dtype: torch.qint8
? weight (quantized):
[[ 3.97? 3.98]
?[ 3.95? 3.96]
?[-0.56 -1.05]]
? weight (raw int):
[[127 127]
?[127 127]
?[-18 -33]]
? scale: 0.06785
? zero_point: 0
? [bias] shape: torch.Size([3]), dtype: torch.float
? bias:
看出區別了嗎?
So:
量化操作 只 適用于 大、中型號的模型……道理就在此:
量化前 的 Weights 的權重 全部都是: Float32浮點型 ……很占內存的!
量化后 是所謂INT8(即一個字節、8bits)……至少(在權重部分)節省了 3/4 的內存!
So: 大模型 必須 要 量化 才 節省內存。
當然前提 是 你的 GPU 硬件 要 支持 INT8(8bits)的 運算哦……(這是后話,下次再聊)。
(over)完!
INT4版本,不調用pytorch,只用numpy:
?
import numpy as np# S型激活函數
def sigmoid(x):return 1 / (1 + np.exp(-np.clip(x, -500, 500))) # 裁剪以防止溢出# S型激活函數的導數
def sigmoid_derivative(x):return x * (1 - x)# INT4 量化函數 (對稱量化)
# scale 和 zero_point 通常在 QAT 中學習,但為簡單起見,
# 這里我們將基于動態確定的 min_val 和 max_val。
def quantize_int4(tensor, min_val, max_val):# 確定將 float32 映射到 INT4 范圍 [-8, 7]q_min = -8q_max = 7epsilon = 1e-7 # 一個小常數,用于防止除以零或在min_val和max_val非常接近時出現問題# 確保 min_val 和 max_val 之間有足夠的間隔if abs(max_val - min_val) < epsilon:# 如果它們都接近0,則將范圍擴大一點點以穩定包含0if abs(min_val) < epsilon and abs(max_val) < epsilon:min_val_eff = -epsilonmax_val_eff = epsilonelse: # 否則,圍繞它們的值稍微擴大范圍min_val_eff = min_val - epsilonmax_val_eff = max_val + epsilonelse:min_val_eff = min_valmax_val_eff = max_val# 計算 scalescale = (q_max - q_min) / (max_val_eff - min_val_eff)# 計算 zero_point (對于非對稱量化是必要的,對于對稱量化,理想情況下是0)# 這里我們使用標準的仿射量化公式,然后對zero_point進行取整和裁剪zero_point = q_min - min_val_eff * scale# 將 zero_point 四舍五入到最近的整數并裁剪zero_point = np.round(zero_point)zero_point = np.clip(zero_point, q_min, q_max).astype(np.int32)# 量化張量quantized_tensor = np.round(tensor * scale + zero_point)# 裁剪到 INT4 范圍quantized_tensor = np.clip(quantized_tensor, q_min, q_max)return quantized_tensor.astype(np.int8), scale, zero_point # 用int8類型存儲INT4的值# INT4 反量化函數
def dequantize_int4(quantized_tensor, scale, zero_point):return (quantized_tensor.astype(np.float32) - zero_point) / scale # 確保運算使用浮點數# 直通估計器 (STE) 用于量化
# 在前向傳播中,應用量化。
# 在反向傳播中,梯度直接通過,好像沒有量化發生。
class STEQuantizer:def __init__(self):self.quantized_val = Noneself.scale = Noneself.zero_point = Nonedef forward(self, x, min_val_for_quant, max_val_for_quant):# 量化在前向傳播中進行quantized_x, scale, zero_point = quantize_int4(x, min_val_for_quant, max_val_for_quant)self.quantized_val = quantized_xself.scale = scaleself.zero_point = zero_point# 為了在網絡中進行實際計算(模擬量化硬件的行為),需要反量化return dequantize_int4(quantized_x, scale, zero_point)def backward(self, grad_output):# 在反向傳播中,梯度直接通過# 這是STE的核心:輸出對輸入的梯度被認為是1return grad_output# 定義神經網絡類
class NeuralNetwork:def __init__(self, input_size, hidden_size, output_size, learning_rate=0.1, l2_lambda=0.0001, ema_decay=0.99):self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_sizeself.learning_rate = learning_rateself.l2_lambda = l2_lambda # L2正則化系數self.ema_decay = ema_decay # EMA衰減因子,用于更新量化范圍統計# 初始化權重和偏置# 權重初始化為小的隨機值self.weights_input_hidden = np.random.uniform(-0.5, 0.5, (self.input_size, self.hidden_size))self.bias_hidden = np.zeros((1, self.hidden_size))self.weights_hidden_output = np.random.uniform(-0.5, 0.5, (self.hidden_size, self.output_size))self.bias_output = np.zeros((1, self.output_size))# 用于量化的統計量 (min/max),初始為None,將在第一次前向傳播時或通過EMA更新self.min_wih = Noneself.max_wih = Noneself.min_who = Noneself.max_who = None# 權重用的量化器self.quantizer_wih = STEQuantizer()self.quantizer_who = STEQuantizer()def _update_ema_stats(self, tensor_data, current_min_stat, current_max_stat):# 使用指數移動平均(EMA)更新統計的min/max值current_tensor_min = tensor_data.min()current_tensor_max = tensor_data.max()if current_min_stat is None or current_max_stat is None: # 第一次或未初始化new_min = current_tensor_minnew_max = current_tensor_maxelse:new_min = self.ema_decay * current_min_stat + (1 - self.ema_decay) * current_tensor_minnew_max = self.ema_decay * current_max_stat + (1 - self.ema_decay) * current_tensor_maxreturn new_min, new_maxdef forward(self, X, quantize=False, is_training=False):# 如果指定量化并且在訓練階段,則更新權重的EMA統計范圍if quantize and is_training:self.min_wih, self.max_wih = self._update_ema_stats(self.weights_input_hidden, self.min_wih, self.max_wih)self.min_who, self.max_who = self._update_ema_stats(self.weights_hidden_output, self.min_who, self.max_who)# 如果指定量化,則對權重進行量化if quantize:# 確保min/max統計量已初始化 (例如,在第一次EMA更新后)# 如果仍然是None (可能發生在第一次推理且之前沒有訓練),則使用當前權重的瞬時min/maxmin_wih_eff = self.min_wih if self.min_wih is not None else self.weights_input_hidden.min()max_wih_eff = self.max_wih if self.max_wih is not None else self.weights_input_hidden.max()min_who_eff = self.min_who if self.min_who is not None else self.weights_hidden_output.min()max_who_eff = self.max_who if self.max_who is not None else self.weights_hidden_output.max()self.quantized_wih_dequant = self.quantizer_wih.forward(self.weights_input_hidden, min_wih_eff, max_wih_eff)self.quantized_who_dequant = self.quantizer_who.forward(self.weights_hidden_output, min_who_eff, max_who_eff)else:# 使用原始的FP32權重self.quantized_wih_dequant = self.weights_input_hiddenself.quantized_who_dequant = self.weights_hidden_output# 輸入層到隱藏層self.hidden_layer_input = np.dot(X, self.quantized_wih_dequant) + self.bias_hiddenself.hidden_layer_output = sigmoid(self.hidden_layer_input)# 隱藏層到輸出層self.output_layer_input = np.dot(self.hidden_layer_output, self.quantized_who_dequant) + self.bias_outputself.output = sigmoid(self.output_layer_input)return self.outputdef backward(self, X, y, output):# 計算誤差self.error = y - output# 輸出層梯度self.d_output = self.error * sigmoid_derivative(output)# 計算 weights_hidden_output 和 bias_output 的梯度self.gradients_who = np.dot(self.hidden_layer_output.T, self.d_output)self.gradients_bo = np.sum(self.d_output, axis=0, keepdims=True)# 反向傳播到隱藏層# 注意:這里使用前向傳播中使用的(可能已反量化的)權重進行梯度計算self.d_hidden_layer = np.dot(self.d_output, self.quantized_who_dequant.T) * sigmoid_derivative(self.hidden_layer_output)# 計算 weights_input_hidden 和 bias_hidden 的梯度self.gradients_wih = np.dot(X.T, self.d_hidden_layer)self.gradients_bh = np.sum(self.d_hidden_layer, axis=0, keepdims=True)# 對權重梯度應用STE (梯度直接通過量化器)self.gradients_wih = self.quantizer_wih.backward(self.gradients_wih)self.gradients_who = self.quantizer_who.backward(self.gradients_who)# 更新權重和偏置,加入L2正則化 (作用于原始的FP32權重)self.weights_hidden_output += self.learning_rate * self.gradients_who - self.learning_rate * self.l2_lambda * self.weights_hidden_outputself.bias_output += self.learning_rate * self.gradients_boself.weights_input_hidden += self.learning_rate * self.gradients_wih - self.learning_rate * self.l2_lambda * self.weights_input_hiddenself.bias_hidden += self.learning_rate * self.gradients_bh# XOR 輸入和輸出
X_xor = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y_xor = np.array([[0], [1], [1], [0]])# 網絡參數
input_size = 2
hidden_size = 3#注意這里:3個節點.比2個節點容易收斂2 #5 # 隱藏層節點數
output_size = 1
learning_rate = 0.0241 # 學習率 (可能需要進一步調整)
epochs = 150000#200000 # 訓練迭代次數 (減少了次數以便更快看到結果,原為500000)
l2_lambda_val = 0.0001 # L2 正則化系數# 初始化神經網絡
nn = NeuralNetwork(input_size, hidden_size, output_size, learning_rate, l2_lambda=l2_lambda_val)# QAT訓練循環
print("開始使用QAT訓練神經網絡...")
for i in range(epochs):# 前向傳播,應用量化,并標記為訓練階段以更新EMA統計output = nn.forward(X_xor, quantize=True, is_training=True)# 反向傳播和參數更新nn.backward(X_xor, y_xor, output)if i % 10000 == 0 or i == epochs -1 : # 每10000次迭代或最后一次迭代時打印損失loss = np.mean(np.square(y_xor - output))print(f"迭代次數 {i}, 損失: {loss:.6f}")print("\n訓練完成。")# QAT后評估模型
print("\nQAT之后的預測結果:")
# 推理時 quantize=True, is_training=False (不更新EMA統計)
predictions = nn.forward(X_xor, quantize=True, is_training=False)
print(f"輸入:\n{X_xor}")
print(f"期望輸出:\n{y_xor}")
# 將預測結果四舍五入到0或1,因為XOR的輸出是二元的
print(f"預測輸出 (反量化后顯示,并四舍五入):\n{np.round(predictions)}")# 顯示量化后的權重 (INT4值)
# 注意:quantizer_wih.quantized_val 保存的是最近一次前向傳播(推理時)的量化權重
# 我們可以在推理后再進行一次前向傳播以確保這些值是最新的,或者直接使用它們
nn.forward(X_xor, quantize=True, is_training=False) # 確保量化值是最新的print("\n--- 量化后的模型權重 (INT4) ---")
if nn.quantizer_wih.quantized_val is not None:print("輸入層到隱藏層權重 (INT4):\n", nn.quantizer_wih.quantized_val)print("量化參數 (scale, zero_point):", nn.quantizer_wih.scale, nn.quantizer_wih.zero_point)
else:print("輸入層到隱藏層權重未被量化。")if nn.quantizer_who.quantized_val is not None:print("隱藏層到輸出層權重 (INT4):\n", nn.quantizer_who.quantized_val)print("量化參數 (scale, zero_point):", nn.quantizer_who.scale, nn.quantizer_who.zero_point)
else:print("隱藏層到輸出層權重未被量化。")# 為了驗證,顯示反量化后的權重 (表示INT4值的FP32形式)
print("\n--- 反量化后的權重 (INT4值的FP32表示) ---")
if nn.quantizer_wih.quantized_val is not None:dequant_wih = dequantize_int4(nn.quantizer_wih.quantized_val, nn.quantizer_wih.scale, nn.quantizer_wih.zero_point)print("輸入層到隱藏層權重 (反量化后):\n", dequant_wih)
else:print("輸入層到隱藏層權重未量化,無法顯示反量化值。")if nn.quantizer_who.quantized_val is not None:dequant_who = dequantize_int4(nn.quantizer_who.quantized_val, nn.quantizer_who.scale, nn.quantizer_who.zero_point)print("隱藏層到輸出層權重 (反量化后):\n", dequant_who)
else:print("隱藏層到輸出層權重未量化,無法顯示反量化值。")print("\n--- 訓練后的原始 (FP32) 權重 ---")
print("輸入層到隱藏層權重 (FP32):\n", nn.weights_input_hidden)
print("隱藏層到輸出層權重 (FP32):\n", nn.weights_hidden_output)print("\n--- 用于量化的動態范圍統計 (EMA) ---")
print(f"WIH Min: {nn.min_wih}, WIH Max: {nn.max_wih}")
print(f"WHO Min: {nn.min_who}, WHO Max: {nn.max_who}")