TVM Relay源碼深度解讀
文章目錄
- TVM Relay源碼深度解讀
- 一 、從Constant看Relay表達式的設計哲學
- 1. 類定義概述
- 2. `ConstantNode` 詳解
- 1. 核心成員
- 2. 關鍵方法
- 3. 類型系統注冊
- 3. `Constant` 詳解
- 1. 核心功能
- 二. 核心內容概述
- (1) Relay表達式基類
- 1. RelayExprNode 和 RelayExpr 的區別與用法
- 2. 主要區別
- 3. 使用模式
- 例子1:常量表達式
- 例子2:變量表達式
- 例子3:函數應用
- 4. 實際使用建議
- (2) 具體表達式類型
- 1. 表達式類型 VarNode舉例子
- 1. 核心設計理念
- 2. 關鍵成員解析
- (1) 核心字段
- (2) 特殊方法
- 3. 變量標識系統
- (1) vid (Unique ID)
- (2) name_hint 與 vid 的關系
- 4. 類型系統整合
- (1) 類型注解流程
- (2) 類型推導規則
- 5. 內存模型與跨語言交互
- (1) C++ 層構造
- (2) Python 綁定
- **(3) 對象生命周期**
- 6. 關鍵應用場景
- (1) 函數參數定義
- (2) 優化 Pass 中的變量處理
- (3) 類型檢查
- 7. 設計亮點總結
- 8. 典型問題分析
- (3) TVM_DECLARE_BASE_OBJECT_INFO 宏詳解
- 1. 宏的參數
- 2. 靜態斷言檢查(防止非法繼承)
- 2. 運行時類型索引(RuntimeTypeIndex)
- 3. 動態分配類型索引(_GetOrAllocRuntimeTypeIndex)
- 通俗版解釋:TVM的類型身份證系統
- 1. 為什么要辦身份證?
- 2. 辦證過程(宏的作用)
- 3. 特殊班級(FINAL版)
- 4. 實際有什么用?
- 舉個栗子🌰
- 一句話總結
- (4) 遍歷接口
- 1. C++ 場景示例
- (1) 模型序列化(保存為JSON)
- (2) 優化Pass中的常量修改
- (3) 調試打印
- 2. Python 場景示例
- (1) 直接屬性訪問
- (2) 模型保存與加載
- (3) 自定義屬性訪問器
一 、從Constant看Relay表達式的設計哲學
??在TVM的Relay IR中,即使是看似簡單的常量表達式relay.const(1),其背后也隱藏著整個類型系統的精妙設計。讓我們從include/tvm/relay/expr.h
中的Constant類入手,逐步拆解…"
1. 類定義概述
類名 | 繼承關系 | 角色 | 關鍵特性 |
---|---|---|---|
ConstantNode | public ExprNode | 常量表達式的實際數據存儲 | 包含常量數據(NDArray )、類型信息,并實現屬性訪問、哈希和相等比較邏輯。 |
Constant | public RelayExpr | 常量表達式的智能指針封裝 | 提供用戶友好的構造函數和訪問方法,隱藏內存管理細節。 |
2. ConstantNode
詳解
class ConstantNode : public ExprNode {public:/*! \brief The data of the tensor */runtime::NDArray data;/*! \return The corresponding tensor type of the data */TensorType tensor_type() const;/*! \return Whether it is scalar(rank-0 tensor) */bool is_scalar() const { return data->ndim == 0; }void VisitAttrs(tvm::AttrVisitor* v) {v->Visit("data", &data);v->Visit("span", &span);v->Visit("mdata", &mdata);v->Visit("_checked_type_", &checked_type_);}bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {return equal(data, other->data);}void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }static constexpr const char* _type_key = "relay.Constant";TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};
1. 核心成員
-
data
(runtime::NDArray
)- 存儲常量張量的實際數據(如權重、偏置等),TVM 使用
NDArray
統一表示多維數組。 - 示例:卷積層的權重矩陣會被存儲在這里。
- 存儲常量張量的實際數據(如權重、偏置等),TVM 使用
-
tensor_type()
- 根據
data
的維度(shape
)和數據類型(dtype
)自動生成對應的TensorType
。 - 用途:類型推斷時確定常量的類型。
- 根據
-
is_scalar()
- 判斷常量是否為標量(0維張量),如
data->ndim == 0
。
- 判斷常量是否為標量(0維張量),如
2. 關鍵方法
-
VisitAttrs
- 實現屬性的序列化/反序列化,支持以下字段:
v->Visit("data", &data); // 張量數據 v->Visit("span", &span); // 源碼位置信息 v->Visit("mdata", &mdata); // 元數據(如調試信息) v->Visit("_checked_type_", &checked_type_); // 類型檢查后的類型
- 實現屬性的序列化/反序列化,支持以下字段:
-
SEqualReduce
和SHashReduce
- 結構化相等比較:比較兩個
ConstantNode
的data
是否相同(用于優化中的常量折疊)。 - 哈希計算:基于
data
生成哈希值(用于快速查找重復常量)。
- 結構化相等比較:比較兩個
3. 類型系統注冊
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
_type_key = "relay.Constant"
:唯一標識常量節點類型。FINAL
:禁止繼承,確保常量節點的行為不可被修改。
3. Constant
詳解
class Constant : public Expr {public:/*!* \brief The constructor* \param data The data of the constant tensor.* \param span The source span of the expression.*/TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};
1. 核心功能
-
構造函數
explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());
- 接收
NDArray
數據,構造一個常量表達式。 - 示例:
# Python 前端等價代碼 data = np.array([1, 2, 3], dtype="float32") const_expr = relay.Constant(tvm.nd.array(data))
- 接收
-
智能指針方法
TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
展開后提供:
operator->()
:直接訪問ConstantNode
成員(如const_expr->data
)。get()
:獲取底層ConstantNode
指針。- 自動內存管理(通過
ObjectRef
的引用計數)。
二. 核心內容概述
??在TVM源碼中,include/tvm/relay/expr.h
是 Relay IR(中間表示)的核心頭文件,定義了所有Relay表達式的基礎數據結構和類型系統。它是實現TVM高層計算圖表示的關鍵組成部分。以下是該文件的詳細解析:
相關重要文件
文件路徑 | 關聯內容 |
---|---|
include/tvm/relay/type.h | 類型系統(TensorType等) |
include/tvm/relay/op.h | 運算符定義 |
include/tvm/relay/adt.h | 代數數據類型支持 |
src/relay/ir/expr.cc | 表達式方法的實現 |
include/tvm/relay/expr.h
文件主要包含:
- (1) Relay表達式基類(
RelayExpr
/RelayExprNode
) - (2) 所有具體表達式類型的聲明(如變量、常量、函數調用等)
- (3) 表達式類型的遍歷和轉換接口
- (4) 類型系統和屬性訪問的支持
(1) Relay表達式基類
class RelayExprNode : public BaseExprNode { /*...*/ };
class RelayExpr : public BaseExpr { /*...*/ };
- 角色:所有Relay表達式的公共基類
- 功能:
- 提供類型系統支持(通過
checked_type_
字段) - 實現屬性訪問(
VisitAttrs
) - 支持結構化相等比較(
SEqualReduce
)
- 提供類型系統支持(通過
1. RelayExprNode 和 RelayExpr 的區別與用法
??RelayExprNode
是 Relay 表達式的實際實現類,是一個 C++ 類,包含了表達式的所有數據和功能實現。它是所有 Relay 表達式類型的基類。
??RelayExpr
是一個智能指針(relay::Expr
),它指向 RelayExprNode
或其子類的實例。它提供了對 RelayExprNode
的安全訪問和管理。
2. 主要區別
特性 | RelayExprNode | RelayExpr |
---|---|---|
類型 | C++ 類 | 智能指針(std::shared_ptr 的封裝) |
生命周期管理 | 需要手動管理 | 自動管理 |
使用方式 | 通常不直接使用,作為實現細節 | 用戶主要交互的接口 |
繼承關系 | 作為基類定義表達式結構 | 作為訪問接口 |
3. 使用模式
在 TVM 中,通常的模式是:
- 定義一個繼承自
RelayExprNode
的具體表達式節點類 - 使用
RelayExpr
作為這些節點的引用
例子1:常量表達式
// 創建一個常量表達式
auto const_node = relay::ConstantNode::make(tvm::runtime::NDArray::Zeros(...));
RelayExpr const_expr = const_node;// 通常更簡潔的寫法
RelayExpr const_expr = relay::Constant(tvm::runtime::NDArray::Zeros(...));
例子2:變量表達式
// 創建一個變量表達式
auto var_node = relay::VarNode::make("x", relay::Type());
RelayExpr var_expr = var_node;// 或者更簡潔地
RelayExpr var_expr = relay::Var("x", relay::Type());
例子3:函數應用
// 創建函數應用表達式
RelayExpr func = ...; // 某個函數
RelayExpr arg = ...; // 某個參數
auto call_node = relay::CallNode::make(func, {arg});
RelayExpr call_expr = call_node;// 或者
RelayExpr call_expr = relay::Call(func, {arg});
4. 實際使用建議
-
用戶代碼:在大多數情況下,你應該使用
RelayExpr
而不是直接操作RelayExprNode
。 -
擴展 Relay:如果你想定義新的表達式類型,需要繼承
RelayExprNode
并實現相應接口。 -
類型轉換:可以使用
as<T>
方法將RelayExpr
向下轉換為特定類型的節點指針:
RelayExpr expr = ...;
if (const auto* call = expr.as<CallNode>()) {// 現在可以訪問 CallNode 的特定成員call->op;call->args;
}
- 創建新表達式:TVM 提供了輔助函數來創建表達式,通常以節點類型名去掉 “Node” 命名(如
relay::Var()
創建VarNode
的RelayExpr
)。
這種分離設計使得 Relay IR 既靈活又安全,同時保持了良好的性能特性
(2) 具體表達式類型
表達式類型 | 說明 | 關鍵成員/方法 |
---|---|---|
VarNode | 變量(輸入/中間結果) | String name_hint , Type type_annotation , Id vid |
ConstantNode | 常量張量(如模型權重) | runtime::NDArray data , tensor_type() , is_scalar() |
CallNode | 函數/運算符調用 | Expr op , Array<Expr> args , Attrs attrs , Array<Type> type_args |
LetNode | Let綁定(實現變量作用域) | Var var , Expr value , Expr body |
TupleNode | 元組結構(多返回值) | Array<Expr> fields |
TupleGetItemNode | 從元組中獲取元素 | Expr tuple , int index |
IfNode | 條件表達式 | Expr cond , Expr true_branch , Expr false_branch |
OpNode | 基本運算符(如add/concat) | 通過Op::Get("op_name") 獲取 |
FunctionNode | 函數定義(在function.h 中聲明,但屬于表達式) | Array<Var> params , Expr body , Type ret_type , Array<TypeVar> type_params |
RefCreateNode | 創建可變引用(用于狀態更新) | Expr value |
RefReadNode | 讀取引用值 | Expr ref |
RefWriteNode | 更新引用值 | Expr ref , Expr value |
ConstructorNode | 代數數據類型(ADT)的構造器(在adt.h 中聲明) | String tag , Array<Type> inputs |
MatchNode | 模式匹配(ADT處理) | Expr data , Array<Clause> clauses |
TempExprNode | 臨時表達式(用于優化過程中的中間表示) | 通常作為優化Pass的中間載體 |
GlobalVarNode | 全局函數引用(跨模塊調用) | String name_hint |
SeqExprNode | 順序執行多個表達式(類似語句塊) | Array<Binding> bindings , Expr body |
1. 表達式類型 VarNode舉例子
include/tvm/relay/expr.h
class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {public:/*!* \brief The unique identifier of the Var.** vid will be preserved for the same Var during type inference* and other rewritings, while the VarNode might be recreated* to attach additional information.* This property can be used to keep track of parameter Var* information across passes.*/Id vid;/*!* \brief type annotaion of the variable.* This field records user provided type annotation of the Var.* This field is optional and can be None.*/Type type_annotation;/*! \return The name hint of the variable */const String& name_hint() const { return vid->name_hint; }void VisitAttrs(tvm::AttrVisitor* v) {v->Visit("vid", &vid);v->Visit("type_annotation", &type_annotation);v->Visit("span", &span);v->Visit("mdata", &mdata);v->Visit("_checked_type_", &checked_type_);}bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {return equal(type_annotation, other->type_annotation) && equal.FreeVarEqualImpl(this, other);}void SHashReduce(SHashReducer hash_reduce) const {hash_reduce(type_annotation);hash_reduce.FreeVarHashImpl(this);}static constexpr const char* _type_key = "relay.Var";TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
};class Var : public Expr {public:/*!* \brief The constructor* \param name_hint The name hint of a variable.* \param type_annotation The type annotation of a variable.* \param span The source span of the expression.*/TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span(), MetaData mdata = MetaData()): Var(Id(name_hint), type_annotation, span, mdata) {}/*!* \brief The constructor* \param vid The unique id of a variable.* \param type_annotation The type annotation of a variable.* \param span The source span of the expression.*/TVM_DLL Var(Id vid, Type type_annotation, Span span = Span(), MetaData mdata = MetaData());TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};
1. 核心設計理念
VarNode
和 Var
共同實現了 Relay IR 的變量系統,采用 TVM 標準的 Object-ObjectRef 設計模式:
VarNode
:存儲實際數據的節點類(繼承自ExprNode
)Var
:管理VarNode
的智能指針包裝類(繼承自Expr
)
2. 關鍵成員解析
(1) 核心字段
成員 | 類型 | 作用 |
---|---|---|
vid | Id | 唯一標識符,跨 Pass 保持不變(即使節點被重建) |
type_annotation | Type | 用戶顯式指定的類型注解(可空) |
name_hint() | String | 通過 vid->name_hint 獲取的可讀名稱(非唯一) |
span | Span | 源碼位置信息(用于錯誤定位) |
mdata | MetaData | 擴展元數據 |
(2) 特殊方法
方法 | 功能 |
---|---|
SEqualReduce | 結構化相等比較(用于優化 Pass 的重復檢測) |
SHashReduce | 哈希計算(支持快速查找) |
VisitAttrs | 屬性序列化/反序列化 |
3. 變量標識系統
(1) vid (Unique ID)
class IdNode : public Object {public:String name_hint;// ... 其他元數據
};
- 核心特性:
- 通過
Id(name_hint)
構造,但系統會保證其唯一性 - 即使優化 Pass 重建變量節點,
vid
保持不變 - 用于跨 Pass 跟蹤參數變量(如梯度更新時識別同一參數)
- 通過
(2) name_hint 與 vid 的關系
x = relay.var("input", shape=(1,3)) # 實際創建:# vid = Id("input_0x7f") (自動去重)# name_hint = "input" (用戶友好)
4. 類型系統整合
(1) 類型注解流程
graph TDA[用戶構造] -->|relay.var(..., dtype="float32")| B(type_annotation)B --> C[類型檢查]C -->|更新| D(_checked_type_)
(2) 類型推導規則
- 若
type_annotation
存在:必須與實際使用類型兼容 - 若為空:從上下文推斷類型
5. 內存模型與跨語言交互
(1) C++ 層構造
// 方式1:通過 name_hint
Var x("data", TensorType({1,3}, DataType::Float(32)));// 方式2:直接指定 Id
Var x(Id("data_0x7f"), TensorType({1,3}, DataType::Float(32)));
(2) Python 綁定
# Python 前端接口
x = relay.var(name="input",shape=(1,3),dtype="float32",span=SourceSpan(...)
)
(3) 對象生命周期
sequenceDiagramPython->>C++: relay.var() 創建請求C++->>Heap: 分配 VarNodeC++->>Python: 返回 Var(ObjectRef)Python->>C++: 析構時觸發引用計數-1
6. 關鍵應用場景
(1) 函數參數定義
def build_linear():x = relay.var("x", shape=(1,3))w = relay.var("w", shape=(3,2))b = relay.var("b", shape=(2,))y = relay.add(relay.matmul(x, w), b)return relay.Function([x, w, b], y)
(2) 優化 Pass 中的變量處理
// 在 ConstantFolding 中識別變量引用
if (const VarNode* var = expr.as<VarNode>()) {if (var_map.count(var->vid)) {// 替換為已知常量}
}
(3) 類型檢查
// 檢查變量類型是否匹配
bool CheckType(const VarNode* var, const Type& expected) {return var->checked_type().as<TensorType>()->dtype == expected;
}
7. 設計亮點總結
- 穩定性:
vid
保證變量在優化過程中的持久標識 - 靈活性:
type_annotation
支持顯式/隱式類型指定 - 安全性:
TVM_DECLARE_FINAL_OBJECT_INFO
防止錯誤繼承 - 可調試性:
span
和name_hint
增強錯誤可讀性 - 性能:
SEqualReduce
/SHashReduce
優化圖操作效率
8. 典型問題分析
Q: 為什么需要同時存在 vid
和 name_hint
?
A: 分工不同:
name_hint
:面向用戶,提供可讀性(允許重復)vid
:面向系統,保證唯一性和跨Pass一致性
Q: 何時會重建 VarNode
?
A: 典型場景:
- 類型推斷后附加
_checked_type_
- 優化 Pass 中克隆表達式時保留原
vid
但新建節點
(3) TVM_DECLARE_BASE_OBJECT_INFO 宏詳解
??這個宏是 TVM 類型系統的核心,用于在 C++ 中動態注冊和管理對象的類型信息。它的核心作用是: 為每個類自動生成類型注冊代碼,使其能被 TVM 運行時識別和操作。
1. 宏的參數
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
TypeName
:當前類名(如ConstantNode
)ParentType
:父類名(如ExprNode
)
2. 靜態斷言檢查(防止非法繼承)
static_assert(!ParentType::_type_final, "ParentObj marked as final");
- 作用:如果父類被標記為
final
(通過_type_final
),則禁止子類繼承。
2. 運行時類型索引(RuntimeTypeIndex)
static uint32_t RuntimeTypeIndex() {// 檢查子類槽位配置是否合法static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 ||TypeName::_type_child_slots < ParentType::_type_child_slots,"子類槽位數不能超過父類限制");// 如果已預分配類型ID,直接返回if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {return TypeName::_type_index;}// 否則動態分配return _GetOrAllocRuntimeTypeIndex();
}
- 功能:返回類的唯一類型 ID(
uint32_t
)。 - 優化:優先使用預分配的
_type_index
(性能更高),否則動態分配。
3. 動態分配類型索引(_GetOrAllocRuntimeTypeIndex)
static uint32_t _GetOrAllocRuntimeTypeIndex() {static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex(TypeName::_type_key, // 類型名稱字符串(如 "relay.Constant")TypeName::_type_index, // 預分配的類型IDParentType::RuntimeTypeIndex(), // 父類類型IDTypeName::_type_child_slots, // 為子類預留的槽位數TypeName::_type_child_slots_can_overflow // 是否允許超額);return tidx;
}
- 作用:向 TVM 運行時注冊類型,并分配唯一 ID。
- 關鍵參數:
_type_child_slots
:限制子類數量(防止類型爆炸)。_type_child_slots_can_overflow
:為true
時允許突破限制。
通俗版解釋:TVM的類型身份證系統
你可以把TVM的類型系統想象成一個學校的學生管理系統,而TVM_DECLARE_BASE_OBJECT_INFO
就是給學生(類)辦身份證的機器:
1. 為什么要辦身份證?
- 每個學生(類)需要唯一學號(類型ID)
- 需要知道他的班主任是誰(父類)
- 防止有人冒充轉校生(非法繼承)
2. 辦證過程(宏的作用)
// 給"小明同學"辦證,班主任是"李老師"
TVM_DECLARE_BASE_OBJECT_INFO(小明, 李老師)
這個宏會自動做三件事:
-
檢查家世清白
static_assert(!李老師::是final班, "班主任明確不收新學生!");
- 如果班主任聲明"我們班不接收轉學生",就報錯
-
分配學號
- 優先用預留的VIP學號(
_type_index
) - 沒有就現場搖號(
_GetOrAllocRuntimeTypeIndex
)
- 優先用預留的VIP學號(
-
登記親屬關系
學號 = 教務處.登記(姓名:"小明",班主任:李老師.學號,可帶小弟人數:3 // _type_child_slots );
3. 特殊班級(FINAL版)
TVM_DECLARE_FINAL_OBJECT_INFO(學霸班, 實驗班)
- 相當于在班級門口掛**“禁止轉入”**牌子
- 其他班同學想轉學過來會直接報錯
4. 實際有什么用?
- 查身份證快:
obj->IsInstance<小明>()
比查戶口本快 - 安全轉班:
obj.as<小明>()
能安全轉換類型 - 防止冒名頂替:禁止隨便認爹(錯誤繼承)
舉個栗子🌰
# Python前端定義一個"漢堡店"類
@register_relay_node("food.HamburgerShop")
class HamburgerShopNode(ExprNode):_type_key = "food.HamburgerShop"_type_child_slots = 2 # 允許開2家分店
C++層通過這個宏:
- 給漢堡店分配類型ID(比如9527)
- 記錄它的父類是ExprNode
- 允許最多2個子類(比如
CheeseBurgerShop
、ChickenBurgerShop
)
一句話總結
這個宏就是TVM給類發身份證+建家族檔案的工具,讓系統能:
- ? 快速識別"你是誰"(類型檢查)
- ? 知道"你爸是誰"(繼承關系)
- ? 防止"亂認親戚"(非法繼承)
(4) 遍歷接口
void VisitAttrs(tvm::AttrVisitor* v) {v->Visit("data", &data);v->Visit("span", &span);v->Visit("mdata", &mdata);v->Visit("_checked_type_", &checked_type_);}
??VisitAttrs
是 TVM 中用于統一序列化、反序列化和屬性訪問的核心接口。以下是 ConstantNode
使用該函數的具體示例,涵蓋 C++ 和 Python 場景:
1. C++ 場景示例
(1) 模型序列化(保存為JSON)
// 創建常量節點
runtime::NDArray arr = runtime::NDArray::Empty({2, 2}, DLDataType{kDLFloat, 32, 1}, DLContext{kDLCPU, 0});
ConstantNode* const_node = new ConstantNode();
const_node->data = arr;// 序列化為JSON
JSONAttrVisitor visitor;
const_node->VisitAttrs(&visitor); // 觸發以下調用:// visitor.Visit("data", &data)// visitor.Visit("span", &span)...
std::string json = visitor.GetJSON();
輸出JSON片段:
{"type_key": "relay.Constant","data": {"b64": "AABAA...", "dtype": "float32", "shape": [2, 2]},"span": null,"_checked_type_": "TensorType([2,2], float32)"
}
(2) 優化Pass中的常量修改
class ConstantMutator : public AttrMutator {public:void VisitAttrs(AttrVisitor* v) override {if (v->IsMutator()) { // 檢查是否為修改模式runtime::NDArray new_data = ...; // 生成新數據v->Visit("data", &new_data); // 修改data字段}}
};// 調用示例:
ConstantMutator mutator;
const_node->VisitAttrs(&mutator); // 修改常量數據
(3) 調試打印
class DebugPrinter : public AttrVisitor {public:void Visit(const char* key, runtime::NDArray* data) override {std::cout << key << ": shape=" << data.Shape();}
};DebugPrinter printer;
const_node->VisitAttrs(&printer); // 輸出:data: shape=[2,2]
2. Python 場景示例
(1) 直接屬性訪問
import tvm
from tvm import relay# 創建常量
data = tvm.nd.array(np.zeros((2,2), dtype="float32"))
const = relay.Constant(data)# Python屬性訪問(背后調用VisitAttrs)
print(const.data) # 訪問NDArray → 觸發Visit("data", &data)
print(const.span) # 訪問源碼位置 → Visit("span", &span)
輸出:
<tvm.nd.NDArray shape=(2, 2), dtype=float32>
None # 未設置span時的默認值
(2) 模型保存與加載
# 保存模型(觸發序列化)
mod = tvm.IRModule.from_expr(const)
mod.save("const.json") # 內部調用VisitAttrs# 加載模型(觸發反序列化)
loaded_mod = tvm.ir.load_json("const.json")
loaded_const = loaded_mod["main"].body
assert isinstance(loaded_const, relay.Constant)
(3) 自定義屬性訪問器
class MyVisitor(tvm.ir.AttrVisitor):def visit(self, name, value):print(f"Attribute {name} has type {type(value)}")visitor = MyVisitor()
const.visit_attrs(visitor) # 顯式調用VisitAttrs
輸出:
Attribute data has type <class 'tvm.runtime.ndarray.NDArray'>
Attribute span has type <class 'tvm.ir.Span'>
...
class Constant;
/*!* \brief Constant tensor type.*/
class ConstantNode : public ExprNode {public:/*! \brief The data of the tensor */runtime::NDArray data;/*! \return The corresponding tensor type of the data */TensorType tensor_type() const;/*! \return Whether it is scalar(rank-0 tensor) */bool is_scalar() const { return data->ndim == 0; }void VisitAttrs(tvm::AttrVisitor* v) {v->Visit("data", &data);v->Visit("span", &span);v->Visit("mdata", &mdata);v->Visit("_checked_type_", &checked_type_);}bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {return equal(data, other->data);}void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }static constexpr const char* _type_key = "relay.Constant";TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};class Constant : public Expr {public:/*!* \brief The constructor* \param data The data of the constant tensor.* \param span The source span of the expression.*/TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};
以下是關于 ConstantNode
和 Constant
類的詳細解釋與概括,結合它們在 TVM Relay IR 中的作用和實現設計: