最小二乘求解器lstsq,處理帶權重和L2正則的線性回歸

目錄

代碼注釋版:

關鍵功能說明:

torch.linalg.cholesky 的原理

代碼示例

Cholesky 分解的應用

與 torch.cholesky 的區別

總結


代碼注釋版:

from typing import Optionalimport torchdef lstsq(matrix: torch.Tensor, rhs: torch.Tensor, weights: torch.Tensor, l2_regularizer: Optional[torch.Tensor] = None,l2_regularizer_rhs: Optional[torch.Tensor] = None,shared: bool = False
) -> torch.Tensor:"""帶權重和L2正則化的最小二乘求解器,使用Cholesky分解解決形如 (A^T W A + λI) x = A^T W b 的線性系統支持多任務共享參數(通過shared參數合并Gram矩陣和右側項)Args:matrix: 設計矩陣A,形狀為 [batch_size, n_obs, n_params]rhs: 右側項b,形狀為 [batch_size, n_obs, n_outputs]weights: 權重矩陣W的對角元素,形狀為 [batch_size, n_obs]l2_regularizer: L2正則化項λ的對角矩陣,形狀為 [batch_size, n_params, n_params]l2_regularizer_rhs: 正則化項對右側的修正,形狀為 [batch_size, n_params, n_outputs]shared: 是否共享參數(將多個系統的Gram矩陣和右側項求和)Returns:最小二乘解,形狀為 [batch_size, n_params, n_outputs]"""# 加權設計矩陣: W^(1/2) * Aweighted_matrix = weights.unsqueeze(-1) * matrix# 計算正則化的Gram矩陣: A^T W A + λIregularized_gramian = weighted_matrix.mT @ matrixif l2_regularizer is not None:regularized_gramian += l2_regularizer  # 添加L2正則項# 計算右側項: A^T W b + λ_rhsATb = weighted_matrix.mT @ rhsif l2_regularizer_rhs is not None:ATb += l2_regularizer_rhs# 如果共享參數,合并所有batch的貢獻if shared:regularized_gramian = regularized_gramian.sum(dim=0, keepdim=True)ATb = ATb.sum(dim=0, keepdim=True)# Cholesky分解求解chol = torch.linalg.cholesky(regularized_gramian)return torch.cholesky_solve(ATb, chol)def lstsq_partial_share(matrix: torch.Tensor,rhs: torch.Tensor,weights: torch.Tensor,l2_regularizer: torch.Tensor,n_shared: int = 0
) -> torch.Tensor:"""部分參數共享的最小二乘求解器將參數分為共享部分和獨立部分:- 共享參數在所有樣本間共享- 獨立參數每個樣本單獨估計通過分塊回歸實現高效求解Args:matrix: 設計矩陣A,形狀為 [batch_size, n_obs, n_params]rhs: 右側項b,形狀為 [batch_size, n_obs, n_outputs]weights: 權重矩陣的對角元素,形狀為 [batch_size, n_obs]l2_regularizer: 正則化強度,形狀為 [batch_size, n_params]n_shared: 共享參數的數量Returns:參數矩陣,前n_shared列為共享參數,其余為獨立參數形狀為 [batch_size, n_params, n_outputs]"""n_params = matrix.shape[-1]n_rhs_outputs = rhs.shape[-1]n_indep = n_params - n_shared# 全共享情況直接返回廣播結果if n_indep == 0:result = lstsq(matrix, rhs, weights, l2_regularizer, shared=True)return result.expand(matrix.shape[0], -1, -1)# 將正則化項轉換為設計矩陣的擴展部分# 相當于添加 λI 的正則化項matrix = torch.cat([matrix, batch_eye(n_params, matrix.shape[0])], dim=1)rhs = torch.nn.functional.pad(rhs, (0, 0, 0, n_params))  # 右側添加0weights = torch.cat([weights, l2_regularizer.unsqueeze(0).expand(matrix.shape[0], -1)], dim=1)# 分割共享和獨立參數對應的設計矩陣matrix_shared, matrix_indep = torch.split(matrix, [n_shared, n_indep], dim=-1)# 步驟1:求解獨立參數對共享參數和輸出的影響indep_coeffs = lstsq(matrix_indep, torch.cat([matrix_shared, rhs], dim=-1), weights)coeff_indep2shared, coeff_indep2rhs = torch.split(indep_coeffs, [n_shared, n_rhs_outputs], dim=-1)# 步驟2:用殘差求解共享參數shared_residual = matrix_shared - matrix_indep @ coeff_indep2sharedrhs_residual = rhs - matrix_indep @ coeff_indep2rhscoeff_shared2rhs = lstsq(shared_residual, rhs_residual, weights, shared=True)# 步驟3:更新獨立參數系數coeff_indep2rhs = coeff_indep2rhs - coeff_indep2shared @ coeff_shared2rhs# 合并結果:共享參數廣播,獨立參數保持獨立coeff_shared2rhs = coeff_shared2rhs.expand(matrix.shape[0], -1, -1)return torch.cat([coeff_shared2rhs, coeff_indep2rhs], dim=1)def batch_eye(n_params: int, batch_size: int) -> torch.Tensor:"""生成批次對角矩陣Args:n_params: 矩陣維度batch_size: 批次大小Returns:形狀為 [batch_size, n_params, n_params] 的單位矩陣批次"""return torch.eye(n_params).reshape(1, n_params, n_params).expand(batch_size, -1, -1)

