【PyTorch】模型訓練過程優化分析

文章目錄

  • 1. 模型訓練過程劃分
    • 1.1. 定義過程
      • 1.1.1. 全局參數設置
      • 1.1.2. 模型定義
    • 1.2. 數據集加載過程
      • 1.2.1. Dataset類:創建數據集
      • 1.2.2. Dataloader類:加載數據集
    • 1.3. 訓練循環
  • 2. 模型訓練過程優化的總體思路
    • 2.1. 提升數據從硬盤轉移到CPU內存的效率
    • 2.2. 提升CPU的運算效率
    • 2.3. 提升數據從CPU轉移到GPU的效率
    • 2.4. 提升GPU的運算效率
  • 3. 模型訓練過程優化分析
    • 3.1. 定義過程
    • 3.2. 數據集加載過程
    • 3.3. 訓練循環
      • 3.3.1. 訓練模型
      • 3.3.2. 評估模型

1. 模型訓練過程劃分

  • 主過程在__main__下。
if __name__ == '__main__':...
  • 主過程分為定義過程數據集配置過程訓練循環

1.1. 定義過程

1.1.1. 全局參數設置

參數名作用
num_epochs指定在訓練集上訓練的輪數
batch_size指定每批數據的樣本數
num_workers指定加載數據集的進程數
prefetch_factor指定每個進程的預加載因子(要求num_workers>0
device指定模型訓練使用的設備(CPU或GPU)
lr學習率,控制模型參數的更新步長

1.1.2. 模型定義

組件作用
writer定義tensorboard的事件記錄器
net定義神經網絡結構
net.apply(init_weights)模型參數初始化
criterion定義損失函數
optimizer定義優化器

1.2. 數據集加載過程

1.2.1. Dataset類:創建數據集

  • 作用:定義數據集的結構和訪問數據集中樣本的方式。定義過程中通常需要讀取數據文件,但這并不意味著將整個數據集加載到內存中
  • 如何創建數據集
    • 繼承Dataset抽象類自定義數據集
    • TensorDataset類:通過包裝張量創建數據集

1.2.2. Dataloader類:加載數據集

  • 作用:定義數據集的加載方式,但這并不意味著正在加載數據集
    • 數據批量加載:將數據集分成多個批次(batches),并逐批次地加載數據。
    • 數據打亂(可選):在每個訓練周期(epoch)開始時,DataLoader會對數據集進行隨機打亂,以確保在訓練過程中每個樣本被均勻地使用。
  • 主要參數
    參數作用
    dataset指定數據集
    batch_size指定每批數據的樣本數
    shuffle=False指定是否在每個訓練周期(epoch)開始時進行數據打亂
    sampler=None指定如何從數據集中選擇樣本,如果指定這個參數,那么shuffle必須設置為False
    batch_sampler=None指定生成每個批次中應包含的樣本數據的索引。與batch_size、shuffle 、sampler and drop_last參數不兼容
    num_workers=0指定進行數據加載的進程數
    collate_fn=None指定將一列表的樣本合成mini-batch的方法,用于映射型數據集
    pin_memory=False是否將數據緩存在物理RAM中以提高GPU傳輸效率
    drop_last=False是否在批次結束時丟棄剩余的樣本(當樣本數量不是批次大小的整數倍時)
    timeout=0定義在每個批次上等待可用數據的最大秒數。如果超過這個時間還沒有數據可用,則拋出一個異常。默認值為0,表示永不超時。
    worker_init_fn=None指定在每個工作進程啟動時進行的初始化操作。可以用于設置共享的隨機種子或其他全局狀態。
    multiprocessing_context=None指定多進程數據加載的上下文環境,即多進程庫
    generator=None指定一個生成器對象來生成數據批次
    prefetch_factor=2控制數據加載器預取數據的數量,默認預取比實際所需的批次數量多2倍的數據
    persistent_workers=False控制數據加載器的工作進程是否在數據加載完成后繼續存在

1.3. 訓練循環

  • 外層循環控制在訓練集上訓練的輪數
for epoch in trange(num_epochs):...
  • 循環內部主要有以下模塊:
    • 訓練模型
    for X, y in dataloader_train:X, y = X.to(device), y.to(device)loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()
    
    • 評估模型
      • 每輪訓練后在數據集上損失
        • 每輪訓練損失
        • 每輪測試損失
    def evaluate_loss(dataloader):"""評估給定數據集上模型的損失"""metric = d2l.Accumulator(2)  # 損失的總和, 樣本數量with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)loss = criterion(net(X), y)metric.add(loss.sum(), loss.numel())return metric[0] / metric[1]
    

2. 模型訓練過程優化的總體思路

注意: 以下只區分變量、對象是在GPU還是在CPU內存中處理。實際處理過程使用的硬件是CPU、內存和GPU,其中CPU有緩存cache,GPU有顯存。忽略具體的數據傳輸路徑和數據處理設備。談GPU包括GPU和顯存,談CPU內存包括CPU、緩存cache和內存

主過程子過程追蹤情況
定義過程全局參數設置變量的定義都是由CPU完成的
模型定義
  • 對象的定義都是由CPU完成的
  • 模型參數和梯度信息可以轉移到GPU
數據集配置過程——對象的定義都是由CPU完成的
訓練循環訓練模型
  • 每批數據的加載是由CPU完成的,先加載到CPU內存,然后可以轉移到GPU
  • 數據的前向傳播可以由GPU完成
  • 誤差反向傳播(包括梯度計算)可以由GPU完成的
  • 模型參數更新可以由GPU完成的
評估模型
  • 每批數據的加載是由CPU完成的,先加載到CPU內存,然后可以轉移到GPU
  • 數據的前向傳播可以由GPU完成,此時可以禁用自動求導機制

由此,要提升硬件資源的利用率和訓練效率,總體上有以下角度:

2.1. 提升數據從硬盤轉移到CPU內存的效率

  • 如果數據集較小,可以一次性讀入CPU內存,之后注意要num_workers設置為0,由主進程加載數據集。否則會增加多余的過程(數據從CPU內存到CPU內存),而且隨進程數num_workers增加而增加。
  • 如果數據集很大,可以采用多進程讀取num_workers設置為大于0的數,小于CPU內核數,加載數據集的效率隨著進程數num_workers增加而增加;也隨著預讀取因子prefetch_factor的增加而增加,之后大致不變,因為預讀取到了極限。
  • 如果數據集較小,但是需要逐元素的預處理,可以采用多進程讀取,以稍微增加訓練時間為代價降低操作的復雜度。

2.2. 提升CPU的運算效率

2.3. 提升數據從CPU轉移到GPU的效率

  • 數據傳輸未準備好也傳輸(即非阻塞模式):non_blocking=True
  • 將張量固定在CPU內存 :pin_memory=True

2.4. 提升GPU的運算效率

  • 使用自動混合精度(AMP,要求pytorch>=1.6.0):通過將模型和數據轉換為低精度的形式(如FP16),可以顯著減少GPU內存使用。

3. 模型訓練過程優化分析

3.1. 定義過程

  • 特點:每次程序運行只需要進行一次。
  • 優化思路:將模型轉移到GPU,同時non_blocking=True

3.2. 數據集加載過程

  • 特點:只是定義數據加載的方式,并沒有加載數據。
  • 優化思路:合理設置數據加載參數,如
    • batch_size:一般取能被訓練集大小整除的值。過小,則每次參數更新時所用的樣本數較少,模型無法充分地學習數據的特征和分布,同時參數更新頻繁,模型收斂速度提高,CPU到GPU的數據傳輸次數增加,CPU內存的消耗總量增加;過大,則每次參數更新時所用的樣本數較多,模型性能更穩定,對GPU、CPU內存的單次消耗增加,對硬件配置要求更高,同時參數更新緩慢,模型收斂速度下降。
    • num_workers:取小于CPU內核數的合適值,比如先取CPU內核數的一半。過小,則數據加載進程少,數據加載緩慢;過大,則數據加載進程多,對CPU要求高,同時也影響效率。
    • pin_memory:當設置為True時,它告訴DataLoader將加載的數據張量固定在CPU內存中,使數據傳輸到GPU的過程更快。
    • prefetch_factor:決定每次從磁盤加載多少個batch的數據到內存中,預先加載batch越多,在處理數據時,不會因為數據加載的延遲而影響整體的訓練速度,同時可以讓GPU在處理數據時保持忙碌,從而提高GPU利用率;過大,則會導致CPU內存消耗增加。

3.3. 訓練循環

  • 優化思路:
    • 訓練和評估過程分離或者減少評估的次數:模型從訓練到評估需要進行狀態切換,模型評估過程開銷很大。
    • 盡量使用非局部變量:減少變量、對象的創建和銷毀過程

3.3.1. 訓練模型

  • 特點:訓練結構固定
  • 優化思路:
    • 將數據轉移到GPU,同時non_blocking=True
    • 優化訓練結構:比如使用自動混合精度:
    from torch.cuda.amp import autocast, GradScalergrad_scaler = GradScaler()
    for epoch in range(num_epochs):start_time = time.perf_counter()for X, y in dataloader_train:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)with autocast():loss = criterion(net(X), y)optimizer.zero_grad()grad_scaler.scale(loss.mean()).backward()grad_scaler.step(optimizer)grad_scaler.update()
    

