pytorch 對抗樣本_【煉丹技巧】功守道:NLP中的對抗訓練 + PyTorch實現

本文分享一個“萬物皆可盤”的NLP對抗訓練實現,只需要四行代碼即可調用。盤他。

最近,微軟的FreeLB-Roberta [1] 靠著對抗訓練 (Adversarial Training)在GLUE榜上超越了Facebook原生的Roberta,追一科技也用到了這個方法僅憑單模型 [2] 就在CoQA榜單中超過了人類,似乎“對抗訓練”一下子變成了NLP任務的一把利器。剛好筆者最近也在看這方面的內容,所以開一篇博客,講一下。GLUE LeaderboardCoQA Leaderboard

提到“對抗”,相信大多數人的第一反應都是CV中的對抗生成網絡 (GAN),殊不知,其實對抗也可以作為一種防御機制,并且經過簡單的修改,便能用在NLP任務上,提高模型的泛化能力。關鍵是,對抗訓練可以寫成一個插件的形式,用幾行代碼就可以在訓練中自由地調用,簡單有效,使用成本低。不過網上的大多數博客對于NLP中的對抗訓練都介紹得比較零散且無代碼實現,筆者在這篇博客中,對NLP任務中的對抗訓練做了一個簡單的綜述,并提供了插件形式的PyTorch實現。

本文專注于NLP對抗訓練的介紹,對對抗攻擊基礎感興趣的讀者,可以看這幾篇博客及論文 [3] [4] [5],這里就不贅述了。不想要理解理論細節的讀者也可以直接看最后的代碼實現。

1. 對抗樣本

我們常常會聽到“對抗樣本”、“對抗攻擊”、“對抗訓練”等等這些令人頭禿的概念,為了讓大家對“對抗”有個更清晰的認識,我們先把這些概念捋捋清楚。Taxonomy

Szegedy在14年的ICLR中 [6] 提出了對抗樣本這個概念。如上圖,對抗樣本可以用來攻擊和防御,而對抗訓練其實是“對抗”家族中防御的一種方式,其基本的原理呢,就是通過添加擾動構造一些對抗樣本,放給模型去訓練,以攻為守,提高模型在遇到對抗樣本時的魯棒性,同時一定程度也能提高模型的表現和泛化能力。

那么,什么樣的樣本才是好的對抗樣本呢?對抗樣本一般需要具有兩個特點:相對于原始輸入,所添加的擾動是微小的;

能使模型犯錯。

下面是一個對抗樣本的例子,決定就是你啦,胖達:一只胖達加了點擾動就被識別成了長臂猿

2. 對抗訓練的基本概念

GAN之父Ian Goodfellow在15年的ICLR中 [7] 第一次提出了對抗訓練這個概念,簡而言之,就是在原始輸入樣本

上加一個擾動

,得到對抗樣本后,用其進行訓練。也就是說,問題可以被抽象成這么一個模型:

其中,

為gold label,

為模型參數。那擾動要如何計算呢?Goodfellow認為,神經網絡由于其線性的特點,很容易受到線性擾動的攻擊。This linear behavior suggests that cheap, analytical perturbations of a linear model should also damage neural networks.

于是,他提出了 Fast Gradient Sign Method (FGSM) ,來計算輸入樣本的擾動。擾動可以被定義為:

其中,

為符號函數,

為損失函數。Goodfellow發現,令

,用這個擾動能給一個單層分類器造成99.9%的錯誤率。看似這個擾動的發現有點拍腦門,但是仔細想想,其實這個擾動計算的思想可以理解為:將輸入樣本向著損失上升的方向再進一步,得到的對抗樣本就能造成更大的損失,提高模型的錯誤率。回想我們上一節提到的對抗樣本的兩個要求,FGSM剛好可以完美地解決。

在 [7] 中,Goodfellow還總結了對抗訓練的兩個作用:提高模型應對惡意對抗樣本時的魯棒性;

作為一種regularization,減少overfitting,提高泛化能力。

3. Min-Max 公式

在 [7] 中,對抗訓練的理論部分被闡述得還是比較intuitive,Madry在2018年的ICLR中 [8]總結了之前的工作,并從優化的視角,將問題重新定義成了一個找鞍點的問題,也就是大名鼎鼎的Min-Max公式:

該公式分為兩個部分,一個是內部損失函數的最大化,一個是外部經驗風險的最小化。內部max是為了找到worst-case的擾動,也就是攻擊,其中,

為損失函數,

為擾動的范圍空間。

外部min是為了基于該攻擊方式,找到最魯棒的模型參數,也就是防御,其中

