GenerationMixin概述

類名簡單說明
GenerateDecoderOnlyOutput繼承自 ModelOutput,適用于非束搜索方法的解碼器-only模型輸出類。
GenerateEncoderDecoderOutput繼承自 ModelOutput,適用于非束搜索方法的編碼器-解碼器模型輸出類。
GenerateBeamDecoderOnlyOutput繼承自 ModelOutput,適用于束搜索方法的解碼器-only模型輸出類。
GenerateBeamEncoderDecoderOutput繼承自 ModelOutput,適用于束搜索方法的編碼器-解碼器模型輸出類。
GreedySearchDecoderOnlyOutputGenerateDecoderOnlyOutput 相同,保留用于向后兼容的別名。
ContrastiveSearchDecoderOnlyOutputGenerateDecoderOnlyOutput 相同,保留用于向后兼容的別名。
SampleDecoderOnlyOutputGenerateDecoderOnlyOutput 相同,保留用于向后兼容的別名。
GreedySearchEncoderDecoderOutputGenerateEncoderDecoderOutput 相同,保留用于向后兼容的別名。
ContrastiveSearchEncoderDecoderOutputGenerateEncoderDecoderOutput 相同,保留用于向后兼容的別名。
SampleEncoderDecoderOutputGenerateEncoderDecoderOutput 相同,保留用于向后兼容的別名。
BeamSearchDecoderOnlyOutputGenerateBeamDecoderOnlyOutput 相同,保留用于向后兼容的別名。
BeamSampleDecoderOnlyOutputGenerateBeamDecoderOnlyOutput 相同,保留用于向后兼容的別名。
BeamSearchEncoderDecoderOutputGenerateBeamEncoderDecoderOutput 相同,保留用于向后兼容的別名。
BeamSampleEncoderDecoderOutputGenerateBeamEncoderDecoderOutput 相同,保留用于向后兼容的別名。
GreedySearchOutputGreedySearchEncoderDecoderOutputGreedySearchDecoderOnlyOutput 的聯合類型。
SampleOutputSampleEncoderDecoderOutputSampleDecoderOnlyOutput 的聯合類型。
BeamSearchOutputBeamSearchEncoderDecoderOutputBeamSearchDecoderOnlyOutput 的聯合類型。
BeamSampleOutputBeamSampleEncoderDecoderOutputBeamSampleDecoderOnlyOutput 的聯合類型。
ContrastiveSearchOutputContrastiveSearchEncoderDecoderOutputContrastiveSearchDecoderOnlyOutput 的聯合類型。
GenerateNonBeamOutputGenerateDecoderOnlyOutputGenerateEncoderDecoderOutput 的聯合類型。
GenerateBeamOutputGenerateBeamDecoderOnlyOutputGenerateBeamEncoderDecoderOutput 的聯合類型。
GenerateOutputGenerateNonBeamOutputGenerateBeamOutput 的聯合類型。
GenerationMixin包含自動回歸文本生成所有功能的類,可作為 PreTrainedModel 的 mixin 使用。
  • 定義了多個數據類(@dataclass),這些類繼承自 ModelOutput,用于表示生成模型在不同情況下的輸出結果。
    Python:@dataclass裝飾器

  • 定義了一些等價的類和類型簡寫(typing shortcuts),主要是為了兼容舊版本的代碼,也方便在代碼中進行類型提示。

重點解釋以下三個類:

  1. GenerateDecoderOnlyOutput

  2. GenerateEncoderDecoderOutput

  3. GenerateNonBeamOutput


1. GenerateDecoderOnlyOutput

描述:

GenerateDecoderOnlyOutput 是一個數據類,用于表示 僅解碼器模型(decoder-only models) 在使用 非束搜索方法(non-beam methods) 進行生成時的輸出結果。

主要用途:

此類主要用于像 GPT-2、GPT-3 等僅包含解碼器的模型,當它們使用貪婪搜索(Greedy Search)、隨機采樣(Sampling)、對比搜索(Contrastive Search)等非束搜索方法進行文本生成時,封裝和返回生成的結果。

2. GenerateEncoderDecoderOutput

描述:

GenerateEncoderDecoderOutput 是一個數據類,用于表示 編碼器-解碼器模型(encoder-decoder models) 在使用 非束搜索方法(non-beam methods) 進行生成時的輸出結果。

