介紹 C++ 中的智能指針及其應用:以 PyTorch框架自動梯度AutogradMeta為例

介紹 C++ 中的智能指針及其應用:以 AutogradMeta 為例

在 C++ 中,智能指針(Smart Pointer)是用于管理動態分配內存的一種工具。它們不僅自動管理內存的生命周期,還能幫助避免內存泄漏和野指針等問題。在深度學習框架如 PyTorch 的實現中,智能指針被廣泛應用于復雜的數據結構和計算圖的管理中。本文將結合 AutogradMeta 類,詳細介紹 C++ 中的智能指針,解釋 std::shared_ptrstd::weak_ptrstd::unique_ptr 等智能指針的使用場景及區別。

Source: https://github.com/pytorch/pytorch/blob/00df63f09f07546bacec734f37132edc58ccf574/torch/csrc/autograd/variable.h#L102

//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
//                            AutogradMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~/// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
/// metadata fields that are necessary for tracking the Variable's autograd
/// history. As an optimization, a Variable may store a nullptr, in lieu of a
/// default constructed AutogradMeta.struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {std::string name_;Variable grad_;std::shared_ptr<Node> grad_fn_;std::weak_ptr<Node> grad_accumulator_;// This field is used to store all the forward AD gradients// associated with this AutogradMeta (and the Tensor it corresponds to)// There is a semantic 1:1 correspondence between AutogradMeta and// ForwardGrad but://   - This field is lazily populated.//   - This field is a shared_ptr but it must never be//     shared by multiple Tensors. See Note [ Using ForwardGrad ]// Any transition from not_initialized to initialized// must be protected by mutex_mutable std::shared_ptr<ForwardGrad> fw_grad_;// The hooks_ field is actually reused by both python and cpp logic// For both cases, we have a data structure, cpp_hooks_list_ (cpp)// or dict (python) which is the canonical copy.// Then, for both cases, we always register a single hook to// hooks_ which wraps all the hooks in the list/dict.// And, again in both cases, if the grad_fn exists on that tensor// we will additionally register a single hook to the grad_fn.//// Note that the cpp and python use cases aren't actually aware of// each other, so using both is not defined behavior.std::vector<std::unique_ptr<FunctionPreHook>> hooks_;std::shared_ptr<hooks_list> cpp_hooks_list_;// The post_acc_grad_hooks_ field stores only Python hooks// (PyFunctionTensorPostAccGradHooks) that are called after the// .grad field has been accumulated into. This is less complicated// than the hooks_ field, which encapsulates a lot more.std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;// Only meaningful on leaf variables (must be false otherwise)bool requires_grad_{false};// Only meaningful on non-leaf variables (must be false otherwise)bool retains_grad_{false};bool is_view_{false};// The "output number" of this variable; e.g., if this variable// was the second output of a function, then output_nr == 1.// We use this to make sure we can setup the backwards trace// correctly when this variable is passed to another function.uint32_t output_nr_;// Mutex to ensure that concurrent read operations that modify internal// state are still thread-safe. Used by grad_fn(), grad_accumulator(),// fw_grad() and set_fw_grad()// This is mutable because we need to be able to acquire this from const// version of this class for the functions abovemutable std::mutex mutex_;/// Sets the `requires_grad` property of `Variable`. This should be true for/// leaf variables that want to accumulate gradients, and false for all other/// variables.void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) final {TORCH_CHECK(!requires_grad ||isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())),"Only Tensors of floating point and complex dtype can require gradients");requires_grad_ = requires_grad;}bool requires_grad() const override {return requires_grad_ || grad_fn_;}/// Accesses the gradient `Variable` of this `Variable`.Variable& mutable_grad() override {return grad_;}const Variable& grad() const override {return grad_;}const Variable& fw_grad(uint64_t level, const at::TensorBase& self)const override;void set_fw_grad(const at::TensorBase& new_grad,const at::TensorBase& self,uint64_t level,bool is_inplace_op) override;AutogradMeta(at::TensorImpl* self_impl = nullptr,bool requires_grad = false,Edge gradient_edge = Edge()): grad_fn_(std::move(gradient_edge.function)),output_nr_(gradient_edge.input_nr) {// set_requires_grad also checks error conditions.if (requires_grad) {TORCH_INTERNAL_ASSERT(self_impl);set_requires_grad(requires_grad, self_impl);}TORCH_CHECK(!grad_fn_ || !requires_grad_,"requires_grad should be false if grad_fn is set");}~AutogradMeta() override {// If AutogradMeta is being destroyed, it means that there is no other// reference to its corresponding Tensor. It implies that no other thread// can be using this object and so there is no need to lock mutex_ here to// guard the check if fw_grad_ is populated.if (fw_grad_) {// See note [ Using ForwardGrad ]fw_grad_->clear();}}
};
1. AutogradMeta 類中的智能指針

