生成任務,大模型

一個生成項目

輸入:文字描述(但是給的數據集是一串數字,id,ct描述,醫生描述)
輸出:診斷報告

一、數據處理

import pandas as pd  #處理表格數據pre_train_file= "data/train.csv"train_df = pd.read_csv(pre_train_file,header=None,names=["id","input","tgt"]) #讀入數據print(train_df.head())train_data = train_df.sample(frac=0.9, random_state=0, axis=0)   #采樣0.9的比例val_data = train_df[~train_df.index.isin(train_data.index)]       #干啥的,  過來用train_data.to_csv("data/pro_train_data.csv", index=False,header=False)val_data.to_csv("data/pro_val_data.csv", index=False,header=False)

主要是用于從一個CSV文件中讀取數據,并將其劃分為訓練集和驗證集,然后將這兩個數據集分別保存到新的CSV文件中。

代碼逐行解釋

導入必要的庫
import pandas as pd  # 處理表格數據
  • pandas:一個強大的數據分析和處理庫,特別適合處理表格數據(如CSV文件)。
定義文件路徑并讀取數據
pre_train_file = "data/train.csv"train_df = pd.read_csv(pre_train_file, header=None, names=["id", "input", "tgt"])  # 讀入數據print(train_df.head())
  • pre_train_file:指定要讀取的CSV文件路徑。
  • pd.read_csv
    • header=None:表示CSV文件沒有表頭(第一行不是列名)。
    • names=["id", "input", "tgt"]:為每一列指定名稱。
  • print(train_df.head()):打印前五行數據,以便檢查讀取是否正確。
數據劃分
train_data = train_df.sample(frac=0.9, random_state=0, axis=0)  # 采樣0.9的比例val_data = train_df[~train_df.index.isin(train_data.index)]  # 干啥的, 過來用
  • train_data

    • 使用 sample 方法隨機采樣90%的數據作為訓練集。
    • frac=0.9:表示采樣的比例為90%。
    • random_state=0:設置隨機種子以確保結果可重復。
    • axis=0:表示沿行方向進行采樣(默認行為)。
  • val_data

    • 使用 ~train_df.index.isin(train_data.index) 來獲取不在訓練集中的數據作為驗證集。
    • isin(train_data.index) 返回一個布爾數組,指示哪些索引在訓練集中。
    • ~ 取反操作符,返回不在訓練集中的索引。
保存數據
train_data.to_csv("data/pro_train_data.csv", index=False, header=False)val_data.to_csv("data/pro_val_data.csv", index=False, header=False)
  • to_csv 方法
    • 將DataFrame保存為CSV文件。
    • index=False:不保存行索引。
    • header=False:不保存列名。

二、處理詞表

import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_argsargs = parse_args()         #設置 ,字典, 屬性類  config  {}def load_data(path):with open(path, 'r', encoding='utf-8') as f:lines = f.readlines()datas = []for line in lines:line = line.strip().split(",")if len(line) == 3:# 訓練集text, target = line[1].split(" "), line[2].split(" ")datas.append(text + target)else:text = line[1].split(" ")datas.append(text)return datastrain_data = load_data('./data/train.csv')token2count = Counter()     #計數工具 哈希表for i in train_data:token2count.update(i)       #不需要知道原理tail = []
ct = 0
for k, v in token2count.items():if v >= ct:tail.append(k)
tail.sort()
vocab = tailvocab.insert(0,"[PAD]")
vocab.insert(100,"[UNK]")
vocab.insert(101,"[CLS]")
vocab.insert(102,"[SEP]")
vocab.insert(103,"[MASK]")
vocab.insert(104,"[EOS]")
# tokenizer = BertTokenizer.from_pretrained(args.pre_model_path)
# vocabs = tokenizer.get_vocab()   #獲取模型詞表# new_vocabs = list(vocabs.keys())
# print(len(vocabs))
# count = 0
# for v in vocab:         #mn復雜度
#     if v not in vocabs:
#         count += 1
#         new_vocabs.append(v)
# print(len(new_vocabs))
new_vocabs = vocab
with open(args.pre_model_path+'/vocab.txt', 'w', encoding='utf-8') as f:for v in new_vocabs:f.write(f"{v}\n")    #保存model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)      #模型
model.resize_token_embeddings(len(new_vocabs))
state_dict = model.state_dict()
torch.save(state_dict, args.pre_model_path+'/pytorch_model.bin')
bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)

1. 導入必要的庫

import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_args
  • sys:用于系統相關的操作(如命令行參數)。
  • torch:PyTorch的核心庫,用于深度學習模型。
  • Counter:來自 collections 模塊,用于統計元素出現的次數。
  • BertTokenizer, BartConfig, BartForConditionalGeneration:來自 transformers 庫,分別用于分詞、配置和加載預訓練模型。
  • parse_args:自定義函數,用于解析命令行參數或配置文件,返回一個包含配置參數的對象。

2. 解析參數

args = parse_args()  # 設置,字典,屬性類 config {}
  • parse_args:調用自定義函數解析配置參數,并將其存儲在 args 對象中。假設 args 包含諸如 pre_model_path 等路徑信息。

3. 定義數據加載函數

def load_data(path):with open(path, 'r', encoding='utf-8') as f:lines = f.readlines()datas = []for line in lines:line = line.strip().split(",")if len(line) == 3:# 訓練集text, target = line[1].split(" "), line[2].split(" ")datas.append(text + target)else:text = line[1].split(" ")datas.append(text)return datas
  • load_data 函數
    • 打開指定路徑的文件并讀取每一行。
    • 使用 strip() 去除每行的前后空白字符,并使用 split(",") 將其按逗號分割為列表。
    • 如果列表長度為3(假設是訓練集),則將第二列和第三列的數據拆分為單詞列表,并合并后添加到 datas 列表中。
    • 如果列表長度不為3,則僅處理第二列的數據,并將其拆分為單詞列表后添加到 datas 列表中。
    • 返回 datas 列表。

4. 加載數據

train_data = load_data('./data/train.csv')
  • 調用 load_data 函數加載訓練數據,并將結果存儲在 train_data 變量中。

5. 統計詞頻

token2count = Counter()  # 計數工具 哈希表for i in train_data:token2count.update(i)  # 不需要知道原理
  • token2count:使用 Counter 類創建一個哈希表來統計每個單詞出現的次數。
  • 遍歷 train_data 中的每一行數據,并使用 update 方法更新 token2count,記錄每個單詞出現的次數。

6. 創建詞匯表

tail = []
ct = 0
for k, v in token2count.items():if v >= ct:tail.append(k)
tail.sort()
vocab = tailvocab.insert(0, "[PAD]")
vocab.insert(100, "[UNK]")
vocab.insert(101, "[CLS]")
vocab.insert(102, "[SEP]")
vocab.insert(103, "[MASK]")
vocab.insert(104, "[EOS]")
  • tail:篩選出頻率大于等于 ct 的單詞,并按字母順序排序。注意這里 ct 設為0,因此所有單詞都會被包含進來。
  • vocab:將 tail 賦值給 vocab
  • 插入特殊標記:在 vocab 中插入一些特殊的標記符號(如 [PAD], [UNK], [CLS], [SEP], [MASK], [EOS]),這些標記在自然語言處理任務中具有特定含義。

7. 保存詞匯表

new_vocabs = vocab
with open(args.pre_model_path + '/vocab.txt', 'w', encoding='utf-8') as f:for v in new_vocabs:f.write(f"{v}\n")  # 保存
  • new_vocabs:直接賦值為 vocab
  • 保存詞匯表:將詞匯表中的每個單詞寫入 vocab.txt 文件中,文件路徑由 args.pre_model_path 指定。

8. 加載預訓練模型并調整詞匯表大小

model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)  # 模型
model.resize_token_embeddings(len(new_vocabs))
state_dict = model.state_dict()
torch.save(state_dict, args.pre_model_path + '/pytorch_model.bin')bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)
  • 加載預訓練模型:使用 BartForConditionalGeneration.from_pretrained 加載預訓練模型。
  • 調整詞匯表大小:使用 resize_token_embeddings 方法調整模型的嵌入層大小以適應新的詞匯表。
  • 保存模型狀態:將模型的狀態字典保存到 pytorch_model.bin 文件中,文件路徑由 args.pre_model_path 指定。
  • 更新配置:更新 BartConfig 中的 vocab_size 屬性,并保存配置。

三、自監督預訓練

from model_utils.pre_data import PreTrainDataset, loadData, MLM_Data
from torch.utils.data import DataLoader, Dataset
from model_utils.models import preModel
import logging        #日志
import os
from model_utils.config import parse_args
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer
import torch
import time
# os.environ['CUDA_VISIBLE_DEVICES']='0'def train_and_validate(args):# 1. load data  modelmodel = preModel(args)     #加載預訓練模型optimizer, scheduler = build_optimizer(args, model)# model = model.to(args.device)use_pre = Falseif use_pre:checkpoint = torch.load(args.pre_file, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=False)if args.device == 'cuda':if args.paral == True:model = torch.nn.parallel.DataParallel(model.to(args.device))else:model = model.to(args.device)# model = BalancedDataParallel(16, model, dim=0).to(args.device)# model = model.to(args.device)#-------ema here-----------------all_data = loadData(args.data_path)train_MLM_data = MLM_Data(all_data, args)train_dataloader = DataLoader(train_MLM_data, batch_size=args.batch_size, shuffle=True,collate_fn=train_MLM_data.collate)step = 0start_time = time.time()num_total_steps = len(train_dataloader) * args.max_epochsfor epoch in range(args.max_epochs):    #開始訓練了for batch in train_dataloader:model.train()loss= model(batch)loss = loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()step += 1if step % args.print_steps == 0:time_per_step = (time.time() - start_time) / max(1, step)remaining_time = time_per_step * (num_total_steps - step)remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")if epoch % 5 == 0:torch.save({'epoch': epoch, 'model_state_dict': model.module.state_dict()},f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin')def main():args = parse_args()           #設置   字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)         #LINUXtrain_and_validate(args)if __name__ == '__main__':main()

實現了一個完整的訓練和驗證流程,包括數據加載、模型初始化、訓練循環、日志記錄以及模型保存等功能

1. 導入必要的庫

from model_utils.pre_data import PreTrainDataset, loadData, MLM_Data
from torch.utils.data import DataLoader, Dataset
from model_utils.models import preModel
import logging        # 日志
import os
from model_utils.config import parse_args
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer
import torch
import time
  • PreTrainDataset, loadData, MLM_Data:自定義模塊,用于數據處理。
  • DataLoader, Dataset:PyTorch提供的類,用于數據加載和管理。
  • preModel:自定義模型類。
  • logging:用于記錄日志信息。
  • os:用于操作系統相關的操作(如文件路徑處理)。
  • parse_args:自定義函數,解析命令行參數或配置文件。
  • setup_device, setup_seed, setup_logging, build_optimizer:自定義工具函數,分別用于設置設備、隨機種子、日志記錄和優化器構建。
  • torch:PyTorch核心庫。
  • time:用于時間相關操作。

2. 定義訓練和驗證函數

