Upload BD3LM
Browse files- modeling_bd3lm.py +4 -4
modeling_bd3lm.py
CHANGED
|
@@ -547,9 +547,12 @@ class DITBackbone(nn.Module):
|
|
| 547 |
if self.adaln:
|
| 548 |
c = F.silu(self.sigma_map(sigma))
|
| 549 |
if self.cross_attn:
|
|
|
|
| 550 |
n = self.mask.shape[-1] // 2
|
| 551 |
# use block-causal mask only during sampling
|
| 552 |
-
if sample_mode:
|
|
|
|
|
|
|
| 553 |
if self.blocks[0].kv_cache is not None:
|
| 554 |
mask = None
|
| 555 |
accum_length = self.blocks[0].cache_idx + self.block_size
|
|
@@ -558,12 +561,9 @@ class DITBackbone(nn.Module):
|
|
| 558 |
x.shape[0], accum_length, x.shape[2]), device=x.device)
|
| 559 |
rotary_cos_sin = self.rotary_emb(x_full)
|
| 560 |
else:
|
| 561 |
-
mask = self.mask.to(x.device)
|
| 562 |
rotary_cos_sin = self.rotary_emb(x[:, :n])
|
| 563 |
mask = mask[
|
| 564 |
n:n+x.shape[1], n:n+x.shape[1]]
|
| 565 |
-
else:
|
| 566 |
-
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
|
| 567 |
else:
|
| 568 |
mask = None
|
| 569 |
rotary_cos_sin = self.rotary_emb(x)
|
|
|
|
| 547 |
if self.adaln:
|
| 548 |
c = F.silu(self.sigma_map(sigma))
|
| 549 |
if self.cross_attn:
|
| 550 |
+
mask = self.mask.to(x.device)
|
| 551 |
n = self.mask.shape[-1] // 2
|
| 552 |
# use block-causal mask only during sampling
|
| 553 |
+
if not sample_mode:
|
| 554 |
+
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
|
| 555 |
+
else:
|
| 556 |
if self.blocks[0].kv_cache is not None:
|
| 557 |
mask = None
|
| 558 |
accum_length = self.blocks[0].cache_idx + self.block_size
|
|
|
|
| 561 |
x.shape[0], accum_length, x.shape[2]), device=x.device)
|
| 562 |
rotary_cos_sin = self.rotary_emb(x_full)
|
| 563 |
else:
|
|
|
|
| 564 |
rotary_cos_sin = self.rotary_emb(x[:, :n])
|
| 565 |
mask = mask[
|
| 566 |
n:n+x.shape[1], n:n+x.shape[1]]
|
|
|
|
|
|
|
| 567 |
else:
|
| 568 |
mask = None
|
| 569 |
rotary_cos_sin = self.rotary_emb(x)
|