一個生成項目
輸入:文字描述(但是給的數據集是一串數字,id,ct描述,醫生描述)
輸出:診斷報告
一、數據處理
import pandas as pd #處理表格數據pre_train_file= "data/train.csv"train_df = pd.read_csv(pre_train_file,header=None,names=["id","input","tgt"]) #讀入數據print(train_df.head())train_data = train_df.sample(frac=0.9, random_state=0, axis=0) #采樣0.9的比例val_data = train_df[~train_df.index.isin(train_data.index)] #干啥的, 過來用train_data.to_csv("data/pro_train_data.csv", index=False,header=False)val_data.to_csv("data/pro_val_data.csv", index=False,header=False)
主要是用于從一個CSV文件中讀取數據,并將其劃分為訓練集和驗證集,然后將這兩個數據集分別保存到新的CSV文件中。
代碼逐行解釋
導入必要的庫
import pandas as pd # 處理表格數據
pandas
:一個強大的數據分析和處理庫,特別適合處理表格數據(如CSV文件)。
定義文件路徑并讀取數據
pre_train_file = "data/train.csv"train_df = pd.read_csv(pre_train_file, header=None, names=["id", "input", "tgt"]) # 讀入數據print(train_df.head())
pre_train_file
:指定要讀取的CSV文件路徑。pd.read_csv
:header=None
:表示CSV文件沒有表頭(第一行不是列名)。names=["id", "input", "tgt"]
:為每一列指定名稱。
print(train_df.head())
:打印前五行數據,以便檢查讀取是否正確。
數據劃分
train_data = train_df.sample(frac=0.9, random_state=0, axis=0) # 采樣0.9的比例val_data = train_df[~train_df.index.isin(train_data.index)] # 干啥的, 過來用
-
train_data
:- 使用
sample
方法隨機采樣90%的數據作為訓練集。 frac=0.9
:表示采樣的比例為90%。random_state=0
:設置隨機種子以確保結果可重復。axis=0
:表示沿行方向進行采樣(默認行為)。
- 使用
-
val_data
:- 使用
~train_df.index.isin(train_data.index)
來獲取不在訓練集中的數據作為驗證集。 isin(train_data.index)
返回一個布爾數組,指示哪些索引在訓練集中。~
取反操作符,返回不在訓練集中的索引。
- 使用
保存數據
train_data.to_csv("data/pro_train_data.csv", index=False, header=False)val_data.to_csv("data/pro_val_data.csv", index=False, header=False)
to_csv
方法:- 將DataFrame保存為CSV文件。
index=False
:不保存行索引。header=False
:不保存列名。
二、處理詞表
import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_argsargs = parse_args() #設置 ,字典, 屬性類 config {}def load_data(path):with open(path, 'r', encoding='utf-8') as f:lines = f.readlines()datas = []for line in lines:line = line.strip().split(",")if len(line) == 3:# 訓練集text, target = line[1].split(" "), line[2].split(" ")datas.append(text + target)else:text = line[1].split(" ")datas.append(text)return datastrain_data = load_data('./data/train.csv')token2count = Counter() #計數工具 哈希表for i in train_data:token2count.update(i) #不需要知道原理tail = []
ct = 0
for k, v in token2count.items():if v >= ct:tail.append(k)
tail.sort()
vocab = tailvocab.insert(0,"[PAD]")
vocab.insert(100,"[UNK]")
vocab.insert(101,"[CLS]")
vocab.insert(102,"[SEP]")
vocab.insert(103,"[MASK]")
vocab.insert(104,"[EOS]")
# tokenizer = BertTokenizer.from_pretrained(args.pre_model_path)
# vocabs = tokenizer.get_vocab() #獲取模型詞表# new_vocabs = list(vocabs.keys())
# print(len(vocabs))
# count = 0
# for v in vocab: #mn復雜度
# if v not in vocabs:
# count += 1
# new_vocabs.append(v)
# print(len(new_vocabs))
new_vocabs = vocab
with open(args.pre_model_path+'/vocab.txt', 'w', encoding='utf-8') as f:for v in new_vocabs:f.write(f"{v}\n") #保存model = BartForConditionalGeneration.from_pretrained(args.pre_model_path) #模型
model.resize_token_embeddings(len(new_vocabs))
state_dict = model.state_dict()
torch.save(state_dict, args.pre_model_path+'/pytorch_model.bin')
bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)
1. 導入必要的庫
import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_args
sys
:用于系統相關的操作(如命令行參數)。torch
:PyTorch的核心庫,用于深度學習模型。Counter
:來自collections
模塊,用于統計元素出現的次數。BertTokenizer
,BartConfig
,BartForConditionalGeneration
:來自transformers
庫,分別用于分詞、配置和加載預訓練模型。parse_args
:自定義函數,用于解析命令行參數或配置文件,返回一個包含配置參數的對象。
2. 解析參數
args = parse_args() # 設置,字典,屬性類 config {}
parse_args
:調用自定義函數解析配置參數,并將其存儲在args
對象中。假設args
包含諸如pre_model_path
等路徑信息。
3. 定義數據加載函數
def load_data(path):with open(path, 'r', encoding='utf-8') as f:lines = f.readlines()datas = []for line in lines:line = line.strip().split(",")if len(line) == 3:# 訓練集text, target = line[1].split(" "), line[2].split(" ")datas.append(text + target)else:text = line[1].split(" ")datas.append(text)return datas
load_data
函數:- 打開指定路徑的文件并讀取每一行。
- 使用
strip()
去除每行的前后空白字符,并使用split(",")
將其按逗號分割為列表。 - 如果列表長度為3(假設是訓練集),則將第二列和第三列的數據拆分為單詞列表,并合并后添加到
datas
列表中。 - 如果列表長度不為3,則僅處理第二列的數據,并將其拆分為單詞列表后添加到
datas
列表中。 - 返回
datas
列表。
4. 加載數據
train_data = load_data('./data/train.csv')
- 調用
load_data
函數加載訓練數據,并將結果存儲在train_data
變量中。
5. 統計詞頻
token2count = Counter() # 計數工具 哈希表for i in train_data:token2count.update(i) # 不需要知道原理
token2count
:使用Counter
類創建一個哈希表來統計每個單詞出現的次數。- 遍歷
train_data
中的每一行數據,并使用update
方法更新token2count
,記錄每個單詞出現的次數。
6. 創建詞匯表
tail = []
ct = 0
for k, v in token2count.items():if v >= ct:tail.append(k)
tail.sort()
vocab = tailvocab.insert(0, "[PAD]")
vocab.insert(100, "[UNK]")
vocab.insert(101, "[CLS]")
vocab.insert(102, "[SEP]")
vocab.insert(103, "[MASK]")
vocab.insert(104, "[EOS]")
tail
:篩選出頻率大于等于ct
的單詞,并按字母順序排序。注意這里ct
設為0,因此所有單詞都會被包含進來。vocab
:將tail
賦值給vocab
。- 插入特殊標記:在
vocab
中插入一些特殊的標記符號(如[PAD]
,[UNK]
,[CLS]
,[SEP]
,[MASK]
,[EOS]
),這些標記在自然語言處理任務中具有特定含義。
7. 保存詞匯表
new_vocabs = vocab
with open(args.pre_model_path + '/vocab.txt', 'w', encoding='utf-8') as f:for v in new_vocabs:f.write(f"{v}\n") # 保存
new_vocabs
:直接賦值為vocab
。- 保存詞匯表:將詞匯表中的每個單詞寫入
vocab.txt
文件中,文件路徑由args.pre_model_path
指定。
8. 加載預訓練模型并調整詞匯表大小
model = BartForConditionalGeneration.from_pretrained(args.pre_model_path) # 模型
model.resize_token_embeddings(len(new_vocabs))
state_dict = model.state_dict()
torch.save(state_dict, args.pre_model_path + '/pytorch_model.bin')bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)
- 加載預訓練模型:使用
BartForConditionalGeneration.from_pretrained
加載預訓練模型。 - 調整詞匯表大小:使用
resize_token_embeddings
方法調整模型的嵌入層大小以適應新的詞匯表。 - 保存模型狀態:將模型的狀態字典保存到
pytorch_model.bin
文件中,文件路徑由args.pre_model_path
指定。 - 更新配置:更新
BartConfig
中的vocab_size
屬性,并保存配置。
三、自監督預訓練
from model_utils.pre_data import PreTrainDataset, loadData, MLM_Data
from torch.utils.data import DataLoader, Dataset
from model_utils.models import preModel
import logging #日志
import os
from model_utils.config import parse_args
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer
import torch
import time
# os.environ['CUDA_VISIBLE_DEVICES']='0'def train_and_validate(args):# 1. load data modelmodel = preModel(args) #加載預訓練模型optimizer, scheduler = build_optimizer(args, model)# model = model.to(args.device)use_pre = Falseif use_pre:checkpoint = torch.load(args.pre_file, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=False)if args.device == 'cuda':if args.paral == True:model = torch.nn.parallel.DataParallel(model.to(args.device))else:model = model.to(args.device)# model = BalancedDataParallel(16, model, dim=0).to(args.device)# model = model.to(args.device)#-------ema here-----------------all_data = loadData(args.data_path)train_MLM_data = MLM_Data(all_data, args)train_dataloader = DataLoader(train_MLM_data, batch_size=args.batch_size, shuffle=True,collate_fn=train_MLM_data.collate)step = 0start_time = time.time()num_total_steps = len(train_dataloader) * args.max_epochsfor epoch in range(args.max_epochs): #開始訓練了for batch in train_dataloader:model.train()loss= model(batch)loss = loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()step += 1if step % args.print_steps == 0:time_per_step = (time.time() - start_time) / max(1, step)remaining_time = time_per_step * (num_total_steps - step)remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")if epoch % 5 == 0:torch.save({'epoch': epoch, 'model_state_dict': model.module.state_dict()},f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin')def main():args = parse_args() #設置 字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args) #LINUXtrain_and_validate(args)if __name__ == '__main__':main()
實現了一個完整的訓練和驗證流程,包括數據加載、模型初始化、訓練循環、日志記錄以及模型保存等功能
1. 導入必要的庫
from model_utils.pre_data import PreTrainDataset, loadData, MLM_Data
from torch.utils.data import DataLoader, Dataset
from model_utils.models import preModel
import logging # 日志
import os
from model_utils.config import parse_args
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer
import torch
import time
PreTrainDataset
,loadData
,MLM_Data
:自定義模塊,用于數據處理。DataLoader
,Dataset
:PyTorch提供的類,用于數據加載和管理。preModel
:自定義模型類。logging
:用于記錄日志信息。os
:用于操作系統相關的操作(如文件路徑處理)。parse_args
:自定義函數,解析命令行參數或配置文件。setup_device
,setup_seed
,setup_logging
,build_optimizer
:自定義工具函數,分別用于設置設備、隨機種子、日志記錄和優化器構建。torch
:PyTorch核心庫。time
:用于時間相關操作。
2. 定義訓練和驗證函數
def train_and_validate(args):# 1. 加載數據和模型model = preModel(args) # 加載預訓練模型optimizer, scheduler = build_optimizer(args, model)use_pre = Falseif use_pre:checkpoint = torch.load(args.pre_file, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'], strict=False)if args.device == 'cuda':if args.paral == True:model = torch.nn.parallel.DataParallel(model.to(args.device))else:model = model.to(args.device)all_data = loadData(args.data_path)train_MLM_data = MLM_Data(all_data, args)train_dataloader = DataLoader(train_MLM_data, batch_size=args.batch_size, shuffle=True, collate_fn=train_MLM_data.collate)step = 0start_time = time.time()num_total_steps = len(train_dataloader) * args.max_epochsfor epoch in range(args.max_epochs): # 開始訓練了for batch in train_dataloader:model.train()loss = model(batch)loss = loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()step += 1if step % args.print_steps == 0:time_per_step = (time.time() - start_time) / max(1, step)remaining_time = time_per_step * (num_total_steps - step)remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")if epoch % 5 == 0:torch.save({'epoch': epoch, 'model_state_dict': model.module.state_dict()},f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin')
解釋
-
加載數據和模型:
- 使用
preModel
類加載預訓練模型。 - 使用
build_optimizer
函數構建優化器和學習率調度器。 - 如果
use_pre
為真,則從指定路徑加載預訓練模型的權重。 - 根據
args.device
和args.paral
參數決定是否使用多GPU并行訓練。
- 使用
-
數據加載:
- 使用
loadData
函數加載所有數據。 - 使用
MLM_Data
類將數據轉換為適合訓練的數據集格式。 - 使用
DataLoader
創建數據加載器,支持批量加載和數據打亂。
- 使用
-
訓練循環:
- 對每個epoch進行遍歷。
- 對每個batch進行前向傳播計算損失,反向傳播更新權重。
- 記錄訓練進度和剩余時間,并在特定步數時打印日志。
- 每隔5個epoch保存一次模型。
3. 主函數
def main():args = parse_args() # 設置 字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args) # LINUXtrain_and_validate(args)if __name__ == '__main__':main()
main
函數:- 調用
parse_args
解析命令行參數。 - 調用
setup_logging
配置日志記錄。 - 調用
setup_device
和setup_seed
分別設置設備和隨機種子。 - 創建保存模型的目錄(如果不存在)。
- 打印訓練和評估參數。
- 調用
train_and_validate
函數開始訓練和驗證過程。
- 調用
四、微調
import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args
from model_utils.data import create_dataloaders
from model_utils.models import myModel
from model_utils.score import CiderD, CE
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer,array2str
from torch.cuda.amp import autocast as ac
from tqdm import tqdm as tqdmos.environ['CUDA_VISIBLE_DEVICES']='0'# 不需要完全理解, 知道每一塊在做什么就行 知道之后, 以后再用到, 搬過去就行def validate(model, loader, args, output_file=None, beam=1, n=-1):res, gts = [], {}tot = 0for (source, targets) in tqdm(loader):if n>0 and tot>n:breaksource = source.cuda()pred = model(source[:, :args. input_l])pred = pred.cpu().detach().numpy()#print(pred.shape)for i in range(pred.shape[0]):# res.append({'image_id':tot, 'caption': [array2str(pred[i][2:], args)]})# gts[tot] = [array2str(targets[i][1:], args)]res.append({'image_id':tot, 'caption': [array2str(pred[i], args)]})gts[tot] = [array2str(targets[i][1:], args)]tot += 1CiderD_scorer = CiderD(df='corpus', sigma=15)cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)return cider_scoredef train_and_validate(args):# 1. load datatrain_dataloader, val_dataloader = create_dataloaders(args)model = myModel(args)use_pre = Trueif use_pre:print('use_pre')checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=True)optimizer, scheduler = build_optimizer(args, model)model = model.to(args.device)#-------ema here-----------------model.train()#-------------------------------# loss, results = validate(model, val_dataloader)# 3. trainingstep = 0best_score = args.best_score #評估指標 準確率for epoch in range(args.max_epochs):for (source, targets) in tqdm(train_dataloader):source = source.cuda()targets = targets.cuda()model.train()pred = model(source[:, :args. input_l], targets[:, :args.output_l])loss = CE(pred[:, :-1], targets[:, 1:])loss = loss.mean()loss.backward()optimizer.step()model.zero_grad()scheduler.step()step += 1if epoch % 1 == 0:cider_score = validate(model, val_dataloader, args)logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")if cider_score >= best_score:best_score = cider_scoretorch.save({'epoch': epoch, 'model_state_dict': model.state_dict()},f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin')def main():args = parse_args()setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)train_and_validate(args)if __name__ == '__main__':main()
實現了一個完整的訓練和驗證流程,包括數據加載、模型初始化、訓練循環、驗證評估以及模型保存等功能。
1. 導入必要的庫
import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args
from model_utils.data import create_dataloaders
from model_utils.models import myModel
from model_utils.score import CiderD, CE
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer, array2str
from torch.cuda.amp import autocast as ac
from tqdm import tqdm as tqdmos.environ['CUDA_VISIBLE_DEVICES'] = '0'
logging
:用于記錄日志信息。os
:用于操作系統相關的操作(如文件路徑處理)。time
:用于時間相關操作。torch
:PyTorch核心庫。PretrainedBartModel
:來自transformers
庫的預訓練模型基類。parse_args
:自定義函數,解析命令行參數或配置文件。create_dataloaders
:自定義函數,創建數據加載器。myModel
:自定義模型類。CiderD
,CE
:自定義評分函數,分別用于計算CIDEr-D分數和交叉熵損失。setup_device
,setup_seed
,setup_logging
,build_optimizer
,array2str
:自定義工具函數,分別用于設置設備、隨機種子、日志記錄、構建優化器和數組轉字符串。autocast
:用于混合精度訓練。tqdm
:用于顯示進度條。
2. 定義驗證函數
def validate(model, loader, args, output_file=None, beam=1, n=-1):res, gts = [], {}tot = 0for (source, targets) in tqdm(loader):if n > 0 and tot > n:breaksource = source.cuda()pred = model(source[:, :args.input_l])pred = pred.cpu().detach().numpy()for i in range(pred.shape[0]):res.append({'image_id': tot, 'caption': [array2str(pred[i], args)]})gts[tot] = [array2str(targets[i][1:], args)]tot += 1CiderD_scorer = CiderD(df='corpus', sigma=15)cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)return cider_score
解釋
-
輸入參數:
model
: 需要驗證的模型。loader
: 數據加載器。args
: 命令行參數或配置對象。output_file
: 輸出文件路徑(可選)。beam
: 束搜索寬度(可選,默認為1)。n
: 驗證樣本數限制(可選,默認為-1,表示不限制)。
-
邏輯:
- 初始化結果列表
res
和真實標簽字典gts
。 - 使用
tqdm
顯示進度條遍歷數據加載器中的每個批次(source, targets)
。 - 將
source
移動到 GPU 并進行前向傳播得到預測結果pred
。 - 將預測結果和真實標簽轉換為字符串格式并添加到
res
和gts
中。 - 使用
CiderD
計算預測結果與真實標簽之間的 CIDEr-D 分數。 - 返回 CIDEr-D 分數。
- 初始化結果列表
3. 定義訓練和驗證函數
def train_and_validate(args):# 1. load datatrain_dataloader, val_dataloader = create_dataloaders(args)model = myModel(args)use_pre = Trueif use_pre:print('use_pre')checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'], strict=True)optimizer, scheduler = build_optimizer(args, model)model = model.to(args.device)model.train()step = 0best_score = args.best_score # 評估指標 準確率for epoch in range(args.max_epochs):for (source, targets) in tqdm(train_dataloader):source = source.cuda()targets = targets.cuda()model.train()pred = model(source[:, :args.input_l], targets[:, :args.output_l])loss = CE(pred[:, :-1], targets[:, 1:])loss = loss.mean()loss.backward()optimizer.step()model.zero_grad()scheduler.step()step += 1if epoch % 1 == 0:cider_score = validate(model, val_dataloader, args)logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")if cider_score >= best_score:best_score = cider_scoretorch.save({'epoch': epoch, 'model_state_dict': model.state_dict()},f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin')
解釋
-
加載數據:
- 使用
create_dataloaders
函數加載訓練和驗證數據加載器。
- 使用
-
初始化模型和優化器:
- 使用
myModel
類加載模型。 - 如果
use_pre
為真,則從指定路徑加載預訓練模型的權重。 - 使用
build_optimizer
函數構建優化器和學習率調度器。 - 將模型移動到指定設備(CPU或GPU)。
- 使用
-
訓練循環:
- 對每個epoch進行遍歷。
- 對每個batch進行前向傳播計算損失,反向傳播更新權重。
- 每個epoch結束后調用
validate
函數計算驗證集上的 CIDEr-D 分數。 - 如果當前 CIDEr-D 分數優于歷史最佳分數,則保存模型。
4. 主函數
def main():args = parse_args() # 設置 字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args) # LINUXtrain_and_validate(args)if __name__ == '__main__':main()
解釋
- 主函數:
- 調用
parse_args
解析命令行參數。 - 調用
setup_logging
配置日志記錄。 - 調用
setup_device
和setup_seed
分別設置設備和隨機種子。 - 創建保存模型的目錄(如果不存在)。
- 打印訓練和評估參數。
- 調用
train_and_validate
函數開始訓練和驗證過程。
- 調用
五、inference
from tqdm import tqdm
import csv
from model_utils.utils import to_device, array2str
from model_utils.models import myModel
from model_utils.data import create_dataloaders
import torch
from model_utils.config import parse_argsdef inference(args):test_loader = create_dataloaders(args,test=True)model = myModel(args)print(args.ckpt_file)checkpoint = torch.load(args.ckpt_file, map_location='cpu')model.load_state_dict(checkpoint['model_state_dict'],strict=False)model.to('cuda:0')model.eval()fp = open(args.test_output_csv, 'w', newline='')writer = csv.writer(fp)tot = 0for source in tqdm(test_loader):source = to_device(source, 'cuda:0')pred = model(source)pred = pred.cpu().numpy()for i in range(pred.shape[0]):writer.writerow([tot, array2str(pred[i][2:], args)])tot += 1fp.close()if __name__ == '__main__':args = parse_args()inference(args)
實現了一個推理(inference)流程,包括數據加載、模型加載、前向傳播以及結果保存等功能。
1. 導入必要的庫
from tqdm import tqdm
import csv
from model_utils.utils import to_device, array2str
from model_utils.models import myModel
from model_utils.data import create_dataloaders
import torch
from model_utils.config import parse_args
tqdm
:用于顯示進度條。csv
:用于處理CSV文件的讀寫操作。to_device
:自定義函數,將數據移動到指定設備(CPU或GPU)。array2str
:自定義函數,將數組轉換為字符串。myModel
:自定義模型類。create_dataloaders
:自定義函數,創建數據加載器。torch
:PyTorch核心庫。parse_args
:自定義函數,解析命令行參數或配置文件。
2. 定義推理函數
def inference(args):test_loader = create_dataloaders(args, test=True)model = myModel(args)print(args.ckpt_file)checkpoint = torch.load(args.ckpt_file, map_location='cpu')model.load_state_dict(checkpoint['model_state_dict'], strict=False)model.to('cuda:0')model.eval()fp = open(args.test_output_csv, 'w', newline='')writer = csv.writer(fp)tot = 0for source in tqdm(test_loader):source = to_device(source, 'cuda:0')pred = model(source)pred = pred.cpu().numpy()for i in range(pred.shape[0]):writer.writerow([tot, array2str(pred[i][2:], args)])tot += 1fp.close()
解釋
-
加載測試數據:
- 使用
create_dataloaders
函數加載測試數據加載器,設置test=True
表示加載測試集。
- 使用
-
初始化模型并加載權重:
- 使用
myModel
類加載模型。 - 打印預訓練模型路徑
args.ckpt_file
。 - 使用
torch.load
加載預訓練模型的權重,并使用load_state_dict
方法加載到模型中。 - 將模型移動到 GPU(
cuda:0
),并設置為評估模式(model.eval()
)。
- 使用
-
推理過程:
- 打開輸出 CSV 文件,并創建 CSV 寫入器。
- 使用
tqdm
顯示進度條遍歷測試數據加載器中的每個批次source
。 - 將
source
移動到 GPU 并進行前向傳播得到預測結果pred
。 - 將預測結果轉換為 NumPy 數組,并逐個樣本寫入 CSV 文件。
3. 主函數
if __name__ == '__main__':args = parse_args()inference(args)
- 主函數:
- 調用
parse_args
解析命令行參數。 - 調用
inference
函數開始推理過程。
- 調用