You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Pass-through for attention mask creation since it is never used:
35
+
- For regular attention, the custom sdpa op in causal mode creates its own attention mask
36
+
- For sliding window attention, the attention mask from the attention mask API is ditched and re-created during the attention API since it needs to know about cache internals
37
+
38
+
Additionally, there were some vmap export issues with sliding window attention mask creation in Transformers.
39
+
40
+
Args:
41
+
batch_size (`int`):
42
+
The batch size of the input sequence.
43
+
cache_position (`torch.Tensor`):
44
+
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
45
+
kv_length (`int`):
46
+
The size that the key and value states will have during the attention computation.
47
+
kv_offset (`int`, optional):
48
+
An optional offset to indicate at which first position the key and values states will refer to.
49
+
mask_function (`Callable`):
50
+
The mask factory function describing the mask pattern.
51
+
attention_mask (`torch.Tensor`, optional):
52
+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
53
+
local_size (`int`, optional):
54
+
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
55
+
to try to skip mask creation if possible.
56
+
allow_is_causal_skip (`bool`, optional):
57
+
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
58
+
`torch.sdpa` instead. Default to `True`.
59
+
allow_torch_fix (`bool`, optional):
60
+
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
61
+
versions. We need an arg to skip it when using eager. By default `True`.
0 commit comments