PyTorch 數據類型和使用

關于PyTorch的數據類型和使用的學習筆記??系統介紹了PyTorch的核心數據類型Tensor及其應用。Tensor作為多維矩陣數據容器,支持0-4維數據結構(標量到批量圖像),并提供了多種數值類型(float32/int64等)。通過積木類比闡述了Tensor的維度概念,展示了創建、變形、隨機生成等基礎操作。重點演示了FashionMNIST數據集分類任務實戰:構建包含兩個全連接層的神經網絡(QYNN),使用交叉熵損失和SGD優化器進行訓練。

1 介紹

? PyTorch 是Torch的Python版本 是開源的神經網絡框架 針對于GPU加速的深度神經網絡編程

? Torch是一個經典的多維矩陣數據進行操作的張量(Tensor)庫 在機器學習和其他數學密集型應用廣泛應用 PyTorch的計算圖是動態的 可以按照計算需求實時改變計算圖

? PyTorch追求最少的封裝 設計遵循Tensor->Variable->nn.Module 三個由低到高的抽象層次 分別代表高維數組(張量) 自動求導(變量) 神經網絡(層/模塊)三個抽象間聯系緊密


2 基礎數據類型

?2.1 圖文說明

PyTorch處理的最基本的操作對象就是張量(Tensor)表示的就是一個多維矩陣 接下來將進行一個通俗的說明

? ?我們類比一下積木 ,Tensor就是構建一切模型和計算的最基本積木塊。而PyThorch就是一個裝數字的??盒子??,并且這個盒子可以有很多??維度??(幾層架子)。

維度類比描述例子具體場景說明
0維張量(標量)一粒積木5, 3.14單個數值(如溫度、概率值)
1維張量(向量)一行/一列整齊擺放的積木[1, 2, 3, 4]物體位置坐標、心電圖波形數據
2維張量(矩陣)行列組成的積木板[[1,2,3],
[4,5,6]]
灰度圖像(28x28像素)、
Excel表格數據
3維張量一摞多個積木板尺寸示例:[3,224,224]彩色圖像(通道×高×寬)
MRI切片掃描數據
4維張量多個3維張量
打包的箱子
尺寸示例:[32,3,224,224]批量處理32張
224x224像素的RGB圖像

?

?

而每一個數字也自己本身的數據類型:浮點型?和?整型

??數據類型????位寬/精度????通俗解釋????典型應用場景????PyTorch創建方法????內存占用??
??torch.float32??
(torch.float)
32位
單精度浮點
帶小數點的數
(如3.14159)
深度學習模型參數
激活函數計算
.float()
dtype=torch.float32
4字節/元素
??torch.float64??
(torch.double)
64位
雙精度浮點
高精度浮點數
(更多小數位)
科學計算
精密數值分析
.double()
dtype=torch.float64
8字節/元素
??torch.int32??
(torch.int)
32位整數普通整數
(如-1, 0, 42)
一般計數
簡單索引
.int()
dtype=torch.int32
4字節/元素
??torch.int64??
(torch.long)
64位整數大范圍整數
(更大或更精確)
??標簽數據??
復雜索引
位置信息
.long()
dtype=torch.int64
8字節/元素
??torch.uint8??8位無符號整數0-255整數
(無負數)
??圖像像素值??
(0=黑, 255=白)
.byte()
dtype=torch.uint8
1字節/元素
??torch.bool??布爾值True/False
(是/否)
條件判斷
數據掩碼
(如x>5)
.bool()
dtype=torch.bool
1字節/元素
??torch.complex64??64位復數復數表示
(實部+虛部浮點)
信號處理
量子計算
dtype=torch.complex648字節/元素
??torch.complex128??128位復數高精度復數高級物理計算
電磁場模擬
dtype=torch.complex12816字節/元素

? ? ? ? ??之所以說 Tensor 是核心數據類型。是因為 PyTorch 幾乎所有操作(神經網絡運算、求梯度)都建立在處理?Tensor?之上。你需要把你的數據(數字、圖像、文本數值化表示等)最終都放進這些不同形狀(維度)的?Tensor?盒子里,PyTorch 才能處理和計算。? 所有東西最終都變成 Tensor 的某種形式。維度 (shape) 決定了數據的基本結構(標量、向量、矩陣、圖片、批量)。而 ?dtype?去指定格子的內容類型

2.2 代碼實現

? ? ?然后 PyTorch實際的數據類型我們再使用代碼實操一下

?2.2.1 基礎張量創建?
# 張量的定義方式和Numpy一樣 傳入矩陣即可生成張量
import torch
a = torch.Tensor([[1,2],[3,4]])
print(a) # <class 'torch.Tensor'>
a = torch.eye(2)  # 創建2x2單位矩陣
print(a)           # 輸出: tensor([[1., 0.], [0., 1.]])
?2.2.2 特殊張量初始化?
 = torch.zeros(3, 3)  # 3x3全0張量
c = torch.ones(3, 3)   # 3x3全1張量
d = torch.arange(1, 10, 2)  # [1,10)區間步長為2: [1,3,5,7,9]
e = torch.linspace(1, 10, 10)  # 1-10的10個等差值
f = torch.logspace(1, 10, 10)  # 10^1到10^10的10個對數間隔值
g = torch.logspace(1, 2, 10)   # 10^1到10^2的10個值
2.2.3 隨機張量生成?
a1 = torch.rand(3, 3) # [0,1)均勻分布 a2 = torch.randn(3, 3) # 標準正態分布(μ=0, σ=1) a3 = torch.randint(1, 10, (5, 5)) # [1,10)區間的隨機整數
2.2.4 ?NumPy互操作?
import numpy as np
a4 = np.array([1, 2])          # 創建NumPy數組
a5 = torch.from_numpy(a4)       # NumPy轉PyTorch張量
# 類型轉換: <class 'numpy.ndarray'> -> <class 'torch.Tensor'>
?2.2.5?張量形狀操作?
a = torch.Tensor(2, 3, 128, 128)
print(a.shape)              # torch.Size([2, 3, 128, 128])
print(a[0].shape)          # torch.Size([3, 128, 128])
print(a[0][0].shape)       # torch.Size([128, 128])# 高級切片
print(a[:1, :1, :64, :64].shape)    # torch.Size([1, 1, 64, 64])
print(a[:1, :1, :64:2, :64:2].shape) # torch.Size([1, 1, 32, 32])
2.2.6 維度變換?
# 重塑形狀
B = a.reshape(2, 3, -1)      # 展平后兩維: (2,3,16384)
C = a.reshape(4, -1)        # (4, 24576)# 增刪維度
a = a.unsqueeze(2)          # 添加維度: (2,3,1,128,128)
a = a.squeeze(1)            # 刪除大小為1的維度# 維度交換
a = a.transpose(0, 1)       # 交換維度0和1: (3,2,128,128)
a = a.permute(1, 0, 3, 2)   # 維度重排: (3,2,128,128)
?2.2.7?維度擴展??
a = torch.randn(2, 1, 128, 128)
a = a.expand(2, 3, 128, 64)  # 復制數據擴展維度
# 要求: 擴展維度必須為1或與原尺寸一致
2.2.8 函數總結
??操作類型????函數/語法????關鍵特性??
基礎創建eye(),?zeros(),?ones()初始化特殊矩陣
序列生成arange(),?linspace()控制步長/數量
隨機生成rand(),?randn(),?randint()均勻/正態/整數分布
維度操作reshape(),?view()數據不復制改變形狀
維度增刪unsqueeze(),?squeeze()添加/移除大小為1的維度
維度交換transpose(),?permute()調整維度順序
數據擴展expand()復制數據擴展張量(僅支持1->N的擴展)
NumPy互操作from_numpy()零拷貝數據共享

3 實戰使用

使用FashionMNIST數據集(FashionMNIST 是一個經典的計算機視覺基準數據集,由德國電商巨頭 Zalando 的研究團隊于 2017 年創建,旨在替代過于簡單的 MNIST 手寫數字數據集。它包含 70,000 張 28x28 像素的時尚單品灰度圖像,涵蓋 10 個類別。)完成一個基本的圖形分類任務

? ? ? ? 我們將從環境配置 模型訓練與評估 模型使用三個階段講起

3.1 環境配置

模型訓練

import torch
import torch.nn as nn 
import torch.optim as optim # 導入優化器
from torchvision import datasets, transforms # 導入數據集和數據預處理庫
from torch.utils.data import DataLoader # 數據加載庫

模型使用

import os # 用于操作文件
import torch
import matplotlib.pyplot as plt
from torchvision import datasets,transforms # 用于數據集和數據變換
from PIL import Image # 用于圖形操作
from torchvision.datasets import FashionMNIST # 用于加載FashionMNIST數據集
from train import QYNN, transform

一定要添加一個本地的解釋器配置環境 以免沖突?

3.2 模型訓練與評估

? train.py


import torch
import torch.nn as nn 
import torch.optim as optim # 導入優化器
from torchvision import datasets, transforms # 導入數據集和數據預處理庫
from torch.utils.data import DataLoader # 數據加載庫# 設置隨機種子
torch.manual_seed(21)# 定義數據預處理
transform = transforms.Compose([transforms.Resize((28,28)),transforms.Grayscale(), #強制灰度圖像(1通道)transforms.ToTensor(),  # 將圖像轉換為張量transforms.Normalize((0.5,), (0.5,))  # 標準化圖像數據 灰度圖,只需要一個0.5
])# 加載FashionMNIST數據集
train_dataset = datasets.FashionMNIST(root='./FashionMNIST_images/train', train=True, download=True, transform=transform)  # 下載訓練集
test_dataset = datasets.FashionMNIST(root='./FashionMNIST_images/test', train=False, download=True, transform=transform) #  下載測試集# 創建數據加載器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 對訓練集進行打包,指定批次為64
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # 對測試集進行打包# 打印數據集大小和樣本檢查
print(f"訓練集大小: {len(train_dataset)}")
print(f"測試集大小: {len(test_dataset)}")# 定義神經網絡模型
class QYNN(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(28*28, 128)  # 第一個全連接層 先轉換為一維向量self.fc2 = nn.Linear(128, 10)  # 第二個全連接層 輸出10個類別def forward(self, x):x = torch.flatten(x, start_dim=1)  # 展平數據,方便進行全連接x = torch.relu(self.fc1(x))  # 非線性x = self.fc2(x) # 十分類 輸出層return x # 檢查是否有 GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 初始化模型
model = QYNN().to(device) # 將模型移植到 GPU 或 CPU 上# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss() # 交叉熵
optimizer = optim.SGD(model.parameters(), lr=0.01) # lr 學習率 用來調整模型收斂速度# 訓練模型
epochs = 10
best_acc = 0 # 初始化最佳準確率
best_model_wts = None # 用于保存最佳權重
for epoch in range(epochs): # 0-9running_loss = 0.0model.train()  # 設置模型為訓練模式for inputs, labels in train_loader:# 移動數據到GPU 或 CPUinputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad() # 梯度清零outputs = model(inputs) # 將圖片塞進網絡訓練獲得 輸出 前向傳播loss = criterion(outputs, labels) # 根據輸出和標簽做對比計算損失loss.backward() # 反向傳播optimizer.step() # 更新參數running_loss += loss.item() # loss值累加print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")# 測試模型
model.eval() # 設置模型為評估模式
correct = 0 # 正確的數量
total = 0 # 樣本總數
with torch.no_grad(): # 不用進行梯度計算for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs, 1) # _取到的最大值,可以不要, 我們需要的是最大值對應的索引 也就是label(predicted)total += labels.size(0) # 獲取當前批次樣本數量correct += (predicted == labels).sum().item() # 對預測對的值進行累加accuracy = 100 * correct / total # 計算準確率
print(f"Epoch{epoch+1}/{epochs},Accuracy on test set: {correct/total:.2%}")# 如果當前模型的準確率比之前的最佳準率好 則保存模型權重
if accuracy > best_acc: best_acc = accuracy
best_model_wts = model.state_dict() # 保存最佳模型的權重torch.save(model.state_dict(), "./FashionMNIST_images/model.pt")
print("Best model weights saved !")

3.3 模型使用?

test.py

import os # 用于操作文件
import torch
import matplotlib.pyplot as plt
from torchvision import datasets,transforms # 用于數據集和數據變換
from PIL import Image # 用于圖形操作
from torchvision.datasets import FashionMNIST # 用于加載FashionMNIST數據集
from train import QYNN, transform# 定義數據集保存路徑
data_dir = './FashionMNIST_images' # 數據集的根目錄
train_dir = os.path.join(data_dir, 'train') # 訓練集保存路經)
test_dir = os.path.join(data_dir, 'test')  # 測試集保存路徑# 定義分類標簽 FashionMNIST共有10個類別
class_names = ['T-shirt/top',   # 0: T恤/上衣'Trouser',       # 1: 褲子'Pullover',      # 2: 套頭衫'Dress',         # 3: 連衣裙'Coat',          # 4: 外套'Sandal',        # 5: 涼鞋'Shirt',         # 6: 襯衫'Sneaker',       # 7: 運動鞋'Bag',           # 8: 包'Ankle boot'     # 9: 短靴
]model = QYNN()
model.load_state_dict(torch.load("./FashionMNIST_images/model.pt"))
model.eval()# 定義推理和可視化函數
def infer_and_visualize_image(image_path,model,classses):# 打開圖形并進行預處理img = Image.open(image_path).convert('L') # 確保灰度圖片img = transform(img).unsqueueeze(0) # 增加一個批次維度# 推理with torch.no_grad():output = model(img)_, predicted = torch.max(output, 1)# 可視化圖形和預測結果plt.imshow(img.squeeze(), cmap='gray')plt.title(f"Predicted {classses[predicted[0]]}")plt.axis('off')plt.show()# 輸入圖形路徑image_path = r""infer_and_visualize_image(image_path, model, class_names)

?訓練效果展示


本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/93417.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/93417.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/93417.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

[python刷題模板] LogTrick

[python刷題模板] LogTrick 一、 算法&數據結構1. 描述2. 復雜度分析3. 常見應用4. 常用優化二、 模板代碼1. 特定或值的最短子數組2. 找特定值3. 找位置j的最后一次被誰更新4. 問某個或和的數量三、其他四、更多例題五、參考鏈接一、 算法&數據結構 1. 描述 LogTric…

Vim與VS Code

Vim is a clone, with additions, of Bill Joys vi text editor program for Unix. It was written by Bram Moolenaar based on source for a port of the Stevie editor to the Amiga and first released publicly in 1991.其實這個本身不是 IDE &#xff08;只有在加入和配置…

[2025CVPR-圖象分類方向]CATANet:用于輕量級圖像超分辨率的高效內容感知標記聚合

?1. 研究背景與動機? ?問題?&#xff1a;Transformer在圖像超分辨率&#xff08;SR&#xff09;中計算復雜度隨空間分辨率呈二次增長&#xff0c;現有方法&#xff08;如局部窗口、軸向條紋&#xff09;因內容無關性無法有效捕獲長距離依賴。?現有局限?&#xff1a; SPI…

課題學習筆記3——SBERT

1 引言在構建基于知識庫的問答系統時&#xff0c;"語義匹配" 是核心難題 —— 如何讓系統準確識別 "表述不同但含義相同" 的問題&#xff1f;比如用戶問 "對親人的期待是不是欲&#xff1f;"&#xff0c;系統能匹配到知識庫中 "追名逐利是欲…

在Word和WPS文字中把全角數字全部改為半角

大部分情況下我們在Word或WPS文字中使用的數字或標點符號都是半角&#xff0c;但是有時不小心按錯了快捷鍵或者點到了輸入法的全角半角切換圖標&#xff0c;就輸入了全角符號和數字。不用擔心&#xff0c;使用它們自帶的全角、半角轉換功能即可快速全部轉換回來。一、為什么會輸…

數據結構的基本知識

一、集合框架1、什么是集合框架Java集合框架(Java Collection Framework),又被稱為容器(container),是定義在java.util包下的一組接口(interfaces)和其實現類(classes).主要表現為把多個元素(element)放在一個單元中,用于對這些元素進行快速、便捷的存儲&#xff08;store&…

WebStack-Hugo | 一個靜態響應式導航主題

WebStack-Hugo | 一個靜態響應式導航主題 #10 shenweiyan announced in 1.3-折騰 WebStack-Hugo | 一個靜態響應式導航主題#10 ?編輯shenweiyan on Oct 23, 2023 6 comments 7 replies Return to top shenweiyan on Oct 23, 2023 Maintainer Via&#xff1a;我給自己…

01 基于sklearn的機械學習-機械學習的分類、sklearn的安裝、sklearn數據集、數據集的劃分、特征工程中特征提取與無量綱化

文章目錄機械學習機械學習分類1. 監督學習2. 半監督學習3. 無監督學習4. 強化學習機械學習的項目開發步驟scikit-learn1 scikit-learn安裝2 sklearn數據集1. sklearn 玩具數據集鳶尾花數據集糖尿病數據集葡萄酒數據集2. sklearn現實世界數據集20 新聞組數據集3. 數據集的劃分特…

攜全雙工語音通話大模型亮相WAIC,Soul重塑人機互動新范式

近日&#xff0c;WAIC 2025在上海隆重開幕。作為全球人工智能領域的頂級盛會&#xff0c;本屆WAIC展覽聚焦底層能力的演進與具體垂類場景的融合落地。堅持“模應一體”方向、立足“AI社交”的具體場景&#xff0c;Soul App此次攜最新升級的自研端到端全雙工語音通話大模型亮相&…

第2章 cmd命令基礎:常用基礎命令(1)

Hi~ 我是李小咖&#xff0c;主要從事網絡安全技術開發和研究。 本文取自《李小咖網安技術庫》&#xff0c;歡迎一起交流學習&#x1fae1;&#xff1a;https://imbyter.com 本節介紹的命令有目錄操作&#xff08;cd&#xff09;、清屏操作&#xff08;cls&#xff09;、設置顏色…

Java 10 新特性解析

Java 10 新特性解析 文章目錄Java 10 新特性解析1. 引言2. 本地變量類型推斷&#xff08;JEP 286&#xff09;2.1. 概述2.2. 使用場景2.3. 限制2.4. 與之前版本的對比2.5. 風格指南2.6. 示例代碼2.7. 優點與注意事項3. 應用程序類數據共享&#xff08;JEP 310&#xff09;3.1. …

【WRF工具】服務器中安裝編譯GrADS

目錄 安裝編譯 GrADS 所需的依賴庫 conda下載庫包 安裝編譯 GrADS 編譯前檢查依賴可用性 安裝編譯 GrADS 參考 安裝編譯 GrADS 所需的依賴庫 以統一方式在 $HOME/WRFDA_LIBS/grads_deps 下安裝所有依賴: # 選擇一個目錄用于安裝所有依賴庫 export DIR=$HOME/WRFDA_LIBS庫包1…

數據結構之隊列(C語言)

1.隊列的定義&#xff1a; 隊列&#xff08;Queue&#xff09;是一種基礎且重要的線性數據結構&#xff0c;遵循先進先出&#xff08;FIFO&#xff09;?? 原則&#xff0c;即最早入隊的元素最先出隊&#xff0c;與棧不同的是出隊列的順序是固定的。隊列具有以下特點&#xff…

C#開發基礎之深入理解“集合遍歷時不可修改”的異常背后的設計

前言 歡迎關注【dotnet研習社】&#xff0c;今天我們聊聊一個基礎問題“集合已修改&#xff1a;可能無法執行枚舉操作”背后的設計。 在日常 C# 開發中&#xff0c;我們常常會操作集合&#xff08;如 List<T>、Dictionary<K,V> 等&#xff09;。一個新手開發者極…

【工具】圖床完全指南:從選擇到搭建的全方位解決方案

前言 在數字化內容創作的時代&#xff0c;圖片已經成為博客、文檔、社交媒體等平臺不可或缺的元素。然而&#xff0c;如何高效、穩定地存儲和分發圖片資源&#xff0c;一直是內容創作者面臨的重要問題。圖床&#xff08;Image Hosting&#xff09;作為專門的圖片存儲和分發服務…

深度學習篇---PaddleDetection模型選擇

PaddleDetection 是百度飛槳推出的目標檢測開發套件&#xff0c;提供了豐富的模型庫和工具鏈&#xff0c;覆蓋從輕量級移動端到高性能服務器的全場景需求。以下是核心模型分類、適用場景及大小選擇建議&#xff08;通俗易懂版&#xff09;&#xff1a;一、主流模型分類及適用場…

cmseasy靶機密碼爆破通關教程

靶場安裝1.首先我們需要下載一個cms靶場CmsEasy_7.6.3.2_UTF-8_20200422,下載后解壓在phpstudy_pro的網站根目錄下。2.然后我們去訪問一下安裝好的網站&#xff0c;然后注冊和鏈接數據庫3.不知道自己數據庫密碼的可以去小皮面板里面查看4.安裝好后就可以了來到后臺就可以了。練…

【C語言】指針深度剖析(一)

文章目錄一、內存和地址1.1 內存的基本概念1.2 編址的原理二、指針變量和地址2.1 取地址操作符&#xff08;&&#xff09;2.2 指針變量和解引用操作符&#xff08;*&#xff09;2.2.1 指針變量2.2.2 指針類型的解讀2.2.3 解引用操作符2.3 指針變量的大小三、指針變量類型的…

半導體企業選用的跨網文件交換系統到底應該具備什么功能?

在半導體行業的數字化轉型過程中&#xff0c;跨網文件交換已成為連接研發、生產、供應鏈的關鍵紐帶。半導體企業的跨網文件交換不僅涉及設計圖紙、工藝參數等核心知識產權&#xff0c;還需要滿足跨國協同、合規審計等復雜需求。那么&#xff0c;一款適合半導體行業的跨網文件交…

影刀RPA_初級課程_玩轉影刀自動化_網頁操作自動化

聲明&#xff1a;相關內容來自影刀學院&#xff0c;本文章為自用筆記&#xff0c;切勿商用&#xff01;&#xff08;若有侵權&#xff0c;請聯絡刪除&#xff09; 1. 基本概念與操作 1.1 正確處理下拉框元素&#xff08;先判斷頁面元素&#xff0c;后進行流程編制&#xff09;…