NLP實踐——LLM生成過程中防止重復循環

NLP實踐——LLM生成過程中防止重復

  • 1. 準備工作
  • 2. 問題分析
  • 3. 創建processor
    • 3.1 防止重復生成的processor
    • 3.2 防止數字無規則循環的processor
  • 4. 使用

本文介紹如何使用LogitsProcessor避免大模型在生成過程中出現重復的問題。

1. 準備工作

首先實例化一個大模型,以GLM2為例:

import re
import os
import json
import random
from typing import *
from copy import deepcopyimport torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList, MaxNewTokensCriteria, StoppingCriteria

創建模型:

tokenizer = AutoTokenizer.from_pretrained(".../ChatGLM2/", trust_remote_code=True)
model = AutoModel.from_pretrained(".../ChatGLM2/", trust_remote_code=True).half()
model.to('cuda:0')

2. 問題分析

接下來思考一下,如何防止模型不停的重復呢?重復分為幾種情況,一個字符循環出現,或者多個字符循環出現,例如:

'abcdeeeeee'
'abcdededede'

從生成的過程來考慮,防止模型生成重復的內容,第一步自然是要判斷模型陷入了重復,第二步就是打斷它重復的過程,也就是將重復的token,在當前step生成的時候,將其概率設置為-inf,那么重復的過程自然就停止了。

3. 創建processor

3.1 防止重復生成的processor

先來解決如何判定重復。這里直接去leetcode上找一個題,獲取一個字符串中最大的重復片段,解法如下:

def longest_dup_substring(s: str) -> str:# 生成兩個進制a1, a2 = random.randint(26, 100), random.randint(26, 100)# 生成兩個模mod1, mod2 = random.randint(10**9+7, 2**31-1), random.randint(10**9+7, 2**31-1)n = len(s)# 先對所有字符進行編碼arr = [ord(c)-ord('a') for c in s]# 二分查找的范圍是[1, n-1]l, r = 1, n-1length, start = 0, -1while l <= r:m = l + (r - l + 1) // 2idx = check(arr, m, a1, a2, mod1, mod2)# 有重復子串,移動左邊界if idx != -1:l = m + 1length = mstart = idx# 無重復子串,移動右邊界else:r = m - 1return s[start:start+length] if start != -1 else ""def check(arr, m, a1, a2, mod1, mod2):n = len(arr)aL1, aL2 = pow(a1, m, mod1), pow(a2, m, mod2)h1, h2 = 0, 0for i in range(m):h1 = (h1 * a1 + arr[i]) % mod1h2 = (h2 * a2 + arr[i]) % mod2# 存儲一個編碼組合是否出現過seen = {(h1, h2)}for start in range(1, n - m + 1):h1 = (h1 * a1 - arr[start - 1] * aL1 + arr[start + m - 1]) % mod1h2 = (h2 * a2 - arr[start - 1] * aL2 + arr[start + m - 1]) % mod2# 如果重復,則返回重復串的起點if (h1, h2) in seen:return startseen.add((h1, h2))# 沒有重復,則返回-1return -1

效果如下:

longestDupSubstring('埃爾多安經濟學可以重振經濟,土耳其土耳其')
# '土耳其'

那么我們就可以寫一個processor,在每一個step即將生成的時候,判定一下,是否之前已經生成的結果中,出現了重復。以及,如果出現了重復,則禁止重復部分的第一個token(例如上面例子中,土耳其的土字),在當前step被生成。

針對實際使用中由這個processor引發的一些其他的問題,我又對這個processor增加了一點規則限制,一個比較好用的版本如下。

其中的參數threshold是判斷重復多少的情況算作循環,例如將threshold設置為10,那么如果重復部分的長度是3,重復了3次,3×3=9,則不被判定為陷入了循環,而如果重復了4次,3×4=12,則被判定為循環,此時processor將發揮效果了。

