2ms's picture
init commit
03ae676
import torch
from torch import nn
import torch.nn.functional as F
class GlobalClassificationHead(nn.Module):
def __init__(self, input_dim, num_classes=32, dropout_rate=0.1):
super().__init__()
self.norm = nn.LayerNorm(input_dim)
self.dropout = nn.Dropout(dropout_rate)
self.fc = nn.Linear(input_dim, num_classes)
def forward(self, x):
x = self.norm(x)
x = self.dropout(x)
logits = self.fc(x) # (batch_size, num_classes)
return logits
class CLSHead(nn.Module):
def __init__(
self,
embed_dim,
num_classes,
dropout=0.1,
use_norm=True,
hidden_dim=None,
activation="silu",
pooling_type="cls",
):
super().__init__()
if hidden_dim is None:
hidden_dim = embed_dim
if activation == "gelu":
self.activation = nn.GELU()
elif activation == "relu":
self.activation = nn.ReLU(inplace=True)
elif activation == "silu":
self.activation = nn.SiLU(inplace=True)
else:
raise ValueError(f"Value Error: {activation}")
self.pooling_type = pooling_type
if pooling_type == "attention":
self.attention_pool = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 1),
# nn.Softmax(dim=1)
)
if use_norm:
self.norm = nn.LayerNorm(embed_dim)
else:
self.norm = nn.Identity()
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.dropout1 = nn.Dropout(dropout)
self.fc2 = nn.Linear(hidden_dim, num_classes)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.fc1.weight, std=0.02)
nn.init.zeros_(self.fc1.bias)
nn.init.trunc_normal_(self.fc2.weight, std=0.02)
nn.init.zeros_(self.fc2.bias)
if self.pooling_type == "attention":
nn.init.trunc_normal_(self.attention_pool[0].weight, std=0.02)
nn.init.zeros_(self.attention_pool[0].bias)
def forward(self, x, x2=None, x3=None, attn_mask=None):
if self.pooling_type == "mlp":
pooled = x + x3
elif self.pooling_type == "attention":
x = torch.cat([x.unsqueeze(1), x2], dim=1)
weights = self.attention_pool(x) # [batch_size, num_tokens, 1]
if attn_mask is not None:
attn_mask = torch.cat([torch.zeros(attn_mask.shape[0], 1).to(attn_mask.dtype).to(attn_mask.device), attn_mask], dim=1)
new_attn_mask = torch.zeros_like(attn_mask, dtype=weights.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
if len(attn_mask.shape) ==2:
attn_mask = attn_mask[..., None] # [batch_size, num_tokens, 1]
weights = weights + attn_mask
weights = F.softmax(weights, dim=1)
pooled = torch.sum(x * weights, dim=1)
else:
raise ValueError(f"지원하지 않는 풀링 타입: {self.pooling_type}")
pooled = self.norm(pooled)
x = self.fc1(pooled)
x = self.activation(x)
x = self.dropout1(x)
x = self.fc2(x)
return x
class BaseClassifier(nn.Module):
def __init__(
self,
base_model,
num_classes,
cls_head_kwargs=None
):
super().__init__()
self.backbone = base_model
embed_dim = self.backbone.embed_dim if hasattr(self.backbone, 'embed_dim') else 768
cls_head_config = {
'embed_dim': embed_dim,
'num_classes': num_classes,
}
if cls_head_kwargs:
cls_head_config.update(cls_head_kwargs)
self.cls_head = CLSHead(**cls_head_config)
def forward(self, image, coords=None, im_mask=None):
feat_final, feats, m_feats = self.backbone(image=image, coords=coords, im_mask=im_mask)
logits = self.cls_head(feat_final, feats, m_feats, attn_mask=~im_mask.bool())
Y_hat = torch.topk(logits, 1, dim=1)[1]
Y_prob = F.softmax(logits, dim=-1)
return logits, Y_prob, Y_hat
class LinearClassifier(nn.Module):
def __init__(
self,
base_model,
num_classes=2,
pool="final"
):
super().__init__()
self.backbone = base_model
self.pool = pool
print("Linear classifier pool type : ", self.pool)
embed_dim = self.backbone.embed_dim if hasattr(self.backbone, 'embed_dim') else 768
self.cls_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, num_classes)
)
def forward(self, image, coords=None, im_mask=None):
feat_final, _, feat_mean = self.backbone(image=image, coords=coords, im_mask=im_mask)
if self.pool == "mean":
logits = self.cls_head(feat_mean)
elif self.pool == "final":
logits = self.cls_head(feat_final)
Y_hat = torch.topk(logits, 1, dim=1)[1]
Y_prob = F.softmax(logits, dim=-1)
return logits, Y_prob, Y_hat