kaitos255 commited on
Commit
f472680
·
verified ·
1 Parent(s): 5f5fbe6

Fix _prepare_4d_causal_attention_mask_for_sdpa

Browse files
Files changed (1) hide show
  1. modeling_plamo.py +98 -2
modeling_plamo.py CHANGED
@@ -6,13 +6,109 @@ from torch import nn
6
  from torch.nn import functional as F
7
  from transformers import AutoTokenizer, PretrainedConfig, PreTrainedModel
8
  from transformers.modeling_attn_mask_utils import (
 
9
  _prepare_4d_causal_attention_mask,
10
- _prepare_4d_causal_attention_mask_for_sdpa,
11
  )
12
  from transformers.modeling_outputs import BaseModelOutputWithPast
13
  from transformers.tokenization_utils_base import BatchEncoding
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def _swiglu(h: torch.Tensor) -> torch.Tensor:
17
  h0, h1 = h.chunk(2, dim=-1)
18
  return torch.nn.functional.silu(h0) * h1
@@ -817,7 +913,7 @@ class ModifiedAttention(Attention):
817
 
818
 
819
  PLAMO_ATTENTION_CLASSES = {
820
- "sdpa": Attention,
821
  }
822
 
823
 
 
6
  from torch.nn import functional as F
7
  from transformers import AutoTokenizer, PretrainedConfig, PreTrainedModel
8
  from transformers.modeling_attn_mask_utils import (
9
+ AttentionMaskConverter,
10
  _prepare_4d_causal_attention_mask,
 
11
  )
12
  from transformers.modeling_outputs import BaseModelOutputWithPast
13
  from transformers.tokenization_utils_base import BatchEncoding
14
 
15
 
16
+ # From: https://github.com/McGill-NLP/llm2vec/blob/main/llm2vec/models/attn_mask_utils.py
17
+ def _prepare_4d_causal_attention_mask_for_sdpa(
18
+ attention_mask: Optional[torch.Tensor],
19
+ input_shape: Union[torch.Size, Tuple, List],
20
+ inputs_embeds: torch.Tensor,
21
+ past_key_values_length: int,
22
+ sliding_window: Optional[int] = None,
23
+ ):
24
+ """
25
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
26
+
27
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
28
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
29
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
30
+ """
31
+ attn_mask_converter = AttentionMaskConverter(
32
+ is_causal=False, sliding_window=sliding_window
33
+ ) # is_causal=True in original implementation
34
+
35
+ key_value_length = input_shape[-1] + past_key_values_length
36
+ batch_size, query_length = input_shape
37
+
38
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
39
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
40
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
41
+ is_tracing = (
42
+ torch.jit.is_tracing()
43
+ or isinstance(inputs_embeds, torch.fx.Proxy)
44
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
45
+ )
46
+
47
+ if attention_mask is not None:
48
+ # 4d mask is passed through
49
+ if len(attention_mask.shape) == 4:
50
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
51
+ if tuple(attention_mask.shape) != expected_shape:
52
+ raise ValueError(
53
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
54
+ )
55
+ else:
56
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
57
+ inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
58
+ attention_mask = inverted_mask.masked_fill(
59
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
60
+ )
61
+ return attention_mask
62
+
63
+ elif not is_tracing and torch.all(attention_mask == 1):
64
+ if query_length == 1:
65
+ # For query_length == 1, causal attention and bi-directional attention are the same.
66
+ attention_mask = None
67
+ elif key_value_length == query_length:
68
+ attention_mask = None
69
+ else:
70
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
71
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
72
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
73
+ pass
74
+ elif query_length > 1 and key_value_length != query_length:
75
+ # See the comment above (https://github.com/pytorch/pytorch/issues/108108).
76
+ # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
77
+ attention_mask = True
78
+ elif is_tracing:
79
+ raise ValueError(
80
+ 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
81
+ )
82
+
83
+ if attention_mask is None:
84
+ expanded_4d_mask = None
85
+ elif attention_mask is True:
86
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
87
+ input_shape[0],
88
+ input_shape[-1],
89
+ key_value_length,
90
+ dtype=inputs_embeds.dtype,
91
+ device=inputs_embeds.device,
92
+ )
93
+ else:
94
+ expanded_4d_mask = attn_mask_converter.to_4d(
95
+ attention_mask,
96
+ input_shape[-1],
97
+ dtype=inputs_embeds.dtype,
98
+ key_value_length=key_value_length,
99
+ )
100
+
101
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
102
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
103
+ # Details: https://github.com/pytorch/pytorch/issues/110213
104
+ if not is_tracing and expanded_4d_mask.device.type == "cuda":
105
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
106
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
107
+ )
108
+
109
+ return expanded_4d_mask
110
+
111
+
112
  def _swiglu(h: torch.Tensor) -> torch.Tensor:
113
  h0, h1 = h.chunk(2, dim=-1)
114
  return torch.nn.functional.silu(h0) * h1
 
913
 
914
 
915
  PLAMO_ATTENTION_CLASSES = {
916
+ "sdpa": ModifiedAttention,
917
  }
918
 
919