序號 | 系列文章 |
---|---|
1 | 深度學習訓練中GPU內存管理 |
2 | 深度學習PyTorch之數據加載DataLoader |
3 | 深度學習 PyTorch 中 18 種數據增強策略與實現 |
4 | 深度學習pytorch之簡單方法自定義9類卷積即插即用 |
5 | 深度學習PyTorch之13種模型精度評估公式及調用方法 |
6 | 深度學習pytorch之4種歸一化方法(Normalization)原理公式解析 |
7 | 深度學習pytorch之19種優化算法(optimizer)解析 |
8 | 深度學習pytorch之22種損失函數數學公式和代碼定義 |
9 | DIY損失函數–以自適應邊界損失為例 |
10 | 深度學習PyTorch之動態計算圖可視化 - 使用 torchviz 生成計算圖 |
文章目錄
- 前言
- 1. 什么是動態計算圖?
- 2. 為什么要可視化計算圖?
- 3. 使用 `torchviz` 生成計算圖
- 3.1 安裝 `torchviz`
- 3.2 生成計算圖完整代碼示例
- 3.3 在訓練過程中生成計算圖
- 3.4 代碼解讀
- 3.5 生成的計算圖
- 4. `torchviz` 的更多應用
- 5. 總結
- 參考文獻
前言
在深度學習模型的開發過程中,理解和可視化模型的計算圖對于調試、優化和教學都具有重要意義。PyTorch 采用的是動態圖機制,這使得每次前向傳播時計算圖都被動態創建。而 torchviz
是一個非常有用的工具,它可以將這些動態圖轉化為可視化圖形,幫助我們更直觀地理解模型的計算過程。在本篇博客中,我們將重點介紹如何使用 torchviz
生成和保存 PyTorch 模型的計算圖,并結合實際訓練代碼進行展示。
1. 什么是動態計算圖?
在 PyTorch 中,計算圖并不是在模型初始化時構建好的,而是通過前向傳播過程動態地構建的。這種動態特性意味著每次運行時,計算圖會根據輸入數據的形狀和大小而變化,因此我們可以靈活地進行調試和優化。PyTorch 的動態圖提供了較高的靈活性,允許在計算圖中嵌入復雜的控制流結構(例如循環和條件判斷)。
2. 為什么要可視化計算圖?
可視化計算圖的優勢在于:
- 調試:通過查看每一層的輸入輸出,可以快速發現模型設計上的問題。
- 優化:通過分析計算圖,可以識別瓶頸和不必要的計算,進而優化模型性能。
- 教學:對于新手來說,計算圖能夠幫助他們理解深度學習模型的前向傳播過程。
雖然 PyTorch 的動態圖功能非常強大,但由于它不提供直接的計算圖展示方式,因此我們需要借助外部工具 torchviz
進行可視化。
3. 使用 torchviz
生成計算圖
torchviz
是一個能夠將 PyTorch 計算圖轉化為圖形的庫,具體來說,它能夠將計算圖渲染為 DOT
格式并生成可視化圖像文件(如 PNG 或 PDF)。我們通過以下幾步可以生成計算圖:
3.1 安裝 torchviz
首先,你需要安裝 torchviz
庫。可以通過 pip
安裝:
pip install torchviz
此時會直接將graphviz,torchziv兩個都安裝好,但是這種方法無法將graphviz導入系統路徑。出現報錯graphviz.backend.ExecutableNotFound: failed to execute ‘dot‘, make sure the Graphviz executables are***,需要從網址 Download | Graphviz下載graphviz的zip格式文件,解壓后復制到以下python路徑下即可。
3.2 生成計算圖完整代碼示例
核心語句只包括make_dot和render兩個函數,其中:
- make_dot(y) 會根據輸入張量 y 的計算過程生成計算圖。
- render(“model_graph”, format=“png”) 將計算圖保存為 PNG 圖片。
import torch
import torch.nn as nn
import torch.optim as optim
from torchviz import make_dot# 定義一個簡單的神經網絡
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(2, 2)self.fc2 = nn.Linear(2, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 創建模型實例
model = SimpleNN()# 輸入數據
x = torch.randn(1, 2)# 前向傳播
y = model(x)# 可視化計算圖
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render("model_graph", format="png") # 保存圖像為png文件
復制以上代碼運行后生成model_graph.png如
3.3 在訓練過程中生成計算圖
假設你已經有了一個標準的 PyTorch 訓練代碼,并且希望在訓練過程中生成計算圖。我們可以在每次前向傳播時使用 torchviz.make_dot
來生成計算圖,并保存為 PNG 文件。
以下是一個集成計算圖生成的訓練代碼示例:
import torch
from torchviz import make_dotfor epo in range(epo_num):print(epo)train_loss = 0train_acc = 0.0seg_model.train()for index, (img, label) in enumerate(train_dataloader):img = img.to(device)label = label.to(device)optimizer.zero_grad()output = seg_model(img) # 得到模型輸出# 使用 torch.sigmoid 激活函數output = torch.sigmoid(output)# 生成計算圖并保存為 PNG 文件if index == 0: # 只在第一個batch時生成計算圖dot = make_dot(output, params=dict(seg_model.named_parameters()))dot.render("model_graph_epoch_{}_batch_{}".format(epo, index), format="png") # 保存為 epoch_x_batch_y.png# 計算損失loss = criterion(output, label)loss.backward()iter_loss = loss.item()all_train_iter_loss.append(iter_loss)train_loss += iter_lossoptimizer.step()# 計算準確率output_1 = output.argmax(dim=1)label_1 = label.argmax(dim=1)correct = torch.eq(output_1, label_1).sum().item()iter_acc = correct / label_1.numel()all_train_iter_acc.append(iter_acc)train_acc += iter_acc
3.4 代碼解讀
-
前向傳播:
output = seg_model(img)
這一行執行了前向傳播,計算了模型的輸出。 -
計算圖生成:在每個 epoch 的第一個 batch 中,使用
make_dot(output, params=dict(seg_model.named_parameters()))
來生成計算圖。output
是模型的輸出,而seg_model.named_parameters()
則提供了模型的參數信息,這對于生成完整的計算圖非常有幫助。 -
保存計算圖:通過
dot.render()
將計算圖保存為 PNG 格式的文件。文件名包含當前的 epoch 和 batch 索引,以便于區分。dot.render("model_graph_epoch_{}_batch_{}".format(epo, index), format="png")
3.5 生成的計算圖
計算圖會包含模型中的每個操作(如矩陣乘法、加法等),以及這些操作之間的連接關系。通過計算圖(以下示例),你可以清楚地看到模型的每一步計算如何進行。
4. torchviz
的更多應用
除了在訓練過程中生成計算圖,torchviz
還可以用于以下場景:
-
單步調試:如果你的模型非常復雜,可以在某個特定步驟(如單個前向傳播)生成計算圖,幫助調試。
-
模型設計:在設計新的網絡架構時,通過生成計算圖,可以確保每一層的輸入輸出形狀是正確的。
-
計算性能分析:通過分析計算圖中的每個節點,可以識別出性能瓶頸并進行優化。
5. 總結
PyTorch 的動態圖特性使得每次前向傳播時計算圖都是動態生成的,而 torchviz
則提供了一個簡便的工具,可以將這些動態生成的計算圖可視化為圖像文件。通過將 torchviz
集成到訓練代碼中,我們可以在訓練過程中實時生成計算圖,這不僅有助于我們調試模型,還可以為教學和研究提供更清晰的解釋。
參考文獻
- torchviz GitHub
- PyTorch 官方文檔