文本分類作為自然語言處理中的基礎任務,能夠幫助我們將海量醫學摘要自動歸類到具體疾病領域中。本文將基于NVIDIA NeMo框架,構建一個用于醫學疾病摘要分類的深度學習應用,支持將摘要劃分為三類:癌癥類疾病、神經系統疾病及障礙、以及其他類型。我們將通過命令行和配置文件快速完成訓練,也將深入探討PyTorch Lightning的核心機制,手動實現訓練過程,以便更靈活地調試與擴展模型。
文章目錄
- 1 NeMo和PyTorch Lightning
- 2 從命令行進行文本分類
- 2.1 準備數據
- 2.2 配置文件
- 2.3 訓練模型
- 3 PyTorch Lightning詳解
- 3.1 LightningModule、Trainer和exp_manager
- 3.2 腳本解析
- 4 總結
1 NeMo和PyTorch Lightning
NeMo是一個用于構建對話式AI應用程序的開源工具包。NeMo是圍繞神經模塊(Neural Modules)構建的,它們是神經網絡的概念模塊,接受類型化輸入并產生類型化輸出。這些模塊通常表示數據層、編碼器、解碼器、語言模型、損失函數或激活組合方法。
NeMo深度學習框架基于PyTorch Lightning,這是一個PyTorch的封裝器,用于組織神經網絡訓練代碼。PyTorch Lightning提供了簡單且高性能的多GPU/多節點混合精度訓練選項。
使用PyTorch Lightning創建一個深度神經網絡項目通常需要兩個主要組件:
- LightningModule:將PyTorch代碼組織為訓練、驗證和測試所需的計算、優化器和循環。該抽象層使深度學習實驗更容易理解和復現。
- Trainer:可以接收
LightningModule
并自動完成深度學習訓練所需的所有內容。
2 從命令行進行文本分類
我們要解決的問題是:給定一個醫學疾病摘要,該摘要是關于癌癥、神經系統疾病,還是其他?
2.1 準備數據
上一篇文章中,我們已經探索了NCBI-disease corpus
,回憶一下,文本分類數據文件由制表符分隔的摘要和標簽組成,并帶有標題行。
# 設置數據目錄變量
TC3_DATA_DIR = './data/NCBI_tc-3'
# 列出所有 .tsv 文件
!ls $TC3_DATA_DIR/*.tsv
輸出如下:
/dli/task/data/NCBI_tc-3/dev.tsv /dli/task/data/NCBI_tc-3/train.tsv
/dli/task/data/NCBI_tc-3/test.tsv
大致看一下文件內容:
print("*****\ntrain.tsv sample\n*****")
!head -n 3 $TC3_DATA_DIR/train.tsv
print("\n\n*****\ndev.tsv sample\n*****")
!head -n 3 $TC3_DATA_DIR/dev.tsv
print("\n\n*****\ntest.tsv sample\n*****")
!head -n 3 $TC3_DATA_DIR/test.tsv
部分輸出:
數據的幾個特征:
-
預處理后數據已符合文檔中規定的格式:
[單詞][空格][單詞][空格][單詞][TAB][標簽]
-
標題行需要刪除:標題行不是訓練數據,它會干擾模型的學習。
-
文本較長,因此訓練時需要考慮
max_seq_length
值。
首先移除標題行,我們可以使用bash流編輯器sed
:
# 刪除每個文件的第一行(標題),生成新格式文件
!sed 1d $TC3_DATA_DIR/train.tsv > $TC3_DATA_DIR/train_nemo_format.tsv
!sed 1d $TC3_DATA_DIR/dev.tsv > $TC3_DATA_DIR/dev_nemo_format.tsv
!sed 1d $TC3_DATA_DIR/test.tsv > $TC3_DATA_DIR/test_nemo_format.tsv
# 查看新生成的訓練、驗證、測試文件(無標題)
print("*****\ntrain_nemo_format.tsv sample\n*****")
!head -n 3 $TC3_DATA_DIR/train_nemo_format.tsv
print("\n\n*****\ndev_nemo_format.tsv sample\n*****")
!head -n 3 $TC3_DATA_DIR/dev_nemo_format.tsv
print("\n\n*****\ntest_nemo_format.tsv sample\n*****")
!head -n 3 $TC3_DATA_DIR/test_nemo_format.tsv
部分輸出如下,可以看到已經去掉了標題行:
2.2 配置文件
模型訓練的所有配置參數都在text_classification_config.yaml中。注意鍵的層次結構,特別是三個頂層鍵:trainer
、model
和exp_manager
。
trainer:gpus:num_nodes:max_epochs:...model:nemo_path:tokenizer: language_model:classifier_head:...exp_manager:...
# 查看 YAML 配置文件內容
CONFIG_DIR = "/dli/task/nemo/examples/nlp/text_classification/conf"
CONFIG_FILE = "text_classification_config.yaml"
!cat $CONFIG_DIR/$CONFIG_FILE
必須修改的參數
參數 | 位置 | 作用 |
---|---|---|
dataset.num_classes | model.dataset | 設置分類的類別數(如2表示二分類) |
train_ds.file_path | model.train_ds | 指定訓練數據的路徑 |
validation_ds.file_path | model.validation_ds | 指定驗證數據的路徑 |
test_ds.file_path | model.test_ds | 指定測試數據的路徑 |
建議修改的參數
參數 | 位置 | 作用 | 修改建議和原因 |
---|---|---|---|
dataset.max_seq_length | model.dataset | 輸入文本的最大長度,默認是256 | 修改為128 ,以減少內存使用,適合資源有限的設備 |
infer_samples | model.infer_samples | 用于訓練完成后測試模型效果的示例句子 | 替換為疾病相關句子,更貼合目標應用領域 |
trainer.max_epochs | trainer | 最大訓練輪數 | 設置為較小值(如3-10),便于快速測試驗證訓練流程是否通順 |
這些參數在初次運行時可以先使用默認值,后續根據效果再微調:
trainer.devices
(當前已設置為1)trainer.precision
(默認32,也可根據硬件支持設為16)trainer.gradient_clip_val
、accumulate_grad_batches
等優化細節參數exp_manager
的所有項(如exp_dir
,create_tensorboard_logger
等)
2.3 訓練模型
Hydra是一個用于配置管理的Python框架,它允許用戶輕松地從命令行覆蓋配置文件(如.yaml
)中的參數,而無需手動修改配置文件本身。所以這個腳本可以通過命令行靈活地傳入配置參數來運行,比如更改訓練集路徑、模型參數等。
下面是我們訓練的python腳本text_classification_with_bert.py的部分內容:
import lightning.pytorch as pl
from omegaconf import DictConfig, OmegaConf # 用于處理 Hydra 加載的配置對象
from nemo.core.config import hydra_runner # 用于接入 Hydra 的裝飾器# 使用 hydra_runner 裝飾器加載配置文件
# config_path 指向配置文件目錄(如 conf/),config_name 是配置文件名(不帶 .yaml 后綴)
@hydra_runner(config_path="conf", config_name="text_classification_config")
def main(cfg: DictConfig) -> None: # cfg 是通過 Hydra 加載的配置對象# 打印當前加載的配置,便于調試和確認print(OmegaConf.to_yaml(cfg))...# 訪問配置中任意字段,如:# cfg.model.train_ds.file_path# cfg.trainer.max_epochs...if __name__ == '__main__':main()
該腳本使用Hydra管理配置文件,也就是說我們可以通過命令行直接覆蓋想要修改的值:
%%time
# 訓練大約耗時 2 分鐘# 設置模型所在目錄
TC_DIR = "/dli/task/nemo/examples/nlp/text_classification"# 設置我們要覆蓋的值
NUM_CLASSES = 3
MAX_SEQ_LENGTH = 128
PATH_TO_TRAIN_FILE = "/dli/task/data/NCBI_tc-3/train_nemo_format.tsv"
PATH_TO_VAL_FILE = "/dli/task/data/NCBI_tc-3/dev_nemo_format.tsv"
PATH_TO_TEST_FILE = "/dli/task/data/NCBI_tc-3/test_nemo_format.tsv"
# 推理樣本對應類別應分別為 0, 1, 2
INFER_SAMPLES_0 = "In contrast no mutations were detected in the p53 gene suggesting that this tumour suppressor is not frequently altered in this leukaemia "
INFER_SAMPLES_1 = "The first predictive testing for Huntington disease was based on analysis of linked polymorphic DNA markers to estimate the likelihood of inheriting the mutation for HD"
INFER_SAMPLES_2 = "Further studies suggested that low dilutions of C5D serum contain a factor or factors interfering at some step in the hemolytic assay of C5 rather than a true C5 inhibitor or inactivator"
MAX_EPOCHS = 3# 運行訓練腳本并通過命令行覆蓋默認配置參數
!python $TC_DIR/text_classification_with_bert.py \model.dataset.num_classes=$NUM_CLASSES \model.dataset.max_seq_length=$MAX_SEQ_LENGTH \model.train_ds.file_path=$PATH_TO_TRAIN_FILE \model.validation_ds.file_path=$PATH_TO_VAL_FILE \model.test_ds.file_path=$PATH_TO_TEST_FILE \model.infer_samples=["$INFER_SAMPLES_0","$INFER_SAMPLES_1","$INFER_SAMPLES_2"] \trainer.max_epochs=$MAX_EPOCHS
每次訓練實驗開始時,都會打印實驗配置的日志,包括通過命令行添加或覆蓋的參數。它還會顯示一些附加信息,例如可用GPU、日志保存位置、數據集樣本及其對應的模型輸入。日志中還提供了數據集中序列長度的統計信息。
每個訓練epoch結束后,會輸出一張驗證集指標表,包括準確率、召回率和F1分數。F1分數同時考慮了假陽性和假陰性,被認為比單純的準確率更有意義。
訓練結束后,NeMo會將最后的檢查點保存在model.nemo_file_path
指定的路徑。由于我們使用的是默認值,它應該已保存在.nemo
格式的工作區中。
# 列出當前目錄下的 .nemo 模型文件
!ls *.nemo# 輸出 text_classification_model.nemo
上面代碼的結果可能不是很好,但僅需微調幾個參數,就可以很容易嘗試另一個實驗。通過延長訓練時間、調整學習率以及更改訓練和驗證集的batch size都可能提升結果。
試著做以下優化,再運行一次代碼:
- 設置混合精度
amp_level
為 “O1”,precision
為16,這樣模型運行更快,精度下降很小甚至不下降。 - 將訓練epoch數調高,以取得更好的結果。
- 略微提升學習率,使模型權重更新時對誤差響應更敏感。
下面用TensorBoard
可視化一下訓練的過程,我們選擇訓練損失標量。圖中的橙色是第一次運行,藍色是第二次運行。可以看到第二次的loss更小。
更換語言模型
運行以下單元格查看NeMo支持的BERT類語言模型列表:
# 查看支持的預訓練語言模型列表
from nemo.collections import nlp as nemo_nlp
nemo_nlp.modules.get_pretrained_lm_models_list()
我們還可以通過修改yaml
中的PRETRAINED_MODEL_NAME
字段來選擇一個新的語言模型,例如megatron-bert-345m-cased
。
- 為了節省GPU顯存,你還可以將
batch_size
降到32、16甚至8,將max_seq_length
(每條文本的token
長度) 降到64。
3 PyTorch Lightning詳解
雖然NeMo提供了諸如text_classification_with_bert.py
的訓練腳本,能夠通過配置文件一鍵完成模型訓練、評估和推理。但在需要更靈活控制訓練流程的場景下(例如自定義模型結構、修改損失函數或逐步調試訓練過程),你可以跳出NeMo的封裝,直接采用PyTorch Lightning的原生工作方式:手動構建Trainer
、初始化模型、并調用fit()
、test()
等方法,實現對整個訓練過程的精細化掌控。
import lightning.pytorch as pl
from omegaconf import DictConfig, OmegaConffrom nemo.collections.nlp.models.text_classification import TextClassificationModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager@hydra_runner(config_path="conf", config_name="text_classification_config")
def main(cfg: DictConfig) -> None:try:strategy = NLPDDPStrategy(find_unused_parameters=True)except (ImportError, ModuleNotFoundError):strategy = 'auto'trainer = pl.Trainer(strategy=strategy, **cfg.trainer)exp_manager(trainer, cfg.get("exp_manager", None))if not cfg.model.train_ds.file_path:raise ValueError("'train_ds.file_path' need to be set for the training!")model = TextClassificationModel(cfg.model, trainer=trainer)trainer.fit(model)if cfg.model.nemo_path:model.save_to(cfg.model.nemo_path)logging.info(f'Model is saved into `.nemo` file: {cfg.model.nemo_path}')# We evaluate the trained model on the test set if test_ds is set in the config fileif cfg.model.test_ds.file_path:trainer.test(model=model, ckpt_path=None, verbose=False)# perform inference on a list of queries.if "infer_samples" in cfg.model and cfg.model.infer_samples:logging.info("Starting the inference on some sample queries...")# max_seq_length=512 is the maximum length BERT supports.results = model.classifytext(queries=cfg.model.infer_samples, batch_size=16, max_seq_length=512)for query, result in zip(cfg.model.infer_samples, results):logging.info(f'Query : {query}')logging.info(f'Predicted label: {result}')if __name__ == '__main__':main()
3.1 LightningModule、Trainer和exp_manager
1.LightningModule
:定義模型邏輯
LightningModule
,它是你定義模型邏輯的地方,也是Lightning框架解耦模型代碼與工程代碼的關鍵。
LightningModule
是PyTorch Lightning的核心類,用于封裝模型結構與訓練邏輯,它是對原生PyTorch中nn.Module
的擴展。你在其中實現所有與模型有關的邏輯,包括:
- 模型結構定義(如BERT、CNN等)
- 前向傳播 (
forward
) - 損失函數計算(在
training_step
中定義) - 驗證與測試流程
- 優化器和學習率調度器配置
在原生PyTorch中,訓練邏輯分散在多個位置:
- 模型結構定義在
nn.Module
子類中 - 訓練循環手動寫
- 損失函數、優化器單獨配置
- 日志記錄、GPU分發、checkpoint都要手動處理
隨著模型規模變大,這種方式變得難以維護和復現。PyTorch Lightning提出了解耦思想:模型邏輯(做什么)放在LightningModule
里,工程控制(怎么做)交給Trainer
自動處理。
以下是一個典型的LightningModule
模板:
import pytorch_lightning as pl
import torch.nn as nn
import torchclass MyModel(pl.LightningModule):def __init__(self):super().__init__()self.model = nn.Linear(128, 3)self.loss_fn = nn.CrossEntropyLoss()def forward(self, x):return self.model(x)def training_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = self.loss_fn(logits, y)self.log("train_loss", loss)return lossdef validation_step(self, batch, batch_idx):...def test_step(self, batch, batch_idx):...def configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=1e-3)
你只需實現5個核心函數:
函數名 | 作用 |
---|---|
__init__() | 定義模型結構和初始化組件 |
forward() | 定義前向傳播邏輯 |
training_step() | 每個batch的訓練邏輯(計算loss等) |
validation_step() / test_step() | 驗證與測試時的邏輯 |
configure_optimizers() | 定義優化器和學習率調度策略 |
NeMo中的LightningModule應用
在NVIDIA NeMo框架中,有這樣的代碼:
from nemo.collections.nlp.models import TextClassificationModel
這里的TextClassificationModel
就是一個LightningModule
的子類。你無需手寫training_step()
、configure_optimizers()
等邏輯,NeMo已為你封裝好。
但如果你需要:修改損失函數、替換語言模型結構或添加多任務損失或自定義輸出,你就需要深入理解LightningModule
,并可能繼承它進行擴展。
2. trainer
:訓練控制器
trainer
是PyTorch Lightning提供的核心組件,用于統一控制整個訓練流程。它將訓練的細節如設備管理、訓練輪數、分布式訓練、混合精度等都統一封裝,用戶只需通過配置指定參數,即可自動完成這些操作。
Lightning的設計理念之一就是將工程代碼與模型代碼解耦,使用戶只關注模型結構和損失函數等核心內容,而不必重復編寫訓練循環、GPU分發、日志記錄等代碼。
模塊 | 你要做的事 | Lightning自動幫你做 |
---|---|---|
模型結構 | 寫forward() | Lightning自動調用 |
損失計算 | 實現training_step() | Lightning自動收集loss |
Optimizer | 實現configure_optimizers() | 自動調用優化器、scheduler |
訓練邏輯 | 不寫for循環 | Lightning管理訓練輪、batch |
GPU訓練 | 設置gpus=1 | 自動.cuda() 、分布式訓練 |
日志記錄 | 用self.log() | 自動寫入TensorBoard |
Checkpoint | 不手寫保存代碼 | 自動保存/加載ckpt |
混合精度 | precision=16 | 自動使用AMP |
pl.Trainer
提供了以下核心功能:
- 控制訓練輪數(
max_epochs
) - 管理硬件資源(
gpus
,tpus
,strategy
等) - 自動支持分布式訓練(
DDP
) - 啟用混合精度訓練(
precision=16
) - 控制日志頻率、驗證頻率(
log_every_n_steps
,val_check_interval
) - 梯度累計與裁剪(
accumulate_grad_batches
,gradient_clip_val
) - 自動斷點恢復(
resume_from_checkpoint
)
一般使用如下方式創建trainer
實例:
import pytorch_lightning as pltrainer = pl.Trainer(**config.trainer)
其中config.trainer
是通過YAML配置文件或者OmegaConf
加載后的字典結構,定義了所有訓練相關參數。
例如,YAML配置中可能包含:
trainer:gpus: 1max_epochs: 5precision: 16amp_level: O1log_every_n_steps: 10
3. exp_manager
:實驗管理器
exp_manager
是NeMo特有的實驗管理工具,其主要目的是為訓練過程提供自動化的日志記錄、模型checkpoint保存、目錄結構管理等功能。它是對PyTorch Lightning的日志系統的進一步封裝與增強。
在機器學習項目中,隨著實驗的增多,如何規范地組織日志、模型文件、超參數記錄,成為一個影響效率的問題。exp_manager
正是為了解決這一問題而設計。它有如下功能:
- 自動創建實驗目錄
- 自動保存訓練日志(TensorBoard, MLFlow, WandB等可選)
- 自動保存模型checkpoint(包括last.ckpt和best.ckpt)
- 支持從上一次訓練中斷處恢復
- 自動記錄超參數(寫入hparams.yaml)
- 管理日志命名規則(按模型類型、時間戳、版本號分類)
調用方式如下:
from nemo.utils.exp_manager import exp_managerexp_manager(trainer, config.exp_manager)
其中config.exp_manager
是包含日志和實驗設置的配置段。例如:
exp_manager:exp_dir: nullname: TextClassificationModelcreate_tensorboard_logger: truecreate_checkpoint_callback: true
如果exp_dir
設置為null
,則會默認創建nemo_experiments/<model_name>/version_x/
目錄結構。
運行完一個訓練流程后,輸出目錄可能如下:
nemo_experiments/└── TextClassificationModel/└── version_0/├── checkpoints/│ ├── TextClassificationModel--val_loss=0.432.ckpt├── hparams.yaml└── events.out.tfevents... (TensorBoard 日志)
這樣,每次實驗都有完整的、隔離的輸出記錄,方便追蹤、對比和復現。
3.2 腳本解析
現在我們就來看一下這個訓練腳本到底做了什么。執行以下單元格以重啟內核,清除變量和GPU內存。
# 重啟 Notebook 內核,釋放內存與資源
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)
我們首先導入所需模塊:NeMo NLP模塊、實驗管理器、PyTorch Lightning、OmegaConf。
# 導入必需模塊
from nemo.collections import nlp as nemo_nlp
from nemo.utils.exp_manager import exp_managerimport torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
這里我們用OmegaConf
(使用python配置yaml字段) 來修改配置對象,然后將其傳遞給Trainer
、exp_manager
和TextClassificationModel
。如果你想使用非默認語言模型,可添加以下代碼:
PRETRAINED_MODEL_NAME = 'bert-base-cased'
config.model.language_model.pretrained_model_name = PRETRAINED_MODEL_NAME
# 載入 YAML 配置文件
TC_DIR = "/dli/task/nemo/examples/nlp/text_classification"
CONFIG_FILE = "text_classification_config.yaml"
config = OmegaConf.load(TC_DIR + "/conf/" + CONFIG_FILE)# 設置訓練所需參數
NUM_CLASSES = 3
MAX_SEQ_LENGTH = 128
PATH_TO_TRAIN_FILE = "/dli/task/data/NCBI_tc-3/train_nemo_format.tsv"
PATH_TO_VAL_FILE = "/dli/task/data/NCBI_tc-3/dev_nemo_format.tsv"
PATH_TO_TEST_FILE = "/dli/task/data/NCBI_tc-3/test_nemo_format.tsv"# 設置推理樣本(標簽應分別為 0, 1, 2)
INFER_SAMPLES = ["Germline mutations in BRCA1 are responsible for most cases of inherited breast and ovarian cancer ","The first predictive testing for Huntington disease was based on analysis of linked polymorphic DNA markers to estimate the likelihood of inheriting the mutation for HD","Further studies suggested that low dilutions of C5D serum contain a factor or factors interfering at some step in the hemolytic assay of C5 rather than a true C5 inhibitor or inactivator"
]MAX_EPOCHS = 5
AMP_LEVEL = 'O1'
PRECISION = 16
LR = 5.0e-05# 使用 OmegaConf 修改配置對象
config.model.dataset.num_classes = NUM_CLASSES
config.model.dataset.max_seq_length = MAX_SEQ_LENGTH
config.model.train_ds.file_path = PATH_TO_TRAIN_FILE
config.model.validation_ds.file_path = PATH_TO_VAL_FILE
config.model.test_ds.file_path = PATH_TO_TEST_FILE
config.model.infer_samples = INFER_SAMPLES
config.trainer.max_epochs = MAX_EPOCHS
config.trainer.amp_level = AMP_LEVEL
config.trainer.precision = PRECISION
config.model.optim.lr = LR
現在配置完成,初始化Trainer和實驗管理器:
# 初始化 Trainer 和實驗管理器
trainer = pl.Trainer(**config.trainer)
exp_manager(trainer, config.exp_manager)
# 使用更新后的 config 初始化文本分類模型
model = nemo_nlp.models.TextClassificationModel(config.model, trainer=trainer)
%%time
# 開始訓練并保存模型
trainer.fit(model)
model.save_to(config.model.nemo_path)
訓練完成后,使用測試集進行評估:
# 使用 test 集評估模型
trainer.test(model=model, verbose=False)
輸出如下:
現在運行推理,對配置中設置的句子進行分類:
# 查看推理樣本
print(config.model.infer_samples)
# 對推理樣本執行文本分類
model.classifytext(queries=config.model.infer_samples, batch_size=64, max_seq_length=128)
batch_size
控制每次推理處理的文本數量,可根據顯存和并發需求靈活設置。max_seq_length
控制輸入文本的最大截斷長度,但不能超過模型支持的上限(如BERT的512)和訓練時的最大值
如果你想對新的文本做推理,不必將其添加到配置文件中,可以直接傳入列表:
# 設置你自己的推理語句列表
my_queries = ['Clustering of missense mutations in the ataxia-telangiectasia gene in a sporadic T-cell leukaemia','Myotonic dystrophy protein kinase is involved in the modulation of the Ca2+ homeostasis in skeletal muscle cells.','Constitutional RB1-gene mutations in patients with isolated unilateral retinoblastoma.','Hereditary deficiency of the fifth component of complement in man. I. Clinical, immunochemical, and family studies.'
]
運行推理:
model.classifytext(queries=my_queries, batch_size=16, max_seq_length=64)
理想結果應為 [0, 1, 2, 2],但輸出如下,說明模型準確率還有提高空間。
[2, 1, 2, 2]
4 總結
通過本篇實踐,我們不僅完成了一個醫學文本分類器的搭建,還深入理解了NeMo與PyTorch Lightning的協作關系。NeMo的封裝加速了模型落地,而Lightning的模塊化設計則為進一步定制打下基礎。未來你可以基于該框架輕松擴展至多分類、多語言模型或其他醫學NLP任務,構建更具實際價值的AI應用。