知識蒸餾的蒸餾損失方法代碼總結(包括:基于logits的方法:KLDiv,dist,dkd等,基于中間層提示的方法:)

有兩種知識蒸餾方法:一種利用教師模型的輸出概率(基于logits的方法)[15,14,11],另一種利用教師模型的中間表示(基于提示的方法)[12,13,18,17]。基于logits的方法利用教師的輸出作為輔助信號來訓練一個較小的模型,即學生模型:

利用教師模型的輸出概率(基于logits的方法)

該類方法損失函數為:
在這里插入圖片描述

DIST

Tao Huang,Shan You,Fei Wang,Chen Qian,and Chang Xu.Knowledge distillation from a strongerteacher.In Advances in Neural Information Processing Systems,2022.

import torch.nn as nndef cosine_similarity(a, b, eps=1e-8):return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)def pearson_correlation(a, b, eps=1e-8):return cosine_similarity(a - a.mean(1).unsqueeze(1),b - b.mean(1).unsqueeze(1), eps)def inter_class_relation(soft_student_outputs, soft_teacher_outputs):return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()def intra_class_relation(soft_student_outputs, soft_teacher_outputs):return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))class DIST(nn.Module):def __init__(self, beta=1.0, gamma=1.0, temp=1.0):super(DIST, self).__init__()self.beta = betaself.gamma = gammaself.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = (student_preds / self.temp).softmax(dim=1)soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)kd_loss = self.beta * inter_loss + self.gamma * intra_lossreturn kd_loss

KLDiv (2015年的原始方法)

import torch.nn as nn
import torch.nn.functional as F# loss = alpha * hard_loss + (1-alpha) * kd_loss,此處是單單的kd_loss
class KLDiv(nn.Module):def __init__(self, temp=1.0):super(KLDiv, self).__init__()self.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()kd_loss *= self.temp ** 2return kd_loss

dkd (Decoupled KD(CVPR 2022) )

Borui Zhao,Quan Cui,Renjie Song,Yiyu Qiu,and Jiajun Liang.Decoupled knowledge distillation.InIEEE/CVF Conference on Computer Vision and Pattern Recognition,2022.

import torch
import torch.nn as nn
import torch.nn.functional as Fdef dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):gt_mask = _get_gt_mask(logits_student, target)other_mask = _get_other_mask(logits_student, target)pred_student = F.softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)pred_student = cat_mask(pred_student, gt_mask, other_mask)pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)log_pred_student = torch.log(pred_student)tckd_loss = (F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')* (temperature ** 2))pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1)log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1)nckd_loss = (F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')* (temperature ** 2))return alpha * tckd_loss + beta * nckd_lossdef _get_gt_mask(logits, target):target = target.reshape(-1)mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()return maskdef _get_other_mask(logits, target):target = target.reshape(-1)mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()return maskdef cat_mask(t, mask1, mask2):t1 = (t * mask1).sum(dim=1, keepdims=True)t2 = (t * mask2).sum(1, keepdims=True)rt = torch.cat([t1, t2], dim=1)return rtclass DKD(nn.Module):def __init__(self, alpha=1., beta=2., temperature=1.):super(DKD, self).__init__()self.alpha = alphaself.beta = betaself.temperature = temperaturedef forward(self, z_s, z_t, **kwargs):target = kwargs['target']if len(target.shape) == 2:  # mixup / smoothingtarget = target.max(1)[1]kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)return kd_loss

利用教師模型的中間表示(基于提示的方法)

該類方法損失函數為:
[ L_{hint} = D_{hint}(T_s(F_s), T_t(F_t)) ]

ReviewKD (CVPR2021)

論文:

Pengguang Chen,Shu Liu,Hengshuang Zhao,and Jiaya Jia.Distilling knowledge via knowledge review.In IEEE/CVF Conference on Computer Vision and Pattern Recognition,2021.

代碼:

https://github.com/dvlab-research/ReviewKD

