【多標簽分類問題的樣本挖掘】Pytorch中的TripletMarginLoss的樣本挖掘

多數度量學習的代碼都需要進行挖掘,樣本挖掘過程就是把一個Batch中的所有樣本,根據標簽來劃分成正樣本和負樣本
這里我們只討論多標簽分類問題,標簽是onehot編碼,如果是單標簽分類任務可以去看pytorch_metric_learning這個庫有實現好的挖掘方法
比如輸入樣本為[Batch,Embedding],對應的標簽是[Batch,Class]
對這些樣本進行挖掘后得到以下三部分:

  1. Anchor :錨點樣本,其實就是和輸入的Batch一模一樣,
  2. Positive Sample : 挖掘的正正樣本
  3. Negtive Sample : 挖掘的負樣本
import torch
import torch.nn as nn 
import torchvision# 損失函數
class HibCriterion(nn.Module):def __init__(self):super().__init__()def forward(self, z_samples, alpha, beta, indices_tuple):n_samples = z_samples.shape[1]if len(indices_tuple) == 3:a, p, n = indices_tupleap = an = aelif len(indices_tuple) == 4:ap, p, an, n = indices_tuplealpha = torch.nn.functional.softplus(alpha)loss = 0for i in range(n_samples):z_i = z_samples[:, i, :]for j in range(n_samples):z_j = z_samples[:, j, :]prob_pos = torch.sigmoid(- alpha * torch.sum((z_i[ap] - z_j[p])**2, dim=1) + beta) + 1e-6prob_neg = torch.sigmoid(- alpha * torch.sum((z_i[an] - z_j[n])**2, dim=1) + beta) + 1e-6# maximize the probability of positive pairs and minimize the probability of negative pairsloss += -torch.log(prob_pos) - torch.log(1 - prob_neg)loss = loss / (n_samples ** 2)return loss.mean()def get_matches_and_diffs(labels):matches = (labels.float() @ labels.float().T).byte()diffs = matches ^ 1 # 異或運算得到負標簽的矩陣return matches, diffsdef get_all_triplets_indices_vectorized_method(all_matches, all_diffs):"""Args:all_matches (torch.Tensor): 相同標簽all_diffs (torch.Tensor): 不相同標簽Processing : all_matches.unsqueeze(2) -> [Batch,Batch,1]all_diffs.unsqeeeze(1) -> [Batch,1,Batch] Returns:torch.Tensor: _description_"""triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)return torch.where(triplets)class TripletMinner(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.sim_mat = get_matches_and_diffsself.selctor = get_all_triplets_indices_vectorized_methoddef forward(self,labels):a , b = self.sim_mat(labels)c = self.selctor(a,b)return c

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/15721.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/15721.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/15721.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

學習Uni-app開發小程序Day18

昨天學習了使用輪播顯示圖片和文字,輪播方式縱向和橫向。今天使用擴展組件和scroll-view顯示圖片,使用scroll-view的grid方式、插槽slot、自定義組件、磨砂背景定位布局做專題組件 這就是需要做成的效果,下面將一步一步的完成。 首先&#x…

如何高效創建與配置工程環境:零基礎入門

新書上架~👇全國包郵奧~ python實用小工具開發教程http://pythontoolsteach.com/3 歡迎關注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目錄 一、工程環境的搭建與準備 二、配置虛擬環境與選擇解釋器 三、編寫代碼與自動添加多行注釋 …

git describe --tags報錯 fatal: No names found, cannot describe anything.

文章目錄 git describe --tags報錯 fatal: No names found, cannot describe anything. git describe --tags報錯 fatal: No names found, cannot describe anything. 問題描述: git describe --tags fatal: No names found, cannot describe anything.原因分析&a…

SpringMVC筆記

一、SpringMVC 簡介 1.1 什么是 MVC MVC 是一種軟件架構的思想,將軟件按照模型、視圖、控制器來劃分 1.M:Model 模型層,指工程中的 JavaBean ,作用是處理數據 JavaBean 分為兩類 實體類Bean:專門存儲業務數據的…

C++vector的簡單模擬實現

文章目錄 目錄 文章目錄 前言 一、vector使用時的注意事項 1.typedef的類型 2.vector不是string 3.vector 4.算法sort 二、vector的實現 1.通過源碼進行猜測vector的結構 2.初步vector的構建 2.1 成員變量 2.2成員函數 2.2.1尾插和擴容 2.2.2operator[] 2.2.3 迭代器 2…

云存儲與云計算詳解

