prompt_vector = torch.sum(prompt_embedding * attention_weights.unsqueeze(-1), dim=1) # [1, hidden_dim]
prompt_vector = torch.sum(prompt_embedding * attention_weights.unsqueeze(-1), dim=1)
主要作用是通過將 prompt_embedding
與 attention_weights
相乘后再按指定維度求和,得到一個新的張量 prompt_vector
。
代碼解釋
prompt_embedding
:這是一個包含提示詞嵌入向量的張量,通常形狀為[batch_size, seq_len, hidden_dim]
,表示批次大小、序列長度和隱藏層維度。attention_weights
:這是一個注意力權重張量,形狀通常為[batch_size, seq_len]
,表示每個位置的注意力權重。