淺談 PyTorch 中的 tensor 及使用

淺談 PyTorch 中的 tensor 及使用

轉自:淺談 PyTorch 中的 tensor 及使用

這篇文章主要是圍繞 PyTorch 中的 tensor 展開的,討論了張量的求導機制,在不同設備之間的轉換,神經網絡中權重的更新等內容。面向的讀者是使用過 PyTorch 一段時間的用戶。本文中的代碼例子基于 Python 3 和 PyTorch 1.1,如果文章中有錯誤或者沒有說明白的地方,歡迎在評論區指正和討論。

文章具體內容分為以下6個部分:

  1. tensor.requires_grad
  2. torch.no_grad()
  3. 反向傳播及網絡的更新
  4. tensor.detach()
  5. CPU and GPU
  6. tensor.item()

因為本文大部分內容是聽著冷鳥的歌完成的,故用此標題封面。

1. requires_grad

當我們創建一個張量 (tensor) 的時候,如果沒有特殊指定的話,那么這個張量是默認是不需要求導的。我們可以通過 tensor.requires_grad 來檢查一個張量是否需要求導。

在張量間的計算過程中,如果在所有輸入中,有一個輸入需要求導,那么輸出一定會需要求導;相反,只有當所有輸入都不需要求導的時候,輸出才會不需要 [1]。

舉一個比較簡單的例子,比如我們在訓練一個網絡的時候,我們從 DataLoader 中讀取出來的一個 mini-batch 的數據,這些輸入默認是不需要求導的,其次,網絡的輸出我們沒有特意指明需要求導吧,Ground Truth 我們也沒有特意設置需要求導吧。這么一想,哇,那我之前的那些 loss 咋還能自動求導呢?其實原因就是上邊那條規則,雖然輸入的訓練數據是默認不求導的,但是,我們的 model 中的所有參數,它默認是求導的,這么一來,其中只要有一個需要求導,那么輸出的網絡結果必定也會需要求的。來看個實例:

input = torch.randn(8, 3, 50, 100)
print(input.requires_grad)
# Falsenet = nn.Sequential(nn.Conv2d(3, 16, 3, 1),nn.Conv2d(16, 32, 3, 1))
for param in net.named_parameters():print(param[0], param[1].requires_grad)
# 0.weight True
# 0.bias True
# 1.weight True
# 1.bias Trueoutput = net(input)
print(output.requires_grad)
# True

誠不欺我!但是,大家請注意前邊只是舉個例子來說明。在寫代碼的過程中,不要把網絡的輸入和 Ground Truth 的 requires_grad 設置為 True。雖然這樣設置不會影響反向傳播,但是需要額外計算網絡的輸入和 Ground Truth 的導數,增大了計算量和內存占用不說,這些計算出來的導數結果也沒啥用。因為我們只需要神經網絡中的參數的導數,用來更新網絡,其余的導數都不需要。

好了,有個這個例子做鋪墊,那么我們來得寸進尺一下。我們試試把網絡參數的 requires_grad 設置為 False 會怎么樣,同樣的網絡:

input = torch.randn(8, 3, 50, 100)
print(input.requires_grad)
# Falsenet = nn.Sequential(nn.Conv2d(3, 16, 3, 1),nn.Conv2d(16, 32, 3, 1))
for param in net.named_parameters():param[1].requires_grad = Falseprint(param[0], param[1].requires_grad)
# 0.weight False
# 0.bias False
# 1.weight False
# 1.bias Falseoutput = net(input)
print(output.requires_grad)
# False

這樣有什么用處?用處大了。我們可以通過這種方法,在訓練的過程中凍結部分網絡,讓這些層的參數不再更新,這在遷移學習中很有用處。我們來看一個 官方 Tutorial: FINETUNING TORCHVISION MODELS 給的例子:

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():param.requires_grad = False# 用一個新的 fc 層來取代之前的全連接層
# 因為新構建的 fc 層的參數默認 requires_grad=True
model.fc = nn.Linear(512, 100)# 只更新 fc 層的參數
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)# 通過這樣,我們就凍結了 resnet 前邊的所有層,
# 在訓練過程中只更新最后的 fc 層中的參數。