def train_and_validate(args):# 1. 加載數據和模型model = preModel(args)     # 加載預訓練模型optimizer, scheduler = build_optimizer(args, model)use_pre = Falseif use_pre:checkpoint = torch.load(args.pre_file, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'], strict=False)if args.device == 'cuda':if args.paral == True:model = torch.nn.parallel.DataParallel(model.to(args.device))else:model = model.to(args.device)all_data = loadData(args.data_path)train_MLM_data = MLM_Data(all_data, args)train_dataloader = DataLoader(train_MLM_data, batch_size=args.batch_size, shuffle=True, collate_fn=train_MLM_data.collate)step = 0start_time = time.time()num_total_steps = len(train_dataloader) * args.max_epochsfor epoch in range(args.max_epochs):    # 開始訓練了for batch in train_dataloader:model.train()loss = model(batch)loss = loss.mean()loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()step += 1if step % args.print_steps == 0:time_per_step = (time.time() - start_time) / max(1, step)remaining_time = time_per_step * (num_total_steps - step)remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")if epoch % 5 == 0:torch.save({'epoch': epoch, 'model_state_dict': model.module.state_dict()},f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin')
解釋
  • 加載數據和模型

    • 使用 preModel 類加載預訓練模型。
    • 使用 build_optimizer 函數構建優化器和學習率調度器。
    • 如果 use_pre 為真,則從指定路徑加載預訓練模型的權重。
    • 根據 args.deviceargs.paral 參數決定是否使用多GPU并行訓練。
  • 數據加載

    • 使用 loadData 函數加載所有數據。
    • 使用 MLM_Data 類將數據轉換為適合訓練的數據集格式。
    • 使用 DataLoader 創建數據加載器,支持批量加載和數據打亂。
  • 訓練循環

    • 對每個epoch進行遍歷。
    • 對每個batch進行前向傳播計算損失,反向傳播更新權重。
    • 記錄訓練進度和剩余時間,并在特定步數時打印日志。
    • 每隔5個epoch保存一次模型。

3. 主函數

def main():args = parse_args()           # 設置   字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)         # LINUXtrain_and_validate(args)if __name__ == '__main__':main()
  • main 函數
    • 調用 parse_args 解析命令行參數。
    • 調用 setup_logging 配置日志記錄。
    • 調用 setup_devicesetup_seed 分別設置設備和隨機種子。
    • 創建保存模型的目錄(如果不存在)。
    • 打印訓練和評估參數。
    • 調用 train_and_validate 函數開始訓練和驗證過程。

四、微調

import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args
from model_utils.data import create_dataloaders
from model_utils.models import myModel
from model_utils.score import CiderD, CE
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer,array2str
from torch.cuda.amp import autocast as ac
from tqdm import tqdm as tqdmos.environ['CUDA_VISIBLE_DEVICES']='0'# 不需要完全理解,  知道每一塊在做什么就行   知道之后,  以后再用到, 搬過去就行def validate(model, loader, args, output_file=None, beam=1, n=-1):res, gts = [], {}tot = 0for (source, targets) in tqdm(loader):if n>0 and tot>n:breaksource = source.cuda()pred = model(source[:, :args. input_l])pred = pred.cpu().detach().numpy()#print(pred.shape)for i in range(pred.shape[0]):# res.append({'image_id':tot, 'caption': [array2str(pred[i][2:], args)]})# gts[tot] = [array2str(targets[i][1:], args)]res.append({'image_id':tot, 'caption': [array2str(pred[i], args)]})gts[tot] = [array2str(targets[i][1:], args)]tot += 1CiderD_scorer = CiderD(df='corpus', sigma=15)cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)return cider_scoredef train_and_validate(args):# 1. load datatrain_dataloader, val_dataloader = create_dataloaders(args)model = myModel(args)use_pre = Trueif use_pre:print('use_pre')checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=True)optimizer, scheduler = build_optimizer(args, model)model = model.to(args.device)#-------ema here-----------------model.train()#-------------------------------# loss, results = validate(model, val_dataloader)# 3. trainingstep = 0best_score = args.best_score     #評估指標  準確率for epoch in range(args.max_epochs):for (source, targets) in tqdm(train_dataloader):source = source.cuda()targets = targets.cuda()model.train()pred = model(source[:, :args. input_l], targets[:, :args.output_l])loss  = CE(pred[:, :-1], targets[:, 1:])loss = loss.mean()loss.backward()optimizer.step()model.zero_grad()scheduler.step()step += 1if epoch % 1 == 0:cider_score = validate(model, val_dataloader, args)logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")if cider_score >= best_score:best_score = cider_scoretorch.save({'epoch': epoch, 'model_state_dict': model.state_dict()},f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin')def main():args = parse_args()setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)train_and_validate(args)if __name__ == '__main__':main()

