內置/自己設計的損失函數使用對比
- 1.內置損失函數
- 2.自己設計損失函數
Pytorch內置了許多常用的損失函數,但是,實際應用中,往往需要依據不同的需求設計不同的損失函數。本篇博文對比總結了使用 內置和 自己設計損失函數的語法規則流程。
1.內置損失函數
使用內嵌的損失函數:
optimizer_d = torch.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999)) #要優化網路netd中的參數,使用Adam優化算法
criterion = t.nn.BCELoss().to(device) #官方損失函數
output = netd(real_img) #網絡輸出 output
error_d_real = criterion(output, true_labels) #將網路輸出output傳入損失函數計算真實值,一批次的輸出直接穿傳進去
optimizer_d.zero_grad() #要優化參數梯度歸零
error_d_real.backward() #由損失函數從最后一層開始極端各個參數的梯度
optimizer_d.step() #梯度反向傳播,更新參數
2.自己設計損失函數
optimizer_D = torch.optim.Adam(D.parameters(), lr=LR_D) #要優化網路D中的參數,使用Adam優化算法
prob_artist0 = D(artist_paintings) #網絡輸出prob_artist0
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1)) #自己設計的損失函數,直接由網絡輸出prob_artist0計算損失
optimizer_D.zero_grad() #要優化參數梯度歸零
D_loss.backward(retain_graph=True) #由損失函數從最后一層開始的各個參數的梯度
optimizer_D.step() #梯度反向傳播,更新參數
核心過程都是一樣的,只不過在實際使用時會因為編程需要調整一些步驟的順序。