本文記錄了自己在閱讀《動手學深度學習》時的一些思考,僅用來作為作者本人的學習筆記,不存在商業用途。
正如我們在9.5機器翻譯
中所討論的,機器翻譯是序列轉換模型的一個核心問題,其輸入和輸出都是長度可變的序列。為了處理這種類型的輸入和輸出,我們可以設計一個包含兩個主要組件的架構:第一個組件是一個編碼器(encoder):它接受一個長度可變的序列作為輸入,并將其轉換為具有固定形狀的編碼狀態。第二個組件是解碼器(decoder):它將固定形狀的編碼狀態映射到長度可變的序列。這被稱為編碼器-解碼器(encoder-decoder)架構,如:fig_encoder_decoder
所示。
🏷
fig_encoder_decoder
我們以英語到法語的機器翻譯為例:給定一個英文的輸入序列:“They”“are”“watching”“.”。首先,這種“編碼器-解碼器”架構將長度可變的輸入序列編碼成一個“狀態”,然后對該狀態進行解碼,一個詞元接著一個詞元地生成翻譯后的序列作為輸出:“Ils”“regordent”“.”。由于“編碼器-解碼器”架構是形成后續章節中不同序列轉換模型的基礎,因此本節將把這個架構轉換為接口方便后面的代碼實現。
9.6.1 編碼器
在編碼器接口中,我們只指定長度可變的序列作為編碼器的輸入X
。任何繼承這個Encoder
基類的模型將完成代碼實現。
from torch import nn#@save
class Encoder(nn.Module):"""編碼器-解碼器架構的基本編碼器接口"""# 利用接受任意關鍵字參數**kwargs(參數都是鍵值對)初始化實例def __init__(self, **kwargs):# 調用父類nn.Module的構造函數, 確保Pytorch能正確初始化模塊super(Encoder, self).__init__(**kwargs)# 前向傳播邏輯需要輸入數據X和位置參數*argsdef forward(self, X, *args):# 強制子類必須重寫此方法, 否則調用時會報錯raise NotImplementedError
🏷python中的*args和**kargs的用法
https://blog.csdn.net/GODSuner/article/details/117961990`
🏷python中的raise NotImplementedError
https://blog.csdn.net/qq_40666620/article/details/105026716`
9.6.2 解碼器
在解碼器接口中,我們新增一個init_state函數,用于將編碼器的輸出(enc_outputs)轉換為編碼后的狀態。注意,此步驟可能需要額外的輸入,例如:輸入序列的有效長度,這在 9.5.4節中進行了解釋。為了逐個地生成長度可變的詞元序列,解碼器在每個時間步都會將輸入(例如:在前一時間步生成的詞元)和編碼后的狀態 映射成當前時間步的輸出詞元。
#@save
class Decoder(nn.Module):"""編碼器-解碼器架構的基本解碼器接口"""# 利用接受任意關鍵字參數**kwargs(參數都是鍵值對)初始化實例def __init__(self, **kwargs):# 調用父類nn.Module的構造函數, 確保Pytorch能正確初始化模塊super(Decoder, self).__init__(**kwargs)# 接受編碼器的出書enc_outputs和位置參數*args初始化解碼器的狀態def init_state(self, enc_outputs, *args):# 強制子類必須重寫此方法, 否則調用時會報錯raise NotImplementedError# 前向傳播邏輯需要輸入數據X和編碼后的狀態statedef forward(self, X, state):# 強制子類必須重寫此方法, 否則調用時會報錯raise NotImplementedError
9.6.3 合并編碼器和解碼器
“編碼器-解碼器”架構包含了一個編碼器和一個解碼器, 并且還擁有可選的額外的參數。 在前向傳播中,編碼器的輸出用于生成編碼狀態, 這個狀態又被解碼器作為其輸入的一部分。
#@save
class EncoderDecoder(nn.Module):"""編碼器-解碼器架構的基類"""def __init__(self, encoder, decoder, **kwargs):super(EncoderDecoder, self).__init__(**kwargs)self.encoder = encoderself.decoder = decoder# 前向傳播需要編碼器輸入enc_X和解碼器輸入dec_X以及位置參數*args(比如語句長度)def forward(self, enc_X, dec_X, *args):# 根據編碼器輸入enc_X(語句中的詞元用詞表索引編碼)和語句長度*args得到編碼器輸出enc_outputsenc_outputs = self.encoder(enc_X, *args)# 根據編碼器輸出enc_outputs和語句長度*args得到編碼后的狀態dec_statedec_state = self.decoder.init_state(enc_outputs, *args)# 根據解碼器輸入dec_X和狀態dec_state得到解碼輸出return self.decoder(dec_X, dec_state)
“編碼器-解碼器”體系架構中的術語狀態 會啟發人們使用具有狀態的神經網絡來實現該架構。 在下一節中,我們將學習如何應用循環神經網絡, 來設計基于“編碼器-解碼器”架構的序列轉換模型。
9.6.4 小結
- “編碼器-解碼器”架構可以將長度可變的序列作為輸入和輸出,因此適用于機器翻譯等序列轉換問題。
- 編碼器將長度可變的序列作為輸入,并將其轉換為具有固定形狀的編碼狀態。
- 解碼器將具有固定形狀的編碼狀態映射為長度可變的序列。