Python訓練Day39

@浙大疏錦行

  1. 圖像數據的格式:灰度和彩色數據
  2. 模型的定義
  3. 顯存占用的4種地方
    1. 模型參數+梯度參數
    2. 優化器參數
    3. 數據批量所占顯存
    4. 神經元輸出中間狀態
  4. batchisize和訓練的關系

一、 圖像數據的介紹

? ? 圖像數據,相較于結構化數據(表格數據)他的特點在于他每個樣本的的形狀并不是(特征數,),而是(寬,高,通道數)

? ? 結構化數據(如表格)的形狀通常是 (樣本數, 特征數),例如 (1000, 5) 表示 1000 個樣本,每個樣本有 5 個特征。圖像數據的形狀更復雜,需要保留空間信息(高度、寬度、通道),因此不能直接用一維向量表示。其中顏色信息往往是最開始輸入數據的通道的含義,因為每個顏色可以用紅綠藍三原色表示,因此一般輸入數據的通道數是 3。? ?

1.1 灰度圖像

# 隨機選擇一張圖片,可以重復運行,每次都會隨機選擇
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 隨機選擇一張圖片的索引
# len(train_dataset) 表示訓練集的圖片數量;size=(1,)表示返回一個索引;torch.randint() 函數用于生成一個指定范圍內的隨機數,item() 方法將張量轉換為 Python 數字
image, label = train_dataset[sample_idx] # 獲取圖片和標簽
# 可視化原始圖像(需要反歸一化)
def imshow(img):img = img * 0.3081 + 0.1307  # 反標準化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 顯示灰度圖像plt.show()print(f"Label: {label}")
imshow(image)

? ? MNIST 數據集是手寫數字的 灰度圖像,每個像素點的取值范圍為 0-255(黑白程度),因此 通道數為 1。圖像尺寸統一為 28×28 像素。

1.2 彩色圖像

? ? 在 PyTorch 中,圖像數據的形狀通常遵循 (通道數, 高度, 寬度) 的格式(即 Channel First 格式),這與常見的 (高度, 寬度, 通道數)(Channel Last,如 NumPy 數組)不同。---注意順序關系,

注意點:

1. 如果用matplotlib庫來畫圖,需要轉換下順序,我們后續介紹

2. 模型輸入通常需要 批次維度(Batch Size),形狀變為 (批次大小, 通道數, 高度, 寬度)。例如,批量輸入 10 張 MNIST 圖像時,形狀為 (10, 1, 28, 28)。

# 打印一張彩色圖像,用cifar-10數據集
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np# 設置隨機種子確保結果可復現
torch.manual_seed(42)
# 定義數據預處理步驟
transform = transforms.Compose([transforms.ToTensor(),  # 轉換為張量并歸一化到[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 標準化處理
])# 加載CIFAR-10訓練集
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform
)# 創建數據加載器
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True
)# CIFAR-10的10個類別
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 隨機選擇一張圖片
sample_idx = torch.randint(0, len(trainset), size=(1,)).item()
image, label = trainset[sample_idx]# 打印圖片形狀
print(f"圖像形狀: {image.shape}")  # 輸出: torch.Size([3, 32, 32])
print(f"圖像類別: {classes[label]}")# 定義圖像顯示函數(適用于CIFAR-10彩色圖像)
def imshow(img):img = img / 2 + 0.5  # 反標準化處理,將圖像范圍從[-1,1]轉回[0,1]npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 調整維度順序:(通道,高,寬) → (高,寬,通道)plt.axis('off')  # 關閉坐標軸顯示plt.show()# 顯示圖像
imshow(image)

二、 圖像相關的神經網絡的定義

2.1 黑白圖像模型的定義

# 先歸一化,再標準化
transform = transforms.Compose([transforms.ToTensor(),  # 轉換為張量并歸一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST數據集的均值和標準差,這個值很出名,所以直接使用
])
import matplotlib.pyplot as plt# 2. 加載MNIST數據集,如果沒有會自動下載
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)
# 定義兩層MLP神經網絡
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 將28x28的圖像展平為784維向量self.layer1 = nn.Linear(784, 128)  # 第一層:784個輸入,128個神經元self.relu = nn.ReLU()  # 激活函數self.layer2 = nn.Linear(128, 10)  # 第二層:128個輸入,10個輸出(對應10個數字類別)def forward(self, x):x = self.flatten(x)  # 展平圖像x = self.layer1(x)   # 第一層線性變換x = self.relu(x)     # 應用ReLU激活函數x = self.layer2(x)   # 第二層線性變換,輸出logitsreturn x# 初始化模型
model = MLP()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # 將模型移至GPU(如果可用)from torchsummary import summary  # 導入torchsummary庫
print("\n模型結構信息:")
summary(model, input_size=(1, 28, 28))  # 輸入尺寸為MNIST圖像尺寸

我們關注和之前結構化MLP的差異

1. 輸入需要展平操作

? ? MLP 的輸入層要求輸入是一維向量,但 MNIST 圖像是二維結構(28×28 像素),形狀為 [1, 28, 28](通道 × 高 × 寬)。nn.Flatten()展平操作 將二維圖像 “拉成” 一維向量(784=28×28 個元素),使其符合全連接層的輸入格式。

? ? 其中不定義這個flatten方法,直接在前向傳播的過程中用 x = x.view(-1, 28 * 28) 將圖像展平為一維向量也可以實現

2. 輸入數據的尺寸包含了通道數input_size=(1, 28, 28)

3. 參數的計算

  • 第一層 layer1(全連接層)

權重參數:輸入維度 × 輸出維度 = 784 × 128 = 100,352

偏置參數:輸出維度 = 128

合計:100,352 + 128 = 100,480

  • 第二層 layer2(全連接層)

權重參數:輸入維度 × 輸出維度 = 128 × 10 = 1,280

偏置參數:輸出維度 = 10

合計:1,280 + 10 = 1,290

  • 總參數:100,480(layer1) + 1,290(layer2) = 101,770

2.2 彩色圖像模型的定義

class MLP(nn.Module):def __init__(self, input_size=3072, hidden_size=128, num_classes=10):super(MLP, self).__init__()# 展平層:將3×32×32的彩色圖像轉為一維向量# 輸入尺寸計算:3通道 × 32高 × 32寬 = 3072self.flatten = nn.Flatten()# 全連接層self.fc1 = nn.Linear(input_size, hidden_size)  # 第一層self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)  # 輸出層def forward(self, x):x = self.flatten(x)  # 展平:[batch, 3, 32, 32] → [batch, 3072]x = self.fc1(x)      # 線性變換:[batch, 3072] → [batch, 128]x = self.relu(x)     # 激活函數x = self.fc2(x)      # 輸出層:[batch, 128] → [batch, 10]return x# 初始化模型
model = MLP()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # 將模型移至GPU(如果可用)from torchsummary import summary  # 導入torchsummary庫
print("\n模型結構信息:")
summary(model, input_size=(3, 32, 32))  # CIFAR-10 彩色圖像(3×32×32)
  • ?第一層 layer1(全連接層)

權重參數:輸入維度 × 輸出維度 = 3072 × 128 = 393,216

偏置參數:輸出維度 = 128

合計:393,216 + 128 = 393,344

  • -第二層 layer2(全連接層)

權重參數:輸入維度 × 輸出維度 = 128 × 10 = 1,280

偏置參數:輸出維度 = 10

合計:1,280 + 10 = 1,290

  • ?總參數:393,344(layer1) + 1,290(layer2) = 394,634

?2.3 模型定義與batchsize的關系

? ? 實際定義中,輸入圖像還存在batchsize這一維度。在 PyTorch 中,模型定義和輸入尺寸的指定不依賴于 batch_size,無論設置多大的 batch_size,模型結構和輸入尺寸的寫法都是不變的。

