深度學習Note.5(機器學習.6)

1.Runner類

一個任務應用機器學習方法流程:

數據集構建

模型構建

損失函數定義

優化器

模型訓練

模型評價

模型預測

所以根據以上,我們把機器學習模型基本要素封裝成一個Runner類(加上模型保存、模型加載等功能。)

Runner類的成員函數定義如下:

  • __init__函數:實例化Runner類時默認調用,需要傳入模型、損失函數、優化器和評價指標等;
  • train函數:完成模型訓練,指定模型訓練需要的訓練集和驗證集;
  • evaluate函數:通過對訓練好的模型進行評價,在驗證集或測試集上查看模型訓練效果;
  • predict函數:選取一條數據對訓練好的模型進行預測;
  • save_model函數:模型在訓練過程和訓練結束后需要進行保存;
  • load_model函數:調用加載之前保存的模型。
class Runner(object):def __init__(self, model, optimizer, loss_fn, metric):self.model = model         # 模型self.optimizer = optimizer # 優化器self.loss_fn = loss_fn     # 損失函數   self.metric = metric       # 評估指標# 模型訓練def train(self, train_dataset, dev_dataset=None, **kwargs):pass# 模型評價def evaluate(self, data_set, **kwargs):pass# 模型預測def predict(self, x, **kwargs):pass# 模型保存def save_model(self, save_path):pass# 模型加載def load_model(self, model_path):pass

1.2Runner類流程

①初始化:傳入模型、損失函數、優化器和評價指標

②訓練:基于訓練集調用train()函數訓練模型,基于驗證集通過evaluate()函數驗證模型。通過save_model()函數保存模型

③評價:基于測試集通過evaluate()函數得到指標性能。

④預測:給定樣本,通過predict()函數得到該樣本標簽

2.案例:波士頓房價預測

波士頓房價預測基于線性回歸模型和Runner類實現

2.1數據處理

2.1.1構建

開源庫pandas導入。

import pandas as pd # 開源數據分析和操作工具# 利用pandas加載波士頓房價的數據集
data=pd.read_csv("/home/aistudio/work/boston_house_prices.csv")
# 預覽前5行數據
data.head()

2.1.2數據集劃分

訓練集 和 測試集。

import paddlepaddle.seed(10)# 劃分訓練集和測試集
def train_test_split(X, y, train_percent=0.8):n = len(X)shuffled_indices = paddle.randperm(n) # 返回一個數值在0到n-1、隨機排列的1-D Tensortrain_set_size = int(n*train_percent)train_indices = shuffled_indices[:train_set_size]test_indices = shuffled_indices[train_set_size:]X = X.valuesy = y.valuesX_train=X[train_indices]y_train = y[train_indices]X_test = X[test_indices]y_test = y[test_indices]return X_train, X_test, y_train, y_test X = data.drop(['MEDV'], axis=1)
y = data['MEDV']X_train, X_test, y_train, y_test = train_test_split(X,y)# X_train每一行是個樣本,shape[N,D]

2.1.3特征化工程

避免數據之間的可比性:對特征數據進行歸一化處理,將數據縮放到[0, 1]區間

import paddleX_train = paddle.to_tensor(X_train,dtype='float32')
X_test = paddle.to_tensor(X_test,dtype='float32')
y_train = paddle.to_tensor(y_train,dtype='float32')
y_test = paddle.to_tensor(y_test,dtype='float32')X_min = paddle.min(X_train,axis=0)
X_max = paddle.max(X_train,axis=0)X_train = (X_train-X_min)/(X_max-X_min)X_test  = (X_test-X_min)/(X_max-X_min)# 訓練集構造
train_dataset=(X_train,y_train)
# 測試集構造
test_dataset=(X_test,y_test)

2.2模型構建

rom nndl.op import Linear# 模型實例化
input_size = 12
model=Linear(input_size)

2.3完善Runner類

測試集上使用MSE對模型性能進行評估。本案例利用飛槳框架提供的MSELoss API實現

import paddle
import os
from nndl.opitimizer import optimizer_lsmclass Runner(object):def __init__(self, model, optimizer, loss_fn, metric):# 優化器和損失函數為None,不再關注# 模型self.model=model# 評估指標self.metric = metric# 優化器self.optimizer = optimizerdef train(self,dataset,reg_lambda,model_dir):X,y = datasetself.optimizer(self.model,X,y,reg_lambda)# 保存模型self.save_model(model_dir)def evaluate(self, dataset, **kwargs):X,y = datasety_pred = self.model(X)result = self.metric(y_pred, y)return resultdef predict(self, X, **kwargs):return self.model(X)def save_model(self, model_dir):if not os.path.exists(model_dir):os.makedirs(model_dir)params_saved_path = os.path.join(model_dir,'params.pdtensor')paddle.save(model.params,params_saved_path)def load_model(self, model_dir):params_saved_path = os.path.join(model_dir,'params.pdtensor')self.model.params=paddle.load(params_saved_path)optimizer = optimizer_lsm# 實例化Runner
runner = Runner(model, optimizer=optimizer,loss_fn=None, metric=mse_loss)

