pytorch學習 訓練一個分類器(五)

訓練一個分類器

就是這個, 你已經看到了如何定義神經網絡, 計算損失并更新網絡的權重.

現在你可能會想,

數據呢?

一般來說, 當你不得不處理圖像, 文本, 音頻或者視頻數據時, 你可以使用標準的 Python 包將數據加載到一個 numpy 數組中. 然后你可以將這個數組轉換成一個?torch.*Tensor.

  • 對于圖像, 會用到的包有 Pillow, OpenCV .
  • 對于音頻, 會用的包有 scipy 和 librosa.
  • 對于文本, 原始 Python 或基于 Cython 的加載, 或者 NLTK 和 Spacy 都是有用的.

特別是對于?vision, 我們已經創建了一個叫做?torchvision, 其中有對普通數據集如 Imagenet, CIFAR10, MNIST 等和用于圖像數據的轉換器, 即?torchvision.datasets?和?torch.utils.data.DataLoader.

這提供了巨大的便利, 避免了編寫重復代碼.

在本教程中, 我們將使用 CIFAR10 數據集. 它有: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’ 這些類別. CIFAR10 中的圖像大小為 3x32x32 , 即 32x32 像素的 3 通道彩色圖像.

cifar10

cifar10

訓練一個圖像分類器

我們將按順序執行以下步驟:

  1. 加載 CIFAR10 測試和訓練數據集并規范化?torchvision
  2. 定義一個卷積神經網絡
  3. 定義一個損失函數
  4. 在訓練數據上訓練網絡
  5. 在測試數據上測試網絡

1. 加載并規范化 CIFAR10

使用?torchvision, 加載 CIFAR10 非常簡單.

import torch
import torchvision
import torchvision.transforms as transforms

torchvision 數據集的輸出是范圍 [0, 1] 的 PILImage 圖像. 我們將它們轉換為歸一化范圍是[-1,1]的張量

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

讓我們展示一些訓練圖像, 只是為了好玩 (0.0).

import matplotlib.pyplot as plt
import numpy as np# 定義函數來顯示圖像def imshow(img):img = img / 2 + 0.5     # 非標準化npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))# 得到一些隨機的訓練圖像
dataiter = iter(trainloader)
images, labels = dataiter.next()# 顯示圖像
imshow(torchvision.utils.make_grid(images))
# 輸出類別
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

2. 定義一個卷積神經網絡

從神經網絡部分復制神經網絡, 并修改它以獲取 3 通道圖像(而不是定義的 1 通道圖像).

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()

3. 定義一個損失函數和優化器

我們使用交叉熵損失函數( CrossEntropyLoss )和隨機梯度下降( SGD )優化器.

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 訓練網絡

這是事情開始變得有趣的時候. 我們只需循環遍歷數據迭代器, 并將輸入提供給網絡和優化器.

for epoch in range(2):  # 循環遍歷數據集多次running_loss = 0.0for i, data in enumerate(trainloader, 0):# 得到輸入數據inputs, labels = data# 包裝數據inputs, labels = Variable(inputs), Variable(labels)# 梯度清零optimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印信息running_loss += loss.data[0]if i % 2000 == 1999:    # 每2000個小批量打印一次print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')

5. 在測試數據上測試網絡

我們在訓練數據集上訓練了2遍網絡, 但是我們需要檢查網絡是否學到了什么.

我們將通過預測神經網絡輸出的類標簽來檢查這個問題, 并根據實際情況進行檢查. 如果預測是正確的, 我們將樣本添加到正確預測的列表中.

好的, 第一步. 讓我們顯示測試集中的圖像以便熟悉.

dataiter = iter(testloader)
images, labels = dataiter.next()# 打印圖像
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

好的, 現在讓我們看看神經網絡認為這些例子是什么:

outputs = net(Variable(images))

輸出的是10個類別的能量. 一個類別的能量越高, 則可以理解為網絡認為越多的圖像是該類別的. 那么, 讓我們得到最高能量的索引:

_, predicted = torch.max(outputs.data, 1)print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]for j in range(4)))

結果看起來不錯.

讓我們看看網絡如何在整個數據集上執行.

correct = 0
total = 0
for data in testloader:images, labels = dataoutputs = net(Variable(images))_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

訓練的準確率遠比隨機猜測(準確率10%)好, 證明網絡確實學到了東西.

