神經網絡代碼入門解析

神經網絡代碼入門解析

import torch
import matplotlib.pyplot as pltimport randomdef create_data(w, b, data_num):  # 數據生成x = torch.normal(0, 1, (data_num, len(w)))y = torch.matmul(x, w) + b  # 矩陣相乘再加bnoise = torch.normal(0, 0.01, y.shape)  # 為y添加噪聲y += noisereturn x, ynum = 500true_w = torch.tensor([8.1, 2, 2, 4])
true_b = 1.1X, Y = create_data(true_w, true_b, num)# plt.scatter(X[:, 3], Y, 1)  # 畫散點圖 對X取全部的行的第三列,標簽Y,點大小
# plt.show()def data_provider(data, label, batchsize):  # 每次取batchsize個數據length = len(label)indices = list(range(length))# 這里需要把數據打亂random.shuffle(indices)for each in range(0, length, batchsize):get_indices = indices[each: each+batchsize]get_data = data[get_indices]get_label = label[get_indices]yield get_data, get_label  # 有存檔點的returnbatchsize = 16
# for batch_x, batch_y in data_provider(X, Y, batchsize):
#     print(batch_x, batch_y)
#     break# 定義模型
def fun(x, w, b):pred_y = torch.matmul(x, w) + breturn pred_y# 定義loss
def maeLoss(pre_y, y):return torch.sum(abs(pre_y-y))/len(y)# sgd(梯度下降)
def sgd(paras, lr):with torch.no_grad():  # 這部分代碼不計算梯度for para in paras:para -= para.grad * lr  # 不能寫成 para = para - paras.grad * lr !!!! 這句相當于要創建一個新的para,會導致報錯para.grad.zero_()  # 將使用過的梯度歸零lr = 0.01
w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True)
b_0 = torch.tensor(0.01, requires_grad=True)
print(w_0, b_0)epochs = 50
for epoch in range(epochs):data_loss = 0for batch_x, batch_y in data_provider(X, Y, batchsize):pred_y = fun(batch_x, w_0, b_0)loss = maeLoss(pred_y, batch_y)loss.backward()sgd([w_0, b_0], lr)data_loss += lossprint("epoch %03d: loss: %.6f" % (epoch, data_loss))print("真實函數值:", true_w, true_b)
print("訓練得到的函數值:", w_0, b_0)idx = 0
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())
plt.scatter(X[:, idx].detach().numpy(), Y, 1)
plt.show()

逐步分析代碼

1.數據生成

image-20250301120222530

首先設計一個函數create_data,提供我們所需要的數據集的x與y

def create_data(w, b, data_num):  # 數據生成x = torch.normal(0, 1, (data_num, len(w)))  # 生成特征數據,形狀為 (data_num, len(w))y = torch.matmul(x, w) + b  # 計算目標值 y = x * w + bnoise = torch.normal(0, 0.01, y.shape)  # 生成噪聲,形狀與 y 相同y += noise  # 為 y 添加噪聲,模擬真實數據中的隨機誤差return x, y
  • torch.normal() 生成一個張量

    • torch.normal(0, 1, (data_num, len(w))):生成一個形狀為 (data_num, len(w)) 的張量,其中的元素是從均值為 0、標準差為 1 的正態分布中隨機采樣的。
  • torch.matmul() 讓矩陣相乘

    matmul: matrix multiply

  • 再使用torch.normal()生成一個張量,添加到y上,相當于為y添加了隨機的噪聲

    噪聲的引入是為了模擬真實數據中的隨機誤差,使生成的數據更接近現實場景。

2.設計一個數據加載器

def data_provider(data, label, batchsize):  # 每次取 batchsize 個數據length = len(label)indices = list(range(length))random.shuffle(indices)  # 打亂數據順序,避免模型學習到順序特征for each in range(0, length, batchsize):get_indices = indices[each: each+batchsize]  # 獲取當前批次的索引get_data = data[get_indices]  # 獲取當前批次的數據get_label = label[get_indices]  # 獲取當前批次的標簽yield get_data, get_label  # 返回當前批次的數據和標簽

