2ms's picture
init commit
03ae676
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from dataclasses import dataclass
from functools import partial
from models.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer
# @dataclass
# class VisionCfg:
# layers: Union[Tuple[int, int, int, int], int] = 6
# width: int = 512
# head_width: int = 64
# mlp_ratio: float = 4.0
# ls_init_value: Optional[float] = None # layer scale initial value
# patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
# no_ln_pre: bool = False # disable pre transformer LayerNorm
# pool_type: str = 'none'
# final_ln_after_pool: bool = True # apply final LayerNorm after pooling
# output_tokens: bool = False
# act_kwargs: Optional[dict] = None
# norm_kwargs: Optional[dict] = None
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 6
width: int = 512
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
attn_pooler_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
no_ln_pre: bool = False # disable pre transformer LayerNorm
pos_embed_type: str = 'none'
final_ln_after_pool: bool = True # apply final LayerNorm after pooling
pool_type: str = 'none'
output_tokens: bool = False
act_kwargs: Optional[dict] = None
norm_kwargs: Optional[dict] = None
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
img_embed: bool = False
cls_embed: bool = False
projection = False
use_flex = True
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def get_input_dtype(precision: str):
input_dtype = None
if precision in ('bf16', 'pure_bf16'):
input_dtype = torch.bfloat16
elif precision in ('fp16', 'pure_fp16'):
input_dtype = torch.float16
return input_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
dropout: float = 0.1,
num_registers: int = 0,
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
act_layer = QuickGELU if quick_gelu else nn.GELU
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if vision_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
if vision_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
visual = VisionTransformer(
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
output_dim=embed_dim,
patch_dropout=vision_cfg.patch_dropout,
no_ln_pre=vision_cfg.no_ln_pre,
pool_type=vision_cfg.pool_type,
final_ln_after_pool=vision_cfg.final_ln_after_pool,
act_layer=act_layer,
norm_layer=norm_layer,
output_tokens=vision_cfg.output_tokens,
img_embed = vision_cfg.img_embed,
use_flex = True,
dropout = dropout,
num_registers = num_registers,
use_rel_bias =True,
)
return visual
class MixedOmicsModel(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
drop_rate: float = 0.25,
num_registers: int = 0,
*args,
**kwargs,
):
super().__init__()
self.drop_prob = drop_rate
self.num_registers = num_registers
vision_cfg.cls_embed = False
self.visual = _build_vision_tower(embed_dim,
vision_cfg,
quick_gelu,
cast_dtype,
dropout=drop_rate,
num_registers=0,
)
self.image_proj = nn.Linear(embed_dim, embed_dim)
self.image_proj.apply(self.init_weights)
self.ln_post = LayerNorm(embed_dim)
def init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _check_tensor(self, tensor, name):
print(name, " : ", tensor.shape)
if torch.isnan(tensor).any():
print(tensor.shape)
print(f"Tensor {name} contains NaN values.")
if torch.isinf(tensor).any():
print(tensor.shape)
print(f"Tensor {name} contains Inf values.")
def forward(
self,
image,
coords=None,
im_mask=None,
*args,
**kwargs,
):
## image embedding
image_embeds = self.visual(image.contiguous(), coords=coords.contiguous(), key_padding_mask=None if im_mask is None else (~im_mask.bool()).contiguous())
image_embeds = self.ln_post(image_embeds)
if im_mask is not None:
mask = im_mask.unsqueeze(-1).contiguous()
masked_embeds = image_embeds * mask
sum_embeds = masked_embeds.sum(dim=1)
valid_counts = mask.sum(dim=1).clamp(min=1) # [N, 1]
mean_embeds = sum_embeds / valid_counts # [N, dim]
else:
mean_embeds = image_embeds.mean(-2)
image_embeds_final = self.image_proj(mean_embeds)
return image_embeds_final, image_embeds, mean_embeds
def make_model(
embed_dim=768,
droprate=0.1,
num_registers=0,
depth=4,
):
vCfg = CLIPVisionCfg
vCfg.width = embed_dim
vCfg.layers = depth
model = MixedOmicsModel(
embed_dim=embed_dim,
vision_cfg=vCfg,
drop_rate=droprate,
num_registers=num_registers,
)
return model