【Python訓練營打卡】day44 @浙大疏錦行

DAY 44 預訓練模型

知識點回顧:

1.? 預訓練的概念

2.? 常見的分類預訓練模型

3.? 圖像預訓練模型的發展史

4.? 預訓練的策略

5.? 預訓練代碼實戰:resnet18

作業:

1.? 嘗試在cifar10對比如下其他的預訓練模型,觀察差異,盡可能和他人選擇的不同

2.? 嘗試通過ctrl進入resnet的內部,觀察殘差究竟是什么

一、預訓練的概念

我們之前在訓練中發現,準確率最開始隨著epoch的增加而增加。隨著循環的更新,參數在不斷發生更新。

所以參數的初始值對訓練結果有很大的影響:

1. 如果最開始的初始值比較好,后續訓練輪數就會少很多

2. 很有可能陷入局部最優值,不同的初始值可能導致陷入不同的局部最優值

所以很自然的想到,如果最開始能有比較好的參數,即可能導致未來訓練次數少,也可能導致未來訓練避免陷入局部最優解的問題。這就引入了一個概念,即預訓練模型。

如果別人在某些和我們目標數據類似的大規模數據集上做過訓練,我們可以用他的訓練參數來初始化我們的模型,這樣我們的模型就比較容易收斂。

Q&A:

1. 那為什么要選擇類似任務的數據集預訓練的模型參數呢?

因為任務差不多,他提取特征的能力才有用,如果任務相差太大,他的特征提取能力就沒那么好。

所以本質預訓練就是拿別人已經具備的通用特征提取能力來接著強化能力使之更加適應我們的數據集和任務。

2. 為什么要求預訓練模型是在大規模數據集上訓練的,小規模不行么?

因為提取的是通用特征,所以如果數據集數據少、尺寸小,就很難支撐復雜任務學習通用的數據特征。比如你是一個物理的博士,讓你去做小學數學題,很快就能上手;但是你是一個小學數學速算高手,讓你做物理博士的課題,就很困難。所以預訓練模型一般就挺強的。

我們把用預訓練模型的參數,然后接著在自己數據集上訓練來調整該參數的過程叫做微調,這種思想叫做遷移學習。把預訓練的過程叫做上游任務,把微調的過程叫做下游任務。

二、經典的預訓練模型

2.1 CNN 架構預訓練模型

模型預訓練數據集核心特點在 CIFAR10 上的適配要點
AlexNetImageNet首次引入 ReLU / 局部響應歸一化,參數量 6000 萬 +需修改首層卷積核大小(原 11x11→適配 32x32)
VGG16ImageNet純卷積堆疊,結構統一,參數量 1.38 億凍結前 10 層卷積,僅微調全連接層
ResNet18ImageNet殘差連接解決梯度消失,參數量 1100 萬直接適配 32x32 輸入,需調整池化層步長
MobileNetV2ImageNet深度可分離卷積,參數量 350 萬 +輕量級設計,適合計算資源有限的場景

2.2 Transformer 類預訓練模型

適用于較大尺寸圖像(如 224×224),在 CIFAR10 上需上采樣圖像尺寸或調整 Patch 大小。

模型預訓練數據集核心特點在 CIFAR10 上的適配要點
ViT-BaseImageNet-21K純 Transformer 架構,參數量 8600 萬圖像 Resize 至 224×224,Patch 大小設為 4×4
Swin TransformerImageNet-22K分層窗口注意力,參數量 8000 萬 +需調整窗口大小適配小圖像
DeiTImageNet結合 CNN 歸納偏置,參數量 2200 萬輕量級 Transformer,適合中小尺寸圖像

2.3 自監督預訓練模型

無需人工標注,通過 pretext task(如掩碼圖像重建)學習特征,適合數據稀缺場景。

模型預訓練方式典型數據集在 CIFAR10 上的優勢
MoCo v3對比學習ImageNet無需標簽即可遷移,適合無標注數據
BEiT掩碼圖像建模ImageNet-22K特征語義豐富,微調時收斂更快

三、常見的分類預訓練模型介紹

3.1 預訓練模型的發展史

模型對比表

模型年份提出團隊關鍵創新點層數參數量ImageNet Top-5 錯誤率典型應用場景預訓練權重可用性
LeNet-51998Yann LeCun 等首個 CNN 架構,卷積層 + 池化層 + 全連接層,Sigmoid 激活函數7~60KN/A手寫數字識別(MNIST)無(歷史模型)
AlexNet2012Alex Krizhevsky 等ReLU 激活函數、Dropout、數據增強、GPU 訓練860M15.3%大規模圖像分類PyTorch/TensorFlow 官方支持
VGGNet2014Oxford VGG 團隊統一 3×3 卷積核、多尺度特征提取、結構簡潔16/19138M/144M7.3%/7.0%圖像分類、目標檢測基礎骨干網絡PyTorch/TensorFlow 官方支持
GoogLeNet2014GoogleInception 模塊(多分支并行卷積)、1×1 卷積降維、全局平均池化225M6.7%大規模圖像分類PyTorch/TensorFlow 官方支持
ResNet2015何愷明等殘差連接(解決梯度消失)、Batch Normalization18/50/15211M/25M/60M3.57%/3.63%/3.58%圖像 / 視頻分類、檢測、分割PyTorch/TensorFlow 官方支持
DenseNet2017Gao Huang 等密集連接(每層與后續所有層相連)、特征復用、參數效率高121/1698M/14M2.80%小數據集、醫學圖像處理PyTorch/TensorFlow 官方支持
MobileNet2017Google深度可分離卷積(減少 75% 計算量)、輕量級設計284.2M7.4%移動端圖像分類 / 檢測PyTorch/TensorFlow 官方支持
EfficientNet2019Google復合縮放(同時優化深度、寬度、分辨率)、NAS 搜索最佳配置B0-B75.3M-66M2.6% (B7)高精度圖像分類(資源受限場景)PyTorch/TensorFlow 官方支持

補充說明

  • 層數含義:代表模型不同版本,如 ResNet 有 18/50/152 層,EfficientNet 有 B0-B7 等變體。
  • ImageNet Top-5 準確率:模型預測概率前五的類別中包含正確類別的比例,是圖像分類任務的重要評估指標(共 1000 類)。

模型架構演進關鍵點總結

  1. 深度突破:從 LeNet 的 7 層到 ResNet152 的 152 層,殘差連接解決了深度網絡的訓練難題。
    (沒上過復試班 CV 部分的自行了解殘差連接,非常重要!)
  2. 計算效率:GoogLeNet(Inception)和 MobileNet 通過結構優化,在保持精度的同時大幅降低參數量。
  3. 特征復用:DenseNet 的密集連接設計使模型能更好地利用淺層特征,適合小數據集。
  4. 自動化設計:EfficientNet 使用神經架構搜索(NAS)自動尋找最優網絡配置,開創了 AutoML 在 CNN 中的應用。

預訓練模型使用建議

任務需求推薦模型理由
快速原型開發ResNet50/18結構平衡,預訓練權重穩定,社區支持完善
移動端部署MobileNetV3參數量小,計算高效,專為移動設備優化
高精度分類(資源充足)EfficientNet-B7目前 ImageNet 準確率領先,適合 GPU/TPU 環境
小數據集或特征復用需求DenseNet密集連接設計減少過擬合,特征復用能力強
多尺度特征提取Inception-ResNet結合 Inception 多分支和 ResNet 殘差連接,適合復雜場景

說明
模型預訓練權重均可通過主流框架(如 PyTorch 的torchvision.models、Keras 的applications模塊)直接加載,便于快速遷移到新任務。

總結:CNN 架構發展脈絡

1. 早期探索(1990s-2010s):LeNet 驗證 CNN 可行性,但受限于計算和數據。

2. 深度學習復興(2012-2015):AlexNet、VGGNet、GoogLeNet 通過加深網絡和結構創新突破性能。

3. 超深網絡時代(2015 年后):ResNet 解決退化問題,開啟殘差連接范式,后續模型圍繞效率(MobileNet)、特征復用(DenseNet)、多分支結構(Inception)等方向優化。

3.2預訓練模型的訓練策略

那么什么模型會被選為預訓練模型呢?比如一些調參后表現很好的cnn神經網絡(固定的神經元個數+固定的層數等)。

所以調用預訓練模型做微調,本質就是 用這些固定的結構+之前訓練好的參數 接著訓練

所以需要找到預訓練的模型結構并且加載模型參數

相較于之前用自己定義的模型有以下幾個注意點

1. 需要調用預訓練模型和加載權重

2. 需要resize 圖片讓其可以適配模型

3. 需要修改最后的全連接層以適應數據集