AutogradMeta 是一個用于存儲與自動求導(Autograd)相關元數據的數據結構。它包含了多種智能指針,例如:

  • std::shared_ptr<Node> grad_fn_;
  • std::weak_ptr<Node> grad_accumulator_;
  • mutable std::shared_ptr<ForwardGrad> fw_grad_;
  • std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_;

這些智能指針的應用各有不同,它們的主要作用是管理計算圖中的節點、梯度、鉤子函數等數據結構的生命周期。

2. std::shared_ptrstd::weak_ptr 的區別

首先,讓我們從 std::shared_ptr<Node>std::weak_ptr<Node> 開始講解。

  • std::shared_ptr<Node> grad_fn_; 是一個共享指針,表示該對象(grad_fn_)的所有者。一個 std::shared_ptr 會通過引用計數來管理對象的生命周期。當一個 shared_ptr 被復制時,引用計數會增加,而當指針超出作用域或被重置時,引用計數會減少,直到計數為 0 時對象會被銷毀。

    示例代碼:

    std::shared_ptr<Node> grad_fn = std::make_shared<Node>();
    // 在此,grad_fn 是 Node 類型對象的所有者
    

    應用場景:

    • std::shared_ptr 適用于需要共享資源所有權的場景。比如,在 AutogradMeta 中,grad_fn_ 指向的是梯度計算的計算圖節點,該節點可能會被多個 Variable 共享,因此使用 std::shared_ptr 可以確保計算圖在不再使用時被自動銷毀。
  • std::weak_ptr<Node> grad_accumulator_; 是一個弱指針,通常與 std::shared_ptr 配合使用。它不會影響對象的引用計數,因此不會阻止對象的銷毀。std::weak_ptr 適用于觀察共享資源但不擁有其所有權的場景。

    示例代碼:

    std::shared_ptr<Node> shared_ptr_node = std::make_shared<Node>();
    std::weak_ptr<Node> weak_ptr_node = shared_ptr_node;
    

    應用場景:

    • std::weak_ptr 常用于防止循環引用。在 AutogradMeta 中,grad_accumulator_ 可能指向一個梯度累加器對象,但我們并不想讓它擁有該對象的所有權,因此使用 std::weak_ptr。這樣,當沒有任何 shared_ptr 指向該對象時,累加器會被銷毀,避免內存泄漏。
3. mutable std::shared_ptr<ForwardGrad> fw_grad_;

mutable 關鍵字在這里的作用是允許即使在 const 對象上也能修改 fw_grad_ 成員變量。在 AutogradMeta 中,fw_grad_ 用于存儲與正向自動求導相關的梯度。由于該對象的生命周期是動態管理的,所以它使用了 std::shared_ptr

示例代碼:

mutable std::shared_ptr<ForwardGrad> fw_grad_;

應用場景:

  • AutogradMeta 類中,fw_grad_ 可能在對象生命周期內多次更新,因此需要一個 std::shared_ptr 來管理其內存。同時,mutable 允許即使 AutogradMeta 對象是 const 類型時,也可以修改 fw_grad_,這對線程安全和優化非常重要。
4. std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_;

std::unique_ptr 是獨占指針,表示某個資源只能由一個指針管理。當 std::unique_ptr 被銷毀時,它所管理的資源會被釋放。

示例代碼:

std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;

應用場景:

  • AutogradMeta 中,post_acc_grad_hooks_ 用于存儲 Python 特定的鉤子函數。這些鉤子函數會在梯度累加后執行,因此使用 std::unique_ptr 確保鉤子對象的獨占管理,避免多個指針同時擁有該對象的所有權。
