Python訓練營---Day44

DAY 44 預訓練模型

知識點回顧:

  1. 預訓練的概念
  2. 常見的分類預訓練模型
  3. 圖像預訓練模型的發展史
  4. 預訓練的策略
  5. 預訓練代碼實戰:resnet18

作業:

  1. 嘗試在cifar10對比如下其他的預訓練模型,觀察差異,盡可能和他人選擇的不同
  2. 嘗試通過ctrl進入resnet的內部,觀察殘差究竟是什么

選用?DenseNet121預訓練模型,注意DenseNet121 模型的最后分類層名為classifier,而不是 ResNet 中的fc

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
from torchvision.models import resnet18, densenet121, vgg16# 設置中文字體支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解決負號顯示問題# 檢查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 1. 數據預處理(訓練集增強,測試集標準化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加載CIFAR-10數據集
train_dataset = datasets.CIFAR10(root='./cifar_data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./cifar_data',train=False,transform=test_transform
)# 3. 創建數據加載器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定義DenseNet121模型
def create_densenet121(pretrained=True, num_classes=10):model = models.densenet121(pretrained=pretrained)# 修改最后一層全連接層in_features = model.classifier.in_featuresmodel.classifier = nn.Linear(in_features, num_classes) # DenseNet121 的最后一層分類器名稱是classifierreturn model.to(device)# 5. 凍結/解凍模型層的函數
# 這種設計允許我們在遷移學習中保留預訓練模型的特征提取部分(卷積層),只訓練新添加的分類層(全連接層)。
def freeze_model(model, freeze=True):"""凍結或解凍模型的卷積層參數"""# 凍結/解凍除fc層外的所有參數for name, param in model.named_parameters():if 'classifier' not in name:    #排除名稱中包含 "fc" 的參數,這些通常是全連接層的參數param.requires_grad = not freeze    #param.requires_grad是 PyTorch 中控制參數是否參與反向傳播和梯度更新的標志# 打印凍結狀態frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)   #統計所有requires_grad=False的參數數量total_params = sum(p.numel() for p in model.parameters())if freeze:print(f"已凍結模型卷積層參數 ({frozen_params}/{total_params} 參數)")else:print(f"已解凍模型所有參數 ({total_params}/{total_params} 參數可訓練)")return model# 6. 訓練函數(支持階段式訓練)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):"""前freeze_epochs輪凍結卷積層,之后解凍所有層進行訓練"""train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []# 初始凍結卷積層if freeze_epochs > 0:model = freeze_model(model, freeze=True)for epoch in range(epochs):# 解凍控制:在指定輪次后解凍所有層if epoch == freeze_epochs:model = freeze_model(model, freeze=False)# 解凍后調整優化器(可選)optimizer.param_groups[0]['lr'] = 1e-4  # 降低學習率防止過擬合model.train()  # 設置為訓練模式running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 記錄Iteration損失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 統計訓練指標running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印進度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 單Batch損失: {iter_loss:.4f}")# 計算 epoch 級指標epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 測試階段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 記錄歷史數據train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新學習率調度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 結果print(f"Epoch {epoch+1} 完成 | 訓練損失: {epoch_train_loss:.4f} "f"| 訓練準確率: {epoch_train_acc:.2f}% | 測試準確率: {epoch_test_acc:.2f}%")# 繪制損失和準確率曲線plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最終測試準確率# 7. 繪制Iteration損失曲線
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序號)')plt.ylabel('損失值')plt.title('訓練過程中的Iteration損失變化')plt.grid(True)plt.show()# 8. 繪制Epoch級指標曲線
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 準確率曲線plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='訓練準確率')plt.plot(epochs, test_acc, 'r-', label='測試準確率')plt.xlabel('Epoch')plt.ylabel('準確率 (%)')plt.title('準確率隨Epoch變化')plt.legend()plt.grid(True)# 損失曲線plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='訓練損失')plt.plot(epochs, test_loss, 'r-', label='測試損失')plt.xlabel('Epoch')plt.ylabel('損失值')plt.title('損失值隨Epoch變化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 主函數:訓練模型
def main():# 參數設置epochs = 40  # 總訓練輪次freeze_epochs = 5  # 凍結卷積層的輪次learning_rate = 1e-3  # 初始學習率weight_decay = 1e-4  # 權重衰減# 創建DenseNet121模型(加載預訓練權重)model = create_densenet121(pretrained=True, num_classes=10)# 定義優化器和損失函數optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)criterion = nn.CrossEntropyLoss()# 定義學習率調度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)# 開始訓練(前5輪凍結卷積層,之后解凍)final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"訓練完成!最終測試準確率: {final_accuracy:.2f}%")# # 保存模型# torch.save(model.state_dict(), 'resnet18_cifar10_finetuned.pth')# print("模型已保存至: resnet18_cifar10_finetuned.pth")if __name__ == "__main__":main()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/908347.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/908347.shtml
英文地址,請注明出處:http://en.pswp.cn/news/908347.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

