本節實現一個簡單的 Seq2Seq(Sequence to Sequence)模型 的編碼器(Encoder)和解碼器(Decoder)部分。
?重點把握Seq2Seq 模型的整體工作流程
理解編碼器(Encoder)和解碼器(Decoder)代碼
本小節引入了nn.GRU API的調用,nn.GRU具體參數將在下一小節進行補充講解
1.?編碼器(Encoder
類定義
class Encoder(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_size):super().__init__()self.emb = nn.Embedding(vocab_size, embedding_dim)self.rnn = nn.GRU(embedding_dim, hidden_size, batch_first=True)
-
vocab_size
:輸入詞匯表的大小,即輸入序列中可能出現的不同單詞或標記的數量。 -
embedding_dim
:嵌入層的維度,即每個單詞或標記被映射到的向量空間的維度。 -
hidden_size
:GRU(門控循環單元)的隱藏狀態維度,決定了模型的內部狀態大小。
主要組件
-
嵌入層(
nn.Embedding
)-
嵌入層會將輸入序列形狀轉換為?
[batch_size, seq_len, embedding_dim]
的張量。 -
這種映射是通過學習嵌入矩陣實現的,每個單詞索引對應嵌入矩陣中的一行。
-
-
GRU(
nn.GRU
)-
embedding_dim
是 GRU 的輸入維度,hidden_size
是隱藏狀態的維度。 -
batch_first=True
表示輸入和輸出的張量的第一個維度是批量大小(batch_size
),而不是序列長度(seq_len
)。
-
前向傳播(forward
)
def forward(self, x):embs = self.emb(x) #batch * token * embedding_dimgru_out, hidden = self.rnn(embs) #batch * token * hidden_sizereturn gru_out, hidden
-
輸入
x
是一個形狀為[batch_size, seq_len]
的張量,表示一個批次的輸入序列。 -
embs
是嵌入層的輸出,形狀為[batch_size, seq_len, embedding_dim]
。 -
gru_out
是 GRU 的輸出,形狀為[batch_size, seq_len, hidden_size]
,表示每個時間步的隱藏狀態。 -
hidden
是 GRU 的最終隱藏狀態,形狀為[1, batch_size, hidden_size]
,用于傳遞給解碼器。
?
2.?解碼器(Decoder)
類定義
class Decoder(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_size):super().__init__()self.emb = nn.Embedding(vocab_size, embedding_dim)self.rnn = nn.GRU(embedding_dim, hidden_size, batch_first=True)
-
解碼器的結構與編碼器類似,但它的作用是將編碼器生成的上下文向量(
hidden
)解碼為目標序列。
主要組件
-
嵌入層(
nn.Embedding
)-
與編碼器類似,將目標序列的單詞索引映射到嵌入向量。
-
-
GRU(
nn.GRU
)-
與編碼器中的 GRU 類似,但其輸入是目標序列的嵌入向量,初始隱藏狀態是編碼器的最終隱藏狀態。
-
前向傳播(forward
)
def forward(self, x, hx):embs = self.emb(x)gru_out, hidden = self.rnn(embs, hx=hx) #batch * token * hidden_size# batch * token * hidden_size# 1 * token * hidden_sizereturn gru_out, hidden
-
輸入
x
是目標序列的單詞索引,形狀為[batch_size, seq_len]
。 -
hx
是編碼器的最終隱藏狀態,形狀為[1, batch_size, hidden_size]
,作為解碼器的初始隱藏狀態。 -
embs
是目標序列的嵌入向量,形狀為[batch_size, seq_len, embedding_dim]
。 -
gru_out
是解碼器 GRU 的輸出,形狀為[batch_size, seq_len, hidden_size]
。 -
hidden
是解碼器 GRU 的最終隱藏狀態,形狀為[1, batch_size, hidden_size]
。
3.?Seq2Seq 模型的整體工作流程?
-
編碼階段
-
輸入序列通過編碼器的嵌入層,將單詞索引映射為嵌入向量。
-
嵌入向量通過 GRU,生成每個時間步的隱藏狀態和最終的隱藏狀態(上下文向量)。
-
最終隱藏狀態(
hidden
)作為編碼器的輸出,傳遞給解碼器。
-
-
解碼階段
-
解碼器的初始隱藏狀態是編碼器的最終隱藏狀態。
-
解碼器逐個生成目標序列的單詞,每次生成一個單詞后,將該單詞的嵌入向量作為下一次輸入,同時更新隱藏狀態。
-
通過這種方式,解碼器逐步生成目標序列。
-