improved-diffusion代碼逐行理解之train

目錄

  • 代碼
  • 理解
    • 1、解析命令行參數
    • 2、分布式設置和日志配置
    • 3、創建模型和擴散過程
    • 4、加載數據
    • 5、訓練循環
    • 6、訓練過程中的關鍵點
    • 7、日志和模型保存

代碼

improved-diffusion代碼地址:https://github.com/openai/improved-diffusion
運行代碼會遇到的幾個問題:
1、源代碼訓練過程沒有設置結束條件,會一直運行,你需要手動終止。
2、源代碼的采樣過程可能會非常慢,需要耐心等待。
下面是image_train.py的部分代碼

def main():args = create_argparser().parse_args()dist_util.setup_dist()logger.configure()logger.log("creating model and diffusion...")model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys()))model.to(dist_util.dev())schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)logger.log("creating data loader...")data = load_data(data_dir=args.data_dir,batch_size=args.batch_size,image_size=args.image_size,class_cond=args.class_cond,)logger.log("training...")TrainLoop(model=model,diffusion=diffusion,data=data,batch_size=args.batch_size,microbatch=args.microbatch,lr=args.lr,ema_rate=args.ema_rate,log_interval=args.log_interval,save_interval=args.save_interval,resume_checkpoint=args.resume_checkpoint,use_fp16=args.use_fp16,fp16_scale_growth=args.fp16_scale_growth,schedule_sampler=schedule_sampler,weight_decay=args.weight_decay,lr_anneal_steps=args.lr_anneal_steps,).run_loop()

理解

1、解析命令行參數

使用create_argparser().parse_args()解析命令行參數,這些參數可能包括模型配置、訓練數據路徑、批量大小、學習率等。

2、分布式設置和日志配置

dist_util.setup_dist():設置分布式訓練環境,包括初始化分布式后端(如PyTorch的torch.distributed)。
logger.configure():配置日志記錄器,以便在訓練過程中記錄關鍵信息。

3、創建模型和擴散過程

通過create_model_and_diffusion函數,根據命令行參數和默認配置創建模型和擴散過程對象。這些對象被用于后續的訓練過程。
使用model.to(dist_util.dev())將模型發送到分布式訓練環境中的指定設備(如GPU)。
根據命令行參數args.schedule_sampler和擴散過程對象創建時間步采樣器schedule_sampler。

4、加載數據

使用load_data函數加載訓練數據,該函數根據指定的數據目錄(args.data_dir)、批量大小(args.batch_size)、圖像大小(args.image_size)和其他條件(如args.class_cond,表示是否進行類別條件訓練)來準備數據加載器。

5、訓練循環

實例化TrainLoop類,并傳入模型、擴散過程、數據加載器以及其他訓練相關的參數(如學習率、指數移動平均率、日志記錄間隔、保存間隔等)。
調用TrainLoop實例的run_loop方法開始訓練過程。該方法將迭代數據加載器提供的數據,執行前向傳播、損失計算、反向傳播和梯度更新等步驟,直到滿足訓練結束的條件(如達到預定的迭代次數或學習率衰減步數)。

6、訓練過程中的關鍵點

在TrainLoop的run_loop方法中,通常會包括微批次迭代、梯度清零、模型參數更新、學習率調整、模型保存和日志記錄等步驟。
如果啟用了半精度訓練(args.use_fp16),則可能需要對損失進行縮放以避免數值下溢,并在反向傳播后恢復梯度比例。
schedule_sampler用于在訓練過程中采樣不同的時間步,這對于控制擴散模型的訓練過程至關重要。

7、日志和模型保存

在訓練過程中,會定期記錄關鍵指標(如損失值)并保存到日志文件中,以便后續分析和可視化。
還會根據save_interval參數定期保存模型檢查點,以便在訓練中斷后能夠恢復訓練或進行模型評估。
這段代碼展示了深度學習訓練過程的一個高度模塊化和可配置的框架,通過命令行參數和配置文件可以輕松調整訓練參數,以適應不同的任務和硬件環境。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/42911.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/42911.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/42911.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

LDR6282-顯示器:從技術革新到視覺盛宴

顯示器,作為我們日常工作和娛樂生活中不可或缺的一部分,承載著將虛擬世界呈現為現實圖像的重要使命。它不僅是我們與電子設備交互的橋梁,更是我們感知信息、享受視覺盛宴的重要窗口。顯示器在各個領域的應用也越來越廣泛。在辦公領域&#xf…

Gradle使用插件SonatypeUploader-v2.6上傳到maven組件到遠程中央倉庫

本文基于sonatypeUploader 2.6版本 插件的使用實例:https://github.com/jeadyx/SonatypeUploaderSample 發布步驟 提前準備好sonatype賬號和signing配置 注:如果沒有,請參考1.0博文的生成步驟: https://jeady.blog.csdn.net/art…

收銀系統源碼-營銷活動-幸運抽獎

1. 功能描述 營運抽獎:智慧新零售收銀系統,線上商城營銷插件,商戶/門店在小程序商城上設置抽獎活動,中獎人員可內定; 2.適用場景 新店開業、門店周年慶、節假日等特定時間促銷;會員拉新,需會…

SQLServer連接異常

2. 文件夾對應的是[internal].[folders]表,與之相關的權限在[internal].[folder_permissions]表 項目對應的是[internal].[projects]表,與之相關的權限在[internal].[project_permissions],版本在[internal].[object_versions]表。 環境對應…

MongoDB本地配置分片

mongodb server version: 7.0.12 社區版 mongo shell version: 2.2.10 平臺:win10 64位 控制臺:Git Bash 分片相關節點結構示意圖 大概步驟 1. 配置 配置服務器 副本集 (最少3個節點) -- 創建數據目錄 mkdir -p ~/dbs/confi…

