針對Helsinki-NLP/opus-mt-zh-en模型進行雙向互翻的微調

引言
?題目聽起來有點怪怪的,但是實際上就是對Helsinki-NLP/opus-mt-en-es模型進行微調。但是這個模型是單向的,只支持中到英的翻譯,反之則不行。這樣的話,如果要做中英雙向互翻就需要兩個模型,那模型體積直接大了兩倍。尤其是部署到手機上,模型的體積是一個非常重要的考慮因素。于是自己就對這個模型的微調過程進行了一些改動,實現了單個模型進行雙向互翻的能力。

原生模型
?這里給出原生模型的使用方法:

from transformers import AutoModel , AutoTokenizer,MarianMTModeltext ="你好,你是誰?"
name ='Helsinki-NLP/opus-mt-zh-en'
tokenizer = AutoTokenizer.from_pretrained(name)
model = MarianMTModel.from_pretrained(name)
input_ids = tokenizer.encode(text, return_tensors="pt")
outputs = model.generate(input_ids)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded)

需要改動的地方
?因為涉及到互翻,所以首先要告訴模型翻譯的方向,具體就是在文本數據之前加一個目標語言的標識符,比如中翻英,原文“你好,你是誰?”,處理后就是“>>eng<< 你好,你是誰?”,英翻中則是“>>zho<< Hello,who are you?”

?因此就引出了一個問題,詞表vocab.json中并沒有“>>eng<<”和“>>zho<<”,那么分詞就會出現問題。我嘗試過兩種方法來解決:

  • 首先是常規的解決辦法,我最開始直接將這兩個標識當做新的token加入詞表中,最后也能跑通。這里只描述思想,具體的實現在下面的代碼中會體現。
  • 然后就是我自己想的非常規方法,為啥自己又想了非常規的方法呢,那是因為我在訓練好模型之后,要將模型轉換為CT2的格式,但是這個轉換過程中因為添加了2個新token導致了報錯,搞了一圈也沒有解決,于是直接把詞表中兩個極其罕見的token給刪除了,用兩個語言標識替代,這樣既不會對翻譯產生大的影響,又能完成模型格式的轉換。當然,這是需要先改詞表后進行微調,順序不能反了。

解決辦法一
?通過下面的代碼微調之后,就能得到一個雙向的翻譯能力的模型了,使用的方法和原生模型一樣,直接加載就能推理了。

