在使用tensorboard可視化,經常會將模型通過save_graph方法保存下來,方便查看結構。在使用save_graph經常會遇到錯誤(至少我經常遇到),對于我,最常見的一個錯誤為
Tracing failed sanity checks!
ERROR: Graphs differed across invocations!Graph diff:
.....
First diverging operator:Node diff:
...
我是在模型中用了 pytorch 自帶的 nn.MultiheadAttention 發生了這個錯誤,一個簡單的解決方法是將原本的
self.attn = nn.MultiheadAttention(128, 8, 0.1, batch_first=True)
中的 batch_first = True 刪去,修改之后為
self.attn = nn.MultiheadAttention(128, 8, 0.1)
注意刪除 batch_first = True 后, 輸入格式需要改為 (seq, batch, feature)。