深度學習訓練中的顯存溢出問題分析與優化:以UNet圖像去噪為例

最近在訓練一個基于 Tiny-UNet 的圖像去噪模型時,我遇到了經典但棘手的錯誤:
RuntimeError: CUDA out of memory。本文記錄了我如何從復現、分析,到逐步優化并成功解決該問題的全過程,希望對深度學習開發者有所借鑒。

  • 訓練數據:SIDD 小型圖像去噪數據集

  • 模型結構:簡化版 U-Net(Tiny-UNet)

  • class UNetDenoiser(nn.Module):def __init__(self):super(UNetDenoiser, self).__init__()# Encoderself.enc1 = self.conv_block(3, 64)self.enc2 = self.conv_block(64, 128)self.enc3 = self.conv_block(128, 256)self.pool = nn.MaxPool2d(2)# Bottleneckself.bottleneck = self.conv_block(256, 512)# Decoderself.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)self.dec3 = self.conv_block(512, 256)self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)self.dec2 = self.conv_block(256, 128)self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)self.dec1 = self.conv_block(128, 64)# Outputself.final = nn.Conv2d(64, 3, kernel_size=1)def conv_block(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True))def forward(self, x):# Encodere1 = self.enc1(x)            # [B, 64, H, W]e2 = self.enc2(self.pool(e1))  # [B, 128, H/2, W/2]e3 = self.enc3(self.pool(e2))  # [B, 256, H/4, W/4]# Bottleneckb = self.bottleneck(self.pool(e3))  # [B, 512, H/8, W/8]# Decoderd3 = self.up3(b)           # [B, 256, H/4, W/4]d3 = self.dec3(torch.cat([d3, e3], dim=1))d2 = self.up2(d3)          # [B, 128, H/2, W/2]d2 = self.dec2(torch.cat([d2, e2], dim=1))d1 = self.up1(d2)          # [B, 64, H, W]d1 = self.dec1(torch.cat([d1, e1], dim=1))return self.final(d1)

    源代碼:
    ?

    # train_denoiser.py
    import os
    import math
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms
    from torchvision.utils import save_image
    from PIL import Image# --- 數據集定義 ---
    class DenoisingDataset(Dataset):def __init__(self, noisy_dir, clean_dir, transform=None):self.noisy_paths = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir)])self.clean_paths = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir)])self.transform = transform if transform else transforms.ToTensor()def __len__(self):return len(self.noisy_paths)def __getitem__(self, idx):noisy_img = Image.open(self.noisy_paths[idx]).convert("RGB")clean_img = Image.open(self.clean_paths[idx]).convert("RGB")return self.transform(noisy_img), self.transform(clean_img)# --- 簡單 CNN 去噪模型 ---
    # class SimpleDenoiser(nn.Module):
    #     def __init__(self):
    #         super(SimpleDenoiser, self).__init__()
    #         self.encoder = nn.Sequential(
    #             nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),
    #             nn.Conv2d(64, 64, 3, padding=1), nn.ReLU()
    #         )
    #         self.decoder = nn.Sequential(
    #             nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
    #             nn.Conv2d(64, 3, 3, padding=1)
    #         )
    #
    #     def forward(self, x):
    #         x = self.encoder(x)
    #         x = self.decoder(x)
    #         return x
    class UNetDenoiser(nn.Module):def __init__(self):super(UNetDenoiser, self).__init__()# Encoderself.enc1 = self.conv_block(3, 64)self.enc2 = self.conv_block(64, 128)self.enc3 = self.conv_block(128, 256)self.pool = nn.MaxPool2d(2)# Bottleneckself.bottleneck = self.conv_block(256, 512)# Decoderself.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)self.dec3 = self.conv_block(512, 256)self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)self.dec2 = self.conv_block(256, 128)self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)self.dec1 = self.conv_block(128, 64)# Outputself.final = nn.Conv2d(64, 3, kernel_size=1)def conv_block(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True))def forward(self, x):# Encodere1 = self.enc1(x)            # [B, 64, H, W]e2 = self.enc2(self.pool(e1))  # [B, 128, H/2, W/2]e3 = self.enc3(self.pool(e2))  # [B, 256, H/4, W/4]# Bottleneckb = self.bottleneck(self.pool(e3))  # [B, 512, H/8, W/8]# Decoderd3 = self.up3(b)           # [B, 256, H/4, W/4]d3 = self.dec3(torch.cat([d3, e3], dim=1))d2 = self.up2(d3)          # [B, 128, H/2, W/2]d2 = self.dec2(torch.cat([d2, e2], dim=1))d1 = self.up1(d2)          # [B, 64, H, W]d1 = self.dec1(torch.cat([d1, e1], dim=1))return self.final(d1)# --- PSNR 計算函數 ---
    def calculate_psnr(img1, img2):mse = torch.mean((img1 - img2) ** 2)if mse == 0:return float("inf")return 20 * torch.log10(1.0 / torch.sqrt(mse))# --- 主訓練過程 ---
    def train_denoiser():noisy_dir = r"F:\SIDD數據集\archive\SIDD_Small_sRGB_Only\noisy"clean_dir = r"F:\SIDD數據集\archive\SIDD_Small_sRGB_Only\clean"batch_size = 1num_epochs = 50lr = 0.0005device = torch.device("cuda" if torch.cuda.is_available() else "cpu")dataset = DenoisingDataset(noisy_dir, clean_dir, transform=transforms.ToTensor())dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# model = SimpleDenoiser().to(device)# 替換為 UNetmodel = UNetDenoiser().to(device)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr=lr)for epoch in range(num_epochs):model.train()total_loss = 0.0total_psnr = 0.0for noisy, clean in dataloader:noisy, clean = noisy.to(device), clean.to(device)denoised = model(noisy)loss = criterion(denoised, clean)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()total_psnr += calculate_psnr(denoised, clean).item()avg_loss = total_loss / len(dataloader)avg_psnr = total_psnr / len(dataloader)print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_loss:.4f}, PSNR: {avg_psnr:.2f} dB")# 保存模型os.makedirs("weights", exist_ok=True)torch.save(model.state_dict(), "weights/denoiser.pth")print("模型已保存為 weights/denoiser.pth")if __name__ == "__main__":train_denoiser()
    

  • 顯卡:8GB 顯存的 RTX GPU

