【Pytorch】生成對抗網絡實戰

GAN框架基于兩個模型的競爭,Generator生成器和Discriminator鑒別器。生成器生成假圖像,鑒別器則嘗試從假圖像中識別真實的圖像。作為這種競爭的結果,生成器將生成更好看的假圖像,而鑒別器將更好地識別它們。

目錄

創建數據集

定義生成器

定義鑒別器

初始化模型權重

定義損失函數

定義優化器

訓練模型

部署生成器


創建數據集

使用 PyTorch torchvision 包中提供的 STL-10 數據集,數據集中有 10 個類:飛機、鳥、車、貓、鹿、狗、馬、猴、船、卡車。圖像為96*96像素的RGB圖像。數據集包含 5,000 張訓練圖像和 8,000 張測試圖像。在訓練數據集和測試數據集中,每個類分別有 500 和 800 張圖像。

?STL-10數據集詳細參考http://t.csdnimg.cn/ojBn6中數據加載和處理部分?

from torchvision import datasets
import torchvision.transforms as transforms
import os# 定義數據集路徑
path2data="./data"
# 創建數據集路徑
os.makedirs(path2data, exist_ok= True)# 定義圖像尺寸
h, w = 64, 64
# 定義均值
mean = (0.5, 0.5, 0.5)
# 定義標準差
std = (0.5, 0.5, 0.5)
# 定義數據預處理
transform= transforms.Compose([transforms.Resize((h,w)),  # 調整圖像尺寸transforms.CenterCrop((h,w)),  # 中心裁剪transforms.ToTensor(),  # 轉換為張量transforms.Normalize(mean, std)])  # 歸一化# 加載訓練集
train_ds=datasets.STL10(path2data, split='train', download=False,transform=transform)

?展示示例圖像張量形狀、最小值和最大值

import torch
for x, _ in train_ds:print(x.shape, torch.min(x), torch.max(x))break

?展示示例圖像

from torchvision.transforms.functional import to_pil_image
import matplotlib.pylab as plt
%matplotlib inline
plt.imshow(to_pil_image(0.5*x+0.5))

?

創建數據加載器?

import torch
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)

?示例

for x,y in train_dl:print(x.shape, y.shape)break

定義生成器

GAN框架是基于兩個模型的競爭,generator生成器和discriminator鑒別器。生成器生成假圖像,鑒別器嘗試從假圖像中識別真實的圖像。

作為這種競爭的結果,生成器將生成更好看的假圖像,而鑒別器將更好地識別它們。

定義生成器模型?

from torch import nn
import torch.nn.functional as Fclass Generator(nn.Module):def __init__(self, params):super(Generator, self).__init__()# 獲取參數nz = params["nz"]ngf = params["ngf"]noc = params["noc"]# 定義反卷積層1self.dconv1 = nn.ConvTranspose2d( nz, ngf * 8, kernel_size=4,stride=1, padding=0, bias=False)# 定義批歸一化層1self.bn1 = nn.BatchNorm2d(ngf * 8)# 定義反卷積層2self.dconv2 = nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層2self.bn2 = nn.BatchNorm2d(ngf * 4)# 定義反卷積層3self.dconv3 = nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層3self.bn3 = nn.BatchNorm2d(ngf * 2)# 定義反卷積層4self.dconv4 = nn.ConvTranspose2d( ngf * 2, ngf, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層4self.bn4 = nn.BatchNorm2d(ngf)# 定義反卷積層5self.dconv5 = nn.ConvTranspose2d( ngf, noc, kernel_size=4, stride=2, padding=1, bias=False)# 前向傳播def forward(self, x):# 反卷積層1x = F.relu(self.bn1(self.dconv1(x)))# 反卷積層2x = F.relu(self.bn2(self.dconv2(x)))            # 反卷積層3x = F.relu(self.bn3(self.dconv3(x)))        # 反卷積層4x = F.relu(self.bn4(self.dconv4(x)))    # 反卷積層5out = torch.tanh(self.dconv5(x))return out

設定生成器模型參數、移動模型到cuda設備并打印模型結構?

params_gen = {"nz": 100,"ngf": 64,"noc": 3,}
model_gen = Generator(params_gen)
device = torch.device("cuda:0")
model_gen.to(device)
print(model_gen)

定義鑒別器

定義鑒別器模型,?用于鑒別真實圖像

class Discriminator(nn.Module):def __init__(self, params):super(Discriminator, self).__init__()# 獲取參數nic= params["nic"]ndf = params["ndf"]# 定義卷積層1self.conv1 = nn.Conv2d(nic, ndf, kernel_size=4, stride=2, padding=1, bias=False)# 定義卷積層2self.conv2 = nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層2self.bn2 = nn.BatchNorm2d(ndf * 2)            # 定義卷積層3self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層3self.bn3 = nn.BatchNorm2d(ndf * 4)# 定義卷積層4self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層4self.bn4 = nn.BatchNorm2d(ndf * 8)# 定義卷積層5self.conv5 = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False)def forward(self, x):# 使用leaky_relu激活函數對卷積層1的輸出進行激活x = F.leaky_relu(self.conv1(x), 0.2, True)# 使用leaky_relu激活函數對卷積層2的輸出進行激活,并使用批歸一化層2進行批歸一化x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, inplace = True)# 使用leaky_relu激活函數對卷積層3的輸出進行激活,并使用批歸一化層3進行批歸一化x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2, inplace = True)# 使用leaky_relu激活函數對卷積層4的輸出進行激活,并使用批歸一化層4進行批歸一化x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2, inplace = True)        # 使用sigmoid激活函數對卷積層5的輸出進行激活,并返回結果# Sigmoid激活函數是一種常用的非線性激活函數,它將輸入值壓縮到0和1之間,[ \sigma(x) = \frac{1}{1 + e^{-x}} ]out = torch.sigmoid(self.conv5(x))return out.view(-1)

設置模型參數,移動模型到cuda設備,打印模型結構?


params_dis = {"nic": 3,"ndf": 64}
model_dis = Discriminator(params_dis)
model_dis.to(device)
print(model_dis)

初始化模型權重

定義函數,初始化模型權重?

def initialize_weights(model):# 獲取模型類的名稱classname = model.__class__.__name__# 如果模型類名稱中包含'Conv',則初始化權重為均值為0,標準差為0.02的正態分布if classname.find('Conv') != -1:nn.init.normal_(model.weight.data, 0.0, 0.02)# 如果模型類名稱中包含'BatchNorm',則初始化權重為均值為1,標準差為0.02的正態分布,偏置為0elif classname.find('BatchNorm') != -1:nn.init.normal_(model.weight.data, 1.0, 0.02)nn.init.constant_(model.bias.data, 0)

初始化生成器模型和鑒別器模型的權重?

# 對生成器模型應用初始化權重函數
model_gen.apply(initialize_weights);
# 對判別器模型應用初始化權重函數
model_dis.apply(initialize_weights);

定義損失函數

定義二元交叉熵(BCE)損失函數?

loss_func = nn.BCELoss()

定義優化器

定義Adam優化器

from torch import optim
# 學習率
lr = 2e-4 
# Adam優化器的beta1參數
beta1 = 0.5
# 定義鑒別器模型的優化器,學習率為lr,beta1參數為beta1,beta2參數為0.999
opt_dis = optim.Adam(model_dis.parameters(), lr=lr, betas=(beta1, 0.999))
# 定義生成器模型的優化器
opt_gen = optim.Adam(model_gen.parameters(), lr=lr, betas=(beta1, 0.999))

訓練模型

?示例訓練1000個epochs

# 定義真實標簽和虛假標簽
real_label = 1
fake_label = 0
# 獲取生成器的噪聲維度
nz = params_gen["nz"]
# 設置訓練輪數
num_epochs = 1000
# 定義損失歷史記錄
loss_history={"gen": [],"dis": []}
# 定義批次數
batch_count = 0
# 遍歷訓練輪數
for epoch in range(num_epochs):# 遍歷訓練數據for xb, yb in train_dl:# 獲取批大小ba_si = xb.size(0)# 將判別器梯度置零model_dis.zero_grad()# 將輸入數據移動到指定設備xb = xb.to(device)# 將標簽數據轉換為指定設備yb = torch.full((ba_si,), real_label, device=device)# 判別器輸出out_dis = model_dis(xb)# 將輸出和標簽轉換為浮點數out_dis = out_dis.float()yb = yb.float()# 計算真實樣本的損失loss_r = loss_func(out_dis, yb)# 反向傳播loss_r.backward()# 生成噪聲noise = torch.randn(ba_si, nz, 1, 1, device=device)# 生成器輸出out_gen = model_gen(noise)# 判別器輸出out_dis = model_dis(out_gen.detach())# 將標簽數據填充為虛假標簽yb.fill_(fake_label)    # 計算虛假樣本的損失loss_f = loss_func(out_dis, yb)# 反向傳播loss_f.backward()# 計算判別器的總損失loss_dis = loss_r + loss_f  # 更新判別器的參數opt_dis.step()   # 將生成器梯度置零model_gen.zero_grad()# 將標簽數據填充為真實標簽yb.fill_(real_label)  # 判別器輸出out_dis = model_dis(out_gen)# 計算生成器的損失loss_gen = loss_func(out_dis, yb)# 反向傳播loss_gen.backward()# 更新生成器的參數opt_gen.step()# 記錄生成器和判別器的損失loss_history["gen"].append(loss_gen.item())loss_history["dis"].append(loss_dis.item())# 更新批次數batch_count += 1# 每100個批打印一次損失if batch_count % 100 == 0:print(epoch, loss_gen.item(),loss_dis.item())

?繪制損失圖像

plt.figure(figsize=(10,5))
plt.title("Loss Progress")
plt.plot(loss_history["gen"],label="Gen. Loss")
plt.plot(loss_history["dis"],label="Dis. Loss")
plt.xlabel("batch count")
plt.ylabel("Loss")
plt.legend()
plt.show()

存儲模型權重?

import os
path2models = "./models/"
os.makedirs(path2models, exist_ok=True)
path2weights_gen = os.path.join(path2models, "weights_gen_128.pt")
path2weights_dis = os.path.join(path2models, "weights_dis_128.pt")
torch.save(model_gen.state_dict(), path2weights_gen)
torch.save(model_dis.state_dict(), path2weights_dis)

部署生成器

通常情況下,訓練完成后放棄鑒別器模型而保留生成器模型,部署經過訓練的生成器來生成新的圖像。為部署生成器模型,將訓練好的權重加載到模型中,然后給模型提供隨機噪聲。

# 加載生成器模型的權重
weights = torch.load(path2weights_gen)
# 將權重加載到生成器模型中
model_gen.load_state_dict(weights)
# 將生成器模型設置為評估模式
model_gen.eval()

?生成圖像

import numpy as np
with torch.no_grad():# 生成固定噪聲fixed_noise = torch.randn(16, nz, 1, 1, device=device)# 打印噪聲形狀print(fixed_noise.shape)# 生成假圖像img_fake = model_gen(fixed_noise).detach().cpu()    
# 打印假圖像形狀
print(img_fake.shape)
# 創建畫布
plt.figure(figsize=(10,10))
# 遍歷假圖像
for ii in range(16):# 在畫布上繪制圖像plt.subplot(4,4,ii+1)# 將圖像轉換為PIL圖像plt.imshow(to_pil_image(0.5*img_fake[ii]+0.5))# 關閉坐標軸plt.axis("off")

其中一些可能看起來扭曲,而另一些看起來相對真實。為改進結果,可以在單個數據類上訓練模型,而不是在多個類上一起訓練。GAN在使用單個類進行訓練時表現更好。此外,可以嘗試更長時間地訓練模型。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/94982.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/94982.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/94982.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Java基礎第7天總結(代碼塊、內部類、函數式編程)

代碼塊靜態代碼塊:有static修飾,屬于類,與類一起優先加載,自動執行一次實例代碼塊:無static修飾,屬于對象,每次創建對象時,都會優先執行一次。package com.itheima.code;import java…

文獻綜述寫作指南:從海量文獻到邏輯閉環的實戰模板

文獻綜述往往是學術寫作的“第一關難題”:面對成百上千篇文獻,如何避免“簡單羅列”的陷阱,梳理出有邏輯、有洞見的論述體系?本文結合學術寫作實踐,總結出一套模塊化的文獻綜述“實戰模板”,通過結構化方法…

CuTe C++ 簡介01,從示例開始

這里先僅僅關注 C 層的介紹,python DSL 以后再說。在 ubuntu 22.04 X64 中,RTX 50801. 環境搭建1.1 安裝 cuda1.2 下載源碼git clone https://github.com/NVIDIA/cutlass.git1.3 編譯mkdir build/ cmake .. -DCUTLASS_NVCC_ARCHS"120" -DCMAK…

Python實現異步多線程Web服務器:從原理到實踐

目錄Python實現異步多線程Web服務器:從原理到實踐引言第一章:Web服務器基礎1.1 Web服務器的工作原理1.2 HTTP協議簡介1.3 同步 vs 異步 vs 多線程第二章:Python異步編程基礎2.1 異步I/O概念2.2 協程與async/await2.3 事件循環第三章&#xff…

Deep Think with Confidence:llm如何進行高效率COT推理優化

1. 引言:大模型的推理解碼優化 大型語言模型(LLM)在處理數學、編碼等復雜推理任務時,一種強大但“耗能巨大”的技術是self-consistency,也稱并行思考(parallel thinking)。其核心思想是讓模型對同一個問題生成多條不同的“思考路徑”(reasoning traces),然后通過多數…

vscode克隆遠程代碼步驟

一、直接使用VsCode1.復制git的https鏈接代碼2.在vscode中點擊 代碼管理-克隆倉庫3.粘貼(在git里面復制的https鏈接)4.選擇需要存儲的文件位置5.確認6.代碼克隆成功二、使用命令行克隆1.確定文件放置位置,右鍵2.復制git的https鏈接代碼3.粘貼…

spi總線

一、介紹SPI總線(Serial Peripheral Interface,串行外設接口)是一種高速全雙工同步串行通信總線,核心通過“主從架構同步時鐘”實現設備間數據傳輸,因結構簡單、速率高,廣泛用于MCU與傳感器、存儲芯片、顯示…

COLA:大型語言模型高效微調的革命性框架

本文由「大千AI助手」原創發布,專注用真話講AI,回歸技術本質。拒絕神話或妖魔化。搜索「大千AI助手」關注我,一起撕掉過度包裝,學習真實的AI技術! 1 COLA技術概述 COLA(Chain of LoRA)是一種創…

數據結構與算法:線段樹(三):維護更多信息

前言 這次的題思維上倒不是很難&#xff0c;就是代碼量比較大。 一、開關 洛谷的這種板子題寫起來比cf順多了&#xff08;&#xff09; #include <bits/stdc.h> using namespace std;typedef long long ll; typedef pair<int,int> pii; typedef pair<ll,ll&…

【LeetCode_27】移除元素

刷爆LeetCode系列LeetCode27題&#xff1a;github地址前言題目描述題目思路分析代碼實現算法代碼優化LeetCode27題&#xff1a; github地址 有夢想的電信狗 前言 本文用C實現LeetCode 第27題 題目描述 題目鏈接&#xff1a;https://leetcode.cn/problems/remove-element/ …

C++11語言(三)

一、引言上期我們介紹了C11的大部分特性。C11的初始化列表、auto關鍵字、右值引用、萬能引用、STL容器的的emplace函數。要補充的是右值引用是不能取地址的&#xff0c;我們程序員一定要遵守相關的語法。操作是未定義的很危險。二、 仿函數和函數指針我們先從仿函數的形…

性能優化三劍客:`memo`, `useCallback`, `useMemo` 詳解

性能優化三劍客&#xff1a;memo, useCallback, useMemo 詳解 作者&#xff1a;碼力無邊各位React性能調優師&#xff0c;歡迎來到《React奇妙之旅》的第十二站&#xff01;我是你們的伙伴碼力無邊。在之前的旅程中&#xff0c;我們已經掌握了如何構建功能豐富的組件&#xff0…

好用的電腦軟件、工具推薦和記錄

固態硬盤讀寫測試 AS SSD Benchmark https://gitee.com/qlexcel/common-resource-backup/blob/master/AS%20SSD%20Benchmark.exe 可以測試SSD的持續讀寫、4K隨機讀寫等性能。也可以測試HDD的性能。 操作非常簡單&#xff0c;點擊Start(開始)即可測試。 體積小&#xff0c;免安…

Spring Task快速上手

一. 介紹Spring Task 是Spring框架提供的任務調度工具&#xff0c;可以按照約定的時間自動執行某個代碼邏輯&#xff0c;無需依賴額外組件&#xff08;如 Quartz&#xff09;&#xff0c;配置簡單、使用便捷&#xff0c;適合處理周期性執行的任務&#xff08;如定時備份數據、定…

函數(2)

6.定義函數的終極絕殺思路&#xff1a;三個問題&#xff1a;1.我定義函數&#xff0c;是為了干什么事情 函數體、2.我干完這件事&#xff0c;需要什么才能完成 形參3.我干完了&#xff0c;調用處是否需要繼續使用 返回值類型需要繼續使用 必須寫不需要返回 void小程序#include …

BGP路由協議(一):基本概念

###BGP概述 BGP的版本&#xff1a; BGP-1 RFC1105BGP-2 RFC1163BGP-3 RFC1267BGP-4 RFC1771 1994年BGP-4 RFC4271 2006年 AS Autonomous System 自治系統&#xff1a;由一個單一的機構或者組織所管理的一系列IP網絡及其設備所構成的集合 根據工作范圍的不同&#xff0c;動態路…

mit6.031 2023spring 軟件構造 筆記 Testing

當你編碼時&#xff0c;目標是使程序正常工作。 但作為測試設計者&#xff0c;你希望讓它失敗。 這是一個微妙但重要的區別。 為什么軟件測試很難&#xff1f; 做不到十分詳盡&#xff1a;測試一個 32 位浮點乘法運算 。有 2^64 個測試用例&#xff01;隨機或統計測試效果差&am…

【Unity開發】Unity核心學習(三)

四、三維模型導入相關設置 1、Model模型頁簽&#xff08;1&#xff09;場景相關&#xff08;2&#xff09;網格相關&#xff08;3&#xff09;幾何體相關2、Rig操縱&#xff08;骨骼&#xff09;頁簽 &#xff08;1&#xff09;面板基礎信息&#xff08;i&#xff09;None&…

C#語言入門詳解(17)字段、屬性、索引器、常量

C#語言入門詳解&#xff08;17&#xff09;字段、屬性、索引器、常量前言一、字段 Field二、屬性三、索引器四、常量內容來自劉鐵猛C#語言入門詳解課程。 參考文檔&#xff1a;CSharp language specification 5.0 中文版 前言 類的成員是靜態成員 (static member) 或者實例成…

Total PDF Converter多功能 PDF 批量轉換工具,無水印 + 高效處理指南

在辦公場景中&#xff0c;PDF 格式的 “不可編輯性” 常成為效率瓶頸 —— 從提取文字到格式轉換&#xff0c;從批量處理到文檔加密&#xff0c;往往需要多款工具協同。Total PDF Converter 破解專業版作為一站式 PDF 解決方案&#xff0c;不僅支持 11 種主流格式轉換&#xff…