BART模型
BART(Bidirectional and Auto-Regressive Transformers)是由 Facebook AI Research(FAIR)在 2019 年提出的序列到序列(seq2seq)預訓練模型,論文發表于《BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension》。
它結合了 BERT 的雙向編碼器 和 GPT 的自回歸解碼器,專為文本生成任務(如摘要、翻譯、對話)設計,同時在理解任務(如分類、問答)上也表現優異。
BART 通過靈活的預訓練任務和統一的編解碼架構,成為生成與理解任務的通用基礎模型,尤其適合需要同時處理輸入理解和輸出生成的場景。
核心特點
架構:標準 Transformer 編解碼器
編碼器:雙向 Transformer(類似 BERT),理解上下文。
解碼器:自回歸 Transformer(類似 GPT),從左到右生成文本。
參數規模:從 BART-Base(140M)到 BART-Large(400M)。
預訓練任務:文本破壞與還原(Denoising Autoencoder) 通過多種噪聲破壞輸入文本,再讓模型還原原始文本,提升生成與理解能力:
Token Masking(類似 BERT):隨機遮蓋詞(如
[MASK]
)。Token Deletion:隨機刪除詞,需還原位置和內容。
Text Infilling:用單個
[MASK]
替換連續片段(如 SpanBERT),需生成缺失片段。Sentence Permutation:打亂句子順序,需重排。
Document Rotation:隨機選擇詞作為開頭,需還原原文起始點。
微調靈活性:可直接用于下游任務:
生成任務:摘要(CNN/DailyMail)、對話、翻譯(需多語言預訓練)。
理解任務:文本分類、問答(將輸入編碼,解碼為答案)。
推理示例代碼:
from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipelineclass ChineseBart:def __init__(self):model_path = "/path/to/bart-base-chinese"self.load_model(model_path)def load_model(self, model_path):# 加載一個中文BART模型(假設已經有微調好的改寫模型權重)self.tokenizer = BertTokenizer.from_pretrained(model_path)self.model = BartForConditionalGeneration.from_pretrained(model_path)self.text2text_generator = Text2TextGenerationPipeline(self.model, self.tokenizer, device=0) def rewrite_text(self, text):# text = "機器學習模型在圖像識別領域取得了突破性的進展。"# 構造輸入(BART可以直接輸入文本)ret = self.text2text_generator(text, max_length=512, do_sample=False)if len(ret) > 0:rewritten_texts = []for obj in ret:ret_text = obj.get('generated_text').replace(" ", "")rewritten_texts.append(ret_text)rewritten_text = "\n\n".join(rewritten_texts)print("改寫結果:", rewritten_text)return rewritten_textreturn text
T5模型
T5(Text-to-Text Transfer Transformer)是 Google Research 在 2019 年提出的統一文本到文本框架,論文發表于《Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer》。它將所有 NLP 任務(翻譯、摘要、問答、分類等)統一為“文本輸入 → 文本輸出”的范式,通過大規模預訓練 + 微調實現通用能力。
核心特點
統一框架:所有任務都是 Text-to-Text
- 輸入和輸出均為純文本,無需任務特定架構。
- 任務前綴:通過在輸入前加提示詞區分任務,例如:
translate English to German: ...
summarize: ...
cola sentence: ...
(分類任務輸出acceptable
或unacceptable
)。
架構:標準 Encoder-Decoder Transformer
- 完全基于原始 Transformer(Vaswani et al., 2017),未做架構創新。
- 規模:從 T5-Small(60M)到 T5-11B(110億參數,最大版本)。
預訓練任務:Span Corruption(改進的 MLM)
- 類似 BERT 的掩碼語言模型(MLM),但連續片段(span)被掩碼(平均長度3),需解碼器還原。
- 預訓練數據:C4(Colossal Clean Crawled Corpus),750GB cleaned English text。
微調靈活性
- 單任務微調:針對特定任務(如翻譯)微調。
- 多任務微調:混合多個任務前綴聯合訓練(如翻譯+摘要+QA)。
- 零樣本/少樣本:通過任務前綴泛化到新任務(如未微調的數學題)。
推理示例代碼:
from transformers import T5Tokenizer, T5ForConditionalGenerationclass ChineseT5:def __init__(self):print("ChineseT5")model_path = "/path/to/flan-t5-base"self.load_model(model_path)def load_model(self, model_name):# 加載一個中文T5模型(假設已經有微調好的改寫模型權重)self.tokenizer = T5Tokenizer.from_pretrained(model_name, legacy=False)self.model = T5ForConditionalGeneration.from_pretrained(model_name)def rewrite_text(self, input_text):# 構造輸入(添加適當的前綴)input_text = "rewrite: " + input_textinput_ids = self.tokenizer(input_text, return_tensors="pt").input_idsoutputs = self.model.generate(input_ids)if len(outputs) > 0:rewritten_text = self.tokenizer.decode(outputs[0])print("改寫結果:", rewritten_text)return rewritten_textreturn input_text