嗯, 我們來看看哪些類別表現良好, 哪些類別表現不佳:

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
for data in testloader:images, labels = dataoutputs = net(Variable(images))_, predicted = torch.max(outputs.data, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i]class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

好的, 接下來呢?

我們如何在 GPU 上運行這些神經網絡?

在 GPU 上訓練

就像你如何將一個張量傳遞給GPU一樣, 你將神經網絡轉移到GPU上. 這將遞歸遍歷所有模塊, 并將其參數和緩沖區轉換為CUDA張量:

net.cuda()

請記住, 您必須將輸入和目標每一步都發送到GPU:

inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())

如果發現在 GPU 上并沒有比 CPU 提速很多, 實際上是因為網絡比較小, GPU 沒有完全發揮自己的真正實力.

練習:?嘗試增加網絡的寬度(第一個?nn.Conv2d?的參數2和第二個?nn.Conv2d?的參數1 它們需要是相同的數字), 看看你得到什么樣的加速.

目標達成:

  • 深入了解PyTorch的張量庫和神經網絡.
  • 訓練一個小的神經網絡來分類圖像.

在多個GPU上進行訓練

如果你希望使用所有 GPU 來看更多的 MASSIVE 加速, 請查看可選?可選: 數據并行.

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/444677.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/444677.shtml
英文地址,請注明出處:http://en.pswp.cn/news/444677.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Git(6)-Git配置文件、底層操作命令

Git基本命令1. 常用(迷糊)命令-冷知識2. git 配置2.1 設置 配置文件2.2 查看 配置文件--git config -l2.3 移除 配置文件設置--unset2.3 命令別名 --alias3.git 對象 (git底層操作命令)3.1 初始化一個版本庫3.2 新建一個簡單的blob 對象3.3 基于散列值查…

【軟考中級】網絡工程師:8.網絡安全

本章考察內容比較廣泛,考題對知識點都會有所涉及。 8.1 網絡安全的基本概念 8.1.1 網絡安全威脅的類型 竊聽 這種情況發生在廣播式網絡系統中,每個節點都可以讀取數據,實現搭線竊聽、安裝通信監視器和讀取網上的信息等。 假冒 當一個實體…

leetcode9 回文數

判斷一個整數是否是回文數。回文數是指正序(從左向右)和倒序(從右向左)讀都是一樣的整數。 示例 1: 輸入: 121 輸出: true 示例 2: 輸入: -121 輸出: false 解釋: 從左向右讀, 為 -121 。 從右向左讀, 為 121- 。因此它不是一個…

caffe各層參數詳解

在prototxt文件中,層都是用layer{}的結構表示,而里面包含的層的參數可以在caffe.proto文件中找到,比如說Data類型的結構由message DataParameter所定義,Convolution類型的結構由message ConvolutionParameter所定義。 具體說明下: name表示該層的名稱type表示該層的類型,…

caffe網絡結構圖繪制

繪制網絡圖通常有兩種方法: 一種是利用python自帶的draw_net.py,首先安裝兩個庫: sudo apt-get install graphviz sudo pip install pydot 接下來就可以用python自帶的draw_net.py文件來繪制網絡圖了。 draw_net.py執行時帶三個參數&…

Git(7)-Git commit

Git提交1.識別不同的提交1.1絕對提交名-ID1.2 引用和符號引用--HEAD2.查看提交的歷史記錄-git log3.提交圖-gitk4.提交的范圍4.1 X..Y4.1 X...Y5.查找bad 提交--git bisect6.查看代碼修改者-git blame命令概覽git commit -a # 直接提交修改和刪除文件有效加了-a,在 …

leetcode111. 二叉樹的最小深度

給定一個二叉樹,找出其最小深度。 最小深度是從根節點到最近葉子節點的最短路徑上的節點數量。 說明: 葉子節點是指沒有子節點的節點。 示例: 給定二叉樹 [3,9,20,null,null,15,7], 3 / \ 9 20 / \ 15 7 返回它的最小深度 2. 思路&#xff1a…

Caffe將圖像數據轉換成leveldb/lmdb

Caffe中convert_imageset projrct將圖像數據轉換成Caffe能讀取的數據格式leveldb/lmdb -gray=true //whether read gray image -shuffle=true //whether mix order -resize_height=28 -resize_width=28 -backend=lmdb …

leetcode155. 最小棧

設計一個支持 push,pop,top 操作,并能在常數時間內檢索到最小元素的棧。 push(x) -- 將元素 x 推入棧中。 pop() -- 刪除棧頂的元素。 top() -- 獲取棧頂元素。 getMin() -- 檢索棧中的最小元素。 示例: MinStack minStack new MinStack()…

理解Caffe的網絡模型

目錄 1. 初見LeNet原始模型2. Caffe LeNet的網絡結構3. 逐層理解Caffe LeNet 3.1 Data Layer3.2 Conv1 Layer3.3 Pool1 Layer3.4 Conv2 Layer3.5 Pool2 Layer3.6 Ip1 Layer3.7 Relu1 Layer3.8 Ip2 Layer3.9 Loss Layer 1. 初見LeNet原始模型 Fig.1. Architecture of original …

Git(8)-分支

分支1. 分支名2. 創建分支-git branch3. 查看分支-git show-branch4. 檢出分支4.1 有未提交的修改時進行檢出4.2 合并變更到不同的分支git checkout -m5. 分離HEAD 分支6.刪除分支分支操作命令概覽 git branch # 列出版本庫中的分支 git branch -r # 列出遠程跟蹤分支…

caffe開始訓練自己的模型(轉載并驗證過)

學習caffe中踩了不少坑,這里我參考了此博主的文章,并體會到了如何訓練自己的模型:http://www.cnblogs.com/denny402/p/5083300.html 學習caffe的目的,不是簡單的做幾個練習,最終還是要用到自己的實際項目或科研中。因…

leetcode169. 多數元素

給定一個大小為 n 的數組,找到其中的多數元素。多數元素是指在數組中出現次數大于 ? n/2 ? 的元素。 你可以假設數組是非空的,并且給定的數組總是存在多數元素。 示例 1: 輸入: [3,2,3] 輸出: 3 示例 2: 輸入: [2,2,1,1,1,2,2] 輸出: 2 思路&…

Git(9)-diff

分支1. diff in Linux/Unix2. diff in Git3. git diff 兩點語法Linux/Unix 系統中存在diff 命令,可以用來顯示兩個文本/工作路徑的差異。Git diff 在此基礎上進行的擴展。 1. diff in Linux/Unix Linux 系統中的diff 命令:提供了一個文件如何轉化為另一…

圖像拼接(一):柱面投影+模板匹配+漸入漸出融合

這種拼接方法的假設前提是:待拼接的兩幅圖像之間的變換模型是平移模型,即兩幅圖像同名點位置之間只相差兩個未知量:ΔxΔx 和ΔyΔy,自由度為2,模型收得最緊。所以只有所有圖像都是用同一水平線或者同一已知傾斜角的攝…

圖像拼接(二):OpenCV同時打開兩個攝像頭捕獲視頻

使用OpenCV實現同時打開兩個USB攝像頭,并實時顯示視頻。如果未檢測有兩個攝像頭,程序會結束并發出“攝像頭未安裝好”的警告。這里推薦一個小巧的攝像頭視頻捕捉軟件:amcap,使用它可以方便的檢查每個攝像頭是否能正常工作。 捕獲…

Git(10)-merge

Merge1. 無沖突合并2. 有沖突合并-手動解決3. git diff in merge4. 廢棄合并5. 合并策略merge相關的操作的命令 git checkout master git merge alternate # 解決沖突 ..... git add file_1 git commit -m "Add slternate line 5, 6" git reset --hard HEAD # b…

elasticsearch的Linux下安裝報錯問題解決

1.啟動報錯如下: vim /etc/security/limits.conf 然后修改如下 * soft nofile 65536 * hard nofile 65536sudo vi /etc/pam.d/common-session 添加 session required pam_limits.so sudo vi /etc/pam.d/common-session-noninteractive 添加 session required pam_limits.so…

leetcode120. 三角形最小路徑和

給定一個三角形,找出自頂向下的最小路徑和。每一步只能移動到下一行中相鄰的結點上。 例如,給定三角形: [ [2], [3,4], [6,5,7], [4,1,8,3] ] 自頂向下的最小路徑和為 11(即,2 3 5 1 11&#xff0…

Elasticsearchan相關插件和工具安裝

1、下載elasticsearch-head的源碼包 地址:https://github.com/mobz/elasticsearch-head/releases 2、安裝node運行環境 地址:https://nodejs.org/en/download/ 3、安裝完node之后編譯elasticsearch-head 執行npm install -g grunt-cli編譯源碼 執行…