《PyTorch深度學習實踐》第八講加載數據集

一、

1、DataSet 是抽象類,不能實例化對象,主要是用于構造我們的數據集

2、DataLoader 需要獲取DataSet提供的索引[i]和len;用來幫助我們加載數據,比如說做shuffle(提高數據集的隨機性),batch_size,能拿出Mini-Batch進行訓練。它幫我們自動完成這些工作。DataLoader可實例化對象。DataLoader is a class to help us loading data in Pytorch.

3、__getitem__目的是為支持下標(索引)操作
?

二、

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader# prepare datasetclass DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0] # shape(多少行,多少列)self.x_data = torch.from_numpy(xy[:, :-1])self.y_data = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lendataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) #num_workers 多線程# design model using classclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# training cycle forward, backward, update
if __name__ == '__main__':for epoch in range(100):for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batchinputs, labels = datay_pred = model(inputs)loss = criterion(y_pred, labels)print(epoch, i, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()

1、需要mini_batch 就需要import DataSet和DataLoader

2、繼承DataSet的類需要重寫init,getitem,len魔法函數。分別是為了加載數據集,獲取數據索引,獲取數據總量。

3、DataLoader對數據集先打亂(shuffle),然后劃分成mini_batch。

4、len函數的返回值 除以 batch_size 的結果就是每一輪epoch中需要迭代的次數。

5、inputs, labels = data中的inputs的shape是[32,8],labels 的shape是[32,1]。也就是說mini_batch在這個地方體現的

6、diabetes.csv數據集老師給了下載地址,該數據集需和源代碼放在同一個文件夾內。

問題:loss沒有收斂

網友解決:

做了兩個實驗:(1)輸出每批次的loss,不收斂,loss在0.6上下浮動(2)每個epoch都不分批,把所有樣本都輸入,收斂,最后結果在0.6附近。所以猜測:小樣本之間的loss差距相對于0.6而言有點大,所以看著像是沒收斂,實際上從總loss來看已經收斂了

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/712642.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/712642.shtml
英文地址,請注明出處:http://en.pswp.cn/news/712642.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Windows10環境下MongoDB安裝配置

1. 下載對應MongoDB安裝包 進入官網:MongoDB官網 如果不連接外網則在官網下載較慢,這里給出下載好的安裝包,版本為4.2.25:百度網盤 選擇你需要的版本,推薦選擇Package的格式為zip(解壓即可) Pa…

[VNCTF2024]-PWN:preinit解析(逆向花指令,繞過strcmp,函數修改,機器碼)

查看保護: 查看ida: 這邊其實看反匯編沒啥大作用,需要自己動調。 但是前面的繞過strcmp還是要看一下的。 解題: 這里是用linux自帶的產生隨機數的文件urandom來產生一個隨機密碼,然后讓我們輸入密碼,用st…

k8s 存儲卷詳解與動靜部署詳解

目錄 一、Volume 卷 1.1 卷類型 emptyDir : hostPath: persistentVolumeClaim (PVC): configMap 和 secret: 二、 emptyDir存儲卷 2.1 特點 2.2 用途: 2.3 示例 三、 hostPath存儲卷 3.1 特點 3.2 用途 …

前端mock數據 —— 使用Apifox mock頁面所需數據

前端mock數據 —— 使用Apifox 一、使用教程二、本地請求Apifox所mock的接口 一、使用教程 在首頁進行新建項目: 新建項目名稱: 新建接口: 創建json: 請求方法: GET。URL: api/basis。響應類型&#xff1…

可以用numpy為for加速

Numpy除了用于科學計算,還有一個功能是可以代替某些for循環,進行同樣的功能實現,有于是向量矩陣運算,碰到復雜的for時,計算速度可以提高,從而提高程序性能。以下是一些常用的NumPy函數和操作,可…

Socket網絡編程(六)——簡易聊天室案例

目錄 聊天室數據傳輸設計客戶端、服務器數據交互數據傳輸協議服務器、多客戶端模型客戶端如何發送消息到另外一個客戶端2個以上設備如何交互數據? 聊天室消息接收實現代碼結構client客戶端重構server服務端重構自身描述信息的構建重構TCPServer.java基于synchronize…

Nginx多次代理后獲取真實的用戶IP訪問地址

需求:記錄用戶操作記錄,類似如下表格的這樣 PS: 注意無論你的服務是Http訪問還是Https 訪問的都是可以的,我們服務之前是客戶只給開放了一個端口,但是既要支持https又要支持http協議,nginx 是可以通過stream 模塊配置雙…

2023中國PostgreSQL數據庫生態大會:洞察前沿趨勢,探索無限可能(附核心PPT資料下載)

隨著數字化浪潮的推進,數據庫技術已成為支撐各行各業數字化轉型的核心力量。2023中國PostgreSQL數據庫生態大會的召開,無疑為業界提供了一個深入交流、共同探索PostgreSQL數據庫技術未來發展趨勢的平臺。本文將帶您走進這場盛會,解析大會的亮…

k8s Pod基礎(概念,容器功能及分類,鏡像拉取和容器重啟策略)

目錄 pod概念 Kubernetes設計Pod概念和特殊組成結構的用意 Pod內部結構: 網絡共享: 存儲共享: pause容器主要功能 pod創建方式 pod使用方式 pod分類 pod的容器分類 基礎容器(infrastructure container)&…

加密和簽名的區別及應用場景

原文網址:加密和簽名的區別及應用場景_IT利刃出鞘的博客-CSDN博客 簡介 本文介紹加密和簽名的區別及應用場景。 RSA是一種非對稱加密算法, 可生成一對密鑰(私鑰和公鑰)。(RSA可以同時支持加密和簽名)。 …

元宇宙3D虛擬場景制作深圳華銳視點免費試用

隨著元宇宙興起,3D線上展廳得到了越來越多的關注和應用。基于VR虛擬現實技術的元宇宙3D線上展廳在線編輯系統,更是為企業在展覽展示領域帶來了前所未有的輔助。 高效便捷: 元宇宙3D線上展廳在線編輯無需復雜的施工和搭建過程,只需…

報錯問題解決django.db.utils.OperationalError: (1049, “Unknown database ‘ mxshop‘“)

開發環境:ubuntu22.04 pycharm 功能:django連接使用mysql數據庫,各項配置看似正常 報錯: django.db.utils.OperationalError: (1049, "Unknown database mxshop") 分析檢查原因: Setting的配置文件內&…

gcd+線性dp,[藍橋杯 2018 國 B] 矩陣求和

一、題目 1、題目描述 經過重重筆試面試的考驗,小明成功進入 Macrohard 公司工作。 今天小明的任務是填滿這么一張表: 表有 �n 行 �n 列,行和列的編號都從 11 算起。 其中第 �i 行第 �j 個元素…

GRPC 錯誤碼表

code數描述OK0不是錯誤;成功返回。CANCELLED1操作通常由調用方取消。UNKNOWN2未知錯誤。例如,當從另一個地址空間接收的值屬于此地址空間中未知的錯誤空間時,可能會返回此錯誤。此外,未返回足夠錯誤信息的 API 引發的錯誤可能會轉換為此錯誤。…

ggplot去除背景

在ggplot2中去除背景,通常指的是去除圖表的灰色背景和網格線,使圖表背景變為透明或白色,以及去除或簡化坐標軸的背景。這可以通過調整主題(theme)來實現。ggplot2提供了多種主題設置,可以用來調整圖表的外觀…

Spring MVC 和 Spring Cloud Gateway不兼容性問題

當啟動SpringCloudGateway網關服務的時候,沒注意好依賴問題,出現了這個問題: Spring MVC found on classpath, which is incompatible with Spring Cloud Gateway. 解決辦法就是:刪除SpringMVC的依賴,即下列依賴。 &…

ChatGPT/GPT4科研應用與AI繪圖及論文高效寫作

原文:ChatGPT/GPT4科研應用與AI繪圖及論文高效寫作 第一:2024年AI領域最新技術 1.OpenAI新模型-GPT-5 2.谷歌新模型-Gemini Ultra 3.Meta新模型-LLama3 4.科大訊飛-星火認知 5.百度-文心一言 6.MoonshotAI-Kimi 7.智譜AI-GLM-4 第二:…

【C++從0到王者】第四十六站:圖的深度優先與廣度優先

文章目錄 一、圖的遍歷二、廣度優先遍歷1.思想2.算法實現3.六度好友 三、深度優先遍歷1.思想2.代碼實現 四、其他問題 一、圖的遍歷 對于圖而言,我們的遍歷一般是遍歷頂點,而不是邊,因為邊的遍歷是比較簡單的,就是鄰接矩陣或者鄰接…

《匯編語言》第3版 (王爽)檢測點3.1解析

第三章 檢測點3.1 (1).在Debug中,用“d 0:0 1f”查看內存,結果如下。 下面的程序執行前,AX 0,BX 0,寫出每條匯編指令執行完后相關寄存器中的值。 mov ax,1 ;將1放入AX寄存器中,…

GC如何判定對象已死

GC判定對象已死的2種方法 引用計數法 給對象中添加一個引用計數器,每當有一個地方引用它時,計數器值就加1;當引用失效時,計數器值就減1;Java語言中沒有選用引用計數算法來管理內存,其中最主要的原因是它很…