其中,訓練過程中,為了不破壞最開始的特征提取器的參數,最開始往往先凍結住特征提取器的參數,然后訓練全連接層,大約在5-10個epoch后解凍訓練。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 設置中文字體支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解決負號顯示問題# 檢查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 1. 數據預處理(訓練集增強,測試集標準化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加載CIFAR-10數據集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform
)# 3. 創建數據加載器(可調整batch_size)
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 訓練函數(支持學習率調度器)
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):model.train()  # 設置為訓練模式train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []for epoch in range(epochs):running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 記錄Iteration損失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 統計訓練指標running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印進度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 單Batch損失: {iter_loss:.4f}")# 計算 epoch 級指標epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 測試階段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 記錄歷史數據train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新學習率調度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 結果print(f"Epoch {epoch+1} 完成 | 訓練損失: {epoch_train_loss:.4f} "f"| 訓練準確率: {epoch_train_acc:.2f}% | 測試準確率: {epoch_test_acc:.2f}%")# 繪制損失和準確率曲線plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最終測試準確率# 5. 繪制Iteration損失曲線
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序號)')plt.ylabel('損失值')plt.title('訓練過程中的Iteration損失變化')plt.grid(True)plt.show()# 6. 繪制Epoch級指標曲線
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 準確率曲線plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='訓練準確率')plt.plot(epochs, test_acc, 'r-', label='測試準確率')plt.xlabel('Epoch')plt.ylabel('準確率 (%)')plt.title('準確率隨Epoch變化')plt.legend()plt.grid(True)# 損失曲線plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='訓練損失')plt.plot(epochs, test_loss, 'r-', label='測試損失')plt.xlabel('Epoch')plt.ylabel('損失值')plt.title('損失值隨Epoch變化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()
# 導入ResNet模型
from torchvision.models import resnet18# 定義ResNet18模型(支持預訓練權重加載)
def create_resnet18(pretrained=True, num_classes=10):# 加載預訓練模型(ImageNet權重)model = resnet18(pretrained=pretrained)# 修改最后一層全連接層,適配CIFAR-10的10分類任務in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)# 將模型轉移到指定設備(CPU/GPU)model = model.to(device)return model
# 創建ResNet18模型(加載ImageNet預訓練權重,不進行微調)
model = create_resnet18(pretrained=True, num_classes=10)
model.eval()  # 設置為推理模式# 測試單張圖片(示例)
from torchvision import utils# 從測試數據集中獲取一張圖片
dataiter = iter(test_loader)
images, labels = next(dataiter)
images = images[:1].to(device)  # 取第1張圖片# 前向傳播
with torch.no_grad():outputs = model(images)_, predicted = torch.max(outputs.data, 1)# 顯示圖片和預測結果
plt.imshow(utils.make_grid(images.cpu(), normalize=True).permute(1, 2, 0))
plt.title(f"預測類別: {predicted.item()}")
plt.axis('off')
plt.show()

在 CIFAR-10 數據集中,類別標簽是固定的 10 個,分別對應:

標簽(數字)類別名稱說明
0airplane飛機
1automobile汽車(含轎車、卡車等)
2bird鳥類
3cat
4deer鹿
5dog
6frog青蛙
7horse
8ship
9truck卡車(重型貨車等)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os# 設置中文字體支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解決負號顯示問題# 檢查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 1. 數據預處理(訓練集增強,測試集標準化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加載CIFAR-10數據集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform
)# 3. 創建數據加載器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定義ResNet18模型
def create_resnet18(pretrained=True, num_classes=10):model = models.resnet18(pretrained=pretrained)# 修改最后一層全連接層in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)return model.to(device)# 5. 凍結/解凍模型層的函數
def freeze_model(model, freeze=True):"""凍結或解凍模型的卷積層參數"""# 凍結/解凍除fc層外的所有參數for name, param in model.named_parameters():if 'fc' not in name:param.requires_grad = not freeze# 打印凍結狀態frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)total_params = sum(p.numel() for p in model.parameters())if freeze:print(f"已凍結模型卷積層參數 ({frozen_params}/{total_params} 參數)")else:print(f"已解凍模型所有參數 ({total_params}/{total_params} 參數可訓練)")return model# 6. 訓練函數(支持階段式訓練)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):"""前freeze_epochs輪凍結卷積層,之后解凍所有層進行訓練"""train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []# 初始凍結卷積層if freeze_epochs > 0:model = freeze_model(model, freeze=True)for epoch in range(epochs):# 解凍控制:在指定輪次后解凍所有層if epoch == freeze_epochs:model = freeze_model(model, freeze=False)# 解凍后調整優化器(可選)optimizer.param_groups[0]['lr'] = 1e-4  # 降低學習率防止過擬合model.train()  # 設置為訓練模式running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 記錄Iteration損失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 統計訓練指標running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印進度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 單Batch損失: {iter_loss:.4f}")# 計算 epoch 級指標epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 測試階段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 記錄歷史數據train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新學習率調度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 結果print(f"Epoch {epoch+1} 完成 | 訓練損失: {epoch_train_loss:.4f} "f"| 訓練準確率: {epoch_train_acc:.2f}% | 測試準確率: {epoch_test_acc:.2f}%")# 繪制損失和準確率曲線plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最終測試準確率# 7. 繪制Iteration損失曲線
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序號)')plt.ylabel('損失值')plt.title('訓練過程中的Iteration損失變化')plt.grid(True)plt.show()# 8. 繪制Epoch級指標曲線
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 準確率曲線plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='訓練準確率')plt.plot(epochs, test_acc, 'r-', label='測試準確率')plt.xlabel('Epoch')plt.ylabel('準確率 (%)')plt.title('準確率隨Epoch變化')plt.legend()plt.grid(True)# 損失曲線plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='訓練損失')plt.plot(epochs, test_loss, 'r-', label='測試損失')plt.xlabel('Epoch')plt.ylabel('損失值')plt.title('損失值隨Epoch變化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 主函數:訓練模型
def main():# 參數設置epochs = 40  # 總訓練輪次freeze_epochs = 5  # 凍結卷積層的輪次learning_rate = 1e-3  # 初始學習率weight_decay = 1e-4  # 權重衰減# 創建ResNet18模型(加載預訓練權重)model = create_resnet18(pretrained=True, num_classes=10)# 定義優化器和損失函數optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)criterion = nn.CrossEntropyLoss()# 定義學習率調度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)# 開始訓練(前5輪凍結卷積層,之后解凍)final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"訓練完成!最終測試準確率: {final_accuracy:.2f}%")# # 保存模型# torch.save(model.state_dict(), 'resnet18_cifar10_finetuned.pth')# print("模型已保存至: resnet18_cifar10_finetuned.pth")if __name__ == "__main__":main()

?

幾個明顯的現象

1. 解凍后幾個epoch即可達到之前cnn訓練20輪的效果,這是預訓練的優勢

2. 由于訓練集用了 RandomCrop(隨機裁剪)、RandomHorizontalFlip(隨機水平翻轉)、ColorJitter(顏色抖動)等數據增強操作,這會讓訓練時模型看到的圖片有更多 “干擾” 或變形。比如一張汽車圖片,訓練時可能被裁剪成只顯示局部、顏色也有變化,模型學習難度更高;而測試集是標準的、沒增強的圖片,模型預測相對輕松,就可能出現訓練集準確率暫時低于測試集的情況,尤其在訓練前期增強對模型影響更明顯。隨著訓練推進,模型適應增強后會緩解。

3. 最后收斂后的效果超過非預訓練模型的80%,大幅提升

作業

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os# 設置中文字體支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False# 設備配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 數據預處理
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加載數據集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)def create_resnet50(pretrained=True, num_classes=10):model = models.resnet50(pretrained=pretrained) in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)  return model.to(device)# 訓練函數
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):"""前freeze_epochs輪凍結卷積層,之后解凍所有層進行訓練"""train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []# 初始凍結卷積層if freeze_epochs > 0:model = freeze_model(model, freeze=True)for epoch in range(epochs):# 解凍控制:在指定輪次后解凍所有層if epoch == freeze_epochs:model = freeze_model(model, freeze=False)# 解凍后調整優化器(可選)optimizer.param_groups[0]['lr'] = 1e-4  # 降低學習率防止過擬合model.train()  # 設置為訓練模式running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 記錄Iteration損失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 統計訓練指標running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印進度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 單Batch損失: {iter_loss:.4f}")# 計算 epoch 級指標epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 測試階段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 記錄歷史數據train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新學習率調度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 結果print(f"Epoch {epoch+1} 完成 | 訓練損失: {epoch_train_loss:.4f} "f"| 訓練準確率: {epoch_train_acc:.2f}% | 測試準確率: {epoch_test_acc:.2f}%")# 繪制損失和準確率曲線plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最終測試準確率# 繪圖函數
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序號)')plt.ylabel('損失值')plt.title('訓練過程中的Iteration損失變化')plt.grid(True)plt.show()def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 準確率曲線plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='訓練準確率')plt.plot(epochs, test_acc, 'r-', label='測試準確率')plt.xlabel('Epoch')plt.ylabel('準確率 (%)')plt.title('準確率隨Epoch變化')plt.legend()plt.grid(True)# 損失曲線plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='訓練損失')plt.plot(epochs, test_loss, 'r-', label='測試損失')plt.xlabel('Epoch')plt.ylabel('損失值')plt.title('損失值隨Epoch變化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()def main():epochs = 40freeze_epochs = 5learning_rate = 5e-4 weight_decay = 1e-4# 創建ResNet50模型model = create_resnet50(pretrained=True)# 定義優化器optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=0.9  # 添加動量加速收斂)criterion = nn.CrossEntropyLoss()scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)# 訓練流程final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"\nResNet50在CIFAR-10上的最終測試準確率: {final_accuracy:.2f}%")# 保存模型(可選)# torch.save(model.state_dict(), 'resnet50_cifar10_finetuned.pth')if __name__ == "__main__":main()

