Pytorch | 從零構建EfficientNet對CIFAR10進行分類

Pytorch | 從零構建EfficientNet對CIFAR10進行分類

  • CIFAR10數據集
  • EfficientNet
    • 設計理念
    • 網絡結構
    • 性能特點
    • 應用領域
    • 發展和改進
  • EfficientNet結構代碼詳解
    • 結構代碼
    • 代碼詳解
      • MBConv 類
        • 初始化方法
        • 前向傳播 forward 方法
      • EfficientNet 類
        • 初始化方法
        • 前向傳播 forward 方法
  • 訓練過程和測試結果
  • 代碼匯總
    • efficientnet.py
    • train.py
    • test.py

前面文章我們構建了AlexNet、Vgg、GoogleNet、ResNet、MobileNet對CIFAR10進行分類:
Pytorch | 從零構建AlexNet對CIFAR10進行分類
Pytorch | 從零構建Vgg對CIFAR10進行分類
Pytorch | 從零構建GoogleNet對CIFAR10進行分類
Pytorch | 從零構建ResNet對CIFAR10進行分類
Pytorch | 從零構建MobileNet對CIFAR10進行分類
這篇文章我們來構建EfficientNet.

CIFAR10數據集

CIFAR-10數據集是由加拿大高級研究所(CIFAR)收集整理的用于圖像識別研究的常用數據集,基本信息如下:

  • 數據規模:該數據集包含60,000張彩色圖像,分為10個不同的類別,每個類別有6,000張圖像。通常將其中50,000張作為訓練集,用于模型的訓練;10,000張作為測試集,用于評估模型的性能。
  • 圖像尺寸:所有圖像的尺寸均為32×32像素,這相對較小的尺寸使得模型在處理該數據集時能夠相對快速地進行訓練和推理,但也增加了圖像分類的難度。
  • 類別內容:涵蓋了飛機(plane)、汽車(car)、鳥(bird)、貓(cat)、鹿(deer)、狗(dog)、青蛙(frog)、馬(horse)、船(ship)、卡車(truck)這10個不同的類別,這些類別都是現實世界中常見的物體,具有一定的代表性。

下面是一些示例樣本:
在這里插入圖片描述

EfficientNet

EfficientNet是由谷歌大腦團隊在2019年提出的一種高效的卷積神經網絡架構,在圖像分類等任務上展現出了卓越的性能和效率,以下是對它的詳細介紹:

設計理念

  • 平衡模型的深度、寬度和分辨率:傳統的神經網絡在提升性能時,往往只是單純地增加網絡的深度、寬度或輸入圖像的分辨率,而EfficientNet則通過一種系統的方法,同時對這三個維度進行優化調整,以達到在計算資源有限的情況下,模型性能的最大化。
    在這里插入圖片描述

網絡結構

  • MBConv模塊:EfficientNet的核心模塊是MBConv(Mobile Inverted Residual Bottleneck),它基于深度可分離卷積和倒置殘差結構。這種結構在減少計算量的同時,能夠有效提取圖像特征,提高模型的表示能力。
  • Compound Scaling方法:使用該方法對網絡的深度、寬度和分辨率進行統一縮放。通過一個固定的縮放系數,同時調整這三個維度,使得模型在不同的計算資源限制下,都能保持較好的性能和效率平衡。
    在這里插入圖片描述
    上圖中是EfficientNet-B0的結構.

性能特點

  • 高效性:在相同的計算資源下,EfficientNet能夠取得比傳統網絡更好的性能。例如,與ResNet-50相比,EfficientNet-B0在ImageNet數據集上取得了相近的準確率,但參數量和計算量卻大大減少。
  • 可擴展性:通過Compound Scaling方法,可以方便地調整模型的大小,以適應不同的應用場景和計算資源限制。從EfficientNet-B0到EfficientNet-B7,模型的復雜度逐漸增加,性能也相應提升,能夠滿足從移動端到服務器端的不同需求。

