python pytorch tensorflow transforms 模型培訓腳本

環境準備
https://www.doubao.com/thread/w5e26d6401c003bb2

執行培訓腳本

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW
import numpy as np# 自定義數據集類
class SentimentDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_length):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_length = max_lengthdef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]encoding = self.tokenizer.encode_plus(text,add_special_tokens=True,max_length=self.max_length,padding='max_length',truncation=True,return_tensors='pt')return {'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'labels': torch.tensor(label, dtype=torch.long)}# 新增問答數據
qa_texts = ["李白是那個朝代的詩人?", "地球的衛星是什么?", "中國的首都是哪里?", "水的化學式是什么?", "蘋果公司的創始人是誰?"]
qa_labels = [1, 2, 3, 4, 5]  # 為每個問題分配一個唯一的標簽# 合并數據
all_texts = qa_texts
all_labels = qa_labels# 初始化 tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')# 創建數據集和數據加載器
dataset = SentimentDataset(all_texts, all_labels, tokenizer, max_length=128)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)# 初始化模型,注意 num_labels 需要根據總標簽數調整
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=8)# 定義優化器,加入 weight_decay 進行正則化
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)# 訓練模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)# 早停策略相關參數
best_loss = float('inf')
patience = 3
counter = 0for epoch in range(20):  # 增加訓練輪數,但結合早停策略total_loss = 0model.train()for batch in dataloader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)optimizer.zero_grad()outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.lossloss.backward()optimizer.step()total_loss += loss.item()avg_loss = total_loss / len(dataloader)print(f'Epoch {epoch + 1}, Loss: {avg_loss}')# 早停策略if avg_loss < best_loss:best_loss = avg_losscounter = 0else:counter += 1if counter >= patience:print("Early stopping triggered!")break# 處理提問的函數
def handle_query(model, tokenizer, query_text, device):model.eval()with torch.no_grad():encoding = tokenizer(query_text, return_tensors='pt', padding=True, truncation=True, max_length=128).to(device)logits = model(**encoding).logitspredictions = np.argmax(logits.cpu().numpy(), axis=1)return predictions# 情感標簽映射
label_mapping = {1: "唐朝",2: "月球",3: "北京",4: "H2O",5: "史蒂夫·喬布斯、史蒂夫·沃茲尼亞克和羅恩·韋恩"
}# 針對新增的5條數據提問
query_text = ["李白是那個朝代的詩人?", "地球的衛星是什么?", "中國的首都是哪里?", "水的化學式是什么?", "蘋果公司的創始人是誰?"]
result = handle_query(model, tokenizer, query_text, device)
readable_result = [label_mapping[pred] for pred in result]
print("Query result:", readable_result)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/72755.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/72755.shtml
英文地址,請注明出處:http://en.pswp.cn/web/72755.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

request庫基礎學習

requests安裝 Windows &#xff1a;pip install requests mac &#xff1a; python3 -m pip install requests requests模塊常用方法 方法含義requests.get()發起get請求requests.post()發起post請求requests.put()發起put請求requests.delete()發起delete請求requests.sess…

Redis客戶端Jedis、Lettuce 和 Redisson優缺點總結

https://developer.huawei.com/consumer/cn/blog/topic/03825550899620047 Redis 官方推薦的 Java 客戶端有Jedis、Lettuce 和 Redisson。本文總結這些客服端的優缺點 1. Jedis Jedis 是老牌的 Redis 的 Java 實現客戶端&#xff0c;提供了比較全面的 Redis 命令的支持&#…

在 Spring Boot 中調用 AnythingLLM 的發消息接口

整體邏輯: 自建系統的web UI界面調用接口: 1.SpringBoot接口&#xff1a;/anything/chatMessageAnything 2.調用anythingLLM - 調用知識庫deepseek r1 . Windows Installation ~ AnythingLLMhttps://docs.anythingllm.com/installation-desktop/windows http://localhost:3…

kubectl describe pod 命令以及輸出詳情講解

kubectl describe pod 命令格式 kubectl describe pod <pod-name> -n <namespace><pod-name>&#xff1a;Pod 的名稱。 -n <namespace>&#xff1a;指定命名空間&#xff0c;默認是當前命名空間。 controlplane ~ ? kubectl describe pod newpods-d…

Python生成和安裝requirements.txt

概述 看到別的大佬項目中&#xff0c;requirements.txt文件&#xff0c;里面包含了所需要的依賴及版本&#xff0c;方便項目管理和安裝。 生成 requirements.txt 文件 pip3 freeze > requirements.txt生成的依賴包有點多&#xff0c;感覺可以根據自己需要整理。 安裝req…

WebGL學習2

WebGL&#xff08;Web Graphics Library&#xff09;是一種基于 OpenGL ES 2.0 的 JavaScript API&#xff0c;用于在網頁上實現高性能的 3D 圖形渲染。 1. 初始化 WebGL 上下文 在使用 WebGL 之前&#xff0c;需要獲取<canvas>元素并創建 WebGL 上下文。 // 獲取canv…

零知識證明:區塊鏈隱私保護的變革力量

