|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class GlobalClassificationHead(nn.Module): |
|
def __init__(self, input_dim, num_classes=32, dropout_rate=0.1): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(input_dim) |
|
self.dropout = nn.Dropout(dropout_rate) |
|
self.fc = nn.Linear(input_dim, num_classes) |
|
|
|
def forward(self, x): |
|
x = self.norm(x) |
|
x = self.dropout(x) |
|
logits = self.fc(x) |
|
return logits |
|
|
|
|
|
class CLSHead(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim, |
|
num_classes, |
|
dropout=0.1, |
|
use_norm=True, |
|
hidden_dim=None, |
|
activation="silu", |
|
pooling_type="cls", |
|
): |
|
super().__init__() |
|
|
|
if hidden_dim is None: |
|
hidden_dim = embed_dim |
|
|
|
if activation == "gelu": |
|
self.activation = nn.GELU() |
|
elif activation == "relu": |
|
self.activation = nn.ReLU(inplace=True) |
|
elif activation == "silu": |
|
self.activation = nn.SiLU(inplace=True) |
|
else: |
|
raise ValueError(f"Value Error: {activation}") |
|
|
|
self.pooling_type = pooling_type |
|
|
|
if pooling_type == "attention": |
|
self.attention_pool = nn.Sequential( |
|
nn.LayerNorm(embed_dim), |
|
nn.Linear(embed_dim, 1), |
|
|
|
) |
|
|
|
if use_norm: |
|
self.norm = nn.LayerNorm(embed_dim) |
|
else: |
|
self.norm = nn.Identity() |
|
|
|
self.fc1 = nn.Linear(embed_dim, hidden_dim) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.fc2 = nn.Linear(hidden_dim, num_classes) |
|
|
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
nn.init.trunc_normal_(self.fc1.weight, std=0.02) |
|
nn.init.zeros_(self.fc1.bias) |
|
nn.init.trunc_normal_(self.fc2.weight, std=0.02) |
|
nn.init.zeros_(self.fc2.bias) |
|
|
|
if self.pooling_type == "attention": |
|
nn.init.trunc_normal_(self.attention_pool[0].weight, std=0.02) |
|
nn.init.zeros_(self.attention_pool[0].bias) |
|
|
|
def forward(self, x, x2=None, x3=None, attn_mask=None): |
|
if self.pooling_type == "mlp": |
|
pooled = x + x3 |
|
elif self.pooling_type == "attention": |
|
x = torch.cat([x.unsqueeze(1), x2], dim=1) |
|
weights = self.attention_pool(x) |
|
|
|
if attn_mask is not None: |
|
attn_mask = torch.cat([torch.zeros(attn_mask.shape[0], 1).to(attn_mask.dtype).to(attn_mask.device), attn_mask], dim=1) |
|
new_attn_mask = torch.zeros_like(attn_mask, dtype=weights.dtype) |
|
new_attn_mask.masked_fill_(attn_mask, float("-inf")) |
|
attn_mask = new_attn_mask |
|
|
|
if len(attn_mask.shape) ==2: |
|
attn_mask = attn_mask[..., None] |
|
|
|
weights = weights + attn_mask |
|
weights = F.softmax(weights, dim=1) |
|
|
|
pooled = torch.sum(x * weights, dim=1) |
|
else: |
|
raise ValueError(f"지원하지 않는 풀링 타입: {self.pooling_type}") |
|
|
|
pooled = self.norm(pooled) |
|
|
|
x = self.fc1(pooled) |
|
x = self.activation(x) |
|
x = self.dropout1(x) |
|
x = self.fc2(x) |
|
|
|
return x |
|
|
|
|
|
class BaseClassifier(nn.Module): |
|
def __init__( |
|
self, |
|
base_model, |
|
num_classes, |
|
cls_head_kwargs=None |
|
): |
|
super().__init__() |
|
|
|
self.backbone = base_model |
|
embed_dim = self.backbone.embed_dim if hasattr(self.backbone, 'embed_dim') else 768 |
|
|
|
cls_head_config = { |
|
'embed_dim': embed_dim, |
|
'num_classes': num_classes, |
|
} |
|
|
|
if cls_head_kwargs: |
|
cls_head_config.update(cls_head_kwargs) |
|
|
|
self.cls_head = CLSHead(**cls_head_config) |
|
|
|
def forward(self, image, coords=None, im_mask=None): |
|
feat_final, feats, m_feats = self.backbone(image=image, coords=coords, im_mask=im_mask) |
|
logits = self.cls_head(feat_final, feats, m_feats, attn_mask=~im_mask.bool()) |
|
|
|
Y_hat = torch.topk(logits, 1, dim=1)[1] |
|
Y_prob = F.softmax(logits, dim=-1) |
|
|
|
return logits, Y_prob, Y_hat |
|
|
|
|
|
class LinearClassifier(nn.Module): |
|
def __init__( |
|
self, |
|
base_model, |
|
num_classes=2, |
|
pool="final" |
|
): |
|
super().__init__() |
|
self.backbone = base_model |
|
self.pool = pool |
|
print("Linear classifier pool type : ", self.pool) |
|
embed_dim = self.backbone.embed_dim if hasattr(self.backbone, 'embed_dim') else 768 |
|
|
|
self.cls_head = nn.Sequential( |
|
nn.LayerNorm(embed_dim), |
|
nn.Linear(embed_dim, num_classes) |
|
) |
|
|
|
def forward(self, image, coords=None, im_mask=None): |
|
feat_final, _, feat_mean = self.backbone(image=image, coords=coords, im_mask=im_mask) |
|
if self.pool == "mean": |
|
logits = self.cls_head(feat_mean) |
|
elif self.pool == "final": |
|
logits = self.cls_head(feat_final) |
|
|
|
Y_hat = torch.topk(logits, 1, dim=1)[1] |
|
Y_prob = F.softmax(logits, dim=-1) |
|
|
|
return logits, Y_prob, Y_hat |
|
|