剪枝
模型剪枝(Model Pruning) 是一種 模型壓縮(Model Compression) 技術,主要思想是:
深度神經網絡里有很多 冗余參數(對預測結果貢獻很小)。
通過去掉這些冗余連接/通道/卷積核,能讓模型更小、更快,同時盡量保持精度。
非結構化剪枝(Unstructured Pruning)
對單個權重參數設置閾值,小于閾值的直接置零。
優點:保留了原始網絡結構,容易實現。
缺點:稀疏矩陣計算對普通硬件加速有限(需要專門稀疏庫)。
#將所有的卷積層通道減掉30%
for module in pruned_model.modules():if isinstance(module,nn.Conv2d):#這行代碼的作用是對指定模塊按照L2范數的標準,沿著輸出通道維度剪去30%的不重要通道,prune.ln_structured(module,name = "weight",amount = 0.3,n=2,dim = 0)
對ResNet18減和不減的效果差不多,一個是精度,另一個是一輪推理的時間
分析原因 確實把 30% 卷積核置零,但是模塊結構沒變:Conv2d 還是原來那么大,只是部分權重被置零, PyTorch 的默認實現不會自動跳過這些“無效通道”, 所以 FLOPs 還是一樣,ptflops 統計出來的數字沒減少, GPU 上仍然執行全量卷積,推理時間幾乎不會變化
結構化剪枝(Structured Pruning)
刪除整個卷積核、通道、層。
優點:能直接減少計算量和推理時間。
缺點:剪掉的多了容易掉精度。
完整代碼
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import time
from tqdm import tqdm
from ptflops import get_model_complexity_info
import torch_pruning as tp# ======================
# 1. 數據準備
# ======================
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,shuffle=False, num_workers=2)device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" )
# ======================
# 2. 定義訓練和測試函數
# ======================
def train(model,optimizer,criterion,epoch):model.train()for inx,(inputs,targets) in enumerate(trainloader):inputs,targets = inputs.to(device),targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs,targets)loss.backward()optimizer.step()def test(model,criterion,epoch,tag = ""):model.eval()start = time.time()correct,total,loss_sum = 0,0,0.0with torch.no_grad():for inputs, targets in testloader:inputs,targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss_sum = criterion(outputs,targets).item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()acc = 100. * correct / totalend = time.time()time_cost = end - startprint(f"{tag} Epoch {epoch}: Loss={loss_sum:.4f}, Acc={acc:.2f}%, Time={time_cost:.2f}s")return acc,time_costdef print_model_stats(model,tag = ""):#統計模型參數和flopsmac, params = get_model_complexity_info(model,(3,32,32),as_strings = True,print_per_layer_stat = False,verbose = False)print(f"{tag} Params:{params},FLOPs:{mac}")# ======================
# 3. 訓練基線模型
# ======================
print("===============BaseLine ResNet18")
baseline_model = models.resnet18(pretrained = True)
baseline_model.fc = nn.Linear(baseline_model.fc.in_features,10)
baseline_model = baseline_model.to(device)
print_model_stats(baseline_model,"Baseline")criterion = nn.CrossEntropyLoss()
optimer = optim.SGD(baseline_model.parameters(),lr = 0.01,momentum = 0.9,weight_decay = 5e-4)
baseline_acc = []
baseline_time = []
for epoch in tqdm(range(10)):train(baseline_model,optimer,criterion,epoch)acc,time_cost = test(baseline_model,criterion,epoch,"Baseline")baseline_acc.append(acc)baseline_time.append(time_cost)# ======================
# 4. 剪枝 + 微調
# ======================
pruned_model = models.resnet18(pretrained = True)
pruned_model.fc = nn.Linear(pruned_model.fc.in_features,10)
pruned_model = pruned_model.to(device)#===============非結構化剪枝=====================
# #將所有的卷積層通道減掉30%
# for module in pruned_model.modules():
# if isinstance(module,nn.Conv2d):
# #這行代碼的作用是對指定模塊按照L2范數的標準,沿著輸出通道維度剪去30%的不重要通道,
# prune.ln_structured(module,name = "weight",amount = 0.3,n=2,dim = 0)#==========================結構化剪枝=====================
# 創建依賴圖對象,用于處理剪枝時各層之間的依賴關系
DG = tp.DependencyGraph()
# 構建模型的依賴關系圖,需要提供示例輸入來追蹤計算圖
# example_inputs用于追蹤模型的前向傳播路徑,確定各層之間的依賴關系
DG.build_dependency(pruned_model,example_inputs = torch.randn(1,3,32,32).to(device))def prune_conv_by_ratio(conv, ratio=0.3):# 計算每個輸出通道的L1范數(絕對值求和),用于評估通道的重要性# conv.weight.data.abs().sum((1, 2, 3)) 對卷積核的后三維(H, W, C_in)求和,得到每個輸出通道的L1范數weight = conv.weight.data.abs().sum((1, 2, 3)) # 根據指定的剪枝比例計算需要移除的通道數量num_remove = int(weight.numel() * ratio)# 找到L1范數最小的num_remove個通道的索引# torch.topk返回最大的k個元素,設置largest=False后返回最小的k個元素_, idxs = torch.topk(weight, k=num_remove, largest=False)# 獲取剪枝組,指定要剪枝的層、剪枝方式和剪枝索引# tp.prune_conv_out_channels表示沿輸出通道維度進行剪枝group = DG.get_pruning_group(conv, tp.prune_conv_out_channels, idxs=idxs.tolist())# 執行剪枝操作,物理移除指定的通道group.prune()# 遍歷剪枝模型的所有模塊
for m in pruned_model.modules():# 檢查模塊是否為卷積層if isinstance(m, nn.Conv2d):# 對該卷積層執行剪枝操作,移除30%的輸出通道prune_conv_by_ratio(m, ratio=0.3)#=======================================================print_model_stats(pruned_model,"Pruned")
criterion1 = nn.CrossEntropyLoss()
optimer1 = optim.SGD(pruned_model.parameters(),lr = 0.01,momentum = 0.9,weight_decay = 5e-4)
pruned_acc = []
pruned_time = []for epoch in tqdm(range(10)):train(pruned_model,optimer1,criterion1,epoch)acc,time_cost = test(pruned_model,criterion1,epoch,"Pruned")pruned_acc.append(acc)pruned_time.append(time_cost)# ======================
# 5. 對比結果
# ======================
print("\n==== Final Accuracy Comparison ====")print(f" Baseline={max(baseline_acc):.2f}% time={sum(baseline_time)/len(baseline_time):.2f}, Pruned={max(pruned_acc):.2f}% time={sum(pruned_time)/len(pruned_time):.2f}")
最終訓練10輪的情況下精度下降7%,模型參數量減少4倍,感覺能夠接受
Params:11.18 M – > 2.7M
FLOPs:37.25 MMac --> 9.48 MMac
acc : 82.86% —> 75.77%
time : 1.20 ----> 1.12
基于正則化/稀疏約束
在訓練時加上稀疏正則項,讓網絡自動學習出“重要性低”的權重趨近于零,再做剪枝。