是輸入樣本的分布。

Madry認為,這個公式簡單清晰地定義了對抗樣本攻防“矛與盾”的兩個問題:如何構造足夠強的對抗樣本?以及,如何使模型變得刀槍不入?剩下的,就是如何求解的問題了。

4. 從 CV 到 NLP

以上提到的一些工作都還是停留在CV領域的,那么問題來了,可否將對抗訓練遷移到NLP上呢?答案是肯定的,但是,我們得考慮這么幾個問題:

首先,CV任務的輸入是連續的RGB的值,而NLP問題中,輸入是離散的單詞序列,一般以one-hot vector的形式呈現,如果直接在raw text上進行擾動,那么擾動的大小和方向可能都沒什么意義。Goodfellow在17年的ICLR中 [9] 提出了可以在連續的embedding上做擾動:Because the set of high-dimensional one-hot vectors does not admit in?nitesimal perturbation, we de?ne the perturbation on continuous word embeddings instead of discrete word inputs.

乍一思考,覺得這個解決方案似乎特別完美。然而,對比圖像領域中直接在原始輸入加擾動的做法,在embedding上加擾動會帶來這么一個問題:這個被構造出來的“對抗樣本”并不能map到某個單詞,因此,反過來在inference的時候,對手也沒有辦法通過修改原始輸入得到這樣的對抗樣本。我們在上面提到,對抗訓練有兩個作用,一是提高模型對惡意攻擊的魯棒性,二是提高模型的泛化能力。在CV任務,根據經驗性的結論,對抗訓練往往會使得模型在非對抗樣本上的表現變差,然而神奇的是,在NLP任務中,模型的泛化能力反而變強了,如[1]中所述:While adversarial training boosts the robustness, it is widely accepted by computer vision researchers that it is at odds with generalization, with classi?cation accuracy on non-corrupted images dropping as much as 10% on CIFAR-10, and 15% on Imagenet (Madry et al., 2018; Xie et al., 2019). Surprisingly, people observe the opposite result for language models (Miyato et al., 2017; Cheng et al., 2019), showing that adversarial training can improve both generalization and robustness.

因此,在NLP任務中,對抗訓練的角色不再是為了防御基于梯度的惡意攻擊,反而更多的是作為一種regularization,提高模型的泛化能力。

有了這些“思想準備”,我們來看看NLP對抗訓練的常用的幾個方法和具體實現吧。

5. NLP中的兩種對抗訓練 + PyTorch實現

a. Fast Gradient Method(FGM)

上面我們提到,Goodfellow在15年的ICLR [7] 中提出了Fast Gradient Sign Method(FGSM),隨后,在17年的ICLR [9]中,Goodfellow對FGSM中計算擾動的部分做了一點簡單的修改。假設輸入的文本序列的embedding vectors

,embedding的擾動為:

實際上就是取消了符號函數,用二范式做了一個scale,需要注意的是:這里的norm計算的是,每個樣本的輸入序列中出現過的詞組成的矩陣的梯度norm。原作者提供了一個TensorFlow的實現 [10],在他的實現中,公式里的

是embedding后的中間結果(batch_size, timesteps, hidden_dim),對其梯度

的后面兩維計算norm,得到的是一個(batch_size, 1, 1)的向量

。為了實現插件式的調用,筆者將一個batch抽象成一個樣本,一個batch統一用一個norm,由于本來norm也只是一個scale的作用,影響不大。筆者的實現如下:

import torch

class FGM():

def __init__(self, model):

self.model = model

self.backup = {}

def attack(self, epsilon=1., emb_name='emb.'):

# emb_name這個參數要換成你模型中embedding的參數名

for name, param in self.model.named_parameters():

if param.requires_grad and emb_name in name:

self.backup[name] = param.data.clone()

norm = torch.norm(param.grad)

if norm != 0 and not torch.isnan(norm):

r_at = epsilon * param.grad / norm

param.data.add_(r_at)

def restore(self, emb_name='emb.'):

# emb_name這個參數要換成你模型中embedding的參數名

for name, param in self.model.named_parameters():

if param.requires_grad and emb_name in name:

assert name in self.backup

param.data = self.backup[name]

self.backup = {}

需要使用對抗訓練的時候,只需要添加五行代碼:

# 初始化

fgm = FGM(model)

for batch_input, batch_label in data:

# 正常訓練

loss = model(batch_input, batch_label)

loss.backward() # 反向傳播,得到正常的grad

# 對抗訓練

fgm.attack() # 在embedding上添加對抗擾動

loss_adv = model(batch_input, batch_label)

