【Next Token Prediction】VLM模型訓練中數據集標簽預處理詳解

源代碼來自:https://github.com/huggingface/nanoVLM/blob/main/data/collators.py

詳解如下所示:

import torch#-------------------------------#
# 主要是在數據加載器的構建中被使用
#-------------------------------#class BaseCollator(object):def __init__(self, tokenizer):self.tokenizer               = tokenizerrandom_string_5_letters      = "xzyvd" # 作為“錨點”,查找它在模板化后的完整文本中的位置# 將輸入消息轉換成Chat模板格式的字符串 例如 "<|start|>assistant\nxzyvd<|end|>" 此為純文本而不是被編碼后得到的token idsrandom_string_chat_templated = self.tokenizer.apply_chat_template([{"role": "assistant", "content": random_string_5_letters}], tokenize=False, add_special_tokens=False)random_string_location       = random_string_chat_templated.find(random_string_5_letters) # 查找我們之前插入的“隨機標記”出現的位置# 例如回復為<|start|>assistant\nxzyvd<|end|># 獲取到nxzyvd開始后的位置, 然后從而獲取到前綴的長度# 目的是在后續設置loss_mask時能夠精準跳過模板前綴,只對assistant回復的實際內容進行監督self.prefix_len              = len(self.tokenizer.encode(random_string_chat_templated[:random_string_location])) # 找到前綴模板結束的位置#----------------------------------------------------------## 用于處理批量對話消息# 隨后返回模型需要的token ids、attention mask以及loss mask# 1.將消息轉換為模型所需的 token 格式# 2.根據消息中的role(例如 assistant)標記哪些token需要計算損失(loss_mask),即只對assistant的具體輸出進行損失計算,而不對user的內容進行計算# 3.將所有輸入統一padding到最大長度max_len,確保批次的輸入大小一致#----------------------------------------------------------#def prepare_inputs_and_loss_mask(self, batched_messages, max_length=None):batch_token_ids: list[list[int]]  = [] # 保存每個批次消息的token idsbatch_masks:     list[list[int]]  = [] # 保存每個批次消息的loss_mask,即哪些token需要計算損失batch_attentions: list[list[int]] = [] # 保存每個批次消息的attention mask,模型用來指示哪些部分是有效輸入,哪些是 paddingfor messages in batched_messages: # 每一條消息中都包含若干user和assistant的內容#---------------------------------------------------------------------------------------## 對于此處生成的attention mask# tokenizer會自動將padding部分的attention mask設為0,其余為1# 其作用為告訴模型哪些token是“真正需要注意的內容”,哪些只是為了湊長度而padding的垃圾位# 它是Transformer中注意力機制不可或缺的一部分,尤其在處理變長輸入(如自然語言對話)時非常關鍵# NOTE:此處,tokenizer沒有做統一長度 padding,而是保留了變長的attention_mask#---------------------------------------------------------------------------------------#conv_ids = self.tokenizer.apply_chat_template(messages,tokenize=True, # 控制attention mask相關內容add_special_tokens=False,return_dict=True,) # conv_ids是面向整個對話的一個字典,包含了對應的 input_ids(token ids)和 attention_maskmask   = [0] * len(conv_ids["input_ids"]) # 為每個對話消息初始化一個全零的 mask 列表# Locate each assistant turn and flip its mask to 1cursor = 0 # 用來記錄當前已經處理過的token數量for msg in messages: # 對user與assistant的內容均進行處理segment_ids = self.tokenizer.apply_chat_template([msg], tokenize=True, add_special_tokens=False) # 將每條消息msg轉換為token ids # 只包含這一條消息的內容seg_len = len(segment_ids) # 獲取消息的長度, 即為每條消息的實際token數目#---------------------------------------## 當處理角色為assistant的時候展開下述操作# 只對其具體回復的內容進行操作#---------------------------------------#if msg["role"] == "assistant":start = cursor + self.prefix_len # 確定消息的起點end   = cursor + seg_len         # 根據消息的長度去確定終點mask[start:end] = [1] * (end - start)  # attend to these tokens # 將assistant的回復部分的mask設置為1cursor += seg_len # 因為一組對話中assistant回復的內容可能有多處, 因此需要進行累積batch_token_ids.append(conv_ids["input_ids"]) # token idsbatch_masks.append(mask) # 哪些token需要去計算batch_attentions.append(conv_ids["attention_mask"]) # 哪些部分是有效輸入# NOTE:主要針對assistant回復過長的情況進行處理if max_length is not None:  # We need to keep the tokens to allow for the img embed replacing logic to work. Otherwise, we would need to track which images correspond to long samples.batch_token_ids  = [ids[:max_length] for ids in batch_token_ids] # 對超過max length的樣本進行裁剪, 使其長度滿足要求# 如果長度超過 max_length,則將其截斷為全零的 mask(表示忽略該樣本)batch_masks      = [m if len(m) <= max_length else [0]*max_length for m in batch_masks] # Ignore samples that are longer than max_lengthbatch_attentions = [a[:max_length] for a in batch_attentions] # 同樣進行截取# Pad samples to max lengthif max_length is not None:max_len = max_lengthelse:max_len = max(map(len, batch_token_ids))# 對每個樣本均展開padding操作batch_token_ids  = [[self.tokenizer.pad_token_id]*(max_len-len(ids)) + ids for ids in batch_token_ids] # 使用pad_token_id將長度填充到max lengthbatch_masks      = [[0]*(max_len-len(m)) + m         for m   in batch_masks]                           # 填充至最大長度max_len,使用0填充batch_attentions = [[0]*(max_len-len(a)) + a         for a   in batch_attentions]                      # 填充至最大長度max_len,使用0填充 # NOTE: 相當于是在tokenzier的基礎上 根據max length去展開補充性paddingreturn torch.tensor(batch_token_ids), torch.tensor(batch_attentions), torch.tensor(batch_masks).to(torch.bool)#-------------------------------------#
# Visual Question Answering Collator
# 訓練與驗證數據集
#-------------------------------------#
class VQACollator(BaseCollator):def __init__(self, tokenizer, max_length):self.max_length  = max_lengthsuper().__init__(tokenizer)def __call__(self, batch):images           = [item["images"] for item in batch]messages_batched = [item["text_data"] for item in batch]# Stack imagesimgs   = [img for sublist in images for img in sublist]images = torch.stack(imgs)# Create inputs by concatenating special image tokens, question, and answerbatch_input_ids, batch_attention_mask, loss_masks = self.prepare_inputs_and_loss_mask(messages_batched, max_length=self.max_length)#--------------------------------------------------------------------------------------------------------------------------------------------------------------------------## Create labels where only answer tokens are predicted# 1. 首先將模型回復的內容全部復制一份出來, 然后將為mask為0的區域全部填充為-100, 表明直接忽視不參與計算# 2. 為適應因果語言建模, 展開標簽平移操作, 作用為確保模型在展開語言生成任務時, 能夠預測當前時間步的下一個token# 具體而言, labels[:, :-1]為選擇每個樣本的所有token中除去最后一個token的部分, labels[:, 1:]為獲取每個樣本中從第二個token到最后一個token的所有內容# 這樣就可以將每個樣本的所有token都可以向左移動一位, 從而將每個位置對應的token都用它的下一個token去進行預測。這樣每個token的標簽都變成了它的下一個token, 即為next token prediction# 3. 這樣最后一個token由于沒有標簽目標, 直接設置為-100即可, 表明到了結尾# 例子:# batch_input_ids為[[101, 2001, 2023, 2045, 102]], 其中2001處的loss mask為0, 那么labels即為[[101, 2023, 2045, 102]]# 然后第一個樣本的0 1 2 3四個位置上對應的label即變為[2023, 2045, 102, -100]# 這樣就形成了真值標簽[[2023, 2045, 102, -100]]#--------------------------------------------------------------------------------------------------------------------------------------------------------------------------#labels         = batch_input_ids.clone().masked_fill(~loss_masks, -100) # 將~loss_masks為1的地方填充為-100 NOTE:此處相當于就是無效的地方labels[:, :-1] = labels[:, 1:] # Shift labels for causal LMlabels[:, -1]  = -100 # Last token has no targetreturn {"image": images, # 圖像"input_ids": batch_input_ids, # 輸入內容"attention_mask": batch_attention_mask, # 告訴模型在等長序列中, 哪些是需要關注的實際token, 哪些是padding token"labels": labels, #標簽}#--------------------------------------------------------#
# 測試數據集
# https://huggingface.co/datasets/Lin-Chen/MMStar
#--------------------------------------------------------#
class MMStarCollator(BaseCollator): def __init__(self, tokenizer):super().__init__(tokenizer)def __call__(self, batch):images           = [item["image"] for item in batch]messages_batched = [item["text_data"] for item in batch]# Stack imagesimages = torch.stack(images)# Create inputs by concatenating special image tokens, question, and answerbatch_input_ids, batch_attention_mask, loss_masks = self.prepare_inputs_and_loss_mask(messages_batched)#---------------------------------------------------------------------------------------------------------------------------------------------## 1. 把需要預測的位置(即 loss_masks=1)設成pad token, 這意味著這些位置不會被送去模型作為“輸入”,因為它們是模型需要生成的內容# 2. 把要預測的部分在attention mask里屏蔽掉, 導致模型不會“看到”這些 token,符合推理階段的auto-regressive decoding 邏輯# 3. 只保留需要預測的token作為標簽,其余地方用pad填充#---------------------------------------------------------------------------------------------------------------------------------------------#"""example:query: "User: What color is the sky?\nAssistant: The sky is"prediction: "blue."那么 loss_mask 會標記 "blue." 這一段, collator就會:把 input_ids 中 "blue." 變成pad(輸入時忽略)把 attention_mask 中對應位置設為0(不關注)把 labels 中 "blue." 保留, 其余是pad(只評估藍天這個詞)"""input_ids      = batch_input_ids.masked_fill(loss_masks, self.tokenizer.pad_token_id)attention_mask = batch_attention_mask.masked_fill(loss_masks, 0)labels         = batch_input_ids.clone().masked_fill(~loss_masks, self.tokenizer.pad_token_id)return {"images": images,"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels,}

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/88907.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/88907.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/88907.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Istio 簡介

