Training a classifier

你已經學習了如何定義神經網絡,計算損失和執行網絡權重的更新。

現在你或許在思考。

What about data?

通常當你需要處理圖像,文本,音頻,視頻數據,你能夠使用標準的python包將數據加載進numpy數組。之后你能夠轉換這些數組到torch.*Tensor。

  • 對于圖片,類似于Pillow,OPenCV的包很有用
  • 對于音頻,類似于scipy和librosa的包
  • 對于文字,無論是基于原生python和是Cython的加載,或者NLTK和SpaCy都有效

對于視覺,我們特意創建了一個包叫做torchvision,它有常見數據集的數據加載,比如ImageNet,CIFAR10,MNIST等,還有圖片的數據轉換,torchvision.datasets和torch.utils.data.Dataloader。

這提供了很方便的實現,避免了寫樣板代碼。

對于這一文章,我們將使用CIFAR10數據集。它擁有飛機,汽車,鳥,貓,鹿,狗,霧,馬,船,卡車等類別。CIFAR-10的圖片尺寸為3*32*32,也就是3個顏色通道和32*32個像素。

?

Training? an image classifier

?我們將按照順序執行如下步驟:

  1. 使用torchvision加載并且標準化CIFAR10訓練和測試數據集
  2. 定義一個卷積神經網絡
  3. 定義損失函數
  4. 使用訓練數據訓練網絡
  5. 使用測試數據測試網絡

?

1.加載并標準化CIFAR10

使用torchvision,加載CIFAR10非常簡單

import torch
import torchvision
import torchvision.transforms as transforms

torchvision數據集的輸出是PIL圖片庫圖片,范圍為[0,1]。我們將它們轉換為tensor并標準化為[-1,1]。

import torch
import torchvision
import torchvision.transforms as transformstransform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset
=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform) trainloader=torch.data.Dataloader(trainset,batch_size=4,shuffle=True,num_workers=2)
testset
=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform) testloader=torch.utils.data.Dataloader(testset,batch_size=4,shuffle=False,num_workers=2)
classes
= ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
out:
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Files already downloaded and verified

?我們來觀察一下訓練集圖片

import matplotlib.pyplot as plt
import numpy as npdef imshow(img):img=img/2+0.5npimg=img.numpy()plt.imshow(np.transpose(npimg,(1,2,0)))dataiter=iter(trainloader)
images,labels=dataiter.next()imshow(torchvision.utils.make_grid(images))
plt.show()
print(''.join('%5s'%classes[labels[j]] for j in range(4)))

?

out:
truck truck  dog truck

?

?2.定義卷積神經網絡