import torch
import evaluate
import zhconv
from datasets import load_dataset, Dataset
import sacrebleu
import os
from transformers import (AutoTokenizer, MarianMTModel,Seq2SeqTrainer, Seq2SeqTrainingArguments,DataCollatorForSeq2Seq
)# 加載 tokenizer,并添加語言標簽
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
special_tokens = [">>eng<<", ">>zho<<"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})# 加載模型,并擴展嵌入層大小
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
model.resize_token_embeddings(len(tokenizer))# 設置 token ID
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id'''
加載 Tatoeba 數據集(中英句對)
這里我使用的是公開的數據集,可以通過下面的代碼來加載到本地。加載到本地之后就可以把data_files換成你自己的地址
import kagglehub
alvations_tatoeba_path = kagglehub.dataset_download('alvations/tatoeba')
'''
tatoeba_dataset = load_dataset("csv",data_files="./data/tatoeba-sentpairs.tsv",delimiter="\t",encoding="utf-8",split="train"
)# 過濾中英句對(zh→en 和 en→zh)
zh2en_dataset = tatoeba_dataset.filter(lambda x: x['SRC LANG'] == "cmn" and x['TRG LANG'] == 'eng')
en2zh_dataset = tatoeba_dataset.filter(lambda x: x['SRC LANG'] == "eng" and x['TRG LANG'] == 'cmn')# 預處理函數:添加語言標簽 + 分詞
def preprocess_zh2en(batch):# 添加目標語言 tokeninputs = [">>eng<< " + x for x in batch['SRC']]# 可選:進行繁轉簡inputs = [zhconv.convert(x, 'zh-cn') for x in inputs]targets = batch['TRG']# 編碼inputs_encoded = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")outputs_encoded = tokenizer(targets, max_length=128, truncation=True, padding="max_length" )return {"input_ids": inputs_encoded["input_ids"],"attention_mask": inputs_encoded["attention_mask"],"decoder_input_ids": outputs_encoded["input_ids"],"decoder_attention_mask": outputs_encoded["attention_mask"],"labels": outputs_encoded["input_ids"].copy(),  # labels 通常跟 decoder_input_ids 相同(訓練時用于 loss)}def preprocess_en2zh(batch):# 添加目標語言 tokeninputs = [">>zho<< " + x for x in batch['SRC']]# 可選:進行繁轉簡targets = batch['TRG']targets = [zhconv.convert(x, 'zh-cn') for x in targets]# 編碼inputs_encoded = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")outputs_encoded = tokenizer(targets, max_length=128, truncation=True, padding="max_length" )return {"input_ids": inputs_encoded["input_ids"],"attention_mask": inputs_encoded["attention_mask"],"decoder_input_ids": outputs_encoded["input_ids"],"decoder_attention_mask": outputs_encoded["attention_mask"],"labels": outputs_encoded["input_ids"].copy(),  # labels 通常跟 decoder_input_ids 相同(訓練時用于 loss)}# 數據清洗 + 映射分詞
zh2en_dataset = zh2en_dataset.map(preprocess_zh2en, batched=True)
en2zh_dataset = en2zh_dataset.map(preprocess_en2zh, batched=True)# 合并中→英和英→中雙向數據
combined_dataset = Dataset.from_dict({key: zh2en_dataset[key] + en2zh_dataset[key] for key in zh2en_dataset.features
})# 拆分訓練集和測試集
split_dataset = combined_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]def compute_metrics(pred):pred_ids = pred.predictionslabel_ids = pred.label_idspred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)label_ids[label_ids == -100] = tokenizer.pad_token_idlabel_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)bleu = sacrebleu.corpus_bleu(pred_str, [label_str])# 保存驗證的結果到本地文件,這樣可以實時查看微調的效果save_dir = "./eval_logs"os.makedirs(save_dir, exist_ok=True)eval_id = f"step_{trainer.state.global_step}" if hasattr(trainer, "state") else "eval"output_file = os.path.join(save_dir, f"pred_vs_ref_{eval_id}.txt")with open(output_file, "w", encoding="utf-8") as f:for i, (pred, ref) in enumerate(zip(pred_str, label_str)):f.write(f"Sample {i + 1}:\n")f.write(f"Prediction: {pred}\n")f.write(f"Reference : {ref}\n")f.write("=" * 50 + "\n")return {"bleu": bleu.score}# 訓練參數
training_args = Seq2SeqTrainingArguments(output_dir='./model/marian-zh-en-bidirectional',num_train_epochs=30,per_device_train_batch_size=16,per_device_eval_batch_size=16,logging_steps=50,save_steps=100,eval_steps=100,eval_strategy="steps",predict_with_generate=True,save_total_limit=10,report_to="tensorboard",  # 啟用 TensorBoard 日志記錄logging_dir='./logs',  # 指定 TensorBoard 日志的保存路徑
)# 構建 Trainer
trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=train_dataset.with_format("torch"),eval_dataset=eval_dataset.with_format("torch"),tokenizer=tokenizer,data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),compute_metrics=compute_metrics
)# 開始訓練
trainer.train(resume_from_checkpoint=False)# 保存模型和 tokenizer
model.save_pretrained("./model/marian-zh-en-bidirectional")
tokenizer.save_pretrained("./model/marian-zh-en-bidirectional")

解決辦法二
?上面是針對大眾場景,具體的場景需要做具體的改動。本方法就是根據我的業務場景來修改的。

?方法一訓練得到的模型是使用tokenizer來編解碼,因為目標語言標識符已經加入到詞表里了,所以編解碼沒問題。但是我轉為CT2格式之后,分詞使用的是sentencepiece模型,具體就是用source.spm、target.spm分別對中文和英文進行分詞,然后根據共享詞表轉換為token的id。 共享詞表中是有語言標識符的,但是sentencepiece模型里卻沒有添加兩個新token,所以就無法識別,導致分詞錯誤。我的做法就是推理的時候先不加目標語言的標識符,先分詞,然后手動加上去。這樣分詞就不會出問題了,然后進行編碼就能根據共享詞表來編碼了。

?還有一個問題就是,輸入是中英混合的文本,這樣sentencepiece分詞器也無法正確識別,一個辦法就是將中英文分開,分別進行分詞,然后將分詞的結果按順序進行拼接。

?最后,以上都是基于不重新訓練分詞模型的做法,如果可以重新訓練分詞模型,那么就不需要搞上面哪些操作了。

