DAY12 Tensorflow 六步法搭建神經網絡

六步法:

一.import? ?

導入各種庫,比如:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
import numpy as np
import pandas as pd
# 可能還會根據需求導入其他庫,如用于數據可視化的 matplotlib 等
import matplotlib.pyplot as plt

二.train,test

準備訓練數據和測試數據,比如:

# 以 MNIST 數據集為例
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 數據預處理
x_train, x_test = x_train / 255.0, x_test / 255.0

    ? 首先,從相應的數據集中加載數據,如這里使用 mnist.load_data() 加載 MNIST 手寫數字數據集,得到訓練集的特征 x_train 和標簽 y_train,以及測試集的特征 x_test 和標簽 y_test。然后,對數據進行預處理,常見的預處理操作包括歸一化、標準化等。在上述代碼中,將圖像像素值除以 255,將其縮放到 0 到 1 的范圍,這有助于模型的訓練和收斂。

    三.model=tf.keras.models.Sequential

    構建模型架構,比如:

    model = tf.keras.models.Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),Dense(10, activation='softmax')
    ])

    四.model.compile

    配置模型訓練過程,比如:

    model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

    這一步用于配置模型的訓練過程,主要設置三個重要參數:

    optimizer:優化器,用于調整模型的參數以最小化損失函數。adam 是一種常用的優化器,它結合了 AdaGrad 和 RMSProp 的優點,具有自適應學習率的能力。


    loss:損失函數,用于衡量模型預測結果與真實標簽之間的差異。其中sparse_categorical_crossentropy 適用于標簽為整數編碼的多分類問題。


    metrics:評估指標,用于在訓練和測試過程中監控模型的性能。accuracy 表示準確率,即模型預測正確的樣本數占總樣本數的比例。

    五,model.fit

    進行模型訓練,使用訓練模型進行迭代訓練。

    model.fit(x_train, y_train, epochs=5)

    六,model.summary

    這一步用于打印模型的結構信息,包括每一層的名稱、輸出形狀和參數數量等。通過查看?model.summary()?的輸出,你可以了解模型的整體架構和參數規模,幫助你檢查模型是否符合預期,以及評估模型的復雜度。

    model.summary()

    各自的使用方法:

    Flatten只是把數值特征拉成一維數組

    Dense全連接

    后面是卷積神經網絡層和循環神經網絡層

    compile配置訓練方法。

    validation_data和validation_split二選一,進行訓練。

    validation_freq 多少輪訓練后用測試集測試一次。

    model.summary()打印出統計結果,其中可以看到,總共的參數15個,可訓練參數15個,不可訓練參數0個。

    以下是用六步法搭建鳶尾花分類。

    import tensorflow as tf
    from sklearn import datasets
    import numpy as npx_train = datasets.load_iris().data
    y_train = datasets.load_iris().targetnp.random.seed(116)
    np.random.shuffle(x_train)
    np.random.seed(116)
    np.random.shuffle(y_train)
    tf.random.set_seed(116)model = tf.keras.models.Sequential([tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
    ])model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)model.summary()
    

    首先import,然后兩個train交代訓練集和測試集。

    np到tf這幾個random用于打亂數據集設置。

    model中設置網絡結構。3指的是神經元個數是3,activation選用激活函數,最后是選用正則化方法。

    complie中配置訓練方法,SGD優化器,學習率0.1,選用SparseCategoricalCrossentropy當做損失函數,由于神經末端使用softmax函數,輸出不是原始分布,所以logits=False。

    鳶尾花數據集給的是0,1,2是數值,神經網絡前向輸出是概率分布,選擇sparse_categorical_accuracy作為測評指標。

    fit中執行訓練過程,分別是 輸入特征,訓練集標簽,訓練時一次喂給神經網絡多少組數據batch_size,循環迭代次數,validation_split=0.2告知從訓練集中選擇百分之20數據當做測試集,validation_freq=20,表示迭代20次,在測試集中驗證一次準確率。

    運行結果:

    可見,打印出了網絡結構和參數統計。

    本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
    如若轉載,請注明出處:http://www.pswp.cn/bicheng/71482.shtml
    繁體地址,請注明出處:http://hk.pswp.cn/bicheng/71482.shtml
    英文地址,請注明出處:http://en.pswp.cn/bicheng/71482.shtml

    如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

    相關文章

    Zookeeper分布式鎖實現

    zookeeper最初設計的初衷就是為了保證分布式系統的一致性。本文將講解如何利用zookeeper的臨時順序結點,實現分布式鎖。 目錄 1. 理論分析 1.1 結點類型 1.2 監聽器 1.3 實現原理 2. 手寫實現簡易zookeeper分布式鎖 1.1 依賴 1.2 常量定義 1.3 實現zookeeper分布式…

    Git是什么

    簡單介紹: Git是一個分布式版本控制系統,用于跟蹤文件的更改,特別是在多人協作開發的環境中。 Key: 分布式 版本控制 系統 最常用于軟件開發,但也可以用于管理任何類型的文件和文件夾。 Git幫助團隊跟蹤和管理文件的歷史版本&a…

    Pycharm 2024在解釋器提供的python控制臺中運行py文件

    2024版的界面發生了變化, run with python console搬到了這里:

    【分布式理論12】事務協調者高可用:分布式選舉算法

    文章目錄 一、分布式系統中事務協調的問題二、分布式選舉算法1. Bully算法2. Raft算法3. ZAB算法 三、小結與比較 一、分布式系統中事務協調的問題 在分布式系統中,常常有多個節點(應用)共同處理不同的事務和資源。前文 【分布式理論9】分布式…

    免費deepseek的API獲取教程及將API接入word或WPS中

    免費deepseek的API獲取教程: 1 https://cloud.siliconflow.cn/中注冊時填寫邀請碼:GAejkK6X即可獲取2000 萬 Tokens; 2 按照圖中步驟進行操作 將API接入word或WPS中 1 打開一個word,文件-選項-自定義功能區-勾選開發工具-左側的信任中心-信任中心設置…

    【SFRA】筆記

    GK_SFRA_INJECT(x) SFRA小信號注入函數,向控制環路注入一個小信號。如下圖所示,當前程序,小信號注入是在固定占空比的基礎疊加小信號,得到新的占空比,使用該占空比控制環路。 1.2 GK_SFRA_COLLECT(x, y) SFRA數據收集函數,將小信號注入環路后,該函數收集環路的數據,以…

    論文筆記-WSDM2024-LLMRec

    論文筆記-WSDM2024-LLMRec: Large Language Models with Graph Augmentation for Recommendation LLMRec: 基于圖增強的大模型推薦摘要1.引言2.前言2.1使用圖嵌入推薦2.2使用輔助信息推薦2.3使用數據增強推薦 3.方法3.1LLM作為隱式反饋增強器3.2基于LLM的輔助信息增強3.2.1用戶…

    Ubuntu 系統 cuda12.2 安裝 MMDetection3D

    DataBall 助力快速掌握數據集的信息和使用方式,會員享有 百種數據集,持續增加中。 需要更多數據資源和技術解決方案,知識星球: “DataBall - X 數據球(free)” 貴在堅持! ---------------------------------------…

    Tomcat的升級

    Tomcat 是一個開源的 Java Servlet 容器,用于部署 Java Servlet 和 JavaServer Pages(JSP)。隨著新版本的發布,Tomcat 通常會帶來性能改進、安全增強、新特性和對最新 Java 版本的更好支持。升級 Tomcat 服務器通常涉及到以下幾個…

    Python常見面試題的詳解10

    1. 哪些操作會導致 Python 內存溢出,怎么處理? 要點 1. 創建超大列表或字典:當我們一次性創建規模極為龐大的列表或字典時,會瞬間占用大量的內存資源。例如,以下代碼試圖創建一個包含 10 億個元素的列表,在…

    多個用戶如何共用一根網線傳輸數據

    前置知識 一、電信號 網線(如以太網線)中傳輸的信號主要是 電信號,它攜帶著數字信息。這些信號用于在計算機和其他網絡設備之間傳輸數據。下面是一些關于網線傳輸信號的詳細信息: 1. 電信號傳輸 在以太網中,數據是…

    華為昇騰 910B 部署 DeepSeek-R1 蒸餾系列模型詳細指南

    本文記錄 在 華為昇騰 910B(65GB) * 8 上 部署 DeepSeekR1 蒸餾系列模型(14B、32B)全過程與測試結果。 NPU:910B3 (65GB) * 8 (910B 有三個版本 910B1、2、3) 模型:DeepSeek-R1-Distill-Qwen-14B、DeepSeek…

    【前端】Vue組件庫之Element: 一個現代化的 UI 組件庫

    文章目錄 前言一、官網1、官網主頁2、設計原則3、導航4、組件 二、核心功能:開箱即用的組件生態1、豐富的組件體系2、特色功能亮點 三、快速上手:三步開啟組件化開發1、安裝(使用Vue 3)2、全局引入3、按需導入(推薦&am…

    關于uniApp的面試題及其答案解析

    我的血液里流淌著戰意!力量與智慧指引著我! 文章目錄 1. 什么是uniApp?2. uniApp與原生小程序開發有什么區別?3. 如何使用uniApp實現條件編譯?4. uniApp支持哪些平臺,各有什么特點?5. 在uniApp中…

    Ubuntu 下 nginx-1.24.0 源碼分析 - ngx_pool_t 類型

    ngx_pool_t 定義在 src/core/ngx_core.h typedef struct ngx_pool_s ngx_pool_t; ngx_pool_s 定義在 src/core/ngx_palloc.h struct ngx_pool_s {ngx_pool_data_t d;size_t max;ngx_pool_t *current;ngx_chain_t *chain;ng…

    力扣 最長遞增子序列

    動態規劃,二分查找。 題目 由題,從數組中找一個最長子序列,不難想到,當這個子序列遞增子序列的數越接近時是越容易拉長的。從dp上看,當遍歷到這個數,會從前面的dp選一個最大的數加上當前數,注意…

    Linux | 進程控制(進程終止與進程等待)

    文章目錄 Linux | 進程控制 — 進程終止 & 進程等待1、進程終止進程常見退出方法1.1退出碼基本概念獲取退出碼的方式常見退出碼約定使用場景 1.2 strerror函數 & errno宏1.3 _exit函數1.4_exit和exit的區別1.4.1 所屬頭文件與函數原型1.4.2 執行過程差異**結合現象分析…

    Android - Handler使用post之后,Runnable沒有執行

    問題:子線程創建的Handler。如果 post 之后,在Handler.removeCallbacks(run)移除了,下次再使用Handler.postDelayed(Runnable)接口或者使用post時,Runnable是沒有執行。導致沒有收到消息。 解決辦法:只有主線程創建的…

    魚皮面試鴨30天后端面試營

    day1 1. MySQL的索引類型有哪些? MySQL里的索引就像是書的目錄,能幫數據庫快速找到你要的數據。以下是各種索引類型的通俗解釋: 按數據結構分 B樹索引:最常用的一種,數據像在一棵樹上分層存放,能快速定位范圍數據…

    【核心算法篇十二】《深入解剖DeepSeek多任務學習:共享表示層的24個設計細節與實戰密碼 》

    引言:為什么你的模型總在"精神分裂"? 想象你訓練了一個AI實習生: 早上做文本分類時準確率90%下午做實體識別卻把"蘋果"都識別成水果公司晚上做情感分析突然開始輸出亂碼這就是典型的任務沖突災難——模型像被不同任務"五馬分尸"。DeepSeek通…