GANs生成對抗網絡生成手寫數字的Pytorch實現

目錄

一、第三方庫導入

二、數據集準備

三、使用轉置卷積的生成器

四、使用卷積的判別器

五、生成器生成圖像

六、主程序

七、運行結果

7.1 生成器和判別器的損失函數圖像

7.2 訓練過程中生成器生成的圖像

八、完整的pytorch代碼


由于之前寫gans的代碼時,我的生成器和判別器不是使用的全連接網絡就是卷積,但是無論這兩種方法怎么組合,最后生成器生成的圖像效果都很不好。因此最后我選擇了生成器使用轉置卷積,而判別器使用卷積,最后得到的生成圖像確實效果比之前好很多了。

一、第三方庫導入

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 設置中文字體
plt.rcParams['axes.unicode_minus'] = False  # 正常顯示負號
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader

二、數據集準備

# 手寫數字數據集
class MINISTDataset(Dataset):def __init__(self, files, root_dir, transform=None):self.files = filesself.root_dir = root_dirself.transform = transformself.labels = []for f in files:parts = f.split("_")p = parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self, idx):img_path = os.path.join(self.root_dir, self.files[idx])img = Image.open(img_path).convert("L")if self.transform:img = self.transform(img)label = self.labels[idx]return img, label

三、使用轉置卷積的生成器

class Generator(nn.Module):def __init__(self, latent_dim=100):super().__init__()self.main = nn.Sequential(# 輸入: latent_dim維噪聲 -> 輸出: 7x7x256nn.ConvTranspose2d(latent_dim, 256, kernel_size=7, stride=1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# 上采樣: 7x7 -> 14x14nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),# 上采樣: 14x14 -> 28x28nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 輸出層: 28x28x1nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False),nn.Tanh())def forward(self, x):# 將噪聲重塑為 (batch_size, latent_dim, 1, 1)x = x.view(x.size(0), -1, 1, 1)return self.main(x)

四、使用卷積的判別器

class Discriminator(nn.Module):def __init__(self):super().__init__()self.main = nn.Sequential(# 輸入: 1x28x28nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # 輸出: 32x14x14nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 輸出: 64x7x7nn.BatchNorm2d(64),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # 輸出: 128x7x7nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Flatten(),nn.Linear(128 * 7 * 7, 1),nn.Sigmoid())def forward(self, x):return self.main(x)

五、生成器生成圖像

# 展示生成器生成的圖像
def gen_img_plot(test_input, save_path):gen_imgs = gen(test_input).detach().cpu()gen_imgs = gen_imgs.view(-1, 28, 28)plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow(gen_imgs[i], cmap="gray")plt.axis("off")plt.savefig(save_path, dpi=300)plt.close()

六、主程序

if __name__ == "__main__":# 對數據做歸一化處理transforms = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 路徑base_dir = 'C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir = os.path.join(base_dir, "minist_train")# 獲取文件夾里圖像的名稱train_files = [f for f in os.listdir(train_dir) if f.endswith('.jpg')]# 創建數據集和數據加載器train_dataset = MINISTDataset(train_files, train_dir, transform=transforms)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 參數epochs = 50lr = 0.0002# 初始化模型的優化器和損失函數gen = Generator()dis = Discriminator()d_optim = torch.optim.Adam(dis.parameters(), lr=lr, betas=(0.5, 0.999))  # 判別器的優化器g_optim = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))  # 生成器的優化器loss_fn = torch.nn.BCELoss()  # 二分類交叉熵損失函數# 記錄lossD_loss = []G_loss = []# 訓練for epoch in range(epochs):d_epoch_loss = 0g_epoch_loss = 0count = len(train_loader)  # 返回批次數for step, (img, _) in enumerate(train_loader):# 每個批次的大小size = img.size(0)random_noise = torch.randn(size, 100)# 判別器訓練d_optim.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))# d_real_loss.backward()gen_img = gen(random_noise)gen_img = gen_img.view(size, 1, 28, 28)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))# d_fake_loss.backward()d_loss = (d_real_loss + d_fake_loss) / 2d_loss.backward()d_optim.step()# 生成器的訓練g_optim.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()g_optim.step()# 計算在一個epoch里面所有的g_loss和d_losswith torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_loss# 計算平均損失值with torch.no_grad():d_epoch_loss = d_epoch_loss / countg_epoch_loss = g_epoch_loss / countD_loss.append(d_epoch_loss.item())G_loss.append(g_epoch_loss.item())print("Epoch:", epoch, "  D loss:", d_epoch_loss.item(), "  G Loss:", g_epoch_loss.item())# 每隔2個epoch繪制生成器生成的圖像if (epoch + 1) % 2 == 0:test_input = torch.randn(16, 100)name = f"gen_img_{epoch}.jpg"save_path = os.path.join('C:\\Users\\Administrator\\PycharmProjects\\CNN\\gen_img_11', name)gen_img_plot(test_input, save_path)# 繪制損失曲線圖plt.figure(figsize=(12, 6))plt.plot(D_loss, label="判別器", color="tomato")plt.plot(G_loss, label="生成器", color="orange")plt.xlabel("epoch")plt.ylabel("loss")plt.title("生成器和判別器的損失曲線圖")plt.legend()plt.grid()plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\gen_dis_loss_11.jpg", dpi=300, bbox_inches="tight")plt.close()

七、運行結果

7.1 生成器和判別器的損失函數圖像

7.2 訓練過程中生成器生成的圖像

這里只展示一部分

gen_img_1.jpg

gen_img_25.jpg

gen_img_49.jpg

八、完整的pytorch代碼

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 設置中文字體
plt.rcParams['axes.unicode_minus'] = False  # 正常顯示負號
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader# 手寫數字數據集
class MINISTDataset(Dataset):def __init__(self, files, root_dir, transform=None):self.files = filesself.root_dir = root_dirself.transform = transformself.labels = []for f in files:parts = f.split("_")p = parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self, idx):img_path = os.path.join(self.root_dir, self.files[idx])img = Image.open(img_path).convert("L")if self.transform:img = self.transform(img)label = self.labels[idx]return img, label# 改進的生成器(使用轉置卷積)
class Generator(nn.Module):def __init__(self, latent_dim=100):super().__init__()self.main = nn.Sequential(# 輸入: latent_dim維噪聲 -> 輸出: 7x7x256nn.ConvTranspose2d(latent_dim, 256, kernel_size=7, stride=1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),# 上采樣: 7x7 -> 14x14nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),# 上采樣: 14x14 -> 28x28nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(True),# 輸出層: 28x28x1nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False),nn.Tanh())def forward(self, x):# 將噪聲重塑為 (batch_size, latent_dim, 1, 1)x = x.view(x.size(0), -1, 1, 1)return self.main(x)# 改進的判別器(使用深度卷積網絡)
class Discriminator(nn.Module):def __init__(self):super().__init__()self.main = nn.Sequential(# 輸入: 1x28x28nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # 輸出: 32x14x14nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 輸出: 64x7x7nn.BatchNorm2d(64),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # 輸出: 128x7x7nn.BatchNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Dropout2d(0.3),nn.Flatten(),nn.Linear(128 * 7 * 7, 1),nn.Sigmoid())def forward(self, x):return self.main(x)# 展示生成器生成的圖像
def gen_img_plot(test_input, save_path):gen_imgs = gen(test_input).detach().cpu()gen_imgs = gen_imgs.view(-1, 28, 28)plt.figure(figsize=(4, 4))for i in range(16):plt.subplot(4, 4, i + 1)plt.imshow(gen_imgs[i], cmap="gray")plt.axis("off")plt.savefig(save_path, dpi=300)plt.close()if __name__ == "__main__":# 對數據做歸一化處理transforms = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 路徑base_dir = 'C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir = os.path.join(base_dir, "minist_train")# 獲取文件夾里圖像的名稱train_files = [f for f in os.listdir(train_dir) if f.endswith('.jpg')]# 創建數據集和數據加載器train_dataset = MINISTDataset(train_files, train_dir, transform=transforms)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 參數epochs = 50lr = 0.0002# 初始化模型的優化器和損失函數gen = Generator()dis = Discriminator()d_optim = torch.optim.Adam(dis.parameters(), lr=lr, betas=(0.5, 0.999))  # 判別器的優化器g_optim = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))  # 生成器的優化器loss_fn = torch.nn.BCELoss()  # 二分類交叉熵損失函數# 記錄lossD_loss = []G_loss = []# 訓練for epoch in range(epochs):d_epoch_loss = 0g_epoch_loss = 0count = len(train_loader)  # 返回批次數for step, (img, _) in enumerate(train_loader):# 每個批次的大小size = img.size(0)random_noise = torch.randn(size, 100)# 判別器訓練d_optim.zero_grad()real_output = dis(img)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))# d_real_loss.backward()gen_img = gen(random_noise)gen_img = gen_img.view(size, 1, 28, 28)fake_output = dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))# d_fake_loss.backward()d_loss = (d_real_loss + d_fake_loss) / 2d_loss.backward()d_optim.step()# 生成器的訓練g_optim.zero_grad()fake_output = dis(gen_img)g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()g_optim.step()# 計算在一個epoch里面所有的g_loss和d_losswith torch.no_grad():d_epoch_loss += d_lossg_epoch_loss += g_loss# 計算平均損失值with torch.no_grad():d_epoch_loss = d_epoch_loss / countg_epoch_loss = g_epoch_loss / countD_loss.append(d_epoch_loss.item())G_loss.append(g_epoch_loss.item())print("Epoch:", epoch, "  D loss:", d_epoch_loss.item(), "  G Loss:", g_epoch_loss.item())# 每隔2個epoch繪制生成器生成的圖像if (epoch + 1) % 2 == 0:test_input = torch.randn(16, 100)name = f"gen_img_{epoch}.jpg"save_path = os.path.join('C:\\Users\\Administrator\\PycharmProjects\\CNN\\gen_img_11', name)gen_img_plot(test_input, save_path)# 繪制損失曲線圖plt.figure(figsize=(12, 6))plt.plot(D_loss, label="判別器", color="tomato")plt.plot(G_loss, label="生成器", color="orange")plt.xlabel("epoch")plt.ylabel("loss")plt.title("生成器和判別器的損失曲線圖")plt.legend()plt.grid()plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\gen_dis_loss_11.jpg", dpi=300, bbox_inches="tight")plt.close()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/93412.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/93412.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/93412.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

ubuntu 通過NAT模式上網

這里必須使用VMnet8 設置為NAT模式 下面設置Ip地址區域ubuntu ip地址設置來自于上面

盲盒抽谷機小程序系統開發:從0到1的完整方法論

開發一款成功的盲盒抽谷機小程序系統,需兼顧技術實現、用戶體驗與商業邏輯。本文將從需求分析、UI/UX設計、技術架構、測試上線到運營增長,系統梳理從0到1的完整方法論。需求分析:明確“為誰而做”盲盒抽谷機的核心用戶是18-35歲的二次元愛好…

web開發,在線%射擊比賽管理%系統開發demo,基于html,css,jquery,python,django,三層mysql數據庫

經驗心得 兩業務單,業務crud開發很簡單了,自行學習,我說一下學習流程。什么是前端,用到那些技術html,css,javascript分別是什么?進階jquery,bootstrap,各種常見前端組件又是什么,前端框架react,angular以及…

Centos9傻瓜式linux部署CRMEB 開源商城系統(PHP)

服務器環境推薦要求* Nignx(必須) * PHP 7.1 ~ 7.4(必須此版本內,版本過大會警告不兼容) * MySQL 5.7 ~ 8.0(必須) * Redis(非必須)后臺頁面展示:…

AI 云電競游戲盒子:從“盒子”到“云-端-芯”一體化競技平臺的架構實踐

摘要 AI 云電競游戲盒子(以下簡稱“電競盒”)不再是一臺簡單的客廳游戲主機,而是一套以 AI 調度為核心、以云原生架構為骨架、以邊緣渲染為肌肉、以端側感知為神經的“云-端-芯”協同競技系統。本文基于 2024 年 Q2 落地的量產方案&#xff0…

基于kuboard實現kubernetes的集群管理

1、前提條件安裝docker-compose2、步驟在本地目錄創建kuboard-v4\在該目錄下創建文件docker-compose.yaml,內容如下:configs:create_db_sql:content: |CREATE DATABASE kuboard DEFAULT CHARACTER SET utf8mb4 DEFAULT COLLATE utf8mb4_unicode_ci;cre…

Linux操作系統軟件編程——多線程

什么是線程線程的定義是輕量級的進程,可以實現多任務的并發。線程是操作系統任務調度的最小單位線程的創建由某個進程創建,且進程創建線程時,會為其分配獨立的棧區空間(默認8M)。線程和所在的進程,以及進程…

linux下找到指定目錄下最新日期log文件

以下是一個完整的C函數&#xff0c;用于在指定目錄下自動查找最近更新的日志文件&#xff08;根據文件名中的時間戳選擇最新的文件&#xff09;&#xff1a;#include <stdio.h> #include <stdlib.h> #include <string.h> #include <dirent.h> #include…

《數學模型》經典案例——鋼管的訂購與運輸

一、問題描述 要鋪設一條 A1→A2→?→A15A_1 \rightarrow A_2 \rightarrow \cdots \rightarrow A_{15}A1?→A2?→?→A15? 的輸送天然氣的主管道&#xff0c;如圖 6.22 所示。經篩選后可以生產這種主管道鋼管的鋼廠有 S1,S2,?,S7S_1, S_2, \cdots, S_7S1?,S2?,?,S7? 。…

Java Web部署

今天小編來分享下如何將本地寫的Java Web程序部署到Linux上。 小編介紹兩種方式&#xff1a; 部署基于Linux Systemd服務、基于Docker容器化部署 首先部署基于Linux Systemd服務 那么部署之前&#xff0c;要對下載所需的環境 軟件下載 Linux&#xff08;以ubuntu&#xf…

告別AI“煉丹術”:“策略懸崖”理論如何為大模型對齊指明科學路徑

摘要&#xff1a;當前&#xff0c;我們訓練大模型的方式&#xff0c;尤其是RLHF&#xff0c;充滿了不確定性&#xff0c;時常產生“諂媚”、“欺騙”等怪異行為&#xff0c;被戲稱為“煉丹”。一篇來自上海AI Lab的重磅論文提出的“策略懸崖”理論&#xff0c;首次為這個混沌的…

深入理解C#特性:從應用到自定義

——解鎖元數據標記的高級玩法&#x1f4a1; 核心認知&#xff1a;特性本質揭秘 public sealed class ReviewCommentAttribute : System.Attribute { ... }特性即特殊類&#xff1a;所有自定義特性必須繼承 System.Attribute&#xff08;基礎規則&#xff09;命名規范&#xff…

機器學習-集成學習(EnsembleLearning)

0 結果展示 0.1 鳶尾花分類 import pandas as pd import numpy as npfrom sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, recall_score, f1_score, classification_repo…

Golang database/sql 包深度解析(一)

database/sql 是 Go 語言標準庫中用于與 SQL&#xff08;或類 SQL&#xff09;數據庫交互的核心包&#xff0c;提供了一套輕量級、通用的接口&#xff0c;使得開發者可以用統一的方式操作各種不同的數據庫&#xff0c;而無需關心底層數據庫驅動的具體實現。 核心設計理念 datab…

文章自然潤色 API 數據接口

文章自然潤色 API 數據接口 ai / 文本處理 基于 AI 的文章潤色 專有模型 / 智能糾錯。 1. 產品功能 基于自有專業模型進行 AI 智能潤色對原始內容進行智能糾錯高效的文本潤色性能全接口支持 HTTPS&#xff08;TLS v1.0 / v1.1 / v1.2 / v1.3&#xff09;&#xff1b;全面兼容…

【狀壓DP】3276. 選擇矩陣中單元格的最大得分|2403

本文涉及知識點 C動態規劃 3276. 選擇矩陣中單元格的最大得分 給你一個由正整數構成的二維矩陣 grid。 你需要從矩陣中選擇 一個或多個 單元格&#xff0c;選中的單元格應滿足以下條件&#xff1a; 所選單元格中的任意兩個單元格都不會處于矩陣的 同一行。 所選單元格的值 互…

IDEA 清除 ctrl+shift+r 全局搜索記錄

定位文件&#xff1a;在Windows系統中&#xff0c;文件通常位于C:Users/用戶名/AppData/Roaming/JetBrains/IntelliJIdea(idea版本)/workspace目錄下&#xff0c;文件名為一小串隨機字符&#xff1b;在Mac系統中&#xff0c;文件位于/Users/用戶名/Library/Application /Suppor…

解鎖AI大模型:Prompt工程全面解析

解鎖AI大模型&#xff1a;Prompt工程全面解析 本文較長&#xff0c;建議點贊收藏&#xff0c;以免遺失。更多AI大模型開發 學習視頻/籽料/面試題 都在這>>Github<< 從新手到高手&#xff0c;Prompt 工程究竟是什么&#xff1f; 在當今數字化時代&#xff0c;AI …

HTTP0.9/1.0/1.1/2.0

在HTTP0.9中&#xff0c;只有GET方法&#xff0c;沒有請求頭headers&#xff0c;沒有狀態碼&#xff0c;只能用于傳輸HTML文件。到了HTTP1.0(1996)&#xff0c;HTTP1.0傳輸請求頭&#xff0c;有狀態碼&#xff0c;并且新增了POST和HEAD方法。HTTP1.0中&#xff0c;使用短連接&a…

gitee 流水線+docker-compose部署 nodejs服務+mysql+redis

文章中的方法是自己琢磨出來的&#xff0c;或許有更優解&#xff0c;共同學習&#xff0c;共同進步&#xff01; docker-compose.yml 文件配置&#xff1a; 說明&#xff1a;【配置中有個別字段冗余&#xff0c;但不影響使用】該文件推薦放在nodejs項目的根目錄中&#xff0c…