問題定位

我們從報錯堆棧中看到:

e3 = self.enc3(self.pool(e2))
RuntimeError: CUDA out of memory. Tried to allocate 746.00 MiB

說明問題發生在模型第三層 encoder(enc3)前的 pooling 后,這說明:

  1. 當前的輸入尺寸、batch size 占用了太多顯存;

  2. 或者模型本身結構太重;

  3. 又或者顯存未被合理管理(例如碎片化)。

分析與優化過程

第一步:降低 batch size

原始 batch size 設置為 16,直接觸發爆顯存。

我們嘗試逐步調小 batch size:

batch_size = 6  # 從16降低到6

觀察顯存變化,發現仍有波動。為更穩定,設置為 4 或動態適配:

batch_size = min(8, torch.cuda.get_device_properties(0).total_memory // estimated_sample_size)

?發現同樣的錯誤,顯存不知。分析可能是網絡參數太大了,或者訓練過程沒有啟動內存優化。導致的內存不足,這些可以通過策略進行改進,達到訓練的目的。

第二步:開啟 cuDNN 自動優化

torch.backends.cudnn.benchmark = True

cuDNN 會根據不同卷積輸入尺寸自動尋找最優算法,可能減少顯存使用。

第三步:開啟混合精度訓練 AMP(Automatic Mixed Precision)

from torch.cuda.amp import autocast, GradScalerscaler = GradScaler()with autocast():output = model(input)loss = criterion(output, target)scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
  • autocast() 自動在部分層使用 float16,提高速度并減小顯存壓力;

  • GradScaler 確保在 float16 條件下梯度依然穩定。

實測顯存使用降低近 30%,OOM 問題明顯緩解!

但以上訓練的預加載時間太慢,顯卡占有率過低,有點顯卡當前沒有任務----“偷懶”的意思。可能是數據的加載或者顯存抖動造成的。

第四步:優化 DataLoader 性能(間接緩解顯存抖動)

 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers, pin_memory=True)
  • num_workers?啟用多進程加載數據;

  • pin_memory=True 啟用固定內存,更快傳輸到 GPU。

