深度學習訓練框架——監督學習為例

訓練框架

文章目錄

  • 訓練框架
    • 1. 模型網絡結構
    • 2. 數據讀取與數據加載
      • 2.1Dataloater參數
      • 2.2 collate_fn
    • 3. 優化器與學習率調整
      • 3.1 優化器
      • 3.2 學習率調度
    • 4迭代訓練
    • 4.1 train_epoch
    • 4.2 train iteration
  • 5.1 保存模型權重

本文內容以pytorch為例

1. 模型網絡結構

自定義網絡模型繼承‘nn.Module’,實現模型的參數的初始化與前向傳播;自定義網絡模型可以添加權重初始化、網絡模塊組合等其他方法

        import torch.nn as nnimport torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))

2. 數據讀取與數據加載

數據集類的基礎方法

  • dataset:

需要包含數據迭代方法

def __getitem__(self, index):image, target = self.list(index)return image, target

利用torch.utils.data.DataLoader封裝后,用于迭代遍歷數據元素;

數據長度方法

def __len__(self):return self.dataset_size

數據加載

  • dataloader:
    對數據集類(通常實現了 getitemlen 方法)時,
    你可以使用 DataLoader 來輕松地進行批量加載、打亂數據、并行加載以及多進程數據加載。
    collate_fn:將字典或數組數據流進行拆分,拆分為圖像、label、邊界框、文字編碼等不同類型數據與模型的輸入與輸出相匹配

2.1Dataloater參數

參數:

  • dataset (Dataset): 加載數據的數據集。
  • batch_size (int, 可選): 每批加載的樣本數量(默認:1)。
  • shuffle (bool, 可選): 設置為 True 以在每個 epoch 重新洗牌數據(默認:False)。
  • sampler (Sampler 或 Iterable, 可選): 定義從數據集中抽取樣本的策略。可以是任何實現了 len 的 Iterable。如果指定了 sampler,則不能指定 :attr:shuffle。
  • batch_sampler (Sampler 或 Iterable, 可選): 與 :attr:sampler 類似,但一次返回一批索引。與 :attr:batch_size, :attr:shuffle, :attr:sampler, 和 :attr:drop_last 互斥。
    num_workers (int, 可選): 數據加載使用的子進程數量。0 表示數據將在主進程中加載(默認:0)。
  • collate_fn (Callable, 可選): 將樣本列表合并以形成 Tensor(s) 的 mini-batch。在使用 map-style 數據集的批量加載時使用。
  • pin_memory (bool, 可選): 如果設置為 True,則數據加載器將在返回它們之前將 Tensors 復制到設備/CUDA 固定內存中。如果你的數據元素是自定義類型,或者你的 :attr:collate_fn 返回的批次是自定義類型,請參見下面的例子。
  • drop_last (bool, 可選): 設置為 True 以丟棄最后一個不完整的批次,如果數據集大小不能被批量大小整除。如果設置為 False 并且數據集大小不能被批量大小整除,則最后一個批次會較小(默認:False)。
  • timeout (numeric, 可選): 如果為正數,這是從工作進程收集一個批次的超時值。應始終為非負數(默認:0)。
  • worker_init_fn (Callable, 可選): 如果不是 None,這將在每個工作進程子進程上調用,輸入為工作進程 id(一個在 [0, num_workers - 1] 中的 int),在設置種子和數據加載之前。(默認:None)
  • generator (torch.Generator, 可選): 如果不是 None,這個 RNG 將被 RandomSampler 用來生成隨機索引,并被多進程用來為工作進程生成 base_seed。(默認:None)
  • prefetch_factor (int, 可選,關鍵字參數): 每個工作進程預先加載的批次數量。2 意味著將有總共 2*num_workers 個批次被預先加載。(默認值取決于 num_workers 的設定值。如果 num_workers=0,默認是 None。否則如果 num_workers>0,默認是 2)。
  • persistent_workers (bool, 可選): 如果設置為 True,則數據加載器在數據集被消費一次后不會關閉工作進程。這允許保持工作進程的 Dataset 實例存活。(默認:False)。
  • pin_memory_device (str, 可選): 如果將 pin_memory 設置為 true,則數據加載器將在返回它們之前將 Tensors 復制到設備固定內存中。

2.2 collate_fn

class CollateFunc(object):def __call__(self, batch):targets = []images = []for sample in batch:image = sample[0]target = sample[1]images.append(image)targets.append(target)images = torch.stack(images, 0) # [B, C, H, W]return images, targets

3. 優化器與學習率調整

3.1 優化器

在訓練過程中,根據梯度變化、損失函數、動量(momontum)、學習率來調整模型參數

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler.step()

優化器調整的模型參數參數包含權重的偏置與歸一化項;
優化器可以為不同的網絡層設置學習率與權重衰減

import torch
import torch.nn as nn# 定義一個簡單的神經網絡模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.fc = nn.Linear(320, 50)def forward(self, x):x = torch.relu(torch.max_pool2d(self.conv1(x), 2))x = torch.relu(torch.max_pool2d(self.conv2(x), 2))x = x.view(-1, 320)x = self.fc(x)return x# 創建模型實例
model = SimpleModel()# 使用 named_parameters() 獲取參數名稱和參數
for name, param in model.named_parameters():print(name, param.size())'''
輸出結果: 
conv1.weight torch.Size([10, 1, 5, 5])
conv1.bias torch.Size([10])
conv2.weight torch.Size([20, 10, 5, 5])
conv2.bias torch.Size([20])
fc.weight torch.Size([320, 50])
fc.bias torch.Size([50])
'''

Optimizer.add_param_group 將參數組添加到Optimizerparam_groups中。
Optimizer.load_state_dict 加載優化器狀態。
Optimizer.state_dict 以字典的形式返回優化器的狀態dict。
Optimizer.step 參數更新
Optimizer.zero_grad 重置累計梯度梯度(梯度累計發生在反向傳播之前)

優化器在模型訓練中的作用是調整模型的參數,以最小化損失函數。訓練過程通常遵循以下步驟:

  • 重置梯度:在每次迭代開始時,需要將模型參數的梯度清零,以避免累積。
  • 前向傳播:模型接收輸入數據,通過其參數進行計算,得到預測值。
  • 計算損失:使用損失函數(如均方誤差、交叉熵等)計算模型預測值與真實值之間的差異,這個差異被稱為損失值。損失函數為模型提供了優化的方向。
  • 反向傳播:根據損失值對模型參數進行反向傳播,計算每個參數的梯度,這些梯度指示了- 如何調整參數以減少損失。
  • 梯度累計
    梯度:表示模型參數發生微小變化,損失函數該如何變化
    學習率:控制參數更新的步長,學習率在參數更新前進行更新
    Momentum :考慮過去梯度的指數加權平均值來調整參數的更新規則,從而幫助模型更快地收斂,并在梯度很小時減少震蕩
    累加梯度:在每次迭代中(反向傳播后),將計算得到的梯度累加到梯度累積器中,而不是立即更新模型參數。(梯度累計是一種靈活的技術,它使得在資源有限的情況下訓練大型模型成為可能,并且可以幫助優化訓練過程。在進行反向傳播之前,如果沒有直接進行模型的梯度更新,一般會進行梯度累計)

3.2 學習率調度

學習率控制著模型參數的更新變化率,在訓練過程中采用不同的學習率衰減策略,能更幫助模型更好的擬合數據,提升模型的泛化能力,定義優化器時,會設置初始學習率,利用 torch.optim.lr_scheduler中的學習率函數對優化器與學習率調整策略進行封裝,結果返回封裝了optimizer,scheduler對象。更新optimizer的學習率
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

  • 訓練完一個epoch進行更新
  • 迭代一次進行一次更新
  • 可以在訓練過程中設置不同參數層的學習率

4迭代訓練

train_eopch:訓練完全部數據跟新一次學習或優化器參數,或者指定更新優化器參數的更新頻率
iteration: 沒迭代一次更新一次優化器參數;
兩者的主要區別在于遍歷數據的形式不同

4.1 train_epoch

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 定義一個簡單的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 2)  # 一個簡單的線性層def forward(self, x):return self.linear(x)# 實例化模型、損失函數和優化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 創建數據集和數據加載器
x_dummy = torch.randn(1000, 10)
y_dummy = torch.randint(0, 2, (1000,))
dataset = TensorDataset(x_dummy, y_dummy)
data_loader = DataLoader(dataset, batch_size=100, shuffle=True)# 假設我們想要模擬的批量大小是1000,但由于內存限制,我們只能實際使用批量大小為100
accumulation_steps = 10  # 需要累積10個steps的梯度
model.train()for epoch in range(2):  # 訓練2個epochfor i, (inputs, targets) in enumerate(data_loader):# 前向傳播outputs = model(inputs)loss = criterion(outputs, targets)# 累加梯度而不是立即清零loss.backward()# 每累積一定步數后更新一次參數if (i + 1) % accumulation_steps == 0:# 更新模型參數之前,我們需要梯度optimizer.step()optimizer.zero_grad()  # 清零梯度,準備下一次累積# 打印損失信息if (i + 1) % (accumulation_steps * 10) == 0:  # 每100個iteration打印一次print(f'Epoch [{epoch+1}/{2}], Step [{i+1}/{len(data_loader)*accumulation_steps}], Loss: {loss.item():.4f}')print("Training complete.")

