TensorFlow 2.x 核心 API 與模型構建

TensorFlow 2.x 核心 API 與模型構建
TensorFlow 是一個強大的開源機器學習庫,尤其在深度學習領域應用廣泛。TensorFlow 2.x 在易用性和效率方面做了大量改進,引入了Keras作為其高級API,使得模型構建和訓練更加直觀和便捷。本文將介紹 TensorFlow 2.x 的核心 API 以及如何使用它們來構建和訓練一個深度學習模型。
一、 TensorFlow 2.x 的核心理念
TensorFlow 2.x 的核心理念是:
易用性 (Ease of Use): 通過Keras作為首選的高級API,簡化了模型的開發流程。
聲明式編程 (Declarative Programming): 允許開發者定義計算圖,但通過Eager Execution(即時執行)模式,使得構建和調試更加直觀,類似于Python的命令式編程。
端到端 (End-to-End): 支持從數據準備、模型訓練到模型部署的完整流程。
跨平臺 (Cross-Platform): 可以在CPU、GPU、TPU以及服務器、桌面、移動設備等多種平臺上運行。
二、 TensorFlow 2.x 的核心 API
TensorFlow 2.x 的API龐大且功能全面,但以下幾個是構建和訓練模型最常用的核心部分:
2.1 tf.keras:高級 API
tf.keras 是TensorFlow 2.x推薦并集成的首選高級API,它封裝了模型構建、層定義、損失函數、優化器、評估指標等常用功能,提供了一套面向對象且易于使用的接口。
模型 (tf.keras.Model 和 tf.keras.Sequential):
tf.keras.Sequential: 用于構建線性的、堆疊的層模型。非常適合順序結構的網絡。
tf.keras.Model: 更靈活的API,可以構建復雜的、具有多輸入/輸出、共享層、多分支的網絡結構。通過子類化(subclassing)tf.keras.Model 來定義。
層 (tf.keras.layers.*):
提供了構建神經網絡的基本單元,如 Dense (全連接層), Conv2D (卷積層), MaxPooling2D (池化層), Flatten (展平層), Dropout (正則化層), BatchNormalization (批歸一化層) 等。
每一層都有其可訓練的權重(kernel 和 bias)。
損失函數 (tf.keras.losses.*):
定義了模型預測與真實標簽之間的差距,如 CategoricalCrossentropy, SparseCategoricalCrossentropy, MeanSquaredError。
優化器 (tf.keras.optimizers.*):
實現了各種梯度下降的變種,用于更新模型的權重,如 Adam, SGD, RMSprop。
指標 (tf.keras.metrics.*):
用于評估模型的性能,如 Accuracy, Precision, Recall, AUC。
2.2 tf.data:數據處理管道
tf.data API 提供了一種高效、靈活地構建輸入數據管道的方式,能夠處理大規模數據集,并與 tf.keras 無縫集成。
創建數據集: 可以從NumPy數組、TensorFlow張量、CSV文件、TFRecords等多種數據源創建 tf.data.Dataset 對象。
數據轉換:
map(): 對數據集中的每個元素應用一個函數(如數據增強、特征工程)。
shuffle(): 隨機打亂數據集,通常在訓練開始前使用。
batch(): 將數據集中的元素分組打包成批。
prefetch(): 在模型訓練時,預先加載下一個批次的數據,避免CPU/GPU等待。
cache(): 將數據集內容緩存到內存或本地文件中,加快重復訪問的速度。
2.3 tf.Tensor:張量(Tensors)
張量是 TensorFlow 的核心數據結構,類似于 NumPy 的數組。它們是多維數組,可以存儲標量、向量、矩陣,乃至更高維度的數據。
創建張量:
tf.constant(): 創建一個不可更改的張量。
tf.Variable(): 創建一個可更改的張量,通常用于存儲模型的可訓練權重。
張量操作: TensorFlow 提供了豐富的張量運算函數,如 tf.add, tf.multiply, tf.matmul, tf.reduce_sum, tf.reshape 等。
Eager Execution: 在 TensorFlow 2.x 中,張量操作會立即執行并返回結果,這使得調試和交互式開發非常方便。
2.4 自動微分 (tf.GradientTape)
自動微分是深度學習模型訓練的關鍵。TensorFlow 2.x 使用 tf.GradientTape API 來記錄計算過程,并計算損失函數關于模型變量的梯度。
三、 使用 tf.keras 構建模型
有兩種主要方式構建 tf.keras 模型:
3.1 順序模型 (tf.keras.Sequential)
適用于線性堆疊的層,非常簡單直觀。
步驟:
創建一個 tf.keras.Sequential 實例。
通過 add() 方法將層依次添加到模型中。
最后,編譯模型(指定優化器、損失函數、評估指標)。
使用 fit() 方法訓練模型。
示例:構建一個簡單的全連接網絡進行MNIST圖像分類
<PYTHON>

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# 1. 定義模型
model = keras.Sequential([
# 輸入層:展平28x28的圖像為784維的向量
layers.Flatten(input_shape=(28, 28), name='input_layer'),
# 第一個隱藏層:全連接層,256個神經元,ReLU激活函數
layers.Dense(256, activation='relu', name='hidden_layer_1'),
# Dropout層:防止過擬合,以0.2的比例丟棄神經元
layers.Dropout(0.2),
# 輸出層:全連接層,10個神經元(對應0-9數字),softmax激活函數,輸出概率分布
layers.Dense(10, activation='softmax', name='output_layer')
])