import torch
import evaluate
import zhconv
from datasets import load_dataset, Dataset
import sacrebleu
import os
from transformers import (AutoTokenizer, MarianMTModel,Seq2SeqTrainer, Seq2SeqTrainingArguments,DataCollatorForSeq2Seq
)# 加載 tokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")# 加載模型
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-zh-en")# 設置 token ID
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id'''
加載 Tatoeba 數據集(中英句對)
這里我使用的是公開的數據集,可以通過下面的代碼來加載到本地。加載到本地之后就可以把data_files換成你自己的地址
import kagglehub
alvations_tatoeba_path = kagglehub.dataset_download('alvations/tatoeba')
'''
tatoeba_dataset = load_dataset("csv",data_files="./data/tatoeba-sentpairs.tsv",delimiter="\t",encoding="utf-8",split="train"
)# 過濾中英句對(zh→en 和 en→zh)
zh2en_dataset = tatoeba_dataset.filter(lambda x: x['SRC LANG'] == "cmn" and x['TRG LANG'] == 'eng')
en2zh_dataset = tatoeba_dataset.filter(lambda x: x['SRC LANG'] == "eng" and x['TRG LANG'] == 'cmn')# 預處理函數:添加語言標簽 + 分詞
def preprocess_zh2en(batch):# 添加目標語言 tokeninputs = [">>eng<< " + x for x in batch['SRC']]# 可選:進行繁轉簡inputs = [zhconv.convert(x, 'zh-cn') for x in inputs]targets = batch['TRG']# 編碼inputs_encoded = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")outputs_encoded = tokenizer(targets, max_length=128, truncation=True, padding="max_length" )return {"input_ids": inputs_encoded["input_ids"],"attention_mask": inputs_encoded["attention_mask"],"decoder_input_ids": outputs_encoded["input_ids"],"decoder_attention_mask": outputs_encoded["attention_mask"],"labels": outputs_encoded["input_ids"].copy(),  # labels 通常跟 decoder_input_ids 相同(訓練時用于 loss)}def preprocess_en2zh(batch):# 添加目標語言 tokeninputs = [">>zho<< " + x for x in batch['SRC']]# 可選:進行繁轉簡targets = batch['TRG']targets = [zhconv.convert(x, 'zh-cn') for x in targets]# 編碼inputs_encoded = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")outputs_encoded = tokenizer(targets, max_length=128, truncation=True, padding="max_length" )return {"input_ids": inputs_encoded["input_ids"],"attention_mask": inputs_encoded["attention_mask"],"decoder_input_ids": outputs_encoded["input_ids"],"decoder_attention_mask": outputs_encoded["attention_mask"],"labels": outputs_encoded["input_ids"].copy(),  # labels 通常跟 decoder_input_ids 相同(訓練時用于 loss)}# 數據清洗 + 映射分詞
zh2en_dataset = zh2en_dataset.map(preprocess_zh2en, batched=True)
en2zh_dataset = en2zh_dataset.map(preprocess_en2zh, batched=True)# 合并中→英和英→中雙向數據
combined_dataset = Dataset.from_dict({key: zh2en_dataset[key] + en2zh_dataset[key] for key in zh2en_dataset.features
})# 拆分訓練集和測試集
split_dataset = combined_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]def compute_metrics(pred):pred_ids = pred.predictionslabel_ids = pred.label_idspred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)label_ids[label_ids == -100] = tokenizer.pad_token_idlabel_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)bleu = sacrebleu.corpus_bleu(pred_str, [label_str])# 保存驗證的結果到本地文件,這樣可以實時查看微調的效果save_dir = "./eval_logs"os.makedirs(save_dir, exist_ok=True)eval_id = f"step_{trainer.state.global_step}" if hasattr(trainer, "state") else "eval"output_file = os.path.join(save_dir, f"pred_vs_ref_{eval_id}.txt")with open(output_file, "w", encoding="utf-8") as f:for i, (pred, ref) in enumerate(zip(pred_str, label_str)):f.write(f"Sample {i + 1}:\n")f.write(f"Prediction: {pred}\n")f.write(f"Reference : {ref}\n")f.write("=" * 50 + "\n")return {"bleu": bleu.score}# 訓練參數
training_args = Seq2SeqTrainingArguments(output_dir='./model/marian-zh-en-bidirectional',num_train_epochs=30,per_device_train_batch_size=16,per_device_eval_batch_size=16,logging_steps=50,save_steps=100,eval_steps=100,eval_strategy="steps",predict_with_generate=True,save_total_limit=10,report_to="tensorboard",  # 啟用 TensorBoard 日志記錄logging_dir='./logs',  # 指定 TensorBoard 日志的保存路徑
)# 構建 Trainer
trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=train_dataset.with_format("torch"),eval_dataset=eval_dataset.with_format("torch"),tokenizer=tokenizer,data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),compute_metrics=compute_metrics
)# 開始訓練
trainer.train(resume_from_checkpoint=False)# 保存模型和 tokenizer
model.save_pretrained("./model/marian-zh-en-bidirectional")
tokenizer.save_pretrained("./model/marian-zh-en-bidirectional")