實現了一個完整的訓練和驗證流程,包括數據加載、模型初始化、訓練循環、驗證評估以及模型保存等功能。

1. 導入必要的庫

import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args
from model_utils.data import create_dataloaders
from model_utils.models import myModel
from model_utils.score import CiderD, CE
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer, array2str
from torch.cuda.amp import autocast as ac
from tqdm import tqdm as tqdmos.environ['CUDA_VISIBLE_DEVICES'] = '0'
  • logging:用于記錄日志信息。
  • os:用于操作系統相關的操作(如文件路徑處理)。
  • time:用于時間相關操作。
  • torch:PyTorch核心庫。
  • PretrainedBartModel:來自 transformers 庫的預訓練模型基類。
  • parse_args:自定義函數,解析命令行參數或配置文件。
  • create_dataloaders:自定義函數,創建數據加載器。
  • myModel:自定義模型類。
  • CiderD, CE:自定義評分函數,分別用于計算CIDEr-D分數和交叉熵損失。
  • setup_device, setup_seed, setup_logging, build_optimizer, array2str:自定義工具函數,分別用于設置設備、隨機種子、日志記錄、構建優化器和數組轉字符串。
  • autocast:用于混合精度訓練。
  • tqdm:用于顯示進度條。

2. 定義驗證函數

def validate(model, loader, args, output_file=None, beam=1, n=-1):res, gts = [], {}tot = 0for (source, targets) in tqdm(loader):if n > 0 and tot > n:breaksource = source.cuda()pred = model(source[:, :args.input_l])pred = pred.cpu().detach().numpy()for i in range(pred.shape[0]):res.append({'image_id': tot, 'caption': [array2str(pred[i], args)]})gts[tot] = [array2str(targets[i][1:], args)]tot += 1CiderD_scorer = CiderD(df='corpus', sigma=15)cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)return cider_score
解釋
  • 輸入參數

    • model: 需要驗證的模型。
    • loader: 數據加載器。
    • args: 命令行參數或配置對象。
    • output_file: 輸出文件路徑(可選)。
    • beam: 束搜索寬度(可選,默認為1)。
    • n: 驗證樣本數限制(可選,默認為-1,表示不限制)。
  • 邏輯

    • 初始化結果列表 res 和真實標簽字典 gts
    • 使用 tqdm 顯示進度條遍歷數據加載器中的每個批次 (source, targets)
    • source 移動到 GPU 并進行前向傳播得到預測結果 pred
    • 將預測結果和真實標簽轉換為字符串格式并添加到 resgts 中。
    • 使用 CiderD 計算預測結果與真實標簽之間的 CIDEr-D 分數。
    • 返回 CIDEr-D 分數。

3. 定義訓練和驗證函數

def train_and_validate(args):# 1. load datatrain_dataloader, val_dataloader = create_dataloaders(args)model = myModel(args)use_pre = Trueif use_pre:print('use_pre')checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')new_KEY = model.load_state_dict(checkpoint['model_state_dict'], strict=True)optimizer, scheduler = build_optimizer(args, model)model = model.to(args.device)model.train()step = 0best_score = args.best_score  # 評估指標 準確率for epoch in range(args.max_epochs):for (source, targets) in tqdm(train_dataloader):source = source.cuda()targets = targets.cuda()model.train()pred = model(source[:, :args.input_l], targets[:, :args.output_l])loss = CE(pred[:, :-1], targets[:, 1:])loss = loss.mean()loss.backward()optimizer.step()model.zero_grad()scheduler.step()step += 1if epoch % 1 == 0:cider_score = validate(model, val_dataloader, args)logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")if cider_score >= best_score:best_score = cider_scoretorch.save({'epoch': epoch, 'model_state_dict': model.state_dict()},f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin')
解釋
  • 加載數據

    • 使用 create_dataloaders 函數加載訓練和驗證數據加載器。
  • 初始化模型和優化器

    • 使用 myModel 類加載模型。
    • 如果 use_pre 為真,則從指定路徑加載預訓練模型的權重。
    • 使用 build_optimizer 函數構建優化器和學習率調度器。
    • 將模型移動到指定設備(CPU或GPU)。
  • 訓練循環

    • 對每個epoch進行遍歷。
    • 對每個batch進行前向傳播計算損失,反向傳播更新權重。
    • 每個epoch結束后調用 validate 函數計算驗證集上的 CIDEr-D 分數。
    • 如果當前 CIDEr-D 分數優于歷史最佳分數,則保存模型。

