FastSpeech2中文語音合成就步解析:TTS數據訓練實戰篇

  1. 參考github網址:

GitHub - roedoejet/FastSpeech2: An implementation of Microsoft’s “FastSpeech 2: Fast and High-Quality End-to-End Text to Speech”

  1. 數據訓練所用python 命令:

python3 train.py -p config/AISHELL3/preprocess.yaml -m config/AISHELL3/model.yaml -t config/AISHELL3/train.yaml

  1. 數據訓練代碼解析

3.1 代碼架構overview:

通過 if __name__ == "__main__"運行整個py文件:

調用 “train.txt"和dataset.py加載數據,

調用utils文件夾下的model.py加載模型,聲碼器,

調用model文件夾下的loss.py中的FastSpeech2Loss class 設置損失函數,

用前面加載的模型和損失函數開始訓練模型,導出結果并記錄日志。

3.2 按訓練步驟分解代碼:

Step?0?: 定義可控訓練參數, 調動main函數

if __name__ == "__main__":#Define Argsparser = argparse.ArgumentParser()parser.add_argument("--restore_step", type=int, default=0)parser.add_argument("-p","--preprocess_config",type=str,required=True,help="path to preprocess.yaml",)parser.add_argument("-m", "--model_config", type=str, required=True, help="path to model.yaml")parser.add_argument("-t", "--train_config", type=str, required=True, help="path to train.yaml")args = parser.parse_args() #args為可控訓練參數# Read Configpreprocess_config = yaml.load(open(args.preprocess_config, "r"), Loader=yaml.FullLoader)model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)configs = (preprocess_config, model_config, train_config)#Run _main_ functionmain(args, configs)

Step 1 : 啟動main函數,加載可控訓練參數

def main(args, configs): print("Prepare training ...")#加載可控訓練參數preprocess_config, model_config, train_config = configs

Step 2 : 從train.txt加載數據,并經由dataset.py和torch里的Dataloader處理

def main(args, configs):# Get datasetdataset = Dataset("train.txt", preprocess_config, train_config, sort=True, drop_last=True) #從 train.txt 中獲取datasetbatch_size = train_config["optimizer"]["batch_size"]group_size = 4  # Set this larger than 1 to enable sorting in Dataset,初始值為4assert batch_size * group_size < len(dataset)loader = DataLoader(dataset,batch_size=batch_size * group_size,shuffle=True,collate_fn=dataset.collate_fn,)

Step 3 : 定義模型,聲碼器,損失函數

def main(args, configs):# Prepare modelmodel, optimizer = get_model(args, configs, device, train=True) #設置優化器# 將模型并行訓練并移入計算設備中model = nn.DataParallel(model) # Model Has Been Defined# 計算模型參數量num_param = get_param_num(model) # Number of TTS Parameters: num_paramprint("Number of FastSpeech2 Parameters:", num_param)# 設置損失函數Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)# 加載聲碼器vocoder = get_vocoder(model_config, device)

Step 4 : 加載日志,在"./output/log/AISHELL3"目錄建立train, val兩個文件夾來記錄日志

def main(args, configs):# Init loggerfor p in train_config["path"].values():os.makedirs(p, exist_ok=True)train_log_path = os.path.join(train_config["path"]["log_path"], "train")val_log_path = os.path.join(train_config["path"]["log_path"], "val")os.makedirs(train_log_path, exist_ok=True)os.makedirs(val_log_path, exist_ok=True)train_logger = SummaryWriter(train_log_path)val_logger = SummaryWriter(val_log_path)

Step 5 : 準備訓練,加載可控訓練參數

def main(args, configs):# Trainingstep = args.restore_step + 1epoch = 1grad_acc_step = train_config["optimizer"]["grad_acc_step"]grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]total_step = train_config["step"]["total_step"]log_step = train_config["step"]["log_step"]save_step = train_config["step"]["save_step"]synth_step = train_config["step"]["synth_step"]val_step = train_config["step"]["val_step"]outer_bar = tqdm(total=total_step, desc="Training", position=0)outer_bar.n = args.restore_stepouter_bar.update()

Step 6 : 準備訓練,加載進度條,調動utils文件夾下tools.py中的to_device function來提取數據

    while True:inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)for batchs in loader:for batch in batchs:batch = to_device(batch, device)

Step 7 :開始訓練,前向傳播,計算損失,反向傳播,梯度剪枝,更新模型權重參數

    #Load Datafor batch in batchs:batch = to_device(batch, device)# Forwardoutput = model(*(batch[2:]))# Cal Losslosses = Loss(batch, output)total_loss = losses[0]# Backwardtotal_loss = total_loss / grad_acc_steptotal_loss.backward()if step % grad_acc_step == 0:# Clipping gradients to avoid gradient explosionnn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)# Update weightsoptimizer.step_and_update_lr()optimizer.zero_grad()

Step 8 : 當訓練步數到達預先設定的log_step時,調動utils文件夾下tool.py里的log function,記錄loss和step

                if step % log_step == 0:losses = [l.item() for l in losses]message1 = "Step {}/{}, ".format(step, total_step)message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(*losses)with open(os.path.join(train_log_path, "log.txt"), "a") as f:f.write(message1 + message2 + "\n")outer_bar.write(message1 + message2)log(train_logger, step, losses=losses)