基于訓練好的模型我還搞了一套使用C++來推理的代碼,方面在更多的平臺使用,具體可以在github上搜"xinliu9451/Opus-Mt_Bidirectional_Translation"。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/82901.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/82901.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/82901.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Object轉Map集合

對象與 Map 轉換詳解&#xff1a; Object.entries() 和 Object.fromEntries() 1&#xff0c;Object.fromEntries() 的主要用途就是將鍵值對集合&#xff08;如 Map&#xff09;轉換為普通對象。 2&#xff0c;Object.entries() 返回一個二維數組&#xff0c;其中每個子數組包…

優先隊列用法

第 5 行定義了一個隊首是最大值的優先隊列,第 10 行的輸出如下: 27 - wuhan 21 - shanghai 11 - beijing 第 13 行定義了一個隊首是最小值的優先隊列,第 19 行的輸出如下: 11 - beijing 21 - shanghai 27 - wuhan #include <bits/stdc.h> using namespace std; int…

Spring Boot3.4.1 集成redis

Spring Boot3.4.1 集成redis 第一步 引入依賴 <!-- redis 緩存操作 --> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency> <!-- pool 對象池 …

Replacing iptables with eBPF in Kubernetes with Cilium

source: https://archive.fosdem.org/2020/schedule/event/replacing_iptables_with_ebpf/attachments/slides/3622/export/events/attachments/replacing_iptables_with_ebpf/slides/3622/Cilium_FOSDEM_2020.pdf 使用Cilium&#xff0c;結合eBPF、Envoy、Istio和Hubble等技術…

英一真題閱讀單詞筆記 05年

2005 年 Text 1 第一段 序號 單詞 音標 詞義 1 fat [ft] a. 豐厚的&#xff0c;巨額的&#xff1b;肥胖的 2 pay [pe?] n. 薪水 3 rise [ra?z] n. 上漲&#xff0c;增加&#xff1b;斜坡 4 pleasure [ple??(r)] n. 快樂&#xff1b;樂事 5 pleasure a…

FastAPI集成APsecheduler的BackgroundScheduler+mongodb(精簡)

項目架構&#xff1a; FastAPI(folder) >app(folder) >core(folder) >models(folder) >routers(folder) >utils(folder) main.py(file) 1 utils文件夾下新建schedulers.py from apscheduler.schedulers.background import BackgroundScheduler from apschedu…

聊一聊接口測試中耗時請求如何合理安排?

目錄 一、異步處理與輪詢機制 輪詢檢查機制 二、 并行化測試執行 三、模擬與樁技術&#xff08;Mock/Stub&#xff09; 四、動態超時與重試策略 五、測試架構設計優化 分層測試策略 并行化執行 網絡優化 六、測試用例分層管理 金字塔策略 七、 緩存與數據復用 響應…

深入詳解DICOMweb:WADO與STOW-RS的技術解析與實現

&#x1f9d1; 博主簡介&#xff1a;CSDN博客專家、CSDN平臺優質創作者&#xff0c;高級開發工程師&#xff0c;數學專業&#xff0c;10年以上C/C, C#, Java等多種編程語言開發經驗&#xff0c;擁有高級工程師證書&#xff1b;擅長C/C、C#等開發語言&#xff0c;熟悉Java常用開…

Splunk Validated Architecture (SVA):構建企業級可觀測性與安全的基石

Splunk Validated Architecture (SVA) 是 Splunk 官方提供的一套經過嚴格測試、性能驗證和最佳實踐指導的參考架構藍圖。它并非單一固定方案&#xff0c;而是根據企業數據規模、性能需求、高可用性目標和合規要求&#xff0c;提供一系列可落地的部署模型。SVA 的核心價值在于為…

Armv7l或樹莓派32位RPI 4B編譯faiss

