對大模型輸出的 logits 進行處理,從而控制文本的生成
flyfish
在文本生成任務中,模型輸出的 logits
代表了每個詞被選為下一個生成詞的未歸一化概率得分。通過對 logits
進行處理,可以精確地控制文本的生成
基本原理
在每一步生成過程中,模型會輸出一個 logits
向量,其長度等于詞匯表的大小,每個元素對應詞匯表中一個詞的得分。通常,會對 logits
應用 softmax
函數將其轉換為概率分布,然后根據這個概率分布來選擇下一個生成的詞。而 logits_processor
就是在應用 softmax
函數之前,對 logits
進行修改,從而改變最終的概率分布和詞的選擇。
具體控制方式
1. 避免重復
- 重復懲罰(
RepetitionPenaltyLogitsProcessor
)- 機制:對于已經在生成文本中出現過的詞,降低其
logits
的值。具體來說,會將這些詞的logits
除以一個大于 1 的懲罰系數,使得它們在后續生成中被選中的概率降低。 - 示例:假設生成的文本中已經出現了“蘋果”這個詞,當模型再次預測下一個詞時,“蘋果”對應的
logits
會被懲罰,從而減少再次生成“蘋果”的可能性,避免文本中出現過多重復內容。
- 機制:對于已經在生成文本中出現過的詞,降低其
- 不重復 n - gram(
NoRepeatNGramLogitsProcessor
)- 機制:檢查生成的文本中是否已經存在某個 n - gram(連續的 n 個詞),如果存在,則將可能導致該 n - gram 重復出現的詞的
logits
設為負無窮。這樣,在后續的概率計算中,這些詞的概率會變為 0,不會被選中。 - 示例:如果 n = 2,當前生成的文本是“我 喜歡”,那么在選擇下一個詞時,會避免選擇那些會導致“我 喜歡”這個 2 - gram 重復出現的詞,如“我”或“喜歡”,從而提高文本的多樣性。
- 機制:檢查生成的文本中是否已經存在某個 n - gram(連續的 n 個詞),如果存在,則將可能導致該 n - gram 重復出現的詞的
2. 控制生成長度
- 最小長度限制(
MinLengthLogitsProcessor
)- 機制:在生成的文本長度未達到指定的最小長度之前,將結束標記(EOS)的
logits
設為負無窮。這樣,在softmax
處理后,結束標記的概率會變為 0,模型不會選擇結束生成,確保文本達到一定的長度。 - 示例:如果設置最小長度為 10 個詞,在生成的詞數小于 10 時,結束標記的
logits
始終為負無窮,模型會繼續生成,直到達到最小長度要求。
- 機制:在生成的文本長度未達到指定的最小長度之前,將結束標記(EOS)的
- 最小新標記數(
MinNewTokensLengthLogitsProcessor
)- 機制:類似于最小長度限制,不過是針對新生成的標記數量。在新生成的標記數未達到指定數量之前,降低結束標記的
logits
,保證生成足夠數量的新內容。
- 機制:類似于最小長度限制,不過是針對新生成的標記數量。在新生成的標記數未達到指定數量之前,降低結束標記的
3. 采樣策略調整
- 溫度調整(
TemperatureLogitsWarper
)- 機制:將
logits
除以一個溫度參數temperature
。溫度越高,logits
之間的差異會被縮小,經過softmax
處理后,概率分布會更加均勻,采樣會更隨機;溫度越低,logits
之間的差異會被放大,概率分布會更集中,更傾向于選擇概率最大的詞。 - 示例:當
temperature = 1
時,保持原始的logits
分布;當temperature > 1
時,模型可能會生成一些更具創意但可能不太準確的文本;當temperature < 1
時,模型會更保守,生成的文本更符合常見的表達。
- 機制:將
- Top - k 采樣(
TopKLogitsWarper
)- 機制:只保留
logits
中概率最高的 k 個詞,將其余詞的logits
設為負無窮。這樣,在后續的采樣中,只會從這 k 個詞中選擇下一個生成的詞,限制了采樣范圍,提高了生成的穩定性。 - 示例:如果 k = 5,模型會在每次生成時,只考慮概率最高的 5 個詞,排除其他詞的干擾。
- 機制:只保留
- Top - p 采樣(
TopPLogitsWarper
)- 機制:選擇累積概率達到 p 的最小詞集合,只保留這些詞的
logits
,其余詞的logits
設為負無窮。這種方法結合了概率和詞的數量,既能控制采樣范圍,又能適應不同的概率分布。 - 示例:如果 p = 0.9,模型會選擇累積概率達到 0.9 的最小詞集合,從這個集合中進行采樣。
- 機制:選擇累積概率達到 p 的最小詞集合,只保留這些詞的
4. 約束生成內容
- 禁用詞過濾(
NoBadWordsLogitsProcessor
)- 機制:將禁用詞的
logits
設為負無窮,使得這些詞在后續的概率計算中概率為 0,不會被選中,從而避免生成包含禁用詞的文本。 - 示例:如果禁用詞列表中包含“臟話”,那么在生成過程中,“臟話”對應的
logits
會被設為負無窮,不會出現在生成的文本中。
- 機制:將禁用詞的
- 前綴約束(
PrefixConstrainedLogitsProcessor
)- 機制:根據給定的前綴允許標記函數,限制生成的詞必須符合特定的前綴約束。不符合約束的詞的
logits
會被設為負無窮,從而保證生成的文本符合特定的前綴要求。 - 示例:如果要求生成的文本必須以“今天”開頭,那么在生成第一個詞時,只有與“今天”相關的詞的
logits
會被保留,其他詞的logits
會被設為負無窮。
- 機制:根據給定的前綴允許標記函數,限制生成的詞必須符合特定的前綴約束。不符合約束的詞的
配置參數
參數 | 數據類型 | 默認值 | 含義 |
---|---|---|---|
guidance_scale | float | None | 引導比例,用于無批量分類器自由引導,值不為 1 時會添加相應的 logits 處理器,影響生成過程的引導程度。 |
sequence_bias | - | None | 序列偏差,用于控制特定序列的生成概率,設置后會添加序列偏差 logits 處理器。 |
diversity_penalty | float | None | 多樣性懲罰,大于 0 時會添加漢明多樣性 logits 處理器,鼓勵生成結果更具多樣性。 |
encoder_repetition_penalty | float | None | 編碼器重復懲罰,不為 1 且編碼器輸入 ID 形狀符合要求時,會添加編碼器重復懲罰 logits 處理器,減少編碼器輸入相關的重復內容。 |
repetition_penalty | float | None | 重復懲罰,不為 1 時會添加重復懲罰 logits 處理器,防止生成結果出現過多重復。 |
no_repeat_ngram_size | int | None | 不重復 n - gram 大小,大于 0 時會添加不重復 n - gram logits 處理器,避免生成的文本中出現重復的 n - gram 片段。 |
encoder_no_repeat_ngram_size | int | None | 編碼器不重復 n - gram 大小,大于 0 且編碼器輸入 ID 形狀符合要求時,會添加編碼器不重復 n - gram logits 處理器,減少編碼器輸入相關的重復 n - gram 內容。 |
bad_words_ids | - | None | 禁用詞 ID,設置后會添加禁用詞 logits 處理器,防止生成包含指定禁用詞的文本。 |
min_length | int | None | 最小長度,大于 0 且有結束標記張量時,會添加最小長度 logits 處理器,確保生成的文本達到最小長度要求。 |
min_new_tokens | int | None | 最小新標記數,大于 0 且有結束標記張量時,會添加最小新標記長度 logits 處理器,保證生成的新標記數量達到要求。 |
forced_bos_token_id | int | None | 強制起始標記 ID,設置后會添加強制起始標記 logits 處理器,確保生成的文本以指定的標記開始。 |
forced_eos_token_id | int | None | 強制結束標記 ID,設置后會添加強制結束標記 logits 處理器,使生成的文本在達到指定標記時結束。 |
remove_invalid_values | bool | False | 是否移除無效值,為 True 時會添加移除無窮大和 NaN 值的 logits 處理器,保證生成過程中 logits 的有效性。 |
exponential_decay_length_penalty | - | None | 指數衰減長度懲罰,設置后會添加指數衰減長度懲罰處理器,對生成文本的長度進行懲罰控制。 |
suppress_tokens | - | None | 抑制標記,設置后會添加抑制標記 logits 處理器,降低指定標記的生成概率。 |
begin_suppress_tokens | - | None | 起始抑制標記,設置后會添加起始抑制標記 logits 處理器,在生成的起始階段抑制指定標記的生成。 |
forced_decoder_ids | - | None | 強制解碼器 ID,不建議使用,設置后會拋出異常,提示使用 input_ids 或 decoder_input_ids 代替。 |
do_sample | bool | False | 是否使用采樣策略,為 True 時會根據其他采樣相關參數添加相應的 logits 調整器。 |
temperature | float | None | 采樣溫度,不為 1 時會添加溫度 logits 調整器,控制采樣的隨機性,值越大隨機性越強。 |
top_k | int | None | top - k 采樣值,不為 0 時會添加 top - k logits 調整器,只考慮概率最高的 k 個標記進行采樣。 |
top_p | float | None | top - p 采樣值,小于 1 時會添加 top - p logits 調整器,只考慮累積概率達到 p 的標記進行采樣。 |
min_p | float | None | 最小概率閾值,設置后會添加最小概率 logits 調整器,在溫度縮放后應用,控制采樣的最小概率。 |
typical_p | float | None | 典型概率采樣值,小于 1 時會添加典型概率 logits 調整器,基于典型概率進行采樣。 |
epsilon_cutoff | float | None | epsilon 截斷值,在 0 到 1 之間時會添加 epsilon logits 調整器,用于截斷低概率標記。 |
eta_cutoff | float | None | eta 截斷值,在 0 到 1 之間時會添加 eta logits 調整器,結合設備信息對低概率標記進行截斷。 |
watermarking_config | - | None | 水印配置,設置后會添加水印處理器,在生成的文本中添加水印。 |
renormalize_logits | bool | False | 是否重新歸一化 logits,為 True 時會添加 logit 歸一化處理器,確保 logits 歸一化。 |
logits
說明
logits
是模型在進行分類或預測任務時,最后一層神經元的原始輸出值,它是未經過歸一化處理的數值。在文本生成場景中,logits
代表了模型預測詞匯表中每個詞作為下一個生成詞的得分,這些得分反映了模型對每個詞成為下一個詞的相對可能性判斷,但并非是概率值。
數學公式
1. 線性變換得到 logits
在許多深度學習模型中,logits
通常是通過對前一層的輸出進行線性變換得到的。假設模型前一層的輸出為向量 h \mathbf{h} h,權重矩陣為 W \mathbf{W} W,偏置向量為 b \mathbf{b} b,則 logits
向量 z \mathbf{z} z 的計算公式如下:
z = W h + b \mathbf{z} = \mathbf{W}\mathbf{h} + \mathbf{b} z=Wh+b
其中, h \mathbf{h} h 是前一層輸出的特征向量,維度通常為 d h d_h dh?; W \mathbf{W} W 是權重矩陣,維度為 V × d h V \times d_h V×dh?, V V V 是詞匯表的大小; b \mathbf{b} b 是偏置向量,維度為 V V V; z \mathbf{z} z 是 logits
向量,維度為 V V V,每個元素 z i z_i zi? 對應詞匯表中第 i i i 個詞的得分。
2. logits
轉換為概率分布
為了將 logits
轉換為概率分布,通常會使用 softmax
函數。softmax
函數可以將 logits
向量中的每個元素轉換為一個在 [ 0 , 1 ] [0, 1] [0,1] 范圍內的值,且所有元素之和為 1,符合概率分布的定義。softmax
函數的數學公式如下:
P ( y i ) = e z i ∑ j = 1 V e z j P(y_i) = \frac{e^{z_i}}{\sum_{j=1}^{V} e^{z_j}} P(yi?)=∑j=1V?ezj?ezi??
其中, P ( y i ) P(y_i) P(yi?) 是詞匯表中第 i i i 個詞被選為下一個生成詞的概率, z i z_i zi? 是 logits
向量中第 i i i 個元素的值, V V V 是詞匯表的大小。
示例
假設詞匯表大小 V = 3 V = 3 V=3,模型輸出的 logits
向量為 z = [ 2 , 1 , 3 ] \mathbf{z} = [2, 1, 3] z=[2,1,3],下面計算經過 softmax
函數處理后的概率分布:
首先,計算分母的值:
∑ j = 1 3 e z j = e 2 + e 1 + e 3 ≈ 7.389 + 2.718 + 20.086 = 30.193 \sum_{j=1}^{3} e^{z_j} = e^2 + e^1 + e^3 \approx 7.389 + 2.718 + 20.086 = 30.193 j=1∑3?ezj?=e2+e1+e3≈7.389+2.718+20.086=30.193
然后,分別計算每個詞的概率:
P ( y 1 ) = e 2 30.193 ≈ 7.389 30.193 ≈ 0.245 P(y_1) = \frac{e^2}{30.193} \approx \frac{7.389}{30.193} \approx 0.245 P(y1?)=30.193e2?≈30.1937.389?≈0.245
P ( y 2 ) = e 1 30.193 ≈ 2.718 30.193 ≈ 0.090 P(y_2) = \frac{e^1}{30.193} \approx \frac{2.718}{30.193} \approx 0.090 P(y2?)=30.193e1?≈30.1932.718?≈0.090
P ( y 3 ) = e 3 30.193 ≈ 20.086 30.193 ≈ 0.665 P(y_3) = \frac{e^3}{30.193} \approx \frac{20.086}{30.193} \approx 0.665 P(y3?)=30.193e3?≈30.19320.086?≈0.665
可以看到,經過 softmax
函數處理后,得到了一個概率分布 [ 0.245 , 0.090 , 0.665 ] [0.245, 0.090, 0.665] [0.245,0.090,0.665],表示詞匯表中三個詞被選為下一個生成詞的概率。
在模型中的作用
在文本生成任務中,模型會根據 logits
轉換后的概率分布來選擇下一個生成的詞。常見的選擇方法有貪心搜索(選擇概率最大的詞)、采樣搜索(根據概率分布隨機采樣)等。同時,logits_processor
會對 logits
進行調整,從而影響最終的概率分布和詞的選擇,以控制文本生成的行為和質量。
代碼說明
logits_processor 是 _get_logits_processor 方法的一個參數,它是一個可選的 LogitsProcessorList 對象。這個方法會根據 GenerationConfig 中的各種配置參數,創建一系列不同的 LogitsProcessor 實例,并將它們添加到 processors 列表中。最后,如果傳入了 logits_processor,還會將其與新創建的處理器列表進行合并。
def _get_logits_processor(self,generation_config: GenerationConfig,input_ids_seq_length: int,encoder_input_ids: torch.LongTensor,prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],logits_processor: Optional[LogitsProcessorList],device: str = None,model_kwargs: Optional[Dict[str, Any]] = None,negative_prompt_ids: Optional[torch.Tensor] = None,negative_prompt_attention_mask: Optional[torch.Tensor] = None,) -> LogitsProcessorList:"""此函數返回一個 `LogitsProcessorList` 對象,該對象包含所有用于修改語言模型頭部得分的相關 `LogitsProcessor` 實例。這些處理器會對模型預測的 logits 進行調整,以控制文本生成的行為,例如避免重復、控制生成長度等。參數:generation_config (GenerationConfig): 生成配置對象,包含了文本生成過程中的各種配置參數。input_ids_seq_length (int): 輸入 ID 序列的長度。encoder_input_ids (torch.LongTensor): 編碼器的輸入 ID。prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): 一個可調用對象,用于指定允許的前綴標記。logits_processor (Optional[LogitsProcessorList]): 可選的 logits 處理器列表。device (str, optional): 設備名稱,如 'cuda' 或 'cpu'。默認為 None。model_kwargs (Optional[Dict[str, Any]], optional): 模型的其他關鍵字參數。默認為 None。negative_prompt_ids (Optional[torch.Tensor], optional): 負提示的 ID。默認為 None。negative_prompt_attention_mask (Optional[torch.Tensor], optional): 負提示的注意力掩碼。默認為 None。返回:LogitsProcessorList: 包含所有 logits 處理器的列表。"""# 實例化一個空的處理器列表processors = LogitsProcessorList()# 如果配置了引導比例且不為 1,則添加無批量分類器自由引導 logits 處理器if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:processors.append(UnbatchedClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale,self,unconditional_ids=negative_prompt_ids,unconditional_attention_mask=negative_prompt_attention_mask,use_cache=generation_config.use_cache,))# 如果配置了序列偏差,則添加序列偏差 logits 處理器if generation_config.sequence_bias is not None:processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))# 如果配置了多樣性懲罰且大于 0,則添加漢明多樣性 logits 處理器if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:processors.append(HammingDiversityLogitsProcessor(diversity_penalty=generation_config.diversity_penalty,num_beams=generation_config.num_beams,num_beam_groups=generation_config.num_beam_groups,))# 如果配置了編碼器重復懲罰且不為 1,并且編碼器輸入 ID 的形狀為二維,則添加編碼器重復懲罰 logits 處理器if (generation_config.encoder_repetition_penalty is not Noneand generation_config.encoder_repetition_penalty != 1.0):if len(encoder_input_ids.shape) == 2:processors.append(EncoderRepetitionPenaltyLogitsProcessor(penalty=generation_config.encoder_repetition_penalty,encoder_input_ids=encoder_input_ids,))else:# 如果編碼器輸入 ID 形狀不符合要求,發出警告warnings.warn("Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to ""`generate`, ignoring the argument.",UserWarning,)# 如果配置了重復懲罰且不為 1,則添加重復懲罰 logits 處理器if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))# 如果配置了不重復 n-gram 大小且大于 0,則添加不重復 n-gram logits 處理器if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))# 如果配置了編碼器不重復 n-gram 大小且大于 0,并且編碼器輸入 ID 的形狀為二維,則添加編碼器不重復 n-gram logits 處理器if (generation_config.encoder_no_repeat_ngram_size is not Noneand generation_config.encoder_no_repeat_ngram_size > 0):if len(encoder_input_ids.shape) == 2:processors.append(EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size,encoder_input_ids,))else:# 如果編碼器輸入 ID 形狀不符合要求,發出警告warnings.warn("Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to ""`generate`, ignoring the argument.",UserWarning,)# 如果配置了禁用詞 ID,則添加禁用詞 logits 處理器if generation_config.bad_words_ids is not None:processors.append(NoBadWordsLogitsProcessor(generation_config.bad_words_ids,generation_config._eos_token_tensor,))# 如果配置了最小長度且大于 0,并且有結束標記張量,則添加最小長度 logits 處理器if (generation_config.min_length is not Noneand generation_config._eos_token_tensor is not Noneand generation_config.min_length > 0):processors.append(MinLengthLogitsProcessor(generation_config.min_length,generation_config._eos_token_tensor,device=device,))# 如果配置了最小新標記數且大于 0,并且有結束標記張量,則添加最小新標記長度 logits 處理器if (generation_config.min_new_tokens is not Noneand generation_config._eos_token_tensor is not Noneand generation_config.min_new_tokens > 0):processors.append(MinNewTokensLengthLogitsProcessor(input_ids_seq_length,generation_config.min_new_tokens,generation_config._eos_token_tensor,device=device,))# 如果提供了前綴允許標記函數,則添加前綴約束 logits 處理器if prefix_allowed_tokens_fn is not None:processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn,generation_config.num_beams // generation_config.num_beam_groups,))# 如果配置了強制起始標記 ID,則添加強制起始標記 logits 處理器if generation_config.forced_bos_token_id is not None:processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id,))# 如果配置了強制結束標記 ID,則添加強制結束標記 logits 處理器if generation_config.forced_eos_token_id is not None:processors.append(ForcedEOSTokenLogitsProcessor(generation_config.max_length,generation_config.forced_eos_token_id,device=device,))# 如果配置了移除無效值,則添加移除無窮大和 NaN 值的 logits 處理器if generation_config.remove_invalid_values is True:processors.append(InfNanRemoveLogitsProcessor())# 如果配置了指數衰減長度懲罰,則添加指數衰減長度懲罰處理器if generation_config.exponential_decay_length_penalty is not None:processors.append(ExponentialDecayLengthPenalty(generation_config.exponential_decay_length_penalty,generation_config._eos_token_tensor,input_ids_seq_length,))# 如果配置了抑制標記,則添加抑制標記 logits 處理器if generation_config.suppress_tokens is not None:processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens,device=device,))# 如果配置了起始抑制標記,則添加起始抑制標記 logits 處理器if generation_config.begin_suppress_tokens is not None:begin_index = input_ids_seq_lengthbegin_index = (begin_indexif (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)else begin_index + 1)processors.append(SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens,begin_index,device=device,))# 如果配置了強制解碼器 ID,則拋出異常,提示使用 input_ids 或 decoder_input_ids 代替if generation_config.forced_decoder_ids is not None:# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PTraise ValueError("You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument ""in favour of `input_ids` or `decoder_input_ids` respectively.",)# 合并自定義的 logits 處理器列表processors = self._merge_criteria_processor_list(processors, logits_processor)# 以下處理器之前被稱為 `LogitsWarpers`,僅在采樣策略下應用if generation_config.do_sample:# 在束搜索方法中,我們需要至少保留一個非結束標記來探索可能有更好得分的延續(即保留 len(list(generation_config._eos_token_tensor)) + 1)if generation_config.num_beams > 1:if isinstance(generation_config._eos_token_tensor, list):min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1elif isinstance(generation_config._eos_token_tensor, torch.Tensor):min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1else:min_tokens_to_keep = 2else:min_tokens_to_keep = 1# 以下思路主要借鑒自這個 PR: https://github.com/huggingface/transformers/pull/5420/files# 所有采樣器可以在 `generation_utils_samplers.py` 中找到# 如果配置了溫度且不為 1,則添加溫度 logits 調整器if generation_config.temperature is not None and generation_config.temperature != 1.0:processors.append(TemperatureLogitsWarper(generation_config.temperature))# 如果配置了 top-k 采樣且不為 0,則添加 top-k logits 調整器if generation_config.top_k is not None and generation_config.top_k != 0:processors.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))# 如果配置了 top-p 采樣且小于 1,則添加 top-p logits 調整器if generation_config.top_p is not None and generation_config.top_p < 1.0:processors.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))# 如果配置了最小概率閾值,則添加最小概率 logits 調整器if generation_config.min_p is not None:# 在溫度縮放后應用(見 https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)processors.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))# 如果配置了典型概率采樣且小于 1,則添加典型概率 logits 調整器if generation_config.typical_p is not None and generation_config.typical_p < 1.0:processors.append(TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep))# 如果配置了 epsilon 截斷且在 0 到 1 之間,則添加 epsilon logits 調整器if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:processors.append(EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep))# 如果配置了 eta 截斷且在 0 到 1 之間,則添加 eta logits 調整器if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:processors.append(EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device))# 水印處理應該在所有 logits 處理完成后進行(見 #34630)if generation_config.watermarking_config is not None:processors.append(generation_config.watermarking_config.construct_processor(self.config.vocab_size, device))# `LogitNormalization` 應該始終是最后一個 logit 處理器(如果存在)if generation_config.renormalize_logits is True:processors.append(LogitNormalization())return processors
transformers/src/transformers/generation/utils.py