應用領域

  • 圖像分類:在ImageNet等大規模圖像分類數據集上,EfficientNet取得了領先的性能,成為圖像分類任務的首選模型之一。
  • 目標檢測:與Faster R-CNN等目標檢測框架結合,EfficientNet作為骨干網絡,能夠提高目標檢測的準確率和速度,在Pascal VOC、COCO等數據集上取得了不錯的效果。
  • 語義分割:在語義分割任務中,EfficientNet也展現出了一定的優勢,通過與U-Net等分割網絡結合,能夠對圖像進行像素級的分類,在Cityscapes等數據集上有較好的表現。

發展和改進

  • EfficientNet v2:在EfficientNet基礎上進行了進一步優化,主要改進包括改進了漸進式學習的方法,在訓練過程中逐漸增加圖像的分辨率,使得模型能夠更好地學習到不同尺度的特征,同時優化了網絡結構,提高了模型的訓練速度和性能。
  • 其他改進:研究人員還在EfficientNet的基礎上,結合其他技術如注意力機制、知識蒸餾等,進一步提高模型的性能和泛化能力。

EfficientNet結構代碼詳解

結構代碼

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MBConv(nn.Module):def __init__(self, in_channels, out_channels, expand_ratio, kernel_size, stride, padding, se_ratio=0.25):super(MBConv, self).__init__()self.stride = strideself.use_res_connect = (stride == 1 and in_channels == out_channels)# 擴展通道數(如果需要)expanded_channels = in_channels * expand_ratioself.expand_conv = nn.Conv2d(in_channels, expanded_channels, kernel_size=1, padding=0, bias=False)self.bn1 = nn.BatchNorm2d(expanded_channels)# 深度可分離卷積self.depthwise_conv = nn.Conv2d(expanded_channels, expanded_channels, kernel_size=kernel_size,stride=stride, padding=padding, groups=expanded_channels, bias=False)self.bn2 = nn.BatchNorm2d(expanded_channels)# 壓縮和激勵(SE)模塊(可選,根據se_ratio判斷是否添加)if se_ratio > 0:se_channels = int(in_channels * se_ratio)self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(expanded_channels, se_channels, kernel_size=1),nn.ReLU(inplace=True),nn.Conv2d(se_channels, expanded_channels, kernel_size=1),nn.Sigmoid())else:self.se = None# 投影卷積,恢復到輸出通道數self.project_conv = nn.Conv2d(expanded_channels, out_channels, kernel_size=1, padding=0, bias=False)self.bn3 = nn.BatchNorm2d(out_channels)def forward(self, x):identity = x# 擴展通道數out = F.relu(self.bn1(self.expand_conv(x)))# 深度可分離卷積out = F.relu(self.bn2(self.depthwise_conv(out)))# SE模塊操作(如果存在)if self.se is not None:se_out = self.se(out)out = out * se_out# 投影卷積out = self.bn3(self.project_conv(out))# 殘差連接(如果滿足條件)if self.use_res_connect:out += identityreturn outclass EfficientNet(nn.Module):def __init__(self, num_classes, width_coefficient=1.0, depth_coefficient=1.0, dropout_rate=0.2):super(EfficientNet, self).__init__()self.stem_conv = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)mbconv_config = [# (in_channels, out_channels, expand_ratio, kernel_size, stride, padding)(32, 16, 1, 3, 1, 1),(16, 24, 6, 3, 2, 1),(24, 40, 6, 5, 2, 2),(40, 80, 6, 3, 2, 1),(80, 112, 6, 5, 1, 2),(112, 192, 6, 5, 2, 2),(192, 320, 6, 3, 1, 1)]# 根據深度系數調整每個MBConv模塊的重復次數,這里簡單地向下取整,你也可以根據實際情況采用更合理的方式repeat_counts = [max(1, int(depth_coefficient * 1)) for _ in mbconv_config]layers = []for i, config in enumerate(mbconv_config):in_channels, out_channels, expand_ratio, kernel_size, stride, padding = configfor _ in range(repeat_counts[i]):layers.append(MBConv(int(in_channels * width_coefficient),int(out_channels * width_coefficient),expand_ratio, kernel_size, stride, padding))self.mbconv_layers = nn.Sequential(*layers)self.last_conv = nn.Conv2d(int(320 * width_coefficient), 1280, kernel_size=1, bias=False)self.bn2 = nn.BatchNorm2d(1280)self.avgpool = nn.AdaptiveAvgPool2d(1)self.dropout = nn.Dropout(dropout_rate)self.fc = nn.Linear(1280, num_classes)def forward(self, x):out = F.relu(self.bn1(self.stem_conv(x)))out = self.mbconv_layers(out)out = F.relu(self.bn2(self.last_conv(out)))out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.dropout(out)out = self.fc(out)return out

