cuda從零開始手搓PB神經網絡

cuda實現PB神經網絡


基于上一篇的矩陣點乘,實現了矩陣的加減乘除、函數調用等。并且復用之前元編程里面寫的梯度下降、Adam、NAdam優化方法。實現PB神經網絡如下:

#ifndef __BP_NETWORK_HPP__
#define __BP_NETWORK_HPP__
#include "matrix.hpp"
#include "mat.hpp"
#include "update_methods.hpp"template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_, int ... remain_layer>
struct bp_network
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using next_node_type = typename bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>;using term_output_type = typename next_node_type::term_output_type;weight_type weight;update_type_tpl<weight_type> weight_update_method;output_type bias;update_type_tpl<output_type> bias_update_method;input_type pre_input;output_type pre_func_input;next_node_type next_node;bp_network():weight_update_method(), bias_update_method(){weight.template reset<init_type>();bias.template reset<init_type>();next_node = bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>();}auto forward(input_type& input){output_type curr_output;pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;curr_output = pre_func_input.template activate<forward_func>();return next_node.forward(curr_output);}auto backward(term_output_type& delta, val_type lr){output_type curr_delta = next_node.backward(delta, lr);curr_delta = pre_func_input.template activate<backward_func>() * curr_delta;auto ret = weight.t_dot(curr_delta);// 更新參數weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update_method.update(weight, delta_weight);bias = bias_update_method.update(bias, curr_delta);return ret;}   // 更新慣性量void update_inert(){weight_update_method.update_inert();bias_update_method.update_inert();next_node.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("=================\n");next_node.print();}
};template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_>
struct bp_network<activate_type, val_type_, update_type_tpl, init_type, input_num_, output_num_>
{constexpr static int input_num = input_num_;constexpr static int output_num = output_num_;using val_type = val_type_;using input_type = mat<input_num, 1, val_type>;using input_t_type = mat<1, input_num, val_type>;using output_type = mat<output_num, 1, val_type>;using weight_type = mat<output_num, input_num, val_type>;using forward_func = typename func_pair<activate_type>::forward_func;using backward_func = typename func_pair<activate_type>::backward_func;using term_output_type = typename output_type;using weight_update_type = typename update_type_tpl<weight_type>;using bias_update_type = typename update_type_tpl<output_type>;weight_type weight;weight_update_type weight_update;output_type bias;bias_update_type bias_update;output_type pre_func_input;input_type pre_input;bp_network():weight_update(), bias_update(){weight.template reset<init_type>();bias.template reset<init_type>();}auto forward(input_type& input){pre_input = input;auto temp = weight.dot(input);pre_func_input = temp + bias;return pre_func_input.template activate<forward_func>();}auto backward(output_type& delta, val_type lr){output_type curr_delta = pre_func_input.template activate<backward_func>() * delta;auto ret = weight.t_dot(curr_delta);// 更新參數weight_type delta_weight = curr_delta.dot(pre_input.t());weight = weight_update.update(weight, delta_weight);bias = bias_update.update(bias, curr_delta);return ret;}void update_inert(){weight_update.update_inert();bias_update.update_inert();}void print(){weight.print();printf("-----------------\n");bias.print();printf("*****************\n");}
};#endif

下面實驗一下我們的bp神經網絡。

#include <chrono>
#include <thread>
#include "matrix.hpp"
#include "bp_network.hpp"
int main()
{constexpr int row_num = 32;constexpr int adj_num = 32;constexpr int col_num = 32;/*matrix_device_proxy<row_num, adj_num, double> A;eyes(A(), 2.0);matrix_device_proxy<adj_num, col_num, double> B;eyes(B(), 1.0);matrix_device_proxy<row_num, col_num, double> C;mat_dot<sigmoid>(A(), B(), C());print(type_cast(C()));auto A = mat<row_num, adj_num, double>::eyes(2.0);auto B = mat<adj_num, col_num, double>::eyes(1.0);auto C = A.dot(B);C = C + 1.0;C = sqrtl(C);C = C - 2.0;C = C * 3.0;C = C / 4.0;C.print();std::cout << "---------- D ----------" << std::endl;auto D = mat<row_num, col_num, double>::xavier_gaussian();D.print();std::cout << "---------- E ----------" << std::endl;auto E = mat<row_num, col_num, double>::xavier_mean();E.print();std::cout << "---------- F ----------" << std::endl;auto F = mat<row_num, col_num, double>::he_gaussian();F.print();std::cout << "---------- G ----------" << std::endl;auto G = mat<row_num, col_num, double>::he_mean();G.print();*/bp_network<sigmoid, double, nadam, xavier_gaussian_type, row_num, adj_num, col_num> node;auto input = mat<row_num, 1, double>::ones(0.2);auto expect = mat<col_num, 1, double>::ones(0.4);int times = 8000;int update_inert_times = 100;int step = times / update_inert_times;// 計時開始auto start = std::chrono::high_resolution_clock::now();for (int i = 0; i < times; ++i){auto output = node.forward(input);auto delta = (output - expect);node.backward(delta, 0.001);if (i == times - 1){output.t().print();}if (i % step == 0 && i != 0){node.update_inert();}}// 計時結束// 獲取結束時間點auto end = std::chrono::high_resolution_clock::now();// 計算持續時間std::chrono::duration<double> duration = end - start;// 輸出執行時間std::cout << "Execution time: " << duration.count() << " seconds" << std::endl;//node.print();cudaDeviceReset();return 0;
}

以上代碼有個學習率lr沒有地方設置哈,將來優化,見諒。執行結果如下:
在這里插入圖片描述
可以看出,經過8000次的訓練,這個使用sigmoid激活函數、NAdam優化、Xavier-Gaussian初始化的323232的PB能夠將誤差縮減到0.0001這個量級,而訓練時間僅為8.54秒。還是相當給力的。
雖然這對于我的工作沒有任何關系,但是我還是想搞一下。畢竟“越是沒用的知識就越有用,越是有用的東西就越沒用”。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/66381.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/66381.shtml
英文地址,請注明出處:http://en.pswp.cn/web/66381.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Next.js 實戰 (八):使用 Lodash 打包構建產生的“坑”?

前言 最近一直在折騰 Nextjs15 &#xff0c;也在斷斷續續地寫《Next.js15 實戰系列》的文章&#xff0c;后來總感覺文章如果沒有線上效果預覽差點意思&#xff0c;所以就想著先把目前做的項目先部署上線&#xff0c;后續再慢慢添加新功能。 因為之前沒有部署過 Nextjs15 工程…

我的世界-與門、或門、非門等基本門電路實現

一、紅石比較器 (1) 紅石比較器結構 紅石比較器有前端單火把、后端雙火把以及兩個側端 其中后端和側端是輸入信號,前端是輸出信號 (2) 紅石比較器的兩種模式 比較模式 前端火把未點亮時處于比較模式 側端>后端 → 0 當任一側端強度大于后端強度時,輸出…

項目開發實踐——基于SpringBoot+Vue3實現的在線考試系統(七)

文章目錄 一、題庫管理模塊實現1、新增題目功能實現1.1 頁面設計1.2 前端功能實現1.3 后端功能實現1.4 效果展示2、題目列表功能實現2.1 頁面設計2.2 前端功能實現2.3 后端功能實現2.3.1 后端查詢題目列表接口實現2.3.2 后端編輯試題接口實現2.4 效果展示二、代碼下載一、題庫管…

打破編程“鄙視鏈”:探索行業發展新路徑

前言&#xff1a;哈嘍&#xff0c;大家好&#xff0c;今天給大家分享一篇文章&#xff01;并提供具體代碼幫助大家深入理解&#xff0c;徹底掌握&#xff01;創作不易&#xff0c;如果能幫助到大家或者給大家一些靈感和啟發&#xff0c;歡迎收藏關注哦 &#x1f495; 目錄 打破…

【統計的思想】假設檢驗(一)

假設檢驗是統計學里的重要方法&#xff0c;同時也是一種“在理想與現實之間觀察求索”的測試活動。假設檢驗從概率的角度去考察理想與現實之間的關系&#xff0c;籍此來緩解測試可信性問題。 我們先來看一個例子。民航旅客服務系統&#xff0c;簡稱PSS系統&#xff0c;有一種業…

SpringBoot2 + Flowable(UI)

文章目錄 引言I 技術棧軟件架構基于 Vue.js 和 Element UI 的后臺管理系統工程結構II 依賴rest,logic,conf 的依賴工作流flowable jar包flowable-ui所需jar包III 配置jdbc 配置 nullCatalogMeansCurrent = true引言 I 技術棧 軟件架構 前端基于vue 、element-ui框架分模塊設…

Linux探秘坊-------3.開發工具詳解(1)

1 初識vim編輯器 創建第一個vim編輯的代碼 1.新建文件 2.使用vim打開 3.打開默認是命令模式&#xff0c;寫代碼需要在屏幕上輸出“i”字符 1.寫完代碼后要按Esc鍵退出到指令模式2.再按shift:wq即可保存并退出vim &#xff08;因為不支持鼠標&#xff0c;通常 使用鍵盤上的箭…

基于海思soc的智能產品開發(高、中、低soc、以及和fpga的搭配)

【 聲明&#xff1a;版權所有&#xff0c;歡迎轉載&#xff0c;請勿用于商業用途。 聯系信箱&#xff1a;feixiaoxing 163.com】 市場上關于圖像、音頻的soc其實非常多&#xff0c;這里面有高、中、低檔&#xff0c;開發方式也不相同。之所以會這樣&#xff0c;有價格的因素&am…

51單片機——DS18B20溫度傳感器

由于DS18B20數字溫度傳感器是單總線接口&#xff0c;所以需要使用51單片機的一個IO口模擬單總線時序與DS18B20通信&#xff0c;將檢測的環境溫度讀取出來 1、DS18B20模塊電路 傳感器接口的單總線管腳接至單片機P3.7IO口上 2、DS18B20介紹 2.1 DS18B20外觀實物圖 管腳1為GN…

STL容器-- list的模擬實現(附源碼)

STL容器-- list的模擬實現&#xff08;附源碼&#xff09; List的實現主要時考察我們對list這一容器的理解&#xff0c;和代碼的編寫能力&#xff0c;通過上節對list容器的使用&#xff0c;我們對list容器已經有了一些基本的了解&#xff0c;接下來就讓我們來實現一些list容器常…

Redis 學習指南與資料分享

Redis學習資料 Redis學習資料 Redis學習資料 Redis 作為一款高性能內存數據庫&#xff0c;在當今軟件開發領域占據著重要地位。其豐富的數據類型、強大的功能特性以及廣泛的應用場景&#xff0c;吸引著眾多開發者深入學習。以下為你精心整理的 Redis 學習指南與實用資料分享&…

Lynx TiDB 慢日志收集工具

作者&#xff1a; 小龍蝦愛大龍蝦 原文來源&#xff1a; https://tidb.net/blog/7247e68f 簡介 lynx 工具可以定時將 TiDB 集群的慢查詢收集并持久化到后端數據庫中&#xff0c;然后通過 grafana 查詢展示出來&#xff0c;這可以幫助我們更好的分析慢查詢日志。 背景 盡管…

Gin 源碼概覽 - 路由

本文基于gin 1.1 源碼解讀 https://github.com/gin-gonic/gin/archive/refs/tags/v1.1.zip 1. 注冊路由 我們先來看一段gin代碼&#xff0c;來看看最終得到的一顆路由樹長啥樣 func TestGinDocExp(t *testing.T) {engine : gin.Default()engine.GET("/api/user", f…

docker 基礎語法學習,K8s基礎語法學習,零基礎學習

下面是關于Docker和Kubernetes的基礎語法學習資料&#xff0c;包括一些關鍵概念和示例代碼。 Docker 基礎語法 1. 安裝 Docker 首先&#xff0c;你需要安裝 Docker。以下是不同操作系統上的安裝指南&#xff1a; Windows/Mac: 下載并安裝 Docker Desktop。 Linux: 根據你的…

【逆境中綻放:萬字回顧2024我在挑戰中突破自我】

&#x1f308;個人主頁: Aileen_0v0 &#x1f525;熱門專欄: 華為鴻蒙系統學習|計算機網絡|數據結構與算法 ?&#x1f4ab;個人格言:“沒有羅馬,那就自己創造羅馬~” 文章目錄 一、引言二、個人成長與盤點情感與心理成長學習與技能提升其它榮譽 三、年度創作歷程回顧創作內容概…

職場溝通與行為

職場溝通與行為 引言 在職場上&#xff0c;你是否曾遇到過困惑的溝通&#xff1f;是否對同事的行為有過疑慮&#xff1f;這不僅是個別現象&#xff0c;而是我們這個時代工作文化中的普遍問題。許多職場的摩擦&#xff0c;來自溝通不暢或是行為不當。那么&#xff0c;如何才能…

【Linux 重裝】Ubuntu 啟動盤 U盤無法被識別,如何處理?

背景 U盤燒錄了 Ubuntu 系統作為啟動盤&#xff0c;再次插入電腦后無法被識別 解決方案&#xff08;Mac 適用&#xff09; &#xff08;1&#xff09;查找 USB&#xff0c;&#xff08;2&#xff09;格式化&#xff08;1&#xff09;在 terminal 中通過 diskutil list 查看是…

中職網絡建設與運維ansible服務

ansible服務 填寫hosts指定主機范圍和控制節點后創建一個腳本&#xff0c;可以利用簡化腳本 1. 在linux1上安裝系統自帶的ansible-core,作為ansible控制節點,linux2-linux7作為ansible的受控節點 Linux1 Linux1-7 Yum install ansible-core -y Vi /etc/ansible/hosts 添加…

數據庫服務體系結構

1. 數據庫服務應用配置 服務進行配置有什么作用&#xff1f; 實現服務運行啟動 實現某些功能 應用配置有三種方式&#xff1f; 利用編譯安裝進行配置 編寫配置文件信息 ,.默認的配置文件: /etc/my.cnf 利用啟動命令參數配置信息&#xff0c;mysqld_safe --skip-grant-tables --…

Langchain+FastApi+Vue前后端Ai對話(超詳細)

一、引入 首先可以先看下作者的文章 FastApi相關文章&#xff1a;創建最簡單FastApi的項目Vue相關文章&#xff1a;最簡單的aixos二次封裝Langchain相關文章&#xff1a;如何使用LangSmith跟蹤deepseek模型 二、后端搭建 1 項目文件結構 routers&#xff1a;存放api接口se…