PyTorch實現邏輯回歸

最終效果

先看下最終效果:
1
這里用一條直線把二維平面上不同的點分開。

生成隨機數據

#創建訓練數據
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#構建線性回歸參數
w = torch.randn((1))#隨機初始化w,要用到自動梯度求導
b = torch.zeros((1))#使用0初始化b,要用到自動梯度求導n_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值為2.標準差為1.5的隨機數組成的矩陣
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值為-2.標準差為1.5的隨機數組成的矩陣
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)

數據可視化

def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()

使用x和o來表示兩種不同類別的數據。
1

定義模型和損失函數

#構建邏輯回歸參數
w = torch.tensor([1.,],requires_grad=True)  # 隨機初始化w
b = torch.zeros((1),requires_grad=True)  # 使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()

這里使用了平方損失函數來估算模型準確度。

訓練模型

最多訓練100次,每次都會更新模型參數,當損失值小于0.03時停止訓練。

xx = torch.arange(-4, 5)
lr = 0.02 #學習率
for iteration in range(100):#前向傳播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向傳播loss.backward()#更新參數b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#繪圖if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:  # 停止條件break

全部代碼

import torch
import matplotlib.pyplot as plt
import matplotlib.markers as mmarkers#創建訓練數據
x = torch.rand(10,1)*10 #shape(10,1)
y = 2*x + (5 + torch.randn(10,1))#構建線性回歸參數
w = torch.randn((1))#隨機初始化w,要用到自動梯度求導
b = torch.zeros((1))#使用0初始化b,要用到自動梯度求導wx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + bn_data = torch.ones(100, 2)
xy0 = torch.normal(2 * n_data, 1.5)  # 生成均值為2.標準差為1.5的隨機數組成的矩陣
c0 = torch.zeros(100)
xy1 = torch.normal(-2 * n_data, 1.5)  # 生成均值為-2.標準差為1.5的隨機數組成的矩陣
c1 = torch.ones(100)x,y = torch.cat((xy0,xy1),0).type(torch.FloatTensor).split(1, dim=1)
x = x.squeeze()
y = y.squeeze()
c = torch.cat((c0,c1),0).type(torch.FloatTensor)def plot(x, y, c):ax = plt.gca()sc = ax.scatter(x, y, color='black')paths = []for i in range(len(x)):if c[i].item() == 0:marker_obj = mmarkers.MarkerStyle('o')else:marker_obj = mmarkers.MarkerStyle('x')path = marker_obj.get_path().transformed(marker_obj.get_transform())paths.append(path)sc.set_paths(paths)return sc
plot(x, y, c)
plt.show()#構建邏輯回歸參數
w = torch.tensor([1.,],requires_grad=True)#隨機初始化w
b = torch.zeros((1),requires_grad=True)#使用0初始化bwx = torch.mul(w,x) # w*x
y_pred = torch.add(wx,b) # y = w*x + b
loss = (0.5*(y-y_pred)**2).mean()xx = torch.arange(-4, 5)
lr = 0.02 #學習率
for iteration in range(100):#前向傳播loss = ((torch.sigmoid(x*w+b-y) - c)**2).mean()#反向傳播loss.backward()#更新參數b.data.sub_(lr*b.grad) # b = b - lr*b.gradw.data.sub_(lr*w.grad) # w = w - lr*w.grad#繪圖if iteration % 3 == 0:plot(x, y, c)yy = w*xx + bplt.plot(xx.data.numpy(),yy.data.numpy(),'r-',lw=5)plt.text(-4,2,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'black'})plt.xlim(-4,4)plt.ylim(-4,4)plt.title("Iteration:{}\nw:{},b:{}".format(iteration,w.data.numpy(),b.data.numpy()))plt.show()if loss.data.numpy() < 0.03:#停止條件break

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/210885.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/210885.shtml
英文地址,請注明出處:http://en.pswp.cn/news/210885.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

使用 ROS 和 Geomagic Haptic 驅動 Franka 機械臂

文章目錄 前言一、安裝 franka_ros二、安裝 OpenHaptics for Linux三、安裝 3D Systems Geomagic Touch ROS Driver四、安裝 franka_interactive_controllers五、使用 Geomagic Haptic 驅動 Franka 機械臂 前言 本文為在雙系統上使用 ROS 和 Geomagic Haptic 驅動 Franka 機械…

滑動窗口(單調隊列)

154. 滑動窗口 - AcWing題庫 給定一個大小為 n≤10^6≤10^6 的數組。 有一個大小為 k 的滑動窗口&#xff0c;它從數組的最左邊移動到最右邊。 你只能在窗口中看到 k 個數字。 每次滑動窗口向右移動一個位置。 以下是一個例子&#xff1a; 該數組為 [1 3 -1 -3 5 3 6 7]&…

HashMap的那些事

一、HashMap與HashTable的區別 1.來歷 HashTable是一種鍵值映射的數據結構&#xff0c;自從java發布就存在&#xff0c;而HashMap是jdk1.2后才出現的&#xff0c;雖然說HashTable出現得早且線程安全&#xff0c;但是效率很低已經棄用了&#xff0c;現在HashMap逐漸成為主流 …

Nmap腳本未來的發展趨勢

Nmap腳本技術的發展趨勢和前景 Nmap腳本是一種基于Lua語言開發的腳本&#xff0c;可以擴展Nmap的功能&#xff0c;用于自動化掃描、漏洞檢測、服務探測、設備管理等方面。隨著網絡安全的不斷發展和漏洞的不斷出現&#xff0c;Nmap腳本技術也在不斷發展和壯大。在本文中&#xf…

小米手機鎖屏時間設置為永不休眠_手機不息屏_保持亮屏

環境&#xff1a;打開手機自帶的鎖屏時間設置發現沒有 永不息屏的選項 原因&#xff1a;采用了三星OLED屏幕&#xff0c;所以根據OLED屏幕特性&#xff0c;這個是為了防止燒屏而特意設計的。非OLED機型支持設置“永不” 解決方案1&#xff1a;原生系統是支持永不鎖屏的&#…

Android 13 - Media框架(20)- ACodec(二)

這一節開始我們就來學習 ACodec 的實現 1、創建 ACodec ACodec 是在 MediaCodec 中創建的&#xff0c;這里先貼出創建部分的代碼&#xff1a; mCodec mGetCodecBase(name, owner);if (mCodec NULL) {ALOGE("Getting codec base with name %s (owner%s) failed", n…

ES 如何將國際標準時間格式進行格式化與調整時區

需求&#xff0c;日志收集的時候&#xff0c;時間格式是國際標準時間格式。形如yyyy-MM-ddTHH:mm:ss.SSS。 &#xff08;2023-12-05T02:45:50.282Z&#xff09;這個時區也不對&#xff0c;那如何將此類型的時間&#xff0c;進行格式化呢&#xff1f; 本篇文章體統一個案例&…

Other -- ChatGPT 原理

本文為個人理解&#xff0c;幫助小白&#xff08;本人就是&#xff09;了解正在創建新時代的 AI 產品&#xff0c;如文中理解有誤歡迎留言。 [參考鏈接--](https://baijiahao.baidu.com/s?id1765556782543603120&wfrspider&forpc) 1. 了解一些基本概念 大語言模型&a…

修改 Ganglia 監控 Grid Report timezone 時區 為 東八區 +8 PRC

Ganglia 監控 Grid Report timezone 默認時區 為 零時區 0 現在要修改為 東八區 8 具體操作如下 modify ganglia-web report timezone 0 --> 8 vim /apps/svr/httpd-2.4.48/htdocs/ganglia/header.php // add timezone GMT8 ini_set(date.timezone, PRC);詳細記錄&#x…

【面試】測試/測開(ING)

