AlphaFold3??data_transforms 模塊的?
squeeze_features?函數的作用去除?蛋白質特征張量中不必要的單維度(singleton dimensions)和重復維度,以使其適配?AlphaFold3?預期的輸入格式。
源代碼:
def squeeze_features(protein):"""Remove singleton and repeated dimensions in protein features."""protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)for k in ["domain_name","msa","num_alignments","seq_length","sequence","superfamily","deletion_matrix","resolution","between_segment_residues","residue_index","template_all_atom_mask",]:if k in protein:final_dim = protein[k].shape[-1]if isinstance(final_dim, int) and final_dim == 1:if torch.is_tensor(protein[k]):protein[k] = torch.squeeze(protein[k], dim=-1)else:protein[k] = np.squeeze(protein[k], axis=-1)for k in ["seq_length", "num_alignments"]:if k in protein:protein[k] = protein[k][0]return protein
源碼解讀:
- 該函數接收?
protein
(一個?包含蛋白質特征的字典)作為輸入。 - 主要任務:
- 將 one-hot?
aatype
?轉換為索引表示。 - 移除 shape 為?
(N, ..., 1)
?的單維度。 - 提取?
seq_length
?和?num_alignments
?的實際數值。
- 將 one-hot?
Step 1: 處理?aatype
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
- 輸入?
aatype
(氨基酸類型)通常是 one-hot 編碼 - 通過?
torch.argmax(..., dim=-1)
?獲取?索引 - 目的:簡化?
aatype
?的數據表示,使其直接存儲氨基酸索引,而不是 one-hot 矩陣。
Step 2: 移除單維度
for k in ["domain_name","msa","num_alignments","seq_length","sequence","superfamily","deletion_matrix","resolution","between_segment_residues","residue_index","template_all_atom_mask",
]:if k in protein:final_dim = protein[k].shape[-1] # 獲取最后一維的大小if isinstance(final_dim, int) and final_dim == 1:if torch.is_tensor(protein[k]):protein[k] = torch.squeeze(protein[k], dim=-1) # 去掉單維度else:protein[k] = np.squeeze(protein[k], axis=-1)
- 遍歷多個?
protein
?特征字段,檢查它們是否存在。 - 如果最后一維?
final_dim
?為?1
,說明這個維度是無意義的單維度,需要去除:- 如果是?PyTorch 張量(
torch.Tensor
),使用?torch.squeeze(dim=-1)
。 - 如果是?NumPy 數組,使用?
np.squeeze(axis=-1)
。
- 如果是?PyTorch 張量(
Step 3: 處理?seq_length
?和?num_alignments
for k in ["seq_length", "num_alignments"]:if k in protein:protein[k] = protein[k][0]
seq_length
?和?num_alignments
?可能是?列表或張量,但它們的數值其實是一個單獨的整數,因此需要轉換成?標量值。
結論
1???轉換?aatype
: 從?one-hot 編碼?轉換成?索引表示。
2???移除無用的單維度: 讓?msa
,?resolution
,?deletion_matrix
?等數據符合 AlphaFold3 預期格式。
3???轉換?seq_length
?和?num_alignments
?為標量: 確保它們不會以張量形式存在,而是整數。
💡?最終作用:保證輸入數據的維度符合 AlphaFold3 訓練時的輸入要求,提高數據處理效率。