知識點:
1.對抗生成網絡的思想:關注損失從何而來
2.生成器、判別器
3.nn.sequential容器:適合于按順序運算的情況,簡化前向傳播寫法
4.leakyReLU介紹:避免relu的神經元失活現象
ps:如果你學有余力,對于gan的損失函數的理解,建議去找找視頻看看,如果只是用,沒必要學
作業:對于心臟病數據集,對于病人這個不平衡的樣本用GAN來學習并生成病人樣本,觀察不用GAN和用GAN的F1分數差異。
對抗生成網絡(GAN,Generative Adversarial Network)是一種深度學習模型架構,由生成器(Generator)和判別器(Discriminator)兩部分組成,通過兩個模型相互對抗、博弈,以達到生成高質量數據樣本的目的。
工作原理
生成器 :負責從隨機噪聲中生成逼真的數據樣本,如圖像、文本等。它類似于一個偽劣藝術家,試圖通過學習訓練數據的分布,生成能夠以假亂真的作品。判別器 :負責判斷給定的數據樣本是來自真實訓練數據還是生成器生成的假數據。它就像一個專業的藝術評論家,通過不斷地審視作品,給出真偽判斷。對抗過程 :訓練過程中,生成器和判別器相互博弈。生成器不斷嘗試生成更逼真的樣本以欺騙判別器,而判別器則不斷學習如何更準確地識別真假樣本。在這一過程中,生成器逐漸學習到訓練數據的分布規律,生成的樣本質量越來越高,判別器的判別能力也越來越強,最終達到納什均衡,此時生成器生成的樣本幾乎可以以假亂真。
網絡結構
生成器結構 :通常以隨機噪聲作為輸入,經過一系列的線性變換、激活函數等操作,逐步將噪聲轉化為具有一定結構和特征的數據樣本,常見的結構有全連接層、反卷積層、批量歸一化層等。例如,DCGAN(Deep Convolutional GAN)中的生成器采用反卷積層逐步上采樣,將低維噪聲映射到高維圖像空間。
判別器結構 :一般是一個卷積神經網絡(CNN),用于接收數據樣本并輸出其為真實數據的概率值。它通過卷積層、池化層等提取樣本的特征,并經過全連接層和激活函數(如 sigmoid)得到概率輸出。判別器的設計需要考慮如何有效地捕捉數據樣本的真實特征,以便準確地區分真實數據和生成數據。
訓練過程
初始化 :隨機初始化生成器和判別器的網絡參數。訓練判別器 :固定生成器的參數,使用真實數據和生成器生成的假數據訓練判別器,通過優化損失函數(如交叉熵損失)來調整判別器的參數,使其能夠更好地判斷數據的真偽。
訓練生成器 :固定判別器的參數,使用生成器生成的假數據訓練生成器,通過優化損失函數(通常也是基于判別器對假數據的判斷結果)來調整生成器的參數,使生成器生成的樣本更有可能被誤判為真實數據。
迭代交替訓練 :重復上述訓練判別器和生成器的過程,直到達到一定的訓練輪數或生成器生成的樣本質量達到預期。
應用領域
圖像生成與編輯 :可以用于生成高質量的圖像,如人物肖像、風景圖等;還可以進行圖像的風格轉換、超分辨率重建、圖像修復等圖像編輯任務。
文本生成 :在自然語言處理領域,GAN 可以用于文本生成,如生成新聞報道、故事、詩歌等,也可以用于文本到文本的轉換任務,如機器翻譯、文本摘要等。
語音生成與合成 :能夠生成逼真的語音信號,實現語音合成、語音轉換等功能,在語音助手、語音識別等應用中具有潛在價值。
數據增強 :通過生成與真實數據分布相似的樣本,為其他機器學習任務提供更多的訓練數據,提高模型的性能和泛化能力,尤其在數據稀缺的情況下具有重要意義。
GAN 自提出以來,不斷涌現出各種改進和變體,如 WGAN(Wasserstein GAN)、CGAN(Conditional GAN)、StyleGAN 等,這些改進在不同方面提升了 GAN 的性能和應用效果。