【深度學習】2.單層感知機

目標:

實現一個簡單的二分類模型的訓練過程,通過模擬數據集進行訓練和優化,訓練目標是使模型能夠根據輸入特征正確分類數據。

演示:

1.通過PyTorch生成了一個模擬的二分類數據集,包括特征矩陣data_x和對應的標簽數據data_y。標簽數據通過基于特征的線性組合生成,并轉換成獨熱編碼的形式。

import torch
# 從torch庫中導入神經網絡模塊nn,用于構建神經網絡模型
from torch import nn
# 導入torch.nn模塊中的functional子模塊,可用于訪問各種函數,例如激活函數
import torch.nn.functional as Fn_item = 1000
n_feature = 2
learning_rate = 0.01
epochs = 100# 生成一個模擬的數據集,其中包括一個隨機生成的特征矩陣data_x和相應生成的標簽數據data_y。標簽數據通過基于特征的線性組合生成,并且轉換成獨熱編碼的形式。# 設置隨機數生成器的種子為123,通過設置隨機種子,我們可以確保在每次運行代碼時生成的隨機數相同,這對于結果的可重現性非常重要。
torch.manual_seed(123)
# 生成一個隨機數矩陣data_x,其中包含n_item行和n_feature列。矩陣中的元素是從標準正態分布(均值為0,標準差為1)中隨機采樣的。
data_x = torch.randn(size=(n_item, n_feature)).float()
# torch.where(...): 根據條件返回兩個張量中相應位置的值。如果條件成立,將為0,否則為1。  long(): 用于將張量轉換為Long型數據類型。
data_y = torch.where(torch.subtract(data_x[:, 0]*0.5, data_x[:, 1]*1.5)+0.02 > 0, 0, 1).long()
# 將標簽數據data_y轉換為獨熱編碼形式,即將每個標簽轉換為一個相應長度的獨熱向量
data_y = F.one_hot(data_y)# print(data_x)
# print(data_y)

2.定義了一個簡單的二分類模型BinaryClassificationModel,包含一個單層感知器(Single Perceptron)結構,其中使用了一個線性層和sigmoid激活函數,用于將輸入特征映射到概率空間。

# 定義了一個簡單的二分類模型,采用單層感知器的結構,包含一個線性層和sigmoid激活函數,用于將輸入特征映射到概率空間。這樣的模型可以用來對數據集進行二分類任務的預測。# 定義了一個名為BinaryClassificationModel的類,其繼承自nn.Module類,這意味著這個類是一個PyTorch模型。
class BinaryClassificationModel(nn.Module):def __init__(self, in_feature):# 調用了父類nn.Module的構造函數,確保正確初始化模型。super(BinaryClassificationModel, self).__init__()"""single perception"""# 這行代碼定義了模型的第一層,是一個線性層(Fully Connected Layer)。in_features參數指定輸入特征的數量,out_features指定輸出特征的數量,這里設置為2表示二分類問題。bias=True表示該層包含偏置項。self.layer_1 = nn.Linear(in_features=in_feature, out_features=2, bias=True)# 定義模型前向傳播的方法,即輸入數據x通過模型前向計算得到輸出。def forward(self, x):# 輸入數據x首先通過定義的線性層self.layer_1進行線性變換,然后通過F.sigmoid()函數進行激活函數處理。return F.sigmoid(self.layer_1(x))

3.創建了該二分類模型的實例model、使用隨機梯度下降(SGD)優化器opt、以及二分類問題常用的損失函數BCELoss(Binary Cross Entropy Loss)。

4.在訓練過程中,通過多個epoch和每個樣本的批處理(在這里是一次處理一個樣本),計算模型預測輸出和真實標簽之間的損失值,進行反向傳播計算梯度,并更新模型參數以最小化損失函數。