主要用途:

此類主要用于像 BART、T5 等包含編碼器和解碼器的模型,當它們使用貪婪搜索、隨機采樣、對比搜索等非束搜索方法進行文本生成時,封裝和返回生成的結果。

GenerateDecoderOnlyOutputGenerateEncoderDecoderOutput

字段名GenerateDecoderOnlyOutputGenerateEncoderDecoderOutput
sequences必填
torch.LongTensor
形狀:(batch_size, sequence_length)
生成的序列。
必填
torch.LongTensor
形狀:(batch_size * num_return_sequences, sequence_length)
生成的序列。
scores可選
Optional[Tuple[torch.FloatTensor]]
output_scores=True 時返回。
處理后的預測分數(每步)。
可選
Optional[Tuple[torch.FloatTensor]]
同左。
logits可選
Optional[Tuple[torch.FloatTensor]]
output_logits=True 時返回。
未經處理的預測分數(每步)。
可選
Optional[Tuple[torch.FloatTensor]]
同左。
attentions可選
Optional[Tuple[Tuple[torch.FloatTensor]]]
output_attentions=True 時返回。
解碼器每層的注意力權重。
名稱不同
GenerateEncoderDecoderOutput 中,該字段為 decoder_attentions
hidden_states可選
Optional[Tuple[Tuple[torch.FloatTensor]]]
output_hidden_states=True 時返回。
解碼器每層的隱藏狀態。
名稱不同
GenerateEncoderDecoderOutput 中,該字段為 decoder_hidden_states
past_key_values可選
Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]]
use_cache=True 時返回。
模型的緩存狀態。
可選
Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]]
同左。
encoder_attentions無此字段可選
Optional[Tuple[torch.FloatTensor]]
output_attentions=True 時返回。
編碼器每層的注意力權重。
encoder_hidden_states無此字段可選
Optional[Tuple[torch.FloatTensor]]
output_hidden_states=True 時返回。
編碼器每層的隱藏狀態。
decoder_attentions無此字段
對應于 attentions 字段。
可選
Optional[Tuple[Tuple[torch.FloatTensor]]]
解碼器每層的注意力權重。
decoder_hidden_states無此字段
對應于 hidden_states 字段。
可選
Optional[Tuple[Tuple[torch.FloatTensor]]]
解碼器每層的隱藏狀態。
cross_attentions無此字段可選
Optional[Tuple[Tuple[torch.FloatTensor]]]
output_attentions=True 時返回。
解碼器每層的跨注意力權重。

字段詳解

共有字段
  • sequences

    • 描述:生成的序列。
    • GenerateDecoderOnlyOutput:形狀為 (batch_size, sequence_length)
    • GenerateEncoderDecoderOutput:形狀為 (batch_size * num_return_sequences, sequence_length)
  • scores

    • 描述:處理后的預測分數(即在 SoftMax 之前的 logits),每步生成一個。
    • 類型Optional[Tuple[torch.FloatTensor]]
    • 返回條件output_scores=True
  • logits

    • 描述:未經處理的預測分數(logits),每步生成一個。
    • 類型Optional[Tuple[torch.FloatTensor]]
    • 返回條件output_logits=True
  • past_key_values

    • 描述:模型的緩存狀態,用于加速解碼。
    • 類型Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]]
    • 返回條件use_cache=True
僅在 GenerateDecoderOnlyOutput
  • attentions

    • 描述:解碼器的注意力權重。
    • 類型Optional[Tuple[Tuple[torch.FloatTensor]]]
    • 返回條件output_attentions=True
  • hidden_states

    • 描述:解碼器的隱藏狀態。
    • 類型Optional[Tuple[Tuple[torch.FloatTensor]]]
    • 返回條件output_hidden_states=True
