DPO學習筆記
- 1 原理
- 1.0 名詞
- 1.1 preference model
- 1.2 RLHF
- 1.3 從RLHF到DPO
- A.解的最優形式
- B. DPO下參數估計
- C. DPO下梯度更新
- D. DPO訓練的穩定性
- 2 源代碼
- 2.1 數據集構成
- 2.2 計算log prob
- 2.3 DPO loss
1 原理
1.0 名詞
- preference model:對人類偏好進行建模,這個"model"不是DL model
- policy model:最終要訓練得到的LLM πθ\pi_\thetaπθ?
- reward model:用來評價LLM生成的結果有多符合人類偏好
1.1 preference model
- 是一種者范式、定義,是用來預測人類對不同輸出項之間相對偏好概率的模型,例如,在比較兩個響應時,偏好模型可以估計出“響應A比響應B更受歡迎”的概率
- DPO中使用的是Bradley–Terry 模型來定義偏好的概率形式,給定2個選項ywy_wyw?和yly_lyl?,Bradley–Terry 定義的的ywy_wyw?比yly_lyl?好的概率為
p(yw≥yl)=exp(θw)exp(θw)+exp(θl)p(y_w \ge y_l)=\frac{exp(\theta_w)}{exp(\theta_w)+exp(\theta_l)} p(yw?≥yl?)=exp(θw?)+exp(θl?)exp(θw?)?
1.2 RLHF
RLHF需要使用人標注的偏好數據對,先訓練一個reward model,然后再讓reward model和LLM做強化學習
【1】SFT訓練LLM: 使用目標任務的訓練數據訓練得到的模型記為πSFT\pi^{SFT}πSFT
【2】訓練reward model: 使用目標任務的另一份數據xxx輸入πSFT\pi^{SFT}πSFT,每份數據得到2個輸出,記為(y1,y2)~πSFT(y∣x)(y_1,y_2) \sim \pi^{SFT}(y \mid x)(y1?,y2?)~πSFT(y∣x)。這些成對的數據給到人工標注者,進行偏好標注,(y1,y2)(y_1,y_2)(y1?,y2?)里面人工覺得回答的好的數據為ywy_wyw?,覺得回答的不好的數據為yly_lyl?,得到的數據集為D={xi,ywi,yli}i=1N\mathcal{D}=\{x^{i},y^i_w,y^i_l\}^N_{i=1}D={xi,ywi?,yli?}i=1N?。假設這種偏好產生自一個隱藏的獎勵模型r?(y,x)r^*(y,x)r?(y,x),當使用Bradley-Terry模型來建模,人類偏好p?p^*p?的分布可以表示為
p?(yw?yl∣x)=exp(r?(x.y1))exp(r?(x.y1))+exp(r?(x.y2))p^*(y_w \succ y_l \mid x)=\frac{exp(r^*(x.y_1))}{exp(r^*(x.y_1))+exp(r^*(x.y_2))} p?(yw??yl?∣x)=exp(r?(x.y1?))+exp(r?(x.y2?))exp(r?(x.y1?))?
??可以形式化獎勵模型參數為r?(x,y)r_\phi(x,y)r??(x,y)并且使用極大似然估計在數據集D\mathcal{D}D上估計參數,建模為二分類問題,損失函數可以為(也可以是其他形式,相減比較符合認知):
LR(r?,D)=?E(x,yw,yl)~D[logσ(r?(x,yw)?r?(x,yl))]\mathcal{L}_R(r_\phi,\mathcal{D})=-\mathbb{E}_{(x,y_w,y_l)\sim\mathcal{D}}[log \sigma(r_\phi(x,y_w)-r_\phi(x,y_l))]LR?(r??,D)=?E(x,yw?,yl?)~D?[logσ(r??(x,yw?)?r??(x,yl?))]
【3】RL微調: 在RL階段,優化目標帶有KL約束
max?πθEx~D,y~πθ(y∣x)[r?(x,y)?βDKL[πθ(y∣x)∥πref(y∣x)]]\max_{\pi_{\theta}}\mathbb{E}_{x \sim \mathcal{D},y \sim \pi_{\theta}(y \mid x)}[r_\phi(x,y)-\beta\mathbb{D}_{KL}[\pi_{\theta}(y \mid x)\parallel \pi_{ref}(y \mid x)]] πθ?max?Ex~D,y~πθ?(y∣x)?[r??(x,y)?βDKL?[πθ?(y∣x)∥πref?(y∣x)]]
1.3 從RLHF到DPO
A.解的最優形式
??首先,根據RL優化目標的形式,獎勵函數為rrr,最優的策略π\piπ的形式為
πr(y∣x))=1Z(x)πref(y∣x)exp(1βr(x,y))\pi_r(y \mid x))=\frac{1}{Z(x)}\pi_{ref}(y \mid x) exp(\frac{1}{\beta}r(x,y)) πr?(y∣x))=Z(x)1?πref?(y∣x)exp(β1?r(x,y))
其中Z(x)=∑yπref(y∣x)exp(1βr(x,y))Z(x)=\sum_{y}\pi_{ref}(y \mid x) exp(\frac{1}{\beta}r(x,y))Z(x)=∑y?πref?(y∣x)exp(β1?r(x,y))。之所以能得到這個形式在原論文的附錄中有推導
??里面的第3步到第4步是因為可以引入Z(x)Z(x)Z(x)構造一個新的概率分布,Z(x)Z(x)Z(x)是歸一化因子,保證π~(y∣x)\tilde{\pi} (y \mid x)π~(y∣x)是有效的概率分布:
π~(y∣x)=1Z(x)πrefexp(1βr(x,y))\tilde{\pi} (y \mid x)=\frac{1}{Z(x)}\pi_{ref}exp(\frac{1}{\beta}r(x,y))π~(y∣x)=Z(x)1?πref?exp(β1?r(x,y))
??這樣,原來的式子
logπ(y∣x)πref(y∣x)=logπ(y∣x)?πref(y∣x)?log[exp(1βr(x,y))]=logπ(y∣x)π~(y∣x)?logZ(x)log \frac{\pi(y \mid x)}{\pi_{ref}(y \mid x)} =log\pi(y \mid x)-\pi_{ref}(y \mid x) - log[exp(\frac{1}{\beta}r(x,y))] \\ =log \frac{\pi(y \mid x)}{\tilde{\pi}_(y \mid x)} - log Z(x) logπref?(y∣x)π(y∣x)?=logπ(y∣x)?πref?(y∣x)?log[exp(β1?r(x,y))]=logπ~(?y∣x)π(y∣x)??logZ(x)
??又因π\piπ的形式只需要滿足是合法的概率分布就可以,因此形式上可以替換,以及Z(x)Z(x)Z(x)不是yyy的函數,所以期望寫進去不會對logZ(x)log Z(x)logZ(x)有影響,得到了最優策略下,策略函數的形式(給定xxx的情況下輸出yyy的概率 / 在給定狀態SSS的情況下,下一個時間的進入狀態S′S'S′的概率)
π?(y∣x)=1Z(x)πref(y∣x)exp(1βr(x,y))\pi^*(y \mid x)= \frac{1}{Z(x)}\pi_{ref}(y \mid x) exp(\frac{1}{\beta} r(x,y)) π?(y∣x)=Z(x)1?πref?(y∣x)exp(β1?r(x,y))
B. DPO下參數估計
- 即使得到了最優策略πr\pi_rπr?的形式,并且即使把里面的r(x,y)r(x,y)r(x,y)用MLE估計的rrr來替換,里面也有一個Z(x)Z(x)Z(x)需要估計,Z(x)Z(x)Z(x)的計算是很復雜的,里面的"狀態"或者說詞表yyy很大的情況下開銷大
- 但是可以進一步把式子整理一下,重新表示一下reward函數
r(x,y)=βlogπr(y∣x)πref(y∣x)+βlogZ(x)r(x,y)=\beta log \frac{\pi_r(y \mid x)}{\pi_{ref}(y \mid x)}+ \beta log Z(x)r(x,y)=βlogπref?(y∣x)πr?(y∣x)?+βlogZ(x) - 帶入原始的Bradley-Terry的式子,會發現,最后衡量偏好的函數里面,沒有reward function Z(x)Z(x)Z(x)這一項需要計算了抵消掉了
- 所以DPO的目標是提升yw?yly_w \succ y_lyw??yl?的概率,損失函數的形式為
LDPO(πθ;πref)=?E(x,yw,wl)~D[logσ(βlogπθ(yw∣x)πref(yw∣x)?βlogπθ(yl∣x)πref(yl∣x))]\mathcal{L}_{DPO}(\pi_\theta;\pi_{ref}) = -\mathbb{E}_{(x,y_w,w_l)\sim \mathcal{D}}[log \sigma(\beta log \frac{\pi_\theta(y_w \mid x)}{\pi_{ref}(y_w \mid x)} - \beta log \frac{\pi_\theta(y_l \mid x)}{\pi_{ref}(y_l \mid x)}) ] LDPO?(πθ?;πref?)=?E(x,yw?,wl?)~D?[logσ(βlogπref?(yw?∣x)πθ?(yw?∣x)??βlogπref?(yl?∣x)πθ?(yl?∣x)?)]
C. DPO下梯度更新
- 和人類偏好差異越大的,前面的系數越大
D. DPO訓練的穩定性
- 第二項為歸一化項是常數是因為對當前xxx,遍歷了所有的yyy
- 減少極端值的影響:通過指數加權平均,極端值的影響會被削弱,從而使得獎勵函數更加平滑
- 穩定梯度估計:由于獎勵函數變得更加平滑,策略梯度的估計也會更加穩定,方差會顯著減小
2 源代碼
RLAIF-V:https://github.com/RLHF-V/RLAIF-V/tree/main
2.1 數據集構成
- chose——人類偏好的回答
- rejected——SFT階段的模型回答
- ref_win_logp——人類偏好回答的所有token的log_probability之和
- ref_rej_logp——模型回答的的所有token的log_probability之和
- ref_win_avg_logp——人類偏好回答的所有token的log_probability之和 / 回答長度的token數
data_dict = {'image': image,"question": question,"chosen": chosen,"rejected": rejected,"idx": sample['idx'],"metainfo": metainfo
}
logps=json.loads(sample['logps']) # 調用/muffin下面的./eval/muffin_inference_logp.pyif type(logps) == type([]):(data_dict['ref_win_logp'], data_dict['ref_win_avg_logp'], data_dict['ref_win_per_token_logp'],data_dict['ref_rej_logp'], data_dict['ref_rej_avg_logp'], data_dict['ref_rej_per_token_logp']) = logps
else:(data_dict['ref_win_logp'], data_dict['ref_win_avg_logp'], data_dict['ref_win_per_token_logp'],data_dict['ref_rej_logp'], data_dict['ref_rej_avg_logp'], data_dict['ref_rej_per_token_logp']) = logps['logps']return data_dict
2.2 計算log prob
def get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, return_per_token_logp=False, return_all=False, tokenizer=None) -> torch.FloatTensor:"""Compute the log probabilities of the given labels under the given logits.Args:logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)Returns:A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits."""assert logits.shape[:-1] == labels.shape, f'logits.shape[:-1]={logits.shape[:-1]}, labels.shape={labels.shape}'labels = labels[:, 1:].clone()logits = logits[:, :-1, :]loss_mask = (labels != -100)# dummy token; we'll ignore the losses on these tokens laterlabels[labels == -100] = 0per_token_logps = torch.gather(logits.log_softmax(-1), dim=2,index=labels.unsqueeze(2)).squeeze(2) # get log probabilities for each token in labelslog_prob = (per_token_logps * loss_mask).sum(-1)average_log_prob = log_prob / loss_mask.sum(-1)
2.3 DPO loss
- policy model指的是正在訓練的模型,ref model是之前SFT階段的模型
- 注意policy_chosen_logps這些是log 的probability,所以和原始的DPO的loss公式是完全等價的
def get_beta_and_logps(data_dict, model, args, is_minicpm=False, is_llava15=False):win_input_ids = data_dict.pop('win_input_ids')rej_input_ids = data_dict.pop('rej_input_ids')ref_win_logp = data_dict.pop('ref_win_logp')ref_rej_logp = data_dict.pop('ref_rej_logp')log_prob, average_log_prob = get_batch_logps(output.logits, concatenated_labels, return_per_token_logp=False)if args.dpo_use_average:concatenated_logp = average_log_probwin_size = win_input_ids.shape[0]rej_size = rej_input_ids.shape[0]policy_win_logp, policy_rej_logp = concatenated_logp.split([win_size, rej_size]) # 默認的是average的log_logits,值越大越置信return policy_win_logp, policy_rej_logp, ref_win_logp, ref_rej_logp, betadef dpo_loss(policy_chosen_logps: torch.FloatTensor,policy_rejected_logps: torch.FloatTensor,reference_chosen_logps: torch.FloatTensor,reference_rejected_logps: torch.FloatTensor,beta: float,reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:"""Compute the DPO loss for a batch of policy and reference model log probabilities.Args:policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.Returns:A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).The losses tensor contains the DPO loss for each example in the batch.The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively."""pi_logratios = policy_chosen_logps - policy_rejected_logps # log(\pi(a_i | x)) - log(\pi(b_i | x)) = log(\pi(a_i | x) / \pi(b_i | x))ref_logratios = reference_chosen_logps - reference_rejected_logps # 完全等價的if reference_free:ref_logratios = 0logits = pi_logratios - ref_logratioslosses = -F.logsigmoid(beta * logits)chosen_rewards = beta * (policy_chosen_logps -reference_chosen_logps).detach()rejected_rewards = beta * \(policy_rejected_logps - reference_rejected_logps).detach()return losses, chosen_rewards, rejected_rewards############# 調用為policy_win_logp, policy_rej_logp, ref_win_logp, ref_rej_logp, beta = get_beta_and_logps(data_dict, model, self.args, is_llava15=True) # 這些都是averaged的token的log_logitslosses, chosen_rewards, rejected_rewards = dpo_loss(policy_win_logp,policy_rej_logp,ref_win_logp,ref_rej_logp,beta=beta)