Pytorch實現一個簡單的貝葉斯卷積神經網絡模型

?

貝葉斯深度模型的主要特點和實現說明:

  1. 模型結構

    • 結合了常規卷積層(用于特征提取)和貝葉斯線性層(用于分類)
    • 貝葉斯層將權重視為隨機變量,而非傳統神經網絡中的確定值
    • 使用變分推斷來近似權重的后驗分布
  2. 貝葉斯特性

    • 通過重參數化技巧實現隨機變量的采樣,使得模型可訓練
    • 損失函數包含兩部分:分類損失(交叉熵)和 KL 散度(衡量近似后驗與先驗的差異)
    • 測試時通過多次采樣獲取預測分布,體現模型的不確定性
  3. 使用方法

    • 代碼會自動下載 MNIST 數據集并進行預處理
    • 支持 GPU 加速(如果可用)
    • 訓練完成后會繪制損失和準確率曲線,并保存模型
  4. 與傳統神經網絡的區別

    • 貝葉斯模型能夠提供預測的不確定性估計
    • 通常具有更好的泛化能力,不易過擬合
    • 訓練過程更復雜,計算成本更高
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt# 定義貝葉斯線性層 - 使用變分推斷近似后驗分布
class BayesianLinear(nn.Module):def __init__(self, in_features, out_features):super(BayesianLinear, self).__init__()self.in_features = in_featuresself.out_features = out_features# 先驗分布參數 (高斯分布)self.prior_mu = 0.0self.prior_sigma = 1.0# 變分參數 - 權重的均值和標準差self.mu_weight = nn.Parameter(torch.Tensor(out_features, in_features).normal_(0, 0.1))self.sigma_weight = nn.Parameter(torch.Tensor(out_features, in_features).fill_(0.1))# 變分參數 - 偏置的均值和標準差self.mu_bias = nn.Parameter(torch.Tensor(out_features).normal_(0, 0.1))self.sigma_bias = nn.Parameter(torch.Tensor(out_features).fill_(0.1))# 用于重參數化技巧的噪聲變量self.epsilon_weight = Noneself.epsilon_bias = Nonedef forward(self, x):# 重參數化技巧:將隨機采樣轉換為確定性操作,便于反向傳播if self.training:# 訓練時從近似后驗分布中采樣self.epsilon_weight = torch.normal(torch.zeros_like(self.mu_weight))self.epsilon_bias = torch.normal(torch.zeros_like(self.mu_bias))weight = self.mu_weight + self.sigma_weight * self.epsilon_weightbias = self.mu_bias + self.sigma_bias * self.epsilon_biaselse:# 測試時使用均值(最大后驗估計)weight = self.mu_weightbias = self.mu_bias# 計算KL散度(衡量近似后驗與先驗的差異)kl_loss = self._kl_divergence()return nn.functional.linear(x, weight, bias), kl_lossdef _kl_divergence(self):# 計算KL散度:KL(q(w) || p(w))kl_weight = 0.5 * torch.sum(1 + 2 * torch.log(self.sigma_weight) - torch.square(self.mu_weight) - torch.square(self.sigma_weight)) / (self.prior_sigma ** 2)kl_bias = 0.5 * torch.sum(1 + 2 * torch.log(self.sigma_bias) - torch.square(self.mu_bias) - torch.square(self.sigma_bias)) / (self.prior_sigma ** 2)return kl_weight + kl_bias# 定義貝葉斯卷積神經網絡模型
class BayesianCNN(nn.Module):def __init__(self, num_classes=10):super(BayesianCNN, self).__init__()# 卷積層使用常規卷積(為簡化模型)self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 全連接層使用貝葉斯層self.fc1 = BayesianLinear(64 * 7 * 7, 128)self.fc2 = BayesianLinear(128, num_classes)self.relu = nn.ReLU()def forward(self, x):# 卷積特征提取部分x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 64 * 7 * 7)  # 展平特征圖# 貝葉斯全連接部分x, kl1 = self.fc1(x)x = self.relu(x)x, kl2 = self.fc2(x)# 總KL散度total_kl = kl1 + kl2return x, total_kl# 訓練函數
def train(model, train_loader, optimizer, criterion, epoch, device):model.train()train_loss = 0correct = 0total = 0# KL散度的權重(根據數據集大小調整)kl_weight = 1.0 / len(train_loader.dataset)for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()# 前向傳播output, kl_loss = model(data)# 總損失 = 分類損失 + KL散度正則化loss = criterion(output, target) + kl_weight * kl_loss# 反向傳播和優化loss.backward()optimizer.step()# 統計train_loss += loss.item()_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()# 打印訓練進度if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')train_loss /= len(train_loader)train_acc = 100. * correct / totalprint(f'Train set: Average loss: {train_loss:.4f}, Accuracy: {correct}/{total} ({train_acc:.2f}%)')return train_loss, train_acc# 測試函數
def test(model, test_loader, criterion, device, num_samples=10):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)# 多次采樣以獲取預測分布(體現貝葉斯模型的不確定性)outputs = []for _ in range(num_samples):output, _ = model(data)outputs.append(output.unsqueeze(0))# 平均多次采樣的結果output = torch.mean(torch.cat(outputs, dim=0), dim=0)test_loss += criterion(output, target).item()# 統計準確率_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()test_loss /= len(test_loader)test_acc = 100. * correct / totalprint(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({test_acc:.2f}%)')return test_loss, test_acc# 主函數
def main():# 超參數設置batch_size = 64test_batch_size = 1000epochs = 10lr = 0.001seed = 42num_samples = 10  # 測試時的采樣次數,用于獲取預測分布# 設置設備(GPU或CPU)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 設置隨機種子,保證結果可復現torch.manual_seed(seed)# 數據預處理transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST數據集的均值和標準差])# 加載MNIST數據集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 創建數據加載器train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)# 初始化模型、損失函數和優化器model = BayesianCNN().to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)# 記錄訓練過程中的損失和準確率train_losses = []train_accs = []test_losses = []test_accs = []# 開始訓練和測試for epoch in range(1, epochs + 1):train_loss, train_acc = train(model, train_loader, optimizer, criterion, epoch, device)test_loss, test_acc = test(model, test_loader, criterion, device, num_samples)train_losses.append(train_loss)train_accs.append(train_acc)test_losses.append(test_loss)test_accs.append(test_acc)# 繪制訓練和測試損失曲線plt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)plt.plot(range(1, epochs + 1), train_losses, label='Train Loss')plt.plot(range(1, epochs + 1), test_losses, label='Test Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Loss vs Epoch')plt.legend()# 繪制訓練和測試準確率曲線plt.subplot(1, 2, 2)plt.plot(range(1, epochs + 1), train_accs, label='Train Accuracy')plt.plot(range(1, epochs + 1), test_accs, label='Test Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Accuracy vs Epoch')plt.legend()plt.tight_layout()plt.show()# 保存模型torch.save(model.state_dict(), 'bayesian_cnn_mnist.pth')print("Model saved as 'bayesian_cnn_mnist.pth'")if __name__ == '__main__':main()

在模型規模相似(例如參數總量、網絡深度和寬度相近)的情況下,普通卷積神經網絡(CNN)的訓練效率通常更高,訓練速度更快。這主要源于貝葉斯卷積神經網絡(Bayesian CNN)的特殊結構和訓練機制帶來的額外計算開銷,具體原因如下:

1.?參數數量與計算復雜度差異

普通 CNN 中,每個權重是確定值,每個層僅需存儲和優化一組權重參數(例如卷積核權重、偏置)。
而貝葉斯 CNN 中,權重被視為隨機變量(通常假設服從高斯分布),需要用變分推斷近似其 posterior 分布。這意味著每個權重需要學習兩個參數:均值(μ)?和標準差(σ)(或精度),參數數量幾乎是普通 CNN 的 2 倍(對于貝葉斯層而言)。

更多的參數直接導致:

  • 前向傳播時需要計算更多變量的組合(例如通過重參數化技巧采樣權重:weight = μ + σ·ε);
  • 反向傳播時需要計算更多參數的梯度(不僅是均值,還有標準差),增加了梯度計算的復雜度。

2.?額外的損失項計算

普通 CNN 的損失函數通常僅包含任務相關損失(例如分類問題的交叉熵損失)。
而貝葉斯 CNN 的損失函數必須包含兩部分:

  • 任務相關損失(與普通 CNN 相同);
  • KL 散度(KL divergence):用于衡量近似后驗分布與先驗分布的差異,作為正則化項。

KL 散度的計算需要對每個貝葉斯層的權重分布進行積分近似(即使是簡化的解析解,也需要對所有權重的均值和標準差進行逐元素運算),這會額外增加計算開銷,尤其當貝葉斯層較多時,累積開銷顯著。

3.?采樣操作的開銷

貝葉斯 CNN 在訓練時,為了通過重參數化技巧實現梯度回傳,需要對每個貝葉斯層的權重進行隨機采樣(例如從N(μ, σ2)中采樣噪聲ε,再計算weight = μ + σ·ε)。雖然采樣操作本身不算復雜,但在大規模網絡中,多次采樣(即使每個 batch 一次)會累積計算時間。

普通 CNN 則無需采樣,權重是確定性的,前向傳播更直接高效。

總結

在模型規模相似的情況下,普通 CNN 由于參數更少、計算流程更簡單(無額外的 KL 散度計算和采樣操作),訓練速度顯著快于貝葉斯 CNN。

貝葉斯 CNN 的優勢不在于訓練效率,而在于其能量化預測的不確定性(例如通過多次采樣得到預測分布),并在小樣本、數據噪聲大的場景下可能具有更好的泛化能力,但這是以更高的計算成本為代價的。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/94291.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/94291.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/94291.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Dubbo 3.x源碼(32)—Dubbo Provider處理服務調用請求源碼

基于Dubbo 3.1,詳細介紹了Dubbo Provider處理服務調用請求源碼 上文我們學習了,Dubbo消息的編碼解的源碼。現在我們來學習一下Dubbo Provider處理服務調用請求源碼。 當前consumer發起了rpc請求,經過請求編碼之后到達provider端,…

每日一leetcode:移動零

目錄 解題過程: 描述: 分析條件: 解題思路: 通過這道題可以學到什么: 解題過程: 描述: 給定一個數組 nums,編寫一個函數將所有 0 移動到數組的末尾,同時保持非零元素的相對順序。 請注意 ,必須在不復制數組的情況下原地對數組進行操…

6-Django項目實戰-[dtoken]-用戶登錄模塊

1.創建應用 python manage.py startapp dtoken 2.注冊應用 settings.py中注冊 3.匹配路由4.編寫登錄功能視圖函數 import hashlib import json import timeimport jwt from django.conf import settings from django.http import JsonResponse from user.models import UserPro…

Axure日期日歷高保真動態交互原型

在數字化產品設計中,日期日歷組件作為高頻交互元素,其功能完整性與用戶體驗直接影響著用戶對產品的信任度。本次帶來的日期日歷高保真動態交互原型,依照Element UI、View UI等主流前端框架為參考,通過動態面板、中繼器、函數、交互…

【YOLOv4】

YOLOv4 論文地址::【https://arxiv.org/pdf/2004.10934】 YOLOv4 論文中文翻譯地址:【深度學習論文閱讀目標檢測篇(七)中文版:YOLOv4《Optimal Speed and Accuracy of Object Detection》-CSDN博客】 yol…

【秋招筆試】2025.08.03蝦皮秋招筆試-第一題

?? 點擊直達筆試專欄 ??《大廠筆試突圍》 ?? 春秋招筆試突圍在線OJ ?? 筆試突圍在線刷題 bishipass.com 01. 蛋糕切分的最大收益 問題描述 K小姐經營著一家甜品店,今天她有一塊長度為 n n n 厘米的長條蛋糕需要切分。根據店里的規定,她必須將蛋糕切成至少 2 2

2.0 vue工程項目的創建

前提準備.需要電腦上已經安裝了nodejs 參考 7.nodejs和npm簡單使用_npmjs官網-CSDN博客 創建vue2工程 全局安裝 Vue CLI 在終端中運行以下命令來全局安裝 Vue CLI: npm install -g vue/cli npm install -g 表示全局安裝。vue/cli 是 Vue CLI 的包名。 安裝完成后…

視覺圖像處理中級篇 [2]—— 外觀檢查 / 傷痕模式的原理與優化設置方法

外觀缺陷檢測是工業生產中的關鍵環節,而傷痕模式作為圖像處理的核心算法,能精準識別工件表面的劃痕、污跡等缺陷。掌握其原理和優化方法,對提升檢測效率至關重要。一、利用傷痕模式進行外觀檢查雖然總稱為外觀檢查,但根據檢查對象…

ethtool,lspci,iperf工具常用命令總結

ethtool、lspci 和 iperf 是 Linux 系統中進行網絡硬件查看、配置和性能測試的核心命令行工具。下面是它們的常用命令分析和總結: 核心作用總結: lspci: 偵察兵 - 列出系統所有 PCI/PCIe 總線上的硬件設備信息,主要用于識別網卡型號、制造商、…

DAY10DAY11-新世紀DL(DeepLearning/深度學習)戰士:序

本文參考視頻[雙語字幕]吳恩達深度學習deeplearning.ai_嗶哩嗶哩_bilibili 參考文章0.0 目錄-深度學習第一課《神經網絡與深度學習》-Stanford吳恩達教授-CSDN博客 1深度學習概論 1.舉例介紹 lg房價預測:房價與面積之間的坐標關系如圖所示,由線性回歸…

flutter release調試插件

chucker_flutter (只有網絡請求的信息,親測可以用) flutter:3.24.3 使用版本 chucker_flutter: 1.8.2 chucker_flutter | Flutter package void main() async {// 可以控制顯示ChuckerFlutter.showNotification false;ChuckerF…

基于開源鏈動2+1模式AI智能名片S2B2C商城小程序的私域流量拉新策略研究

摘要:私域流量運營已成為企業數字化轉型的核心戰略,其本質是通過精細化用戶運營實現流量價值最大化。本文以“定位、拉新、養熟、成交、裂變、留存”全鏈路為框架,聚焦開源鏈動21模式、AI智能名片與S2B2C商城小程序的協同創新,揭示…

華為云云服務高級顧問葉正暉:華為對多模態大模型的思考與實踐

嘉賓介紹:葉正暉,華為云云服務高級顧問,全球化企業信息化專家,從業年限超過23年,在華為任職超過21年,涉及運營商、企業、消費者、云服務、安全與隱私等領域,精通云服務、安全合規、隱私保護等領…

【機器學習(二)】KNN算法與模型評估調優

目錄 一、寫在前面的話 二、KNN(K-Nearest Neighbor) 2.1 KNN算法介紹 2.1.1 概念介紹 2.1.2 算法特點 2.1.3 API 講解 2.2 樣本距離計算 2.2.1 距離的類型 (1)歐幾里得距離(Euclidean Distance) …

《Uniapp-Vue 3-TS 實戰開發》實現自定義頭部導航欄

本文介紹了如何將Vue2組件遷移至Vue3的組合式API。主要內容包括:1) 使用<script setup lang="ts">語法;2) 通過接口定義props類型約束;3) 用defineProps替代props選項;4) 將data變量轉為ref響應式變量;5) 使用computed替代計算屬性;6) 將created生命周期…

GitCode疑難問題診療

問題診斷與解決框架通用問題排查流程&#xff08;適用于大多數場景&#xff09; 版本兼容性驗證方法 網絡連接與權限檢查清單常見錯誤分類與解決方案倉庫克隆失敗場景分析 HTTP/SSH協議錯誤代碼解讀 403/404錯誤深層原因排查高級疑難問題處理分支合并沖突的深度解決 .gitignore…

告別物業思維:科技正重構產業園區的價值坐標系

文 | 方寸控股引言&#xff1a;當產業園區的競爭升維為“科技軍備競賽”&#xff0c;土地紅利消退&#xff0c;政策優勢趨同&#xff0c;傳統園區運營陷入增長困局。當招商團隊還在用Excel統計企業需求&#xff0c;當能耗管理依賴保安夜間巡檢&#xff0c;當企業服務停留在“修…

GitHub 熱門項目 PandaWiki:零門檻搭建智能漏洞庫,支持 10 + 大模型接入

轉自&#xff1a;Khan安全團隊你還沒有自己的漏洞庫嗎&#xff1f;一條命令教你搭建。PandaWiki 是一款 AI 大模型驅動的開源知識庫搭建系統&#xff0c;幫助你快速構建智能化的 產品文檔、技術文檔、FAQ、博客系統&#xff0c;借助大模型的力量為你提供 AI 創作、AI 問答、AI …

Python 程序設計講義(55):Python 的函數——函數的參數

Python 程序設計講義&#xff08;55&#xff09;&#xff1a;Python 的函數——函數的參數 目錄Python 程序設計講義&#xff08;55&#xff09;&#xff1a;Python 的函數——函數的參數一、聲明形參二、傳遞實參&#xff08;位置參數&#xff09;1、在調用函數進行傳遞參數時…

機器學習sklearn:支持向量機svm

概述&#xff1a;現在就只知道這個svm可以畫出決策邊界&#xff0c;對數據的劃分。簡單舉例就是&#xff1a;好的和壞的數據分開&#xff0c;中間的再驗證from sklearn.datasets import make_blobs from sklearn.svm import SVC import matplotlib.pyplot as plt import numpy …