2.4模型訓練

組裝完成Runner之后,我們將開始進行模型訓練、評估和測試

# 模型保存文件夾
saved_dir = '/home/aistudio/work/models'# 啟動訓練
runner.train(train_dataset,reg_lambda=0,model_dir=saved_dir)columns_list = data.columns.to_list()
weights = runner.model.params['w'].tolist()
b = runner.model.params['b'].item()for i in range(len(weights)):print(columns_list[i],"weight:",weights[i])print("b:",b)

2.5模型測試

加載訓練好的模型參數,在測試集上得到模型的MSE指標

# 加載模型權重
runner.load_model(saved_dir)mse = runner.evaluate(test_dataset)
print('MSE:', mse.item())

2.6模型預測

load_model函數加載保存好的模型,使用predict進行模型預測

runner.load_model(saved_dir)
pred = runner.predict(X_test[:1])
print("真實房價:",y_test[:1].item())
print("預測的房價:",pred.item())
真實房價: 33.099998474121094
預測的房價: 33.04654312133789

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/899839.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/899839.shtml
英文地址,請注明出處:http://en.pswp.cn/news/899839.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

linux服務器專題1------redis的安裝及簡單配置

在 linux上安裝 Redis 可以按照以下步驟進行(此處用Ubuntu 服務器進行講解): 步驟 1: 更新系統包 打開終端并運行以下命令以確保你的系統是最新的: sudo apt update sudo apt upgrade步驟 2: 安裝 Redis 使用 apt 包管理器安裝 Redis: s…

面試問題總結:qt工程師/c++工程師

C 語言相關問題答案 面試問題總結:qt工程師/c工程師 C 語言相關問題答案 目錄基礎語法與特性內存管理預處理與編譯 C 相關問題答案面向對象編程模板與泛型編程STL 標準模板庫 Qt 相關問題答案Qt 基礎與信號槽機制Qt 界面設計與布局管理Qt 多線程與并發編程 目錄 基礎…

實現實時數據推送:SpringBoot中SSE接口的兩種方法

🌟 前言 歡迎來到我的技術小宇宙!🌌 這里不僅是我記錄技術點滴的后花園,也是我分享學習心得和項目經驗的樂園。📚 無論你是技術小白還是資深大牛,這里總有一些內容能觸動你的好奇心。🔍 &#x…

LXC 導入多Linux系統

前提要求 ubuntu下安裝lxd 參考Rockylinux下安裝lxd 參考LXC 源替換參考LXC 容器端口發布參考LXC webui 管理<

ES的文檔更新機制

想獲取更多高質量的Java技術文章&#xff1f;歡迎訪問Java技術小館官網&#xff0c;持續更新優質內容&#xff0c;助力技術成長 Java技術小館官網https://www.yuque.com/jtostring ES的文檔更新機制 在現代應用中&#xff0c;數據的動態性越來越強&#xff0c;我們不僅需要快…

trae.ai 編輯器:前端開發者的智能效率革命

一、為什么我們需要更智能的編輯器&#xff1f; 作為從業5年的前端開發者&#xff0c;我使用過從Sublime到VSCode的各種編輯器。但隨著現代前端技術的復雜度爆炸式增長&#xff08;想想一個React組件可能涉及JSX、CSS-in-JS、TypeScript和GraphQL&#xff09;&#xff0c;傳統…

MySQL篇(一):慢查詢定位及索引、B樹相關知識詳解

MySQL篇&#xff08;一&#xff09;&#xff1a;慢查詢定位及索引、B樹相關知識詳解 MySQL篇&#xff08;一&#xff09;&#xff1a;慢查詢定位及索引、B樹相關知識詳解一、MySQL中慢查詢的定位&#xff08;一&#xff09;慢查詢日志的開啟&#xff08;二&#xff09;慢查詢日…

uniapp APP端在線升級(簡版)

設計思路&#xff1a; 1.版本比較&#xff1a;應用程序檢查其當前版本與遠程服務器上可用的最新版本 2. 更新狀態指示&#xff1a;如果應用程序是不是最新的版本&#xff0c;則頁面提示下載最新版本。 3.下載啟動&#xff1a;通過plus.downloader.createDownload()啟動新應用…

基于javaweb的SpringBoot教務課程管理設計與實現(源碼+文檔+部署講解)

