python torch exp_學習Pytorch過程遇到的坑(持續更新中)

1. 關于單機多卡的處理:

在pytorch官網上有一個簡單的示例:函數使用為:torch.nn.DataParallel(model, deviceids, outputdevice, dim)關鍵的在于model、device_ids這兩個參數。DATA PARALLELISM?pytorch.org

但是官網的例子中沒有講到一個核心的問題:即所有的tensor必須要在同一個GPU上。這是網絡運行的前提。這篇文章給了我很大的幫助,里面的例子也很好懂,很直觀:pytorch: 一機多卡訓練的嘗試?www.jianshu.com

一般來說有兩種數據遷移的方法:

1)是先定義一個device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')【這里面已經定義了device在卡0上“cuda:0”】

然后將model = torch.nn.DataParallel(model,devices_ids=[0, 1, 2])(假設有三張卡)

此后需要將tensor 也遷移到GPU上去。注意所有的tensor必須要在同一張GPU上面

即:tensor1 = tensor1.to(device), tensor2 = tensor2.to(device)等等

(可能有人會問了,我并沒有指定那一塊GPU啊,怎么這樣也沒有出錯啊?

原因很簡單,因為一開始的device中已經指定了那一塊卡了(卡的id為0))

2)第二中方法就是直接用tensor.cuda()的方法

即先model = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) (假設有三塊卡, 卡的ID 為0, 1, 2)

然后tensor1 = tensor1.cuda(0), tensor2=tensor2.cuda(0)等等。(我這里面把所有的tensor全放進ID 為 0 的卡里面,也可以將全部的tensor都放在ID 為1 的卡里面)

2 關于DataParallel的封裝問題

在DataParallel中,沒有和nn.Module一樣多的特性。但是有些時候我們可能需要使用到如.fc這樣的性質(.fc性質在nn.Module中有, 但是在DataParallel中沒有)這個時候我們需要一個.Module屬性來進行過渡。操作如下:

model = Model() # 這里實例化Model類得到一個model

model.fc # 這樣做不會報錯

# DataParallel情況下

parallel_model = torch.nn.DataParallel(model)

parallel_model.fc # 會報錯。解決辦法,很簡單, 在fc前加一個.module即可

parall_model.module.fc # 不會報錯

3 Pytorch中的數據導入潛規則

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]

所以我們在transform的時候可以先定義:normalized = torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 然后用的時候直接調用normalized就行了。

4 python中的某些包的版本不同也會導致程序運行失敗。

如,今天遇到一個pillow包的問題。原先裝的包的6.0.0版本的,但是在制作數據集的時候,訓練集跑的好好的,一到驗證集就開始無端報錯。在確定程序無誤之后,將程序放在別的環境中跑(也是pytorch環境),正常運行。于是經過幾番查找,發現是pillow出了問題,于是乎卸載了原來的版本,重新裝一個低一點的版本問題就解決了。這種版本問題的坑其實很多,而且每個人遇到的還都不盡相同,所以需要慢慢的去摸索才能發現問題所在。

5 關于CUDA 內存溢出的問題。

這個一般是因為batch_size 設置的比較大。(8G顯存的話大概batch_size < = 64都ok, 如果還是報錯的話,就在對半分 64, 32, 16, 8, 4等等)。而且這個和你的數據大小沒什么太大的關系。因為我剛剛開始也是想可能是我訓練集太大了,于是將數據集縮小了十倍,還是同樣的報錯。所以就想可能 batch_size的問題。最后果然是batchsize的問題。

6 關于模型導入

一般來說如果你的模型是再GPU上面訓練的,那么如果你繼續再GPU上面進行其他的后續操作(如遷移學習等)那么直接使用:

import torch

from torchvision import models

pre_trained_weight = torch.load('pre_trained_weight.pt') # pre_trained_weight.pt 是我在resnet18上面訓練好的模型

resnet18 = models.resnet18(pretrained=False) # 導入框架

resnet18.load_state_dict(pre_trained_weight) # load_state_dict()函數表示導入當前權值,因為這個權值都是以字典的形式保存的

# 如果你模型在GPU上訓練的,而且后續操作也在GPU上進行,那上面的操作就沒啥毛病。但是…………

如果你模型在GPU上訓練的,后續操作是在CPU上進行的話。那么還用上面的代碼的話就會報錯了。因為你模型在GPU上訓練,其實其內部的某些數據格式和CPU上的不大一樣。所以需要一個函數將GPU上的模型轉化為CPU上的模型。這個工作在pytorch里面其實很簡單。只要把上面的代碼簡單修改一下即可:(在torch.load函數里面加一個map_location='cpu'即可!)

