import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_iris
import warnings
# 忽略不必要的警告信息
warnings.filterwarnings("ignore")
# --------------------------
# 1. 配置訓練參數與設備
# --------------------------
# 潛在空間維度(生成器的輸入維度)
latent_dim = 10 ?
# 訓練總輪數(GAN通常需要較多迭代才能收斂)
train_epochs = 10000 ?
# 批次大小(根據數據集規模調整)
batch_size = 32 ?
# 學習率(控制參數更新幅度)
learning_rate = 0.0002 ?
# Adam優化器的動量參數(影響收斂穩定性)
beta1 = 0.5 ?
# 自動選擇運算設備(優先GPU,沒有則用CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"當前使用設備: {device}")
# --------------------------
# 2. 數據加載與預處理
# --------------------------
# 加載鳶尾花數據集
iris_dataset = load_iris()
# 提取特征數據和標簽
features = iris_dataset.data
labels = iris_dataset.target
# 只選取Setosa類別(標簽為0)的數據進行訓練
setosa_features = features[labels == 0]
# 將數據縮放到[-1, 1]區間(配合生成器的Tanh輸出激活)
scaler = MinMaxScaler(feature_range=(-1, 1))
scaled_features = scaler.fit_transform(setosa_features)
# 轉換為PyTorch張量并創建數據加載器
# 注意:必須轉為float類型才能與模型參數兼容
data_tensor = torch.from_numpy(scaled_features).float()
dataset = TensorDataset(data_tensor)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 打印數據基本信息
print(f"訓練樣本數量: {len(scaled_features)}")
print(f"特征維度: {scaled_features.shape[1]}") ?# 鳶尾花數據集固定為4維特征
# --------------------------
# 3. 定義生成器和判別器
# --------------------------
class Generator(nn.Module):
? ? """生成器:將隨機噪聲轉換為模擬的鳶尾花特征數據"""
? ? def __init__(self):
? ? ? ? super(Generator, self).__init__()
? ? ? ? # 簡單的全連接網絡結構
? ? ? ? self.net = nn.Sequential(
? ? ? ? ? ? nn.Linear(latent_dim, 16), ?# 從潛在空間映射到16維
? ? ? ? ? ? nn.ReLU(), ?# 激活函數增加非線性
? ? ? ? ? ? nn.Linear(16, 32), ?# 進一步映射到32維
? ? ? ? ? ? nn.ReLU(),
? ? ? ? ? ? nn.Linear(32, 4), ?# 輸出4維特征(與真實數據一致)
? ? ? ? ? ? nn.Tanh() ?# 確保輸出在[-1, 1]范圍內
? ? ? ? )
? ??
? ? def forward(self, x):
? ? ? ? # 前向傳播:輸入噪聲,輸出生成的數據
? ? ? ? return self.net(x)
class Discriminator(nn.Module):
? ? """判別器:區分輸入數據是真實樣本還是生成器偽造的"""
? ? def __init__(self):
? ? ? ? super(Discriminator, self).__init__()
? ? ? ? # 簡單的全連接網絡結構
? ? ? ? self.net = nn.Sequential(
? ? ? ? ? ? nn.Linear(4, 32), ?# 輸入4維特征
? ? ? ? ? ? nn.LeakyReLU(0.2), ?# LeakyReLU避免梯度消失問題
? ? ? ? ? ? nn.Linear(32, 16), ?# 壓縮到16維
? ? ? ? ? ? nn.LeakyReLU(0.2),
? ? ? ? ? ? nn.Linear(16, 1), ?# 輸出單個概率值
? ? ? ? ? ? nn.Sigmoid() ?# 將輸出壓縮到[0,1](表示真實數據的概率)
? ? ? ? )
? ??
? ? def forward(self, x):
? ? ? ? # 前向傳播:輸入數據,輸出判斷概率
? ? ? ? return self.net(x)
# 初始化模型并移動到運算設備
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 打印模型結構
print("\n生成器結構:")
print(generator)
print("\n判別器結構:")
print(discriminator)
# --------------------------
# 4. 配置訓練組件
# --------------------------
# 定義損失函數(二元交叉熵,適合二分類問題)
criterion = nn.BCELoss()
# 定義優化器(分別優化生成器和判別器)
gen_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
dis_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
# --------------------------
# 5. 開始訓練
# --------------------------
print("\n--- 訓練開始 ---")
for epoch in range(train_epochs):
? ? # 遍歷數據加載器中的每一批次
? ? for batch_idx, (real_data,) in enumerate(data_loader):
? ? ? ? # 將真實數據移動到運算設備
? ? ? ? real_data = real_data.to(device)
? ? ? ? current_batch_size = real_data.size(0) ?# 獲取當前批次的實際樣本數(最后一批可能不滿)
? ? ? ??
? ? ? ? # 創建標簽:真實數據標為1,生成數據標為0
? ? ? ? real_labels = torch.ones(current_batch_size, 1).to(device)
? ? ? ? fake_labels = torch.zeros(current_batch_size, 1).to(device)
? ? ? ??
? ? ? ? # --------------------
? ? ? ? # 訓練判別器
? ? ? ? # --------------------
? ? ? ? dis_optimizer.zero_grad() ?# 清空判別器的梯度緩存
? ? ? ??
? ? ? ? # 1. 用真實數據訓練
? ? ? ? real_output = discriminator(real_data)
? ? ? ? # 計算真實數據的損失(希望判別器能認出真實數據)
? ? ? ? loss_real = criterion(real_output, real_labels)
? ? ? ??
? ? ? ? # 2. 用生成的數據訓練
? ? ? ? # 生成隨機噪聲(作為生成器的輸入)
? ? ? ? noise = torch.randn(current_batch_size, latent_dim).to(device)
? ? ? ? # 生成假數據,并阻斷梯度流向生成器(避免影響生成器參數)
? ? ? ? fake_data = generator(noise).detach()
? ? ? ? fake_output = discriminator(fake_data)
? ? ? ? # 計算假數據的損失(希望判別器能認出假數據)
? ? ? ? loss_fake = criterion(fake_output, fake_labels)
? ? ? ??
? ? ? ? # 總損失反向傳播并更新判別器參數
? ? ? ? dis_loss = loss_real + loss_fake
? ? ? ? dis_loss.backward()
? ? ? ? dis_optimizer.step()
? ? ? ??
? ? ? ? # --------------------
? ? ? ? # 訓練生成器
? ? ? ? # --------------------
? ? ? ? gen_optimizer.zero_grad() ?# 清空生成器的梯度緩存
? ? ? ??
? ? ? ? # 重新生成假數據(這次需要計算生成器的梯度)
? ? ? ? noise = torch.randn(current_batch_size, latent_dim).to(device)
? ? ? ? fake_data = generator(noise)
? ? ? ? fake_output = discriminator(fake_data)
? ? ? ??
? ? ? ? # 生成器的損失:希望判別器把假數據當成真的(所以標簽用real_labels)
? ? ? ? gen_loss = criterion(fake_output, real_labels)
? ? ? ? gen_loss.backward()
? ? ? ? gen_optimizer.step()
? ??
? ? # 每1000輪打印一次訓練狀態
? ? if (epoch + 1) % 1000 == 0:
? ? ? ? print(
? ? ? ? ? ? f"輪次 [{epoch+1}/{train_epochs}], "
? ? ? ? ? ? f"判別器損失: {dis_loss.item():.4f}, "
? ? ? ? ? ? f"生成器損失: {gen_loss.item():.4f}"
? ? ? ? )
print("\n--- 訓練完成 ---")