import torch from torch import nn import torch.nn.functional as F class GlobalClassificationHead(nn.Module): def __init__(self, input_dim, num_classes=32, dropout_rate=0.1): super().__init__() self.norm = nn.LayerNorm(input_dim) self.dropout = nn.Dropout(dropout_rate) self.fc = nn.Linear(input_dim, num_classes) def forward(self, x): x = self.norm(x) x = self.dropout(x) logits = self.fc(x) # (batch_size, num_classes) return logits class CLSHead(nn.Module): def __init__( self, embed_dim, num_classes, dropout=0.1, use_norm=True, hidden_dim=None, activation="silu", pooling_type="cls", ): super().__init__() if hidden_dim is None: hidden_dim = embed_dim if activation == "gelu": self.activation = nn.GELU() elif activation == "relu": self.activation = nn.ReLU(inplace=True) elif activation == "silu": self.activation = nn.SiLU(inplace=True) else: raise ValueError(f"Value Error: {activation}") self.pooling_type = pooling_type if pooling_type == "attention": self.attention_pool = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1), # nn.Softmax(dim=1) ) if use_norm: self.norm = nn.LayerNorm(embed_dim) else: self.norm = nn.Identity() self.fc1 = nn.Linear(embed_dim, hidden_dim) self.dropout1 = nn.Dropout(dropout) self.fc2 = nn.Linear(hidden_dim, num_classes) self._init_weights() def _init_weights(self): nn.init.trunc_normal_(self.fc1.weight, std=0.02) nn.init.zeros_(self.fc1.bias) nn.init.trunc_normal_(self.fc2.weight, std=0.02) nn.init.zeros_(self.fc2.bias) if self.pooling_type == "attention": nn.init.trunc_normal_(self.attention_pool[0].weight, std=0.02) nn.init.zeros_(self.attention_pool[0].bias) def forward(self, x, x2=None, x3=None, attn_mask=None): if self.pooling_type == "mlp": pooled = x + x3 elif self.pooling_type == "attention": x = torch.cat([x.unsqueeze(1), x2], dim=1) weights = self.attention_pool(x) # [batch_size, num_tokens, 1] if attn_mask is not None: attn_mask = torch.cat([torch.zeros(attn_mask.shape[0], 1).to(attn_mask.dtype).to(attn_mask.device), attn_mask], dim=1) new_attn_mask = torch.zeros_like(attn_mask, dtype=weights.dtype) new_attn_mask.masked_fill_(attn_mask, float("-inf")) attn_mask = new_attn_mask if len(attn_mask.shape) ==2: attn_mask = attn_mask[..., None] # [batch_size, num_tokens, 1] weights = weights + attn_mask weights = F.softmax(weights, dim=1) pooled = torch.sum(x * weights, dim=1) else: raise ValueError(f"지원하지 않는 풀링 타입: {self.pooling_type}") pooled = self.norm(pooled) x = self.fc1(pooled) x = self.activation(x) x = self.dropout1(x) x = self.fc2(x) return x class BaseClassifier(nn.Module): def __init__( self, base_model, num_classes, cls_head_kwargs=None ): super().__init__() self.backbone = base_model embed_dim = self.backbone.embed_dim if hasattr(self.backbone, 'embed_dim') else 768 cls_head_config = { 'embed_dim': embed_dim, 'num_classes': num_classes, } if cls_head_kwargs: cls_head_config.update(cls_head_kwargs) self.cls_head = CLSHead(**cls_head_config) def forward(self, image, coords=None, im_mask=None): feat_final, feats, m_feats = self.backbone(image=image, coords=coords, im_mask=im_mask) logits = self.cls_head(feat_final, feats, m_feats, attn_mask=~im_mask.bool()) Y_hat = torch.topk(logits, 1, dim=1)[1] Y_prob = F.softmax(logits, dim=-1) return logits, Y_prob, Y_hat class LinearClassifier(nn.Module): def __init__( self, base_model, num_classes=2, pool="final" ): super().__init__() self.backbone = base_model self.pool = pool print("Linear classifier pool type : ", self.pool) embed_dim = self.backbone.embed_dim if hasattr(self.backbone, 'embed_dim') else 768 self.cls_head = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes) ) def forward(self, image, coords=None, im_mask=None): feat_final, _, feat_mean = self.backbone(image=image, coords=coords, im_mask=im_mask) if self.pool == "mean": logits = self.cls_head(feat_mean) elif self.pool == "final": logits = self.cls_head(feat_final) Y_hat = torch.topk(logits, 1, dim=1)[1] Y_prob = F.softmax(logits, dim=-1) return logits, Y_prob, Y_hat