class ForbidDuplicationProcessor(LogitsProcessor):"""防止生成的內容陷入循環。當循環內容與循環次數之乘積大于指定次數則在生成下一個token時將循環內容的第一個token概率設置為0---------------ver: 2023-08-17by: changhongyu"""def __init__(self, tokenizer, threshold: int = 10):self.tokenizer = tokenizerself.threshold = thresholddef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:current_sequence = self.tokenizer.decode(input_ids[0][current_token_len: ])current_dup_str = longest_dup_substring(current_sequence)if len(current_dup_str):# 如果存在重復子序列,則根據其長度與重復次數判斷是否禁止循環if len(current_dup_str) > 1 or (len(current_dup_str) == 1 and current_dup_str * self.threshold in current_sequence):if len(current_dup_str) * current_sequence.count(current_dup_str) >= self.threshold:token_ids = self.tokenizer.encode(current_dup_str)# 獲取截止目前的上一個tokenlast_token = input_ids[0][-1].detach().cpu().numpy().tolist()if len(token_ids) and last_token == token_ids[-1]:# 如果截止目前的上一個token,與重復部分的最后一個token一致# 說明即將觸發重復, 先把重復部分的第一個token禁掉scores[:, token_ids[0]] = 0# 然后按出現比率判斷是否重復部分內還有其他重復for token_id in token_ids:if token_ids.count(token_id) * len(token_ids) > 1.2:scores[:, token_id] = 0return scores

需要注意的是,為了獲取當前的序列已經生成的長度,需要在processor的外部,也就是與model.generate同級的結構處,定義一個全局變量current_token_len

global current_token_len

3.2 防止數字無規則循環的processor

出了上述的情況,還有一種常見的循環,無法利用上面的規則解決,即數字無規則循環的情況。針對這個場景,創建另一個processor,只要連續出現的數字出現次數,大于一定的閾值,則禁止當前step再次生成數字。

class MaxConsecutiveProcessor(LogitsProcessor):"""給定一個集合,集合中的字符最多連續若干次下一次生成時不能再出現該集合中的字符---------------ver: 2023-08-17by: changhongyu---------------修復bugver: 2023-09-11"""def __init__(self, consecutive_token_ids, max_num: int = 10):self.consecutive_token_ids = consecutive_token_idsself.max_num = max_numdef __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:input_ids_list = input_ids.squeeze(0).detach().cpu().numpy().tolist()cur_num = 0for token in input_ids_list[::-1]:if token in self.consecutive_token_ids:cur_num += 1else:breakif cur_num >= self.max_num:# 如果連續次數超過閾值,那集合中的所有token在下一個step都不可以再出現for token_id in self.consecutive_token_ids:scores[..., token_id] = 0return scores

4. 使用

使用方法非常簡單,首先創建processor容器。對processor不熟悉的同學,可以去看之前的文章,有非常詳細的介紹。

logits_processor = LogitsProcessorList()

然后對于ChatGLM而言,需要先添加其默認的processor:

logits_processor.append(InvalidScoreLogitsProcessor())

接下來,再添加防止陷入循環的兩個processor:

number_tokens = [str(i) for i in range(10)] + ['.', '-']
number_token_ids = [tokenizer.convert_tokens_to_ids(tok) for tok in number_tokens]
logits_processor.append(ForbidDuplicationProcessor(tokenizer))
logits_processor.append(MaxConsecutiveProcessor(number_token_ids))

最后在調用generate的時候,把logits_processor作為參數傳進去就可以了。

以上便是使用logits_processor來防止大模型在生成過程中陷入循環的方法。經過我的反復調整,基本可以覆蓋大多數情景,如果在使用中遇到了bug,也歡迎指出。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/166756.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/166756.shtml
英文地址,請注明出處:http://en.pswp.cn/news/166756.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

實時語音克隆:5 秒內生成任意文本的語音 | 開源日報 No.84

CorentinJ/Real-Time-Voice-Cloning Stars: 43.3k License: NOASSERTION 這個開源項目是一個實時語音克隆工具&#xff0c;可以在5秒內復制一種聲音&#xff0c;并生成任意文本的語音。 該項目的主要功能包括&#xff1a; 從幾秒鐘的錄音中創建聲紋模型根據給定文本使用參考…

