使用八股搭建神經網絡

神經網絡搭建八股

使用tf.keras

六步法搭建模型

1.import

2.train, test 指定輸入特征/標簽

3.model = tf.keras.model.Sequential

在Squential,搭建神經網絡

4.model.compile

配置訓練方法,選擇哪種優化器、損失函數、評測指標

5.model.fit 執行訓練過程,告知訓練集輸入特征,batch,epoch

6.model.summary打印網絡結構和參數統計

model = tf.keras.model.Sequential

Sequential是個容器,封裝了網絡結構

網絡結構例子:

拉直層:tf.keras.layers.Flatten()

全連接層:tf.keras.layers.Dense(神經元個數,activetion="激活函數",kernel_regularizer=那種正則化)

卷積層:

tf.keras.layers.Conv2D(filters= 卷積核個數,kernel_size=卷積核尺寸,strides=卷積步長,padding="valid"or"same"

LSTM層:

tf.keras.layers.LSTM()

model.compile

model.compile(optimizer=優化器,loss=損失函數,metrics=["準確率"]

后期可通過tensorflow官網查詢函數的具體用法,調節超參數

有些網絡輸出經過softmax概率分布輸出,有些不經過概率分布輸出?

當網絡評測指標和蒙的概率一樣,例如十分類概率為.1/10.可能概率分布錯了

獨熱碼y_和y是[010]網絡輸出則為[0.xx, 0.xx, 0.xx]

?第三種方法 y_= [1] y =[0.2xx,0xx,0xx]

model.fit

model.fit(訓練集的輸入特征,訓練集的標簽,batch_size, epochs=,?

validation_data=(測試集的輸入特征,標簽),

validation_split=從訓練集劃分多少比例給測試集,

validation_freq=多少次epoch測試一次)

model.summary

重構Iris分類

import tensorflow as tf
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)model = tf.keras.models.Sequential([tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)model.summary()

自定義搭建模型

swquential可以搭建上層輸出就是下層輸入的網絡結構,但是無法搭建帶有跳連特征的非順序網絡結構

class MyModel(Model)

? ? ? ? def __init__(self):

? ? ? ? ? ? ? ? super(MyModel, self) __init()

? ? ? ? ????????定義網絡結構塊

? ? ? ? def call(self, x): #寫出前向傳播

? ? ? ? ? ? ? ?調用網絡結構塊,實現前向傳播

? ? ? ? return y? ? ?

model = MyModel

__init__定義出積木

call調用積木,實現前向傳播

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return ymodel = IrisModel()model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

每循環一次train,計算一次test的測試指標

MNIST數據集

1.導入MNIST數據集

mnist=tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) =? mnist.load_data(

2.作為輸入特征,輸入神經網絡時,將數據拉伸成一維數組:

tf.keras.layers.Flatten()

把784個像素點的灰度值作為輸入特征放入神經網絡

plt.imshow(x_train[0], cmap='gray')#繪制灰度圖

plt.show()

0表示純黑色255表示純白色

需要對測試集和數據集進行歸一化處理,把數值變小,更適合神經網絡吸收,使用sequental訓練模型,由于輸入特征為數組,輸出為概率分布,所以我們選擇sparse_categorical_accuracy

import tensorflow as tfmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

準確率是測試集的準確率

自定義Model實現 __init__中定義cell函數中用到的層

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Modelmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0class MnistModel(Model):def __init__(self):super(MnistModel, self).__init__()self.flatten = Flatten()self.d1 = Dense(128, activation='relu')self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.flatten(x)x = self.d1(x)y = self.d2(x)return ymodel = MnistModel()model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

FASHION數據集

import tensorflow as tffashion = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/42947.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/42947.shtml
英文地址,請注明出處:http://en.pswp.cn/web/42947.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

送給我親愛的Python

親愛的 Python, 在萬物皆代碼的世界里,你是我最優雅、最高效的算法。自從第一次遇見你,在那行“Hello, World!”之后,我的世界就被點亮了。你的簡潔性和強大的功能,讓我深深著迷,就像一個精心設計的函數&am…

數據結構雙向循環鏈表

主程序 #include "fun.h" int main(int argc, const char *argv[]) { double_p Hcreate_head(); insert_head(H,10); insert_head(H,20); insert_head(H,30); insert_head(H,40); insert_tail(H,50); show_link(H); del_tail(H); …

Python 傳遞參數和返回值

Python是一種功能強大的編程語言,它以其簡潔和易用性而廣受歡迎。在Python編程中,參數傳遞和返回值是函數調用中兩個非常重要的概念。理解這些概念對于編寫高效且可維護的代碼至關重要。 一、參數傳遞 在Python中,函數參數可以通過以下幾種…

Linux 網絡時間同步:NTP 與 Chrony 的終極對決

Linux 網絡時間同步:NTP 與 Chrony 的終極對決 在網絡世界中,時間同步是一項至關重要的任務。無論是確保分布式系統的一致性,還是維護安全協議的完整性,準確的時間同步都是必不可少的。網絡時間協議(NTP)和…

Golang期末作業之電子商城(源碼)

作品介紹 1.網頁作品簡介方面 :主要有:首頁 商品詳情 購物車 訂單 評價 支付 總共 5個頁面 2.作品使用的技術:這個作品基于Golang語言,并且結合一些前端的知識,例如:HTML、CSS、JS、AJAX等等知識點,同時連接數據庫的&…

統信UOS軟件包標識化工具deepin-sbom-tools使用

原文鏈接:統信UOS上使用軟件包標識化工具deepin-sbom-tools Hello,大家好啊!今天給大家帶來一篇關于在統信UOS上使用軟件包標識化工具deepin-sbom-tools的文章。deepin-sbom-tools是一個強大的工具,可以幫助開發者和系統管理員更好…

Linux初始化新的git倉庫

1.在git服務器上找到項目常部署的git地址可以根據其他項目的git地址確認 例如ssh://git192.168.10.100/opt/git/repository.git 用戶名:git(前面的是用戶) 服務器地址:192.168.10.100 git倉庫路徑:/opt/git/ 2.在服務器…

數據結構之折半查找

折半查找的算法思想: 折半查找又稱二分查找,它僅僅適用于有序的順表。 折半查找的基本思想:首先將給定值key與表中中間位置的元素(mid的指向元素)比較。midlowhigh/2(向下取整) 若key與中間元…

C#—Json序列化和反序列化

C#—Json序列化和反序列化 在C#中,可以使用System.Web.Script.Serialization.JavaScriptSerializer類來序列化和反序列化JSON數據。 可以使用Newtonsoft.Json庫進行JSON的序列化。 可以使用.NET內置的System.Text.Json庫來進行JSON的序列化。 json文件格式 [ { …

搜索引擎優化培訓機構怎么選?這篇文章告訴你答案

搜索引擎優化(SEO)已成為網絡生存必備技能。然而面對眾多培訓機構,如何選擇優秀者?本文將為您揭曉此事,助您找到騰飛之地。 一、培訓機構的多樣性:琳瑯滿目的選擇 當前SEO培訓市場繁蕪復雜,既…

C++ 八股(1)

C語言中strcpy為什么不安全?如何解決? 主要原因是缺乏對輸入長度的邊界檢查,容易導致緩沖區溢出漏洞。 解決:可以使用strncpy函數替代,或者在程序最頂端加入代碼段 #define _CRT_SECURE_NO_WARNINGS 緩沖區溢出 …

javascript高級部分筆記

javascript高級部分 Function方法 與 函數式編程 call 語法:call([thisObj[,arg1[, arg2[, [,.argN]]]]]) 定義:調用一個對象的一個方法,以另一個對象替換當前對象。 說明:call 方法可以用來代替另一個對象調用一個方法。cal…

MySQL運維實戰之ProxySQL(9.5)proxysql和MySQL Group Replication配合使用

作者:俊達 如果后端MySQL使用了Group Replication,可通過配置mysql_group_replication_hostgroups表來實現高可用 1 mysql_group_replication_hostgroups 字段描述writer_hostgroup寫hostgroup。read_only和super_read_only OFF的節點。backup_writer…

Vue3 pdf.js將二進制文件流轉成pdf預覽

好久沒寫東西,19年之前寫過一篇Vue2將pdf二進制文件流轉換成pdf文件,如果Vue2換成Vue3了,順帶來一篇文章,pdf.js這個東西用來解決內網pdf預覽,是個不錯的選擇。 首先去pdfjs官網,下載需要的文件 然后將下載…

第4章 IT服務規劃設計

第4章 IT服務規劃設計 4.1 概述 規劃設計處于整個IT服務生命周期中的前端,可以幫助IT服務供方了解客戶的需求,并對其進行全面的需求分析,然后通過對服務要素(包括人員、資源、技術和過程)、服務模式和服務方案的具體…

OpenHarmony4.x 系統模擬器環境

先下載源碼和編譯程序: 首先查看 OpenHarmony4.1源碼下載、編譯,生成OHOS_Image可執行文件的最簡易流程 準備在QEMU模擬器中運行ARM Cortex-M4的輕型開源鴻蒙系統 官方支持的開發板和模擬器種類-編譯形態整體說明OpenAtom OpenHarmony 已支持的示例工…

ArduPilot開源代碼之AP_MSP

ArduPilot開源代碼之AP_MSP 1. 源由2. Library設計2.1 啟動代碼2.2 支持特性2.3 MSP DisplayPort v.s. DJI FPV OSD 3. 重要例程3.1 AP_MSP::init3.2 AP_MSP::loop3.3 AP_MSP::init_backend 4. 實例理解5. 總結6. 參考資料 1. 源由 AP_MSP是處理MSP協議格式的報文數據應用類。…

反向業務判斷邏輯

業務功能需求: 根據id扣減用戶余額 包括:判斷用戶狀態是否正常判斷用戶余額是否充足 正向邏輯: 判斷用戶為正常下,判斷用戶余額充足,進行余額扣減; 》正向邏輯,多重嵌套,代碼不美觀…

??一文帶你入門【NestJS】

??引言 在現代Web開發領域,框架和技術的迭代速度令人咋舌。其中,NestJS作為一款基于Node.js的后端框架,以其卓越的設計理念和強大的功能集,迅速吸引了眾多開發者的眼球。本文將帶你深入了解NestJS的起源、發展,以及…

SpringIOC原理

SpringIOC原理 1.概念 Spring通過一個配置文件描述Bean及Bean之間的依賴關系,利用Java語言的反射功能實例化Bean并建立Bean之間的依賴關系。Spring的IOC容器在完成這些底層工作的基礎上,還提供了Bean實例緩存、生命周期管理、Bean實例代理、事件發布、…