2. torch.no_grad()

當我們在做 evaluating 的時候(不需要計算導數),我們可以將推斷(inference)的代碼包裹在 with torch.no_grad(): 之中,以達到 暫時 不追蹤網絡參數中的導數的目的,總之是為了減少可能存在的計算和內存消耗。看 官方 Tutorial 給出的例子:

x = torch.randn(3, requires_grad = True)
print(x.requires_grad)
# True
print((x ** 2).requires_grad)
# Truewith torch.no_grad():print((x ** 2).requires_grad)# Falseprint((x ** 2).requires_grad)
# True

3. 反向傳播及網絡的更新

這部分我們比較簡單地講一講,有了網絡輸出之后,我們怎么根據這個結果來更新我們的網絡參數呢。我們以一個非常簡單的自定義網絡來講解這個問題,這個網絡包含2個卷積層,1個全連接層,輸出的結果是20維的,類似分類問題中我們一共有20個類別,網絡如下:

class Simple(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False)self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1, bias=False)self.linear = nn.Linear(32*10*10, 20, bias=False)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.linear(x.view(x.size(0), -1))return x

接下來我們用這個網絡,來研究一下整個網絡更新的流程:

# 創建一個很簡單的網絡:兩個卷積層,一個全連接層
model = Simple()
# 為了方便觀察數據變化,把所有網絡參數都初始化為 0.1
for m in model.parameters():m.data.fill_(0.1)criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)model.train()
# 模擬輸入8個 sample,每個的大小是 10x10,
# 值都初始化為1,讓每次輸出結果都固定,方便觀察
images = torch.ones(8, 3, 10, 10)
targets = torch.ones(8, dtype=torch.long)output = model(images)
print(output.shape)
# torch.Size([8, 20])loss = criterion(output, targets)print(model.conv1.weight.grad)
# None
loss.backward()
print(model.conv1.weight.grad[0][0][0])
# tensor([-0.0782, -0.0842, -0.0782])
# 通過一次反向傳播,計算出網絡參數的導數,
# 因為篇幅原因,我們只觀察一小部分結果print(model.conv1.weight[0][0][0])
# tensor([0.1000, 0.1000, 0.1000], grad_fn=<SelectBackward>)
# 我們知道網絡參數的值一開始都初始化為 0.1 的optimizer.step()
print(model.conv1.weight[0][0][0])
# tensor([0.1782, 0.1842, 0.1782], grad_fn=<SelectBackward>)
# 回想剛才我們設置 learning rate 為 1,這樣,
# 更新后的結果,正好是 (原始權重 - 求導結果) !optimizer.zero_grad()
print(model.conv1.weight.grad[0][0][0])
# tensor([0., 0., 0.])
# 每次更新完權重之后,我們記得要把導數清零啊,
# 不然下次會得到一個和上次計算一起累加的結果。
# 當然,zero_grad() 的位置,可以放到前邊去,
# 只要保證在計算導數前,參數的導數是清零的就好。

這里,我們多提一句,我們把整個網絡參數的值都傳到 optimizer 里面了,這種情況下我們調用 model.zero_grad(),效果是和 optimizer.zero_grad() 一樣的。這個知道就好,建議大家堅持用 optimizer.zero_grad()。我們現在來看一下如果沒有調用 zero_grad(),會怎么樣吧:

# ...
# 代碼和之前一樣
model.train()# 第一輪
images = torch.ones(8, 3, 10, 10)
targets = torch.ones(8, dtype=torch.long)output = model(images)
loss = criterion(output, targets)
loss.backward()
print(model.conv1.weight.grad[0][0][0])
# tensor([-0.0782, -0.0842, -0.0782])# 第二輪
output = model(images)
loss = criterion(output, targets)
loss.backward()
print(model.conv1.weight.grad[0][0][0])
# tensor([-0.1564, -0.1684, -0.1564])

我們可以看到,第二次的結果正好是第一次的2倍。第一次結束之后,因為我們沒有更新網絡權重,所以第二次反向傳播的求導結果和第一次結果一樣,加上上次我們沒有將 loss 清零,所以結果正好是2倍。另外大家可以看一下這個博客 (torch 代碼解析 為什么要使用 optimizer.zero_grad() ),我覺得講得很好。

4. tensor.detach()

接下來我們來探討兩個 0.4.0 版本更新產生的遺留問題。第一個,tensor.datatensor.detach()

在 0.4.0 版本以前,.data 是用來取 Variable 中的 tensor 的,但是之后 Variable 被取消,.data 卻留了下來。現在我們調用 tensor.data,可以得到 tensor的數據 + requires_grad=False 的版本,而且二者共享儲存空間,也就是如果修改其中一個,另一個也會變。因為 PyTorch 的自動求導系統不會追蹤 tensor.data 的變化,所以使用它的話可能會導致求導結果出錯。官方建議使用 tensor.detach() 來替代它,二者作用相似,但是 detach 會被自動求導系統追蹤,使用起來很安全[2]。多說無益,我們來看個例子吧:

a = torch.tensor([7., 0, 0], requires_grad=True)
b = a + 2
print(b)
# tensor([9., 2., 2.], grad_fn=<AddBackward0>)loss = torch.mean(b * b)b_ = b.detach()
b_.zero_()
print(b)
# tensor([0., 0., 0.], grad_fn=<AddBackward0>)
# 儲存空間共享,修改 b_ , b 的值也變了loss.backward()
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

這個例子中,b 是用來計算 loss 的一個變量,我們在計算完 loss 之后,進行反向傳播之前,修改 b 的值。這么做會導致相關的導數的計算結果錯誤,因為我們在計算導數的過程中還會用到 b 的值,但是它已經變了(和正向傳播過程中的值不一樣了)。在這種情況下,PyTorch 選擇報錯來提醒我們。但是,如果我們使用 tensor.data 的時候,結果是這樣的:

a = torch.tensor([7., 0, 0], requires_grad=True)
b = a + 2
print(b)
# tensor([9., 2., 2.], grad_fn=<AddBackward0>)loss = torch.mean(b * b)b_ = b.data
b_.zero_()
print(b)
# tensor([0., 0., 0.], grad_fn=<AddBackward0>)loss.backward()print(a.grad)
# tensor([0., 0., 0.])# 其實正確的結果應該是:
# tensor([6.0000, 1.3333, 1.3333])

這個導數計算的結果明顯是錯的,但沒有任何提醒,之后再 Debug 會非常痛苦。所以,建議大家都用 tensor.detach() 啊。上邊這個代碼例子是受 這里 啟發。

5. CPU and GPU

接下來我們來說另一個問題,是關于 tensor.cuda()tensor.to(device) 的。后者是 0.4.0 版本之后后添加的,當 device 是 GPU 的時候,這兩者并沒有區別。那為什么要在新版本增加后者這個表達呢,是因為有了它,我們直接在代碼最上邊加一句話指定 device ,后面的代碼直接用to(device) 就可以了:

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")a = torch.rand([3,3]).to(device)
# 干其他的活
b = torch.rand([3,3]).to(device)
# 干其他的活
c = torch.rand([3,3]).to(device)

而之前版本的話,當我們每次在不同設備之間切換的時候,每次都要用 if cuda.is_available() 判斷能否使用 GPU,很麻煩。這個精彩的解釋來自于 這里 。