5. std::shared_ptrstd::weak_ptrstd::unique_ptr 對比
智能指針類型主要特點使用場景
std::shared_ptr共享所有權,通過引用計數管理資源的生命周期;多個指針可以共享資源用于共享資源所有權,確保資源在最后一個指針被銷毀時被釋放。
std::weak_ptr不增加資源的引用計數,不控制資源的生命周期;可以觀察資源用于避免循環引用和觀察對象的生命周期。
std::unique_ptr獨占所有權,確保資源只被一個指針管理用于資源的獨占管理,確保資源在超出作用域時被釋放。
6. 總結與應用場景

在 C++ 中,智能指針是非常強大的工具,可以有效避免內存泄漏、野指針和循環引用等問題。std::shared_ptrstd::weak_ptrstd::unique_ptr 各有各的特點,能夠應對不同的資源管理需求。結合 AutogradMeta 這樣的復雜數據結構,智能指針幫助我們確保計算圖、梯度和鉤子等資源的安全管理。

  • std::shared_ptr 適用于需要共享資源所有權的場景,如計算圖的節點。
  • std::weak_ptr 適用于觀察資源但不控制其生命周期的場景,如梯度累加器。
  • std::unique_ptr 適用于獨占資源所有權的場景,如梯度累加后的鉤子函數。

通過合理選擇智能指針類型,能夠顯著提升代碼的安全性和可維護性,減少內存管理上的錯誤。


以上就是對 C++ 中智能指針的詳細介紹及其在 AutogradMeta 類中的應用。希望通過這個例子,讀者能夠更加清晰地理解智能指針的區別及其適用場景。

附錄:override關鍵字

override 關鍵字詳解

在 C++ 中,override 是一個用于顯式聲明虛函數重寫的關鍵字。它告訴編譯器,當前成員函數是用來重寫基類中的虛函數的。如果基類中沒有對應的虛函數,編譯器將生成一個錯誤,從而幫助開發者捕獲潛在的錯誤。

語法和作用

在類的成員函數后面加上 override,表示該函數是重寫了基類中的一個虛函數。如果基類中沒有定義該函數,或者該函數的簽名不匹配,編譯器將報錯。

語法示例:

class Base {
public:virtual void foo() {// 基類的實現}
};class Derived : public Base {
public:void foo() override { // 重寫基類的 foo 函數// 派生類的實現}
};

在上面的代碼中,Derived 類中的 foo 函數用 override 關鍵字顯式地聲明為重寫基類 Base 中的 foo 函數。如果基類的 foo 函數沒有定義為虛函數,或者派生類中的 foo 函數簽名與基類的不一致,編譯器會給出錯誤。

為什么要使用 override
  1. 避免拼寫錯誤和簽名錯誤: 使用 override 可以幫助程序員確保派生類中的函數簽名完全匹配基類中的虛函數簽名。如果有拼寫錯誤或簽名不匹配,編譯器會在編譯時提醒我們,避免在運行時遇到潛在的問題。

  2. 提高代碼可讀性: override 顯示了一個函數是基類虛函數的重寫,有助于代碼閱讀者理解該函數是被派生類特意重寫的,而不是無意間添加的。

  3. 增強可維護性: 如果將來基類的虛函數發生了修改,override 可以幫助發現派生類中需要更新的地方,從而避免一些潛在的bug。

overrideset_fw_grad 中的應用

在您提供的代碼片段中:

void set_fw_grad(const at::TensorBase& new_grad,const at::TensorBase& self,uint64_t level,bool is_inplace_op) override;

override 表示 set_fw_grad 函數重寫了基類中的一個虛函數。這個函數的作用可能是設置正向梯度(forward gradient)。如果基類中沒有定義 set_fw_grad 或其簽名不同,編譯器會報錯,提醒開發者檢查是否正確實現了虛函數。

示例:虛函數重寫的完整示例
#include <iostream>class Base {
public:// 聲明一個虛函數virtual void set_fw_grad(const std::string& new_grad) {std::cout << "Base class set_fw_grad: " << new_grad << std::endl;}
};class Derived : public Base {
public:// 重寫基類的虛函數,并加上 override 關鍵字void set_fw_grad(const std::string& new_grad) override {std::cout << "Derived class set_fw_grad: " << new_grad << std::endl;}
};int main() {Base* obj = new Derived();obj->set_fw_grad("Gradient Data"); // 調用的是 Derived 類的重寫函數delete obj;return 0;
}

輸出:

Derived class set_fw_grad: Gradient Data
總結
  • override 是 C++11 引入的一個關鍵字,用于顯式標識派生類中的成員函數重寫了基類中的虛函數。
  • 它增強了代碼的安全性,幫助避免常見的編程錯誤,如函數簽名不匹配等問題。
  • set_fw_grad 中,override 確保該函數是正確地重寫了基類的虛函數。如果基類中沒有定義相應的虛函數,編譯器會發出錯誤提示。

后記

2025年1月3日15點33分于上海,在GPT4o大模型輔助下完成。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/64915.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/64915.shtml
英文地址,請注明出處:http://en.pswp.cn/web/64915.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

python +t kinter繪制彩虹和云朵

python t kinter繪制彩虹和云朵 彩虹&#xff0c;簡稱虹&#xff0c;是氣象中的一種光學現象&#xff0c;當太陽光照射到半空中的水滴&#xff0c;光線被折射及反射&#xff0c;在天空上形成拱形的七彩光譜&#xff0c;由外圈至內圈呈紅、橙、黃、綠、藍、靛、紫七種顏色。事實…

Zabbix5.0版本(監控Nginx+PHP服務狀態信息)

目錄 1.監控Nginx服務狀態信息 &#xff08;1&#xff09;通過Nginx監控模塊&#xff0c;監控Nginx的7種狀態 &#xff08;2&#xff09;開啟Nginx狀態模塊 &#xff08;3&#xff09;配置監控項 &#xff08;4&#xff09;創建模板 &#xff08;5&#xff09;用默認鍵值…

Python入門教程 —— 字符串

字符串介紹 字符串可以理解為一段普通的文本內容,在python里,使用引號來表示一個字符串,不同的引號表示的效果會有區別。 字符串表示方式 a = "Im Tom" # 一對雙引號 b = Tom said:"I am Tom" # 一對單引號c = Tom said:"I\m Tom" # 轉義…

AcWing練習題:差

讀取四個整數 A,B,C,D&#xff0c;并計算 (AB?CD)的值。 輸入格式 輸入共四行&#xff0c;第一行包含整數 A&#xff0c;第二行包含整數 B&#xff0c;第三行包含整數 C&#xff0c;第四行包含整數 D。 輸出格式 輸出格式為 DIFERENCA X&#xff0c;其中 X 為 (AB?CD) 的…

小程序添加購物車業務邏輯

數據庫設計 DTO設計 實現步驟 1 判斷當前加入購物車中的的商品是否已經存在了 2 如果已經存在 只需要將數量加一 3 如果不存在 插入一條購物車數據 4 判斷加到本次購物車的是菜品還是套餐 Impl代碼實現 Service public class ShoppingCartServiceImpl implements Shoppin…

如何在谷歌瀏覽器中使用自定義搜索快捷方式

在數字時代&#xff0c;瀏覽器已經成為我們日常生活中不可或缺的一部分。作為最常用的瀏覽器之一&#xff0c;谷歌瀏覽器憑借其簡潔的界面和強大的功能深受用戶喜愛。本文將詳細介紹如何自定義谷歌瀏覽器的快捷工具欄&#xff0c;幫助你更高效地使用這一工具。 一、如何找到谷歌…

Python 3 與 Python 2 的主要區別

文章目錄 1. 語法與關鍵字print 函數整數除法 2. 字符串處理默認字符串類型字符串格式化 3. 輸入函數4. 迭代器和生成器range 函數map, filter, zip 5. 標準庫變化urllib 模塊configparser 模塊 6. 異常處理7. 移除的功能8. 其他重要改進數據庫操作多線程與并發類型注解 9. 總結…

關于IDE的相關知識之二【插件推薦】

成長路上不孤單&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【14后&#x1f60a;///計算機愛好者&#x1f60a;///持續分享所學&#x1f60a;///如有需要歡迎收藏轉發///&#x1f60a;】 今日分享關于ide插件推薦的相關內容&#xff01…

如何獲取穩定高效的動態代理?

在數據采集的領域&#xff0c;動態代理IP是我們探索網絡世界的小助手&#xff0c;它不僅幫助我們高效地收集信息&#xff0c;還能在保護數據安全方面發揮重要作用。但如何在眾多選擇中找到最適合的那個——即穩定且高效的動態代理也是一大難題。 明確你的需求 首先&#xff0…

基于微信小程序的校園點餐平臺的設計與實現(源碼+SQL+LW+部署講解)

文章目錄 摘 要1. 第1章 選題背景及研究意義1.1 選題背景1.2 研究意義1.3 論文結構安排 2. 第2章 相關開發技術2.1 前端技術2.2 后端技術2.3 數據庫技術 3. 第3章 可行性及需求分析3.1 可行性分析3.2 系統需求分析 4. 第4章 系統概要設計4.1 系統功能模塊設計4.2 數據庫設計 5.…

原生js封裝ajax請求以及css實現提示效果和禁止點擊效果

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0,user-scalableno"><title>本地模式網絡切換</title>&l…

Pytorch的自動求導模塊

文章目錄 torch.autograd.backward()基本用法非標量張量的反向傳播保留計算圖指定輸入張量高階梯度計算 與 y.backward() 的區別torch.autograd.grad()基本用法非標量張量的梯度高階梯度計算多輸入、多輸出的梯度計算未使用的輸入張量保留計算圖 與 backward() 的區別 torch.au…

Mac OS

本文來自智譜清言 ------ Mac OS&#xff08;現稱為macOS&#xff09;是蘋果公司開發和銷售的操作系統&#xff0c;自1984年推出以來&#xff0c;它已經經歷了多次重大的演變和發展。 起源&#xff1a;Mac OS 1.0的誕生 - 1984年&#xff0c;蘋果發布了Macintosh計算機&#…

spring中使用@Validated,什么是JSR 303數據校驗,spring boot中怎么使用數據校驗

文章目錄 一、JSR 303后臺數據校驗1.1 什么是 JSR303&#xff1f;1.2 為什么使用 JSR 303&#xff1f; 二、Spring Boot 中使用數據校驗2.1 基本注解校驗2.1.1 使用步驟2.1.2 舉例Valid注解全局統一異常處理 2.2 分組校驗2.2.1 使用步驟2.2.2 舉例Validated注解Validated和Vali…

ubuntu常用快捷鍵和變量記錄

alias b‘cd …/’ alias bb‘cd …/…/’ alias bbb‘cd …/…/…/’ alias bbbb‘cd …/…/…/…/’ alias bbbbb‘cd …/…/…/…/…/’ alias bbbbbb‘cd …/…/…/…/…/…/’ alias apkinfo‘aapt dump badging’ alias npp‘notepad-plus-plus’ export ANDROID_HOME/h…

AWS S3文件存儲工具類

pom依賴 <!--aws-s3--> <dependency><groupId>com.amazonaws</groupId><artifactId>aws-java-sdk-s3</artifactId><version>1.12.95</version></dependency>S3Utils import cn.hutool.core.util.ZipUtil; import com.a…

【SOC 芯片設計 DFT 學習專欄 -- 測試向量生成 ATPG (Automatic Test Pattern Generation) 】

文章目錄 OverviewATPG 的基本功能ATPG 的工作流程ATPG 應用場景示例示例 1&#xff1a;檢測單個信號的 Stuck-at Fault示例 2&#xff1a;針對 Transition Fault 的 ATPG ATPG 工具與常用工具鏈ATPG 優化與挑戰 Overview 本文主要介紹 DFT scan 中的 ATPG 功能。在 DFT (Desi…

2024 高通邊緣智能創新應用大賽智能邊緣計算賽道冠軍方案解讀

2024 高通邊緣智能創新應用大賽聚焦不同細分領域的邊緣智能創新應用落地&#xff0c;共設立三大熱門領域賽道——工業智能質檢賽道、智能邊緣計算賽道和智能機器人賽道。本文為智能邊緣計算賽道冠軍項目《端側大模型智能翻譯機》的開發思路與成果分享。 賽題要求 聚焦邊緣智能…

【Python運維】用Python和Ansible實現高效的自動化服務器配置管理

《Python OpenCV從菜鳥到高手》帶你進入圖像處理與計算機視覺的大門! 解鎖Python編程的無限可能:《奇妙的Python》帶你漫游代碼世界 隨著云計算和大規模數據中心的興起,自動化配置管理已經成為現代IT運維中不可或缺的一部分。通過自動化,企業可以大幅提高效率,降低人為錯…

微信小程序獲取后端數據

在小程序中獲取后端接口數據 通常可以使用 wx.request 方法&#xff0c;以下是一個基本示例&#xff1a; // pages/index/index.js Page({data: {// 用于存儲后端返回的數據resultData: [] },onLoad() {this.fetchData();},fetchData() {wx.request({url: https://your-backe…