Istio 簡介 什么是 Istio Istio 是一個開源的 服務網格&#xff08;Service Mesh&#xff09; 框架&#xff0c;由 Google、IBM 和 Lyft 聯合開發&#xff0c;目前屬于 CNCF&#xff08;云原生計算基金會&#xff09;項目。它主要用于管理和連接微服務架構中的服務&#xff0…

融云在華為開發者大會分享智能辦公平臺的鴻蒙化探索實踐

6 月 20 日-22 日&#xff0c;“華為開發者大會&#xff08;HDC 2025&#xff09;”在東莞隆重召開&#xff0c;融云受邀出席并在“政企內部應用論壇”發表主旨演講。 鴻蒙為千行百業的生態伙伴創新帶來了獨特的歷史機遇&#xff0c;其蓬勃發展也為我國數字經濟高質量發展提供…

滾珠導軌如何助力自動化生產實現高質量輸出?

在自動化生產線的蓬勃發展中&#xff0c;高效、精準與穩定是核心追求。滾珠導軌作為關鍵的傳動部件&#xff0c;以其獨特的優勢&#xff0c;在眾多自動化生產場景里大放異彩&#xff0c;為生產流程的優化和產品質量的提升顯著提高設備系統的穩定性和可靠性。 汽車自動化裝配線 …