loss_adv.backward() # 反向傳播,并在正常的grad基礎上,累加對抗訓練的梯度

fgm.restore() # 恢復embedding參數

# 梯度下降,更新參數

optimizer.step()

model.zero_grad()

PyTorch為了節約內存,在backward的時候并不保存中間變量的梯度。因此,如果需要完全照搬原作的實現,需要用register_hook接口[11]將embedding后的中間變量的梯度保存成全局變量,norm后面兩維,計算出擾動后,在對抗訓練forward時傳入擾動,累加到embedding后的中間變量上,得到新的loss,再進行梯度下降。不過這樣實現就與我們追求插件式簡單好用的初衷相悖,這里就不贅述了,感興趣的讀者可以自行實現。

b. Projected Gradient Descent(PGD)

內部max的過程,本質上是一個非凹的約束優化問題,FGM解決的思路其實就是梯度上升,那么FGM簡單粗暴的“一步到位”,是不是有可能并不能走到約束內的最優點呢?當然是有可能的。于是,一個很intuitive的改進誕生了:Madry在18年的ICLR中[8],提出了用Projected Gradient Descent(PGD)的方法,簡單的說,就是“小步走,多走幾步”,如果走出了擾動半徑為

的空間,就映射回“球面”上,以保證擾動不要過大:

其中

為擾動的約束空間,

為小步的步長。

import torch

class PGD():

def __init__(self, model):

self.model = model

self.emb_backup = {}

self.grad_backup = {}

def attack(self, epsilon=1., alpha=0.3, emb_name='emb.', is_first_attack=False):

# emb_name這個參數要換成你模型中embedding的參數名

for name, param in self.model.named_parameters():

if param.requires_grad and emb_name in name:

if is_first_attack:

self.emb_backup[name] = param.data.clone()

norm = torch.norm(param.grad)

if norm != 0 and not torch.isnan(norm):

r_at = alpha * param.grad / norm

param.data.add_(r_at)

param.data = self.project(name, param.data, epsilon)

def restore(self, emb_name='emb.'):

# emb_name這個參數要換成你模型中embedding的參數名

for name, param in self.model.named_parameters():

if param.requires_grad and emb_name in name:

assert name in self.emb_backup

param.data = self.emb_backup[name]

self.emb_backup = {}

def project(self, param_name, param_data, epsilon):

r = param_data - self.emb_backup[param_name]

if torch.norm(r) > epsilon:

r = epsilon * r / torch.norm(r)

return self.emb_backup[param_name] + r

def backup_grad(self):

for name, param in self.model.named_parameters():

if param.requires_grad:

self.grad_backup[name] = param.grad.clone()

def restore_grad(self):

for name, param in self.model.named_parameters():

if param.requires_grad:

param.grad = self.grad_backup[name]

使用的時候,要麻煩一點:

pgd = PGD(model)

K = 3

for batch_input, batch_label in data:

# 正常訓練

loss = model(batch_input, batch_label)

loss.backward() # 反向傳播,得到正常的grad

pgd.backup_grad()

# 對抗訓練

for t in range(K):

pgd.attack(is_first_attack=(t==0)) # 在embedding上添加對抗擾動, first attack時備份param.data

if t != K-1:

model.zero_grad()

else:

pgd.restore_grad()

loss_adv = model(batch_input, batch_label)

loss_adv.backward() # 反向傳播,并在正常的grad基礎上,累加對抗訓練的梯度

pgd.restore() # 恢復embedding參數

# 梯度下降,更新參數

optimizer.step()

model.zero_grad()

在[8]中,作者將這一類通過一階梯度得到的對抗樣本稱之為“一階對抗”,在實驗中,作者發現,經過PGD訓練過的模型,對于所有的一階對抗都能得到一個低且集中的損失值,如下圖所示:

我們可以看到,面對約束空間

內隨機采樣的十萬個擾動,PGD模型能夠得到一個非常低且集中的loss分布,因此,在論文中,作者稱PGD為“一階最強對抗”。也就是說,只要能搞定PGD對抗,別的一階對抗就不在話下了。

6. 實驗對照

為了說明對抗訓練的作用,筆者選了四個GLUE中的任務進行了對照試驗。實驗代碼是用的Huggingface的transfomers/examples/run_glue.py [12],超參都是默認的,對抗訓練用的也是相同的超參。

我們可以看到,對抗訓練還是有效的,在MRPC和RTE任務上甚至可以提高三四個百分點。不過,根據我們使用的經驗來看,是否有效有時也取決于數據集。畢竟:緣,妙不可言~

7. 總結

