|
from typing import Any, Dict, Optional, Tuple, Union |
|
import torch |
|
from torch import nn |
|
from dataclasses import dataclass |
|
from functools import partial |
|
|
|
from models.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class CLIPVisionCfg: |
|
layers: Union[Tuple[int, int, int, int], int] = 6 |
|
width: int = 512 |
|
head_width: int = 64 |
|
mlp_ratio: float = 4.0 |
|
patch_size: int = 16 |
|
image_size: Union[Tuple[int, int], int] = 224 |
|
|
|
ls_init_value: Optional[float] = None |
|
patch_dropout: float = 0. |
|
attentional_pool: bool = False |
|
attn_pooler_queries: int = 256 |
|
attn_pooler_heads: int = 8 |
|
no_ln_pre: bool = False |
|
pos_embed_type: str = 'none' |
|
final_ln_after_pool: bool = True |
|
pool_type: str = 'none' |
|
output_tokens: bool = False |
|
act_kwargs: Optional[dict] = None |
|
norm_kwargs: Optional[dict] = None |
|
|
|
timm_model_name: Optional[str] = None |
|
timm_model_pretrained: bool = False |
|
timm_pool: str = 'avg' |
|
timm_proj: str = 'linear' |
|
timm_proj_bias: bool = False |
|
timm_drop: float = 0. |
|
timm_drop_path: Optional[float] = None |
|
img_embed: bool = False |
|
cls_embed: bool = False |
|
projection = False |
|
use_flex = True |
|
|
|
|
|
def get_cast_dtype(precision: str): |
|
cast_dtype = None |
|
if precision == 'bf16': |
|
cast_dtype = torch.bfloat16 |
|
elif precision == 'fp16': |
|
cast_dtype = torch.float16 |
|
return cast_dtype |
|
|
|
|
|
def get_input_dtype(precision: str): |
|
input_dtype = None |
|
if precision in ('bf16', 'pure_bf16'): |
|
input_dtype = torch.bfloat16 |
|
elif precision in ('fp16', 'pure_fp16'): |
|
input_dtype = torch.float16 |
|
return input_dtype |
|
|
|
|
|
def _build_vision_tower( |
|
embed_dim: int, |
|
vision_cfg: CLIPVisionCfg, |
|
quick_gelu: bool = False, |
|
cast_dtype: Optional[torch.dtype] = None, |
|
dropout: float = 0.1, |
|
num_registers: int = 0, |
|
): |
|
if isinstance(vision_cfg, dict): |
|
vision_cfg = CLIPVisionCfg(**vision_cfg) |
|
|
|
act_layer = QuickGELU if quick_gelu else nn.GELU |
|
|
|
vision_heads = vision_cfg.width // vision_cfg.head_width |
|
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
|
if vision_cfg.norm_kwargs: |
|
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) |
|
if vision_cfg.act_kwargs is not None: |
|
act_layer = partial(act_layer, **vision_cfg.act_kwargs) |
|
|
|
visual = VisionTransformer( |
|
width=vision_cfg.width, |
|
layers=vision_cfg.layers, |
|
heads=vision_heads, |
|
mlp_ratio=vision_cfg.mlp_ratio, |
|
ls_init_value=vision_cfg.ls_init_value, |
|
output_dim=embed_dim, |
|
patch_dropout=vision_cfg.patch_dropout, |
|
no_ln_pre=vision_cfg.no_ln_pre, |
|
pool_type=vision_cfg.pool_type, |
|
final_ln_after_pool=vision_cfg.final_ln_after_pool, |
|
act_layer=act_layer, |
|
norm_layer=norm_layer, |
|
output_tokens=vision_cfg.output_tokens, |
|
img_embed = vision_cfg.img_embed, |
|
use_flex = True, |
|
dropout = dropout, |
|
num_registers = num_registers, |
|
use_rel_bias =True, |
|
) |
|
|
|
return visual |
|
|
|
|
|
class MixedOmicsModel(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
vision_cfg: CLIPVisionCfg, |
|
quick_gelu: bool = False, |
|
cast_dtype: Optional[torch.dtype] = None, |
|
drop_rate: float = 0.25, |
|
num_registers: int = 0, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.drop_prob = drop_rate |
|
self.num_registers = num_registers |
|
|
|
vision_cfg.cls_embed = False |
|
|
|
self.visual = _build_vision_tower(embed_dim, |
|
vision_cfg, |
|
quick_gelu, |
|
cast_dtype, |
|
dropout=drop_rate, |
|
num_registers=0, |
|
) |
|
|
|
self.image_proj = nn.Linear(embed_dim, embed_dim) |
|
self.image_proj.apply(self.init_weights) |
|
|
|
self.ln_post = LayerNorm(embed_dim) |
|
|
|
|
|
|
|
def init_weights(self, module): |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
def _check_tensor(self, tensor, name): |
|
print(name, " : ", tensor.shape) |
|
if torch.isnan(tensor).any(): |
|
print(tensor.shape) |
|
print(f"Tensor {name} contains NaN values.") |
|
if torch.isinf(tensor).any(): |
|
print(tensor.shape) |
|
print(f"Tensor {name} contains Inf values.") |
|
|
|
def forward( |
|
self, |
|
image, |
|
coords=None, |
|
im_mask=None, |
|
*args, |
|
**kwargs, |
|
): |
|
|
|
|
|
image_embeds = self.visual(image.contiguous(), coords=coords.contiguous(), key_padding_mask=None if im_mask is None else (~im_mask.bool()).contiguous()) |
|
image_embeds = self.ln_post(image_embeds) |
|
|
|
if im_mask is not None: |
|
mask = im_mask.unsqueeze(-1).contiguous() |
|
masked_embeds = image_embeds * mask |
|
sum_embeds = masked_embeds.sum(dim=1) |
|
valid_counts = mask.sum(dim=1).clamp(min=1) |
|
mean_embeds = sum_embeds / valid_counts |
|
|
|
else: |
|
mean_embeds = image_embeds.mean(-2) |
|
|
|
image_embeds_final = self.image_proj(mean_embeds) |
|
|
|
return image_embeds_final, image_embeds, mean_embeds |
|
|
|
|
|
|
|
def make_model( |
|
embed_dim=768, |
|
droprate=0.1, |
|
num_registers=0, |
|
depth=4, |
|
): |
|
vCfg = CLIPVisionCfg |
|
vCfg.width = embed_dim |
|
vCfg.layers = depth |
|
|
|
model = MixedOmicsModel( |
|
embed_dim=embed_dim, |
|
vision_cfg=vCfg, |
|
drop_rate=droprate, |
|
num_registers=num_registers, |
|
) |
|
|
|
return model |