if torch.cuda.is_available():a = torch.rand([3,3]).cuda()
# 干其他的活
if  torch.cuda.is_available():b = torch.rand([3,3]).cuda()
# 干其他的活
if  torch.cuda.is_available():c = torch.rand([3,3]).cuda()

關于使用 GPU 還有一個點,在我們想把 GPU tensor 轉換成 Numpy 變量的時候,需要先將 tensor 轉換到 CPU 中去,因為 Numpy 是 CPU-only 的。其次,如果 tensor 需要求導的話,還需要加一步 detach,再轉成 Numpy 。例子如下:

x  = torch.rand([3,3], device='cuda')
x_ = x.cpu().numpy()y  = torch.rand([3,3], requires_grad=True, device='cuda').
y_ = y.cpu().detach().numpy()
# y_ = y.detach().cpu().numpy() 也可以
# 二者好像差別不大?我們來比比時間:
start_t = time.time()
for i in range(10000):y_ = y.cpu().detach().numpy()
print(time.time() - start_t)
# 1.1049120426177979start_t = time.time()
for i in range(10000):y_ = y.detach().cpu().numpy()
print(time.time() - start_t)
# 1.115112543106079
# 時間差別不是很大,當然,這個速度差別可能和電腦配置
# (比如 GPU 很貴,CPU 卻很爛)有關。

6. tensor.item()

我們在提取 loss 的純數值的時候,常常會用到 loss.item(),其返回值是一個 Python 數值 (python number)。不像從 tensor 轉到 numpy (需要考慮 tensor 是在 cpu,還是 gpu,需不需要求導),無論什么情況,都直接使用 item() 就完事了。如果需要從 gpu 轉到 cpu 的話,PyTorch 會自動幫你處理。

但注意 item() 只適用于 tensor 只包含一個元素的時候。因為大多數情況下我們的 loss 就只有一個元素,所以就經常會用到 loss.item()。如果想把含多個元素的 tensor 轉換成 Python list 的話,要使用 tensor.tolist()

x  = torch.randn(1, requires_grad=True, device='cuda')
print(x)
# tensor([-0.4717], device='cuda:0', requires_grad=True)y = x.item()
print(y, type(y))
# -0.4717346727848053 <class 'float'>x = torch.randn([2, 2])
y = x.tolist()
print(y)
# [[-1.3069953918457031, -0.2710231840610504], [-1.26217520236969, 0.5559719800949097]]

結語

以上內容就是我平時在寫代碼的時候,覺得需要注意的地方。文章中用了一些簡單的代碼作為例子,旨在幫助大家理解。文章內容不少,看到這里的大家都辛苦了, 感謝閱讀。

最后還是那句話,希望本文能對大家學習和理解 PyTorch 有所幫助。

參考

  1. PyTorch Docs: AUTOGRAD MECHANICS https://pytorch.org/docs/stable/notes/autograd.html
  2. PyTorch 0.4.0 release notes https://github.com/pytorch/pytorch/releases/tag/v0.4.0

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532399.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532399.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532399.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

簡述springmvc過程_spring mvc的工作流程是什么?

展開全部SpringMVC工作流程描述向服務器發送HTTP請求&#xff0c;請求被前端控制器 DispatcherServlet 捕獲。DispatcherServlet 根據 -servlet.xml 中的配置對請62616964757a686964616fe59b9ee7ad9431333365646233求的URL進行解析&#xff0c;得到請求資源標識符(URI)。 然后根…

PyTorch 的 Autograd

PyTorch 的 Autograd 轉自&#xff1a;PyTorch 的 Autograd PyTorch 作為一個深度學習平臺&#xff0c;在深度學習任務中比 NumPy 這個科學計算庫強在哪里呢&#xff1f;我覺得一是 PyTorch 提供了自動求導機制&#xff0c;二是對 GPU 的支持。由此可見&#xff0c;自動求導 (a…

商場樓層導視牌圖片_百寶圖商場電子導視軟件中預約產品功能簡介