Adriana Romero,Nicolas Ballas,Samira Ebrahimi Kahou,Antoine Chassang,Carlo Gatta,and YoshuaBengio.Fitnets:Hints for thin deep nets.arXiv preprint arXiv:1412.6550,2014.

Yonglong Tian,Dilip Krishnan,and Phillip Isola.Contrastive representation distillation.In IEEE/CVFInternational Conference on Learning Representations,2020.

Baoyun Peng,Xiao Jin,Jiaheng Liu,Dongsheng Li,Yichao Wu,Yu Liu,Shunfeng Zhou,and ZhaoningZhang.Correlation congruence for knowledge distillation.In International Conference on ComputerVision,2019.

關于知識蒸餾損失函數的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介紹:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/206637.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/206637.shtml
英文地址,請注明出處:http://en.pswp.cn/news/206637.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

VBA語法結構及編程思想

VBA(Visual Basic for Applications)是一種編程語言,它被用于Microsoft Office應用程序的自動化,允許用戶編寫宏來執行常規任務。VBA是基于Microsoft的Visual Basic語言,但專為Office應用程序定制。 VBA語法格式 VBA的…

【STM32】TIM定時器輸出比較

1 輸出比較 1.1 輸出比較簡介 OC(Output Compare)輸出比較;IC(Input Capture)輸入捕獲;CC(Capture/Compare)輸入捕獲和輸出比較的單元輸出比較可以通過比較CNT與CCR寄存器值&#…

JavaWeb-HTTP協議

1. 什么是HTTP協議 HTTP超文本傳輸協(Hyper Text transfer protocol),是一種用于用于分布式、協作式和超媒體信息系統的應用層協議。它于1990年提出,經過十幾年的使用與發展,得到不斷地完善和擴展。HTTP 是為 Web 瀏覽器與 Web 服務器之間的…

AI自動生成代碼工具

AI自動生成代碼工具是一種利用人工智能技術來輔助或自動化軟件開發過程中的編碼任務的工具。這些工具使用機器學習和自然語言處理等技術,根據開發者的需求生成相應的源代碼。以下是一些常見的AI自動生成代碼工具,希望對大家有所幫助。北京木奇移動技術有…

Redisson的基本使用

Redisson官網描述:Redisson 是一個在 Redis 的基礎上實現的 Java 駐內存數據網格客戶端(In-Memory Data Grid)。它不僅提供了一系列的 redis 常用數據結構命令服務,還提供了許多分布式服務,例如分布式鎖、分布式對象、…

HCIP —— BGP 基礎 (上)

BGP --- 邊界網關協議 (路徑矢量協議) IGP --- 內部網關協議 --- OSPF RIP ISIS EGP --- 外部網關協議 --- EGP BGP AS --- 自治系統 由單一的組織或者機構獨立維護的網絡設備以及網絡資源的集合。 因 網絡范圍太大 需 自治 。 為區分不同的AS&#…

vim常見操作

vim常見操作 文章目錄 vim常見操作1. 回退/前進2. 搜索3. 刪除4. 定位到50行5. 顯示行號6. 復制粘貼7. 剪貼8. 替換9. vim打開文件的時候出現 1. 回退/前進 1.esc進入命令模式 2.ctrlr 前進 u 回退2. 搜索 1) esc進入命令模式 2) /text  查找text&am…

Docker load 命令

docker load :導入使用docker save命令導出的鏡像。 語法 docker load [OPTIONS]OPTIONS 說明: --input , -i :指定導入的文件,代替STDIN。 --quiet , -q :精簡輸出信息。 實例: 導入鏡像&#xff1a…

【STM32】TIM定時器輸入捕獲

