采用自動微分進行模型的訓練

?自動微分訓練模型

?簡單代碼實現:

import torch
import torch.nn as nn
import torch.optim as optim# 定義一個簡單的線性回歸模型
class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # 輸入維度是1,輸出維度也是1def forward(self, x):return self.linear(x)# 準備訓練數據
x_train = torch.tensor([[1.0], [2.0], [3.0]])
y_train = torch.tensor([[2.0], [4.0], [6.0]])# 實例化模型、損失函數和優化器
model = LinearRegression()
criterion = nn.MSELoss()  # 均方誤差損失函數
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 隨機梯度下降優化器# 訓練模型
epochs = 1000
for epoch in range(epochs):# 前向傳播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向傳播optimizer.zero_grad()  # 清空之前的梯度loss.backward()  # 自動計算梯度optimizer.step()  # 更新模型參數if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')# 測試模型
x_test = torch.tensor([[4.0]])
predicted = model(x_test)
print(f'預測值: {predicted.item():.4f}')

代碼分解:

1.定義一個簡單的線性回歸模型:

  • LinearRegression?類繼承自nn.Module,這是所有神經網絡模型的基類
  • 在?__init__?方法中,定義了一個線性層?self.linear,它的輸入維度是1,輸出維度也是1。
  • forward?方法定義了數據在模型中的傳播路徑,即輸入?x?經過?self.linear?層后得到輸出。
    class LinearRegression(nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # 輸入維度是1,輸出維度也是1def forward(self, x):return self.linear(x)
    

2.準備訓練數據:

  • x_train?和?y_train?分別是輸入和目標輸出的訓練數據。每個張量表示一個樣本,x_train?中的每個元素是一個維度為1的張量,因為模型的輸入維度是1。
    x_train = torch.tensor([[1.0], [2.0], [3.0]])
    y_train = torch.tensor([[2.0], [4.0], [6.0]])
    

3.實例化模型,損失函數和優化器:

  • model?是我們定義的?LinearRegression?類的一個實例,即我們要訓練的線性回歸模型。
  • criterion?是損失函數,這里選擇了均方誤差損失(MSE Loss),用于衡量預測值與實際值之間的差異。
  • optimizer?是優化器,這里選擇了隨機梯度下降(SGD),用于更新模型參數以最小化損失。
    model = LinearRegression()
    criterion = nn.MSELoss()  # 均方誤差損失函數
    optimizer = optim.SGD(model.parameters(), lr=0.01)  # 隨機梯度下降優化器
    

4.訓練模型:

  • 這里進行了1000次迭代的訓練過程。
  • 在每個迭代中,首先進行前向傳播,計算模型對?x_train?的預測輸出?outputs,然后計算損失?loss
  • 調用?optimizer.zero_grad()?來清空之前的梯度,然后調用?loss.backward()?自動計算梯度,最后調用?optimizer.step()?來更新模型參數
    epochs = 1000
    for epoch in range(epochs):# 前向傳播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向傳播optimizer.zero_grad()  # 清空之前的梯度loss.backward()  # 自動計算梯度optimizer.step()  # 更新模型參數if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
    

5.測試模型:

  • x_test?是用來測試模型的輸入數據,這里表示輸入為4.0。
  • model(x_test)?對?x_test?進行前向傳播,得到預測結果?predicted
  • predicted.item()?取出預測結果的標量值并打印出來。
    x_test = torch.tensor([[4.0]])
    predicted = model(x_test)
    print(f'預測值: {predicted.item():.4f}')
    

運行結果:

運行結果如下:

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/44963.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/44963.shtml
英文地址,請注明出處:http://en.pswp.cn/web/44963.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Linux】數據流重定向

數據流重定向(redirect)由字面上的意思來看,好像就是將【數據給它定向到其他地方去】的樣子? 沒錯,數據流重定向就是將某個命令執行后應該要出現在屏幕上的數據,給它傳輸到其他的地方,例如文件或…

[圖解]企業應用架構模式2024新譯本講解26-層超類型2

1 00:00:00,510 --> 00:00:03,030 這個時候,如果再次查找所有人員 2 00:00:03,040 --> 00:00:03,750 我們會發現 3 00:00:05,010 --> 00:00:06,370 這一次所有的對象 4 00:00:06,740 --> 00:00:08,690 都是來自標識映射的 5 00:00:10,540 --> 00…

VB 上位機開發

VB 上位機開發第一節 在 VB(Visual Basic)上位機開發的第一節課程中涵蓋以下基礎內容: 一、上位機開發簡介 解釋上位機的概念和作用,它是與硬件設備進行通信和控制的軟件應用程序。舉例說明上位機在工業自動化、智能家居、監控系統等領域的應用。二、VB 開發環境介紹 展示如…

2024遼寧省數學建模C題【改性生物碳對水中洛克沙胂和砷離子的吸附】原創論文分享

大家好呀,從發布賽題一直到現在,總算完成了2024 年遼寧省大學數學建模競賽C題改性生物碳對水中洛克沙胂和砷離子的吸附完整的成品論文。 本論文可以保證原創,保證高質量。絕不是隨便引用一大堆模型和代碼復制粘貼進來完全沒有應用糊弄人的垃…

Rubber Duck Debugging: History and Benefits 橡皮鴨調試:歷史和優勢

注:機翻,未校對。 Discover the origins of rubber duck debugging, why it works, and why it has become so popular among programmers. 了解橡皮鴨調試的起源,它為什么有效,以及為什么它在程序員中如此受歡迎。 Debugging co…

AMD CPU加 vega 顯卡運行ollama本地大模型

顯卡是VEGA56,這個卡代號是gfx900 雖然ollama頁面上寫著這個卡可以,但是實際是不可以的 報錯如下: levelWARN sourceamd_windows.go:97 msg"amdgpu is not supported" gpu0 gpu_typegfx900:xnack 它認為的GPU型號是 gfx900:xna…

【JavaScript】解決 JavaScript 語言報錯:Uncaught SyntaxError: Unexpected identifier

文章目錄 一、背景介紹常見場景 二、報錯信息解析三、常見原因分析1. 缺少必要的標點符號2. 使用了不正確的標識符3. 關鍵詞拼寫錯誤4. 變量名與保留字沖突 四、解決方案與預防措施1. 檢查和添加必要的標點符號2. 使用正確的標識符3. 檢查關鍵詞拼寫4. 避免使用保留字作為變量名…

全棧 Discord 克隆:Next.js 13、React、Socket.io、Prisma、Tailwind、MySQL筆記(一)

前言 閱讀本文你需要有 Next.js 基礎 React 基礎 Prisma 基礎 tailwind 基礎 MySql基礎 準備工作 打開網站 https://ui.shadcn.com/docs 這不是一個組件庫。它是可重用組件的集合,您可以將其復制并粘貼到應用中。 打開installation 選擇Next.js 也就是此頁面…

Python3 第十七課 -- 編程第一步

目錄 一. 前言 二. end 關鍵字 一. 前言 在前面的教程中我們已經學習了一些 Python3 的基本語法知識,接下來我們來嘗試一些實例。 打印字符串: print("Hello, world!") 輸出結果為: Hello, world! 輸出變量值: i 256*256…

智慧校園服務監控功能

智慧校園系統中的服務監控功能,扮演著維護整個校園數字化生態系統穩定與高效運作的重要角色。它如同一位全天候的守護者,通過實時跟蹤、分析并響應系統各層面的運行狀況,確保教學、管理等核心業務流程的順暢進行。 服務監控功能覆蓋了智慧校園…

開發個人Ollama-Chat--6 OpenUI

開發個人Ollama-Chat–6 OpenUI Open-webui Open WebUI 是一種可擴展、功能豐富且用戶友好的自托管 WebUI,旨在完全離線運行。它支持各種 LLM 運行器,包括 Ollama 和 OpenAI 兼容的 API。 功能 由于總所周知的原由,OpenAI 的接口需要密鑰才…

知識圖譜與 LLM:微調與檢索增強生成

Midjourney 的知識圖譜聊天機器人的想法。 大型語言模型 (LLM) 的第一波炒作來自 ChatGPT 和類似的基于網絡的聊天機器人,這些模型在理解和生成文本方面非常出色,這讓人們(包括我自己)感到震驚。 我們中的許多人登錄并測試了它寫…

微信視頻號的視頻怎么下載到本地?快速教你下載視頻號視頻

天來說說市面上常見的微信視頻號視頻下載工具,教大家快速下載視頻號視頻! 方法一:緩存方法 該方法來源早期視頻技術,因早期無法將大量視頻通過網絡存儲,故而會有緩存視頻文件到手機,其目的為了提高用戶體驗…

尚硅谷Vue3入門到實戰,最新版vue3+TypeScript前端開發教程

Vue3 編碼規范 創建vue3工程 基于vite創建 快速上手 | Vue.js (vuejs.org) npm create vuelatest 在nodejs環境下運行進行創建 按提示進行創建 用vscode打開項目 安裝依賴 源文件有src 內有main.ts App.vue 簡單分析 編寫src vue2語法在三中適用 vue2中的date metho…

UnityECS學習中問題及總結entityQuery.ToComponentDataArray和entityQuery.ToEntityArray區別

在Unity的ECS&#xff08;Entity Component System&#xff09;開發中&#xff0c;entityQuery.ToComponentDataArray<T>(Allocator.Temp) 和 entityQuery.ToEntityArray(Allocator.Temp) 是兩種不同的方法&#xff0c;用于從實體查詢中獲取數據。除了泛型參數之外&#…

【深度學習入門篇 ⑤ 】PyTorch網絡模型創建

【&#x1f34a;易編橙&#xff1a;一個幫助編程小伙伴少走彎路的終身成長社群&#x1f34a;】 大家好&#xff0c;我是小森( &#xfe61;?o?&#xfe61; ) &#xff01; 易編橙終身成長社群創始團隊嘉賓&#xff0c;橙似錦計劃領銜成員、阿里云專家博主、騰訊云內容共創官…

git、huggingface 學術加速

1、git 有時候服務器不能直接訪問 github&#xff0c;下載代碼會很麻煩&#xff1b;安裝庫的時候&#xff0c;pip xx git 就更難了 比如&#xff0c;這次我需要安裝 unsloth&#xff0c;官方給出的腳本是&#xff1a; pip install “unsloth[cu121-torch220] githttps://git…

【python】函數重構

函數重構 函數重構pycharm函數重構步驟函數重構練習 函數重構 函數重構是指對現有函數進行修改和優化的過程。重構的目的是改善代碼的可讀性、可維護性和靈活性&#xff0c;同時保持其功能不變。函數重構通常包括以下步驟&#xff1a; 理解函數的功能和目的。了解函數的作用和…

OSPF.綜合實驗

1、首先將各個網段基于172.16.0.0 16 進行劃分 1.1、劃分為4個大區域 172.16.0.0 18 172.16.64.0 18 172.16.128.0 18 172.16.192.0 18 四個網段 劃分R4 劃分area2 劃分area3 劃分area1 2、進行IP配置 如圖使用配置指令進行配置 ip address x.x.x.x /x 并且將缺省路由…

Sortable.js板塊拖拽示例

圖例 代碼在圖片后面 點贊??關注&#x1f64f;收藏?? 頁面加載后顯示 拖拽效果 源代碼 由于js庫使用外鏈&#xff0c;所以會加載一會兒 <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name&qu…