PCGrad解決多任務沖突

論文解讀:"Gradient Surgery for Multi-Task Learning"

1. 論文標題直譯
  • Gradient Surgery: 梯度手術
  • for Multi-Task Learning: 應用于多任務學習

合在一起就是:為多任務學習量身定制的梯度手術。這個名字非常形象地概括了它的核心思想。

2. 它要解決的核心問題:多任務學習中的“梯度沖突”

想象一下,你正在訓練一個AI模型來開一輛車,它需要同時完成兩個任務:

  • 任務A: 識別紅綠燈(要求模型關注圖像上方的顏色區域)。
  • 任務B: 保持在車道線內(要求模型關注圖像下方的白色線條)。

在訓練時,模型會根據任務A的錯誤計算出一個梯度?g_A,根據任務B的錯誤計算出另一個梯度?g_B。梯度本質上是告訴模型參數“應該朝哪個方向更新才能做得更好”。

問題來了:?如果某次更新中,g_A?說“參數應該向東調整”,而?g_B?恰好說“參數應該向西調整”,那么把它們簡單相加(g_A + g_B)的結果可能接近于零,模型幾乎學不到任何東西。

更常見的情況是,g_A?想讓參數向東走,g_B?想讓參數向西北走。它們的合力會是一個“折衷”的方向,這個方向可能對兩個任務都不是最優的,甚至可能提升一個任務的性能卻損害了另一個。

這種現象就叫做梯度沖突 (Gradient Conflict)?或?負遷移 (Negative Transfer)。這是多任務學習中一個長期存在的痛點,它會導致訓練不穩定,模型性能難以提升。

3. PCGrad 的解決方案:“梯度手術”

PCGrad?(Projected Gradient Descent) 提出了一種非常聰明的解決方案,就像一個外科醫生一樣,在更新模型參數之前,先對這些相互沖突的梯度做一次“手術”。

手術流程如下:

第1步:分別計算每個任務的梯度?和傳統方法不同,它不把所有損失加起來,而是為每個任務的損失?loss_A,?loss_B... 單獨計算梯度?g_A,?g_B...

第2步:診斷是否存在“沖突”?PCGrad 遍歷所有梯度對(如?g_A?和?g_B),并通過計算它們的點積 (dot product)?來判斷它們是否沖突。

  • 如果?dot(g_A, g_B) > 0: 說明兩個梯度的夾角小于90度,它們大方向一致,是“盟友”。無需手術
  • 如果?dot(g_A, g_B) < 0: 說明兩個梯度的夾角大于90度,它們的方向是“敵對”的。診斷為沖突,需要手術!

第3步:執行“手術”——投影和矯正?當檢測到?g_A?和?g_B?沖突時,PCGrad 會執行以下操作:

  1. 投影 (Project):將梯度?g_A?投影到梯度?g_B?的方向上,得到一個分量?proj_B(g_A)。這個分量可以被理解為?g_A?中與?g_B?“正面沖突”的那一部分。
  2. 矯正 (Correct):從原始梯度?g_A?中減去這個沖突分量:g_{A_{new}} = g_A - proj_B(g_A)

手術效果:?經過手術后的新梯度?g_{A_{new}}?與?g_B?變成了正交的(夾角為90度)。這意味著,g_{A_{new}}?的更新方向中,已經完全剔除了與?g_B?直接對抗的部分。它只保留了對自己有益,且不傷害對方的部分。

PCGrad 會對所有發生沖突的梯度對都執行這個“手術”。

第4步:合并與更新?將所有經過“手術”矯正后的新梯度相加,得到最終的、和諧的、沒有內斗的梯度,然后用這個梯度去更新模型參數。

4. TensorFlow 實現中的 PCGrad

你在代碼中看到的?PCGrad?通常是一個優化器包裝器 (Optimizer Wrapper)。它的用法一般是這樣的:

  1. 首先,定義一個基礎的優化器,比如 Adam。

    base_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
  2. 然后,用?PCGrad?包裝它

    from .PCGrad import PCGrad
    optimizer = PCGrad(base_optimizer)
  3. 在訓練循環中,用法會稍有不同。 你不再是計算一個總的 loss 然后調用?apply_gradients。而是:

    # 1. 分別計算每個任務的 loss
    loss_A = compute_loss_A(y_true_A, y_pred_A)
    loss_B = compute_loss_B(y_true_B, y_pred_B)
    list_of_losses = [loss_A, loss_B]# 2. PCGrad 優化器會接管梯度的計算和矯正
    # 這一步是 PCGrad 內部實現的,它會:
    #   - 為每個 loss 計算梯度
    #   - 執行梯度手術
    #   - 返回最終的梯度
    # 通常會通過一個自定義的 train_step 來實現
    final_gradients = optimizer.get_gradients(list_of_losses, model.trainable_variables)# 3. 應用經過手術后的梯度
    optimizer.apply_gradients(zip(final_gradients, model.trainable_variables))

