Day08 【基于jieba分詞實現詞嵌入的文本多分類】

基于jieba分詞的文本多分類

      • 目標
      • 數據準備
      • 參數配置
      • 數據處理
      • 模型構建
      • 主程序
      • 測試與評估
      • 測試結果

目標

本文基于給定的詞表,將輸入的文本基于jieba分詞分割為若干個詞,然后將詞基于詞表進行初步編碼,之后經過網絡層,輸出在已知類別標簽上的概率分布,從而實現一個簡單文本的多分類。

數據準備

詞表文件chars.txt

類別標簽文件schema.json

{"停機保號": 0,"密碼重置": 1,"寬泛業務問題": 2,"親情號碼設置與修改": 3,"固話密碼修改": 4,"來電顯示開通": 5,"親情號碼查詢": 6,"密碼修改": 7,"無線套餐變更": 8,"月返費查詢": 9,"移動密碼修改": 10,"固定寬帶服務密碼修改": 11,"UIM反查手機號": 12,"有限寬帶障礙報修": 13,"暢聊套餐變更": 14,"呼叫轉移設置": 15,"短信套餐取消": 16,"套餐余量查詢": 17,"緊急停機": 18,"VIP密碼修改": 19,"移動密碼重置": 20,"彩信套餐變更": 21,"積分查詢": 22,"話費查詢": 23,"短信套餐開通立即生效": 24,"固話密碼重置": 25,"解掛失": 26,"掛失": 27,"無線寬帶密碼修改": 28
}

訓練集數據train.json訓練集數據

驗證集數據valid.json驗證集數據

參數配置

config.py