僅在 GenerateEncoderDecoderOutput
  • encoder_attentions

    • 描述:編碼器的注意力權重。
    • 類型Optional[Tuple[torch.FloatTensor]]
    • 返回條件output_attentions=True
  • encoder_hidden_states

    • 描述:編碼器的隱藏狀態。
    • 類型Optional[Tuple[torch.FloatTensor]]
    • 返回條件output_hidden_states=True
  • decoder_attentions

    • 描述:解碼器的注意力權重(相當于 GenerateDecoderOnlyOutput 中的 attentions)。
    • 類型Optional[Tuple[Tuple[torch.FloatTensor]]]
    • 返回條件output_attentions=True
  • decoder_hidden_states

    • 描述:解碼器的隱藏狀態(相當于 GenerateDecoderOnlyOutput 中的 hidden_states)。
    • 類型Optional[Tuple[Tuple[torch.FloatTensor]]]
    • 返回條件output_hidden_states=True
  • cross_attentions

    • 描述:解碼器的跨注意力權重(解碼器與編碼器之間的注意力)。
    • 類型Optional[Tuple[Tuple[torch.FloatTensor]]]
    • 返回條件output_attentions=True

3. GenerateNonBeamOutput

描述:

GenerateNonBeamOutput 是一個類型別名,用于表示在使用 非束搜索方法(non-beam methods) 進行生成時,模型的輸出結果。

定義:
GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
含義:
  • 它可以是 GenerateDecoderOnlyOutput 類型,也可以是 GenerateEncoderDecoderOutput 類型。

  • 這個類型別名的存在,使得在處理非束搜索生成輸出時,可以統一處理,不用區分模型是僅解碼器模型還是編碼器-解碼器模型。


附加說明:

  • 非束搜索方法(Non-beam methods)

    指在生成文本時,不使用束搜索(Beam Search)算法的生成方法,例如貪婪搜索、隨機采樣、對比搜索等。這些方法通常速度更快,但可能生成的結果質量不如束搜索。

  • 緩存機制(Past Key Values)

    在生成長序列時,模型可以緩存之前計算的鍵和值,以避免重復計算,提高生成效率。緩存的內容和格式因模型而異。


GenerationMixin

以下是對 GenerationMixin 類中各個方法和屬性的整理,包括分類和功能描述:

類別名稱功能描述
靜態方法_expand_inputs_for_generation擴展輸入以用于生成,將張量從 [batch_size, ...] 擴展為 [batch_size * expand_size, ...]
實例方法prepare_inputs_for_generation準備生成所需的模型輸入,包括計算注意力掩碼或根據緩存裁剪輸入等操作。
實例方法_prepare_model_inputs提取用于生成的模型特定輸入。
實例方法_maybe_initialize_input_ids_for_generation在必要時初始化用于生成的 input_ids
實例方法_prepare_attention_mask_for_generation為生成準備注意力掩碼。
實例方法_prepare_encoder_decoder_kwargs_for_generation在生成期間為編碼器-解碼器模型準備 kwargs
實例方法_prepare_decoder_input_ids_for_generation為編碼器-解碼器模型準備用于生成的 decoder_input_ids
實例方法_update_model_kwargs_for_generation更新下一步生成所需的 model_kwargs
實例方法_reorder_cache重新排序緩存,需要在子類中實現,以適應 beam search 等方法。
實例方法_get_candidate_generator返回在輔助生成中使用的候選生成器。
實例方法_get_logits_processor返回一個 LogitsProcessorList,其中包含所有用于修改分數的相關 LogitsProcessor 實例。
實例方法_get_stopping_criteria返回用于生成的 StoppingCriteriaList,包括各種停止條件。
實例方法_merge_criteria_processor_list合并默認和自定義的 criteria 或 processor 列表。
實例方法compute_transition_scores根據生成的得分計算序列的轉移得分。
實例方法_validate_model_class驗證模型類是否兼容生成操作。
實例方法_validate_assistant驗證輔助模型(如果提供)是否兼容和正確配置。
實例方法_validate_model_kwargs驗證用于生成的 model_kwargs 參數。
實例方法_validate_generated_length執行與生成長度相關的驗證,確保參數設置正確。
實例方法_prepare_generated_length在生成配置中準備最大和最小長度,避免參數沖突。
實例方法_prepare_generation_config準備基礎的生成配置,并應用來自 kwargs 的任何選項。
實例方法_get_initial_cache_position計算預填充階段的 cache_position
實例方法_get_cache根據參數為生成設置緩存。
實例方法_supports_default_dynamic_cache返回模型是否支持將 DynamicCache 實例作為 past_key_values
實例方法_prepare_cache_for_generation為生成準備緩存,并將其寫入 model_kwargs
實例方法_supports_logits_to_keep返回模型是否支持 logits_to_keep 參數,用于節省內存。
實例方法_prepare_special_tokens為生成準備特殊的 tokens,如 bos_token_ideos_token_id 等。
實例方法generate為具有語言模型頭的模型生成 token 序列,是生成過程的主要入口方法。
實例方法_has_unfinished_sequences檢查設備中是否仍然存在未完成的序列,用于確定是否繼續生成循環。
實例方法heal_tokens生成 token 序列,其中每個序列的尾部 token 替換為適當的擴展,用于修復不完整的 token。
實例方法_dola_decoding使用 DoLa 解碼生成序列,一種改進生成質量的解碼策略。
實例方法_contrastive_search使用對比搜索生成序列,旨在改善生成文本的質量和多樣性。
實例方法_sample使用多項式采樣生成序列,可以實現隨機性和多樣性。
實例方法_temporary_reorder_cache臨時函數,用于處理不同類型的緩存重新排序。
實例方法_beam_search使用 beam search 解碼生成序列,支持高質量的序列生成。
實例方法_group_beam_search使用分組 beam search 解碼生成序列,引入多樣性。
實例方法_constrained_beam_search使用受限 beam search 解碼生成序列,支持強制包含特定詞語等約束。
實例方法_assisted_decoding使用輔助解碼生成序列,利用輔助模型加速和改善生成過程。

