NLP實踐——LLM生成過程中防止重復
- 1. 準備工作
- 2. 問題分析
- 3. 創建processor
- 3.1 防止重復生成的processor
- 3.2 防止數字無規則循環的processor
- 4. 使用
本文介紹如何使用LogitsProcessor避免大模型在生成過程中出現重復的問題。
1. 準備工作
首先實例化一個大模型,以GLM2為例:
import re
import os
import json
import random
from typing import *
from copy import deepcopyimport torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList, MaxNewTokensCriteria, StoppingCriteria
創建模型:
tokenizer = AutoTokenizer.from_pretrained(".../ChatGLM2/", trust_remote_code=True)
model = AutoModel.from_pretrained(".../ChatGLM2/", trust_remote_code=True).half()
model.to('cuda:0')
2. 問題分析
接下來思考一下,如何防止模型不停的重復呢?重復分為幾種情況,一個字符循環出現,或者多個字符循環出現,例如:
'abcdeeeeee'
'abcdededede'
從生成的過程來考慮,防止模型生成重復的內容,第一步自然是要判斷模型陷入了重復,第二步就是打斷它重復的過程,也就是將重復的token,在當前step生成的時候,將其概率設置為-inf,那么重復的過程自然就停止了。
3. 創建processor
3.1 防止重復生成的processor
先來解決如何判定重復。這里直接去leetcode上找一個題,獲取一個字符串中最大的重復片段,解法如下:
def longest_dup_substring(s: str) -> str:# 生成兩個進制a1, a2 = random.randint(26, 100), random.randint(26, 100)# 生成兩個模mod1, mod2 = random.randint(10**9+7, 2**31-1), random.randint(10**9+7, 2**31-1)n = len(s)# 先對所有字符進行編碼arr = [ord(c)-ord('a') for c in s]# 二分查找的范圍是[1, n-1]l, r = 1, n-1length, start = 0, -1while l <= r:m = l + (r - l + 1) // 2idx = check(arr, m, a1, a2, mod1, mod2)# 有重復子串,移動左邊界if idx != -1:l = m + 1length = mstart = idx# 無重復子串,移動右邊界else:r = m - 1return s[start:start+length] if start != -1 else ""def check(arr, m, a1, a2, mod1, mod2):n = len(arr)aL1, aL2 = pow(a1, m, mod1), pow(a2, m, mod2)h1, h2 = 0, 0for i in range(m):h1 = (h1 * a1 + arr[i]) % mod1h2 = (h2 * a2 + arr[i]) % mod2# 存儲一個編碼組合是否出現過seen = {(h1, h2)}for start in range(1, n - m + 1):h1 = (h1 * a1 - arr[start - 1] * aL1 + arr[start + m - 1]) % mod1h2 = (h2 * a2 - arr[start - 1] * aL2 + arr[start + m - 1]) % mod2# 如果重復,則返回重復串的起點if (h1, h2) in seen:return startseen.add((h1, h2))# 沒有重復,則返回-1return -1
效果如下:
longestDupSubstring('埃爾多安經濟學可以重振經濟,土耳其土耳其')
# '土耳其'
那么我們就可以寫一個processor,在每一個step即將生成的時候,判定一下,是否之前已經生成的結果中,出現了重復。以及,如果出現了重復,則禁止重復部分的第一個token(例如上面例子中,土耳其的土字),在當前step被生成。
針對實際使用中由這個processor引發的一些其他的問題,我又對這個processor增加了一點規則限制,一個比較好用的版本如下。
其中的參數threshold是判斷重復多少的情況算作循環,例如將threshold設置為10,那么如果重復部分的長度是3,重復了3次,3×3=9,則不被判定為陷入了循環,而如果重復了4次,3×4=12,則被判定為循環,此時processor將發揮效果了。
class ForbidDuplicationProcessor(LogitsProcessor):"""防止生成的內容陷入循環。當循環內容與循環次數之乘積大于指定次數則在生成下一個token時將循環內容的第一個token概率設置為0---------------ver: 2023-08-17by: changhongyu"""def __init__(self, tokenizer, threshold: int = 10):self.tokenizer = tokenizerself.threshold = thresholddef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:current_sequence = self.tokenizer.decode(input_ids[0][current_token_len: ])current_dup_str = longest_dup_substring(current_sequence)if len(current_dup_str):# 如果存在重復子序列,則根據其長度與重復次數判斷是否禁止循環if len(current_dup_str) > 1 or (len(current_dup_str) == 1 and current_dup_str * self.threshold in current_sequence):if len(current_dup_str) * current_sequence.count(current_dup_str) >= self.threshold:token_ids = self.tokenizer.encode(current_dup_str)# 獲取截止目前的上一個tokenlast_token = input_ids[0][-1].detach().cpu().numpy().tolist()if len(token_ids) and last_token == token_ids[-1]:# 如果截止目前的上一個token,與重復部分的最后一個token一致# 說明即將觸發重復, 先把重復部分的第一個token禁掉scores[:, token_ids[0]] = 0# 然后按出現比率判斷是否重復部分內還有其他重復for token_id in token_ids:if token_ids.count(token_id) * len(token_ids) > 1.2:scores[:, token_id] = 0return scores
需要注意的是,為了獲取當前的序列已經生成的長度,需要在processor的外部,也就是與model.generate同級的結構處,定義一個全局變量current_token_len
。
global current_token_len
3.2 防止數字無規則循環的processor
出了上述的情況,還有一種常見的循環,無法利用上面的規則解決,即數字無規則循環的情況。針對這個場景,創建另一個processor,只要連續出現的數字出現次數,大于一定的閾值,則禁止當前step再次生成數字。
class MaxConsecutiveProcessor(LogitsProcessor):"""給定一個集合,集合中的字符最多連續若干次下一次生成時不能再出現該集合中的字符---------------ver: 2023-08-17by: changhongyu---------------修復bugver: 2023-09-11"""def __init__(self, consecutive_token_ids, max_num: int = 10):self.consecutive_token_ids = consecutive_token_idsself.max_num = max_numdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:input_ids_list = input_ids.squeeze(0).detach().cpu().numpy().tolist()cur_num = 0for token in input_ids_list[::-1]:if token in self.consecutive_token_ids:cur_num += 1else:breakif cur_num >= self.max_num:# 如果連續次數超過閾值,那集合中的所有token在下一個step都不可以再出現for token_id in self.consecutive_token_ids:scores[..., token_id] = 0return scores
4. 使用
使用方法非常簡單,首先創建processor容器。對processor不熟悉的同學,可以去看之前的文章,有非常詳細的介紹。
logits_processor = LogitsProcessorList()
然后對于ChatGLM而言,需要先添加其默認的processor:
logits_processor.append(InvalidScoreLogitsProcessor())
接下來,再添加防止陷入循環的兩個processor:
number_tokens = [str(i) for i in range(10)] + ['.', '-']
number_token_ids = [tokenizer.convert_tokens_to_ids(tok) for tok in number_tokens]
logits_processor.append(ForbidDuplicationProcessor(tokenizer))
logits_processor.append(MaxConsecutiveProcessor(number_token_ids))
最后在調用generate的時候,把logits_processor作為參數傳進去就可以了。
以上便是使用logits_processor來防止大模型在生成過程中陷入循環的方法。經過我的反復調整,基本可以覆蓋大多數情景,如果在使用中遇到了bug,也歡迎指出。