當你在深度學習中進入“多任務學習(Multi-task Learning)”的領域,第一道關卡可能不是設計網絡結構,也不是準備數據集,而是:多個Loss到底是加起來一起backward,還是分別backward?
這個問題看似簡單,卻涉及PyTorch計算圖的構建邏輯、自動求導機制、內存管理、任務耦合性、優化目標權衡等多重復雜因素。
1. 多任務學習中的Loss定義
1.1 多任務Loss形式
在一個多任務模型中,我們一般會有若干個子任務,設任務數為 ,每個任務都有一個對應的Loss函數 ,我們最終優化的Loss是:
其中, 是任務的權重系數。
1.2 PyTorch中的基本寫法
在PyTorch中,多任務Loss通常如下所示:
loss_task1 = criterion1(output1, target1)
loss_task2 = criterion2(output2, target2)
total_loss &