3D Gaussian splatting 05: 代碼閱讀-訓練整體流程

目錄

  • 3D Gaussian splatting 01: 環境搭建
  • 3D Gaussian splatting 02: 快速評估
  • 3D Gaussian splatting 03: 用戶數據訓練和結果查看
  • 3D Gaussian splatting 04: 代碼閱讀-提取相機位姿和稀疏點云
  • 3D Gaussian splatting 05: 代碼閱讀-訓練整體流程
  • 3D Gaussian splatting 06: 代碼閱讀-訓練參數
  • 3D Gaussian splatting 07: 代碼閱讀-訓練載入數據和保存結果
  • 3D Gaussian splatting 08: 代碼閱讀-渲染

訓練整體流程

程序入參

訓練程序入參除了訓練過程參數, 另外設置了ModelParams, OptimizationParams, PipelineParams三個參數組, 分別控制數據加載、渲染計算和優化訓練環節, 詳細的說明查看下一節 06: 代碼閱讀-訓練參數

    # 命令行參數解析器parser = ArgumentParser(description="Training script parameters")# 模型相關參數lp = ModelParams(parser)op = OptimizationParams(parser)pp = PipelineParams(parser)parser.add_argument('--ip', type=str, default="127.0.0.1")parser.add_argument('--port', type=int, default=6009)parser.add_argument('--debug_from', type=int, default=-1)parser.add_argument('--detect_anomaly', action='store_true', default=False)parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])parser.add_argument("--quiet", action="store_true")parser.add_argument('--disable_viewer', action='store_true', default=False)parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])parser.add_argument("--start_checkpoint", type=str, default = None)args = parser.parse_args(sys.argv[1:])

開始訓練

程序調用 training() 這個方法開始訓練

torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)

初始化

以下是 training() 這個方法中初始化訓練的代碼和對應的注釋說明

# 如果指定了 sparse_adam 加速器, 檢查是否已經安裝
if not SPARSE_ADAM_AVAILABLE and opt.optimizer_type == "sparse_adam":sys.exit(f"Trying to use sparse adam but it is not installed, please install the correct rasterizer using pip install [3dgs_accel].")
# first_iter用于記錄當前是第幾次迭代
first_iter = 0
# 創建本次訓練的輸出目錄和日志記錄器, 每次執行訓練, 都會在 output 目錄下創建一個隨機目錄名
tb_writer = prepare_output_and_logger(dataset)
# 初始化 Gaussian 模型
gaussians = GaussianModel(dataset.sh_degree, opt.optimizer_type)
# 初始化訓練場景, 這里會載入相機參數和稀疏點云等數據
scene = Scene(dataset, gaussians)
# 初始化訓練參數
gaussians.training_setup(opt)
# 如果存在檢查點, 則載入
if checkpoint:(model_params, first_iter) = torch.load(checkpoint)gaussians.restore(model_params, opt)# 設置背景顏色
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")# 初始化CUDA事件
iter_start = torch.cuda.Event(enable_timing = True)
iter_end = torch.cuda.Event(enable_timing = True)# 是否使用 sparse adam 加速器
use_sparse_adam = opt.optimizer_type == "sparse_adam" and SPARSE_ADAM_AVAILABLE 
# Get depth L1 weight scheduling function
depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations)# Initialize viewpoint stack and indices
viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_indices = list(range(len(viewpoint_stack)))
# Initialize exponential moving averages for logging
ema_loss_for_log = 0.0
ema_Ll1depth_for_log = 0.0# 初始化進度條
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")

迭代訓練

從大約73行開始, 進行迭代訓練

first_iter += 1
for iteration in range(first_iter, opt.iterations + 1):

對外連工具展示渲染結果

    # 這部分處理網絡連接, 對外展示當前訓練的渲染結果if network_gui.conn == None:network_gui.try_connect()while network_gui.conn != None:try:net_image_bytes = None# Receive data from GUIcustom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()if custom_cam != None:# Render image for GUInet_image = render(custom_cam, gaussians, pipe, background, scaling_modifier=scaling_modifer, use_trained_exp=dataset.train_test_exp, separate_sh=SPARSE_ADAM_AVAILABLE)["render"]net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())# Send image to GUInetwork_gui.send(net_image_bytes, dataset.source_path)if do_training and ((iteration < int(opt.iterations)) or not keep_alive):breakexcept Exception as e:network_gui.conn = None

