marriola commited on
Commit
6c64458
·
verified ·
1 Parent(s): 58fce4e

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +4 -4
modeling_bd3lm.py CHANGED
@@ -547,9 +547,12 @@ class DITBackbone(nn.Module):
547
  if self.adaln:
548
  c = F.silu(self.sigma_map(sigma))
549
  if self.cross_attn:
 
550
  n = self.mask.shape[-1] // 2
551
  # use block-causal mask only during sampling
552
- if sample_mode:
 
 
553
  if self.blocks[0].kv_cache is not None:
554
  mask = None
555
  accum_length = self.blocks[0].cache_idx + self.block_size
@@ -558,12 +561,9 @@ class DITBackbone(nn.Module):
558
  x.shape[0], accum_length, x.shape[2]), device=x.device)
559
  rotary_cos_sin = self.rotary_emb(x_full)
560
  else:
561
- mask = self.mask.to(x.device)
562
  rotary_cos_sin = self.rotary_emb(x[:, :n])
563
  mask = mask[
564
  n:n+x.shape[1], n:n+x.shape[1]]
565
- else:
566
- rotary_cos_sin = self.rotary_emb(x[:, :self.n])
567
  else:
568
  mask = None
569
  rotary_cos_sin = self.rotary_emb(x)
 
547
  if self.adaln:
548
  c = F.silu(self.sigma_map(sigma))
549
  if self.cross_attn:
550
+ mask = self.mask.to(x.device)
551
  n = self.mask.shape[-1] // 2
552
  # use block-causal mask only during sampling
553
+ if not sample_mode:
554
+ rotary_cos_sin = self.rotary_emb(x[:, :self.n])
555
+ else:
556
  if self.blocks[0].kv_cache is not None:
557
  mask = None
558
  accum_length = self.blocks[0].cache_idx + self.block_size
 
561
  x.shape[0], accum_length, x.shape[2]), device=x.device)
562
  rotary_cos_sin = self.rotary_emb(x_full)
563
  else:
 
564
  rotary_cos_sin = self.rotary_emb(x[:, :n])
565
  mask = mask[
566
  n:n+x.shape[1], n:n+x.shape[1]]
 
 
567
  else:
568
  mask = None
569
  rotary_cos_sin = self.rotary_emb(x)