SPOT(Sequential Predictive Modeling of Clinical Trial Outcome with Meta-Learning)模型是用于臨床試驗結果預測的模型,
借鑒了模型無關元學習(MAML,Model-Agnostic Meta-Learning)的框架,將模型參數分為全局共享參數和任務特定參數,以平衡跨任務泛化與任務內適配:
一、任務定義:將每個試驗主題序列視為獨立任務
SPOT通過主題發現模塊(Topic Discovery)將臨床試驗數據聚類為多個主題(topic),每個主題包含具有相似特征(如疾病類型、治療方案、試驗設計)的試驗。由于同一主題的試驗在時間上具有連續性(按時間戳排序),SPOT將每個主題的時序試驗序列定義為一個“任務”。
- 動機:臨床試驗數據存在嚴重的不平衡性(如某些疾病或治療方案的試驗數量少,屬于“小眾任務”)。元學習的核心優勢是“學習如何學習”,能在少量數據上快速適應新任務,因此適合處理這類不平衡場景。
- 具體操作:每個主題的時序序列(如某類腫瘤藥物的I期試驗按年份排列的序列)被視為一個獨立任務,模型需要為每個任務學習特定的預測模式。
二、參數設計:全局參數與任務特定參數分離
SPOT借鑒了模型無關元學習(MAML,Model-Agnostic Meta-Learning)的框架,將模型參數分為全局共享參數和任務特定參數,以平衡跨任務泛化與任務內適配:
-
全局參數(θ?和θ?):
- θ?:來自靜態試驗嵌入模塊(如疾病編碼器GRAM、治療方案編碼器MPNN、入排標準編碼器Trial2Vec),負責提取所有試驗的通用特征(如疾病本體、分子結構的共性),在所有任務中共享。
- θ?:對應序列建模模塊(RNN和序列預測網絡)的基礎參數,用于捕捉時序模式的通用規律(如試驗設計隨時間演進的共性趨勢)。
-
任務特定參數(θ??):
- 針對每個主題任務k,θ??是θ?的微調版本,通過局部更新適配該主題的獨特時序模式(如某類罕見病試驗的成功率波動規律)。
- 設計目的:讓模型在保留全局共性的同時,為每個任務定制參數,避免“多數類任務”主導模型學習,提升對“小眾任務”的預測能力。
其算法流程可以分為以下幾個主要步驟:
1. 模型初始化
在初始化階段,會對模型的各種參數進行設置,并創建模型對象和主題發現器。以下是初始化部分的代碼:
class SPOT(TrialOutcomeBase):def __init__(self,num_topics=50,n_trial_projector=2,n_timestemp_projector=2,n_rnn_layer=1,criteria_column='criteria',batch_size=1,n_trial_per_batch=None,learning_rate=1e-4,weight_decay=1e-4,epochs=10,evaluation_steps=50,warmup_ratio=0,device="cuda:0",seed=42,output_dir="./checkpoints/spot",):self.config = {'num_topics': num_topics,'n_trial_projector': n_trial_projector,'n_timestemp_projector': n_timestemp_projector,'n_rnn_layer': n_rnn_layer,'criteria_column': criteria_column,'batch_size': batch_size,'n_trial_per_batch':n_trial_per_batch,'learning_rate': learning_rate,'epochs': epochs,'weight_decay': weight_decay,'evaluation_steps':eva