data_provider可以分批提供數據,并通過yield來返回已實現記憶功能

首先把list y順序打亂,這樣就相當于從生成的訓練集y中隨機讀取,若不打亂數據,可能造成訓練結果的不理想

打亂數據可以避免模型在訓練過程中學習到數據的順序特征,從而提高模型的泛化能力。

之后分段遍歷打亂的y,返回對應的局部的數據集來給神經網絡進行訓練

3.定義模型函數

image-20250301122853184

def fun(x, w, b):pred_y = torch.matmul(x, w) + b  # 計算預測值 y = x * w + breturn pred_y

fun(x, w, b) 是一個線性模型,形式為 y = x * w + b,其中 x 是輸入特征,w 是權重,b 是偏置。

4.定義Loss函數

image-20250301122958888

def maeLoss(pre_y, y):return torch.sum(abs(pre_y - y)) / len(y)  # 計算平均絕對誤差 (MAE)
  • maeLoss 是平均絕對誤差(Mean Absolute Error, MAE),它計算預測值 pre_y 和真實值 y 之間的絕對誤差的平均值。
  • 公式為:MAE = (1/n) * Σ|pre_y - y|,其中 n 是樣本數量。

5.梯度下降sgd函數

# sgd(梯度下降)
def sgd(paras, lr):with torch.no_grad():  # 這部分代碼不計算梯度for para in paras:para -= para.grad * lr  # 不能寫成 para = para - paras.grad * lr !!!! 這句相當于要創建一個新的para,會導致報錯para.grad.zero_()  # 將使用過的梯度歸零

這里需要使用torch.no_grad()來避免重復計算梯度

image-20250301123531781

在前向過程中已經累計過一次梯度了,如果在梯度下降過程中又累計了梯度,那么就會造成不必要的麻煩

PyTorch 會累積梯度,如果不手動清零,梯度會不斷累積,導致參數更新錯誤。

para -= para.grad * lr就是將參數w修正的過程(w=w-(dy^/dw)*learningRate)

torch.no_grad() 是一個上下文管理器,用于禁用梯度計算。在參數更新時,禁用梯度計算可以避免不必要的計算和內存占用。

5.開始訓練

epochs = 50
for epoch in range(epochs):data_loss = 0num_batches = len(Y) // batchsize  # 計算批次數量for batch_x, batch_y in data_provider(X, Y, batchsize):pred_y = fun(batch_x, w_0, b_0)  # 前向傳播loss = maeLoss(pred_y, batch_y)  # 計算損失loss.backward()  # 反向傳播sgd([w_0, b_0], lr)  # 更新參數data_loss += loss.item()  # 累積損失print("epoch %03d: loss: %.6f" % (epoch, data_loss / num_batches))  # 打印平均損失

先定義一個訓練輪次epochs=50,表示訓練50輪

在每輪訓練中將loss記錄下來,以此評價訓練的效果

首先用data_provider來獲取數據集中隨機的一部分

接著傳入相應數據給模型函數,通過前向傳播獲得預測y值pred_y

調用Loss計算函數,獲取這次的loss,再通過反向傳播loss.backward()計算梯度

loss.backward() 是反向傳播的核心步驟,用于計算損失函數對模型參數的梯度。

再通過梯度下降sgd([w_0, b_0], lr)來更新模型的參數

最終將這組數據的loss累加到這輪數據的loss中

6.結果繪制

idx = 0
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy() * w_0[idx].detach().numpy() + b_0.detach().numpy())  # 繪制預測直線
plt.scatter(X[:, idx].detach().numpy(), Y, 1)  # 繪制真實數據點
plt.show()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/71113.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/71113.shtml
英文地址,請注明出處:http://en.pswp.cn/web/71113.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