@浙大疏錦行

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/908237.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/908237.shtml
英文地址,請注明出處:http://en.pswp.cn/news/908237.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

MySQL中關于事務和鎖的常見執行命令整理包括版本區別

MySQL中關于事務和鎖的常見執行命令實例整理,并標注了不同版本下的區別(如MySQL 8.0與舊版本的差異): 一、事務相關命令 1. 事務控制 命令描述版本差異START TRANSACTION; 或 BEGIN;顯式開啟事務通用語法,無版本差異…

PyTorch-Transforms的使用(二)

對圖像進行處理 安裝open cv ctrlP 看用法 ToTensor的使用 常見的Transforms 歸一化的圖片 兩個長度為三的數組,分別表示三個通道的平均值和標準差 Resize() Compose() 合并執行功能,輸入進去一個列表&a…

vscode實用配置

前端開發安裝插件: 1.可以更好看的顯示文件圖標 2.用戶快速打開文件 使用步驟:在html文件下右鍵點擊 open with live server 即可 刷力扣: 安裝這個插件 還需要安裝node.js即可

Day130 | 靈神 | 回溯算法 | 子集型 電話號碼的字母組合

Day130 | 靈神 | 回溯算法 | 子集型 電話號碼的字母組合 17.電話號碼的字母組合 17. 電話號碼的字母組合 - 力扣(LeetCode) 思路: 筆者用index代替i,這里的index其實就是digits數組的下標 按照靈神的回溯三問,那就…

深入理解JavaScript設計模式之閉包與高階函數

前言小序 一場失敗面試 2023年的某一天,一場讓我印象深刻的面試: 面試官: “你了解閉包嗎?請說一下你對閉包的理解。” 我自信滿滿地答道: “閉包就是函數里面套函數,里面的函數可以訪問外部函數的變量。…

使用 Spring Boot 3.3 和 JdbcTemplate 操作 MySQL 數據庫

在現代的 Java 應用開發中,Spring Boot 提供了強大的工具來簡化數據庫操作。JdbcTemplate 是 Spring 提供的一個核心類,用于簡化 JDBC 操作,減少樣板代碼。本文將介紹如何在 Spring Boot 3.3 項目中使用 JdbcTemplate 來操作 MySQL 數據庫&am…

如何做好一份技術文檔?(下篇)

如何做好一份技術文檔?(下篇) 下篇:文檔體驗的極致優化 ——從可用性到愉悅性的跨越 文檔用戶體驗地圖 新手路徑 專家路徑 [安裝] → [配置] → [示例] [API] → [參數] → [源碼] │ ▲ …

Windows 12確認沒了,Win11 重心偏移修Bug

微軟悄然擱置了傳說中的Windows 12開發計劃,轉身將精力投入到Windows 11的持續進化中。今年秋季的主角已經確定——Windows 11 25H2,它將于9月或10月間與我們正式見面。 與去年24H2的大規模更新不同,25H2更像是場精心策劃的“功能解鎖”。微軟…

JavaScript中的正則表達式:文本處理的瑞士軍刀

JavaScript中的正則表達式:文本處理的瑞士軍刀 在編程世界中,正則表達式(Regular Expression,簡稱RegExp)被譽為“文本處理的瑞士軍刀”。它能夠高效地完成字符串匹配、替換、提取和驗證等任務。無論是前端開發中的表…

基于LEAP模型在能源環境發展、碳排放建模預測及分析中實踐應用

