TensorFlow代碼邏輯 vs PyTorch代碼邏輯

文章目錄

  • 一、TensorFlow
    • (一)導入必要的庫
    • (二)加載MNIST數據集
    • (三)數據預處理
    • (四)構建神經網絡模型
    • (五)編譯模型
    • (六)訓練模型
    • (七)評估模型
    • (八)將模型的輸出轉化為概率
    • (九)預測測試集的前5個樣本
  • 二、PyTorch
    • (一)導入必要的庫
    • (二)定義神經網絡模型
    • (三)數據預處理和加載
    • (四)初始化模型、損失函數和優化器
    • (五)訓練模型
    • (六)評估模型
    • (七)設置設備為GPU或CPU
    • (八)運行訓練和評估
    • (九)預測測試集的前5個樣本
  • 三、TensorFlow和PyTorch代碼邏輯上的對比
    • (一)模型定義
    • (二)數據處理
    • (三)訓練過程
    • (四)自動求導
  • 四、TensorFlow和PyTorch的應用
  • 五、動態圖計算
    • (一)TensorFlow(靜態圖計算):
    • (二)PyTorch(動態圖計算):

一、TensorFlow

使用TensorFlow構建一個簡單的神經網絡來對MNIST數據集進行分類

(一)導入必要的庫

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

(二)加載MNIST數據集

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

(三)數據預處理

將圖像數據歸一化到[0, 1]范圍,以提高模型的訓練效果

x_train, x_test = x_train / 255.0, x_test / 255.0

(四)構建神經網絡模型

  • Flatten層:將輸入的28x28像素圖像展平成784個特征的一維向量。
  • Dense層:全連接層,包含128個神經元,使用ReLU激活函數。
  • Dropout層:在訓練過程中隨機丟棄20%的神經元,防止過擬合。
  • 輸出層:包含10個神經元,對應10個類別(數字0-9)。
model = models.Sequential([layers.Flatten(input_shape=(28, 28)),layers.Dense(128, activation='relu'),layers.Dropout(0.2),layers.Dense(10)
])

(五)編譯模型

  • optimizer=‘adam’:使用Adam優化器。
  • loss=‘SparseCategoricalCrossentropy’:使用交叉熵損失函數。
  • metrics=[‘accuracy’]:使用準確率作為評估指標。
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

(六)訓練模型

model.fit(x_train, y_train, epochs=5)

(七)評估模型

model.evaluate(x_test, y_test, verbose=2)

(八)將模型的輸出轉化為概率

probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()
])

(九)預測測試集的前5個樣本

predictions = probability_model.predict(x_test[:5])
print(predictions)

二、PyTorch

使用PyTorch來構建、訓練和評估一個用于MNIST數據集的神經網絡模型

(一)導入必要的庫

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

(二)定義神經網絡模型

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28 * 28, 128)self.dropout = nn.Dropout(0.2)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = F.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x

(三)數據預處理和加載

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

(四)初始化模型、損失函數和優化器

model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

(五)訓練模型

def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

(六)評估模型

def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} 'f'({accuracy:.0f}%)\n')

(七)設置設備為GPU或CPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

(八)運行訓練和評估

for epoch in range(1, 6):train(model, device, train_loader, optimizer, epoch)test(model, device, test_loader)

(九)預測測試集的前5個樣本

model.eval()
with torch.no_grad():samples = next(iter(test_loader))[0][:5].to(device)output = model(samples)predictions = F.softmax(output, dim=1)print(predictions)

三、TensorFlow和PyTorch代碼邏輯上的對比

(一)模型定義

  • 在TensorFlow中,通常使用tf.keras模塊來定義模型。可以使用Sequential API或Functional API。
import tensorflow as tf# Sequential API
model = tf.keras.Sequential([tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])# Functional API
inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)
  • PyTorch中,定義模型時需要繼承nn.Module類并實現forward方法
import torch
import torch.nn as nnclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.dense1 = nn.Linear(784, 64)self.relu = nn.ReLU()self.dense2 = nn.Linear(64, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.relu(self.dense1(x))x = self.softmax(self.dense2(x))return xmodel = Model()

(二)數據處理

  • TensorFlow有tf.data模塊來處理數據管道
import tensorflow as tfdef preprocess(data):# 數據預處理邏輯return datadataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
dataset = dataset.map(preprocess).batch(32)
  • PyTorch使用torch.utils.data.DataLoader和Dataset類來處理數據管道
import torch
from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):x = self.data[idx]y = self.labels[idx]return x, ydataset = CustomDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

(三)訓練過程

  • TensorFlow的tf.keras提供了高階API來進行模型編譯和訓練
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=5, batch_size=32)
  • PyTorch中,訓練過程需要手動編寫,包括前向傳播、損失計算、反向傳播和優化步驟
import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(5):for data, labels in dataloader:optimizer.zero_grad()outputs = model(data)loss = criterion(outputs, labels)loss.backward()optimizer.step()

(四)自動求導

  • TensorFlow在后端自動處理梯度計算和應用
# 使用model.fit自動處理
  • PyTorch的自動求導功能非常靈活,可以使用autograd模塊
# 使用loss.backward()和optimizer.step()手動處理

四、TensorFlow和PyTorch的應用

總體來說,PyTorch提供了更多的靈活性和控制,適合需要自定義復雜模型和訓練過程的場景。而TensorFlow則更加高級和簡潔,適合快速原型和標準模型的開發。

TensorFlow:

  • 高階API:使用tf.keras簡化模型定義、訓練和評估,適合快速原型開發和生產部署。
  • 性能優化:支持圖計算,優化執行速度和資源使用,適合大規模分布式訓練。
  • 廣泛生態:擁有豐富的工具和庫,如TensorBoard用于可視化,TensorFlow Lite用于移動端部署。
  • 企業支持:由Google支持,廣泛應用于工業界,提供穩定的長期支持和更新。

PyTorch:

  • 靈活性:采用動態圖計算,代碼易于調試和修改,適合研究和實驗。
  • 簡單直觀:符合Python語言習慣,API設計簡潔明了,降低學習曲線。
  • 社區活躍:由Facebook支持,擁有活躍的開源社區,快速響應用戶需求和改進。
  • 科研應用:廣泛應用于學術界,支持多種前沿研究,如自定義損失函數和復雜模型結構。

五、動態圖計算

動態圖計算是PyTorch的一個顯著特點,它讓模型的計算圖在每次前向傳播時動態生成,而不是像TensorFlow那樣預先定義和編譯。

動態圖計算的定義與特性:

  • 動態生成:每次執行前向傳播時,計算圖都會根據當前輸入數據動態構建。
  • 即時調試:允許在代碼執行時使用標準的Python調試工具(如pdb),進行逐步調試和檢查。
  • 靈活性高:支持更復雜和動態的模型結構,如條件控制流和遞歸神經網絡,更適合研究實驗和快速原型開發。

(一)TensorFlow(靜態圖計算):

在TensorFlow中,計算圖是預先定義并編譯的。在模型定義和編譯之后,圖結構固定,隨后輸入數據進行計算。

import tensorflow as tf# 定義計算圖
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
logits = tf.matmul(x, W) + b
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))# 創建會話并執行
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(1000):batch_x, batch_y = ...  # 獲取訓練數據sess.run(loss, feed_dict={x: batch_x, y: batch_y})

(二)PyTorch(動態圖計算):

在PyTorch中,計算圖在每次前向傳播時動態構建,代碼更接近標準的Python編程風格。

import torch
import torch.nn as nn
import torch.optim as optim# 定義模型
class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.dense = nn.Linear(784, 10)def forward(self, x):return self.dense(x)model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 訓練過程
for epoch in range(1000):for data, target in dataloader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/38473.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/38473.shtml
英文地址,請注明出處:http://en.pswp.cn/web/38473.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

@RequestMapping屬性詳解及案例演示