華為eNSP:HCIA匯總實驗

本次拓撲實驗需求: 1、內網地址用DHCP 2、VLAN10不能訪問外網 3、使用靜態NAT 實驗用到的技術有DHCP、劃分VLAN、IP配置、VLAN間的通信:單臂路由、VLANIF,靜態NAT、基本ACL DHCP是一種用于自動分配IP地址和其他網絡參數的協議。 劃分VLA…

新型模型架構(參數化狀態空間模型、狀態空間模型變種)

文章目錄 參數化狀態空間模型狀態空間模型變種Transformer 模型自問世以來,在自然語言處理、計算機視覺等多個領域得到了廣泛應用,并展現出卓越的數據表示與建模能力。然而,Transformer 的自注意力機制在計算每個詞元時都需要利用到序列中所有詞元的信息,這導致計算和存儲復…

Butterfly主題添加動畫加載效果

安裝插件 安裝插件,在博客根目錄[Blogroot]下打開終端,運行以下指令: npm install hexo-butterfly-wowjs --save添加配置 添加配置信息,以下為寫法示例 在站點配置文件_config.yml或者主題配置文件_config.butterfly.yml中添加 wowjs:ena…

簡單介紹 Dagger2 的入門使用

依賴注入 在介紹 Dagger2 這個之前,必須先解釋一下什么是依賴注入,因為這個庫就是用來做依賴注入的。所以這里先簡單用一句話來介紹一下依賴注入: 依賴注入是一種設計模式,它允許對象在運行時注入其依賴項。而不是在編譯時確定&a…

Andorid 11 InputDispatcher FocusedApplication設置過程分析

在Input ANR中,有一類ANR打印的reason 為 “xx does not have a focused window” ,表明 輸入事件 5s 內,只有FocusedApplication,而沒找到focused window。本文分析下FocusedApplication的設置過程。 setFocusedApp 源碼路徑&am…

iOS 應用內存超過多少會收到系統內存警告 ?

iOS 應用內存超過多少會收到系統內存警告 ? 在 iOS 應用中,系統內存警告的觸發是由 iOS 操作系統動態決定的,并不是一個固定的閾值。系統會根據當前設備的可用內存、正在運行的其他應用程序的內存需求以及當前應用程序的內存占用情況來判斷是…

用PlantUML可視化顯示JSON

概述 PlantUML除了繪制UML中的一些標準圖之外,也可以以圖形化的方式顯示一些其他圖形或數據形式的結構,這其中就包括JSON。 它以一種簡單且優美的圖形形式,表達了JSON的結構。你可以用它來作為設計JSON數據文件的依據,輔助設計或…

day01:項目概述,環境搭建

文章目錄 軟件開發整體介紹軟件開發流程角色分工軟件環境 外賣平臺項目介紹項目介紹定位功能架構 產品原型技術選型 開發環境搭建整體結構:前后端分離開發前后端混合開發缺點前后端分離開發 前端環境搭建Nginx 后端環境搭建熟悉項目結構使用Git進行版本控制數據庫環…

【C++】AVL樹(旋轉、平衡因子)

🌈個人主頁:秦jh_-CSDN博客🔥 系列專欄:https://blog.csdn.net/qinjh_/category_12575764.html?spm1001.2014.3001.5482 ? 目錄 前言 AVL樹的概念 節點 插入 AVL樹的旋轉 新節點插入較高左子樹的左側---左左:…

【C++】stack和queue的模擬實現 雙端隊列deque的介紹

🔥個人主頁: Forcible Bug Maker 🔥專欄: STL || C 目錄 🌈前言🔥stack的模擬實現🔥queue的模擬實現🔥deque(雙端隊列)deque的缺陷 🌈為什么選擇…

基于Go 1.19的站點模板爬蟲

創建一個基于Go 1.19的站點模板爬蟲涉及到幾個關鍵步驟:初始化項目,安裝必要的包,編寫爬蟲邏輯,以及處理和存儲抓取的數據。下面是一個簡單的示例,使用goquery庫來解析HTML,并使用net/http來發起HTTP請求。…

【containerd】解決敲擊crictl images命令報錯問題

【Containerd】解決輸入crictl images命令報錯問題 文章目錄 【Containerd】解決輸入crictl images命令報錯問題問題復現解決辦法驗證結果參考鏈接 問題復現 [rootmaster01 ~]# crictl images WARN[0000] image connect using default endpoints: [unix:///var/run/dockershim…

七、Docker常規軟件安裝

目錄 一、總體步驟 二、安裝tomcat 1、docker hub上查找tomcat鏡像 三、安裝MySQL 1、查看MySQL鏡像 2、拉取MySQL鏡像到本地,本次拉取MySQL5.7 3、使用MySQL鏡像創建容器 4、使用Windows數據庫工具,連接MySQL實例 5、常見問題 6、創建MySQL容器實例 7、新…

DDP:微軟提出動態detection head選擇,適配計算資源有限場景 | CVPR 2022

DPP能夠對目標檢測proposal進行非統一處理,根據proposal選擇不同復雜度的算子,加速整體推理過程。從實驗結果來看,效果非常不錯 來源:曉飛的算法工程筆記 公眾號 論文: Should All Proposals be Treated Equally in Object Detect…

同聲傳譯app哪個好免費?對話交流推薦這5個

暑期到,也是旅游出行的好日子~自打周邊不少國家都開放免簽政策之后,出國游也變得更加方便了~對于外語水平不高的朋友來講,想要保證出行體驗,其實手上只要備好一個同聲傳譯app就OK! 倘若你還不清楚都有哪些同聲傳譯app…