File size: 7,668 Bytes
03ae676 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from dataclasses import dataclass
from functools import partial
from models.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer
# @dataclass
# class VisionCfg:
# layers: Union[Tuple[int, int, int, int], int] = 6
# width: int = 512
# head_width: int = 64
# mlp_ratio: float = 4.0
# ls_init_value: Optional[float] = None # layer scale initial value
# patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
# no_ln_pre: bool = False # disable pre transformer LayerNorm
# pool_type: str = 'none'
# final_ln_after_pool: bool = True # apply final LayerNorm after pooling
# output_tokens: bool = False
# act_kwargs: Optional[dict] = None
# norm_kwargs: Optional[dict] = None
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 6
width: int = 512
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
attn_pooler_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
no_ln_pre: bool = False # disable pre transformer LayerNorm
pos_embed_type: str = 'none'
final_ln_after_pool: bool = True # apply final LayerNorm after pooling
pool_type: str = 'none'
output_tokens: bool = False
act_kwargs: Optional[dict] = None
norm_kwargs: Optional[dict] = None
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
img_embed: bool = False
cls_embed: bool = False
projection = False
use_flex = True
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def get_input_dtype(precision: str):
input_dtype = None
if precision in ('bf16', 'pure_bf16'):
input_dtype = torch.bfloat16
elif precision in ('fp16', 'pure_fp16'):
input_dtype = torch.float16
return input_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
dropout: float = 0.1,
num_registers: int = 0,
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
act_layer = QuickGELU if quick_gelu else nn.GELU
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if vision_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
if vision_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
visual = VisionTransformer(
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
output_dim=embed_dim,
patch_dropout=vision_cfg.patch_dropout,
no_ln_pre=vision_cfg.no_ln_pre,
pool_type=vision_cfg.pool_type,
final_ln_after_pool=vision_cfg.final_ln_after_pool,
act_layer=act_layer,
norm_layer=norm_layer,
output_tokens=vision_cfg.output_tokens,
img_embed = vision_cfg.img_embed,
use_flex = True,
dropout = dropout,
num_registers = num_registers,
use_rel_bias =True,
)
return visual
class MixedOmicsModel(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
drop_rate: float = 0.25,
num_registers: int = 0,
*args,
**kwargs,
):
super().__init__()
self.drop_prob = drop_rate
self.num_registers = num_registers
vision_cfg.cls_embed = False
self.visual = _build_vision_tower(embed_dim,
vision_cfg,
quick_gelu,
cast_dtype,
dropout=drop_rate,
num_registers=0,
)
self.image_proj = nn.Linear(embed_dim, embed_dim)
self.image_proj.apply(self.init_weights)
self.ln_post = LayerNorm(embed_dim)
def init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _check_tensor(self, tensor, name):
print(name, " : ", tensor.shape)
if torch.isnan(tensor).any():
print(tensor.shape)
print(f"Tensor {name} contains NaN values.")
if torch.isinf(tensor).any():
print(tensor.shape)
print(f"Tensor {name} contains Inf values.")
def forward(
self,
image,
coords=None,
im_mask=None,
*args,
**kwargs,
):
## image embedding
image_embeds = self.visual(image.contiguous(), coords=coords.contiguous(), key_padding_mask=None if im_mask is None else (~im_mask.bool()).contiguous())
image_embeds = self.ln_post(image_embeds)
if im_mask is not None:
mask = im_mask.unsqueeze(-1).contiguous()
masked_embeds = image_embeds * mask
sum_embeds = masked_embeds.sum(dim=1)
valid_counts = mask.sum(dim=1).clamp(min=1) # [N, 1]
mean_embeds = sum_embeds / valid_counts # [N, dim]
else:
mean_embeds = image_embeds.mean(-2)
image_embeds_final = self.image_proj(mean_embeds)
return image_embeds_final, image_embeds, mean_embeds
def make_model(
embed_dim=768,
droprate=0.1,
num_registers=0,
depth=4,
):
vCfg = CLIPVisionCfg
vCfg.width = embed_dim
vCfg.layers = depth
model = MixedOmicsModel(
embed_dim=embed_dim,
vision_cfg=vCfg,
drop_rate=droprate,
num_registers=num_registers,
)
return model |