# 完成對模型的訓練過程,每個epoch中通過優化器進行參數更新,計算損失,反向傳播更新梯度。最終我們會得到訓練過程中每個epoch的損失值,并可以觀察損失的變化情況。# 創建了一個二分類模型實例model,參數n_feature表示輸入特征的數量。
model = BinaryClassificationModel(n_feature)
# 創建了一個隨機梯度下降(SGD)優化器opt,用于根據計算出的梯度更新模型參數。
opt = torch.optim.SGD(model.parameters(), lr=learning_rate)
# 創建了一個二分類問題常用的損失函數BCELoss(Binary Cross Entropy Loss),用于衡量模型輸出與真實標簽之間的差異。
criteria = nn.BCELoss()for epoch in range(epochs):# 對每個樣本進行訓練。for step in range(n_item):x = data_x[step]y = data_y[step]# 梯度清零,避免梯度累加影響優化結果。opt.zero_grad()# 將輸入特征x通過模型前向傳播得到預測輸出y_hat。unsqueeze(0)是因為我們的模型期望輸入是(batch_size, n_feature)的形式。y_hat = model(x.unsqueeze(0))# 計算預測輸出y_hat和真實標簽y之間的損失值。loss = criteria(y_hat, y.unsqueeze(0).float())# 反向傳播計算梯度。loss.backward()# 根據計算出的梯度更新模型參數。opt.step()print("Epoch: %03d, Loss: %.3f" % (epoch, loss.item()))

5.打印出每個epoch的序號和損失值,用于監控訓練過程中損失值的變化情況。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/15936.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/15936.shtml
英文地址,請注明出處:http://en.pswp.cn/web/15936.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

加密與安全_AES RSA 密鑰對生成及PEM格式的代碼實現

文章目錄 RSA(非對稱)和AES(對稱)加密算法一、RSA(Rivest-Shamir-Adleman)二、AES(Advanced Encryption Standard) RSA加密三種填充模式一、RSA填充模式二、常見的RSA填充模式組合三…

新業務 新市場 | 靈途科技新品亮相馬來西亞亞洲防務展

5月6日,靈途科技攜新品模組與武漢長盈通光電(股票代碼:688143)攜手參加第18屆馬來西亞亞洲防務展。首次亮相海外,靈途科技便收獲全球客戶的廣泛關注,為公司海外市場開拓打下堅實基礎。 靈途科技與長盈通共同…

Dbs封裝_連接池

1.Dbs封裝 每一個數據庫都對應著一個dao 每個dao勢必存在公共部分 我們需要將公共部分抽取出來 封裝成一個工具類 保留個性化代碼即可 我們的工具類一般命名為xxxs 比如Strings 就是字符串相關的工具類 而工具類 我們將其放置于util包中我們以是否有<T>區分泛型方法和非泛…

Python并發編程學習記錄

1、初識并發編程 1.1、串行&#xff0c;并行&#xff0c;并發 串行(serial)&#xff1a;一個cpu上按順序完成多個任務&#xff1b; 并行(parallelism)&#xff1a;任務數小于或等于cup核數&#xff0c;多個任務是同時執行的&#xff1b; 并發(concurrency)&#xff1a;一個…

計算機SCI期刊,IF=8+,專業性強,潛力新刊!

一、期刊名稱 Journal of Big data 二、期刊簡介概況 期刊類型&#xff1a;SCI 學科領域&#xff1a;計算機科學 影響因子&#xff1a;8.1 中科院分區&#xff1a;2區 出版方式&#xff1a;開放出版 版面費&#xff1a;$1990 三、期刊征稿范圍 《大數據雜志》發表了關于…

2024年【T電梯修理】考試內容及T電梯修理新版試題

題庫來源&#xff1a;安全生產模擬考試一點通公眾號小程序 2024年【T電梯修理】考試內容及T電梯修理新版試題&#xff0c;包含T電梯修理考試內容答案和解析及T電梯修理新版試題練習。安全生產模擬考試一點通結合國家T電梯修理考試最新大綱及T電梯修理考試真題匯總&#xff0c;…

線性dp合集,藍橋杯

貿易航線 0貿易航線 - 藍橋云課 (lanqiao.cn) n,m,kmap(int ,input().split()) #貪心的想&#xff0c;如果買某個東西利潤最大&#xff0c;那我肯定直接拉滿啊&#xff0c;所以買k個和買一個沒區別 p[0] for i in range(n):p.append([-1]list(map(int,input().split())))dp[[…

(2024,SDE,對抗薛定諤橋匹配,離散時間迭代馬爾可夫擬合,去噪擴散 GAN)

Adversarial Schrdinger Bridge Matching 公眾號&#xff1a;EDPJ&#xff08;進 Q 交流群&#xff1a;922230617 或加 VX&#xff1a;CV_EDPJ 進 V 交流群&#xff09; 目錄 0. 摘要 1. 簡介 4. 實驗 0. 摘要 薛定諤橋&#xff08;Schrdinger Bridge&#xff0c;SB&…

el-autocomplete后臺遠程搜索

