Supports the input of attention mask
Browse files- modeling_llada.py +18 -5
modeling_llada.py
CHANGED
|
@@ -649,12 +649,12 @@ class LLaDABlock(nn.Module):
|
|
| 649 |
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 650 |
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 651 |
|
| 652 |
-
# Modify: MDM set causal to False
|
| 653 |
return F.scaled_dot_product_attention(
|
| 654 |
q,
|
| 655 |
k,
|
| 656 |
v,
|
| 657 |
-
attn_mask=
|
| 658 |
dropout_p=dropout_p,
|
| 659 |
is_causal=False,
|
| 660 |
)
|
|
@@ -712,7 +712,7 @@ class LLaDABlock(nn.Module):
|
|
| 712 |
q,
|
| 713 |
k,
|
| 714 |
v,
|
| 715 |
-
attn_mask=
|
| 716 |
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 717 |
is_causal=False,
|
| 718 |
)
|
|
@@ -1157,7 +1157,20 @@ class LLaDAModel(nn.Module):
|
|
| 1157 |
alibi_bias = alibi_attention_bias(seq_len, self.config, device)
|
| 1158 |
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1159 |
return alibi_bias
|
| 1160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1161 |
def forward(
|
| 1162 |
self,
|
| 1163 |
input_ids: torch.LongTensor,
|
|
@@ -1257,7 +1270,7 @@ class LLaDAModel(nn.Module):
|
|
| 1257 |
self.__cache, past_length + seq_len, x.device
|
| 1258 |
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
| 1259 |
elif attention_bias is None:
|
| 1260 |
-
attention_bias =
|
| 1261 |
elif attention_bias.dtype in (torch.int8, torch.bool):
|
| 1262 |
attention_bias = attention_bias.to(dtype=torch.float)
|
| 1263 |
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|
|
|
|
| 649 |
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 650 |
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
| 651 |
|
| 652 |
+
# Modify: MDM set causal to False.
|
| 653 |
return F.scaled_dot_product_attention(
|
| 654 |
q,
|
| 655 |
k,
|
| 656 |
v,
|
| 657 |
+
attn_mask=attn_mask,
|
| 658 |
dropout_p=dropout_p,
|
| 659 |
is_causal=False,
|
| 660 |
)
|
|
|
|
| 712 |
q,
|
| 713 |
k,
|
| 714 |
v,
|
| 715 |
+
attn_mask=attention_bias,
|
| 716 |
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
| 717 |
is_causal=False,
|
| 718 |
)
|
|
|
|
| 1157 |
alibi_bias = alibi_attention_bias(seq_len, self.config, device)
|
| 1158 |
self.__cache["alibi_attention_bias"] = alibi_bias
|
| 1159 |
return alibi_bias
|
| 1160 |
+
|
| 1161 |
+
def get_bidirectional_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 1162 |
+
if (bidirectional_bias := self.__cache.get("bidirectional_attention_bias")) is not None and bidirectional_bias.shape[
|
| 1163 |
+
-1
|
| 1164 |
+
] >= seq_len:
|
| 1165 |
+
if bidirectional_bias.device != device:
|
| 1166 |
+
bidirectional_bias = bidirectional_bias.to(device)
|
| 1167 |
+
self.__cache["bidirectional_attention_bias"] = bidirectional_bias
|
| 1168 |
+
return bidirectional_bias
|
| 1169 |
+
with torch.autocast(device.type, enabled=False):
|
| 1170 |
+
bidirectional_bias = torch.zeros((1, 1, seq_len, seq_len), device=device, dtype=torch.float)
|
| 1171 |
+
self.__cache["bidirectional_attention_bias"] = bidirectional_bias
|
| 1172 |
+
return bidirectional_bias
|
| 1173 |
+
|
| 1174 |
def forward(
|
| 1175 |
self,
|
| 1176 |
input_ids: torch.LongTensor,
|
|
|
|
| 1270 |
self.__cache, past_length + seq_len, x.device
|
| 1271 |
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
| 1272 |
elif attention_bias is None:
|
| 1273 |
+
attention_bias = self.get_bidirectional_attention_bias(past_length + seq_len, x.device)
|
| 1274 |
elif attention_bias.dtype in (torch.int8, torch.bool):
|
| 1275 |
attention_bias = attention_bias.to(dtype=torch.float)
|
| 1276 |
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
|