更新學習率, 選中相機進行渲染

    # 記錄迭代的開始時間iter_start.record()# 更新學習率, 底下都是調用的 get_expon_lr_func(), 一個學習率調度函數, 根據訓練步數計算當前的學習率, 學習率從初始值指數衰減到最終值.gaussians.update_learning_rate(iteration)# 每1000次迭代, 球諧函數(SH, Spherical Harmonics)的階數加1, 直到設置的最大的階數, 默認最大為3, # 每個3D高斯點需要存儲(階數 + 1)^2 個球諧系數, 3階時為16個系數, 每個系數有RGB 3個值所以一共48個值if iteration % 1000 == 0:gaussians.oneupSHdegree()# 當棧為空時, 復制一份訓練幀的相機位姿列表并創建對應的索引列表if not viewpoint_stack:viewpoint_stack = scene.getTrainCameras().copy()viewpoint_indices = list(range(len(viewpoint_stack)))# 從中隨機選取一個相機位姿rand_idx = randint(0, len(viewpoint_indices) - 1)# 從當前棧中彈出, 避免重復選取, 這樣最終會按隨機的順序遍歷完所有的相機位姿viewpoint_cam = viewpoint_stack.pop(rand_idx)vind = viewpoint_indices.pop(rand_idx)# 如果到了開啟debug的迭代次數, 開啟debugif (iteration - 1) == debug_from:pipe.debug = True# 如果設置了隨機背景, 創建隨機背景顏色張量bg = torch.rand((3), device="cuda") if opt.random_background else background# 用當前選中的相機視角, 渲染當前的場景render_pkg = render(viewpoint_cam, gaussians, pipe, bg, use_trained_exp=dataset.train_test_exp, separate_sh=SPARSE_ADAM_AVAILABLE)# 讀出渲染結果image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]# 處理攝像機視角的alpha遮罩(透明度), 將alpha遮罩數據從CPU內存轉移到GPU顯存, 將當前圖像與alpha遮罩進行逐像素相乘, # alpha值為1時保留原像素, alpha值為0時使像素完全透明if viewpoint_cam.alpha_mask is not None:alpha_mask = viewpoint_cam.alpha_mask.cuda()image *= alpha_mask

計算損失

    # 從viewpoint_cam對象中獲取原始圖像數據, 使用.cuda()方法將數據從CPU內存轉移到GPU顯存, # 調用L1損失函數, 計算渲染結果與原圖gt_image之間的像素級絕對差平均值gt_image = viewpoint_cam.original_image.cuda()Ll1 = l1_loss(image, gt_image)# 計算兩個圖像之間的結構相似性指數(SSIM), 如果 fused_ssim 可用則使用 fused_ssim, 否則使用普通的ssim# Calculate SSIM using fused implementation if availableif FUSED_SSIM_AVAILABLE:# 用unsqueeze(0)來增加一個維度,fused_ssim需要批量輸入ssim_value = fused_ssim(image.unsqueeze(0), gt_image.unsqueeze(0))else:ssim_value = ssim(image, gt_image)# 結合L1損失和SSIM損失計算混合損失, (1.0 - ssim_value) 將SSIM相似度轉換為損失值, 因為SSIM值越大損失越小loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value)# Depth regularization 深度正則化, 引入單目深度估計作為弱監督信號改善幾何一致性, 緩解漂浮物偽影, 增強遮擋區域的重建效果Ll1depth_pure = 0.0if depth_l1_weight(iteration) > 0 and viewpoint_cam.depth_reliable:# 從渲染結果中獲取逆向深度圖(1/depth)invDepth = render_pkg["depth"]# 獲取單目深度估計的逆向深度圖并轉移到GPUmono_invdepth = viewpoint_cam.invdepthmap.cuda()# 深度有效區域的掩碼(標記可靠區域)depth_mask = viewpoint_cam.depth_mask.cuda()# 計算帶掩碼的L1損失 = 絕對差(渲染深度 - 單目深度) * 掩碼 → 取均值Ll1depth_pure = torch.abs((invDepth  - mono_invdepth) * depth_mask).mean()# 應用動態權重系數(可能隨迭代次數衰減)Ll1depth = depth_l1_weight(iteration) * Ll1depth_pure # 將加權后的深度損失加入總損失loss += Ll1depth# 將Tensor轉換為Python數值用于記錄Ll1depth = Ll1depth.item()else:Ll1depth = 0

