from typing import Any, Dict, Optional, Tuple, Union import torch from torch import nn from dataclasses import dataclass from functools import partial from models.transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer # @dataclass # class VisionCfg: # layers: Union[Tuple[int, int, int, int], int] = 6 # width: int = 512 # head_width: int = 64 # mlp_ratio: float = 4.0 # ls_init_value: Optional[float] = None # layer scale initial value # patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results # no_ln_pre: bool = False # disable pre transformer LayerNorm # pool_type: str = 'none' # final_ln_after_pool: bool = True # apply final LayerNorm after pooling # output_tokens: bool = False # act_kwargs: Optional[dict] = None # norm_kwargs: Optional[dict] = None @dataclass class CLIPVisionCfg: layers: Union[Tuple[int, int, int, int], int] = 6 width: int = 512 head_width: int = 64 mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) attn_pooler_queries: int = 256 # n_queries for attentional pooler attn_pooler_heads: int = 8 # n heads for attentional_pooling no_ln_pre: bool = False # disable pre transformer LayerNorm pos_embed_type: str = 'none' final_ln_after_pool: bool = True # apply final LayerNorm after pooling pool_type: str = 'none' output_tokens: bool = False act_kwargs: Optional[dict] = None norm_kwargs: Optional[dict] = None timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') timm_proj_bias: bool = False # enable bias final projection timm_drop: float = 0. # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth img_embed: bool = False cls_embed: bool = False projection = False use_flex = True def get_cast_dtype(precision: str): cast_dtype = None if precision == 'bf16': cast_dtype = torch.bfloat16 elif precision == 'fp16': cast_dtype = torch.float16 return cast_dtype def get_input_dtype(precision: str): input_dtype = None if precision in ('bf16', 'pure_bf16'): input_dtype = torch.bfloat16 elif precision in ('fp16', 'pure_fp16'): input_dtype = torch.float16 return input_dtype def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, dropout: float = 0.1, num_registers: int = 0, ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) act_layer = QuickGELU if quick_gelu else nn.GELU vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm if vision_cfg.norm_kwargs: norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) if vision_cfg.act_kwargs is not None: act_layer = partial(act_layer, **vision_cfg.act_kwargs) visual = VisionTransformer( width=vision_cfg.width, layers=vision_cfg.layers, heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, output_dim=embed_dim, patch_dropout=vision_cfg.patch_dropout, no_ln_pre=vision_cfg.no_ln_pre, pool_type=vision_cfg.pool_type, final_ln_after_pool=vision_cfg.final_ln_after_pool, act_layer=act_layer, norm_layer=norm_layer, output_tokens=vision_cfg.output_tokens, img_embed = vision_cfg.img_embed, use_flex = True, dropout = dropout, num_registers = num_registers, use_rel_bias =True, ) return visual class MixedOmicsModel(nn.Module): def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, drop_rate: float = 0.25, num_registers: int = 0, *args, **kwargs, ): super().__init__() self.drop_prob = drop_rate self.num_registers = num_registers vision_cfg.cls_embed = False self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype, dropout=drop_rate, num_registers=0, ) self.image_proj = nn.Linear(embed_dim, embed_dim) self.image_proj.apply(self.init_weights) self.ln_post = LayerNorm(embed_dim) def init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def _check_tensor(self, tensor, name): print(name, " : ", tensor.shape) if torch.isnan(tensor).any(): print(tensor.shape) print(f"Tensor {name} contains NaN values.") if torch.isinf(tensor).any(): print(tensor.shape) print(f"Tensor {name} contains Inf values.") def forward( self, image, coords=None, im_mask=None, *args, **kwargs, ): ## image embedding image_embeds = self.visual(image.contiguous(), coords=coords.contiguous(), key_padding_mask=None if im_mask is None else (~im_mask.bool()).contiguous()) image_embeds = self.ln_post(image_embeds) if im_mask is not None: mask = im_mask.unsqueeze(-1).contiguous() masked_embeds = image_embeds * mask sum_embeds = masked_embeds.sum(dim=1) valid_counts = mask.sum(dim=1).clamp(min=1) # [N, 1] mean_embeds = sum_embeds / valid_counts # [N, dim] else: mean_embeds = image_embeds.mean(-2) image_embeds_final = self.image_proj(mean_embeds) return image_embeds_final, image_embeds, mean_embeds def make_model( embed_dim=768, droprate=0.1, num_registers=0, depth=4, ): vCfg = CLIPVisionCfg vCfg.width = embed_dim vCfg.layers = depth model = MixedOmicsModel( embed_dim=embed_dim, vision_cfg=vCfg, drop_rate=droprate, num_registers=num_registers, ) return model