Motif-Image-6B-Preview / models /modeling_dit.py
beomgyu-kim's picture
Add MotifVision model for text-to-image generation
6cd6a16
raw
history blame
27.5 kB
import math
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from loguru import logger
try:
motif_ops = torch.ops.motif
MotifRMSNorm = motif_ops.T5LayerNorm
ScaledDotProductAttention = None
MotifFlashAttention = motif_ops.flash_attention
except ImportError: # if motif_ops is not available
MotifRMSNorm = None
ScaledDotProductAttention = None
MotifFlashAttention = None
NUM_MODULATIONS = 6
SD3_LATENT_CHANNEL = 16
LOW_RES_POSEMB_BASE_SIZE = 16
HIGH_RES_POSEMB_BASE_SIZE = 64
class IdentityConv2d(nn.Module):
def __init__(self, channels, kernel_size=3, stride=1, padding=1, bias=True):
super().__init__()
self.conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
self._initialize_identity()
def _initialize_identity(self):
k = self.conv.kernel_size[0]
nn.init.zeros_(self.conv.weight)
center = k // 2
for i in range(self.conv.in_channels):
self.conv.weight.data[i, i, center, center] = 1.0
if self.conv.bias is not None:
nn.init.zeros_(self.conv.bias)
def forward(self, x):
return self.conv(x)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.mask = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float)
if self.mask is not None:
hidden_states = self.mask.to(hidden_states.device).to(hidden_states.dtype) * hidden_states
variance = hidden_states.pow(2).sum(-1, keepdim=True)
if self.mask is not None:
variance /= torch.count_nonzero(self.mask)
else:
variance /= hidden_states.shape[-1]
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class MLP(nn.Module):
def __init__(self, input_size, hidden_size=None):
super().__init__()
if hidden_size is None:
self.input_size, self.hidden_size = input_size, input_size * 4
else:
self.input_size, self.hidden_size = input_size, hidden_size
self.gate_proj = nn.Linear(self.input_size, self.hidden_size)
self.down_proj = nn.Linear(self.hidden_size, self.input_size)
self.act_fn = nn.SiLU()
def forward(self, x):
down_proj = self.act_fn(self.gate_proj(x))
down_proj = self.down_proj(down_proj)
return down_proj
class TextTimeEmbToGlobalParams(nn.Module):
def __init__(self, emb_dim, hidden_dim):
super().__init__()
self.projection = nn.Linear(emb_dim, hidden_dim * NUM_MODULATIONS)
def forward(self, emb):
emb = F.silu(emb) # emb: B x D
params = self.projection(emb) # emb: B x C
params = params.reshape(params.shape[0], NUM_MODULATIONS, params.shape[-1] // NUM_MODULATIONS) # emb: B x 6 x C
return params.chunk(6, dim=1) # [B x 1 x C] x 6
class TextTimeEmbedding(nn.Module):
"""
Input:
pooled_text_emb (B x C_l)
time_steps (B)
Output:
()
"""
def __init__(self, time_channel, text_channel, embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0):
super().__init__()
self.time_proj = Timesteps(
time_channel, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift
)
self.time_emb = TimestepEmbedding(time_channel, time_channel * 4, out_dim=embed_dim) # Encode time emb with MLP
self.pooled_text_emb = TimestepEmbedding(
text_channel, text_channel * 4, out_dim=embed_dim
) # Encode pooled text with MLP
def forward(self, pooled_text_emb, time_steps):
time_steps = self.time_proj(time_steps)
time_emb = self.time_emb(time_steps.to(dtype=torch.bfloat16))
pooled_text_emb = self.pooled_text_emb(pooled_text_emb)
return time_emb + pooled_text_emb
class LatentPatchModule(nn.Module):
def __init__(self, patch_size, embedding_dim, latent_channels, vae_type):
super().__init__()
self.patch_size = patch_size
self.embedding_dim = embedding_dim
self.projection_SD3 = nn.Conv2d(SD3_LATENT_CHANNEL, embedding_dim, kernel_size=patch_size, stride=patch_size)
self.latent_channels = latent_channels
def forward(self, x):
assert (
x.shape[1] == SD3_LATENT_CHANNEL
), f"VAE-Latent channel is not matched with '{SD3_LATENT_CHANNEL}'. current shape: {x.shape}"
patches = self.projection_SD3(
x.to(dtype=torch.bfloat16)
) # Shape: (B, embedding_dim, num_patches_h, num_patches_w)
patches = patches.to(dtype=torch.bfloat16)
patches = patches.contiguous()
patches = patches.flatten(2) # Shape: (B, embedding_dim, num_patches)
patches = patches.transpose(1, 2) # Shape: (B, num_patches, embedding_dim)
patches = patches.contiguous()
return patches
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
n = x.shape[0]
c = self.latent_channels
p = self.patch_size
# check the valid patching
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.contiguous()
# (N x T x [C * patch_size**2]) -> (N x H x W x P_1 x P_2 x C)
x = x.reshape(shape=(n, h, w, p, p, c))
# x = torch.einsum('nhwpqc->nchpwq', x) # Note that einsum possibly be the problem.
# (N x H x W x P_1 x P_2 x C) -> (N x C x H x P_1 x W x P_2)
# (0 . 1 . 2 . 3 . 4 . 5) -> (0 . 5 . 1 . 3 2 . 4 )
x = x.permute(0, 5, 1, 3, 2, 4)
return x.reshape(shape=(n, c, h * p, h * p)).contiguous()
class TextConditionModule(nn.Module):
def __init__(self, text_dim, latent_dim):
super().__init__()
self.projection = nn.Linear(text_dim, latent_dim)
def forward(self, t5_xxl, clip_a, clip_b):
clip_emb = torch.cat([clip_a, clip_b], dim=-1)
clip_emb = torch.nn.functional.pad(clip_emb, (0, t5_xxl.shape[-1] - clip_emb.shape[-1]))
text_emb = torch.cat([clip_emb, t5_xxl], dim=-2)
text_emb = self.projection(text_emb.to(torch.bfloat16))
return text_emb
class MotifDiTBlock(nn.Module):
def __init__(self, emb_dim, t_emb_dim, attn_emb_dim, mlp_dim, attn_config, text_dim=4096):
super().__init__()
self.affine_params_c = TextTimeEmbToGlobalParams(t_emb_dim, emb_dim)
self.affine_params_x = TextTimeEmbToGlobalParams(t_emb_dim, emb_dim)
self.norm_1_c = nn.LayerNorm(emb_dim, elementwise_affine=False)
self.norm_1_x = nn.LayerNorm(emb_dim, elementwise_affine=False)
self.linear_1_c = nn.Linear(emb_dim, attn_emb_dim)
self.linear_1_x = nn.Linear(emb_dim, attn_emb_dim)
self.attn = JointAttn(attn_config)
self.norm_2_c = nn.LayerNorm(emb_dim, elementwise_affine=False)
self.norm_2_x = nn.LayerNorm(emb_dim, elementwise_affine=False)
self.mlp_3_c = MLP(emb_dim, mlp_dim)
self.mlp_3_x = MLP(emb_dim, mlp_dim)
def forward(self, x_emb, c_emb, t_emb, perturbed=False):
"""
x_emb (N, TOKEN_LENGTH x 2, C)
c_emb (N, T + REGISTER_TOKENS, C)
t_emb (N, modulation_dim)
"""
device = x_emb.device
# get global affine transformation parameters
alpha_x, beta_x, gamma_x, delta_x, epsilon_x, zeta_x = self.affine_params_x(t_emb) # scale and shift for image
alpha_c, beta_c, gamma_c, delta_c, epsilon_c, zeta_c = self.affine_params_c(t_emb) # scale and shift for text
# projection and affine transform before attention
x_emb_pre_attn = self.linear_1_x((1 + alpha_x) * self.norm_1_x(x_emb) + beta_x)
c_emb_pre_attn = self.linear_1_c((1 + alpha_c) * self.norm_1_c(c_emb) + beta_c)
# attn_output, attn_weight (None), past_key_value (None)
x_emb_post_attn, c_emb_post_attn = self.attn(
x_emb_pre_attn, c_emb_pre_attn, perturbed
) # mixed feature for both text and image (N, [T_x + T_c], C)
# scale with gamma and residual with the original inputs
x_emb_post_attn = x_emb_post_attn.to(gamma_x.device)
x_emb_post_attn = (1 + gamma_x) * x_emb_post_attn + x_emb # NOTE: nan loss for self.linear_2_x.bias
c_emb_post_attn = c_emb_post_attn.to(gamma_c.device)
c_emb_post_attn = (1 + gamma_c) * c_emb_post_attn + c_emb
# norm the features -> affine transform with modulation -> MLP
normalized_x_emb = self.norm_2_x(x_emb_post_attn).to(delta_x.device)
normalized_c_emb = self.norm_2_c(c_emb_post_attn).to(delta_c.device)
x_emb_final = self.mlp_3_x(delta_x * normalized_x_emb + epsilon_x)
c_emb_final = self.mlp_3_c(delta_c * normalized_c_emb + epsilon_c)
# final scaling with zeta and residual with the original inputs
x_emb_final = zeta_x.to(device) * x_emb_final.to(device) + x_emb.to(device)
c_emb_final = zeta_c.to(device) * c_emb_final.to(device) + c_emb.to(device)
return x_emb_final, c_emb_final
class MotifDiT(nn.Module):
ENCODED_TEXT_DIM = 4096
def __init__(self, config):
super(MotifDiT, self).__init__()
self.patch_size = config.patch_size
self.h, self.w = config.height // config.vae_compression, config.width // config.vae_compression
self.latent_chennels = 16
# Embedding for (1) text; (2) input image; (3) time
self.text_cond = TextConditionModule(self.ENCODED_TEXT_DIM, config.hidden_dim)
self.patching = LatentPatchModule(config.patch_size, config.hidden_dim, self.latent_chennels, config.vae_type)
self.time_emb = TextTimeEmbedding(config.time_embed_dim, config.pooled_text_dim, config.modulation_dim)
# main multi-modal DiT blocks
self.mmdit_blocks = nn.ModuleList(
[
MotifDiTBlock(
config.hidden_dim, config.modulation_dim, config.hidden_dim, config.mlp_hidden_dim, config
)
for layer_idx in range(config.num_layers)
]
)
self.final_modulation = nn.Linear(config.modulation_dim, config.hidden_dim * 2)
self.final_linear_SD3 = nn.Linear(config.hidden_dim, SD3_LATENT_CHANNEL * config.patch_size**2)
self.skip_register_token_num = config.skip_register_token_num
if getattr(config, "pos_emb_size", None):
pos_emb_size = config.pos_emb_size
else:
pos_emb_size = HIGH_RES_POSEMB_BASE_SIZE if config.height > 512 else LOW_RES_POSEMB_BASE_SIZE
logger.info(f"Positional embedding of Motif-DiT is set to {pos_emb_size}")
self.pos_embed = torch.from_numpy(
get_2d_sincos_pos_embed(
config.hidden_dim, (self.h // self.patch_size, self.w // self.patch_size), base_size=pos_emb_size
)
).to(device="cuda", dtype=torch.bfloat16)
# set register tokens (https://arxiv.org/abs/2309.16588)
if config.register_token_num > 0:
self.register_token_num = config.register_token_num
self.register_tokens = nn.Parameter(torch.randn(1, self.register_token_num, config.hidden_dim))
self.register_parameter("register_tokens", self.register_tokens)
# if needed, add additional register tokens for higher resolution training
self.additional_register_token_num = config.additional_register_token_num
if config.additional_register_token_num > 0:
self.register_tokens_highres = nn.Parameter(
torch.randn(1, self.additional_register_token_num, config.hidden_dim)
)
self.register_parameter("register_tokens_highres", self.register_tokens_highres)
if config.use_final_layer_norm:
self.final_norm = nn.LayerNorm(config.hidden_dim)
if config.conv_header:
logger.info("use convolution header after de-patching")
self.depatching_conv_header = IdentityConv2d(SD3_LATENT_CHANNEL)
if config.use_time_token_in_attn:
self.t_token_proj = nn.Linear(config.modulation_dim, config.hidden_dim)
def forward(self, latent, t, text_embs: List[torch.Tensor], pooled_text_embs, guiding_feature=None):
"""
latent (torch.Tensor)
t (torch.Tensor)
text_embs (List[torch.Tensor])
pooled_text_embs (torch.Tensor)
"""
# 1. get inputs for the MMDiT blocks
emb_c = self.text_cond(*text_embs) # (N, L, D), text conditions
emb_t = self.time_emb(pooled_text_embs, t).to(emb_c.device) # (N, D), time and pooled text conditions
emb_x = (self.patching(latent) + self.pos_embed).to(
emb_c.device
) # (N, T, D), where T = H*W / (patch_size ** 2), input latent patches
# additional "register" tokens, to convey the global information and prevent high-norm abnormal patch
# see https://openreview.net/forum?id=2dnO3LLiJ1
if hasattr(self, "register_tokens"):
if hasattr(self, "register_tokens_highres"):
emb_x = torch.cat(
(
self.register_tokens_highres.expand(emb_x.shape[0], -1, -1),
self.register_tokens.expand(emb_x.shape[0], -1, -1),
emb_x,
),
dim=1,
)
else:
emb_x = torch.cat((self.register_tokens.expand(emb_x.shape[0], -1, -1), emb_x), dim=1)
# time embedding into text embedding
if hasattr(self, "use_time_token_in_attn"):
t_token = self.t_token_proj(emb_t).unsqueeze(1)
emb_c = torch.cat([emb_c, t_token], dim=1) # (N, [T_c + 1], C)
# 2. MMDiT Blocks
for block_idx, block in enumerate(self.mmdit_blocks):
emb_x, emb_c = block(emb_x, emb_c, emb_t)
# accumulating the feature_similarity loss
# TODO: add modeling_dit related test
if hasattr(self, "num_feature_align_layers") and block_idx == self.num_feature_align_layers:
self.feature_alignment_loss = self.feature_align_mlp(emb_x, guiding_feature) # exclude register tokens
# Remove the register tokens at the certain layer (the last layer as default).
if block_idx == len(self.mmdit_blocks) - (1 + self.skip_register_token_num):
if hasattr(self, "register_tokens_highres"):
emb_x = emb_x[
:, self.register_token_num + self.additional_register_token_num :
] # remove the register tokens for the output layer
elif hasattr(self, "register_tokens"):
emb_x = emb_x[:, self.register_token_num :] # remove the register tokens for the output layer
# 3. final modulation (shift-and-scale)
scale, shift = self.final_modulation(emb_t).chunk(2, -1) # (N, D) x 2
scale, shift = scale.unsqueeze(1), shift.unsqueeze(1) # (N, 1, D) x 2
if hasattr(self, "final_norm"):
emb_x = self.final_norm(emb_x)
final_emb = (scale + 1) * emb_x + shift
# 4. final linear layer to reduce channel and un-patching
emb_x = self.final_linear_SD3(final_emb) # (N, T, D) to (N, T, out_channels * patch_size**2)
emb_x = self.patching.unpatchify(emb_x) # (N, out_channels, H, W)
if hasattr(self, "depatching_conv_header"):
emb_x = self.depatching_conv_header(emb_x)
return emb_x
class JointAttn(nn.Module):
"""
SD3 style joint-attention layer
"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_dim
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.add_q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.add_k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.add_v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.add_o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.q_norm_x = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
self.k_norm_x = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
self.q_norm_c = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
self.k_norm_c = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
self.q_scale = nn.Parameter(torch.ones(self.num_heads))
# Attention mode : {'sdpa', 'flash', None}
self.attn_mode = config.attn_mode
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
# `context` projections.
query_c = self.add_q_proj(encoder_hidden_states)
key_c = self.add_k_proj(encoder_hidden_states)
value_c = self.add_v_proj(encoder_hidden_states)
# head first
inner_dim = key.shape[-1]
head_dim = inner_dim // self.num_heads
def norm_qk(x, f_norm):
x = x.view(batch_size, -1, self.num_heads, head_dim)
b, l, h, d_h = x.shape
x = x.reshape(b * l, h, d_h)
x = f_norm(x)
return x.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l, d_h]
query = norm_qk(query, self.q_norm_x) # [b, h, l, d_h]
key = norm_qk(key, self.k_norm_x) # [b, h, l, d_h]
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l, d_h]
query_c = norm_qk(query_c, self.q_norm_c) * self.q_scale.reshape(1, self.num_heads, 1, 1) # [b, h, l_c, d]
key_c = norm_qk(key_c, self.k_norm_c) # [b, h, l_c, d]
value_c = value_c.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l_c, d]
# attention
query = torch.cat([query, query_c], dim=2).contiguous() # [b, h, l + l_c, d]
key = torch.cat([key, key_c], dim=2).contiguous() # [b, h, l + l_c, d]
value = torch.cat([value, value_c], dim=2).contiguous() # [b, h, l + l_c, d]
# deprecated.
hidden_states = self.joint_attention(batch_size, query, key, value, head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
hidden_states, encoder_hidden_states = (
hidden_states[:, : residual.shape[1]],
hidden_states[:, residual.shape[1] :],
)
# linear proj
hidden_states = self.o_proj(hidden_states)
encoder_hidden_states = self.add_o_proj(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
def joint_attention(self, batch_size, query, key, value, head_dim):
if self.attn_mode == "sdpa" and ScaledDotProductAttention is not None:
# NOTE: SDPA does not support high-resolution (long-context).
q_len = query.size(-2)
masked_bias = torch.zeros((batch_size, self.num_heads, query.size(-2), key.size(-2)), device="cuda")
query = query.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
key = key.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
value = value.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
scale_factor = 1.0
scale_factor /= float(self.head_dim) ** 0.5
hidden_states = ScaledDotProductAttention(
query,
key,
value,
masked_bias,
dropout_rate=0.0,
training=self.training,
attn_weight_scale_factor=scale_factor,
num_kv_groups=1,
)
elif self.attn_mode == "flash" and MotifFlashAttention is not None:
query = query.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
key = key.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
value = value.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
scale_factor = 1.0 / math.sqrt(self.head_dim)
# NOTE (1): masking of motif flash-attention uses (`1`: un-mask, `0`: mask) and has [Batch, Seq] shape
# NOTE (2): Q,K,V must be [Batch, Seq, Heads, Dim] and contiguous.
mask = torch.ones((batch_size, query.size(-3))).cuda()
hidden_states = MotifFlashAttention(
query,
key,
value,
padding_mask=mask,
softmax_scale=scale_factor,
causal=False,
)
hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * head_dim).contiguous()
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
return hidden_states
@staticmethod
def alt_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, scale=None) -> torch.Tensor:
"""
Pure-pytorch version of the xformers.scaled_dot_product_attention
(or F.scaled_dot_product_attention from torch>2.0.0)
Args:
query (Tensor): query tensor
key (Tensor): key tensor
value (Tensor): value tensor
attn_mask (Tensor, optional): attention mask. Defaults to None.
dropout_p (float, optional): attention dropout probability. Defaults to 0.0.
scale (Tensor or float, optional): scaling for QK. Defaults to None.
Returns:
torch.Tensor: attention score (after softmax)
"""
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor # B, L, S
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1) # B, L, S
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value # B, L, S * S, D -> B, L, D
# ===============================================
# Sine/Cosine Positional Embedding Functions
# ===============================================
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if not isinstance(grid_size, tuple):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
if base_size is not None:
grid_h *= base_size / grid_size[0]
grid_w *= base_size / grid_size[1]
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
pos = np.arange(0, length)[..., None] / scale
return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb