相同的隨機種子CPU和GPU上torch.nn.init.xavier_normal_結果并不一致
- 一.測試代碼
- 二.輸出
在訓練pytorch模型時,相同的隨機種子,不同的服務器上loss并不一樣,通過調試發現這二個平臺的權值也不一樣.單獨測試torch.nn.init.xavier_normal_,發現也不一樣.如果都放在CPU上則二臺服務器上的結果一致,原來Megatron-DeepSpeed也有–use-cpu-initialization這樣一個參數,采用CPU初始化權值
一.測試代碼
cat > test_torch_rand.py <<-'EOF'
import torch
import numpy as np
import randomdef init_test(device):shape=(1,4)RANDOM_SEED = 42random.seed(RANDOM_SEED)np.random.seed(RANDOM_SEED)torch.manual_seed(RANDOM_SEED)if torch.cuda.is_available():torch.cuda.manual_seed_all(RANDOM_SEED)print(f"------------------------test torch init on {device}-------------------------------")weight = torch.rand(shape, dtype = torch.float16).to(device)print("torch.rand:",weight.detach().cpu().float().numpy())weight = torch.nn.Parameter(torch.empty(shape,device=device,dtype=torch.float16))torch.nn.init.xavier_normal_(weight)print("xavier_normal_:",weight.detach().cpu().float().numpy())weight = torch.nn.Parameter(torch.empty(shape,device=device,dtype=torch.float16))torch.nn.init.uniform_(weight, a=0.0, b=1.0)print("uniform_:",weight.detach().cpu().float().numpy())weight = torch.nn.Parameter(torch.empty(shape,device=device,dtype=torch.float16))torch.nn.init.normal_(weight)print("normal_:",weight.detach().cpu().float().numpy())weight = torch.nn.Parameter(torch.empty(shape,device=device,dtype=torch.float16))torch.nn.init.kaiming_uniform_(weight)print("kaiming_uniform_:",weight.detach().cpu().float().numpy())init_test("cpu")
if torch.cuda.is_available():init_test("cuda")
EOF
python3 test_torch_rand.py
二.輸出
------------------------test torch init on cpu-------------------------------
torch.rand: [[0.5498047 0.71240234 0.41992188 0.63183594]]
xavier_normal_: [[ 0.14831543 0.14562988 -0.70996094 -0.11785889]]
uniform_: [[0.16113281 0.7236328 0.04248047 0.6816406 ]]
normal_: [[0.46166992 0.26733398 0.53466797 0.8095703 ]]
kaiming_uniform_: [[ 0.58740234 0.49389648 -0.2619629 -0.76416016]]
------------------------test torch init on cuda-------------------------------
torch.rand: [[0.5498047 0.71240234 0.41992188 0.63183594]]
xavier_normal_: [[ 0.12268066 1.3671875 -0.10882568 0.5371094 ]]
uniform_: [[0.98779297 0.12890625 0.5620117 0.52197266]]
normal_: [[-0.5185547 1.2265625 0.6254883 -0.9116211]]
kaiming_uniform_: [[ 0.5625 -0.9321289 -1.0996094 -0.640625 ]]