4. 主函數

def main():args = parse_args()  # 設置   字典setup_logging()setup_device(args)setup_seed(args)os.makedirs(args.savedmodel_path, exist_ok=True)logging.info("Training/evaluation parameters: %s", args)  # LINUXtrain_and_validate(args)if __name__ == '__main__':main()
解釋
  • 主函數
    • 調用 parse_args 解析命令行參數。
    • 調用 setup_logging 配置日志記錄。
    • 調用 setup_devicesetup_seed 分別設置設備和隨機種子。
    • 創建保存模型的目錄(如果不存在)。
    • 打印訓練和評估參數。
    • 調用 train_and_validate 函數開始訓練和驗證過程。

五、inference

from tqdm import tqdm
import csv
from model_utils.utils import to_device, array2str
from model_utils.models import myModel
from model_utils.data import create_dataloaders
import torch
from model_utils.config import parse_argsdef inference(args):test_loader = create_dataloaders(args,test=True)model = myModel(args)print(args.ckpt_file)checkpoint = torch.load(args.ckpt_file, map_location='cpu')model.load_state_dict(checkpoint['model_state_dict'],strict=False)model.to('cuda:0')model.eval()fp = open(args.test_output_csv, 'w', newline='')writer = csv.writer(fp)tot = 0for source in tqdm(test_loader):source = to_device(source, 'cuda:0')pred = model(source)pred = pred.cpu().numpy()for i in range(pred.shape[0]):writer.writerow([tot, array2str(pred[i][2:], args)])tot += 1fp.close()if __name__ == '__main__':args = parse_args()inference(args)

實現了一個推理(inference)流程,包括數據加載、模型加載、前向傳播以及結果保存等功能。

1. 導入必要的庫

from tqdm import tqdm
import csv
from model_utils.utils import to_device, array2str
from model_utils.models import myModel
from model_utils.data import create_dataloaders
import torch
from model_utils.config import parse_args
  • tqdm:用于顯示進度條。
  • csv:用于處理CSV文件的讀寫操作。
  • to_device:自定義函數,將數據移動到指定設備(CPU或GPU)。
  • array2str:自定義函數,將數組轉換為字符串。
  • myModel:自定義模型類。
  • create_dataloaders:自定義函數,創建數據加載器。
  • torch:PyTorch核心庫。
  • parse_args:自定義函數,解析命令行參數或配置文件。

2. 定義推理函數

def inference(args):test_loader = create_dataloaders(args, test=True)model = myModel(args)print(args.ckpt_file)checkpoint = torch.load(args.ckpt_file, map_location='cpu')model.load_state_dict(checkpoint['model_state_dict'], strict=False)model.to('cuda:0')model.eval()fp = open(args.test_output_csv, 'w', newline='')writer = csv.writer(fp)tot = 0for source in tqdm(test_loader):source = to_device(source, 'cuda:0')pred = model(source)pred = pred.cpu().numpy()for i in range(pred.shape[0]):writer.writerow([tot, array2str(pred[i][2:], args)])tot += 1fp.close()
解釋
  • 加載測試數據

    • 使用 create_dataloaders 函數加載測試數據加載器,設置 test=True 表示加載測試集。
  • 初始化模型并加載權重

    • 使用 myModel 類加載模型。
    • 打印預訓練模型路徑 args.ckpt_file
    • 使用 torch.load 加載預訓練模型的權重,并使用 load_state_dict 方法加載到模型中。
    • 將模型移動到 GPU(cuda:0),并設置為評估模式(model.eval())。
  • 推理過程

    • 打開輸出 CSV 文件,并創建 CSV 寫入器。
    • 使用 tqdm 顯示進度條遍歷測試數據加載器中的每個批次 source
    • source 移動到 GPU 并進行前向傳播得到預測結果 pred
    • 將預測結果轉換為 NumPy 數組,并逐個樣本寫入 CSV 文件。

