參考資料:
https://github.com/pytorch/pytorch/issues/73515
https://www.cnblogs.com/X1OO/articles/18171700
由于業務原因,需要在Pytorch代碼中使用分布式通訊來把計算負載平均到多張顯卡上。在無數次確認我的業務代碼沒問題之后,我開始把懷疑的對象轉移到分布式通訊的問題上:
單卡推理的中間層輸出
多卡推理的中間層輸出
如上兩圖,在打印了中間層輸出之后,我發現在After gather之后,多卡推理與單卡推理的中間變量的均值、最大值和最小值是完全一致的。但是緊鄰的一個log卻顯示,它們各自在進入decode方法之后,值就隨即發生了變化。這就不符合我的認知了,因為我可以完全保證從gather到decode沒有任何對tensor做特殊處理的操作。并且在我的固有觀念里,只要兩個形狀較大的tensor統計值類似,基本就可以保證兩個tensor是一模一樣的,那么問題到底出現在哪呢?
不信邪的我把After gather之后統計值相同的兩個tensor都保存下載進行了分析,一分析我就傻眼了:
只見兩個tensor統計值完全相同,甚至通過排序之后發現Tensor中的元素也似乎完全相同,但是這兩個Tensor就是不一樣的。在此檢查了一下代碼中沒有對維度進行特殊操作之后,我把目光放到了我寫的分布式gather函數里:
1 def _conv_gather_avg(input_, dim):2 cp_world_size = get_context_parallel_world_size()3 # Bypass the function if context parallel is 14 if cp_world_size == 1:5 return input_.contiguous()6 7 # input_ = input_.contiguous()8 9 group = get_context_parallel_group()
10 cp_rank = get_context_parallel_rank()
11 tensor_list = [torch.empty_like(input_) for _ in range(cp_world_size)]
12 tensor_list[cp_rank] = input_
13 torch.distributed.all_gather(tensor_list, input_, group=group)
14 # Note: torch.cat already creates a contiguous tensor.
15 output = torch.cat(tensor_list, dim=dim).contiguous()
16 # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
17 return output
注意第七行一開始是沒有的,這里的代碼我是借鑒了其他人的,我發現很多地方都強調了contiguous這個方法,難道它真的有這么重要?于是抱著試一試的態度,我在第七行上加上了input_ = input_.contiguous(),然后神奇的事情就發生了,gather之后的tensor居然就能夠精度對上了。總結一下問題就是如果在all_gather之前不對輸入input_運行contiguous的話,會導致gather之后的tensor雖然值都一樣,但是排列順序完全混亂。下面引用參考資料講一下為什么Pytorch分布式通訊中all_gather要求tensor連續。
工作太忙了沒時間講了... 有時間再補充?