python學習打卡day52

DAY 52 神經網絡調參指南

知識點回顧:

  1. 隨機種子
  2. 內參的初始化
  3. 神經網絡調參指南
    1. 參數的分類
    2. 調參的順序
    3. 各部分參數的調整心得

作業:對于day'41的簡單cnn,看看是否可以借助調參指南進一步提高精度。

day41的簡單CNN最后的結果,今天要做的是使用調參指南中的方法進一步提高精度

?

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 定義通道注意力
class ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):"""通道注意力機制初始化參數:in_channels: 輸入特征圖的通道數ratio: 降維比例,用于減少參數量,默認為16"""super().__init__()# 全局平均池化,將每個通道的特征圖壓縮為1x1,保留通道間的平均值信息self.avg_pool = nn.AdaptiveAvgPool2d(1)# 全局最大池化,將每個通道的特征圖壓縮為1x1,保留通道間的最顯著特征self.max_pool = nn.AdaptiveMaxPool2d(1)# 共享全連接層,用于學習通道間的關系# 先降維(除以ratio),再通過ReLU激活,最后升維回原始通道數self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // ratio, bias=False),  # 降維層nn.ReLU(),  # 非線性激活函數nn.Linear(in_channels // ratio, in_channels, bias=False)   # 升維層)# Sigmoid函數將輸出映射到0-1之間,作為各通道的權重self.sigmoid = nn.Sigmoid()def forward(self, x):"""前向傳播函數參數:x: 輸入特征圖,形狀為 [batch_size, channels, height, width]返回:調整后的特征圖,通道權重已應用"""# 獲取輸入特征圖的維度信息,這是一種元組的解包寫法b, c, h, w = x.shape# 對平均池化結果進行處理:展平后通過全連接網絡avg_out = self.fc(self.avg_pool(x).view(b, c))# 對最大池化結果進行處理:展平后通過全連接網絡max_out = self.fc(self.max_pool(x).view(b, c))# 將平均池化和最大池化的結果相加并通過sigmoid函數得到通道權重attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)# 將注意力權重與原始特征相乘,增強重要通道,抑制不重要通道return x * attention #這個運算是pytorch的廣播機制## 空間注意力模塊
class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):# 通道維度池化avg_out = torch.mean(x, dim=1, keepdim=True)  # 平均池化:(B,1,H,W)max_out, _ = torch.max(x, dim=1, keepdim=True)  # 最大池化:(B,1,H,W)pool_out = torch.cat([avg_out, max_out], dim=1)  # 拼接:(B,2,H,W)attention = self.conv(pool_out)  # 卷積提取空間特征return x * self.sigmoid(attention)  # 特征與空間權重相乘## CBAM模塊
class CBAM(nn.Module):def __init__(self, in_channels, ratio=16, kernel_size=7):super().__init__()self.channel_attn = ChannelAttention(in_channels, ratio)self.spatial_attn = SpatialAttention(kernel_size)def forward(self, x):x = self.channel_attn(x)x = self.spatial_attn(x)return x
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np# 設置中文字體支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解決負號顯示問題# 檢查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用設備: {device}")# 1. 數據預處理
# 訓練集:使用多種數據增強方法提高模型泛化能力
train_transform = transforms.Compose([# 隨機裁剪圖像,從原圖中隨機截取32x32大小的區域transforms.RandomCrop(32, padding=4),# 隨機水平翻轉圖像(概率0.5)transforms.RandomHorizontalFlip(),# 隨機顏色抖動:亮度、對比度、飽和度和色調隨機變化transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),# 隨機旋轉圖像(最大角度15度)transforms.RandomRotation(15),# 將PIL圖像或numpy數組轉換為張量transforms.ToTensor(),# 標準化處理:每個通道的均值和標準差,使數據分布更合理transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 測試集:僅進行必要的標準化,保持數據原始特性,標準化不損失數據信息,可還原
test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加載CIFAR-10數據集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform  # 使用增強后的預處理
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform  # 測試集不使用增強
)# 3. 創建數據加載器
batch_size = 80
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 4. 定義CNN模型的定義(替代原MLP)
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 初始卷積層self.conv_init = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU())# 第一卷積塊(含CBAM)self.block1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(64),CBAM(64)  # 在卷積塊后添加CBAM)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.drop1 = nn.Dropout2d(0.1)# 第二卷積塊(含CBAM)self.block2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),  # stride=2降維nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(128),CBAM(128)  # 在卷積塊后添加CBAM)self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)self.drop2 = nn.Dropout2d(0.2)# 第三卷積塊(含CBAM)self.block3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(256),CBAM(256)  # 在卷積塊后添加CBAM)self.pool3 = nn.AdaptiveAvgPool2d(4)self.drop3 = nn.Dropout2d(0.3)# 全連接層self.fc = nn.Sequential(nn.Linear(256 * 4 * 4, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 128),nn.BatchNorm1d(128),nn.ReLU(),nn.Dropout(0.3),nn.Linear(128, 10))def forward(self, x):x = self.conv_init(x)x = self.block1(x)x = self.pool1(x)x = self.drop1(x)x = self.block2(x)x = self.pool2(x)x = self.drop2(x)x = self.block3(x)x = self.pool3(x)x = self.drop3(x)x = x.view(-1, 256 * 4 * 4)x = self.fc(x)return x# 初始化模型
model = CNN()
model = model.to(device)  # 將模型移至GPU(如果可用)
criterion = nn.CrossEntropyLoss()  # 交叉熵損失函數
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam優化器# 引入學習率調度器,在訓練過程中動態調整學習率--訓練初期使用較大的 LR 快速降低損失,訓練后期使用較小的 LR 更精細地逼近全局最優解。
# 在每個 epoch 結束后,需要手動調用調度器來更新學習率,可以在訓練過程中調用 scheduler.step()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,        # 指定要控制的優化器(這里是Adam)mode='min',       # 監測的指標是"最小化"(如損失函數)patience=3,       # 如果連續3個epoch指標沒有改善,才降低LRfactor=0.5        # 降低LR的比例(新LR = 舊LR × 0.5)
)
# 5. 訓練模型(記錄每個 iteration 的損失)
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):model.train()  # 設置為訓練模式# 記錄每個 iteration 的損失all_iter_losses = []  # 存儲所有 batch 的損失iter_indices = []     # 存儲 iteration 序號# 記錄每個 epoch 的準確率和損失train_acc_history = []test_acc_history = []train_loss_history = []test_loss_history = []# 早停相關參數best_test_acc = 0.0patience = 5  # 早停耐心值,5個epochcounter = 0   # 計數器,記錄連續未改進的epoch數early_stop = False  # 早停標志for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)  # 移至GPUoptimizer.zero_grad()  # 梯度清零output = model(data)  # 前向傳播loss = criterion(output, target)  # 計算損失loss.backward()  # 反向傳播optimizer.step()  # 更新參數# 記錄當前 iteration 的損失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 統計準確率和損失running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 每100個批次打印一次訓練信息if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 單Batch損失: {iter_loss:.4f} | 累計平均損失: {running_loss/(batch_idx+1):.4f}')# 計算當前epoch的平均訓練損失和準確率epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totaltrain_acc_history.append(epoch_train_acc)train_loss_history.append(epoch_train_loss)# 測試階段model.eval()  # 設置為評估模式test_loss = 0correct_test = 0total_test = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testtest_acc_history.append(epoch_test_acc)test_loss_history.append(epoch_test_loss)# 更新學習率調度器scheduler.step(epoch_test_loss)print(f'Epoch {epoch+1}/{epochs} 完成 | 訓練準確率: {epoch_train_acc:.2f}% | 測試準確率: {epoch_test_acc:.2f}%')# 早停檢查if epoch_test_acc > best_test_acc:best_test_acc = epoch_test_acccounter = 0# 保存最佳模型(可選)torch.save(model.state_dict(), 'best_model.pth')print(f"找到更好的模型,準確率: {best_test_acc:.2f}%,已保存")else:counter += 1print(f"早停計數器: {counter}/{patience}")if counter >= patience:print(f"早停觸發!連續 {patience} 個epoch測試準確率未提高")early_stop = True# 如果觸發早停,跳出訓練循環if early_stop:print(f"訓練在第 {epoch+1} 個epoch提前結束")break# 繪制所有 iteration 的損失曲線plot_iter_losses(all_iter_losses, iter_indices)# 繪制每個 epoch 的準確率和損失曲線plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最終測試準確率# 6. 繪制每個 iteration 的損失曲線
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序號)')plt.ylabel('損失值')plt.title('每個 Iteration 的訓練損失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 7. 繪制每個 epoch 的準確率和損失曲線
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 4))# 繪制準確率曲線plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='訓練準確率')plt.plot(epochs, test_acc, 'r-', label='測試準確率')plt.xlabel('Epoch')plt.ylabel('準確率 (%)')plt.title('訓練和測試準確率')plt.legend()plt.grid(True)# 繪制損失曲線plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='訓練損失')plt.plot(epochs, test_loss, 'r-', label='測試損失')plt.xlabel('Epoch')plt.ylabel('損失值')plt.title('訓練和測試損失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 8. 執行訓練和測試
epochs = 40  # 增加訓練輪次以獲得更好效果
print("開始使用CNN訓練模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)
print(f"訓練完成!最終測試準確率: {final_accuracy:.2f}%")# # 保存模型
# torch.save(model.state_dict(), 'cifar10_cnn_model.pth')
# print("模型已保存為: cifar10_cnn_model.pth")

訓練完成!最終測試準確率: 87.04%

@浙大疏精行?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/84737.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/84737.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/84737.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

自定義線程池 4.0

自定義線程池 4.0 1. 簡介 上次我們實現了自定義線程池的 3.1 版本,提供了線程工廠創建線程和工具類創建簡單線程池的功能,增強了線程池的靈活性,并且用起來更加方便了,本文我們將做如下的優化: 給線程池添加關閉的…

list is not in GROUPBY clause and contains nonaggregated column ‘*.*‘

SELECT list is not in GROUP BY clause and contains nonaggregated column mydb.t.address which is not functionally dependent on columns in GROUP BY clause; this is incompatible with sql_modeonly_full_group_by 關于查詢列不在分組字段內觸發錯誤 之前我一直使用其…

Linux vmware image iso qcow2鏡像大全

Download Linux VMware Images | Linux VMware Images

城市排水管網液位流量監測系統解決方案

一、方案背景 城市排水管網作為城市的“生命線”,其運行狀況直接關系到城市的防洪排澇、水環境質量以及居民的生活質量。隨著城市化進程的加速,城市排水管網規模不斷擴大,結構日益復雜,傳統的人工巡檢和簡單監測手段已難以滿足對排…

算法學習筆記:3.廣度優先搜索 (BFS)——二叉樹的層序遍歷

什么是廣度優先搜索 (BFS)? 想象一下你在玩一個迷宮游戲,你需要找到從起點到終點的最短路徑。廣度優先搜索 (BFS) 就像是你在迷宮中逐層探索的過程: 先探索距離起點最近的所有位置然后探索距離起點第二近的所有位置以此類推,直到找到終點 …

并發編程-Synchronized

Mark Word 什么是Mark Word? Mark Word是Java對象頭中的一個字段,它是一個32位或64位的字段(取決于系統架構),用于存儲對象的元數據信息。這些信息包括對象的哈希碼、鎖狀態、年齡等。 Mark Word有什么用&#xff1f…

【51單片機】5. 矩陣鍵盤與矩陣鍵盤密碼鎖Demo

1. 矩陣鍵盤原理 通過矩陣連接的模式,原本需要16個引腳連接的按鈕只需要8個引腳就能連接好,減少了I/O口的占用。 矩陣按鈕是通過掃描來讀取狀態的。 2. 掃描的概念 輸出掃描示例:數碼管掃描 原理:顯示第1位→顯示第2位→顯示第…

Android Studio jetpack compose折疊日歷日期選擇器【折疊日歷】

今天寫一個日期選擇器,大家根據自己需求改代碼,記得點贊支持,謝謝~ 這是進入的默認狀態 折疊狀態選中本周其他日期狀態 切換上下周狀態 展開日歷狀態 切換上下月狀態 選中狀態 代碼如下: import android.content.C…

馭碼CodeRider 2.0全棧開發實戰指南:從零構建現代化電商平臺

馭碼CodeRider 2.0全棧開發實戰指南:從零構建現代化電商平臺 一、CodeRider 2.0:重新定義全棧智能開發 1.1 革命性升級亮點 #mermaid-svg-AKjytNB4hD95UZtF {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-AKjyt…

大模型智能體AutoGen面試題及參考答案

目錄 AutoGen 的核心是什么? Agent 在 AutoGen 中承擔什么角色? AutoGen 是如何定義 AssistantAgent、UserProxyAgent 等代理類型的? 什么是 GroupChat(組對話)模式? AutoGen 的 system message 在框架中扮演什么作用? 如何通過 Agent 實現自然語言處理? AutoGen…

深度學習筆記26-天氣預測(Tensorflow)

🍨 本文為🔗365天深度學習訓練營中的學習記錄博客🍖 原作者:K同學啊 一、前期準備 1.數據導入 import numpy as np import pandas as pd import warnings import seaborn as sns import matplotlib.pyplot as plt warnings.filt…

day54 python對抗生成網絡

目錄 一、GAN對抗生成網絡思想 二、實踐過程 1. 數據準備 2. 構建生成器和判別器 3. 訓練過程 4. 生成結果與可視化 三、學習總結 一、GAN對抗生成網絡思想 GAN的核心思想非常有趣且富有對抗性。它由兩部分組成:生成器(Generator)和判…

龍虎榜——20250613

上證指數放量下跌收陰線,個股下跌超4000只,受外圍消息影響情緒總體較差。 深證指數放量下跌,收陰線,6月總體外圍風險較高,轉下跌走勢的概率較大,注意風險。 2025年6月13日龍虎榜行業方向分析 1. 石油石化&…

Linux常用命令加強版替代品

Linux常用命令加強版替代品 還在日復一日地使用 ls、grep、cd 這些“上古”命令嗎?是時候給你的終端來一次大升級了!本文將為你介紹一系列強大、高效且設計現代的Linux命令行工具,它們將徹底改變你的工作流,讓你愛上在終端里操作…

Hadoop 003 — JAVA操作MapReduce入門案例

MapReduce入門案例-分詞統計 文章目錄 MapReduce入門案例-分詞統計1.xml依賴2.編寫MapReduce處理邏輯3.上傳統計文件到HDFS3.配置MapReduce作業并測試4.執行結果 1.xml依賴 <dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-commo…

Python打卡第53天

浙大疏錦行 作業&#xff1a; 對于心臟病數據集&#xff0c;對于病人這個不平衡的樣本用GAN來學習并生成病人樣本&#xff0c;觀察不用GAN和用GAN的F1分數差異。 import pandas as pd import numpy as np import torch import torch.nn as nn import torch.optim as optim from…

力扣-279.完全平方數

題目描述 給你一個整數 n &#xff0c;返回 和為 n 的完全平方數的最少數量 。 完全平方數 是一個整數&#xff0c;其值等于另一個整數的平方&#xff1b;換句話說&#xff0c;其值等于一個整數自乘的積。例如&#xff0c;1、4、9 和 16 都是完全平方數&#xff0c;而 3 和 1…

前端構建工具Webapck、Vite——>前沿字節開源Rspack詳解——2023D2大會

Rspack 以下是針對主流構建工具&#xff08;Webpack、Vite、Rollup、esbuild&#xff09;的核心不足分析&#xff0c;以及 Rspack 如何基于這些痛點進行針對性改進 的深度解析&#xff1a; 一、主流構建工具的不足 1. Webpack&#xff1a;性能與生態的失衡 核心問題 冷啟動慢…

輸入法,開頭輸入這U I V 三個字母會不顯示 任何中文

1. 漢語拼音規則的限制 漢語拼音中不存在以“V”“U”“I”為聲母的情況&#xff1a; 漢語拼音的聲母是輔音&#xff0c;而“V”“U”“I”在漢語拼音中都是元音&#xff08;或韻母的一部分&#xff09;。漢語拼音的聲母系統中沒有“V”“U”“I”作為聲母的音節。例如&#xf…

Linux文件權限詳解:從入門到精通

前言 權限是什么&#xff1f; 本質&#xff1a;無非就是能做和不能做什么。 為什么要有權限呢&#xff1f; 目的&#xff1a;為了控制用戶行為&#xff0c;防止發生錯誤。 1.權限的理解 在學習下面知識之前要先知道的一點是&#xff1a;linux下一切皆文件&#xff0c;對li…