【數據準備】——深度學習.全連接神經網絡

目錄

1 數據加載器

1.1 構建數據類

1.1.1 Dataset類

1.1.2 TensorDataset類

1.2 數據加載器

2 數據加載案例

2.1 加載csv數據集

2.2 加載圖片數據集

2.3 加載官方數據集

2.4 pytorch實現線性回歸


1 數據加載器

分數據集和加載器2個步驟~

1.1 構建數據類

1.1.1 Dataset類

Dataset是一個抽象類,是所有自定義數據集應該繼承的基類。它定義了數據集必須實現的方法。

必須實現的方法

  1. __len__: 返回數據集的大小

  2. __getitem__: 支持整數索引,返回對應的樣本

在 PyTorch 中,構建自定義數據加載類通常需要繼承 torch.utils.data.Dataset 并實現以下幾個方法:

  1. __init__ 方法 用于初始化數據集對象:通常在這里加載數據,或者定義如何從存儲中獲取數據的路徑和方法。
def __init__(self, data, labels):self.data = dataself.labels = labels

??????2.__len__ 方法 返回樣本數量:需要實現,以便 Dataloader加載器能夠知道數據集的大小。

def __len__(self):return len(self.data)

??????3.__getitem__ 方法 根據索引返回樣本:將從數據集中提取一個樣本,并可能對樣本進行預處理或變換。

def __getitem__(self, index):sample = self.data[index]label = self.labels[index]return sample, label

如果你需要進行更多的預處理或數據變換,可以在 __getitem__ 方法中添加額外的邏輯。

import torch
from torch.utils.data import Dataset,DataLoader,TensorDataset
from sklearn.datasets import make_regression
from torch import nn,optim
# 自定義數據集類
# 1.繼承dataset類
# 2.實現__init__方法,初始化外部的數據
# 3.實現__len__方法,用來返回數據集的長度
# 4.實現__getitem__方法,根據索引獲取對應位置的數據
class MyDataset(Dataset):def __init__(self,data,labels):self.data=dataself.labels=labelsdef __len__(self):return len(self.data)def __getitem__(self, index):sample=self.data[index]label=self.labels[ index]return sample,labeldef test01():x=torch.randn(100,20)y=torch.randn(100,1)dataset=MyDataset(x,y)print( dataset[0])if __name__=='__main__':test01()

1.1.2 TensorDataset類

TensorDatasetDataset的一個簡單實現,它封裝了張量數據,適用于數據已經是張量形式的情況。

特點

  1. 簡單快捷:當數據已經是張量形式時,無需自定義Dataset類

  2. 多張量支持:可以接受多個張量作為輸入,按順序返回

  3. 索引一致:所有張量的第一個維度必須相同,表示樣本數量

def test03():torch.manual_seed(0)# 創建特征張量和標簽張量features = torch.randn(100, 5)  # 100個樣本,每個樣本5個特征labels = torch.randint(0, 2, (100,))  # 100個二進制標簽# 創建TensorDatasetdataset = TensorDataset(features, labels)# 使用方式與自定義Dataset相同print(len(dataset))  # 輸出: 100print(dataset[0])  # 輸出: (tensor([...]), tensor(0))

1.2 數據加載器

在訓練或者驗證的時候,需要用到數據加載器批量的加載樣本。

DataLoader 是一個迭代器,用于從 Dataset 中批量加載數據。它的主要功能包括:

  • 批量加載:將多個樣本組合成一個批次。

  • 打亂數據:在每個 epoch 中隨機打亂數據順序。

  • 多線程加載:使用多線程加速數據加載。

創建DataLoader:

# 創建 DataLoader
dataloader = DataLoader(dataset,          # 數據集batch_size=10,    # 批量大小shuffle=True,     # 是否打亂數據num_workers=2     # 使用 2 個子進程加載數據
)

遍歷:

# 遍歷 DataLoader
# enumerate返回一個枚舉對象(iterator),生成由索引和值組成的元組
for batch_idx, (samples, labels) in enumerate(dataloader):print(f"Batch {batch_idx}:")print("Samples:", samples)print("Labels:", labels)

案例:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader# 定義數據加載類
class CustomDataset(Dataset):#略......def test01():# 簡單的數據集準備data_x = torch.randn(666, 20, requires_grad=True, dtype=torch.float32)data_y = torch.randn(data_x.size(0), 1, dtype=torch.float32)dataset = CustomDataset(data_x, data_y)# 構建數據加載器data_loader = DataLoader(dataset, batch_size=8, shuffle=True)for i, (batch_x, batch_y) in enumerate(data_loader):print(batch_x, batch_y)breakif __name__ == "__main__":test01()