從前面神經網絡章節復制神經網絡,并把它改成接受3維圖片輸入(而不是之前定義的一維圖片)。

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1=nn.Conv2d(3,6,5)self.pool=nn.MaxPool2d(2,2)self.conv2=nn.Conv2d(6,16,5)self.fc1=nn.Linear(16*5*5,120)self.fc2=nn.Linear(120,84)self.fc3=nn.Linear(84,10)def forward(self,x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=x.view(-1,16*5*5)x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return xnet=Net()

?

?3.定義損失函數和優化器

我們使用分類交叉熵損失和帶有動量的SGD

import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

?

4.訓練網絡

我們只需要簡單地迭代數據,把輸入喂進網絡并優化。

import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)for epoch in range(2):running_loss=0.0for i,data in enumerate(trainloader,0):inputs,labels=dataoptimizer.zero_grad()outputs=net(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()running_loss+=loss.item()if i%2000==1999:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss=0
print('Finished Training')
out:
[1,  2000] loss: 2.208
[1,  4000] loss: 1.797
[1,  6000] loss: 1.627
[1,  8000] loss: 1.534
[1, 10000] loss: 1.508
[1, 12000] loss: 1.453
[2,  2000] loss: 1.378
[2,  4000] loss: 1.365
[2,  6000] loss: 1.326
[2,  8000] loss: 1.309
[2, 10000] loss: 1.290
[2, 12000] loss: 1.262
Finished Training

?

?

4.在測試數據集上測試網絡

我們已經遍歷了兩遍訓練集來訓練網絡。需要檢查下網絡是不是已經學習到了什么。

我們將檢查神經網絡輸出的預測標簽是否與真實標簽相同。如果預測是正確的,我們將這一樣本加入到正確預測的列表。

我們先來熟悉一下訓練圖片。

dataiter=iter(testloader)
images,labes=dataiter.next()imshow(torchvision.utils.make_grid(images))
plt.show()
print('GroundTruth: ',' '.join('%5s' % classes[labels[j]] for j in range(4)))

?

out:
GroundTruth:  plane  deer   dog horse

?ok,現在讓我們看一下神經網絡認為這些樣本是什么。

outputs=net(images)

?輸出是10個類別的量值,大的值代表網絡認為某一類的可能性更大。所以我們來獲得最大值得索引:

_,predicted=torch.max(outputs,1)
print("Predicted: ",' '.join('%5s' %classes[predicted[j]] for j in range(4)))
out:
Predicted:   bird   dog  deer horse

?讓我們看看整個數據集上的模型表現。

out:
Accuracy of the network on the 10000 test images: 54 %

?這看起來要好過瞎猜,隨機的話只要10%的準確率(因為是10類)。看來網絡是學習到了一些東西。

我們來繼續看看在哪些類上的效果好,在哪些類上的效果比較差:

out:
Accuracy of plane : 56 %
Accuracy of   car : 70 %
Accuracy of  bird : 27 %
Accuracy of   cat : 16 %
Accuracy of  deer : 44 %
Accuracy of   dog : 64 %
Accuracy of  frog : 61 %
Accuracy of horse : 73 %
Accuracy of  ship : 68 %
Accuracy of truck : 61 %

好了,接下來該干點啥?

我們怎樣將這個神經網絡運行在GPU上呢?

Trainning on GPU

就像你怎么把一個Tensor轉移到GPU上一樣,現在把神經網絡轉移到GPU上。

如果我們有一個可用的CUDA,首先將我們的設備定義為第一個可見的cuda設備:

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
out:
cuda:0

?剩下的章節我們假定我們的設備是CUDA。

之后這些方法將遞歸到所有模塊,將其參數和緩沖區轉換為CUDA張量:

net.to(device)

?記得你還需要在每步循環里將數據轉移到GPU上:

inputs,labels=inputs.to(device),labels.to(device)

為什么沒注意到相對于CPU巨大的速度提升?這是因為你的網絡還非常小。

?

練習:嘗試增加你網絡的寬度(第一個nn.Conv2d的參數2應該與第二個nn.Conv2d的參數1是相等的數字),觀察你得到的速度提升。

達成目標:

  • 更深一步理解Pytorch的Tensor庫和神經網絡
  • 訓練一個小神經網絡來分類圖片

?Trainning on multiple GPUs

如果你想看到更加顯著的GPU加速,請移步:https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html

?

轉載于:https://www.cnblogs.com/Thinker-pcw/p/9637411.html

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/249527.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/249527.shtml
英文地址,請注明出處:http://en.pswp.cn/news/249527.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

19歲白帽子通過bug懸賞賺到一百萬美元--轉

出處:https://news.cnblogs.com/n/620858/ 19 歲的 Santiago Lopez 通過 bug 懸賞平臺 HackerOne 報告漏洞,成為第一位通過 bug 懸賞賺到一百萬美元的白帽子黑客。他的白帽子生涯始于 2015 年,至今共報告了超過 1600 個安全漏洞。他在 16 歲時…

代碼分層的設計

分層思想,是應用系統最常見的一種架構模式,我們會將系統橫向切割,根據業務職責劃分。MVC 三層架構就是非常典型架構模式,劃分的目的是規劃軟件系統的邏輯結構便于開發維護。MVC:英文即 Model-View-Controller&#xff…

【24小時內第四更】為什么我們要堅持寫博客?

前言 從2018年7月份,我開始了寫作博客之路。開始之前,我打算分享下之前的經歷。去年初公司來了個架構師,內部分享過docker原理,TDD單元測試驅動,并發并行異步編程等內容,讓我著實驚呆了,因為確實…

sqoop快速入門

轉自http://www.aboutyun.com/thread-22549-1-1.html 轉載于:https://www.cnblogs.com/drjava/p/10473297.html

ListableBeanFactory接口

ListableBeanFactory獲取bean時,Spring 鼓勵使用這個接口定義的api. 還有個Beanfactory方便使用.其他的4個接口都是不鼓勵使用的. 提供容器中bean迭代的功能,不再需要一個個bean地查找.比如可以一次獲取全部的bean(太暴力了),根據類型獲取bean.在看SpringMVC時,掃描包路徑下的…

HDU 4035 Maze

Maze http://acm.hdu.edu.cn/showproblem.php?pid4035 分析: 在樹上走來走去,然后在一個點可以k的概率回到1,可以e的概率走出去,可以1-k-e的概率走到其他的位置(分為父節點和子節點討論)。 轉移方程就是&a…

面向對象之三大特性:繼承,封裝,多態

python面向對象的三大特性:繼承,封裝,多態。 1. 封裝: 把很多數據封裝到?個對象中. 把固定功能的代碼封裝到?個代碼塊, 函數, 對象, 打包成模塊. 這都屬于封裝的思想. 具體的情況具體分析. 比如. 你寫了?個很?B的函數. 那這個也可以被稱為…

configurablebeanfactory

ConfigurableBeanFactory定義BeanFactory的配置.ConfigurableBeanFactory中定義了太多太多的api,比如類加載器,類型轉化,屬性編輯器,BeanPostProcessor,作用域,bean定義,處理bean依賴關系,合并其他ConfigurableBeanFactory,bean如何銷毀. ConfigurableBeanFactory同時繼承了Hi…

Xlua文件在熱更新中調用方法

Xlua文件在熱更新中調用方法 public class news : MonoBehaviour { LuaEnv luaEnv;//定義Lua初始變量 void Awake() { luaEnv new LuaEnv();//new開辟空間 luaEnv.AddLoader(myload);//調用方法地址、返回字節 luaEnv.DoString("requirefish");//更新文件 } void O…

springboot 使用的配置

1,控制臺打印sql logging:level:com.sdyy.test.mapper: debug 2,開啟駝峰命名 mybatis.configuration.map-underscore-to-camel-casetrue 轉載于:https://www.cnblogs.com/xiaohu1218/p/10477318.html

AutowireCapableBeanFactory接口

AutowireCapableBeanFactory在BeanFactory基礎上實現了對存在實例的管理.可以使用這個接口集成其它框架,捆綁并填充并不由Spring管理生命周期并已存在的實例.像集成WebWork的Actions 和Tapestry Page就很實用. 一般應用開發者不會使用這個接口,所以像ApplicationContext這樣的…

外觀模式

一、什么是外觀模式   有些人可能炒過股票,但其實大部分人都不太懂,這種沒有足夠了解證券知識的情況下做股票是很容易虧錢的,剛開始炒股肯定都會想,如果有個懂行的幫幫手就好,其實基金就是個好幫手,支付寶…

OC內存管理

OC內存管理 一、基本原理 (一)為什么要進行內存管理。 由于移動設備的內存極其有限,所以每個APP所占的內存也是有限制的,當app所占用的內存較多時,系統就會發出內存警告,這時需要回收一些不需要再繼續使用的…

cf1132E. Knapsack(搜索)

題意 題目鏈接 Sol 看了status里面最短的代碼。。感覺自己真是菜的一批。。直接爆搜居然可以過&#xff1f;。。但是現在還沒終測所以可能會fst。。 #include<bits/stdc.h> #define Pair pair<int, int> #define MP(x, y) make_pair(x, y) #define fi first #defi…

ConfigurableListableBeanFactory

ConfigurableListableBeanFactory 提供bean definition的解析,注冊功能,再對單例來個預加載(解決循環依賴問題). 貌似我們一般開發就會直接定義這么個接口了事.而不是像Spring這樣先根據使用情況細分那么多,到這邊再合并 ConfigurableListableBeanFactory具體&#xff1a; 1、…

焦旭超 201771010109《面向對象程序設計課程學習進度條》

《2018面向對象程序設計&#xff08;java&#xff09;課程學習進度條》 周次 &#xff08;閱讀/編寫&#xff09;代碼行數 發布博客量/博客評論量 課堂/課余學習時間&#xff08;小時&#xff09; 最滿意的編程任務 第一周 50/20 1/0 6/4 九九乘法表 第二周 90/5…

面試題集錦

1. L1范式和L2范式的區別 (1) L1范式是對應參數向量絕對值之和 (2) L1范式具有稀疏性 (3) L1范式可以用來作為特征選擇&#xff0c;并且可解釋性較強&#xff08;這里的原理是在實際Loss function 中都需要求最小值&#xff0c;根據L1的定義可知L1最小值只有0&#xff0c;故可以…

Spring注解配置工作原理源碼解析

一、背景知識 在【Spring實戰】Spring容器初始化完成后執行初始化數據方法一文中說要分析其實現原理&#xff0c;于是就從源碼中尋找答案&#xff0c;看源碼容易跑偏&#xff0c;因此應當有個主線&#xff0c;或者帶著問題、目標去看&#xff0c;這樣才能最大限度的提升自身代…

halt

關機 init 0 reboot init6 shutdown -r now 重啟 -h now 關機 轉載于:https://www.cnblogs.com/todayORtomorrow/p/10486123.html

Spring--Context

應用上下文 Spring通過應用上下文&#xff08;Application Context&#xff09;裝載bean的定義并把它們組裝起來。Spring應用上下文全權負責對象的創建和組裝。Spring自帶了多種應用上下文的實現&#xff0c;它們之間主要的區別僅僅在于如何加載配置。 1.AnnotationConfigApp…