總體
https://github.com/nerfstudio-project/gsplat
simple_trainer_mcmc.py
2個關鍵點:
- 高斯狀態轉移(每100iter調用)
- 高斯隨機過程(每1iter調用)
relocate_gs
- 對 alive gs 進行采樣,被采樣的 alive gs 將作為 dead gs 的轉移目標。
- 對被采樣的 alive gs 進行狀態更新,opacities和scales屬性會重新計算。
- 對 dead gs 進行狀態轉移。
add_new_gs
- 對 all gs 進行采樣
- 被采樣的 gs 進行狀態更新,opacities和scales屬性會重新計算
- 再把被采樣的 gs 作為 copy,添加到所有 gs 中。
add_noise_to_gs
- 根據 學習率 和 opacities 控制噪聲的大小
- 根據 quats 和 scales 控制噪聲的分布
- 得到 delt_xyz 噪聲
- 添加到 gs 的 xyz 屬性上
代碼AI解讀
relocate_gs
(add_new_gs 類似)
@torch.no_grad()def relocate_gs(self, min_opacity: float = 0.005) -> int:dead_mask = torch.sigmoid(self.splats["opacities"]) <= min_opacitydead_indices = dead_mask.nonzero(as_tuple=True)[0]alive_indices = (~dead_mask).nonzero(as_tuple=True)[0]num_gs = len(dead_indices)if num_gs <= 0:return num_gs# Sample for new GSseps = torch.finfo(torch.float32).epsprobs = torch.sigmoid(self.splats["opacities"])[alive_indices]probs = probs / (probs.sum() + eps)sampled_idxs = torch.multinomial(probs, num_gs, replacement=True) # 進行多項式采樣,num_gs 是要重新定位的粒子數量,replacement=True 表示允許重復采樣。sampled_idxs = alive_indices[sampled_idxs]new_opacities, new_scales = compute_relocation(opacities=torch.sigmoid(self.splats["opacities"])[sampled_idxs],scales=torch.exp(self.splats["scales"])[sampled_idxs],ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, # torch.bincount: 這個函數計算輸入張量中每個整數值的出現次數。對于 sampled_idxs,torch.bincount 的輸出將是一個包含每個索引出現次數的張量。例如,對于 sampled_idxs = [2, 1, 2, 3, 1],torch.bincount(sampled_idxs) 的輸出將是 [0, 2, 2, 1]。這里,0 表示索引 0 沒有出現,2 表示索引 1 出現了兩次,2 表示索引 2 出現了兩次,1 表示索引 3 出現了一次。)new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)self.splats["opacities"][sampled_idxs] = torch.logit(new_opacities)self.splats["scales"][sampled_idxs] = torch.log(new_scales)# Update splats and optimizersfor k in self.splats.keys():self.splats[k][dead_indices] = self.splats[k][sampled_idxs]for optimizer in self.optimizers:for i, param_group in enumerate(optimizer.param_groups):p = param_group["params"][0]name = param_group["name"]p_state = optimizer.state[p]del optimizer.state[p]for key in p_state.keys():if key != "step":p_state[key][sampled_idxs] = 0p_new = torch.nn.Parameter(self.splats[name])optimizer.param_groups[i]["params"] = [p_new]optimizer.state[p_new] = p_stateself.splats[name] = p_newtorch.cuda.empty_cache()return num_gs
compute_relocation
// Equation (9) in "3D Gaussian Splatting as Markov Chain Monte Carlo"
__global__ void compute_relocation_kernel(int N, float *opacities, float *scales,int *ratios, float *binoms, int n_max,float *new_opacities, float *new_scales) {int idx = threadIdx.x + blockIdx.x * blockDim.x;if (idx >= N)return;int n_idx = ratios[idx];float denom_sum = 0.0f;// compute new opacitynew_opacities[idx] = 1.0f - powf(1.0f - opacities[idx], 1.0f / n_idx);// compute new scalefor (int i = 1; i <= n_idx; ++i) {for (int k = 0; k <= (i - 1); ++k) {float bin_coeff = binoms[(i - 1) * n_max + k];float term = (pow(-1.0f, k) / sqrt(static_cast<float>(k + 1))) *pow(new_opacities[idx], k + 1);denom_sum += (bin_coeff * term);}}float coeff = (opacities[idx] / denom_sum);for (int i = 0; i < 3; ++i)new_scales[idx * 3 + i] = coeff * scales[idx * 3 + i];
}
計算新的透明度(Opacity):
使用公式 new_opacities[idx]=1.0?(1.0?opacities[idx])1.0/n_idx\text{new\_opacities}[idx] = 1.0 - (1.0 - \text{opacities}[idx])^{1.0 / n\_idx}new_opacities[idx]=1.0?(1.0?opacities[idx])1.0/n_idx 來計算新的透明度。這個公式是基于論文中的公式 (9) 推導出來的。
計算新的尺度(Scale):
通過一個嵌套的循環來計算新的尺度。這個過程涉及到二項式系數(
binoms
)和一些數學運算,包括冪運算和平方根運算。具體來說,內核函數計算了一個系數
coeff
,然后用這個系數來調整原始的尺度值,得到新的尺度值。
add_noise_to_gs
@torch.no_grad()def add_noise_to_gs(self, last_lr):opacities = torch.sigmoid(self.splats["opacities"])scales = torch.exp(self.splats["scales"])actual_covariance, _ = quat_scale_to_covar_preci(self.splats["quats"],scales,compute_covar=True,compute_preci=False,triu=False,)def op_sigmoid(x, k=100, x0=0.995):return 1 / (1 + torch.exp(-k * (x - x0)))noise = (torch.randn_like(self.splats["means3d"])* (op_sigmoid(1 - opacities)).unsqueeze(-1)* cfg.noise_lr* last_lr)noise = torch.bmm(actual_covariance, noise.unsqueeze(-1)).squeeze(-1)self.splats["means3d"].add_(noise) # 只改變xyz