Seq2Seq模型概述
Seq2Seq(Sequence-to-Sequence)是一種基于深度學習的序列生成模型,主要用于處理輸入和輸出均為序列的任務,如機器翻譯、文本摘要、對話生成等。其核心思想是將可變長度的輸入序列映射為另一個可變長度的輸出序列。
核心結構
Seq2Seq模型通常由兩部分組成:編碼器(Encoder)和解碼器(Decoder)。編碼器將輸入序列壓縮為一個固定長度的上下文向量(Context Vector),解碼器根據該向量逐步生成輸出序列。
- 編碼器:通常是一個循環神經網絡(RNN),如LSTM或GRU,逐時間步處理輸入序列,最終隱藏狀態作為上下文向量。
- 解碼器:另一個RNN,以編碼器的上下文向量為初始狀態,逐步生成輸出序列的每個元素。
注意力機制
傳統Seq2Seq的瓶頸在于上下文向量的固定長度限制了模型處理長序列的能力。注意力機制(Attention)通過動態分配權重解決這一問題:
- 解碼器在每一步生成時,會關注編碼器所有時間步的隱藏狀態,而非僅依賴單一上下文向量。
- 注意力權重計算通常采用點積、加性或乘性方式,例如:
其中,為編碼器隱藏狀態,
為解碼器隱藏狀態,
、
、
為可學習參數。
典型應用場景
- 機器翻譯:輸入源語言句子,輸出目標語言句子。
- 文本摘要:輸入長文本,輸出概括性短文本。
- 語音識別:輸入音頻特征序列,輸出文本序列。
- 對話系統:輸入用戶語句,生成系統回復。
改進與變體
- Transformer:完全基于自注意力機制的架構,摒棄RNN結構,提升并行計算能力。
- 指針網絡(Pointer Networks):解決輸出詞匯來自輸入序列的任務,如抽取式摘要。
- 雙向編碼器:結合正向和反向RNN,增強上下文理解能力。
代碼示例(PyTorch實現片段)
import torch
import torch.nn as nnclass Seq2Seq(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super().__init__()self.encoder = nn.LSTM(input_dim, hidden_dim)self.decoder = nn.LSTM(hidden_dim, hidden_dim)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, src, trg):# 編碼器處理輸入_, (hidden, cell) = self.encoder(src)# 解碼器逐步生成outputs = []for t in range(trg.shape[0]):out, (hidden, cell) = self.decoder(trg[t].unsqueeze(0), (hidden, cell))outputs.append(self.fc(out.squeeze(0)))return torch.stack(outputs)
挑戰與局限性
- 長序列依賴:盡管注意力機制有所改善,超長序列仍可能導致性能下降。
- 曝光偏差(Exposure Bias):訓練時使用真實標簽,推理時依賴模型自身預測,累積誤差可能放大。
- 計算效率:RNN的串行特性限制了訓練速度,部分場景需改用Transformer等架構。