關鍵功能說明:

  1. lstsq:

    • 核心最小二乘求解器,處理帶權重和L2正則的線性回歸

    • 使用Cholesky分解提高數值穩定性

    • 支持多任務參數共享模式(shared=True時合并所有任務的貢獻)

  2. lstsq_partial_share:

    • 處理部分參數共享的回歸問題

    • 通過三步分塊回歸實現:

      1. 估計獨立參數對共享參數和輸出的影響

      2. 用殘差估計共享參數

      3. 修正獨立參數估計值

    • 通過矩陣拼接技巧將正則化轉換為設計矩陣擴展

  3. batch_eye:

    • 生成批次單位矩陣,用于構建正則化項

    • 典型應用:將L2正則轉換為擴展設計矩陣的偽觀測

torch.linalg.cholesky 的原理

torch.linalg.cholesky(A) 用于對對稱正定矩陣 AAA 進行 Cholesky 分解,即將其分解為:

A=LLTA = L L^TA=LLT

其中:

  • AAA 是 對稱正定矩陣(必須滿足 A=ATA = A^TA=AT 且所有特征值大于 0)。

  • LLL 是 下三角矩陣

計算 Cholesky 分解 的方式基于逐行計算 LLL:

  1. 計算對角元素:

    Lii=Aii?∑k=1i?1Lik2L_{ii} = \sqrt{ A_{ii} - \sum_{k=1}^{i-1} L_{ik}^2 }Lii?=Aii??k=1∑i?1?Lik2??
  2. 計算非對角元素:

    Lji=1Lii(Aji?∑k=1i?1LjkLik),j>iL_{ji} = \frac{1}{L_{ii}} \left( A_{ji} - \sum_{k=1}^{i-1} L_{jk} L_{ik} \right), \quad j > iLji?=Lii?1?(Aji??k=1∑i?1?Ljk?Lik?),j>i

這個算法 只需要計算下三角部分,所以比 LU 分解 計算量更少,適用于 正定矩陣的快速求解


代碼示例

import torch# 生成一個對稱正定矩陣
A = torch.tensor([[4.0, 12.0, -16.0], [12.0, 37.0, -43.0], [-16.0, -43.0, 98.0]])# Cholesky 分解
L = torch.linalg.cholesky(A)
print(L)

輸出

tensor([[ 2.0000, 0.0000, 0.0000],

[ 6.0000, 1.0000, 0.0000],

[-8.0000, 5.0000, 3.0000]])

可以驗證:

print(torch.mm(L, L.T))# 結果應當等于 A

Cholesky 分解的應用

  1. 解線性方程組 Ax=bAx = bAx=b:

    • 先求 L = torch.linalg.cholesky(A)

    • Ly = b(前代法)

    • L^T x = y(后代法)

  2. 生成多元正態分布

    • 如果協方差矩陣 Σ\SigmaΣ 進行 Cholesky 分解 Σ=LLT\Sigma = L L^TΣ=LLT,

    • 則可以用 L @ torch.randn(n, d) 生成符合協方差 Σ\SigmaΣ 的多元正態分布數據。


torch.cholesky 的區別

  • torch.cholesky(A) 舊版 API,不推薦使用。

  • torch.linalg.cholesky(A) 現代 API,支持 batch 計算,推薦使用。


總結

  • torch.linalg.cholesky(A) 計算 對稱正定矩陣Cholesky 分解,分解成下三角矩陣 L,使得 A=LLTA = L L^TA=LLT。

  • 計算方式比 LU 分解更快,主要用于 正定矩陣的求解、統計學、多元正態分布 等。

  • 使用 Cholesky 分解求解線性方程組比直接求逆更穩定高效。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/76441.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/76441.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/76441.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

AI輔助下基于ArcGIS Pro的SWAT模型全流程高效建模實踐與深度進階應用

目前,流域水資源和水生態問題逐漸成為制約社會經濟和環境可持續發展的重要因素。SWAT模型是一種基于物理機制的分布式流域水文與生態模擬模型,能夠對流域的水循環過程、污染物遷移等過程進行精細模擬和量化分析。SWAT模型目前廣泛應用于流域水文過程研究…

DHT11數字溫濕度傳感器驅動開發全解析(下) | 零基礎入門STM32第八十八步

主題內容教學目的/擴展視頻DHT11芯片電路連接,手冊分析。驅動程序,讀出數據。能讀出溫濕度值即可。 師從洋桃電子,杜洋老師 📑文章目錄 一、硬件接口與通信原理1.1 硬件連接拓撲1.2 單總線通信時序 二、驅動代碼深度解析&#xff…

24、網絡編程基礎概念

網絡編程基礎概念 網絡結構模式MAC地址IP地址子網掩碼端口網絡模型協議網絡通信的過程(封裝與解封裝) 網絡結構模式 C/S結構,由客戶機和服務器兩部分組成,如QQ、英雄聯盟 B/S結構,通過瀏覽器與服務器進程交互&#xf…

【超詳細】講解Ubuntu上如何配置分區方案

Ubuntu 的分區方案 一、通用分區方案(200G為例) EFI系統分區(僅UEFI啟動模式需要,) 大小:512MB–1GB類型:主分區(FAT32格式)掛載點:/boot/efi說明&#xff1…

函數的局部變量和全局變量的區分,Kimi的回答

這段代碼的目的是通過計算 2**i 和 5**i 的首位數字,并將這兩個首位數字的乘積添加到一個集合中,最終返回這些乘積的總和。下面是具體的解釋和問題的分析。 sum_t的角色: sum_t 是一個累加器,用來存儲所有獨特的(不重復…

RNN模型及NLP應用(5/9)——多層RNN、雙向RNN、預訓練

聲明: 本文基于嗶站博主【Shusenwang】的視頻課程【RNN模型及NLP應用】,結合自身的理解所作,旨在幫助大家了解學習NLP自然語言處理基礎知識。配合著視頻課程學習效果更佳。 材料來源:【Shusenwang】的視頻課程【RNN模型及NLP應用…

【3.軟件工程】3.4 原型及相關模型

軟件開發模型進化論:從原型驅動到混合模型的完整指南 🔄 一、模型進化關系全景圖 #mermaid-svg-GcOFjt54gUs4oPeu {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GcOFjt54gUs4oPeu .error-i…

硬件與軟件的邊界-從單片機到linux的問答詳解

硬件與軟件的邊界——從單片機到 Linux 設備驅動的問答詳解 在嵌入式開發和操作系統領域,經常會有人問: “如果一個設備里沒有任何代碼,硬件是不是依然會工作?例如,數據收發、寄存器數據存儲、甚至中斷觸發&#xff…

瑪卡巴卡的k8s知識點問答題(七)

25. 說明 Job 與 CronJob 的功能 Job 功能: 用于運行一次性任務(批處理任務),確保一個或多個 Pod 成功完成任務后退出。 適用于數據處理、備份、測試等場景,任務完成后 Pod 不會自動重啟。 特點: 任務…

【NLP 51、一些LLM模型結構上的變化】

目錄 一、multi-head 共享 二、attention結構 1.傳統的Tranformer結構 2.GPTJ —— 平行放置的Transformer結構 三、歸一化層位置的選擇 1.Post LN: 2.Pre-LN【目前主流】: 3.Sandwich-LN: 四、歸一化函數選擇 1.傳統的歸一化函數 LayerNorm …

VS+Qt配置QtXlsx庫實現execl文件導入導出(全教程)

一、配置QtXlsx 1.1 下載解壓QtXlsxWriter(在github下載即可) 網址:https://github.com/dbzhang800/QtXlsxWriter 1.2 使用qt運行 點擊qtxlsx.pro運行QtXlsxWriter 選擇DesktopQt51211MSVC201564bit編譯器(選擇自己本地電腦qt…

Golang的文件處理優化策略

Golang的文件處理優化策略 一、Golang的文件處理優化策略概述 是一門效率高、易于編程的編程語言,它的文件處理能力也非常強大。 在實際開發中,需要注意一些優化策略,以提高文件處理的效率和性能。 本文將介紹Golang中的文件處理優化策略&…

自學-C語言-基礎-數組、函數、指針、結構體和共同體、文件

這里寫自定義目錄標題 代碼環境:?問題思考:一、數組二、函數三、指針四、結構體和共同體五、文件問題答案: 代碼環境: Dev C ?問題思考: 把上門的字母與下面相同的字母相連,線不能…

VMware+Ubuntu+VScode+ROS一站式教學+常見問題解決

目錄 一.VMware的安裝 二.Ubuntu下載 1.前言 2.Ubuntu版本選擇 三.VMware中Ubuntu的安裝 四.Ubuntu系統基本設置 1.中文更改 2.中文輸入法更改 3. 輔助工具 vmware tools 五.VScode的安裝ros基本插件 1.安裝 2.ros輔助插件下載 六.ROS安裝 1.安裝ros 2.配置ROS…

PostgreSQL pg_repack 重新組織表并釋放表空間

pg_repack pg_repack是 PostgreSQL 的一個擴展,它允許您從表和索引中刪除膨脹,并可選擇恢復聚集索引的物理順序。與CLUSTER和VACUUM FULL不同,它可以在線工作,在處理過程中無需對已處理的表保持獨占鎖定。pg_repack 啟動效率高&a…

5G_WiFi_CE_射頻輸出功率、發射功率控制(TPC)和功率密度測試

目錄 一、規范要求 1、法規目錄: (1)RF Output Power (2)Transmit Power Control (TPC) (3)Power Density 2、限值: 二、EIRP測試方法 (1)測試條件 (2&#xff…

掃描線離散化線段樹解決矩形面積并-洛谷P5490

https://www.luogu.com.cn/problem/P5490 題目描述 求 n n n 個四邊平行于坐標軸的矩形的面積并。 輸入格式 第一行一個正整數 n n n。 接下來 n n n 行每行四個非負整數 x 1 , y 1 , x 2 , y 2 x_1, y_1, x_2, y_2 x1?,y1?,x2?,y2?,表示一個矩形的四個…

Java項目之基于ssm的簡易版營業廳寬帶系統(源碼+文檔)

項目簡介 簡易版營業廳寬帶系統實現了以下功能: 此營業廳寬帶系統利用當下成熟完善的SSM框架,使用跨平臺的可開發大型商業網站的Java語言,以及最受歡迎的RDBMS應用軟件之一的Mysql數據庫進行程序開發。實現了營業廳寬帶系統基礎數據的管理&…

從入門到入土,SQLServer 2022慢查詢問題總結

列為,由于公司原因,作者接觸了一個SQLServer 2022作為數據存儲到項目,可能是上一任的哥們兒離開的時候帶有情緒,所以現在項目的主要問題就是,所有功能都實現了,但是就是慢,列表頁3s打底,客戶很生氣,經過幾周摸爬滾打,作以下總結,作為自己的成長記錄。 一、索引問題…

PDF處理控件Aspose.PDF教程:在Python、Java 和 C# 中旋轉 PDF 文檔

您是否希望快速輕松地在線旋轉PDF文檔?無論您需要修復文檔的方向還是只想重新排列頁面,本指南都能滿足您的需求。有簡單的方法可以解決此問題 - 無論您喜歡在線工具還是編程解決方案。 在本指南中,我們將向您展示如何免費在線旋轉 PDF&#…