總結

方面解釋
它是什么?PCGrad?是一種優化策略,而非損失函數或模型架構。
解決什么問題?解決多任務學習中的梯度沖突 (Gradient Conflict)?問題。
核心思想?梯度手術 (Gradient Surgery):在更新模型前,先檢測并消除梯度之間的沖突部分。
如何實現?通過向量投影,將沖突的梯度分量從原始梯度中移除,使它們變得正交
最終效果?1. 訓練過程更穩定。 2. 避免了任務間的“內耗”,有助于所有任務性能的同步提升。

因此,當你看到代碼中使用了?PCGrad,就可以立刻明白:這個項目正在處理一個多任務學習的場景,并且使用了一種相當先進的技術來確保不同任務能夠“和平共處”,協同進步。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/97851.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/97851.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/97851.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Nvidia顯卡架構解析與cuda應用生態淺析

文章目錄 0. Nvidia顯卡簡介 一、主要顯卡系列 二、主要GPU架構與代表產品 1.main 1.1 CUDA 13.0 的重大變化 1.2 V100 的硬件短板已顯現 1.3 這意味著什么? 1.4 寫在后面 彩蛋:V100 0. Nvidia顯卡簡介 一、主要顯卡系列 GeForce 系列(消費級) 用途:游戲、創作、日常圖形…

開發指南:使用 MQTTNet 庫構建 .Net 物聯網 MQTT 應用程序

一、背景介紹 隨著物聯網的興起&#xff0c;.Net 框架在構建物聯網應用程序方面變得越來越流行。微軟的 .Net Core 和 .Net 框架為開發人員提供了一組工具和庫&#xff0c;以構建可以在 Raspberry Pi、HummingBoard、BeagleBoard、Pine A64 等平臺上運行的物聯網應用程序。 MQT…

突破性能瓶頸:基于騰訊云EdgeOne的AI圖片生成器全球加速實踐

1. 項目背景與挑戰 1.1 開發背景 隨著AIGC技術爆發&#xff0c;我們團隊決定開發一款多模型支持的AI圖片生成器&#xff0c;主要解決以下痛點&#xff1a; 不同AI模型的參數規范不統一生成結果難以系統化管理缺乏企業級的安全水印方案全球用戶訪問延遲高&#xff0c;中國用戶…

一、Java 基礎入門:從 0 到 1 認識 Java(詳細筆記)

1.1 Java 語言簡介與發展歷程 Java 是一門面向對象的高級編程語言&#xff0c;以“跨平臺、安全、穩定”為核心特性&#xff0c;自誕生以來長期占據編程語言排行榜前列&#xff0c;廣泛應用于后端開發、移動端開發、大數據等領域。 1.1.1 起源與核心人物 起源背景&#xff1…

uniapp:根據目的地經緯度,名稱,喚起高德/百度地圖來導航,兼容App,H5,小程序

1、需要自行申請高德地圖的key,配置manifest.json 2、MapSelector選擇組件封裝 <template><view><u-action-sheet :list="mapList" v-model="show" @click="changeMap"></u-action-sheet></view> </template&…

我對 WPF 動搖時的選擇:.NET Framework 4.6.2+WPF+Islands+UWP+CompostionApi

目錄 NET Framework 4.6.2的最大亮點 為什么固守462不升級 WPF-開發體驗的巔峰 為什么對WPF動搖了 基于IslandsUWP的濾鏡嘗試 總結 NET Framework 4.6.2的最大亮點 安全性能大提升&#xff1a; 默認啟用TLS1.2協議&#xff0c;更安全&#xff0c;它為后續的版本提供了重…

SpringBoot大文件下載失敗解決方案

SpringBoot大文件下載失敗解決方案 后端以文件流方式給前端接收下載文件,文件過大時出現下載失敗的情況或者打開后提示文件損壞,實際是字節未完全讀取寫入。 針對大文件下載失敗的情況,以下是詳細的解決方案: 大文件下載失敗的主要原因 內存溢出:一次性加載大文件到內存…

torch.gather

torch.gather 介紹 torch.gather(input, dim, index, *, sparse_gradFalse, outNone) → Tensor 沿由 dim 指定的軸收集值。 對于三維張量&#xff0c;輸出按如下方式確定&#xff1a; out[i][j][k] input[index[i][j][k]][j][k] # 如果 dim 0 out[i][j][k] input[i][i…

Golang | http/server Gin框架簡述

http/server http指的是Golang中的net/http包&#xff0c;這里用的是1.23.10。 概覽 http包的作用文檔里寫的很簡明&#xff1a;Package http provides HTTP client and server implementations. 主要是提供http的客戶端和服務端&#xff0c;也就是能作為客戶端發http請求&a…

Vision Transformer (ViT) :Transformer在computer vision領域的應用(三)

