Pytorch高階API示范——DNN二分類模型

代碼部分:

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,TensorDataset"""
準備數據
"""#正負樣本數量
n_positive,n_negative = 2000,2000#生成正樣本, 小圓環分布
r_p = 5.0 + torch.normal(0.0,1.0,size = [n_positive,1])
theta_p = 2*np.pi*torch.rand([n_positive,1])
Xp = torch.cat([r_p*torch.cos(theta_p),r_p*torch.sin(theta_p)],axis = 1)
Yp = torch.ones_like(r_p)#生成負樣本, 大圓環分布
r_n = 8.0 + torch.normal(0.0,1.0,size = [n_negative,1])
theta_n = 2*np.pi*torch.rand([n_negative,1])
Xn = torch.cat([r_n*torch.cos(theta_n),r_n*torch.sin(theta_n)],axis = 1)
Yn = torch.zeros_like(r_n)#匯總樣本
X = torch.cat([Xp,Xn],axis = 0)
Y = torch.cat([Yp,Yn],axis = 0)#可視化
plt.figure(figsize = (6,6))
plt.scatter(Xp[:,0],Xp[:,1],c = "r")
plt.scatter(Xn[:,0],Xn[:,1],c = "g")
plt.legend(["positive","negative"])plt.show()"""
#構建輸入數據管道
"""
ds = TensorDataset(X,Y)
dl = DataLoader(ds,batch_size = 10,shuffle=True)"""
2, 定義模型
"""
class DNNModel(nn.Module):def __init__(self):super(DNNModel, self).__init__()self.fc1 = nn.Linear(2,4)self.fc2 = nn.Linear(4,8)self.fc3 = nn.Linear(8,1)def forward(self,x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))y = nn.Sigmoid()(self.fc3(x))return ydef loss_func(self,y_pred,y_true):return nn.BCELoss()(y_pred,y_true)def metric_func(self,y_pred,y_true):y_pred = torch.where(y_pred > 0.5, torch.ones_like(y_pred, dtype=torch.float32),torch.zeros_like(y_pred, dtype=torch.float32))acc = torch.mean(1 - torch.abs(y_true - y_pred))return acc@propertydef optimizer(self):return torch.optim.Adam(self.parameters(), lr=0.001)model = DNNModel()# 測試模型結構
(features,labels) = next(iter(dl))
predictions = model(features)loss = model.loss_func(predictions,labels)
metric = model.metric_func(predictions,labels)print("init loss:",loss.item())
print("init metric:",metric.item())"""
3,訓練模型
"""
def train_step(model, features, labels):# 正向傳播求損失predictions = model(features)loss = model.loss_func(predictions,labels)metric = model.metric_func(predictions,labels)# 反向傳播求梯度loss.backward()# 更新模型參數model.optimizer.step()model.optimizer.zero_grad()return loss.item(),metric.item()# 測試train_step效果
features,labels = next(iter(dl))    #非for循環就用next
train_step(model,features,labels)def train_model(model,epochs):for epoch in range(1,epochs+1):loss_list,metric_list = [],[]for features, labels in dl:lossi,metrici = train_step(model,features,labels)loss_list.append(lossi)metric_list.append(metrici)loss = np.mean(loss_list)metric = np.mean(metric_list)if epoch%100==0:print("epoch =",epoch,"loss = ",loss,"metric = ",metric)train_model(model,epochs = 300)# 結果可視化
fig, (ax1,ax2) = plt.subplots(nrows=1,ncols=2,figsize = (12,5))
ax1.scatter(Xp[:,0],Xp[:,1], c="r")
ax1.scatter(Xn[:,0],Xn[:,1],c = "g")
ax1.legend(["positive","negative"])
ax1.set_title("y_true")Xp_pred = X[torch.squeeze(model.forward(X)>=0.5)]
Xn_pred = X[torch.squeeze(model.forward(X)<0.5)]ax2.scatter(Xp_pred[:,0],Xp_pred[:,1],c = "r")
ax2.scatter(Xn_pred[:,0],Xn_pred[:,1],c = "g")
ax2.legend(["positive","negative"])
ax2.set_title("y_pred")plt.show()

