pytorch 中 drop_last與 nn.Parameter

1. drop_last

在使用深度學習,pytorch 的DataLoader 中,

from torch.utils.data import DataLoader# Define your dataset and other necessary configurations
# Create DataLoader
train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True)

drop_last=True :DataLoader 中的此設置會刪除不完整的最后一批(如果它小于指定的批量大小)。這確保了訓練期間處理的每個批次包含相同數量的樣本。

1.1 drop_last = True

dataset_size = 100
batch_size = 32
train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True)

使用 drop_last=True ,DataLoader 確保每個批次包含 32 個樣本,刪除不完整的最終批次。例如,在這種情況下,訓練期間將處理 3 個批次(32、32、32),其余 4 個樣本將不會用于訓練。

適用情況:

當網絡模型的初始化中,需要用到batch size 時, 這種情況下, 需要注意的是此時, drop_last = False , 會影響網絡模型結構, 由于模型的初始化過程中,使用了batch size 參數, 所有此時應該設置為 True;

1.2 drop_last = False

而當 drop_last = False, 當最后一個批次中, 剩余的樣本個數不足 batch 樣本數目時, 會保留這剩余的樣本,使用剩余的樣本進行訓練。

當數據不均衡, 并且某一類中樣本數量很少時, 此時 drop_last = True 會嚴重影響到模型的精度,此時應該使用 False;

原因是,本身的某個類別中訓練集和測試集的數量就已經小于batch size 時, 此時使用 drop last, 會嚴重該類別的訓練和測試效果。

如下面的情況:

遇到了這樣的問題。一共16類,第15 16類的訓練集數量是15、15,測試集分別為14、5。其他1-14類訓練集分別有50個,測試集均為200左右。

當我在pytorch的dataloader中設置了drop_last=True時,無論怎么訓練,使用怎么樣的數據增強,第15 16類才測試集上的準確率永遠為0.

原因分析:
當dataloader設置了drop_last=True時,在訓練時如果數據總量無法整除batch_size,那么這個dataloader就會丟掉最后一個batch,也就是說訓練的時候有部分數據是被丟掉的。而我遇到的情況可能是正好把第15 16類的測試數據給丟掉了部分,導致模型很好的學習到這兩類的特征。

解決方案:
將drop_last改為False,即可解決該問題。

2. nn.Parameter

在深度學習訓練過程中, 通常需要自己創建出一個初始化的張量, 并且希望通過模型訓練過程中, 更新該張量。

torch.randn(bt, 3, 256)

而普通的使用torch 隨機初始化的方式,如上面的這種方式,
在大多數情況下,隨機初始化張量不會使其參數變得可學習。在沒有任何相關學習過程或梯度更新的情況下隨機初始化的張量在網絡訓練期間不會適應或改變。

2.1 可學習參數

為了使得創建的張量,在網絡訓練過程中,可以得到更新。

在 PyTorch 中, nn.Parameter 是一個繼承自 torch.Tensor 的類。它允許您向框架指示該張量應被視為模型參數的一部分。當您將其分配為 nn.Module 中的屬性時,它在優化過程中變得可訓練。

import torch
import torch.nn as nn# Creating a tensor as a learnable parameter
param_tensor = nn.Parameter(torch.randn(1, 3))

param_tensor 將在訓練過程中進行優化因為它們被視為模型可學習參數的一部分。

放到 cuda 設備上

 self.cuda_param = nn.Parameter(torch.randn(1, 2).cuda())

2.2 nn.ParameterList

同樣, 當想創建一個列表都是可學習的參數時, 使用如下的方式;

self.parameters = nn.ParameterList([nn.Parameter(torch.randn(256)) for _ in range(5)])

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/213710.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/213710.shtml
英文地址,請注明出處:http://en.pswp.cn/news/213710.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vue項目列表跳轉詳情返回列表頁保留搜索條件

