介紹 C++ 中的智能指針及其應用:以 AutogradMeta
為例
在 C++ 中,智能指針(Smart Pointer)是用于管理動態分配內存的一種工具。它們不僅自動管理內存的生命周期,還能幫助避免內存泄漏和野指針等問題。在深度學習框架如 PyTorch 的實現中,智能指針被廣泛應用于復雜的數據結構和計算圖的管理中。本文將結合 AutogradMeta
類,詳細介紹 C++ 中的智能指針,解釋 std::shared_ptr
、std::weak_ptr
、std::unique_ptr
等智能指針的使用場景及區別。
Source: https://github.com/pytorch/pytorch/blob/00df63f09f07546bacec734f37132edc58ccf574/torch/csrc/autograd/variable.h#L102
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// AutogradMeta
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~/// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
/// metadata fields that are necessary for tracking the Variable's autograd
/// history. As an optimization, a Variable may store a nullptr, in lieu of a
/// default constructed AutogradMeta.struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {std::string name_;Variable grad_;std::shared_ptr<Node> grad_fn_;std::weak_ptr<Node> grad_accumulator_;// This field is used to store all the forward AD gradients// associated with this AutogradMeta (and the Tensor it corresponds to)// There is a semantic 1:1 correspondence between AutogradMeta and// ForwardGrad but:// - This field is lazily populated.// - This field is a shared_ptr but it must never be// shared by multiple Tensors. See Note [ Using ForwardGrad ]// Any transition from not_initialized to initialized// must be protected by mutex_mutable std::shared_ptr<ForwardGrad> fw_grad_;// The hooks_ field is actually reused by both python and cpp logic// For both cases, we have a data structure, cpp_hooks_list_ (cpp)// or dict (python) which is the canonical copy.// Then, for both cases, we always register a single hook to// hooks_ which wraps all the hooks in the list/dict.// And, again in both cases, if the grad_fn exists on that tensor// we will additionally register a single hook to the grad_fn.//// Note that the cpp and python use cases aren't actually aware of// each other, so using both is not defined behavior.std::vector<std::unique_ptr<FunctionPreHook>> hooks_;std::shared_ptr<hooks_list> cpp_hooks_list_;// The post_acc_grad_hooks_ field stores only Python hooks// (PyFunctionTensorPostAccGradHooks) that are called after the// .grad field has been accumulated into. This is less complicated// than the hooks_ field, which encapsulates a lot more.std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;// Only meaningful on leaf variables (must be false otherwise)bool requires_grad_{false};// Only meaningful on non-leaf variables (must be false otherwise)bool retains_grad_{false};bool is_view_{false};// The "output number" of this variable; e.g., if this variable// was the second output of a function, then output_nr == 1.// We use this to make sure we can setup the backwards trace// correctly when this variable is passed to another function.uint32_t output_nr_;// Mutex to ensure that concurrent read operations that modify internal// state are still thread-safe. Used by grad_fn(), grad_accumulator(),// fw_grad() and set_fw_grad()// This is mutable because we need to be able to acquire this from const// version of this class for the functions abovemutable std::mutex mutex_;/// Sets the `requires_grad` property of `Variable`. This should be true for/// leaf variables that want to accumulate gradients, and false for all other/// variables.void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) final {TORCH_CHECK(!requires_grad ||isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())),"Only Tensors of floating point and complex dtype can require gradients");requires_grad_ = requires_grad;}bool requires_grad() const override {return requires_grad_ || grad_fn_;}/// Accesses the gradient `Variable` of this `Variable`.Variable& mutable_grad() override {return grad_;}const Variable& grad() const override {return grad_;}const Variable& fw_grad(uint64_t level, const at::TensorBase& self)const override;void set_fw_grad(const at::TensorBase& new_grad,const at::TensorBase& self,uint64_t level,bool is_inplace_op) override;AutogradMeta(at::TensorImpl* self_impl = nullptr,bool requires_grad = false,Edge gradient_edge = Edge()): grad_fn_(std::move(gradient_edge.function)),output_nr_(gradient_edge.input_nr) {// set_requires_grad also checks error conditions.if (requires_grad) {TORCH_INTERNAL_ASSERT(self_impl);set_requires_grad(requires_grad, self_impl);}TORCH_CHECK(!grad_fn_ || !requires_grad_,"requires_grad should be false if grad_fn is set");}~AutogradMeta() override {// If AutogradMeta is being destroyed, it means that there is no other// reference to its corresponding Tensor. It implies that no other thread// can be using this object and so there is no need to lock mutex_ here to// guard the check if fw_grad_ is populated.if (fw_grad_) {// See note [ Using ForwardGrad ]fw_grad_->clear();}}
};
1. AutogradMeta
類中的智能指針
AutogradMeta
是一個用于存儲與自動求導(Autograd)相關元數據的數據結構。它包含了多種智能指針,例如:
std::shared_ptr<Node> grad_fn_;
std::weak_ptr<Node> grad_accumulator_;
mutable std::shared_ptr<ForwardGrad> fw_grad_;
std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_;
這些智能指針的應用各有不同,它們的主要作用是管理計算圖中的節點、梯度、鉤子函數等數據結構的生命周期。
2. std::shared_ptr
和 std::weak_ptr
的區別
首先,讓我們從 std::shared_ptr<Node>
和 std::weak_ptr<Node>
開始講解。
-
std::shared_ptr<Node> grad_fn_;
是一個共享指針,表示該對象(grad_fn_
)的所有者。一個std::shared_ptr
會通過引用計數來管理對象的生命周期。當一個shared_ptr
被復制時,引用計數會增加,而當指針超出作用域或被重置時,引用計數會減少,直到計數為 0 時對象會被銷毀。示例代碼:
std::shared_ptr<Node> grad_fn = std::make_shared<Node>(); // 在此,grad_fn 是 Node 類型對象的所有者
應用場景:
std::shared_ptr
適用于需要共享資源所有權的場景。比如,在AutogradMeta
中,grad_fn_
指向的是梯度計算的計算圖節點,該節點可能會被多個Variable
共享,因此使用std::shared_ptr
可以確保計算圖在不再使用時被自動銷毀。
-
std::weak_ptr<Node> grad_accumulator_;
是一個弱指針,通常與std::shared_ptr
配合使用。它不會影響對象的引用計數,因此不會阻止對象的銷毀。std::weak_ptr
適用于觀察共享資源但不擁有其所有權的場景。示例代碼:
std::shared_ptr<Node> shared_ptr_node = std::make_shared<Node>(); std::weak_ptr<Node> weak_ptr_node = shared_ptr_node;
應用場景:
std::weak_ptr
常用于防止循環引用。在AutogradMeta
中,grad_accumulator_
可能指向一個梯度累加器對象,但我們并不想讓它擁有該對象的所有權,因此使用std::weak_ptr
。這樣,當沒有任何shared_ptr
指向該對象時,累加器會被銷毀,避免內存泄漏。
3. mutable std::shared_ptr<ForwardGrad> fw_grad_;
mutable
關鍵字在這里的作用是允許即使在 const
對象上也能修改 fw_grad_
成員變量。在 AutogradMeta
中,fw_grad_
用于存儲與正向自動求導相關的梯度。由于該對象的生命周期是動態管理的,所以它使用了 std::shared_ptr
。
示例代碼:
mutable std::shared_ptr<ForwardGrad> fw_grad_;
應用場景:
- 在
AutogradMeta
類中,fw_grad_
可能在對象生命周期內多次更新,因此需要一個std::shared_ptr
來管理其內存。同時,mutable
允許即使AutogradMeta
對象是const
類型時,也可以修改fw_grad_
,這對線程安全和優化非常重要。
4. std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_;
std::unique_ptr
是獨占指針,表示某個資源只能由一個指針管理。當 std::unique_ptr
被銷毀時,它所管理的資源會被釋放。
示例代碼:
std::unique_ptr<PostAccumulateGradHook> post_acc_grad_hooks_ = nullptr;
應用場景:
- 在
AutogradMeta
中,post_acc_grad_hooks_
用于存儲 Python 特定的鉤子函數。這些鉤子函數會在梯度累加后執行,因此使用std::unique_ptr
確保鉤子對象的獨占管理,避免多個指針同時擁有該對象的所有權。
5. std::shared_ptr
、std::weak_ptr
、std::unique_ptr
對比
智能指針類型 | 主要特點 | 使用場景 |
---|---|---|
std::shared_ptr | 共享所有權,通過引用計數管理資源的生命周期;多個指針可以共享資源 | 用于共享資源所有權,確保資源在最后一個指針被銷毀時被釋放。 |
std::weak_ptr | 不增加資源的引用計數,不控制資源的生命周期;可以觀察資源 | 用于避免循環引用和觀察對象的生命周期。 |
std::unique_ptr | 獨占所有權,確保資源只被一個指針管理 | 用于資源的獨占管理,確保資源在超出作用域時被釋放。 |
6. 總結與應用場景
在 C++ 中,智能指針是非常強大的工具,可以有效避免內存泄漏、野指針和循環引用等問題。std::shared_ptr
、std::weak_ptr
和 std::unique_ptr
各有各的特點,能夠應對不同的資源管理需求。結合 AutogradMeta
這樣的復雜數據結構,智能指針幫助我們確保計算圖、梯度和鉤子等資源的安全管理。
std::shared_ptr
適用于需要共享資源所有權的場景,如計算圖的節點。std::weak_ptr
適用于觀察資源但不控制其生命周期的場景,如梯度累加器。std::unique_ptr
適用于獨占資源所有權的場景,如梯度累加后的鉤子函數。
通過合理選擇智能指針類型,能夠顯著提升代碼的安全性和可維護性,減少內存管理上的錯誤。
以上就是對 C++ 中智能指針的詳細介紹及其在 AutogradMeta
類中的應用。希望通過這個例子,讀者能夠更加清晰地理解智能指針的區別及其適用場景。
附錄:override關鍵字
override
關鍵字詳解
在 C++ 中,override
是一個用于顯式聲明虛函數重寫的關鍵字。它告訴編譯器,當前成員函數是用來重寫基類中的虛函數的。如果基類中沒有對應的虛函數,編譯器將生成一個錯誤,從而幫助開發者捕獲潛在的錯誤。
語法和作用
在類的成員函數后面加上 override
,表示該函數是重寫了基類中的一個虛函數。如果基類中沒有定義該函數,或者該函數的簽名不匹配,編譯器將報錯。
語法示例:
class Base {
public:virtual void foo() {// 基類的實現}
};class Derived : public Base {
public:void foo() override { // 重寫基類的 foo 函數// 派生類的實現}
};
在上面的代碼中,Derived
類中的 foo
函數用 override
關鍵字顯式地聲明為重寫基類 Base
中的 foo
函數。如果基類的 foo
函數沒有定義為虛函數,或者派生類中的 foo
函數簽名與基類的不一致,編譯器會給出錯誤。
為什么要使用 override
?
-
避免拼寫錯誤和簽名錯誤: 使用
override
可以幫助程序員確保派生類中的函數簽名完全匹配基類中的虛函數簽名。如果有拼寫錯誤或簽名不匹配,編譯器會在編譯時提醒我們,避免在運行時遇到潛在的問題。 -
提高代碼可讀性:
override
顯示了一個函數是基類虛函數的重寫,有助于代碼閱讀者理解該函數是被派生類特意重寫的,而不是無意間添加的。 -
增強可維護性: 如果將來基類的虛函數發生了修改,
override
可以幫助發現派生類中需要更新的地方,從而避免一些潛在的bug。
override
在 set_fw_grad
中的應用
在您提供的代碼片段中:
void set_fw_grad(const at::TensorBase& new_grad,const at::TensorBase& self,uint64_t level,bool is_inplace_op) override;
override
表示 set_fw_grad
函數重寫了基類中的一個虛函數。這個函數的作用可能是設置正向梯度(forward gradient)。如果基類中沒有定義 set_fw_grad
或其簽名不同,編譯器會報錯,提醒開發者檢查是否正確實現了虛函數。
示例:虛函數重寫的完整示例
#include <iostream>class Base {
public:// 聲明一個虛函數virtual void set_fw_grad(const std::string& new_grad) {std::cout << "Base class set_fw_grad: " << new_grad << std::endl;}
};class Derived : public Base {
public:// 重寫基類的虛函數,并加上 override 關鍵字void set_fw_grad(const std::string& new_grad) override {std::cout << "Derived class set_fw_grad: " << new_grad << std::endl;}
};int main() {Base* obj = new Derived();obj->set_fw_grad("Gradient Data"); // 調用的是 Derived 類的重寫函數delete obj;return 0;
}
輸出:
Derived class set_fw_grad: Gradient Data
總結
override
是 C++11 引入的一個關鍵字,用于顯式標識派生類中的成員函數重寫了基類中的虛函數。- 它增強了代碼的安全性,幫助避免常見的編程錯誤,如函數簽名不匹配等問題。
- 在
set_fw_grad
中,override
確保該函數是正確地重寫了基類的虛函數。如果基類中沒有定義相應的虛函數,編譯器會發出錯誤提示。
后記
2025年1月3日15點33分于上海,在GPT4o大模型輔助下完成。