DAY 45 Tensorboard使用介紹

@浙大疏錦行https://blog.csdn.net/weixin_45655710知識點回顧:

  1. tensorboard的發展歷史和原理
  2. tensorboard的常見操作
  3. tensorboard在cifar上的實戰:MLP和CNN模型

作業:對resnet18在cifar10上采用微調策略下,用tensorboard監控訓練過程。

核心:

  1. 數據加載和模型創建:復用之前的函數,保持模塊化。

  2. SummaryWriter初始化:創建TensorBoard的寫入器,并自動處理日志目錄,避免覆蓋。

  3. train_and_evaluate函數:創建一個總控函數,封裝了完整的“凍結-解凍”訓練循環,并在其中集成了TensorBoard的各種日志記錄功能。

  4. TensorBoard日志記錄

  • 模型圖譜 (Graph):在訓練開始前,記錄模型的計算圖。
  • 標量 (Scalars):實時記錄訓練集和測試集的損失(Loss)與準確率(Accuracy),以及學習率(Learning Rate)的變化。
  • 圖像 (Images):記錄輸入的樣本圖像和每個epoch結束時預測錯誤的樣本。
  • 直方圖 (Histograms):定期記錄模型各層權重(Weights)和梯度(Gradients)的分布,用于診斷訓練狀態。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter # 導入TensorBoard的核心類