技術范圍&#xff1a;SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬蟲、數據可視化、小程序、安卓app、大數據、物聯網、機器學習等設計與開發。 主要內容&#xff1a;免費功能設計、開題報告、任務書、中期檢查PPT、系統功能實現、代碼編寫、論文編寫和輔導、論文…

使用大語言模型進行Python圖表可視化

Python使用matplotlib進行可視化一直有2個問題&#xff0c;一是代碼繁瑣&#xff0c;二是默認模板比較丑。因此發展出seaborn等在matplotlib上二次開發&#xff0c;以更少的代碼進行畫圖的和美化的庫&#xff0c;但是這也帶來了定制化不足的問題。在大模型時代&#xff0c;這個…

【JavaEE】MyBatis - Plus

目錄 一、快速使用二、CRUD簡單使用三、常見注解3.1 TableName3.2 TableFiled3.3 TableId 四、條件構造器4.1 QueryWrapper4.2 UpdateWrapper4.3 LambdaQueryWrapper4.4 LambdaUpdateWrapper 五、自定義SQL 一、快速使用 MyBatis Plus官方文檔&#xff1a;MyBatis Plus官方文檔…

采用前端技術開源了一個數據結構算法的可視化工具

今天要推薦的開源項目叫VisuAlgoX,是一個面向計算機科學和游戲開發的 交互式算法可視化工具&#xff0c;幫助用戶通過直觀的動畫理解各種數據結構和算法。 項目的前身 由于最近在做一些關于游戲和圖形化方面的文章&#xff0c;因此做了一部分相關算法的動態可視化來做配圖展示…

體驗智譜清言的AutoGLM進行自動化的操作(Chrome插件)

最近體驗了很多的大模型&#xff0c;大模型我是一直關注著ChatGLM&#xff0c;因為它確實在7b和8b這檔模型里&#xff0c;非常聰明&#xff01; 最近還體驗了很多大模型的應用軟件&#xff0c;比如Agently、5ire、 mcphost、 Dive、 NextChat等。但是這些一般都是圖形界面或者…

pytorch中dataloader自定義數據集

前言 在深度學習中我們需要使用自己的數據集做訓練&#xff0c;因此需要將自定義的數據和標簽加載到pytorch里面的dataloader里&#xff0c;也就是自實現一個dataloader。 數據集處理 以花卉識別項目為例&#xff0c;我們分別做出圖片的訓練集和測試集&#xff0c;訓練集的標…

Blender模型導入虛幻引擎設置

單位系統不一致 Blender默認單位是米&#xff08;Meters&#xff09;&#xff0c;而虛幻引擎默認使用**厘米&#xff08;Centimeters&#xff09;**作為單位。 當模型從Blender導出為FBX或其他格式時&#xff0c;如果沒有調整單位&#xff0c;虛幻引擎會將1米&#xff08;Blen…

Docker基礎詳解

Docker 技術詳解 一、概述 Docker官網&#xff1a;https://docs.docker.com/ 菜鳥教程&#xff1a;https://www.runoob.com/docker/docker-tutorial.html 1.1 什么是Docker&#xff1f; Docker 是一個開源的容器化平臺&#xff0c;它允許開發者將應用程序和其依賴項打包到…

FastPillars:一種易于部署的基于支柱的 3D 探測器

FastPillars&#xff1a;一種易于部署的基于支柱的 3D 探測器Report issue for preceding element Sifan Zhou 1 , Zhi Tian 2 , Xiangxiang Chu 2 , Xinyu Zhang 2 , Bo Zhang 2 , Xiaobo Lu11{}^{1}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPT11footnotemark: 1 Chengji…

NLP語言模型訓練里的特殊向量

1. CLS 向量和 DEC 向量的區別及訓練方式 (1) CLS 向量與 DEC 向量是否都是特殊 token&#xff1f; CLS 向量&#xff08;[CLS] token&#xff09;和 DEC 向量&#xff08;Decoder Input token&#xff09;都是特殊的 token&#xff0c;但它們出現在不同類型的 NLP 模型中&am…

字節跳動 UI-TARS 匯總整理報告

1. 摘要 UI-TARS 是字節跳動開發的一種原生圖形用戶界面&#xff08;GUI&#xff09;代理模型 。它將感知、行動、推理和記憶整合到一個統一的視覺語言模型&#xff08;VLM&#xff09;中 。UI-TARS 旨在跨桌面、移動和 Web 平臺實現與 GUI 的無縫交互 。實驗結果表明&#xf…

基于Python深度學習的鯊魚識別分類系統

摘要&#xff1a;鯊魚是海洋環境健康的指標&#xff0c;但受到過度捕撈和數據缺乏的挑戰。傳統的觀察方法成本高昂且難以收集數據&#xff0c;特別是對于具有較大活動范圍的物種。論文討論了如何利用基于媒體的遠程監測方法&#xff0c;結合機器學習和自動化技術&#xff0c;來…