DeepSeek 開源狂歡周(一)FlashMLA:高效推理加速新時代

上周末,DeepSeek在X平臺(Twitter)宣布將開啟連續一周的開源,整個開源社區為之沸騰,全球AI愛好者紛紛為關注。沒錯,這是一場由DeepSeek引領的開源盛宴,推翻了傳統推理加速的種種限制。這周一&…

EfficientViT模型詳解及代碼復現

核心架構 在EfficientViT模型的核心架構中,作者設計了一種創新的 sandwich布局 作為基礎構建塊,旨在提高內存效率和計算效率。這種布局巧妙地平衡了自注意力層和前饋神經網絡層的比例,具體結構如下: 基于深度卷積的Token Interaction :通過深度卷積操作對輸入特征進行初步…

大語言模型(LLM)如何賦能時間序列分析?

引言 近年來,大語言模型(LLM)在文本生成、推理和跨模態任務中展現了驚人能力。與此同時,時間序列分析作為工業、金融、物聯網等領域的核心技術,長期依賴傳統統計模型(如ARIMA)或深度學習模型&a…

Java 設計模式:軟件開發的精髓與藝

目錄 一、設計模式的起源二、設計模式的分類1. 創建型模式2. 結構型模式3. 行為型模式三、設計模式的實踐1. 單例模式2. 工廠模式3. 策略模式四、設計模式的優勢五、設計模式的局限性六、總結在軟件開發的浩瀚星空中,設計模式猶如一顆顆璀璨的星辰,照亮了開發者前行的道路。它…

【基于Raft的KV共識算法】-序:Raft概述

本文目錄 1.為什么會有Raft?CAP理論 2.Raft基本原理流程為什么要以日志作為中間載體? 3.實現思路任期領導選舉日志同步 1.為什么會有Raft? 簡單來說就是數據會隨著業務和時間的增長,單機不能存的下,這個時候需要以某種…

【愚公系列】《Python網絡爬蟲從入門到精通》040-Matplotlib 概述

標題詳情作者簡介愚公搬代碼頭銜華為云特約編輯,華為云云享專家,華為開發者專家,華為產品云測專家,CSDN博客專家,CSDN商業化專家,阿里云專家博主,阿里云簽約作者,騰訊云優秀博主,騰訊云內容共創官,掘金優秀博主,亞馬遜技領云博主,51CTO博客專家等。近期榮譽2022年度…

EasyRTC嵌入式WebRTC技術與AI大模型結合:從ICE框架優化到AI推理

實時通信技術在現代社會中扮演著越來越重要的角色,從視頻會議到在線教育,再到遠程醫療,其應用場景不斷拓展。WebRTC作為一項開源項目,為瀏覽器和移動應用提供了便捷的實時通信能力。而EasyRTC作為基于WebRTC的嵌入式解決方案&…

javaEE初階————多線程初階(5)

本期是多線程初階的最后一篇文章了,下一篇就是多線程進階的文章了,大家加油! 一,模擬實現線程池 我們上期說過線程池類似一個數組,我們有任務就放到線程池中,讓線程池幫助我們完成任務,我們該如…

工業AR眼鏡的‘芯’動力:FPC讓制造更智能【新立電子】

隨著增強現實(AR)技術的快速發展,工業AR智能眼鏡也正逐步成為制造業領域的重要工具。它不僅為現場工作人員提供了視覺輔助,還極大地提升了遠程協助的效率、優化了倉儲管理。FPC在AI眼鏡中的應用,為工業AR智能眼鏡提供了…

FPGA開發,使用Deepseek V3還是R1(5):temperature設置

以下都是Deepseek生成的答案 FPGA開發,使用Deepseek V3還是R1(1):應用場景 FPGA開發,使用Deepseek V3還是R1(2):V3和R1的區別 FPGA開發,使用Deepseek V3還是R1&#x…