&#x1f9d1; 博主簡介&#xff1a;CSDN博客專家&#xff0c;歷代文學網&#xff08;PC端可以訪問&#xff1a;https://literature.sinhy.com/#/literature?__c1000&#xff0c;移動端可微信小程序搜索“歷代文學”&#xff09;總架構師&#xff0c;15年工作經驗&#xff0c;…

【java】集合的基本使用

集合是 Java 中用來存儲一組對象的容器。與數組相比&#xff0c;集合更加靈活和強大&#xff0c;支持動態增刪元素、自動擴容、多種數據結構等特性。下面我會用通俗易懂的語言解釋集合的基本使用。 1. 什么是集合&#xff1f; 集合就像是一個“容器”&#xff0c;可以用來裝很多…

WPF-實現按鈕的動態變化

MVVM 模式基礎 視圖模型&#xff08;ViewModel&#xff09;&#xff1a;MainViewModel類作為視圖模型&#xff0c;封裝了與視圖相關的屬性和命令。它實現了INotifyPropertyChanged接口&#xff0c;當屬性值發生改變時&#xff0c;通過OnPropertyChanged方法通知視圖進行更新&am…

主流NoSQL數據庫類型及選型分析

在數據庫領域&#xff0c;不同類型的數據庫針對不同場景設計&#xff0c;以下是四類主流NoSQL數據庫的對比分析&#xff1a; 一、核心特性對比 鍵值數據庫&#xff08;Key-Value&#xff09; 數據模型&#xff1a;簡單鍵值對存儲 特點&#xff1a;毫秒級讀寫、高并發、無固定…

西門子PLC

西門子PLC與C#通信全解析&#xff1a;從協議選型到實戰開發 一、西門子PLC通信協議概述 西門子PLC支持多種通信協議&#xff0c;需根據設備型號及項目需求選擇&#xff1a; S7協議 西門子私有協議&#xff0c;適用于S7-200/300/400/1200/1500系列PLC特點&#xff1a;直接訪問…

Visual Studio(VS)的 Release 配置中生成程序數據庫(PDB)文件

最近工作中的一個測試工具在測試多臺設備上使用過程中閃退&#xff0c;存了dump&#xff0c;但因為是release版本&#xff0c;沒有pdb&#xff0c;無法根據dump定位代碼哪塊出了問題&#xff0c;很苦惱&#xff0c;查了下怎么加pdb生成&#xff0c;記錄一下。以下是具體的設置步…

★ Linux ★ 進程(上)

Ciallo&#xff5e;(∠?ω< )⌒☆ ~ 今天&#xff0c;我將和大家一起學習 linux 進程~ ????????????????????????????? 澄嵐主頁&#xff1a;椎名澄嵐-CSDN博客 Linux專欄&#xff1a;https://blog.csdn.net/2302_80328146/category_12815302…

JAVA并發-volatile底層原理

volatile相當于是一個輕量級的synchronized&#xff0c;一般作用在變量上&#xff0c;它具有三個特性&#xff1a;可見性、有序性&#xff0c;相比于synchronized&#xff0c;他的執行成本更低。 先來說可見性&#xff0c;java存在共享變量不可見性的原因就是&#xff0c;線程…

Java面試第十一山!《SpringCloud框架》

大家好&#xff0c;我是陳一。如果文章對你有幫助&#xff0c;請留下一個寶貴的三連哦&#xff5e; 萬分感謝&#xff01; 目錄 一、Spring Cloud 是什么? 二、Spring Cloud 核心組件? 1. 服務發現 - Eureka? 2. ?負載均衡 - Ribbon? 3. 斷路器 - Hystrix? ??4. …

Transaction rolled back because it has been marked as rollback-only問題解決

transaction rolled back because it has been marked as rollback-only 簡略總結> 發生場景&#xff1a;try-catch多業務場景 發生原因&#xff1a;業務嵌套&#xff0c;事務管理混亂&#xff0c;外層業務與內層業務拋出異常節點與回滾節點不一致。 解決方式&#xff1a;修…

sql server數據遷移,springboot搭建開發環境遇到的問題及解決方案

最近搭建springboot項目開發環境&#xff0c;數據庫連的是sql server&#xff0c;遇到許多問題在此記錄一下。 1、sql server安裝教程 參考&#xff1a;https://www.bilibili.com/opus/944736210624970769 2、sql server導出、導入數據庫 參考&#xff1a;https://blog.csd…

【數學建模】灰色關聯分析模型詳解與應用

灰色關聯分析模型詳解與應用 文章目錄 灰色關聯分析模型詳解與應用引言灰色系統理論簡介灰色關聯分析基本原理灰色關聯分析計算步驟1. 確定分析序列2. 數據無量綱化處理3. 計算關聯系數4. 計算關聯度 灰色關聯分析應用實例實例&#xff1a;某企業生產效率影響因素分析 灰色關聯…

Spring配置文件-Bean實例化三種方式

無參構造方法實例化 工廠靜態方法實例化 工廠實例方法實例化

SSL 和 TLS 認證

SSL&#xff08;Secure Sockets Layer&#xff0c;安全套接層&#xff09;認證是一種用于加密網絡通信和驗證服務器身份的安全技術。它是TLS&#xff08;Transport Layer Security&#xff0c;傳輸層安全協議&#xff09;的前身&#xff0c;雖然現在大多數應用使用的是TLS&…