PyTorch API 6 - 編譯、fft、fx、函數轉換、調試、符號追蹤

文章目錄

  • torch.compiler
    • 延伸閱讀
  • torch.fft
    • 快速傅里葉變換
    • 輔助函數
  • torch.func
    • 什么是可組合的函數變換?
    • 為什么需要可組合的函數變換?
    • 延伸閱讀
  • torch.futures
  • torch.fx
    • 概述
    • 編寫轉換函數
      • 圖結構快速入門
      • 圖操作
        • 直接操作計算圖
        • 使用 replace_pattern() 進行子圖重寫
        • 圖操作示例
      • 代理/回溯機制
      • 解釋器模式
        • 解釋器模式示例
    • 調試
      • 簡介
      • 變換編寫中的常見陷阱
      • 檢查模塊的正確性
      • 調試生成的代碼
        • 使用 `pdb`
        • 打印生成的代碼
        • 使用 `GraphModule` 中的 `to_folder` 函數
      • 調試轉換過程
      • 可用的調試器
    • 符號追蹤的局限性
      • 動態控制流
        • 靜態控制流
      • 非`torch`函數
      • 使用 `Tracer` 類自定義追蹤功能
        • 葉子模塊
      • 雜項說明
    • API 參考
  • torch.fx.experimental
    • torch.fx.experimental.symbolic_shapes
    • torch.fx.experimental.proxy_tensor


torch.compiler

torch.compiler 是一個命名空間,通過它向用戶開放了一些內部編譯器方法。該命名空間中的主要功能和特性是 torch.compile

torch.compile 是 PyTorch 2.x 引入的一個函數,旨在解決 PyTorch 中精確圖捕獲的問題,最終幫助軟件工程師加速運行他們的 PyTorch 程序。torch.compile 使用 Python 編寫,標志著 PyTorch 從 C++ 向 Python 的過渡。

torch.compile 利用了以下底層技術:

  • TorchDynamo (torch._dynamo) 是一個內部 API,它使用 CPython 的 Frame Evaluation API 功能來安全捕獲 PyTorch 計算圖。通過 torch.compiler 命名空間向 PyTorch 用戶開放可用方法。
  • TorchInductortorch.compile 默認的深度學習編譯器,為多種加速器和后端生成快速代碼。需要通過后端編譯器才能實現 torch.compile 的加速效果。對于 NVIDIA、AMD 和 Intel GPU,它使用 OpenAI Triton 作為關鍵構建塊。
  • AOT Autograd 不僅能捕獲用戶級代碼,還能捕獲反向傳播,實現"提前"捕獲反向傳遞。這使得 TorchInductor 能夠同時加速前向和反向傳遞。

注意:在本文檔中,術語 torch.compile、TorchDynamo 和 torch.compiler 有時會互換使用。

如上所述,要通過 TorchDynamo 運行更快的工作流,torch.compile 需要一個后端將捕獲的計算圖轉換為快速機器碼。不同的后端會帶來不同的優化效果。默認后端是 TorchInductor(也稱為 inductor)。TorchDynamo 還支持由合作伙伴開發的一系列后端,可以通過運行 torch.compiler.list_backends() 查看,每個后端都有其可選依賴項。

一些最常用的后端包括:

訓練和推理后端

后端描述
torch.compile(m, backend="inductor")使用 TorchInductor 后端。了解更多
torch.compile(m, backend="cudagraphs")使用 AOT Autograd 的 CUDA 圖。了解更多
torch.compile(m, backend="ipex")在 CPU 上使用 IPEX。了解更多
torch.compile(m, backend="onnxrt")使用 ONNX Runtime 在 CPU/GPU 上進行訓練。了解更多

僅推理后端

后端描述
torch.compile(m, backend="tensorrt")使用 Torch-TensorRT 進行推理優化。需要在調用腳本中 import torch_tensorrt 來注冊后端。了解更多
torch.compile(m, backend="ipex")在 CPU 上使用 IPEX 進行推理。了解更多
torch.compile(m, backend="tvm")使用 Apache TVM 進行推理優化。了解更多
torch.compile(m, backend="openvino")使用 OpenVINO 進行推理優化。了解更多

延伸閱讀

PyTorch 用戶入門指南

  • 快速入門
  • torch.compiler API 參考
  • torch.compiler.config 配置
  • TorchDynamo 細粒度追蹤 API
  • AOTInductor: Torch.Export 模型的預編譯方案
  • TorchInductor GPU 性能分析
  • torch.compile 性能剖析指南
  • 常見問題解答
  • torch.compile 故障排查
  • PyTorch 2.0 性能看板

PyTorch 開發者深度解析

  • Dynamo 架構概覽
  • Dynamo 技術深潛
  • 動態形狀支持
  • PyTorch 2.0 NNModule 支持
  • 后端開發最佳實踐
  • CUDA 圖樹優化
  • 偽張量機制

PyTorch 后端供應商指南

  • 自定義后端開發
  • ATen IR 圖轉換開發
  • 中間表示層詳解

torch.fft

離散傅里葉變換及相關函數。


快速傅里葉變換

fft計算input的一維離散傅里葉變換
ifft計算input的一維離散傅里葉逆變換
fft2計算input的二維離散傅里葉變換
ifft2計算input的二維離散傅里葉逆變換
fftn計算input的N維離散傅里葉變換
ifftn計算input的N維離散傅里葉逆變換
rfft計算實數input的一維傅里葉變換
irfft計算rfft()的逆變換
rfft2計算實數input的二維離散傅里葉變換
irfft2計算rfft2()的逆變換
rfftn計算實數input的N維離散傅里葉變換
irfftn計算rfftn()的逆變換
hfft計算Hermitian對稱input信號的一維離散傅里葉變換
ihfft計算hfft()的逆變換
hfft2計算Hermitian對稱input信號的二維離散傅里葉變換
ihfft2計算實數input的二維離散傅里葉逆變換
hfftn計算Hermitian對稱input信號的N維離散傅里葉變換
ihfftn計算實數input的N維離散傅里葉逆變換

輔助函數

fftfreq計算大小為 n 的信號的離散傅里葉變換采樣頻率。
rfftfreq計算大小為 n 的信號在使用 rfft() 時的采樣頻率。
fftshift對由 fftn() 提供的 n 維 FFT 數據進行重新排序,使負頻率項優先。
ifftshiftfftshift() 的逆操作。

torch.func

torch.func(前身為"functorch")是為PyTorch提供的JAX風格可組合函數變換工具。


注意:該庫目前處于測試階段。
這意味著這些功能基本可用(除非另有說明),且我們(PyTorch團隊)將持續推進該庫的發展。但API可能會根據用戶反饋進行調整,且尚未完全覆蓋所有PyTorch操作。

如果您對API有改進建議,或希望支持特定使用場景,請提交GitHub issue或直接聯系我們。我們非常期待了解您如何使用這個庫。


什么是可組合的函數變換?

  • 函數變換是一種高階函數,它接受一個數值函數作為輸入,并返回一個新函數來計算不同的量。
  • torch.func 提供了自動微分變換(例如 grad(f) 返回計算 f 梯度的函數)、向量化/批處理變換(例如 vmap(f) 返回對輸入批次執行 f 的函數)等多種變換。
  • 這些函數變換可以任意組合使用。例如,組合 vmap(grad(f)) 可以計算單樣本梯度(per-sample-gradients),這是當前標準 PyTorch 無法高效計算的量。

為什么需要可組合的函數變換?

目前在 PyTorch 中實現以下用例較為棘手:

  • 計算逐樣本梯度(或其他逐樣本量)
  • 在單臺機器上運行模型集成
  • 在 MAML 內循環中高效批處理任務
  • 高效計算雅可比矩陣和海森矩陣
  • 高效計算批量雅可比矩陣和海森矩陣

通過組合使用 vmap()grad()vjp() 變換,我們無需為每個用例單獨設計子系統即可實現上述功能。這種可組合函數變換的理念源自 JAX 框架。


延伸閱讀

  • torch.func 快速指南
    • 什么是 torch.func?
    • 為什么需要可組合函數變換?
    • 有哪些變換方法?
  • torch.func API 參考
    • 函數變換
    • torch.nn.Module 工具集
    • 調試工具
  • 使用限制
    • 通用限制
    • torch.autograd API
    • vmap 限制
    • 隨機性控制
  • 從 functorch 遷移到 torch.func
    • 函數變換
    • 神經網絡模塊工具
    • functorch.compile

torch.futures

該包提供了一種 Future 類型,用于封裝異步執行過程,并提供一組實用函數來簡化對 Future 對象的操作。目前,Future 類型主要被 分布式RPC框架 使用。


class torch.futures.Future(*, devices=None) 

Wrapper around a torch._C.Future which encapsulates an asynchronous
execution of a callable, e.g. rpc_async(). It also exposes a set of APIs to add callback functions and set results.


Warning: GPU support is a beta feature, subject to changes.


add_done_callback(callback)

將給定的回調函數附加到此Future上,該回調函數將在Future完成時運行。可以向同一個Future添加多個回調,但無法保證它們的執行順序。回調函數必須接受一個參數,即對此Future的引用。回調函數可以使用value()方法獲取值。請注意,如果此Future已經完成,給定的回調將立即內聯執行。

我們建議使用then()方法,因為它提供了一種在回調完成后進行同步的方式。如果回調不返回任何內容,add_done_callback可能更高效。但then()add_done_callback在底層使用相同的回調注冊API。

對于GPU張量,此方法的行為與then()相同。

參數

  • callback (Future) – 一個可調用對象,接受一個參數,即對此Future的引用。

注意:請注意,如果回調函數拋出異常,無論是由于原始future以異常完成并調用fut.wait(),還是由于回調中的其他代碼,都必須仔細處理錯誤。例如,如果此回調隨后完成了其他future,這些future不會被標記為以錯誤完成,用戶需要獨立處理這些future的完成/等待。


示例

>>> def callback(fut):
...     print("This will run after the future has finished.")
...     print(fut.wait())
>>> fut = torch.futures.Future()
>>> fut.add_done_callback(callback)
>>> fut.set_result(5)
This will run after the future has finished.
5

done()

如果該Future已完成則返回True。當Future包含結果或異常時即視為完成。

如果值包含位于GPU上的張量,即使填充這些張量的異步內核尚未在設備上完成運行,Future.done()仍會返回True,因為在此階段結果已可被使用(前提是執行適當的同步操作,參見wait())。

返回類型:bool


set_exception(result)

為這個 Future 設置一個異常,這將標記該 Future 以錯誤狀態完成,并觸發所有已附加的回調。請注意,當對此 Future 調用 wait()/value() 時,此處設置的異常

將被內聯拋出。