# 2. 編譯模型
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss=keras.losses.SparseCategoricalCrossentropy(), # MNIST標簽是整數,使用SparseCategoricalCrossentropy
metrics=['accuracy'])

# (假設已加載并預處理好MNIST數據集: train_images, train_labels, test_images, test_labels)
# 例如:
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
# 需要將像素值歸一化到 [0, 1]
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0

# 3. 訓練模型
history = model.fit(train_images, train_labels,
epochs=10, # 訓練輪數
batch_size=32, # 每個批次的大小
validation_split=0.2) # 從訓練數據中劃分20%作為驗證集

# 4. 評估模型
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f'\nTest accuracy: {test_acc}')

# (可選)進行預測
# predictions = model.predict(test_images[:5])
# print(f'\nPredictions for first 5 test images:\n {predictions}')
3.2 函數式 API (tf.keras.Model 子類化)
適用于構建更復雜的模型,如多輸入、多輸出、共享層、非線性連接的模型。
步驟:
創建一個類,繼承自 tf.keras.Model。
在 __init__ 方法中定義模型所需的層。
在 call() 方法中實現模型的前向傳播邏輯,定義數據如何通過這些層。
實例化該類,然后編譯和訓練。
示例:構建一個更復雜的模型(例如,帶殘差連接)
<PYTHON>

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# 定義一個可以重用的殘差塊
def residual_block(x, filters, kernel_size=3):
# 存儲輸入,以便進行殘差連接
shortcut = x

# 第一個卷積層
x = layers.Conv2D(filters, kernel_size, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)

# 第二個卷積層
x = layers.Conv2D(filters, kernel_size, padding='same')(x)
x = layers.BatchNormalization()(x)

# 殘差連接:如果輸入和輸出的特征維度不匹配,需要通過1x1卷積進行轉換
if shortcut.shape[-1] != filters:
shortcut = layers.Conv2D(filters, (1, 1), padding='same')(shortcut)
shortcut = layers.BatchNormalization()(shortcut)

# 激活函數
x = layers.add([x, shortcut])
x = layers.Activation('relu')(x)
return x

# 定義主模型
class ComplexModel(keras.Model):
def __init__(self, num_classes=10):
super(ComplexModel, self).__init__()

# 輸入層 - 假設輸入尺寸為 (height, width, channels)
self.conv1 = layers.Conv2D(32, 3, activation='relu', padding='same', input_shape=(32, 32, 3))
self.pool1 = layers.MaxPooling2D((2, 2))

# 第一個殘差塊
self.res1 = residual_block(32, 32) # 32通道

# 第二個殘差塊(特征通道加倍)
self.res2 = residual_block(32, 64) # 64通道
self.pool2 = layers.MaxPooling2D((2, 2))

# 展平層
self.flatten = layers.Flatten()

# 全連接層
self.dense1 = layers.Dense(128, activation='relu')

# 輸出層
self.dropout = layers.Dropout(0.5)
self.output_dense = layers.Dense(num_classes, activation='softmax')

def call(self, inputs, training=False): # training參數用于控制Dropout等層的行為
x = self.conv1(inputs)
x = self.pool1(x)