數字化轉型沒錢?沒人?沒IT?低代碼平臺輕松幫你搞定

隨著數字技術的不斷滲透&#xff0c;數字化已經不僅僅是一個趨勢&#xff0c;而是深入人心的日常生活部分。在這樣的時代背景下&#xff0c;企業面臨的挑戰也愈發嚴峻&#xff1a;如何不斷創新&#xff0c;滿足用戶日益增長的業務需求&#xff1f; 傳統的開發方式&#xff0c;隨…

基于單片機設計的大氣氣壓檢測裝置(STC89C52+BMP180實現)

一、前言 本項目設計一個大氣氣壓檢測裝置&#xff0c;該裝置以單片機為基礎&#xff0c;采用STC89C52作為核心控制芯片&#xff0c;結合BMP180模塊作為氣壓傳感器。大氣氣壓&#xff0c;也就是由氣體重力在大氣層中產生的壓力&#xff0c;其變化與天氣預報、氣象觀測以及高度…

江蘇某市人民醫院實現IT基礎資源統一監控

一、背景介紹 江蘇某市人民醫院是一家擁有豐富醫療資源和龐大患者群體的醫療機構。隨著醫療業務的不斷發展&#xff0c;其IT系統的規模和復雜性也不斷增加&#xff0c;涉及各類IT資源&#xff0c;包括服務器、網絡設備、數據庫、應用軟件等。為了提高IT系統的可靠性和穩定性&am…

11.7統一功能處理

一.登錄攔截器 1.實現一個普通的類,實現HeadlerInterceptor接口,重寫preHeadler方法. 2.將攔截器添加到配置中,并設定攔截規則. 二.訪問前綴添加 方法1: 方法2:properties 三.統一異常處理 以上返回的是空指針異常,如果是別的異常就不會識別,建議加上最終異常 . 四.統一數據格…

英語學習軟件 Eudic歐路詞典 mac中文版介紹說明

歐路詞典 mac (Eudic) 是一個功能強大的英語學習工具&#xff0c;它包含了豐富的英語詞匯、短語和例句&#xff0c;并提供了發音、例句朗讀、單詞筆記等功能。 Eudic歐路詞典 mac 軟件介紹 多語種支持&#xff1a;歐路詞典支持多種語言&#xff0c;包括英語、中文、日語、法語…

uni微信小程序 map 添加padding

問題背景&#xff1a; 規劃駕車線路的時候&#xff0c;使用uni的include-points指定可視范圍的時候&#xff0c;會很極限。導致marker不能完全顯示。 解決方法 給地圖顯示范圍添加padding (推薦) <mapid"myMap":markers"markers":polyline"pol…

視頻服務網關的三大部署(二)

視頻網關是軟硬一體的一款產品&#xff0c;可提供多協議&#xff08;RTSP/ONVIF/GB28181/海康ISUP/EHOME/大華、海康SDK等&#xff09;的設備視頻接入、采集、處理、存儲和分發等服務&#xff0c; 配合視頻網關云管理平臺&#xff0c;可廣泛應用于安防監控、智能檢測、智慧園區…

spark寫入關系型數據庫的duplicateIncs參數使用

在看一段spark寫數據到關系型數據庫代碼時&#xff0c;發現一個參數沒有見過&#xff1a; df.write.format("org.apache.spark.sql.execution.datasources.jdbc2").options(Map("savemode" -> JDBCSaveMode.Update.toString,"driver" -> …

Android13 launcher循環切頁

launcher 常規切頁&#xff1a;https://blog.csdn.net/a396604593/article/details/125305234 循環切頁 我們知道&#xff0c;launcher切頁是在packages\apps\Launcher3\src\com\android\launcher3\PagedView.java的onTouchEvent中實現的。 1、滑動限制 public boolean onT…

Python與設計模式--門面模式

8-Python與設計模式–門面模式 一、火警報警器&#xff08;1&#xff09; 假設有一組火警報警系統&#xff0c;由三個子元件構成&#xff1a;一個警報器&#xff0c;一個噴水器&#xff0c; 一個自動撥打電話的裝置。其抽象如下&#xff1a; class AlarmSensor:def run(self):…

