ds證據理論python實現_ALI模型理論以及Python實現

https://openreview.net/forum?id=B1ElR4cgg

模型結構和明天要發BiGAN模型一模一樣,但是兩篇論文的作者都是獨立完成自己的內容的。而且從寫作的風格來看emmm完全不一樣

ALI跟BiGAN的設計一模一樣,但是就是沒有加Latent regressor。雖然在ALI中也簡要地談到了這個Latent regressor。

并且根據ALI中的模型(G, E,D的架構)更容易實現,條理更加清晰,模型的結構設計實現也很容易。

ALI和BiGAN的對比

整體的設計上一模一樣,這是共同點。并且兩者都是獨立設計的。

  1. ALI雖然提到Latent regressor但是,并沒有使用。(只是說可以用來作為一個正則化,提高精度的額外的方法);BiGAN則放了較大的筆墨在這個regressor上。

  2. ALI結構更加清晰,并且各個模塊的訓練對應的損失也較大的很清晰;BiGAN雖然有在語言上大致描述為什么,但是描述的不夠直觀清晰,而且GAN訓練本來就存在大量的坑,稍微'合理'修改某個小細節,就會導致訓練不出結果。

  3. ALI對于E的解釋上做得比較好(E可以理解為另外的一種G),這樣看來都是來fool D的所以也是一種對抗,較為直觀。并且ALI的數學分析部分和GAN承接的更加好,寫得更加清晰。

b0a2954811b60d3c3dec7a0c41e998ea.png

計算方式,使用L2范數

兩者雖然都談到了Latent regressor,但是ALI更側重筆墨于模型的結構的設計(但是畫圖不行)。BiGAN雖然更側重于Latent?regressor,但是結構畫圖相當不錯。可以說是非常喜劇了。

給個對比:

32922ec45396e2d8933df2fc033e11b0.png

兩個結構說的是一回事,當時看到真的笑死。

c5d40640ba5e9e66dc6b156510d2777d.png

雖然談到了latent regressor,但是算法中并沒有交代使用。

800a3a7ee9123b0e855c662c57263ff5.png

BiGAN雖然交代了使用,但是BiGAN沒有給損失的具體寫法,對于E的訓練要自己設計。

可能是兩者都或多或少有點問題,所以17年的ICLR就把兩篇都錄用了。

(或少:應該是ALI,或多多半是BiGAN)

后來就常用 ALI/BiGAN來表示這個模型。


恰飯


實驗

實驗相比于BiGAN沒有使用latent regressor,但是效果居然也還行。

按照論文實驗操作一樣,第一行是G(E(x)),第二行是x。

x來源是真實數據。通過E學習到x的隱式特征z,輸入給G,讓G生成。

99553413a77536ef11e2d7f767cedeb5.png

1a00d70f4520cf43f3db8c21f1a44da9.png

edf7730360c69e5ec35bd78725636fc8.png

18060262250841c74d4318118840d72d.png

main.py

import osimport torchfrom torch.utils.data import Dataset, DataLoaderimport torch.nn as nnfrom model import Generator, Discriminator, Encoderimport torchvisionimport itertoolsimport matplotlib.pyplot as pltimport torchvision.utils as vutilsimport numpy as npif __name__ == '__main__':    LR = 0.0002    EPOCH = 100  # 50    BATCH_SIZE = 100    N_IDEAS = 128    lam = 1    DOWNLOAD_MNIST = False    TRAINED = False    mnist_root = '../Conditional-GAN/mnist/'    if not (os.path.exists(mnist_root)) or not os.listdir(mnist_root):        # not mnist dir or mnist is empyt dir        DOWNLOAD_MNIST = True    train_data = torchvision.datasets.MNIST(        root=mnist_root,        train=True,  # this is training data        transform=torchvision.transforms.ToTensor(),  # Converts a PIL.Image or numpy.ndarray to        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]        download=DOWNLOAD_MNIST,    )    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)    torch.cuda.empty_cache()    if TRAINED:        G = torch.load('G.pkl').cuda()        D = torch.load('D.pkl').cuda()        E = torch.load('E.pkl').cuda()    else:        G = Generator(N_IDEAS).cuda()        D = Discriminator().cuda()        E = Encoder(input_size=1, out_size=N_IDEAS).cuda()    optimizerG_E = torch.optim.Adam(itertools.chain(G.parameters(), E.parameters()), lr=LR)    optimizerD = torch.optim.Adam(D.parameters(), lr=LR)    l_c = nn.MSELoss()    for epoch in range(EPOCH):        tmpD, tmpG_E, tmpE = 0, 0, 0        for step, (x, y) in enumerate(train_loader):            # x            x = x.cuda()            z = torch.randn((x.shape[0], N_IDEAS, 1, 1)).cuda()            # z, G, D            G_z = G(z)            D_G_z = torch.mean(D(G_z, z))  # fake            # x, E, D            E_x = E(x)            D_E_x = torch.mean(D(x, E_x))  # real            D_loss = -torch.mean(torch.log(D_E_x) + torch.log(1 - D_G_z))            Latent_regress = l_c(z, E(G_z))            G_E_loss = -torch.mean(torch.log(1 - D_E_x) + torch.log(D_G_z))  # + lam * Latent_regress            optimizerD.zero_grad()            D_loss.backward(retain_graph=True)            optimizerD.step()            optimizerG_E.zero_grad()            G_E_loss.backward(retain_graph=True)            optimizerG_E.step()            tmpD_ = D_loss.cpu().detach().data            tmpG_E_ = G_E_loss.cpu().detach().data            tmpE_ = Latent_regress.cpu().detach().data            tmpD += tmpD_            tmpG_E += tmpG_E_            tmpE += tmpE_        tmpD /= (step + 1)        tmpG_E /= (step + 1)        tmpE /= (step + 1)        print(            'epoch %d avg of loss: D: %.6f, G_E: %.6f, latent: %.6f' % (epoch, tmpD, tmpG_E, tmpE)        )        if epoch % 2 == 0:            # x = x.cuda()            G_imgs = G(E(x)).cpu().detach()            fig = plt.figure(figsize=(10, 10))            plt.axis("off")            plt.imshow(                np.transpose(vutils.make_grid(torch.cat([G_imgs, x.cpu().detach()]), nrow=10, padding=0, normalize=True,                                              scale_each=True), (1, 2, 0)))            plt.savefig('E_%d_.png' % step)            plt.show()    torch.save(G, 'G.pkl')    torch.save(D, 'D.pkl')    torch.save(E, 'E.pkl')

model.py

import osimport torchimport torch.nn as nnimport torch.utils.data as Dataimport torchvisionfrom torch.utils.data import DataLoaderclass Generator(nn.Module):    def __init__(self, input_size):        super(Generator, self).__init__()        strides = [1, 2, 2, 2]        padding = [0, 1, 1, 1]        channels = [input_size,                    256, 128, 64, 32]  # 1表示一維        kernels = [4, 3, 4, 4]        model = []        for i, stride in enumerate(strides):            model.append(                nn.ConvTranspose2d(                    in_channels=channels[i],                    out_channels=channels[i + 1],                    stride=stride,                    kernel_size=kernels[i],                    padding=padding[i]                )            )            model.append(                nn.BatchNorm2d(channels[i + 1])            )            model.append(                nn.LeakyReLU(.1)            )        self.Conv_T = nn.Sequential(*model)        self.Conv = nn.Sequential(            nn.Conv2d(kernel_size=1, stride=1, in_channels=channels[-1], out_channels=channels[-1]),            nn.BatchNorm2d(channels[-1]),            nn.LeakyReLU(.1),            nn.Conv2d(kernel_size=1, stride=1, in_channels=channels[-1], out_channels=1),            nn.Sigmoid()        )    def forward(self, x):        x = self.Conv_T(x)        x = self.Conv(x)        return xclass Encoder(nn.Module):    def __init__(self, input_size=1, out_size=128):        super(Encoder, self).__init__()        strides = [2, 2, 2, 1, 1, 1]        padding = [1, 1, 1, 0, 0, 0]        channels = [input_size, 32, 64, 128, 256, out_size, out_size]  # 1表示一維        kernels = [4, 4, 4, 3, 1, 1]        model = []        for i, stride in enumerate(strides):            model.append(                nn.Conv2d(                    in_channels=channels[i],                    out_channels=channels[i + 1],                    stride=stride,                    kernel_size=kernels[i],                    padding=padding[i]                )            )            if i != len(strides) - 1:                model.append(                    nn.BatchNorm2d(channels[i + 1])                )                model.append(                    nn.ReLU()                )        self.main = nn.Sequential(*model)    def forward(self, x):        x = self.main(x)        return xclass Discriminator(nn.Module):    def __init__(self, x_in=1, z_in=128):        super(Discriminator, self).__init__()        self.D_x = nn.Sequential(            nn.Conv2d(in_channels=x_in, out_channels=32, kernel_size=4, stride=2),            nn.Dropout2d(.2),            nn.LeakyReLU(.1),            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),            nn.BatchNorm2d(64),            nn.Dropout2d(.2),            nn.LeakyReLU(.1),            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2),            nn.BatchNorm2d(128),            nn.Dropout2d(.2),            nn.LeakyReLU(.1),        )        self.D_z = nn.Sequential(            nn.Conv2d(in_channels=z_in, out_channels=256, kernel_size=1, stride=1),            nn.Dropout2d(.2),            nn.LeakyReLU(.1),            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1),            nn.Dropout2d(.2),            nn.LeakyReLU(.1),        )        self.D_x_z = nn.Sequential(            nn.Conv2d(in_channels=256 + 128, out_channels=512, kernel_size=1, stride=1),            nn.Dropout2d(.2),            nn.LeakyReLU(.1),            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1),            nn.Dropout2d(.2),            nn.LeakyReLU(.1),            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1),            nn.Dropout2d(.2),            nn.Sigmoid(),        )    def forward(self, x, z):        x = self.D_x(x)        z = self.D_z(z)        cat_x_z = torch.cat([x, z], dim=1)        return self.D_x_z(cat_x_z)if __name__ == '__main__':    N_IDEAS = 128    G = Generator(N_IDEAS, )    rand_noise = torch.randn((10, N_IDEAS, 1, 1))    print(G(rand_noise).shape)    E = Encoder(input_size=1, out_size=N_IDEAS)    print(E(G(rand_noise)).shape)    D = Discriminator()    print(D(G(rand_noise), rand_noise).shape)

judge.py

import numpy as npimport torchimport matplotlib.pyplot as pltfrom model import Generator, Discriminatorimport torchvision.utils as vutilsimport osimport torchvisionfrom torch.utils.data import Dataset, DataLoaderif __name__ == '__main__':    BATCH_SIZE = 100    N_IDEAS = 12    TIME = 10????G?=?torch.load("G.pkl").cuda()    mnist_root = '../Conditional-GAN/mnist/'    DOWNLOAD_MNIST = False    if not (os.path.exists(mnist_root)) or not os.listdir(mnist_root):        # not mnist dir or mnist is empyt dir        DOWNLOAD_MNIST = True    train_data = torchvision.datasets.MNIST(        root=mnist_root,        train=True,  # this is training data        transform=torchvision.transforms.ToTensor(),  # Converts a PIL.Image or numpy.ndarray to        # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]        download=DOWNLOAD_MNIST,    )    train_loader = DataLoader(dataset=train_data, batch_size=10, shuffle=True)    E = torch.load('E.pkl')    for t in range(TIME):        tmp = []        for step, (x, y) in enumerate(train_loader):            # x            x = x.cuda()            G_imgs = G(E(x)).cpu().detach()            tmp.append(torch.cat([G_imgs, x.cpu().detach()]))            if step == 5:                break        fig = plt.figure(figsize=(10, 10))        plt.axis("off")        plt.imshow(            np.transpose(vutils.make_grid(torch.cat(tmp), nrow=10, padding=0, normalize=True,                                          scale_each=True), (1, 2, 0)))        plt.savefig('E_%d.png' % t)        plt.show()

882427abd25326e081a98344d54a7cbd.png

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/271363.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/271363.shtml
英文地址,請注明出處:http://en.pswp.cn/news/271363.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

IO操作總結

1,讀取文件將文件轉換為二進制流 1 InputStream in new FileInputStream("C:/test.png"); 2 byte[] photo new byte[in.available()]; 3 in.read(photo); 4 in.close(); View Code2,寫文件 1 Outpu…

計算機網絡基礎:網絡標準相關知識介紹

1、常見的制定網絡標準的機構 國際標準化組織、國際電信聯盟、電子工業協會、電氣和電子工程協會、因特網活動委員會 2、常見的網絡標準 2.1 電信標準 國際電信聯盟(ITU)1947年成為聯合國的一個組織,包括ITU-R、ITU-T、ITU-D組成。 ITU-R:無線…

Long類型轉為String類型

如果java返回給前端的字段有Long類型的,比如主鍵id,那么就要把這個Long類型轉為String類型才可以,不然前端拿到這個字段再傳回給你后端用的時候會導致精度缺失,也就是這個字段的值會改變,原因是java的Long類型是18位&a…

1-5Tomcat 目錄結構 和 web項目目錄結構

對應我的安裝路徑: web項目目錄結構 轉載于:https://www.cnblogs.com/huiziz/p/5671612.html

execjs執行js出現window對象未定義時的解決_10個常見的JS語言錯誤總匯

1、 Uncaught TypeError: Cannot Read Property這是 JavaScript 開發人員最常遇到的錯誤。當你讀取一個屬性或調用一個未定義對象的方法時,Chrome 中就會報出這樣的錯誤。導致這個錯誤發生的原因有很多,常見的一種情況是在渲染 UI 組件時,不正…

安卓logcat工具apk_backdoorapk 安卓APK后門捆綁腳本

項目地址https://github.com/dana-at-cp/backdoor-apk項目介紹backdoor-apk是一個bash寫的腳本,通過msfvenom生成一個android的payload,然后再使用apktools將payload捆綁到正常的apk文件中。使用方法rootkali:~/Android/evol-lab/BaiduBrowserRat# ./bac…

java8 supplier 接口

Supplier 接口 Supplier 接口是一個供給型的接口,其實,說白了就是一個容器,可以用來存儲數據,然后可以供其他方法使用的這么一個接口 *** Supplier接口測試,supplier相當一個容器或者變量,可以存儲值*/Tes…

mantis apache mysql_軟件測試(軟件安裝:php+mysql+apache+mantis過程遇到的問題以及解決方法)...

實驗環境: Windows 7 64位操作系統瀏覽器版本: Mozilla Firefox 41.0.0.5378一.PHP的安裝① 版本: php-5.4.45-Win32-VC9-x86安裝步驟:安裝將PHP安裝到 D:\PHP下(目錄可以自行更改)配置找到PHP目錄里的類似 php.ini-dist ,…

c#程序中使用like“查詢access數據庫查詢為空的問題

今天,在開發的過程中發現了一個特別奇怪的問題:access中like查詢時候,在Access數據庫中執行,發現可以查詢出結果,這是在數據庫上執行,select * from KPProj where KpName like *測試*,但是同樣的…

html登錄界面_使用數據庫制作一套注冊登錄系統

經過了那么多個星期的學習&#xff0c;終于到了使用數據庫的階段了&#xff0c;最基本的也就是制作注冊登錄與數據庫連接。首先要制作一個注冊窗口先是html界面<效果如圖&#xff1a;&#xff08;樣子怎么樣不重要&#xff0c;重要的是測試&#xff09;這主要是將form數值發…

java8中Predicate用法

Predicate是個斷言式接口其參數是<T,boolean>&#xff0c;也就是給一個參數T&#xff0c;返回boolean類型的結果。跟Function一樣&#xff0c;Predicate的具體實現也是根據傳入的lambda表達式來決定的。 Testpublic void predicate(){/*** Predicate謂詞測試&#xff0c…

計算機網絡基礎:局域網協議相關知識

1、局域網協議的概念 局域網絡中的通信被限制在中等規模的地理范圍內&#xff0c;比如一所學校&#xff1b;能夠使用具體中等或較高數據速率的物理信道&#xff0c;并且具有較低的誤碼率&#xff1b;局域網絡是專用的&#xff0c; 由單一組織機構所使用。 局域網特點&#xff1…

mysql數據庫交叉連接_【數據庫】內連接、外連接、交叉連接

基本概念關系模型(表)關系模型由關系數據結構、關系操作集合和關系完整性約束三部分組成。關系模型的數據結構非常簡單&#xff1a;一張扁平的二維表。元組&#xff1a;二維表中的具有相同數據類型的某一行屬性&#xff1a;二維表中的具有相同數據類型的某一列笛卡爾積(Cartesi…

C#實現GDI+基本圖的縮放、拖拽、移動

C#實現GDI基本圖的縮放、拖拽、移動示例代碼如下&#xff1a; using System;using System.Collections.Generic;using System.ComponentModel;using System.Data;using System.Drawing;using System.Text;using System.Windows.Forms; namespace ResizableControls{ public …

網絡資產管理系統_固定資產管理系統的三種網絡架構方式

隨著互聯網技術的發展和信息技術的廣泛使用&#xff0c;固定資產管理系統在各行業的應用越來越普及&#xff0c;固定資產管理系統作為當今主流的企業固定資產信息化管理模式&#xff0c;能夠對企業固定資產進行有效管理并提升企業的管理水平。對于固定資產管理系統的網絡結構方…

計算機網絡基礎:廣域網協議相關知識筆記

廣域網常指覆蓋范圍廣、數據傳輸速率較低&#xff0c;以數據通信為目的的數據通信網。廣域網主要是通過專用的或交換式的連接把計算機連接起來。廣域網傳輸協議主要包括&#xff1a;PPP&#xff08;點對點協議&#xff09;、DDN、ISDN&#xff08;綜合業務數字網&#xff09;、…

mysql check table_修復MySQL的MyISAM表命令check table用法

MyISAM如果損壞了修復方法是比較簡單了我們只要使用check table命令就可以了&#xff0c;下面我們來看一篇關于修復MySQL的MyISAM表命令check table用法&#xff0c;具體如下所示。MySQL日志文件里出現以下錯誤&#xff0c;MySQL表通常不會發生crash情況&#xff0c;一般是在更…

python字典append_python中字典重復賦值,append到list中引發的異常

今天遇到了一個關于python 字典的誤用。先上代碼&#xff1a; data [{id: 1, name: 管理員, role: admin, desc: 系統管理員, acl: None}, {id: 2, name: 研發, role: dev, desc: 研發人員, acl: None}, {id: 3, name: 測試, role: qa, desc: 測試人員, acl: None}, {id: 4, n…

計算機網絡基礎:TCP/IP協議相關知識筆記?

1、TCP/IP特性邏輯編址&#xff1a;每一塊網卡會在出廠時由廠家分配了唯一的永久性物理地址。針對Internet&#xff0c;會為每臺連入因特網的計算機分配一個邏輯地址也就是IP地址。路由選擇&#xff1a;專門用于定義路由器如何選擇網絡路徑的協議&#xff0c;即IP數據包的路由選…