x = self.res1(x)
x = self.res2(x)
x = self.pool2(x)

x = self.flatten(x)
x = self.dense1(x)

if training: # 只在訓練時應用Dropout
x = self.dropout(x)

return self.output_dense(x)

# 實例化模型
complex_model = ComplexModel(num_classes=10)

# 編譯模型
complex_model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])

# (假設已加載并預處理好CIFAR-10數據集)
# (train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
# ... 數據預處理 ...

# 訓練模型
# history = complex_model.fit(train_images, train_labels, epochs=20, batch_size=64, validation_split=0.2)

# 評估模型
# test_loss, test_acc = complex_model.evaluate(test_images, test_labels, verbose=2)
# print(f'\nTest accuracy: {test_acc}')
四、 數據處理管道 tf.data
使用 tf.data 可以高效地準備訓練數據。
示例:構建MNIST數據集的 tf.data 管道
<PYTHON>

import tensorflow as tf
from tensorflow import keras

# 加載數據
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

# 數據歸一化和重塑
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0
# 對于 Conv2D 層,輸入數據需要一個通道維度 (batch, height, width, channels)
# MNIST 是灰度圖,所以通道是 1
train_images = train_images[..., tf.newaxis]
test_images = test_images[..., tf.newaxis]

# 定義超參數
BATCH_SIZE = 64
BUFFER_SIZE = tf.data.AUTOTUNE # AUTOTUNE 會自動選擇最佳的緩沖區大小

# 構建訓練數據集管道
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_dataset = train_dataset.shuffle(BUFFER_SIZE) # 打亂數據
train_dataset = train_dataset.batch(BATCH_SIZE) # 分批
train_dataset = train_dataset.prefetch(buffer_size=BUFFER_SIZE) # 預取數據

# 構建測試數據集管道 (通常不需要shuffle,但需要batch和prefetch)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(buffer_size=BUFFER_SIZE)

# 現在可以直接將 train_dataset 和 test_dataset 傳遞給 model.fit() 和 model.evaluate()
# 示例:
# model = keras.Sequential([...]) # 假設模型已定義
# model.compile(...)
# history = model.fit(train_dataset, epochs=10, validation_data=test_dataset) # 可以直接傳入dataset
# test_loss, test_acc = model.evaluate(test_dataset)
五、 訓練、評估與預測
model.fit(): 這是模型訓練的核心方法。
接受訓練數據(X, y)或 tf.data.Dataset。
epochs: 訓練的總輪數。
batch_size: m?i批次樣本數。
validation_data 或 validation_split: 用于驗證模型的性能。
callbacks: 可以在訓練過程中執行特定動作,如保存模型、早停(Early Stopping)。
model.evaluate(): 用于評估模型在測試集或驗證集上的性能。
接受測試數據(X, y)或 tf.data.Dataset。
返回損失值和指定的評估指標。
model.predict(): 用于在新數據上進行預測。
接受輸入數據。
對于分類任務,通常返回預測屬于每個類別的概率;對于回歸任務,返回預測值。
六、 保存與加載模型
訓練好的模型可以保存下來,以便后續使用或部署。
保存整個模型: 包括模型結構、權重、優化器狀態。
<PYTHON>

model.save('my_model.keras') # 新格式
# 或者
# model.save('my_model_h5', save_format='h5') # 舊格式
加載模型:
<PYTHON>

loaded_model = keras.models.load_model('my_model.keras')
僅保存權重:
<PYTHON>

model.save_weights('my_model_weights.weights.h5') # 會自動選擇合適的格式
加載權重:
<PYTHON>

# 需要先構建模型結構
# complex_model_for_weights = ComplexModel()
# complex_model_for_weights.load_weights('my_model_weights.weights.h5')
七、 總結
TensorFlow 2.x 通過 Keras API 極大地簡化了深度學習模型的構建和訓練過程。掌握 tf.keras.Sequential 和 tf.keras.Model 的使用,結合 tf.data 構建高效的數據管道,并理解 tf.Tensor 和 tf.GradientTape 的概念,是成為一名TensorFlow開發者的基礎。
通過以上介紹,你應該已經對 TensorFlow 2.x 的核心 API 和模型構建有了初步的認識。在實際應用中,還需要不斷探索更多的層類型、激活函數、優化器、正則化技術以及更復雜的數據處理方法,來解決各種實際的機器學習問題。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/98489.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/98489.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/98489.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

