pytorch小記(十):pytorch中torch.tril 和 torch.triu 詳解
- PyTorch `torch.tril` 和 `torch.triu` 詳解
- 1. `torch.tril`(計算下三角矩陣)
- 📌 作用
- 🔍 語法
- 🔹 參數
- 📌 示例
- 🔍 `diagonal` 參數
- 🔍 `torch.tril` 的應用
- 2. `torch.triu`(計算上三角矩陣)
- 📌 作用
- 🔍 語法
- 🔹 參數
- 📌 示例
- 🔍 `diagonal` 參數
- 3. `torch.tril` vs `torch.triu` 對比
- 總結
PyTorch torch.tril
和 torch.triu
詳解
在數值計算和深度學習中,下三角矩陣(Lower Triangular Matrix) 和 上三角矩陣(Upper Triangular Matrix) 是非常常見的矩陣操作。PyTorch 提供了 torch.tril()
和 torch.triu()
這兩個函數,分別用于計算下三角矩陣和上三角矩陣。
1. torch.tril
(計算下三角矩陣)
📌 作用
torch.tril
返回輸入張量的 下三角部分,即:
- 保留 主對角線及其以下的元素。
- 主對角線以上的元素全部變為 0。
🔍 語法
torch.tril(input, diagonal=0)
🔹 參數
參數 | 說明 |
---|---|
input | 輸入張量 |
diagonal | 控制對角線位置(默認 0 ) |
diagonal=0 | 保留主對角線 及其以下的元素 |
diagonal>0 | 向上偏移,保留主對角線以上 diagonal 行 |
diagonal<0 | 向下偏移,移除 -diagonal 行的主對角線元素 |
📌 示例
import torch# 創建一個 4×4 的矩陣
A = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12],[13, 14, 15, 16]
])print("原始矩陣 A:")
print(A)# 計算 A 的下三角矩陣
L = torch.tril(A)
print("\nA 的下三角矩陣(diagonal=0):")
print(L)
輸出:
原始矩陣 A:
tensor([[ 1, 2, 3, 4],[ 5, 6, 7, 8],[ 9, 10, 11, 12],[13, 14, 15, 16]])A 的下三角矩陣(diagonal=0):
tensor([[ 1, 0, 0, 0],[ 5, 6, 0, 0],[ 9, 10, 11, 0],[13, 14, 15, 16]])
💡 說明:主對角線上的元素保留,其上的元素變為
0
。
🔍 diagonal
參數
print(torch.tril(A, diagonal=1)) # 保留主對角線以上 1 行
print(torch.tril(A, diagonal=-1)) # 移除主對角線
輸出:
A 的下三角矩陣(diagonal=1):
tensor([[ 1, 2, 0, 0],[ 5, 6, 7, 0],[ 9, 10, 11, 12],[13, 14, 15, 16]])A 的下三角矩陣(diagonal=-1):
tensor([[ 0, 0, 0, 0],[ 5, 0, 0, 0],[ 9, 10, 0, 0],[13, 14, 15, 0]])
🔺 diagonal=1:向上偏移,保留
1
行主對角線以上的元素。
🔻 diagonal=-1:向下偏移,移除主對角線。
🔍 torch.tril
的應用
📌 用于 Masking(掩碼)
seq_length = 5
mask = torch.tril(torch.ones(seq_length, seq_length)) # 創建一個下三角 Mask
print(mask)
輸出:
tensor([[1., 0., 0., 0., 0.],[1., 1., 0., 0., 0.],[1., 1., 1., 0., 0.],[1., 1., 1., 1., 0.],[1., 1., 1., 1., 1.]])
💡 Transformer 中,這種 Mask 用于防止模型在訓練時提前看到未來的信息。
2. torch.triu
(計算上三角矩陣)
📌 作用
torch.triu
返回輸入張量的 上三角部分,即:
- 保留 主對角線及其以上的元素。
- 主對角線以下的元素全部變為 0。
🔍 語法
torch.triu(input, diagonal=0)
🔹 參數
參數 | 說明 |
---|---|
input | 輸入張量 |
diagonal=0 | 保留主對角線及其以上的元素 |
diagonal>0 | 移除 diagonal 行的主對角線元素 |
diagonal<0 | 保留主對角線以下 -diagonal 行 |
📌 示例
U = torch.triu(A)
print("A 的上三角矩陣(diagonal=0):")
print(U)
輸出:
A 的上三角矩陣(diagonal=0):
tensor([[ 1, 2, 3, 4],[ 0, 6, 7, 8],[ 0, 0, 11, 12],[ 0, 0, 0, 16]])
💡 說明:主對角線上的元素及其上的元素保留,下面的元素變為
0
。
🔍 diagonal
參數
print(torch.triu(A, diagonal=1)) # 移除主對角線元素
print(torch.triu(A, diagonal=-1)) # 保留主對角線以下 1 行
輸出:
A 的上三角矩陣(diagonal=1):
tensor([[ 0, 2, 3, 4],[ 0, 0, 7, 8],[ 0, 0, 0, 12],[ 0, 0, 0, 0]])A 的上三角矩陣(diagonal=-1):
tensor([[ 1, 2, 3, 4],[ 5, 6, 7, 8],[ 0, 10, 11, 12],[ 0, 0, 15, 16]])
🔺 diagonal=1:移除主對角線的元素,僅保留主對角線以上的元素。
🔻 diagonal=-1:允許保留主對角線以下1
行的元素。
3. torch.tril
vs torch.triu
對比
作用 | torch.tril(A) | torch.triu(A) |
---|---|---|
計算結果 | 取下三角部分 | 取上三角部分 |
對角線控制 | diagonal=0 保留主對角線 | diagonal=0 保留主對角線 |
diagonal>0 | 保留主對角線以上元素 | 移除主對角線部分元素 |
diagonal<0 | 移除主對角線部分元素 | 保留主對角線以下部分 |
總結
torch.tril()
取 下三角矩陣,可以用于 Cholesky 分解、Transformer Masking。torch.triu()
取 上三角矩陣,常用于 線性代數計算 和 矩陣變換。
🚀 你可以根據不同的需求選擇合適的函數,在 PyTorch 中高效處理矩陣運算!