【2025】Datawhale AI春訓營-RNA結構預測(AI+創新藥)-Task2筆記

【2025】Datawhale AI春訓營-RNA結構預測(AI+創新藥)-Task2筆記

本文對Task2提供的進階代碼進行理解。

任務描述

Task2的任務仍然是基于給定的RNA三維骨架結構,生成一個或多個RNA序列,使得這些序列能夠折疊并盡可能接近給定的目標三維骨架結構。這是一個RNA逆折疊的過程。

將RNA序列折疊成特定三維結構的過程是一個RNA折疊的過程。

在Task2中,繼續使用算法進行RNA逆折疊。評估標準是序列的恢復率,即算法生成的RNA序列在多大程度上能與真實能夠折疊成目標結構的RNA序列相似。

代碼理解

1、導入模塊

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import TransformerConv, LayerNorm
from torch_geometric.nn import radius_graph
from Bio import SeqIO
import math

2、配置參數

# 配置參數
class Config:seed = 42device = "cuda" if torch.cuda.is_available() else "cpu"batch_size = 16 if torch.cuda.is_available() else 8  # 根據顯存調整lr = 0.001epochs = 50seq_vocab = "AUCG"coord_dims = 7  hidden_dim = 256num_layers = 4  # 減少層數防止顯存溢出k_neighbors = 20  dropout = 0.1rbf_dim = 16num_heads = 4amp_enabled = True  # 混合精度訓練

3、定義幾何生成器

# 幾何特征生成器
class GeometricFeatures:@staticmethoddef rbf(D, D_min=0., D_max=20., D_count=16):device = D.deviceD_mu = torch.linspace(D_min, D_max, D_count, device=device)D_mu = D_mu.view(*[1]*len(D.shape), -1)D_sigma = (D_max - D_min) / D_countD_expand = D.unsqueeze(-1)return torch.exp(-((D_expand - D_mu)/D_sigma) ** 2)@staticmethoddef dihedrals(X, eps=1e-7):X = X.to(torch.float32)L = X.shape[0]dX = X[1:] - X[:-1]U = F.normalize(dX, dim=-1)# 計算連續三個向量u_prev = U[:-2]u_curr = U[1:-1]u_next = U[2:]# 計算法向量n_prev = F.normalize(torch.cross(u_prev, u_curr, dim=-1), dim=-1)n_curr = F.normalize(torch.cross(u_curr, u_next, dim=-1), dim=-1)# 計算二面角cosD = (n_prev * n_curr).sum(-1)cosD = torch.clamp(cosD, -1+eps, 1-eps)D = torch.sign((u_prev * n_curr).sum(-1)) * torch.acos(cosD)# 填充處理if D.shape[0] < L:D = F.pad(D, (0,0,0,L-D.shape[0]), "constant", 0)return torch.stack([torch.cos(D[:,:5]), torch.sin(D[:,:5])], -1).view(L,-1)@staticmethoddef direction_feature(X):dX = X[1:] - X[:-1]return F.pad(F.normalize(dX, dim=-1), (0,0,0,1))

4、定義圖構建器

# 圖構建器
class RNAGraphBuilder:@staticmethoddef build_graph(coord, seq):assert coord.shape[1:] == (7,3), f"坐標維度錯誤: {coord.shape}"coord = torch.tensor(coord, dtype=torch.float32)# 節點特征node_feats = [coord.view(-1, 7 * 3),  # [L,21]GeometricFeatures.dihedrals(coord[:,:6,:]),  # [L,10]GeometricFeatures.direction_feature(coord[:,4,:])  # [L,3]]x = torch.cat(node_feats, dim=-1)  # [L,34]# 邊構建pos = coord[:,4,:]edge_index = radius_graph(pos, r=20.0, max_num_neighbors=Config.k_neighbors)# 邊特征row, col = edge_indexedge_vec = pos[row] - pos[col]edge_dist = torch.norm(edge_vec, dim=-1, keepdim=True)edge_feat = torch.cat([GeometricFeatures.rbf(edge_dist).squeeze(1),  # [E,16]F.normalize(edge_vec, dim=-1)  # [E,3]], dim=-1)  # [E,19]# 標簽y = torch.tensor([Config.seq_vocab.index(c) for c in seq], dtype=torch.long)return Data(x=x, edge_index=edge_index, edge_attr=edge_feat, y=y)

5、定義模型結構

# 模型架構
class RNAGNN(nn.Module):def __init__(self):super().__init__()# 節點特征編碼self.feat_encoder = nn.Sequential(nn.Linear(34, Config.hidden_dim),nn.ReLU(),LayerNorm(Config.hidden_dim),nn.Dropout(Config.dropout))# 邊特征編碼(關鍵修復)self.edge_encoder = nn.Sequential(nn.Linear(19, Config.hidden_dim),nn.ReLU(),LayerNorm(Config.hidden_dim),nn.Dropout(Config.dropout))# Transformer卷積層self.convs = nn.ModuleList([TransformerConv(Config.hidden_dim,Config.hidden_dim // Config.num_heads,heads=Config.num_heads,edge_dim=Config.hidden_dim,  # 匹配編碼后維度dropout=Config.dropout) for _ in range(Config.num_layers)])# 殘差連接self.mlp_skip = nn.ModuleList([nn.Sequential(nn.Linear(Config.hidden_dim, Config.hidden_dim),nn.ReLU(),LayerNorm(Config.hidden_dim)) for _ in range(Config.num_layers)])# 分類頭self.cls_head = nn.Sequential(nn.Linear(Config.hidden_dim, Config.hidden_dim),nn.ReLU(),LayerNorm(Config.hidden_dim),nn.Dropout(Config.dropout),nn.Linear(Config.hidden_dim, len(Config.seq_vocab)))self.apply(self._init_weights)def _init_weights(self, module):if isinstance(module, nn.Linear):nn.init.xavier_uniform_(module.weight)if module.bias is not None:nn.init.constant_(module.bias, 0)def forward(self, data):x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr# 邊特征編碼(關鍵步驟)edge_attr = self.edge_encoder(edge_attr)  # [E,19] -> [E,256]# 節點編碼h = self.feat_encoder(x)# 消息傳遞for i, (conv, skip) in enumerate(zip(self.convs, self.mlp_skip)):h_res = conv(h, edge_index, edge_attr=edge_attr)h = h + skip(h_res)if i < len(self.convs)-1:h = F.relu(h)h = F.dropout(h, p=Config.dropout, training=self.training)return self.cls_head(h)

6、定義數據增強類

# 數據增強
class CoordTransform:@staticmethoddef random_rotation(coords):device = torch.device(Config.device)coords_tensor = torch.from_numpy(coords).float().to(device)angle = np.random.uniform(0, 2*math.pi)rot_mat = torch.tensor([[math.cos(angle), -math.sin(angle), 0],[math.sin(angle), math.cos(angle), 0],[0, 0, 1]], device=device)return (coords_tensor @ rot_mat.T).cpu().numpy()

7、定義數據集類

# 數據集類
class RNADataset(torch.utils.data.Dataset):def __init__(self, coords_dir, seqs_dir, augment=False):self.samples = []self.augment = augmentfor fname in os.listdir(coords_dir):# 加載坐標coord = np.load(os.path.join(coords_dir, fname))coord = np.nan_to_num(coord, nan=0.0)# 數據增強if self.augment and np.random.rand() > 0.5:coord = CoordTransform.random_rotation(coord)# 加載序列seq_id = os.path.splitext(fname)[0]seq_path = os.path.join(seqs_dir, f"{seq_id}.fasta")seq = str(next(SeqIO.parse(seq_path, "fasta")).seq)# 構建圖self.samples.append(RNAGraphBuilder.build_graph(coord, seq))def __len__(self): return len(self.samples)def __getitem__(self, idx): return self.samples[idx]

8、訓練函數

# 訓練函數
def train(model, loader, optimizer, scheduler, criterion):model.train()scaler = torch.cuda.amp.GradScaler(enabled=Config.amp_enabled)total_loss = 0for batch in loader:batch = batch.to(Config.device)optimizer.zero_grad()with torch.cuda.amp.autocast(enabled=Config.amp_enabled):logits = model(batch)loss = criterion(logits, batch.y)scaler.scale(loss).backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)scaler.step(optimizer)scaler.update()total_loss += loss.item()scheduler.step()return total_loss / len(loader)

9、評估函數

# 評估函數
def evaluate(model, loader):model.eval()total_correct = total_nodes = 0with torch.no_grad():for batch in loader:batch = batch.to(Config.device)logits = model(batch)preds = logits.argmax(dim=1)total_correct += (preds == batch.y).sum().item()total_nodes += batch.y.size(0)return total_correct / total_nodes

10、主函數

