【NLP 38、實踐 ⑩ NER 命名實體識別任務 Bert 實現】

去做具體的事,然后穩穩托舉自己

????????????????????????????????????????—— 25.3.17

數據文件:

通過網盤分享的文件:Ner命名實體識別任務
鏈接: https://pan.baidu.com/s/1fUiin2um4PCS5i91V9dJFA?pwd=yc6u 提取碼: yc6u?
--來自百度網盤超級會員v3的分享

一、配置文件 config.py

1.模型與數據路徑

model_path:模型訓練完成后保存的位置。例如:保存最終的模型權重文件。

schema_path:數據結構定義文件,通常用于描述數據的格式(如字段名、標簽類型)。
在NER任務中,可能定義實體類別(如?{"PERSON": "人名", "ORG": "組織"})。

train_data_path:訓練數據集路徑,通常為標注好的文本文件(如?train.txt?或?JSON?格式)。

valid_data_path:?驗證數據集路徑,用于模型訓練時的性能評估和超參數調優。

vocab_path:?字符詞匯表文件,記錄模型中使用的字符集(如中文字符、字母、數字等)。


2.模型架構

max_length:輸入文本的最大序列長度。超過此長度的文本會被截斷或填充(如用?[PAD])。

hidden_size:模型隱藏層神經元的數量,影響模型容量和計算復雜度。

num_layers:模型的堆疊層數(如LSTM、Transformer的編碼器/解碼器層數)。

class_num:?任務類別總數。例如:NER任務中可能有9種實體類型。

vocab_size:詞表大小


3.訓練配置

epoch:訓練輪數。每輪遍歷整個訓練數據集一次。

batch_size:每次梯度更新所使用的樣本數量。較小的批次可能更適合內存受限的環境。

optimizer:優化器類型,用于調整模型參數。Adam是常用優化器,結合動量梯度下降。

learning_rate:學習率,控制參數更新的步長。值過小可能導致訓練緩慢,過大易過擬合。

use_crf:是否啟用條件隨機場(CRF)?層。在序列標注任務(如NER)中,CRF可捕捉標簽間的依賴關系,提升準確性。


4.預訓練模型

bert_path:?預訓練BERT模型的路徑。BERT是一種強大的預訓練語言模型,此處可能用于微調或特征提取。