1 輸入捕獲 1.1 輸入捕獲簡介 IC(Input Capture)輸入捕獲 輸入捕獲模式下,當通道輸入引腳出現指定電平跳變時(上升沿/下降沿),當前CNT的值將被鎖存到CCR中(把CNT的值讀出來,寫入到…

ubuntu16.04安裝ROS+Gazebo

ubuntu16.04安裝ROS參考文章 ros安裝(一鍵最簡安裝,吹爆魚香ROS,請叫我魚吹) ROS篇——Ubuntu快速一鍵安裝ROS或ROS2(通用) ubuntu安裝ROS melodic(最新、超詳細圖文教程) 配置ubuntu以及安裝ros2必要環…

類風濕性關節炎口腔黏膜破裂引發抗瓜氨酸細菌和人蛋白抗體反應

今天給同學們分享一篇實驗文章“Oral mucosal breaks trigger anti-citrullinated bacterial and human protein antibody responses in rheumatoid arthritis”,這篇文章發表在Sci Transl Med期刊上,影響因子為17.1。 結果解讀: 口腔黏膜破…

Redis主從復制的配置和實現原理

Redis的持久化功能在一定程度上保證了數據的安全性,即便是服務器宕機的情況下,也可以保證數據的丟失非常少。通常,為了避免服務的單點故障,會把數據復制到多個副本放在不同的服務器上,且這些擁有數據副本的服務器可以用…

如何快速構建知識服務平臺,打造個人或企業私域流量

隨著互聯網的快速發展,傳統的知識付費平臺已經不能滿足用戶的需求。而SaaS知識付費小程序平臺則是一種新型的知識付費方式,具有靈活、便捷、高效等特點,為用戶提供了更加優質的付費知識服務。本文將介紹如何搭建自己的SaaS知識付費小程序平臺…

如何掌握構建 LMS 網站的藝術

目錄 什么是學習管理系統 (LMS) 在線課程和 LMS 網站的好處 為什么 WordPress 對于 LMS 網站很重要 統一學習中心 多功能性和可擴展性 提高教育參與度 簡化管理和監控 節省時間和費用 技能評估和績效監督 持續學習和技能提升 使用 WordPress 插件構建成功的 LMS 課程 專注于您的…

sparkc程序idea調試提示內存不足

報錯如下: Exception in thread "main" java.lang.IllegalArgumentException: System memory 259522560 must be at least 471859200. Please increase heap size using the --driver-memory option or spark.driver.memory in Spark configuration. 測…

自動駕駛:傳感器初始標定

手眼標定 機器人手眼標定AxxB(eye to hand和eye in hand)及平面九點法標定 Ax xB問題求解,旋轉和平移分步求解法 手眼標定AXXB求解方法(文獻總結) 基于靶的方法 相機標定 (1) ApriTag (2) 棋盤格:cv::f…

富時中國A50指數暴跌

近年來,中國股市的波動一直備受關注,而富時中國A50指數更是其中一項備受矚目的指標之一。然而,近期卻出現了一場引人矚目的暴跌,引發了廣泛的關注和討論。 富時中國A50指數簡介 富時中國A50指數,作為富時羅素指數系列…

【C/PTA】結構體專項練習

本文結合PTA專項練習帶領讀者掌握結構體,刷題為主注釋為輔,在代碼中理解思路,其它不做過多敘述。 目錄 6-1 選隊長6-2 按等級統計學生成績6-3 學生成績比高低6-4 綜合成績6-5 利用“選擇排序算法“對結構體數組進行排序6-6 結構體的最值6-7 復…

香港商標注冊申請所需資料及辦理流程

作為東方明珠,自由港香港是世界上較自由的貿易通商口岸,再加上本身良好的基礎設施和健全的法律制度,這給企業家提供了得天獨厚的營商環境。在香港注冊商標,可以迅速提高企業的知名度,提升企業不斷成長的競爭力&#xf…

全新UI彩虹外鏈網盤系統源碼V5.5/支持批量封禁+優化加載速度+用戶系統與分塊上傳

源碼簡介: 全新UI彩虹外鏈網盤系統源碼V5.5,它可以支持批量封禁優化加載速度。新增用戶系統與分塊上傳。 彩虹外鏈網盤,作為一款PHP網盤與外鏈分享程序,具備廣泛的文件格式支持能力。它不僅能夠實現各種格式文件的上傳&#xff…