RequestMapping源碼 Target({ElementType.TYPE, ElementType.METHOD}) Retention(RetentionPolicy.RUNTIME) Documented Mapping public interface RequestMapping {String name() default "";AliasFor("path")String[] value() default {};AliasFor(&quo…

智能寫作與痕跡消除:AI在創意文案和論文去痕中的應用

作為一名AI愛好者,我積累了許多實用的AI生成工具。今天,我想分享一些我經常使用的工具,這些工具不僅能幫助提升工作效率,還能激發創意思維。 我們都知道,隨著技術的進步,AI生成工具已經變得越來越智能&…

簡單分享 for循環,從基礎到高級

1. 基礎篇:Hello, For Loop! 想象一下,你想給班上的每位同學發送“Hello!”,怎么辦?那就是for循環啦, eg:首先有個名字的列表,for循環取出,分別打印 names ["Alice", …

Apache APISIX 介紹

Apache APISIX 是一個動態、實時、高性能的云原生API網關,屬于Apache軟件基金會旗下的項目。以下是對Apache APISIX的詳細介紹: 一、基本概述 定義:Apache APISIX是一個提供豐富流量管理功能的云原生API網關。功能:包括負載均衡…

git出現Permission denied問題

Warning: Permanently added ‘icode.baidu.com,10.11.81.103’ (RSA) to the list of known hosts. Permission denied (baas,keyboard-interactive,publickey). fatal: Could not read from remote repository. Please make sure you have the correct access rights and the…

nodejs操作excel文件實例,讀取sheets, 設置cell顏色

本代碼是我幫客戶做的兼職的實例,涉及用node讀取excel文件,遍歷sheets,給單元格設置顏色等操作,希望對大家接活有所幫助。 gen.js let dir"D:\\武漢煙廠\\山東區域\\備檔資料\\銷區零售終端APP維護清單\\走訪檔案\\2024年6月…

Spring之事務失效的場景

Spring事務失效的場景 異常捕獲處理:自己處理了異常,沒有拋出。解決:手動拋出拋出檢查異常:配置rollbackFor屬性為Excetion非public方法導致事務失效,改為public 1、異常捕獲處理 示例: 張三1000元&#…

7月形勢分析-您下一步該如何做,才能走出困境?

馬上工程項目,再有三五天就要結束的了。即便推后也不會超過一周時間了。所以需要考慮將來干啥呢?  一方面就是繼續去濟寧做建筑工程的活。管吃住,但是因為至親之間,難免咋說呢,總之還是不太舒服的樣子。管事情多&…

bigNumber的部分使用方法與屬性

場景:最近做IoT項目的時候碰到一個問題,涉及到雙精度浮點型的數據范圍的校驗問題。業務上其實有三種類型:int、float和double類型三種。他們的范圍分別是: //int int: [-2147483648, 2147483647],//float float: [-3402823466385…

PHP7源碼結構

PHP7程序的執行過程 1.PHP代碼經過詞法分析轉換為有意義的Token; 2.Token經過語法分析生成AST(Abstract Synstract Syntax Tree,抽象語法樹); 3.AST生成對應的opcode,被虛擬機執行。 源碼結構&#xff1…

一切為了安全丨2024中國應急(消防)品牌巡展武漢站成功召開!

消防品牌巡展武漢站 6月28日,由中國安全產業協會指導,中國安全產業協會應急創新分會、應急救援產業網聯合主辦,湖北消防協會協辦的“一切為了安全”2024年中國應急(消防)品牌巡展-武漢站成功舉辦。該巡展旨在展示中國應急(消防&am…

qt QTreeView的簡單使用(多級子節點)

MainWindow::MainWindow(QWidget *parent): QMainWindow(parent), ui(new Ui::MainWindow) {ui->setupUi(this);setWindowTitle("QTreeView的簡單使用");model new QStandardItemModel;model->setHorizontalHeaderLabels(QStringList() << "left&q…

【數據結構 - 時間復雜度和空間復雜度】

文章目錄 <center>時間復雜度和空間復雜度算法的復雜度時間復雜度大O的漸進表示法常見時間復雜度計算舉例 空間復雜度實例 時間復雜度和空間復雜度 算法的復雜度 算法在編寫成可執行程序后&#xff0c;運行時需要耗費時間資源和空間(內存)資源 。因此衡量一個算法的好壞&…

[leetcode]longest-arithmetic-subsequence-of-given-difference. 最長定差子序列

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:int longestSubsequence(vector<int> &arr, int difference) {int ans 0;unordered_map<int, int> dp;for (int v: arr) {dp[v] dp[v - difference] 1;ans max(ans, dp[v]);}return ans…

Qt源碼分析:窗體繪制與響應

作為一套開源跨平臺的UI代碼庫&#xff0c;窗體繪制與響應自然是最為基本的功能。在前面的博文中&#xff0c;已就Qt中的元對象系統(反射機制)、事件循環等基礎內容進行了分析&#xff0c;并捎帶闡述了窗體響應相關的內容。因此&#xff0c;本文著重分析Qt中窗體繪制相關的內容…

ECharts 快速入門

文章目錄 1. 引入 ECharts2. 初始化 ECharts 實例3. 配置圖表選項4. 使用配置項生成圖表5. 最常用的幾種圖形5.1 柱狀圖&#xff08;Bar Chart&#xff09;5.2 折線圖&#xff08;Line Chart&#xff09;5.3 餅圖&#xff08;Pie Chart&#xff09;5.4 散點圖&#xff08;Scatt…

如何完成域名解析驗證

一&#xff1a;什么是DNS解析&#xff1a; DNS解析是互聯網上將人類可讀的域名&#xff08;如www.example.com&#xff09;轉換為計算機可識別的IP地址&#xff08;如192.0.2.1&#xff09;的過程&#xff0c;大致遵循以下步驟&#xff1a; 查詢本地緩存&#xff1a;當用戶嘗…

Linux內核 -- 多線程之完成量completion的使用

Linux Kernel Completion 使用指南 在Linux內核編程中&#xff0c;completion是一個用于進程同步的機制&#xff0c;常用于等待某個事件的完成。它提供了一種簡單的方式&#xff0c;讓一個線程等待另一個線程完成某項任務。 基本使用方法 初始化 completion結構需要在使用之…

順序串算法庫構建

學習賀利堅老師順序串算法庫 數據結構之自建算法庫——順序串_創建順序串s1,創建順序串s2-CSDN博客 本人詳細解析博客 串的概念及操作_串的基本操作-CSDN博客 版本更新日志 V1.0: 在賀利堅老師算法庫指導下, 結合本人詳細解析博客思路基礎上,進行測試, 加入異常彈出信息 v1.0補…

已解決java.awt.geom.NoninvertibleTransformException:在Java2D中無法逆轉的轉換的正確解決方法,親測有效!!!

已解決java.awt.geom.NoninvertibleTransformException&#xff1a;在Java2D中無法逆轉的轉換的正確解決方法&#xff0c;親測有效&#xff01;&#xff01;&#xff01; 目錄 問題分析 出現問題的場景 報錯原因 解決思路 解決方法 1. 檢查縮放因子 修改后的縮放變換 …