目錄
代碼注釋版:
關鍵功能說明:
torch.linalg.cholesky 的原理
代碼示例
Cholesky 分解的應用
與 torch.cholesky 的區別
總結
代碼注釋版:
from typing import Optionalimport torchdef lstsq(matrix: torch.Tensor, rhs: torch.Tensor, weights: torch.Tensor, l2_regularizer: Optional[torch.Tensor] = None,l2_regularizer_rhs: Optional[torch.Tensor] = None,shared: bool = False
) -> torch.Tensor:"""帶權重和L2正則化的最小二乘求解器,使用Cholesky分解解決形如 (A^T W A + λI) x = A^T W b 的線性系統支持多任務共享參數(通過shared參數合并Gram矩陣和右側項)Args:matrix: 設計矩陣A,形狀為 [batch_size, n_obs, n_params]rhs: 右側項b,形狀為 [batch_size, n_obs, n_outputs]weights: 權重矩陣W的對角元素,形狀為 [batch_size, n_obs]l2_regularizer: L2正則化項λ的對角矩陣,形狀為 [batch_size, n_params, n_params]l2_regularizer_rhs: 正則化項對右側的修正,形狀為 [batch_size, n_params, n_outputs]shared: 是否共享參數(將多個系統的Gram矩陣和右側項求和)Returns:最小二乘解,形狀為 [batch_size, n_params, n_outputs]"""# 加權設計矩陣: W^(1/2) * Aweighted_matrix = weights.unsqueeze(-1) * matrix# 計算正則化的Gram矩陣: A^T W A + λIregularized_gramian = weighted_matrix.mT @ matrixif l2_regularizer is not None:regularized_gramian += l2_regularizer # 添加L2正則項# 計算右側項: A^T W b + λ_rhsATb = weighted_matrix.mT @ rhsif l2_regularizer_rhs is not None:ATb += l2_regularizer_rhs# 如果共享參數,合并所有batch的貢獻if shared:regularized_gramian = regularized_gramian.sum(dim=0, keepdim=True)ATb = ATb.sum(dim=0, keepdim=True)# Cholesky分解求解chol = torch.linalg.cholesky(regularized_gramian)return torch.cholesky_solve(ATb, chol)def lstsq_partial_share(matrix: torch.Tensor,rhs: torch.Tensor,weights: torch.Tensor,l2_regularizer: torch.Tensor,n_shared: int = 0
) -> torch.Tensor:"""部分參數共享的最小二乘求解器將參數分為共享部分和獨立部分:- 共享參數在所有樣本間共享- 獨立參數每個樣本單獨估計通過分塊回歸實現高效求解Args:matrix: 設計矩陣A,形狀為 [batch_size, n_obs, n_params]rhs: 右側項b,形狀為 [batch_size, n_obs, n_outputs]weights: 權重矩陣的對角元素,形狀為 [batch_size, n_obs]l2_regularizer: 正則化強度,形狀為 [batch_size, n_params]n_shared: 共享參數的數量Returns:參數矩陣,前n_shared列為共享參數,其余為獨立參數形狀為 [batch_size, n_params, n_outputs]"""n_params = matrix.shape[-1]n_rhs_outputs = rhs.shape[-1]n_indep = n_params - n_shared# 全共享情況直接返回廣播結果if n_indep == 0:result = lstsq(matrix, rhs, weights, l2_regularizer, shared=True)return result.expand(matrix.shape[0], -1, -1)# 將正則化項轉換為設計矩陣的擴展部分# 相當于添加 λI 的正則化項matrix = torch.cat([matrix, batch_eye(n_params, matrix.shape[0])], dim=1)rhs = torch.nn.functional.pad(rhs, (0, 0, 0, n_params)) # 右側添加0weights = torch.cat([weights, l2_regularizer.unsqueeze(0).expand(matrix.shape[0], -1)], dim=1)# 分割共享和獨立參數對應的設計矩陣matrix_shared, matrix_indep = torch.split(matrix, [n_shared, n_indep], dim=-1)# 步驟1:求解獨立參數對共享參數和輸出的影響indep_coeffs = lstsq(matrix_indep, torch.cat([matrix_shared, rhs], dim=-1), weights)coeff_indep2shared, coeff_indep2rhs = torch.split(indep_coeffs, [n_shared, n_rhs_outputs], dim=-1)# 步驟2:用殘差求解共享參數shared_residual = matrix_shared - matrix_indep @ coeff_indep2sharedrhs_residual = rhs - matrix_indep @ coeff_indep2rhscoeff_shared2rhs = lstsq(shared_residual, rhs_residual, weights, shared=True)# 步驟3:更新獨立參數系數coeff_indep2rhs = coeff_indep2rhs - coeff_indep2shared @ coeff_shared2rhs# 合并結果:共享參數廣播,獨立參數保持獨立coeff_shared2rhs = coeff_shared2rhs.expand(matrix.shape[0], -1, -1)return torch.cat([coeff_shared2rhs, coeff_indep2rhs], dim=1)def batch_eye(n_params: int, batch_size: int) -> torch.Tensor:"""生成批次對角矩陣Args:n_params: 矩陣維度batch_size: 批次大小Returns:形狀為 [batch_size, n_params, n_params] 的單位矩陣批次"""return torch.eye(n_params).reshape(1, n_params, n_params).expand(batch_size, -1, -1)
關鍵功能說明:
-
lstsq
:-
核心最小二乘求解器,處理帶權重和L2正則的線性回歸
-
使用Cholesky分解提高數值穩定性
-
支持多任務參數共享模式(shared=True時合并所有任務的貢獻)
-
-
lstsq_partial_share
:-
處理部分參數共享的回歸問題
-
通過三步分塊回歸實現:
-
估計獨立參數對共享參數和輸出的影響
-
用殘差估計共享參數
-
修正獨立參數估計值
-
-
通過矩陣拼接技巧將正則化轉換為設計矩陣擴展
-
-
batch_eye
:-
生成批次單位矩陣,用于構建正則化項
-
典型應用:將L2正則轉換為擴展設計矩陣的偽觀測
-
torch.linalg.cholesky
的原理
torch.linalg.cholesky(A)
用于對對稱正定矩陣 AAA 進行 Cholesky 分解,即將其分解為:
A=LLTA = L L^TA=LLT
其中:
-
AAA 是 對稱正定矩陣(必須滿足 A=ATA = A^TA=AT 且所有特征值大于 0)。
-
LLL 是 下三角矩陣
計算 Cholesky 分解 的方式基于逐行計算 LLL:
-
計算對角元素:
Lii=Aii?∑k=1i?1Lik2L_{ii} = \sqrt{ A_{ii} - \sum_{k=1}^{i-1} L_{ik}^2 }Lii?=Aii??k=1∑i?1?Lik2?? -
計算非對角元素:
Lji=1Lii(Aji?∑k=1i?1LjkLik),j>iL_{ji} = \frac{1}{L_{ii}} \left( A_{ji} - \sum_{k=1}^{i-1} L_{jk} L_{ik} \right), \quad j > iLji?=Lii?1?(Aji??k=1∑i?1?Ljk?Lik?),j>i
這個算法 只需要計算下三角部分,所以比 LU 分解 計算量更少,適用于 正定矩陣的快速求解。
代碼示例
import torch# 生成一個對稱正定矩陣
A = torch.tensor([[4.0, 12.0, -16.0], [12.0, 37.0, -43.0], [-16.0, -43.0, 98.0]])# Cholesky 分解
L = torch.linalg.cholesky(A)
print(L)
輸出:
tensor([[ 2.0000, 0.0000, 0.0000],
[ 6.0000, 1.0000, 0.0000],
[-8.0000, 5.0000, 3.0000]])
可以驗證:
print(torch.mm(L, L.T))# 結果應當等于 A
Cholesky 分解的應用
-
解線性方程組 Ax=bAx = bAx=b:
-
先求
L = torch.linalg.cholesky(A)
-
解
Ly = b
(前代法) -
解
L^T x = y
(后代法)
-
-
生成多元正態分布:
-
如果協方差矩陣 Σ\SigmaΣ 進行 Cholesky 分解 Σ=LLT\Sigma = L L^TΣ=LLT,
-
則可以用
L @ torch.randn(n, d)
生成符合協方差 Σ\SigmaΣ 的多元正態分布數據。
-
與 torch.cholesky
的區別
-
torch.cholesky(A)
舊版 API,不推薦使用。 -
torch.linalg.cholesky(A)
現代 API,支持 batch 計算,推薦使用。
總結
-
torch.linalg.cholesky(A)
計算 對稱正定矩陣 的 Cholesky 分解,分解成下三角矩陣L
,使得 A=LLTA = L L^TA=LLT。 -
計算方式比 LU 分解更快,主要用于 正定矩陣的求解、統計學、多元正態分布 等。
-
使用 Cholesky 分解求解線性方程組比直接求逆更穩定高效。