深度學習分類回歸(衣帽數據集)

一、步驟

1 加載數據集fashion_minst

2 搭建class NeuralNetwork模型

3 設置損失函數,優化器

4 編寫評估函數

5 編寫訓練函數

6 開始訓練

7 繪制損失,準確率曲線

二、代碼

導包,打印版本號:

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as Fprint(sys.version_info)
for module in mpl, np, pd, sklearn, torch:print(module.__name__, module.__version__)device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

torch的運算過程都是張量,也叫算子(tensor)

torchvision的包可以提供數據集,圖片就是datasets:

這里下載到data目錄,如果已有數據則不會下載。這段代碼可以實現數據向tensor的轉換:

做預處理的時候把圖片變成tensor,啥都沒寫的時候就不會轉換成tensor?

from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision import transforms# 定義數據集的變換
transform = transforms.Compose([
])
# fashion_mnist圖像分類數據集,衣服分類,60000張訓練圖片,10000張測試圖片
train_ds = datasets.FashionMNIST(root="data",train=True,download=True,transform=transform
)test_ds = datasets.FashionMNIST(root="data",train=False,download=True,transform=transform
)# torchvision 數據集里沒有提供訓練集和驗證集的劃分
# 當然也可以用 torch.utils.data.Dataset 實現人為劃分
type(train_ds[0]) # 元組,第一個元素是圖片,第二個元素是標簽

如果使用了數據類型變換:

img_tensor, label = train_ds[0]
img_tensor.shape  #img這時是一個tensor,shape=(1, 28, 28)

在PyTorch中,DataLoader是一個迭代器,它封裝了數據的加載和預處理過程,使得在訓練機器學習模型時可以方便地批量加載數據。DataLoader主要負責以下幾個方面:

  1. 批量加載數據DataLoader可以將數據集(Dataset)切分為更小的批次(batch),每次迭代提供一小批量數據,而不是單個數據點。這有助于模型學習數據中的統計依賴性,并且可以更高效地利用GPU等硬件的并行計算能力。

  2. 數據打亂:默認情況下,DataLoader會在每個epoch(訓練周期)開始時打亂數據的順序。這有助于模型訓練時避免陷入局部最優解,并且可以提高模型的泛化能力。

  3. 多線程數據加載DataLoader支持多線程(通過參數num_workers)來并行地加載數據,這可以顯著減少訓練過程中的等待時間,尤其是在處理大規模數據集時。

  4. 數據預處理DataLoader可以與transforms結合使用,對加載的數據進行預處理,如歸一化、標準化、數據增強等操作。

  5. 內存管理DataLoader負責管理數據的內存使用,確保在訓練過程中不會耗盡內存資源。

  6. 易用性DataLoader提供了一個簡單的接口,可以很容易地集成到訓練循環中。

# 從數據集到dataloader
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True) #batch_size分批,shuffle洗牌
val_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False)

這里每32個樣本就會算一次平均損失,更新一次w。

定義模型:繼承nn.Module

class NeuralNetwork(nn.Module):def __init__(self):super().__init__() # 繼承父類的初始化方法,子類有父類的屬性self.flatten = nn.Flatten()  # 展平層self.linear_relu_stack = nn.Sequential(nn.Linear(784, 300),  # in_features=784, out_features=300, 784是輸入特征數,300是輸出特征數nn.ReLU(), # 激活函數nn.Linear(300, 100),#隱藏層神經元數100nn.ReLU(), # 激活函數nn.Linear(100, 10),#輸出層神經元數10 )def forward(self, x): # 前向計算,前向傳播# x.shape [batch size, 1, 28, 28],1是通道數x = self.flatten(x)  # print(f'x.shape--{x.shape}')# 展平后 x.shape [batch size, 784]logits = self.linear_relu_stack(x)# logits.shape [batch size, 10]return logits #沒有經過softmax,稱為logitsmodel = NeuralNetwork()

model的結構:第一層是展平層,然后激活,然后隱藏層,激活,輸出層


?在訓練之前需要測試一下模型能不能用,所以我們隨機一個或者從樣本拿一個,同尺寸就行:

#為了查看模型運算的tensor尺寸
x = torch.randn(32, 1, 28, 28)
print(x.shape)
logits = model(x) # 把x輸入到模型中,得到logits
print(logits.shape)

?然后開始訓練,pytorch的訓練需要自行實現,包括定義損失函數、優化器、訓練步,訓練

# 1. 定義損失函數 采用交叉熵損失
loss_fct = nn.CrossEntropyLoss() #內部先做softmax,然后計算交叉熵
# 2. 定義優化器 采用SGD
# Optimizers specified in the torch.optim package,隨機梯度下降
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
from sklearn.metrics import accuracy_score # sk里面有一個算子,可以計算準確率@torch.no_grad() # 裝飾器,禁止反向傳播,節省內存,就是不求導的意思
def evaluating(model, dataloader, loss_fct): # 評估函數,評估也要做一次向前計算,不需要求梯度loss_list = [] # 記錄損失pred_list = [] # 記錄預測label_list = [] # 記錄標簽for datas, labels in dataloader:#10000/32=312datas = datas.to(device) # 轉到GPUlabels = labels.to(device) # 轉到GPU 這兩行代碼torch必寫,把tensor放到GPU上# 前向計算logits = model(datas)  # 進行前向計算loss = loss_fct(logits, labels)         # 驗證集損失,loss尺寸是一個數值loss_list.append(loss.item()) # 記錄損失,item是把tensor轉換為數值preds = logits.argmax(axis=-1)    # 驗證集預測,argmax返回最大值索引,-1就是最后一個維度print(f'評估中的preds.shape--{preds.shape}')pred_list.extend(preds.cpu().numpy().tolist())#將PyTorch張量轉換為NumPy數組。只有當張量在CPU上時,這個轉換才是合法的# print(preds.cpu().numpy().tolist())label_list.extend(labels.cpu().numpy().tolist())acc = accuracy_score(label_list, pred_list) # 計算準確率return np.mean(loss_list), acc
# 訓練
def training(model, train_loader, val_loader, epoch, loss_fct, optimizer, eval_step=500):#參數分別是模型,訓練集,驗證集,訓練epoch,損失函數,優化器,評估步數(500評估一次)record_dict = { # 記錄字典,用于記錄訓練過程中的信息"train": [],"val": []}global_step = 0 # 全局步數,記錄訓練的步數model.train() # 進入訓練模式,模型可以切換模式#tqdm是一個進度條庫with tqdm(total=epoch * len(train_loader)) as pbar: # 進度條 加入epoch等于10,就是所有樣本搞10次,不斷地把樣本帶進去學習,1875*10,60000/32=1875for epoch_id in range(epoch): # 訓練epoch次# trainingfor datas, labels in train_loader: #執行次數是60000/32=1875datas = datas.to(device) #datas尺寸是[batch_size,1,28,28]labels = labels.to(device) #labels尺寸是[batch_size]# 梯度清空optimizer.zero_grad() # 每次訓練前都要把梯度清空,不然會累加# 模型前向計算logits = model(datas)# 計算損失loss = loss_fct(logits, labels)# 梯度回傳,loss.backward()會計算梯度,loss對模型參數求導loss.backward()# 調整優化器,包括學習率的變動等,優化器的學習率會隨著訓練的進行而減小,更新w,boptimizer.step() #梯度是計算并存儲在模型參數的 .grad 屬性中,優化器使用這些存儲的梯度來更新模型參數preds = logits.argmax(axis=-1) # 訓練集預測acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())   # 計算準確率,numpy可以,每個step都算一次loss = loss.cpu().item() # 損失轉到CPU,item()取值,一個數值# tensor如果只有一個值(標量),一維是向量,二維是矩陣,可以用item()取出值,如果有多個值,則需要用tolist()轉為列表# record# recordrecord_dict["train"].append({"loss": loss, "acc": acc, "step": global_step}) # 記錄訓練集信息,每一步的損失,準確率,步數# evaluatingif global_step % eval_step == 0:model.eval() # 進入評估模式,不會求梯度val_loss, val_acc = evaluating(model, val_loader, loss_fct)record_dict["val"].append({"loss": val_loss, "acc": val_acc, "step": global_step})model.train() # 進入訓練模式# udate stepglobal_step += 1 # 全局步數加1pbar.update(1) # 更新進度條pbar.set_postfix({"epoch": epoch_id}) # 設置進度條顯示信息return record_dictepoch = 20 #改為40
model = model.to(device)
record = training(model, train_loader, val_loader, epoch, loss_fct, optimizer, eval_step=1000)
#畫線要注意的是損失是不一定在零到1之間的
def plot_learning_curves(record_dict, sample_step=1000):# build DataFrametrain_df = pd.DataFrame(record_dict["train"]).set_index("step").iloc[::sample_step]val_df = pd.DataFrame(record_dict["val"]).set_index("step")last_step = train_df.index[-1] # 最后一步的步數# print(train_df.columns)print(train_df['acc'])print(val_df['acc'])# plotfig_num = len(train_df.columns) # 畫幾張圖,分別是損失和準確率fig, axs = plt.subplots(1, fig_num, figsize=(5 * fig_num, 5))for idx, item in enumerate(train_df.columns):# print(train_df[item].values)axs[idx].plot(train_df.index, train_df[item], label=f"train_{item}")axs[idx].plot(val_df.index, val_df[item], label=f"val_{item}")axs[idx].grid() # 顯示網格axs[idx].legend() # 顯示圖例axs[idx].set_xticks(range(0, train_df.index[-1], 5000)) # 設置x軸刻度axs[idx].set_xticklabels(map(lambda x: f"{int(x/1000)}k", range(0, last_step, 5000))) # 設置x軸標簽axs[idx].set_xlabel("step")plt.show()plot_learning_curves(record)  #橫坐標是 steps