參數

  • result ([BaseException](https://docs.python.org/3/library/exceptions.html#BaseException "(in Python v3.13)")) – 該 Future 的異常對象。

示例

>>> fut = torch.futures.Future()
>>> fut.set_exception(ValueError("foo"))
>>> fut.wait()
Traceback (most recent call last):
...
ValueError: foo

set_result(result)

為這個Future設置結果,這將標記該Future為已完成狀態并觸發所有關聯的回調。需要注意的是,一個Future不能被標記為已完成兩次。

如果結果包含位于GPU上的張量,即使填充這些張量的異步內核尚未在設備上完成運行,只要調用此方法時這些內核所入隊的流被設置為當前流,仍可調用此方法。簡而言之,在啟動這些內核后立即調用此方法是安全的,無需額外同步,前提是期間不切換流。此方法會在所有相關當前流上記錄事件,并利用它們確保此Future的所有消費者都能得到正確調度。

參數

  • result ( object ) - 該Future的結果對象。

示例:


>>> import threading
>>> import time
>>> def slow_set_future(fut, value):
...     time.sleep(0.5)
...     fut.set_result(value)
>>> fut = torch.futures.Future()
>>> t = threading.Thread(
...     target=slow_set_future, 
...     args=(fut, torch.ones(2) * 3)
... )
>>> t.start()
>>> print(fut.wait())
tensor([3., 3.])
>>> t.join()

then(callback)

將給定的回調函數附加到此Future上,該回調函數將在Future完成時運行。可以向同一個Future添加多個回調,但無法保證它們的執行順序(如需確保特定順序,請考慮鏈式調用:fut.then(cb1).then(cb2))。回調函數必須接受一個參數,即對此Future的引用。回調函數可通過value()方法獲取值。請注意,如果此Future已完成,給定的回調將立即內聯執行。

如果Future的值包含位于GPU上的張量,回調可能在填充這些張量的異步內核尚未在設備上完成執行時就被調用。不過,回調將通過設置為當前的一些專用流(從全局池中獲取)被調用,這些流將與那些內核同步。因此,回調對這些張量執行的任何操作都將在內核完成后調度到設備上。換句話說,只要回調不切換流,它就可以安全地操作結果而無需額外同步。這與wait()的非阻塞行為類似。

類似地,如果回調返回的值包含位于GPU上的張量,即使生成這些張量的內核仍在設備上運行,回調也可以這樣做,前提是回調在執行期間沒有切換流。如果想要切換流,必須注意與原始流重新同步,即回調被調用時當前的流。

參數

  • callback (Callable) – 一個以該Future為唯一參數的可調用對象。

返回

一個新的Future對象,它持有callback的返回值,并將在給定callback完成時標記為已完成。

返回類型

Future[S]

注意:請注意,如果回調函數拋出異常,無論是通過原始future以異常完成并調用fut.wait(),還是通過回調中的其他代碼,then返回的future將適當地標記為遇到錯誤。但是,如果此回調隨后完成其他future,這些future不會標記為以錯誤完成,用戶需負責獨立處理這些future的完成/等待。


示例

>>> def callback(fut):
...     print(f"RPC return value is {fut.wait()}.")
>>> fut = torch.futures.Future()
>>> # The inserted callback will print the return value when
>>> # receiving the response from "worker1"
>>> cb_fut = fut.then(callback)
>>> chain_cb_fut = cb_fut.then(
...     lambda x : print(f"Chained cb done. {x.wait()}")
... )
>>> fut.set_result(5)
RPC return value is 5、Chained cb done. None

value()

獲取已完成的Future對象的值。

此方法僅應在調用wait()完成后,或在傳遞給then()的回調函數內部使用。其他情況下,該Future可能尚未持有值,調用value()可能會失敗。

如果值包含位于GPU上的張量,此方法將不會執行任何額外的同步操作。此類同步應事先通過調用wait()單獨完成(回調函數內部除外,因為then()已自動處理此情況)。

返回值
Future持有的值。如果創建該值的函數(回調或RPC)拋出錯誤,此value()方法同樣會拋出錯誤。

返回類型:T


wait()

等待直到該 Future 的值準備就緒。

如果值包含位于 GPU 上的張量,則會與設備上異步填充這些張量的內核執行額外的同步操作。此類同步是非阻塞的,這意味著 wait() 會在當前流中插入必要的指令,以確保后續在這些流上排隊的操作能正確安排在異步內核之后執行。但一旦完成指令插入,即使這些內核仍在運行,wait() 也會立即返回。只要不切換流,在訪問和使用這些值時無需進一步同步。

返回值:此 Future 持有的值。如果創建該值的函數(回調或 RPC)拋出錯誤,此 wait 方法同樣會拋出錯誤。

返回類型:T


torch.futures.collect_all(futures)

將提供的 Future 對象收集到一個統一的組合 Future 中,該組合 Future 會在所有子 Future 完成時完成。

參數

  • futures (list) – 一個包含 Future 對象的列表。

返回

返回一個 Future 對象,該對象關聯到傳入的 Future 列表。

返回類型

Future[list [torch.jit.Future]]


示例

>>> fut0 = torch.futures.Future()
>>> fut1 = torch.futures.Future()
>>> fut = torch.futures.collect_all([fut0, fut1])
>>> fut0.set_result(0)
>>> fut1.set_result(1)
>>> fut_list = fut.wait()
>>> print(f"fut0 result = {fut_list[0].wait()}")
fut0 result = 0
>>> print(f"fut1 result = {fut_list[1].wait()}")
fut1 result = 1

torch.futures.wait_all(futures)

等待所有提供的 futures 完成,并返回已完成值的列表。如果任一 future 遇到錯誤,該方法將提前退出并報告錯誤,而不會等待其他 futures 完成。

參數

  • futures (list) – 一個 Future 對象列表。

返回值:已完成 Future 結果的列表。如果對任何 Future 調用 wait 時拋出錯誤,該方法也會拋出錯誤。

返回類型:list


torch.fx


概述

FX 是一個供開發者使用的工具包,用于轉換 nn.Module 實例。FX 包含三個核心組件:符號追蹤器中間表示Python 代碼生成。以下是這些組件的實際應用演示:

import torch# Simple module for demonstration
class MyModule(torch.nn.Module):def __init__(self) -None:super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return self.linear(x + self.param).clamp(min=0.0, max=1.0)module = MyModule()from torch.fx import symbolic_trace# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():%x : [num_users=1] = placeholder[target=x]%param : [num_users=1] = get_attr[target=param]%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})return clamp
"""# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):param = self.paramadd = x + param;  x = param = Nonelinear = self.linear(add);  add = Noneclamp = linear.clamp(min = 0.0, max = 1.0);  linear = Nonereturn clamp
"""

符號追蹤器(symbolic tracer)對Python代碼執行"符號執行"。它通過代碼傳遞稱為Proxy的虛擬值,并記錄對這些Proxy的操作。有關符號追蹤的更多信息,請參閱symbolic_trace()Tracer文檔。

中間表示(intermediate representation)是符號追蹤過程中記錄操作的容器。它由一組節點組成,這些節點表示函數輸入、調用點(指向函數、方法或torch.nn.Module實例)以及返回值。有關IR的更多信息,請參閱Graph文檔。IR是應用轉換的基礎格式。

Python代碼生成功能使FX成為Python到Python(或Module到Module)的轉換工具包。對于每個Graph IR,我們都可以生成符合Graph語義的有效Python代碼。這個功能被封裝在GraphModule中,它是一個torch.nn.Module實例,包含一個Graph以及從Graph生成的forward方法。

這些組件(符號追蹤→中間表示→轉換→Python代碼生成)共同構成了FX的Python到Python轉換流程。此外,這些組件也可以單獨使用。例如,符號追蹤可以單獨用于捕獲代碼形式進行分析(而非轉換)目的。代碼生成可以用于通過編程方式生成模型,例如從配置文件生成。FX有許多用途!

在示例庫中可以找到幾個轉換示例。


編寫轉換函數

什么是FX轉換?本質上,它是一個形如下列的函數。


import torch
import torch.fxdef transform(m: nn.Module,        tracer_class : type = torch.fx.Tracer) -torch.nn.Module:# Step 1: Acquire a Graph representing the code in `m`# NOTE: torch.fx.symbolic_trace is a wrapper around a call to     # fx.Tracer.trace and constructing a GraphModule. We'll# split that out in our transform to allow the caller to     # customize tracing behavior.graph : torch.fx.Graph = tracer_class().trace(m)# Step 2: Modify this Graph or create a new onegraph = ...# Step 3: Construct a Module to returnreturn torch.fx.GraphModule(m, graph)

您的轉換器將接收一個 torch.nn.Module,從中獲取 Graph,進行一些修改后返回一個新的 torch.nn.Module。您應該將 FX 轉換器返回的 torch.nn.Module 視為與常規 torch.nn.Module 完全相同——可以將其傳遞給另一個 FX 轉換器、傳遞給 TorchScript 或直接運行它。確保 FX 轉換器的輸入和輸出均為 torch.nn.Module 將有助于實現組合性。

注意:也可以直接修改現有的 GraphModule 而不創建新實例,例如:

import torch
import torch.fxdef transform(m : nn.Module) -nn.Module:gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)# Modify gm.graph# <...># Recompile the forward() method of `gm` from its Graphgm.recompile()return gm

請注意,你必須調用 GraphModule.recompile() 方法,使生成的 forward() 方法與修改后的 Graph 保持同步。

假設你已經傳入了一個經過追蹤轉換為 Graphtorch.nn.Module,現在主要有兩種方法來構建新的 Graph


圖結構快速入門

關于圖的語義完整說明可以參考 Graph 文檔,這里我們主要介紹基礎概念。Graph 是一種數據結構,用于表示 GraphModule 上的方法。其核心需要描述以下信息:

  • 方法的輸入參數是什么?
  • 方法內部運行了哪些操作?
  • 方法的輸出(即返回值)是什么?

這三個概念都通過 Node 實例來表示。下面通過一個簡單示例來說明:

import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self):super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)m = MyModule()
gm = torch.fx.symbolic_trace(m)gm.graph.print_tabular()

這里我們定義一個演示用的模塊 MyModule,實例化后進行符號追蹤,然后調用 Graph.print_tabular() 方法打印該 Graph 的節點表格:

操作碼名稱目標參數關鍵字參數
placeholderxx(){}
get_attrlinear_weightlinear.weight(){}
call_functionadd_1<built-in function add(x, linear_weight){}
call_modulelinear_1linear(add_1,){}
call_methodrelu_1relu(linear_1,){}
call_functionsum_1<built-in method sum …(relu_1,){‘dim’: -1}
call_functiontopk_1<built-in method topk …(sum_1, 3){}
outputoutputoutput(topk_1,){}

通過這些信息,我們可以回答之前提出的問題:

  • 方法的輸入是什么?
    在FX中,方法輸入通過特殊的 placeholder 節點指定。本例中有一個目標為 xplaceholder 節點,表示存在一個名為x的(非self)參數。
  • 方法內部有哪些操作?
    get_attrcall_functioncall_modulecall_method 節點表示方法中的操作。這些節點的完整語義說明可參考 Node 文檔。
  • 方法的返回值是什么?
    Graph 中,返回值由特殊的 output 節點指定。

現在我們已經了解FX中代碼表示的基本原理,接下來可以探索如何編輯 Graph


圖操作


直接操作計算圖

構建新Graph的一種方法是直接操作原有計算圖。為此,我們可以簡單地獲取通過符號追蹤得到的Graph并進行修改。例如,假設我們需要將所有torch.add()調用替換為torch.mul()調用。


import torch
import torch.fx# Sample module
class M(torch.nn.Module):def forward(self, x, y):return torch.add(x, y)def transform(m: torch.nn.Module,        tracer_class : type = fx.Tracer) -torch.nn.Module:graph : fx.Graph = tracer_class().trace(m)# FX represents its Graph as an ordered list of     # nodes, so we can iterate through them.for node in graph.nodes:# Checks if we're calling a function (i.e:# torch.add)if node.op == 'call_function':# The target attribute is the function# that call_function calls.if node.target == torch.add:node.target = torch.mulgraph.lint() # Does some checks to make sure the                  # Graph is well-formed.return fx.GraphModule(m, graph)

我們還可以進行更復雜的 Graph 重寫操作,例如刪除或追加節點。為了輔助這些轉換,FX 提供了一些用于操作計算圖的實用函數,這些函數可以在 Graph 文檔中找到。

下面展示了一個使用這些 API 追加 torch.relu() 調用的示例。


# Specifies the insertion point. Any nodes added to the # Graph within this scope will be inserted after `node` with traced.graph.inserting_after(node):# Insert a new `call_function` node calling `torch.relu`new_node = traced.graph.call_function(torch.relu, args=(node,))# We want all places that used the value of `node` to     # now use that value after the `relu` call we've added.# We use the `replace_all_uses_with` API to do this.node.replace_all_uses_with(new_node)

對于僅包含替換操作的簡單轉換,您也可以使用子圖重寫器。


使用 replace_pattern() 進行子圖重寫

FX 在直接圖操作的基礎上提供了更高層次的自動化能力。replace_pattern() API 本質上是一個用于編輯 Graph 的"查找/替換"工具。它允許你指定一個 pattern(模式)和 replacement(替換)函數,然后會追蹤這些函數,在圖中找到與 pattern 圖匹配的操作組實例,并用 replacement 圖的副本替換這些實例。這可以極大地自動化繁瑣的圖操作代碼,隨著轉換邏輯變得復雜,手動操作會變得難以維護。


圖操作示例
  • 替換單個操作符
  • 卷積/批量歸一化融合
  • replace_pattern:基礎用法
  • 量化
  • 逆變換

代理/回溯機制

另一種操作 Graph 的方式是復用符號追蹤中使用的 Proxy 機制。例如,假設我們需要編寫一個將 PyTorch 函數分解為更小操作的轉換器:將每個 F.relu(x) 調用轉換為 (x > 0) * x。傳統做法可能是通過圖重寫來插入比較和乘法操作,然后清理原始的 F.relu。但借助 Proxy 對象,我們可以自動將操作記錄到 Graph 中來實現這一過程。

具體實現時,只需將需要插入的操作寫成常規 PyTorch 代碼,并用 Proxy 對象作為參數調用該代碼。這些 Proxy 對象會捕獲對其執行的操作,并將其追加到 Graph 中。


# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):return (x 0) * xdecomposition_rules = {}
decomposition_rules[F.relu] = relu_decompositiondef decompose(model: torch.nn.Module,        tracer_class : type = fx.Tracer) -torch.nn.Module:"""Decompose `model` into smaller constituent operations.Currently,this only supports decomposing ReLU into itsmathematical definition: (x 0) * x"""graph : fx.Graph = tracer_class().trace(model)new_graph = fx.Graph()env = {}tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)for node in graph.nodes:if node.op == 'call_function' and node.target in decomposition_rules:# By wrapping the arguments with proxies,      # we can dispatch to the appropriate# decomposition rule and implicitly add it# to the Graph by symbolically tracing it.proxy_args = [fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]output_proxy = decomposition_rules[node.target](proxy_args)# Operations on `Proxy` always yield new `Proxy`s, and the             # return value of our decomposition rule is no exception.# We need to extract the underlying `Node` from the `Proxy`# to use it in subsequent iterations of this transform.new_node = output_proxy.nodeenv[node.name] = new_nodeelse:# Default case: we don't have a decomposition rule for this             # node, so just copy the node over into the new graph.new_node = new_graph.node_copy(node, lambda x: env[x.name])env[node.name] = new_nodereturn fx.GraphModule(model, new_graph)

除了避免顯式的圖操作外,使用Proxy還允許您將重寫規則指定為原生Python代碼。對于需要大量重寫規則的轉換(如vmap或grad),這通常可以提高規則的可讀性和可維護性。

需要注意的是,在調用Proxy時,我們還傳遞了一個指向底層變量圖的追蹤器。這樣做是為了防止當圖中的操作是n元操作時(例如add是二元運算符),調用Proxy不會創建多個圖追蹤器實例,否則可能導致意外的運行時錯誤。特別是在底層操作不能安全地假設為一元操作時,我們推薦使用這種Proxy方法。

一個使用Proxy進行Graph操作的實際示例可以在這里找到。


解釋器模式

在FX中,一個實用的代碼組織模式是遍歷Graph中的所有Node并執行它們。這種模式可用于多種場景,包括:

  • 運行時分析流經計算圖的值
  • 通過Proxy重新追蹤來實現代碼轉換

例如,假設我們想運行一個GraphModule,并在運行時記錄節點上torch.Tensor的形狀和數據類型屬性。實現代碼可能如下:

import torch
import torch.fx
from torch.fx.node import Nodefrom typing import Dictclass ShapeProp:"""Shape propagation. This class takes a `GraphModule`.Then, its `propagate` method executes the `GraphModule`node-by-node with the given arguments. As each operationexecutes, the ShapeProp class stores away the shape and     element type for the output values of each operation on     the `shape` and `dtype` attributes of the operation's`Node`."""def __init__(self, mod):self.mod = modself.graph = mod.graphself.modules = dict(self.mod.named_modules())def propagate(self, args):args_iter = iter(args)env : Dict[str, Node] = {}def load_arg(a):return torch.fx.graph.map_arg(a, lambda n: env[n.name])def fetch_attr(target : str):target_atoms = target.split('.')attr_itr = self.modfor i, atom in enumerate(target_atoms):if not hasattr(attr_itr, atom):raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")attr_itr = getattr(attr_itr, atom)return attr_itrfor node in self.graph.nodes:if node.op == 'placeholder':result = next(args_iter)elif node.op == 'get_attr':result = fetch_attr(node.target)elif node.op == 'call_function':result = node.target(load_arg(node.args), *load_arg(node.kwargs))elif node.op == 'call_method':self_obj, args = load_arg(node.args)kwargs = load_arg(node.kwargs)result = getattr(self_obj, node.target)(args, *kwargs)elif node.op == 'call_module':result = self.modules[node.target](load_arg(node.args), *load_arg(node.kwargs))# This is the only code specific to shape propagation.# you can delete this `if` branch and this becomes# a generic GraphModule interpreter.if isinstance(result, torch.Tensor):node.shape = result.shapenode.dtype = result.dtypeenv[node.name] = resultreturn load_arg(self.graph.result)

如你所見,為FX實現一個完整的解釋器并不復雜,但卻非常實用。為了簡化這一模式的使用,我們提供了Interpreter類,它封裝了上述邏輯,允許通過方法重寫來覆蓋解釋器執行的某些方面。

除了執行操作外,我們還可以通過向解釋器傳遞Proxy值來生成新的計算圖。

類似地,我們提供了Transformer類來封裝這種模式。Transformer的行為與Interpreter類似,但不同于調用run方法從模塊獲取具體輸出值,你需要調用Transformer.transform()方法來返回一個新的GraphModule,該模塊會應用你通過重寫方法設置的任何轉換規則。


解釋器模式示例
  • 形狀傳播
  • 性能分析器

調試


簡介

在編寫轉換代碼的過程中,我們的代碼往往不會一開始就完全正確。這時就需要進行調試。關鍵在于采用逆向思維:首先檢查調用生成模塊的結果,驗證其正確性;接著審查并調試生成的代碼;最后追溯導致生成代碼的轉換過程并進行調試。

如果您不熟悉調試工具,請參閱輔助章節可用調試工具。


變換編寫中的常見陷阱

  • set迭代順序的不確定性。在Python中,set數據類型是無序的。例如,使用set來存儲Node等對象集合可能導致意外的非確定性行為。比如當迭代一組Node并將其插入Graph時,由于set數據類型是無序的,輸出程序中操作的順序將是非確定性的,且每次程序調用都可能變化。

推薦的替代方案是使用dict數據類型。自Python 3.7起(以及cPython 3.6起),dict保持了插入順序。通過將需要去重的值存儲在dict的鍵中,可以等效地實現set的功能。


檢查模塊的正確性

由于大多數深度學習模塊的輸出都是浮點型 torch.Tensor 實例,因此檢查兩個 torch.nn.Module 的結果是否相等并不像簡單的相等性檢查那樣直接。為了說明這一點,我們來看一個示例:

import torch
import torch.fx
import torchvision.models as modelsdef transform(m : torch.nn.Module) -torch.nn.Module:gm = torch.fx.symbolic_trace(m)# Imagine we're doing some transforms here# <...>gm.recompile()return gmresnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)input_image = torch.randn(5, 3, 224, 224)assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""

在這里,我們嘗試使用==相等運算符來檢查兩個深度學習模型的值是否相等。然而,這種做法存在兩個問題:首先,該運算符返回的是張量而非布爾值;其次,浮點數值的比較應考慮誤差范圍(或epsilon),以解決浮點運算不可交換性的問題(詳見此處)。

我們可以改用torch.allclose()函數,它會基于相對和絕對容差閾值進行近似比較:

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

這是我們工具箱中的第一個工具,用于檢查轉換后的模塊與參考實現相比是否按預期運行。


調試生成的代碼

由于 FX 在 GraphModule 上生成 forward() 函數,使用傳統的調試技術(如 print 語句或 pdb)會不太直觀。幸運的是,我們有多種方法可以用來調試生成的代碼。


使用 pdb

通過調用 pdb 可以進入正在運行的程序進行調試。雖然表示 Graph 的代碼不在任何源文件中,但當執行前向傳播時,我們仍然可以手動使用 pdb 進入該代碼進行調試。


import torch
import torch.fx
import torchvision.models as modelsdef my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:graph = tracer_class().trace(inp)# Transformation logic here# <...># Return new Modulereturn fx.GraphModule(inp, graph)my_module = models.resnet18()
my_module_transformed = my_pass(my_module)input_value = torch.randn(5, 3, 224, 224)# When this line is executed at runtime, we will be dropped into an # interactive `pdb` prompt. We can use the `step` or `s` command to # step into the execution of the next line
import pdb; pdb.set_trace()my_module_transformed(input_value)

打印生成的代碼

如果需要多次運行相同的代碼,使用pdb逐步調試到目標代碼可能會有些繁瑣。這種情況下,一個簡單的方法是將生成的forward傳遞代碼直接復制粘貼到你的代碼中,然后在那里進行檢查。


# Assume that `traced` is a GraphModule that has undergone some
# number of transforms# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):x = self.xadd_1 = x + y;  x = y = Nonereturn add_1
"""# Subclass the original Module
class SubclassM(M):def __init__(self):super().__init__()# Paste the generated `forward` function (the one we printed and     # copied above) heredef forward(self, y):x = self.xadd_1 = x + y;  x = y = Nonereturn add_1# Create an instance of the original, untraced Module. Then, create an # instance of the Module with the copied `forward` function. We can # now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()

使用 GraphModule 中的 to_folder 函數

GraphModule.to_folder()GraphModule 中的一個方法,它允許你將生成的 FX 代碼導出到一個文件夾。雖然像打印生成的代碼中那樣直接復制前向傳播代碼通常已經足夠,但使用 to_folder 可以更方便地檢查模塊和參數。


m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

運行上述示例后,我們可以查看foo/module.py中的代碼,并根據需要進行修改(例如添加print語句或使用pdb)來調試生成的代碼。


調試轉換過程

既然我們已經確認是轉換過程生成了錯誤代碼,現在就該調試轉換本身了。首先,我們會查閱文檔中的符號追蹤限制部分。在確認追蹤功能按預期工作后,我們的目標就轉變為找出GraphModule轉換過程中出現的問題。編寫轉換部分可能有快速解決方案,如果沒有的話,我們還可以通過多種方式來檢查追蹤模塊:

# Sample Module
class M(torch.nn.Module):def forward(self, x, y):return x + y# Create an instance of `M`
m = M()# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a # GraphModule, so we aren't showing any sample transforms for the # sake of brevity.
traced = symbolic_trace(m)# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):add = x + y;  x = y = Nonereturn add
"""# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():%x : [num_users=1] = placeholder[target=x]%y : [num_users=1] = placeholder[target=y]%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})return add
"""# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add     <built-in function add (x, y)  {}
output         output  output                   (add,)  {}
"""

通過使用上述工具函數,我們可以對比應用轉換前后的追蹤模塊。有時,簡單的視覺對比就足以定位錯誤。如果問題仍不明確,下一步可以嘗試使用 pdb 這類調試器。

以上述示例為基礎,請看以下代碼:

# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:# Get the Graph from our traced Moduleg = tracer_class().trace(module)"""Transformations on `g` go here"""return fx.GraphModule(module, g)# Transform the Graph
transformed = transform_graph(traced)# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)

以上述示例為例,假設調用print(traced)時發現轉換過程中存在錯誤。我們需要通過調試器定位問題根源。啟動pdb調試會話后,可以在transform_graph(traced)處設置斷點,然后按s鍵"步入"該函數調用,實時觀察轉換過程。

另一個有效方法是修改print_tabular方法,使其輸出圖中節點的不同屬性(例如查看節點的input_nodesusers關系)。


可用的調試器

最常用的Python調試器是pdb。你可以通過在命令行輸入python -m pdb FILENAME.py來以"調試模式"啟動程序,其中FILENAME是你要調試的文件名。之后,你可以使用pdb的調試器命令逐步執行正在運行的程序。通常的做法是在啟動pdb時設置一個斷點(b LINE-NUMBER),然后調用c讓程序運行到該斷點處。這樣可以避免你不得不使用sn逐行執行代碼才能到達想要檢查的部分。或者,你也可以在想中斷的代碼行前寫入import pdb; pdb.set_trace()。如果添加了pdb.set_trace(),當你運行程序時它會自動進入調試模式(換句話說,你只需在命令行輸入python FILENAME.py而不用輸入python -m pdb FILENAME.py)。一旦以調試模式運行文件,你就可以使用特定命令逐步執行代碼并檢查程序的內部狀態。網上有很多關于pdb的優秀教程,包括RealPython的《Python Debugging With Pdb》。

像PyCharm或VSCode這樣的IDE通常內置了調試器。在你的IDE中,你可以選擇:a)通過調出IDE中的終端窗口(例如在VSCode中選擇View → Terminal)使用pdb,或者b)使用內置的調試器(通常是pdb的圖形化封裝)。


符號追蹤的局限性

FX 采用符號追蹤(又稱符號執行)系統,以可轉換/可分析的形式捕獲程序語義。該系統具有以下特點:

  • 追蹤性:通過實際執行程序(實際是torch.nn.Module或函數)來記錄操作
  • 符號性:執行過程中流經程序的數據并非真實數據,而是符號(FX術語中稱為Proxy)

雖然符號追蹤適用于大多數神經網絡代碼,但它仍存在一些局限性。


動態控制流

符號追蹤的主要局限在于目前不支持動態控制流。也就是說,當循環或if語句的條件可能依賴于程序輸入值時,就無法處理。

例如,我們來看以下程序:

def func_to_trace(x):if x.sum() 0:return torch.relu(x)else:return torch.neg(x)traced = torch.fx.symbolic_trace(func_to_trace)
"""<...>File "dyn.py", line 6, in func_to_traceif x.sum() 0:File "pytorch/torch/fx/proxy.py", line 155, in __bool__return self.tracer.to_bool(self)File "pytorch/torch/fx/proxy.py", line 85, in to_boolraise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

if語句的條件依賴于x.sum()的值,而該值又依賴于函數輸入x。由于x可能發生變化(例如向追蹤函數傳入新的輸入張量時),這就形成了動態控制流。回溯信息會沿著代碼向上追溯,展示這種情況發生的位置。


靜態控制流

另一方面,系統支持所謂的靜態控制流。靜態控制流指的是那些在多次調用中值不會改變的循環或if語句。通常在PyTorch程序中,這種控制流出現在根據超參數決定模型架構的代碼中。舉個具體例子:

import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self, do_activation : bool = False):super().__init__()self.do_activation = do_activationself.linear = torch.nn.Linear(512, 512)def forward(self, x):x = self.linear(x)# This if-statement is so-called static control flow.# Its condition does not depend on any input valuesif self.do_activation:x = torch.relu(x)return xwithout_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):linear_1 = self.linear(x);  x = Nonereturn linear_1
"""traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):linear_1 = self.linear(x);  x = Nonerelu_1 = torch.relu(linear_1);  linear_1 = Nonereturn relu_1
"""

if self.do_activation 這個條件語句不依賴于任何函數輸入,因此它是靜態的。do_activation 可以被視為一個超參數,當 MyModule 的不同實例使用不同參數值時,生成的代碼軌跡也會不同。這是一種有效模式,符號追蹤功能支持這種模式。

許多動態控制流的實例在語義上其實是靜態控制流。通過消除對輸入值的數據依賴,這些實例可以支持符號追蹤。具體方法包括:

  • 將值移至 Module 屬性中
  • 在符號追蹤期間將具體值綁定到參數上

def f(x, flag):if flag: return xelse: return x*2fx.symbolic_trace(f) # Fails!fx.symbolic_trace(f, concrete_args={'flag': True})

在真正動態控制流的情況下,包含此類代碼的程序部分可以被追蹤為對方法的調用(參見使用Tracer類自定義追蹤)或函數調用(參見wrap()),而不是直接追蹤這些代碼本身。


torch函數

FX采用__torch_function__作為攔截調用的機制(更多技術細節請參閱技術概覽)。某些函數(如Python內置函數或math模塊中的函數)不受__torch_function__覆蓋,但我們仍希望在符號追蹤中捕獲它們。例如:

import torch
import torch.fx
from math import sqrtdef normalize(x):"""Normalize `x` by the size of the batch dimension"""return x / sqrt(len(x))# It's valid Python code
normalize(torch.rand(3, 4))traced = torch.fx.symbolic_trace(normalize)
"""<...>File "sqrt.py", line 9, in normalizereturn x / sqrt(len(x))File "pytorch/torch/fx/proxy.py", line 161, in __len__raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