雖然不直接節省顯存,但顯著減少顯存峰值抖動(尤其在小 batch 訓練時)。

第五步:檢查圖像輸入尺寸是否太大

原始圖像尺寸為 512×512:

transform = transforms.Compose([transforms.Resize((256, 256)),  # 降低分辨率transforms.ToTensor()
])

最終訓練代碼結構

我們將上述策略集成到了 train.py 腳本中(如下),包括:

  • Dataset & Dataloader 加速

  • 混合精度訓練

  • cuDNN 優化

  • 實時 PSNR 顯示

  • 自動保存模型權重

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm# --- 啟用 cuDNN 自動優化 ---
torch.backends.cudnn.benchmark = True# --- 數據集定義 ---
class DenoisingDataset(Dataset):def __init__(self, noisy_dir, clean_dir, transform=None):self.noisy_paths = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir)])self.clean_paths = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir)])self.transform = transform if transform else transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor()])def __len__(self):return len(self.noisy_paths)def __getitem__(self, idx):noisy_img = Image.open(self.noisy_paths[idx]).convert("RGB")clean_img = Image.open(self.clean_paths[idx]).convert("RGB")return self.transform(noisy_img), self.transform(clean_img)# --- Tiny UNet 模型 ---
class TinyUNet(nn.Module):def __init__(self):super(TinyUNet, self).__init__()self.enc1 = self.conv_block(3, 16)self.enc2 = self.conv_block(16, 32)self.enc3 = self.conv_block(32, 64)self.pool = nn.MaxPool2d(2)self.bottleneck = self.conv_block(64, 128)self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)self.dec3 = self.conv_block(128, 64)self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)self.dec2 = self.conv_block(64, 32)self.up1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)self.dec1 = self.conv_block(32, 16)self.final = nn.Conv2d(16, 3, kernel_size=1)def conv_block(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.ReLU(inplace=True))def forward(self, x):e1 = self.enc1(x)e2 = self.enc2(self.pool(e1))e3 = self.enc3(self.pool(e2))b = self.bottleneck(self.pool(e3))d3 = self.up3(b)d3 = self.dec3(torch.cat([d3, e3], dim=1))d2 = self.up2(d3)d2 = self.dec2(torch.cat([d2, e2], dim=1))d1 = self.up1(d2)d1 = self.dec1(torch.cat([d1, e1], dim=1))return self.final(d1)# --- PSNR 計算 ---
def calculate_psnr(img1, img2):mse = torch.mean((img1 - img2) ** 2)if mse == 0:return float("inf")return 20 * torch.log10(1.0 / torch.sqrt(mse))# --- 訓練函數 ---
def train_denoiser():noisy_dir = r"F:\SIDD數據集\archive\SIDD_Small_sRGB_Only\noisy"clean_dir = r"F:\SIDD數據集\archive\SIDD_Small_sRGB_Only\clean"batch_size = 6num_epochs = 50lr = 0.0005num_workers = 4device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor()])dataset = DenoisingDataset(noisy_dir, clean_dir, transform=transform)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers, pin_memory=True)model = TinyUNet().to(device)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr=lr)scaler = GradScaler()  # AMP 梯度縮放器os.makedirs("weights", exist_ok=True)for epoch in range(num_epochs):model.train()total_loss = 0.0total_psnr = 0.0for noisy, clean in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):noisy = noisy.to(device, non_blocking=True)clean = clean.to(device, non_blocking=True)optimizer.zero_grad()with autocast():  # 混合精度推理denoised = model(noisy)loss = criterion(denoised, clean)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()total_loss += loss.item()total_psnr += calculate_psnr(denoised.detach(), clean).item()avg_loss = total_loss / len(dataloader)avg_psnr = total_psnr / len(dataloader)print(f"? Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_loss:.4f}, PSNR: {avg_psnr:.2f} dB")torch.save(model.state_dict(), f"weights/tiny_unet_epoch{epoch+1}.pth")print("🎉 模型訓練完成,所有權重已保存至 weights/ 目錄")if __name__ == "__main__":train_denoiser()

最后得到的訓練文件,這里我設置的50次訓練迭代:

測試模型的推理效果

?原去帶噪聲圖片:

去噪后(可以看到這里仍然有bug,肉眼看效果并不是很好,需要進一步優化,考慮到模型的泛化性):