c語言習題1124

分別定義函數求圓的面積和周長。 寫一個函數&#xff0c;分別求三個數當中的最大數。 寫一個函數&#xff0c;計算輸入n個數的乘積 一個判斷素數的函數&#xff0c;在主函數輸入一個整數&#xff0c;輸出是否為素數的信息 寫一個函數求n! ,利用該函數求1&#xff01;2&…

功率半導體器件CV測試系統

概述 電容-電壓(C-V)測量廣泛用于測量半導體參數&#xff0c;尤其是MOS CAP和MOSFET結構。MOS(金屬-氧化物-半導體)結構的電容是外加電壓的函數&#xff0c;MOS電容隨外加電壓變化的曲線稱之為C-V曲線&#xff08;簡稱C-V特性&#xff09;&#xff0c;C-V 曲線測試可以方便的確…

opencv-使用 Haar 分類器進行面部檢測

Haar 分類器是一種用于對象檢測的方法&#xff0c;最常見的應用之一是面部檢測。Haar 分類器基于Haar-like 特征&#xff0c;這些特征可以通過計算圖像中的積分圖來高效地計算。 在OpenCV中&#xff0c;Haar 分類器被廣泛用于面部檢測。以下是一個簡單的使用OpenCV進行面部檢測…

鴻蒙系統使用hdc_std.exe使用身份證讀卡器等外設USB獲得權限方法

hdc_std.exe是OpenHarmony 的命令行工具&#xff0c;由于使用的開源鴻蒙開發板上面沒有文件管理器&#xff0c;所以無法通過U盤等方式進行安裝.hap應用。 下面是使用hdc_std.exe安裝身份證讀卡器的步驟&#xff1a; 1、hdc_std.exe放桌面&#xff0c;然后WINR&#xff0c;打開…

CBTC 2023氫能展倒計時6天,最新同期會議活動Plus版發布

隨著時間的推移&#xff0c;CBTC2023深圳氫能技術展覽會即將拉開序幕。這場盛會將于11月30日在深圳福田會展中心盛大開幕&#xff0c;以“以儲賦能&#xff0c;智造未來”為主題&#xff0c;旨在搭建一個商務交流、供需合作、創新產品發布的平臺&#xff0c;讓氫能全產業鏈之間…

尋找質數 II

題目描述 輸入兩個整數 a&#xff0c;b&#xff0c;計算并輸出小于 a 的 b個質數&#xff0c;所有符合條件的質數里&#xff0c;輸出最大的 b 個質數&#xff0c;按照從大到小輸出&#xff0c;使用空格隔開。 假如符合條件的數量不夠&#xff0c;則輸出已經滿足的質數。 如果…

詳解Java中的異常體系機構(throw,throws,try catch,finally)

目錄 一.異常的概念 二.異常的體系結構 三.異常的處理 異常處理思路 LBYL&#xff1a;Look Before You Leap EAFP: Its Easier to Ask Forgiveness than Permission 異常拋出throw 異常的捕獲 提醒聲明throws try-catch捕獲處理 finally的作用 四.自定義異常類 一.異…

微信小程序:This Mini Program cannot be opened as your Weixin version is out-of-date.

項目場景&#xff1a; 問題描述 升級基礎庫3.2.0&#xff0c;然后PC端整個小程序都打不開了&#xff0c;點擊小程序提示”This Mini Program cannot be opened as your Weixin version is out-of-date. Update Weixin to the latest version.“&#xff0c;并且點擊Update Wei…

一個悄然崛起的國產軟件!!AI 又進化了!!

大家好&#xff0c;我是 Jack。 AI 寫代碼想必很多人都體驗過了&#xff0c;使用 AI 編程工具是一個大趨勢&#xff0c;越早學會使用 AI 輔助你寫代碼&#xff0c;你的效率也會越高。 甚至有些公司已經要求員工具備 AI 編程能力。 對于學生黨&#xff0c;AI 編程可以幫助我們…