redmoe-ai-v1 warrenwjk commited on
Commit
4d84cbc
·
verified ·
1 Parent(s): 325ed02

Update VisionSdpaAttention to support memory efficient backend. (#27)

Browse files

- Update VisionSdpaAttention to support memory efficient backend. (fc8b0b11b92c381639e616506cca574f1b05af09)


Co-authored-by: warren wang <[email protected]>

Files changed (1) hide show
  1. modeling_dots_vision.py +14 -5
modeling_dots_vision.py CHANGED
@@ -274,12 +274,21 @@ class VisionSdpaAttention(nn.Module):
274
  for i in range(1, len(cu_seqlens)):
275
  attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
276
 
277
- q = q.transpose(0, 1)
278
- k = k.transpose(0, 1)
279
- v = v.transpose(0, 1)
 
280
 
281
- attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
282
- attn_output = attn_output.transpose(0, 1)
 
 
 
 
 
 
 
 
283
  attn_output = attn_output.reshape(seq_length, -1)
284
 
285
  attn_output = self.proj(attn_output)
 
274
  for i in range(1, len(cu_seqlens)):
275
  attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
276
 
277
+ # Convert q, k, v to 4D to enable : (1, num_heads, seq_length, head_dim)
278
+ q = q.transpose(0, 1).unsqueeze(0) # (1, num_heads, seq_length, head_dim)
279
+ k = k.transpose(0, 1).unsqueeze(0)
280
+ v = v.transpose(0, 1).unsqueeze(0)
281
 
282
+ # See: https://github.com/pytorch/pytorch/issues/127523
283
+ if attention_mask.stride(-1) != 1:
284
+ attention_mask = torch.empty_like(attention_mask, memory_format=torch.contiguous_format).copy_(attention_mask)
285
+
286
+ # use memory efficient backend
287
+ from torch.nn.attention import SDPBackend, sdpa_kernel
288
+ with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
289
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
290
+
291
+ attn_output = attn_output.squeeze(0).transpose(0, 1) # (seq_length, num_heads, head_dim)
292
  attn_output = attn_output.reshape(seq_length, -1)
293
 
294
  attn_output = self.proj(attn_output)