目錄
推理代碼:
EnvLight 代碼:
推理代碼:
sky_model = self.models["Sky"]outputs["rgb_sky"] = sky_model(image_info)outputs["rgb_sky_blend"] = outputs["rgb_sky"] * (1.0 - outputs["opacity"])
EnvLight 代碼:
import torch# 定義環境光類(EnvLight),繼承自 torch.nn.Module
class EnvLight(torch.nn.Module):def __init__(self, class_name: str, resolution: int = 1024, device: torch.device = torch.device("cuda"), **kwargs):# 初始化函數,接收類名、分辨率、設備(默認 GPU)以及其他關鍵字參數super().__init__()# 設置類的前綴,方便后續參數管理self.class_prefix = class_name + "#"# 設置設備(默認為 GPU)self.device = device# 定義 OpenGL 轉換矩陣,將世界坐標系轉換為 OpenGL 坐標系# 該矩陣的作用是轉換方向向量self.to_opengl = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32, device="cuda")# 定義基礎光照參數:初始化為一個 6 x resolution x resolution 的全 0.5 張量,# 每個光照樣本有 3 個值(RGB)。該參數是可訓練的(requires_grad=True)self.base = torch.nn.Parameter(0.5 * torch.ones(6, resolution, resolution, 3, requires_grad=True),)def forward(self, image_info: ImageInfo):# 前向傳播函數,接受一個 ImageInfo 類型的輸入(包含射線信息)# 獲取傳入圖像信息中的方向向量(viewdirs),表示視角方向directions = image_info.rays.viewdirs# 將方向向量從世界坐標系轉換到 OpenGL 坐標系directions = (directions.reshape(-1, 3) @ self.to_opengl.T).reshape(*directions.shape)# 重新調整方向向量的內存布局為連續的,以便后續操作directions = directions.contiguous()# 獲取方向向量的前綴尺寸,用于后續的形狀調整prefix = directions.shape[:-1]# 如果前綴尺寸不是三維(即 [B, H, W]),則將方向向量重塑為 [1, 1, -1, 3]# 目的是將其轉換為適合批量處理的形狀if len(prefix) != 3: # reshape to [B, H, W, -1]directions = directions.reshape(1, 1, -1, directions.shape[-1])# 使用 dr.texture 函數計算光照(dr 是某個光照計算庫)# `self.base[None, ...]` 代表基礎光照紋理,`directions` 是輸入的方向向量# `filter_mode="linear"` 表示紋理的過濾模式,`boundary_mode="cube"` 表示紋理的邊界模式light = dr.texture(self.base[None, ...], directions, filter_mode="linear", boundary_mode="cube")# 將輸出的光照結果 reshaped 為適合的形狀light = light.view(*prefix, -1)return lightdef get_param_groups(self):# 獲取模型參數分組,返回一個字典# 這里我們將所有參數歸為一個組,鍵為 "class_name + all"return {self.class_prefix + "all": self.parameters(),}