生成式人工智能實戰 | 生成對抗網絡(Generative Adversarial Network, GAN)

生成式人工智能實戰 | 生成對抗網絡

    • 0. 前言
    • 1. 生成對抗網絡
    • 2. 模型構建
      • 2.1 生成器
      • 2.2 判別器
    • 3. 模型訓練
      • 3.1 數據加載
      • 3.2 訓練流程

0. 前言

生成對抗網絡 (Generative Adversarial Networks, GAN) 是一種由兩個相互競爭的神經網絡組成的深度學習模型,它由一個生成網絡和一個判別網絡組成,通過彼此之間的博弈來提高生成網絡的性能。生成對抗網絡使用神經網絡生成與原始圖像集非常相似的新圖像,它在圖像生成中應用廣泛,且 GAN 的相關研究正在迅速發展,以生成與真實圖像難以區分的逼真圖像。在本節中,我們將學習 GAN 網絡的原理并使用 PyTorch 實現 GAN

1. 生成對抗網絡

生成對抗網絡 (Generative Adversarial Networks, GAN) 包含兩個網絡:生成網絡( Generator,也稱生成器)和判別網絡( discriminator,也稱判別器)。在 GAN 網絡訓練過程中,需要有一個合理的圖像樣本數據集,生成網絡從圖像樣本中學習圖像表示,然后生成與圖像樣本相似的圖像。判別網絡接收(由生成網絡)生成的圖像和原始圖像樣本作為輸入,并將圖像分類為原始(真實)圖像或生成(偽造)圖像:

  • 生成器 G ( z ; θ G ) G(z;θ_G) G(z;θG?) 接受噪聲 z ~ p z z~p_z zpz??,學習映射到數據空間,以“欺騙”判別器
  • 判別器 D ( x ; θ D ) D(x;θ_D) D(x;θD?) 輸出樣本 x x x 屬于真實數據的概率,旨在區分真實與生成數據
    兩者通過以下最小–最大化 (minimax) 目標函數進行博弈:
    m i n G m a x D ? V ( D , G ) = E x ~ p d a t a [ l o g ? D ( x ) ] + E z ~ p z [ l o g ? ( 1 ? D ( G ( z ) ) ) ] \underset G {min} \underset D {max} ?V(D,G)=\mathbb E_{x~p_{data}}[log?D(x)]+\mathbb E_{z~p_z}[log?(1?D(G(z)))] Gmin?Dmax??V(D,G)=Expdata??[log?D(x)]+Ezpz??[log?(1?D(G(z)))]
    生成網絡的目標是生成逼真的偽造圖像騙過判別網絡,判別網絡的目標是將生成的圖像分類為偽造圖像,將原始圖像樣本分類為真實圖像。本質上,GAN 中的對抗表示兩個網絡的相反性質,生成網絡生成圖像來欺騙判別網絡,判別網絡通過判別圖像是生成圖像還是原始圖像來對輸入圖像進行分類:

GAN原理

在上圖中,生成網絡根據輸入隨機噪聲生成圖像,判別網絡接收生成網絡生成的圖像,并將它們與真實圖像樣本進行比較,以判斷生成的圖像是真實的還是偽造的。生成網絡嘗試生成盡可能逼真的圖像,而判別網絡嘗試判定生成網絡生成圖像的真實性,從而學習生成盡可能逼真的圖像。
GAN 的關鍵思想是生成網絡和判別網絡之間的競爭和動態平衡,通過不斷的訓練和迭代,生成網絡和判別網絡會逐漸提高性能,生成網絡能夠生成更加逼真的樣本,而判別網絡則能夠更準確地區分真實和偽造的樣本。
通常,生成網絡和判別網絡交替訓練,將生成網絡和判別網絡視為博弈雙方,并通過兩者之間的對抗來推動模型性能的提升,直到生成網絡生成的樣本能夠以假亂真,判別網絡無法分辨真實樣本和生成樣本之間的差異:

  • 生成網絡的訓練過程:凍結判別網絡權重,生成網絡以噪聲 z 作為輸入,通過最小化生成網絡與真實數據之間的差異來學習如何生成更好的樣本,以便判別網絡將圖像分類為真實圖像
  • 判別網絡的訓練過程:凍結生成網絡權重,判別網絡通過最小化真實樣本和假樣本之間的分類誤差來更新判別網絡,區分真實樣本和生成樣本,將生成網絡生成的圖像分類為偽造圖像

