【深度學習】NLP中的對抗訓練

????????在NLP中,對抗訓練往往都是針對嵌入層(包括詞嵌入,位置嵌入,segment嵌入等等)開展的,思想很簡單,即針對嵌入層添加干擾,從而提高模型的魯棒性和泛化能力,下面結合具體代碼講解一些NLP中常見對抗訓練算法。

1.Fast Gradient Method(FGM)

????????FGM的思想是針對詞嵌入加入梯度方向的干擾,至于干擾的大小是我們可以調節的,增加干擾后的樣本可以作為額外的對抗樣本進行訓練,以此提高模型的效果。由于我們在訓練時會針對每個樣本都進行一次額外的增加干擾后的訓練,所以使用FGM后訓練時間理論上也會大概增加一倍。

? ? ? ? FGM在原訓練代碼的基礎上,主要增加了以下幾個額外的操作:針對嵌入層添加干擾并備份參數,計算添加干擾后的損失,梯度回傳從而累積添加干擾后的梯度,恢復原來的嵌入層參數。

1.1 算法流程

對于每個x:1.計算x的前向loss、反向傳播得到梯度2.根據embedding矩陣的梯度計算出r,并加到當前embedding上,相當于x+r3.計算x+r的前向loss,反向傳播得到對抗的梯度,累加到(1)的梯度上4.將embedding恢復為(1)時的值5.根據(3)的梯度對參數進行更新

? 1.2?具體代碼

import torch
class FGM():def __init__(self, model):self.model = modelself.backup = {}def attack(self, epsilon=1., emb_name='word_embeddings'):# emb_name這個參數要換成你模型中embedding的參數名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:#print('增加擾動的對象是', name)#print(type(param.grad))self.backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = epsilon * param.grad / normparam.data.add_(r_at)def restore(self, emb_name='word_embeddings'):# emb_name這個參數要換成你模型中embedding的參數名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name: assert name in self.backupparam.data = self.backup[name]self.backup = {}

1.3 具體用法

fgm = FGM(model) # (#1)初始化
for batch_input, batch_label in data:loss = model(batch_input, batch_label) # 正常訓練loss.backward() # 反向傳播,得到正常的grad# 對抗訓練fgm.attack() # (#2)在embedding上添加對抗擾動loss_adv = model(batch_input, batch_label) # (#3)計算含有擾動的對抗樣本的lossloss_adv.backward() # (#4)反向傳播,并在正常的grad基礎上,累加對抗訓練的梯度fgm.restore() # (#5)恢復embedding參數# 梯度下降,更新參數optimizer.step()model.zero_grad()

2.Projected Gradient Descent?(PGD

????????Project Gradient Descent(PGD)是一種迭代攻擊算法,相比于普通的FGM 僅做一次迭代,PGD是做多次迭代,每次走一小步,每次迭代都會將擾動投射到規定范圍內。其中r為擾動約束空間(一個半徑為r的球體),原始的輸入樣本對應的初識點為球心,避免擾動超過球面。迭代多次后,保證擾動在一定范圍內,如下圖所示:

?2.1 算法流程

對于每個x:1.計算x的前向loss、反向傳播得到梯度并備份對于每步t:2.根據embedding矩陣的梯度計算出r,并加到當前embedding上,相當于x+r(超出范圍則投影回epsilon內)3.t不是最后一步: 將梯度歸0,根據1的x+r計算前后向并得到梯度4.t是最后一步: 恢復(1)的梯度,計算最后的x+r并將梯度累加到(1)上5.將embedding恢復為(1)時的值6.根據(4)的梯度對參數進行更新

?2.2??具體代碼

import torch
class PGD():def __init__(self, model):self.model = modelself.emb_backup = {}self.grad_backup = {}def attack(self, epsilon=1., alpha=0.3, emb_name='word_embeddings', is_first_attack=False):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:if is_first_attack:self.emb_backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = alpha * param.grad / normparam.data.add_(r_at)param.data = self.project(name, param.data, epsilon)def restore(self, emb_name='word_embeddings'):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name: assert name in self.emb_backupparam.data = self.emb_backup[name]self.emb_backup = {}def project(self, param_name, param_data, epsilon):r = param_data - self.emb_backup[param_name]if torch.norm(r) > epsilon:r = epsilon * r / torch.norm(r)return self.emb_backup[param_name] + rdef backup_grad(self):for name, param in self.model.named_parameters():if param.requires_grad:self.grad_backup[name] = param.grad.clone()def restore_grad(self):for name, param in self.model.named_parameters():if param.requires_grad:param.grad = self.grad_backup[name]

2.3 具體用法

pgd = PGD(model)
K = 3
for batch_input, batch_label in data:# 正常訓練loss = model(batch_input, batch_label)loss.backward() # 反向傳播,得到正常的gradpgd.backup_grad()# 累積多次對抗訓練——每次生成對抗樣本后,進行一次對抗訓練,并不斷累積梯度for t in range(K):pgd.attack(is_first_attack=(t==0)) # 在embedding上添加對抗擾動, first attack時備份param.dataif t != K-1:model.zero_grad()else:pgd.restore_grad()loss_adv = model(batch_input, batch_label)loss_adv.backward() # 反向傳播,并在正常的grad基礎上,累加對抗訓練的梯度pgd.restore() # 恢復embedding參數# 梯度下降,更新參數optimizer.step()model.zero_grad()

Reference:

1.NLP中的對抗訓練_colourmind的博客-CSDN博客

2.【NLP】NLP中的對抗訓練_風度78的博客-CSDN博客

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/36979.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/36979.shtml
英文地址,請注明出處:http://en.pswp.cn/news/36979.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Spark 學習記錄

基礎 SparkContext是什么?有什么作用? https://blog.csdn.net/Shockang/article/details/118344357 SparkContext 是什么? SparkContext 是通往 Spark 集群的唯一入口,可以用來在 Spark 集群中創建 RDDs 、累加和廣播變量( Br…

【數據庫基礎】Mysql下載安裝及配置

下載 下載地址:https://downloads.mysql.com/archives/community/ 當前最新版本為 8.0版本,可以在Product Version中選擇指定版本,在Operating System中選擇安裝平臺,如下 安裝 MySQL安裝文件分兩種 .msi和.zip [外鏈圖片轉存失…

C++11時間日期庫chrono的使用

chrono是C11中新加入的時間日期操作庫,可以方便地進行時間日期操作,主要包含了:duration, time_point, clock。 時鐘與時間點 chrono中用time_point模板類表示時間點,其支持基本算術操作;不同時鐘clock分別返回其對應…

vue中router路由的原理?兩種路由模式如何實現?(vue2) -(上)

平時我們編寫路由時,通常直接下載插件使用,在main.js文件中引入直接通過引入vue-router中的Router通過Vue.use使用以后定義一個routeMap數組,里邊是我們編寫路由的地方,最后通過實例化一個 Router實例 將routes我們定義的routeMao…

Docker中部署Nginx

1.Nginx部署需求 2.操作教程 3.實際步驟 把配置粘過來。

客戶端遠程啟動服務器腳本文件

目錄 軟件需求 實現 方法1 方法2 方法3 軟件需求 有兩臺計算機,一臺是linux客戶端,另一臺是linux服務器。要求操作員可以在客戶端遠程啟動服務器上的腳本文件,控制服務器。 實現 方法1 客戶端通過ssh登錄服務器,然后通過…

Cookie、Session、Token的區別

有人或許還停留在它們只是驗證身份信息的機制,但是它們之間的關系你真的弄懂了么? 發展史: Coolie: Netscape Communications 公司引入了 Cookie 概念,作為在客戶端存儲狀態信息的一種方法。初始目的是為了解決 HTTP 的無狀態性…

Python爬蟲:單線程、多線程、多進程

前言 在使用爬蟲爬取數據的時候,當需要爬取的數據量比較大,且急需很快獲取到數據的時候,可以考慮將單線程的爬蟲寫成多線程的爬蟲。下面來學習一些它的基礎知識和代碼編寫方法。 一、進程和線程 進程可以理解為是正在運行的程序的實例。進…

python爬蟲數據解析xpath、jsonpath,bs4

數據的解析 解析數據的方式大概有三種 xpathJsonPathBeautifulSoup xpath 安裝xpath插件 打開谷歌瀏覽器擴展程序,打開開發者模式,拖入插件,重啟瀏覽器,ctrlshiftx,打開插件頁面 安裝lxml庫 安裝在python環境中的Scri…

劍指Offer61.撲克牌中的順子 C++

1、題目描述 從若干副撲克牌中隨機抽 5 張牌,判斷是不是一個順子,即這5張牌是不是連續的。2~10為數字本身,A為1,J為11,Q為12,K為13,而大、小王為 0 ,可以看成任意數字。…

并發服務器模型,多線程并發

一、多線程并發完整代碼 #include <stdio.h> #include <sys/types.h> #include <sys/socket.h> #include <arpa/inet.h> #include <string.h> #include <unistd.h> #include <sys/wait.h> #include <stdlib.h> #include <…

突然讓做性能測試?試試RunnerGo

當前&#xff0c;性能測試已經是一名軟件測試工程師必須要了解&#xff0c;甚至熟練使用的一項技能了&#xff0c;在工作時可能每次發版都要跑一遍性能&#xff0c;跑一遍自動化。性能測試入門容易&#xff0c;深入則需要太多的知識量&#xff0c;今天這篇文章給大家帶來&#…

Rocky Linux更換為國內源

Rocky Linux提供的可供切換的源列表&#xff1a;Mirrors - Mirror Manager 其中以 COUNTRY 列為 CN 的是國內源。 選擇其中一個Rocky Linux 源使用幫助 — USTC Mirror Help 文檔 操作前請做好備份 對于 Rocky Linux 8&#xff0c;使用以下命令替換默認的配置 sed -e s|^mirr…

新能源汽車電控系統

新能源汽車電控系統主要分為&#xff1a;三電系統電控系統、高壓系統電控系統、低壓系統電控系統 三電系統電控系統 包括整車控制器、電池管理系統、驅動電機控制器等。 整車控制器VCU 整車控制器作為電動汽車中央控制單元&#xff0c;是整個控制系統的核心&#xff0c;也是…

zabbix監控mysql數據庫、nginx、Tomcat

zabbix監控mysql數據庫、nginx、Tomcat 一.zabbix監控mysql數據庫 1.環境規劃 hostIP部署zabbix-server192.168.198.17zabbix服務器搭建zabbix-mysql192.168.198.15zabbix客戶端搭建 2.zabbix-server安裝部署&#xff08;192.168.198.17&#xff09; 請參考以下配置&#…

Azure概念介紹

云計算定義 云計算是一種使用網絡進行存儲和處理數據的計算方式。它通過將數據和應用程序存儲在云端服務器上&#xff0c;使用戶能夠通過互聯網訪問和使用這些資源&#xff0c;而無需依賴于本地硬件和軟件。 發展歷史 云計算的概念最早可以追溯到20世紀60年代的時候&#x…

mysql 分庫分表淺析

分表是分散數據庫壓力的好方法。 分表&#xff0c;最直白的意思&#xff0c;就是將一個表結構分為多個表&#xff0c;然后&#xff0c;可以再同一個庫里&#xff0c;也可以放到不同的庫。 當然&#xff0c;首先要知道什么情況下&#xff0c;才需要分表。個人覺得單表記錄條數達…

2023河南萌新聯賽第(五)場:鄭州輕工業大學C-數位dp

鏈接&#xff1a;登錄—專業IT筆試面試備考平臺_牛客網 給定一個正整數 n&#xff0c;你可以對 n 進行任意次&#xff08;包括零次&#xff09;如下操作&#xff1a; 選擇 n 上的某一數位&#xff0c;將其刪去&#xff0c;剩下的左右部分合并。例如 123&#xff0c;你可以選擇…

年至年的選擇仿elementui的樣式

組件&#xff1a;<!--* Author: liuyu liuyuxizhengtech.com* Date: 2023-02-01 16:57:27* LastEditors: wangping wangpingxizhengtech.com* LastEditTime: 2023-06-30 17:25:14* Description: 時間選擇年 - 年 --> <template><div class"yearPicker"…

Smart HTML Elements 16.1 Crack

Smart HTML Elements 是一個現代 Vanilla JS 和 ES6 庫以及下一代前端框架。企業級 Web 組件包括輔助功能&#xff08;WAI-ARIA、第 508 節/WCAG 合規性&#xff09;、本地化、從右到左鍵盤導航和主題。與 Angular、ReactJS、Vue.js、Bootstrap、Meteor 和任何其他框架集成。 智…