nathanrchn commited on
Commit
2fc0ee1
·
1 Parent(s): 695bee3

Update modeling_phi.py

Browse files
Files changed (1) hide show
  1. modeling_phi.py +0 -5
modeling_phi.py CHANGED
@@ -355,10 +355,8 @@ class SelfAttention(nn.Module):
355
  key_padding_mask: Optional[torch.BoolTensor] = None,
356
  **kwargs,
357
  ) -> torch.FloatTensor:
358
- print(qkv.shape)
359
  batch_size, seqlen = qkv.shape[0], qkv.shape[1]
360
  q, k, v = qkv.unbind(dim=2)
361
- print(q.shape, k.shape, v.shape)
362
 
363
  q = q.to(torch.float32)
364
  k = k.to(torch.float32)
@@ -369,7 +367,6 @@ class SelfAttention(nn.Module):
369
  # Autocast is manually disabled to avoid `torch.einsum` performing the operation
370
  # using float16, which might lead to overflow
371
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
372
- print(scores.shape)
373
 
374
  if key_padding_mask is not None:
375
  padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
@@ -379,14 +376,12 @@ class SelfAttention(nn.Module):
379
 
380
  if causal:
381
  causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
382
- print(causal_mask.shape)
383
  scores = scores + causal_mask.to(dtype=scores.dtype)
384
 
385
  attention = torch.softmax(scores, dim=-1).to(v.dtype)
386
  attention = self.drop(attention)
387
 
388
  output = torch.einsum("bhts,bshd->bthd", attention, v)
389
- print(output.shape)
390
 
391
  return output
392
 
 
355
  key_padding_mask: Optional[torch.BoolTensor] = None,
356
  **kwargs,
357
  ) -> torch.FloatTensor:
 
358
  batch_size, seqlen = qkv.shape[0], qkv.shape[1]
359
  q, k, v = qkv.unbind(dim=2)
 
360
 
361
  q = q.to(torch.float32)
362
  k = k.to(torch.float32)
 
367
  # Autocast is manually disabled to avoid `torch.einsum` performing the operation
368
  # using float16, which might lead to overflow
369
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
 
370
 
371
  if key_padding_mask is not None:
372
  padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
 
376
 
377
  if causal:
378
  causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
 
379
  scores = scores + causal_mask.to(dtype=scores.dtype)
380
 
381
  attention = torch.softmax(scores, dim=-1).to(v.dtype)
382
  attention = self.drop(attention)
383
 
384
  output = torch.einsum("bhts,bshd->bthd", attention, v)
 
385
 
386
  return output
387