VAE-變分自編碼器(Variational Autoencoder,VAE)

變分自編碼器(Variational Autoencoder,VAE)是一種生成模型,結合了概率圖模型與神經網絡技術,廣泛應用于數據生成、表示學習和數據壓縮等領域。以下是對VAE的詳細解釋和理解:

基本概念

1. 自編碼器(Autoencoder)

自編碼器是一種無監督學習模型,通常用于降維和特征提取。它由兩個主要部分組成:

  • 編碼器(Encoder):將輸入數據映射到一個低維隱變量空間。
  • 解碼器(Decoder):從低維隱變量空間重建輸入數據。
    自編碼器的目標是使重建的數據盡可能與原始輸入數據相似。

2. 變分自編碼器(VAE)

VAE 是自編碼器的一種擴展,它通過引入概率分布的概念來對隱變量空間進行建模。VAE 的目標不僅是重建輸入數據,還要使隱變量遵循某種已知的概率分布(通常是標準正態分布)。這樣可以通過采樣隱變量來生成新數據。

VAE的工作原理

  1. 編碼器
    在VAE中,編碼器不是直接輸出一個隱變量,而是輸出隱變量的參數(均值 μ 和標準差 σ)。這些參數定義了隱變量的一個概率分布,通常假設為正態分布 N(μ, σ^2)。

  2. 重新參數化技巧(Reparameterization Trick)
    為了使模型能夠通過梯度下降進行訓練,VAE引入了重新參數化技巧。通過采樣一個標準正態分布的變量 ε ~ N(0, 1),然后進行線性變換得到隱變量 z:
    在這里插入圖片描述

這樣,采樣操作變成了一個確定性的操作,允許梯度反向傳播。

  1. 解碼器
    解碼器接受從上述分布中采樣的隱變量 z,并嘗試重建輸入數據。解碼器的目標是最大化重建數據的概率。

損失函數

VAE 的損失函數由兩部分組成:

  • 重構損失(Reconstruction Loss):衡量重建數據與原始數據的相似度,通常使用均方誤差(MSE)或交叉熵損失。 KL

  • 散度(KL Divergence):衡量隱變量分布與標準正態分布的差異。通過最小化KL散度,使隱變量分布接近標準正態分布。

綜合起來,VAE的損失函數為:

在這里插入圖片描述

VAE的優點

  1. 生成能力:可以從隱變量空間采樣生成新數據,具有良好的生成能力。
  2. 隱變量解釋性:通過將隱變量空間約束為標準正態分布,隱變量具有一定的解釋性和可操作性。
  3. 無監督學習:VAE是一種無監督學習模型,不需要標簽數據即可進行訓練。

VAE的缺點

  1. **生成質量有限:**生成數據的質量有時不如GAN(生成對抗網絡)等其他生成模型。
  2. **訓練復雜:**VAE的訓練涉及到復雜的概率推斷和優化過程。

總結

變分自編碼器通過引入概率分布和重新參數化技巧,使得隱變量具有良好的生成能力和解釋性。其核心思想是在保持重建數據質量的同時,使隱變量遵循標準正態分布,從而實現數據生成和表示學習。盡管存在一些缺點,但VAE在許多應用場景中仍然表現出色,并為生成模型的研究提供了重要的理論基礎。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable# 定義VAE模型
class VAE(nn.Module):def __init__(self, input_dim, hidden_dim, latent_dim):super(VAE, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc21 = nn.Linear(hidden_dim, latent_dim)self.fc22 = nn.Linear(hidden_dim, latent_dim)self.fc3 = nn.Linear(latent_dim, hidden_dim)self.fc4 = nn.Linear(hidden_dim, input_dim)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparameterize(self, mu, logvar):std = torch.exp(0.5*logvar)eps = torch.randn_like(std)return mu + eps*stddef decode(self, z):h3 = F.relu(self.fc3(z))return torch.sigmoid(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x.view(-1, 784))z = self.reparameterize(mu, logvar)return self.decode(z), mu, logvar# 定義損失函數
def loss_function(recon_x, x, mu, logvar):BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return BCE + KLD# 加載MNIST數據集
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.ToTensor()),batch_size=128, shuffle=True)# 初始化模型
vae = VAE(input_dim=784, hidden_dim=512, latent_dim=20)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)# 訓練模型
def train(epoch):vae.train()train_loss = 0for batch_idx, (data, _) in enumerate(train_loader):optimizer.zero_grad()recon_batch, mu, logvar = vae(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader),loss.item() / len(data)))print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))# 開始訓練
for epoch in range(1, 11):train(epoch)

代碼說明

  • 編碼器和解碼器:編碼器將輸入圖像編碼為潛在空間的均值和對數方差,解碼器從潛在變量生成重建的圖像。
  • Sampling層:這是實現重參數化技巧的關鍵部分,將均值和對數方差轉換為潛在變量。
  • VAE類:組合編碼器和解碼器,并實現自定義訓練步驟,包括計算重建損失和KL散度損失。
  • 數據準備和訓練:加載MNIST數據集,對數據進行預處理,然后訓練VAE模型。
    這個示例展示了一個簡單的VAE模型。根據具體的應用需求,你可能需要調整網絡結構和超參數。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/14041.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/14041.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/14041.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

基于 Milvus Cloud + LlamaIndex 實現初級 RAG

初級 RAG 初級 RAG 的定義 初級 RAG 研究范式代表了最早的方法論,在 ChatGPT 廣泛采用后不久就取得了重要地位。初級 RAG 遵循傳統的流程,包括索引創建(Indexing)、檢索(Retrieval)和生成(Generation),常常被描繪成一個“檢索—讀取”框架,其工作流包括三個關鍵步…

AWS安全性身份和合規性之Key Management Service(KMS)

AWS Key Management Service(KMS)是一項用于創建和管理加密密鑰的托管服務,可幫助客戶保護其數據的安全性和機密性。 比如一家醫療保健公司需要在AWS上存儲敏感的病人健康數據,需要對數據進行加密以確保數據的機密性。他們使用AW…

課時134:awk實踐_邏輯控制_自定義函數

1.3.7 自定義函數 學習目標 這一節,我們從 基礎知識、簡單實踐、小結 三個方面來學習。 基礎知識 需求 雖然awk提供了內置的函數來實現相應的內置函數,但是有些功能場景,還是需要我們自己來設定,這就用到了awk的自定義函數功能…

WebSocket簡介

參考:Java NIO實現WebSocket服務器_nio websocket-CSDN博客 WebSocket API是HTML5中的一大特色,能夠使得建立連接的雙方在任意時刻相互推送消息,這意味著不同于HTTP,服務器服務器也可以主動向客戶端推送消息了。 WebSocket協議是…

使用TensorBoard記錄功能時,添加SummaryWriter到callbacks,某些版本可能不適用該如何修改

如果發現將SummaryWriter直接添加到callbacks不被支持,您可以采取另一種方式來集成TensorBoard記錄功能,即通過自定義回調函數來實現。Hugging Face Transformers庫允許用戶自定義訓練回調,這可以用來在訓練過程中向TensorBoard寫入日志。 下…

配置yum源

以下是在 Linux 系統中配置新的 yum 源的一般步驟和命令示例(以 CentOS 系統為例): 備份原有 yum 源配置文件:mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.bak 創建新的 yum 源配置文件&#xff08…

【PB案例學習筆記】-08 控件拖動實現

寫在前面 這是PB案例學習筆記系列文章的第8篇,該系列文章適合具有一定PB基礎的讀者。 通過一個個由淺入深的編程實戰案例學習,提高編程技巧,以保證小伙伴們能應付公司的各種開發需求。 文章中設計到的源碼,小凡都上傳到了gitee…

反序列化漏洞的入門知識總結

1.概念定義 序列化與反序列化的目的是讓數據在傳輸和處理的時候更簡單,更快,反序列化出現在多種同面向對象語言所開發的網站和軟件上,比如java,php,python等等,如果有語言一個都沒學的,可以先去…

1941springboot VUE 服務機構評估管理系統開發mysql數據庫web結構java編程計算機網頁源碼maven項目

一、源碼特點 springboot VUE服務機構評估管理系統是一套完善的完整信息管理類型系統,結合springboot框架和VUE完成本系統,對理解JSP java編程開發語言有幫助系統采用springboot框架(MVC模式開發),系統具有完整的源代…

【NOIP2014普及組復賽】題2:比例簡化

題2:比例簡化 【題目描述】 在社交媒體上,經常會看到針對某一個觀點同意與否的民意調查以及結果。例如,對某一觀點表示支持的有 1498 1498 1498 人,反對的有 902 902 902 人,那么贊同與反對的比例可以簡單的記為 …

計算機-編程相關

在 Linux 中、一切都是文件、硬件設備是文件、管道是文件、網絡套接字也是文件。 for https://juejin.cn/post/6844904103437582344 fork 進程的一些問題 fork 函數比較特殊、一次調用會返回兩次。在父進程和子進程都會返回。 每個進程在內核中都是一個 taskstruct 結構、for…

ECMAScript、BOM與DOM:網頁開發的三大基石

在深入Web開發的世界時,有三個核心概念構成了理解網頁如何工作以及如何與之交互的基礎:ECMAScript、BOM(Browser Object Model),以及DOM(Document Object Model)。本文旨在簡要介紹這三個概念&a…

Thingsboard規則鏈:Entity Type Switch節點詳解

在物聯網(IoT)領域,隨著設備數量的爆炸式增長和數據復雜性的增加,高效、靈活的數據處理機制變得至關重要。作為一款先進的物聯網平臺,ThingsBoard提供了強大的規則鏈(Rule Chains)功能&#xff…

第四節 Starter 加載時機和源碼理解

tips:每個 springBoot 的版本不同,代碼的實現存會存在不同。 上一章,我們聊到 mybatis-spring-boot-starter; 簡單分析了它的結構。 這一章我們將著重分析 Starter 的加載機制,并結合源碼進行分析理解。 一、加載實際…

問題與解決:element ui垂直菜單展開后顯示不全

比如我這個垂直菜單展開后,其實系統管理下面還有其他子菜單,但是顯示不出來了。 解決方法很簡單,只需要在菜單外面包一層el-scrollbar,并且將高度設置為100vh。

Laravel 11 PHP8

一直都是用laravel 7 左右的,現在要求將項目升級到laravel 11 和使用PHP8,隨手記錄一些小問題,laravel 11的包是領導給的,沒有使用composer 安裝,所以我也不確定和官方的是否一致 遇到這問題 可以這樣 env 中默認的數…

基于若依的旅游推薦管理系統(spring boot+vue+mybatis+Ajax)

一、項目目的 隨著社會的高速發展,人們生活水平的不斷提高,以及工作節奏的加快,旅游逐漸成為一個熱門的話題,因為其形式的多樣,涉及的面比較廣,成為人們放松壓力,調節情緒的首要選擇。 傳統的旅…

上位機圖像處理和嵌入式模塊部署(mcu的按鍵輸入)

【 聲明:版權所有,歡迎轉載,請勿用于商業用途。 聯系信箱:feixiaoxing 163.com】 做技術的同學,大部分都會把精力放在技術本身,卻忽視了學的東西有什么實際的用途。就拿gpio來說,一般我們點燈也…

正確認識IP地址和子網掩碼的聯系

IP地址和子網掩碼是計算機網絡中兩個非常重要的概念,它們共同確定了設備在局域網中的地址以及該地址所屬的子網,只要兩者結合,就能確定唯一地址IP66_ip歸屬地在線查詢_免費ip查詢_ip精準定位平臺。 IP地址是用于標識計算機網絡中的每臺設備的…

Ajax用法總結(包括原生Ajax、Jquery、Axois)

HTTP知識 HTTP(hypertext transport protocol)協議『超文本傳輸協議』,協議詳細規定了瀏覽器和萬維網服務器之間互相通信的規則。 請求報文 請求行: GET、POST /s?ieutf-8...(url的一長串參數) HTTP/1.1 請求頭…