lstm需要優化的參數_使用PyTorch手寫代碼從頭構建LSTM,更深入的理解其工作原理...

這是一個造輪子的過程,但是從頭構建LSTM能夠使我們對體系結構進行更加了解,并將我們的研究帶入下一個層次。

LSTM單元是遞歸神經網絡深度學習研究領域中最有趣的結構之一:它不僅使模型能夠從長序列中學習,而且還為長、短期記憶創建了一個數值抽象,可以在需要時相互替換。

751f844ccf62f3619125d59300241223.png

在這篇文章中,我們不僅將介紹LSTM單元的體系結構,還將通過PyTorch手工實現它。

最后但最不重要的是,我們將展示如何對我們的實現做一些小的調整,以實現一些新的想法,這些想法確實出現在LSTM研究領域,如peephole。

LSTM體系結構

LSTM被稱為門結構:一些數學運算的組合,這些運算使信息流動或從計算圖的那里保留下來。因此,它能夠“決定”其長期和短期記憶,并輸出對序列數據的可靠預測:

1b6a8415463c8ddebb44ba41672bdee1.png

LSTM單元中的預測序列。注意,它不僅會傳遞預測值,而且還會傳遞一個c,c是長期記憶的代表

遺忘門

遺忘門(forget gate)是輸入信息與候選者一起操作的門,作為長期記憶。請注意,在輸入、隱藏狀態和偏差的第一個線性組合上,應用一個sigmoid函數:

344929c78260868950e1737331a9dde8.png

sigmoid將遺忘門的輸出“縮放”到0-1之間,然后,通過將其與候選者相乘,我們可以將其設置為0,表示長期記憶中的“遺忘”,或者將其設置為更大的數字,表示我們從長期記憶中記住的“多少”。

新型長時記憶的輸入門及其解決方案

輸入門是將包含在輸入和隱藏狀態中的信息組合起來,然后與候選和部分候選c''u t一起操作的地方:

d0dfb1555009820c429fca38a56a8402.png

在這些操作中,決定了多少新信息將被引入到內存中,如何改變——這就是為什么我們使用tanh函數(從-1到1)。我們將短期記憶和長期記憶中的部分候選組合起來,并將其設置為候選。

單元的輸出門和隱藏狀態(輸出)

之后,我們可以收集ot作為LSTM單元的輸出門,然后將其乘以候選單元(長期存儲器)的tanh,后者已經用正確的操作進行了更新。網絡輸出為ht。

2eda49c81569ac2a102ddce6a557e602.png

LSTM單元方程

95673b00cbe1aeebef348057b550a9c7.png

在PyTorch上實現

import math
import torch
import torch.nn as nn

我們現在將通過繼承nn.Module,然后還將引用其參數和權重初始化,如下所示(請注意,其形狀由網絡的輸入大小和輸出大小決定):

class NaiveCustomLSTM(nn.Module):
def __init__(self, input_sz: int, hidden_sz: int):
super().__init__()
self.input_size = input_sz
self.hidden_size = hidden_sz
#i_t
self.U_i = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_i = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_i = nn.Parameter(torch.Tensor(hidden_sz))
#f_t
self.U_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_f = nn.Parameter(torch.Tensor(hidden_sz))
#c_t
self.U_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_c = nn.Parameter(torch.Tensor(hidden_sz))
#o_t
self.U_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_o = nn.Parameter(torch.Tensor(hidden_sz))
self.init_weights()

要了解每個操作的形狀,請看:

矩陣的輸入形狀是(批量大小、序列長度、特征長度),因此將序列的每個元素相乘的權重矩陣必須具有該形狀(特征長度、輸出長度)。

序列上每個元素的隱藏狀態(也稱為輸出)都具有形狀(批大小、輸出大小),這將在序列處理結束時產生輸出形狀(批大小、序列長度、輸出大小)。-因此,將其相乘的權重矩陣必須具有與單元格的參數hiddensz相對應的形狀(outputsize,output_size)。

這里是權重初始化,我們將其用作PyTorch默認值中的權重初始化nn.Module:

def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)

前饋操作

前饋操作接收initstates參數,該參數是上面方程的(ht,ct)參數的元組,如果不引入,則設置為零。然后,我們對每個保留(ht,c_t)的序列元素執行LSTM方程的前饋,并將其作為序列下一個元素的狀態引入。

最后,我們返回預測和最后一個狀態元組。讓我們看看它是如何發生的:

def forward(self,x,init_states=None):
"""
assumes x.shape represents (batch_size, sequence_size, input_size)
"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (
torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device),
)
else:
h_t, c_t = init_states
for t in range(seq_sz):
x_t = x[:, t, :]
i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i)
f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f)
g_t = torch.tanh(x_t @ self.U_c + h_t @ self.V_c + self.b_c)
o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
#reshape hidden_seq p/ retornar
hidden_seq = torch.cat(hidden_seq, dim=0)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)

優化版本

這個LSTM在運算上是正確的,但在計算時間上沒有進行優化:我們分別執行8個矩陣乘法,這比矢量化的方式慢得多。我們現在將演示如何通過將其減少到2個矩陣乘法來完成,這將使它更快。

為此,我們設置了兩個矩陣U和V,它們的權重包含在4個矩陣乘法上。然后,我們對已經通過線性組合+偏置操作的矩陣執行選通操作。

通過矢量化操作,LSTM單元的方程式為:

afc365239768ca48a6fb365d64482711.png

class CustomLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz):
super().__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x,
init_states=None):
"""Assumes x is of shape (batch, sequence, feature)"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device))
else:
h_t, c_t = init_states
HS = self.hidden_size
for t in range(seq_sz):
x_t = x[:, t, :]
# batch the computations into a single matrix multiplication
gates = x_t @ self.W + h_t @ self.U + self.bias
i_t, f_t, g_t, o_t = (
torch.sigmoid(gates[:, :HS]), # input
torch.sigmoid(gates[:, HS:HS*2]), # forget
torch.tanh(gates[:, HS*2:HS*3]),
torch.sigmoid(gates[:, HS*3:]), # output
)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
# reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)

最后但并非最不重要的是,我們可以展示如何優化,以使用LSTM peephole connections。

LSTM peephole

LSTM peephole對其前饋操作進行了細微調整,從而將其更改為優化的情況:

1f1bf717535b3a64fa898c7c1ba5c12b.png

如果LSTM實現得很好并經過優化,我們可以添加peephole選項,并對其進行一些小的調整:

class CustomLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz, peephole=False):
super().__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
self.peephole = peephole
self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x,
init_states=None):
"""Assumes x is of shape (batch, sequence, feature)"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device))
else:
h_t, c_t = init_states
HS = self.hidden_size
for t in range(seq_sz):
x_t = x[:, t, :]
# batch the computations into a single matrix multiplication
if self.peephole:
gates = x_t @ U + c_t @ V + bias
else:
gates = x_t @ U + h_t @ V + bias
g_t = torch.tanh(gates[:, HS*2:HS*3])
i_t, f_t, o_t = (
torch.sigmoid(gates[:, :HS]), # input
torch.sigmoid(gates[:, HS:HS*2]), # forget
torch.sigmoid(gates[:, HS*3:]), # output
)
if self.peephole:
c_t = f_t * c_t + i_t * torch.sigmoid(x_t @ U + bias)[:, HS*2:HS*3]
h_t = torch.tanh(o_t * c_t)
else:
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
# reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)

我們的LSTM就這樣結束了。如果有興趣大家可以將他與torch LSTM內置層進行比較。

代碼:https://github.com/piEsposito/pytorch-lstm-by-hand

作者:Piero Esposito

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/276897.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/276897.shtml
英文地址,請注明出處:http://en.pswp.cn/news/276897.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

有哪些漂亮的中國風 LOGO 設計?

提到中國風的logo,我覺得首先登場的應該是北京故宮博物院的logo,鐺!故宮博物院的logo,從顏色,到外形,到元素,無一例外,充滿了中國風的味道,可謂是中國風中的典型。同一風…

大家放松下,仿《大腕》經典對白

仿《大腕》經典對白: 一定要找那最流行的框架, 用功能最強大編輯器, 做就要做最復雜的系統, 輕量級的絕對不行, 框架最簡單也得是SPRING&…

MySQL-8.0.12源碼安裝實例

1、通過官網下載對應的版本后,通過FTP上傳至云服務器的/usr/local/src 目錄 2、解壓縮文件 [rootJSH-01 src]# ls mysql-boost-8.0.12.tar.gz [rootJSH-01 src]# tar zxvf mysql-boost-8.0.12.tar.gz [rootJSH-01 src]# ls mysql-8.0.12 mysql-boost-8.0.12.tar.gz…

python3常用模塊_Python3 常用模塊

一、time與datetime模塊 在Python中,通常有這幾種方式來表示時間: 時間戳(timestamp):通常來說,時間戳表示的是從1970年1月1日00:00:00開始按秒計算的偏移量。我們運行“type(time.time())”,返回的是float類型。 格式…

Windows下的HEAP溢出及其利用

Windows下的HEAP溢出及其利用 作者: isno 一、概述 前一段時間ASP的溢出鬧的沸沸揚揚,這個漏洞并不是普通的堆棧溢出,而是發生在HEAP中的溢出,這使大家重新認識到了Windows下的HEAP溢出的可利用性。其實WIN下的HEAP溢出比Linux和SOLARIS下面的…

地方政府不愿房價下跌 救市或化解房地產調控

地方政府不愿房價下跌 "救市"或化解房地產調控 2008年05月09日 07:29:38  來源:上海證券報 漫畫 劉道偉 由于房地產業與地方政府利益攸關,地方政府最不愿意看到房價下跌。中央房地產調控政策剛剛導致部分城市的房價步入調整,一些…

App移動端性能工具調研

使用GT的差異化場景平臺描述release版本development版本Android在Android平臺上,如果希望使用GT的高級功能,如“插樁”等,就必須將GT的SDK嵌入到被調測的應用的工程里,再配合安裝好的GT使用。支持AndroidiOS在iOS平臺上&#xff0…

UITabBar Contoller

。UITabBar中的UIViewController獲得控制權:在TabBar文件中添加:IBOutlet UITabBar *myTabBar; //在xib中連接tabBar;(void)tabBarController:(UITabBarController *)tabBarController didSelectViewController:      (UIViewControlle…

python3.5安裝pip_win10上python3.5.2第三方庫安裝(運用pip)

1 首先在python官網下載并安裝python。我這兒用的是python3.5.2,其自帶了pip。如果你選擇的版本沒有自帶pip,那么請查找其他的安裝教程。 2 python安裝好以后,我在其自帶的命令提示符窗口中輸入了pip,結果尷尬了,提示我…

C語言程序設計 練習題參考答案 第八章 文件(2)

/* 8.8從文件ex88_1.txt中取出成績,排序后,按降序存放EX88_2.TXT中 */ #include "stdio.h" #define N 10 struct student { int num; char name[20]; int score[3]; /*不能使用float*/ float average; }; void sort(struc…

語法上的小trick

語法上的小trick 構造函數 雖然不寫構造函數也是可以的,但是可能會開翻車,所以還是寫上吧。: 提供三種寫法: ? 使用的時候只用: 注意,這里的A[i]gg(3,3,3)的“gg”不能打括號,否則就是強制轉換…

Ubuntu18.04如何讓桌面軟件默認root權限運行?

什么是gksu? 什么是gksu:Linxu中的gksu是系統中的su/sudo工具,如果安裝了gksu,在終端中鍵入gksu會彈出一個對話框. 安裝gksu: 在Ubuntu之前的版本中是繼承gksu工具的,但是在Ubutu18.04中并沒有集成, 在Elementary OS中連gksu的APT源都沒有. Ubuntu18.04 安裝和使用gksu: seven…

win10診斷啟動后聯網_小技巧:win10網絡共享文件夾出現錯誤無法訪問如何解決?...

win10系統共享文件夾時在資源管理器中的網絡里能夠看到所共享的文件夾,但在打開文件夾時卻出現 Windows無法訪問 Desktop-r8ceh55新建文件夾 請檢查名稱的拼寫。否則,網絡可能有問題。要嘗試識別并解決網絡問題,請單擊“診斷”的錯誤提示&…

兩段關于統計日期的sql語句

統計月份:selectleft(convert(char(10),[Article_TimeDate],102),7) as月份, count(*) as數量from[hdsource].[dbo].[article]groupbyleft(convert(char(10),[Article_TimeDate],102),7)orderby1統計年份: selectleft(convert(char(10),[Article_TimeDat…

apache配置文件詳解與優化

apache配置文件詳解與優化 一、總結 一句話總結&#xff1a;結合apache配置文件中的英文說明和配置詳解一起看 1、apache模塊配置用的什么標簽&#xff1f; IfModule 例如&#xff1a; <IfModule dir_module>DirectoryIndex index.html 索引文件 首頁文件&#xff08;首頁…

帆軟報表(finereport)單元格函數,OP參數

單元格模型&#xff1a;單元格數據和引用&#xff1a;數據類型、實際值與顯示值、單元格支持的操作單元格樣式&#xff1a;行高列寬、隱藏行列、自動換行、上下標、文字豎排、大文本字段分頁時斷開、標識說明、格式刷單元格Web屬性&#xff1a;web顯示、web編輯風格、控件實際值…

sklearn 安裝_sklearn-classification_report

原型sklearn.metrics.classification_report(y_true, y_pred, labelsNone, target_namesNone, sample_weightNone, digits2)參數y_true&#xff1a;1維數組或標簽指示數組/離散矩陣&#xff0c;樣本實際類別值列表y_pred&#xff1a;1維數組或標簽指示數組/離散矩陣&#xff0c…

effective c++條款11擴展——關于拷貝構造函數和賦值運算符

effective c條款11擴展——關于拷貝構造函數和賦值運算符 作者&#xff1a;馮明德重點:包含動態分配成員的類 應提供拷貝構造函數,并重載""賦值操作符。 以下討論中將用到的例子: class CExample { public: CExample(){pBufferNULL; nSize0;} ~CExample(){delete pB…

SparkSQL 之 Shuffle Join 內核原理及應用深度剖析-Spark商業源碼實戰

本套技術專欄是作者&#xff08;秦凱新&#xff09;平時工作的總結和升華&#xff0c;通過從真實商業環境抽取案例進行總結和分享&#xff0c;并給出商業應用的調優建議和集群環境容量規劃等內容&#xff0c;請持續關注本套博客。版權聲明&#xff1a;禁止轉載&#xff0c;歡迎…

Python標準庫之csv(1)

1.Python處理csv文件之csv.writer() import csvdef csv_write(path,data):with open(path,w,encodingutf-8,newline) as f:writer csv.writer(f,dialectexcel)for row in data:writer.writerow(row)return True 調用上面的函數 data [[Name,Height],[Keys,176cm],[HongPing,1…