結果展示:

數據部分:

在這里插入圖片描述

結果分類:

在這里插入圖片描述

思考:

本文中的DNN模型,將loss(損失),metric(準確率),optimizer(優化器)的定義放在了DNN網絡中,這也產生了一系列的問題。首先在調用這些函數時,需要用網絡名+“.”來調用。例如:loss = model.loss_func(predictions,labels)
但是這里最重要的一點是 def optimizer(self):在optimizer函數上面必須有:@property,若沒有將會出現AttributeError: 'function' object has no attribute 'step'的報錯。
@property裝飾器的作用:我們可以使用@property裝飾器來創建只讀屬性,@property裝飾器會將方法轉換為相同名稱的只讀屬性,可以與所定義的屬性配合使用,這樣可以防止屬性被修改。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/389376.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/389376.shtml
英文地址,請注明出處:http://en.pswp.cn/news/389376.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

OO期末總結

$0 寫在前面 善始善終&#xff0c;臨近期末&#xff0c;為一學期的收獲和努力畫一個圓滿的句號。 $1 測試與正確性論證的比較 $1-0 什么是測試&#xff1f; 測試是使用人工操作或者程序自動運行的方式來檢驗它是否滿足規定的需求或弄清預期結果與實際結果之間的差別的過程。 它…

puppet puppet模塊、file模塊

轉載&#xff1a;http://blog.51cto.com/ywzhou/1577356 作用&#xff1a;通過puppet模塊自動控制客戶端的puppet配置&#xff0c;當需要修改客戶端的puppet配置時不用在客戶端一一設置。 1、服務端配置puppet模塊 &#xff08;1&#xff09;模塊清單 [rootpuppet ~]# tree /et…

數據可視化及其重要性:Python

Data visualization is an important skill to possess for anyone trying to extract and communicate insights from data. In the field of machine learning, visualization plays a key role throughout the entire process of analysis.對于任何試圖從數據中提取和傳達見…

熊貓數據集_熊貓邁向數據科學的第三部分

熊貓數據集Data is almost never perfect. Data Scientist spend more time in preprocessing dataset than in creating a model. Often we come across scenario where we find some missing data in data set. Such data points are represented with NaN or Not a Number i…

Pytorch有關張量的各種操作

一&#xff0c;創建張量 1. 生成float格式的張量: a torch.tensor([1,2,3],dtype torch.float)2. 生成從1到10&#xff0c;間隔是2的張量: b torch.arange(1,10,step 2)3. 隨機生成從0.0到6.28的10個張量 注意&#xff1a; (1).生成的10個張量中包含0.0和6.28&#xff…

mongodb安裝失敗與解決方法(附安裝教程)

安裝mongodb遇到的一些坑 浪費了大量的時間 在此記錄一下 主要是電腦系統win10企業版自帶的防火墻 當然還有其他的一些坑 一般的問題在第6步驟都可以解決&#xff0c;本教程的安裝步驟不夠詳細的話 請自行百度或谷歌 安裝教程很多 我是基于node.js使用mongodb結合Robo 3T數…

【洛谷算法題】P1046-[NOIP2005 普及組] 陶陶摘蘋果【入門2分支結構】Java題解

&#x1f468;?&#x1f4bb;博客主頁&#xff1a;花無缺 歡迎 點贊&#x1f44d; 收藏? 留言&#x1f4dd; 加關注?! 本文由 花無缺 原創 收錄于專欄 【洛谷算法題】 文章目錄 【洛谷算法題】P1046-[NOIP2005 普及組] 陶陶摘蘋果【入門2分支結構】Java題解&#x1f30f;題目…

web性能優化(理論)

什么是性能優化&#xff1f; 就是讓用戶感覺你的網站加載速度很快。。。哈哈哈。 分析 讓我們來分析一下從用戶按下回車鍵到網站呈現出來經歷了哪些和前端相關的過程。 緩存 首先看本地是否有緩存&#xff0c;如果有符合使用條件的緩存則不需要向服務器發送請求了。DNS查詢建立…

python多項式回歸_如何在Python中實現多項式回歸模型

python多項式回歸Let’s start with an example. We want to predict the Price of a home based on the Area and Age. The function below was used to generate Home Prices and we can pretend this is “real-world data” and our “job” is to create a model which wi…

充分利用UC berkeleys數據科學專業

By Kyra Wong and Kendall Kikkawa黃凱拉(Kyra Wong)和菊川健多 ( Kendall Kikkawa) 什么是“數據科學”&#xff1f; (What is ‘Data Science’?) Data collection, an important aspect of “data science”, is not a new idea. Before the tech boom, every industry al…

文本二叉樹折半查詢及其截取值

using System;using System.ComponentModel;using System.Data;using System.Drawing;using System.Text;using System.Windows.Forms;using System.Collections;using System.IO;namespace CS_ScanSample1{ /// <summary> /// Logic 的摘要說明。 /// </summary> …

nn.functional 和 nn.Module入門講解

本文來自《20天吃透Pytorch》 一&#xff0c;nn.functional 和 nn.Module 前面我們介紹了Pytorch的張量的結構操作和數學運算中的一些常用API。 利用這些張量的API我們可以構建出神經網絡相關的組件(如激活函數&#xff0c;模型層&#xff0c;損失函數)。 Pytorch和神經網絡…

10.30PMP試題每日一題

SC>0&#xff0c;CPI<1&#xff0c;說明項目截止到當前&#xff1a;A、進度超前&#xff0c;成本超值B、進度落后&#xff0c;成本結余C、進度超前&#xff0c;成本結余D、無法判斷 答案將于明天和新題一起揭曉&#xff01; 10.29試題答案&#xff1a;A轉載于:https://bl…

02-web框架

1 while True:print(server is waiting...)conn, addr server.accept()data conn.recv(1024) print(data:, data)# 1.得到請求的url路徑# ------------dict/obj d["path":"/login"]# d.get(”path“)# 按著http請求協議解析數據# 專注于web業…

ai驅動數據安全治理_AI驅動的Web數據收集解決方案的新起點

ai驅動數據安全治理Data gathering consists of many time-consuming and complex activities. These include proxy management, data parsing, infrastructure management, overcoming fingerprinting anti-measures, rendering JavaScript-heavy websites at scale, and muc…

從Text文本中讀值插入到數據庫中

/// <summary> /// 轉換數據&#xff0c;從Text文本中導入到數據庫中 /// </summary> private void ChangeTextToDb() { if(File.Exists("Storage Card/Zyk.txt")) { try { this.RecNum.Visibletrue; SqlCeCommand sqlCreateTable…

Dataset和DataLoader構建數據通道

重點在第二部分的構建數據通道和第三部分的加載數據集 Pytorch通常使用Dataset和DataLoader這兩個工具類來構建數據管道。 Dataset定義了數據集的內容&#xff0c;它相當于一個類似列表的數據結構&#xff0c;具有確定的長度&#xff0c;能夠用索引獲取數據集中的元素。 而D…

鐵拳nat映射_鐵拳如何重塑我的數據可視化設計流程

鐵拳nat映射It’s been a full year since I’ve become an independent data visualization designer. When I first started, projects that came to me didn’t relate to my interests or skills. Over the past eight months, it’s become very clear to me that when cl…

Django2 Web 實戰03-文件上傳

作者&#xff1a;Hubery 時間&#xff1a;2018.10.31 接上文&#xff1a;接上文&#xff1a;Django2 Web 實戰02-用戶注冊登錄退出 視頻是一種可視化媒介&#xff0c;因此視頻數據庫至少應該存儲圖像。讓用戶上傳文件是個很大的隱患&#xff0c;因此接下來會討論這倆話題&#…

BZOJ.2738.矩陣乘法(整體二分 二維樹狀數組)

題目鏈接 BZOJ洛谷 整體二分。把求序列第K小的樹狀數組改成二維樹狀數組就行了。 初始答案區間有點大&#xff0c;離散化一下。 因為這題是一開始給點&#xff0c;之后詢問&#xff0c;so可以先處理該區間值在l~mid的修改&#xff0c;再處理詢問。即二分標準可以直接用點的標號…