這篇博客梳理了NLP對抗訓練發展的來龍去脈,介紹了對抗訓練的數學定義,并對于兩種經典的對抗訓練方法,提供了插件式的實現,做了簡單的實驗對照。由于筆者接觸對抗訓練的時間也并不長,如果文中有理解偏差的地方,希望讀者不吝指出。

8. 一個彩蛋:Virtual Adversarial Training

除了監督訓練,對抗訓練還可以用在半監督任務中,尤其對于NLP任務來說,很多時候輸入的無監督文本多的很,但是很難大規模地進行標注,那么就可以參考[13]中提到的Virtual Adversarial Training進行半監督訓練。

首先,我們抽取一個隨機標準正態擾動(

),加到embedding上,并用KL散度計算梯度:

然后,用得到的梯度,計算對抗擾動,并進行對抗訓練:

實現方法跟FGM差不多,這里就不給出了。

更優雅的排版請見我的博客:瓦特蘭蒂斯

ReferenceFreeLB: Enhanced Adversarial Training for Language Understanding. https://arxiv.org/abs/1909.11764

Technical report on Conversational Question Answering. https://arxiv.org/abs/1909.10772

Towards a Robust Deep Neural Network in Text Domain A Survey. https://arxiv.org/abs/1902.07285

Adversarial Attacks on Deep Learning Models in Natural Language Processing: A Survey. https://arxiv.org/abs/1901.06796

Intriguing properties of neural networks. https://arxiv.org/abs/1312.6199

Explaining and Harnessing Adversarial Examples. https://arxiv.org/abs/1412.6572

Towards Deep Learning Models Resistant to Adversarial Attacks. https://arxiv.org/abs/1706.06083

Adversarial Training Methods for Semi-Supervised Text Classification. https://arxiv.org/abs/1605.07725

Distributional Smoothing with Virtual Adversarial Training. https://arxiv.org/abs/1507.00677

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/538288.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/538288.shtml
英文地址,請注明出處:http://en.pswp.cn/news/538288.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

linux 開源郵件 系統,4 個開源的命令行郵件客戶端

無論你承認與否,email并沒有消亡。對那些對命令行至死不渝的 Linux 高級用戶而言,離開 shell 轉而使用傳統的桌面或網頁版郵件客戶端并不適應。歸根結底,命令行最善于處理文件,特別是文本文件,能使效率倍增。幸運的是&…

kafka清理數據日志

背景問題: 使用kafka的路上踩過不少坑,其中一個就是在測試環境使用kafka一陣子以后,發現其日志目錄變的很大,占了磁盤很大空間,定位到指定目錄下發現其中一個系統自動創建的 topic,__consumer_offsets-45&a…

修改docker-倉庫資源地址Error response from daemon: Get https://index.docker.io/v1/search

[rootzengmg /]# docker search centosError response from daemon: Get https://index.docker.io/v1/search?qcentos: read tcp 52.200.132.201:443: i/o timeout docker在中國已經有了倉庫:https://www.docker-cn.com/registry-mirror 根據上面網站提供的修改方法…

oracle19c的版本號_Windows10安裝Oracle19c數據庫詳細記錄(圖文詳解)

1. 下載資源官網下載地址: 點此進入直接點擊下載,會自動開始下載。2. 開始安裝將下載的安裝包解壓到本地,右鍵-以管理員身份運行setup.exe,開始安裝(一定要以管理員身份運行,不然后面會報錯)。step1:選擇創…

qt調用Linux腳本范例,QT下實現對Linux Shell調用的幾種方法

使用QProcess QThread#include int main(){QProcess::execute("ls");return 0;}QProcess *poc new QProcess;poc-> start( "ping 222.207.53.1> hh ");打開hh文檔 讀取里面的內容給QTextEditQProcess *proc new QProcess;proc->addArgument(&qu…

Apache發布Groovy 2.5正式版及3.0預覽版

Apache基金會最近發布了Groovy2.5,新功能包括:\\AST轉換的改進\新的宏支持\其他雜項改進\運行Groovy 2.5至少需要JDK 7,在JDK 9上運行可以忽略良性警告。\\盡管最近人們把關注點轉到了其他JVM語言上(如Kotlin)&#xf…

virtualbox 命令

原文鏈接:http://418684644-qq-com.iteye.com/blog/1451000 ----------------------------------------------------------------------------------------- 查看當前虛擬機 VBxoManage list vms 查看當前正在運行的虛擬機 VBoxManage list runningvms 啟動虛擬機 …

js小學生圖區_推薦12個最好的 JavaScript 圖形繪制庫

