摘要
我們探討了生成思維鏈——一系列中間推理步驟——如何顯著提升大型語言模型執行復雜推理的能力。特別地,我們展示了在足夠大的語言模型中,這種推理能力如何通過一種簡單的方法——思維鏈提示(chain-of-thought prompting)自然地顯現出來。該方法在提示中提供少量思維鏈示例作為范例。
在三種大型語言模型上的實驗表明,思維鏈提示能夠提升模型在算術、常識和符號推理任務上的表現。實驗結果非常顯著。例如,僅使用八個思維鏈示例提示PaLM 540B,就在數學文字題基準GSM8K上達到了最先進的準確率,甚至超越了經過微調并帶有驗證器的GPT-3。
1.引言
自然語言處理領域最近因語言模型的發展而發生了革命性變化(Peters 等,2018;Devlin 等,2019;Brown 等,2020 等)。擴大語言模型的規模已被證明帶來多方面的好處,例如提升性能和樣本效率(Kaplan 等,2020;Brown 等,2020 等)。然而,僅僅增加模型規模并不足以在算術、常識和符號推理等具有挑戰性的任務上取得高性能(Rae 等,2021)。
本研究探討了一種簡單方法,如何激發大型語言模型的推理能力,該方法基于兩個核心思想。首先,算術推理技術可以通過生成引導至最終答案的自然語言推理過程獲益。此前的工作使模型具備了生成自然語言中間步驟的能力,這包括從零開始訓練(Ling 等,2017)或對預訓練模型進行微調(Cobbe 等,2021),此外還有使用形式語言而非自然語言的神經符號方法(Roy 和 Roth,2015;Chiang 和 Chen,2019;Amini 等,2019;Chen 等,2019)。其次,大型語言模型帶來了通過提示實現上下文內少量樣本學習的令人興奮的可能性。也就是說,不需要針對每個新任務微調單獨的語言模型檢查點,而是可以簡單地通過“提示”模型幾個輸入-輸出示例來演示任務。令人驚訝的是,這種方法已在多種簡單問答任務中取得了成功(Brown 等,2020)。
上述兩種方法均存在關鍵限制。對于增強推理過程的訓練和微調方法,制作大量高質量的推理步驟成本較高,這比普通機器學習中使用的簡單輸入-輸出對復雜得多。而傳統的少樣本提示方法(如Brown等,2020年所用)在需要推理能力的任務上表現不佳,且隨著語言模型規模的增大,性能提升往往不明顯(Rae等,2021年)。本文結合了這兩種思想的優勢,同時避免了它們的缺陷。具體而言,我們探索了語言模型在推理任務中利用少樣本提示的能力,提示中包含三元組:〈輸入,推理鏈,輸出〉。推理鏈是一系列中間的自然語言推理步驟,最終導向答案,我們將這種方法稱為“推理鏈提示法(chain-of-thought prompting)”。圖1展示了一個示例提示。
我們在算術、常識和符號推理基準測試中進行了實證評估,結果顯示推理鏈提示法優于標準提示,有時差距顯著。圖2展示了其中一個結果——在數學文字題基準GSM8K(Cobbe等,2021)上,使用PaLM 540B的推理鏈提示法遠超標準提示,取得了新的最先進性能。僅通過提示的方式尤為重要,因為它不需要大量訓練數據,并且單一模型檢查點可以在多種任務上通用而不損失表現。此工作強調了大型語言模型如何通過少量示例和自然語言的任務信息學習(相比于通過大量訓練數據自動學習輸入輸出之間的模式)。
2 推理鏈提示法
想象一下自己在解決復雜推理任務(例如多步驟的數學文字題)時的思考過程。通常,我們會將問題分解為若干中間步驟,逐步求解,最后給出答案:“簡給媽媽2朵花后,她還有10朵……然后她給爸爸3朵后,就剩7朵……所以答案是7。”本文的目標是賦予語言模型生成類似推理鏈的能力——即生成一系列連貫的中間推理步驟,最終得到問題的答案。我們將展示,只要在少樣本提示中提供了推理鏈的示范,足夠大的語言模型就能生成這樣的推理鏈。
圖1展示了一個模型為了解決一題數學文字題而生成的推理鏈,這道題如果不使用推理鏈,模型本來會答錯。此處的推理鏈類似于解題過程,也可以看作是一個解決方案,但我們仍稱之為推理鏈,以更好地體現它模擬了逐步思考的過程(而且通常解答和解釋是在最終答案之后給出的,見Narang等,2020;Wiegreffe等,2022;Lampinen等,2022等研究)。
推理鏈提示法作為促進語言模型推理能力的手段,具有以下幾個顯著優勢:
- 首先,推理鏈原則上允許模型將多步驟問題拆解成中間步驟,意味著可以為需要更多推理步驟的問題分配額外計算資源。
- 其次,推理鏈為模型的行為提供了可解釋的窗口,幫助理解模型如何得出某一答案,并為調試推理路徑中的錯誤提供機會(盡管完整刻畫支持答案的計算過程仍是一個開放問題)。
- 第三,推理鏈推理可以應用于數學文字題、常識推理和符號運算等任務,且原則上適用于任何人類能通過語言解決的任務。
- 最后,只需在少樣本提示示例中加入推理鏈示范,就能輕松在足夠大的現成語言模型中激發推理鏈推理能力。
在實證實驗中,我們將觀察推理鏈提示法在算術推理(第3節)、常識推理(第4節)和符號推理(第5節)上的效用。
3 算術推理
我們首先考慮如圖1所示形式的數學文字題,這類題用于測量語言模型的算術推理能力。盡管對人類來說很簡單,但算術推理是語言模型經常感到困難的任務(見Hendrycks等,2021;Patel等,2021等)。令人驚訝的是,使用推理鏈提示法并配合5400億參數的語言模型,在多個任務中表現可與針對特定任務微調的模型相媲美,甚至在具有挑戰性的GSM8K基準測試(Cobbe等,2021)上取得了新的最先進成績。
3.1 實驗設置
我們在多種語言模型和多個基準測試上探索推理鏈提示法的效果。
基準測試。我們考慮以下五個數學文字題基準數據集:
(1) GSM8K數學文字題基準(Cobbe等,2021);
(2) 結構多樣的SVAMP數學文字題數據集(Patel等,2021);
(3) 多樣化數學文字題的ASDiv數據集(Miao等,2020);
(4) 代數文字題的AQuA數據集;
(5) MAWPS基準(Koncel-Kedziorski等,2016)。
示例題目見附錄表12。
標準提示法。作為基線,我們采用Brown等(2020)推廣的標準少樣本提示法,即給語言模型提供輸入-輸出對的上下文示例,之后模型對測試樣本給出預測。示例格式為問題和答案,模型直接給出答案,如圖1左側所示。
推理鏈提示法。我們提出的方法是在少樣本提示的每個示例中,增加與答案對應的推理鏈,如圖1右側所示。由于大多數數據集僅有評估集,我們手工編寫了一組包含八個推理鏈的少樣本示例用于提示——圖1右側展示了其中一個推理鏈示例,完整示例集見附錄表20。(這些示例未經過提示工程處理;第3.4節和附錄A.2討論了方法的魯棒性。)
為了驗證這種形式的推理鏈提示法是否能夠在多種數學文字題上成功激發推理能力,我們對除AQuA外的所有基準均使用了同一組八個推理鏈示例。由于AQuA是多選題而非自由回答,我們使用了訓練集中四個示例及其解答,見附錄表21。
語言模型。我們評估了五個大型語言模型。
第一個是GPT-3(Brown等,2020),我們使用了text-ada-001、text-babbage-001、text-curie-001和text-davinci-002,這些模型大致對應參數規模為3.5億、13億、67億和1750億的InstructGPT模型(Ouyang等,2022)。
第二個是LaMDA(Thoppilan等,2022),包含參數規模為4.22億、20億、80億、680億和1370億的多個模型。
第三個是PaLM,擁有參數規模為80億、620億和5400億的模型。
第四個是UL2 20B(Tay等,2022),第五個是Codex(Chen等,2021,OpenAI API中的code-davinci-002)。
我們通過貪心解碼從模型中采樣(盡管后續工作表明,鏈式思維提示法可通過對多次生成結果進行多數投票以提升性能(Wang等,2022a))。
對于LaMDA,我們報告了五個隨機種子下的平均結果,每個種子對應一組不同隨機打亂的示例順序。由于LaMDA實驗中不同種子的結果方差不大,為節省計算資源,其它模型僅報告單一示例順序的結果。
3.2 結果
鏈式思維提示法的最強結果總結在圖4中,所有模型集合、模型規模和基準測試的實驗輸出詳見附錄表2。有三個關鍵要點:首先,圖4顯示鏈式思維提示是一種隨模型規模自然出現的能力(Wei等,2022b)。也就是說,對于小規模模型,鏈式思維提示并不會帶來性能提升,只有當模型規模達到約1000億參數時,才開始產生性能提升。我們通過定性分析發現,小規模模型雖然能夠生成流暢的鏈式思維文本,但邏輯不夠嚴密,反而導致表現低于標準提示法。
第二,鏈式思維提示在更復雜的問題上帶來更大的性能提升。例如,在GSM8K數據集(基線表現最低的數據集)上,最大的GPT和PaLM模型的性能提升超過一倍。相反,對于MAWPS中最簡單的SingleOp子集(只需要一步即可解決的問題),性能提升則要么是負面,要么非常小(詳見附錄表3)。
第三,使用GPT-3 175B和PaLM 540B的鏈式思維提示,在性能上優于以往通常通過在標注訓練集上微調任務專用模型所達到的最新水平。圖4顯示,PaLM 540B通過鏈式思維提示,在GSM8K、SVAMP和MAWPS上達到了新的最先進成績(盡管標準提示在SVAMP上已超過了之前的最好成績)。在另外兩個數據集AQuA和ASDiv上,PaLM使用鏈式思維提示的表現也接近最先進水平,差距不到2%(詳見附錄表2)。
為了更好地理解鏈式思維提示為何有效,我們對LaMDA 137B在GSM8K上的模型生成的思維鏈進行了人工檢查。在50個模型給出正確最終答案的隨機樣本中,除了兩例偶然得到正確答案的情況,所有生成的思維鏈在邏輯和數學上均正確(詳見附錄D.1及表8中的正確示例)。我們還隨機檢查了50個模型答錯的樣本,分析總結顯示,46%的思維鏈幾乎正確,僅存在一些小錯誤(如計算器錯誤、符號映射錯誤或缺少一步推理),而其余54%的思維鏈則存在語義理解或連貫性上的重大錯誤(詳見附錄D.2)。
為了探究模型規模為何能提升鏈式思維的推理能力,我們對PaLM 62B的錯誤進行了類似分析,并比較這些錯誤是否被擴展到PaLM 540B時修正。結果表明,將模型擴展到540B參數,修正了62B模型中大量缺失一步推理和語義理解的錯誤(詳見附錄A.1)。
3.3 消融研究
使用鏈式思維提示帶來的性能提升,引出了一個自然的問題:是否通過其他類型的提示也能帶來類似的性能提升?圖5展示了三個鏈式思維變體的消融實驗,具體如下。
僅輸出方程。鏈式思維提示之所以有效的一個原因可能是它生成了需要計算的數學方程,因此我們測試了一個變體,模型僅被提示先輸出數學方程,再給出答案。圖5顯示,僅輸出方程的提示對GSM8K幫助不大,這暗示GSM8K中問題的語義過于復雜,難以在沒有鏈式思維中自然語言推理步驟的情況下,直接轉換成方程。然而,對于只有一步或兩步問題的數據集,我們發現僅輸出方程的提示確實能提升性能,因為這些問題的方程可以很容易地從問題中推導出來(詳見附錄表6)。
變量計算。另一種直覺認為,鏈式思維讓模型能在更復雜的問題上投入更多計算(即生成更多中間推理步驟的標記)。為了將變量計算的效果與鏈式思維推理區分開,我們測試了一個變體:模型被提示輸出一串點(“: : :”),點的數量等于解題所需方程字符數。這個變體的表現與基線差不多,說明單純的變量計算并不是鏈式思維提示成功的原因,表明通過自然語言表達中間步驟確實帶來了額外價值。
答案后鏈式思維。鏈式思維提示的另一個潛在好處可能是幫助模型更好地訪問預訓練期間獲得的相關知識。為此,我們測試了另一種設置——鏈式思維提示僅在答案之后給出,以檢驗模型是否真正依賴鏈式思維過程來得出最終答案。這個變體的表現與基線相當,表明鏈式思維中體現的順序推理過程對于提升性能有重要作用,不只是激活已有知識那么簡單。
3.4 Chain of Thought 的魯棒性
示例的敏感性是提示方法中的一個關鍵考慮因素——例如,改變few-shot示例的排列順序,可能導致GPT-3在SST-2任務上的準確率從接近隨機的54.3%波動到接近最先進水平的93.4%(Zhao等,2021)。在本節的最后一小節中,我們評估了不同標注者編寫的鏈式思維的魯棒性。除了上述使用標注者A編寫的鏈式思維外,本文的另外兩位合著者(標注者B和C)也獨立為相同的few-shot示例編寫了鏈式思維(見附錄H)。標注者A還編寫了另一組比原始版本更簡潔的鏈式思維,風格參考了Cobbe等(2021)中給出的解題示例。
圖6展示了LaMDA 137B模型在GSM8K和MAWPS數據集上的實驗結果(其他數據集的消融結果見附錄表6和表7)。雖然不同鏈式思維注釋之間存在差異,這在使用基于示例的提示時是預料之中的(Le Scao和Rush,2021;Reynolds和McDonell,2021;Zhao等,2021),但所有鏈式思維提示集均遠超標準基線。這一結果表明,鏈式思維的成功應用不依賴于特定的語言風格。
溫馨提示:
閱讀全文請訪問"AI深語解構" Cot2:思維鏈提示激發大型語言模型的推理能力