3.3.2. 評估模型

  • 特點:評估結構固定
  • 優化思路:
    • 將數據轉移到GPU,同時non_blocking=True
    • 減少不必要的運算:比如梯度計算,即:
    with torch.no_grad():...
    

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/211438.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/211438.shtml
英文地址,請注明出處:http://en.pswp.cn/news/211438.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SPRD Android 13 需要在設置--顯示--鎖定屏幕--雙行時鐘--<關閉>

開始去改默認值沒生效 --- a/frameworks/base/packages/SettingsProvider/res/values/defaults.xml +++ b/frameworks/base/packages/SettingsProvider/res/values/defaults.xml @@ -336,4 +336,6 @@<integer name="def_navigation_bar_config">0</integer…

西南科技大學數字電子技術實驗三(MSI邏輯器件設計組合邏輯電路及FPGA的實現)FPGA部分

一、實驗目的 進一步掌握MIS(中規模集成電路)設計方法。通過用MIS譯碼器、數據選擇器實現電路功能,熟悉它們的應用。進一步學習如何記錄實驗中遇到的問題及解決方法。二、實驗原理 1、4位奇偶校驗器 Y=S7i=0DiMi D0=D3=D5=D6=D D1=D2=D4=D7= `D 2、組合邏輯電路 F=A`B C …