4.2 train iteration

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 定義一個簡單的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 2)  # 一個簡單的線性層def forward(self, x):return self.linear(x)# 實例化模型、損失函數和優化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 創建數據集和數據加載器
x_dummy = torch.randn(1000, 10)
y_dummy = torch.randint(0, 2, (1000,))
dataset = TensorDataset(x_dummy, y_dummy)
data_loader = DataLoader(dataset, batch_size=100, shuffle=True)# 假設我們想要模擬的批量大小是1000,但由于內存限制,我們只能實際使用批量大小為100
accumulation_steps = 10  # 需要累積10個steps的梯度
model.train()
max_iter = 100
# 設置最大迭代次數
iterator = iter(data_loader)
for iter in range(1,max_iter):  # 訓練max_iter個iter# 迭代數據,若完成數據一輪迭代,則重新初始化iterator = iter(train_loader)# 直至完成max_iter次迭代try:inputs, targets = next(iterator)except:iterator = iter(train_loader)inputs, targets = next(iterator)# 前向傳播outputs = model(inputs)loss = criterion(outputs, targets)# 累加梯度而不是立即清零loss.backward()# 每累積一定步數后更新一次參數if (iter) % accumulation_steps == 0:# 更新模型參數之前,我們需要梯度optimizer.step()optimizer.zero_grad()  # 清零梯度,準備下一次累積# 打印損失信息if (iter) % (accumulation_steps * 10) == 0:  # 每100個iteration打印一次print(f'Epoch [{epoch+1}/{2}], Step [{i+1}/{len(data_loader)*accumulation_steps}], Loss: {loss.item():.4f}')print("Training complete.")

5.1 保存模型權重

以字典形式,保存權重與詳細的參數

torch.save({'model': model_eval.state_dict(),'mAP': -1.,'optimizer': self.optimizer.state_dict(),'epoch': self.epoch,'args': self.args}, checkpoint_path)

只保存模型參數

torch.save(model.state_dict(), save_temp_weights+"_fg{}.pt".format(it))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/15022.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/15022.shtml
英文地址,請注明出處:http://en.pswp.cn/web/15022.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

測試開發面試題

簡述自動化測試的三大等待 強制等待。直接使用time.sleep()方法讓程序暫停指定的時間。優點是實現簡單,缺點是不夠靈活,可能會導致不必要的等待時間浪費。隱式等待。設置一個固定的等待時間,在這個時間內不斷嘗試去查找元素,如果…

Java17 --- SpringCloud之Sentinel

目錄 一、Sentinel下載并運行 二、創建8401微服務整合Sentinel 三、流控規則 3.1、直接模式 3.2、關聯模式 3.3、鏈路模式 3.3.1、修改8401代碼 3.3.2、創建流控模式 3.4、Warm UP(預熱) ?編輯 3.5、排隊等待 四、熔斷規則 4.1、慢調用比…

【C++】09.vector

一、vector介紹和使用 1.1 vector的介紹 vector是表示可變大小數組的序列容器。就像數組一樣,vector也采用的連續存儲空間來存儲元素。也就是意味著可以采用下標對vector的元素進行訪問,和數組一樣高效。但是又不像數組,它的大小是可以動態改…

操作系統實驗四 (綜合實驗)設計簡單的Shell程序

前言 因為是一年前的實驗,很多細節還有知識點我都已經遺忘了,但我還是盡可能地把各個細節講清楚,請見諒。 1.實驗目的 綜合利用進程控制的相關知識,結合對shell功能的和進程間通信手段的認知,編寫簡易shell程序&…

Excel透視表:快速計算數據分析指標的利器

文章目錄 概述1.數據透視表基本操作1.1準備數據:1.2創建透視表:1.3設置透視表字段:1.4多級分類匯總和交叉匯總的差別1.5計算匯總數據:1.6透視表美化:1.7篩選和排序:1.8更新透視表: 2.數據透視-數…

【B站 heima】小兔鮮Vue3 項目學習筆記Day02

文章目錄 Pinia1.使用2. pinia-計數器案例3. getters實現4. 異步action5. storeToRefsx 數據解構保持響應式6. pinia 調試 項目起步1.項目初始化和git管理2. 使用ElementPlus3. ElementPlus 主題色定制4. axios 基礎配置5. 路由設計6. 靜態資源初始化和 Error lens安裝7.scss自…

Github 2024-05-24 開源項目日報 Top10

根據Github Trendings的統計,今日(2024-05-24統計)共有10個項目上榜。根據開發語言中項目的數量,匯總情況如下: 開發語言項目數量Python項目3非開發語言項目2TypeScript項目2JavaScript項目1Kotlin項目1C#項目1C++項目1Shell項目1Microsoft PowerToys: 最大化Windows系統生產…

軟件設計師備考筆記(十):網絡與信息安全基礎知識

文章目錄 一、網絡概述二、網絡互連硬件(一)網絡的設備(二)網絡的傳輸介質(三)組建網絡 三、網絡協議與標準(一)網絡的標準與協議(二)TCP/IP協議簇 四、Inter…

某神,云手機啟動?

某神自從上線之后,熱度不減,以其豐富的內容和獨特的魅力吸引著眾多玩家; 但是隨著劇情無法跳過,長草期過長等原因,近年脫坑的玩家多之又多,之前米家推出了一款云某神的app,目標是為了減少用戶手…

RedisTemplateAPI:String

文章目錄 ?1 String 介紹?2 命令?3 對應 RedisTemplate API???? 3.1 添加緩存???? 3.2 設置過期時間(單獨設置)???? 3.3 獲取緩存值???? 3.4 刪除key???? 3.5 順序遞增???? 3.6 順序遞減 ?4 以下是一些常用的API?5 應用場景 ?1 String 介紹 Str…

ue引擎游戲開發筆記(47)——設置狀態機解決跳躍問題

1.問題分析: 目前當角色起跳時,只是簡單的上下移動,空中仍然保持行走動作,并沒有設置跳躍動作,因此,給角色設置新的跳躍動作,并優化新的動作動畫。 2.操作實現: 1.實現跳躍不復雜&…

LabVIEW常用的電機控制算法有哪些?

LabVIEW常用的電機控制算法主要包括以下幾種: 1. PID控制(比例-積分-微分控制) 描述:PID控制是一種經典的控制算法,通過調節比例、積分和微分三個參數來控制電機速度和位置。應用:廣泛應用于直流電機、步…

Java中的繼承和多態

繼承 在現實世界中,狗和貓都是動物,這是因為他們都有動物的一些共有的特征。 在Java中,可以通過繼承的方式來讓對象擁有相同的屬性,并且可以簡化很多代碼 例如:動物都有的特征,有名字,有年齡…

Mybatis源碼剖析---第一講

Mybatis源碼剖析 基礎環境搭建 JDK8 Maven3.6.3&#xff08;別的版本也可以…&#xff09; MySQL 8.0.28 --> MySQL 8 Mybatis 3.4.6 準備jar&#xff0c;準備數據庫數據 把依賴導入pom.xml中 <properties><project.build.sourceEncoding>UTF-8</p…

Linux學習筆記:線程

Linux中的線程 什么是線程線程的使用原生線程庫創建線程線程的id線程退出等待線程join分離線程取消一個線程線程的局部存儲在c程序中使用線程使用c自己封裝一個簡易的線程庫 線程互斥(多線程)導致共享數據出錯的原因互斥鎖關鍵函數pthread_mutex_t :創建一個鎖pthread_mutex_in…

雷電預警監控系統:守護安全的重要防線

TH-LD1在自然界中&#xff0c;雷電是一種常見而強大的自然現象。它既有震撼人心的壯觀景象&#xff0c;又潛藏著巨大的安全風險。為了有效應對雷電帶來的威脅&#xff0c;雷電預警監控系統應運而生&#xff0c;成為現代社會中不可或缺的安全防護工具。 雷電預警監控系統的基本…

makefile 編寫規則

1.概念 1.1 什么是makefile Makefile 是一種文本文件&#xff0c;用于描述軟件項目的構建規則和依賴關系&#xff0c;通常用于自動化軟件構建過程。它包含了一系列規則和指令&#xff0c;告訴構建系統如何編譯和鏈接源代碼文件以生成最終的可執行文件、庫文件或者其他目標文件…

Node.js知識點以及案例總結

思考&#xff1a;為什么JavaScript可以在瀏覽器中被執行 每個瀏覽器都有JS解析引擎&#xff0c;不同的瀏覽器使用不同的JavaScript解析引擎&#xff0c;待執行的js代碼會在js解析引擎下執行 為什么JavaScript可以操作DOM和BOM 每個瀏覽器都內置了DOM、BOM這樣的API函數&#xf…

開源模型應用落地-食用指南-以最小成本博最大收獲

一、背景 時間飛逝&#xff0c;我首次撰寫的“開源大語言模型-實際應用落地”專欄已經完成了一半以上的內容。由衷感謝各位朋友的支持,希望這些內容能給正在學習的朋友們帶來一些幫助。 在這里&#xff0c;我想分享一下創作這個專欄的初心以及如何有效的&#xff0c;循序漸進的…

STM32F103C8T6 HC-SR04超聲波模塊——超聲波障礙物測距(HAl庫)

超聲波障礙物測距 一、HC-SR04超聲波模塊&#xff08;一&#xff09;什么是HC-SR04&#xff1f;&#xff08;二&#xff09;HC-SR04工作原理&#xff08;三&#xff09;如何使用HC-SR04&#xff08;四&#xff09;注意事項 二、程序編寫&#xff08;一&#xff09;CubeMX配置1.…