消息隊列的推拉模式詳解:實現原理與代碼實戰

消息隊列是現代分布式系統中不可或缺的中間件&#xff0c;它通過"生產者-消費者"模式實現了系統間的解耦和異步通信。本文將深入探討消息隊列中的兩種核心消息傳遞模式&#xff1a;推送(Push)和拉取(Pull)&#xff0c;并通過代碼示例展示它們的實現方式。 目錄 消息…

OpenCV圖像噪點消除五大濾波方法

在數字圖像處理中&#xff0c;噪點消除是提高圖像質量的關鍵步驟。本文將基于OpenCV庫&#xff0c;詳細講解五種經典的圖像去噪濾波方法&#xff1a;均值濾波、方框濾波、高斯濾波、中值濾波和雙邊濾波&#xff0c;并通過豐富的代碼示例展示它們的實際應用效果。 一、圖像噪點…

Rust宏和普通函數的區別

Rust 中的宏&#xff08;macro&#xff09;和普通函數有以下核心區別&#xff0c;分別從用途、擴展方式、性能影響和語法特征等多個方面來解釋&#xff1a; &#x1f4cc; 1. 定義方式 項目宏函數定義方式macro_rules! 或 macro&#xff08;新版&#xff09;fn 關鍵字調用方式…

基于Qt C++的影像重采樣批處理工具設計與實現

摘要 本文介紹了一種基于Qt C++框架開發的高效影像重采樣批處理工具。該工具支持按分辨率(DPI) 和按縮放倍率兩種重采樣模式,提供多種插值算法選擇,具備強大的批量處理能力和直觀的用戶界面。工具實現了影像處理的自動化流程,顯著提高了圖像處理效率,特別適用于遙感影像處…

TypeScript 中的 WebSocket 入門

如何開始使用 Typescript 和 React 中的 WebSockets 創建一個簡單的聊天應用程序 示例源碼&#xff1a;ws 下一篇&#xff1a;https://blog.csdn.net/hefeng_aspnet/article/details/148898147 介紹 WebSocket 是一項我目前還沒有在工作中使用過的技術&#xff0c;但我知道…

TMS汽車熱管理系統HILRCP解決方案

TMS汽車熱管理系統介紹 隨著汽車電動化和智能化的發展&#xff0c;整車能量管理內容增多&#xff0c;對汽車能量管理的要求也越來越高&#xff0c;從整車層面出發對各子系統進行能量統籌管理將成為電動汽車未來的發展趨勢&#xff0c;其中汽車熱管理是整車能量管理的重要組成部…

