Rigid Node: 表示 car
或者trucks
Deformable Node : 表示一些 分布之外的 non-rigid 的運動物體, 比如遠處的行人等和Cyclist。
在 load_objects
會讀取每一個 dynamic objects 的 'bounding box’的信息,具體如下:
frame_instances
記錄了每一幀都有哪些 instance, 以及對應 每一幀 其 位姿信息等;
instances_info
包含每一幀 對于哪些 instance 是可見的。
1. 讀取 Bounding Box 的基本信息
邏輯是先遍歷場景的 instance, 然后 再每一個 instance 的信息。
## 存放每一幀 的instance 的pose
instances_pose = np.zeros((num_full_frames, num_instances, 4, 4))for k, v in instances_info.items():instances_model_types[int(k)] = OBJECT_CLASS_NODE_MAPPING[v["class_name"]]for frame_idx, obj_to_world, box_size in zip(v["frame_annotations"]["frame_idx"], v["frame_annotations"]["obj_to_world"], v["frame_annotations"]["box_size"]):# the first ego pose as the origin of the world coordinate system.obj_to_world = np.array(obj_to_world).reshape(4, 4)obj_to_world = np.linalg.inv(ego_to_world_start) @ obj_to_worldinstances_pose[frame_idx, int(k)] = np.array(obj_to_world)instances_size[frame_idx, int(k)] = np.array(box_size)
根據 per_frame_instance_mask
來得到 每一幀對于哪些instance 是可見的。
per_frame_instance_mask = np.zeros((num_full_frames, num_instances))for frame_idx, valid_instances in frame_instances.items():per_frame_instance_mask[int(frame_idx), valid_instances] = 1
2. 使用Bounding Box 的信息初始化高斯
這里需要介紹一個 dynamic vehicle 非常重要的坐標系,物體坐標系(Object系),
其通常位于汽車的車輛中心
。所以任何一幀的 Lidar 通過w2o
矩陣可以將Lidar 點轉換到 canonical space, 完成對于多幀 Lidar 的聚集
將 Lidar 點 (世界坐標系下面的) 通過 轉化矩陣 w2o
轉換到 Object 坐標系下面 , 然后 根據 Bounding Box 的 Size 去保留 在 BBX 內部的點云,準備進行初始化。
o2w = self.pixel_source.instances_pose[fi, ins_id]o_size = self.pixel_source.instances_size[ins_id]# convert the lidar points to the instance's coordinate systemw2o = torch.inverse(o2w)o_pts = transform_points(lidar_pts, w2o)# 將BBX 之外的點通過 Mask 濾除,這一步是在局部 Object 坐標系下面進行的mask = ((o_pts[:, 0] > -o_size[0] / 2)& (o_pts[:, 0] < o_size[0] / 2)& (o_pts[:, 1] > -o_size[1] / 2)& (o_pts[:, 1] < o_size[1] / 2)& (o_pts[:, 2] > -o_size[2] / 2)& (o_pts[:, 2] < o_size[2] / 2))valid_pts = o_pts[mask]valid_colors = self.lidar_source.colors[lidar_dict["lidar_mask"]][mask]
通過 比較 在 instances_pose
的pose (O2W系) 移動,僅僅對于 動態的 instance 進行保留 。
因為車輛的移動其實可以看成是
O2W
坐標系的移動。 相當于車輛是靜止的,但是環境是運動的
if only_moving:# consider only the instances with non-zero flowslogger.info(f"Filtering out the instances with non-moving trajectories")new_instance_dict = {}for k, v in instance_dict.items():if v["num_pts"] > 0: ## 僅僅考慮有點的 instance# flows = v["flows"]# if flows.norm(dim=-1).mean() > moving_thres:# v.pop("flows")# new_instance_dict[k] = v# logger.info(f"Instance {k} has {v['num_pts']} lidar sample points")frame_info = self.pixel_source.per_frame_instance_mask[:, k]instances_pose = self.pixel_source.instances_pose[:, k]instances_trans = instances_pose[:, :3, 3]valid_trans = instances_trans[frame_info]traj_length = valid_trans[1:] - valid_trans[:-1]traj_length = torch.norm(traj_length, dim=-1).sum()if traj_length > traj_length_thres:new_instance_dict[k] = vlogger.info(f"Instance {k} has {v['num_pts']} lidar sample points")instance_dict = new_instance_dict
將所有幀的 Lidar Aggregated 到 Canonical Space 下面,如圖所示:
靜態高斯的初始化
靜態的 高斯初始化 = Lidar_samples + 半球內的隨機采樣點。 隨機采樣點是 PVG
這篇文章所介紹的, 在 球內部 和 球外面進行均勻采樣。
Rigid 高斯的初始化
從 Canonical Space
累計的 點云進行 高斯的各項屬性的初始化, 讀取 點云的 坐標和顏色,然后進行初始化。 并記錄了 每個bbx 的大小以及 每個instance 在每一幀的可見性,分別用 self.instances_size
和 self.instances_fv
表示。
## (num_instances, 3) BBX 的大小self.instances_size = torch.stack(instances_size).to(self.device) # # (num_frame, num_instances) instance 在每一幀的可見性
self.instances_fv = torch.cat(instances_fv, dim=1).to(self.device)
值得注意的是, Drivestudio 將每一幀的每一個 instance 的 BBX 的 的 pose 也作為參數去考慮優化:
# (num_frame, num_instances, 4) 四元數self.instances_quats = Parameter(self.quat_act(instances_quats))# (num_frame, num_instances, 3) 平移
self.instances_trans = Parameter(instances_trans)
高斯參數的優化器設置:
所有的 Rigid Nodes 會把放進一個 優化字典當中,然后一起優化,并不是每個 instance 去獨立的優化。
Rigid 的每一個GS 都是像原始的 3DGS 一樣,配置 每一個屬性的 學習率去進行優化的。
groups.append({'params': params,'name': params_name,'lr': optim_cfg.lr,'eps': optim_cfg.eps,'weight_decay': optim_cfg.weight_decay})
groups 構建好之后,全部一起當作字典丟進 Adam 優化器去進行優化
self.optimizer = torch.optim.Adam(groups, lr=0.0, eps=1e-15)
Sky Model
Drivestudio 使用場景的 Environment map
來對于 天空的顏色進行建模. Sky 被建模成一個 長方體 cube, 然后使用基于光線方向(Opengl系)來在 environment cube 上進行紋理查詢。這個 environment map 雖然沒有任何網絡,但是其本身的參數也是需要被優化的。 對應的 Code 如下
class EnvLight(torch.nn.Module):def __init__(self,class_name: str,resolution=1024,device: torch.device = torch.device("cuda"),**kwargs):super().__init__()self.class_prefix = class_name + "#"self.device = deviceself.to_opengl = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32, device="cuda")## 需要被優化的 environment mapself.base = torch.nn.Parameter(0.5 * torch.ones(6, resolution, resolution, 3, requires_grad=True),)def forward(self, image_infos):l = image_infos["viewdirs"]l = (l.reshape(-1, 3) @ self.to_opengl.T).reshape(*l.shape)l = l.contiguous()prefix = l.shape[:-1]if len(prefix) != 3: # reshape to [B, H, W, -1]l = l.reshape(1, 1, -1, l.shape[-1])light = dr.texture(self.base[None, ...], l, filter_mode='linear', boundary_mode='cube')light = light.view(*prefix, -1)return light
開始訓練:
針對每一個 Node 提取出場景的N 個動態對象高斯。 如果是 Rigid 物體的高斯,前面的代碼是采用 Object
系存儲的,需要轉換到 World
系,然后提取出來。
以 平移變化來分析:
首先我們有 frame_id
標記 我們訓練的是哪一幀,取出這一幀的所有 instance 對應的 旋轉和 rot_cur_frame
平移trans_cur_frame
. 假設我們有M個動態點,將這M個動態點 和 應用在M個旋轉和平移向量上,同時得到了這個所有動態類別在場景frame_id
對應的位置和坐標。
def transform_means(self, means: torch.Tensor) -> torch.Tensor:"""transform the means of instances to world spaceaccording to the pose at the current frame"""assert means.shape[0] == self.point_ids.shape[0], \"its a bug here, we need to pass the mask for points_ids"quats_cur_frame = self.instances_quats[self.cur_frame] # (num_instances, 4)rot_cur_frame = quat_to_rotmat(self.quat_act(quats_cur_frame)) # (num_instances, 3, 3)## 求出每個點的旋轉rot_per_pts = rot_cur_frame[self.point_ids[..., 0]] # (num_points, 3, 3)trans_cur_frame = self.instances_trans[self.cur_frame] # (num_instances, 3)## 求出每個點的平移trans_per_pts = trans_cur_frame[self.point_ids[..., 0]]# transform the means to world spacemeans = torch.bmm(rot_per_pts, means.unsqueeze(-1)).squeeze(-1) + trans_per_ptsreturn means
之后使用 gsplat
作為渲染的框架,執行渲染, 這里的動態和靜態實際上都是轉換到 世界系的 高斯 然后一起渲染的。 為了渲染 動態物體,將場景高斯的 動態物體的 Opacity
設置為0, 其他的屬性不用改變。