GenerationMixin中最核心的方法是generate方法,其它方法都是generate方法的輔助方法:
GenerationMixin:generate

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/75760.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/75760.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/75760.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【備賽】藍橋杯嵌入式實現led閃爍

原理 由于藍橋杯的板子帶有鎖存器,并且與lcd屏幕有沖突,所以這個就成了考點。 主要就是用定時器來實現,同時也要兼顧lcd的沖突。 一、處理LCD函數 首先來解決與lcd屏幕沖突的問題,把我們所有用到的lcd函數改裝一下。 以下是基…

C++ 并發性能優化實戰:提升多線程應用的效率與穩定性

🧑 博主簡介:CSDN博客專家、CSDN平臺優質創作者,獲得2024年博客之星榮譽證書,高級開發工程師,數學專業,擁有高級工程師證書;擅長C/C、C#等開發語言,熟悉Java常用開發技術&#xff0c…

Python----計算機視覺處理(Opencv:道路檢測之車道線擬合)

完整版: Python----計算機視覺處理(Opencv:道路檢測完整版:透視變換,提取車道線,車道線擬合,車道線顯示) 一、獲取左右車道線的原始位置 導入模塊 import cv2 import numpy as np from matplot…

優選算法的妙思之流:分治——歸并專題

專欄:算法的魔法世界 個人主頁:手握風云 目錄 一、歸并排序 二、例題講解 2.1. 排序數組 2.2. 交易逆序對的總數 2.3. 計算右側小于當前元素的個數 2.4. 翻轉對 一、歸并排序 歸并排序也是采用了分治的思想,將數組劃分為多個長度為1的子…

C語言查漏補缺:基礎篇

1.原理 C語言是一門編譯型計算機語言,要編寫C代碼,C源代碼文本文件本身無法直接執行,必須通過編譯器翻譯和鏈接器的鏈接,生成二進制的可執行文件,然后才能執行。這里的二進制的可執行文件就是我們最終要形成的可執行程…

TPS入門DAY02 服務器篇

1.創建空白插件 2.導入在線子系統以及在線steam子系統庫 MultiplayerSessions.uplugin MultiplayerSessions.Build.cs 3.創建游戲實例以及初始化會話創建流程 創建會話需要的函數,委托,委托綁定的回調,在線子系統接口綁定某一個委托的控制其…

產品經理課程

原型工具 一、土耳其機器人 這個說法來源于 1770 年出現的一個騙局,一個叫沃爾夫岡馮肯佩倫(Wolfgang von Kempelen)的人為了取悅奧地利女皇瑪麗婭特蕾莎(Maria Theresia),“制造”了一個會下國際象棋的機…

nginx中的limit_req 和 limit_conn

在 Nginx 中,limit_req 和 limit_conn 是兩個用于限制客戶端請求的指令,它們分別用于限制請求速率和并發連接數。 limit_req limit_req 用于限制請求速率,防止客戶端發送過多請求影響服務器性能。它通過 limit_req_zone 指令定義一個共享內存…

基于winform的串口調試助手