在國家“3060”碳達峰碳中和的政策背景下,如何尋求經濟-能源-環境的平衡有效發展是國家、省份、城市及園區等不同級別經濟體的重要課題。根據國家政策、當地能源結構、能源技術發展水平以及相關碳排放指標制定合理有效的低碳能源發展規劃需要以科學準確的能源環境發…

Python爬蟲實戰:研究RoboBrowser庫相關技術

1. 引言 1.1 研究背景與意義 隨著電子商務的快速發展,商品信息呈現爆炸式增長。據 Statista 數據顯示,2025 年全球電子商務銷售額預計將達到 7.4 萬億美元,海量的商品數據蘊含著巨大的商業價值。對于電商企業而言,及時獲取競爭對手的產品信息、價格動態和用戶評價,能夠幫…

JVM垃圾回收器-ZGC

一、概述 ZGC(Z Garbage Collector)是一種高效且可擴展的低延遲垃圾回收器。在垃圾回收過程中,ZGC通過優化算法和硬件支持,將Stop-The-World(STW)時間控制在一毫秒以內,使其成為追求低延遲應用…

區間動態規劃

線性 DP 的一種,簡稱為「區間 DP」。以「區間長度」劃分階段,以兩個坐標(區間的左、右端點)作為狀態的維度。一個狀態通常由被它包含且比它更小的區間狀態轉移而來。 一、概念 間 DP 的主要思想就是:先在小區間內得到…

4. 數據類型

4.1 數據類型分類 分類 數據類型 說明 數值類型 BIT(M) 位類型。M指定位數,默認值1,范圍1 - 64 TINYINT [UNSIGNED] 帶符號的范圍 -128 ~ 127,無符號范圍0 ~ 255,默認有符號 BOOL 使用0和1表示真和假 SMALLINT [UNSIGNED] 帶符號是…

設計模式-2 結構型模式

一、代理模式 1、舉例 海外代購 2、代理基本結構圖 3、靜態代理 1、真實類實現一個接口,代理類也實現這個接口。 2、代理類通過真實對象調用真實類的方法。 4、靜態代理和動態代理的區別 1、靜態代理在編譯時就已經實現了,編譯完成后代理類是一個實際…

vue+element-ui一個頁面有多個子組件組成。子組件里面有各種表單,實現點擊enter實現跳轉到下一個表單元素的功能。

一個父組件里面是有各個子組件的form表單組成的。 我想實現點擊enter。焦點直接跳轉到下一個表單元素。 父組件就是由各個子組件構成 子組件就像下圖一樣的都有個el-form的表單。 enterToTab.js let enterToTab {}; (function() {// 返回隨機數enterToTab.addEnterListener …

Open SSL 3.0相關知識以及源碼流程分析

Open SSL 3.0相關知識以及源碼流程分析 編譯 windows環境編譯1、工具安裝 安裝安裝perl腳本解釋器、安裝nasm匯編器(添加到環境變量)、Visual Studio編譯工具 安裝dmake ppm install dmake # 需要過墻2、開始編譯 # 1、找到Visual Studio命令行編譯工具目錄 或者菜單欄直接…

【Redis】筆記|第5節|Redisson實現高并發分布式鎖核心源碼

一、加鎖流程 1. 核心方法調用鏈 RLock lock redisson.getLock("resource"); lock.lock(); // 阻塞式加鎖? lockInterruptibly()? tryAcquire(-1, leaseTime, unit) // leaseTime-1表示啟用看門狗? tryAcquireAsync()? tryLockInnerAsync() // 執行Lua腳本 2…

基于React + TypeScript構建高度可定制的QR碼生成器

前言 在現代Web應用中,QR碼已成為連接線上線下的重要橋梁。本文將詳細介紹如何使用React TypeScript Vite構建一個功能強大、高度可定制的QR碼生成器,支持背景圖片、文本疊加、HTML模塊、圓角導出等高級功能。 前往試試 項目概述 技術棧 前端框架:…

【MATLAB代碼】制導——三點法,二維平面下的例程|運動目標制導,附完整源代碼

三點法制導是一種導彈制導策略,主要用于確保導彈能夠準確追蹤并擊中移動目標。該方法通過計算導彈、目標和制導站之間的相對位置關系,實現對目標的有效制導。 本文給出MATLAB下的三點法例程,模擬平面上捕獲運動目標的情況訂閱專欄后可直接查看源代碼,粘貼到MATLAB空腳本中即…