cifar10數據集測試有多少張圖_pytorch VGG11識別cifar10數據集(訓練+預測單張輸入圖片操作)...

首先這是VGG的結構圖,VGG11則是紅色框里的結構,共分五個block,如紅框中的VGG11第一個block就是一個conv3-64卷積層:

一,寫VGG代碼時,首先定義一個 vgg_block(n,in,out)方法,用來構建VGG中每個block中的卷積核和池化層:

n是這個block中卷積層的數目,in是輸入的通道數,out是輸出的通道數

有了block以后,我們還需要一個方法把形成的block疊在一起,我們定義這個方法叫vgg_stack:

def vgg_stack(num_convs, channels): # vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

net = []

for n, c in zip(num_convs, channels):

in_c = c[0]

out_c = c[1]

net.append(vgg_block(n, in_c, out_c))

return nn.Sequential(*net)

右邊的注釋

vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

里,(1, 1, 2, 2, 2)表示五個block里,各自的卷積層數目,((3, 64), (64, 128), (128, 256), (256, 512), (512, 512))表示每個block中的卷積層的類型,如(3,64)表示這個卷積層輸入通道數是3,輸出通道數是64。vgg_stack方法返回的就是完整的vgg11模型了。

接著定義一個vgg類,包含vgg_stack方法:

#vgg類

class vgg(nn.Module):

def __init__(self):

super(vgg, self).__init__()

self.feature = vgg_net

self.fc = nn.Sequential(

nn.Linear(512, 100),

nn.ReLU(True),

nn.Linear(100, 10)

)

def forward(self, x):

x = self.feature(x)

x = x.view(x.shape[0], -1)

x = self.fc(x)

return x

最后:

net = vgg() #就能獲取到vgg網絡

那么構建vgg網絡完整的pytorch代碼是:

def vgg_block(num_convs, in_channels, out_channels):

net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)]

for i in range(num_convs - 1): # 定義后面的許多層

net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))

net.append(nn.ReLU(True))

net.append(nn.MaxPool2d(2, 2)) # 定義池化層

return nn.Sequential(*net)

# 下面我們定義一個函數對這個 vgg block 進行堆疊

def vgg_stack(num_convs, channels): # vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

net = []

for n, c in zip(num_convs, channels):

in_c = c[0]

out_c = c[1]

net.append(vgg_block(n, in_c, out_c))

return nn.Sequential(*net)

#確定vgg的類型,是vgg11 還是vgg16還是vgg19

vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

#vgg類

class vgg(nn.Module):

def __init__(self):

super(vgg, self).__init__()

self.feature = vgg_net

self.fc = nn.Sequential(

nn.Linear(512, 100),

nn.ReLU(True),

nn.Linear(100, 10)

)

def forward(self, x):

x = self.feature(x)

x = x.view(x.shape[0], -1)

x = self.fc(x)

return x

#獲取vgg網絡

net = vgg()

基于VGG11的cifar10訓練代碼:

import sys

import numpy as np

import torch

from torch import nn

from torch.autograd import Variable

from torchvision.datasets import CIFAR10

import torchvision.transforms as transforms

def vgg_block(num_convs, in_channels, out_channels):

net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)]

for i in range(num_convs - 1): # 定義后面的許多層

net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))

net.append(nn.ReLU(True))

net.append(nn.MaxPool2d(2, 2)) # 定義池化層

return nn.Sequential(*net)

# 下面我們定義一個函數對這個 vgg block 進行堆疊

def vgg_stack(num_convs, channels): # vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

net = []

for n, c in zip(num_convs, channels):

in_c = c[0]

out_c = c[1]

net.append(vgg_block(n, in_c, out_c))

return nn.Sequential(*net)

#vgg類

class vgg(nn.Module):

def __init__(self):

super(vgg, self).__init__()

self.feature = vgg_net

self.fc = nn.Sequential(

nn.Linear(512, 100),

nn.ReLU(True),

nn.Linear(100, 10)

)

def forward(self, x):

x = self.feature(x)

x = x.view(x.shape[0], -1)

x = self.fc(x)

return x

# 然后我們可以訓練我們的模型看看在 cifar10 上的效果

def data_tf(x):

x = np.array(x, dtype='float32') / 255

x = (x - 0.5) / 0.5

x = x.transpose((2, 0, 1)) ## 將 channel 放到第一維,只是 pytorch 要求的輸入方式

x = torch.from_numpy(x)

return x

transform = transforms.Compose([transforms.ToTensor(),

transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),

])

def get_acc(output, label):

total = output.shape[0]

_, pred_label = output.max(1)

num_correct = (pred_label == label).sum().item()

return num_correct / total

def train(net, train_data, valid_data, num_epochs, optimizer, criterion):

if torch.cuda.is_available():

net = net.cuda()

for epoch in range(num_epochs):

train_loss = 0

train_acc = 0

net = net.train()

for im, label in train_data:

if torch.cuda.is_available():

im = Variable(im.cuda())

label = Variable(label.cuda())

else:

im = Variable(im)

label = Variable(label)

# forward

output = net(im)

loss = criterion(output, label)

# forward

optimizer.zero_grad()

loss.backward()

optimizer.step()

train_loss += loss.item()

train_acc += get_acc(output, label)

if valid_data is not None:

valid_loss = 0

valid_acc = 0

net = net.eval()

for im, label in valid_data:

if torch.cuda.is_available():

with torch.no_grad():

im = Variable(im.cuda())

label = Variable(label.cuda())

else:

with torch.no_grad():

im = Variable(im)

label = Variable(label)

output = net(im)

loss = criterion(output, label)

valid_loss += loss.item()

valid_acc += get_acc(output, label)

epoch_str = (

"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "

% (epoch, train_loss / len(train_data),

train_acc / len(train_data), valid_loss / len(valid_data),

valid_acc / len(valid_data)))

else:

epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %

(epoch, train_loss / len(train_data),

train_acc / len(train_data)))

# prev_time = cur_time

print(epoch_str)

if __name__ == '__main__':

# 作為實例,我們定義一個稍微簡單一點的 vgg11 結構,其中有 8 個卷積層

vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

print(vgg_net)

train_set = CIFAR10('./data', train=True, transform=transform, download=True)

train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)

test_set = CIFAR10('./data', train=False, transform=transform, download=True)

test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = vgg()

optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)

criterion = nn.CrossEntropyLoss() #損失函數為交叉熵

train(net, train_data, test_data, 50, optimizer, criterion)

torch.save(net, 'vgg_model.pth')

結束后,會出現一個模型文件vgg_model.pth

二,然后網上找張圖片,把圖片縮成32x32,放到預測代碼中,即可有預測結果出現,預測代碼如下:

import torch

import cv2

import torch.nn.functional as F

from vgg2 import vgg ##重要,雖然顯示灰色(即在次代碼中沒用到),但若沒有引入這個模型代碼,加載模型時會找不到模型

from torch.autograd import Variable

from torchvision import datasets, transforms

import numpy as np