# -*- coding: utf-8 -*-"""
配置參數信息
"""Config = {"model_path": "model_output","schema_path": "ner_data/schema.json","train_data_path": "ner_data/train","valid_data_path": "ner_data/test","vocab_path":"chars.txt","max_length": 100,"hidden_size": 256,"num_layers": 2,"epoch": 20,"batch_size": 16,"optimizer": "adam","learning_rate": 1e-3,"use_crf": False,"class_num": 9,"bert_path": r"F:\人工智能NLP/NLP資料\week6 語言模型/bert-base-chinese","vocab_size": 20000
}

二、數據加載 loader.py

1.初始化數據加載類

def __init__(self, data_path, config):構造函數接收數據路徑和配置對象。

data_path:數據文件存儲路徑

config:包含訓練 / 數據配置的字典

self.config:保存包含訓練 / 數據配置的字典

self.path:保存數據文件存儲路徑

self.tokenizer:將文本數據轉換為深度學習模型(如 BERT)可處理的輸入格式的核心工具

self.sentences:初始化句子列表

self.schema:加載實體標簽與索引的映射關系表

self.load:調用?load()?方法從?data_path?加載原始數據,進行分詞、編碼、填充/截斷等預處理。

    def __init__(self, data_path, config):self.config = configself.path = data_pathself.tokenizer = load_vocab(config["bert_path"])self.sentences = []self.schema = self.load_schema(config["schema_path"])self.load()

2.加載數據并預處理

①?初始化數據容器 ——>

②?文件讀取與分段處理 ——>

③ 逐段解析字符與標簽 ——>

④ 句子編碼與填充 ——>

⑤ 數據封裝與返回?

self.path:數據文件的存儲路徑(如?train.txt),由類初始化時傳入的?data_path?參數賦值。

f:文件對象,用于讀取?self.path?指向的原始數據文件。

segments:是按雙換行符分隔的段落列表,每個段落對應一個樣本(如一個句子及其標注序列)。

segment:遍歷?segments?時的單個樣本段落,進一步按行分割處理為字符和標簽

labels:存儲當前樣本的標簽序列,[8]可能表示?[CLS]?標記的 ID,用于序列起始符,之后將每個字符的標簽轉換為ID。

char:當前行的字符(如?"中"),屬于句子中的一個基本單元。

lable:當前行的原始標簽字符串(如?"B-LOC"),?尚未映射為 ID

input_ids:將字符序列編碼為模型輸入所需的 ID 序列(如 BERT 分詞后的 Token ID)

self.data:列表,存儲預處理后的數據樣本,每個樣本由輸入張量和標簽張量組成

sentence:由字符列表拼接而成的完整句子(如?"中國科技大學"),存入?self.sentences?供后續可視化或調試。

open():打開文件并返回文件對象,支持讀/寫/追加等模式。

參數名類型說明
file字符串文件路徑(絕對/相對路徑)
mode字符串打開模式(如?r-只讀、w-寫入、a-追加)
encoding字符串文件編碼(如?utf-8,文本模式需指定)
errors字符串編碼錯誤處理方式(如?ignorereplace

文件對象.read():讀取文件內容,返回字符串或字節流

參數名類型說明
size整數可選,指定讀取的字節數(默認讀取全部內容)

split():按分隔符分割字符串,返回子字符串列表

參數名類型說明
delimiter字符串分隔符(默認空格)
maxsplit整數可選,最大分割次數(默認-1表示全部)

strip():去除字符串首尾指定字符(默認空白字符)

參數名類型說明
chars字符串可選,指定需去除的字符集合

join():用分隔符連接可迭代對象的元素,返回新字符串

參數名類型說明
iterable可迭代對象需連接的元素集合(如列表、元組)
sep字符串分隔符(默認空字符串)

列表.append():在列表末尾添加元素

參數名類型說明
obj任意類型要添加的元素
    def load(self):self.data = []with open(self.path, encoding="utf8") as f:segments = f.read().split("\n\n")for segment in segments:sentence = []labels = [8]  # cls_tokenfor line in segment.split("\n"):if line.strip() == "":continuechar, label = line.split()sentence.append(char)labels.append(self.schema[label])sentence = "".join(sentence)self.sentences.append(sentence)input_ids = self.encode_sentence(sentence)labels = self.padding(labels, -1)# print(self.decode(sentence, labels))# input()self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])return

3.加載字 / 詞表

vocab_path:字 / 詞表的存儲路徑

BertTokenizer.from_pretrained():Hugging Face Transformers 庫中用于加載預訓練 BERT 分詞器的核心方法。它支持從 Hugging Face 模型庫或本地路徑加載預訓練的分詞器,并允許通過參數調整分詞行為。

參數名?類型?默認值?說明
pretrained_model_name_or_pathstr必填預訓練模型名稱(如?bert-base-uncased)或本地路徑。若為名稱,自動從 Hugging Face 下載
cache_dirstrNone模型緩存目錄。若指定,下載的模型文件會存儲在此路徑下
force_downloadboolFalse是否強制重新下載模型,即使本地已緩存
resume_downloadboolFalse是否斷點續傳下載任務
do_lower_caseboolTrue(英文模型)是否將文本轉為小寫。?中文模型需注意:若設為?False,可能導致英文單詞被識別為?[UNK]
add_special_tokensboolTrue是否在輸入文本中添加?[CLS]?和?[SEP]?等特殊標記
tokenize_chinese_charsboolTrue是否對中文字符進行逐字分詞(如將“你好”拆分為“你”和“好”)
strip_accentsboolNone是否去除重音符號(如將?é?轉換為?e
use_fastboolTrue是否啟用快速分詞模式(基于 Rust 實現,速度更快)
def load_vocab(vocab_path):return BertTokenizer.from_pretrained(vocab_path)

4.加載映射關系表?

????????加載位于指定路徑的 JSON 格式的模式文件,并將其內容解析為 Python 對象以便在數據生成過程中使用。

path:映射關系表schema的存儲路徑

open():打開文件并返回文件對象,用于讀寫文件內容。

參數名類型默認值說明
file_namestr文件路徑(需包含擴展名)
modestr'r'文件打開模式:
-?'r': 只讀
-?'w': 只寫(覆蓋原文件)
-?'a': 追加寫入
-?'b': 二進制模式
-?'x': 創建新文件(若存在則報錯)
bufferingintNone緩沖區大小(僅二進制模式有效)
encodingstrNone文件編碼(僅文本模式有效,如?'utf-8'
newlinestr'\n'行結束符(僅文本模式有效)
closefdboolTrue是否在文件關閉時自動關閉文件描述符
dir_fdint-1文件描述符(高級用法,通常忽略)
flagsint0Linux 系統下的額外標志位
modestr(重復參數,實際使用中只需指定?mode

json.load():從已打開的 JSON 文件對象中加載數據,并將其轉換為 Python 對象(如字典、列表)。

參數名類型默認值說明
fpio.TextIO已打開的文件對象(需處于讀取模式)
indentint/strNone縮進空格數(美化輸出,如?4?或?" "
sort_keysboolFalse是否對 JSON 鍵進行排序
load_hookcallableNone自定義對象加載回調函數
object_hookcallableNone自定義對象解析回調函數
    def load_schema(self, path):with open(path, encoding="utf8") as f:return json.load(f)

5.封裝數據

Ⅰ、初始化DataGenerator:初始化DataGenerator實例dg,傳入data_path和config

Ⅱ、創建?DataLoader?對象:創建DataLoader實例dl,使用dg、batch_size和shuffle參數

Ⅲ、返回?DataLoader?迭代器:返回dl

data_path:數據文件的路徑(如?train.txt),用于初始化?DataGenerator,指向原始數據文件。

config:配置參數字典,通常包含?batch_sizebert_pathschema_path?等參數,用于控制數據加載邏輯。

dg:自定義數據集對象,繼承?torch.utils.data.Dataset,負責數據加載、預處理和樣本生成。

dl:封裝?DataGenerator?的迭代器,實現批量加載、多進程加速等功能,直接用于模型訓練。

DataLoader():PyTorch 模型訓練的標配工具,通過合理的參數配置(如?batch_sizenum_workersshuffle),可以顯著提升數據加載效率,尤其適用于大規模數據集和復雜預處理任務。其與?Dataset?類的配合使用,是構建高效訓練管道的核心。

參數名類型默認值說明
datasetDatasetNone必須參數,自定義數據集對象(需繼承?torch.utils.data.Dataset)。
batch_sizeint1每個批次的樣本數量。
shuffleboolFalse是否在每個 epoch 開始時打亂數據順序(訓練時推薦設為?True)。
num_workersint0使用多線程加載數據的工人數量(需大于 0 時生效)。
pin_memoryboolFalse是否將數據存儲在 pinned memory 中(加速 GPU 數據傳輸)。
drop_lastboolFalse如果數據集長度無法被?batch_size?整除,是否丟棄最后一個不完整的批次。
persistent_workersboolFalse是否保持工作線程在 epoch 之間持續運行(減少多線程初始化開銷)。
worker_init_fncallableNone自定義工作線程初始化函數。
# 用torch自帶的DataLoader類封裝數據
def load_data(data_path, config, shuffle=True):dg = DataGenerator(data_path, config)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dl

6.對于輸入文本做截斷 / 填充

Ⅰ、截斷過長序列?(超過預設最大長度)

Ⅱ、填充過短序列?(用?pad_token?補齊到預設最大長度)

    #補齊或截斷輸入的序列,使其可以在一個batch內運算def padding(self, input_id, pad_token=0):input_id = input_id[:self.config["max_length"]]input_id += [pad_token] * (self.config["max_length"] - len(input_id))return input_id

7.類內魔術方法

self.data:表示數據集對象本身存儲的數據容器

index:表示數據集中某個樣本的索引值,用于定位并返回特定位置的樣本。

__len__():用于定義對象的“長度”,通過內置函數?len()?調用時返回該值。它通常用于容器類(如列表、字典、自定義數據結構),表示容器中元素的個數

__getitem__():允許對象通過索引或鍵值訪問元素,支持?obj[index]?或?obj[key]?語法。它使對象表現得像序列(如列表)或映射(如字典)

    def __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]

8.對于輸入的文本編碼

調用分詞器編碼(參數控制標準化)

self.tokenizer:將文本數據轉換為深度學習模型(如 BERT)可處理的輸入格式的核心工具

self.tokenizer.encode():Hugging Face Transformers 庫中?BertTokenizer?的核心方法,用于將原始文本轉換為模型可處理的輸入形式。

參數名類型默認值說明
textstr?或?List[str]?必填輸入文本(單句或句子對)。
text_pairstrNone第二段文本(用于句子對任務,如問答),與?text?拼接后生成?[CLS] text [SEP] text_pair [SEP]
add_special_tokensboolTrue是否添加?[CLS]?和?[SEP]?標記。關閉后僅返回原始分詞索引
max_lengthint512最大序列長度。超長文本會被截斷,不足則填充
paddingstr?或?boolFalse填充策略:True/'longest'(按批次最長填充)、'max_length'(按?max_length?填充)
truncationstr?或?boolFalse截斷策略:True(按?max_length?截斷)、'only_first'(僅截斷第一句)
return_tensorsstrNone

返回張量類型:

'pt'(PyTorch)、'tf'(TensorFlow)、'np'(NumPy)

return_attention_maskboolTrue是否生成?attention_mask,標識有效內容(1)與填充部分(0)
    def encode_sentence(self, text, padding=True):return self.tokenizer.encode(text,padding="max_length",max_length=self.config["max_length"],truncation=True)

9.對于編碼后的輸入文本作解碼

(04+): 匹配以?0(B-LOCATION)開頭,后接多個?4(I-LOCATION)的連續標簽

(15+)(26+)(37+)分別對應 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的標簽模式。

sentence:輸入的原句(添加?$?后的版本),用于根據標簽索引提取實體文本。

lables:模型輸出的標簽序列,轉換為字符串后通過正則匹配定位實體位置。

results:存儲提取的實體,鍵為實體類型(如?"LOCATION"),值為該類型實體的文本列表。

location:正則匹配結果,通過?span()?獲取實體在?sentence?中的起止位置,用于提取具體文本片段。

join():將可迭代對象(列表、元組等)中的元素按指定分隔符連接成一個字符串。調用該方法的字符串作為分隔符。

參數名類型默認值說明
iterable可迭代對象必填需連接的元素集合,所有元素必須是字符串類型。若為空,返回空字符串。

str():將其他數據類型(整數、浮點數、布爾值等)轉換為字符串類型。支持格式化輸出和復雜對象的字符串表示。

參數名類型默認值說明
object任意類型必填需轉換的對象,如整數、列表、字典等。
encoding字符串可選編碼格式(僅對字節類型有效),如?utf-8
errors字符串可選編碼錯誤處理策略,如?ignorereplace

defaultdict():創建字典的子類,為不存在的鍵自動生成默認值。需指定?default_factory(如?listint)定義默認值類型。

參數名類型默認值說明
default_factory可調用對象或無參數函數None用于生成默認值的函數。若未指定,訪問不存在的鍵會拋出?KeyError
**kwargs關鍵字參數可選其他初始化字典的鍵值對,如?name="Alice"

re.finditer():在字符串中全局搜索正則表達式匹配項,返回一個迭代器,每個元素為?Match?對象

參數名類型說明
patternstr?或正則表達式對象要匹配的正則表達式模式
stringstr要搜索的字符串
flagsint?(可選)正則匹配標志(如?re.IGNORECASE

.span():返回正則匹配的起始和結束索引(左閉右開區間)

列表.append():向列表末尾添加單個元素,直接修改原列表

參數名類型說明
element任意要添加的元素
    def decode(self, sentence, labels):sentence = "$" + sentencelabels = "".join([str(x) for x in labels[:len(sentence) + 2]])results = defaultdict(list)for location in re.finditer("(04+)", labels):s, e = location.span()print("location", s, e)results["LOCATION"].append(sentence[s:e])for location in re.finditer("(15+)", labels):s, e = location.span()print("org", s, e)results["ORGANIZATION"].append(sentence[s:e])for location in re.finditer("(26+)", labels):s, e = location.span()print("per", s, e)results["PERSON"].append(sentence[s:e])for location in re.finditer("(37+)", labels):s, e = location.span()print("time", s, e)results["TIME"].append(sentence[s:e])return results

完整代碼

DataLoader():PyTorch 中用于高效加載和管理數據集的核心工具

?參數名?類型?默認值?說明
datasetDataset必填加載的數據集對象,需實現?__len__?和?__getitem__?方法
batch_sizeint1每個批次包含的樣本數
shuffleboolFalse是否在每個訓練周期(epoch)開始時打亂數據順序。若?sampler?被指定,則忽略此參數。
samplerSamplerNone自定義數據采樣策略(如隨機采樣?RandomSampler?或順序采樣?SequentialSampler
batch_samplerSamplerNone自定義批次采樣策略(需與?batch_sizeshuffle?等參數互斥)
num_workersint0用于加載數據的子進程數。0?表示在主進程加載;大于?0?時啟用多進程加速
collate_fnCallableNone合并多個樣本為批次的函數(如填充序列長度)。默認將 NumPy 數組轉為 Tensor
pin_memoryboolFalse若為?True,將數據復制到 CUDA 固定內存中,加速 GPU 數據傳輸
drop_lastboolFalse若為?True,丟棄最后一個不完整的批次(當數據集樣本數無法被?batch_size?整除時)
timeoutfloat0等待從子進程收集批次的超時時間(秒)。0?表示無限等待
worker_init_fnCallableNone子進程初始化函數(如設置隨機種子)
prefetch_factorint2每個子進程預加載的批次數量(需?num_workers > 0
persistent_workersboolFalse是否在訓練周期結束后保留子進程(減少重復創建進程的開銷)

.shape:??NumPy 數組或 ?PyTorch 張量的屬性,用于獲取數據的維度信息。

input():Python 的內置函數,用于從標準輸入(如鍵盤)讀取用戶輸入的字符串。

參數名?類型?默認值?說明
promptstr""可選提示信息,顯示在輸入前(如?input("請輸入:")
?返回值str-返回用戶輸入的字符串,需手動轉換為其他類型(如?int(input())
# -*- coding: utf-8 -*-import json
import re
import os
import torch
import random
import jieba
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from transformers import BertTokenizer"""
數據加載
"""class DataGenerator:def __init__(self, data_path, config):self.config = configself.path = data_pathself.tokenizer = load_vocab(config["bert_path"])self.sentences = []self.schema = self.load_schema(config["schema_path"])self.load()def load(self):self.data = []with open(self.path, encoding="utf8") as f:segments = f.read().split("\n\n")for segment in segments:sentenece = []labels = [8]  # cls_tokenfor line in segment.split("\n"):if line.strip() == "":continuechar, label = line.split()sentenece.append(char)labels.append(self.schema[label])sentence = "".join(sentenece)self.sentences.append(sentence)input_ids = self.encode_sentence(sentenece)labels = self.padding(labels, -1)# print(self.decode(sentence, labels))# input()self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])returndef encode_sentence(self, text, padding=True):return self.tokenizer.encode(text,padding="max_length",max_length=self.config["max_length"],truncation=True)def decode(self, sentence, labels):sentence = "$" + sentencelabels = "".join([str(x) for x in labels[:len(sentence) + 2]])results = defaultdict(list)for location in re.finditer("(04+)", labels):s, e = location.span()print("location", s, e)results["LOCATION"].append(sentence[s:e])for location in re.finditer("(15+)", labels):s, e = location.span()print("org", s, e)results["ORGANIZATION"].append(sentence[s:e])for location in re.finditer("(26+)", labels):s, e = location.span()print("per", s, e)results["PERSON"].append(sentence[s:e])for location in re.finditer("(37+)", labels):s, e = location.span()print("time", s, e)results["TIME"].append(sentence[s:e])return results# 補齊或截斷輸入的序列,使其可以在一個batch內運算def padding(self, input_id, pad_token=0):input_id = input_id[:self.config["max_length"]]input_id += [pad_token] * (self.config["max_length"] - len(input_id))return input_iddef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]def load_schema(self, path):with open(path, encoding="utf8") as f:return json.load(f)def load_vocab(vocab_path):return BertTokenizer.from_pretrained(vocab_path)# 用torch自帶的DataLoader類封裝數據
def load_data(data_path, config, shuffle=True):dg = DataGenerator(data_path, config)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dlif __name__ == "__main__":from config import Configdg = DataGenerator("ner_data/train", Config)dl = DataLoader(dg, batch_size=32)for x, y in dl:print(x.shape, y.shape)print(x[1], y[1])input()

三、模型建立 model.py

1.代碼運行流程

輸入 x → 嵌入層 → 雙向LSTM → 全連接分類層 → 分支判斷:│├── 有 target → CRF? → 是:計算 CRF 損失(通過維特比算法計算序列概率)│                 ││                 └→ 否:計算交叉熵損失(logits 展平后與標簽計算交叉熵)│└── 無 target → CRF? → 是:解碼最優標簽序列(使用CRF的decode方法)│└→ 否:返回原始 logits(全連接層輸出的未歸一化分數)

2.模型初始化

代碼運行流程

輸入 x → BERT預訓練模型 → 分類層 → 分支判斷:│├── 有 target → CRF? → 是:計算 CRF 損失(通過轉移矩陣計算序列聯合概率)│                 ││                 └→ 否:計算交叉熵損失(logits 與標簽的逐位置交叉熵)│└── 無 target → CRF? → 是:維特比解碼最優路徑(考慮標簽轉移約束)│└→ 否:返回原始 logits(全連接層輸出的未歸一化分數)

hidden_size:定義LSTM隱藏層的維度(即每個時間步輸出的特征數量

vocab_size:詞表大小,即嵌入層(Embedding)可處理的詞匯總數

max_length:輸入序列的最大長度,用于數據預處理(如截斷或填充)

class_num:分類任務的類別數量,決定線性層(nn.Linear)的輸出維度

num_layers:堆疊的LSTM層數,用于增加模型復雜度

BertModel.from_pretrained():加載預訓練的 BERT 模型,支持從本地或 Hugging Face 模型庫加載

參數名類型默認值說明
pretrained_model_name字符串預訓練模型名稱或路徑(如?bert-base-chinese
config字典/類默認配置自定義模型配置,覆蓋默認參數(如隱藏層維度、注意力頭數)
cache_dir字符串None模型緩存目錄
output_hidden_states布爾值False是否返回所有隱藏層輸出(用于特征提取)

nn.Linear():實現全連接層的線性變換(y = xW^T + b

參數名類型默認值說明
in_features整數輸入特征維度(如詞向量維度?hidden_size
out_features整數輸出特征維度(如分類類別數?class_num
bias布爾值True是否啟用偏置項

CRF():條件隨機場層,用于序列標注任務中約束標簽轉移邏輯。

參數名類型默認值說明
num_tags整數標簽類別數(如?class_num
batch_first布爾值False輸入張量是否為?(batch_size, seq_len)?格式

torch.nn.CrossEntropyLoss():計算交叉熵損失,常用于分類任務

參數名類型默認值說明
ignore_index整數-1忽略指定索引的標簽(如填充符?-1
reduction字符串mean損失聚合方式(可選?nonesummean
    def __init__(self, config):super(TorchModel, self).__init__()hidden_size = config["hidden_size"]vocab_size = config["vocab_size"] + 1max_length = config["max_length"]class_num = config["class_num"]num_layers = config["num_layers"]# self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)self.layer = BertModel.from_pretrained(config["bert_path"], hidden_size=hidden_size, num_layers=num_layers)# self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)self.classify = nn.Linear(hidden_size * 2, class_num)self.crf_layer = CRF(class_num, batch_first=True)self.use_crf = config["use_crf"]self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  #loss采用交叉熵損失

3.前向計算 

代碼運行流程

輸入 x → 嵌入層 → LSTM層 → 分類層 → 分支判斷:│├── 有 target → CRF? → 是:計算 CRF 損失│ ? ? ? ? ? ? ? ? ││ ? ? ? ? ? ? ? ? └→ 否:計算交叉熵損失│└── 無 target → CRF? → 是:解碼最優標簽序列│└→ 否:返回預測 logits

x:輸入序列的 Token ID 矩陣,代表一個批次的文本數據(如?[[101, 234, ...], [103, 456, ...]])。

target:真實標簽序列(如實體標注),若不為?None?表示訓練階段,需計算損失;否則為預測階段。

predict:分類層輸出的每個位置標簽的未歸一化分數(logits),用于后續的 CRF 或交叉熵損失計算。

mask:標記序列中有效 Token 的位置(非填充部分),target.gt(-1)?表示標簽值大于?-1?的位置有效。

gt():張量的逐元素比較函數,返回布爾型張量,標記輸入張量中大于指定值的元素位置。常用于生成掩碼(如忽略填充符)

參數名類型默認值說明
otherTensor/標量比較的閾值或張量。若為標量,則張量中每個元素與該值比較;若為張量,需與輸入張量形狀相同。
outTensorNone可選輸出張量,用于存儲結果。

shape():返回張量的維度信息,描述各軸的大小。

view():調整張量的形狀,支持自動推斷維度(通過-1占位符)。常用于數據展平或維度轉換。

參數名類型默認值說明
*shape可變參數目標形狀的維度序列,如view(2, 3)view(-1, 28)-1表示自動計算。
    #當輸入真實標簽,返回loss值;無真實標簽,返回預測值def forward(self, x, target=None):x = self.embedding(x)  #input shape:(batch_size, sen_len)x, _ = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)if target is not None:if self.use_crf:mask = target.gt(-1)# loss 是 crf 的相反數,即 - crf(predict, target, mask)return - self.crf_layer(predict, target, mask, reduction="mean")else:#(number, class_num), (number)return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))else:if self.use_crf:return self.crf_layer.decode(predict)else:return predict

4.選擇優化器 

代碼運行流程

輸入 config → 提取參數 → 分支判斷:│├── optimizer == "adam" → 返回 Adam 優化器實例│└── optimizer == "sgd" → 返回 SGD 優化器實例

config:這個參數應該是一個字典,里面存儲了配置信息。

model:這是傳入的模型對象,通常是一個神經網絡模型。優化器需要模型的參數來更新權重

optimizer:從config中獲取的字符串,決定使用哪種優化器。比如"adam"對應Adam優化器,"sgd"對應隨機梯度下降。

learning_rate:學習率,是優化器的一個重要超參數,控制權重更新的步長

Adam():自適應矩估計優化器(Adaptive Moment Estimation),結合動量和 RMSProp 的優點。

參數名類型默認值說明
lrfloat1e-3學習率。
betastuple(0.9, 0.999)動量系數(β?, β?)。
epsfloat1e-8防止除零誤差。
weight_decayfloat0權重衰減率。
amsgradboolFalse是否啟用 AMSGrad 優化。
foreachboolFalse是否為每個參數單獨計算梯度。

SGD():隨機梯度下降優化器(Stochastic Gradient Descent)

參數名類型默認值說明
lrfloat1e-3學習率。
momentumfloat0動量系數(如?momentum=0.9)。
weight_decayfloat0權重衰減率。
dampeningfloat0動力衰減系數(用于 SGD with Momentum)。
nesterovboolFalse是否啟用 Nesterov 動量。
foreachboolFalse是否為每個參數單獨計算梯度。

parameters():返回模型所有可訓練參數的迭代器,常用于參數初始化或梯度清零。

參數名類型默認值說明
filtercallableNone過濾條件函數(如?lambda p: p.requires_grad)。默認返回所有參數。
def choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)

5.模型建立

# -*- coding: utf-8 -*-import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torchcrf import CRF
import torch
from transformers import BertModel"""
建立網絡模型結構
"""class TorchModel(nn.Module):def __init__(self, config):super(TorchModel, self).__init__()hidden_size = config["hidden_size"]vocab_size = config["vocab_size"] + 1max_length = config["max_length"]class_num = config["class_num"]num_layers = config["num_layers"]self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)# self.layer = BertModel.from_pretrained(config["bert_path"], hidden_size=hidden_size, num_layers=num_layers)self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)self.classify = nn.Linear(hidden_size * 2, class_num)self.crf_layer = CRF(class_num, batch_first=True)self.use_crf = config["use_crf"]self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  #loss采用交叉熵損失#當輸入真實標簽,返回loss值;無真實標簽,返回預測值def forward(self, x, target=None):x = self.embedding(x)  #input shape:(batch_size, sen_len)x, _ = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)if target is not None:if self.use_crf:mask = target.gt(-1)# loss 是 crf 的相反數,即 - crf(predict, target, mask)return - self.crf_layer(predict, target, mask, reduction="mean")else:#(number, class_num), (number)return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))else:if self.use_crf:return self.crf_layer.decode(predict)else:return predictdef choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)if __name__ == "__main__":from config import Configmodel = TorchModel(Config)

四、模型效果測試 evaluate.py

1.代碼運行流程

輸入驗證集 → 數據加載 → 模型預測 → 分支判斷:│├── 啟用CRF → 直接解碼標簽序列 → 實體提取│└── 禁用CRF → argmax獲取預測標簽 → 實體提取→ 統計指標計算 → 分支判斷:│├── 按實體類別統計 → 計算precision/recall/F1(LOCATION/TIME/PERSON/ORGANIZATION)│└── 全局統計 → 計算micro-F1 → 輸出綜合評估結果

2.初始化

Ⅰ、加載配置文件、模型及日志模塊 ——>

Ⅱ、讀取驗證集數據(固定順序,避免隨機性干擾評估)——>

Ⅲ、初始化統計字典?stats_dict,按實體類別記錄正確識別數、樣本實體數等

config:?存儲運行時配置,例如數據路徑、超參數(如批次大小?batch_size)、是否使用CRF層等。通過?config["valid_data_path"]?動態獲取驗證集路徑。

model:待評估的模型實例,用于調用預測方法(如?model(input_id)),需提前完成訓練和加載。

logger:?記錄運行日志,例如輸出評估指標(準確率、F1值)到文件或控制臺,便于調試和監控。

valid_data:驗證數據集,用于模型訓練時的性能評估和超參數調優。

load_data():數據加載類中,用torch自帶的DataLoader類封裝數據的函數

    def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = loggerself.valid_data = load_data(config["valid_data_path"], config, shuffle=False)

?3.統計模型效果

Ⅰ、?輸入驗證與初始化

????????通過?assert?確保輸入的三組數據長度一致(labels,?pred_results,?sentences)。

????????若模型未使用 CRF 層(use_crf=False),將預測結果通過?torch.argmax?轉換為標簽索引序列


Ⅱ、逐樣本處理

????????遍歷每個樣本的真實標簽、預測標簽及原始句子。

????????若未使用 CRF,將預測標簽從 GPU Tensor 轉換為 CPU List(避免內存泄漏)。

????????調用?decode()?方法解碼標簽序列,得到真實實體字典?true_entities?和預測實體字典?pred_entities?


Ⅲ、實體統計

對每個實體類別(如?PERSON,?LOCATION):

????????正確識別數:遍歷預測實體列表,統計與真實實體完全匹配的數量(ent in true_entities[key])。

?????????樣本實體數:統計真實實體列表的長度。

?????????識別出實體數:統計預測實體列表的長度。


Ⅳ、輸出統計結果

????????最終統計結果存儲在?self.stats_dict?中,后續可通過該字典計算準確率(正確識別數 / 識別出實體數)和召回率(正確識別數 / 樣本實體數

labels:真實標簽序列(如實體標注的整數 ID 列表),用于與預測結果對比計算評估指標

pred_results:模型預測結果,若使用 CRF,為標簽序列,否則為每個位置的 logits(未歸一化概率)。

sentences:原始文本句子列表(如?["中國北京", "今天天氣"]),用于解碼標簽序列到具體實體。

use_crf:控制是否使用 CRF 層

pred_label:單個樣本的預測標簽序列,若未使用 CRF,需從 logits 中提取(argmax)并轉換為列表。

true_label:單個樣本的真實標簽序列(如?[0, 4, 4, 8]),已從 GPU 張量轉換為 CPU 列表。

true_entities:解碼后的真實實體字典,如?{"LOCATION": ["北京"], "PERSON": []}

pred_entities:解碼后的預測實體字典,用于與真實實體對比統計正確識別數。

key:字符串,實體類別名稱(如?"PERSON"),遍歷四類實體以分別統計指標。

assert:Python 的 ?調試斷言工具,主要用于在開發階段驗證程序內部的邏輯條件是否成立

????????assert expression [, message] ?

?參數?類型?是否必填?作用
?expression布爾表達式需要驗證的條件。若結果為?False,則觸發斷言失敗;若為?True,程序繼續執行。
?message字符串(可選)斷言失敗時輸出的自定義錯誤信息,用于輔助調試。若省略,則輸出默認錯誤提示。

len():返回對象的元素數量(字符串、列表、元組、字典等)

參數名類型說明
object任意可迭代對象如字符串、列表、字典等

torch.argmax():返回張量中最大值所在的索引

參數名類型說明
inputTensor輸入張量
dimint沿指定維度查找最大值
keepdimbool是否保持輸出維度一致

cpu():將張量從GPU移動到CPU內存

zip():將多個可迭代對象打包成元組列表

參數名類型說明
iterables多個可迭代對象如列表、元組、字符串

.detach():從計算圖中分離張量,阻止梯度傳播

.tolist():將張量或數組轉換為Python列表

    def write_stats(self, labels, pred_results, sentences):assert len(labels) == len(pred_results) == len(sentences)if not self.config["use_crf"]:pred_results = torch.argmax(pred_results, dim=-1)for true_label, pred_label, sentence in zip(labels, pred_results, sentences):if not self.config["use_crf"]:pred_label = pred_label.cpu().detach().tolist()true_label = true_label.cpu().detach().tolist()true_entities = self.decode(sentence, true_label)pred_entities = self.decode(sentence, pred_label)# 正確率 = 識別出的正確實體數 / 識別出的實體數# 召回率 = 識別出的正確實體數 / 樣本的實體數for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:self.stats_dict[key]["正確識別"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])self.stats_dict[key]["樣本實體數"] += len(true_entities[key])self.stats_dict[key]["識別出實體數"] += len(pred_entities[key])return

4.可視化統計模型效果

?精確率 (Precision):正確預測實體數 / 總預測實體數

?召回率 (Recall):正確預測實體數 / 總真實實體數?

F1值:精確率與召回率的調和平均

F1:F1分數:準確率與召回率的調和平均數,綜合衡量模型的精確性與覆蓋能力。

F1_scores:存儲四個實體類別的 F1 分數,用于計算宏觀平均。

precision:準確率:模型預測為某類實體的結果中,正確的比例。反映模型預測的精確度。

recall:召回率:真實存在的某類實體中,被模型正確識別的比例。反映模型對實體的覆蓋能力。

key:當前處理的實體類別(如?"PERSON""LOCATION")。

correct_pred:?總正確識別數:所有類別中被正確識別的實體總數。

total_pred:總識別實體數:模型預測出的所有實體數量(含錯誤識別)。

true_enti:?總樣本實體數:驗證數據中真實存在的所有實體數量。

micro_precision:微觀準確率:全局視角下的準確率,所有實體類別的正確識別數與總識別數的比例。

micro_recall:微觀召回率:全局視角下的召回率,所有實體類別的正確識別數與總樣本實體數的比例。

micro_f1:微觀F1分數:微觀準確率與微觀召回率的調和平均數。

列表.append():在列表末尾添加元素

參數名類型說明
element任意要添加的元素

logger.info():記錄日志信息(需配置日志模塊)

參數名類型說明
formatstr格式化字符串
*args可變參數格式化參數

sum():計算可迭代對象的元素總和

參數名類型說明
iterable可迭代對象如列表、元組
start數值(可選)初始累加值

列表推導式:通過簡潔語法生成新列表,語法:[表達式 for item in iterable if 條件]

    def show_stats(self):F1_scores = []for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:# 正確率 = 識別出的正確實體數 / 識別出的實體數# 召回率 = 識別出的正確實體數 / 樣本的實體數precision = self.stats_dict[key]["正確識別"] / (1e-5 + self.stats_dict[key]["識別出實體數"])recall = self.stats_dict[key]["正確識別"] / (1e-5 + self.stats_dict[key]["樣本實體數"])F1 = (2 * precision * recall) / (precision + recall + 1e-5)F1_scores.append(F1)self.logger.info("%s類實體,準確率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))self.logger.info("Macro-F1: %f" % np.mean(F1_scores))correct_pred = sum([self.stats_dict[key]["正確識別"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])total_pred = sum([self.stats_dict[key]["識別出實體數"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])true_enti = sum([self.stats_dict[key]["樣本實體數"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])micro_precision = correct_pred / (total_pred + 1e-5)micro_recall = correct_pred / (true_enti + 1e-5)micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)self.logger.info("Micro-F1 %f" % micro_f1)self.logger.info("--------------------")return

5.評估模型效果

?模型切換為評估模式:關閉Dropout等訓練層?

批次處理數據

? ? ? ? ?提取原始句子?sentences

? ?將數據遷移至GPU(若可用)

? ? ? ? ?預測時禁用梯度計算(torch.no_grad())優化內存

統計結果:調用?write_stats?對比預測與真實標簽

epoch:當前訓練輪次,用于日志。

logger:記錄日志的工具。

stats_dict:統計字典,記錄各實體類別的指標。

valid_data:驗證數據集,通常由?load_data?加載(如?config["valid_data_path"]?指定路徑)

index:?循環中的批次索引

batch_data:?循環中的數據。

sentences:當前批次的原始句子

pred_results:模型預測結果

write_stats():寫入統計信息

show_stats():顯示統計結果

logger.info():記錄日志信息(需配置日志模塊)

參數名類型說明
formatstr格式化字符串
*args可變參數格式化參數

defaultdict():創建帶有默認值工廠的字典

參數名類型說明
default_factory可調用對象如int、list、自定義函數

model.eval():將模型設置為評估模式(關閉Dropout等訓練層)

enumerate():返回索引和元素組成的枚舉對象

參數名類型說明
iterable可迭代對象如列表、字符串
startint(可選)起始索引,默認為0

torch.cuda.is_available():檢查當前環境是否支持CUDA(GPU加速)

cuda():將張量或模型移動到GPU

參數名類型說明
deviceint/str指定GPU設備號,如"cuda:0"

torch.no_grad():禁用梯度計算,節省內存并加速推理

    def eval(self, epoch):self.logger.info("開始測試第%d輪模型效果:" % epoch)self.stats_dict = {"LOCATION": defaultdict(int),"TIME": defaultdict(int),"PERSON": defaultdict(int),"ORGANIZATION": defaultdict(int)}self.model.eval()for index, batch_data in enumerate(self.valid_data):sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #輸入變化時這里需要修改,比如多輸入,多輸出的情況with torch.no_grad():pred_results = self.model(input_id) #不輸入labels,使用模型當前參數進行預測self.write_stats(labels, pred_results, sentences)self.show_stats()return

6.解碼

?根據代碼中,Schema文件映射的定義對標簽序列預處理:將數值標簽拼接為字符串(如?[0,4,4]?→?"044"

正則匹配實體

???04+B-LOCATION(0)后接多個I-LOCATION(4)

???15+B-ORGANIZATION(1)后接I-ORGANIZATION(5)

? ? ? ? ? ?其他實體類別同理

?索引對齊:根據匹配位置截取原始句子中的實體文本

Ⅰ、輸入預處理

在原句首添加?$?符號,通常用于對齊標簽與字符位置(例如避免索引越界)

        sentence = "$" + sentence

Ⅱ、標簽序列轉換

將整數標簽序列轉換為字符串,并截取長度與?sentence?對齊

str.join():將可迭代對象中的字符串元素按指定分隔符連接成一個新字符串

參數名類型說明
iterable可迭代對象元素必須為字符串類型

str():將對象轉換為字符串表示形式,支持自定義類的?__str__?方法

參數名類型說明
object任意要轉換的對象

len():返回對象的長度或元素個數(適用于字符串、列表、字典等)

參數名類型說明
object可迭代對象如字符串、列表等

列表推導式:通過簡潔語法生成新列表,支持條件過濾和多層循環

????????[expression for item in iterable if condition]

部分類型說明
expression表達式對?item?處理后的結果
item變量迭代變量
iterable可迭代對象如列表、range()?生成的序列
condition條件表達式 (可選)過濾不符合條件的元素
        labels = "".join([str(x) for x in labels[:len(sentence)+1]])

Ⅲ、初始化結果容器

創建默認值為列表的字典,存儲四類實體:

????????(LOCATION、ORGANIZATION、PERSON、TIME)的識別結果

defaultdict():創建默認值字典,當鍵不存在時自動生成默認值(基于工廠函數)

參數名類型說明
default_factory可調用對象如?intlist?或自定義函數
        results = defaultdict(list)

Ⅳ、正則表達式匹配

? ? (04+): 匹配以?0(B-LOCATION)開頭,后接多個?4(I-LOCATION)的連續標簽

? ? (15+)(26+)(37+)分別對應 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的標簽模式。

re.finditer():在字符串中全局搜索正則表達式匹配項,返回一個迭代器,每個元素為?Match?對象

參數名類型說明
patternstr?或正則表達式對象要匹配的正則表達式模式
stringstr要搜索的字符串
flagsint?(可選)正則匹配標志(如?re.IGNORECASE

.span():返回正則匹配的起始和結束索引(左閉右開區間)

列表.append():向列表末尾添加單個元素,直接修改原列表

參數名類型說明
element任意要添加的元素
        for location in re.finditer("(04+)", labels):s, e = location.span()results["LOCATION"].append(sentence[s:e])

Ⅴ、完整代碼?

    '''Schema文件{"B-LOCATION": 0,"B-ORGANIZATION": 1,"B-PERSON": 2,"B-TIME": 3,"I-LOCATION": 4,"I-ORGANIZATION": 5,"I-PERSON": 6,"I-TIME": 7,"O": 8}'''def decode(self, sentence, labels):sentence = "$" + sentencelabels = "".join([str(x) for x in labels[:len(sentence)+1]])results = defaultdict(list)for location in re.finditer("(04+)", labels):s, e = location.span()results["LOCATION"].append(sentence[s:e])for location in re.finditer("(15+)", labels):s, e = location.span()results["ORGANIZATION"].append(sentence[s:e])for location in re.finditer("(26+)", labels):s, e = location.span()results["PERSON"].append(sentence[s:e])for location in re.finditer("(37+)", labels):s, e = location.span()results["TIME"].append(sentence[s:e])return results

7.完整代碼?

# -*- coding: utf-8 -*-
import torch
import re
import numpy as np
from collections import defaultdict
from loader import load_data"""
模型效果測試
"""class Evaluator:def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = loggerself.valid_data = load_data(config["valid_data_path"], config, shuffle=False)def eval(self, epoch):self.logger.info("開始測試第%d輪模型效果:" % epoch)self.stats_dict = {"LOCATION": defaultdict(int),"TIME": defaultdict(int),"PERSON": defaultdict(int),"ORGANIZATION": defaultdict(int)}self.model.eval()for index, batch_data in enumerate(self.valid_data):sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #輸入變化時這里需要修改,比如多輸入,多輸出的情況with torch.no_grad():pred_results = self.model(input_id) #不輸入labels,使用模型當前參數進行預測self.write_stats(labels, pred_results, sentences)self.show_stats()returndef write_stats(self, labels, pred_results, sentences):assert len(labels) == len(pred_results) == len(sentences)if not self.config["use_crf"]:pred_results = torch.argmax(pred_results, dim=-1)for true_label, pred_label, sentence in zip(labels, pred_results, sentences):if not self.config["use_crf"]:pred_label = pred_label.cpu().detach().tolist()true_label = true_label.cpu().detach().tolist()true_entities = self.decode(sentence, true_label)pred_entities = self.decode(sentence, pred_label)# 正確率 = 識別出的正確實體數 / 識別出的實體數# 召回率 = 識別出的正確實體數 / 樣本的實體數for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:self.stats_dict[key]["正確識別"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])self.stats_dict[key]["樣本實體數"] += len(true_entities[key])self.stats_dict[key]["識別出實體數"] += len(pred_entities[key])returndef show_stats(self):F1_scores = []for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:# 正確率 = 識別出的正確實體數 / 識別出的實體數# 召回率 = 識別出的正確實體數 / 樣本的實體數precision = self.stats_dict[key]["正確識別"] / (1e-5 + self.stats_dict[key]["識別出實體數"])recall = self.stats_dict[key]["正確識別"] / (1e-5 + self.stats_dict[key]["樣本實體數"])F1 = (2 * precision * recall) / (precision + recall + 1e-5)F1_scores.append(F1)self.logger.info("%s類實體,準確率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))self.logger.info("Macro-F1: %f" % np.mean(F1_scores))correct_pred = sum([self.stats_dict[key]["正確識別"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])total_pred = sum([self.stats_dict[key]["識別出實體數"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])true_enti = sum([self.stats_dict[key]["樣本實體數"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])micro_precision = correct_pred / (total_pred + 1e-5)micro_recall = correct_pred / (true_enti + 1e-5)micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)self.logger.info("Micro-F1 %f" % micro_f1)self.logger.info("--------------------")return'''{"B-LOCATION": 0,"B-ORGANIZATION": 1,"B-PERSON": 2,"B-TIME": 3,"I-LOCATION": 4,"I-ORGANIZATION": 5,"I-PERSON": 6,"I-TIME": 7,"O": 8}'''def decode(self, sentence, labels):sentence = "$" + sentencelabels = "".join([str(x) for x in labels[:len(sentence)+1]])results = defaultdict(list)for location in re.finditer("(04+)", labels):s, e = location.span()results["LOCATION"].append(sentence[s:e])for location in re.finditer("(15+)", labels):s, e = location.span()results["ORGANIZATION"].append(sentence[s:e])for location in re.finditer("(26+)", labels):s, e = location.span()results["PERSON"].append(sentence[s:e])for location in re.finditer("(37+)", labels):s, e = location.span()results["TIME"].append(sentence[s:e])return results

五、主函數文件 main.py

1.代碼運行流程

配置參數 → 創建模型目錄 → 加載訓練數據 → 初始化模型 → 設備檢測:│├── GPU可用 → 遷移模型至GPU│└── GPU不可用 → 保持CPU模式→ 選擇優化器 → 初始化評估器 → 進入訓練循環:│├── 當前epoch → 訓練模式 → 遍歷數據批次:│                 ││                 ├── 清空梯度 → 數據遷移至GPU → 前向計算 → 分支判斷:│                 │             ││                 │             ├── 啟用CRF → 計算CRF損失 → 反向傳播 → 參數更新│                 │             ││                 │             └── 禁用CRF → 計算交叉熵損失 → 反向傳播 → 參數更新│                 ││                 └── 記錄批次損失 → 周期中點打印日志│└── 計算epoch平均損失 → 驗證集評估 → 保存當前模型權重

2.導入文件

# -*- coding: utf-8 -*-import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data

3.日志配置

logging.basicConfig():配置日志系統的基礎參數(一次性設置,應在首次日志調用前調用)

參數名類型是否必需默認值說明
filename字符串None日志輸出文件名(若指定,日志寫入文件而非控制臺)
filemode字符串'a'文件打開模式(如'w'覆蓋,'a'追加)
format字符串基礎格式日志格式模板(如'%(asctime)s - %(levelname)s - %(message)s'
datefmt字符串時間格式(如'%Y-%m-%d %H:%M:%S'
level整數WARNING日志級別(如logging.INFOlogging.DEBUG
stream對象None指定日志輸出流(如sys.stderr,與filename互斥)

logging.getLogger():獲取或創建指定名稱的日志記錄器(Logger)。若nameNone,返回根日志記錄器

參數名類型是否必需默認值說明
name字符串None日志記錄器名稱(分層結構,如'module.sub'
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

4.主函數 main  

Ⅰ、創建模型保存目錄

os.path.isdir():檢查指定路徑是否為目錄(文件夾)

參數名類型是否必需默認值說明
path字符串要檢查的路徑(絕對或相對)

os.mkdir():創建單個目錄(若父目錄不存在會拋出異常)

參數名類型是否必需默認值說明
path字符串要創建的目錄路徑
mode整數0o777目錄權限(八進制格式,某些系統可能忽略此參數)
    #創建保存模型的目錄if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])

Ⅱ、加載訓練數據

    #加載訓練數據train_data = load_data(config["train_data_path"], config)

Ⅲ、加載模型

    #加載模型model = TorchModel(config)

Ⅳ、檢查GPU并遷移模型

torch.cuda.is_available():檢查系統是否滿足 CUDA 環境要求

logger.info():記錄日志信息,輸出訓練過程中的關鍵狀態

參數類型必須說明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化參數(用于%占位符)

cuda():將張量或模型移動到GPU顯存,加速計算

參數類型必須說明示例
deviceint/str指定GPU設備(如0"cuda:0"tensor.cuda(device=0)
non_blockingbool是否異步傳輸數據(默認False)tensor.cuda(non_blocking=True)
    # 標識是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,遷移模型至gpu")model = model.cuda()

Ⅴ、加載優化器

    #加載優化器optimizer = choose_optimizer(config, model)

Ⅵ、加載評估器

    #加載效果測試類evaluator = Evaluator(config, model, logger)

Ⅶ、模型訓練 ?

① Epoch循環控制

range():Python 內置函數,用于生成一個不可變的整數序列,?核心功能是為循環控制提供高效的數值迭代支持

參數名類型默認值說明
start整數0序列起始值(包含)。若省略,則默認從?0?開始。例如?range(3)?等價于?range(0,3)
stop整數?必填序列結束值(不包含)。例如?range(2, 5)?生成?2,3,4
step整數1步長(正/負):
- ?正步長需滿足?start < stop,否則無輸出(如?range(5, 2)?無效)。
- ?負步長需滿足?start > stop,例如?range(5, 0, -1)?生成?5,4,3,2,1
?**不能為?0**?(否則觸發?ValueError
for epoch in range(config["epoch"]):epoch += 1
② 模型設置訓練模式?

train_loss:計算當前批次的損失值,通常結合損失函數(如交叉熵、均方誤差)使用

model.train():設置模型為訓練模式,啟用Dropout、BatchNorm等層的訓練行為

參數類型默認值說明示例
modeboolTrue是否啟用訓練模式(True)或評估模式(False)model.train(True)

logger.info():記錄日志信息,輸出訓練過程中的關鍵狀態

參數類型必須說明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化參數(用于%占位符)
        model.train()logger.info("epoch %d begin" % epoch)train_loss = []

③?Batch數據遍歷

enumerate():遍歷可迭代對象時返回索引和元素,支持自定義起始索引

參數類型必須說明示例
iterableIterable可迭代對象(如列表、生成器)enumerate(["a", "b"])
startint索引起始值(默認0)enumerate(data, start=1)
        for index, batch_data in enumerate(train_data):

④?梯度清零與設備切換

optimizer.zero_grad():清空模型參數的梯度,防止梯度累積

參數類型必須說明示例
set_to_nonebool是否將梯度置為None(高效但危險)optimizer.zero_grad(True)

cuda():將張量或模型移動到GPU顯存,加速計算

參數類型必須說明示例
deviceint/str指定GPU設備(如0"cuda:0"tensor.cuda(device=0)
non_blockingbool是否異步傳輸數據(默認False)tensor.cuda(non_blocking=True)
            optimizer.zero_grad()if cuda_flag:batch_data = [d.cuda() for d in batch_data]

⑤ 前向傳播與損失計算
            input_id, labels = batch_data   #輸入變化時這里需要修改,比如多輸入,多輸出的情況loss = model(input_id, labels)

⑥ 反向傳播與參數更新

loss.backward():反向傳播計算梯度,基于損失值更新模型參數的.grad屬性

參數類型必須說明示例
retain_graphbool是否保留計算圖(用于多次反向傳播)loss.backward(retain_graph=True)

optimizer.step():根據梯度更新模型參數,執行優化算法(如SGD、Adam)

參數類型必須說明示例
closureCallable重新計算損失的閉包函數(如LBFGS)optimizer.step(closure)
            loss.backward()optimizer.step()

⑦ 損失記錄與日志輸出

列表.append():在列表末尾添加元素,直接修改原列表

參數類型必須說明示例
objectAny要添加到列表末尾的元素train_loss.append(loss.item())

int():將字符串或浮點數轉換為整數,支持進制轉換

參數類型必須說明示例
xstr/float待轉換的值(如字符串或浮點數)int("10", base=2)(輸出2進制10=2)
baseint進制(默認10)

len():返回對象(如列表、字符串)的長度或元素個數

參數類型必須說明示例
objSequence/Collection可計算長度的對象(如列表、字符串)len([1, 2, 3])(返回3)

logger.info():記錄日志信息,輸出訓練過程中的關鍵狀態

參數類型必須說明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化參數(用于%占位符)
            train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)

⑧?Epoch評估與日志

item():從張量中提取標量值(僅當張量包含單個元素時可用)

列表.append():Python 列表(list)的內置方法,用于向列表的 ?末尾?添加一個元素。

參數名?類型?默認值?說明
element任意類型要添加到列表末尾的元素。可以是單個值(如?42)、對象(如?[1, 2, 3])等。

logger.info():記錄日志信息,輸出訓練過程中的關鍵狀態

參數類型必須說明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化參數(用于%占位符)
            train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)

⑨ 完整訓練代碼
    #訓練for epoch in range(config["epoch"]):epoch += 1model.train()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):optimizer.zero_grad()if cuda_flag:batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #輸入變化時這里需要修改,比如多輸入,多輸出的情況loss = model(input_id, labels)loss.backward()optimizer.step()train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)logger.info("epoch average loss: %f" % np.mean(train_loss))evaluator.eval(epoch)

Ⅷ、模型保存

os.path.join():Python 中用于拼接路徑的核心函數,其核心價值在于自動處理不同操作系統的路徑分隔符,從而保證代碼的跨平臺兼容性

參數類型必填說明
path1字符串初始路徑組件
*paths可變參數后續路徑組件(可傳多個)

torch.save():??PyTorch 中用于序列化保存模型、張量或字典等對象的核心函數,支持將數據持久化存儲為?.pth?或?.pt?文件,便于后續加載和復用

參數名?類型?默認值?說明
obj任意 PyTorch 對象必填待保存的對象,如模型、張量或字典。
fstr?或文件對象必填保存路徑(如?'model.pth')或已打開的文件對象(需二進制寫入模式?'wb'
pickle_protocolint2指定 pickle 協議版本(通常無需修改,高版本可能提升效率但需兼容性驗證)
_use_new_zipfile_serializationboolTrue啟用新版序列化格式(壓縮率更高,推薦保持默認)
    model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)# torch.save(model.state_dict(), model_path)return model, train_data

5.調用模型預測

# -*- coding: utf-8 -*-import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_datalogging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)"""
模型訓練主程序
"""def main(config):#創建保存模型的目錄if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])#加載訓練數據train_data = load_data(config["train_data_path"], config)#加載模型model = TorchModel(config)# 標識是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,遷移模型至gpu")model = model.cuda()#加載優化器optimizer = choose_optimizer(config, model)#加載效果測試類evaluator = Evaluator(config, model, logger)#訓練for epoch in range(config["epoch"]):epoch += 1model.train()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):optimizer.zero_grad()if cuda_flag:batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #輸入變化時這里需要修改,比如多輸入,多輸出的情況loss = model(input_id, labels)loss.backward()optimizer.step()train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)logger.info("epoch average loss: %f" % np.mean(train_loss))evaluator.eval(epoch)# 保存模型model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)torch.save(model.state_dict(), model_path)return model, train_dataif __name__ == "__main__":model, train_data = main(Config)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/72534.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/72534.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/72534.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

藍橋杯學習-11棧

11棧 先進后出 例題–藍橋19877 用數組來設置棧 1.向棧頂插入元素--top位置標記元素 2.刪除棧頂元素--top指針減減 3.輸出棧頂元素--輸出top位置元素使用arraylist import java.util.ArrayList; import java.util.Scanner;public class Main {public static void main(Str…

Linux 藍牙音頻軟件棧實現分析

Linux 藍牙音頻軟件棧實現分析 藍牙協議棧簡介藍牙控制器探測BlueZ 插件系統及音頻插件藍牙協議棧簡介 藍牙協議棧是實現藍牙通信功能的軟件架構,它由多個層次組成,每一層負責特定的功能。藍牙協議棧的設計遵循藍牙標準 (由藍牙技術聯盟,Bluetooth SIG 定義),支持多種藍牙…

JetBrains(全家桶: IDEA、WebStorm、GoLand、PyCharm) 2024.3+ 2025 版免費體驗方案

JetBrains&#xff08;全家桶: IDEA、WebStorm、GoLand、PyCharm&#xff09; 2024.3 2025 版免費體驗方案 前言 JetBrains IDE 是許多開發者的主力工具&#xff0c;但從 2024.02 版本起&#xff0c;JetBrains 調整了試用政策&#xff0c;新用戶不再享有默認的 30 天免費試用…

1.8PageTable

頁表的作用 虛擬地址空間映射&#xff1a;頁表記錄了進程的虛擬頁號到物理頁號的映射關系。每個進程都有自己的頁表&#xff0c;操作系統為每個進程維護一個獨立的頁表。內存管理&#xff1a;頁表用于實現虛擬內存管理&#xff0c;支持進程的虛擬地址空間和物理地址空間之間的…

Prosys OPC UA Gateway:實現 OPC Classic 與 OPC UA 無縫連接

在工業自動化的數字化轉型中&#xff0c;設備與系統之間的高效通信至關重要。然而&#xff0c;許多企業仍依賴于基于 COM/DCOM 技術的 OPC 產品&#xff0c;這給與現代化的 OPC UA 架構的集成帶來了挑戰。 Prosys OPC UA Gateway 正是為解決這一問題而生&#xff0c;它作為一款…

數據結構------線性表

一、線性表順序存儲詳解 &#xff08;一&#xff09;線性表核心概念 1. 結構定義 // 數據元素類型 typedef struct person {char name[32];char sex;int age;int score; } DATATYPE;// 順序表結構 typedef struct list {DATATYPE *head; // 存儲空間基地址int tlen; …

【WPF】在System.Drawing.Rectangle中限制鼠標保持在Rectangle中移動?

方案一&#xff0c;在OnMouseMove方法限制 在WPF應用程序中&#xff0c;鼠標在移動過程中保持在這個矩形區域內&#xff0c;可以通過監聽鼠標的移動事件并根據鼠標的當前位置調整其坐標來實現。不過需要注意的是&#xff0c;WPF原生使用的是System.Windows.Rect而不是System.D…

基于銀河麒麟系統ARM架構安裝達夢數據庫并配置主從模式

達夢數據庫簡要概述 達夢數據庫&#xff08;DM Database&#xff09;是一款由武漢達夢公司開發的關系型數據庫管理系統&#xff0c;支持多種高可用性和數據同步方案。在主從模式&#xff08;也稱為 Master-Slave 或 Primary-Secondary 模式&#xff09;中&#xff0c;主要通過…

系統思考全球化落地

感謝加密貨幣公司Bybit的再次邀請&#xff0c;為全球團隊分享系統思考課程&#xff01;雖然大家來自不同國家&#xff0c;線上學習的形式依然讓大家充滿熱情與互動&#xff0c;思維的碰撞不斷激發新的靈感。 盡管時間存在挑戰&#xff0c;但我看到大家的討論異常積極&#xff…

Figma的漢化

Figma的漢化插件有客戶端版本與Chrome版本&#xff0c;大家可根據自己的需要進行選擇。 下載插件 進入Figma軟件漢化-Figma中文版下載-Figma中文社區使用客戶端&#xff1a;直接下載客戶端使用網頁版&#xff1a;安裝chrome瀏覽器漢化插件國外推薦前往chrome商店安裝國內推薦下…

【Go語言圣經2.5】

目標 了解類型定義不僅告訴編譯器如何在內存中存儲和處理數據&#xff0c;還對程序設計產生深遠影響&#xff1a; 內存結構&#xff1a;類型決定了變量的底層存儲&#xff08;比如占用多少字節、內存布局等&#xff09;。操作符與方法集&#xff1a;類型決定了哪些內置運算符…

IDEA 一鍵完成:打包 + 推送 + 部署docker鏡像

1、本方案要解決場景&#xff1f; 想直接通過本地 IDEA 將最新的代碼部署到遠程服務器上。 2、本方案適用于什么樣的項目&#xff1f; 項目是一個 Spring Boot 的 Java 項目。項目用 maven 進行管理。項目的運行基于 docker 容器&#xff08;即項目將被打成 docker image&am…

SpringBoot 第一課(Ⅲ) 配置類注解

目錄 一、PropertySource 二、ImportResource ①SpringConfig &#xff08;Spring框架全注解&#xff09; ②ImportResource注解實現 三、Bean 四、多配置文件 多Profile文件的使用 文件命名約定&#xff1a; 激活Profile&#xff1a; YAML文件支持多文檔塊&#xff…

深度解析React Native底層核心架構

React Native 工作原理深度解析 一、核心架構&#xff1a;三層異構協作體系 React Native 的跨平臺能力源于其獨特的 JS層-Shadow層-Native層 架構設計&#xff0c;三者在不同線程中協同工作&#xff1a; JS層 運行于JavaScriptCore&#xff08;iOS&#xff09;或Hermes&…

對話智能體的正確打開方式:解析主流AI聊天工具的核心能力與使用方式

一、人機對話的黃金法則 在與人工智能對話系統交互時&#xff0c;掌握以下七項核心原則可顯著提升溝通效率&#xff1a;文末有教程分享地址 意圖精準表達術 采用"背景需求限定條件"的結構化表達 示例優化&#xff1a;"請用Python編寫一個網絡爬蟲&#xff08…

Xinference大模型配置介紹并通過git-lfs、hf-mirror安裝

文章目錄 一、Xinference開機服務systemd二、語言&#xff08;LLM&#xff09;模型2.1 配置介紹2.2 DeepSeek-R1-Distill-Qwen-32B&#xff08;大杯&#xff09;工具下載git-lfs&#xff08;可以繞過Hugging Face&#xff09; 2.3 DeepSeek-R1-Distill-Qwen-32B-Q4_K_M-GGUF&am…

MyBatis操縱數據庫-XML實現(補充)

目錄 一.多表查詢二.MyBatis參數賦值(#{ }和${ })2.1 #{ }和${ }的使用2.2 #{ }和${ }的區別2.3 SQL注入2.3 ${ }的應用場景2.3.1 排序功能2.3.2 like查詢 一.多表查詢 多表查詢的操作和單表查詢基本相同&#xff0c;只需改變一下SQL語句&#xff0c;同時也要在實體類中創建出…

快速導出接口設計表——基于DOMParser的Swagger接口詳情半自動化提取方法

作者聲明&#xff1a;不想看作者聲明的&#xff08;需要生成接口設計表的&#xff09;直接前往https://capujin.github.io/A2T/。 注&#xff1a;Github Pages生成的頁面可能會出現訪問不穩定&#xff0c;暫時沒將源碼上傳至Github&#xff0c;如有需要&#xff0c;可聯系我私…

TS常見內置映射類型的實現及應用場景

以下是 TypeScript 在前端項目中 常用的映射類型&#xff08;Mapped Types&#xff09;&#xff0c;結合具體場景和代碼示例&#xff0c;幫助開發者高效處理復雜類型&#xff1a; 一、基礎映射類型 1. Partial<T> 作用&#xff1a;將對象類型 T 的所有屬性變為可選。 實…

介紹如何使用YOLOv8模型進行基于深度學習的吸煙行為檢測

下面為你詳細介紹如何使用YOLOv8模型進行基于深度學習的吸煙行為檢測&#xff0c;包含環境配置、數據準備、模型訓練以及推理等步驟。 1. 環境配置 首先&#xff0c;你需要安裝必要的庫&#xff0c;主要是ultralytics庫&#xff0c;它包含了YOLOv8模型。你可以使用以下命令進…