nieshen commited on
Commit
b4d75ba
·
verified ·
1 Parent(s): ce71e3c

Supports the input of attention mask

Browse files
Files changed (1) hide show
  1. modeling_llada.py +18 -5
modeling_llada.py CHANGED
@@ -649,12 +649,12 @@ class LLaDABlock(nn.Module):
649
  k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
650
  v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
651
 
652
- # Modify: MDM set causal to False, and with no attn_mask.
653
  return F.scaled_dot_product_attention(
654
  q,
655
  k,
656
  v,
657
- attn_mask=None,
658
  dropout_p=dropout_p,
659
  is_causal=False,
660
  )
@@ -712,7 +712,7 @@ class LLaDABlock(nn.Module):
712
  q,
713
  k,
714
  v,
715
- attn_mask=None,
716
  dropout_p=0.0 if not self.training else self.config.attention_dropout,
717
  is_causal=False,
718
  )
@@ -1157,7 +1157,20 @@ class LLaDAModel(nn.Module):
1157
  alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1158
  self.__cache["alibi_attention_bias"] = alibi_bias
1159
  return alibi_bias
1160
-
 
 
 
 
 
 
 
 
 
 
 
 
 
1161
  def forward(
1162
  self,
1163
  input_ids: torch.LongTensor,
@@ -1257,7 +1270,7 @@ class LLaDAModel(nn.Module):
1257
  self.__cache, past_length + seq_len, x.device
1258
  ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1259
  elif attention_bias is None:
1260
- attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1261
  elif attention_bias.dtype in (torch.int8, torch.bool):
1262
  attention_bias = attention_bias.to(dtype=torch.float)
1263
  attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
 
649
  k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
650
  v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
651
 
652
+ # Modify: MDM set causal to False.
653
  return F.scaled_dot_product_attention(
654
  q,
655
  k,
656
  v,
657
+ attn_mask=attn_mask,
658
  dropout_p=dropout_p,
659
  is_causal=False,
660
  )
 
712
  q,
713
  k,
714
  v,
715
+ attn_mask=attention_bias,
716
  dropout_p=0.0 if not self.training else self.config.attention_dropout,
717
  is_causal=False,
718
  )
 
1157
  alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1158
  self.__cache["alibi_attention_bias"] = alibi_bias
1159
  return alibi_bias
1160
+
1161
+ def get_bidirectional_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1162
+ if (bidirectional_bias := self.__cache.get("bidirectional_attention_bias")) is not None and bidirectional_bias.shape[
1163
+ -1
1164
+ ] >= seq_len:
1165
+ if bidirectional_bias.device != device:
1166
+ bidirectional_bias = bidirectional_bias.to(device)
1167
+ self.__cache["bidirectional_attention_bias"] = bidirectional_bias
1168
+ return bidirectional_bias
1169
+ with torch.autocast(device.type, enabled=False):
1170
+ bidirectional_bias = torch.zeros((1, 1, seq_len, seq_len), device=device, dtype=torch.float)
1171
+ self.__cache["bidirectional_attention_bias"] = bidirectional_bias
1172
+ return bidirectional_bias
1173
+
1174
  def forward(
1175
  self,
1176
  input_ids: torch.LongTensor,
 
1270
  self.__cache, past_length + seq_len, x.device
1271
  ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1272
  elif attention_bias is None:
1273
+ attention_bias = self.get_bidirectional_attention_bias(past_length + seq_len, x.device)
1274
  elif attention_bias.dtype in (torch.int8, torch.bool):
1275
  attention_bias = attention_bias.to(dtype=torch.float)
1276
  attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)