Pytorch | 從零構建MobileNet對CIFAR10進行分類
- CIFAR10數據集
- MobileNet
- 設計理念
- 網絡結構
- 技術優勢
- 應用領域
- MobileNet結構代碼詳解
- 結構代碼
- 代碼詳解
- DepthwiseSeparableConv 類
- 初始化方法
- 前向傳播 forward 方法
- MobileNet 類
- 初始化方法
- 前向傳播 forward 方法
- 訓練過程和測試結果
- 代碼匯總
- mobilenet.py
- train.py
- test.py
前面文章我們構建了AlexNet、Vgg、GoogleNet對CIFAR10進行分類:
Pytorch | 從零構建AlexNet對CIFAR10進行分類
Pytorch | 從零構建Vgg對CIFAR10進行分類
Pytorch | 從零構建GoogleNet對CIFAR10進行分類
Pytorch | 從零構建ResNet對CIFAR10進行分類
這篇文章我們來構建MobileNet.
CIFAR10數據集
CIFAR-10數據集是由加拿大高級研究所(CIFAR)收集整理的用于圖像識別研究的常用數據集,基本信息如下:
- 數據規模:該數據集包含60,000張彩色圖像,分為10個不同的類別,每個類別有6,000張圖像。通常將其中50,000張作為訓練集,用于模型的訓練;10,000張作為測試集,用于評估模型的性能。
- 圖像尺寸:所有圖像的尺寸均為32×32像素,這相對較小的尺寸使得模型在處理該數據集時能夠相對快速地進行訓練和推理,但也增加了圖像分類的難度。
- 類別內容:涵蓋了飛機(plane)、汽車(car)、鳥(bird)、貓(cat)、鹿(deer)、狗(dog)、青蛙(frog)、馬(horse)、船(ship)、卡車(truck)這10個不同的類別,這些類別都是現實世界中常見的物體,具有一定的代表性。
下面是一些示例樣本:
MobileNet
MobileNet是由谷歌在2017年提出的一種輕量級卷積神經網絡,主要用于移動端和嵌入式設備等資源受限的環境中進行圖像識別和分類任務,以下是對其的詳細介紹:
設計理念
- 深度可分離卷積:其核心創新是采用了深度可分離卷積(Depthwise Separable Convolution)來替代傳統的卷積操作。深度可分離卷積將標準卷積分解為一個深度卷積(Depthwise Convolution)和一個逐點卷積(Pointwise Convolution),大大減少了計算量和模型參數,同時保持了較好的性能。
網絡結構
- 標準卷積層:輸入層為3通道的彩色圖像,首先經過一個普通的卷積層
conv1
,將通道數從3變為32,同時進行了步長為2的下采樣操作,以減小圖像尺寸。 - 深度可分離卷積層:包含了一系列的深度可分離卷積層
dsconv1
至dsconv13
,這些層按照一定的規律進行排列,通道數逐漸增加,同時通過不同的步長進行下采樣,以提取不同層次的特征。 - 池化層和全連接層:在深度可分離卷積層之后,通過一個自適應平均池化層
avgpool
將特征圖轉換為1x1的大小,然后通過一個全連接層fc
將特征映射到指定的類別數,完成分類任務。
技術優勢
- 模型輕量化:通過深度可分離卷積的使用,大大減少了模型的參數量和計算量,使得模型更加輕量化,適合在移動設備和嵌入式設備上運行。
- 計算效率高:由于減少了計算量,MobileNet在推理時具有較高的計算效率,可以快速地對圖像進行分類和識別,滿足實時性要求較高的應用場景。
- 性能表現較好:盡管模型輕量化,但MobileNet在圖像識別任務上仍然具有較好的性能表現,能夠在保持較高準確率的同時,大大降低模型的復雜度。
應用領域
- 移動端視覺任務:廣泛應用于各種移動端設備,如智能手機、平板電腦等,用于圖像分類、目標檢測、人臉識別等視覺任務。
- 嵌入式設備視覺:在嵌入式設備,如智能攝像頭、自動駕駛汽車等領域,MobileNet可以為這些設備提供高效的視覺處理能力,實現實時的圖像分析和決策。
- 物聯網視覺應用:在物聯網設備中,MobileNet可以幫助實現對圖像數據的快速處理和分析,為智能家居、智能安防等應用提供支持。
MobileNet結構代碼詳解
結構代碼
import torch
import torch.nn as nnclass DepthwiseSeparableConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):super(DepthwiseSeparableConv, self).__init__()self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias)self.bn1 = nn.BatchNorm2d(in_channels)self.relu1 = nn.ReLU6(inplace=True)self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)self.bn2 = nn.BatchNorm2d(out_channels)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):out = self.depthwise(x)out = self.bn1(out)out = self.relu1(out)out = self.pointwise(out)out = self.bn2(out)out = self.relu2(out)return outclass MobileNet(nn.Module):def __init__(self, num_classes):super(MobileNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)self.relu = nn.ReLU6(inplace=True)self.dsconv1 = DepthwiseSeparableConv(32, 64, stride=1)self.dsconv2 = DepthwiseSeparableConv(64, 128, stride=2)self.dsconv3 = DepthwiseSeparableConv(128, 128, stride=1)self.dsconv4 = DepthwiseSeparableConv(128, 256, stride=2)self.dsconv5 = DepthwiseSeparableConv(256, 256, stride=1)self.dsconv6 = DepthwiseSeparableConv(256, 512, stride=2)self.dsconv7 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv8 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv9 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv10 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv11 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv12 = DepthwiseSeparableConv(512, 1024, stride=2)self.dsconv13 = DepthwiseSeparableConv(1024, 1024, stride=1)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(1024, num_classes)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.dsconv1(out)out = self.dsconv2(out)out = self.dsconv3(out)out = self.dsconv4(out)out = self.dsconv5(out)out = self.dsconv6(out)out = self.dsconv7(out)out = self.dsconv8(out)out = self.dsconv9(out)out = self.dsconv10(out)out = self.dsconv11(out)out = self.dsconv12(out)out = self.dsconv13(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
代碼詳解
以下是對上述代碼的詳細解釋:
DepthwiseSeparableConv 類
這是一個自定義的深度可分離卷積層類,繼承自 nn.Module
。
初始化方法
- 參數說明:
in_channels
:輸入通道數,指定輸入數據的通道數量。out_channels
:輸出通道數,即卷積操作后輸出特征圖的通道數量。kernel_size
:卷積核大小,默認為3,用于定義卷積操作中卷積核的尺寸。stride
:步長,默認為1,控制卷積核在輸入特征圖上滑動的步長。padding
:填充大小,默認為1,在輸入特征圖周圍添加的填充像素數量,以保持特征圖尺寸在卷積過程中合適變化。bias
:是否使用偏置,默認為False
,決定卷積層是否添加偏置項。
- 構建的層及作用:
self.depthwise
:這是一個深度卷積層(nn.Conv2d
),通過設置groups=in_channels
,實現了深度可分離卷積中的深度卷積部分,它對每個輸入通道分別進行卷積操作,有效地減少了計算量。self.bn1
:批歸一化層(nn.BatchNorm2d
),用于對深度卷積后的輸出進行歸一化處理,加速模型收斂并提升模型的泛化能力。self.relu1
:激活函數層(nn.ReLU6
),采用ReLU6
激活函數(輸出值限定在0到6之間),并且設置inplace=True
,意味著直接在輸入的張量上進行修改,節省內存空間,增加非線性特性。self.pointwise
:逐點卷積層(nn.Conv2d
),卷積核大小為1,用于將深度卷積后的特征圖在通道維度上進行融合,改變通道數到指定的out_channels
。self.bn2
:又是一個批歸一化層,對逐點卷積后的輸出進行歸一化處理。self.relu2
:同樣是ReLU6
激活函數層,進一步增加非線性,處理逐點卷積歸一化后的結果。
前向傳播 forward 方法
定義了數據在該層的前向傳播過程:
- 首先將輸入
x
通過深度卷積層self.depthwise
進行深度卷積操作,得到輸出特征圖。 - 然后將深度卷積的輸出依次經過批歸一化層
self.bn1
和激活函數層self.relu1
。 - 接著把經過處理后的特征圖通過逐點卷積層
self.pointwise
進行逐點卷積,改變通道數等特征。 - 最后再經過批歸一化層
self.bn2
和激活函數層self.relu2
,并返回最終的輸出結果。
MobileNet 類
這是定義的 MobileNet 網絡模型類,同樣繼承自 nn.Module
。
初始化方法
- 參數說明:
num_classes
:分類的類別數量,用于最后全連接層輸出對應類別數的預測結果。
- 構建的層及作用:
self.conv1
:普通的二維卷積層(nn.Conv2d
),輸入通道數為3(通常對應RGB圖像的三個通道),輸出通道數為32,卷積核大小為3,步長為2,用于對輸入圖像進行初步的特征提取和下采樣,減少特征圖尺寸同時增加通道數。self.bn1
:批歸一化層,對conv1
卷積后的輸出進行歸一化。self.relu
:激活函數層,采用ReLU6
激活函數給特征圖增加非線性。- 一系列的
self.dsconv
層(從dsconv1
到dsconv13
):都是前面定義的深度可分離卷積層DepthwiseSeparableConv
的實例,它們逐步對特征圖進行更精細的特征提取、通道變換以及下采樣等操作,不同的dsconv
層有著不同的輸入輸出通道數以及步長設置,以此構建出 MobileNet 網絡的主體結構,不斷提取和融合特征,逐步降低特征圖尺寸并增加通道數來獲取更高級、更抽象的特征表示。 self.avgpool
:自適應平均池化層(nn.AdaptiveAvgPool2d
),將輸入特征圖轉換為指定大小(1, 1)
的輸出,起到全局平均池化的作用,進一步壓縮特征圖信息,同時保持特征圖的維度一致性,方便后續全連接層處理。self.fc
:全連接層(nn.Linear
),輸入維度為1024(與前面網絡結構最終輸出的特征維度對應),輸出維度為num_classes
,用于將經過前面卷積和池化等操作得到的特征向量映射到對應類別數量的預測分數上,實現分類任務。
前向傳播 forward 方法
定義了 MobileNet 模型整體的前向傳播流程:
- 首先將輸入
x
通過conv1
進行初始卷積、bn1
進行歸一化以及relu
激活。 - 然后依次通過各個深度可分離卷積層(
dsconv1
到dsconv13
),逐步提取和變換特征。 - 接著經過自適應平均池化層
self.avgpool
,將特征圖壓縮為(1, 1)
大小。 - 再通過
out.view(out.size(0), -1)
操作將特征圖展平為一維向量(其中out.size(0)
表示批量大小,-1
表示自動計算剩余維度大小使其展平)。 - 最后將展平后的特征向量通過全連接層
self.fc
得到最終的分類預測結果并返回。
訓練過程和測試結果
訓練過程損失函數變化曲線:
訓練過程準確率變化曲線:
測試結果:
代碼匯總
項目github地址
項目結構:
|--data
|--models|--__init__.py|-mobilenet.py|--...
|--results
|--weights
|--train.py
|--test.py
mobilenet.py
import torch
import torch.nn as nnclass DepthwiseSeparableConv(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):super(DepthwiseSeparableConv, self).__init__()self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias)self.bn1 = nn.BatchNorm2d(in_channels)self.relu1 = nn.ReLU6(inplace=True)self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)self.bn2 = nn.BatchNorm2d(out_channels)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):out = self.depthwise(x)out = self.bn1(out)out = self.relu1(out)out = self.pointwise(out)out = self.bn2(out)out = self.relu2(out)return outclass MobileNet(nn.Module):def __init__(self, num_classes):super(MobileNet, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(32)self.relu = nn.ReLU6(inplace=True)self.dsconv1 = DepthwiseSeparableConv(32, 64, stride=1)self.dsconv2 = DepthwiseSeparableConv(64, 128, stride=2)self.dsconv3 = DepthwiseSeparableConv(128, 128, stride=1)self.dsconv4 = DepthwiseSeparableConv(128, 256, stride=2)self.dsconv5 = DepthwiseSeparableConv(256, 256, stride=1)self.dsconv6 = DepthwiseSeparableConv(256, 512, stride=2)self.dsconv7 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv8 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv9 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv10 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv11 = DepthwiseSeparableConv(512, 512, stride=1)self.dsconv12 = DepthwiseSeparableConv(512, 1024, stride=2)self.dsconv13 = DepthwiseSeparableConv(1024, 1024, stride=1)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(1024, num_classes)def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.dsconv1(out)out = self.dsconv2(out)out = self.dsconv3(out)out = self.dsconv4(out)out = self.dsconv5(out)out = self.dsconv6(out)out = self.dsconv7(out)out = self.dsconv8(out)out = self.dsconv9(out)out = self.dsconv10(out)out = self.dsconv11(out)out = self.dsconv12(out)out = self.dsconv13(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定義數據預處理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加載CIFAR10訓練集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定義設備(GPU優先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 實例化模型
model_name = 'MobileNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練輪次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型權重,每5輪次保存到weights文件夾下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')# 繪制損失曲線plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_loss_curve.png')plt.close()# 繪制準確率曲線plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig(f'results\\{model_name}_train_acc_curve.png')plt.close()
test.py
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定義數據預處理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加載CIFAR10測試集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定義設備(GPU優先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 實例化模型
model_name = 'MobileNet'
if model_name == 'AlexNet':model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':model = GoogleNet(num_classes=10).to(device)
elif model_name == 'ResNet18':model = ResNet18(num_classes=10).to(device)
elif model_name == 'ResNet34':model = ResNet34(num_classes=10).to(device)
elif model_name == 'ResNet50':model = ResNet50(num_classes=10).to(device)
elif model_name == 'ResNet101':model = ResNet101(num_classes=10).to(device)
elif model_name == 'ResNet152':model = ResNet152(num_classes=10).to(device)
elif model_name == 'MobileNet':model = MobileNet(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加載模型權重
weights_path = f"weights/{model_name}_epoch_15.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print(f"================{model_name} Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')