import torch

from torchvision import models

pre_trained_weight = torch.load('pre_trained_weight.pt',map_location='cpu') # pre_trained_weight.pt 是我在resnet18上面訓練好的模型

resnet18 = models.resnet18(pretrained=False) # 導入框架

resnet18.load_state_dict(pre_trained_weight) # load_state_dict()函數表示導入當前權值,因為這個權值都是以字典的形式保存的

7. 關于兩次sort操作:

前幾天看SSD pytorch的源碼發現了,有這樣的一步操作,不得解,

于是查閱了一下資料和動手操作后發現了兩次sort操作的神奇之處。

首先 sort操作沒什么好說。它接收兩個參數:dim和descending參數。dim表示的是從哪個維度進行排列,descending參數接收布爾類型的輸入,表示結果是否按降序排列。兩次sort操作的具體實施為。

import torch

x = torch.randon(3, 4)

>>>x

tensor([[-0.1361, 0.4076, -0.8244, 0.9163],

[-0.0997, -1.1689, -2.3145, 1.2334],

[-0.4384, -1.6083, 1.7621, -0.9648]])

_, indices = x.sort(dim=1, descending=False)

>>>indices

tensor([[2, 0, 1, 3],

[2, 1, 0, 3],

[1, 3, 0, 2]])

# 上面的是進行第一次的sort, 得到的結果關于x的每一行的元素的升序排列

# 下面進行第二次sort操作。

_, idx = indices.sort(dim=1, descending=False)

>>>idx

tensor([[1, 2, 0, 3],

[2, 1, 0, 3],

[2, 0, 3, 1]])

# 我們來分析一下這個得到的idx和原始數據x的關系。

>>>x

tensor([[-0.1361, 0.4076, -0.8244, 0.9163],

[-0.0997, -1.1689, -2.3145, 1.2334],

[-0.4384, -1.6083, 1.7621, -0.9648]])

按升序排列的話,x的【第一行】中的第一個元素對應的是第二小,第二個元素對應的第三小,第三個元素對應是最小, 最后一個元素應該是最大的

所以這個排列的大小和位置可以從二次sort操作的idx中能看到。現在分析idx,取其第一行【1, 2, 0, 3】, 表示的意思是x[0,0]處在x[0]這一行

的第二位,x[0, 1]處在下x[0]中的第三位, x[0, 2]處在x[0]這一行的第一位, 下x[0, 3]處在x[0]行的最后一位。

(注:這里的第幾位表示的是每一行按升序排列原則,其中的元素所處的位置)

從上面的分析中可以看到,兩次sort操作得到的idx的意義是: 在保證原始元素的位置不變的情況下,可以表示排序情況(升序or降序)。

以上是原理,那么兩次sort究竟用在什么地方呢?

還是上面哪個例子:

>>>x

tensor([[-0.1361, 0.4076, -0.8244, 0.9163],

[-0.0997, -1.1689, -2.3145, 1.2334],

[-0.4384, -1.6083, 1.7621, -0.9648]])

我想取x的第一行元素的前1個最小值, 第二行元素的前2個最小值,第三行元素的前3個最小值。該怎么操作呢?

根據上面的兩次sort操作,我們得到idx

tensor([[1, 2, 0, 3],

[2, 1, 0, 3],

[2, 0, 3, 1]])

# 定義criterion

criterion = torch.tensor([1, 2, 3]).view(3, -1)

criterion = criterion.expand_as(idx)

>>>criterion

tensor([[1, 1, 1, 1],

[2, 2, 2, 2],

[3, 3, 3, 3]])

mask = idx < criterion

>>>mask

tensor([[0, 0, 1, 0],

[0, 1, 1, 0],

[1, 1, 0, 1]], dtype=torch.uint8)

# 可以看到,mask得到的就是我們所需要的索引。可以看到mask第一行只有一個1, 第二行有兩個1,第三行有三個1.這里的1表示的True的意思,即得到這個數

>>>x[mask]

tensor([-0.8244, -1.1689, -2.3145, -0.4384, -1.6083, -0.9648]) # 最終結果

8. log_sum_exp的trick:機器學習常見模式LogSumExp解密人工智能_機器人之家?www.jqr.com

參考這篇文章,寫的通俗易懂。大概介紹一下問題:

發現這個問題是前幾天,這里面在進行exp操作的時候用x-x_max。當時很是疑惑。后來一看上面這篇文章才明白了。

一般來說

是有一個確切的值與之對應的。但是在計算里面卻不是這樣的。輸入torch.exp(1000), 結果是:

這樣的結果并不意外,因為計算機的存儲階段誤差導致的。基于這種情況的存在,所以人們想到了一個比較好的解決方法。具體怎么實現看看上面的鏈接便清楚了。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/538951.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/538951.shtml
英文地址,請注明出處:http://en.pswp.cn/news/538951.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

php 讀文件返回字符串,PHP:file_get_contents('php:// input')返回JSON消息的字符串...

我正在嘗試在我的PHP應用程序中讀取JSON消息&#xff0c;這是我的php代碼&#xff1a;$json file_get_contents(php://input);$obj json_decode($json, TRUE);echo $obj->{S3URL};當我這樣做時&#xff0c;出現以下錯誤&#xff1a;Trying to get property of non-object …

Android中ListView數據使用sAdapter.notifyDataSetChanged();方法不刷新的問題

原文鏈接&#xff1a;http://blog.csdn.net/caihongdao123/article/details/51513410 點擊閱讀原文 ------------------------- 1.涉及到數據庫 當要動態顯示更數據庫改動&#xff0c;相信大家應該都用過notifyDataSetChanged();. 例如&#xff1a; ...... …

keepalived配置高可用集群

準備工作 分別在主從上安裝keepalived和nginxyum install -y keepalivedyum install -y nginx關閉主從上的防火墻和SELinuxsystemctl stop firewalldsetenforce 0 配置主機 查看主機ip [rootlynn-04 ~]# ifconfig ens33: flags4163<UP,BROADCAST,RUNNING,MULTICAST> mtu…

如何快速掌握python包_如何快速掌握一個python模塊?

初學者就別想快了。 我自己是這樣的。先上網看看一些基礎的教程&#xff0c;非常快的過一下&#xff0c;十幾分鐘&#xff0c;主要是了解這個module能干什么&#xff0c;特別是一些基本的功能&#xff0c;頭腦中建立起初步映射。 然后就是用了&#xff0c;不用看了也白搭。我假…

python設計一個函數定義計算并返回n價調和函數_python函數的調和平均值?

我有兩個函數&#xff0c;給出精度和召回分數&#xff0c;我需要做一個調和平均函數&#xff0c;定義在同一個庫中&#xff0c;使用這兩個分數。函數如下所示&#xff1a;功能如下&#xff1a;def precision(ref, hyp):"""Calculates precision.Args:- ref: a l…

jsp超鏈接到java文件,jsp頁面超鏈接傳中文終極解決辦法

在做web前端頁面的時候&#xff0c;經常碰到傳中文問題。網上有許多方案&#xff0c;但仍不能根治&#xff0c;最終要用js或者java的encode相關方法。常規方案有三部&#xff1a;1.改tomcat的server.xml中URIEncodeing為utf-82.頁面中編碼設置為utf-83.整個項目編碼使用utf-8我…

自定義ListView中的分割線

原文&#xff1a;http://blog.csdn.net/zuolongsnail/article/details/7187302點擊閱讀 --------------------------------------- ListView中每個Item項之間都有分割線&#xff0c;設置Android:footerDividersEnabled表示是否顯示分割線&#xff0c;此屬性默認為true。 1.不顯…

隱藏域input里面放當前時間_【小A問答】Win10的隱藏小秘密,被我發現了!

無驚無險又到小A問答環節辣~~今天的小A要來給大家分享一些小秘密&#xff01;當然&#xff0c;這可不是小A自己的小秘密&#xff0c;是關于你電腦的小秘密哦&#xff01;知道嗎&#xff1f;Windows10每一次升級更新&#xff0c;都會伴隨著新功能的增加。這些隱藏的功能你都發現…

網絡相關的面試題

1&#xff09;簡述tcp/ip的三次交互過程&#xff08;個人理解&#xff1a;syn是握手信號&#xff0c;ack是確認信號&#xff0c;ack就相當于前面的syn值1&#xff0c;簡單一點理解就是客戶端發送握手請求&#xff0c;服務器收到握手請求后&#xff0c;回復一個包確認它接收到了…

h5文字垂直居中_CSS中垂直居中和水平垂直居中的方法