TENGJUN防水TYPE-C連接器:工業級防護,認證級可靠,賦能嚴苛場景連接

在工業控制、戶外電子、水下設備等對連接穩定性與防護性要求極致的場景中&#xff0c;TENGJUN防水TYPE-C連接器以“硬核性能全面認證”的雙重優勢&#xff0c;成為關鍵連接環節的信賴之選。從結構設計到認證標準&#xff0c;每一處細節都為應對復雜環境而生&#xff0c;重新定義…

【小呆的隨機振動力學筆記】概率論基礎

文章目錄0. 概率論基礎0.1 概率的初步認知0.2 隨機變量的分布0.3 隨機變量的數字特征0.3.1 隨機變量的期望算子0.3.2 隨機變量的矩0.4 隨機變量的特征函數0.5 高數基礎附錄A 典型分布0. 概率論基礎 \quad\quad在生活中或自然中&#xff0c;處處都存在隨機現象&#xff0c;比如每…

使用海康機器人相機SDK實現基本參數配置(C語言示例)

在機器視覺項目開發中&#xff0c;相機的初始化、參數讀取與設置是最基礎也是最關鍵的環節。本文基于海康機器人&#xff08;Hikrobot&#xff09;提供的MVS SDK&#xff0c;使用C語言實現了一個簡潔的控制程序&#xff0c;完成設備枚舉、連接以及常用參數的獲取與設置。 &…

【IoTDB】時序數據庫選型指南:為何IoTDB成為工業大數據場景的首選?

【作者主頁】Francek Chen 【專欄介紹】???大數據與數據庫應用??? 大數據是規模龐大、類型多樣且增長迅速的數據集合&#xff0c;需特殊技術處理分析以挖掘價值。數據庫作為數據管理的關鍵工具&#xff0c;具備高效存儲、精準查詢與安全維護能力。二者緊密結合&#xff0…

用計算思維“破解”復雜Excel考勤表的自動化之旅

在我們日常工作中&#xff0c;經常會遇到一些看似簡單卻極其繁瑣的任務。手動處理一份結構復雜的Excel考勤表&#xff0c;就是典型的例子。它充滿了合并單元格、不規則的布局和隱藏的格式陷阱。面對這樣的挑戰&#xff0c;我們是選擇“卷起袖子&#xff0c;日復一日地手動復制粘…

PAT 1006 Sign In and Sign Out

1006 Sign In and Sign Out分數 25作者 CHEN, Yue單位 浙江大學At the beginning of every day, the first person who signs in the computer room will unlock the door, and the last one who signs out will lock the door. Given the records of signing ins and outs, yo…

【git】首次clone的使用采用-b指定了分支,還使用了--depth=1 后續在這個基礎上拉取所有的分支代碼方法

要解決當前問題&#xff08;從淺克隆轉換為完整克隆并獲取所有分支&#xff09;&#xff0c;請按照以下步驟操作&#xff1a; 步驟 1&#xff1a;檢查當前遠程地址 首先確認遠程倉庫地址是否正確&#xff1a; git remote -v步驟 2&#xff1a;修改遠程配置以獲取所有分支 默認淺…

蘿卜切丁機 機構筆記

蘿卜切丁機_STEP_模型圖紙免費下載 – 懶石網 機械工程師設計手冊 1是傳送帶 2是曲柄滑塊機構&#xff1f; 擠壓動作

多張圖片生成視頻模型技術深度解析

多張圖片生成視頻模型測試相比純文本輸入&#xff0c;有視覺參考約束的生成通常質量更穩定&#xff0c;細節更豐富 1. 技術原理和工作機制 多張圖片生成視頻模型是一種先進的AI技術&#xff0c;能夠接收多張輸入圖像&#xff0c;理解場景變化關系&#xff0c;并合成具有時間連…

中電金信:AI重構測試體系·智能化時代的軟件工程新范式

AI技術的迅猛發展正加速推動軟件工程3.0時代的到來&#xff0c;深刻地重塑了測試行業的運作邏輯&#xff0c;推動測試角色從“后置保障”轉變為“核心驅動力”。在大模型技術的助力下&#xff0c;測試質量和效能將顯著提升。9月5日至6日&#xff0c;Gtest2025全球軟件測試技術峰…