3. 主函數

if __name__ == '__main__':args = parse_args()inference(args)
  • 主函數
    • 調用 parse_args 解析命令行參數。
    • 調用 inference 函數開始推理過程。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/71738.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/71738.shtml
英文地址,請注明出處:http://en.pswp.cn/web/71738.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Spring Boot API 項目中 HAProxy 與 Nginx 的選擇與實踐

在開發 Spring Boot 構建的 RESTful API 項目時,負載均衡和反向代理是提升性能與可用性的關鍵環節。HAProxy 和 Nginx 作為兩種流行的工具,經常被用于流量分發,但它們各有側重。究竟哪一個更適合你的 Spring Boot API 項目?本文將…

Java常用集合與映射的線程安全問題深度解析

Java常用集合與映射的線程安全問題深度解析 一、線程安全基礎認知 在并發編程環境下,當多個線程同時操作同一集合對象時,若未采取同步措施,可能導致以下典型問題: 數據競爭:多個線程同時修改數據導致結果不可預測狀…

DeepLabv3+改進6:在主干網絡中添加SegNext_Attention|助力漲點

??【DeepLabv3+改進專欄!探索語義分割新高度】 ?? 你是否在為圖像分割的精度與效率發愁? ?? 本專欄重磅推出: ? 獨家改進策略:融合注意力機制、輕量化設計與多尺度優化 ? 即插即用模塊:ASPP+升級、解碼器 PS:訂閱專欄提供完整代碼 目錄 論文簡介 步驟一 步驟二…

使用 Elastic-Agent 或 Beats 將 Journald 中的 syslog 和 auth 日志導入 Elastic Stack

作者:來自 Elastic TiagoQueiroz 我們在 Elastic 一直努力將更多 Linux 發行版添加到我們的支持矩陣中,現在 Elastic-Agent 和 Beats 已正式支持 Debian 12! 本文演示了我們正在開發的功能,以支持使用 Journald 存儲系統和身份驗…

3.9[A]csd

在傳統CPU中心架構中,中央處理器通過內存訪問外部存儲器,而數據必須經過網絡接口卡才能到達外部存儲器。這種架構存在集中式計算、DRAM帶寬和容量挑戰、大量數據移動(服務器內和網絡)以及固定計算導致工作負載容量增長等問題。 而…

ESP32S3讀取數字麥克風INMP441的音頻數據