總結:處理 CUDA OOM 的思路模板

  1. 先查 batch size,這是最常見爆顯存原因;

  2. 確認輸入尺寸是否太大或未 resize;

  3. 啟用 AMP,簡單又高效;

  4. 合理設計模型結構(Tiny UNet > ResUNet);

  5. 使用 Dataloader 加速,避免數據傳輸抖動;

  6. 手動清理緩存防止 PyTorch 持有多余內存;

  7. 查看 PyTorch 顯存使用報告,加上:

print(torch.cuda.memory_summary())

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/76764.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/76764.shtml
英文地址,請注明出處:http://en.pswp.cn/web/76764.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

FramePack V2版 - 支持首尾幀生成,支持LoRA,支持批量,支持50系顯卡,一個強大的AI視頻生成軟件 本地一鍵整合包下載

FramePack 是斯坦福大學主導開發的視頻生成框架,是一種用于視頻生成的下一幀(下一幀部分)預測神經網絡結構,可以逐步生成視頻。FramePack 主要開發者之一,就是業內大名鼎鼎的張呂敏大佬,AI領域的“賽博佛祖…

STM32 HAL 通用定時器延時函數

使用通用定時器TIM3,實現ms、us延時。 delay.c #include "delay.h" #include "stm32f1xx_hal.h"TIM_HandleTypeDef htim3;/*** brief 初始化定時器3用于延時* param 無* retval 無*/ void Delay_Init(void) {TIM_ClockConfigTypeDef sClock…

軟件功能測試和非功能測試有什么區別和聯系?

軟件測試是保障軟件質量的核心環節,而軟件功能測試和非功能測試作為測試領域的兩大重要組成部分,承擔著不同但又相互關聯的職責。 軟件功能測試指的是通過驗證軟件系統的各項功能是否按照需求規格說明書來正確實現,確保軟件的功能和業務流程…

使用Java調用TensorFlow與PyTorch模型:DJL框架的應用探索

在現代機器學習的應用場景中,Python早已成為廣泛使用的語言,尤其是在深度學習框架TensorFlow和PyTorch的開發和應用中。盡管Java在許多企業級應用中占據一席之地,但因為缺乏直接使用深度學習框架的能力,往往使得Java開發者對機器學…

Docker安裝beef-xss

新版的kali系統中安裝了beef-xss會因為環境問題而無法啟動,可以使用Docker來安裝beef-xss,節省很多時間。 安裝步驟 1.啟動kali虛擬機,打開終端,切換到root用戶,然后執行下面的命令下載beef的docker鏡像 wget https:…

metasploit(2)生成dll木馬

聲明!本文章所有的工具分享僅僅只是供大家學習交流為主,切勿用于非法用途,如有任何觸犯法律的行為,均與本人及團隊無關!!! 一、dll文件基本概念 DLL 是一種包含可由多個程序同時使用的代碼和數…

5V 1A充電標準的由來與技術演進——從USB誕生到智能手機時代的電力革命

點擊下面圖片帶您領略全新的嵌入式學習路線 🔥爆款熱榜 88萬閱讀 1.6萬收藏 一、起源:USB標準與早期電力傳輸需求 1. USB的誕生背景 1996年,由英特爾、微軟、IBM等公司組成的USB-IF(USB Implementers Forum)發布了…

使用Python設置excel單元格的字體(font值)

一、前言 通過使用Python的openpyxl庫,來操作excel單元格,設置單元格的字體,也就是font值。 把學習的過程分享給大家。大佬勿噴! 二、程序展示 1、新建excel import openpyxl from openpyxl.styles import Font wb openpyxl.…

【設計模式】深入解析代理模式(委托模式):代理模式思想、靜態模式和動態模式定義與區別、靜態代理模式代碼實現

代理模式 代理模式,也叫委托模式。 Spring AOP 是基于動態代理來實現 AOP 的 定義 為其他對象提供一種代理 以控制對這個對象的訪問。它的作用就是通過提供一個代理類,讓我們在調用目標方法的時候,不再是直接對目標方法進行調用,而…

利用java語言,怎樣開發和利用各種開源庫和內部/自定義框架,實現“提取-轉換-加載”(ETL)流程的自動化

