|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""MLP-Mixer model."""
|
|
|
|
from typing import Optional, Tuple
|
|
from absl import logging
|
|
|
|
from big_vision import utils
|
|
from big_vision.models import common
|
|
|
|
import einops
|
|
import flax.linen as nn
|
|
import flax.training.checkpoints
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
|
|
class MlpBlock(nn.Module):
|
|
mlp_dim: int
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
y = nn.Dense(self.mlp_dim)(x)
|
|
y = nn.gelu(y)
|
|
return nn.Dense(x.shape[-1])(y)
|
|
|
|
|
|
class MixerBlock(nn.Module):
|
|
"""Mixer block layer."""
|
|
tokens_mlp_dim: int
|
|
channels_mlp_dim: int
|
|
drop_p: float
|
|
|
|
@nn.compact
|
|
def __call__(self, x, *, train=False):
|
|
y = nn.LayerNorm()(x)
|
|
y = jnp.swapaxes(y, 1, 2)
|
|
y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y)
|
|
y = jnp.swapaxes(y, 1, 2)
|
|
x = x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng)
|
|
y = nn.LayerNorm()(x)
|
|
y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y)
|
|
return x + y * _stoch_depth_mask(x, self.drop_p, not train, self.make_rng)
|
|
|
|
|
|
class MlpMixer(nn.Module):
|
|
"""Mixer architecture."""
|
|
patch_size: Tuple[int, int]
|
|
num_classes: Optional[int]
|
|
num_blocks: int
|
|
hidden_dim: int
|
|
tokens_mlp_dim: int
|
|
channels_mlp_dim: int
|
|
model_name: Optional[str] = None
|
|
stoch_depth: float = 0.0
|
|
|
|
@nn.compact
|
|
def __call__(self, image, *, train=False):
|
|
out = {}
|
|
x = out["stem"] = nn.Conv(self.hidden_dim, self.patch_size,
|
|
strides=self.patch_size, name="stem")(image)
|
|
x = out["input_tokens"] = einops.rearrange(x, "n h w c -> n (h w) c")
|
|
for i in range(self.num_blocks):
|
|
drop_p = (i / max(self.num_blocks - 1, 1)) * self.stoch_depth
|
|
x = out[f"block_{i}"] = MixerBlock(
|
|
self.tokens_mlp_dim, self.channels_mlp_dim, drop_p)(x, train=train)
|
|
x = nn.LayerNorm(name="pre_head_layer_norm")(x)
|
|
x = out["pre_logits"] = jnp.mean(x, axis=1)
|
|
if self.num_classes:
|
|
x = out["logits"] = nn.Dense(
|
|
self.num_classes, kernel_init=nn.initializers.zeros, name="head")(x)
|
|
return x, out
|
|
|
|
|
|
def Model(num_classes=None, *, variant=None, **kw):
|
|
"""Factory function to easily create a Model variant like "L/16"."""
|
|
|
|
if variant is not None:
|
|
model_size, patch = variant.split("/")
|
|
kw.setdefault("patch_size", (int(patch), int(patch)))
|
|
config = {
|
|
"S": {
|
|
"hidden_dim": 512,
|
|
"num_blocks": 8,
|
|
"channels_mlp_dim": 2048,
|
|
"tokens_mlp_dim": 256
|
|
},
|
|
"B": {
|
|
"hidden_dim": 768,
|
|
"num_blocks": 12,
|
|
"channels_mlp_dim": 3072,
|
|
"tokens_mlp_dim": 384
|
|
},
|
|
"L": {
|
|
"hidden_dim": 1024,
|
|
"num_blocks": 24,
|
|
"channels_mlp_dim": 4096,
|
|
"tokens_mlp_dim": 512
|
|
},
|
|
"H": {
|
|
"hidden_dim": 1280,
|
|
"num_blocks": 32,
|
|
"channels_mlp_dim": 5120,
|
|
"tokens_mlp_dim": 640
|
|
},
|
|
}[model_size]
|
|
|
|
for k, v in config.items():
|
|
kw.setdefault(k, v)
|
|
|
|
logging.info("Mixer config: %s", kw)
|
|
return MlpMixer(num_classes=num_classes, **kw)
|
|
|
|
|
|
def load(init_params, init_file, model_cfg, dont_load=()):
|
|
"""Load checkpoint."""
|
|
|
|
del model_cfg
|
|
|
|
init_file = {
|
|
|
|
|
|
"B-i1k/16": "gs://mixer_models/imagenet1k/Mixer-B_16.npz",
|
|
"L-i1k/16": "gs://mixer_models/imagenet1k/Mixer-L_16.npz",
|
|
"B-i21k/16": "gs://mixer_models/imagenet21k/Mixer-B_16.npz",
|
|
"L-i21k/16": "gs://mixer_models/imagenet21k/Mixer-L_16.npz",
|
|
|
|
}.get(init_file, init_file)
|
|
restored_params = utils.load_params(init_file)
|
|
restored_params = flax.training.checkpoints.convert_pre_linen(restored_params)
|
|
|
|
if "Mixer" in restored_params:
|
|
restored_params["pre_head_layer_norm"] = restored_params["Mixer"].pop(
|
|
"encoder_norm"
|
|
)
|
|
restored_params["stem"] = restored_params.pop("embedding")
|
|
def unflatten_dense(d):
|
|
return {
|
|
"Dense_0": {
|
|
"bias": d["bias1"].squeeze(),
|
|
"kernel": d["kernel1"].squeeze(),
|
|
},
|
|
"Dense_1": {
|
|
"bias": d["bias2"].squeeze(),
|
|
"kernel": d["kernel2"].squeeze(),
|
|
},
|
|
}
|
|
for k, v in restored_params["Mixer"].items():
|
|
assert k.startswith("encoderblock_"), k
|
|
v["token_mixing"] = unflatten_dense(v.pop("token_mixing_phase_0"))
|
|
v["channel_mixing"] = unflatten_dense(v.pop("channel_mixing_phase_0"))
|
|
restored_params["MixerBlock_" + k[len("encoderblock_"):]] = v
|
|
del restored_params["Mixer"]
|
|
|
|
|
|
restored_params = common.merge_params(restored_params, init_params, dont_load)
|
|
|
|
return restored_params
|
|
|
|
|
|
def _stoch_depth_mask(x, drop_p, deterministic, make_rng):
|
|
if not deterministic and drop_p:
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
return 1.0 - jax.random.bernoulli(make_rng("dropout"), drop_p, shape)
|
|
return 1.0
|
|
|