import math from typing import List import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.embeddings import TimestepEmbedding, Timesteps from loguru import logger try: motif_ops = torch.ops.motif MotifRMSNorm = motif_ops.T5LayerNorm ScaledDotProductAttention = None MotifFlashAttention = motif_ops.flash_attention except Exception: # if motif_ops is not available MotifRMSNorm = None ScaledDotProductAttention = None MotifFlashAttention = None NUM_MODULATIONS = 6 SD3_LATENT_CHANNEL = 16 LOW_RES_POSEMB_BASE_SIZE = 16 HIGH_RES_POSEMB_BASE_SIZE = 64 class IdentityConv2d(nn.Module): def __init__(self, channels, kernel_size=3, stride=1, padding=1, bias=True): super().__init__() self.conv = nn.Conv2d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, ) self._initialize_identity() def _initialize_identity(self): k = self.conv.kernel_size[0] nn.init.zeros_(self.conv.weight) center = k // 2 for i in range(self.conv.in_channels): self.conv.weight.data[i, i, center, center] = 1.0 if self.conv.bias is not None: nn.init.zeros_(self.conv.bias) def forward(self, x): return self.conv(x) class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps self.mask = None def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float) if self.mask is not None: hidden_states = self.mask.to(hidden_states.device).to(hidden_states.dtype) * hidden_states variance = hidden_states.pow(2).sum(-1, keepdim=True) if self.mask is not None: variance /= torch.count_nonzero(self.mask) else: variance /= hidden_states.shape[-1] hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class MLP(nn.Module): def __init__(self, input_size, hidden_size=None): super().__init__() if hidden_size is None: self.input_size, self.hidden_size = input_size, input_size * 4 else: self.input_size, self.hidden_size = input_size, hidden_size self.gate_proj = nn.Linear(self.input_size, self.hidden_size) self.down_proj = nn.Linear(self.hidden_size, self.input_size) self.act_fn = nn.SiLU() def forward(self, x): down_proj = self.act_fn(self.gate_proj(x)) down_proj = self.down_proj(down_proj) return down_proj class TextTimeEmbToGlobalParams(nn.Module): def __init__(self, emb_dim, hidden_dim): super().__init__() self.projection = nn.Linear(emb_dim, hidden_dim * NUM_MODULATIONS) def forward(self, emb): emb = F.silu(emb) # emb: B x D params = self.projection(emb) # emb: B x C params = params.reshape(params.shape[0], NUM_MODULATIONS, params.shape[-1] // NUM_MODULATIONS) # emb: B x 6 x C return params.chunk(6, dim=1) # [B x 1 x C] x 6 class TextTimeEmbedding(nn.Module): """ Input: pooled_text_emb (B x C_l) time_steps (B) Output: () """ def __init__(self, time_channel, text_channel, embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0): super().__init__() self.time_proj = Timesteps( time_channel, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift ) self.time_emb = TimestepEmbedding(time_channel, time_channel * 4, out_dim=embed_dim) # Encode time emb with MLP self.pooled_text_emb = TimestepEmbedding( text_channel, text_channel * 4, out_dim=embed_dim ) # Encode pooled text with MLP def forward(self, pooled_text_emb, time_steps): time_steps = self.time_proj(time_steps) time_emb = self.time_emb(time_steps.to(dtype=torch.bfloat16)) pooled_text_emb = self.pooled_text_emb(pooled_text_emb) return time_emb + pooled_text_emb class LatentPatchModule(nn.Module): def __init__(self, patch_size, embedding_dim, latent_channels, vae_type): super().__init__() self.patch_size = patch_size self.embedding_dim = embedding_dim self.projection_SD3 = nn.Conv2d(SD3_LATENT_CHANNEL, embedding_dim, kernel_size=patch_size, stride=patch_size) self.latent_channels = latent_channels def forward(self, x): assert ( x.shape[1] == SD3_LATENT_CHANNEL ), f"VAE-Latent channel is not matched with '{SD3_LATENT_CHANNEL}'. current shape: {x.shape}" patches = self.projection_SD3( x.to(dtype=torch.bfloat16) ) # Shape: (B, embedding_dim, num_patches_h, num_patches_w) patches = patches.to(dtype=torch.bfloat16) patches = patches.contiguous() patches = patches.flatten(2) # Shape: (B, embedding_dim, num_patches) patches = patches.transpose(1, 2) # Shape: (B, num_patches, embedding_dim) patches = patches.contiguous() return patches def unpatchify(self, x): """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ n = x.shape[0] c = self.latent_channels p = self.patch_size # check the valid patching h = w = int(x.shape[1] ** 0.5) assert h * w == x.shape[1] x = x.contiguous() # (N x T x [C * patch_size**2]) -> (N x H x W x P_1 x P_2 x C) x = x.reshape(shape=(n, h, w, p, p, c)) # x = torch.einsum('nhwpqc->nchpwq', x) # Note that einsum possibly be the problem. # (N x H x W x P_1 x P_2 x C) -> (N x C x H x P_1 x W x P_2) # (0 . 1 . 2 . 3 . 4 . 5) -> (0 . 5 . 1 . 3 2 . 4 ) x = x.permute(0, 5, 1, 3, 2, 4) return x.reshape(shape=(n, c, h * p, h * p)).contiguous() class TextConditionModule(nn.Module): def __init__(self, text_dim, latent_dim): super().__init__() self.projection = nn.Linear(text_dim, latent_dim) def forward(self, t5_xxl, clip_a, clip_b): clip_emb = torch.cat([clip_a, clip_b], dim=-1) clip_emb = torch.nn.functional.pad(clip_emb, (0, t5_xxl.shape[-1] - clip_emb.shape[-1])) text_emb = torch.cat([clip_emb, t5_xxl], dim=-2) text_emb = self.projection(text_emb.to(torch.bfloat16)) return text_emb class MotifDiTBlock(nn.Module): def __init__(self, emb_dim, t_emb_dim, attn_emb_dim, mlp_dim, attn_config, text_dim=4096): super().__init__() self.affine_params_c = TextTimeEmbToGlobalParams(t_emb_dim, emb_dim) self.affine_params_x = TextTimeEmbToGlobalParams(t_emb_dim, emb_dim) self.norm_1_c = nn.LayerNorm(emb_dim, elementwise_affine=False) self.norm_1_x = nn.LayerNorm(emb_dim, elementwise_affine=False) self.linear_1_c = nn.Linear(emb_dim, attn_emb_dim) self.linear_1_x = nn.Linear(emb_dim, attn_emb_dim) self.attn = JointAttn(attn_config) self.norm_2_c = nn.LayerNorm(emb_dim, elementwise_affine=False) self.norm_2_x = nn.LayerNorm(emb_dim, elementwise_affine=False) self.mlp_3_c = MLP(emb_dim, mlp_dim) self.mlp_3_x = MLP(emb_dim, mlp_dim) def forward(self, x_emb, c_emb, t_emb, perturbed=False): """ x_emb (N, TOKEN_LENGTH x 2, C) c_emb (N, T + REGISTER_TOKENS, C) t_emb (N, modulation_dim) """ device = x_emb.device # get global affine transformation parameters alpha_x, beta_x, gamma_x, delta_x, epsilon_x, zeta_x = self.affine_params_x(t_emb) # scale and shift for image alpha_c, beta_c, gamma_c, delta_c, epsilon_c, zeta_c = self.affine_params_c(t_emb) # scale and shift for text # projection and affine transform before attention x_emb_pre_attn = self.linear_1_x((1 + alpha_x) * self.norm_1_x(x_emb) + beta_x) c_emb_pre_attn = self.linear_1_c((1 + alpha_c) * self.norm_1_c(c_emb) + beta_c) # attn_output, attn_weight (None), past_key_value (None) x_emb_post_attn, c_emb_post_attn = self.attn( x_emb_pre_attn, c_emb_pre_attn, perturbed ) # mixed feature for both text and image (N, [T_x + T_c], C) # scale with gamma and residual with the original inputs x_emb_post_attn = x_emb_post_attn.to(gamma_x.device) x_emb_post_attn = (1 + gamma_x) * x_emb_post_attn + x_emb # NOTE: nan loss for self.linear_2_x.bias c_emb_post_attn = c_emb_post_attn.to(gamma_c.device) c_emb_post_attn = (1 + gamma_c) * c_emb_post_attn + c_emb # norm the features -> affine transform with modulation -> MLP normalized_x_emb = self.norm_2_x(x_emb_post_attn).to(delta_x.device) normalized_c_emb = self.norm_2_c(c_emb_post_attn).to(delta_c.device) x_emb_final = self.mlp_3_x(delta_x * normalized_x_emb + epsilon_x) c_emb_final = self.mlp_3_c(delta_c * normalized_c_emb + epsilon_c) # final scaling with zeta and residual with the original inputs x_emb_final = zeta_x.to(device) * x_emb_final.to(device) + x_emb.to(device) c_emb_final = zeta_c.to(device) * c_emb_final.to(device) + c_emb.to(device) return x_emb_final, c_emb_final class MotifDiT(nn.Module): ENCODED_TEXT_DIM = 4096 def __init__(self, config): super(MotifDiT, self).__init__() self.patch_size = config.patch_size self.h, self.w = config.height // config.vae_compression, config.width // config.vae_compression self.latent_chennels = 16 # Embedding for (1) text; (2) input image; (3) time self.text_cond = TextConditionModule(self.ENCODED_TEXT_DIM, config.hidden_dim) self.patching = LatentPatchModule(config.patch_size, config.hidden_dim, self.latent_chennels, config.vae_type) self.time_emb = TextTimeEmbedding(config.time_embed_dim, config.pooled_text_dim, config.modulation_dim) # main multi-modal DiT blocks self.mmdit_blocks = nn.ModuleList( [ MotifDiTBlock( config.hidden_dim, config.modulation_dim, config.hidden_dim, config.mlp_hidden_dim, config ) for layer_idx in range(config.num_layers) ] ) self.final_modulation = nn.Linear(config.modulation_dim, config.hidden_dim * 2) self.final_linear_SD3 = nn.Linear(config.hidden_dim, SD3_LATENT_CHANNEL * config.patch_size**2) self.skip_register_token_num = config.skip_register_token_num if getattr(config, "pos_emb_size", None): pos_emb_size = config.pos_emb_size else: pos_emb_size = HIGH_RES_POSEMB_BASE_SIZE if config.height > 512 else LOW_RES_POSEMB_BASE_SIZE logger.info(f"Positional embedding of Motif-DiT is set to {pos_emb_size}") self.pos_embed = torch.from_numpy( get_2d_sincos_pos_embed( config.hidden_dim, (self.h // self.patch_size, self.w // self.patch_size), base_size=pos_emb_size ) ).to(device="cuda", dtype=torch.bfloat16) # set register tokens (https://arxiv.org/abs/2309.16588) if config.register_token_num > 0: self.register_token_num = config.register_token_num self.register_tokens = nn.Parameter(torch.randn(1, self.register_token_num, config.hidden_dim)) self.register_parameter("register_tokens", self.register_tokens) # if needed, add additional register tokens for higher resolution training self.additional_register_token_num = config.additional_register_token_num if config.additional_register_token_num > 0: self.register_tokens_highres = nn.Parameter( torch.randn(1, self.additional_register_token_num, config.hidden_dim) ) self.register_parameter("register_tokens_highres", self.register_tokens_highres) if config.use_final_layer_norm: self.final_norm = nn.LayerNorm(config.hidden_dim) if config.conv_header: logger.info("use convolution header after de-patching") self.depatching_conv_header = IdentityConv2d(SD3_LATENT_CHANNEL) if config.use_time_token_in_attn: self.t_token_proj = nn.Linear(config.modulation_dim, config.hidden_dim) def forward(self, latent, t, text_embs: List[torch.Tensor], pooled_text_embs, guiding_feature=None): """ latent (torch.Tensor) t (torch.Tensor) text_embs (List[torch.Tensor]) pooled_text_embs (torch.Tensor) """ # 1. get inputs for the MMDiT blocks emb_c = self.text_cond(*text_embs) # (N, L, D), text conditions emb_t = self.time_emb(pooled_text_embs, t).to(emb_c.device) # (N, D), time and pooled text conditions emb_x = (self.patching(latent) + self.pos_embed).to( emb_c.device ) # (N, T, D), where T = H*W / (patch_size ** 2), input latent patches # additional "register" tokens, to convey the global information and prevent high-norm abnormal patch # see https://openreview.net/forum?id=2dnO3LLiJ1 if hasattr(self, "register_tokens"): if hasattr(self, "register_tokens_highres"): emb_x = torch.cat( ( self.register_tokens_highres.expand(emb_x.shape[0], -1, -1), self.register_tokens.expand(emb_x.shape[0], -1, -1), emb_x, ), dim=1, ) else: emb_x = torch.cat((self.register_tokens.expand(emb_x.shape[0], -1, -1), emb_x), dim=1) # time embedding into text embedding if hasattr(self, "use_time_token_in_attn"): t_token = self.t_token_proj(emb_t).unsqueeze(1) emb_c = torch.cat([emb_c, t_token], dim=1) # (N, [T_c + 1], C) # 2. MMDiT Blocks for block_idx, block in enumerate(self.mmdit_blocks): emb_x, emb_c = block(emb_x, emb_c, emb_t) # accumulating the feature_similarity loss # TODO: add modeling_dit related test if hasattr(self, "num_feature_align_layers") and block_idx == self.num_feature_align_layers: self.feature_alignment_loss = self.feature_align_mlp(emb_x, guiding_feature) # exclude register tokens # Remove the register tokens at the certain layer (the last layer as default). if block_idx == len(self.mmdit_blocks) - (1 + self.skip_register_token_num): if hasattr(self, "register_tokens_highres"): emb_x = emb_x[ :, self.register_token_num + self.additional_register_token_num : ] # remove the register tokens for the output layer elif hasattr(self, "register_tokens"): emb_x = emb_x[:, self.register_token_num :] # remove the register tokens for the output layer # 3. final modulation (shift-and-scale) scale, shift = self.final_modulation(emb_t).chunk(2, -1) # (N, D) x 2 scale, shift = scale.unsqueeze(1), shift.unsqueeze(1) # (N, 1, D) x 2 if hasattr(self, "final_norm"): emb_x = self.final_norm(emb_x) final_emb = (scale + 1) * emb_x + shift # 4. final linear layer to reduce channel and un-patching emb_x = self.final_linear_SD3(final_emb) # (N, T, D) to (N, T, out_channels * patch_size**2) emb_x = self.patching.unpatchify(emb_x) # (N, out_channels, H, W) if hasattr(self, "depatching_conv_header"): emb_x = self.depatching_conv_header(emb_x) return emb_x class JointAttn(nn.Module): """ SD3 style joint-attention layer """ def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_dim self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.add_q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.add_k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.add_v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.add_o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.q_norm_x = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim) self.k_norm_x = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim) self.q_norm_c = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim) self.k_norm_c = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim) self.q_scale = nn.Parameter(torch.ones(self.num_heads)) # Attention mode : {'sdpa', 'flash', None} self.attn_mode = config.attn_mode def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, *args, **kwargs, ) -> torch.FloatTensor: residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size = encoder_hidden_states.shape[0] # `sample` projections. query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) # `context` projections. query_c = self.add_q_proj(encoder_hidden_states) key_c = self.add_k_proj(encoder_hidden_states) value_c = self.add_v_proj(encoder_hidden_states) # head first inner_dim = key.shape[-1] head_dim = inner_dim // self.num_heads def norm_qk(x, f_norm): x = x.view(batch_size, -1, self.num_heads, head_dim) b, l, h, d_h = x.shape x = x.reshape(b * l, h, d_h) x = f_norm(x) return x.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l, d_h] query = norm_qk(query, self.q_norm_x) # [b, h, l, d_h] key = norm_qk(key, self.k_norm_x) # [b, h, l, d_h] value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l, d_h] query_c = norm_qk(query_c, self.q_norm_c) * self.q_scale.reshape(1, self.num_heads, 1, 1) # [b, h, l_c, d] key_c = norm_qk(key_c, self.k_norm_c) # [b, h, l_c, d] value_c = value_c.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l_c, d] # attention query = torch.cat([query, query_c], dim=2).contiguous() # [b, h, l + l_c, d] key = torch.cat([key, key_c], dim=2).contiguous() # [b, h, l + l_c, d] value = torch.cat([value, value_c], dim=2).contiguous() # [b, h, l + l_c, d] # deprecated. hidden_states = self.joint_attention(batch_size, query, key, value, head_dim) hidden_states = hidden_states.to(query.dtype) # Split the attention outputs. hidden_states, encoder_hidden_states = ( hidden_states[:, : residual.shape[1]], hidden_states[:, residual.shape[1] :], ) # linear proj hidden_states = self.o_proj(hidden_states) encoder_hidden_states = self.add_o_proj(encoder_hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if context_input_ndim == 4: encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states, encoder_hidden_states def joint_attention(self, batch_size, query, key, value, head_dim): if self.attn_mode == "sdpa" and ScaledDotProductAttention is not None: # NOTE: SDPA does not support high-resolution (long-context). q_len = query.size(-2) masked_bias = torch.zeros((batch_size, self.num_heads, query.size(-2), key.size(-2)), device="cuda") query = query.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous() key = key.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous() value = value.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous() scale_factor = 1.0 scale_factor /= float(self.head_dim) ** 0.5 hidden_states = ScaledDotProductAttention( query, key, value, masked_bias, dropout_rate=0.0, training=self.training, attn_weight_scale_factor=scale_factor, num_kv_groups=1, ) elif self.attn_mode == "flash" and MotifFlashAttention is not None: query = query.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d] key = key.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d] value = value.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d] scale_factor = 1.0 / math.sqrt(self.head_dim) # NOTE (1): masking of motif flash-attention uses (`1`: un-mask, `0`: mask) and has [Batch, Seq] shape # NOTE (2): Q,K,V must be [Batch, Seq, Heads, Dim] and contiguous. mask = torch.ones((batch_size, query.size(-3))).cuda() hidden_states = MotifFlashAttention( query, key, value, padding_mask=mask, softmax_scale=scale_factor, causal=False, ) hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * head_dim).contiguous() else: hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim) return hidden_states @staticmethod def alt_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, scale=None) -> torch.Tensor: """ Pure-pytorch version of the xformers.scaled_dot_product_attention (or F.scaled_dot_product_attention from torch>2.0.0) Args: query (Tensor): query tensor key (Tensor): key tensor value (Tensor): value tensor attn_mask (Tensor, optional): attention mask. Defaults to None. dropout_p (float, optional): attention dropout probability. Defaults to 0.0. scale (Tensor or float, optional): scaling for QK. Defaults to None. Returns: torch.Tensor: attention score (after softmax) """ L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor # B, L, S attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) # B, L, S attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value # B, L, S * S, D -> B, L, D # =============================================== # Sine/Cosine Positional Embedding Functions # =============================================== # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if not isinstance(grid_size, tuple): grid_size = (grid_size, grid_size) grid_h = np.arange(grid_size[0], dtype=np.float32) / scale grid_w = np.arange(grid_size[1], dtype=np.float32) / scale if base_size is not None: grid_h *= base_size / grid_size[0] grid_w *= base_size / grid_size[1] grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0): pos = np.arange(0, length)[..., None] / scale return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb