GAN的概念
GAN(Generative Adversarial Network,生成對抗網絡)是一種深度學習模型,由生成器(Generator)和判別器(Discriminator)兩部分組成。生成器負責生成 synthetic data(如假圖像、文本等),判別器則試圖區分生成數據和真實數據。兩者通過對抗訓練不斷優化,最終使生成數據難以被判別器識別。
GAN的核心原理
生成器:接收隨機噪聲作為輸入,生成盡可能逼真的數據,目標是“欺騙”判別器。
判別器:接收真實數據和生成數據,輸出一個概率值判斷輸入的真偽,目標是準確區分兩者。
兩者的目標函數可以表示為以下 minimax 問題:
[ \min_G \max_D V(D, G) = \mathbb{E}{x \sim p{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] ]
其中:
- ( D(x) ) 是判別器對真實數據的判斷概率;
- ( G(z) ) 是生成器從噪聲 ( z ) 生成的數據;
- ( p_{data} ) 和 ( p_z ) 分別是真實數據分布和噪聲分布。
GAN的應用場景
- 圖像生成:如生成人臉(StyleGAN)、藝術作品(DeepDream)。
- 數據增強:為小樣本任務生成補充數據。
- 圖像修復:填充缺失區域(如修復老照片)。
- 風格遷移:將圖像轉換為特定風格(如卡通化)。
GAN的變體與改進
- DCGAN:使用卷積層提升圖像生成質量。
- WGAN:通過 Wasserstein 距離改進訓練穩定性。
- CycleGAN:支持無配對數據的跨域轉換(如馬→斑馬)。
挑戰與局限性
- 訓練不穩定:生成器和判別器可能無法同步收斂。
- 模式坍縮:生成器僅生成單一類型樣本。
- 評估困難:缺乏統一的量化指標衡量生成質量。
GAN 因其強大的生成能力成為 AI 領域的重要研究方向,廣泛應用于計算機視覺、自然語言處理等領域。
生成對抗網絡(GAN)
以下是一個基于C++和StyleGAN實現人臉生成的示例框架,包含關鍵代碼片段和解釋。這些示例假設已配置好StyleGAN模型(如stylegan2-ada-pytorch
)并導出為ONNX或LibTorch格式供C++調用。
環境準備
確保已安裝以下依賴:
- OpenCV(圖像處理)
- LibTorch(PyTorch C++ API)
- ONNX Runtime(可選)
#include <torch/script.h>
#include <opencv2/opencv.hpp>
示例1:加載預訓練模型
torch::jit::script::Module module;
try {module = torch::jit::load("stylegan2-ada.pt");
} catch (const std::exception& e) {std::cerr << "Error loading model: " << e.what() << std::endl;
}
示例2:生成隨機潛在向量(Z空間)
torch::Tensor z = torch::randn({1, 512}); // 512-dim latent vector
示例3:映射網絡(Z→W空間)
torch::Tensor w = module.forward({z}).toTensor(); // 通過StyleGAN的映射網絡
示例4:生成人臉圖像
torch::Tensor img_tensor = module.forward({w}).toTensor(); // 合成圖像
img_tensor = img_tensor.squeeze().detach().clamp(0, 1); // 歸一化到[0,1]
示例5:張量轉OpenCV格式
img_tensor = img_tensor.mul(255).permute({1, 2, 0}).to(torch::kU8);
cv::Mat img(img_tensor.size(0), img_tensor.size(1), CV_8UC3, img_tensor.data_ptr());
cv::cvtColor(img, img, cv::COLOR_RGB2BGR);
示例6:保存生成圖像
cv::imwrite("generated_face.png", img);
示例7:批量生成人臉
torch::Tensor z_batch = torch::randn({10, 512}); // 批量生成10張
torch::Tensor imgs = module.forward({z_batch}).toTensor();
示例8:插值生成(平滑過渡)
torch::Tensor z1 = torch::randn({1, 512});
torch::Tensor z2 = torch::randn({1, 512});
for (float alpha = 0; alpha <= 1; alpha += 0.1) {torch::Tensor z_interp = z1 * (1 - alpha) + z2 * alpha;torch::Tensor img = module.forward({z_interp}).toTensor();
}
示例9:使用StyleGAN的截斷技巧(Truncation Trick)
float psi = 0.7; // 截斷系數
torch::Tensor w_mean = ...; // 預計算W空間均值
torch::Tensor w_truncated = w_mean + psi * (w - w_mean);
示例10:條件生成(添加標簽)
torch::Tensor label = torch::zeros({1, 10}); // 假設10類
label[0][3] = 1; // 選擇第3類
torch::Tensor img = module.forward({z, label}).toTensor();
示例11:圖像分辨率設置
module.attr("resolution").setAttr(1024); // 設置為1024x1024輸出
示例12:GPU加速
module.to(torch::kCUDA);
torch::Tensor z = torch::randn({1, 512}, torch::kCUDA);
示例13:混合風格(Style Mixing)
torch::Tensor z1 = torch::randn({1, 512});
torch::Tensor z2 = torch::randn({1, 512});
torch::Tensor w1 = module.forward({z1}).toTensor();
torch::Tensor w2 = module.forward({z2}).toTensor();
// 混合前4層風格
w1.slice(1, 0, 4) = w2.slice(1, 0, 4);
torch::Tensor img = module.forward({w1}).toTensor();
示例14:生成動畫序列
std::vector<torch::Tensor> frames;
for (int i = 0; i < 60; ++i) {torch::Tensor z = torch::randn({1, 512});frames.push_back(module.forward({z}).toTensor());
}
// 保存為視頻
示例15:使用ONNX Runtime推理
Ort::Env env;
Ort::Session session(env, "stylegan2.onnx", Ort::SessionOptions{});
Ort::AllocatorWithDefaultOptions allocator;
std::vector<int64_t> input_shape = {1, 512};
std::vector<float> z_data(512);
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(allocator, z_data.data(), z_data.size(), input_shape.data(), i