1. 云存儲與云計算概述 1.1 云存儲 云存儲(Cloud Storage)是指通過互聯網將數據存儲在遠程服務器上,用戶可以隨時隨地訪問和管理這些數據。云存儲的優點包括高可擴展性、靈活性和成本效益。 1.2 云計算 云計算(Cloud Computin…

前端 控制臺提示invalid date

如果你遇到了 "Invalid Date" 的錯誤,這通常意味著傳遞給 Date 構造函數的字符串或數值無法被解析為一個有效的日期。對于時間戳來說,確保它是一個有效的數字(表示自1970年1月1日00:00:00 UTC以來的毫秒數)。 以下是一…

Java如何設計一個功能

流程說明:實現一組功能的步驟 1,充分了解需求,包括所有的細節,需要知道要做一個什么樣的功能。 2,設計實體/表 正向工程:設計實體、映射文件 --> 建表 反向工程:設計表 --> 映射文件、實體 設計實體類型分析步驟: 1)功能模塊有幾個實體…

【Apache Doris】BE宕機問題排查指南

【Apache Doris】BE宕機問題排查指南 背景BE宕機分類如何判斷是BE進程是Crash還是OOMBE Crash 后如何排查BE OOM 后如何分析Cache 沒及時釋放導致BE OOM(2.0.3-rc04) 關于社區 作者|李淵淵 背景 在實際線上生產環境中,大家可能遇…

校園網撥號上網環境下多開虛擬機,實現宿主機與虛擬機互通,并訪問外部網絡

校園網某些登錄客戶端只允許同一時間一臺設備登錄,因此必須使用NAT模式共享宿主機的真實IP,相當于訪問外網時只使用宿主機IP,此方式通過虛擬網卡與物理網卡之間的數據轉發實現訪問外網及互通 經驗證,將centos的物理地址與主機物理…

有什么好用的語音翻譯軟件推薦?親測實用的語音翻譯工具來了

嘿,大家好!你們有沒有想過,現在世界這么“小”,我們跟不同國家的人打交道的機會越來越多了。 但是呢,語言不通真是個大問題。別擔心,現在有個超棒的解決方案——語音翻譯技術!這玩意兒能實時把…

Spring Cloud學習筆記(Nacos):配置中心基礎和代碼樣例

這是本人學習的總結,主要學習資料如下 - 馬士兵教育 1、Overview2、樣例2.1、Dependency2.2、配置文件的定位2.3、bootstrap.yml2.4、配置中心新增配置2.5、驗證 1、Overview 配置中心用于管理配置項和配置文件,比如平時寫的application.yml就是配置文件…

Python 遍歷字典的方法,你都掌握了嗎

Python中的字典是一種非常靈活的數據結構,它允許通過鍵來存儲和訪問值。在處理字典時,經常需要遍歷字典中的元素,以下是幾種常見的遍歷字典的方法。 1. 使用 for 循環直接遍歷字典的鍵 字典的鍵是唯一的,可以直接通過 for 循環來…

【Spring Security + OAuth2】OAuth2

Spring Security OAuth2 第一章 Spring Security 快速入門 第二章 Spring Security 自定義配置 第三章 Spring Security 前后端分離配置 第四章 Spring Security 身份認證 第五章 Spring Security 授權 第六章 OAuth2 文章目錄 Spring Security OAuth21、OAuth2簡介1.1、OAu…

call、apply和bind

call、apply和bind都是JavaScript中函數對象的方法,用于改變函數的this值。 call:call方法接收一個對象和一系列參數,并立即調用函數,將this值設置為提供的對象。例如: function greet(greeting, punctuation) {cons…

Linux驅動開發筆記(二) 基于字符設備驅動的I/O操作

文章目錄 前言一、設備驅動的作用與本質1. 驅動的作用2. 有無操作系統的區別 二、內存管理單元MMU三、相關函數1. ioremap( )2. iounmap( )3. class_create( )4. class_destroy( ) 四、GPIO的基本知識1. GPIO的寄存器進行讀寫操作流程2. 引腳復用2. 定義GPIO寄存器物理地址 五、…

【2024最新華為OD-C卷試題匯總】傳遞悄悄話的最長時間(100分) - 三語言AC題解(Python/Java/Cpp)

🍭 大家好這里是清隆學長 ,一枚熱愛算法的程序員 ? 本系列打算持續跟新華為OD-C卷的三語言AC題解 💻 ACM銀牌🥈| 多次AK大廠筆試 | 編程一對一輔導 👏 感謝大家的訂閱? 和 喜歡💗 文章目錄 前…

東哥一句兄弟,你還當真了?

關注盧松松,會經常給你分享一些我的經驗和觀點。 你還真把自己當劉強東兄弟了?誰跟你是兄弟了?你在國外的房子又不給我住,你出去旅游也不帶上我!都成人年了,東哥一句客套話,別當真! 今天,東哥在高管會上直言&…

mysql內存結構

一:邏輯存儲結構:表空間->段->區->頁->行、 表空間:一個mysql實例對應多個表空間,用于存儲記錄,索引等數據。 段:分為數據段,索引段,回滾段。innoDB是索引組織表&…