昇思25天學習打卡營第7天 | 基于MindSpore的GPT2文本摘要

本次打卡基于gpt2的文本摘要

數據加載及預處理

from mindnlp.utils import http_get# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')from mindspore.dataset import TextFileDataset# load dataset
dataset = TextFileDataset(str(path), shuffle=False)
dataset.get_dataset_size()# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)import json
import numpy as np# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def merge_and_pad(article, summary):# tokenization# pad to max_seq_length, only truncate the articletokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)return tokenized['input_ids'], tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])# change column names to input_ids and labels for the following trainingdataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])dataset = dataset.batch(batch_size)if shuffle:dataset = dataset.shuffle(batch_size)return datasetfrom mindnlp.transformers import BertTokenizer# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)

?模型構建?

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModelclass GPT2ForSummarization(GPT2LMHeadModel):def construct(self,input_ids = None,attention_mask = None,labels = None,):outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :]shift_labels = labels[..., 1:]# Flatten the tokensloss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)return lossfrom mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateScheduleclass LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate."""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()self.learning_rate = learning_rateself.num_warmup_steps = num_warmup_stepsself.num_training_steps = num_training_stepsdef construct(self, global_step):if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_ratereturn ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_ratenum_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModelconfig = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)# 記錄模型參數數量
print('number of model parameters: {}'.format(model.num_parameters()))from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallbackckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',epochs=1, keep_checkpoint_max=2)trainer = Trainer(network=model, train_dataset=train_dataset,epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 開啟混合精度trainer.run(tgt_columns="labels")

?

?結論

gpt2相較bert等模型,在文本識別、文本摘要、命名體識別中有著優秀的表現,但其模型規模相對較大,訓練時間較長,打卡中展示的沒有完成訓練,這里需要更好的gpu來輔助訓練。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/45449.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/45449.shtml
英文地址,請注明出處:http://en.pswp.cn/web/45449.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

以太坊(以太坊solidity合約)

以太坊&#xff08;以太坊solidity合約&#xff09; 1&#xff0c;以太坊2&#xff0c;開發名詞解釋&#xff08;1&#xff09;錢包&#xff08;2&#xff09;Solidity&#xff08;3&#xff09;Ether&#xff08;以太幣&#xff09;&#xff08;4&#xff09;Truffle&#xff…

Redis 7.x 系列【23】哨兵模式

有道無術&#xff0c;術尚可求&#xff0c;有術無道&#xff0c;止于術。 本系列Redis 版本 7.2.5 源碼地址&#xff1a;https://gitee.com/pearl-organization/study-redis-demo 文章目錄 1. 概述2. 工作原理2.1 監控2.2 標記下線2.3 哨兵領袖2.4 新的主節點2.5 通知更新 3. …

請求響應(后端必備)

一、請求 1.簡單參數 原始方式&#xff1a; 在原始的web程序中&#xff0c;獲取請求參數&#xff0c;需要通過HttpServletRequest對象手動獲取 RequestMapping("/simpleParam")public String simpleParam(HttpServletRequest request){String name request.getP…

什么叫價內期權?直接帶你了解期權價內期權怎么使用?!

今天帶你了解什么叫價內期權&#xff1f;直接帶你了解期權價內期權怎么使用&#xff1f;&#xff01;價內期權是具有內在價值的期權。期權持有人行權時&#xff0c;對看漲期權而言&#xff0c;行權價格低于標的證券結算價格&#xff1b;對看跌期權而言&#xff0c;標的證券結算…

js 請求blob:https:// 圖片