pip3 install faiss-cpu當然找不到預編譯的包 手動下載 git clone https://github.com/facebookresearch/faiss.git cd faiss #能需要切換到特定版本標簽&#xff0c;例如 v1.7.1&#xff0c;這個版本Cmake 3.18可以過&#xff0c;因為apt install安裝的cmake只更新到這里&am…

C++之string的模擬實現

string 手寫C字符串類類的基本結構與成員變量一、構造函數與析構函數二、賦值運算符重載三、迭代器支持四、內存管理與擴容機制五、字符串操作函數六、運算符重載總結 手寫C字符串類 從零實現一個簡易版std::string 類的基本結構與成員變量 namespace zzh { class string { …

修改Docker鏡像源

配置文件位置&#xff1a; sudo vim /etc/docker/daemon.json Docker 或 containerd 的鏡像加速器配置&#xff0c;旨在提高從 Docker Hub 拉取鏡像的速度。 { "features": { "buildkit": true, "containerd-snapshotter": true }, …

服務器帶寬線路的區別(GIA、CN2、BGP、CMI等)

服務器帶寬線路的區別&#xff08;GIA、CN2、BGP、CMI等&#xff09; 一、BGP線路 1. 定義與技術特點 BGP&#xff08;Border Gateway Protocol&#xff0c;邊界網關協議&#xff09;是一種用于不同自治系統&#xff08;AS&#xff09;之間交換路由信息的協議&#xff0c;屬…

從0到1搭建AI繪畫模型:Stable Diffusion微調全流程避坑指南

從0到1搭建AI繪畫模型&#xff1a;Stable Diffusion微調全流程避坑指南 系統化學習人工智能網站&#xff08;收藏&#xff09;&#xff1a;https://www.captainbed.cn/flu 文章目錄 從0到1搭建AI繪畫模型&#xff1a;Stable Diffusion微調全流程避坑指南摘要引言一、數據集構…

VSCode + GD32F407 構建燒錄

前言 最近調試一塊 GD32F407VET6&#xff08;168Mhz&#xff0c;8Mhz晶振&#xff09; 板子時&#xff0c;踩了一些“啟動失敗”的坑。本以為是時鐘配置有誤&#xff0c;最后發現是鏈接腳本&#xff08;.ld 文件&#xff09;沒有配置好&#xff0c;導致程序根本沒能正常執行 ma…

AI繪畫提示詞:從零開始掌握Prompt Engineering的藝術

文章目錄 什么是AI繪畫提示詞&#xff1f;提示詞的基本結構主體描述場景/背景風格指定技術參數負面提示人物肖像模板風景模板 高級技巧權重調整混合風格顏色控制情緒氛圍 常見問題與解決方法手部變形問題構圖不理想風格不夠突出 提示詞示例庫科幻場景奇幻人物靜物畫 結語 在當今…

在 Linux 上安裝 Minikube:輕松搭建本地 Kubernetes 單節點集群

&#x1f525;「炎碼工坊」技術彈藥已裝填&#xff01; 點擊關注 → 解鎖工業級干貨【工具實測|項目避坑|源碼燃燒指南】 一、Minikube 是什么&#xff1f; Minikube 是 Kubernetes 官方推出的輕量級工具&#xff0c;專為開發者設計&#xff0c;用于在本地快速搭建單節點 Kube…

day41 python圖像識別任務

目錄 一、數據預處理&#xff1a;為模型打下堅實基礎 二、模型構建&#xff1a;多層感知機的實現 三、訓練過程&#xff1a;迭代優化與性能評估 四、測試結果&#xff1a;模型性能的最終檢驗 五、總結與展望 在深度學習的旅程中&#xff0c;多層感知機&#xff08;MLP&…

JS數組 concat() 與擴展運算符的深度解析與最佳實踐

文章目錄 前言一、語法對比1. Array.prototype.concat()2. 擴展運算符&#xff08;解構賦值&#xff09; 二、性能差異&#xff08;大規模數組&#xff09;關鍵差異原因 三、適用場景建議總結 前言 最近工作中遇到了一個大規模數組合并相關的問題&#xff0c;在數據合并時有些…

一套qt c++的串口通信

實現了創建線程使用串口的功能 具備功能: 1.線程使用串口 2.定時發送隊列內容&#xff0c;防止粘包 3.沒處理接收粘包&#xff0c;根據你的需求來&#xff0c;handleReadyRead函數中&#xff0c;可以通過m_receiveBuffer來緩存接收&#xff0c;然后拆分數據來處理 源碼 seri…