2 數據加載案例

通過一些數據集的加載案例,真正了解數據類及數據加載器。

2.1 加載csv數據集

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import pandas as pd
from torchvision import datasets,transformsdef build_csv_data(filepath):df=pd.read_csv(filepath)df.drop(["學號","姓名"],axis=1,inplace=True)# print(df.head())samples=df.iloc[...,:-1]labels=df.iloc[...,-1]# print(samples.head())# print(labels.head())samples=torch.tensor(samples.values)labels=torch.tensor(labels.values)# print(samples)# print(labels)return samples,labelsdef load_csv_data():filepath="./datasets/大數據答辯成績表.csv"samples,labels=build_csv_data(filepath)dataset=TensorDataset(samples,labels)dataloader=DataLoader(dataset=dataset,batch_size=1,shuffle=True)for sample,label in dataloader:print(sample)print(label)breakif __name__=="__main__":load_csv_data()

2.2 加載圖片數據集

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import pandas as pd
from torchvision import datasets,transformsdef load_img_data():path="./datasets/animals"transform=transforms.Compose([# 圖片縮放 把所有圖片縮放到同一尺寸transforms.Resize((224,224)),# 把PIL圖片或numpy數組轉為張量transforms.ToTensor(),])dataset=datasets.ImageFolder(root=path,transform=transform)dataloader=DataLoader(dataset=dataset,batch_size=4,shuffle=True)for x,y in dataloader:print(x.shape)print(x)print(y)breakif __name__=="__main__":load_img_data()

2.3 加載官方數據集

在 PyTorch 中官方提供了一些經典的數據集,如 CIFAR-10、MNIST、ImageNet 等,可以直接使用這些數據集進行訓練和測試。

數據集:Datasets — Torchvision 0.22 documentation

常見數據集:

  • MNIST: 手寫數字數據集,包含 60,000 張訓練圖像和 10,000 張測試圖像。

  • CIFAR10: 包含 10 個類別的 60,000 張 32x32 彩色圖像,每個類別 6,000 張圖像。

  • CIFAR100: 包含 100 個類別的 60,000 張 32x32 彩色圖像,每個類別 600 張圖像。

  • COCO: 通用對象識別數據集,包含超過 330,000 張圖像,涵蓋 80 個對象類別。

torchvision.transforms 和 torchvision.datasets 是 PyTorch 中處理計算機視覺任務的兩個核心模塊,它們為圖像數據的預處理和標準數據集的加載提供了強大支持。

transforms 模塊提供了一系列用于圖像預處理的工具,可以將多個變換組合成處理流水線。

datasets 模塊提供了多種常用計算機視覺數據集的接口,可以方便地下載和加載。

參考如下:

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, datasetsdef test():transform = transforms.Compose([transforms.ToTensor(),])# 訓練數據集data_train = datasets.MNIST(root="./data",train=True,download=True,transform=transform,)trainloader = DataLoader(data_train, batch_size=8, shuffle=True)for x, y in trainloader:print(x.shape)print(y)break# 測試數據集data_test = datasets.MNIST(root="./data",train=False,download=True,transform=transform,)testloader = DataLoader(data_test, batch_size=8, shuffle=True)for x, y in testloader:print(x.shape)print(y)breakdef test006():transform = transforms.Compose([transforms.ToTensor(),])# 訓練數據集data_train = datasets.CIFAR10(root="./data",train=True,download=True,transform=transform,)trainloader = DataLoader(data_train, batch_size=4, shuffle=True, num_workers=2)for x, y in trainloader:print(x.shape)print(y)break# 測試數據集data_test = datasets.CIFAR10(root="./data",train=False,download=True,transform=transform,)testloader = DataLoader(data_test, batch_size=4, shuffle=False, num_workers=2)for x, y in testloader:print(x.shape)print(y)breakif __name__ == "__main__":test()test006()

2.4 pytorch實現線性回歸

import torch
from torch.utils.data import Dataset,DataLoader,TensorDataset
from sklearn.datasets import make_regression
from torch import nn,optim# pytorch實現線性回歸
def build_data(in_features,out_features):bias=14.5# 生成的數據需要轉換成tensorx,y,coef=make_regression(n_samples=1000,n_features=in_features,n_targets=out_features,coef=True,bias=bias,noise=0.1,random_state=42)x=torch.tensor(x,dtype=torch.float32)y=torch.tensor(y,dtype=torch.float32).view(-1,1) # 注意要把y轉換成二維數組(本來是一維) 否則會報警告coef=torch.tensor(coef,dtype=torch.float32)bias=torch.tensor(bias,dtype=torch.float32)return x,y,coef,biasdef train():# 數據準備in_features=10out_features=1x,y,coef,bias=build_data(in_features,out_features)dataset=TensorDataset(x,y)dataloader=DataLoader(dataset=dataset,batch_size=100,shuffle=True)# 定義網絡模型model=nn.Linear(in_features,out_features)# 定義損失函數criterion=nn.MSELoss()# 優化器opt=optim.SGD(model.parameters(),lr=0.1)epochs=20for epoch in range(epochs):for tx,ty in dataloader:y_pred=model(tx)loss=criterion(y_pred,ty)opt.zero_grad()loss.backward()opt.step()print(f'epoch:{epoch},loss:{loss.item()}')# detach()、data:作用是將計算圖中的weight參數值獲取出來print(f"真實權重:{coef.numpy()},訓練權重:{model.weight.detach().numpy()}") # datach()相當于把weight從計算圖中抽離出來print(f"真實偏置:{bias},訓練偏置:{model.bias.item()}")if __name__=='__main__':train()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/91552.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/91552.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/91552.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

健康生活,從細節開始

健康生活,從細節開始在當今快節奏的生活中,健康逐漸成為人們關注的焦點。擁有健康的身體,才能更好地享受生活、追求夢想。那么,如何才能擁有健康呢?這就需要我們從生活中的點滴細節入手,培養良好的生活習慣…

javax.servlet.http.HttpServletResponse;API導入報錯解決方案

javax.servlet.http.HttpServletResponse;API導入報錯解決方案與Postman上傳下載文件驗證 1. 主要錯誤:缺少 Servlet API 依賴 錯誤信息顯示 javax.servlet.http 包不存在。這是因為你的項目缺少 Servlet API 依賴。 解決方案: 如果你使用的是 Maven&…

reids依賴刪除,但springboot仍然需要redis參數才能啟動

背景:項目需要刪除redis。我刪除完項目所有配置redis的依賴,啟動報錯。[2025-07-17 15:08:37:561] [DEBUG] [restartedMain] DEBUG _.s.w.s.H.Mappings - [detectHandlerMethods,295] [] - o.s.b.a.w.s.e.BasicErrorController:{ [/error]}: error(HttpS…

【前端】CSS類命名規范指南

在 CSS 中,合理且規范的 class 命名格式對項目的可維護性和協作效率至關重要。以下是主流的 class 命名規范和方法論:一、核心命名原則語義化命名:描述功能而非樣式 ? .search-form(描述功能)? .red-text&#xff08…

C++網絡編程 4.UDP套接字(socket)編程示例程序

以下是基于UDP協議的完整客戶端和服務器代碼。UDP與TCP的核心區別在于無連接特性&#xff0c;因此代碼結構會更簡單&#xff08;無需監聽和接受連接&#xff09;。 UDP服務器代碼&#xff08;udp_server.cpp&#xff09; #include <iostream> #include <sys/socket.h&…

King’s LIMS:實驗室數字化轉型的智能高效之選

實驗室數字化轉型不僅是技術升級&#xff0c;更是管理理念和工作方式的革新。LIMS系統作為這一轉型的核心工具&#xff0c;能夠將分散的實驗數據轉化為可分析、可復用的資產&#xff0c;為科研決策提供支持&#xff1b;規范檢測流程&#xff0c;減少人為干預&#xff0c;確保結…

【力扣 中等 C】97. 交錯字符串

目錄 題目 解法一 題目 待添加 解法一 bool isInterleave(char* s1, char* s2, char* s3) {const int len1 strlen(s1);const int len2 strlen(s2);const int len3 strlen(s3);if (len1 len2 ! len3) {return false;}if (len1 < len2) {return isInterleave(s2, s1,…

Class9簡潔實現

Class9簡潔實現 %matplotlib inline import torch from torch import nn from d2l import torch as d2l# 初始化訓練樣本、測試樣本、樣本特征維度和批量大小 n_train,n_test,num_inputs,batch_size 20,100,200,5 # 設置真實權重和偏置 true_w,true_b torch.ones((num_inputs…

ELK日志分析,涉及logstash、elasticsearch、kibana等多方面應用,必看!

目錄 ELK日志分析 1、下載lrzsc 2、下載源包 3、解壓文件,下載elasticsearch、kibana、 logstash 4、配置elasticsearch 5、配種域名解析 6、配置kibana 7、配置logstash 8、進行測試 ELK日志分析 1、下載lrzsc [rootlocalhost ~]# hostnamectl set-hostname elk ##…

終極剖析HashMap:數據結構、哈希沖突與解決方案全解

文章目錄 引言 一、HashMap底層數據結構&#xff1a;三維存儲架構 1. 核心存儲層&#xff08;硬件優化設計&#xff09; 2. 內存布局對比 二、哈希沖突的本質與數學原理 1. 沖突產生的必然性 2. 沖突概率公式 三、哈希沖突解決方案全景圖 1. 鏈地址法&#xff08;Hash…

1.1.5 模塊與包——AI教你學Django

1.1.5 模塊與包&#xff08;Django 基礎學習細節&#xff09; 模塊和包是 Python 項目組織和代碼復用的基礎。Django 項目本質上就是由多個模塊和包組成。理解和靈活運用模塊與包機制&#xff0c;是寫好大型項目的關鍵。 一、import、from-import、as 的用法 1. import 用于導入…

UE5 相機后處理材質與動態參數修改

一、完整實現流程1. 創建后處理材質材質設置&#xff1a;在材質編輯器中&#xff0c;將材質域(Material Domain)設為后處理(Post Process)設置混合位置(Blendable Location)&#xff08;如After Tonemapping&#xff09;創建標量/向量參數&#xff08;如Intensity, ColorTint&a…

Django基礎(三)———模板

前言 在之前的文章中&#xff0c;視圖函數只是直接返回文本&#xff0c;而在實際生產環境中其實很少這樣用&#xff0c;因為實際的頁面大多是帶有樣式的HTML代碼&#xff0c;這可以讓瀏覽器渲染出非常漂亮的頁面。目前市面上有非常多的模板系統&#xff0c;其中最知名最好用的…

mysql6表清理跟回收空間

mysql6表清理跟回收空間 文章目錄mysql6表清理跟回收空間1.清理表2.備份表或者備份庫3.回收表空間4.查看5.驗證業務1.清理表 ## 登錄 C:\Program Files\MySQL\MySQL Server 5.6\bin>mysql -uroot -p Enter password: ****** Welcome to the MySQL monitor. Commands end w…

Java-74 深入淺出 RPC Dubbo Admin可視化管理 安裝使用 源碼編譯、Docker啟動

點一下關注吧&#xff01;&#xff01;&#xff01;非常感謝&#xff01;&#xff01;持續更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持續更新中&#xff01;&#xff08;長期更新&#xff09; AI煉丹日志-30-新發布【1T 萬億】參數量大模型&#xff01;K…

VSCode同時支持Vue2和Vue3開發的插件指南

引言 隨著Vue生態系統的演進&#xff0c;許多開發者面臨著在同一開發環境中同時處理Vue 2和Vue 3項目的需求。Visual Studio Code (VSCode)作為最受歡迎的前端開發工具之一&#xff0c;其插件生態對Vue的支持程度直接影響開發效率。本文將深入探討如何在VSCode中配置插件組合&a…

卷積神經網絡CNN的Python實現

一、環境準備與庫導入 在開始實現卷積神經網絡之前&#xff0c;需要確保開發環境已正確配置&#xff0c;并導入必要的Python庫。常用的深度學習框架有TensorFlow和PyTorch&#xff0c;本示例將基于Keras&#xff08;可使用TensorFlow后端&#xff09;進行實現&#xff0c;因為K…

js是實現記住密碼自動填充功能

記住密碼自動填充使用js實現記住密碼功能&#xff0c;在下次打開登陸頁面的時候進行獲取并自動填充到頁面【cookie和localStorage】使用js實現記住密碼功能&#xff0c;在下次打開登陸頁面的時候進行獲取并自動填充到頁面【cookie和localStorage】 //添加功能----記住上一個登陸…

【Java】文件編輯器

代碼&#xff1a;&#xff08;SimpleEditor.java&#xff09;import java.awt.Color; import java.awt.Font; import java.awt.Insets; import java.awt.BorderLayout;import java.awt.event.ActionEvent; import java.awt.event.ActionListener;import java.io.BufferedReader…

PyTorch中torch.topk()詳解:快速獲取最大值索引

torch.topk(similarities, k=2).indices 是什么意思 torch.topk(similarities, k=2).indices 是 PyTorch 中用于獲取張量中最大值元素及其索引的函數。在你的代碼中,它的作用是從 similarities 向量里找出得分最高的2個元素的位置索引。 1. 核心功能:找出張量中最大的k個值…