生成式對抗網絡(GAN)模型原理概述

??生成對抗網絡(Generative Adversarial Network, GAN)是一種通過對抗訓練生成數據的深度學習模型,由生成器(Generator)和判別器(Discriminator)兩部分組成,其核心思想源于博弈論中的零和博弈。

一、核心組成

生成器(G)

??目標:生成逼真的假數據(如圖像、文本),試圖欺騙判別器。
輸入:隨機噪聲(通常服從高斯分布或均勻分布)。
輸出:合成數據(如假圖像)。

判別器(D)

??目標:區分真實數據(來自訓練集)和生成器合成的假數據。
輸出:概率值(0到1),表示輸入數據是真實的概率。

二、關于對抗訓練

1. 動態博弈

??1)生成器嘗試生成越來越逼真的數據,使得判別器無法區分真假。
2)判別器則不斷優化自身,以更準確地區分真假數據。
3)兩者交替訓練,最終達到納什均衡(生成器生成的數據與真實數據分布一致,判別器無法區分,輸出概率恒為0.5)。

2. 優化目標(極小極大博弈)

min?Gmax?DV(D,G)=Ex~pdata[logD(x)]+Ez~pz[log(1?D(G(z)))]\min_{G}{\max_D}V(D,G)=E_{x\sim p_{data}}[logD(x)]+E_{z\sim p_z}[log(1-D(G(z)))]Gmin?Dmax?V(D,G)=Expdata??[logD(x)]+Ezpz??[log(1?D(G(z)))]

??其中,
D(x)D(x)D(x):判別器對真實數據的判別結果;
G(z)G(z)G(z):生成器生成的假數據;
判別器希望最大化V(D,G)V(D,G)V(D,G)(正確分類真假數據);
生成器希望最小化V(D,G)V(D,G)V(D,G)(讓判別器無法區分)。

3.交替更新

1) 固定生成器,訓練判別器:

??用真實數據(標簽1)和生成數據(標簽0)訓練判別器,提高其鑒別能力。

2) 固定判別器,訓練生成器:

??通過反向傳播調整生成器參數,使得判別器對生成數據的輸出概率接近1(即欺騙判別器)。

三、典型應用

??圖像生成:生成逼真的人臉、風景、藝術畫(如 DCGAN、StyleGAN);
圖像編輯:圖像修復(填補缺失區域)、風格遷移(如將照片轉為油畫風格);
數據增強:為小樣本任務生成額外的訓練數據;
超分辨率重建:將低分辨率圖像恢復為高分辨率圖像。

四、優勢與挑戰

優勢

??無監督學習:無需對數據進行標注,僅通過真實數據即可訓練(適用于標注成本高的場景)。

??生成高質量數據:相比其他生成模型(如變分自編碼器 VAE),GAN 在圖像生成等任務中往往能生成更逼真、細節更豐富的數據。

??靈活性:生成器和判別器可以采用不同的網絡結構(如卷積神經網絡 CNN、循環神經網絡 RNN 等),適用于多種數據類型(圖像、文本、音頻等)。

挑戰

??訓練不穩定:容易出現 “模式崩潰”(生成器只生成少數幾種相似數據,缺乏多樣性)或難以收斂;

??平衡難題:生成器和判別器的能力需要匹配,否則可能一方過強導致另一方無法學習(如判別器太弱,生成器無需優化即可欺騙它);

??可解釋性差:生成器的內部工作機制難以解釋,生成結果的可控性較弱(近年通過改進模型如 StyleGAN 緩解了這一問題)。

五、Python示例

??使用 PyTorch 實現簡單 的GAN 模型,生成手寫數字圖像。

import matplotlib
matplotlib.use('TkAgg')import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as npplt.rcParams['font.sans-serif']=['SimHei']  # 中文支持
plt.rcParams['axes.unicode_minus']=False  # 負號顯示# 設置隨機種子,確保結果可復現
torch.manual_seed(42)
np.random.seed(42)# 定義設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 數據加載和預處理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 將圖像歸一化到 [-1, 1]
])train_dataset = datasets.MNIST(root='./data', train=True,download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 定義生成器網絡
class Generator(nn.Module):def __init__(self, latent_dim=100, img_dim=784):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, 256),nn.LeakyReLU(0.2),nn.BatchNorm1d(256),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.BatchNorm1d(512),nn.Linear(512, img_dim),nn.Tanh()  # 輸出范圍 [-1, 1])def forward(self, z):return self.model(z).view(z.size(0), 1, 28, 28)# 定義判別器網絡
class Discriminator(nn.Module):def __init__(self, img_dim=784):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(img_dim, 512),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(256, 1),nn.Sigmoid()  # 輸出概率值)def forward(self, img):img_flat = img.view(img.size(0), -1)return self.model(img_flat)# 初始化模型
latent_dim = 100
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)# 定義損失函數和優化器
criterion = nn.BCELoss()
lr = 0.0002
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))# 訓練函數
def train_gan(epochs):for epoch in range(epochs):for i, (real_imgs, _) in enumerate(train_loader):batch_size = real_imgs.size(0)real_imgs = real_imgs.to(device)# 創建標簽real_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# ---------------------#  訓練判別器# ---------------------d_optimizer.zero_grad()# 計算判別器對真實圖像的損失real_pred = discriminator(real_imgs)d_real_loss = criterion(real_pred, real_labels)# 生成假圖像z = torch.randn(batch_size, latent_dim).to(device)fake_imgs = generator(z)# 計算判別器對假圖像的損失fake_pred = discriminator(fake_imgs.detach())d_fake_loss = criterion(fake_pred, fake_labels)# 總判別器損失d_loss = d_real_loss + d_fake_lossd_loss.backward()d_optimizer.step()# ---------------------#  訓練生成器# ---------------------g_optimizer.zero_grad()# 生成假圖像fake_imgs = generator(z)# 計算判別器對假圖像的預測fake_pred = discriminator(fake_imgs)# 生成器希望判別器將假圖像判斷為真g_loss = criterion(fake_pred, real_labels)g_loss.backward()g_optimizer.step()# 打印訓練進度if i % 100 == 0:print(f"Epoch [{epoch}/{epochs}] Batch {i}/{len(train_loader)} "f"Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}")# 每個epoch結束后,生成一些樣本圖像if (epoch + 1) % 10 == 0:generate_samples(generator, epoch + 1, latent_dim, device)# 生成樣本圖像
def generate_samples(generator, epoch, latent_dim, device, n_samples=16):generator.eval()z = torch.randn(n_samples, latent_dim).to(device)with torch.no_grad():samples = generator(z).cpu()# 可視化生成的樣本fig, axes = plt.subplots(4, 4, figsize=(8, 8))for i, ax in enumerate(axes.flatten()):ax.imshow(samples[i][0].numpy(), cmap='gray')ax.axis('off')plt.tight_layout()plt.savefig(f"gan_samples/gan_samples_epoch_{epoch}.png")plt.close()generator.train()# 訓練模型
train_gan(epochs=50)# 生成最終樣本
generate_samples(generator, "final", latent_dim, device)

最終生成的樣本:
在這里插入圖片描述

六、小結

??GAN通過對抗機制實現了強大的生成能力,成為生成模型領域的里程碑技術。衍生變體(如CGAN、CycleGAN等)進一步擴展了其應用場景。



End.

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/90825.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/90825.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/90825.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Vue和Element的使用

文章目錄1.vue 腳手架創建步驟2.vue項目開發流程3.vue路由4.Element1.vue 腳手架創建步驟 創建一個文件夾 vue雙擊進入文件夾,在路徑上輸入cmd輸入vue ui, 目的:調出圖形化用戶界面點擊創建 9. 10.在vscode中打開 主要目錄介紹 src目錄介紹 vue項目啟動 圖形化界面中沒有npm…

如何設置直播間的觀看門檻,讓直播間安全有效地運行?

文章目錄前言一、直播間觀看門檻有哪幾種形式?二、設置直播間的觀看門檻,對直播的好處是什么三、如何一站式實現上述功能?總結前言 打造一個安全、高效、互動良好的直播間并非易事。面對海量涌入的觀眾,如何有效識別并阻擋潛在的…

【SkyWalking】配置告警規則并通過 Webhook 推送釘釘通知

🧭 本文為 【SkyWalking 系列】第 3 篇 👉 系列導航:點擊跳轉 【SkyWalking】配置告警規則并通過 Webhook 推送釘釘通知 簡介 介紹 SkyWalking 告警機制、告警規則格式以及如何通過 webhook 方式將告警信息發送到釘釘。 引入 服務響應超時…

關于 驗證碼系統 詳解

驗證碼系統的目的是:阻止自動化腳本訪問網頁資源,驗證訪問者是否為真實人類用戶。它通過各種測試(圖像、行為、計算等)判斷請求是否來自機器人。一、驗證碼系統的整體架構驗證碼系統通常由 客戶端 服務端 風控模型 數據采集 四…

微服務集成snail-job分布式定時任務系統實踐

前言 從事開發工作的同學,應該對定時任務的概念并不陌生,就是我們的系統在運行過程中能夠自動執行的一些任務、工作流程,無需人工干預。常見的使用場景包括:數據庫的定時備份、文件系統的定時上傳云端服務、每天早上的業務報表數…

依賴注入的邏輯基于Java語言

對于一個廚師,要做一道菜。傳統的做法是:你需要什么食材,就自己去菜市場買什么。這意味著你必須知道去哪個菜市場、怎么挑選食材、怎么討價還價等等。你不僅要會做菜,還要會買菜,職責變得復雜了。 而依賴注入就像是有一…

skywalking鏡像應用springboot的例子

目錄 1、skywalking-ui連接skywalking-oap服務失敗問題 2、k8s環境 檢查skywalking-oap服務狀態 3、本地iidea啟動服務連接skywalking oap服務 4、基于apache-skywalking-java-agent-9.4.0.tgz構建skywalking-agent鏡像 4.1、Dockerfile內容如下 4.2、AbstractBuilder.M…

3. java 堆和 JVM 內存結構

1. JVM介紹和運行流程-CSDN博客 2. 什么是程序計數器-CSDN博客 3. java 堆和 JVM 內存結構-CSDN博客 4. 虛擬機棧-CSDN博客 5. JVM 的方法區-CSDN博客 6. JVM直接內存-CSDN博客 7. JVM類加載器與雙親委派模型-CSDN博客 8. JVM類裝載的執行過程-CSDN博客 9. JVM垃圾回收…

UnityShader——SSAO

目錄 1.是什么 2.原理 3.各部分解釋 2.1.從屏幕空間到視圖空間 2.2.以法線半球為基,獲取隨機向量 2.3.應用偏移,并將其轉換為uv坐標 2.4.獲取深度 2.5.比較并計算貢獻 2.6.最后計算 4.改進 4.1.平滑過渡 4.2.模糊 5.變量和語句解釋 5.1._D…

【設計模式】外觀模式(門面模式)

外觀模式(Facade Pattern)詳解一、外觀模式簡介 外觀模式(Facade Pattern) 是一種 結構型設計模式,它為一個復雜的子系統提供一個統一的高層接口,使得子系統更容易使用。 外觀模式又稱為門面模式&#xff0…

【6.1.1 漫畫分庫分表】

漫畫分庫分表 “數據量大了不可怕,可怕的是不知道如何優雅地拆分。” 🎭 人物介紹 架構師老王:資深數據庫架構專家,精通各種分庫分表方案Java小明:對分庫分表充滿疑問的開發者ShardingSphere師傅:Apache S…

Tomcat問題:啟動腳本startup.bat中文亂碼問題解決

一、問題描述 我們第一次下載或者打開Tomcat時可能在控制臺會出現中文亂碼問題二、解決辦法 我的是8.x版本的tomcat用notepad打開:logging.properties 找到:java.util.logging.ConsoleHandler.encoding設置成GBK,重啟tomcat即可

Linux中Gitee的使用

一、Gitee簡介:Gitee(碼云)是中國的一個代碼托管和協作開發平臺,類似于GitHub或GitLab,主要面向開發者提供代碼管理、項目協作及開源生態服務。適用場景個人開發者:托管私有代碼或參與開源項目。中小企業&a…

Oracle大表數據清理優化與注意事項詳解

一、性能優化策略 1. 批量處理優化批量大小選擇: 小批量(1,000-10,000行):減少UNDO生成,但需要更多提交次數中批量(10,000-100,000行):平衡性能與資源消耗大批量(100,000行):適合高配置環境,但需監控資源使…

Anaconda及Conda介紹及使用

文章目錄Anaconda簡介為什么選擇 Anaconda?Anaconda 安裝Win 平臺macOS 平臺Linux 平臺Anaconda 界面使用Conda簡介Conda下載安裝conda 命令環境管理包管理其他常用命令Jupyter Notebook(可選)Anaconda簡介 Anaconda 是一個數據科學和機器學…

外包干了一周,技術明顯退步

我是一名本科生,自2019年起,我便在南京某軟件公司擔任功能測試的工作。這份工作雖然穩定,但日復一日的重復性工作讓我逐漸陷入了舒適區,失去了前進的動力。兩年的時光匆匆流逝,我卻在原地踏步,技術沒有絲毫…

【QT】多線程相關教程

一、核心概念與 Qt 線程模型 1.線程與進程的區別: 線程是程序執行的最小單元,進程是資源分配的最小單元,線程共享進程的內存空間(堆,全局變量等),而進程擁有獨立的內存空間。Qt線程只要關注同一進程內的并發。 2.為什么使用多線程…

VS 版本更新git安全保護問題的解決

問題:我可能移動了一個VS C# 項目,然后,發現里面的git版本檢測不能用了 正在打開存儲庫: X:\Prj_C#\3D fatal: detected dubious ownership in repository at X:/Prj_C#/3DSnapCatch X:/Prj_C#/3D is owned by:S-1-5-32-544 but the current …

Git常用命令一覽

Git 是基于 Linux內核開發的版本控制工具。與常用的版本控制工具 CVS, Subversion 等不同,它采用了分布式版本庫的方式,不必服務器端軟件支持(ps:這得分是用什么樣的服務端,使用http協議或者git協議等不太一樣。并且在…

基于 JSON 文件定位圖片缺陷點并保存

基于JSON的圖片缺陷處理流程 ├── 1. 輸入檢查 │ ├── 驗證圖片文件是否存在 │ └── 驗證JSON文件是否存在 │ ├── 2. 數據加載 │ ├── 打開并加載圖片 │ └── 讀取并解析JSON文件 │ ├── 3. 缺陷信息提取 │ ├── 檢查JSON中是否存在shapes字…