import matplotlib.pyplot as plt
import os
import time
from tqdm import tqdm
import torchvision # 確保torchvision被導入以使用make_grid# --- 步驟 1: 準備數據加載器 (保持不變) ---
def get_cifar10_loaders(batch_size=128):"""獲取CIFAR-10的數據加載器,包含數據增強"""train_transform = transforms.Compose([transforms.RandomResizedCrop(224), # ResNet通常在224x224的圖像上預訓練transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet的標準化參數])test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)return train_loader, test_loader# --- 步驟 2: 模型創建與凍結/解凍函數 (保持不變) ---
def create_resnet18(pretrained=True, num_classes=10):"""創建并修改ResNet18模型"""model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)return modeldef set_freeze_state(model, freeze=True):"""凍結或解凍模型的特征提取層"""print(f"--- {'凍結' if freeze else '解凍'} 特征提取層 ---")for name, param in model.named_parameters():if 'fc' not in name: # 只訓練最后的全連接層param.requires_grad = not freeze# --- 步驟 3: 封裝了TensorBoard的訓練與評估總控函數 ---
def train_with_tensorboard(model, device, train_loader, test_loader, epochs, freeze_epochs, writer):"""使用TensorBoard監控的完整訓練流程"""# 初始化優化器和損失函數criterion = nn.CrossEntropyLoss()# 初始只優化未凍結的參數optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5, verbose=True)# --- TensorBoard初始記錄 ---print("正在記錄初始信息到TensorBoard...")dataiter = iter(train_loader)images, _ = next(dataiter)writer.add_graph(model, images.to(device)) # 記錄模型圖img_grid = torchvision.utils.make_grid(images[:16]) # 取16張圖預覽writer.add_image('CIFAR-10 樣本圖像', img_grid)print("? 初始信息記錄完成。")# 開始訓練global_step = 0for epoch in range(1, epochs + 1):# --- 解凍控制 ---if epoch == freeze_epochs + 1:set_freeze_state(model, freeze=False)# 解凍后需要為優化器加入所有參數optimizer = optim.Adam(model.parameters(), lr=1e-4) # 使用更小的學習率進行全局微調print("優化器已更新以包含所有參數,學習率已降低。")# --- 訓練部分 ---model.train()train_loss, train_correct, train_total = 0, 0, 0loop = tqdm(train_loader, desc=f"Epoch [{epoch}/{epochs}] Training", leave=False)for data, target in loop:data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item() * data.size(0)_, pred = output.max(1)train_correct += pred.eq(target).sum().item()train_total += data.size(0)writer.add_scalar('Train/Batch_Loss', loss.item(), global_step)global_step += 1loop.set_postfix(loss=loss.item())loop.close()# 記錄Epoch級訓練指標avg_train_loss = train_loss / train_totalavg_train_acc = 100. * train_correct / train_totalwriter.add_scalar('Train/Epoch_Loss', avg_train_loss, epoch)writer.add_scalar('Train/Epoch_Accuracy', avg_train_acc, epoch)# --- 評估部分 ---model.eval()test_loss, test_correct, test_total = 0, 0, 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)loss = criterion(output, target)test_loss += loss.item() * data.size(0)_, pred = output.max(1)test_correct += pred.eq(target).sum().item()test_total += data.size(0)# 記錄Epoch級測試指標avg_test_loss = test_loss / test_totalavg_test_acc = 100. * test_correct / test_totalwriter.add_scalar('Test/Epoch_Loss', avg_test_loss, epoch)writer.add_scalar('Test/Epoch_Accuracy', avg_test_acc, epoch)# 記錄權重和梯度的直方圖 (每個epoch記錄一次)for name, param in model.named_parameters():writer.add_histogram(f'Weights/{name}', param, epoch)if param.grad is not None:writer.add_histogram(f'Gradients/{name}', param.grad, epoch)# 更新學習率調度器scheduler.step(avg_test_loss)writer.add_scalar('Train/Learning_Rate', optimizer.param_groups[0]['lr'], epoch)print(f"Epoch {epoch} 完成 | 訓練準確率: {avg_train_acc:.2f}% | 測試準確率: {avg_test_acc:.2f}%")# --- 步驟 4: 主執行流程 ---
if __name__ == "__main__":# --- 配置 ---EPOCHS = 15FREEZE_EPOCHS = 5 # 先凍結訓練5輪,再解凍訓練10輪BATCH_SIZE = 64DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")# --- TensorBoard 初始化 ---log_dir = "runs/resnet18_finetune_cifar10"version = 1while os.path.exists(f"{log_dir}_v{version}"):version += 1log_dir = f"{log_dir}_v{version}"writer = SummaryWriter(log_dir)print(f"TensorBoard 日志將保存在: {log_dir}")# --- 開始實驗 ---train_loader, test_loader = get_cifar10_loaders(batch_size=BATCH_SIZE)model = create_resnet18(pretrained=True).to(DEVICE)set_freeze_state(model, freeze=True) # 初始凍結print("\n--- 開始使用ResNet18微調模型 ---")print("訓練完成后,在終端運行 `tensorboard --logdir=runs` 來查看可視化結果。")train_with_tensorboard(model, DEVICE, train_loader, test_loader, EPOCHS, FREEZE_EPOCHS, writer)writer.close() # 關閉writerprint("\n? 訓練完成,TensorBoard日志已保存。")

解析

1.數據預處理適配 (get_cifar10_loaders)

圖像尺寸ResNet系列是在224x224的ImageNet圖像上預訓練的。雖然它們也能處理32x32的CIFAR-10圖像,但為了更好地利用預訓練權重,一個常見的做法是將小圖像放大224x224。我們在transforms中加入了transforms.RandomResizedCrop(224)transforms.Resize(256) / transforms.CenterCrop(224)來實現這一點。

標準化參數:使用了ImageNet數據集的標準化均值和標準差,這是使用在ImageNet上預訓練的模型的標準做法

2.模塊化訓練流程 (train_with_tensorboard)

將整個包含“凍結-解凍”邏輯的訓練循環封裝成一個函數,使得主程序非常簡潔。

該函數接收一個writer對象作為參數,所有TensorBoard的日志記錄都在這個函數內部完成。

3.TensorBoard全面監控

  • 模型圖 (add_graph):在訓練開始前,將模型的結構圖寫入日志,方便在GRAPHS標簽頁查看。
  • 圖像 (add_image):將一批原始訓練樣本寫入日志,可以在IMAGES標簽頁直觀地看到輸入數據。
  • 標量 ( add_scalar )

Batch級:記錄了每個訓練批次的損失(Train/Batch_Loss),可以觀察到最細粒度的訓練動態。

Epoch級:記錄了每個輪次結束后的訓練和測試的損失準確率,以及學習率的變化。這能讓我們在同一個圖表中清晰地對比訓練集和測試集的性能曲線,判斷過擬合。

  • 直方圖 (add_histogram):每個輪次結束后,記錄模型所有可訓練參數的權重分布梯度分布。這對于高級調試非常有用,可以幫助判斷是否存在梯度消失/爆炸,或者權重是否更新正常。

4.清晰的執行邏輯

if __name__ == "__main__":中,代碼邏輯非常清晰:設置參數 -> 初始化TensorBoard寫入器 -> 準備數據 -> 創建模型 -> 調用總控函數開始訓練 -> 結束并關閉寫入器。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/89323.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/89323.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/89323.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

2023年全國碩士研究生招生考試英語(一)試題總結

文章目錄 題型與分值分布完形填空錯誤 1:考察連詞 or 前后內容之間的邏輯關系錯誤2:錯誤3:錯誤4:這個錯得最有價值,因為壓根沒讀懂錯誤5:學到的短語: 仔細閱讀排序/新題型翻譯小作文大作文 題型…

react-數據Mock實現——json-server

什么是mock? 在前后端分離的開發模式下,前端可以在沒有實際后端接口的支持下先進行接口數據的模擬,進行正常的業務功能開發 json-server實現數據Mock json-server是一個node的包,可以在不到30秒內獲得零編碼的完整Mock服務 實現…

使用POI導入解析excel文件

首先校驗 /*** 校驗導入文件* param file 上傳的文件* return 校驗結果,成功返回包含成功狀態的AjaxResult,失敗返回包含錯誤信息的AjaxResult*/private AjaxResult validateImportFile(MultipartFile file) {if (file.isEmpty()) {return AjaxResult.er…

從0開始學習計算機視覺--Day06--反向傳播算法

盡管解析梯度可以讓我們省去巨大的計算量,但如果函數比較復雜,對這個損失函數進行微分計算會變得很困難。我們通常會用反向傳播技術來遞歸地調用鏈式法則來計算向量每一個方向上的梯度。具體來說,我們將整個計算過程的輸入與輸入具體化&#…

企業流程知識:《學習觀察:通過價值流圖創造價值、消除浪費》讀書筆記

《學習觀察:通過價值流圖創造價值、消除浪費》讀書筆記 作者:邁克魯斯(Mike Rother),約翰舒克(John Shook) 出版時間:1999年 歷史地位:精益生產可視化工具的黃金標準&am…

Day02_C語言IO進程線程

01.思維導圖 02.將當前的時間寫入到time. txt的文件中,如果ctrlc退出之后,在再次執行支持斷點續寫 1.2022-04-26 19:10:20 2.2022-04-26 19:10:21 3.2022-04-26 19:10:22 //按下ctrlc停止,再次執行程序 4.2022-04-26 20:00:00 5.2022-04-26 2…

FFmpeg中TS與MP4格式的extradata差異詳解

在視頻處理中,extradata是存儲解碼器初始化參數的核心元數據,直接影響視頻能否正確解碼。本文深入解析TS和MP4格式中extradata的結構差異、存儲邏輯及FFmpeg處理方案。 📌 一、extradata的核心作用 extradata是解碼必需的參數集合&#xff0…

【CV數據集介紹-40】Cityscapes 數據集:助力自動駕駛的語義分割神器

🧑 博主簡介:曾任某智慧城市類企業算法總監,目前在美國市場的物流公司從事高級算法工程師一職,深耕人工智能領域,精通python數據挖掘、可視化、機器學習等,發表過AI相關的專利并多次在AI類比賽中獲獎。CSDN…

SAP月結問題9-FAGLL03H與損益表中研發費用金額不一致(FAGLL03H Bug)

SAP月結問題9-FAGLL03H與損益表中研發費用金額不一致(S4 1709) 財務反饋,月結后核對數據時發現FAGLL03H導出的研發費用與損益表中的研發費用不一致,如下圖所示: 對比FAGLL03H與損益表對應的明細,發現FAGLL03H與損益表數據存在倍數…

HTML inputmode 屬性詳解

inputmode 是一個 HTML 屬性&#xff0c;用于指定用戶在編輯元素或其內容時應使用的虛擬鍵盤布局類型。它主要影響移動設備和平板電腦的輸入體驗。 語法 <input inputmode"value"> <!-- 或 --> <textarea inputmode"value"></texta…

軟考中級【網絡工程師】第6版教材 第1章 計算機網絡概述

考點分析&#xff1a; 本章重要程度&#xff1a;一般&#xff0c;為后續章節做鋪墊&#xff0c;有總體認識即可&#xff0c;選擇題1-2分高頻考點&#xff1a;OSI模型、TCP/IP模型、每個層次的功能、協議層次新教材變化&#xff1a;刪除網絡結構、刪除X.25、更新互聯網發展【基本…

Mysql事務與鎖

數據庫并發事務 數據庫一般都會并發執行多個事務&#xff0c;多個事務可能會并發的對相同的一批數據進行增刪改查操作&#xff0c;可能就會導致我們說的臟寫、臟讀、不可重復讀、幻讀這些問題。為了解決這些并發事務的問題&#xff0c;數據庫設計了事務隔離機制、鎖機制、MVCC多…

Bilibili多語言字幕翻譯擴展:基于上下文的實時翻譯方案設計

Bilibili多語言字幕翻譯擴展&#xff1a;基于上下文的實時翻譯方案設計 本文介紹了一個Chrome擴展的設計與實現&#xff0c;該擴展可以為Bilibili視頻提供實時多語言字幕翻譯功能。重點討論了字幕翻譯中的上下文問題及其解決方案。 該項目已經登陸Chrome Extension Store: http…

熱血三國野地名將列表

<!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>野地名將信息表</title><style>tabl…

【記錄】Word|Word創建自動編號的多級列表標題樣式

文章目錄 前言創建方式第一種方法&#xff1a;從“定義多級列表”中直接綁定已有樣式第二種方法&#xff1a;通過已有段落創建樣式&#xff0c;再綁定補充說明 尾聲 前言 這世上荒唐的事情不少&#xff0c;但若說到吊詭&#xff0c;Word中的多級列表樣式設定&#xff0c;倒是能…

使用mavros啟動多機SITL仿真

使用mavros啟動多機SITL仿真 方式1&#xff1a;使用roslaunch一鍵啟動Step1&#xff1a;創建一個新的 ROS 包或放到現有包里Step2&#xff1a;編輯 multi_mavros.launchStep3&#xff1a;構建工作空間并 source 環境Step4&#xff1a;構建工作空間并 source 環境 方式2&#xf…

Flutter 網絡棧入門,Dio 與 Retrofit 全面指南

面向多年 iOS 開發者的零阻力上手 寫在前面 你在 iOS 項目中也許習慣了 URLSession、Alamofire 或 Moya。 換到 Flutter 后&#xff0c;等價的「組合拳」就是 Dio Retrofit。 本文將帶你一次吃透兩套庫的安裝、核心 API、進階技巧與最佳實踐。 1. Dio&#xff1a;Flutter 里的…

工作室考核源碼(帶后端)

題目內容可更改 下載地址:https://mcwlkj.lanzoub.com/iUF3z300tgfe 如圖所示

數字孿生技術為UI前端提供全面支持:實現產品的可視化配置與定制

hello寶子們...我們是艾斯視覺擅長ui設計、前端開發、數字孿生、大數據、三維建模、三維動畫10年經驗!希望我的分享能幫助到您!如需幫助可以評論關注私信我們一起探討!致敬感謝感恩! 一、引言&#xff1a;數字孿生驅動產品定制的技術革命 在消費升級與工業 4.0 的雙重驅動下&a…

通往物理世界自主智能的二元實在論與羅塞塔協議

序章&#xff1a;AI的“兩種文化”之爭——我們是否在構建錯誤的“神”&#xff1f; 自誕生以來&#xff0c;人工智能領域始終存在著一場隱秘的“兩種文化”之爭。一方是符號主義與邏輯的信徒&#xff0c;他們追求可解釋、嚴謹的推理&#xff0c;相信智能的核心在于對世界規則…