默認的 分辨率是 [160,240] ,基于 Transformer 的方法不能做高分辨率。
Dataloader
輸入是 帶有 pose 信息的 RGB 圖像
eval datasets
## 采樣幀數目 = 20
num_max_future_frames = int(self.timespan * fps)
## 每次間隔多少個時間 timesteps 取一個context image
num_context_timesteps = 4
按照STORM
原來的 setting, future_frames = 20
context_image
每次間隔4幀,所以是 context_frame_idx = [0,5,10,15]
, 在 target_frame
包含了 從[0,20]的所有20幀。
以這樣20 幀的 image 作為一個基本的 batch, 進行預測: 進入 model
所以,輸入網絡的 context_image
對應的 shape (1,4,3,160,240)
輸入4個時刻幀的 frame, 每一個 frame 有 3個相機;對應的 context_camtoworlds
shape (1,4,3,4,4)
train datasets
第一幀 ID
隨機采樣, 之后的 context_image
每次間隔 5 幀,比如: [47 52 57 62]
: target·_frame_id
也是進行隨機選取:
if self.equispaced:context_frame_idx = np.arange(context_frame_idx,context_frame_idx + num_max_future_frames,num_max_future_frames // self.num_context_timesteps,)
隨機在 num_future_id 里面 選擇 self.num_target_timesteps
選擇 4幀作為 target_image
的監督幀
Network
輸入網絡的 有3個 input: context_image, ray 和 time 的信息
- context_image: (1,4,3,3,160,240)
- Ray embedding (1,4,3,6,160,240)
- time_embedding (1,4,3)
- 將 image 和 ray_embedding 進行
concat
操作, 得到 x:(12,9,160,240):
x = rearrange(x, "b t v c h w -> (b t v) c h w")
plucker_embeds = rearrange(plucker_embeds, "b t v h w c-> (b t v) c h w")
x = torch.cat([x, plucker_embeds], dim=1) ## (12,9,160,240)
然后經過3個 embedding , 將這些 feature 映射成為 token:
x = self.patch_embed(x) # (b t v) h w c2x = self._pos_embed(x) # (b t v) (h w) c2x = self._time_embed(x, time, num_views=v)
得到 x.shape (12,600,768)
, 表示一共有12張圖像,每個圖象 是 600 個 token, 每個 token 的 channel 是768. 然后將這些 token concat 在一起 得到了 (7200,768) 的 feature;
給得到的 token 分別加上可學習的 motion_token, affine_token 和 sky_token. 連接方式都是 concat
這樣得到的 feature 為 (7220,768)的 feature
if self.num_motion_tokens > 0:motion_tokens = repeat(self.motion_tokens, "1 k d -> b k d", b=x.shape[0])x = torch.cat([motion_tokens, x], dim=-2)
if self.use_affine_token:affine_token = repeat(self.affine_token, "1 k d -> b k d", b=b)x = torch.cat([affine_token, x], dim=-2)
if self.use_sky_token:sky_token = repeat(self.sky_token, "1 1 d -> b 1 d", b=x.shape[0])x = torch.cat([sky_token, x], dim=-2)
- 使用 Transformer 進行學習, 得到的 feature 維度不變。:
x = self.transformer(x)x = self.norm(x) ## shape(7220,768)
運行完之后,可以將學習到的 token
提取出來:
if self.use_sky_token:sky_token = x[:, :1] ## (1,1,768)x = x[:, 1:]if self.use_affine_token:affine_tokens = x[:, : self.num_cams] ## (1,3,768)x = x[:, self.num_cams :]if self.num_motion_tokens > 0:motion_tokens = x[:, : self.num_motion_tokens] ## (1,16,768)x = x[:, self.num_motion_tokens :]
在 Transformer 內部,沒有上采樣層,也可以實現 這種 per-pixel feature 的學習。
對于 x 進行 GS 的預測,得到 pixel_align 的高斯。 對于每個 patch, 得到的 feature 是 (12,600,768)
, 通過一個CNN,雖然通道數沒有變 (12,600,768)
, 但是之前 768 可以理解為全局的 語義, 之后的 768 為 一個patch 內部不同像素的語義,他們共享著 全局的 語義信息,但是每個pixel 卻又不一樣。 通過下面的 unpatchify
函數將將一個patch 的語義拆成 per-pixel 的語義,將每個768維token
展開為8×8
像素。
b, t, v, h, w, _ = origins.shape## x_shape: (12,600,768)x = rearrange(x, "b (t v hw) c -> (b t v) hw c", t=t, v=v)## gs_params_shape: (12,600,768),這一步雖然通道沒變,但其實是將一個 token 的全局 語義,映射成## token 內部的像素級別的語義gs_params = self.gs_pred(x)## gs_params_shape: (12,12,160,240)### 關鍵步驟:unpatchify將每個768維token展開為8×8像素gs_params = self.unpatchify(gs_params, hw=(h, w), patch_size=self.unpatch_size)
根據 token
展開的 per-pixel feature, 進行3DGS 的屬性預測
gs_params = rearrange(gs_params, "(b t v) c h w -> b t v h w c", t=t, v=v)
depth, scales, quats, opacitys, colors = gs_params.split([1, 3, 4, 1, self.gs_dim], dim=-1)
scales = self.scale_act_fn(scales)
opacitys = self.opacity_act_fn(opacitys)
depths = self.depth_act_fn(depth)
colors = self.rgb_act_fn(colors)
means = origins + directions * depths
除了3DGS 的一半屬性之外, storm 還額外預測了其他的運動屬性,包括:
其中: x: (1,7200,768)
代表 image_token
, motion_tokens 是(1,16,768)
代表 motion_token. 處理的大致思路是 motion_token
作為 query,
然后 image_token
映射的feature 作為 key,
去結合計算每一個 高斯的 moition_weights
和 moition_bases
gs_params = self.forward_motion_predictor(x, motion_tokens, gs_params)
其中:
forward_flow = torch.einsum("b t v h w k, b k c -> b t v h w c", motion_weights, motion_bases)
moition_bases: shape: [1,16,3]
moition_weights: shape: [1,4,3,160,240,16]
forward_flow: shape: [1,4,3,160,240,3]:
是 weights 和bases 結合的結果
GS_param Rendering
- 取出高斯的各項屬性,尤其是 means 和 速度
forward_v
: STORM 假設 在這 20幀是出于勻速直線運動
, 其速度時不變的,可能并不合理。我們的方法直接預測 BBX,可能更為準確。
means = rearrange(gs_params["means"], "b t v h w c -> b (t v h w) c")
scales = rearrange(gs_params["scales"], "b t v h w c -> b (t v h w) c")
quats = rearrange(gs_params["quats"], "b t v h w c -> b (t v h w) c")
opacities = rearrange(gs_params["opacities"], "b t v h w -> b (t v h w)")
colors = rearrange(gs_params["colors"], "b t v h w c -> b (t v h w) c")
forward_v = rearrange(gs_params["forward_flow"], "b t v h w c -> b (t v h w) c")
這里得到的 高斯的 mean
是全部由 context_image
得到的, shape (46800,3)
, 但這其實是 4個 時刻context_frame_idx = [0,5,10,15]
, 得到的高斯,并不處于同一時間刻度。
通過比較 target_time
和 context_time
之間的插值,去得到每一個 target_time
的 3D Gaussian 的坐標means_batched
:
if tgt_time.ndim == 3:tdiff_forward = tgt_time.unsqueeze(2) - ctx_time.unsqueeze(1)tdiff_forward = tdiff_forward.view(b * tgt_t, t * v, 1)tdiff_forward_batched = tdiff_forward.repeat_interleave(h * w, dim=1)else:tdiff_forward = tgt_time.unsqueeze(-1) - ctx_time.unsqueeze(-2)tdiff_forward = tdiff_forward.view(b * tgt_t, t, 1)tdiff_forward_batched = tdiff_forward.repeat_interleave(v * h * w, dim=1)forward_translation = forward_v_batched * tdiff_forward_batchedmeans_batched = means_batched + forward_translation ## (20,460800,3)
使用 gsplat
的 batch_rasterization
函數:
rendered_color, rendered_alpha, _ = rasterization(means=means_batched.float(), ## (20,460800,3)quats=quats_batched.float(),scales=scales_batched.float(),opacities=opacities_batched.float(),colors=colors_batched.float(),viewmats=viewmats_batched, ## (20,3,4,4)Ks=Ks_batched, ## (20,3,3,3)width=tgt_w,height=tgt_h,render_mode="RGB+ED",near_plane=self.near,far_plane=self.far,packed=False,radius_clip=radius_clip,)
bug 記錄:
當使用單個相機的時候,下面這段代碼會把 維度搞錯:
if self.use_affine_token:affine = self.affine_linear(affine_tokens) # b v (gs_dim * (gs_dim + 1))affine = rearrange(affine, "b v (p q) -> b v p q", p=self.gs_dim)images = torch.einsum("b t v h w p, b v p q -> b t v h w p", images, affine)gs_params["affine"] = affine