Step 9 : 當訓練步數到達預先設定的synth_step時,調動utils文件夾下tool.py里的log function 和?synth_one_sample function(具體用來干什么沒看懂)

                if step % synth_step == 0:fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(batch,output,vocoder,model_config,preprocess_config,)log(train_logger,fig=fig,tag="Training/step_{}_{}".format(step, tag),)sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]log(train_logger,audio=wav_reconstruction,sampling_rate=sampling_rate,tag="Training/step_{}_{}_reconstructed".format(step, tag),)log(train_logger,audio=wav_prediction,sampling_rate=sampling_rate,tag="Training/step_{}_{}_synthesized".format(step, tag),)

Step 10 : 當訓練步數到達預先設定的val_step時,調動evaluate.py里的evaluate function來進行evaluation,并記錄在log/AISHELL3/val/log.txt

                if step % val_step == 0:model.eval()message = evaluate(model, step, configs, val_logger, vocoder)with open(os.path.join(val_log_path, "log.txt"), "a") as f:f.write(message + "\n")outer_bar.write(message)model.train()

Step 11 : 當訓練步數到達預先設定的save_step時,保存訓練模型

                if step % save_step == 0:torch.save({"model": model.module.state_dict(),"optimizer": optimizer._optimizer.state_dict(),},os.path.join(train_config["path"]["ckpt_path"],"{}.pth.tar".format(step),),)

Step 12 : 當訓練步數到達預先設定的total_step時,退出訓練

                if step == total_step:quit()step += 1outer_bar.update(1)inner_bar.update(1)epoch += 1
  1. 數據訓練代碼的輸出

在train_log_path和val_log_path輸出日志

在ckpt_path輸出訓練過程中按照save_step存儲的模型

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/41122.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/41122.shtml
英文地址,請注明出處:http://en.pswp.cn/web/41122.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

ida動態調試-cnblog

ida動態調試 傳遞啟動ida服務 android_server在ida\dbgsrv目錄中 adb push android_server /data/local/tmp/chmod 755 /data/local/tmp/android_server /data/local/tmp/android_serveradb forward tcp:23946 tcp:23946ida報錯:大多是手機端口被占用 報錯提示&#xff1a; …

java面試-java基礎(下)

文章目錄 一、和equals區別&#xff1f;二、hashcode方法作用&#xff1f;兩個對象的hashCode方法相同&#xff0c;則equals方法也一定為true嗎&#xff1f;三、為什么重寫equals方法就一定要重寫hashCode方法&#xff1f;四、Java中的參數傳遞時傳值呢還是傳引用&#xff1f;五…

期末上分站——計組(3)

復習題21-42 21、指令周期是指__C_。 A. CPU從主存取出一條指令的時間 B. CPU執行一條指令的時間 C. CPU從主存取出一條指令的時間加上執行這條指令的時間。 D. 時鐘周期時間 22、微型機系統中外設通過適配器與主板的系統總線相連接&#xff0c;其功能是__D_。 A. 數據緩沖和…

數據庫可視化管理工具dbeaver試用及問題處理。

本文記錄了在內網離線安裝數據庫可視化管理工具dbeaver的過程和相關問題處理方法。 一、下載dbeaver https://dbeaver.io/download/ 筆者測試時Windows平臺最新版本為&#xff1a;dbeaver-ce-24.1.1-x86_64-setup.exe 二、安裝方法 一路“下一步”即可 三、問題處理 1、問…

【深度學習】vscode 命令行下的debug

其實我一直知道vscode可以再命令行下進行debug。 比如 python aaa.py --bb1 --cc2 以前的做法是 去aaa.py 寫死bb和cc 然后直接debug。 直到今天我遇到這個&#xff1a; hydra hydra.main(version_baseNone, config_name/home/justin/Desktop/code/python_project/WASB-SBDT-m…

Truffle學習筆記

Truffle學習筆記 安裝truffle, 注意: 雖然目前truffle最新版是 5.0.0, 但是經過我實踐之后, 返現和v4有很多不同(比如: web3.eth.accounts; 都獲取不到賬戶), 還是那句話: “nodejs模塊的版本問題會搞死人的 !” 目前4.1.15之前的版本都不能用了, 只能安裝v4.1.15 npm instal…

新手學Cocos報錯 [Assets] Failed to open

兩個都在偏好設置里面調&#xff08;文件下面的偏好設置&#xff09;&#xff1a; 1.設置中文&#xff1f; 2.報錯 [Assets] Failed to open&#xff1f; 這樣在點擊打開ts文件的時候就不會報錯&#xff0c;并且用vscode編輯器打開了&#xff0c; 同樣也可以改成你們自己喜歡…

LabVIEW在圖像處理中的應用

abVIEW作為一種圖形化編程環境&#xff0c;不僅在數據采集和儀器控制領域表現出色&#xff0c;還在圖像處理方面具有強大的功能。借助其Vision Development Module&#xff0c;LabVIEW提供了豐富的圖像處理工具&#xff0c;廣泛應用于工業檢測、醫學影像、自動化控制等多個領域…

Apache Seata應用側啟動過程剖析——RM TM如何與TC建立連接

本文來自 Apache Seata官方文檔&#xff0c;歡迎訪問官網&#xff0c;查看更多深度文章。 本文來自 Apache Seata官方文檔&#xff0c;歡迎訪問官網&#xff0c;查看更多深度文章。 Apache Seata應用側啟動過程剖析——RM & TM如何與TC建立連接 前言 看過官網 README 的第…

Android最近任務顯示的圖片

Android最近任務顯示的圖片 1、TaskSnapshot截圖1.1 snapshotTask1.2 drawAppThemeSnapshot 2、導航欄顯示問題3、Recentan按鍵進入最近任務 1、TaskSnapshot截圖 frameworks/base/services/core/java/com/android/server/wm/TaskSnapshotController.java frameworks/base/cor…

IPython 性能評估工具的較量:%%timeit 與 %timeit 的差異解析

IPython 性能評估工具的較量&#xff1a;%%timeit 與 %timeit 的差異解析 在 IPython 的世界中&#xff0c;性能評估是一項至關重要的任務。%%timeit 和 %timeit 是兩個用于測量代碼執行時間的魔術命令&#xff0c;但它們之間存在一些關鍵的差異。本文將深入探討這兩個命令的不…

2786. 訪問數組中的位置使分數最大

2786. 訪問數組中的位置使分數最大 題目鏈接&#xff1a;2786. 訪問數組中的位置使分數最大 代碼如下&#xff1a; //參考鏈接:https://leetcode.cn/problems/visit-array-positions-to-maximize-score/solutions/2810335/dp-by-kkkk-16-tn9f class Solution { public:long …

vue-router 4匯總

一、vue和vue-router版本&#xff1a; "vue": "^3.4.29", "vue-router": "^4.4.0" 二、路由傳參&#xff1a; 方式一&#xff1a; 路由配置&#xff1a;/src/router/index.ts import {createRouter,createWebHistory } from &quo…

探索 WebKit 的緩存迷宮:深入理解其高效緩存機制

探索 WebKit 的緩存迷宮&#xff1a;深入理解其高效緩存機制 在當今快速變化的網絡世界中&#xff0c;WebKit 作為領先的瀏覽器引擎之一&#xff0c;其緩存機制對于提升網頁加載速度、減少服務器負載以及改善用戶體驗起著至關重要的作用。本文將深入探討 WebKit 的緩存機制&am…

代碼隨想錄leetcode200題之額外題目

目錄 1 介紹2 訓練3 參考 1 介紹 本博客用來記錄代碼隨想錄leetcode200題之額外題目相關題目。 2 訓練 題目1&#xff1a;1365. 有多少小于當前數字的數字 解題思路&#xff1a;二分查找。 C代碼如下&#xff0c; class Solution { public:vector<int> smallerNumb…

卷積神經網絡(CNN)和循環神經網絡(RNN) 的區別與聯系

卷積神經網絡&#xff08;CNN&#xff09;和循環神經網絡&#xff08;RNN&#xff09;是兩種廣泛應用于深度學習的神經網絡架構&#xff0c;它們在設計理念和應用領域上有顯著區別&#xff0c;但也存在一些聯系。 ### 卷積神經網絡&#xff08;CNN&#xff09; #### 主要特點…

解決C++編譯時的產生的skipping incompatible xxx 錯誤

問題 我在編譯項目時&#xff0c;產生了一個 /usr/bin/ld: skipping incompatible ../../xxx/ when searching for -lxxx 的編譯錯誤&#xff0c;如下圖所示&#xff1a; 解決方法 由圖中的錯誤可知&#xff0c;在編譯時&#xff0c;是能夠在我們指定目錄下的 *.so 動態庫的…

python函數和c的區別有哪些

Python有很多內置函數&#xff08;build in function&#xff09;&#xff0c;不需要寫頭文件&#xff0c;Python還有很多強大的模塊&#xff0c;需要時導入便可。C語言在這一點上遠不及Python&#xff0c;大多時候都需要自己手動實現。 C語言中的函數&#xff0c;有著嚴格的順…

Java基礎(六)——繼承

個人簡介 &#x1f440;個人主頁&#xff1a; 前端雜貨鋪 ?開源項目&#xff1a; rich-vue3 &#xff08;基于 Vue3 TS Pinia Element Plus Spring全家桶 MySQL&#xff09; &#x1f64b;?♂?學習方向&#xff1a; 主攻前端方向&#xff0c;正逐漸往全干發展 &#x1…

【Web】

1、配倉庫 [rootlocalhost yum.repos.d]# vi rpm.repo ##本地倉庫標準寫法 [baseos] namemiaoshubaseos baseurl/mnt/BaseOS gpgcheck0 [appstream] namemiaoshuappstream baseurlfile:///mnt/AppStream gpgcheck0 2、掛載 [rootlocalhost ~]mount /dev/sr0 /mnt mount: /m…