Commit
·
b8e0a76
1
Parent(s):
c70cc56
Update Prithvi.py
Browse files- Prithvi.py +31 -3
Prithvi.py
CHANGED
|
@@ -15,12 +15,42 @@ import torch
|
|
| 15 |
import torch.nn as nn
|
| 16 |
|
| 17 |
from timm.models.vision_transformer import Block
|
| 18 |
-
from timm.models.layers import to_2tuple
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
|
| 22 |
from einops import rearrange
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 25 |
"""
|
| 26 |
grid_size: 3d tuple of grid size: t, h, w
|
|
@@ -85,8 +115,6 @@ class PatchEmbed(nn.Module):
|
|
| 85 |
|
| 86 |
def forward(self, x):
|
| 87 |
B, C, T, H, W = x.shape
|
| 88 |
-
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
| 89 |
-
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
| 90 |
x = self.proj(x)
|
| 91 |
if self.flatten:
|
| 92 |
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
|
|
|
| 15 |
import torch.nn as nn
|
| 16 |
|
| 17 |
from timm.models.vision_transformer import Block
|
| 18 |
+
from timm.models.layers import to_2tuple
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
|
| 22 |
from einops import rearrange
|
| 23 |
|
| 24 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 25 |
+
"""
|
| 26 |
+
embed_dim: output dimension for each position
|
| 27 |
+
pos: a list of positions to be encoded: size (M,)
|
| 28 |
+
out: (M, D)
|
| 29 |
+
"""
|
| 30 |
+
assert embed_dim % 2 == 0
|
| 31 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
| 32 |
+
omega /= embed_dim / 2.
|
| 33 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 34 |
+
|
| 35 |
+
pos = pos.reshape(-1) # (M,)
|
| 36 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 37 |
+
|
| 38 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 39 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 40 |
+
|
| 41 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 42 |
+
return emb
|
| 43 |
+
|
| 44 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 45 |
+
assert embed_dim % 2 == 0
|
| 46 |
+
|
| 47 |
+
# use half of dimensions to encode grid_h
|
| 48 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 49 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 50 |
+
|
| 51 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 52 |
+
return emb
|
| 53 |
+
|
| 54 |
def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 55 |
"""
|
| 56 |
grid_size: 3d tuple of grid size: t, h, w
|
|
|
|
| 115 |
|
| 116 |
def forward(self, x):
|
| 117 |
B, C, T, H, W = x.shape
|
|
|
|
|
|
|
| 118 |
x = self.proj(x)
|
| 119 |
if self.flatten:
|
| 120 |
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|