方式1 def get_file_content_chrome(driver, uri):result driver.execute_async_script("""var uri arguments[0];var callback arguments[1];var toBase64 function(buffer){for(var r,nnew Uint8Array(buffer),tn.length,anew Uint8Array(4*Math.ceil(t/…

前端Vue組件化實踐:自定義加載組件的探索與應用

在前端開發領域&#xff0c;隨著業務邏輯復雜度的提升和系統規模的不斷擴大&#xff0c;傳統的開發方式逐漸暴露出效率低下、維護困難等問題。為了解決這些挑戰&#xff0c;組件化開發作為一種高效、靈活的開發模式&#xff0c;受到了越來越多開發者的青睞。本文將結合實踐&…

Java基礎及進階

JAVA特性 基礎語法 一、Java程序的命令行工具 二、final、finally、finalize 三、繼承 class 父類 { //代碼 }class 子類 extends 父類 { //代碼 }四、Vector、ArrayList、LinkedList 五、原始數據類型和包裝類 六、接口和抽象類 JAVA進階 Java引用隊列 Object counter ne…

PostgreSQL行級安全策略探究

前言 最近和朋友討論oracle行級安全策略(VPD)時&#xff0c;查看了下官方文檔&#xff0c;看起來VPD的原理是針對應用了Oracle行級安全策略的表、視圖或同義詞發出的 SQL 語句動態添加where子句。通俗理解就是將行級安全策略動態添加為where 條件。那么PG中的行級安全策略是怎…

搭建基于 ChatGPT 的問答系統

搭建基于 ChatGPT 的問答系統 &#x1f4e3;1.簡介&#x1f4e3;2.語言模型&#xff0c;提問范式和 token?2.1語言模型?2.2Tokens?2.3Helper function輔助函數&#xff08;提問范式&#xff09; &#x1f4e3;3.評估輸入-分類&#x1f4e3;4.檢查輸入-審核?4.1審核4.1.1 我…

使用UDP通信接收與發送Mavlink2.0協議心跳包完整示例

1.克隆mavlink源碼 https://github.com/mavlink/mavlink.git 2.進入mavlink目錄,安裝依賴 python3 -m pip install -r pymavlink/requirements.txt 3.生成Mavlink的C頭文件 mavlink % python3 -m pymavlink.tools.mavgen --lang=C --wire-protocol=2.0 --output=generated…

1-5歲幼兒胼胝體的表面形態測量

摘要 胼胝體(CC)是大腦中的一個大型白質纖維束&#xff0c;它參與各種認知、感覺和運動過程。盡管CC與多種發育和精神疾病有關&#xff0c;但關于這一結構的正常發育(特別是在幼兒階段)還有很多待解開的謎團。雖然早期文獻中報道了性別二態性&#xff0c;但這些研究的觀察結果…

【Linux網絡】select{理解認識select/select與多線程多進程/認識select函數/使用select開發并發echo服務器}

文章目錄 0.理解/認識回顧回調函數select/pollread與直接使用 read 的效率差異 1.認識selectselect/多線程&#xff08;Multi-threading&#xff09;/多進程&#xff08;Multi-processing&#xff09;select函數socket就緒條件select的特點總結 2.select下echo服務器封裝套接字…

C++ 類和對象 賦值運算符重載

前言&#xff1a; 在上文我們知道數據類型分為自定義類型和內置類型&#xff0c;當我想用內置類型比較大小是非常容易的但是在C中成員變量都是在類(自定義類型)里面的&#xff0c;那我想給類比較大小那該怎么辦呢&#xff1f;這時候運算符重載就出現了 一 運算符重載概念&…

安全防御:防火墻基本模塊

目錄 一、接口 1.1 物理接口 1.2 虛擬接口 二、區域 三、模式 3.1 路由模式 3.2 透明模式 3.3 旁路檢測模式 3.4 混合模式 四、安全策略 五、防火墻的狀態檢測和會話表技術 一、接口 1.1 物理接口 三層口 --- 可以配置IP地址的接口 二層口&#xff1a; 普通二層…

Java面試題:分庫分表

分庫分表 當數據量非常大時,就需要通過分庫分表的方式進行壓力分攤,避免數據庫訪問壓力過大 分庫分表的前提: 業務數據達到一定量級:單表數據量達到1000w或20g 優化解決不了性能問題 分庫分表策略 垂直拆分 垂直分庫 以表為依據,根據業務將不同表拆分到不同庫中 eg:根…

車載終端_RTK定位|4路攝像頭|駕駛輔助系統ADAS定制方案

現代車輛管理行業的發展趨勢逐漸向智能化和高效化方向發展&#xff0c;車載終端成為關鍵的工具之一。在這個背景下&#xff0c;一款特別為車隊管理行業設計的車載終端應運而生。該車載終端采用8寸多點觸控電容屏&#xff0c;搭載聯發科四核處理器&#xff0c;主頻2.0GHz&#x…

如何安裝node.js

Node.js Node.js 是一個基于 Chrome V8 引擎的 JavaScript 運行時環境。 主要特點和優勢&#xff1a; 非阻塞 I/O 和事件驅動&#xff1a;能夠高效處理大量并發連接&#xff0c;非常適合構建高并發的網絡應用&#xff0c;如 Web 服務器、實時聊天應用等。 例如&#xff0c;在…

FeignClient詳解

FeignClient 是 Spring Cloud Open Feign 中的一個注解&#xff0c;它用于定義一個 Feign 客戶端&#xff0c;Feign 是一個聲明式的 Web 服務客戶端&#xff0c;使得編寫 Web 服務客戶端變得更加簡單。以下是 FeignClient 注解的詳細說明&#xff1a; 定義 Feign 客戶端&#x…

網絡安全——防御(防火墻)帶寬以及雙機熱備實驗

12&#xff0c;對現有網絡進行改造升級&#xff0c;將當個防火墻組網改成雙機熱備的組網形式&#xff0c;做負載分擔模式&#xff0c;游客區和DMZ區走FW3&#xff0c;生產區和辦公區的流量走FW1 13&#xff0c;辦公區上網用戶限制流量不超過100M&#xff0c;其中銷售部人員在其…