Experiment 上來的一段話就概括了整章的內容。 We evaluate the representation learning capabilities of ResNet, Vision Transformer (ViT), and the hybrid. 章節的一開頭就說明了,對比的模型就是 ResNet,CNN領域中的代碼模型。 ViT。 上一篇中提到的Hybrid模型,也就是…

5-12 WPS JS宏 Range數組規范性測試

Range()數組是JS宏中不缺少的組成部分,了解Range()數組的特性必不可少,下面我們一起測試一下各種Range()數組。 1.Range()數組特性 單元格區域:Range("a2:m2")與Range("a2","m2")的類型都是:Range/Object,功能都為單元格區域,功能…

uniapp微信小程序保存海報到手機相冊canvas

在uniapp中實現微信小程序保存海報到手機相冊&#xff0c;主要涉及Canvas繪制和圖片保存。以下是關鍵步驟和代碼示例&#xff1a; 一、關鍵代碼展示&#xff1a; 1. 模板配置&#xff1a;頁面展示該海報&#xff0c;可直接查看&#xff0c;也可下載保存到手機相冊&#xff0c;h…

glib2-2.62.5-7.ky10.x86_64.rpm怎么安裝?Kylin Linux RPM包安裝詳細步驟

一、準備工作 ?確認系統版本? 這個包是 ky10的&#xff08;也就是 openEuler 20.03 LTS SP3 或類似版本&#xff09;&#xff0c;而且是 ?x86_64 架構&#xff08;就是常見的64位電腦&#xff09;?。 你要先確認你的系統是不是這個版本&#xff0c;不然可能裝不上或者出問題…

webrtc之語音活動下——VAD人聲判定原理以及源碼詳解

文章目錄前言一、高斯混合模型介紹1.高斯模型舉例1&#xff09;定義2&#xff09;舉例說明2.高斯混合模型(GMM)1&#xff09;定義2&#xff09;舉例說明3&#xff09;一維曲線二、VAD高斯混合模型1.模型訓練介紹1&#xff09;訓練方法2&#xff09;訓練結果2.噪聲高斯模型分布1…

【Redis】-- 主從復制

文章目錄1. 主從復制1.1 主從復制是怎么個事&#x1f914;1.2 拓撲結構1.2.1 一主一從拓撲1.2.2 一主多從拓撲1.2.3 樹形拓撲1.3 主從復制原理1.3.1 復制過程1.3.2 數據同步PSYNC1.3.2.1 replicationid/replid (復制id)1.3.2.2 復制偏移量維護1.3.3 psync運行流程1.3.4 全量復制…

開源炸場!阿里通義千問Qwen3-Next發布:80B參數僅激活3B,訓練成本降90%,長文本吞吐提升10倍?

開源炸場&#xff01;阿里通義千問Qwen3-Next發布&#xff1a;80B參數僅激活3B&#xff0c;訓練成本降90%&#xff0c;長文本吞吐提升10倍? 開源世界迎來震撼突破&#xff01; 通義千問團隊最新發布的Qwen3-Next架構&#xff0c;以其獨創的"小而精"設計理念&#x…

【C++入門】C++基礎

目錄 1. 命名空間 1.1 命名空間的創建和使用 2. 輸入輸出 2.1 輸出 2.2 輸入 3. 缺省參數 3.1 全缺省 3.2 半缺省 4.函數重載 4.1 為什么C支持重載而C語言不支持&#xff1f; 4.1.2 編譯的四個過程 4.2 extern是什么 5.引用 5.1 引用的特性 5.1.1 引用的“隱式類…

如何往mp4視頻添加封面圖和獲取封面圖?

前言&#xff1a;大家好&#xff0c;之前有給大家分享過mp4錄像的方案&#xff0c;今天給大家分享的內容是&#xff1a;如何在添加自定義的封面圖到mp4里面去&#xff0c;以及在進入回放mp4視頻列表的時候&#xff0c;怎么獲取mp4視頻里面的封面圖&#xff0c;當然這個獲取到的…

你的第一個Transformer模型:從零實現并訓練一個迷你ChatBot

點擊 “AladdinEdu&#xff0c;同學們用得起的【H卡】算力平臺”&#xff0c;注冊即送-H卡級別算力&#xff0c;80G大顯存&#xff0c;按量計費&#xff0c;靈活彈性&#xff0c;頂級配置&#xff0c;學生更享專屬優惠。 引言&#xff1a;破除神秘感&#xff0c;擁抱核心思想 …

【20期】滬深指數《實時交易數據》免費獲取股票數據API:PythonJava等5種語言調用實例演示與接口API文檔說明

? 隨著量化投資在金融市場的快速發展&#xff0c;高質量數據源已成為量化研究的核心基礎設施。本文將系統介紹股票量化分析中的數據獲取解決方案&#xff0c;涵蓋實時行情、歷史數據及基本面信息等關鍵數據類型。 本文將重點演示這些接口在以下技術棧中的實現&#xff1a; P…