目錄 一、串口助手界面設計 1.1 串口配置 1.2 接收配置 1.3 發送配置 1.4 接收窗口和發送窗口 1.5 狀態顯示窗口 1.6 串口通訊控件 二、程序編寫 2.1 端口號自動識別并顯示在端口號下拉框 功能說明: 2.2 波特率下拉框顯示 2.3 數據位下拉框顯示 2.4 校…

Docker基礎2

如需轉載,標記出處 本次我們將下載一個 Docker 鏡像,從鏡像中啟動容器 上一章,安裝 Docker 時,獲得兩個主要組件: Docker 客戶端 Docker 守護進程(有時稱為“服務器”或“引擎”) 守護進程實…

Rocketmq2

一、生產者端防丟失 1. 發送方式選擇 同步發送:使用 send() 方法,等待 Broker 確認響應(SendResult),確保消息已成功發送。異步發送:使用 sendAsync() 方法并設置回調函數,處理發送成功 / 失敗…

RabbitMQ詳解,RabbitMQ是什么?架構是怎樣的?

目錄 一,RabbitMQ是什么? 二,RabbitMQ架構 2.1 首先我們來看下RabbitMQ里面的心概念Queue是什么? 2.2 交換器Exchange 2.3 RabbitMQ是什么? 2.4 重點看下優先級隊列是什么? 三,RabbitMQ集群 3.1 普通集群模式 3.2 鏡像隊列集群 一,RabbitMQ是什么? 假設我們程序…

【一步步開發AI運動APP】六、運動計時計數能調用

之前我們為您分享了【一步步開發AI運動小程序】開發系列博文,通過該系列博文,很多開發者開發出了很多精美的AI健身、線上運動賽事、AI學生體測、美體、康復鍛煉等應用場景的AI運動小程序;為了幫助開發者繼續深耕AI運動領域市場,今…

MySQL——DQL的多表查詢

一、交叉連接 標準語法:select * from 表1 cross join 表2 where 表1.公共列 表2.公共列; 簡單語法:select * from 表1 , 表2 where 表1.公共列 表2.公共列; 公共列:兩張表具有相同含義的列,不是列名一樣。 …

【Linux內核】如何更加優雅閱讀Linux內核源碼(vscode)

1. 前言 因為已經習慣在Ubuntu下進行嵌入式工作開發,但Linux源碼在Source Insight下進行閱讀,一直很苦惱Linux/Windows來回切換的開發方式,當前發現可以通過 vscode clangd(擴展組件) 方式進行更好的內核源碼閱讀。 2. 環境 操作系統&…

21.OpenCV獲取圖像輪廓信息

OpenCV獲取圖像輪廓信息 在計算機視覺領域,識別和分析圖像中的對象形狀是一項基本任務。OpenCV 庫提供了一個強大的工具——輪廓檢測(Contour Detection),它能夠幫助我們精確地定位對象的邊界。這篇博文將帶你入門 OpenCV 的輪廓…

LETTERS(DFS)

【題目描述】 給出一個rowcolrowcol的大寫字母矩陣,一開始的位置為左上角,你可以向上下左右四個方向移動,并且不能移向曾經經過的字母。問最多可以經過幾個字母。 【輸入】 第一行,輸入字母矩陣行數RR和列數SS,1≤R,S≤…

Day2-2:前端項目uniapp壁紙實戰

再在wallpaper新建一個目錄components 在components下新建組件common-title 記得點擊創建同名目錄 在index加 <view class"select"><common-title></common-title></view> 圖片換了下&#xff0c;原來的有點丑&#xff0c;圖片可按自己喜歡…

其他 vector 操作詳解(四十)

介紹 除去向 vector 添加元素&#xff08;如 push_back&#xff09;之外&#xff0c;vector 還提供了許多其他操作&#xff0c;這些操作大多與 string 的操作類似。通過掌握這些操作&#xff0c;我們可以方便地查詢、修改和比較 vector 中的元素&#xff0c;從而構建靈活、高效…

【Leetcode 每日一題】368. 最大整除子集

問題背景 給你一個由 無重復 正整數組成的集合 n u m s nums nums&#xff0c;請你找出并返回其中最大的整除子集 a n s w e r answer answer&#xff0c;子集中每一元素對 ( a n s w e r [ i ] , a n s w e r [ j ] ) (answer[i], answer[j]) (answer[i],answer[j]) 都應當…