網站內容更新后百度排名下降怎么辦?有效策略有哪些?

轉自 網站內容更新后百度排名下降怎么辦?有效策略有哪些? 網站內容更新是促進網站優化的關鍵環節,但是頻繁修改網站內容會對網站的搜索引擎排名造成很大的影響。為了保持網站排名,我們需要采取一些措施來最小化對百度排名的影響。…

安裝 cpolar 內網穿透工具的步驟

安裝 cpolar 內網穿透工具的步驟 1. 下載 cpolar 軟件安裝包 步驟: 前往 cpolar 官方下載頁面。 根據您的操作系統(Windows、macOS、Linux 等),選擇對應的安裝包進行下載。 2. 注冊 cpolar 賬號 步驟: 訪問 cpolar…

Linux :進程狀態

目錄 1 引言 2 操作系統的資源分配 3進程狀態 3.1運行狀態 3.2 阻塞狀態 3.3掛起狀態 4.進程狀態詳解 4.1 運行狀態R 4.2 休眠狀態S 4.3深度睡眠狀態D 4.4僵尸狀態Z 5 孤兒進程 6 進程優先級 其他概念 1 引言 🌻在前面的文章中,我們已…

openwebUI訪問vllm加載deepseek微調過的本地大模型

文章目錄 前言一、openwebui安裝二、配置openwebui環境三、安裝vllm四、啟動vllm五、啟動openwebui 前言 首先安裝vllm,然后加載本地模型,會起一個端口好。 在安裝openwebui,去訪問這個端口號。下面具體步驟的演示。 一、openwebui安裝 rootautodl-co…

DeepSeek-V3:AI語言模型的高效訓練與推理之路

參考:【論文學習】DeepSeek-V3 全文翻譯 在人工智能領域,語言模型的發展日新月異。從早期的簡單模型到如今擁有數千億參數的巨無霸模型,技術的進步令人矚目。然而,隨著模型規模的不斷擴大,訓練成本和推理效率成為了擺在…

Spring單例模式 Spring 中的單例 餓漢式加載 懶漢式加載

目錄 核心特性 實現方式詳解 1. 餓漢式(Eager Initialization) 2. 懶漢式(Lazy Initialization) 3. 靜態內部類(Bill Pugh 實現) 4. 枚舉(Enum) 破壞單例的場景及防御 Sprin…

DeepSeek MLA(Multi-Head Latent Attention)算法淺析

目錄 前言1. 從MHA、MQA、GQA到MLA1.1 MHA1.2 瓶頸1.3 MQA1.4 GQA1.5 MLA1.5.1 Part 11.5.2 Part 21.5.3 Part 3 結語參考 前言 學習 DeepSeek 中的 MLA 模塊,究極縫合怪,東抄抄西抄抄,主要 copy 自蘇神的文章,僅供自己參考&#…

uniapp 中引入使用uView UI

文章目錄 一、前言:選擇 uView UI的原因二、完整引入步驟1. 安裝 uView UI2. 配置全局樣式變量(關鍵!)3. 在 pages.json中添加:4. 全局注冊組件5. 直接使用組件 五、自定義主題色(秒換皮膚) 一、…

zookeeper-docker版

Zookeeper-docker版 1 zookeeper概述 1.1 什么是zookeeper Zookeeper是一個分布式的、高性能的、開源的分布式系統的協調(Coordination)服務,它是一個為分布式應用提供一致性服務的軟件。 1.2 zookeeper應用場景 zookeeper是一個經典的分…

【量化金融自學筆記】--開篇.基本術語及學習路徑建議

在當今這個信息爆炸的時代,金融領域正經歷著一場前所未有的變革。傳統的金融分析方法逐漸被更加科學、精準的量化技術所取代。量化金融,這個曾經高不可攀的領域,如今正逐漸走進大眾的視野。它將數學、統計學、計算機科學與金融學深度融合&…