一、ETL 架構設計的核心要素? 在企業級數據處理場景中,ETL(Extract-Transform-Load)流程自動化是數據倉庫、數據湖建設的核心環節。基于 Java 生態的技術棧,我們可以構建分層解耦的 ETL 架構,主要包含以下四層結構&am…

2023藍帽杯初賽內存取證-8

也是用到pslist模塊,加上grep過濾”chrome“即可: vol.py --plugin/opt/volatility/plugins -f memdump.mem --profile Win7SP1x64 pslist | grep "chrome" 第一個是PID,第二個是PPID,第三個是線程數,第四個…

【C語言】動態內存的常見錯誤

前言&#xff1a; 在上章節中講解了動態內存的概念和管理的核心函數。 在本章節繼續為大家介紹動態內存的常見錯誤&#xff0c;讓大家更好的理解運用。 補充&#xff1a;使用內存函數需要頭文件<stdlib.h> 對NULL指針的解引用操作 當使用malloc、calloc或realloc等函…

uniapp-x 二維碼生成

支持X&#xff0c;二維碼生成&#xff0c;支持微信小程序&#xff0c;android&#xff0c;ios&#xff0c;網頁 - DCloud 插件市場 免費的單純用愛發電的

Linux內核之文件驅動隨筆

前言 近期需要實現linux系統文件防護功能&#xff0c;故此調研了些許知識&#xff0c;如何實現文件防護功能從而實現針對文件目錄防護功能。當被保護的目錄&#xff0c;禁止增刪改操作。通過內核層面實現相關功能&#xff0c;另外在通過跟應用層面交互從而實現具體的業務功能。…

利用大模型實現地理領域文檔中英文自動化翻譯

一、 背景描述 在跨國性企業日常經營過程中&#xff0c;經常會遇到專業性較強的文檔翻譯的需求&#xff0c;例如法律文書、商務合同、技術文檔等&#xff1b;以往遇到此類場景&#xff0c;企業內部往往需要指派專人投入數小時甚至數天來整理和翻譯&#xff0c;效率低下&#x…

鴻蒙Flutter倉庫停止更新?

停止更新 熟悉 Flutter 鴻蒙開發的小伙伴應該知道&#xff0c;Flutter 3.7.12 鴻蒙化 SDK 已經在開源鴻蒙社區發布快一年了&#xff0c; Flutter 3.22.x 的鴻蒙化適配一直由鴻蒙突擊隊倉庫提供&#xff0c;最近有小伙伴反饋已經 2 個多月沒有停止更新了&#xff0c;不少人以為停…

(七)深入了解AVFoundation-采集:采集系統架構與 AVCaptureSession 全面梳理

引言 在 iOS 開發中&#xff0c;AVFoundation 是構建音視頻功能的強大底層框架。而在音視頻功能中&#xff0c;“采集”往往是最基礎也是最關鍵的一環。從攝像頭捕捉圖形、到麥克風獲取聲音&#xff0c;構建一條高效且穩定的采集鏈是開發高質量音視頻應用的前提。 本系列將逐…

QML ShaderEffect(著色器效果)組件

ShaderEffect 是 QML 中用于實現自定義著色器效果的組件&#xff0c;允許開發者使用 GLSL 著色器語言創建圖形效果。 核心屬性 基本屬性 屬性類型默認值說明fragmentShaderstring""片段著色器代碼vertexShaderstring""頂點著色器代碼blendingbooltrue是…

基于javaweb的SSM教材征訂與發放管理系統設計與實現(源碼+文檔+部署講解)

技術范圍&#xff1a;SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬蟲、數據可視化、小程序、安卓app、大數據、物聯網、機器學習等設計與開發。 主要內容&#xff1a;免費功能設計、開題報告、任務書、中期檢查PPT、系統功能實現、代碼編寫、論文編寫和輔導、論文…

大模型學習筆記------Llama 3模型架構之分組查詢注意力(GQA)

大模型學習筆記------Llama 3模型架構之分組查詢注意力&#xff08;GQA&#xff09; 1、分組查詢注意力&#xff08;GQA&#xff09;的動機2、 多頭注意力&#xff08;Multi-Head Attention, MHA&#xff09;3、 多查詢注意力 (Multi-Query Attention&#xff0c;MQA)4、 分組查…