代碼詳解

以下是對上述EfficientNet代碼的詳細解釋,代碼整體定義了EfficientNet網絡結構,主要由MBConv模塊堆疊以及一些常規的卷積、池化和全連接層構成,下面按照類和方法分別進行剖析:

MBConv 類

這是EfficientNet中的核心模塊,實現了MBConv(Mobile Inverted Residual Bottleneck Convolution)結構,其代碼如下:

class MBConv(nn.Module):def __init__(self, in_channels, out_channels, expand_ratio, kernel_size, stride, padding, se_ratio=0.25):super(MBConv, self).__init__()self.stride = strideself.use_res_connect = (stride == 1 and in_channels == out_channels)# 擴展通道數(如果需要)expanded_channels = in_channels * expand_ratioself.expand_conv = nn.Conv2d(in_channels, expanded_channels, kernel_size=1, padding=0, bias=False)self.bn1 = nn.BatchNorm2d(expanded_channels)# 深度可分離卷積self.depthwise_conv = nn.Conv2d(expanded_channels, expanded_channels, kernel_size=kernel_size,stride=stride, padding=padding, groups=expanded_channels, bias=False)self.bn2 = nn.BatchNorm2d(expanded_channels)# 壓縮和激勵(SE)模塊(可選,根據se_ratio判斷是否添加)if se_ratio > 0:se_channels = int(in_channels * se_ratio)self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(expanded_channels, se_channels, kernel_size=1),nn.ReLU(inplace=True),nn.Conv2d(se_channels, expanded_channels, kernel_size=1),nn.Sigmoid())else:self.se = None# 投影卷積,恢復到輸出通道數self.project_conv = nn.Conv2d(expanded_channels, out_channels, kernel_size=1, padding=0, bias=False)self.bn3 = nn.BatchNorm2d(out_channels)def forward(self, x):identity = x# 擴展通道數out = F.relu(self.bn1(self.expand_conv(x)))# 深度可分離卷積out = F.relu(self.bn2(self.depthwise_conv(out)))# SE模塊操作(如果存在)if self.se is not None:se_out = self.se(out)out = out * se_out# 投影卷積out = self.bn3(self.project_conv(out))# 殘差連接(如果滿足條件)if self.use_res_connect:out += identityreturn out
初始化方法
  • 參數說明
    • in_channels:輸入張量的通道數。
    • out_channels:輸出張量的通道數。
    • expand_ratio:用于確定擴展通道數時的比例系數,決定是否對輸入通道數進行擴展以及擴展的倍數。
    • kernel_size:卷積核的大小,用于深度可分離卷積等操作。
    • stride:卷積的步長,控制特征圖在卷積過程中的下采樣程度等。
    • padding:卷積操作時的填充大小,保證輸入輸出特征圖尺寸在特定要求下的一致性等。
    • se_ratio(可選,默認值為0.25):用于控制壓縮和激勵(SE)模塊中通道壓縮的比例,若為0則不添加SE模塊。
  • 初始化操作
    • 首先保存傳入的stride參數,并根據stride和輸入輸出通道數判斷是否使用殘差連接(use_res_connect),只有當步長為1且輸入輸出通道數相等時才使用殘差連接,這符合殘差網絡的基本原理,有助于梯度傳播和特征融合。
    • 根據expand_ratio計算擴展后的通道數expanded_channels,并創建expand_conv卷積層用于擴展通道數,同時搭配對應的bn1批歸一化層,對擴展后的特征進行歸一化處理,有助于加速網絡收斂和提升模型穩定性。
    • 定義depthwise_conv深度可分離卷積層,其分組數設置為expanded_channels,意味著每個通道單獨進行卷積操作,這種方式可以在減少計算量的同時保持較好的特征提取能力,同時搭配bn2批歸一化層。
    • 根據se_ratio判斷是否添加壓縮和激勵(SE)模塊。如果se_ratio大于0,則創建一個包含自適應平均池化、卷積、激活函數(ReLU)、卷積和Sigmoid激活的序列模塊se,用于對特征進行通道維度上的重加權,增強模型對不同通道特征的關注度;若se_ratio為0,則將se設為None
    • 最后創建project_conv投影卷積層用于將擴展和處理后的特征恢復到指定的輸出通道數,并搭配bn3批歸一化層。
前向傳播 forward 方法
  • 首先將輸入張量x保存為identity,用于后續可能的殘差連接。
  • 通過F.relu(self.bn1(self.expand_conv(x)))對輸入進行通道擴展,并使用ReLU激活函數和批歸一化進行處理,得到擴展后的特征表示。
  • 接著對擴展后的特征執行深度可分離卷積操作F.relu(self.bn2(self.depthwise_conv(out))),同樣使用ReLU激活和批歸一化處理。
  • 如果存在SE模塊(self.se不為None),則將經過深度可分離卷積后的特征傳入SE模塊進行通道重加權,即se_out = self.se(out),然后將特征與重加權后的結果相乘out = out * se_out
  • 通過self.bn3(self.project_conv(out))進行投影卷積操作,將特征恢復到輸出通道數,并進行批歸一化處理。
  • 最后,如果滿足殘差連接條件(self.use_res_connectTrue),則將投影卷積后的特征與最初保存的輸入特征identity相加,實現殘差連接,最終返回處理后的特征張量。

EfficientNet 類

這是整體的EfficientNet網絡模型類,代碼如下:

class EfficientNet(nn.Module):def __init__(self, num_classes, width_coefficient=1.0, depth_coefficient=1.0, dropout_rate=0.2):super(EfficientNet, self).__init__()self.stem_conv = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)mbconv_config = [# (in_channels, out_channels, expand_ratio, kernel_size, stride, padding)(32, 16, 1, 3, 1, 1),(16, 24, 6, 3, 2, 1),(24, 40, 6, 5, 2, 2),(40, 80, 6, 3, 2, 1),(80, 112, 6, 5, 1, 2),(112, 192, 6, 5, 2, 2),(192, 320, 6, 3, 1, 1)]# 根據深度系數調整每個MBConv模塊的重復次數,這里簡單地向下取整,你也可以根據實際情況采用更合理的方式repeat_counts = [max(1, int(depth_coefficient * 1)) for _ in mbconv_config]layers = []for i, config in enumerate(mbconv_config):in_channels, out_channels, expand_ratio, kernel_size, stride, padding = configfor _ in range(repeat_counts[i]):layers.append(MBConv(int(in_channels * width_coefficient),int(out_channels * width_coefficient),expand_ratio, kernel_size, stride, padding))self.mbconv_layers = nn.Sequential(*layers)self.last_conv = nn.Conv2d(int(320 * width_coefficient), 1280, kernel_size=1, bias=False)self.bn2 = nn.BatchNorm2d(1280)self.avgpool = nn.AdaptiveAvgPool2d(1)self.dropout = nn.Dropout(dropout_rate)self.fc = nn.Linear(1280, num_classes)def forward(self, x):out = F.relu(self.bn1(self.stem_conv(x)))out = self.mbconv_layers(out)out = F.relu(self.bn2(self.last_conv(out)))out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.dropout(out)out = self.fc(out)return out
初始化方法
  • 參數說明
    • num_classes:最終分類任務的類別數量,用于確定全連接層的輸出維度。
    • width_coefficient(默認值為1.0):用于控制模型各層的通道數,實現對模型寬度的縮放調整。
    • depth_coefficient(默認值為1.0):用于控制模型中MBConv模塊的重復次數,實現對模型深度的縮放調整。
    • dropout_rate(默認值為0.2):在全連接層之前使用的Dropout概率,用于防止過擬合。
  • 初始化操作
    • 首先創建stem_conv卷積層,它將輸入的圖像數據(通常通道數為3,對應RGB圖像)進行初始的卷積操作,步長為2,起到下采樣的作用,同時不使用偏置(bias=False),并搭配bn1批歸一化層對卷積后的特征進行歸一化處理。
    • 定義mbconv_config列表,其中每個元素是一個元組,包含了MBConv模塊的各項配置參數,如輸入通道數、輸出通道數、擴展比例、卷積核大小、步長和填充等,這是構建MBConv模塊的基礎配置信息。
    • 根據depth_coefficient計算每個MBConv模塊的重復次數,通過列表推導式 repeat_counts = [max(1, int(depth_coefficient * 1)) for _ in mbconv_config] 實現,這里簡單地將每個配置對應的重復次數設置為與depth_coefficient成比例(向下取整且保證至少重復1次),你可以根據更精細的設計規則來調整這個計算方式。
    • 構建self.mbconv_layers,通過兩層嵌套循環實現。外層循環遍歷mbconv_config配置列表,內層循環根據對應的重復次數來多次添加同一個MBConv模塊實例到layers列表中,最后將layers列表轉換為nn.Sequential類型的模塊,這樣就實現了根據depth_coefficient對模型深度進行調整以及MBConv模塊的堆疊搭建。
    • 創建last_conv卷積層,它將經過MBConv模塊處理后的特征進行進一步的卷積操作,將通道數轉換為1280,同樣不使用偏置,搭配bn2批歸一化層。
    • 定義avgpool自適應平均池化層,將特征圖轉換為固定大小(這里為1x1),方便后續全連接層處理。
    • 創建dropout Dropout層,按照指定的dropout_rate在全連接層之前進行隨機失活操作,防止過擬合。
    • 最后定義fc全連接層,其輸入維度為1280(經過池化后的特征維度),輸出維度為num_classes,用于最終的分類預測。
前向傳播 forward 方法
  • 首先將輸入x傳入stem_conv卷積層進行初始卷積,然后通過F.relu(self.bn1(self.stem_conv(x)))進行激活和批歸一化處理,得到初始的特征表示。
  • 將初始特征傳入self.mbconv_layers,即經過一系列堆疊的MBConv模塊進行特征提取和變換,充分挖掘圖像中的特征信息。
  • 接著對經過MBConv模塊處理后的特征執行F.relu(self.bn2(self.last_conv(out)))操作,進行最后的卷積以及激活、批歸一化處理。
  • 使用self.avgpool(out)進行自適應平均池化,將特征圖尺寸變為1x1,實現特征的壓縮和固定維度表示。
  • 通過out = out.view(out.size(0), -1)將池化后的特征張量展平為一維向量,方便全連接層處理,這里-1表示自動根據張量元素總數和已知的批量大小維度(out.size(0))來推斷展平后的維度大小。
  • 然后將展平后的特征傳入self.dropout(out)進行Dropout操作,隨機丟棄一部分神經元,防止過擬合。
  • 最后將特征傳入self.fc(out)全連接層進行分類預測,得到最終的輸出結果,輸出的維度與設定的num_classes一致,表示每個樣本屬于各個類別的預測概率(或得分等,具體取決于任務和后續處理),并返回該輸出結果。

訓練過程和測試結果

訓練過程損失函數變化曲線:

在這里插入圖片描述

訓練過程準確率變化曲線:
在這里插入圖片描述

測試結果:
在這里插入圖片描述

代碼匯總

項目github地址
項目結構:

|--data
|--models|--__init__.py|-efficientnet.py|--...
|--results
|--weights
|--train.py
|--test.py

efficientnet.py

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MBConv(nn.Module):def __init__(self, in_channels, out_channels, expand_ratio, kernel_size, stride, padding, se_ratio=0.25):super(MBConv, self).__init__()self.stride = strideself.use_res_connect = (stride == 1 and in_channels == out_channels)# 擴展通道數(如果需要)expanded_channels = in_channels * expand_ratioself.expand_conv = nn.Conv2d(in_channels, expanded_channels, kernel_size=1, padding=0, bias=False)self.bn1 = nn.BatchNorm2d(expanded_channels)# 深度可分離卷積self.depthwise_conv = nn.Conv2d(expanded_channels, expanded_channels, kernel_size=kernel_size,stride=stride, padding=padding, groups=expanded_channels, bias=False)self.bn2 = nn.BatchNorm2d(expanded_channels)# 壓縮和激勵(SE)模塊(可選,根據se_ratio判斷是否添加)if se_ratio > 0:se_channels = int(in_channels * se_ratio)self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(expanded_channels, se_channels, kernel_size=1),nn.ReLU(inplace=True),nn.Conv2d(se_channels, expanded_channels, kernel_size=1),nn.Sigmoid())else:self.se = None# 投影卷積,恢復到輸出通道數self.project_conv = nn.Conv2d(expanded_channels, out_channels, kernel_size=1, padding=0, bias=False)self.bn3 = nn.BatchNorm2d(out_channels)def forward(self, x):identity = x# 擴展通道數out = F.relu(self.bn1(self.expand_conv(x)))# 深度可分離卷積out = F.relu(self.bn2(self.depthwise_conv(out)))# SE模塊操作(如果存在)if self.se is not None:se_out = self.se(out)out = out * se_out# 投影卷積out = self.bn3(self.project_conv(out))# 殘差連接(如果滿足條件)if self.use_res_connect:out += identityreturn outclass EfficientNet(nn.Module):def __init__(self, num_classes, width_coefficient=1.0, depth_coefficient=1.0, dropout_rate=0.2):super(EfficientNet, self).__init__()self.stem_conv = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)mbconv_config = [# (in_channels, out_channels, expand_ratio, kernel_size, stride, padding)(32, 16, 1, 3, 1, 1),(16, 24, 6, 3, 2, 1),(24, 40, 6, 5, 2, 2),(40, 80, 6, 3, 2, 1),(80, 112, 6, 5, 1, 2),(112, 192, 6, 5, 2, 2),(192, 320, 6, 3, 1, 1)]# 根據深度系數調整每個MBConv模塊的重復次數,這里簡單地向下取整,你也可以根據實際情況采用更合理的方式repeat_counts = [max(1, int(depth_coefficient * 1)) for _ in mbconv_config]layers = []for i, config in enumerate(mbconv_config):in_channels, out_channels, expand_ratio, kernel_size, stride, padding = configfor _ in range(repeat_counts[i]):layers.append(MBConv(int(in_channels * width_coefficient),int(out_channels * width_coefficient),expand_ratio, kernel_size, stride, padding))self.mbconv_layers = nn.Sequential(*layers)self.last_conv = nn.Conv2d(int(320 * width_coefficient), 1280, kernel_size=1, bias=False)self.bn2 = nn.BatchNorm2d(1280)self.avgpool = nn.AdaptiveAvgPool2d(1)self.dropout = nn.Dropout(dropout_rate)self.fc = nn.Linear(1280, num_classes)def forward(self, x):out = F.relu(self.bn1(self.stem_conv(x)))out = self.mbconv_layers(out)out = F.relu(self.bn2(self.last_conv(out)))out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.dropout(out)out = self.fc(out)return out

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定義數據預處理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加載CIFAR10訓練集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定義設備(GPU優先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 實例化模型
model_name = 'EfficientNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)
elif model_name == 'EfficientNet':model = EfficientNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練輪次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型權重,每5輪次保存到weights文件夾下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')# 繪制損失曲線plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_loss_curve.png')plt.close()# 繪制準確率曲線plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_acc_curve.png')plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定義數據預處理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加載CIFAR10測試集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定義設備(GPU優先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 實例化模型
model_name = 'EfficientNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)
elif model_name == 'EfficientNet':model = EfficientNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加載模型權重
weights_path = f"weights/{model_name}_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print(f"================{model_name} Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/63926.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/63926.shtml
英文地址,請注明出處:http://en.pswp.cn/web/63926.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Vue 2 中實現雙擊事件的幾種方法

在 Vue 2 中處理用戶交互,特別是雙擊事件,是一個常見的需求。Vue 提供了一種簡潔的方式來綁定事件,包括雙擊事件。本文將介紹幾種在 Vue 2 中實現雙擊事件的方法。 1. 使用 dblclick 指令 Vue 允許你直接在模板中使用 dblclick 指令來監聽雙…

音視頻入門基礎:MPEG2-TS專題(20)——ES流簡介

《T-REC-H.222.0-202106-S!!PDF-E.pdf》第27頁對ES進行了定義。ES流是PES packets(PES包)中編碼的視頻、編碼的音頻或其他編碼的比特流。一個ES流(elementary stream)在具有且只有一個stream_id的PES packets序列中攜帶&#xff1…

天水月亮圈圈:舌尖上的歷史與傳承

在天水甘谷縣,有一種美食如同夜空中的明月,散發著獨特的魅力,它就是有著百年歷史的月亮圈圈。月亮圈圈原名甘谷酥圈圈,據傳,由大像山鎮蔣家莊一姓李的廚師創制而成,后經王明玖等廚師的光大傳承,…

YOLOv11融合[CVPR2023]FFTformer中的FSAS模塊

YOLOv11v10v8使用教程: YOLOv11入門到入土使用教程 YOLOv11改進匯總貼:YOLOv11及自研模型更新匯總 《Efficient Frequency Domain-based Transformers for High-Quality Image Deblurring》 一、 模塊介紹 論文鏈接:https://arxiv.org/abs…

java如何使用poi-tl在word模板里渲染多張圖片

1、poi-tl官網地址 http://deepoove.com/poi-tl/ 2、引入poi-tl的依賴 <dependency><groupId>com.deepoove</groupId><artifactId>poi-tl</artifactId><version>1.12.1</version></dependency>3、定義word模板 釋義&#xf…

《信管通低代碼信息管理系統開發平臺》Windows環境安裝說明

1 簡介 《信管通低代碼信息管理系統應用平臺》提供多環境軟件產品開發服務&#xff0c;包括單機、局域網和互聯網。我們專注于適用國產硬件和操作系統應用軟件開發應用。為事業單位和企業提供行業軟件定制開發&#xff0c;滿足其獨特需求。無論是簡單的應用還是復雜的系統&…

8K+Red+Raw+ProRes422分享5個影視級視頻素材網站

Hello&#xff0c;大家好&#xff0c;我是后期圈&#xff01; 在視頻創作中&#xff0c;電影級的視頻素材能夠為作品增添專業質感&#xff0c;讓畫面更具沖擊力。無論是廣告、電影短片&#xff0c;還是品牌宣傳&#xff0c;高質量的視頻素材都是不可或缺的資源。然而&#xff…

Git遠程倉庫的使用

一.遠程倉庫注冊 1.github&#xff1a;GitHub Build and ship software on a single, collaborative platform GitHub 2.gitee&#xff1a;GitHub Build and ship software on a single, collaborative platform GitHub github需要使用魔法&#xff0c;而gitee是國內的倉…

Echarts連接數據庫,實時繪制圖表詳解

文章目錄 Echarts連接數據庫&#xff0c;實時繪制圖表詳解一、引言二、步驟一&#xff1a;環境準備與數據庫連接1、環境搭建2、數據庫連接 三、步驟二&#xff1a;數據獲取與處理1、查詢數據庫2、數據處理 四、步驟三&#xff1a;ECharts圖表配置與渲染1、配置ECharts選項2、動…

MongoDB 常用操作指南(Docker 環境下)

本文詳細介紹如何在 Docker 中操作 MongoDB&#xff0c;包括如何進入命令行、進行用戶認證、查看數據庫和集合&#xff0c;以及常用的索引操作和其他高頻使用的 MongoDB 方法。小白也能輕松上手 1. 在 Docker 中進入 MongoDB 命令行 進入運行 MongoDB 容器的命令行&#xff1a;…

【Java基礎面試題038】棧和隊列在Java中的區別是什么?

回答重點 棧&#xff08;Stack&#xff09;&#xff1a;遵循后進先出&#xff08;LIFO&#xff0c;Last In&#xff0c;First Out&#xff09;原則。即&#xff0c;最后插入的元素最先被移除。主要操作包括push&#xff08;入棧&#xff09;和pop&#xff08;出棧&#xff09;…

idea2024創建JavaWeb項目以及配置Tomcat詳解

今天呢&#xff0c;博主的學習進度也是步入了JavaWeb&#xff0c;目前正在逐步楊帆旗航&#xff0c;迎接全新的狂潮海浪。 那么接下來就給大家出一期有關JavaWeb的配置教學&#xff0c;希望能對大家有所幫助&#xff0c;也特別歡迎大家指點不足之處&#xff0c;小生很樂意接受正…

由于這些關鍵原因,我總是手邊有一臺虛擬機

概括 虛擬機提供了一個安全的環境來測試有風險的設置或軟件,而不會影響您的主系統。設置和保存虛擬機非常簡單,無需更改主要設備即可方便地訪問多個操作系統。運行虛擬機可能會占用大量資源,但現代 PC 可以很好地處理它,為實驗和工作流程優化提供無限的可能性。如果您喜歡使…

【FPGA】ISE13.4操作手冊,新建工程示例

關注作者了解更多 我的其他CSDN專欄 求職面試 大學英語 過程控制系統 工程測試技術 虛擬儀器技術 可編程控制器 工業現場總線 數字圖像處理 智能控制 傳感器技術 嵌入式系統 復變函數與積分變換 單片機原理 線性代數 大學物理 熱工與工程流體力學 數字信號處…

python環境中阻止相關庫的自動更新

找到conda中的Python虛擬環境位置 這里以conda中的pytorch虛擬環境為例&#xff08;Python環境位置&#xff09;&#xff0c;在.conda下的envs中進入pytorch下的conda-meta路徑下 新建一個空白的pinned文檔 右鍵點擊桌面或文件資源管理器中的空白處&#xff0c;選擇“新建” …

重溫設計模式--外觀模式

文章目錄 外觀模式&#xff08;Facade Pattern&#xff09;概述定義 外觀模式UML圖作用 外觀模式的結構C 代碼示例1C代碼示例2總結 外觀模式&#xff08;Facade Pattern&#xff09;概述 定義 外觀模式是一種結構型設計模式&#xff0c;它為子系統中的一組接口提供了一個統一…

uniapp 微信小程序 頁面部分截圖實現

uniapp 微信小程序 頁面部分截圖實現 ? 原理都是將頁面元素畫成canvas 然后將canvas轉化為圖片&#xff0c;問題是我頁面里邊本來就有一個canvas&#xff0c;ucharts圖畫的canvas我無法畫出這塊。 ? 想了一晚上&#xff0c;既然canvas最后能轉化為圖片&#xff0c;那我直接…

Flutter 基礎知識總結

1、Flutter 介紹與環境安裝 為什么選擇 Dart&#xff1a; 基于 JIT 快速開發周期&#xff1a;Flutter 在開發階段采用 JIT 模式&#xff0c;避免每次改動都進行編譯&#xff0c;極大的節省了開發時間基于 AOT 發布包&#xff1a;Flutter 在發布時可以通過 AOT 生成高效的 ARM…

Jenkins 持續集成部署

Jenkins的安裝與部署 前言 當我們在實施一個項目時&#xff0c;從新代碼中獲得反饋的速度越快&#xff0c;問題越早得到解決&#xff0c;獲得反饋的一種常見方法是在新代碼之后運行測試&#xff0c;但這就導致了當代碼正在編譯并且正在運行測試時&#xff0c;開發人員無法在測…

跨站請求偽造之基本介紹

一.基本概念 1.定義 跨站請求偽造&#xff08;Cross - Site Request Forgery&#xff0c;縮寫為 CSRF&#xff09;漏洞是一種網絡安全漏洞。它是指攻擊者通過誘導用戶訪問一個惡意網站&#xff0c;利用用戶在被信任網站&#xff08;如銀行網站、社交網站等&#xff09;的登錄狀…