眾多周知,圖形和圖表要比文本更具表現力和說服力。圖表是數據圖形化的表示,通過形象的圖表來展示數據,比如條形圖,折線圖,餅圖等等。可視化圖表可以幫助開發者更容易理解復雜的數據,提高生產的效率和 Web 應…

linux 關閉登錄權限,linux – /var/www/html的權限[已關閉]

我有一個虛擬CentOS服務器與GoDaddy,我無法設置/ var / www / html的權限。用戶不能以root用戶身份登錄,甚至不能將自己添加到根組中,因此,我將自己寫入的角落:>我使用以下命令更改了其所有者(我使用httpd.conf中的…

cifar10數據集測試有多少張圖_pytorch VGG11識別cifar10數據集(訓練+預測單張輸入圖片操作)...

首先這是VGG的結構圖,VGG11則是紅色框里的結構,共分五個block,如紅框中的VGG11第一個block就是一個conv3-64卷積層:一,寫VGG代碼時,首先定義一個 vgg_block(n,in,out)方法,用來構建VGG中每個blo…

npm ERR! Please try running this command again as root/Administrator.

win10操作系統下 webstrom的控制臺使用 npm install angular-file-upload 安裝組件,報錯:npm ERR! Please try running this command again as root/Administrator. 解決方法: 開始按鈕右鍵---- windows powershell(管理員&…

map flatmap mappartition flatMapToPair四種用法區別

原文鏈接:http://blog.csdn.net/u013086392/article/details/55666912 ----------------------------------------------------------------------------------- map: 我們可以看到數據的每一行在map之后產生了一個數組,那么rdd存儲的是一個數組的集合…

eve可以在linux運行嗎,ubuntu下為eve游戲搭載 wine環境

援引該地址的參考,本文僅做整理:http://bbs.eve-china.com/thread-626756-1-1.htmllinux的顯卡是否驅動成功,依次鍵入如下命令察看:sudo apt-get install mesa-utils /*安裝 mesa-utils 的指令*/glxinfo | grep r…

自動飛行控制系統_波音公司將重設計737MAX自動飛行控制系統!力求十月前復飛...

據西雅圖時報8月1日報道,美國聯邦航空管理局(FAA)在6月份對波音737 MAX飛行控制系統進行新的嚴格測試時,發現了一個潛在的缺陷,該缺陷促使波音公司對其基本的軟件設計進行變革。波音公司如今正在改變737 MAX的自動飛行控制系統軟件&#xff0…

每日一題——LeetCode141.環形鏈表

個人主頁:白日依山璟 專欄:Java|數據結構與算法|每日一題 文章目錄 1. 題目描述示例1:示例2:示例3:提示: 2. 思路3. 代碼 1. 題目描述 給你一個鏈表的頭節點 head ,判斷鏈表中是否有環。 如果鏈表中有某…

Android O 獲取APK文件權限 Demo案例

1. 通過 aapt 工具查看 APK權限 C:\Users\zh>adb pull /system/priv-app/Settings . /system/priv-app/Settings/: 3 files pulled. 10.8 MB/s (48840608 bytes in 4.325s)C:\Users\zh>aapt d permissions C:\Users\zh\Settings\Settings.apk package: com.android.sett…

VBoxManage命令更詳盡版

原文鏈接:http://418684644-qq-com.iteye.com/blog/1451000 ------------------------------------- VBoxManage命令詳解(一) 本人對vboxmange命令按我個人的理解作了解釋,由于本人水平有限難免有錯誤的地方,希望大…

linux make命令實現,Linux make命令主要參數詳解

-C dir或者 --directoryDIR在讀取makefile文件前,先切換到“dir”目錄下,即把dir作為當前目錄。如果存在多個-C選項,make的最終當前目錄是第一個目錄的相對路徑,如“make –C /home/leowang –C document”,等價于“ma…

行人屬性數據集pa100k_基于InceptionV3的多數據集聯合訓練的行人外觀屬性識別方法與流程...

本發明涉及模式識別技術、智能監控技術等領域,具體的說,是基于Inception V3的多數據集聯合訓練的行人外觀屬性識別方法。背景技術:近年來,視頻監控系統已經被廣泛應用于安防領域。安防人員通過合理的攝像頭布局,實現對…

VBoxManage獲取虛擬機IP地址

在宿主機Linux上安裝VirtualBox,然后VirtualBox上安裝linux虛擬機,在Virtualbox非界面啟動虛擬機時,ip地址無法查看。怎么辦? 使用命令: VBoxManage guestproperty enumerate 虛擬機名 | grep "Net.*V4.*IP"…