AlphaFold3 msa_pairing 模塊的 _correct_post_merged_feats
函數用于對合并后的特征進行修正,確保它們符合預期的格式和要求。這包括可能的對特征值進行調整或進一步的格式化,確保合并后的 FeatureDict
適合于后續模型的輸入。
主要作用是:
- 在多鏈蛋白質 MSA(多序列比對)合并后,重新計算/調整某些特征:
seq_length
(序列長度)num_alignments
(MSA 比對的序列數)
- 為 MSA 生成合適的掩碼(mask),用于模型訓練:
cluster_bias_mask
:控制 MSA 的 query 序列位置。bert_mask
:用于 BERT-style MSA 預訓練掩碼。
源代碼:
def _correct_post_merged_feats(np_example: Mapping[str, np.ndarray],np_chains_list: Sequence[Mapping[str, np.ndarray]],pair_msa_sequences: bool
) -> Mapping[str, np.ndarray]:"""Adds features that need to be computed/recomputed post merging."""np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0],dtype=np.int32)np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0],dtype=np.int32)if not pair_msa_sequences:# Generate a bias that is 1 for the first row of every block in the# block diagonal MSA - i.e. make sure the cluster stack always includes# the query sequences for each chain (since the first row is the query# sequence).cluster_bias_masks = []for chain in np_chains_list:mask = np.zeros(chain['msa'].shape[0])mask[0] = 1cluster_bias_masks.append(mask)np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)# Initialize Bert mask with masked out off diagonals.msa_masks = [np.ones(x['msa'].shape, dtype=np.float32)for x in np_chains_list]np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0)else:np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])np_example['cluster_bias_mask'][0] = 1# Initialize Bert mask with masked out off diagonals.msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) forx in np_chains_list]msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) forx in np_chains_list]