百寶圖商場電子導視軟件中預約產品功能簡介 管理端&#xff0c;可配合百寶圖商場電子導視軟件配套使用 1&#xff1a;數據展示&#xff1a;圖形展示總預約數/預約時間峰值/預約途徑/各途徑數量對比 2&#xff1a;數據統計&#xff1a;有效預約數量/無效預約數量/無效預約原因備…

Pytorch autograd.grad與autograd.backward詳解

Pytorch autograd.grad與autograd.backward詳解 引言 平時在寫 Pytorch 訓練腳本時&#xff0c;都是下面這種無腦按步驟走&#xff1a; outputs model(inputs) # 模型前向推理 optimizer.zero_grad() # 清除累積梯度 loss.backward() # 模型反向求導 optimizer.step()…

相對熵與交叉熵_熵、KL散度、交叉熵

公眾號關注 “ML_NLP”設為 “星標”&#xff0c;重磅干貨&#xff0c;第一時間送達&#xff01;機器學習算法與自然語言處理出品公眾號原創專欄作者 思婕的便攜席夢思單位 | 哈工大SCIR實驗室KL散度 交叉熵 - 熵1. 熵(Entropy)抽象解釋&#xff1a;熵用于計算一個隨機變量的信…

動手實現一個帶自動微分的深度學習框架

動手實現一個帶自動微分的深度學習框架 轉自&#xff1a;Automatic Differentiation Tutorial 參考代碼&#xff1a;https://github.com/borgwang/tinynn-autograd (主要看 core/tensor.py 和 core/ops.py) 目錄 簡介自動求導設計自動求導實現一個例子總結參考資料 簡介 梯度…

git安裝后找不見版本_結果發現git版本為1.7.4,(git --version)而官方提示必須是1.7.10及以后版本...

結果發現git版本為1.7.4,(git --version)而官方提示必須是1.7.10及以后版本升級增加ppasudo apt-add-repository ppa:git-core/ppasudo apt-get updatesudo apt-get install git如果本地已經安裝過Git&#xff0c;可以使用升級命令&#xff1a;sudo apt-get dist-upgradeapt命令…

隨機數生成算法:K進制逐位生成+拒絕采樣

隨機數生成算法&#xff1a;K進制逐位生成拒絕采樣 轉自&#xff1a;【宮水三葉】k 進制諸位生成 拒絕采樣 基本分析 給定一個隨機生成 1 ~ 7 的函數&#xff0c;要求實現等概率返回 1 ~ 10 的函數。 首先需要知道&#xff0c;在輸出域上進行定量整體偏移&#xff0c;仍然滿…

深入理解NLP Subword算法:BPE、WordPiece、ULM

深入理解NLP Subword算法&#xff1a;BPE、WordPiece、ULM 本文首發于微信公眾號【AI充電站】&#xff0c;感謝大家的贊同、收藏和轉發(▽) 轉自&#xff1a;深入理解NLP Subword算法&#xff1a;BPE、WordPiece、ULM 前言 Subword算法如今已經成為了一個重要的NLP模型性能提升…

http 錯誤 404.0 - not found_電腦Regsvr32 用法和錯誤消息的說明

? 對于那些可以自行注冊的對象鏈接和嵌入 (OLE) 控件&#xff0c;例如動態鏈接庫 (DLL) 文件或 ActiveX 控件 (OCX) 文件&#xff0c;您可以使用 Regsvr32 工具 (Regsvr32.exe) 來將它們注冊和取消注冊。Regsvr32.exe 的用法RegSvr32.exe 具有以下命令行選項&#xff1a; Regs…

mysql error 1449_MySql錯誤:ERROR 1449 (HY000)

筆者系統為 mac &#xff0c;不知怎的&#xff0c;Mysql 竟然報如下錯誤&#xff1a;ERROR 1449 (HY000): The user specified as a definer (mysql.infoschemalocalhost) does not exist一時沒有找到是什么操作導致的這個錯誤。然后經過查詢&#xff0c;參考文章解決了問題。登…

