視頻序列和射頻信號多模態融合算法Fusion-Vital解讀
- 概述
- 模型整體流程
- 視頻幀時間差分歸一化
- TSM模塊
- 視頻序列特征融合模塊
- 跨模態特征融合模塊
概述
最近看了Fusion-Vital的視頻-射頻(RGB-RF)融合Transformer模型。記錄一下,對于實際項目中的多模態數據融合有一定參考價值。原始論文,參考實現源碼。
具體來說,Fusion-Vital模型首先將多幀視頻RGB圖像投影到一個共享的時間差分域中,以有效捕捉微小的生理信號,同時避免全局運動的干擾。
對于RF射頻模態,利用多普勒特性,通過短時傅里葉變換(STFT)生成時間-頻率圖像,作為時間差分域的替代指標。
模型采用并行編碼分支,分別處理RGB和RF數據,并引入多級特征融合模塊,利用交叉注意力機制在時間差分域中對齊和融合兩種模態的特征。
模型整體流程
整個模型有以下模塊組成:
- 時序移動模塊(TSM)
- RGB通路(運動分支 + 外觀分支)
- RF通路(射頻分支)
- 注意力機制(Attention)
- 兩次跨模態交互塊(CrossAttentionModule)
- 池化與Dropout
- 最終MLP輸出
輸入輸出
- 輸入:
rgb_input: 形狀為 (B, C=3, T, H, W),視頻幀序列。
rf_input: 形狀為 (B, C=4, T, F),射頻信號時頻特征。 - 輸出:
bvp :形狀為 (B, T),這里為脈搏波形。
視頻幀時間差分歸一化
- 連續的視頻幀時間差分歸一化處理。
- 維度變換?:
輸入:(B, C, T, H, W)
輸出:(B, C, T, H, W)
def diff_normalize_data(x):"""Calculate discrete difference in video data along the time-axis and nornamize by its standard deviation."""B, C, T, H, W = x.shape# denominatordenominator = torch.ones((B, C, T, H, W), dtype=torch.float32, device=x.device)for j in range(T - 1):denominator[:, :, j, :, :] = x[:, :, j + 1, :, :] + x[:, :, j, :, :] + 1e-7x_diff = torch.cat([torch.zeros((B, C, 1, H, W), device=x.device), x.diff(dim=2)], dim=2) / denominatorx_diff = x_diff / x_diff.view(B, -1).std(dim=1)[:, None, None, None, None]x_diff[torch.isnan(x_diff)] = 0return x_diff
TSM模塊
- 通道分割?:將特征通道分為3部分:
前1/3:向前時序移位(用下一幀的特征替換當前幀)
中1/3:向后時序移位(用上一幀的特征替換當前幀)
后1/3:保持不變 - 維度變換?:
輸入:(B×T, C, H, W)
class TSM(nn.Module):def __init__(self, n_segment=32, fold_div=3):super(TSM, self).__init__()self.n_segment = n_segmentself.fold_div = fold_divdef forward(self, x):nt, c, h, w = x.size()n_batch = nt // self.n_segmentx = x.view(n_batch, self.n_segment, c, h, w)fold = c // self.fold_divout = torch.zeros_like(x)out[:, :-1, :fold] = x[:, 1:, :fold] # shift leftout[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift rightout[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shiftreturn out.view(nt, c, h, w)
視頻序列特征融合模塊
- 運動分支:對輸入的時序差分特征(diff_input)做兩次卷積和兩次TSM,得到 d2。
外觀分支:對原始輸入做兩次卷積,得到 r2。
注意力分支:對外觀分支 r2 做 1x1 卷積和 sigmoid,得到注意力權重 g1。
融合:d2 * g1,得到融合后的 rgb1 - 輸入:diff_input, raw_input → (BT, C, H, W)
輸出:rgb1 → (BT, nb_filters1=32, H, W)
# 第一次和第二次卷積
diff_input = self.TSM_1(diff_input)
d1 = torch.tanh(self.motion_conv1(diff_input))
d1 = self.TSM_2(d1)
d2 = torch.tanh(self.motion_conv2(d1))
r1 = torch.tanh(self.apperance_conv1(raw_input))
r2 = torch.tanh(self.apperance_conv2(r1))
g1 = torch.sigmoid(self.apperance_att_conv1(r2))
g1 = self.attn_mask_1(g1)
rgb1 = d2 * g1 # 第一次輸入Fusion Block的rgb1
- 注意力掩碼attn_mask的計算:
- 空間求和?:
首先沿高度(H)維度求和,保持維度(keepdim=True)
然后沿寬度(W)維度求和,得到每個空間位置的總和 - 歸一化處理?:
將輸入特征圖除以其空間總和,實現初步歸一化
乘以高度和寬度恢復數值范圍
乘以0.5的縮放因子
def forward(self, x):# 計算空間維度的總和 (高度和寬度)xsum = torch.sum(x, dim=2, keepdim=True) # 沿高度維度(H)求和xsum = torch.sum(xsum, dim=3, keepdim=True) # 沿寬度維度(W)求和# 獲取輸入形狀xshape = tuple(x.size())# 計算注意力掩碼并應用return x / xsum * xshape[2] * xshape[3] * 0.5
跨模態特征融合模塊
交叉注意力融合
- 展平空間/頻率維度,視頻幀序列維度變為
(B, T, H*W*C)
,射頻序列維度(B, T, F*C)
- 加入時間位置編碼,時間位置編碼由
time_indices = torch.arange(T, device=rgb.device)
然后經過Embedding實現,輸出維度不變:(B, T, H*W*C)
和(B, T, F*C)
- Transform1,Q為視頻特征,K和V為射頻特征:Q為視頻特征維度變換得到,維度
(T, B, D=64)
,K為射頻特征變換為得到,維度(T, B, D=64)
,這里把T和B的位置進行了變換,因為nn.MultiheadAttention 默認的輸入格式是 (seq_len, batch, embed_dim)。經過MultiheadAttention之后變換回視頻特征維度,也就是(B, T, H*W*C)
。 - Transform2: Q為射頻特征,K和V為視頻特征,同上,經過MultiheadAttention之后變換回射頻特征維度,也就是
(B, T, F*C)
。
def forward(self, rgb, rf):B, C, T, H, W = rgb.shape_, _, _, F = rf.shape# Flatten spatial dimensions to create sequences for cross-attentionrgb = rgb.permute(0, 2, 3, 4, 1).reshape(B, T, -1) # (B, T, H*W*C)rf = rf.permute(0, 2, 3, 1).reshape(B, T, -1) # (B, T, F*C)# print(rgb.shape, rf.shape)# Temporal Embeddingtime_indices = torch.arange(T, device=rgb.device)rgb_time_embeddings = self.rgb_embedding(time_indices).unsqueeze(0) # (1, T, H*W*C)rf_time_embeddings = self.rf_embedding(time_indices).unsqueeze(0) # (1, T, F*C)rgb = rgb + rgb_time_embeddingsrf = rf + rf_time_embeddings# nn.MultiheadAttention 默認的輸入格式是 (seq_len, batch, embed_dim),即 (T, B, D)。rgb, rf = rgb.permute(1, 0, 2), rf.permute(1, 0, 2)# Transform RGB to Q and RF to K and VQ = self.rgb_to_q(rgb) # (T, B, D)K = self.rf_to_k(rf) # (T, B, D)V = self.rf_to_v(rf) # (T, B, D)# Apply cross-attention: RGB as query, RF as key and valuergb_prime, _ = self.attention_rgb_rf(Q, K, V) # (T, B, D)rgb_prime = self.proj_rgb(rgb_prime) # (T, B, D)# Reverse the flattening process for RGB'rgb_prime = rgb_prime.view(T, B, H, W, C).permute(1, 4, 0, 2, 3) # (B, C, T, H, W)# Transform RF to Q and RGB to K and VQ = self.rf_to_q(rf) # (T, B, D)K = self.rgb_to_k(rgb) # (T, B, D)V = self.rgb_to_v(rgb) # (T, B, D)# Apply cross-attention: RF as query, RGB as key and valuerf_prime, _ = self.attention_rf_rgb(Q, K, V) # (T, B, D)rf_prime = self.proj_rf(rf_prime) # (T, B, D)# Reverse the flattening process for RF'rf_prime = rf_prime.view(T, B, F, C).permute(1, 3, 0, 2) # (B, C, T, F)return rgb_prime, rf_prime