ESP32S3 與 INMP441 麥克風模塊的集成通常涉及使用 I2S 接口進行數字音頻數據的傳輸。INMP441 是一款高性能的數字麥克風,它通過 I2S 接口輸出音頻數據。在 Arduino 環境中,ESP32S3 的開發通常使用 ESP-IDF(Espressif IoT Development Framew…

DeepSeek大模型 —— 全維度技術解析

DeepSeek大模型 —— 全維度技術解析 前些天發現了一個巨牛的人工智能學習網站,通俗易懂,風趣幽默,可以分享一下給大家。點擊跳轉到網站。 https://www.captainbed.cn/ccc 文章目錄 DeepSeek大模型 —— 全維度技術解析一、模型架構全景解析1…

[Kubernetes] 7控制平面組件

1. 調度 kube- scheduler what 負責分配調度pod到集群節點監聽kube-apiserver,查詢未分配node的pod根據調度策略分配這些pod(更新pod的nodename)需要考慮的因素: 公平調度,資源有效利用,QoS,affinity, an…

PyTorch系列教程:編寫高效模型訓練流程

當使用PyTorch開發機器學習模型時,建立一個有效的訓練循環是至關重要的。這個過程包括組織和執行對數據、參數和計算資源的操作序列。讓我們深入了解關鍵組件,并演示如何構建一個精細的訓練循環流程,有效地處理數據處理,向前和向后…

LeetCode Hot100刷題——反轉鏈表(迭代+遞歸)

206.反轉鏈表 給你單鏈表的頭節點 head ,請你反轉鏈表,并返回反轉后的鏈表。 示例 1: 輸入:head [1,2,3,4,5] 輸出:[5,4,3,2,1]示例 2: 輸入:head [1,2] 輸出:[2,1]示例 3&#…

機器學習的發展史

機器學習(Machine Learning, ML)作為人工智能(AI)的一個分支,其發展經歷了多個階段。以下是機器學習的發展史概述: 1. 早期探索(20世紀50年代 - 70年代) 1950年:艾倫圖…

Springboot redis bitMap實現用戶簽到以及統計,保姆級教程

項目架構,這是作為demo展示使用: Redis config: package com.zy.config;import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.databind.Ob…

Ardupilot開源無人機之Geek SDK進展2025Q1

Ardupilot開源無人機之Geek SDK進展2025Q1 1. 源由2. 內容匯總2.1 【jetson-fpv】YOLO INT8 coco8 dataset 精度降級2.2 【OpenIPC-Configurator】OpenIPC Configurator 固件升級失敗2.3 【OpenIPC-Adaptive-link】OpenIPC RF信號質量相關顯示2.4 【OpenIPC-msposd】.srt/.osd…

《云原生監控體系構建實錄:從Prometheus到Grafana的觀測革命》

PrometheusGrafana部署配置 Prometheus安裝 下載Prometheus服務端 Download | PrometheusAn open-source monitoring system with a dimensional data model, flexible query language, efficient time series database and modern alerting approach.https://prometheus.io/…

SpringMvc與Struts2

一、Spring MVC 1.1 概述 Spring MVC 是 Spring 框架的一部分,是一個基于 MVC 設計模式的輕量級 Web 框架。它提供了靈活的配置和強大的擴展能力,適合構建復雜的 Web 應用程序。 1.2 特點 輕量級:與 Spring 框架無縫集成,依賴…

數據類設計_圖片類設計之1_矩陣類設計(前端架構基礎)

前言 學的東西多了,要想辦法用出來.C和C是偏向底層的語言,直接與數據打交道.嘗試做一些和數據方面相關的內容 引入 圖形在底層是怎么表示的,用C來表示 認識圖片 圖片是個風景,動物,還是其他內容,人是可以看出來的.那么計算機是怎么看懂的呢?在有自主意識的人工智能被設計出來…

開發者社區測試報告(功能測試+性能測試)

功能測試 測試相關用例 開發者社區功能背景 在當今數字化時代,編程已經成為一項核心技能,越來越多的人開始學習編程,以適應快速變化的科技 環境。基于這一需求,我設計開發了一個類似博客的論壇系統,專注于方便程序員…

EasyRTC嵌入式音視頻通話SDK:基于ICE與STUN/TURN的實時音視頻通信解決方案

在當今數字化時代,實時音視頻通信技術已成為人們生活和工作中不可或缺的一部分。無論是家庭中的遠程看護、辦公場景中的遠程協作,還是工業領域的遠程巡檢和智能設備的互聯互通,高效、穩定的通信技術都是實現這些功能的核心。 EasyRTC嵌入式音…

【OneAPI】網頁截圖API-V2

API簡介 生成指定URL的網頁截圖或縮略圖。 舊版本請參考:網頁截圖 V2版本新增全屏截圖、帶殼截圖等功能,并修復了一些已知問題。 全屏截圖: 支持全屏截圖,通過設置fullscreentrue來支持全屏截圖。全屏模式下,系統…

簡單的 Python 示例,用于生成電影解說視頻的第一人稱獨白解說文案

以下是一個簡單的 Python 示例,用于生成電影解說視頻的第一人稱獨白解說文案。這個示例使用了 OpenAI 的 GPT 模型,因為它在自然語言生成方面表現出色。 實現思路 安裝必要的庫:使用 openai 庫與 OpenAI API 進行交互。設置 API 密鑰&#…