反向計算梯度并優化

    # 執行反向傳播算法, 自動計算所有可訓練參數關于loss的梯度loss.backward()# 記錄迭代結束時間iter_end.record()# End iteration timing# torch.no_grad() 臨時關閉梯度計算的上下文管理器with torch.no_grad():# Progress barema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_logema_Ll1depth_for_log = 0.4 * Ll1depth + 0.6 * ema_Ll1depth_for_logif iteration % 10 == 0:progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Depth Loss": f"{ema_Ll1depth_for_log:.{7}f}"})progress_bar.update(10)if iteration == opt.iterations:progress_bar.close()# 輸出日志, 當迭代次數為 testing_iterations 時(默認為7000和30000), 會做一次整體評估, 間隔5取5個樣本, 取一部分相機視角計算L1和SSIM損失, iter_start.elapsed_time(iter_end) 計算耗時training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp), dataset.train_test_exp)# 當迭代次數為 saving_iterations(默認為7000和30000)時,保存if (iteration in saving_iterations):print("\n[ITER {}] Saving Gaussians".format(iteration))# 里面會調用 gaussians.save_ply() 保存ply文件scene.save(iteration)# 當迭代次數小于致密化結束的右邊界時if iteration < opt.densify_until_iter:# 可見性半徑更新, 記錄每個高斯點在所有視角下的最大可見半徑, 用于后續剪枝判斷. visibility_filter過濾出當前視角可見的高斯點# Keep track of max radii in image-space for pruninggaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])# 累積視空間位置梯度統計量, 用于后續判斷哪些高斯點需要分裂(高梯度區域)或克隆(高位置變化區域)gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)# 當迭代次數大于致密化開始的左邊界, 并且滿足致密化間隔時, 進行致密化與修剪處理if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:# 如果迭代次數小于不透明度重置間隔(3000)則返回20作為2D尺寸限制, 否則不限制size_threshold = 20 if iteration > opt.opacity_reset_interval else None# 致密化與修剪gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii)# 定期(默認3000一次)重置不透明度, 恢復被錯誤剪枝的高斯點, 調整新生成高斯的可見性, 適配白背景場景的特殊初始化if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):gaussians.reset_opacity()# Optimizer階段, 反向優化模型參數if iteration < opt.iterations:gaussians.exposure_optimizer.step()gaussians.exposure_optimizer.zero_grad(set_to_none = True)if use_sparse_adam:visible = radii > 0gaussians.optimizer.step(visible, radii.shape[0])gaussians.optimizer.zero_grad(set_to_none = True)else:gaussians.optimizer.step()gaussians.optimizer.zero_grad(set_to_none = True)# 到達預設的checkpoint, 默認為7000和30000, 保存當前的訓練進度if (iteration in checkpoint_iterations):print("\n[ITER {}] Saving Checkpoint".format(iteration))torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")

如果有錯誤請留言指出

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/83680.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/83680.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/83680.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【黑馬程序員uniapp】項目配置、請求函數封裝

黑馬程序員前端項目uniapp小兔鮮兒微信小程序項目視頻教程&#xff0c;基于Vue3TsPiniauni-app的最新組合技術棧開發的電商業務全流程_嗶哩嗶哩_bilibili 參考 有代碼&#xff0c;還有app、h5頁面、小程序的演示 小兔鮮兒-vue3ts-uniapp-一套代碼多端部署: 小兔鮮兒-vue3ts-un…

前端使用 preview 插件預覽docx文件