1.文件操作相關的庫

一、filesystem(C17) 和 fstream 1.std::filesystem::path - cppreference.cn - C參考手冊 std::filesystem::path 表示路徑 構造函數: path( string_type&& source, format fmt auto_format ); 可以用string進行構造,也可以用string進行隱式類…

【 java 集合知識 第二篇 】

目錄 1.Map集合 1.1.快速遍歷Map 1.2.HashMap實現原理 1.3.HashMap的擴容機制 1.4.HashMap在多線程下的問題 1.5.解決哈希沖突的方法 1.6.HashMap的put過程 1.7.HashMap的key使用什么類型 1.8.HashMapkey可以為null的原因 1.9.HashMap為什么不采用平衡二叉樹 1.10.Hash…

【Dify 知識庫 API】“根據文本更新文檔” 真的是差異更新嗎?一文講透真實機制!

在使用 Dify 知識庫 API 過程中,很多開發者在調用 /datasets/{dataset_id}/document/update-by-text 接口時,常常會產生一個疑問: ?? 這個接口到底是 “智能差異更新” 還是 “純覆蓋更新”? 網上的資料并不多,很多人根據接口名誤以為是增量更新。今天我結合官方源碼 …

大模型如何革新用戶價值、內容匹配與ROI預估

寫在前面 在數字營銷的戰場上,理解用戶、精準觸達、高效轉化是永恒的追求。傳統方法依賴結構化數據和機器學習模型,在用戶價值評估、人群素材匹配以及策略ROI預估等核心問題上取得了顯著成就。然而,隨著數據維度日益復雜,用戶行為愈發多變,傳統方法也面臨著特征工程繁瑣、…

基于端到端深度學習模型的語音控制人機交互系統

基于端到端深度學習模型的語音控制人機交互系統 摘要 本文設計并實現了一個基于端到端深度學習模型的人機交互系統,通過語音指令控制其他設備的程序運行,并將程序運行結果通過語音合成方式反饋給用戶。系統采用Python語言開發,使用PyTorch框架實現端到端的語音識別(ASR)…

【2025年】解決Burpsuite抓不到https包的問題

環境:windows11 burpsuite:2025.5 在抓取https網站時,burpsuite抓取不到https數據包,只顯示: 解決該問題只需如下三個步驟: 1、瀏覽器中訪問 http://burp 2、下載 CA certificate 證書 3、在設置--隱私與安全--…

Jenkins 工作流程

1. 觸發構建 Jenkins 的工作流程從觸發構建開始。構建可以由以下幾種方式觸發: 代碼提交觸發:通過與版本控制系統(如 Git、SVN)集成,當代碼倉庫有新的提交時,Jenkins 會自動觸發構建。 定時觸發&#xff…

Jmeter如何進行多服務器遠程測試?

🍅 點擊文末小卡片 ,免費獲取軟件測試全套資料,資料在手,漲薪更快 JMeter是Apache軟件基金會的開源項目,主要來做功能和性能測試,用Java編寫。 我們一般都會用JMeter在本地進行測試,但是受到…

Kafka入門-生產者

生產者 生產者發送流程: 延遲時間為0ms時,也就意味著每當有數據就會直接發送 異步發送API 異步發送和同步發送的不同在于:異步發送不需要等待結果,同步發送必須等待結果才能進行下一步發送。 普通異步發送 首先導入所需的k…

分類預測 | Matlab實現CNN-LSTM-Attention高光譜數據分類

分類預測 | Matlab實現CNN-LSTM-Attention高光譜數據分類 目錄 分類預測 | Matlab實現CNN-LSTM-Attention高光譜數據分類分類效果功能概述程序設計參考資料 分類效果 功能概述 代碼功能 該MATLAB代碼實現了一個結合CNN、LSTM和注意力機制的高光譜數據分類模型,核心…

gemini和chatgpt數據對比:誰在卷性能、價格和場景?

先把結論“劇透”給趕時間的朋友:頂配 Gemini Ultra/2.5 Pro 在紙面成績上普遍領先,而 ChatGPT 家族(GPT-4o / o3 / 4.1)則在延遲、生態和穩定性上占優。下面把核心數據拆開講,方便你對號入座。附帶參考來源&#xff0…

代碼訓練LeetCode(23)隨機訪問元素

代碼訓練(23)LeetCode之隨機訪問元素 Author: Once Day Date: 2025年6月5日 漫漫長路,才剛剛開始… 全系列文章可參考專欄: 十年代碼訓練_Once-Day的博客-CSDN博客 參考文章: 380. O(1) 時間插入、刪除和獲取隨機元素 - 力扣(LeetCode)力…

C++面試5——對象存儲區域詳解

C++對象存儲區域詳解 核心觀點:內存是程序員的戰場,存儲區域決定對象的生殺大權!棧對象自動赴死,堆對象生死由你,全局對象永生不死,常量區對象只讀不滅。 一、四大地域生死簿 棧區(Stack) ? 特點:自動分配釋放,速度極快(類似高鐵進出站) ? 生存期:函數大括號{}就…

STM32 智能小車項目 L298N 電機驅動模塊

今天開始著手做智能小車的項目了 在智能小車或機器人項目中,我們經常會聽到一個詞叫 “H 橋電機驅動”,尤其是常見的 L298N 模塊,就是基于“雙 H 橋”原理設計的。那么,“H 橋”到底是什么?為什么要用“雙 H 橋”來驅動…

python項目如何創建docker環境

這里寫自定義目錄標題 python項目創建docker環境docker配置國內鏡像源構建一個Docker 鏡像驗證鏡像合理的創建標題,有助于目錄的生成如何改變文本的樣式插入鏈接與圖片如何插入一段漂亮的代碼片生成一個適合你的列表創建一個表格設定內容居中、居左、居右SmartyPant…

MySQL-多表關系、多表查詢

一. 一對多(多對一) 1. 例如;一個部門下有多個員工 在數據庫表中多的一方(員工表)、添加字段,來關聯一的一方(部門表)的主鍵 二. 外鍵約束 1.如將部門表的部門直接刪除,然而員工表還存在其部門下的員工,出現了數據的不一致問題&am…

【 HarmonyOS 5 入門系列 】鴻蒙HarmonyOS示例項目講解

【 HarmonyOS 5 入門系列 】鴻蒙HarmonyOS示例項目講解 一、前言:移動開發聲明式 UI 框架的技術變革 在移動操作系統的發展歷程中,UI 開發模式經歷了從命令式到聲明式的重大變革。 根據華為開發者聯盟 2024 年數據報告顯示,HarmonyOS 設備…

【SSM】SpringMVC學習筆記7:前后端數據傳輸協議和異常處理

這篇學習筆記是Spring系列筆記的第7篇,該筆記是筆者在學習黑馬程序員SSM框架教程課程期間的筆記,供自己和他人參考。 Spring學習筆記目錄 筆記1:【SSM】Spring基礎: IoC配置學習筆記-CSDN博客 對應黑馬課程P1~P20的內容。 筆記2…

借助 Spring AI 和 LM Studio 為業務系統引入本地 AI 能力

Spring AI 1.0.0-SNAPSHOTLM Studio 0.3.16qwen3-4b 參考 Unable to use spring ai with LMStudio using spring-ai openai module Issue #2441 spring-projects/spring-ai GitHub LM Studio 下載安裝 LM Studio下載 qwen3-4b 模型。對于 qwen3 系列模型,測試…

C++學習-入門到精通【13】標準庫的容器和迭代器

C學習-入門到精通【13】標準庫的容器和迭代器 目錄 C學習-入門到精通【13】標準庫的容器和迭代器一、標準模板庫簡介1.容器簡介2.STL容器總覽3.近容器4.STL容器的通用函數5.首類容器的通用typedef6.對容器元素的要求 二、迭代器簡介1.使用istream_iterator輸入,使用…