錯誤提示表明內置函數 len 不被支持。

我們可以通過 wrap() API 將此類函數記錄為跟蹤中的直接調用:

torch.fx.wrap('len')
torch.fx.wrap('sqrt')traced = torch.fx.symbolic_trace(normalize)print(traced.code)
"""
import math
def forward(self, x):len_1 = len(x)sqrt_1 = math.sqrt(len_1);  len_1 = Nonetruediv = x / sqrt_1;  x = sqrt_1 = Nonereturn truediv
"""

使用 Tracer 類自定義追蹤功能

Tracer 類是 symbolic_trace 功能的基礎實現類。通過繼承 Tracer 類,可以自定義追蹤行為,例如:

class MyCustomTracer(torch.fx.Tracer):# Inside here you can override various methods# to customize tracing. See the `Tracer` API# referencepass# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):def forward(self, x):return torch.relu(x) + torch.ones(3, 4)mod = MyModule()traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a # GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)

葉子模塊

葉子模塊是指在符號追蹤過程中作為調用出現,而不會被繼續追蹤的模塊。默認的葉子模塊集合由標準torch.nn模塊實例組成。例如:

class MySpecialSubmodule(torch.nn.Module):def forward(self, x):return torch.neg(x)class MyModule(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(3, 4)self.submod = MySpecialSubmodule()def forward(self, x):return self.submod(self.linear(x))traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):linear_1 = self.linear(x);  x = Noneneg_1 = torch.neg(linear_1);  linear_1 = Nonereturn neg_1
"""

可以通過重寫 Tracer.is_leaf_module() 來自定義葉子模塊集合。


雜項說明

  • 當前無法追蹤張量構造函數(如torch.zerostorch.onestorch.randtorch.randntorch.sparse_coo_tensor):
    • 確定性構造函數(zerosones)仍可使用,其生成的值會作為常量嵌入追蹤記錄。僅當這些構造函數的參數涉及動態輸入大小時才會出現問題,此時可改用ones_likezeros_like作為替代方案。
    • 非確定性構造函數(randrandn)會將單個隨機值嵌入追蹤記錄,這通常不符合預期行為。變通方法是將torch.randn包裝在torch.fx.wrap函數中并調用該包裝函數。

(注:保留所有代碼塊及技術術語原貌,被動語態轉為主動表述,長句拆分后保持技術嚴謹性)


@torch.fx.wrap
def torch_randn(x, shape):return torch.randn(shape)def f(x):return x + torch_randn(x, 5)
fx.symbolic_trace(f)

此行為可能在未來的版本中修復。

  • 類型注解
  • 支持 Python 3 風格的類型注解(例如
    func(x : torch.Tensor, y : int) -torch.Tensor),
    并且會通過符號追蹤保留這些注解。

  • 目前不支持 Python 2 風格的注釋類型注解
    # type: (torch.Tensor, int) -torch.Tensor

  • 目前不支持函數內部局部變量的類型注解。

  • 關于 training 標志和子模塊的注意事項
  • 當使用像 torch.nn.functional.dropout 這樣的函數時,通常會傳入 self.training 作為訓練參數。在 FX 追蹤過程中,這個值很可能會被固定為一個常量。

import torch
import torch.fxclass DropoutRepro(torch.nn.Module):def forward(self, x):return torch.nn.functional.dropout(x, training=self.training)traced = torch.fx.symbolic_trace(DropoutRepro())
print(traced.code)
"""
def forward(self, x):dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = Nonereturn dropout
"""traced.eval()x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)
"""
AssertionError: Tensor-likes are not close!Mismatched elements: 15 / 15 (100.0%)
Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
"""

然而,當使用標準的 nn.Dropout() 子模塊時,訓練標志會被封裝起來,并且由于保留了 nn.Module 對象模型,可以對其進行修改。


class DropoutRepro2(torch.nn.Module):def __init__(self):super().__init__()self.drop = torch.nn.Dropout()def forward(self, x):return self.drop(x)traced = torch.fx.symbolic_trace(DropoutRepro2())
print(traced.code)
"""
def forward(self, x):drop = self.drop(x);  x = Nonereturn drop
"""traced.eval()x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)

由于這一差異,建議將與動態training標志交互的模塊標記為葉模塊。


API 參考


torch.fx.symbolic_trace(root, concrete_args=None)

符號追蹤 API

給定一個 nn.Module 或函數實例 root,該 API 會返回一個 GraphModule,這是通過記錄追蹤 root 時觀察到的操作構建而成的。

concrete_args 參數允許你對函數進行部分特化,無論是為了移除控制流還是數據結構。

例如:

def f(a, b):if b == True:return a     else:return a * 2

由于控制流的存在,FX通常無法追蹤此過程。不過,我們可以使用concrete_args來針對變量b的值進行特化處理,從而實現追蹤:

f = fx.symbolic_trace(f, concrete_args={"b": False})
assert f(3, False) == 6

請注意,雖然您仍可以傳入不同的b值,但這些值將被忽略。

我們還可以使用concrete_args來消除函數中對數據結構的處理。這將利用pytrees來展平您的輸入。為了避免過度特化,對于不應特化的值,請傳入fx.PH。例如:

def f(x):out = 0for v in x.values():out += vreturn outf = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
assert f({"a": 1, "b": 2, "c": 4}) == 7

參數

  • root (Union[torch.nn.Module, Callable]) - 待追蹤并轉換為圖表示形式的模塊或函數
  • concrete_args (Optional[Dict[str, any]]) - 需要部分特化的輸入參數

返回從root記錄的操作所創建的模塊。

返回類型:GraphModule

注意:此API保證向后兼容性。


torch.fx.wrap(fn_or_name)

該函數可在模塊級作用域調用,將fn_or_name注冊為"葉子函數"。

"葉子函數"在FX跟蹤中會保留為CallFunction節點,而不會被進一步跟蹤。


# foo/bar/baz.py
def my_custom_function(x, y):return x * x + y * ytorch.fx.wrap("my_custom_function")def fn_to_be_traced(x, y):# When symbolic tracing, the below call to my_custom_function will be inserted into# the graph rather than tracing it.return my_custom_function(x, y)

該函數也可以等效地用作裝飾器:

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):return x * x + y * y

包裝函數可以被視為"葉子函數",類似于"葉子模塊"的概念,也就是說,這些函數在FX跟蹤中會保留為調用點,而不會被進一步追蹤。

參數

  • fn_or_name (Union[str, Callable]) - 當被調用時,要插入到圖中的函數或全局函數名稱

注意:此API保證向后兼容性。


class torch.fx.GraphModule(*args, **kwargs)

GraphModule 是由 fx.Graph 生成的 nn.Module。GraphModule 具有一個 graph 屬性,以及從該 graph 生成的 codeforward 屬性。

警告:當重新分配 graph 時,codeforward 將自動重新生成。但如果你編輯了 graph 的內容而沒有重新分配 graph 屬性本身,則必須調用 recompile() 來更新生成的代碼。

注意:此 API 保證向后兼容性。


__init__(root, graph, class_name='GraphModule')

構建一個 GraphModule。

參數

  • root (Union[torch.nn.Module , Dict[str, Any])root 可以是 nn.Module 實例,也可以是將字符串映射到任意屬性類型的字典。

root 是 Module 時,Graph 的 Nodes 中 target 字段對基于 Module 的對象(通過限定名稱引用)的任何引用,都會從 root 的 Module 層次結構中的相應位置復制到 GraphModule 的模塊層次結構中。

root 是字典時,Node 的 target 中找到的限定名稱將直接在字典的鍵中查找。字典映射的對象將被復制到 GraphModule 模塊層次結構中的適當位置。

  • graph (Graph)graph 包含此 GraphModule 用于代碼生成的節點
  • class_name (str)name 表示此 GraphModule 的名稱,用于調試目的。如果未設置,所有錯誤消息將報告為源自 GraphModule。將其設置為 root 的原始名稱或在轉換上下文中合理的名稱可能會有所幫助。

注意:此 API 保證向后兼容性。


add_submodule(target, m)

將給定的子模塊添加到self中。

如果target是子路徑且對應位置尚未存在模塊,此方法會安裝空的模塊。

參數

  • target (str) - 新子模塊的完整限定字符串名稱
    (參見nn.Module.get_submodule中的示例了解如何指定完整限定字符串)
  • m (Module) - 子模塊本身;即我們想要安裝到當前模塊中的實際對象

返回

子模塊是否能夠被插入。要使該方法返回True,target表示的鏈中每個對象必須滿足以下條件之一:
a) 尚不存在,或
b) 引用的是nn.Module(而非參數或其他屬性)

返回類型:bool

注意:此API保證向后兼容性。


property code:  str 

返回從該 GraphModule 底層 Graph 生成的 Python 代碼。

delete_all_unused_submodules()

***
Deletes all unused submodules from `self`.A Module is considered “used” if any one of the following is true:
1、It has children that are used
2、Its forward is called directly via a `call_module` node
3、It has a non-Module attribute that is used from a `get_attr` nodeThis method can be called to clean up an `nn.Module` without
manually calling `delete_submodule` on each unused submodule.
***
Note: Backwards-compatibility for this API is guaranteed.delete_submodule(target)

self中刪除指定的子模塊。

如果target不是有效的目標,則不會刪除該模塊。

參數

  • target (str) - 新子模塊的完全限定字符串名稱
    (有關如何指定完全限定字符串的示例,請參閱nn.Module.get_submodule

返回值
表示目標字符串是否引用了我們要刪除的子模塊。返回值為False意味著target不是有效的子模塊引用。

返回類型 : bool

注意:此API保證向后兼容性。


property graph: [Graph](https://pytorch.org/docs/stable/data.html#torch.fx.Graph "torch.fx.graph.Graph")

返回該 GraphModule 底層對應的 Graph


print_readable(print_output=True, include_stride=False, include_device=False, colored=False)

返回為當前 GraphModule 及其子 GraphModule 生成的 Python 代碼

警告:此 API 為實驗性質,且保證向后兼容性。


recompile()

根據其 graph 屬性重新編譯該 GraphModule。在編輯包含的 graph 后應調用此方法,否則該 GraphModule 生成的代碼將過期。

注意:此 API 保證向后兼容性。

返回類型:PythonCode


to_folder(folder, module_name='FxModule')

將模塊以 module_name 名稱轉儲到 folder 目錄下,以便可以通過 from <folder> import <module_name> 方式導入。

參數:

folder (Union [str, os.PathLike]): 用于輸出代碼的目標文件夾路徑
module_name (str): 在輸出代碼時使用的頂層模塊名稱

警告:此 API 為實驗性質,不保證向后兼容性。


class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)

Graph 是 FX 中間表示中使用的主要數據結構。

它由一系列 Node 組成,每個節點代表調用點(或其他語法結構)。這些 Node 的集合共同構成了一個有效的 Python 函數。

例如,以下代碼


import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self):super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)m = MyModule()
gm = torch.fx.symbolic_trace(m)

將生成以下圖表:

print(gm.graph)

graph(x):%linear_weight : [num_users=1] = self.linear.weight%add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})%linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})%relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})%sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})%topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})return topk_1

關于Graph中操作的具體語義,請參閱Node文檔。

注意:本API保證向后兼容性。


__init__(owning_module=None, tracer_cls=None, tracer_extras=None)

構建一個空圖。

注意:此 API 保證向后兼容性。


call_function(the_function, args=None, kwargs=None, type_expr=None)

Graph中插入一個call_function類型的Nodecall_function節點表示對Python可調用對象的調用,由the_function指定。

參數

  • the_function (Callable[...*, Any]) – 要調用的函數。可以是任何PyTorch運算符、Python函數,或屬于builtinsoperator命名空間的成員。
  • args (Optional[Tuple[Argument*, ...]]) – 傳遞給被調用函數的位置參數。
  • kwargs (Optional[Dict[str, Argument]]) – 傳遞給被調用函數的關鍵字參數。
  • type_expr (Optional[Any]) – 可選的類型注解,表示該節點輸出值的Python類型。

返回

新創建并插入的call_function節點。

返回類型

Node

注意:此方法的插入點和類型表達式規則與Graph.create_node()相同。

注意:此API保證向后兼容性。


call_method(method_name, args=None, kwargs=None, type_expr=None)

Graph中插入一個call_method節點。call_method節點表示對args第0個元素調用指定方法。

參數

  • method_name (str) - 要應用于self參數的方法名稱。例如,如果args[0]是一個表示TensorNode,那么要對該Tensor調用relu()方法時,需將relu作為method_name傳入。
  • args (Optional[Tuple[Argument*, ...]]) - 要傳遞給被調用方法的位置參數。注意這應該包含一個self參數。
  • kwargs (Optional[Dict[str, Argument]]) - 要傳遞給被調用方法的關鍵字參數
  • type_expr (Optional[Any]) - 可選的類型注解,表示該節點輸出結果的Python類型。

返回

新創建并插入的call_method節點。

返回類型

Node

注意:本方法的插入點和類型表達式規則與Graph.create_node()相同。

注意:此API保證向后兼容性。


call_module(module_name, args=None, kwargs=None, type_expr=None)

Graph中插入一個call_module類型的Node節點。call_module節點表示對Module層級結構中某個Module的forward()函數的調用。

參數

  • module_name (str) - 要調用的Module在層級結構中的限定名稱。例如,若被追蹤的Module有一個名為foo的子模塊,而該子模塊又包含名為bar的子模塊,則應以foo.bar作為module_name來調用該模塊。
  • args (Optional[Tuple[Argument*, ...]]) - 傳遞給被調用方法的位置參數。注意:此處不應包含self參數。
  • kwargs (Optional[Dict[str, Argument]]) - 傳遞給被調用方法的關鍵字參數
  • type_expr (Optional[Any]) - 可選類型注解,表示該節點輸出值的Python類型。

返回

新創建并插入的call_module節點。

返回類型:Node

注意:本方法的插入點與類型表達式規則與Graph.create_node()相同。

注意:本API保證向后兼容性。


create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)

創建一個 Node 并將其添加到當前插入點的 Graph 中。

注意:當前插入點可以通過 Graph.inserting_before()Graph.inserting_after() 進行設置。

參數

  • op (str) - 該節點的操作碼。可選值包括 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’。這些操作碼的語義在 Graph 的文檔字符串中有詳細說明。
  • args (Optional[Tuple[Argument*, ...]]) - 該節點的參數元組。
  • kwargs (Optional[Dict[str, Argument]]) - 該節點的關鍵字參數。
  • name (Optional[str]) - 為 Node 指定的可選字符串名稱。這將影響生成的 Python 代碼中賦值給該節點的變量名。
  • type_expr (Optional[Any]) - 可選類型注解,表示該節點輸出值的 Python 類型。

返回

新創建并插入的節點。

返回類型:Node

注意:此 API 保證向后兼容。


eliminate_dead_code(is_impure_node=None)

根據圖中各節點的用戶數量及是否具有副作用,移除所有死代碼。調用前必須確保圖已完成拓撲排序。

參數

  • is_impure_node (Optional[Callable[[Node],* [bool]]]) —— 用于判斷節點是否為非純函數的回調函數。若未提供該參數,則默認使用 Node.is_impure 方法。

返回值:返回布爾值,表示該過程是否導致圖結構發生變更。

返回類型:bool

示例

在消除死代碼前,下方表達式 a = x + 1 中的變量 a 無用戶引用,因此可從圖中安全移除而不影響結果。


def forward(self, x):a = x + 1return x + self.attr_1

消除死代碼后,a = x + 1 已被移除,前向傳播部分的其他代碼保留不變。


def forward(self, x):return x + self.attr_1

警告:死代碼消除機制雖然采用了一些啟發式方法來避免刪除具有副作用的節點(參見 Node.is_impure),但總體覆蓋率非常不理想。因此,除非你明確知道當前 FX 計算圖完全由無副作用的操作構成,或者自行提供了檢測副作用節點的自定義函數,否則不應假設調用此方法是安全可靠的。

注意:本 API 保證向后兼容性。


erase_node(to_erase)

Graph中刪除一個Node。如果該節點在Graph中仍被使用,將拋出異常。

參數

  • to_erase (Node) – 要從Graph中刪除的Node

注意:此API保證向后兼容性。


find_nodes(*, op, target=None, sort=True)

支持快速查詢節點

參數

  • op (str) – 操作名稱
  • target (Optional[Target]) – 節點目標。對于call_function操作,target為必填項;其他操作中target為可選參數。
  • sort ([bool]) – 是否按節點在圖中出現的順序返回結果。

返回值:返回符合指定op和target條件的節點迭代器。

警告:此API為實驗性質,且保證向后兼容。


get_attr(qualified_name, type_expr=None)

向圖中插入一個 get_attr 節點。get_attr 類型的 Node 表示從 Module 層次結構中獲取某個屬性。

參數

  • qualified_name (str) - 要獲取屬性的全限定名稱。例如,若被追蹤的 Module 包含名為 foo 的子模塊,該子模塊又包含名為 bar 的子模塊,而 bar 擁有名為 baz 的屬性,則應將全限定名稱 foo.bar.baz 作為 qualified_name 傳入。
  • type_expr (Optional[Any]) - 可選的類型注解,用于表示該節點輸出值的 Python 類型。

返回

新創建并插入的 get_attr 節點。

返回類型:Node

注意:本方法的插入點與類型表達式規則與 Graph.create_node 方法保持一致。

注意:此 API 保證向后兼容性。


graph_copy(g, val_map, return_output_node=False)

將給定圖中的所有節點復制到 self 中。

參數

  • g (Graph) – 作為節點復制來源的原始圖。
  • val_map (Dict[Node,* Node]) – 用于存儲節點映射關系的字典,鍵為 g 中的節點,值為 self 中的對應節點。注意:val_map 可預先包含值以實現特定值的復制覆蓋。

返回值:如果 g 包含輸出節點,則返回 self 中與 g 輸出值等效的值;否則返回 None

返回類型:Optional[Union [tuple [Argument, …], Sequence [Argument], Mapping [str , Argument], slice , range , Node, str , int , float, bool , complex , [dtype](tensor_attributes.html#torch.dtype "torch.dtype"), Tensor , device , memory_format , layout , OpOverload, [SymInt](torch.html#torch.SymInt "torch.SymInt"), SymBool , SymFloat ]]

注意:本API保證向后兼容性。


inserting_after(n=None)

設置 create_node 及相關方法在圖中插入節點的位置。當在 with 語句中使用時,這會臨時設置插入點,并在 with 語句退出時恢復原位置。


with g.inserting_after(n):...  # inserting after node n
...  # insert point restored to what it was previously
g.inserting_after(n)  #  set the insert point permanently

參數:

n (可選[Node]): 要在其之前插入的節點。如果為None,則會在整個圖的起始位置之后插入。

返回:
一個資源管理器,它會在__exit__時恢復插入點。

注意:此API保證向后兼容性。


inserting_before(n=None)

設置 create_node 及相關方法在圖中插入節點的基準位置。當在 with 語句中使用時,這將臨時設置插入點,并在 with 語句退出時恢復原位置。


with g.inserting_before(n):...  # inserting before node n
...  # insert point restored to what it was previously
g.inserting_before(n)  #  set the insert point permanently

參數:

n (Optional[Node]): 要插入位置的前一個節點。如果為None,則會在整個圖的起始位置前插入。

返回:
一個資源管理器,該管理器會在__exit__時恢復插入點。

注意:此API保證向后兼容性。


lint()

對該圖執行多項檢查以確保其結構正確。具體包括:

  • 檢查節點是否具有正確的所有權(由本圖所有)
  • 檢查節點是否按拓撲順序排列
  • 若該圖擁有所屬的GraphModule,則檢查目標是否存在該GraphModule中

注:本API保證向后兼容性。


node_copy(node, arg_transform=<function Graph.<lambda>>)

將節點從一個圖復制到另一個圖中。arg_transform需要將節點所在圖的參數轉換為目標圖(self)的參數。示例:

# Copying all the nodes in `g` into `new_graph`
g: torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}for node in g.nodes:value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])

參數

  • node (Node) – 要復制到 self 中的節點。
  • arg_transform (Callable[[Node], Argument]) – 一個函數,用于將節點 argskwargs 中的 Node 參數轉換為 self 中的等效參數。最簡單的情況下,該函數應從原始圖中節點到 self 的映射表中檢索值。

返回類型:Node

注意:此 API 保證向后兼容性。


property nodes: _node_list

獲取構成該圖的所有節點列表。

請注意,這個Node列表是以雙向鏈表的形式表示的。在迭代過程中進行修改(例如刪除節點、添加節點)是安全的。

返回值:一個雙向鏈表結構的節點列表。注意可以對該列表調用reversed方法來切換迭代順序。


on_generate_code(make_transformer)

在生成 Python 代碼時注冊轉換器函數

參數:

make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):返回待注冊代碼轉換器的函數。

該函數由 on_generate_code 調用以獲取代碼轉換器。

此函數的輸入參數為當前已注冊的代碼轉換器(若未注冊則為 None),以便在不需要覆蓋時使用。該機制可用于串聯多個代碼轉換器。

返回值:一個上下文管理器,當在 with 語句中使用時,會自動恢復先前注冊的代碼轉換器。


示例:

gm: fx.GraphModule = ...# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):return ["import pdb; pdb.set_trace()\n", body]# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(lambda _: insert_pdb)# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(lambda current_trans: (lambda body: insert_pdb(current_trans(body) if current_trans else body))
)gm.recompile()
gm(inputs)  # drops into pdb

該函數也可作為上下文管理器使用,其優勢在于能自動恢復之前注冊的代碼轉換器。


# ... continue from previous examplewith gm.graph.on_generate_code(lambda _: insert_pdb):# do more stuff with `gm`...gm.recompile()gm(inputs)  # drops into pdb# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).

警告:此 API 為實驗性質,且向后兼容。


output(result, type_expr=None)

output Node 插入到 Graph 中。output 節點代表 Python 代碼中的 return 語句。result 是應當返回的值。

參數

  • result (Argument) – 要返回的值。
  • type_expr (Optional[Any]) – 可選的類型注解,表示此節點輸出將具有的 Python 類型。

注意:此方法的插入點和類型表達式規則與 Graph.create_node 相同。

注意:此 API 保證向后兼容性。


output_node()

警告:此 API 為實驗性質,且向后兼容。

返回值類型:Node


placeholder(name, type_expr=None, default_value)

在圖中插入一個placeholder節點。placeholder表示函數的輸入參數。

參數

  • name (str) - 輸入值的名稱。這對應于該Graph所表示函數的位置參數名稱。
  • type_expr (Optional[Any]) - 可選的類型注解,表示該節點輸出值的Python類型。在某些情況下(例如當函數后續用于TorchScript編譯時),這是生成正確代碼所必需的。
  • default_value (Any) - 該函數參數的默認值。注意:為了允許None作為默認值,當參數沒有默認值時,應傳遞inspect.Signature.empty來指定。

返回類型:Node

注意:此方法的插入點和類型表達式規則與Graph.create_node相同。

注意:此API保證向后兼容性。


print_tabular()

以表格形式打印圖的中間表示。注意:此API需要安裝tabulate模塊。

注:該API保證向后兼容性。


process_inputs(*args)

處理參數以便它們可以傳遞到 FX 計算圖中。

警告:此 API 為實驗性質,且向后兼容。


process_outputs(out)

警告:此 API 為實驗性質,且向后兼容。


python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)

將這段Graph轉換為有效的Python代碼。

參數

  • root_module (str) – 用于查找限定名稱目標的根模塊名稱。通常為’self’。

返回值:src: 表示該對象的Python源代碼
globals: 包含src中全局名稱及其引用對象的字典

返回類型:一個包含兩個字段的PythonCode對象

注意:此API保證向后兼容性。


set_codegen(codegen)

警告:此 API 為實驗性功能,且向后兼容。


class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)

Node 是表示 Graph 中單個操作的數據結構。在大多數情況下,Node 表示對各種實體的調用點,例如運算符、方法和模塊(某些例外包括指定函數輸入和輸出的節點)。每個 Node 都有一個由其 op 屬性指定的函數。不同 op 值的 Node 語義如下:

  • placeholder 表示函數輸入。name 屬性指定該值的名稱。target 同樣是參數的名稱。args 包含:1) 空值,或 2) 表示函數輸入默認參數的單個參數。kwargs 無關緊要。占位符對應于圖形輸出中的函數參數(例如 x)。
  • get_attr 從模塊層次結構中檢索參數。name 同樣是獲取結果后賦值的名稱。target 是參數在模塊層次結構中的完全限定名稱。argskwargs 無關緊要。
  • call_function 將自由函數應用于某些值。name 同樣是賦值目標的名稱。target 是要應用的函數。argskwargs 表示函數的參數,遵循 Python 調用約定。
  • call_module 將模塊層次結構中的 forward() 方法應用于給定參數。name 同前。target 是要調用的模塊在模塊層次結構中的完全限定名稱。argskwargs 表示調用模塊時的參數(不包括 self 參數*)。
  • call_method 調用值的方法。name 類似。target 是要應用于 self 參數的方法名稱字符串。argskwargs 表示調用模塊時的參數(包括 self 參數*)。
  • output 在其 args[0] 屬性中包含跟蹤函數的輸出。這對應于圖形輸出中的 “return” 語句。

注意:此 API 保證向后兼容。


property all_input_nodes:  list ['Node'] 
Return all Nodes that are inputs to this Node. This is equivalent to iterating over `args` and `kwargs` and only collecting the values that are Nodes.Returns
List of `Nodes` that appear in the `args` and `kwargs` of this `Node`, in that order.append(x)

在圖的節點列表中,將 x 插入到當前節點之后。

等價于調用 self.next.prepend(x)

參數

  • x (Node) – 要插入到當前節點后的節點。必須屬于同一個圖。

注意:此 API 保證向后兼容。


property args:  tuple [Union [tuple ['Argument', 
...], collections.abc.Sequence ['Argument'], collections.abc.Mapping[str , 'Argument'], slice , range , torch.fx.node.Node, str , int , float, bool , complex , torch.dtype , torch.Tensor, torch.device , torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], 
...] 

Node的參數元組。參數的具體含義取決于節點的操作碼(opcode)。更多信息請參閱Node文檔字符串。

允許對此屬性進行賦值操作。所有關于使用情況和用戶的記錄都會在賦值時自動更新。


format_node(placeholder_names=None, maybe_return_typename=None)

返回一個描述性的字符串表示形式self

該方法可不帶參數使用,作為調試工具。

此函數也用于Graph__str__方法內部。placeholder_namesmaybe_return_typename中的字符串共同構成了該Graph所屬GraphModule中自動生成的forward函數的簽名。placeholder_namesmaybe_return_typename不應在其他情況下使用。

參數

  • placeholder_names (Optional[list[str]]) - 一個列表,用于存儲表示生成的forward函數中占位符的格式化字符串。僅供內部使用。
  • maybe_return_typename (Optional[list[str]]) - 一個單元素列表,用于存儲表示生成的forward函數輸出的格式化字符串。僅供內部使用。

返回

如果1)我們在Graph__str__方法中將format_node用作內部輔助工具,且2)self是一個占位符Node,則返回None。否則,返回當前Node的描述性字符串表示形式。

返回類型:str

注意:此API保證向后兼容。


insert_arg(idx, arg)

在參數列表的指定索引位置插入一個位置參數。

參數

  • idx ( int ) – 要插入到self.args中元素之前的索引位置。
  • arg (Argument) – 要插入到args中的新參數值

注意:本API保證向后兼容性。


is_impure()

返回該操作是否為不純操作,即判斷其操作是否為占位符或輸出,或者是否為不純的call_functioncall_module

返回值:指示該操作是否不純。

返回類型:bool

警告:此API為實驗性質,且向后兼容。


property kwargs:  dict[str , Union [tuple ['Argument', 
...], collections.abc.Sequence['Argument'], collections.abc.Mapping, [str , 'Argument'], slice , range , torch.fx.node.Node, str , int , float, bool , complex , torch.dtype , torch.Tensor, torch.device , torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]] 

Node的關鍵字參數字典。參數的解析取決于節點的操作碼。更多信息請參閱Node文檔字符串。

允許對此屬性進行賦值。所有關于使用情況和用戶的統計都會在賦值時自動更新。


property next: Node

返回鏈表中下一個Node節點。

返回值:鏈表中下一個Node節點。


normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)

返回經過標準化的Python目標參數。這意味著當normalize_to_only_use_kwargs為真時,args/kwargs將與模塊/函數的簽名匹配,并按位置順序僅返回kwargs。

同時會填充默認值。不支持僅限位置參數或可變參數。

支持模塊調用。

可能需要arg_typeskwarg_types來消除重載歧義。

參數

  • root (torch.nn.Module) – 用于解析模塊目標的基模塊
  • arg_types (Optional[Tuple[Any]]) – 參數的元組類型
  • kwarg_types (Optional[Dict[str, Any]]) – 關鍵字參數的字典類型
  • normalize_to_only_use_kwargs ([bool]) – 是否標準化為僅使用kwargs

返回
返回命名元組ArgsKwargsPair,若失敗則返回None

返回類型
Optional[ArgsKwargsPair]

警告:該API為實驗性質,不保證向后兼容。


prepend(x)

在圖的節點列表中,在此節點前插入x。示例:

Before: p -selfbx -x -ax
After:  p -x -selfbx -ax

參數

  • x (Node) – 要放置在該節點之前的節點。必須是同一圖的成員。

注意:此 API 保證向后兼容。


property prev: Node

返回鏈表中當前節點的前一個Node

返回值:鏈表中當前節點的前一個Node


replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)

將圖中所有使用 self 的地方替換為節點 replace_with

參數

  • replace_with (Node) – 用于替換所有 self 的節點。
  • delete_user_cb (Callable) – 回調函數,用于判斷是否應移除某個使用原 self 節點的用戶節點。
  • propagate_meta ([bool]) – 是否將原節點 .meta 字段的所有屬性復制到替換節點上。出于安全考慮,僅當替換節點本身沒有 .meta 字段時才允許此操作。

返回值

返回受此變更影響的節點列表。

返回類型:list [Node]

注意:此 API 保證向后兼容。


replace_input_with(old_input, new_input)

遍歷 self 的輸入節點,將所有 old_input 實例替換為 new_input

參數

  • old_input (Node) – 需要被替換的舊輸入節點。
  • new_input (Node) – 用于替換 old_input 的新輸入節點。

注意:此 API 保證向后兼容性。


property stack_trace: Optional[str ] 

返回在追蹤過程中記錄的 Python 堆棧跟蹤信息(如果有)。

當使用 fx.Tracer 進行追蹤時,該屬性通常由 Tracer.create_proxy 填充。若需在追蹤過程中記錄堆棧跟蹤以用于調試,請在 Tracer 實例上設置 record_stack_traces = True

當使用 dynamo 進行追蹤時,該屬性默認會由 OutputGraph.create_proxy 填充。

stack_trace 的字符串末尾將包含最內層的調用幀。


update_arg(idx, arg)

更新現有位置參數以包含新值

調用后,self.args[idx] == arg 將成立。

參數

  • idx ( int ) - 要更新元素在 self.args 中的索引位置
  • arg (Argument) - 要寫入 args 的新參數值

注意:此 API 保證向后兼容性。


update_kwarg(key, arg)

更新現有關鍵字參數以包含新值

arg。調用后,self.kwargs[key] == arg

參數

  • key (str) - 要更新的元素在self.kwargs中的鍵名
  • arg (Argument) - 要寫入kwargs的新參數值

注意:此API保證向后兼容性。


class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())

Tracer 是實現 torch.fx.symbolic_trace 符號追蹤功能的類。調用 symbolic_trace(m) 等價于執行 Tracer().trace(m)

可以通過繼承 Tracer 類來覆蓋追蹤過程中的各種行為。具體可覆蓋的行為詳見該類方法的文檔字符串。

注意:此 API 保證向后兼容。


call_module(m, forward, args, kwargs)

該方法定義了當Tracer遇到對nn.Module實例調用時的行為。

默認行為是通過is_leaf_module檢查被調用的模塊是否為葉子模塊。如果是,則在Graph中生成指向mcall_module節點;否則正常調用該Module,并跟蹤其forward函數中的操作。

可通過重寫此方法實現自定義行為,例如:

  • 創建嵌套的追蹤GraphModules
  • 實現跨Module邊界追蹤時的特殊處理

參數說明:

  • m (Module) - 當前被調用的模塊實例
  • forward (Callable) - 待調用模塊的forward()方法
  • args (Tuple) - 模塊調用點的參數元組
  • kwargs (Dict) - 模塊調用點的關鍵字參數字典

返回值:

  • 若生成call_module節點,則返回Proxy代理值
  • 否則返回模塊調用的原始結果

返回類型:任意類型

注意:本API保證向后兼容性。


create_arg(a)

一種方法,用于指定在準備值作為Graph中節點的參數時追蹤的行為。

默認行為包括:

1、遍歷集合類型(如元組、列表、字典)并遞歸地對元素調用create_args

2、給定一個Proxy對象,返回底層IR Node的引用。

3、給定一個非Proxy的Tensor對象,為以下情況生成IR:

  • 對于Parameter,生成一個引用該Parameter的get_attr節點。

  • 對于非Parameter的Tensor,將該Tensor存儲在一個特殊屬性中,并引用該屬性。

可以重寫此方法以支持更多類型。


參數

  • a (Any) – 將被作為ArgumentGraph中使用的值。

返回值:將值a轉換為適當的Argument

返回類型:Argument


注意:此API保證向后兼容。


create_args_for_root(root_fn, is_module, concrete_args=None)

root模塊的簽名創建對應的placeholder節點。該方法會檢查root模塊的簽名并據此生成這些節點,同時支持*args**kwargs參數。

警告:此API為實驗性質,且向后兼容。


create_node(kind, target, args, kwargs, name=None, type_expr=None)

根據給定的目標、參數、關鍵字參數和名稱插入一個圖節點。

該方法可以被重寫,用于在節點創建過程中對使用的值進行額外檢查、驗證或修改。例如,可能希望禁止記錄原地操作。

注意:此API保證向后兼容性。

返回類型:Node


create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)

根據給定的參數創建一個節點,然后返回包裹在 Proxy 對象中的節點。

如果 kind = ‘placeholder’,則表示我們正在創建一個代表函數參數的節點。若需要編碼默認參數,則使用 args 元組。對于 placeholder 類型的節點,args 在其他情況下為空。

注意:此 API 保證向后兼容性。


get_fresh_qualname(prefix)

獲取一個基于前綴的新名稱并返回。該函數確保生成的名稱不會與圖中現有屬性發生沖突。

注意:此API保證向后兼容。

返回類型:str


getattr(attr, attr_val, parameter_proxy_cache)

該方法定義了當對nn.Module實例調用getattr時,該Tracer的行為表現。

默認情況下,其行為是返回該屬性的代理值。同時會將代理值存入parameter_proxy_cache中,以便后續調用能復用該代理而非新建。

可通過重寫此方法來實現不同行為——例如在查詢參數時不返回代理。

參數說明:

  • attr (str) - 被查詢的屬性名稱
  • attr_val (Any) - 該屬性的值
  • parameter_proxy_cache (Dict[str, Any]) - 屬性名到代理值的映射緩存

返回值:
getattr調用的返回結果。

警告:此API屬于實驗性質,且不保證向后兼容。


is_leaf_module(m, module_qualified_name)

一種用于判斷給定nn.Module是否為"葉子"模塊的方法。

葉子模塊是指出現在IR(中間表示)中的原子單元,通過call_module調用進行引用。默認情況下,PyTorch標準庫命名空間(torch.nn)中的模塊都屬于葉子模塊。除非通過本參數特別指定,否則其他所有模塊都會被追蹤并記錄其組成操作。

參數說明:

  • m (Module) - 被查詢的模塊
  • module_qualified_name (str) - 該模塊到根模塊的路徑。例如,若模塊層級結構中子模塊foo包含子模塊bar,而bar又包含子模塊baz,則該模塊的限定名將顯示為foo.bar.baz

返回類型:bool

注意:本API保證向后兼容性。


iter(obj)

當代理對象被迭代時調用,例如在控制流中使用時。通常我們不知道如何處理,因為我們不知道代理的值,但自定義跟蹤器可以通過 create_node 向圖節點附加更多信息,并可以選擇返回一個迭代器。

注意:此 API 保證向后兼容性。

返回類型:迭代器


keys(obj)

當代理對象的 keys() 方法被調用時觸發。這是在代理對象上調用 ** 時發生的情況。該方法應返回一個迭代器,如果 ** 需要在自定義追蹤器中生效。

注意:此 API 保證向后兼容。

返回類型:任意


path_of_module(mod)

這是一個輔助方法,用于在root模塊的層級結構中查找mod的限定名稱。例如,如果root有一個名為foo的子模塊,而foo又有一個名為bar的子模塊,那么將bar傳入此函數將返回字符串"foo.bar"。

參數

  • mod (str) – 需要獲取限定名稱的Module

返回類型:str

注意:此API保證向后兼容性。


proxy(node)

注意:此 API 保證向后兼容性。

返回類型:Proxy

to_bool(obj)

當代理對象需要轉換為布爾值時調用,例如在控制流中使用時。通常我們無法確定如何處理,因為不知道代理的具體值,但自定義追蹤器可以通過create_node向圖節點附加更多信息,并選擇返回一個值。

注意:此API保證向后兼容。

返回類型:bool


trace(root, concrete_args=None)

追蹤 root 并返回對應的 FX Graph 表示形式。root 可以是 nn.Module 實例或 Python 可調用對象。

請注意,在此調用后,self.root 可能與傳入的 root 不同。例如,當向 trace() 傳遞自由函數時,我們會創建一個 nn.Module 實例作為根節點,并添加嵌入的常量。

參數

  • root (Union[Module, Callable]) – 需要追蹤的 Module 或函數。該參數保證向后兼容性。
  • concrete_args (Optional[Dict[str, any]]) – 不應被視為代理的具體參數。此參數為實驗性功能,其向后兼容性作保證。

返回值:表示傳入 root 語義的 Graph 對象。

返回類型:Graph

注意:此 API 保證向后兼容性。


class torch.fx.Proxy(node, tracer=None)

Proxy對象是Node包裝器,在符號追蹤過程中流經程序,并記錄它們接觸到的所有操作(包括torch函數調用、方法調用和運算符)到不斷增長的FX Graph中。

如果需要進行圖變換,您可以在原始Node上封裝自己的Proxy方法,這樣就可以使用重載運算符向Graph添加額外內容。

Proxy對象不可迭代。換句話說,如果在循環中或作為*args/**kwargs函數參數使用Proxy,符號追蹤器會拋出錯誤。

有兩種主要解決方法:

1、將不可追蹤的邏輯提取到頂層函數中,并使用fx.wrap進行處理。
2、如果控制流是靜態的(即循環次數基于某些超參數),可以保持代碼在原位,并重構為類似形式:

for i in range(self.some_hyperparameter):indexed_item = proxied_value[i]

如需更深入了解 Proxy 的內部實現細節,請查閱 torch/fx/README.md 文件中的 “Proxy” 章節。


注意:本 API 保證向后兼容性。


class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)

解釋器(Interpreter)會逐節點(Node-by-Node)執行FX圖。這種模式在許多場景下非常有用,包括編寫代碼轉換器以及分析過程。

通過重寫Interpreter類中的方法,可以自定義執行行為。以下是按調用層次結構劃分的可重寫方法映射:

run()+-- run_node+-- placeholder()+-- get_attr()+-- call_function()+-- call_method()+-- call_module()+-- output()

示例

假設我們需要將所有 torch.neg 實例與 torch.sigmoid 互換(包括它們對應的 Tensor 方法等價形式)。我們可以通過如下方式繼承 Interpreter 類:

class NegSigmSwapInterpreter(Interpreter):def call_function(self, target: Target, args: Tuple, kwargs: Dict) -Any:if target == torch.sigmoid:return torch.neg(args, *kwargs)return super().call_function(target, args, kwargs)def call_method(self, target: Target, args: Tuple, kwargs: Dict) -Any:if target == "neg":call_self, args_tail = argsreturn call_self.sigmoid(args_tail, *kwargs)return super().call_method(target, args, kwargs)def fn(x):return torch.sigmoid(x).neg()gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())

參數

  • module ( torch.nn.Module ) – 待執行的模塊
  • garbage_collect_values ([bool]) – 是否在模塊執行過程中最后一次使用后刪除值。這能確保執行期間內存使用最優。可以禁用此功能,例如通過查看Interpreter.env屬性來檢查執行中的所有中間值。
  • graph (Optional[Graph]) – 如果傳入該參數,解釋器將執行此圖而非module.graph,并使用提供的模塊參數來滿足任何狀態請求。

注意:此API保證向后兼容性。


boxed_run(args_list)

通過解釋方式運行模塊并返回結果。該過程采用"boxed"調用約定,即傳遞一個參數列表(這些參數會被解釋器自動清除),從而確保輸入張量能夠及時釋放。

注意:本API保證向后兼容性。


call_function(target, args, kwargs)

執行一個call_function節點并返回結果。

參數

  • target (Target) – 該節點的調用目標。關于語義的詳細信息請參閱Node
  • args (Tuple) – 本次調用的位置參數元組
  • kwargs (Dict) – 本次調用的關鍵字參數字典

返回類型:任意類型

返回值: 函數調用返回的值

注意:此API保證向后兼容性。


call_method(target, args, kwargs)

執行一個 call_method 節點并返回結果。

參數

  • target (Target) – 該節點的調用目標。有關語義的詳細信息,請參閱 Node
  • args (Tuple) – 該調用的位置參數元組
  • kwargs (Dict) – 該調用的關鍵字參數字典

返回類型:任意

返回值:方法調用返回的值

注意:此 API 保證向后兼容性。


call_module(target, args, kwargs)

執行一個call_module節點并返回結果。

參數

  • target (Target) – 該節點的調用目標。關于語義的詳細信息請參閱
    Node
  • args (Tuple) – 本次調用的位置參數元組
  • kwargs (Dict) – 本次調用的關鍵字參數字典

返回類型:Any

返回值:模塊調用返回的值

注意:此API保證向后兼容性。


fetch_args_kwargs_from_env(n)

從當前執行環境中獲取節點nargskwargs具體值

參數

  • n (Node) – 需要獲取argskwargs的目標節點

返回值
節點n對應的具體argskwargs

返回類型:Tuple[Tuple, Dict]

注意:本API保證向后兼容性


fetch_attr(target)

self.moduleModule 層級結構中獲取一個屬性。

參數

  • target (str) - 要獲取屬性的全限定名稱

返回

該屬性的值。

返回類型

任意類型

注意:此 API 保證向后兼容。


get_attr(target, args, kwargs)

執行一個 get_attr 節點。該操作會從 self.moduleModule 層級結構中獲取屬性值。

參數

  • target (Target) – 該節點的調用目標。關于語義的詳細信息請參閱 Node
  • args (Tuple) – 本次調用的位置參數元組
  • kwargs (Dict) – 本次調用的關鍵字參數字典

返回值
獲取到的屬性值

返回類型
任意類型

注意:此 API 保證向后兼容性。


map_nodes_to_values(args, n)

遞歸遍歷 args 并在當前執行環境中查找每個 Node 的具體值。

參數

  • args (Argument) – 需要查找具體值的數據結構
  • n (Node)args 所屬的節點。僅用于錯誤報告。

返回類型:Optional[Union [tuple [Argument’, …], Sequence [Argument], Mapping [str , Argument], slice , range , Node, str , int , float, bool , complex , dtype, Tensor , device , memory_format , layout , OpOverload, SymInt, SymBool , SymFloat ]]

注意:此 API 保證向后兼容性。


output(target, args, kwargs)

執行一個output節點。該操作實際上只是獲取output節點引用的值并返回它。

參數

  • target (Target) – 該節點的調用目標。有關語義詳情請參閱
    Node
  • args (Tuple) – 本次調用的位置參數元組
  • kwargs (Dict) – 本次調用的關鍵字參數字典

返回值:輸出節點引用的返回值

返回類型:任意類型

注意:此API保證向后兼容。


placeholder(target, args, kwargs)

執行一個placeholder節點。請注意這是有狀態的:

Interpreter內部維護了一個針對run方法傳入參數的迭代器,本方法會返回該迭代器的next()結果。

參數

  • target (Target) – 該節點的調用目標。關于語義的詳細信息請參閱Node
  • args (Tuple) – 本次調用的位置參數元組
  • kwargs (Dict) – 本次調用的關鍵字參數字典

返回值:獲取到的參數值。

返回類型:任意類型

注意:此API保證向后兼容。


run(*args, initial_env=None, enable_io_processing=True)

通過解釋執行模塊并返回結果。

參數

  • *args – 按位置順序傳遞給模塊的運行參數
  • initial_env (Optional[Dict[Node, Any]]) – 可選的執行初始環境。這是一個將節點映射到任意值的字典。例如,可用于預先填充某些節點的結果,從而在解釋器中僅進行部分求值。
  • enable_io_processing ([bool]) – 如果為true,我們會在使用輸入和輸出之前,先用圖的process_inputs和process_outputs函數對它們進行處理。

返回值:執行模塊后返回的值

返回類型:任意

注意:此API保證向后兼容。


run_node(n)

運行特定節點 n 并返回結果。

根據 node.op 的類型,調用對應的占位符、get_attr、call_function、call_method、call_module 或 output 方法。

參數

  • n (Node) – 需要執行的節點

返回值:執行節點 n 的結果

返回類型:任意類型

注意:此 API 保證向后兼容性。


class torch.fx.Transformer(module)

Transformer 是一種特殊類型的解釋器,用于生成新的 Module。它提供了一個 transform() 方法,返回轉換后的 Module。與 Interpreter 不同,Transformer 不需要參數即可運行,完全基于符號化方式工作。


示例

假設我們需要將所有 torch.neg 實例與 torch.sigmoid 互換(包括它們的 Tensor 方法等效形式)。可以通過如下方式子類化 Transformer

class NegSigmSwapXformer(Transformer):def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:if target == torch.sigmoid:return torch.neg(*args, **kwargs)return super().call_function(target, args, kwargs)def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:if target == "neg":call_self, *args_tail = argsreturn call_self.sigmoid(*args_tail, **kwargs)return super().call_method(target, args, kwargs)def fn(x):return torch.sigmoid(x).neg()gm = torch.fx.symbolic_trace(fn)transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())

參數

  • module ([GraphModule](https://pytorch.org/docs/stable/data.html#torch.fx.GraphModule "torch.fx.GraphModule")) – 待轉換的Module對象。

注意:此API保證向后兼容性。


call_function(target, args, kwargs)

注意:該 API 保證向后兼容。

返回類型

Any


call_module(target, args, kwargs)

注意:此 API 保證向后兼容。

返回類型

Any


get_attr(target, args, kwargs)

執行一個 get_attr 節點。在 Transformer 中,該方法被重寫以便向輸出圖中插入新的 get_attr 節點。

參數

  • target (Target) – 該節點的調用目標。關于語義的詳細信息請參閱
    Node
  • args (Tuple) – 該調用的位置參數元組
  • kwargs (Dict) – 該調用的關鍵字參數字典

返回類型
Proxy

注意:此 API 保證向后兼容。


placeholder(target, args, kwargs)

執行一個 placeholder 節點。在 Transformer 中,該方法被重寫以便向輸出圖中插入新的 placeholder

參數

  • target (Target) – 該節點的調用目標。關于語義的詳細信息請參閱 Node
  • args (Tuple) – 該調用的位置參數元組
  • kwargs (Dict) – 該調用的關鍵字參數字典

返回類型:Proxy

注意:此 API 保證向后兼容。


transform()

轉換 self.module 并返回轉換后的 GraphModule

注意:此 API 保證向后兼容性。

返回類型 : GraphModule


torch.fx.replace_pattern(gm, pattern, replacement)

在GraphModule的圖結構(gm)中,匹配所有可能的非重疊運算符集及其數據依賴關系(pattern),然后將每個匹配到的子圖替換為另一個子圖(replacement)。

參數

  • gm (GraphModule) - 封裝待操作圖的GraphModule
  • pattern (Union[Callable, GraphModule]) - 需要在gm中匹配并替換的子圖
  • replacement (Union[Callable, GraphModule]) - 用于替換pattern的子圖

返回值:返回一個Match對象列表,表示原始圖中與pattern匹配的位置。如果沒有匹配項則返回空列表。Match定義如下:

class Match(NamedTuple):# Node from which the match was foundanchor: Node# Maps nodes in the pattern subgraph to nodes in the larger graphnodes_map: Dict[Node, Node]

返回類型:List[Match]


示例:

import torch
from torch.fx import symbolic_trace, subgraph_rewriterclass M(torch.nn.Module):def __init__(self) -None:super().__init__()def forward(self, x, w1, w2):m1 = torch.cat([w1, w2]).sum()m2 = torch.cat([w1, w2]).sum()return x + torch.max(m1) + torch.max(m2)def pattern(w1, w2):return torch.cat([w1, w2])def replacement(w1, w2):return torch.stack([w1, w2])traced_module = symbolic_trace(M())subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上述代碼會先在 traced_moduleforward 方法中匹配 pattern。模式匹配基于使用-定義關系而非節點名稱進行。例如,若 pattern 中包含 p = torch.cat([a, b]),則可以在原始 forward 函數中匹配到 m = torch.cat([a, b]),即使變量名不同(pm)也不影響。

pattern 中的 return 語句僅根據其值進行匹配,它可能與更大圖中的 return 語句匹配,也可能不匹配。換句話說,模式不必延伸至更大圖的末尾。

當模式匹配成功時,它將從更大的函數中被移除,并由 replacement 替換。如果更大函數中存在多個 pattern 匹配項,每個非重疊的匹配項都會被替換。若出現匹配重疊的情況,則替換重疊匹配集中最先找到的匹配項(此處的"最先"定義為節點使用-定義關系拓撲排序中的第一個節點。大多數情況下,第一個節點是緊接 self 后出現的參數,而最后一個節點是函數返回的內容)。

需要特別注意:pattern 可調用對象的參數必須在該可調用對象內部使用,且 replacement 可調用對象的參數必須與模式匹配。第一條規則解釋了為何上述代碼塊中 forward 函數有參數 x, w1, w2,而 pattern 函數只有參數 w1, w2——因為 pattern 未使用 x,故不應將 x 指定為參數。

關于第二條規則的示例,考慮替換…


def pattern(x, y):return torch.neg(x) + torch.relu(y)

with


def replacement(x, y):return torch.relu(x)

在這種情況下,replacement需要與pattern相同數量的參數(包括xy),即使參數yreplacement中并未使用。

調用subgraph_rewriter.replace_pattern后,生成的Python代碼如下所示:

def forward(self, x, w1, w2):stack_1 = torch.stack([w1, w2])sum_1 = stack_1.sum()stack_2 = torch.stack([w1, w2])sum_2 = stack_2.sum()max_1 = torch.max(sum_1)add_1 = x + max_1max_2 = torch.max(sum_2)add_2 = add_1 + max_2return add_2

注意:該 API 保證向后兼容。


torch.fx.experimental


警告:這些API屬于實驗性質,可能會隨時變更而不另行通知。


torch.fx.experimental.symbolic_shapes

ShapeEnv
DimDynamic控制如何為維度分配符號。
StrictMinMaxConstraint對客戶端:該維度的大小必須在’vr’范圍內(指定包含性上下界),且必須為非負數且不應為0或1(但參見下方注意事項)。
RelaxedUnspecConstraint對客戶端:無顯式約束;約束由追蹤過程中的守衛隱式推斷得出。
EqualityConstraint表示并判定輸入源之間的各類相等性約束。
SymbolicContext數據結構,指定在create_symbolic_sizes_strides_storage_offset中如何創建符號;例如,應為靜態還是動態。
StatelessSymbolicContext通過DimDynamicDimConstraint給定的symbolic_context判定,在create_symbolic_sizes_strides_storage_offset中創建符號。
StatefulSymbolicContext通過Source:Symbol緩存給定的symbolic_context判定,在create_symbolic_sizes_strides_storage_offset中創建符號。
SubclassSymbolicContext可追蹤張量子類的內部張量的正確符號上下文可能與外部符號上下文不同。
DimConstraints針對符號維度約束系統的自定義求解器。
ShapeEnvSettings封裝所有可能影響FakeTensor調度的形狀環境設置。
ConvertIntKey
CallMethodKey
PropagateUnbackedSymInts
DivideByKey
InnerTensorKey
hint_int獲取整數的提示值(基于運行時觀察到的底層實際值)。
is_concrete_int檢查SymInt底層對象是否為具體值的實用工具。
is_concrete_bool檢查SymBool底層對象是否為具體值的實用工具。
is_concrete_float檢查SymInt底層對象是否為具體值的實用工具。
has_free_symbolsbool(free_symbols(val))的快速版本
has_free_unbacked_symbolsbool(free_unbacked_symbols(val))的快速版本
definitely_true僅當能確定a為True時返回True,過程中可能引入守衛。
definitely_false僅當能確定a為False時返回True,過程中可能引入守衛。
guard_size_oblivious以無視大小的方式對符號布爾表達式執行守衛。
sym_eq類似==,但在列表/元組上運行時,會遞歸測試相等性并使用sym_and連接結果,不引入守衛。
constrain_range應用約束使傳入的SymInt必須在min-max范圍內(包含邊界),且不引入SymInt的守衛(意味著可用于未綁定的SymInt)。
constrain_unify給定兩個SymInt,約束它們必須相等。
canonicalize_bool_expr通過將布爾表達式轉換為lt/le不等式并將所有非常量項移至右側,實現規范化。
statically_known_true如果x可簡化為常量且為真,則返回True。
lru_cache
check_consistent測試兩個"meta"值(通常為Tensor或SymInt)是否具有相同的值,例如在重追蹤后。
compute_unbacked_bindings在運行fake tensor傳播并生成example_value結果后,遍歷example_value查找新綁定的未支持符號并記錄其路徑供后續使用。
rebind_unbacked假設我們正在重追蹤一個已有FX圖,該圖先前進行過fake tensor傳播(因此存在未支持的SymInt)。
resolve_unbacked_bindings
is_accessor_node

torch.fx.experimental.proxy_tensor

make_fx給定函數f,返回一個新函數。當使用有效參數執行該函數時,會返回一個FX GraphModule,表示執行過程中所執行的操作集合。
handle_sym_dispatch調用當前活動的代理跟蹤模式,對操作SymInt/SymFloat/SymBool參數的函數進行符號調度跟蹤。
get_proxy_mode獲取當前活動的代理跟蹤模式,如果當前未處于跟蹤狀態則返回None。
maybe_enable_thunkify在此上下文管理器內,如果正在進行make_fx跟蹤,將對所有SymNode計算進行thunkify處理,并避免將其跟蹤到圖中,除非確實需要。
maybe_disable_thunkify在某個上下文中禁用thunkification功能。

2025-05-10(六)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/905399.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/905399.shtml
英文地址,請注明出處:http://en.pswp.cn/news/905399.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

可觀測性方案怎么選?SelectDB vs Elasticsearch vs ClickHouse

可觀測性&#xff08;Observability&#xff09;是指通過系統的外部輸出數據&#xff0c;推斷其內部狀態的能力。可觀測性平臺通過采集、存儲、可視化分析三大可觀測性數據&#xff1a;日志&#xff08;Logging&#xff09;、鏈路追蹤&#xff08;Tracing&#xff09;和指標&am…

機器人廚師上崗!AI在餐飲界掀起新風潮!

想要了解人工智能在其他各個領域的應用&#xff0c;可以查看下面一篇文章 《AI在各領域的應用》 餐飲業是與我們日常生活息息相關的行業&#xff0c;而人工智能&#xff08;AI&#xff09;正在迅速改變這個傳統行業的面貌。從智能點餐到食材管理&#xff0c;再到個性化推薦&a…

Linux動態庫靜態庫總結

靜態庫生成 g -c mylib.cpp -o mylib.o ar rcs libmylib.a mylib.o 動態庫生成 g -fPIC -shared mylib.cpp -o libmylib.so -fPIC&#xff1a;生成位置無關代碼&#xff08;Position-Independent Code&#xff09;&#xff0c;對動態庫必需。 庫文件使用&#xff1a; 靜態庫&…

通過user-agent來源判斷阻止爬蟲訪問網站,并防止生成[ error ] NULL日志

一、TP5.0通過行為&#xff08;Behavior&#xff09;攔截爬蟲并避免生成 [ error ] NULL 錯誤日志 1. 創建行為類&#xff08;攔截爬蟲&#xff09; 在 application/common/behavior 目錄下新建BlockBot.php &#xff0c;用于識別并攔截爬蟲請求&#xff1a; <?php name…

OpenHarmony平臺驅動開發(十五),SDIO

OpenHarmony平臺驅動開發&#xff08;十五&#xff09; SDIO 概述 功能簡介 SDIO&#xff08;Secure Digital Input and Output&#xff09;由SD卡發展而來&#xff0c;與SD卡統稱為MMC&#xff08;MultiMediaCard&#xff09;&#xff0c;二者使用相同的通信協議。SDIO接口…

使用FastAPI和React以及MongoDB構建全棧Web應用03 全棧開發快速入門

一、什么是全棧開發 A full-stack web application is a complete software application that encompasses both the frontend and backend components. It’s designed to interact with users through a web browser and perform actions that involve data processing and …

Coco AI 開源應用程序 - 搜索、連接、協作、您的個人 AI 搜索和助手,都在一個空間中。

一、軟件介紹 文末提供程序和源碼下載 Coco AI 是一個統一的搜索平臺&#xff0c;可將您的所有企業應用程序和數據&#xff08;Google Workspace、Dropbox、Confluent Wiki、GitHub 等&#xff09;連接到一個功能強大的搜索界面中。此存儲庫包含為桌面和移動設備構建的 Coco 應…

CSS經典布局之圣杯布局和雙飛翼布局

目標&#xff1a; 中間自適應&#xff0c;兩邊定寬&#xff0c;并且三欄布局在一行展示。 圣杯布局 實現方法&#xff1a; 通過float搭建布局margin使三列布局到一行上relative相對定位調整位置&#xff1b; 給外部容器添加padding&#xff0c;通過相對定位調整左右兩列的…

# 實時英文 OCR 文字識別:從攝像頭到 PyQt5 界面的實現

實時英文 OCR 文字識別&#xff1a;從攝像頭到 PyQt5 界面的實現 引言 在數字化時代&#xff0c;文字識別技術&#xff08;OCR&#xff09;在眾多領域中發揮著重要作用。無論是文檔掃描、車牌識別還是實時視頻流中的文字提取&#xff0c;OCR 技術都能提供高效且準確的解決方案…

<C#>log4net 的配置文件配置項詳細介紹

log4net 是一個功能強大的日志記錄工具&#xff0c;通過配置文件可以靈活地控制日志的輸出方式、格式、級別等。以下是對 log4net 配置文件常見配置項的詳細介紹&#xff1a; 根元素 <log4net> 這是 log4net 配置文件的根元素&#xff0c;所有配置項都要包含在該元素內…

編譯docker版openresty

使用alpine為基礎鏡像 # 使用Alpine作為基礎鏡像 FROM alpine:3.18# 替換為阿里云鏡像源&#xff0c;并安裝必要的依賴 RUN sed -i s|https://dl-cdn.alpinelinux.org/alpine|https://mirrors.aliyun.com/alpine|g /etc/apk/repositories && \apk add --no-cache \bui…

conda 輸出指定python環境的庫 輸出為 yaml文件

conda 輸出指定python環境的庫 輸出為 yaml文件。 有時為了項目部署&#xff0c;需要匹配之前的python環境&#xff0c;需要輸出對應的python依賴庫。 假設你的目標環境名為 myenv&#xff0c;運行以下命令&#xff1a; conda env export -n myenv > myenv_environment.ym…

[Java][Leetcode middle] 121. 買賣股票的最佳時機

暴力循環 總是以最低的價格買入&#xff0c;以最高的價格賣出: 例如第一天買入&#xff0c;去找剩下n-1天的最高價格&#xff0c;計算利潤 依次計算到n-1天買入&#xff1b; 比較上述利潤 // 運行時間超時。 o(n^2)public int maxProfit1(int[] prices) {int profit 0;for (i…

克隆虛擬機組成集群

一、克隆虛擬機 1. 準備基礎虛擬機 確保基礎虛擬機已安裝好操作系統&#xff08;如 Ubuntu&#xff09;、Java 和 Hadoop。關閉防火墻并禁用 SELinux&#xff08;如適用&#xff09;&#xff1a; bash sudo ufw disable # Ubuntu sudo systemctl disable firewalld # CentO…

記錄一次使用thinkphp使用PhpSpreadsheet擴展導出數據,解決身份證號碼等信息科學計數法問題處理

PhpSpreadsheet官網 PhpSpreadsheet安裝 composer require phpoffice/phpspreadsheet使用composer安裝時一定要下載php對應的版本&#xff0c;下載之前使用php -v檢查當前php版本 簡單使用 <?php require vendor/autoload.php;use PhpOffice\PhpSpreadsheet\Spreadshee…

前端工程化:從 Webpack 到 Vite

引言 前端工程化是現代Web開發不可或缺的一部分&#xff0c;它通過自動化流程和標準化實踐&#xff0c;提高了開發效率和代碼質量。在這個領域中&#xff0c;構建工具扮演著核心角色&#xff0c;而Webpack和Vite則是其中的兩位重要角色。本文將探討前端工程化的演進歷程&#…

Leetcode 3543. Maximum Weighted K-Edge Path

Leetcode 3543. Maximum Weighted K-Edge Path 1. 解題思路2. 代碼實現 題目鏈接&#xff1a;3543. Maximum Weighted K-Edge Path 1. 解題思路 這一題思路上就是一個遍歷的思路&#xff0c;我們只需要考察每一個節點作為起點時&#xff0c;所有長為 k k k的線段的長度&…

香橙派zero3 安卓TV12,更換桌面launcher,開機自啟動kodi

打開開發者模式&#xff0c;連擊版本號&#xff0c;基本上都是這樣。 adb連接 查找桌面包名 adb shell dumpsys activity activities | findstr mResumedActivity 禁用原桌面com.android.tv.launcher&#xff0c;已經安裝了projectivylauncher434.apk桌面。 adb shell pm …

半小時快速入門Spring AI:使用騰訊云編程助手CodeBuddy 開發簡易聊天程序

引言 隨著人工智能&#xff08;AI&#xff09;技術的飛速發展&#xff0c;越來越多的開發者開始探索如何將AI集成到自己的應用中。人工智能正在迅速改變各行各業的工作方式&#xff0c;從自動化客服到智能推薦系統&#xff0c;AI的應用幾乎無處不在。Spring AI作為一種開源框架…

【unity游戲開發——編輯器擴展】使用MenuItem自定義菜單欄拓展

免職聲明&#xff1a; 1、目前本博客分享的大部分知識產出方式是&#xff1a;學習別人知識自己實際做一遍自己的理解擴展內容自己整理、歸納、總結再分享。2、正如博客簡介所說&#xff1a;這里沒有教程&#xff0c;這里只做學習分享。所有的內容都是學習筆記&#xff0c;可以說…