目錄 前言一 引入插件二 JS 處理 前言 前端使用 preview 插件預覽docx文件 一 引入插件 建議下載至本地&#xff0c;靜態引入&#xff0c;核心的文件已打包&#xff08;前端使用 preview 插件預覽docx文件&#xff09;&#xff0c;在文章目錄處下載至本地&#xff0c;復制在項…

如何在運動中保護好半月板?

文章目錄 引言I 半月板的作用穩定作用緩沖作用潤滑作用II 在跳繩運動中保護好半月板III 半月板損傷自測IV 半月板“殺手”半月板損傷必須滿足四個因素:消耗品引言 膝蓋是連接大腿骨和小腿骨的地方,在兩部分骨頭的連接處,墊著兩片半月形的纖維軟骨板,這就是半月板。半月板分…

安科瑞防逆流方案落地內蒙古中高綠能光伏項目,筑牢北疆綠電安全防線

一、項目概況 內蒙古阿拉善中高綠能能源分布式光伏項目&#xff0c;位于內蒙古烏斯太鎮&#xff0c;裝機容量為7MW&#xff0c;采用自發自用、余電不上網模式。 用戶配電站為35kV用戶站&#xff0c;采用兩路電源單母線分段系統。本項目共設置12臺35/0.4kV變壓器&#xff0c;在…

1.3 fs模塊詳解

fs 模塊詳解 Node.js 的 fs 模塊提供了與文件系統交互的能力&#xff0c;是服務器端編程的核心模塊之一。它支持同步、異步&#xff08;回調式&#xff09;和 Promise 三種 API 風格&#xff0c;可滿足不同場景的需求。 1. 模塊引入 const fs require(fs); // 回調…

LeetCode 70 爬樓梯(Java)

爬樓梯問題&#xff1a;動態規劃與斐波那契的巧妙結合 問題描述 假設你正在爬樓梯&#xff0c;需要爬 n 階才能到達樓頂。每次你可以爬 1 或 2 個臺階。求有多少種不同的方法可以爬到樓頂&#xff1f; 示例&#xff1a; n 2 → 輸出 2&#xff08;1階1階 或 2階&#xff0…

【學習分享】shell基礎-參數傳遞

參數傳遞 我們可以在執行 Shell 腳本時&#xff0c;向腳本傳遞參數&#xff0c;腳本內獲取參數的格式為 $n&#xff0c;n 代表一個數字&#xff0c;1 為執行腳本的第一個參數&#xff0c;2 為執行腳本的第二個參數。 例如可以使用 $1、$2 等來引用傳遞給腳本的參數&#xff0…

Fluence推出“Pointless計劃”:五種方式參與RWA算力資產新時代

2025年6月1日&#xff0c;去中心化算力平臺 Fluence 正式宣布啟動“Pointless 計劃”——這是其《Fluence Vision 2026》戰略中四項核心舉措之一&#xff0c;旨在通過貢獻驅動的積分體系&#xff0c;激勵更廣泛的社區參與&#xff0c;為用戶帶來現實世界資產&#xff08;RWA&am…

Excel數據分析:基礎

在現代辦公環境中&#xff0c;Excel 是一款不可或缺的工具&#xff0c;它是 Microsoft&#xff08;微軟&#xff09;開發的電子表格軟件&#xff0c;用于處理和分析結構化數據。市場上還有其他類似的軟件&#xff0c;如 Google Sheets 和 Apple Numbers&#xff0c;但 Excel 以…

12V降5V12A大功率WD5030A,充電器、便攜式設備、網絡及工業領域的理想選擇

WD5030A 高效單片同步降壓型直流 / 直流轉換器 一、芯片核心概述 WD5030A 是一款高性能同步降壓型 DC/DC 轉換器&#xff0c;采用 平均電流模式控制架構&#xff08;帶頻率抖動功能&#xff09;&#xff0c;具備以下核心優勢&#xff1a; 精準電流控制&#xff1a;快速響應負…

企業級AI邁入黃金時代,企業該如何向AI“蝶變”?

科技云報到原創。 近日&#xff0c;微軟&#xff08;MSFT.US&#xff09;在最新全員大會上高調展示企業級AI業務進展&#xff0c;其中與巴克萊銀行達成的10萬份Copilot許可證交易成為焦點。 微軟首席商務官賈德森阿爾索夫在會上披露&#xff0c;這家英國金融巨頭已簽約采購相…