需求 列表進入詳情后,返回詳情的時候保留搜索的條件,第幾頁進入的返回還在第幾頁 1.在詳情頁設置定義一個字段 mounted() {sessionStorage.setItem("msgInfo", true);},2.在獲取列表數據的時候在mounted里面判斷定義的字段 if (sessionStor…

【EI會議征稿】第二屆純數學、應用數學與計算數學國際學術會議(PACM 2024)

第二屆純數學、應用數學與計算數學國際學術會議(PACM 2024) 2024 2nd International Cnference on Pure, Applied and Computational Mathematics (PACM 2024) 第二屆純數學、應用數學計算數學國際學術會議 (PACM2024) 將于2024年1月19-21日在中國廈門隆…

報錯:AttributeError: ‘DataFrame‘ object has no attribute ‘reshape‘

這個錯誤通常發生在你試圖在 Pandas DataFrame 上直接使用 reshape 方法時。reshape 方法通常與 NumPy 數組相關聯,而不是 Pandas DataFrame。 如果你正在使用 Pandas DataFrame 并希望重新塑造它,你應該使用 Pandas 的重塑函數,如 pivot、m…

linux常用命令大全50個Linux常用命令

Linux有許多常用的命令,這些命令可以用來管理文件、運行程序、查看系統狀態等。以下是一些常用的Linux命令: pwd:顯示當前所在的工作目錄的全路徑名稱。cd:用于更改當前工作目錄,例如,若要進入Documents目…

UE5 樹葉飄落 學習筆記

一個Plane是由兩個三角形構成的,所以World Position Offset,只會從中間這條線折疊 所有材質 這里前幾篇博客有說這種邏輯,就是做一個對稱的漸變數值 這里用粒子的A值來做樹葉折疊的程度,當然你也可以用Dynamic Param 這樣就可以讓…

Android 11.0 長按按鍵切換SIM卡默認移動數據

Android 11.0 長按按鍵切換SIM卡默認移動數據 近來收到客戶需求想要通過長按按鍵實現切換SIM卡默認移動數據的功能,該功能主要通過長按按鍵發送廣播來實現,具體修改參照如下: 首先創建廣播,具體修改參照如下: /vend…

麒麟KYLINOS上刪除多余有線連接

原文鏈接:麒麟KYLINOS上刪除多余網絡有線連接 hello,大家好啊,今天我要給大家介紹的是在麒麟KYLINOS操作系統中,如何刪除通過Parallels Desktop虛擬機安裝時產生的多余有線連接。在使用Parallels Desktop虛擬機安裝麒麟桌面操作系…

C/C++ 題目:給定字符串s1和s2,判斷s1是否是s2的子序列

判斷子序列一個字符串是否是另一個字符串的子序列 解釋:字符串的一個子序列是原始字符串刪除一些(也可以不刪除)字符,不改變剩余字符相對位置形成的新字符串。 如,"ace"是"abcde"的一個子序…

服務器數據恢復—raid5少盤狀態下新建raid5如何恢復原raid5數據?

服務器數據恢復環境: 一臺服務器上搭建了一組由5塊硬盤組建的raid5陣列,服務器上層存放單位重要數據,無備份文件。 服務器故障&分析: 服務器上raid5有一塊硬盤掉線,外聘運維人員在沒有了解服務器具體情況下&#x…

如何在linux中使用rpm管理軟件

本章主要介紹使用rpm對軟件包進行管理。 使用rpm查詢軟件的信息 使用rpm安裝及卸載軟件 使用rpm對軟件進行更新 使用rpm對軟件進行驗證 rpm 全稱是redhat package manager,后來改成rpm package manager,這是根據源 碼包編譯出來的包。先從光盤中拷貝一…

[算法每日一練]-雙指針 (保姆級教程篇 1) #A-B數對 #求和 #元音字母 #最短連續子數組 #無重復字符的最長子串 #最小子串覆蓋 #方塊桶

目錄 A-B數對 解法一:雙指針 解法二:STL二分查找 解法三:map 求和 元音字母 最短連續子數組 無重復字符的最長子串 最小子串覆蓋 方塊桶 雙指針特點:雙指針絕不回頭 A-B數對 解法一:雙指針 先把數列排列成…

《C++新經典設計模式》之第8章 外觀模式

《C新經典設計模式》之第8章 外觀模式 外觀模式.cpp 外觀模式.cpp #include <iostream> #include <memory> using namespace std;// 中間層角色&#xff0c;隔離接口&#xff0c;兩部分模塊通過中間層打交道 // 提供簡單接口&#xff0c;不與底層直接打交道 // 提…

Grounding DINO、TAG2TEXT、RAM、RAM++論文解讀

提示&#xff1a;Grounding DINO、TAG2TEXT、RAM、RAM論文解讀 文章目錄 前言一、Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection1、摘要2、背景3、部分文獻翻譯4、貢獻5、模型結構解讀a.模型整體結構b.特征增強結構c.解碼結構 6、實…

使用Sourcetrail解析C項目

閱讀源碼的工具很多&#xff0c;今天給大家推薦一款別具一格的源碼閱讀神器。 它就是 Sourcetrail&#xff0c;一個免費開源、跨平臺的可視化源碼探索項目 使用

釋放深度學習的力量:使用 CUDA 和 Turing GPU 構建 AI

深度學習是一種人工智能的分支,它使用神經網絡模擬人類大腦的學習過程,從大量的數據中學習特征和規律。深度學習已經徹底改變了無數領域,從圖像和語音識別到自然語言處理和自動駕駛汽車。但是,要充分利用深度學習的強大功能,需要強大的工具,而 NVIDIA 的 Turing GPU 就是…

Faster R-CNN pytorch源碼血細胞檢測實戰(二)數據增強

Faster R-CNN pytorch源碼血細胞檢測實戰&#xff08;二&#xff09;數據增強 文章目錄 Faster R-CNN pytorch源碼血細胞檢測實戰&#xff08;二&#xff09;數據增強1. 資源&參考2. 數據增強2.1 代碼運行2.2 文件存放 3 數據集劃分4. 訓練&測試5. 總結 1. 資源&參…

靜態SOCKS5的未來發展趨勢和新興應用場景

隨著網絡技術的不斷發展和進步&#xff0c;靜態SOCKS5代理也在不斷地完善和發展。未來&#xff0c;靜態SOCKS5代理將會呈現以下發展趨勢和新興應用場景。 一、發展趨勢 安全性更高&#xff1a;隨著網絡安全問題的日益突出&#xff0c;用戶對代理服務器的安全性要求也越來越高…

AcWing 3425:小白鼠排隊 ← 北京大學考研機試題

【題目來源】https://www.acwing.com/problem/content/3428/【題目描述】 N 只小白鼠&#xff0c;每只鼠頭上戴著一頂有顏色的帽子。 現在稱出每只白鼠的重量&#xff0c;要求按照白鼠重量從大到小的順序輸出它們頭上帽子的顏色。 帽子的顏色用 red&#xff0c;blue 等字符串來…

c#下載微信跟支付寶交易賬單

下載微信交易賬單 //賬單日期只能下載前一天的string datetime DateTime.Now.AddDays(-1).ToString("yyyy-MM-dd");string body "";string URL "/v3/bill/fundflowbill" "?bill_date" datetime;//生成簽名認證var auth BuildAu…

nodejs 異步函數加 await 和不加 await 的區別

在 nodejs 中&#xff0c;異步函數加上 await 和不加 await 的區別在于函數的返回值。 當一個異步函數加上 await 時&#xff0c;它會暫停當前函數的執行&#xff0c;直到異步操作完成并返回結果。這意味著可以直接使用異步操作的結果&#xff0c;而不需要使用 .then() 方法或…