CCleaner Pro v6.29.11342 綠色便攜版

CCleaner Pro v6.29.11342 綠色便攜版 CCleaner是Piriform&#xff08;梨子公司&#xff09;最著名廣受好評的系統清理優化及隱私保護軟件&#xff0c;也是該公司主打和首發產品&#xff0c;它體積小、掃描速度快&#xff0c;具有強大的自定義清理規則擴展能力。CCleaner是一款…

不做手機控APP:戒掉手機癮,找回專注與自律

在當今數字化時代&#xff0c;手機已經成為我們生活中不可或缺的一部分。然而&#xff0c;過度依賴手機不僅會分散我們的注意力&#xff0c;影響學習和工作效率&#xff0c;還可能對身心健康造成負面影響。為了幫助用戶擺脫手機依賴&#xff0c;重拾自律和專注&#xff0c;一款…

Go 語言中的接口

1、接口與鴨子類型 在 Go 語言中&#xff0c;接口&#xff08;interface&#xff09;是一個核心且至關重要的概念。它為構建靈活、可擴展的軟件提供了堅實的基礎。要深入理解 Go 的接口&#xff0c;我們必須首先了解一個在動態語言中非常普遍的設計哲學——鴨子類型&#xff0…

在項目中如何巧妙使用緩存

緩存 對于經常訪問的數據&#xff0c;每次都從數據庫&#xff08;硬盤&#xff09;中獲取是比較慢&#xff0c;可以利用性能更高的存儲來提高系統響應速度&#xff0c;俗稱緩存 。合理使用緩存可以顯著降低數據庫的壓力、提高系統性能。 那么&#xff0c;什么樣的數據適合緩存…

SLAM中的非線性優化-2D圖優化之零空間(十五)

這節在進行講解SLAM中一個重要概念&#xff0c;零空間&#xff0c;講它有啥用呢&#xff1f;因為SLAM中零空間的存在&#xff0c;才需要FEJ或固定約束存在&#xff0c;本節內容不屬于2D圖優化獨有&#xff0c;先看看什么是零空間概念&#xff1b;零空間是一個核心概念&#xff…

如何解決本地DNS解析失敗問題?以連接AWS ElastiCache Redis為例

在云服務開發中,DNS解析問題常常成為困擾開發者的隱形障礙。本文將通過AWS ElastiCache Redis連接失敗的實際案例,詳細介紹如何診斷和解決DNS解析問題,幫助你快速恢復服務連接。 引言 在使用 telnet 或 redis-cli 連接 AWS ElastiCache Redis 時,有時會遇到類似以下錯誤:…

探索釘釘生態中的宜搭:創建與分享應用的新視界

在當今快速發展的數字化時代&#xff0c;企業對于高效協作和信息管理的需求日益增長。作為阿里巴巴集團旗下的智能工作平臺&#xff0c;釘釘不僅為企業提供了強大的溝通工具&#xff0c;其開放的生態系統也為用戶帶來了無限可能。其中&#xff0c;宜搭&#xff08;YiDa&#xf…

深入理解事務和MVCC

文章目錄 事務定義并發事務代碼實現 MVCC定義核心機制 事務 定義 什么是事務&#xff1f; 事務是指一組操作要么全部成功&#xff0c;要么全部失敗的執行單位。 在數據庫中&#xff0c;一個事務通常包含一組SQL語句&#xff0c;系統保證這些語句作為一個整體執行。 為什么引…

用 Python 繪制精美雷達圖:多維度材料屬性對比可視化全指南

&#x1f31f; 為什么選擇雷達圖&#xff1f;從材料科學到多維數據對比的可視化利器 在科研和數據分析領域&#xff0c;當我們需要同時展示多個維度的數據對比時&#xff0c;傳統的柱狀圖或折線圖往往顯得力不從心。這時候&#xff0c;雷達圖&#xff08;Radar Chart&#xff…

Excel學習03

超級表與圖表 Excel中具有超級表的功能。所謂超級表&#xff08;官方名稱為“表格”&#xff0c;快捷鍵CtrlT&#xff09;是Excel中一個強大的數據管理工具&#xff0c;它將普通的數據區域轉換為具有只能功能的交互式表格。 這就是表格變為超級表的樣子。超級表默認具備凍結窗…

Netflix 網飛的架構演進過程、Java在網飛中的應用|圖解

寫在前面 上一篇文章中&#xff0c;我們講解了網飛當前的架構&#xff0c;但網飛的架構并不是一開始就是這樣的&#xff0c;而是不斷演進發展才是當前的樣子。 這篇文章我們就來講講網飛架構的演進過程。 第一階段&#xff1a;Zuul Gateway REST API 使用 Zuul 作為API網關…