Upload model
Browse files- README.md +1 -1
- adaptor_generic.py +19 -4
- adaptor_mlp.py +35 -17
- common.py +16 -0
- dual_hybrid_vit.py +213 -0
- enable_cpe_support.py +4 -1
- enable_spectral_reparam.py +1 -1
- extra_timm_models.py +133 -1
- forward_intermediates.py +5 -2
- radio_model.py +7 -7
- vit_patch_generator.py +4 -19
README.md
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
---
|
| 2 |
library_name: transformers
|
| 3 |
-
pipeline_tag: image-feature-extraction
|
| 4 |
license: other
|
| 5 |
license_name: nvidia-open-model-license
|
| 6 |
license_link: https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf
|
|
|
|
| 7 |
---
|
| 8 |
|
| 9 |
# Model Overview
|
|
|
|
| 1 |
---
|
| 2 |
library_name: transformers
|
|
|
|
| 3 |
license: other
|
| 4 |
license_name: nvidia-open-model-license
|
| 5 |
license_link: https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf
|
| 6 |
+
pipeline_tag: image-feature-extraction
|
| 7 |
---
|
| 8 |
|
| 9 |
# Model Overview
|
adaptor_generic.py
CHANGED
|
@@ -19,9 +19,23 @@ class GenericAdaptor(AdaptorBase):
|
|
| 19 |
def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):
|
| 20 |
super().__init__()
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
if state is not None:
|
| 23 |
-
|
| 24 |
-
self.
|
|
|
|
| 25 |
else:
|
| 26 |
assert mlp_config is not None, "Config must not be None if state is None"
|
| 27 |
|
|
@@ -38,16 +52,17 @@ class GenericAdaptor(AdaptorBase):
|
|
| 38 |
mlp_config["feature"]["hidden_dim"],
|
| 39 |
mlp_config["feature"]["output_dim"],
|
| 40 |
mlp_config["feature"]["num_inner"],
|
|
|
|
| 41 |
)
|
| 42 |
|
| 43 |
def forward(self, input: AdaptorInput) -> RadioOutput:
|
| 44 |
# Convert input'd type to the type of the first parameter of the adaptor.
|
| 45 |
first_param = next(self.parameters())
|
| 46 |
summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)
|
| 47 |
-
feat = self.feat_mlp(input.features.to(dtype=first_param.dtype)).to(dtype=input.features.dtype)
|
| 48 |
|
| 49 |
if input.feature_fmt == 'NCHW':
|
| 50 |
-
feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size, input.images.shape[-1] // input.patch_size, feat.shape[2])
|
| 51 |
.permute(0, 3, 1, 2)
|
| 52 |
)
|
| 53 |
|
|
|
|
| 19 |
def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):
|
| 20 |
super().__init__()
|
| 21 |
|
| 22 |
+
extra_args = dict()
|
| 23 |
+
ups = None
|
| 24 |
+
ups_rank = None
|
| 25 |
+
if adaptor_config is not None:
|
| 26 |
+
ups = adaptor_config.get('fd_upsample_factor', None)
|
| 27 |
+
ups_rank = adaptor_config.get('fd_upsample_rank', None)
|
| 28 |
+
elif mlp_config is not None:
|
| 29 |
+
ups = mlp_config["feature"].get('upsample_factor', None)
|
| 30 |
+
ups_rank = mlp_config["feature"].get('upsample_rank', None)
|
| 31 |
+
if ups is not None:
|
| 32 |
+
extra_args['upsample_factor'] = ups
|
| 33 |
+
extra_args['upsample_rank'] = ups_rank
|
| 34 |
+
|
| 35 |
if state is not None:
|
| 36 |
+
spectral_heads = getattr(main_config, 'spectral_heads', False)
|
| 37 |
+
self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.', spectral_weights=spectral_heads)
|
| 38 |
+
self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.', spectral_weights=spectral_heads, **extra_args)
|
| 39 |
else:
|
| 40 |
assert mlp_config is not None, "Config must not be None if state is None"
|
| 41 |
|
|
|
|
| 52 |
mlp_config["feature"]["hidden_dim"],
|
| 53 |
mlp_config["feature"]["output_dim"],
|
| 54 |
mlp_config["feature"]["num_inner"],
|
| 55 |
+
**extra_args
|
| 56 |
)
|
| 57 |
|
| 58 |
def forward(self, input: AdaptorInput) -> RadioOutput:
|
| 59 |
# Convert input'd type to the type of the first parameter of the adaptor.
|
| 60 |
first_param = next(self.parameters())
|
| 61 |
summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)
|
| 62 |
+
feat = self.feat_mlp(input.features.to(dtype=first_param.dtype), images=input.images, patch_size=input.patch_size).to(dtype=input.features.dtype)
|
| 63 |
|
| 64 |
if input.feature_fmt == 'NCHW':
|
| 65 |
+
feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size * self.feat_mlp.upsample_factor, input.images.shape[-1] // input.patch_size * self.feat_mlp.upsample_factor, feat.shape[2])
|
| 66 |
.permute(0, 3, 1, 2)
|
| 67 |
)
|
| 68 |
|
adaptor_mlp.py
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
# distribution of this software and related documentation without an express
|
| 7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
import math
|
| 9 |
-
from typing import Dict
|
| 10 |
|
| 11 |
import torch
|
| 12 |
from torch import nn
|
|
@@ -14,6 +14,8 @@ from torch import nn
|
|
| 14 |
from einops import rearrange
|
| 15 |
from timm.models.vision_transformer import Block
|
| 16 |
|
|
|
|
|
|
|
| 17 |
|
| 18 |
class MLP(nn.Module):
|
| 19 |
def __init__(self, input_size: int, hidden_size: int, output_size: int,
|
|
@@ -51,6 +53,8 @@ class MLP2(nn.Module):
|
|
| 51 |
num_inner: int = 0,
|
| 52 |
pre_norm: bool = False, device: torch.device = None,
|
| 53 |
upsample_factor: int = 1,
|
|
|
|
|
|
|
| 54 |
**kwargs):
|
| 55 |
super().__init__()
|
| 56 |
|
|
@@ -60,10 +64,12 @@ class MLP2(nn.Module):
|
|
| 60 |
) if pre_norm else nn.Identity()
|
| 61 |
|
| 62 |
self.upsample_factor = upsample_factor
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
hidden_size *= upsample_factor
|
| 66 |
-
output_size *= (upsample_factor ** 2)
|
| 67 |
|
| 68 |
self.fc1 = nn.Linear(input_size, hidden_size, device=device)
|
| 69 |
|
|
@@ -82,7 +88,7 @@ class MLP2(nn.Module):
|
|
| 82 |
nn.Linear(hidden_size, output_size, device=device),
|
| 83 |
)
|
| 84 |
|
| 85 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 86 |
x = self.pre_norm(x)
|
| 87 |
x = self.fc1(x)
|
| 88 |
for block in self.blocks:
|
|
@@ -90,8 +96,12 @@ class MLP2(nn.Module):
|
|
| 90 |
x = self.final(x)
|
| 91 |
|
| 92 |
if self.upsample_factor > 1:
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
|
| 96 |
c=self._real_output_dim)
|
| 97 |
|
|
@@ -113,20 +123,22 @@ def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
|
|
| 113 |
return state
|
| 114 |
|
| 115 |
|
| 116 |
-
def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
|
| 117 |
state = strip_prefix(state, prefix)
|
| 118 |
|
|
|
|
|
|
|
| 119 |
if version == 'v1':
|
| 120 |
-
hidden_dim, input_dim = state['fc1.
|
| 121 |
-
output_dim = state['fc2.
|
| 122 |
|
| 123 |
for num_inner in range(1000):
|
| 124 |
k = f'inner.{num_inner}.0.weight'
|
| 125 |
if k not in state:
|
| 126 |
break
|
| 127 |
elif version == 'v2':
|
| 128 |
-
hidden_dim, input_dim = state['fc1.
|
| 129 |
-
output_dim = state['final.2.
|
| 130 |
|
| 131 |
for num_inner in range(1000):
|
| 132 |
k = f'blocks.{num_inner}.0.weight'
|
|
@@ -138,19 +150,25 @@ def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix
|
|
| 138 |
return input_dim, hidden_dim, output_dim, num_inner
|
| 139 |
|
| 140 |
|
| 141 |
-
def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int):
|
| 142 |
-
ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner)
|
| 143 |
|
| 144 |
return ret
|
| 145 |
|
| 146 |
|
| 147 |
-
def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
|
| 148 |
state = strip_prefix(state, prefix)
|
| 149 |
|
| 150 |
-
input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state)
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
|
|
|
|
| 153 |
|
| 154 |
ret.load_state_dict(state)
|
| 155 |
|
|
|
|
|
|
|
|
|
|
| 156 |
return ret
|
|
|
|
| 6 |
# distribution of this software and related documentation without an express
|
| 7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
import math
|
| 9 |
+
from typing import Dict, Optional
|
| 10 |
|
| 11 |
import torch
|
| 12 |
from torch import nn
|
|
|
|
| 14 |
from einops import rearrange
|
| 15 |
from timm.models.vision_transformer import Block
|
| 16 |
|
| 17 |
+
from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam
|
| 18 |
+
|
| 19 |
|
| 20 |
class MLP(nn.Module):
|
| 21 |
def __init__(self, input_size: int, hidden_size: int, output_size: int,
|
|
|
|
| 53 |
num_inner: int = 0,
|
| 54 |
pre_norm: bool = False, device: torch.device = None,
|
| 55 |
upsample_factor: int = 1,
|
| 56 |
+
upsample_rank: int = None,
|
| 57 |
+
from_config: bool = False,
|
| 58 |
**kwargs):
|
| 59 |
super().__init__()
|
| 60 |
|
|
|
|
| 64 |
) if pre_norm else nn.Identity()
|
| 65 |
|
| 66 |
self.upsample_factor = upsample_factor
|
| 67 |
+
sq_ups = upsample_factor ** 2
|
| 68 |
+
|
| 69 |
+
self._real_output_dim = output_size // sq_ups
|
| 70 |
|
| 71 |
+
# hidden_size *= upsample_factor
|
| 72 |
+
# output_size *= (upsample_factor ** 2)
|
| 73 |
|
| 74 |
self.fc1 = nn.Linear(input_size, hidden_size, device=device)
|
| 75 |
|
|
|
|
| 88 |
nn.Linear(hidden_size, output_size, device=device),
|
| 89 |
)
|
| 90 |
|
| 91 |
+
def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = None, patch_size: Optional[int] = None) -> torch.Tensor:
|
| 92 |
x = self.pre_norm(x)
|
| 93 |
x = self.fc1(x)
|
| 94 |
for block in self.blocks:
|
|
|
|
| 96 |
x = self.final(x)
|
| 97 |
|
| 98 |
if self.upsample_factor > 1:
|
| 99 |
+
if images is None:
|
| 100 |
+
raise ValueError(f'`images` cannot be `None` when the head\'s `upsample_factor > 1`!')
|
| 101 |
+
if patch_size is None:
|
| 102 |
+
raise ValueError(f'`patch_size` cannot be `None` when the head\'s `upsample_factor > 1`!')
|
| 103 |
+
h, w = tuple(d // patch_size for d in images.shape[-2:])
|
| 104 |
+
x = rearrange(x, 'b (h w) (u1 u2 c) -> b (h u1 w u2) c',
|
| 105 |
h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
|
| 106 |
c=self._real_output_dim)
|
| 107 |
|
|
|
|
| 123 |
return state
|
| 124 |
|
| 125 |
|
| 126 |
+
def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False):
|
| 127 |
state = strip_prefix(state, prefix)
|
| 128 |
|
| 129 |
+
weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original'
|
| 130 |
+
|
| 131 |
if version == 'v1':
|
| 132 |
+
hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
|
| 133 |
+
output_dim = state[f'fc2.{weight_suffix}'].shape[0]
|
| 134 |
|
| 135 |
for num_inner in range(1000):
|
| 136 |
k = f'inner.{num_inner}.0.weight'
|
| 137 |
if k not in state:
|
| 138 |
break
|
| 139 |
elif version == 'v2':
|
| 140 |
+
hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
|
| 141 |
+
output_dim = state[f'final.2.{weight_suffix}'].shape[0]
|
| 142 |
|
| 143 |
for num_inner in range(1000):
|
| 144 |
k = f'blocks.{num_inner}.0.weight'
|
|
|
|
| 150 |
return input_dim, hidden_dim, output_dim, num_inner
|
| 151 |
|
| 152 |
|
| 153 |
+
def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, **kwargs):
|
| 154 |
+
ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs)
|
| 155 |
|
| 156 |
return ret
|
| 157 |
|
| 158 |
|
| 159 |
+
def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, **kwargs):
|
| 160 |
state = strip_prefix(state, prefix)
|
| 161 |
|
| 162 |
+
input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights)
|
| 163 |
+
|
| 164 |
+
ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, **kwargs)
|
| 165 |
|
| 166 |
+
if spectral_weights:
|
| 167 |
+
enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state)
|
| 168 |
|
| 169 |
ret.load_state_dict(state)
|
| 170 |
|
| 171 |
+
if spectral_weights:
|
| 172 |
+
disable_spectral_reparam(ret)
|
| 173 |
+
|
| 174 |
return ret
|
common.py
CHANGED
|
@@ -87,6 +87,22 @@ RESOURCE_MAP = {
|
|
| 87 |
max_resolution=2048,
|
| 88 |
preferred_resolution=Resolution(512, 512),
|
| 89 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
}
|
| 91 |
|
| 92 |
DEFAULT_VERSION = "radio_v2.5-h"
|
|
|
|
| 87 |
max_resolution=2048,
|
| 88 |
preferred_resolution=Resolution(512, 512),
|
| 89 |
),
|
| 90 |
+
# C-RADIO
|
| 91 |
+
"c-radio_v2.5-g": RadioResource(
|
| 92 |
+
"https://huggingface.co/nvidia/C-RADIOv2-g/resolve/main/c-radio_v2-g_half.pth.tar",
|
| 93 |
+
patch_size=16,
|
| 94 |
+
max_resolution=2048,
|
| 95 |
+
preferred_resolution=(768, 768),
|
| 96 |
+
vitdet_num_global=8,
|
| 97 |
+
),
|
| 98 |
+
"c-radio_v3-l": RadioResource(
|
| 99 |
+
# NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-L
|
| 100 |
+
# and accept the license terms.
|
| 101 |
+
"https://huggingface.co/nvidia/C-RADIOv3-L/resolve/main/c-radio-v3_l_half.pth.tar?download=true",
|
| 102 |
+
patch_size=16,
|
| 103 |
+
max_resolution=2048,
|
| 104 |
+
preferred_resolution=Resolution(512, 512),
|
| 105 |
+
),
|
| 106 |
}
|
| 107 |
|
| 108 |
DEFAULT_VERSION = "radio_v2.5-h"
|
dual_hybrid_vit.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from logging import getLogger
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from timm.models import register_model
|
| 9 |
+
from timm.models import vision_transformer as tvit
|
| 10 |
+
from timm.models import convnext as tconv
|
| 11 |
+
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
from . import extra_timm_models as et
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Fuser(nn.Module):
|
| 18 |
+
def __init__(self, src_dim: int, tgt_dim: int, gated: bool = True):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.gated = gated
|
| 21 |
+
|
| 22 |
+
mid_dim = max(src_dim, tgt_dim) * 2
|
| 23 |
+
|
| 24 |
+
self.fwd = nn.Sequential(
|
| 25 |
+
nn.Conv2d(src_dim, mid_dim, kernel_size=3, stride=1, padding=1),
|
| 26 |
+
nn.GELU(),
|
| 27 |
+
nn.Conv2d(mid_dim, tgt_dim * (2 if gated else 1), kernel_size=3, stride=1, padding=1),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
if src.ndim == 3:
|
| 32 |
+
shape = tgt.shape[-2:]
|
| 33 |
+
else:
|
| 34 |
+
shape = src.shape[-2:]
|
| 35 |
+
|
| 36 |
+
nd = shape[0] * shape[1]
|
| 37 |
+
|
| 38 |
+
if src.ndim == 3:
|
| 39 |
+
src = src[:, -nd:].reshape(src.shape[0], src.shape[2], *shape)
|
| 40 |
+
|
| 41 |
+
if tgt.ndim == 3:
|
| 42 |
+
tgt_pre = tgt[:, :-nd]
|
| 43 |
+
tgt = tgt[:, -nd:].reshape(tgt.shape[0], tgt.shape[2], *shape)
|
| 44 |
+
else:
|
| 45 |
+
tgt_pre = None
|
| 46 |
+
|
| 47 |
+
pred = self.fwd(src)
|
| 48 |
+
|
| 49 |
+
if self.gated:
|
| 50 |
+
g, pred = torch.chunk(pred, 2, dim=1)
|
| 51 |
+
|
| 52 |
+
g = F.sigmoid(g)
|
| 53 |
+
|
| 54 |
+
pred = g * pred
|
| 55 |
+
|
| 56 |
+
tgt = tgt + pred
|
| 57 |
+
|
| 58 |
+
if tgt_pre is not None:
|
| 59 |
+
tgt = rearrange(tgt, 'b c h w -> b (h w) c')
|
| 60 |
+
tgt = torch.cat([tgt_pre, tgt], dim=1)
|
| 61 |
+
|
| 62 |
+
return tgt
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class AttnDownsample(nn.Module):
|
| 66 |
+
def __init__(self, dim: int, window_size: int, num_heads: int = 16):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.q = nn.Parameter(torch.randn(1, num_heads, 1, dim // num_heads) * 0.01)
|
| 69 |
+
self.kv = nn.Linear(dim, dim * 2)
|
| 70 |
+
self.proj = nn.Linear(dim, dim)
|
| 71 |
+
self.window_size = window_size
|
| 72 |
+
self.num_heads = num_heads
|
| 73 |
+
self.head_dim = dim // num_heads
|
| 74 |
+
self.scale = self.head_dim ** -0.5
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.Tensor, twod_shape: Tuple[int, int]) -> torch.Tensor:
|
| 77 |
+
ntok = twod_shape[0] * twod_shape[1]
|
| 78 |
+
x_pre = x[:, :-ntok]
|
| 79 |
+
|
| 80 |
+
B = x.shape[0]
|
| 81 |
+
ds_hw = tuple(s // self.window_size for s in twod_shape)
|
| 82 |
+
|
| 83 |
+
x_spat = rearrange(
|
| 84 |
+
x[:, -ntok:],
|
| 85 |
+
'b (h d1 w d2) c -> (b h w) (d1 d2) c',
|
| 86 |
+
h=ds_hw[0], w=ds_hw[1],
|
| 87 |
+
d1=self.window_size, d2=self.window_size,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
B, N, C = x_spat.shape
|
| 91 |
+
|
| 92 |
+
k, v = self.kv(x_spat).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 93 |
+
|
| 94 |
+
q = (self.q * self.scale).expand(B, -1, -1, -1)
|
| 95 |
+
attn = q @ k.transpose(-2, -1)
|
| 96 |
+
attn = F.softmax(attn, dim=-1)
|
| 97 |
+
x = attn @ v
|
| 98 |
+
|
| 99 |
+
x = x.transpose(1, 2).reshape(B, C)
|
| 100 |
+
x = self.proj(x)
|
| 101 |
+
|
| 102 |
+
x = rearrange(x, '(b h w) c -> b (h w) c', b=x_pre.shape[0], h=ds_hw[0], w=ds_hw[1])
|
| 103 |
+
|
| 104 |
+
x = torch.cat([x_pre, x], dim=1)
|
| 105 |
+
return x
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class HybridModel(nn.Module):
|
| 109 |
+
def __init__(self, vit: tvit.VisionTransformer, conv: tconv.ConvNeXt, pretrained: bool = False,
|
| 110 |
+
concatenate: bool = False, **kwargs):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.conv = conv
|
| 113 |
+
self.vit = vit
|
| 114 |
+
self.concatenate = concatenate
|
| 115 |
+
|
| 116 |
+
conv.stages = nn.ModuleList(conv.stages)
|
| 117 |
+
vit.blocks = nn.ModuleList(vit.blocks)
|
| 118 |
+
|
| 119 |
+
self._half_vit_idx = len(vit.blocks) // 2 + 1
|
| 120 |
+
|
| 121 |
+
self._half_conv_idx = None
|
| 122 |
+
x = torch.empty(1, 3, 256, 256)
|
| 123 |
+
x = self.conv.stem(x)
|
| 124 |
+
for i in range(len(conv.stages)):
|
| 125 |
+
x = conv.stages[i](x)
|
| 126 |
+
if self._half_conv_idx is None and x.shape[-2:] == (16, 16):
|
| 127 |
+
self._half_conv_idx = i + 1
|
| 128 |
+
half_conv_dim = x.shape[1]
|
| 129 |
+
final_conv_dim = x.shape[1]
|
| 130 |
+
|
| 131 |
+
self.vit_to_conv_fusion = Fuser(vit.embed_dim, half_conv_dim)
|
| 132 |
+
self.conv_to_vit_fusion = Fuser(half_conv_dim, vit.embed_dim)
|
| 133 |
+
self.vit_ds = AttnDownsample(vit.embed_dim, window_size=2)
|
| 134 |
+
|
| 135 |
+
embed_dim = vit.embed_dim + (final_conv_dim if concatenate else 0)
|
| 136 |
+
if not concatenate:
|
| 137 |
+
self.final_fuse = Fuser(final_conv_dim, vit.embed_dim, gated=False)
|
| 138 |
+
self.final_block = tvit.Block(embed_dim, num_heads=16)
|
| 139 |
+
|
| 140 |
+
self.embed_dim = embed_dim
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def patch_size(self):
|
| 144 |
+
return 32
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def no_fsdp_wrap_types(self):
|
| 148 |
+
return {tvit.VisionTransformer, tconv.ConvNeXt}
|
| 149 |
+
|
| 150 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 151 |
+
return self.forward_features(x)
|
| 152 |
+
|
| 153 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
y_vit = self.vit.patch_generator(x)
|
| 155 |
+
|
| 156 |
+
for i in range(self._half_vit_idx):
|
| 157 |
+
y_vit = self.vit.blocks[i](y_vit)
|
| 158 |
+
|
| 159 |
+
y_conv = self.conv.stem(x)
|
| 160 |
+
for i in range(self._half_conv_idx):
|
| 161 |
+
y_conv = self.conv.stages[i](y_conv)
|
| 162 |
+
|
| 163 |
+
y_vit, y_conv = self.conv_to_vit_fusion(y_conv, y_vit), self.vit_to_conv_fusion(y_vit, y_conv)
|
| 164 |
+
|
| 165 |
+
y_vit = self.vit_ds(y_vit, y_conv.shape[-2:])
|
| 166 |
+
|
| 167 |
+
for i in range(self._half_vit_idx, len(self.vit.blocks)):
|
| 168 |
+
y_vit = self.vit.blocks[i](y_vit)
|
| 169 |
+
|
| 170 |
+
for i in range(self._half_conv_idx, len(self.conv.stages)):
|
| 171 |
+
y_conv = self.conv.stages[i](y_conv)
|
| 172 |
+
|
| 173 |
+
if self.concatenate:
|
| 174 |
+
y_conv = rearrange(y_conv, 'b c h w -> b (h w) c')
|
| 175 |
+
# Average pool across the board, and replicate for each cls/register token
|
| 176 |
+
conv_summary = y_conv.mean(dim=1, keepdim=True).expand(-1, self.vit.patch_generator.num_cls_patches, -1)
|
| 177 |
+
y_conv = torch.cat([conv_summary, y_conv], dim=1)
|
| 178 |
+
y = torch.cat([y_vit, y_conv], dim=2)
|
| 179 |
+
else:
|
| 180 |
+
y = self.final_fuse(y_conv, y_vit)
|
| 181 |
+
y = self.final_block(y)
|
| 182 |
+
|
| 183 |
+
summary = y[:, :self.vit.patch_generator.num_cls_tokens]
|
| 184 |
+
features = y[:, self.vit.patch_generator.num_cls_patches:]
|
| 185 |
+
|
| 186 |
+
return summary, features
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@register_model
|
| 190 |
+
def hybrid_base(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
|
| 191 |
+
cfg = dict(num_classes=0, **kwargs)
|
| 192 |
+
conv = tconv.convnextv2_base(pretrained=pretrained, **cfg)
|
| 193 |
+
vit = tvit.vit_base_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
|
| 194 |
+
|
| 195 |
+
return HybridModel(vit, conv, pretrained, concatenate=concatenate)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@register_model
|
| 199 |
+
def hybrid_large(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
|
| 200 |
+
cfg = dict(num_classes=0, **kwargs)
|
| 201 |
+
conv = tconv.convnextv2_large(pretrained=pretrained, **cfg)
|
| 202 |
+
vit = tvit.vit_large_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
|
| 203 |
+
|
| 204 |
+
return HybridModel(vit, conv, pretrained, concatenate=concatenate)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@register_model
|
| 208 |
+
def hybrid_huge(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
|
| 209 |
+
cfg = dict(num_classes=0, **kwargs)
|
| 210 |
+
conv = tconv.convnextv2_huge(pretrained=pretrained, **cfg)
|
| 211 |
+
vit = et.vit_huge_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
|
| 212 |
+
|
| 213 |
+
return HybridModel(vit, conv, pretrained, concatenate=concatenate)
|
enable_cpe_support.py
CHANGED
|
@@ -19,6 +19,7 @@ from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermedi
|
|
| 19 |
from .extra_models import DinoWrapper
|
| 20 |
from .vit_patch_generator import ViTPatchGenerator
|
| 21 |
from .forward_intermediates import forward_intermediates
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -161,7 +162,9 @@ def enable_cpe(model: nn.Module,
|
|
| 161 |
):
|
| 162 |
if isinstance(model, VisionTransformer):
|
| 163 |
_enable_cpe_for_timm_vit(model, *args, **kwargs)
|
| 164 |
-
elif
|
| 165 |
_enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)
|
|
|
|
|
|
|
| 166 |
else:
|
| 167 |
raise ValueError(f'CPE not supported for this model type: {type(model)}')
|
|
|
|
| 19 |
from .extra_models import DinoWrapper
|
| 20 |
from .vit_patch_generator import ViTPatchGenerator
|
| 21 |
from .forward_intermediates import forward_intermediates
|
| 22 |
+
from .dual_hybrid_vit import HybridModel
|
| 23 |
|
| 24 |
|
| 25 |
def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 162 |
):
|
| 163 |
if isinstance(model, VisionTransformer):
|
| 164 |
_enable_cpe_for_timm_vit(model, *args, **kwargs)
|
| 165 |
+
elif isinstance(model, DinoWrapper):
|
| 166 |
_enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)
|
| 167 |
+
elif isinstance(model, HybridModel):
|
| 168 |
+
_enable_cpe_for_timm_vit(model.vit, *args, **kwargs)
|
| 169 |
else:
|
| 170 |
raise ValueError(f'CPE not supported for this model type: {type(model)}')
|
enable_spectral_reparam.py
CHANGED
|
@@ -155,7 +155,7 @@ def enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]],
|
|
| 155 |
return True
|
| 156 |
|
| 157 |
p_name = f'{name}.parametrizations'
|
| 158 |
-
is_prm = any(k for k in state_dict_guidance if k.startswith(p_name))
|
| 159 |
return is_prm
|
| 160 |
|
| 161 |
def parametrize_linear(linear: nn.Linear):
|
|
|
|
| 155 |
return True
|
| 156 |
|
| 157 |
p_name = f'{name}.parametrizations'
|
| 158 |
+
is_prm = any(k for k in state_dict_guidance if k.startswith(p_name) and k.endswith('_sn_version'))
|
| 159 |
return is_prm
|
| 160 |
|
| 161 |
def parametrize_linear(linear: nn.Linear):
|
extra_timm_models.py
CHANGED
|
@@ -6,10 +6,24 @@
|
|
| 6 |
# distribution of this software and related documentation without an express
|
| 7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from torch import nn
|
|
|
|
| 10 |
|
| 11 |
from timm.models import register_model
|
| 12 |
-
from timm.models.vision_transformer import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
@register_model
|
|
@@ -40,6 +54,34 @@ def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
|
| 40 |
return model
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
@register_model
|
| 44 |
def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 45 |
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
@@ -66,9 +108,99 @@ def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransforme
|
|
| 66 |
return model
|
| 67 |
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
@register_model
|
| 70 |
def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 71 |
model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6)
|
| 72 |
model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
|
| 73 |
return model
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
# distribution of this software and related documentation without an express
|
| 7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
|
| 9 |
+
import math
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
from torch import nn
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
|
| 16 |
from timm.models import register_model
|
| 17 |
+
from timm.models.vision_transformer import (
|
| 18 |
+
VisionTransformer,
|
| 19 |
+
_create_vision_transformer as _timm_create_vision_transformer,
|
| 20 |
+
Mlp,
|
| 21 |
+
Block,
|
| 22 |
+
LayerScale as TIMMLayerScale,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Import these to also register them
|
| 26 |
+
from . import dinov2_arch
|
| 27 |
|
| 28 |
|
| 29 |
@register_model
|
|
|
|
| 54 |
return model
|
| 55 |
|
| 56 |
|
| 57 |
+
@register_model
|
| 58 |
+
def vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 59 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 60 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 61 |
+
"""
|
| 62 |
+
model_args = dict(
|
| 63 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
|
| 64 |
+
reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
|
| 65 |
+
)
|
| 66 |
+
model = _create_vision_transformer(
|
| 67 |
+
'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 68 |
+
return model
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@register_model
|
| 72 |
+
def vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
| 73 |
+
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 74 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 75 |
+
"""
|
| 76 |
+
name = 'vit_large_patch14_reg4_dinov2'
|
| 77 |
+
model_args = dict(
|
| 78 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
|
| 79 |
+
reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
|
| 80 |
+
)
|
| 81 |
+
model = _create_vision_transformer(name, pretrained=pretrained, **dict(model_args, **kwargs))
|
| 82 |
+
|
| 83 |
+
return model
|
| 84 |
+
|
| 85 |
@register_model
|
| 86 |
def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 87 |
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
| 108 |
return model
|
| 109 |
|
| 110 |
|
| 111 |
+
@register_model
|
| 112 |
+
def vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **kwargs) -> VisionTransformer:
|
| 113 |
+
""" ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 114 |
+
"""
|
| 115 |
+
model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24)
|
| 116 |
+
model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))
|
| 117 |
+
if scaled_ln:
|
| 118 |
+
_apply_scaled_ln(model)
|
| 119 |
+
return model
|
| 120 |
+
|
| 121 |
+
|
| 122 |
@register_model
|
| 123 |
def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 124 |
model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6)
|
| 125 |
model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
|
| 126 |
return model
|
| 127 |
|
| 128 |
+
|
| 129 |
+
def _create_vision_transformer(*args, **kwargs):
|
| 130 |
+
model = _timm_create_vision_transformer(*args, **kwargs)
|
| 131 |
+
_patch_layer_scale(model)
|
| 132 |
+
return model
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _patch_layer_scale(model: VisionTransformer):
|
| 136 |
+
def replace_ls(old_ls: TIMMLayerScale):
|
| 137 |
+
new_ls = dinov2_arch.LayerScale(old_ls.gamma.shape[0], inplace=old_ls.inplace)
|
| 138 |
+
new_ls.load_state_dict(old_ls.state_dict())
|
| 139 |
+
return new_ls
|
| 140 |
+
|
| 141 |
+
# Monkey patch: Replace TIMM's LayerScale with our modified DINOv2 one, that uses a param name
|
| 142 |
+
# other than gamma, so that HFHub doesn't mess with it!
|
| 143 |
+
for mod in model.modules():
|
| 144 |
+
if isinstance(mod, Block):
|
| 145 |
+
if isinstance(mod.ls1, TIMMLayerScale):
|
| 146 |
+
mod.ls1 = replace_ls(mod.ls1)
|
| 147 |
+
if isinstance(mod.ls2, TIMMLayerScale):
|
| 148 |
+
mod.ls2 = replace_ls(mod.ls2)
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ScaledLayerNorm(nn.LayerNorm):
|
| 153 |
+
'''
|
| 154 |
+
https://arxiv.org/pdf/2502.05795v1
|
| 155 |
+
'''
|
| 156 |
+
def __init__(self, ln_base: nn.LayerNorm, depth: int = 0):
|
| 157 |
+
super().__init__(ln_base.normalized_shape, eps=ln_base.eps, elementwise_affine=ln_base.elementwise_affine)
|
| 158 |
+
self.load_state_dict(ln_base.state_dict())
|
| 159 |
+
self.register_buffer('ln_scale', torch.tensor(1.0 / math.sqrt(depth)), persistent=False)
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
y = super().forward(x)
|
| 163 |
+
y = y * self.ln_scale
|
| 164 |
+
return y
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class DyT(nn.Module):
|
| 168 |
+
def __init__(self, C: int, init_alpha: float):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.alpha = nn.Parameter(torch.full((1,), init_alpha))
|
| 171 |
+
self.gamma = nn.Parameter(torch.ones(C))
|
| 172 |
+
self.beta = nn.Parameter(torch.zeros(C))
|
| 173 |
+
|
| 174 |
+
def forward(self, x: torch.Tensor):
|
| 175 |
+
x = F.tanh(self.alpha * x)
|
| 176 |
+
return self.gamma * x + self.beta
|
| 177 |
+
|
| 178 |
+
@register_model
|
| 179 |
+
def vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
| 180 |
+
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 181 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 182 |
+
"""
|
| 183 |
+
model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
|
| 184 |
+
model = _create_vision_transformer('vit_large_dyt_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 185 |
+
|
| 186 |
+
def _replace_ln_with_dyt(ln: nn.LayerNorm, depth: int):
|
| 187 |
+
return DyT(ln.normalized_shape[0], init_alpha=0.9)
|
| 188 |
+
_replace_ln(model, _replace_ln_with_dyt)
|
| 189 |
+
|
| 190 |
+
return model
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _apply_scaled_ln(model: VisionTransformer):
|
| 194 |
+
warnings.warn('Post-LayerNorm scaling activated!')
|
| 195 |
+
|
| 196 |
+
_replace_ln(model, lambda ln, depth: ScaledLayerNorm(ln, depth=depth))
|
| 197 |
+
|
| 198 |
+
def _replace_ln(model: VisionTransformer, fn):
|
| 199 |
+
def _inner_replace_ln(block: Block, depth: int, key: str):
|
| 200 |
+
prev = getattr(block, key)
|
| 201 |
+
if isinstance(prev, nn.LayerNorm):
|
| 202 |
+
setattr(block, key, fn(prev, depth=depth))
|
| 203 |
+
|
| 204 |
+
for i, block in enumerate(model.blocks):
|
| 205 |
+
_inner_replace_ln(block, i + 1, 'norm1')
|
| 206 |
+
_inner_replace_ln(block, i + 1, 'norm2')
|
forward_intermediates.py
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
# distribution of this software and related documentation without an express
|
| 7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
|
| 9 |
-
from typing import Callable, List, Optional, Set, Tuple, Union, Any, Iterable
|
| 10 |
from types import MethodType
|
| 11 |
|
| 12 |
import torch
|
|
@@ -42,6 +42,7 @@ def forward_intermediates(
|
|
| 42 |
aggregation: Optional[str] = "sparse",
|
| 43 |
inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None,
|
| 44 |
norm_alpha_scheme = "post-alpha",
|
|
|
|
| 45 |
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
| 46 |
""" Forward features that returns intermediates.
|
| 47 |
|
|
@@ -65,6 +66,8 @@ def forward_intermediates(
|
|
| 65 |
reshape = output_fmt == 'NCHW'
|
| 66 |
intermediates = []
|
| 67 |
|
|
|
|
|
|
|
| 68 |
blocks = model.blocks
|
| 69 |
|
| 70 |
take_indices, max_index = _take_indices(len(blocks), indices)
|
|
@@ -90,7 +93,7 @@ def forward_intermediates(
|
|
| 90 |
take_off = 0
|
| 91 |
|
| 92 |
for i, blk in enumerate(blocks):
|
| 93 |
-
x = blk(x)
|
| 94 |
if aggregation == "dense":
|
| 95 |
# Arbitrarily use the rotation matrix from the final layer in the dense group
|
| 96 |
y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens)
|
|
|
|
| 6 |
# distribution of this software and related documentation without an express
|
| 7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
|
| 9 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, Any, Iterable
|
| 10 |
from types import MethodType
|
| 11 |
|
| 12 |
import torch
|
|
|
|
| 42 |
aggregation: Optional[str] = "sparse",
|
| 43 |
inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None,
|
| 44 |
norm_alpha_scheme = "post-alpha",
|
| 45 |
+
block_kwargs: Dict = None,
|
| 46 |
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
| 47 |
""" Forward features that returns intermediates.
|
| 48 |
|
|
|
|
| 66 |
reshape = output_fmt == 'NCHW'
|
| 67 |
intermediates = []
|
| 68 |
|
| 69 |
+
block_kwargs = block_kwargs or dict()
|
| 70 |
+
|
| 71 |
blocks = model.blocks
|
| 72 |
|
| 73 |
take_indices, max_index = _take_indices(len(blocks), indices)
|
|
|
|
| 93 |
take_off = 0
|
| 94 |
|
| 95 |
for i, blk in enumerate(blocks):
|
| 96 |
+
x = blk(x, **block_kwargs)
|
| 97 |
if aggregation == "dense":
|
| 98 |
# Arbitrarily use the rotation matrix from the final layer in the dense group
|
| 99 |
y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens)
|
radio_model.py
CHANGED
|
@@ -18,6 +18,7 @@ from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
|
|
| 18 |
from . import eradio_model
|
| 19 |
from .enable_spectral_reparam import configure_spectral_reparam_from_args
|
| 20 |
from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class Resolution(NamedTuple):
|
|
@@ -69,7 +70,7 @@ class RADIOModel(nn.Module):
|
|
| 69 |
patch_gen = getattr(self.model, "patch_generator", None)
|
| 70 |
if patch_gen is not None:
|
| 71 |
return patch_gen.num_skip
|
| 72 |
-
elif self.model
|
| 73 |
return 0
|
| 74 |
return 1
|
| 75 |
|
|
@@ -81,7 +82,7 @@ class RADIOModel(nn.Module):
|
|
| 81 |
patch_gen = getattr(self.model, 'patch_generator', None)
|
| 82 |
if patch_gen is not None:
|
| 83 |
return patch_gen.num_cls_tokens
|
| 84 |
-
elif self.model
|
| 85 |
return 0
|
| 86 |
return 1
|
| 87 |
|
|
@@ -218,7 +219,10 @@ class RADIOModel(nn.Module):
|
|
| 218 |
ret = dict(backbone=ret)
|
| 219 |
for name, adaptor in self.adaptors.items():
|
| 220 |
if all_summary.ndim == 3:
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
| 222 |
else:
|
| 223 |
summary = all_summary
|
| 224 |
ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)
|
|
@@ -326,10 +330,6 @@ def create_model_from_args(args) -> nn.Module:
|
|
| 326 |
|
| 327 |
model.head = nn.Identity()
|
| 328 |
|
| 329 |
-
assert (
|
| 330 |
-
not args.cls_token_per_teacher or args.cpe_max_size is not None
|
| 331 |
-
), "CPE must be enabled for multiple CLS tokens!"
|
| 332 |
-
|
| 333 |
if args.cpe_max_size is not None:
|
| 334 |
uq_teachers = set(t['name'] for t in args.teachers)
|
| 335 |
enable_cpe(
|
|
|
|
| 18 |
from . import eradio_model
|
| 19 |
from .enable_spectral_reparam import configure_spectral_reparam_from_args
|
| 20 |
from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
|
| 21 |
+
from . import dual_hybrid_vit
|
| 22 |
|
| 23 |
|
| 24 |
class Resolution(NamedTuple):
|
|
|
|
| 70 |
patch_gen = getattr(self.model, "patch_generator", None)
|
| 71 |
if patch_gen is not None:
|
| 72 |
return patch_gen.num_skip
|
| 73 |
+
elif getattr(self.model, 'global_pool', None) == 'avg':
|
| 74 |
return 0
|
| 75 |
return 1
|
| 76 |
|
|
|
|
| 82 |
patch_gen = getattr(self.model, 'patch_generator', None)
|
| 83 |
if patch_gen is not None:
|
| 84 |
return patch_gen.num_cls_tokens
|
| 85 |
+
elif getattr(self.model, 'global_pool', None) == 'avg':
|
| 86 |
return 0
|
| 87 |
return 1
|
| 88 |
|
|
|
|
| 219 |
ret = dict(backbone=ret)
|
| 220 |
for name, adaptor in self.adaptors.items():
|
| 221 |
if all_summary.ndim == 3:
|
| 222 |
+
if all_summary.shape[1] == 1:
|
| 223 |
+
summary = all_summary[:, 0]
|
| 224 |
+
else:
|
| 225 |
+
summary = all_summary[:, adaptor.head_idx]
|
| 226 |
else:
|
| 227 |
summary = all_summary
|
| 228 |
ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)
|
|
|
|
| 330 |
|
| 331 |
model.head = nn.Identity()
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
if args.cpe_max_size is not None:
|
| 334 |
uq_teachers = set(t['name'] for t in args.teachers)
|
| 335 |
enable_cpe(
|
vit_patch_generator.py
CHANGED
|
@@ -106,6 +106,10 @@ class ViTPatchGenerator(nn.Module):
|
|
| 106 |
def num_cls_tokens(self):
|
| 107 |
return self.cls_token.num_tokens
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
@property
|
| 110 |
def num_registers(self):
|
| 111 |
return self.cls_token.num_registers
|
|
@@ -119,10 +123,6 @@ class ViTPatchGenerator(nn.Module):
|
|
| 119 |
'pos_embed',
|
| 120 |
]
|
| 121 |
|
| 122 |
-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
| 123 |
-
if self.abs_pos:
|
| 124 |
-
self._load_embed(state_dict[f'{prefix}pos_embed'], self.pos_embed)
|
| 125 |
-
|
| 126 |
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
|
| 127 |
if src_embed.shape != targ_embed.shape:
|
| 128 |
src_size = int(math.sqrt(src_embed.shape[1]))
|
|
@@ -285,18 +285,3 @@ class ViTPatchLinear(nn.Linear):
|
|
| 285 |
**factory
|
| 286 |
)
|
| 287 |
self.patch_size = patch_size
|
| 288 |
-
|
| 289 |
-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
| 290 |
-
if self.bias is not None:
|
| 291 |
-
self.bias.data.copy_(state_dict[f'{prefix}bias'])
|
| 292 |
-
|
| 293 |
-
chk_weight = state_dict[f'{prefix}weight']
|
| 294 |
-
if chk_weight.shape != self.weight.shape:
|
| 295 |
-
src_patch_size = int(math.sqrt(chk_weight.shape[1] // 3))
|
| 296 |
-
|
| 297 |
-
assert (src_patch_size ** 2) * 3 == chk_weight.shape[1], 'Unable to interpolate non-square patch size'
|
| 298 |
-
|
| 299 |
-
chk_weight = rearrange(chk_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
|
| 300 |
-
chk_weight = F.interpolate(chk_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
|
| 301 |
-
chk_weight = rearrange(chk_weight, 'b c h w -> b (c h w)')
|
| 302 |
-
self.weight.data.copy_(chk_weight)
|
|
|
|
| 106 |
def num_cls_tokens(self):
|
| 107 |
return self.cls_token.num_tokens
|
| 108 |
|
| 109 |
+
@property
|
| 110 |
+
def num_cls_patches(self):
|
| 111 |
+
return self.cls_token.num_patches
|
| 112 |
+
|
| 113 |
@property
|
| 114 |
def num_registers(self):
|
| 115 |
return self.cls_token.num_registers
|
|
|
|
| 123 |
'pos_embed',
|
| 124 |
]
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
|
| 127 |
if src_embed.shape != targ_embed.shape:
|
| 128 |
src_size = int(math.sqrt(src_embed.shape[1]))
|
|
|
|
| 285 |
**factory
|
| 286 |
)
|
| 287 |
self.patch_size = patch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|