classes = ('plane', 'car', 'bird', 'cat',

'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

if __name__ == '__main__':

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.load('vgg_model.pth') # 加載模型

model = model.to(device)

model.eval() # 把模型轉為test模式

img = cv2.imread("horse.jpg") # 讀取要預測的圖片

trans = transforms.Compose(

[

transforms.ToTensor(),

transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

])

img = trans(img)

img = img.to(device)

img = img.unsqueeze(0) # 圖片擴展多一維,因為輸入到保存的模型中是4維的[batch_size,通道,長,寬],而普通圖片只有三維,[通道,長,寬]

# 擴展后,為[1,1,28,28]

output = model(img)

prob = F.softmax(output,dim=1) #prob是10個分類的概率

print(prob)

value, predicted = torch.max(output.data, 1)

print(predicted.item())

print(value)

pred_class = classes[predicted.item()]

print(pred_class)

# prob = F.softmax(output, dim=1)

# prob = Variable(prob)

# prob = prob.cpu().numpy() # 用GPU的數據訓練的模型保存的參數都是gpu形式的,要顯示則先要轉回cpu,再轉回numpy模式

# print(prob) # prob是10個分類的概率

# pred = np.argmax(prob) # 選出概率最大的一個

# # print(pred)

# # print(pred.item())

# pred_class = classes[pred]

# print(pred_class)

縮成32x32的圖片:

運行結果:

以上這篇pytorch VGG11識別cifar10數據集(訓練+預測單張輸入圖片操作)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/538278.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/538278.shtml
英文地址,請注明出處:http://en.pswp.cn/news/538278.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

npm ERR! Please try running this command again as root/Administrator.

win10操作系統下 webstrom的控制臺使用 npm install angular-file-upload 安裝組件,報錯:npm ERR! Please try running this command again as root/Administrator. 解決方法: 開始按鈕右鍵---- windows powershell(管理員&…

map flatmap mappartition flatMapToPair四種用法區別

原文鏈接:http://blog.csdn.net/u013086392/article/details/55666912 ----------------------------------------------------------------------------------- map: 我們可以看到數據的每一行在map之后產生了一個數組,那么rdd存儲的是一個數組的集合…

eve可以在linux運行嗎,ubuntu下為eve游戲搭載 wine環境

援引該地址的參考,本文僅做整理:http://bbs.eve-china.com/thread-626756-1-1.htmllinux的顯卡是否驅動成功,依次鍵入如下命令察看:sudo apt-get install mesa-utils /*安裝 mesa-utils 的指令*/glxinfo | grep r…

自動飛行控制系統_波音公司將重設計737MAX自動飛行控制系統!力求十月前復飛...

據西雅圖時報8月1日報道,美國聯邦航空管理局(FAA)在6月份對波音737 MAX飛行控制系統進行新的嚴格測試時,發現了一個潛在的缺陷,該缺陷促使波音公司對其基本的軟件設計進行變革。波音公司如今正在改變737 MAX的自動飛行控制系統軟件&#xff0…

每日一題——LeetCode141.環形鏈表

個人主頁:白日依山璟 專欄:Java|數據結構與算法|每日一題 文章目錄 1. 題目描述示例1:示例2:示例3:提示: 2. 思路3. 代碼 1. 題目描述 給你一個鏈表的頭節點 head ,判斷鏈表中是否有環。 如果鏈表中有某…

Android O 獲取APK文件權限 Demo案例

1. 通過 aapt 工具查看 APK權限 C:\Users\zh>adb pull /system/priv-app/Settings . /system/priv-app/Settings/: 3 files pulled. 10.8 MB/s (48840608 bytes in 4.325s)C:\Users\zh>aapt d permissions C:\Users\zh\Settings\Settings.apk package: com.android.sett…

VBoxManage命令更詳盡版

原文鏈接:http://418684644-qq-com.iteye.com/blog/1451000 ------------------------------------- VBoxManage命令詳解(一) 本人對vboxmange命令按我個人的理解作了解釋,由于本人水平有限難免有錯誤的地方,希望大…

linux make命令實現,Linux make命令主要參數詳解

-C dir或者 --directoryDIR在讀取makefile文件前,先切換到“dir”目錄下,即把dir作為當前目錄。如果存在多個-C選項,make的最終當前目錄是第一個目錄的相對路徑,如“make –C /home/leowang –C document”,等價于“ma…

行人屬性數據集pa100k_基于InceptionV3的多數據集聯合訓練的行人外觀屬性識別方法與流程...

本發明涉及模式識別技術、智能監控技術等領域,具體的說,是基于Inception V3的多數據集聯合訓練的行人外觀屬性識別方法。背景技術:近年來,視頻監控系統已經被廣泛應用于安防領域。安防人員通過合理的攝像頭布局,實現對…

VBoxManage獲取虛擬機IP地址

在宿主機Linux上安裝VirtualBox,然后VirtualBox上安裝linux虛擬機,在Virtualbox非界面啟動虛擬機時,ip地址無法查看。怎么辦? 使用命令: VBoxManage guestproperty enumerate 虛擬機名 | grep "Net.*V4.*IP"…

springboot系列(十)springboot整合shiro實現登錄認證

關于shiro的概念和知識本篇不做詳細介紹,但是shiro的概念還是需要做做功課的要不無法理解它的運作原理就無法理解使用shiro; 本篇主要講解如何使用shiro實現登錄認證,下篇講解使用shiro實現權限控制 要實現shiro和springboot的整合需要以下幾…

recyclerview item動畫_這可能是你見過的迄今為止最簡單的RecyclerView Item加載動畫...

如何實現RecyclerView Item動畫? 這個問題想必有很多人都會講,我可以用ItemAnimator實現啊,這是RecyclerView官方定義的接口,專門擴展Item動畫的,那我為什么要尋求另外一種方法實現呢?因為最近反思了一個問…

群暉編譯LCD4Linux,LCD4LINUX配置文件一些參數使用解釋。

#LCD顯示配置Display dpf {Driver DPF #LCD驅動類型Port usb0 #連接端口Font 6x8 #字體大小Foreground ffffff #字體…

VBoxManage: error: Nonexistent host networking interface, name 'vboxnet0' (VERR_INTERNAL_ERROR)

錯誤: VBoxManage: error: Nonexistent host networking interface, name vboxnet0 (VERR_INTERNAL_ERROR) 原因: 原來配置的網卡發生了變更,找不到了,啟動失敗。 解決方法: 第一步,命令: V…

捷信達溫泉管理軟件員工卡SQL查詢

捷信達溫泉管理軟件員工卡SQL查詢 select * from snkey where v_name2 like %員工% 網名:浩秦; 郵箱:root#landv.pw; 只要我能控制一個國家的貨幣發行,我不在乎誰制定法律。金錢一旦作響,壞話隨之戛然而止。

Linux 軟件安裝到 /usr,/usr/local/ 還是 /opt 目錄?

Linux 的軟件安裝目錄是也是有講究的,理解這一點,在對系統管理是有益的 /usr:系統級的目錄,可以理解為C:/Windows/,/usr/lib理解為C:/Windows/System32。 /usr/local:用戶級的程序目錄,可以理解…

winpe裝雙系統linux_使用syslinux在u盤安裝pubbylinux和winpe雙系統

使用syslinux在u盤安裝pubbylinux和winpe雙系統1,在u盤里安裝winpe,請參見"比較簡單的制作U盤winpe啟動盤方法"比較簡單的制作U盤winpe啟動盤方法 收藏1,下載一個深度winpev3.iso2,用winrar或ultraISO解壓深度winpev3.iso3,進入解壓出來的文件夾下,找到se…

esp32 嵌入式linux,初體驗樂鑫 ESP32 AT 指令-嵌入式系統-與非網

樂鑫 AT 固件初體驗初步體驗 AT 指令下 TCP 數傳,為了驗證 AT 命令解析器。前往樂鑫官網 下載最新版本 AT 固件和 AT 指令集手冊。硬件準備本文使用樂鑫的 ESP-WROOM-32(ESP-WROOM-32 是 ESP32-WROOM-32 的曾用名)模塊,4MB Flash,無 PSRAM。E…

主機ping不通Virtualbox里的虛擬機

在redhat上安裝了VirtualBox,虛擬了三臺Linux機器。 宿主機網卡更換過了。三臺虛擬機無法啟動了,搭建虛擬機的運維離職了。 VirtualBox的圖形界面壞了,啟動不了。只能用命令行,今天時間就花在命令行上了。 第一個問題是&#xf…

python后端開發靠譜嗎_【后端開發】python有這么強大嗎

因為Python是一種代表簡單主義思想的語言。除此之外,Python所擁有的標準庫更是金融、營銷類人群選擇它的理由。Python 易于學習可靠且高效(推薦學習:Python視頻教程)好吧,相較于其它許多你可以拿來用的編程語言而言,它“更容易一些…