100、23種設計模式之適配器模式(9/23)

適配器模式&#xff08;Adapter Pattern&#xff09; 是一種結構型設計模式&#xff0c;它允許將不兼容的接口轉換為客戶端期望的接口&#xff0c;使原本由于接口不兼容而不能一起工作的類可以協同工作。 一、核心思想 將一個類的接口轉換成客戶期望的另一個接口使原本因接口不…

線上環境CPU使用率飆升,如何排查

線上環境CPU使用率飆升&#xff0c;如何排查 1.CPU飆升的常見原因 1. 代碼層面問題 死循環&#xff1a;錯誤的循環條件導致無限循環遞歸過深&#xff1a;沒有正確的終止條件算法效率低&#xff1a;O(n)或更高時間復雜度的算法處理大數據集頻繁GC&#xff1a;內存泄漏導致頻繁垃…

《sklearn機器學習——特征提取》

在 sklearn.feature_extraction 模塊中&#xff0c;DictVectorizer 是從字典&#xff08;dict&#xff09;中加載和提取特征的核心工具。它主要用于將包含特征名稱和值的 Python 字典列表轉換為機器學習算法所需的數值型數組或稀疏矩陣。 這種方法在處理結構化數據&#xff08;…

IEEE出版,限時早鳥優惠!|2025年智能制造、機器人與自動化國際學術會議 (IMRA 2025)

2025年智能制造、機器人與自動化國際學術會議 (IMRA2025)2025 International Conference on Intelligent Manufacturing, Robotics, and Automation中國?湛江2025年11月14日-2025年11月16日IMRA2025權威出版大咖云集穩定檢索智能制造、人工智能、機器人、物聯網&#xff08;Io…

C# 基于halcon的視覺工作流-章30-圓圓距離測量

C# 基于halcon的視覺工作流-章30-圓圓距離測量 本章目標&#xff1a; 一、利用圓卡尺找兩圓心&#xff1b; 二、distance_pp算子計算兩圓點距離&#xff1b; 三、匹配批量計算&#xff1b;本章是在章23-圓查找的基礎上進行測量使用&#xff0c;圓查找知識請閱讀章23&#xff0c…

java設計模式二、工廠

概述 工廠方法模式是一種常用的創建型設計模式&#xff0c;它通過將對象的創建過程封裝在工廠類中&#xff0c;實現了創建與使用的分離。這種模式不僅提高了代碼的復用性&#xff0c;還增強了系統的靈活性和可擴展性。本文將詳細介紹工廠方法模式的三種形式&#xff1a;簡單工廠…

Ubuntu 24.04 中 nvm 安裝 Node 權限問題解決

個人博客地址&#xff1a;Ubuntu 24.04 中 nvm 安裝 Node 權限問題解決 | 一張假鈔的真實世界 參考nvm的一個issue&#xff1a;https://github.com/nvm-sh/nvm/issues/3363 異常信息如下&#xff1a; $ nvm install 22 Downloading and installing node v22.19.0... Download…

Java面試-線程安全篇

一、synchronized關鍵字&#xff1a; 基本使用與作用&#xff1a;通過搶票代碼示例&#xff0c;展示了synchronized作為對象鎖&#xff0c;可避免多線程超賣或搶到同一張票問題&#xff0c;保證代碼原子性&#xff0c;同一時刻只有一個線程獲得鎖&#xff0c;其他線程阻塞。底層…

R 語言科研繪圖 --- 其他繪圖-匯總2

在發表科研論文的過程中&#xff0c;科研繪圖是必不可少的&#xff0c;一張好看的圖形會是文章很大的加分項。 為了便于使用&#xff0c;本系列文章介紹的所有繪圖都已收錄到了 sciRplot 項目中&#xff0c;獲取方式&#xff1a; R 語言科研繪圖模板 --- sciRplothttps://mp.…

【數學建模學習筆記】啟發式算法:粒子群算法

零基礎小白看懂粒子群優化算法&#xff08;PSO&#xff09;一、什么是粒子群優化算法&#xff1f;簡單說&#xff0c;粒子群優化算法&#xff08;PSO&#xff09;是一種模擬鳥群 / 魚群覓食的智能算法。想象一群鳥在找食物&#xff1a;每只鳥&#xff08;叫 “粒子”&#xff0…