【PyTorch][chapter 20][李宏毅深度學習]【無監督學習][ GAN]【實戰】

前言

?本篇主要是結合手寫數字例子,結合PyTorch 介紹一下Gan 實戰

第一輪訓練效果

第20輪訓練效果,已經可以生成數字了

68 輪


目錄:?

  1. ? 谷歌云服務器(Google Colab)
  2. ? 整體訓練流程
  3. ? Python 代碼

一? 谷歌云服務器(Google Colab)

? ? ?個人用的一直是聯想小新筆記本,雖然非常穩定方便。但是現在跑深度學習,性能確實有點跟不上.?

? ?1.1? ? 打開谷歌云服務器(Google Colab)

? ? ? https://colab.research.google.com/

? ? 1. 2? 新建筆記

? ? ? ? ? ? ? ? ?

1

?1.4? 選擇T4GPU?

1.5? 點擊運行按鈕

可以看到當前硬件的情況

? ? ?


二? 整體訓練流程


三? ? PyTorch 例子

# -*- coding: utf-8 -*-
"""
Created on Fri Mar  1 13:27:49 2024@author: chengxf2
"""
import torch.optim as optim #優化器
import numpy as np 
import matplotlib.pyplot  as plt
import torchvision
from torchvision import transforms
import torch
import torch.nn as nn#第一步加載手寫數字集
def loadData():#同時歸一化數據集(-1,1)style = transforms.Compose([transforms.ToTensor(),   #0-1 歸一化0-1, channel,height,widthtransforms.Normalize(mean=0.5, std=0.5) #變成了-1,1 ])trainData = torchvision.datasets.MNIST('data',train=True,transform=style,download=True)dataloader = torch.utils.data.DataLoader(trainData,batch_size= 16,shuffle=True)imgs,_ = next(iter(dataloader))#torch.Size([64, 1, 28, 28])print("\n imgs shape ",imgs.shape)return dataloaderclass Generator(nn.Module):'''定義生成器輸入:z 隨機噪聲[batch, input_size]輸出:x: 圖片 [batch, height, width, channel]'''def __init__(self,input_size):super(Generator,self).__init__()self.net = nn.Sequential(nn.Linear(in_features = input_size , out_features =256),nn.ReLU(),nn.Linear(in_features = 256 , out_features =512),nn.ReLU(),nn.Linear(in_features = 512 , out_features =28*28),nn.Tanh())def forward(self, z):# z 隨機輸入[batch, dim]x = self.net(z)#[batch, height, width, channel]#print(x.shape)x = x.view(-1,28,28,1)return xclass Discriminator(nn.Module):'''定義鑒別器輸入:x: 圖片 [batch, height, width, channel]輸出:y:  二分類圖片的概率: BCELoss 計算交叉熵損失'''def __init__(self):super(Discriminator,self).__init__()#開始的維度和終止的維度,默認值分別是1和-1self.flatten = nn.Flatten()self.net = nn.Sequential(nn.Linear(in_features = 28*28 , out_features =512),nn.LeakyReLU(), #負值的時候保留梯度信息nn.Linear(in_features = 512 , out_features =256),nn.LeakyReLU(),nn.Linear(in_features = 256 , out_features =1),nn.Sigmoid())def forward(self, x):x = self.flatten(x)#print(x.shape)out =self.net(x)return outdef gen_img_plot(model, epoch, test_input):out = model(test_input).detach().cpu()out = out.numpy()imgs = np.squeeze(out)fig = plt.figure(figsize=(4,4))for i in range(out.shape[0]):plt.subplot(4,4,i+1)img = (imgs[i]+1)/2.0#[-1,1]plt.imshow(img)plt.axis('off')plt.show()def train():#1 初始化參數device ='cuda' if torch.cuda.is_available() else 'cpu'#2 加載訓練數據dataloader = loadData()test_input  = torch.randn(16,100,device=device)#3 超參數maxIter = 20 #最大訓練次數input_size = 100batchNum = 16input_size =100#4 初始化模型gen = Generator(100).to(device)dis = Discriminator().to(device)#5 優化器,損失函數d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)g_optim = torch.optim.Adam(gen.parameters(),lr=1e-4)loss_fn = torch.nn.BCELoss()#6 loss 變化列表D_loss =[]G_loss= []for epoch in range(0,maxIter):d_epoch_loss = 0.0g_epoch_loss  =0.0#count = len(dataloader)for step ,(realImgs, _) in enumerate(dataloader):realImgs = realImgs.to(device)random_noise = torch.randn(batchNum, input_size).to(device)#先訓練判別器d_optim.zero_grad()real_output = dis(realImgs)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))d_real_loss.backward()#不要訓練生成器,所以要生成器detachfake_img = gen(random_noise)fake_output = dis(fake_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))d_fake_loss.backward()d_loss = d_real_loss+d_fake_lossd_optim.step()#優化生成器g_optim.zero_grad()fake_output = dis(fake_img.detach())g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss+= d_lossg_epoch_loss+= g_losscount = 16       with torch.no_grad():d_epoch_loss/=countg_epoch_loss/=countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)gen_img_plot(gen, epoch, test_input)print("Epoch: ",epoch)print("-----finised-----")if __name__ == "__main__":train()

參考:

10.完整課程簡介_嗶哩嗶哩_bilibili

理論【PyTorch][chapter 19][李宏毅深度學習]【無監督學習][ GAN]【理論】-CSDN博客

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/711894.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/711894.shtml
英文地址,請注明出處:http://en.pswp.cn/news/711894.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Linux學習-字符串數組和字符串

目錄 使用場景 字符型數組定義: 初始化 數組儲存 打印 字符型數組常見函數 常見操作 strcpy:字符串拷貝 strcat(str1,str2)字符串拼接 strcmp:字符串比較 注意: 二維字符型數…

Open CASCADE學習|曲線曲面連續性

1、曲線的連續性 曲線的連續性是三維建模、動畫設計等領域中非常重要的一個概念,它涉及到曲線在不同點之間的連接方式和光滑程度。下面將詳細介紹曲線的連續性,包括C連續性和G連續性。 1.1C連續性(參數連續性) C連續性是指曲線…

使用MyBatisPlus實現向數據庫中存儲List類型的數據

使用MyBatisPlus實現向數據庫中存儲List類型的數據 問題描述 建表時,表中的這五個字段為json類型 但是在入庫的時候既不能寫入數據,也不能查詢出數據。 解決方案: 1.首先明確,數據存入的時候是經過了數據類型轉化&#xff0c…

中國電子學會2020年06月真題C語言軟件編程等級考試三級(含詳細解析答案)

中國電子學會考評中心歷屆真題(含解析答案) C語言軟件編程等級考試三級 2020年06月 編程題五道 總分:100分一、最接近的分數(20分) 分母不超過N且小于A/B的最大最簡分數是多少? 時間限制: 1000ms 內存限制: 65536kb 輸入…

數據之光:探索數據庫技術的演進之路

?? 歡迎大家來訪Srlua的博文(づ ̄3 ̄)づ╭?~?? 🌟🌟 歡迎各位親愛的讀者,感謝你們抽出寶貴的時間來閱讀我的文章。 我是Srlua,在這里我會分享我的知識和經驗。&#x…

喜訊!持安科技CEO何藝獲評安全419《2023年度十大優秀創業者》

近日,由網絡安全產業資訊媒體安全419主辦的《年度策劃》2023年度十大優秀創業者正式出爐,零信任辦公安全技術創新企業持安科技創始人兼CEO何藝,獲評十大優秀創業者。 這是安全419第二屆推出該項目的評選活動,安全419編輯老師在多年…

抽象類、模板方法模式

抽象類概述 在Java中abstract是抽象的意思,如果一個類中的某個方法的具體實現不能確定,就可以申明成abstract修飾的抽象方法(不能寫方法體了),這個類必須用abstract修飾,被稱為抽象類。 抽象方法定義&…

【解決】修改 UI界面渲染層級 的常見誤區

開發平臺:Unity 2021版本 ? 問題描述 Unity 中管理 UI 上顯示元素的前后層級關系大致為以下兩種方式: 方式一:修改UI元素隊列順序與層級方式二:使用 Canvas 組件中的 Override Sort 屬性配置 方式二 對應復雜的 UI 層級關系將常…

這些單片機匯編語言的錯誤,你還在犯錯嗎?

在單片機開發中,很多工程師會選擇匯編語言來作為底層編程,來直接控制硬件和高校執行命令,然而因為匯編語言是直接與硬件交互,所以很容易出現錯誤,本文將基于Keil C51匯編器的環境總結單片機匯編語言常見的錯誤&#xf…

人工智能_大模型010_Centos7.9中CPU安裝ChatGLM3-6B大模型_安裝使用_010---人工智能工作筆記0145

從一個空的虛擬機開始安裝: https://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/files 可以看到這里有很多的數據文件,那么這里 這里點擊模型文件就可以下載,這個就是chatglm3-6B的文件,需要點擊每個文件,然后點擊右邊的下載,把文件都下載下來 右側有下載按鈕.點擊下載可…

使用Fabric創建的canvas畫布背景圖片,自適應畫布寬高

之前的文章寫過vue2使用fabric實現簡單畫圖demo,完成批閱功能;但是功能不完善,對于很大的圖片就只能顯示一部分出來,不符合我們的需求。這就需要改進,對我們設置的背景圖進行自適應。 有問題的canvas畫布背景 修改后的…

Unity2023.1.19_ECS

Unity2023.1.19_ECS 在學習的路上一往無前的遇到了好東西,官方的EntityComponnentSystemSamples的Repository,這是一個包含實體,圖形,網絡,物理案例的全方位案例教程。 又找見接下來要干的事情了!學習永無…

【rust】11、所有權

文章目錄 一、背景二、Stack 和 Heap2.1 Stack2.2 Heap2.3 性能區別2.4 所有權和堆棧 三、所有權原則3.1 變量作用域3.2 String 類型示例 四、變量綁定背后的數據交互4.1 所有權轉移4.1.1 基本類型: 拷貝, 不轉移所有權4.1.2 分配在 Heap 的類型: 轉移所有權 4.2 Clone(深拷貝)…

Quartz 任務調度框架源碼閱讀解析

概念: quartz 是一個基于JAVA的定時任務調度框架 案例: <dependency><groupId>org.quartz-scheduler</groupId><artifactId>quartz</artifactId><version>2.3.0</version></dependency>JobDetail job JobBuilder.newJob(Sc…

每日一練 | 華為認證真題練習Day191

1、在沒有啟用BGP路徑負載分擔的情況下&#xff0c;哪種BGP路由會發送BGP鄰居? A. 從所有鄰居學到的所有BGP路由。 B. 只有從IBGP學到的路由。 C. 只有從EBGP學到的路由。 D. 只有被BGP優選的最佳路由。 2、第三類LSA的LINK ID是 A. 生成這條LSA的路由器的ROUTER ID B. …

LeetCode 刷題 [C++] 第236題.二叉樹的最近公共祖先

題目描述 給定一個二叉樹, 找到該樹中兩個指定節點的最近公共祖先。 百度百科中最近公共祖先的定義為&#xff1a;“對于有根樹 T 的兩個節點 p、q&#xff0c;最近公共祖先表示為一個節點 x&#xff0c;滿足 x 是 p、q 的祖先且 x 的深度盡可能大&#xff08;一個節點也可以…

大數據分析案例-基于SVM支持向量機算法構建手機價格分類預測模型

&#x1f935;?♂? 個人主頁&#xff1a;艾派森的個人主頁 ?&#x1f3fb;作者簡介&#xff1a;Python學習者 &#x1f40b; 希望大家多多支持&#xff0c;我們一起進步&#xff01;&#x1f604; 如果文章對你有幫助的話&#xff0c; 歡迎評論 &#x1f4ac;點贊&#x1f4…

矩陣爆破逆向之條件斷點的妙用

不知道你是否使用過IDA的條件斷點呢&#xff1f;在IDA進階使用中&#xff0c;它的很多功能都有大作用&#xff0c;比如&#xff1a;ida-trace來跟蹤調用流程。同時IDA的斷點功能也十分強大&#xff0c;配合IDA-python的輸出語句能夠大殺特殺&#xff01; 那么本文就介紹一下這…

【JAVA】JDK內置工具之appletviewer

下載java 下載java的時候會先下載Java jdk&#xff0c;Java Development Kit Java開發工具包。 然后會下載jre&#xff0c;也就是Java Runtime Environment Java運行環境。什么是JDK、JRE&#xff1f;_java中的jdk,jre代表什么-CSDN博客 下載之后先找到java下的bin文件&#x…

yolov9 tensorRT 的 C++ 部署

yolov9 tensorRT C 部署 本示例中&#xff0c;包含完整的代碼、模型、測試圖片、測試結果。 完整的代碼、模型、測試圖片、測試結果【github參考鏈接】 TensorRT版本&#xff1a;TensorRT-7.1.3.4 導出onnx模型 導出適配本實例的onnx模型參考【yolov9 瑞芯微芯片rknn部署、地平…