Upload BD3LM
Browse files- modeling_bd3lm.py +9 -8
modeling_bd3lm.py
CHANGED
|
@@ -302,7 +302,6 @@ class DDiTBlock(nn.Module):
|
|
| 302 |
def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
|
| 303 |
dropout=0.1, max_seqlen=1024, attn_backend='flash_attn'):
|
| 304 |
super().__init__()
|
| 305 |
-
self.max_seqlen = max_seqlen
|
| 306 |
self.n = n
|
| 307 |
self.block_size = block_size
|
| 308 |
self.n_heads = n_heads
|
|
@@ -341,7 +340,7 @@ class DDiTBlock(nn.Module):
|
|
| 341 |
qkv = self.attn_qkv(x)
|
| 342 |
# store kv cache in a sliding window (can't exceed context len)
|
| 343 |
if store_kv:
|
| 344 |
-
self.kv_cache = qkv[:, -(self.
|
| 345 |
|
| 346 |
qkv = einops.rearrange(
|
| 347 |
qkv,
|
|
@@ -389,8 +388,9 @@ class DDiTBlock(nn.Module):
|
|
| 389 |
|
| 390 |
# get qkvs
|
| 391 |
if mask is not None and not sample_mode:
|
| 392 |
-
|
| 393 |
-
|
|
|
|
| 394 |
qkv = torch.cat((qkv_x, qkv_x0), dim=1)
|
| 395 |
else:
|
| 396 |
qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
|
|
@@ -518,12 +518,13 @@ class DITBackbone(nn.Module):
|
|
| 518 |
all_hidden_states.append(x)
|
| 519 |
c = F.silu(self.sigma_map(sigma))
|
| 520 |
if self.cross_attn:
|
| 521 |
-
|
|
|
|
| 522 |
mask = self.mask.to(x.device)
|
| 523 |
# use block-causal mask only during sampling
|
| 524 |
if sample_mode:
|
| 525 |
mask = mask[
|
| 526 |
-
|
| 527 |
else:
|
| 528 |
mask = None
|
| 529 |
rotary_cos_sin = self.rotary_emb(x)
|
|
@@ -540,8 +541,8 @@ class DITBackbone(nn.Module):
|
|
| 540 |
all_hidden_states.append(x)
|
| 541 |
logits = self.output_layer(x, c)
|
| 542 |
if self.cross_attn and not sample_mode:
|
| 543 |
-
logits = logits[:, :
|
| 544 |
-
all_hidden_states = [hidden_states[:, :
|
| 545 |
return logits, all_hidden_states
|
| 546 |
|
| 547 |
class BD3LM(transformers.PreTrainedModel):
|
|
|
|
| 302 |
def __init__(self, n, block_size, dim, n_heads, cond_dim, mlp_ratio=4,
|
| 303 |
dropout=0.1, max_seqlen=1024, attn_backend='flash_attn'):
|
| 304 |
super().__init__()
|
|
|
|
| 305 |
self.n = n
|
| 306 |
self.block_size = block_size
|
| 307 |
self.n_heads = n_heads
|
|
|
|
| 340 |
qkv = self.attn_qkv(x)
|
| 341 |
# store kv cache in a sliding window (can't exceed context len)
|
| 342 |
if store_kv:
|
| 343 |
+
self.kv_cache = qkv[:, -(self.n-self.block_size):]
|
| 344 |
|
| 345 |
qkv = einops.rearrange(
|
| 346 |
qkv,
|
|
|
|
| 388 |
|
| 389 |
# get qkvs
|
| 390 |
if mask is not None and not sample_mode:
|
| 391 |
+
n = mask.shape[-1] // 2
|
| 392 |
+
qkv_x = self.get_qkv(x[:,:n], rotary_cos_sin)
|
| 393 |
+
qkv_x0 = self.get_qkv(x[:,n:], rotary_cos_sin)
|
| 394 |
qkv = torch.cat((qkv_x, qkv_x0), dim=1)
|
| 395 |
else:
|
| 396 |
qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
|
|
|
|
| 518 |
all_hidden_states.append(x)
|
| 519 |
c = F.silu(self.sigma_map(sigma))
|
| 520 |
if self.cross_attn:
|
| 521 |
+
n = self.mask.shape[-1] // 2
|
| 522 |
+
rotary_cos_sin = self.rotary_emb(x[:, :n])
|
| 523 |
mask = self.mask.to(x.device)
|
| 524 |
# use block-causal mask only during sampling
|
| 525 |
if sample_mode:
|
| 526 |
mask = mask[
|
| 527 |
+
n:n+x.shape[1], n:n+x.shape[1]]
|
| 528 |
else:
|
| 529 |
mask = None
|
| 530 |
rotary_cos_sin = self.rotary_emb(x)
|
|
|
|
| 541 |
all_hidden_states.append(x)
|
| 542 |
logits = self.output_layer(x, c)
|
| 543 |
if self.cross_attn and not sample_mode:
|
| 544 |
+
logits = logits[:, :n]
|
| 545 |
+
all_hidden_states = [hidden_states[:, :n] for hidden_states in all_hidden_states]
|
| 546 |
return logits, all_hidden_states
|
| 547 |
|
| 548 |
class BD3LM(transformers.PreTrainedModel):
|