在深度學習中,預訓練模型(Pretrained Model)?是提升開發效率和模型性能的 “利器”。無論是圖像識別、自然語言處理還是語音識別,預訓練模型都被廣泛使用。下面從概念、使用原因、場景、作用等方面詳細介紹,并結合 Python 代碼展示常用預訓練模型的使用。
一、什么是預訓練模型?(通俗易懂版)
可以把預訓練模型理解為:“別人已經訓練好的‘半成品模型’,你可以直接拿來用,或者稍作修改就能適配自己的任務”。
舉個例子:假設你想訓練一個 “識別貓和狗” 的模型,需要大量圖片和算力。但有人已經用百萬張圖片(如 ImageNet 數據集)訓練了一個 “能識別 1000 種物體” 的模型,這個模型已經學會了 “邊緣、紋理、形狀” 等通用視覺特征(比如 “貓有耳朵、狗有尾巴”)。你可以直接用這個模型,要么直接預測貓和狗,要么在它的基礎上再用少量貓和狗的圖片 “微調”,就能快速得到一個好模型。
簡言之,預訓練模型是 “前人訓練好的成果,你可以站在它的肩膀上做開發”。
二、為什么要用預訓練模型?
節省時間和算力
訓練一個復雜模型(如 ResNet、BERT)可能需要幾天甚至幾周,還需要高性能 GPU。預訓練模型已經完成了大部分計算,直接用或微調只需幾小時,適合個人或小團隊(沒有超強算力)。數據量少時也能出效果
深度學習需要大量數據(如幾十萬張圖片),但實際場景中可能只有幾千張數據(如自己拍的貓和狗圖片)。預訓練模型已經 “見過” 海量數據,學到了通用特征,用少量數據微調就能達到不錯的效果(否則從頭訓練可能過擬合)。性能更優
預訓練模型通常基于大規模數據集(如 ImageNet 有 1400 萬張圖片)和優化的網絡結構,其學到的特征更通用、更魯棒。在此基礎上微調的模型,性能往往比 “從頭訓練” 好很多。
三、預訓練模型的使用場景
快速開發原型
當你需要快速驗證一個想法(比如 “用模型識別工廠的零件是否合格”),可以直接用預訓練模型做初步測試,不需要從零開始訓練。數據量有限的任務
比如醫學影像識別(數據少且標注成本高)、小眾物體識別(如特定品種的花),用預訓練模型微調能顯著提升精度。遷移學習任務
從 “通用任務” 遷移到 “具體任務”:比如用 “識別 1000 類物體” 的預訓練模型,遷移到 “識別 5 類水果” 的任務;用 “通用文本分類” 的 BERT,遷移到 “情感分析” 任務。邊緣設備部署
很多預訓練模型有 “輕量化版本”(如 MobileNet、EfficientNet-Lite),適合在手機、攝像頭等邊緣設備上部署(算力有限但需要快速推理)。
四、預訓練模型的作用
提供通用特征提取能力
預訓練模型的前半部分(如 CNN 的卷積層、Transformer 的編碼器)已經學會了通用特征(如圖像的邊緣、紋理,文本的語義關系),可以直接作為 “特征提取器” 使用。加速模型收斂
微調時,模型參數不需要從 0 開始學習,而是在預訓練的 “好起點” 上優化,訓練速度更快(比如原本需要 100 個 epoch,微調可能只需 20 個)。降低過擬合風險
預訓練模型學到的通用特征能 “抵抗” 小數據集的噪聲,減少模型對訓練數據的過度依賴(過擬合)。
五、Python 中常用的預訓練模型及使用代碼
以計算機視覺(圖像任務)?為例,PyTorch 的torchvision
庫和 TensorFlow 的tf.keras.applications
提供了大量預訓練模型。下面以 PyTorch 為例,介紹最常用的模型及代碼。
常用預訓練模型(圖像任務)
模型名稱 | 特點 | 適用場景 |
---|---|---|
ResNet(ResNet50/101) | 結構深、精度高,適合需要高精度的任務 | 圖像分類、特征提取 |
VGG16/VGG19 | 結構簡單、特征提取能力強 | 遷移學習、細粒度分類 |
MobileNetV2/V3 | 輕量化、計算量小 | 手機、攝像頭等邊緣設備部署 |
EfficientNet | 精度與效率平衡(比 ResNet 好且更輕量) | 兼顧精度和速度的場景 |
Faster R-CNN | 經典目標檢測模型 | 目標檢測(定位 + 分類) |
代碼示例:使用預訓練模型進行圖像分類
以ResNet50
為例,展示 “加載預訓練模型→預處理圖像→推理預測” 的完整流程。
步驟 1:安裝依賴
確保安裝了torch
和torchvision
:
pip install torch torchvision
步驟 2:加載預訓練模型并查看結構
import torch
from torchvision import models# 加載預訓練的ResNet50(pretrained=True表示加載預訓練權重)
resnet50 = models.resnet50(pretrained=True)
# 設置為評估模式(關閉 dropout、batchnorm等訓練時的層)
resnet50.eval()# 查看模型結構(簡化輸出)
print("ResNet50結構概覽:")
print(resnet50)
模型結構說明:
- 前半部分是
conv1
到layer4
的卷積層(特征提取); - 后半部分是
avgpool
(全局平均池化)和fc
(全連接層,輸出 1000 類,對應 ImageNet 的 1000 個類別)。
步驟 3:圖像預處理(必須與預訓練一致)
預訓練模型對輸入圖像有固定要求(如尺寸、歸一化參數),需嚴格匹配:
from torchvision import transforms
from PIL import Image# 定義預處理流程(與ResNet訓練時的預處理一致)
preprocess = transforms.Compose([transforms.Resize(256), # 縮放到256x256transforms.CenterCrop(224), # 中心裁剪到224x224transforms.ToTensor(), # 轉為張量并歸一化到0-1# 標準化(使用ImageNet的均值和標準差,必須與預訓練一致)transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
步驟 4:用預訓練模型進行推理(預測圖像類別)
# 讀取一張測試圖片(如一只貓)
img = Image.open("cat.jpg") # 替換為你的圖片路徑
# 預處理
input_tensor = preprocess(img)
# 增加批次維度(模型要求輸入是(batch_size, channels, H, W),這里batch_size=1)
input_batch = input_tensor.unsqueeze(0)# 用GPU加速(如果有)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50.to(device)
input_batch = input_batch.to(device)# 推理(關閉梯度計算,加快速度)
with torch.no_grad():output = resnet50(input_batch)# 輸出是1000類的概率(logits),取最大概率的類別
predicted_class = torch.argmax(output[0]).item()# 加載ImageNet的類別名稱(1000類)
from torchvision.datasets import ImageNet
# 注意:ImageNet數據集需手動下載,這里簡化為加載類別名稱(可網上搜索獲取)
with open("imagenet_classes.txt") as f: # 包含1000類名稱的文件classes = [line.strip() for line in f.readlines()]print(f"預測類別:{classes[predicted_class]}")
說明:imagenet_classes.txt
包含 ImageNet 的 1000 個類別名稱(如 “貓”“狗”“汽車”),可從網上下載(搜索 “imagenet classes list”)。
步驟 5:微調預訓練模型(適配自定義任務)
如果要解決自己的分類任務(如識別 “貓、狗、鳥”3 類),需要微調模型:
# 1. 修改輸出層(將1000類改為3類)
num_classes = 3 # 自定義類別數
resnet50.fc = torch.nn.Linear(resnet50.fc.in_features, num_classes)# 2. 凍結部分層(可選,加速訓練)
# 凍結前幾層(保留預訓練的通用特征),只訓練最后幾層
for param in list(resnet50.parameters())[:-10]: # 凍結除最后10層外的參數param.requires_grad = False# 3. 定義損失函數和優化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet50.parameters(), lr=0.001)# 4. 加載自定義數據集(假設已通過DataLoader準備好)
# train_loader = ...(自定義數據集的DataLoader)# 5. 微調訓練
resnet50.train()
for epoch in range(10): # 訓練10個epochrunning_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向傳播outputs = resnet50(images)loss = criterion(outputs, labels)# 反向傳播+優化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")
微調關鍵:
- 修改輸出層以匹配自定義類別數;
- 可選凍結部分層(減少計算量,保留通用特征);
- 用較小的學習率(避免破壞預訓練的好參數)。
六、總結(通俗易懂版)
預訓練模型就像 “已經學過基礎知識的學霸”:
- 如果你想快速解決一個問題(如識別圖片里的東西),可以直接讓學霸幫你 “答題”(推理);
- 如果你想讓學霸學新技能(如識別你的 3 種寵物),只需讓他在已有知識上 “稍作練習”(微調),比教一個零基礎的人(從頭訓練)快得多,效果也好得多。
無論是小團隊、個人開發者還是企業,合理使用預訓練模型都能大幅提升效率,少走彎路。