MobileNet 系列:從V1到V3

MobileNet 系列&#xff1a;從V1到V3 轉自&#xff1a;輕量級神經網絡“巡禮”&#xff08;二&#xff09;—— MobileNet&#xff0c;從V1到V3 自從2017年由谷歌公司提出&#xff0c;MobileNet可謂是輕量級網絡中的Inception&#xff0c;經歷了一代又一代的更新。成為了學習輕…

mysql 查詢表的key_mysql查詢表和字段的注釋

1,新建表以及添加表和字段的注釋.create table auth_user(ID INT(19) primary key auto_increment comment 主鍵,NAME VARCHAR(300) comment 姓名,CREATE_TIME date comment 創建時間)comment 用戶信息表;2,修改表/字段的注釋.alter table auth_user comment 修改后的表注…

mysql 高級知識點_這是我見過最全的《MySQL筆記》,涵蓋MySQL所有高級知識點!...

作為運維和編程人員&#xff0c;對MySQL一定不會陌生&#xff0c;尤其是互聯網行業&#xff0c;對MySQL的使用是比較多的。MySQL 作為主流的數據庫&#xff0c;是各大廠面試官百問不厭的知識點&#xff0c;但是需要了解到什么程度呢&#xff1f;僅僅停留在 建庫、創表、增刪查改…

teechart mysql_TeeChart 的應用

TeeChart 是一個很棒的繪圖控件&#xff0c;不過由于里面沒有注釋&#xff0c;網上相關的資料也很少&#xff0c;所以在應用的時候只能是一點點的試。為了防止以后用到的時候忘記&#xff0c;我就把自己用到的東西都記錄下來&#xff0c;以便以后使用的時候查詢。1、進制縮放圖…

NLP新寵——淺談Prompt的前世今生

NLP新寵——淺談Prompt的前世今生 轉自&#xff1a;NLP新寵——淺談Prompt的前世今生 作者&#xff1a;閔映乾&#xff0c;中國人民大學信息學院碩士&#xff0c;目前研究方向為自然語言處理。 《Pre-train, Prompt, and Predict: A Systematic Survey of Prompting Methods in…

mysql key_len_淺談mysql explain中key_len的計算方法

mysql的explain命令可以分析sql的性能&#xff0c;其中有一項是key_len(索引的長度)的統計。本文將分析mysql explain中key_len的計算方法。1、創建測試表及數據CREATE TABLE member (id int(10) unsigned NOT NULL AUTO_INCREMENT,name varchar(20) DEFAULT NULL,age tinyint(…

requestfacade 這個是什么類?_Java 的大 Class 到底是什么?

作者在之前工作中&#xff0c;面試過很多求職者&#xff0c;發現有很多面試者對Java的 Class 搞不明白&#xff0c;理解的不到位&#xff0c;一知半解&#xff0c;一到用的時候&#xff0c;就不太會用。想寫一篇關于Java Class 的文章&#xff0c;沒有那么多專業名詞&#xff0…

初學機器學習:直觀解讀KL散度的數學概念

初學機器學習&#xff1a;直觀解讀KL散度的數學概念 轉自&#xff1a;初學機器學習&#xff1a;直觀解讀KL散度的數學概念 譯自&#xff1a;https://towardsdatascience.com/light-on-math-machine-learning-intuitive-guide-to-understanding-kl-divergence-2b382ca2b2a8 解讀…

php mysql讀取數據查詢_PHP MySQL 讀取數據

PHP MySQL 讀取數據從 MySQL 數據庫讀取數據SELECT 語句用于從數據表中讀取數據:SELECT column_name(s) FROM table_name我們可以使用 * 號來讀取所有數據表中的字段&#xff1a;SELECT * FROM table_name如需學習更多關于 SQL 的知識&#xff0c;請訪問我們的 SQL 教程。使用 …