63. APP端特有的測試 參考&#xff1a;APP專項測試、APP應用測試 crash和anr的區別 1&#xff09;網絡測試 2&#xff09;中斷測試 3&#xff09;安裝、卸載測試 4&#xff09;兼容測試 5&#xff09;性能測試&#xff08;耗電量、流量、內存、服務器端&#xff09; 6&#xf…

畫對比折線圖【Python】

出這一期想必是我做某個課程作業遇到了。 由于去各個官網下載對比圖要錢&#xff0c;我還是不想花錢的&#xff01;真討厭&#xff01;淺淺水一期。 以下是要做的對比圖的數據&#xff1a; 代碼&#xff1a; from matplotlib import pyplot as plt#設置中文顯示plt.rcParams[…

CSS新手入門筆記整理:CSS浮動布局

文檔流概述 正常文檔流 “文檔流”指元素在頁面中出現的先后順序。正常文檔流&#xff0c;又稱為“普通文檔流”或“普通流”&#xff0c;也就是W3C標準所說的“normal flow”。正常文檔流&#xff0c;將一個頁面從上到下分為一行一行&#xff0c;其中塊元素獨占一行&#xf…

ChatGPT OpenAI API請求限制 嘗試解決

1. OpenAI API請求限制 Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: Rate limit reached for gpt-3.5-turbo-16k in organization org-U7I2eKpAo6xA7RUa2Nq307ae on reques…

讓內存無處可逃:智能指針[C++11]

智能指針 文章目錄 智能指針前言RAII什么是智能指針智能指針的應用示例 C98的auto_ptr共享型智能指針&#xff1a;shared_ptrshared_ptr的使用初始化獲取原生指針指定刪除器默認刪除器default_delete指定刪除器指定刪除器管理動態數組 shared_ptr的偽實現shared_ptr的注意事項避…

【Docker】進階之路:(五)Docker引擎

【Docker】進階之路&#xff1a;&#xff08;五&#xff09;Docker引擎 Docker引擎簡介Docker引擎的組件構成runccontainerd Docker引擎簡介 Docker引擎是用來運行和管理容器的核心部分。Docker首次發布時&#xff0c;Docker 引擎由LXC 和 Docker daemon 兩個核心組件構成。 …

linux驅動開發——內核調試技術

目錄 一、前言 二、內核調試方法 2.1 內核調試概述 2.2 學會分析內核源程序 2.3調試方法介紹 三、內核打印函數 3.1內核鏡像解壓前的串口輸出函數 3.2 內核鏡像解壓后的串口輸出函數 3.3 內核打印函數 四、獲取內核信息 4.1系統請求鍵 4.2 通過/proc 接口 4.3 通過…

算法:有效的括號(入棧出棧)

時間復雜度 O(n) 空間復雜度 O(n∣Σ∣)&#xff0c;其中 Σ 表示字符集&#xff0c;本題中字符串只包含 6 種括號 /*** param {string} s* return {boolean}*/ var isValid function(s) {const map {"(":")","{":"}","["…

List截取指定長度(java截取拼接URL)

場景&#xff1a; N多個參數&#xff0c;截取指定個數&#xff0c;拼接URL public static void main(final String[] args) {int count 0;//每頁數量final int pageSize 5;final List<Integer> memberNos ListUtil.toList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13…

python格式化內容

1.字符串格式化: 定義列表 [{"姓名": "張三", "年齡": 18, "性別": "男"}, {"姓名": "里斯李四李斯", "年齡": 18, "性別": "男"}, {"姓名": "斯托夫斯基…

C++知識 抽象基類

抽象基類通常包含至少一個純虛函數&#xff0c;即一個沒有具體實現的虛函數&#xff0c;通過在基類中聲明純虛函數&#xff0c;它強制派生類提供這個函數的具體實現。 通過在類的聲明中使用 virtual 關鍵字和 0 初始化來創建純虛函數&#xff0c;這樣的類就成為抽象基類。以下…