# -*- coding: utf-8 -*-"""
配置參數信息
"""Config = {"model_path": "model_output","schema_path": "../data/schema.json","train_data_path": "../data/train.json","valid_data_path": "../data/valid.json","vocab_path":"../chars.txt","max_length": 20,"hidden_size": 128,"epoch": 10,"batch_size": 32,"optimizer": "adam","learning_rate": 1e-3,
}

數據處理

loader.py

# -*- coding: utf-8 -*-import json
import re
import os
import torch
import random
import jieba
import numpy as np
from torch.utils.data import Dataset, DataLoader"""
數據加載
"""class DataGenerator:def __init__(self, data_path, config):self.config = configself.path = data_pathself.vocab = load_vocab(config["vocab_path"])self.config["vocab_size"] = len(self.vocab)self.schema = load_schema(config["schema_path"])self.config["class_num"] = len(self.schema)self.load()def load(self):self.data = []with open(self.path, encoding="utf8") as f:for line in f:line = json.loads(line)#加載訓練集if isinstance(line, dict):questions = line["questions"]label = line["target"]label_index = torch.LongTensor([self.schema[label]])for question in questions:input_id = self.encode_sentence(question)input_id = torch.LongTensor(input_id)self.data.append([input_id, label_index])else:assert isinstance(line, list)question, label = lineinput_id = self.encode_sentence(question)input_id = torch.LongTensor(input_id)label_index = torch.LongTensor([self.schema[label]])self.data.append([input_id, label_index])returndef encode_sentence(self, text):input_id = []if self.config["vocab_path"] == "words.txt":for word in jieba.cut(text):input_id.append(self.vocab.get(word, self.vocab["[UNK]"]))else:for char in text:input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))input_id = self.padding(input_id)return input_id#補齊或截斷輸入的序列,使其可以在一個batch內運算def padding(self, input_id):input_id = input_id[:self.config["max_length"]]input_id += [0] * (self.config["max_length"] - len(input_id))return input_iddef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]#加載字表或詞表
def load_vocab(vocab_path):token_dict = {}with open(vocab_path, encoding="utf8") as f:for index, line in enumerate(f):token = line.strip()token_dict[token] = index + 1  #0留給padding位置,所以從1開始return token_dict#加載schema
def load_schema(schema_path):with open(schema_path, encoding="utf8") as f:return json.loads(f.read())#用torch自帶的DataLoader類封裝數據
def load_data(data_path, config, shuffle=True):dg = DataGenerator(data_path, config)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dlif __name__ == "__main__":from config import Configdg = DataGenerator("valid_tag_news.json", Config)print(dg[1])

主要實現一個自定義數據加載器 DataGenerator,用于加載和處理文本數據。它通過詞匯表和標簽映射將輸入文本轉化為索引序列,并進行補齊或截斷。

模型構建

model.py

# -*- coding: utf-8 -*-import torch
import torch.nn as nn
from torch.optim import Adam, SGD
"""
建立網絡模型結構
"""class TorchModel(nn.Module):def __init__(self, config):super(TorchModel, self).__init__()hidden_size = config["hidden_size"]vocab_size = config["vocab_size"] + 1max_length = config["max_length"]class_num = config["class_num"]self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)self.layer = nn.Linear(hidden_size, hidden_size)self.classify = nn.Linear(hidden_size, class_num)self.pool = nn.AvgPool1d(max_length)self.activation = torch.relu     #relu做激活函數self.dropout = nn.Dropout(0.1)self.loss = nn.functional.cross_entropy  #loss采用交叉熵損失#當輸入真實標簽,返回loss值;無真實標簽,返回預測值def forward(self, x, target=None):x = self.embedding(x)  #input shape:(batch_size, sen_len)x = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)x = self.pool(x.transpose(1,2)).squeeze() #input shape:(batch_size, sen_len, input_dim)predict = self.classify(x)                #input shape:(batch_size, input_dim)if target is not None:return self.loss(predict, target.squeeze())else:return predictdef choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)

定義了一個神經網絡模型 TorchModel,繼承自 nn.Module,用于文本分類任務。模型包括嵌入層、線性層、平均池化層和分類層,使用 ReLU 激活函數和 Dropout 防止過擬合。前向傳播根據輸入返回預測值或損失值(若提供標簽)。choose_optimizer 函數根據配置選擇 Adam 或 SGD 優化器,并設置學習率。模型通過交叉熵損失進行訓練。

主程序

main.py

# -*- coding: utf-8 -*-import torch
import os
import random
import os
import numpy as np
import loggingfrom config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data, load_schemalogging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)"""
模型訓練主程序
"""def main(config):#創建保存模型的目錄if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])#加載訓練數據train_data = load_data(config["train_data_path"], config)#加載模型model = TorchModel(config)# 標識是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,遷移模型至gpu")model = model.cuda()#加載優化器optimizer = choose_optimizer(config, model)#加載效果測試類evaluator = Evaluator(config, model, logger)#訓練for epoch in range(config["epoch"]):epoch += 1model.train()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):optimizer.zero_grad()if cuda_flag:batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #輸入變化時這里需要修改,比如多輸入,多輸出的情況loss = model(input_id, labels)train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)loss.backward()# print(loss.item())# print(model.classify.weight.grad)optimizer.step()logger.info("epoch average loss: %f" % np.mean(train_loss))evaluator.eval(epoch)model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)torch.save(model.state_dict(), model_path)return model, train_datadef ask(model, question):input_id = train_data.dataset.encode_sentence(question)model.eval()model = model.cpu()cls = torch.argmax(model(torch.LongTensor([input_id])))schemes = load_schema(Config["schema_path"])ans = ""for name, val in schemes.items():if val == cls:ans = namereturn ansif __name__ == "__main__":model, train_data = main(Config)print(ask(model, "積分是怎么積的"))while True:question = input("請輸入問題:")res = ask(model, question)print("命中問題:", res)print("-----------")

實現一個基于 PyTorch 的文本分類模型的訓練和推理過程。首先,通過 main 函數創建模型訓練的主流程。代碼首先檢查是否有 GPU 可用,并將模型遷移至 GPU(如果可用)。然后加載訓練數據、模型、優化器以及效果評估類。訓練過程中,模型使用交叉熵損失函數計算訓練誤差并進行反向傳播更新參數,每個 epoch 后記錄并輸出平均損失。同時,訓練結束后,將模型保存至指定路徑。

在訓練完成后,ask 函數用于推理,輸入問題并通過模型進行預測。它首先將輸入問題轉化為模型所需的格式,然后利用訓練好的模型進行分類,最后返回匹配的答案。整個程序支持通過命令行輸入問題,模型根據訓練結果給出對應的答案。

在主程序中,首先進行一次初始化訓練,之后進入循環,可以持續輸入問題并得到模型的預測答案。

測試與評估

evaluate.py

# -*- coding: utf-8 -*-
import torch
from loader import load_data"""
模型效果測試
"""class Evaluator:def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = loggerself.valid_data = load_data(config["valid_data_path"], config, shuffle=False)self.stats_dict = {"correct":0, "wrong":0}  #用于存儲測試結果def eval(self, epoch):self.logger.info("開始測試第%d輪模型效果:" % epoch)self.stats_dict = {"correct":0, "wrong":0}  #清空前一輪的測試結果self.model.eval()for index, batch_data in enumerate(self.valid_data):if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #輸入變化時這里需要修改,比如多輸入,多輸出的情況with torch.no_grad():pred_results = self.model(input_id) #不輸入labels,使用模型當前參數進行預測self.write_stats(labels, pred_results)self.show_stats()returndef write_stats(self, labels, pred_results):assert len(labels) == len(pred_results)for true_label, pred_label in zip(labels, pred_results):pred_label = torch.argmax(pred_label)if int(true_label) == int(pred_label):self.stats_dict["correct"] += 1else:self.stats_dict["wrong"] += 1returndef show_stats(self):correct = self.stats_dict["correct"]wrong = self.stats_dict["wrong"]self.logger.info("預測集合條目總量:%d" % (correct +wrong))self.logger.info("預測正確條目:%d,預測錯誤條目:%d" % (correct, wrong))self.logger.info("預測準確率:%f" % (correct / (correct + wrong)))self.logger.info("--------------------")return

定義一個 Evaluator 類,用于評估深度學習模型在驗證集上的表現。Evaluator 初始化時接受配置文件、模型和日志記錄器,并加載驗證數據。eval 方法用于進行模型評估,在每輪評估開始時清空統計信息,設置模型為評估模式,然后通過遍歷驗證數據集進行預測。預測結果通過 write_stats 方法與真實標簽進行比對,統計正確和錯誤的預測條目。最后,show_stats 方法輸出總預測條目數、正確條目數、錯誤條目數以及準確率。該類的作用是幫助監控模型在驗證集上的性能,便于調整和優化模型。

測試結果

請輸入問題:在官網上如何修改移動密碼
命中問題: 移動密碼修改
-----------
請輸入問題:我想多加一個號碼作為親情號
命中問題: 親情號碼設置與修改
-----------
請輸入問題:我已經交足了話費請立即幫我開機
命中問題: 話費查詢
-----------
請輸入問題:密碼想換一下
命中問題: 密碼修改

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/901640.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/901640.shtml
英文地址,請注明出處:http://en.pswp.cn/news/901640.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

入門-C編程基礎部分:6、常量

飛書文檔https://x509p6c8to.feishu.cn/wiki/MnkLwEozRidtw6kyeW9cwClbnAg C 常量 常量是固定值,在程序執行期間不會改變,可以讓我們編程更加規范。 常量可以是任何的基本數據類型,比如整數常量、浮點常量、字符常量,或字符串字…

第二階段:數據結構與函數

模塊4:常用數據結構 (Organizing Lots of Data) 在前面的模塊中,我們學習了如何使用變量來存儲單個數據,比如一個數字、一個名字或一個布爾值。但很多時候,我們需要處理一組相關的數據,比如班級里所有學生的名字、一本…

【C++算法】61.字符串_最長公共前綴

文章目錄 題目鏈接:題目描述:解法C 算法代碼:解釋 題目鏈接: 14. 最長公共前綴 題目描述: 解法 解法一:兩兩比較 先算前兩個字符串的最長公共前綴,然后拿這個最長公共前綴和后面一個來比較&…

JVM 調優不再難:AI 工具自動生成內存優化方案

在 Java 應用程序的開發與運行過程中,Java 虛擬機(JVM)的性能調優一直是一項極具挑戰性的任務,尤其是內存優化方面。不合適的 JVM 內存配置可能會導致應用程序出現性能瓶頸,甚至頻繁拋出內存溢出異常,影響業…

紛析云開源財務軟件:企業財務數字化轉型的靈活解決方案

紛析云是一家專注于開源財務軟件研發的公司,自2018年成立以來,始終以“開源開放”為核心理念,致力于通過技術創新助力企業實現財務管理的數字化與智能化轉型。其開源財務軟件憑借高擴展性、靈活部署和全面的功能模塊,成為眾多企業…

【數字圖像處理】數字圖像空間域增強(3)

圖像銳化 圖像細節增強 圖像輪廓:灰度值陡然變化的部分 空間變化:計算灰度變化程度 圖像微分法:微分計算灰度梯度突變的速率 一階微分:單向差值 二階微分:雙向插值 一階微分濾波 1:梯度法 梯度&#xff1…

基于Linux的ffmpeg python的關鍵幀抽取

1.FFmpeg的環境配置 首先強調,ffmpeg-python包與ffmpeg包不一樣。 1) 創建一個虛擬環境env conda create -n yourenv python3.x conda activate yourenv2) ffmpeg-python包的安裝 pip install ffmpeg-python3) 安裝系統級別的 FFmpeg 工具 雖然安裝了 ffmpeg-p…

C#進階學習(四)單向鏈表和雙向鏈表,循環鏈表(上)單向鏈表

目錄 前置知識: 一、鏈表中的結點類LinkedNode 1、申明字段節點類: 2、申明屬性節點類: 二、兩種方式實現單向鏈表 ①定框架: ②增加元素的方法:因為是單鏈表,所以增加元素一定是只能在末尾添加元素,…

RK3588 Buildroot 串口測試工具

RK3588 Buildroot串口測試工具(含代碼) 一、引言 1.1 目的 本文檔旨在指導開發人員能快速測試串口功能 1.2 適用范圍 本文檔適用于linux 系統串口測試。 二、開發環境準備 2.1 硬件環境 開發板:RK3588開發板,確保其串口硬件連接正常,具備電源供應、調試串口等基本硬…

HOJ PZ

https://docs.hdoi.cn/deploy 單體部署 請到~/hoj-deploy/standAlone的目錄下,即是與docker-compose.yml的文件同個目錄下,該目錄下有個叫hoj的文件夾,里面的文件夾介紹如下: hoj ├── file # 存儲了上傳的圖片、上傳的臨…

EtherCAT 的優點與缺點

EtherCAT(以太網控制自動化技術)是一種高性能的工業以太網協議,廣泛應用于實時自動化控制。以下是其核心優缺點分析: ?一、EtherCAT 的核心優點? 1. ?超低延遲 & 高實時性? ?原理?:采用"?Processing…

高并發多級緩存架構實現思路

目錄 1.整體架構 3.安裝環境 1.1 使用docket安裝redis 1.2 配置redis緩存鏈接: 1.3 使用redisTemplate實現 1.4 緩存注解優化 1.4.1 常用緩存注解簡紹 1.4.2 EnableCaching注解的使用 1.4.3使用Cacheable 1.4.4CachePut注解的使用 1.4.5 優化 2.安裝Ngin…

Qt QML實現Windows桌面顏色提取器

前言 實現一個簡單的小工具,使用Qt QML實現Windows桌面顏色提取器,實時顯示鼠標移動位置的顏色值,包括十六進制值和RGB值。該功能在實際應用中比較常見,比如截圖的時候,鼠標移動就會在鼠標位置實時顯示坐標和顏色值&a…

vue3+vite 多個環境配置

同一套代碼 再也不用在不同的環境里來回切換請求地址了 然后踩了一個坑 就是env的文件路徑是在當前項目下 不是在views內 因為公司項目需求只有dev和pro兩個環境 雖然我新增了3個 但是只在這兩個里面配置了 .env是可以配置一些公共配置的 目前需求來說不需要 所以我也懶得配了。…

AI賦能PLC(一):三菱FX-3U編程實戰初級篇

前言 在工業自動化領域,三菱PLC以其高可靠性、靈活性和廣泛的應用場景,成為眾多工程師的首選控制設備。然而,傳統的PLC編程往往需要深厚的專業知識和經驗積累,開發周期長且調試復雜。隨著人工智能技術的快速發展,利用…

XSS 跨站Cookie 盜取表單劫持網絡釣魚溯源分析項目平臺框架

漏洞原理:接受輸入數據,輸出顯示數據后解析執行 基礎類型:反射 ( 非持續 ) ,存儲 ( 持續 ) , DOM-BASE 拓展類型: jquery , mxss , uxss , pdfxss , flashx…

鴻蒙應用(醫院診療系統)開發篇2·Axios網絡請求封裝全流程解析

一、項目初始化與環境準備 1. 創建鴻蒙工程 src/main/ets/ ├── api/ │ ├── api.ets # 接口聚合入口 │ ├── login.ets # 登錄模塊接口 │ └── request.ets # 網絡請求核心封裝 └── pages/ └── login.ets # 登錄頁面邏輯…

ADAS高級駕駛輔助系統詳細介紹

ADAS(高級駕駛輔助系統)核心模塊,通過 “監測→預警→干預” 三層邏輯提升行車安全。用戶選擇車輛時,可關注傳感器配置(如是否標配毫米波雷達)、功能覆蓋場景(如 AEB 是否支持夜間行人&#xff…

Prometheus+Grafana+K8s構建監控告警系統

一、技術介紹 Prometheus、Grafana及K8S服務發現詳解 Prometheus簡介 Prometheus是一個開源的監控系統和時間序列數據庫,最初由SoundCloud開發,現已成為CNCF(云原生計算基金會)的畢業項目?。它專注于實時監控和告警,特別適合云原生和分布式…

MATLAB腳本實現了一個三自由度的通用航空運載器(CAV-H)的軌跡仿真,主要用于模擬升力體在不同飛行階段(初始滑翔段、滑翔段、下壓段)的運動軌跡

%升力體:通用航空運載器CAV-H %讀取數據1 升力系數 alpha = [10 15 20]; Ma = [3.5 5 8 10 15 20 23]; alpha1 = 10:0.1:20; Ma1 = 3.5:0.1:23; [Ma1, alpha1] = meshgrid(Ma1, alpha1); CL = readmatrix(simulation.xlsx, Sheet, Sheet1, Range, B2:H4); CL1 = interp2(…