GAN框架基于兩個模型的競爭,Generator生成器和Discriminator鑒別器。生成器生成假圖像,鑒別器則嘗試從假圖像中識別真實的圖像。作為這種競爭的結果,生成器將生成更好看的假圖像,而鑒別器將更好地識別它們。
目錄
創建數據集
定義生成器
定義鑒別器
初始化模型權重
定義損失函數
定義優化器
訓練模型
部署生成器
創建數據集
使用 PyTorch torchvision 包中提供的 STL-10 數據集,數據集中有 10 個類:飛機、鳥、車、貓、鹿、狗、馬、猴、船、卡車。圖像為96*96像素的RGB圖像。數據集包含 5,000 張訓練圖像和 8,000 張測試圖像。在訓練數據集和測試數據集中,每個類分別有 500 和 800 張圖像。
?STL-10數據集詳細參考http://t.csdnimg.cn/ojBn6中數據加載和處理部分?
from torchvision import datasets
import torchvision.transforms as transforms
import os# 定義數據集路徑
path2data="./data"
# 創建數據集路徑
os.makedirs(path2data, exist_ok= True)# 定義圖像尺寸
h, w = 64, 64
# 定義均值
mean = (0.5, 0.5, 0.5)
# 定義標準差
std = (0.5, 0.5, 0.5)
# 定義數據預處理
transform= transforms.Compose([transforms.Resize((h,w)), # 調整圖像尺寸transforms.CenterCrop((h,w)), # 中心裁剪transforms.ToTensor(), # 轉換為張量transforms.Normalize(mean, std)]) # 歸一化# 加載訓練集
train_ds=datasets.STL10(path2data, split='train', download=False,transform=transform)
?展示示例圖像張量形狀、最小值和最大值
import torch
for x, _ in train_ds:print(x.shape, torch.min(x), torch.max(x))break
?展示示例圖像
from torchvision.transforms.functional import to_pil_image
import matplotlib.pylab as plt
%matplotlib inline
plt.imshow(to_pil_image(0.5*x+0.5))
?
創建數據加載器?
import torch
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
?示例
for x,y in train_dl:print(x.shape, y.shape)break
定義生成器
GAN框架是基于兩個模型的競爭,generator生成器和discriminator鑒別器。生成器生成假圖像,鑒別器嘗試從假圖像中識別真實的圖像。
作為這種競爭的結果,生成器將生成更好看的假圖像,而鑒別器將更好地識別它們。
定義生成器模型?
from torch import nn
import torch.nn.functional as Fclass Generator(nn.Module):def __init__(self, params):super(Generator, self).__init__()# 獲取參數nz = params["nz"]ngf = params["ngf"]noc = params["noc"]# 定義反卷積層1self.dconv1 = nn.ConvTranspose2d( nz, ngf * 8, kernel_size=4,stride=1, padding=0, bias=False)# 定義批歸一化層1self.bn1 = nn.BatchNorm2d(ngf * 8)# 定義反卷積層2self.dconv2 = nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層2self.bn2 = nn.BatchNorm2d(ngf * 4)# 定義反卷積層3self.dconv3 = nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層3self.bn3 = nn.BatchNorm2d(ngf * 2)# 定義反卷積層4self.dconv4 = nn.ConvTranspose2d( ngf * 2, ngf, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層4self.bn4 = nn.BatchNorm2d(ngf)# 定義反卷積層5self.dconv5 = nn.ConvTranspose2d( ngf, noc, kernel_size=4, stride=2, padding=1, bias=False)# 前向傳播def forward(self, x):# 反卷積層1x = F.relu(self.bn1(self.dconv1(x)))# 反卷積層2x = F.relu(self.bn2(self.dconv2(x))) # 反卷積層3x = F.relu(self.bn3(self.dconv3(x))) # 反卷積層4x = F.relu(self.bn4(self.dconv4(x))) # 反卷積層5out = torch.tanh(self.dconv5(x))return out
設定生成器模型參數、移動模型到cuda設備并打印模型結構?
params_gen = {"nz": 100,"ngf": 64,"noc": 3,}
model_gen = Generator(params_gen)
device = torch.device("cuda:0")
model_gen.to(device)
print(model_gen)
定義鑒別器
定義鑒別器模型,?用于鑒別真實圖像
class Discriminator(nn.Module):def __init__(self, params):super(Discriminator, self).__init__()# 獲取參數nic= params["nic"]ndf = params["ndf"]# 定義卷積層1self.conv1 = nn.Conv2d(nic, ndf, kernel_size=4, stride=2, padding=1, bias=False)# 定義卷積層2self.conv2 = nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層2self.bn2 = nn.BatchNorm2d(ndf * 2) # 定義卷積層3self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層3self.bn3 = nn.BatchNorm2d(ndf * 4)# 定義卷積層4self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False)# 定義批歸一化層4self.bn4 = nn.BatchNorm2d(ndf * 8)# 定義卷積層5self.conv5 = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False)def forward(self, x):# 使用leaky_relu激活函數對卷積層1的輸出進行激活x = F.leaky_relu(self.conv1(x), 0.2, True)# 使用leaky_relu激活函數對卷積層2的輸出進行激活,并使用批歸一化層2進行批歸一化x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2, inplace = True)# 使用leaky_relu激活函數對卷積層3的輸出進行激活,并使用批歸一化層3進行批歸一化x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2, inplace = True)# 使用leaky_relu激活函數對卷積層4的輸出進行激活,并使用批歸一化層4進行批歸一化x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2, inplace = True) # 使用sigmoid激活函數對卷積層5的輸出進行激活,并返回結果# Sigmoid激活函數是一種常用的非線性激活函數,它將輸入值壓縮到0和1之間,[ \sigma(x) = \frac{1}{1 + e^{-x}} ]out = torch.sigmoid(self.conv5(x))return out.view(-1)
設置模型參數,移動模型到cuda設備,打印模型結構?
params_dis = {"nic": 3,"ndf": 64}
model_dis = Discriminator(params_dis)
model_dis.to(device)
print(model_dis)
初始化模型權重
定義函數,初始化模型權重?
def initialize_weights(model):# 獲取模型類的名稱classname = model.__class__.__name__# 如果模型類名稱中包含'Conv',則初始化權重為均值為0,標準差為0.02的正態分布if classname.find('Conv') != -1:nn.init.normal_(model.weight.data, 0.0, 0.02)# 如果模型類名稱中包含'BatchNorm',則初始化權重為均值為1,標準差為0.02的正態分布,偏置為0elif classname.find('BatchNorm') != -1:nn.init.normal_(model.weight.data, 1.0, 0.02)nn.init.constant_(model.bias.data, 0)
初始化生成器模型和鑒別器模型的權重?
# 對生成器模型應用初始化權重函數
model_gen.apply(initialize_weights);
# 對判別器模型應用初始化權重函數
model_dis.apply(initialize_weights);
定義損失函數
定義二元交叉熵(BCE)損失函數?
loss_func = nn.BCELoss()
定義優化器
定義Adam優化器
from torch import optim
# 學習率
lr = 2e-4
# Adam優化器的beta1參數
beta1 = 0.5
# 定義鑒別器模型的優化器,學習率為lr,beta1參數為beta1,beta2參數為0.999
opt_dis = optim.Adam(model_dis.parameters(), lr=lr, betas=(beta1, 0.999))
# 定義生成器模型的優化器
opt_gen = optim.Adam(model_gen.parameters(), lr=lr, betas=(beta1, 0.999))
訓練模型
?示例訓練1000個epochs
# 定義真實標簽和虛假標簽
real_label = 1
fake_label = 0
# 獲取生成器的噪聲維度
nz = params_gen["nz"]
# 設置訓練輪數
num_epochs = 1000
# 定義損失歷史記錄
loss_history={"gen": [],"dis": []}
# 定義批次數
batch_count = 0
# 遍歷訓練輪數
for epoch in range(num_epochs):# 遍歷訓練數據for xb, yb in train_dl:# 獲取批大小ba_si = xb.size(0)# 將判別器梯度置零model_dis.zero_grad()# 將輸入數據移動到指定設備xb = xb.to(device)# 將標簽數據轉換為指定設備yb = torch.full((ba_si,), real_label, device=device)# 判別器輸出out_dis = model_dis(xb)# 將輸出和標簽轉換為浮點數out_dis = out_dis.float()yb = yb.float()# 計算真實樣本的損失loss_r = loss_func(out_dis, yb)# 反向傳播loss_r.backward()# 生成噪聲noise = torch.randn(ba_si, nz, 1, 1, device=device)# 生成器輸出out_gen = model_gen(noise)# 判別器輸出out_dis = model_dis(out_gen.detach())# 將標簽數據填充為虛假標簽yb.fill_(fake_label) # 計算虛假樣本的損失loss_f = loss_func(out_dis, yb)# 反向傳播loss_f.backward()# 計算判別器的總損失loss_dis = loss_r + loss_f # 更新判別器的參數opt_dis.step() # 將生成器梯度置零model_gen.zero_grad()# 將標簽數據填充為真實標簽yb.fill_(real_label) # 判別器輸出out_dis = model_dis(out_gen)# 計算生成器的損失loss_gen = loss_func(out_dis, yb)# 反向傳播loss_gen.backward()# 更新生成器的參數opt_gen.step()# 記錄生成器和判別器的損失loss_history["gen"].append(loss_gen.item())loss_history["dis"].append(loss_dis.item())# 更新批次數batch_count += 1# 每100個批打印一次損失if batch_count % 100 == 0:print(epoch, loss_gen.item(),loss_dis.item())
?繪制損失圖像
plt.figure(figsize=(10,5))
plt.title("Loss Progress")
plt.plot(loss_history["gen"],label="Gen. Loss")
plt.plot(loss_history["dis"],label="Dis. Loss")
plt.xlabel("batch count")
plt.ylabel("Loss")
plt.legend()
plt.show()
存儲模型權重?
import os
path2models = "./models/"
os.makedirs(path2models, exist_ok=True)
path2weights_gen = os.path.join(path2models, "weights_gen_128.pt")
path2weights_dis = os.path.join(path2models, "weights_dis_128.pt")
torch.save(model_gen.state_dict(), path2weights_gen)
torch.save(model_dis.state_dict(), path2weights_dis)
部署生成器
通常情況下,訓練完成后放棄鑒別器模型而保留生成器模型,部署經過訓練的生成器來生成新的圖像。為部署生成器模型,將訓練好的權重加載到模型中,然后給模型提供隨機噪聲。
# 加載生成器模型的權重
weights = torch.load(path2weights_gen)
# 將權重加載到生成器模型中
model_gen.load_state_dict(weights)
# 將生成器模型設置為評估模式
model_gen.eval()
?生成圖像
import numpy as np
with torch.no_grad():# 生成固定噪聲fixed_noise = torch.randn(16, nz, 1, 1, device=device)# 打印噪聲形狀print(fixed_noise.shape)# 生成假圖像img_fake = model_gen(fixed_noise).detach().cpu()
# 打印假圖像形狀
print(img_fake.shape)
# 創建畫布
plt.figure(figsize=(10,10))
# 遍歷假圖像
for ii in range(16):# 在畫布上繪制圖像plt.subplot(4,4,ii+1)# 將圖像轉換為PIL圖像plt.imshow(to_pil_image(0.5*img_fake[ii]+0.5))# 關閉坐標軸plt.axis("off")
其中一些可能看起來扭曲,而另一些看起來相對真實。為改進結果,可以在單個數據類上訓練模型,而不是在多個類上一起訓練。GAN在使用單個類進行訓練時表現更好。此外,可以嘗試更長時間地訓練模型。