面試計算機網絡八股文五問五答第二期

面試計算機網絡八股文五問五答第二期 作者&#xff1a;程序員小白條&#xff0c;個人博客 相信看了本文后&#xff0c;對你的面試是有一定幫助的&#xff01; ?點贊?收藏?不迷路&#xff01;? 1.OSI七層協議&#xff1f; 2. TCP和UDP傳輸協議的區別&#xff1f; TCP是可…

C語言_常見位操作

C語言_常見位操作 文章目錄 C語言_常見位操作一、位操作函數二、代碼示例 一、位操作函數 設置某位為1或者對某位清0、獲取某位的值、對某位取反 /*對某位置1*/ unsigned Setbit(unsigned x,int n) {return x | 1 << n; }/*對某位清0*/ unsigned Resetbit(unsigned x,…

為什么要用向量檢索

之前寫過一篇文章&#xff0c;是我個人到目前階段的認知&#xff0c;所做的判斷。我個人是做萬億級數據的搜索優化工作的。一直在關注任何和搜索相關的內容。 下一代搜索引擎會什么&#xff1f;-CSDN博客 這篇文章再來講講為什么要使用向量搜索。 在閱讀這篇文章之前呢&#xf…

【網絡安全】網絡設備可能面臨哪些攻擊?

網絡設備通常是網絡基礎設施的核心&#xff0c;并控制著整個網絡的通信和安全&#xff0c;同樣面臨著各種各樣的攻擊威脅。 對網絡設備的攻擊一旦成功&#xff0c;并進行暴力破壞&#xff0c;將會導致網絡服務不可用&#xff0c;且可以對網絡流量進行控制&#xff0c;利用被攻陷…

【JavaEE】線程池

作者主頁&#xff1a;paper jie_博客 本文作者&#xff1a;大家好&#xff0c;我是paper jie&#xff0c;感謝你閱讀本文&#xff0c;歡迎一建三連哦。 本文于《JavaEE》專欄&#xff0c;本專欄是針對于大學生&#xff0c;編程小白精心打造的。筆者用重金(時間和精力)打造&…

springcloud分布式事務

文章目錄 一.為什么引入分布式事務?二.理論基礎1.CAP定理2.BASE理論 三.Seata1.微服務集成Seata2.XA模式(掌握)3.AT模式(重點)4.TCC模式(重點)5.Saga模式(了解) 四.四種模式對比五.Seata高可用 一.為什么引入分布式事務? 事務的ACID原則 在大型的微服務項目中,每一個微服務都…

案例課4——智齒客服

1.公司介紹 智齒科技&#xff0c;一體化客戶聯絡中心解決方案提供商。提供基于「客戶聯絡中心」場景的一體化解決方案&#xff0c;包括公域私域、營銷服務、軟件BPO的三維一體化。 智齒科技不斷整合前沿的人工智能及大數據技術&#xff0c;已構建形成呼叫中心、機器人「在線語音…

Python中函數的遞歸調用

函數調用自己的編程方式被稱為函數的遞歸調用。遞歸通常能夠將一個大型的復雜問題的遞歸條件&#xff0c;一層一層的回溯到終止條件&#xff0c;然后再根據終止條件的運算結果&#xff0c;一層一層的遞進運算到滿足全部的遞歸條件。它能夠使用少量程序描述出解題過程中的重復運…

主機訪問Android模擬器網絡服務方法

0x00 背景 因為公司的一個手機app的開發需求&#xff0c;要嘗試鏈接手機開啟的web服務。于是在Android Studio的Android模擬器上嘗試連接&#xff0c;發現谷歌給模擬器做了網絡限制&#xff0c;不能直接連接。當然這個限制似乎從很久以前就存在了。一直沒有注意到。 0x01 And…

分銷電商結算設計

概述 分銷電商中涉及支付與結算&#xff1b;支付職責是收錢&#xff0c;結算則是出錢給各利益方&#xff1b; 結算核心圍繞業務模式涉及哪些費用&#xff0c;以及這些費用什么時候通過什么出資渠道&#xff0c;由誰給到收方利益方&#xff1b; 結算要素組成費用項結算周期出…

區塊鏈的可拓展性研究【03】擴容整理

為什么擴容&#xff1a;在layer1上&#xff0c;交易速度慢&#xff0c;燃料價格高 擴容的目的&#xff1a;在保證去中心化和安全性的前提下&#xff0c;提升交易速度&#xff0c;更快確定交易&#xff0c;提升交易吞吐量&#xff08;提升每秒交易量&#xff09; 目前方案有&…

詳解進程管理(銀行家算法、死鎖詳解)

處理機是計算機系統的核心資源。操作系統的功能之一就是處理機管理。隨著計算機的迅速發展&#xff0c;處理機管理顯得更為重要&#xff0c;這主要由于計算機的速度越來越快&#xff0c;處理機的充分利用有利于系統效率的大大提高&#xff1b;處理機管理是整個操作系統的重心所…

前后端聯調神器《OpenAPI-Codegen》

在后端開發完接口之后&#xff0c;前端如果再去寫一遍接口來聯調的話&#xff0c;會很浪費時間&#xff0c;這個時候使用OpenAPI接口文檔來生成Axios接口代碼的話&#xff0c;會大大提高我們的開發效率。 Axios引入 Axios是一個基于Promise的HTTP客戶端&#xff0c;用于瀏覽器…

Go壓測工具

前言 在做Go的性能分析調研的時候也使用到了一些壓測方面的工具&#xff0c;go本身也給我們提供了BenchMark性能測試用例&#xff0c;可以很好的去測試我們的單個程序性能&#xff0c;比如測試某個函數&#xff0c;另外還有第三方包go-wrk也可以幫助我們做http接口的性能壓測&…

C# 任務并行類庫Parallel調用示例

寫在前面 Task Parallel Library 是微軟.NET框架基礎類庫&#xff08;BCL&#xff09;中的一個&#xff0c;主要目的是為了簡化并行編程&#xff0c;可以實現在不同的處理器上并行處理不同任務&#xff0c;以提升運行效率。Parallel常用的方法有For/ForEach/Invoke三個靜態方法…

Element-UI定制化Tree 樹形控件

1.復制 說明&#xff1a;復制Tree樹形控件。 <script> export default {data() {return {data: [{label: 一級 1,children: [{label: 二級 1-1,children: [{label: 三級 1-1-1}]}]}, {label: 一級 2,children: [{label: 二級 2-1,children: [{label: 三級 2-1-1}]}, {l…

Linux:進程優先級與命令行參數

目錄 1.進程優先級 1.1 基本概念 1.2 查看系統進程 1.3 修改進程優先級的命令 2.進程間切換 2.1 相關概念 2.2 Linux2.6內核進程調度隊列&#xff08;了解即可&#xff09; 3.命令行參數 1.進程優先級 1.1 基本概念 cpu資源分配的先后順序&#xff0c;就是指進程的優…

【C++】在類外部定義成員函數時,不應該再次指定默認參數值

2023年12月10日&#xff0c;周日下午 錯誤的代碼 #include<iostream>class A { public:void fun(int a10); };void A::fun(int a10) //<----在這里報錯 {}int main() {} 正確的代碼 代碼目前有一個問題&#xff0c;主要是在類外部定義成員函數時&#xff0c;不應該…