重復訓練生成網絡與判別網絡,直到達到平衡,當判別網絡能夠很好地檢測到生成的圖像時,生成網絡對應的損失比判別網絡對應的損失要高得多。通過不斷訓練生成網絡和判別網絡,直到生成網絡可以生成逼真圖像,而判別網絡無法區分真實圖像和生成圖像。

2. 模型構建

2.1 生成器

生成器由若干全連接層與 LeakyReLU 激活構成,最后用 Tanh 將輸出映射至 [?1,1] 范圍內:

# 定義生成器 G
import torch.nn as nnclass Generator(nn.Module):def __init__(self, z_dim=10):super().__init__()self.net = nn.Sequential(nn.Linear(z_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 28*28),nn.Tanh()  # 輸出像素映射到 [-1,1])def forward(self, z):return self.net(z).view(-1,1,28,28)

2.2 判別器

判別器使用全連接層與 LeakyReLU 激活,末端使用 Sigmoid 激活函數,輸出一個標量真值估計:

# 定義判別器 D,對輸入圖片輸出真偽概率
class Discriminator(nn.Module):def __init__(self, img_dim=28*28):super().__init__()self.model = nn.Sequential(nn.Flatten(),nn.Linear(img_dim, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 512),nn.LeakyReLU(0.2),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):return self.model(x)

3. 模型訓練

接下來,使用 MNIST 數據集訓練 GAN 模型。

3.1 數據加載

MNIST 像素值歸一化到生成器 Tanh 輸出所需的 [?1,1] 區間:

# 加載并歸一化 MNIST 數據集
from torchvision import datasets, transforms
from torch.utils.data import DataLoadertransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 映射到 [-1,1]
])
train_ds = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)

3.2 訓練流程

首先,初始化模型、優化器與損失函數:

import torch
import torch.optim as optimdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
z_dim = 20
G = Generator(z_dim).to(device)
D = Discriminator().to(device)opt_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.999))
loss_fn = nn.BCELoss()

訓練模型 50epoch

epochs = 50for epoch in range(epochs):for real, _ in train_loader:real = real.to(device)batch_size = real.size(0)real_labels = torch.ones(batch_size, 1, device=device)fake_labels = torch.zeros(batch_size, 1, device=device)# 訓練判別器# 在真實樣本上進行訓練D_real = D(real)loss_D_real = loss_fn(D_real, real_labels)opt_D.zero_grad()loss_D_real.backward()opt_D.step()# 在虛假樣本上進行訓練z = torch.randn(batch_size, z_dim, device=device)fake = G(z)D_fake = D(fake.detach())loss_D_fake = loss_fn(D_fake, fake_labels)opt_D.zero_grad()loss_D_fake.backward()opt_D.step()d_loss = (loss_D_real + loss_D_fake) / 2# 訓練生成器z = torch.randn(batch_size, z_dim, device=device)fake = G(z)D_fake = D(fake)loss_G = loss_fn(D_fake, real_labels)  # 生成器希望D認為它生成的是真的opt_G.zero_grad()loss_G.backward()opt_G.step()print(f"Epoch [{epoch+1}/{epochs}]  Loss_D: {d_loss.item():.4f}  Loss_G: {loss_G.item():.4f}")

使用訓練后的模型生成偽造數據:

# 采樣生成圖片并顯示
import matplotlib.pyplot as pltG.eval()
with torch.no_grad():z = torch.randn(16, z_dim, device=device)fake_images = G(z).cpu()fig, axes = plt.subplots(4, 4, figsize=(6, 6))
for i, ax in enumerate(axes.flatten()):ax.imshow(fake_images[i].squeeze().reshape(28, 28), cmap='gray')ax.axis('off')
plt.tight_layout()
plt.show()

生成結果

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/86496.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/86496.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/86496.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

緩存與加速技術實踐-MongoDB數據庫應用

一.什么是MongoDB MongoDB 是一個文檔型數據庫,數據以類似 JSON 的文檔形式存儲。 MongoDB 的設計理念是為了應對大數據量、高性能和靈活性需求。 MongoDB 使用集合(Collections)來組織文檔(Documents)&#xff0…

聲網對話式AI把“答疑機器人”變成“有思維的助教”

作為一家專注初高中學生的線上教育平臺,我們精心打磨的系統化課程收獲了不少認可,但課后無人答疑的難題卻始終橫亙在前。學生課后遇到疑惑,要么只能默默憋在心里,要么就得苦苦等待下一節課,家長們也頻繁抱怨 “花了錢&…

常見的排序方法

目錄 1. 插入排序 2. 希爾排序 3. 選擇排序 4. 堆排序 5. 冒泡排序 6. 快速排序 1. 快速排序的實現 1. 思路(以從小到大排序為例) 2. 選取基準元素的方法(Hoare) 3. 選取基準元素的方法(挖坑法) …

【matlab定位例程】基于AOA和TDOA混合的定位方法,背景為三維空間,自適應錨點數量,附下載鏈接

文章目錄 代碼概述代碼功能概述核心算法原理AOA定位模型TDOA定位迭代算法混合定位策略關鍵技術創新 運行結果4個錨點的情況40個錨點的情況 MATLAB源代碼 代碼概述 代碼功能概述 本代碼實現了一種三維空間中的混合定位算法,結合到達角( A O A AOA AOA&a…

專題:2025醫療AI應用研究報告|附200+份報告PDF匯總下載

原文鏈接:https://tecdat.cn/?p42748 本報告匯總解讀聚焦醫療行業人工智能應用的前沿動態與市場機遇,以數據驅動視角剖析技術演進與商業落地的關鍵路徑。從GenAI在醫療領域的爆發式增長,到細分場景的成熟度矩陣,再到運營成本壓力…

推薦一個前端基于vue3.x,vite7.x,后端基于springboot3.4.x的完全開源的前后端分離的中后臺管理系統基礎項目(純凈版)

XHan Admin 簡介 🎉🎉 XHan Admin 是一個開箱即用的開源中后臺管理系統基礎解決方案, 項目為前后端分離架構。采用最新的技術棧全新構建,純凈的項目代碼,沒有歷史包袱。 前端使用最新發布的 vite7.0 版本構建&#xf…

MySQL誤刪數據急救指南:基于Binlog日志的實戰恢復詳解

背景 數據誤刪是一個比較嚴重的場景 1.典型誤操作場景 場景1:DELETE FROM orders WHERE status0 → 漏寫AND create_time>‘2025-06-20’ 場景2:DROP TABLE customer → 誤執行于生產環境 認識 binlog 1.binlog 的核心作用 記錄所有 DDL/DML 操…

高效數據采集方案:快速部署與應用 AnyCrawl 網頁爬蟲工具實操指南

以下是對 AnyCrawl 的簡單介紹: AnyCrawl 提供高性能網頁數據爬取,其功能專為 LLM 集成和數據處理而設計支持利用搜索引擎直接查詢獲取結果內容,類似 searxng提供開發者友好的API,支持動態內容抓取,并輸出結構化數據&…

vue3可以分頁、搜索的select

下載 npm i v-selectpage基本使用 import { SelectPageList } from v-selectpage;<SelectPageListlanguage"zh-chs"key-prop"id"label-prop"name"fetch-data"fetchData" />const fetchData (data,callback) > {const { sea…

C# 入門學習教程 (一)

文章目錄 一、解決方案與項目1. Solution 與 project 二、類與名稱空間1.類與名稱空間2.類庫的引用1. DLL引用&#xff08;黑盒引用&#xff0c;無源代碼&#xff09;2. Nuget 引用3. 項目引用&#xff08;白盒引用&#xff0c;有源代碼&#xff09; 3.依賴關系 三、類&#xf…

76、單元測試-參數化測試

76、單元測試-參數化測試 參數化測試是一種單元測試技術&#xff0c;通過將測試數據與測試邏輯分離&#xff0c;使用不同的輸入參數多次運行相同的測試用例&#xff0c;從而提高測試效率和代碼復用性。 #### 基本原理 - **數據驅動測試**&#xff1a;將測試數據參數化&#xf…

SQL學習筆記3

SQL常用函數 1、字符串函數 函數調用的語法&#xff1a;select 函數&#xff08;參數); 常用的字符串函數有&#xff1a; 拼接字符串&#xff0c;將幾個字符串拼到一起&#xff1a;concat (s1,s2,……); select concat(你好,hello); update mytable set wherefo concat(中…

Golang 面向對象編程,如何實現 封裝、繼承、多態

Go語言雖然不是純粹的面向對象語言&#xff0c;但它通過結構體(struct)、接口(interface)和方法(method)提供了面向對象編程的能力。下面我將通過具體示例展示Go中如何實現類、封裝、繼承、多態以及構造函數等概念。 1. 類與封裝 在Go中&#xff0c;使用結構體(struct)來定義…

為什么android要使用Binder機制

1.linux中大多數標準 IPC 場景&#xff08;如管道、消息隊列、ioctl 等&#xff09;的進程間通信機制 ------------------ ------------------ ------------------ | 用戶進程 A | | 內核空間 | | 用戶進程 B | | (User Spa…

OpenCV CUDA模塊設備層-----雙曲余弦函數cosh()

操作系統&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 編程語言&#xff1a;C11 算法描述 該函數用于計算四維浮點向量&#xff08;float4類型&#xff09;的雙曲余弦值&#xff0c;作用于CUDA設備端。雙曲余弦函數定義為cosh(x) (e? …

48頁PPT | 企業數字化轉型關鍵方法論:實踐路徑、案例和落地評估框架

目錄 一、什么是企業數據化轉型&#xff1f; 二、為什么要進行數據化轉型&#xff1f; 1. 市場復雜性與不確定性上升 2. 內部流程效率與協同難題突出 3. 數字資產沉淀不足&#xff0c;智能化基礎薄弱 三、數據化流程管理&#xff1a;從“業務流程”到“數據流程”的對齊 …

VTK中的形態學處理

VTK圖像處理代碼解析:閾值化與形態學開閉運算 這段代碼展示了使用VTK進行醫學圖像處理的兩個關鍵步驟:閾值分割和形態學開閉運算。下面我將詳細解析每個部分的功能和實現原理。 處理前 處理后 1. 閾值分割部分 (vtkImageThreshold) vtkSmartPointer<vtkImageThresho…

xlsx.utils.sheet_to_json() 方法詳解

sheet_to_json() 是 SheetJS/xlsx 庫中最常用的方法之一&#xff0c;用于將 Excel 工作表&#xff08;Worksheet&#xff09;轉換為 JSON 格式數據。下面我將全面講解它的用法、參數配置和實際應用場景。 基本語法 javascript 復制 下載 const jsonData XLSX.utils.sheet…

〔從零搭建〕BI可視化平臺部署指南

&#x1f525;&#x1f525; AllData大數據產品是可定義數據中臺&#xff0c;以數據平臺為底座&#xff0c;以數據中臺為橋梁&#xff0c;以機器學習平臺為中層框架&#xff0c;以大模型應用為上游產品&#xff0c;提供全鏈路數字化解決方案。 ?杭州奧零數據科技官網&#xf…

合規型區塊鏈RWA系統解決方案報告——機構資產數字化的終極武器

&#xff08;跨境金融科技解決方案白皮書&#xff09; 一、直擊機構客戶四大痛點 痛點傳統方案缺陷我們的破局點?? 跨境資產流動性差結算周期30天&#xff0c;摩擦成本超8%?? 724h全球實時交易&#xff08;速度提升90%&#xff09;?? 合規成本飆升KYC/AML人工審核占成本…