Pytorch高階API示范——線性回歸模型

本文與《20天吃透Pytorch》有所不同,《20天吃透Pytorch》中是繼承之前的模型進行擬合,本文是單獨建立網絡進行擬合。

代碼實現:

import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,TensorDataset"""
1.準備數據
"""
n=800   #樣本數量#生成測試用的數據集
X = 10*torch.rand([n,2])-5.0    #torch.rand是均勻分布
w0 = torch.tensor([[2.0],[-3.0]])
b0 = torch.tensor([10.0])
Y = X@w0 + b0 + torch.normal(0.0,2.0,size=[n,1])    ## @表示矩陣乘法,增加正態擾動#數據可視化
plt.figure(figsize= (12,5))
ax1 = plt.subplot(121)
ax1.scatter(X[:,0],Y[:,0],c = 'b',label = 'samples')
ax1.legend()    #圖例
plt.xlabel("x1")
plt.ylabel("y",rotation = 0)
ax2 = plt.subplot(122)
ax2.scatter(X[:,1],Y[:,0],c = 'g',label = 'samples')
ax2.legend()
plt.xlabel('x2')
plt.ylabel('y',rotation = 0)
plt.show()"""
構建通道
"""ds = TensorDataset(X,Y)
ds_train,ds_valid = torch.utils.data.random_split(ds,[int (n*0.7),n-int(n*0.7)])  #選取總樣本的70%為訓練數據
dl_train = DataLoader(ds_train,batch_size=10,shuffle=True)
dl_valid = DataLoader(ds_valid,batch_size=10,shuffle=True)"""
2.定義模型
"""class LinearRegression(torch.nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.fc = nn.Linear(2,1)def forward(self,x):x = self.fc(x)return xnet = LinearRegression()
"""
3.訓練模型
"""
loss_func = torch.nn.MSELoss()
optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)eporchs = 10
log_step_freq = 20for eporch in range(1,eporchs+1):net.train()loss_sum = 0.0metric_sum = 0.0step = 1for step,(features,labels) in enumerate(dl_train,1):predictions = net(features)loss = loss_func(predictions,labels)optimizer.zero_grad()loss.backward()optimizer.step()w = net.state_dict()["fc.weight"]b = net.state_dict()["fc.bias"]print("step =", step, "loss = ", loss)print("w =", w)print("b =", b)loss_sum += loss.item()"""
結果可視化
"""
w,b = net.state_dict()["fc.weight"],net.state_dict()["fc.bias"]plt.figure(figsize = (12,5))
ax1 = plt.subplot(121)
ax1.scatter(X[:,0],Y[:,0], c = "b",label = "samples")
ax1.plot(X[:,0],w[0,0]*X[:,0]+b[0],"-r",linewidth = 5.0,label = "model")
ax1.legend()
plt.xlabel("x1")
plt.ylabel("y",rotation = 0)ax2 = plt.subplot(122)
ax2.scatter(X[:,1],Y[:,0], c = "g",label = "samples")
ax2.plot(X[:,1],w[0,1]*X[:,1]+b[0],"-r",linewidth = 5.0,label = "model")
ax2.legend()
plt.xlabel("x2")
plt.ylabel("y",rotation = 0)plt.show()

結果展示:

數據部分:

在這里插入圖片描述

線性回歸結果:

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389381.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389381.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389381.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vue 上傳圖片限制大小和格式

<div class"upload-box clear"><span class"fl">上傳圖片</span><div class"artistDet-logo-box fl"><el-upload :action"this.baseServerUrl/fileUpload/uploadPic?filepathartwork" list-type"pic…

作業要求 20181023-3 每周例行報告

本周要求參見&#xff1a;https://edu.cnblogs.com/campus/nenu/2018fall/homework/2282 1、本周PSP 總計&#xff1a;927min 2、本周進度條 代碼行數 博文字數 用到的軟件工程知識點 217 757 PSP、版本控制 3、累積進度圖 &#xff08;1&#xff09;累積代碼折線圖 &…

算命數據_未來的數據科學家或算命精神向導

算命數據Real Estate Sale Prices, Regression, and Classification: Data Science is the Future of Fortune Telling房地產銷售價格&#xff0c;回歸和分類&#xff1a;數據科學是算命的未來 As we all know, I am unusually blessed with totally-real psychic abilities.眾…

openai-gpt_為什么到處都看到GPT-3?

openai-gptDisclaimer: My opinions are informed by my experience maintaining Cortex, an open source platform for machine learning engineering.免責聲明&#xff1a;我的看法是基于我維護 機器學習工程的開源平臺 Cortex的 經驗而 得出 的。 If you frequent any part…

Pytorch高階API示范——DNN二分類模型

代碼部分&#xff1a; import numpy as np import pandas as pd from matplotlib import pyplot as plt import torch from torch import nn import torch.nn.functional as F from torch.utils.data import Dataset,DataLoader,TensorDataset""" 準備數據 &qu…

OO期末總結

$0 寫在前面 善始善終&#xff0c;臨近期末&#xff0c;為一學期的收獲和努力畫一個圓滿的句號。 $1 測試與正確性論證的比較 $1-0 什么是測試&#xff1f; 測試是使用人工操作或者程序自動運行的方式來檢驗它是否滿足規定的需求或弄清預期結果與實際結果之間的差別的過程。 它…

puppet puppet模塊、file模塊

轉載&#xff1a;http://blog.51cto.com/ywzhou/1577356 作用&#xff1a;通過puppet模塊自動控制客戶端的puppet配置&#xff0c;當需要修改客戶端的puppet配置時不用在客戶端一一設置。 1、服務端配置puppet模塊 &#xff08;1&#xff09;模塊清單 [rootpuppet ~]# tree /et…

數據可視化及其重要性:Python

Data visualization is an important skill to possess for anyone trying to extract and communicate insights from data. In the field of machine learning, visualization plays a key role throughout the entire process of analysis.對于任何試圖從數據中提取和傳達見…

熊貓數據集_熊貓邁向數據科學的第三部分

熊貓數據集Data is almost never perfect. Data Scientist spend more time in preprocessing dataset than in creating a model. Often we come across scenario where we find some missing data in data set. Such data points are represented with NaN or Not a Number i…

Pytorch有關張量的各種操作

一&#xff0c;創建張量 1. 生成float格式的張量: a torch.tensor([1,2,3],dtype torch.float)2. 生成從1到10&#xff0c;間隔是2的張量: b torch.arange(1,10,step 2)3. 隨機生成從0.0到6.28的10個張量 注意&#xff1a; (1).生成的10個張量中包含0.0和6.28&#xff…

mongodb安裝失敗與解決方法(附安裝教程)

安裝mongodb遇到的一些坑 浪費了大量的時間 在此記錄一下 主要是電腦系統win10企業版自帶的防火墻 當然還有其他的一些坑 一般的問題在第6步驟都可以解決&#xff0c;本教程的安裝步驟不夠詳細的話 請自行百度或谷歌 安裝教程很多 我是基于node.js使用mongodb結合Robo 3T數…

【洛谷算法題】P1046-[NOIP2005 普及組] 陶陶摘蘋果【入門2分支結構】Java題解

&#x1f468;?&#x1f4bb;博客主頁&#xff1a;花無缺 歡迎 點贊&#x1f44d; 收藏? 留言&#x1f4dd; 加關注?! 本文由 花無缺 原創 收錄于專欄 【洛谷算法題】 文章目錄 【洛谷算法題】P1046-[NOIP2005 普及組] 陶陶摘蘋果【入門2分支結構】Java題解&#x1f30f;題目…

web性能優化(理論)

什么是性能優化&#xff1f; 就是讓用戶感覺你的網站加載速度很快。。。哈哈哈。 分析 讓我們來分析一下從用戶按下回車鍵到網站呈現出來經歷了哪些和前端相關的過程。 緩存 首先看本地是否有緩存&#xff0c;如果有符合使用條件的緩存則不需要向服務器發送請求了。DNS查詢建立…

python多項式回歸_如何在Python中實現多項式回歸模型

python多項式回歸Let’s start with an example. We want to predict the Price of a home based on the Area and Age. The function below was used to generate Home Prices and we can pretend this is “real-world data” and our “job” is to create a model which wi…

充分利用UC berkeleys數據科學專業

By Kyra Wong and Kendall Kikkawa黃凱拉(Kyra Wong)和菊川健多 ( Kendall Kikkawa) 什么是“數據科學”&#xff1f; (What is ‘Data Science’?) Data collection, an important aspect of “data science”, is not a new idea. Before the tech boom, every industry al…

文本二叉樹折半查詢及其截取值

using System;using System.ComponentModel;using System.Data;using System.Drawing;using System.Text;using System.Windows.Forms;using System.Collections;using System.IO;namespace CS_ScanSample1{ /// <summary> /// Logic 的摘要說明。 /// </summary> …

nn.functional 和 nn.Module入門講解

本文來自《20天吃透Pytorch》 一&#xff0c;nn.functional 和 nn.Module 前面我們介紹了Pytorch的張量的結構操作和數學運算中的一些常用API。 利用這些張量的API我們可以構建出神經網絡相關的組件(如激活函數&#xff0c;模型層&#xff0c;損失函數)。 Pytorch和神經網絡…

10.30PMP試題每日一題

SC>0&#xff0c;CPI<1&#xff0c;說明項目截止到當前&#xff1a;A、進度超前&#xff0c;成本超值B、進度落后&#xff0c;成本結余C、進度超前&#xff0c;成本結余D、無法判斷 答案將于明天和新題一起揭曉&#xff01; 10.29試題答案&#xff1a;A轉載于:https://bl…

02-web框架

1 while True:print(server is waiting...)conn, addr server.accept()data conn.recv(1024) print(data:, data)# 1.得到請求的url路徑# ------------dict/obj d["path":"/login"]# d.get(”path“)# 按著http請求協議解析數據# 專注于web業…

ai驅動數據安全治理_AI驅動的Web數據收集解決方案的新起點

ai驅動數據安全治理Data gathering consists of many time-consuming and complex activities. These include proxy management, data parsing, infrastructure management, overcoming fingerprinting anti-measures, rendering JavaScript-heavy websites at scale, and muc…