# dataload for evaluatingmodel.eval() # 進入評估模式
loss, acc = evaluating(model, val_loader, loss_fct)
print(f"loss:     {loss:.4f}\naccuracy: {acc:.4f}")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/897454.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/897454.shtml
英文地址,請注明出處:http://en.pswp.cn/news/897454.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【leetcode hot 100 19】刪除鏈表的第N個節點

解法一:將ListNode放入ArrayList中,要刪除的元素為num list.size()-n。如果num 0則將頭節點刪除;否則利用num-1個元素的next刪除第num個元素。 /*** Definition for singly-linked list.* public class ListNode {* int val;* Lis…

【iOS逆向與安全】sms短信轉發插件與上傳服務器開發

一、目標 一步步分析并編寫一個短信自動轉發的deb插件 二、工具 mac系統已越獄iOS設備:脫殼及frida調試IDA Pro:靜態分析測試設備:iphone6s-ios14.1.1三、步驟 1、守護進程 ? 守護進程(daemon)是一類在后臺運行的特殊進程,用于執行特定的系統任務。例如:推送服務、人…

Midjourney繪圖參數詳解:從基礎到高級的全面指南

引言 Midjourney作為當前最受歡迎的AI繪圖工具之一,其強大的參數系統為用戶提供了豐富的創作可能性。本文將深入解析Midjourney的各項參數,幫助開發者更好地掌握這一工具,提升創作效率和質量。 一、基本參數配置 1. 圖像比例調整 使用--ar…

音頻進階學習十九——逆系統(簡單進行回聲消除)

文章目錄 前言一、可逆系統1.定義2.解卷積3.逆系統恢復原始信號過程4.逆系統與原系統的零極點關系 二、使用逆系統去除回聲獲取原信號的頻譜原系統和逆系統幅頻響應和相頻響應使用逆系統恢復原始信號整體代碼如下 總結 前言 在上一篇音頻進階學習十八——幅頻響應相同系統、全…

vue3 使用sass變量

1. 在<style>中使用scss定義的變量和css變量 1. 在/style/variables.scss文件中定義scss變量 // scss變量 $menuText: #bfcbd9; $menuActiveText: #409eff; $menuBg: #304156; // css變量 :root {--el-menu-active-color: $menuActiveText; // 活動菜單項的文本顏色--el…

gbase8s rss集群通信流程

什么是rss RSS是一種將數據從主服務器復制到備服務器的方法 實例級別的復制 (所有啟用日志記錄功能的數據庫) 基于邏輯日志的復制技術&#xff0c;需要傳輸大量的邏輯日志,數據庫需啟用日志模式 通過網絡持續將數據復制到備節點 如果主服務器發生故障&#xff0c;那么備用服務…

熵與交叉熵詳解

前言 本文隸屬于專欄《機器學習數學通關指南》&#xff0c;該專欄為筆者原創&#xff0c;引用請注明來源&#xff0c;不足和錯誤之處請在評論區幫忙指出&#xff0c;謝謝&#xff01; 本專欄目錄結構和參考文獻請見《機器學習數學通關指南》 ima 知識庫 知識庫廣場搜索&#…

程序化廣告行業(3/89):深度剖析行業知識與數據處理實踐

程序化廣告行業&#xff08;3/89&#xff09;&#xff1a;深度剖析行業知識與數據處理實踐 大家好&#xff01;一直以來&#xff0c;我都希望能和各位技術愛好者一起在學習的道路上共同進步&#xff0c;分享知識、交流經驗。今天&#xff0c;咱們聚焦在程序化廣告這個充滿挑戰…

探索在生成擴散模型中基于RAG增強生成的實現與未來

概述 像 Stable Diffusion、Flux 這樣的生成擴散模型&#xff0c;以及 Hunyuan 等視頻模型&#xff0c;都依賴于在單一、資源密集型的訓練過程中通過固定數據集獲取的知識。任何在訓練之后引入的概念——被稱為 知識截止——除非通過 微調 或外部適應技術&#xff08;如 低秩適…

DeepSeek 助力 Vue3 開發:打造絲滑的表格(Table)之添加列寬調整功能,示例Table14基礎固定表頭示例

前言&#xff1a;哈嘍&#xff0c;大家好&#xff0c;今天給大家分享一篇文章&#xff01;并提供具體代碼幫助大家深入理解&#xff0c;徹底掌握&#xff01;創作不易&#xff0c;如果能幫助到大家或者給大家一些靈感和啟發&#xff0c;歡迎收藏關注哦 &#x1f495; 目錄 Deep…

取反符號~

取反符號 ~ 用于對整數進行按位取反操作。它會將二進制表示中的每一位取反&#xff0c;即 0 變 1&#xff0c;1 變 0。 示例 a 5 # 二進制表示為 0000 0101 b ~a # 按位取反&#xff0c;結果為 1111 1010&#xff08;補碼表示&#xff09; print(b) # 輸出 -6解釋 5 的二…

論文閱讀分享——UMDF(AAAI-24)

概述 題目&#xff1a;A Unified Self-Distillation Framework for Multimodal Sentiment Analysis with Uncertain Missing Modalities 發表&#xff1a;The Thirty-Eighth AAAI Conference on Artificial Intelligence (AAAI-24) 年份&#xff1a;2024 Github&#xff1a;暫…

WBC已形成“東亞-美洲雙中心”格局·棒球1號位

世界棒球經典賽&#xff08;WBC&#xff09;作為全球最高水平的國家隊棒球賽事&#xff0c;參賽隊伍按實力、地域和歷史表現可分為多個“陣營”。以下是基于歷屆賽事&#xff08;截至2023年&#xff09;的陣營劃分及代表性隊伍分析&#xff1a; 第一陣營&#xff1a;傳統豪強&a…

django中路由配置規則的詳細說明

在 Django 中,路由配置是將 URL 映射到視圖函數或類視圖的關鍵步驟,它決定了用戶請求的 URL 會觸發哪個視圖進行處理。以下將詳細介紹 Django 中路由配置的規則、高級使用方法以及多個應用配置的規則。 基本路由配置規則 1. 項目級路由配置 在 Django 項目中,根路由配置文…

【報錯】微信小程序預覽報錯”60001“

1.問題描述 我在微信開發者工具寫小程序時&#xff0c;使用http://localhost:8080是可以請求成功的&#xff0c;數據全都可以無報錯&#xff0c;但是點擊【預覽】&#xff0c;用手機掃描二維碼瀏覽時&#xff0c;發現前端圖片無返回且報錯60001&#xff08;打開開發者模式查看日…

柵格裁剪(Python)

在地理數據處理中&#xff0c;矢量裁剪柵格是一個非常重要的操作&#xff0c;它可以幫助我們提取感興趣的區域并獲得更精確的分析結果。其重要性包括&#xff1a; 區域限定&#xff1a;地球科學研究通常需要關注特定的地理區域。通過矢量裁剪柵格&#xff0c;我們可以將柵格數…

【無人機路徑規劃】基于麻雀搜索算法(SSA)的無人機路徑規劃(Matlab)

效果一覽 代碼獲取私信博主基于麻雀搜索算法&#xff08;SSA&#xff09;的無人機路徑規劃&#xff08;Matlab&#xff09; 一、算法背景與核心思想 麻雀搜索算法&#xff08;Sparrow Search Algorithm, SSA&#xff09;是一種受麻雀群體覓食行為啟發的元啟發式算法&#xff0…

MySQL數據庫安裝及基礎用法

安裝數據庫 第一步&#xff1a;下載并解壓mysql-8.4.3-winx64文件夾 鏈接: https://pan.baidu.com/s/1lD6XNNSMhPF29I2_HBAvXw?pwd8888 提取碼: 8888 第二步&#xff1a;打開文件中的my.ini文件 [mysqld]# 設置3306端口port3306# 自定義設置mysql的安裝目錄&#xff0c;即解…

軟件工程:軟件開發之需求分析

物有本末&#xff0c;事有終始。知所先后&#xff0c;則近道矣。對軟件開發而言&#xff0c;軟件需求乃重中之重。必先之事重千鈞&#xff0c;不可或缺如日辰。 汽車行業由于有方法論和各種標準約束&#xff0c;對軟件開發有嚴苛的要求。ASPICE指導如何審核軟件開發&#xff0…

正則表達式,idea,插件anyrule

????package lx;import java.util.regex.Pattern;public class lxx {public static void main(String[] args) {//正則表達式//寫一個電話號碼的正則表達式String regex "1[3-9]\\d{9}";//第一個數字是1&#xff0c;第二個數字是3-9&#xff0c;后面跟著9個數字…