-
Notifications
You must be signed in to change notification settings - Fork 73
Description
Hello, I think If you want the additive attention be able to deal with batch, while inputs are like these
Inputs: query, value
- query (batch_size, q_len, hidden_dim): tensor containing the output features from the decoder.
- value (batch_size, v_len, hidden_dim): tensor containing features of the encoded input sequence
the code in forward function should be like this:
def forward(self, query: Tensor, key: Tensor, value: Tensor):
score = self.score_proj(
torch.tanh(self.key_proj(key.unsqueeze(1)) + self.query_proj(query.unsqueeze(2)) + self.bias)).squeeze()
attn = F.softmax(score, dim=-1)
context = torch.bmm(attn, value)
return context, attn
otherwise, the size of self.key_proj(key.unsqueeze(1)) and self.query_proj(query.unsqueeze(2) will be dismatch on second dimension and can not be added