1.之前只能做一些圖像預測,我有個大膽的想法,如果神經網絡正向就是預測圖片的類別,如果我只有一個類別那就可以進行生成圖片,專業術語叫做gan對抗網絡
2.訓練代碼
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dset
import matplotlib.pyplot as plt
import os# 設置環境變量
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'# 定義生成器模型
class Generator(nn.Module):def __init__(self, input_dim=100, output_dim=784):super(Generator, self).__init__()self.fc1 = nn.Linear(input_dim, 256)self.fc2 = nn.Linear(256, 512)self.fc3 = nn.Linear(512, 1024)self.fc4 = nn.Linear(1024, output_dim)self.relu = nn.ReLU()self.tanh = nn.Tanh()def forward(self, x):x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.relu(self.fc3(x))x = self.tanh(self.fc4(x))return x# 定義判別器模型
class Discriminator(nn.Module):def __init__(self, input_dim=784, output_dim=1):super(Discriminator, self).__init__()self.fc1 = nn.Linear(input_dim, 1024)self.fc2 = nn.Linear(1024, 512)self.fc3 = nn.Linear(512, 256)self.fc4 = nn.Linear(256, output_dim)self.relu = nn.ReLU()self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.relu(self.fc3(x))x = self.sigmoid(self.fc4(x))return x# 加載 MNIST 手寫數字圖片數據集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
dataroot = "path_to_your_mnist_dataset" # 替換為 MNIST 數據集的路徑
dataset = dset.MNIST(root=dataroot, train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)# 創建生成器和判別器實例
input_dim = 100
output_dim = 784
generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)# 定義優化器和損失函數
lr = 0.0002
beta1 = 0.5
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
criterion = nn.BCELoss()# 訓練 GAN 模型
num_epochs = 50
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)
generator.to(device)
discriminator.to(device)
for epoch in range(num_epochs):for i, data in enumerate(dataloader, 0):real_images, _ = datareal_images = real_images.to(device)batch_size = real_images.size(0) # 獲取批次樣本數量# 訓練判別器optimizer_d.zero_grad()real_labels = torch.full((batch_size, 1), 1.0, device=device)fake_labels = torch.full((batch_size, 1), 0.0, device=device)noise = torch.randn(batch_size, input_dim, device=device)fake_images = generator(noise)real_outputs = discriminator(real_images.view(batch_size, -1))fake_outputs = discriminator(fake_images.detach())d_loss_real = criterion(real_outputs, real_labels)d_loss_fake = criterion(fake_outputs, fake_labels)d_loss = d_loss_real + d_loss_faked_loss.backward()optimizer_d.step()# 訓練生成器optimizer_g.zero_grad()noise = torch.randn(batch_size, input_dim, device=device)fake_images = generator(noise)fake_outputs = discriminator(fake_images)g_loss = criterion(fake_outputs, real_labels)g_loss.backward()optimizer_g.step()# 輸出訓練信息if i % 100 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]"% (epoch, num_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))# 保存生成器的權重和圖片示例if epoch % 10 == 0:with torch.no_grad():noise = torch.randn(64, input_dim, device=device)fake_images = generator(noise).view(64, 1, 28, 28).cpu().numpy()fig, axes = plt.subplots(nrows=8, ncols=8, figsize=(12, 12), sharex=True, sharey=True)for i, ax in enumerate(axes.flatten()):ax.imshow(fake_images[i][0], cmap='gray')ax.axis('off')plt.subplots_adjust(wspace=0.05, hspace=0.05)plt.savefig("epoch_%d.png" % epoch)plt.close()torch.save(generator.state_dict(), "generator_epoch_%d.pth" % epoch)
3.測試模型的代碼
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image# 定義生成器模型
class Generator(nn.Module):def __init__(self, input_dim, output_dim):super(Generator, self).__init__()self.fc1 = nn.Linear(input_dim, 256)self.fc2 = nn.Linear(256, 512)self.fc3 = nn.Linear(512, 1024)self.fc4 = nn.Linear(1024, output_dim)def forward(self, x):x = F.leaky_relu(self.fc1(x), 0.2)x = F.leaky_relu(self.fc2(x), 0.2)x = F.leaky_relu(self.fc3(x), 0.2)x = torch.tanh(self.fc4(x))return x# 創建生成器模型
generator = Generator(input_dim=100, output_dim=784)# 加載預訓練權重
generator_weights = torch.load("generator_epoch_40.pth", map_location=torch.device('cpu'))# 將權重加載到生成器模型
generator.load_state_dict(generator_weights)# 生成隨機噪聲
noise = torch.randn(1, 100)# 生成圖像
fake_image = generator(noise).view(1, 1, 28, 28)# 保存生成的圖片
save_image(fake_image, "generated_image.png", normalize=False)
#測試結果,由于我的訓練集是數字的,所以會生成各種各樣的數字,下面明顯的是1
#應該也是1
#再次運行,我也看不出來,不過只要我訓練只有一個種類的問題就可以生成這個種類的圖像
#搞定黑白圖,那彩色圖應該距離不遠了,我需要改進的是把對抗網絡的代碼改為訓練一個種類的圖形,不過我感覺這種圖形具有隨機性,雖然通過訓練我們得到了所有圖像他們的規律,但是如果需要正常點的圖片還是挺難的,就像是上面這張人都不一定知道他是什么東西(在沒有顏色的情況下)總結就是精度不夠,而且隨機性太強了,現在普遍圖片AI生成工具具有這個缺點(生成的物體可能會扭曲,挺陰間的),而且生成的圖片速度慢,如果誰比較受益那一定是老黃(英偉達)哈哈哈
//比如下面這個圖片生成視頻的網站
https://app.runwayml.com/login
#每一幀看起來都沒有問題,就是連起來變成視頻不自然,如果有改進方法的話那可能需要引入重力/加速度/光處理 等等物理公式,來讓圖片更自然…