class MLP(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten() # nn.Flatten()會將每個樣本的圖像展平為 784 維向量,但保留 batch 維度。self.layer1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.layer2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)  # 輸入:[batch_size, 1, 28, 28] → [batch_size, 784]x = self.layer1(x)   # [batch_size, 784] → [batch_size, 128]x = self.relu(x)x = self.layer2(x)   # [batch_size, 128] → [batch_size, 10]return x

? ? PyTorch 模型會自動處理 batch 維度(即第一維),無論 batch_size 是多少,模型的計算邏輯都不變。batch_size 是在數據加載階段定義的,與模型結構無關。

? ? summary(model, input_size=(1, 28, 28))中的input_size不包含 batch 維度,只需指定樣本的形狀(通道 × 高 × 寬)。

三、顯存占用的主要組成部分

? ? 昨天說到了在面對數據集過大的情況下,由于無法一次性將數據全部加入到顯存中,所以采取了分批次加載這種方式。即一次只加載一部分數據,保證在顯存的范圍內。

? ? 那么顯存設置多少合適呢?如果設置的太小,那么每個batchsize的訓練不足以發揮顯卡的能力,浪費計算資源;如果設置的太大,會出現OOT(out of memory)

顯存一般被以下內容占用:

1. 模型參數與梯度:模型的權重(Parameters)和對應的梯度(Gradients)會占用顯存,尤其是深度神經網絡(如 Transformer、ResNet 等),一個 1 億參數的模型(如 BERT-base),單精度(float32)參數占用約 400MB(1e8×4Byte),加上梯度則翻倍至 800MB(每個權重參數都有其對應的梯度)。

2. 部分優化器(如 Adam)會為每個參數存儲動量(Momentum)和平方梯度(Square Gradient),進一步增加顯存占用(通常為參數大小的 2-3 倍)

3. 其他開銷。

from torch.utils.data import DataLoader# 定義訓練集的數據加載器,并指定batch_size
train_loader = DataLoader(dataset=train_dataset,  # 加載的數據集batch_size=64,          # 每次加載64張圖像shuffle=True            # 訓練時打亂數據順序
)# 定義測試集的數據加載器(通常batch_size更大,減少測試時間)
test_loader = DataLoader(dataset=test_dataset,batch_size=1000,shuffle=False
)

3.1 模型參數與梯度(FP32 精度)

  • 1字節(Byte)= 8位(bit),是計算機存儲的最小尋址單位。 ?
  • 位(bit)是二進制數的最小單位(0或1),例如`0b1010`表示4位二進制數。
  • 1KB=1024字節;1MB=1024KB=1,048,576字節

3.2 優化器狀態

? SGD

  • SGD優化器**不存儲額外動量**,因此無額外顯存占用。 ?
  • SGD 隨機梯度下降,最基礎的優化器,直接沿梯度反方向更新參數。
  • 參數更新公式:w = w - learning_rate * gradient

?Adam

  • Adam優化器:自適應學習率優化器,結合了動量(Momentum)和梯度平方的指數移動平均。 ?
  • 每個參數存儲動量(m)和平方梯度(v),占用約 `101,770 × 8 Byte ≈ 806 KB` ?
  • 動量(m):每個參數對應一個動量值,數據類型與參數相同(float32),占用 403 KB。
  • 梯度平方(v):每個參數對應一個梯度平方值,數據類型與參數相同(float32),占用 403 KB。

3.3.數據批量(batch_size)的顯存占用

  • 單張圖像尺寸:`1×28×28`(通道×高×寬),歸一化轉換為張量后為`float32`類型 ?

? ? ? ? ? 單張圖像顯存占用:`1×28×28×4 Byte = 3,136 Byte ≈ 3 KB` ?

  • 批量數據占用:`batch_size × 單張圖像占用` ?

? ? ? ? ? 例如:`batch_size=64` 時,數據占用為 `64×3 KB ≈ 192 KB` ?

? ? ? ? ?`batch_size=1024` 時,數據占用為 `1024×3 KB ≈ 3 MB`

3.4. 前向/反向傳播中間變量

  • 對于兩層MLP,中間變量(如`layer1`的輸出)占用較小: ?

? - `batch_size×128`維向量:`batch_size×128×4 Byte = batch_size×512 Byte` ?

? - 例如`batch_size=1024`時,中間變量約 `512 KB`

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/92730.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/92730.shtml
英文地址,請注明出處:http://en.pswp.cn/web/92730.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

十八、MySQL-DML-數據操作-插入(增加)、更新(修改)、刪除

DML數據操作添加數據更新(修改)數據刪除數據總結代碼: -- DML:數據操作語言-- -- DML:插入數據-insert -- 1.為tb_emp表的username,name,gender 字股插入值insert into tb_emp(username,name,gender,create_time,update_time) values (Toki,小時,2,now()…

Linux 安裝 JDK 8u291 教程(jdk-8u291-linux-x64.tar.gz 解壓配置詳細步驟)?

一、準備工作 ?下載 JDK 安裝包? 去 Oracle 官網或者可信的鏡像站下載: ?jdk-8u291-linux-x64.tar.gz? (這是一個壓縮包,不是安裝程序,解壓就能用) ?jdk-8u291-linux-x64.tar.gz?下載鏈接:https://pa…

藍橋杯----鎖存器、LED、蜂鳴器、繼電器、Motor

(七)、鎖存器1、原理藍橋杯中數據傳入口都是P0,也就是數碼管段選、位選數據、LED亮滅的數據、蜂鳴器啟動或禁用的數據,外設啟動或者關閉都需要通過P0寫入數據,那么如何這樣共用一個端口會造成沖突嘛,答案是肯定的。所以藍橋杯加入…

AI熱點周報(8.3~8.9):OpenAI重返開源,Anthropic放大招,Claude4.1、GPT5相繼發布

名人說:博觀而約取,厚積而薄發。——蘇軾《稼說送張琥》 創作者:Code_流蘇(CSDN)(一個喜歡古詩詞和編程的Coder😊) 目錄一、OpenAI的"開源回歸":時隔5年的戰略大轉彎1. GPT-OSS系列&a…

《Kubernetes部署篇:基于x86_64+aarch64架構CPU+containerd一鍵離線部署容器版K8S1.33.3高可用集群》

總結:整理不易,如果對你有幫助,可否點贊關注一下? 更多詳細內容請參考:企業級K8s集群運維實戰 一、部署背景 由于業務系統的特殊性,我們需要針對不同的客戶環境部署基于containerd容器版 K8S 1.33.3集群&a…

Linux抓包命令tcpdump詳解筆記

文章目錄一、tcpdump 是什么?二、基本語法三、常用參數說明四、抓包示例(通俗易懂)1. 抓所有數據包(默認 eth0)2. 指定接口抓包3. 抓取端口 80 的數據包(即 HTTP 請求)4. 抓取訪問某個 IP 的數據…

抖音、快手、視頻號等多平臺視頻解析下載 + 磁力嗅探下載、視頻加工(提取音頻 / 壓縮等)

跟你們說個安卓上的下載工具,還挺厲害的。它能支持好多種下載方式,具體多少種我沒細數,反正挺全乎的。? 平時用得最多的就是視頻解析,像抖音、快手、B 站上那些視頻,想存下來直接用它就行,連海外視頻的也能…

【iOS】JSONModel源碼學習

JSONModel源碼學習前言JSONModel的使用最基礎的使用轉換屬性名稱自定義錯誤模型嵌套JSONModel的繼承源碼實現initWithDictionaryinit__doesDictionaryimportDictionary優點前言 之前了解過JSONModel的一些使用方法等,但是對于底層實現并不清楚了解,今天…

SmartMediaKit 模塊化音視頻框架實戰指南:場景鏈路 + 能力矩陣全解析

?? 引言:從“內核能力”到“模塊體系”的演進 自 2015 年起,大牛直播SDK(SmartMediaKit)便致力于打造一個可深度嵌入、跨平臺兼容、模塊自由組合的實時音視頻基礎能力框架。經過多輪技術迭代與場景打磨,該 SDK 已覆…

【第5話:相機模型1】針孔相機、魚眼相機模型的介紹及其在自動駕駛中的作用及使用方法

相機模型介紹及相機模型在自動駕駛中的作用及使用方法 相機模型是計算機視覺中的核心概念,用于描述真實世界中的點如何投影到圖像平面上。在自動駕駛系統中,相機模型用于環境感知,如物體檢測和場景理解。下面我將詳細介紹針孔相機模型和魚眼相…

推薦一款優質的開源博客與內容管理系統

Halo是一款由Java Spring Boot打造的開源博客與內容管理系統(CMS),在 GitHub上擁有超過36K Start的活躍開發者社區。它使用GPL?3.0授權開源,穩定性與可維護性極高。 Halo的設計簡潔、注重性能,同時保持高度靈活性&a…

【GPT入門】第43課 使用LlamaFactory微調Llama3

【GPT入門】第43課 使用LlamaFactory微調Llama31.環境準備2. 下載基座模型3.LLaMA-Factory部署與啟動4. 重新訓練![在這里插入圖片描述](https://i-blog.csdnimg.cn/direct/e7aa869f8e2c4951a0983f0918e1b638.png)1.環境準備 采購autodl服務器,24G,GPU,型號3090&am…

計算機網絡:如何理解目的網絡不再是一個完整的分類網絡

這一理解主要源于無分類域間路由(CIDR)技術的廣泛應用,它打破了傳統的基于類的IP地址分配方式。具體可從以下方面理解: 傳統分類網絡的局限性:在早期互聯網中,IP地址被分為A、B、C等固定類別,每…

小米開源大模型 MiDashengLM-7B:不僅是“聽懂”,更能“理解”聲音

目錄 前言 一、一枚“重磅炸彈”:開源,意味著一扇大門的敞開 二、揭秘MiDashengLM-7B:它究竟“神”在哪里? 2.1 “超級耳朵” 與 “智慧大腦” 的協作 2.2 突破:從 “聽見文字” 到 “理解世界” 2.3 創新訓練&a…

mysql出現大量redolog、undolog排查以及解決方案

排查步驟 監控日志增長情況 -- 查看InnoDB狀態 SHOW ENGINE INNODB STATUS;-- 查看redo log配置和使用情況 SHOW VARIABLES LIKE innodb_log_file%; SHOW VARIABLES LIKE innodb_log_buffer_size;-- 查看undo log信息 SHOW VARIABLES LIKE innodb_undo%;檢查長時間運行的事務 -…

華為網路設備學習-28(BGP協議 三)路由策略

目錄: 一、BGP路由匯總1、注:使用network命令注入的BGP不會被自動匯總2、主類網絡號計算過程如下:3.示例 開啟BGP路由自動匯總bgp100 開啟BGP路由自動匯總import-route 直連路由 11.1.1.0 /24對端 為 10.1.12.2 AS 2004.手動配置BGP路…

微信小程序中實現表單數據實時驗證的方法

一、實時驗證的基本實現思路表單實時時驗證通過監聽表單元素的輸入事件,在用戶輸入過程中即時對數據進行校驗,并并即時反饋驗證結果,主要實現步驟包括:為每個表單字段綁定輸入事件在事件處理函數中獲取當前輸入值應用驗證規則進行…

openpnp - 頂部相機如果超過6.5米影響通訊質量,可以加USB3.0信號放大器延長線

文章目錄openpnp - 頂部相機如果超過6.5米影響通訊質量,可以加USB3.0信號放大器延長線概述備注ENDopenpnp - 頂部相機如果超過6.5米影響通訊質量,可以加USB3.0信號放大器延長線 概述 手頭有1080x720x60FPS的攝像頭模組備件,換上后&#xff…

【驅動】RK3576-Debian系統使用ping報錯:socket operation not permitted

1、問題描述 在RK3576-Debian系統中,連接了Wifi后,測試網絡通斷時,報錯: ping www.csdn.net ping: socktype: SOCK_RAW ping: socket: Operation not permitted ping: => missing cap_net_raw+p capability or setuid?2、原因分析 2.1 分析打印日志 socktype: SOCK…

opencv:圖像輪廓檢測與輪廓近似(附代碼)

目錄 圖像輪廓 cv2.findContours(img, mode, method) 繪制輪廓 輪廓特征與近似 輪廓特征 輪廓近似 輪廓近似原理 opencv 實現輪廓近似 輪廓外接矩形 輪廓外接圓 圖像輪廓 cv2.findContours(img, mode, method) mode:輪廓檢索模式(通常使用第四個模式&am…