Java編程課(一)

Java編程課 一、java簡介二、Java基礎語法2.1 環境搭建2.2 使用Intellij IDEA新建java項目2.3 Java運行介紹2.4 參數說明2.5 Java基礎語法2.6 注釋2.7 變量和常量一、java簡介 Java是一種廣泛使用的高級編程語言,最初由Sun Microsystems于1995年發布。它被設計為具有簡單、可…

【Java Web】速通Tomcat

參考筆記:JavaWeb 速通Tomcat_tomcat部署java項目-CSDN博客 目錄 一、Tomcat服務 1. 下載和安裝 2. 啟動Tomcat服務 3. 啟動Tomcat服務的注意事項 4. 關閉Tomcat服務 二、Tomcat的目錄結構 1. bin ?? 2. conf ?? 3. lib 4. logs 5. temp 6. webapps 7. work 三、Web項目…

Mysql 身份認證繞過漏洞 CVE-2012-2122

前言&#xff1a;CVE-2012-2122 是一個影響 MySQL 和 MariaDB 的身份驗證漏洞&#xff0c;存在于特定版本中 vulhub/mysql/CVE-2012-2122/README.zh-cn.md at master vulhub/vulhubhttps://github.com/vulhub/vulhub/blob/master/mysql/CVE-2012-2122/README.zh-cn.md 任務一…

Win10停更,Win11不好用?現在Mac電腦比Win11電腦更便宜

最近不少朋友在換電腦前都犯了難。 以前大家最常說的一句是&#xff1a;“Mac太貴了&#xff0c;還是買Windows吧。”但現在不一樣了&#xff0c;國補教育優惠下來&#xff0c;新款M4芯片的Mac mini的入門價已經降到了3000元左右&#xff0c;曾經的價格壁壘&#xff0c;已經不…

C#中Struct與IntPtr轉換:實用擴展方法

C#中Struct與IntPtr轉換&#xff1a;實用擴展方法 在 C# 編程的世界里&#xff0c;我們常常會遇到需要與非托管代碼交互&#xff0c;或者進行一些底層內存操作的場景。這時&#xff0c;IntPtr類型就顯得尤為重要&#xff0c;它可以表示一個指針或句柄&#xff0c;用來指向非托…

手機歸屬地查詢接口如何用Java調用?

一、什么是手機歸屬地查詢接口&#xff1f; 是一種便捷、高效的工具&#xff0c;操作簡單&#xff0c;請求速度快。它不僅能夠提高用戶填寫地址的效率&#xff0c;還能幫助企業更好地了解客戶需求&#xff0c;制定個性化的營銷策略&#xff0c;降低風險。隨著移動互聯網的發展…

43、視圖解析-Thymeleaf初體驗

43、視圖解析-Thymeleaf初體驗 “43、視圖解析-Thymeleaf初體驗”通常是指在學習Spring Boot框架時&#xff0c;關于如何使用Thymeleaf模板引擎進行視圖解析的入門課程或章節。以下是對該主題的詳細介紹&#xff1a; #### Thymeleaf簡介 - **定義**&#xff1a;Thymeleaf是一個…

Day 40訓練

Day 40 訓練 PyTorch 圖像數據訓練與測試的規范寫法單通道圖像的規范訓練流程數據預處理與加載模型定義訓練與測試函數封裝模型訓練執行 彩色圖像的擴展應用數據預處理調整模型結構調整 關鍵要點總結 知識點回顧&#xff1a; 彩色和灰度圖片測試和訓練的規范寫法&#xff1a;封…

杰理可視化SDK--系統死機異常調試

杰理可視化SDK--系統死機異常調試 系統異常原因杰理SDK異常調試準備工作杰理SDK系統異常定位異常代碼示例1異常代碼示例2 在使用杰理可視化SDK進行軟件開發時&#xff0c;往往會遇到一些系統異常問題&#xff0c;系統異常是指芯片在運行代碼時&#xff0c;由于軟件或硬件狀態出…