flex垂直居中&#xff1a;第一種&#xff1a;使用flex布局&#xff0c;讓居中元素的父元素為flex屬性,讓它在交叉軸上center就可以達到居中效果了&#xff1a;html代碼: <div class"father"><p>我要垂直居中</p> </div>css代碼: .father {…

ListMap排序

//compareto就是比較兩個數據的大小關系 大于0表示前一個數據比后一個數據大&#xff0c; 0表示相等&#xff0c;小于0表示第一個數據小于第二個數據 public static List<Map<String, String>> sortWifi(List<Map<String, String>> wifiList){if(wif…

thinkphp回調的php調用db類,請問thinkphp中model類自動完成功能 回調函數能不能獲取其他字段的值?...

http://www.thinkphp.cn/api/source-class-Model.html#975protected function _validationFieldItem($data,$val) {switch(strtolower(trim($val[4]))) {case function:// 使用函數進行驗證case callback:// 調用方法進行驗證$args isset($val[6])?(array)$val[6]:array();if…

輸入年份和月份輸出該月有多少天python_Python實現用戶輸入年月日,程序打印出這是這一年的第多少天...

1. 自己造輪子yearint(input(請輸入年份&#xff0c;如2019>>>))monthint(input(請輸入月份&#xff0c;如8>>>))dayint(input(請輸入日期,如25>>>))#下面這塊代碼是按照閏年計算if (year%40 and year%100!0) or (year%4000):calendar{1:31,2:29,3:…

Linux命令之find命令中的-mtime參數

有關find -mtime這個參數的使用有比較多的坑&#xff0c;今天把這個問題在這里記錄下來&#xff1a; mtime參數的理解應該如下&#xff1a; -mtime n 按照文件的更改時間來找文件&#xff0c;n為整數。 n 表示文件更改時間距離為n天-n 表示文件更改時間距離在n天以內n 表示文件…

WifiManager的getScanResults()返回列表為0

這個問題查了好久&#xff0c;花了2個小時。就是出不來。 原來問題在android sdk 版本問題。 在android 6.0的時候&#xff0c;返回為空&#xff0c;且不為null&#xff0c;在華為mate&#xff0c;6.0手機上測試&#xff0c;也不報錯。 官網和網上沒有具體的解決方法。 后來…

c++直角坐標系與極坐標系的轉換_平面向量的奇技淫巧——斜坐標系的一系列低級研究...

事先說明&#xff1a;筆者初三&#xff0c;如在敘述中有不嚴謹的地方&#xff0c;還請諸位指出&#xff0c;自當感激不盡。一.什么是斜坐標系眾所周知&#xff0c;我們目前平面中使用相當廣的坐標系是笛卡爾發明的平面直角坐標系。然而&#xff0c;笛卡爾真的只使用了這一種坐標…

php 字節轉為kb,PHP獲取文件大小并轉化為KB、MB、GB單位

PHP獲取文件大小并轉化為KB、MB、GB單位。function getSize($filesize) {if ($filesize > 1073741824) {$filesize round($filesize / 1073741824 * 100) / 100 . GB;} elseif ($filesize > 1048576) {$filesize round($filesize / 1048576 * 100) / 100 . MB;} else…

python 重定向stdout_Python 犄角旮旯--重定向 stdout

What&#xff1f;在 Python 程序中&#xff0c;使用 print 輸出調試信息的做法非常常見&#xff0c;但有的時候我們需要將 print 的內容改寫到其他位置&#xff0c;比如一個文件中&#xff0c;便于隨時排查。但是又不希望大面積替換 print 函數&#xff0c;這就需要一些技巧實現…

Jetty實戰之 安裝 運行 部署

原文地址&#xff1a;http://blog.csdn.net/kongxx/article/details/7218767 1. 首先從Jetty的官方網站http://wiki.eclipse.org/Jetty/Starting/Downloads下載最新的Jetty&#xff0c;上面有兩個版本7.x和8.x&#xff0c;7.x是運行在JDK5及以上版本&#xff0c;8.x是運行在JD…

一行命令從 APK 文件中提取 Endpoint 及 URL

做IoT的人免不了要接觸Android&#xff0c;接觸Android的人又免不了要研究別人的App應用。 Diggy&#xff0c;一款能夠從 apk 文件中提取 endpoint 及 URL 的工具&#xff0c;只要一行命令就可以幫大家提取出相關Android apk文件的安裝信息和互聯網訪問信息。 下載地址&#xf…