attention_weights = torch.ones_like(prompt_embedding[:, :, 0]):切片操作獲取第1 維度,第二維度
attention_weights = torch.ones_like(prompt_embedding[:, :, 0])
這行代碼的作用是創建一個與 prompt_embedding[:, :, 0]
形狀相同且所有元素都為 1
的張量,它用于初始化注意力權重。
代碼解釋
torch.ones_like()
:這是PyTorch中的一個函數,它創建一個形狀與輸入張量相同且所有元素都為1
的張量。prompt_embedding[:, :, 0]
:這部分是對prompt_embedding
張量的切片操作。prompt_embedding
是一個三維張量,[:, :, 0]
表示取每個二維切片的第0
個元素,得到一個二維張量。
因此,torch.ones_like(prompt_embedd