el-complete可以實現后臺遠程搜索功能&#xff0c;但有時傳入數據為空時&#xff0c;接口可能會報錯。此時可在querySearchAsync方法中&#xff0c;根據queryString判斷&#xff0c;若為空&#xff0c;則不掉用接口&#xff0c;直接callback([])&#xff0c;反之則調用接口&…

浮點型比較大小

浮點數的存儲形式 浮點數按照在內存中所占字節數和數值范圍&#xff0c;可以分為浮點型&#xff0c;雙精度浮點型和長雙浮點型數。 代碼&#xff1a; printf("lgn:%e \n", pow(exp(1), 100));printf("lgn:%f ", pow(exp(1), 100));輸出結果&#xff1a; …

Stanford斯坦福 CS 224R: 深度強化學習 (5)

離線強化學習:第一部分 強化學習(RL)旨在讓智能體通過與環境交互來學習最優策略,從而最大化累積獎勵。傳統的RL訓練都是在線(online)進行的,即智能體在訓練過程中不斷與環境交互,實時生成新的狀態-動作數據,并基于新數據來更新策略。這種在線學習雖然簡單直觀,但也存在一些局限…

【Could not find Chrome This can occur if either】

爬蟲練習中遇到的問題 使用puppeteer執行是提示一下錯誤 Error: Could not find Chrome (ver. 125.0.6422.78). This can occur if either you did not perform an installation before running the script (e.g. npx puppeteer browsers install chrome) oryour cache path…

CLIP 論文的關鍵內容

CLIP 論文整體架構 該論文總共有 48 頁&#xff0c;除去最后的補充材料十頁去掉&#xff0c;正文也還有三十多頁&#xff0c;其中大部分篇幅都留給了實驗和響應的一些分析。 從頭開始的話&#xff0c;第一頁就是摘要&#xff0c;接下來一頁多是引言&#xff0c;接下來的兩頁就…

常用 CSS 寫法

不是最后一個 :not(:last-child)漸變色 background: linear-gradient(270deg, #15aaff 0%, #02396a 100%);文字漸變色 background-image: linear-gradient(to right, #ff7e5f, #feb47b); -webkit-background-clip: text; background-clip: text; color: transparent;

python文件IO基礎知識

目錄 1.open函數打開文件 2.文件對象讀寫數據和關閉 3.文本文件和二進制文件的區別 4.編碼和解碼 讀寫文本文件時 讀寫二進制文件時 5.文件指針位置 6.文件緩存區與flush()方法 1.open函數打開文件 使用 open 函數創建一個文件對象&#xff0c;read 方法來讀取數據&…

談談磁盤的那些操作

磁盤格式化 是指把一張空白的盤劃分成一個個小區域并編號&#xff0c;以供計算機存儲和讀取數據。格式化是一種純物理操作&#xff0c;是在磁盤的所有數據區上寫零的操作過程&#xff0c;同時對硬盤介質做一致性檢測&#xff0c;并且標記出不可讀和壞的扇區。由于大部分硬盤在…

電子技術學習路線

在小破站上看到大佬李皆寧的技術路線分析&#xff0c;再結合自己這幾年的工作。發現的確是這樣&#xff0c;跟著大佬的技術路線去學習是會輕松很多&#xff0c;現在想想&#xff0c;這路線其實跟大學四年的學習順序是很像的。 本期記錄學習路線&#xff0c;方便日后查看。 傳統…

python 深度圖生成點云(方法二)

深度圖生成點云 一、介紹1.1 概念1.2 思路1.3 函數講解二、代碼示例三、結果示例接上篇:深度圖生成點云(方法1) 一、介紹 1.1 概念 深度圖生成點云:根據深度圖像(depth image)和相機內參(camera intrinsics)生成點云(PointCloud)。 1.2 思路 點云坐標的計算公式如…

pillow學習7

繪制驗證碼 from PIL import Image,ImageFilter,ImageFont,ImageDraw import random width100 hight100 imImage.new(RGB,(width,hight),(255,255,255)) drawImageDraw.Draw(im) #獲取顏色 def get_color1():return (random.randint(200, 255), random.randint(200, 255), ran…

京東Java社招面試題真題,最新面試題

Java中接口與抽象類的區別是什么&#xff1f; 1、定義方式&#xff1a; 接口是完全抽象的&#xff0c;只能定義抽象方法和常量&#xff0c;不能有實現&#xff1b;而抽象類可以有抽象方法和具體實現的方法&#xff0c;也可以定義成員變量。 2、實現與繼承&#xff1a; 一個類…