if __name__ == "__main__":# 初始化torch.manual_seed(Config.seed)if torch.cuda.is_available():torch.cuda.manual_seed_all(Config.seed)torch.backends.cudnn.benchmark = True# 數據集train_set = RNADataset("./RNA_design_public/RNAdesignv1/train/coords","./RNA_design_public/RNAdesignv1/train/seqs",augment=True)# 劃分數據集train_size = int(0.8 * len(train_set))val_size = (len(train_set) - train_size) // 2test_size = len(train_set) - train_size - val_sizetrain_set, val_set, test_set = torch.utils.data.random_split(train_set, [train_size, val_size, test_size])# 數據加載train_loader = torch_geometric.loader.DataLoader(train_set, batch_size=Config.batch_size, shuffle=True,pin_memory=True,num_workers=4)val_loader = torch_geometric.loader.DataLoader(val_set, batch_size=Config.batch_size)test_loader = torch_geometric.loader.DataLoader(test_set, batch_size=Config.batch_size)# 模型初始化model = RNAGNN().to(Config.device)optimizer = optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=0.01)scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.epochs)criterion = nn.CrossEntropyLoss()# 訓練循環best_acc = 0for epoch in range(Config.epochs):train_loss = train(model, train_loader, optimizer, scheduler, criterion)val_acc = evaluate(model, val_loader)print(f"Epoch {epoch+1}/{Config.epochs} | Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f}")if val_acc > best_acc:best_acc = val_acctorch.save(model.state_dict(), "best_model.pth")# 最終測試model.load_state_dict(torch.load("best_model.pth"))test_acc = evaluate(model, test_loader)print(f"\nFinal Test Accuracy: {test_acc:.4f}")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/77895.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/77895.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/77895.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vim 命令復習

命令模式下的命令及快捷鍵 # dd刪除光所在行的內容 # ndd從光標所在行開始向下刪除n行 # yy復制光標所在行的內容 # nyy復制光標所在行向下n行的內容 # p將復制的內容粘貼到光標所在行以下&#xff08;小寫&#xff09; # P將復制的內容粘貼到光標所在行以上&#xff08;大寫&…

哪些心電圖表現無緣事業編體檢呢?

根據《公務員錄用體檢通用標準》心血管系統條款及事業單位體檢實施細則&#xff0c;心電圖不合格主要涉及以下類型及處置方案&#xff1a; 一、心律失常類 早搏&#xff1a;包括房性早搏、室性早搏和交界性早搏。如果每分鐘早搏次數較多&#xff08;如超過5次&#xff09;&…

Linux學習——UDP

編程的整體框架 bind&#xff1a;綁定服務器&#xff1a;TCP地址和端口號 receivefrom()&#xff1a;阻塞等待客戶端數據 sendto():指定服務器的IP地址和端口號&#xff0c;要發送的數據 無連接盡力傳輸&#xff0c;UDP:是不可靠傳輸 實時的音視頻傳輸&#x…

ReAct Agent 實戰:基于DeepSeek從0到1實現大模型Agent的探索模式

寫在前面:動態思考,邊想邊做 大型語言模型(LLM)的崛起開啟了通用人工智能(AGI)的無限遐想。但要讓 LLM 從一個被動的“文本生成器”轉變為能夠主動解決問題、與環境交互的智能體(Agent),我們需要賦予它思考、行動和學習的能力。ReAct (Reason + Act) 框架正是實現這一…

從物理到預測:數據驅動的深度學習的結構化探索及AI推理

在當今科學探索的時代&#xff0c;理解的前沿不再僅僅存在于我們書寫的方程式中&#xff0c;也存在于我們收集的數據和構建的模型中。在物理學和機器學習的交匯處&#xff0c;一個快速發展的領域正在興起&#xff0c;它不僅觀察宇宙&#xff0c;更是在學習宇宙。 AI推理 我們…

結合地理數據處理

CSV 文件不僅可以存儲表格數據&#xff0c;還可以與地理空間數據結合&#xff0c;實現更強大的地理處理功能。例如&#xff0c;你可以將 CSV 文件中的坐標數據轉換為點要素類&#xff0c;然后進行空間分析。 示例&#xff1a;將 CSV 文件中的坐標數據轉換為點要素類 假設我們有…

SpringBoot中6種自定義starter開發方法

在SpringBoot生態中,starter是一種特殊的依賴,它能夠自動裝配相關組件,簡化項目配置。 自定義starter的核心價值在于: ? 封裝復雜的配置邏輯,實現開箱即用 ? 統一技術組件的使用規范,避免"輪子"泛濫 ? 提高開發效率,減少重復代碼 方法一:基礎配置類方式 …

滾珠導軌松動會導致哪些影響?

直線導軌用于高精度或快速直線往復運動場所&#xff0c;且能夠擔負一定的扭矩&#xff0c;在高負載的情況下實現高精度的直線運動。它主要由導軌和滑塊組成&#xff0c;其中導軌作為固定元件&#xff0c;滑塊則在其上進行往復直線運動。但是滾珠導軌松動會導致哪些影響&#xf…

從零開始搭建Django博客②--Django的服務器內容搭建

本文主要在Ubuntu環境上搭建&#xff0c;為便于研究理解&#xff0c;采用SSH連接在虛擬機里的ubuntu-24.04.2-desktop系統搭建&#xff0c;當涉及一些文件操作部分便于通過桌面化進行理解&#xff0c;通過Nginx代理綁定域名&#xff0c;對外發布。 此為從零開始搭建Django博客…

ZLMediaKit支持JT1078實時音視頻

ZLMediaKit 對 JT1078 實時音視頻協議的支持主要通過其擴展版本或與其他中間件結合實現。以下是基于搜索結果的綜合分析&#xff1a; 一、ZLMediaKit 原生支持能力 開源版本的基礎支持 ZLMediaKit 開源版本本身未直接集成 JT1078 協議解析模塊&#xff0c;但可通過 RTP 推流功能…

Java隊列(Queue)核心操作與最佳實踐:深入解析與面試指南

文章目錄 概述一、Java隊列核心實現類對比1. LinkedList2. ArrayDeque3. PriorityQueue 二、核心操作API與時間復雜度三、經典使用場景與最佳實踐場景1&#xff1a;BFS層序遍歷&#xff08;樹/圖&#xff09;場景2&#xff1a;滑動窗口最大值&#xff08;單調隊列&#xff09; …

MetaGPT智能體框架深度解析:記憶模塊設計與應用實踐

在AI智能體技術從單點突破邁向系統工程的關鍵階段&#xff0c;MetaGPT憑借其創新的記憶架構重新定義了多智能體協作范式。本文深度解構其革命性的三級記憶系統&#xff0c;揭秘支撐10倍效能提升的知識蒸餾算法與動態上下文控制策略&#xff0c;通過企業級應用案例與性能基準測試…

集結號海螺捕魚服務器調度與房間分配機制詳解:六

本篇圍繞服務器調度核心邏輯進行剖析&#xff0c;重點講解用戶連接過程、房間分配機制、服務端并發策略及常見性能瓶頸優化。適用于具備中高級 C 后端開發經驗的讀者&#xff0c;覆蓋網絡會話池、邏輯服調度器與房間生命周期管理等關鍵模塊。 一、服務器結構概覽 整體系統采用…

【電子通識】熱敏打印機是怎么形成(打印)圖像和文字的?

在我們身邊&#xff0c;熱敏打印方式常見用于裝飾貼紙、便利店的小票。此外&#xff0c;物流及食品條碼標簽、身份證件、機票?火車票、X光片、食品日期印刷等&#xff0c;很多打印都用到了熱敏打印頭。 熱敏打印頭的蓄熱層(涂釉層)上分布著一排加熱元件&#xff08;發熱線&…

SQL注入漏洞中會使用到的函數

目錄 一、信息獲取函數 1. 通用函數 2. 元數據查詢&#xff08;INFORMATION_SCHEMA&#xff09; 二、字符串操作函數 1. 字符串連接 2. 字符串截取 3. 編碼/解碼 三、報錯注入專用函數 1. MySQL 2. SQL Server 3. PostgreSQL 四、時間盲注函數 1. 通用延遲 2. 計…

車載信息安全架構 --- 汽車網絡安全

我是穿拖鞋的漢子,魔都中堅持長期主義的汽車電子工程師。 老規矩,分享一段喜歡的文字,避免自己成為高知識低文化的工程師: 周末洗了一個澡,換了一身衣服,出了門卻不知道去哪兒,不知道去找誰,漫無目的走著,大概這就是成年人最深的孤獨吧! 舊人不知我近況,新人不知我過…

Linux423 刪除用戶

查找 上面已查過&#xff1a;無法使用sudo 新開個終端試試 之前開了一個終端&#xff0c;按照deepseek排查 計劃再開一個進程 開一個終端 后強制刪除時顯示&#xff1a;此事將被報告

《從卷積核到數字解碼:CNN 手寫數字識別實戰解析》

文章目錄 一、手寫數字識別的本質與挑戰二、使用步驟1.導入torch庫以及與視覺相關的torchvision庫2.下載datasets自帶的手寫數字的數據集到本地 三、完整代碼展示 一、手寫數字識別的本質與挑戰 手寫數字識別的核心是&#xff1a;從二維像素矩陣中提取具有判別性的特征&#x…

UniOcc:自動駕駛占用預測和預報的統一基準

25年3月來自 UC Riverside、U Wisconsin 和 TAMU 的論文"UniOcc: A Unified Benchmark for Occupancy Forecasting and Prediction in Autonomous Driving"。 UniOcc 是一個全面統一的占用預測基準&#xff08;即基于歷史信息預測未來占用&#xff09;和基于攝像頭圖…

模型量化核心技術解析:從算法原理到工業級實踐

一、模型量化為何成為大模型落地剛需&#xff1f; 算力困境&#xff1a;175B參數模型FP32推理需0.5TB內存&#xff0c;超出主流顯卡容量 速度瓶頸&#xff1a;FP16推理延遲難以滿足實時對話需求&#xff08;如客服場景<200ms&#xff09; 能效挑戰&#xff1a;邊緣設備運行…