Upload BD3LM
Browse files- modeling_bd3lm.py +4 -4
modeling_bd3lm.py
CHANGED
@@ -547,9 +547,12 @@ class DITBackbone(nn.Module):
|
|
547 |
if self.adaln:
|
548 |
c = F.silu(self.sigma_map(sigma))
|
549 |
if self.cross_attn:
|
|
|
550 |
n = self.mask.shape[-1] // 2
|
551 |
# use block-causal mask only during sampling
|
552 |
-
if sample_mode:
|
|
|
|
|
553 |
if self.blocks[0].kv_cache is not None:
|
554 |
mask = None
|
555 |
accum_length = self.blocks[0].cache_idx + self.block_size
|
@@ -558,12 +561,9 @@ class DITBackbone(nn.Module):
|
|
558 |
x.shape[0], accum_length, x.shape[2]), device=x.device)
|
559 |
rotary_cos_sin = self.rotary_emb(x_full)
|
560 |
else:
|
561 |
-
mask = self.mask.to(x.device)
|
562 |
rotary_cos_sin = self.rotary_emb(x[:, :n])
|
563 |
mask = mask[
|
564 |
n:n+x.shape[1], n:n+x.shape[1]]
|
565 |
-
else:
|
566 |
-
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
|
567 |
else:
|
568 |
mask = None
|
569 |
rotary_cos_sin = self.rotary_emb(x)
|
|
|
547 |
if self.adaln:
|
548 |
c = F.silu(self.sigma_map(sigma))
|
549 |
if self.cross_attn:
|
550 |
+
mask = self.mask.to(x.device)
|
551 |
n = self.mask.shape[-1] // 2
|
552 |
# use block-causal mask only during sampling
|
553 |
+
if not sample_mode:
|
554 |
+
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
|
555 |
+
else:
|
556 |
if self.blocks[0].kv_cache is not None:
|
557 |
mask = None
|
558 |
accum_length = self.blocks[0].cache_idx + self.block_size
|
|
|
561 |
x.shape[0], accum_length, x.shape[2]), device=x.device)
|
562 |
rotary_cos_sin = self.rotary_emb(x_full)
|
563 |
else:
|
|
|
564 |
rotary_cos_sin = self.rotary_emb(x[:, :n])
|
565 |
mask = mask[
|
566 |
n:n+x.shape[1], n:n+x.shape[1]]
|
|
|
|
|
567 |
else:
|
568 |
mask = None
|
569 |
rotary_cos_sin = self.rotary_emb(x)
|