File size: 27,495 Bytes
6cd6a16 0012f0c 6cd6a16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 |
import math
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from loguru import logger
try:
motif_ops = torch.ops.motif
MotifRMSNorm = motif_ops.T5LayerNorm
ScaledDotProductAttention = None
MotifFlashAttention = motif_ops.flash_attention
except Exception: # if motif_ops is not available
MotifRMSNorm = None
ScaledDotProductAttention = None
MotifFlashAttention = None
NUM_MODULATIONS = 6
SD3_LATENT_CHANNEL = 16
LOW_RES_POSEMB_BASE_SIZE = 16
HIGH_RES_POSEMB_BASE_SIZE = 64
class IdentityConv2d(nn.Module):
def __init__(self, channels, kernel_size=3, stride=1, padding=1, bias=True):
super().__init__()
self.conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
self._initialize_identity()
def _initialize_identity(self):
k = self.conv.kernel_size[0]
nn.init.zeros_(self.conv.weight)
center = k // 2
for i in range(self.conv.in_channels):
self.conv.weight.data[i, i, center, center] = 1.0
if self.conv.bias is not None:
nn.init.zeros_(self.conv.bias)
def forward(self, x):
return self.conv(x)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.mask = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float)
if self.mask is not None:
hidden_states = self.mask.to(hidden_states.device).to(hidden_states.dtype) * hidden_states
variance = hidden_states.pow(2).sum(-1, keepdim=True)
if self.mask is not None:
variance /= torch.count_nonzero(self.mask)
else:
variance /= hidden_states.shape[-1]
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class MLP(nn.Module):
def __init__(self, input_size, hidden_size=None):
super().__init__()
if hidden_size is None:
self.input_size, self.hidden_size = input_size, input_size * 4
else:
self.input_size, self.hidden_size = input_size, hidden_size
self.gate_proj = nn.Linear(self.input_size, self.hidden_size)
self.down_proj = nn.Linear(self.hidden_size, self.input_size)
self.act_fn = nn.SiLU()
def forward(self, x):
down_proj = self.act_fn(self.gate_proj(x))
down_proj = self.down_proj(down_proj)
return down_proj
class TextTimeEmbToGlobalParams(nn.Module):
def __init__(self, emb_dim, hidden_dim):
super().__init__()
self.projection = nn.Linear(emb_dim, hidden_dim * NUM_MODULATIONS)
def forward(self, emb):
emb = F.silu(emb) # emb: B x D
params = self.projection(emb) # emb: B x C
params = params.reshape(params.shape[0], NUM_MODULATIONS, params.shape[-1] // NUM_MODULATIONS) # emb: B x 6 x C
return params.chunk(6, dim=1) # [B x 1 x C] x 6
class TextTimeEmbedding(nn.Module):
"""
Input:
pooled_text_emb (B x C_l)
time_steps (B)
Output:
()
"""
def __init__(self, time_channel, text_channel, embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0):
super().__init__()
self.time_proj = Timesteps(
time_channel, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift
)
self.time_emb = TimestepEmbedding(time_channel, time_channel * 4, out_dim=embed_dim) # Encode time emb with MLP
self.pooled_text_emb = TimestepEmbedding(
text_channel, text_channel * 4, out_dim=embed_dim
) # Encode pooled text with MLP
def forward(self, pooled_text_emb, time_steps):
time_steps = self.time_proj(time_steps)
time_emb = self.time_emb(time_steps.to(dtype=torch.bfloat16))
pooled_text_emb = self.pooled_text_emb(pooled_text_emb)
return time_emb + pooled_text_emb
class LatentPatchModule(nn.Module):
def __init__(self, patch_size, embedding_dim, latent_channels, vae_type):
super().__init__()
self.patch_size = patch_size
self.embedding_dim = embedding_dim
self.projection_SD3 = nn.Conv2d(SD3_LATENT_CHANNEL, embedding_dim, kernel_size=patch_size, stride=patch_size)
self.latent_channels = latent_channels
def forward(self, x):
assert (
x.shape[1] == SD3_LATENT_CHANNEL
), f"VAE-Latent channel is not matched with '{SD3_LATENT_CHANNEL}'. current shape: {x.shape}"
patches = self.projection_SD3(
x.to(dtype=torch.bfloat16)
) # Shape: (B, embedding_dim, num_patches_h, num_patches_w)
patches = patches.to(dtype=torch.bfloat16)
patches = patches.contiguous()
patches = patches.flatten(2) # Shape: (B, embedding_dim, num_patches)
patches = patches.transpose(1, 2) # Shape: (B, num_patches, embedding_dim)
patches = patches.contiguous()
return patches
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
n = x.shape[0]
c = self.latent_channels
p = self.patch_size
# check the valid patching
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.contiguous()
# (N x T x [C * patch_size**2]) -> (N x H x W x P_1 x P_2 x C)
x = x.reshape(shape=(n, h, w, p, p, c))
# x = torch.einsum('nhwpqc->nchpwq', x) # Note that einsum possibly be the problem.
# (N x H x W x P_1 x P_2 x C) -> (N x C x H x P_1 x W x P_2)
# (0 . 1 . 2 . 3 . 4 . 5) -> (0 . 5 . 1 . 3 2 . 4 )
x = x.permute(0, 5, 1, 3, 2, 4)
return x.reshape(shape=(n, c, h * p, h * p)).contiguous()
class TextConditionModule(nn.Module):
def __init__(self, text_dim, latent_dim):
super().__init__()
self.projection = nn.Linear(text_dim, latent_dim)
def forward(self, t5_xxl, clip_a, clip_b):
clip_emb = torch.cat([clip_a, clip_b], dim=-1)
clip_emb = torch.nn.functional.pad(clip_emb, (0, t5_xxl.shape[-1] - clip_emb.shape[-1]))
text_emb = torch.cat([clip_emb, t5_xxl], dim=-2)
text_emb = self.projection(text_emb.to(torch.bfloat16))
return text_emb
class MotifDiTBlock(nn.Module):
def __init__(self, emb_dim, t_emb_dim, attn_emb_dim, mlp_dim, attn_config, text_dim=4096):
super().__init__()
self.affine_params_c = TextTimeEmbToGlobalParams(t_emb_dim, emb_dim)
self.affine_params_x = TextTimeEmbToGlobalParams(t_emb_dim, emb_dim)
self.norm_1_c = nn.LayerNorm(emb_dim, elementwise_affine=False)
self.norm_1_x = nn.LayerNorm(emb_dim, elementwise_affine=False)
self.linear_1_c = nn.Linear(emb_dim, attn_emb_dim)
self.linear_1_x = nn.Linear(emb_dim, attn_emb_dim)
self.attn = JointAttn(attn_config)
self.norm_2_c = nn.LayerNorm(emb_dim, elementwise_affine=False)
self.norm_2_x = nn.LayerNorm(emb_dim, elementwise_affine=False)
self.mlp_3_c = MLP(emb_dim, mlp_dim)
self.mlp_3_x = MLP(emb_dim, mlp_dim)
def forward(self, x_emb, c_emb, t_emb, perturbed=False):
"""
x_emb (N, TOKEN_LENGTH x 2, C)
c_emb (N, T + REGISTER_TOKENS, C)
t_emb (N, modulation_dim)
"""
device = x_emb.device
# get global affine transformation parameters
alpha_x, beta_x, gamma_x, delta_x, epsilon_x, zeta_x = self.affine_params_x(t_emb) # scale and shift for image
alpha_c, beta_c, gamma_c, delta_c, epsilon_c, zeta_c = self.affine_params_c(t_emb) # scale and shift for text
# projection and affine transform before attention
x_emb_pre_attn = self.linear_1_x((1 + alpha_x) * self.norm_1_x(x_emb) + beta_x)
c_emb_pre_attn = self.linear_1_c((1 + alpha_c) * self.norm_1_c(c_emb) + beta_c)
# attn_output, attn_weight (None), past_key_value (None)
x_emb_post_attn, c_emb_post_attn = self.attn(
x_emb_pre_attn, c_emb_pre_attn, perturbed
) # mixed feature for both text and image (N, [T_x + T_c], C)
# scale with gamma and residual with the original inputs
x_emb_post_attn = x_emb_post_attn.to(gamma_x.device)
x_emb_post_attn = (1 + gamma_x) * x_emb_post_attn + x_emb # NOTE: nan loss for self.linear_2_x.bias
c_emb_post_attn = c_emb_post_attn.to(gamma_c.device)
c_emb_post_attn = (1 + gamma_c) * c_emb_post_attn + c_emb
# norm the features -> affine transform with modulation -> MLP
normalized_x_emb = self.norm_2_x(x_emb_post_attn).to(delta_x.device)
normalized_c_emb = self.norm_2_c(c_emb_post_attn).to(delta_c.device)
x_emb_final = self.mlp_3_x(delta_x * normalized_x_emb + epsilon_x)
c_emb_final = self.mlp_3_c(delta_c * normalized_c_emb + epsilon_c)
# final scaling with zeta and residual with the original inputs
x_emb_final = zeta_x.to(device) * x_emb_final.to(device) + x_emb.to(device)
c_emb_final = zeta_c.to(device) * c_emb_final.to(device) + c_emb.to(device)
return x_emb_final, c_emb_final
class MotifDiT(nn.Module):
ENCODED_TEXT_DIM = 4096
def __init__(self, config):
super(MotifDiT, self).__init__()
self.patch_size = config.patch_size
self.h, self.w = config.height // config.vae_compression, config.width // config.vae_compression
self.latent_chennels = 16
# Embedding for (1) text; (2) input image; (3) time
self.text_cond = TextConditionModule(self.ENCODED_TEXT_DIM, config.hidden_dim)
self.patching = LatentPatchModule(config.patch_size, config.hidden_dim, self.latent_chennels, config.vae_type)
self.time_emb = TextTimeEmbedding(config.time_embed_dim, config.pooled_text_dim, config.modulation_dim)
# main multi-modal DiT blocks
self.mmdit_blocks = nn.ModuleList(
[
MotifDiTBlock(
config.hidden_dim, config.modulation_dim, config.hidden_dim, config.mlp_hidden_dim, config
)
for layer_idx in range(config.num_layers)
]
)
self.final_modulation = nn.Linear(config.modulation_dim, config.hidden_dim * 2)
self.final_linear_SD3 = nn.Linear(config.hidden_dim, SD3_LATENT_CHANNEL * config.patch_size**2)
self.skip_register_token_num = config.skip_register_token_num
if getattr(config, "pos_emb_size", None):
pos_emb_size = config.pos_emb_size
else:
pos_emb_size = HIGH_RES_POSEMB_BASE_SIZE if config.height > 512 else LOW_RES_POSEMB_BASE_SIZE
logger.info(f"Positional embedding of Motif-DiT is set to {pos_emb_size}")
self.pos_embed = torch.from_numpy(
get_2d_sincos_pos_embed(
config.hidden_dim, (self.h // self.patch_size, self.w // self.patch_size), base_size=pos_emb_size
)
).to(device="cuda", dtype=torch.bfloat16)
# set register tokens (https://arxiv.org/abs/2309.16588)
if config.register_token_num > 0:
self.register_token_num = config.register_token_num
self.register_tokens = nn.Parameter(torch.randn(1, self.register_token_num, config.hidden_dim))
self.register_parameter("register_tokens", self.register_tokens)
# if needed, add additional register tokens for higher resolution training
self.additional_register_token_num = config.additional_register_token_num
if config.additional_register_token_num > 0:
self.register_tokens_highres = nn.Parameter(
torch.randn(1, self.additional_register_token_num, config.hidden_dim)
)
self.register_parameter("register_tokens_highres", self.register_tokens_highres)
if config.use_final_layer_norm:
self.final_norm = nn.LayerNorm(config.hidden_dim)
if config.conv_header:
logger.info("use convolution header after de-patching")
self.depatching_conv_header = IdentityConv2d(SD3_LATENT_CHANNEL)
if config.use_time_token_in_attn:
self.t_token_proj = nn.Linear(config.modulation_dim, config.hidden_dim)
def forward(self, latent, t, text_embs: List[torch.Tensor], pooled_text_embs, guiding_feature=None):
"""
latent (torch.Tensor)
t (torch.Tensor)
text_embs (List[torch.Tensor])
pooled_text_embs (torch.Tensor)
"""
# 1. get inputs for the MMDiT blocks
emb_c = self.text_cond(*text_embs) # (N, L, D), text conditions
emb_t = self.time_emb(pooled_text_embs, t).to(emb_c.device) # (N, D), time and pooled text conditions
emb_x = (self.patching(latent) + self.pos_embed).to(
emb_c.device
) # (N, T, D), where T = H*W / (patch_size ** 2), input latent patches
# additional "register" tokens, to convey the global information and prevent high-norm abnormal patch
# see https://openreview.net/forum?id=2dnO3LLiJ1
if hasattr(self, "register_tokens"):
if hasattr(self, "register_tokens_highres"):
emb_x = torch.cat(
(
self.register_tokens_highres.expand(emb_x.shape[0], -1, -1),
self.register_tokens.expand(emb_x.shape[0], -1, -1),
emb_x,
),
dim=1,
)
else:
emb_x = torch.cat((self.register_tokens.expand(emb_x.shape[0], -1, -1), emb_x), dim=1)
# time embedding into text embedding
if hasattr(self, "use_time_token_in_attn"):
t_token = self.t_token_proj(emb_t).unsqueeze(1)
emb_c = torch.cat([emb_c, t_token], dim=1) # (N, [T_c + 1], C)
# 2. MMDiT Blocks
for block_idx, block in enumerate(self.mmdit_blocks):
emb_x, emb_c = block(emb_x, emb_c, emb_t)
# accumulating the feature_similarity loss
# TODO: add modeling_dit related test
if hasattr(self, "num_feature_align_layers") and block_idx == self.num_feature_align_layers:
self.feature_alignment_loss = self.feature_align_mlp(emb_x, guiding_feature) # exclude register tokens
# Remove the register tokens at the certain layer (the last layer as default).
if block_idx == len(self.mmdit_blocks) - (1 + self.skip_register_token_num):
if hasattr(self, "register_tokens_highres"):
emb_x = emb_x[
:, self.register_token_num + self.additional_register_token_num :
] # remove the register tokens for the output layer
elif hasattr(self, "register_tokens"):
emb_x = emb_x[:, self.register_token_num :] # remove the register tokens for the output layer
# 3. final modulation (shift-and-scale)
scale, shift = self.final_modulation(emb_t).chunk(2, -1) # (N, D) x 2
scale, shift = scale.unsqueeze(1), shift.unsqueeze(1) # (N, 1, D) x 2
if hasattr(self, "final_norm"):
emb_x = self.final_norm(emb_x)
final_emb = (scale + 1) * emb_x + shift
# 4. final linear layer to reduce channel and un-patching
emb_x = self.final_linear_SD3(final_emb) # (N, T, D) to (N, T, out_channels * patch_size**2)
emb_x = self.patching.unpatchify(emb_x) # (N, out_channels, H, W)
if hasattr(self, "depatching_conv_header"):
emb_x = self.depatching_conv_header(emb_x)
return emb_x
class JointAttn(nn.Module):
"""
SD3 style joint-attention layer
"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_dim
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.add_q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.add_k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.add_v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.add_o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.q_norm_x = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
self.k_norm_x = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
self.q_norm_c = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
self.k_norm_c = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
self.q_scale = nn.Parameter(torch.ones(self.num_heads))
# Attention mode : {'sdpa', 'flash', None}
self.attn_mode = config.attn_mode
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
# `context` projections.
query_c = self.add_q_proj(encoder_hidden_states)
key_c = self.add_k_proj(encoder_hidden_states)
value_c = self.add_v_proj(encoder_hidden_states)
# head first
inner_dim = key.shape[-1]
head_dim = inner_dim // self.num_heads
def norm_qk(x, f_norm):
x = x.view(batch_size, -1, self.num_heads, head_dim)
b, l, h, d_h = x.shape
x = x.reshape(b * l, h, d_h)
x = f_norm(x)
return x.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l, d_h]
query = norm_qk(query, self.q_norm_x) # [b, h, l, d_h]
key = norm_qk(key, self.k_norm_x) # [b, h, l, d_h]
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l, d_h]
query_c = norm_qk(query_c, self.q_norm_c) * self.q_scale.reshape(1, self.num_heads, 1, 1) # [b, h, l_c, d]
key_c = norm_qk(key_c, self.k_norm_c) # [b, h, l_c, d]
value_c = value_c.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l_c, d]
# attention
query = torch.cat([query, query_c], dim=2).contiguous() # [b, h, l + l_c, d]
key = torch.cat([key, key_c], dim=2).contiguous() # [b, h, l + l_c, d]
value = torch.cat([value, value_c], dim=2).contiguous() # [b, h, l + l_c, d]
# deprecated.
hidden_states = self.joint_attention(batch_size, query, key, value, head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# linear proj
hidden_states = self.o_proj(hidden_states)
encoder_hidden_states = self.add_o_proj(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
def joint_attention(self, batch_size, query, key, value, head_dim):
if self.attn_mode == "sdpa" and ScaledDotProductAttention is not None:
# NOTE: SDPA does not support high-resolution (long-context).
q_len = query.size(-2)
masked_bias = torch.zeros((batch_size, self.num_heads, query.size(-2), key.size(-2)), device="cuda")
query = query.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
key = key.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
value = value.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
scale_factor = 1.0
scale_factor /= float(self.head_dim) ** 0.5
hidden_states = ScaledDotProductAttention(
query,
key,
value,
masked_bias,
dropout_rate=0.0,
training=self.training,
attn_weight_scale_factor=scale_factor,
num_kv_groups=1,
)
elif self.attn_mode == "flash" and MotifFlashAttention is not None:
query = query.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
key = key.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
value = value.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
scale_factor = 1.0 / math.sqrt(self.head_dim)
# NOTE (1): masking of motif flash-attention uses (`1`: un-mask, `0`: mask) and has [Batch, Seq] shape
# NOTE (2): Q,K,V must be [Batch, Seq, Heads, Dim] and contiguous.
mask = torch.ones((batch_size, query.size(-3))).cuda()
hidden_states = MotifFlashAttention(
query,
key,
value,
padding_mask=mask,
softmax_scale=scale_factor,
causal=False,
)
hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * head_dim).contiguous()
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
return hidden_states
@staticmethod
def alt_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, scale=None) -> torch.Tensor:
"""
Pure-pytorch version of the xformers.scaled_dot_product_attention
(or F.scaled_dot_product_attention from torch>2.0.0)
Args:
query (Tensor): query tensor
key (Tensor): key tensor
value (Tensor): value tensor
attn_mask (Tensor, optional): attention mask. Defaults to None.
dropout_p (float, optional): attention dropout probability. Defaults to 0.0.
scale (Tensor or float, optional): scaling for QK. Defaults to None.
Returns:
torch.Tensor: attention score (after softmax)
"""
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor # B, L, S
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1) # B, L, S
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value # B, L, S * S, D -> B, L, D
# ===============================================
# Sine/Cosine Positional Embedding Functions
# ===============================================
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if not isinstance(grid_size, tuple):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
if base_size is not None:
grid_h *= base_size / grid_size[0]
grid_w *= base_size / grid_size[1]
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
pos = np.arange(0, length)[..., None] / scale
return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
|