|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""BiT models as in the paper (ResNet V2) w/ loading of public weights.
|
|
|
|
See reproduction proof: http://(internal link)/qY70qs6j944
|
|
"""
|
|
|
|
import functools
|
|
import re
|
|
from typing import Optional, Sequence, Union
|
|
|
|
from big_vision import utils as u
|
|
from big_vision.models import bit
|
|
from big_vision.models import common
|
|
import flax.linen as nn
|
|
import jax.numpy as jnp
|
|
|
|
|
|
def standardize(x, axis, eps):
|
|
x = x - jnp.mean(x, axis=axis, keepdims=True)
|
|
x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps)
|
|
return x
|
|
|
|
|
|
|
|
|
|
class GroupNorm(nn.Module):
|
|
"""Group normalization (arxiv.org/abs/1803.08494)."""
|
|
ngroups: int = 32
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
|
|
input_shape = x.shape
|
|
group_shape = x.shape[:-1] + (self.ngroups, x.shape[-1] // self.ngroups)
|
|
|
|
x = x.reshape(group_shape)
|
|
|
|
|
|
x = standardize(x, axis=[1, 2, 4], eps=1e-5)
|
|
x = x.reshape(input_shape)
|
|
|
|
bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]])
|
|
x = x * self.param('scale', nn.initializers.ones, bias_scale_shape)
|
|
x = x + self.param('bias', nn.initializers.zeros, bias_scale_shape)
|
|
return x
|
|
|
|
|
|
class StdConv(nn.Conv):
|
|
|
|
def param(self, name, *a, **kw):
|
|
param = super().param(name, *a, **kw)
|
|
if name == 'kernel':
|
|
param = standardize(param, axis=[0, 1, 2], eps=1e-10)
|
|
return param
|
|
|
|
|
|
class RootBlock(nn.Module):
|
|
"""Root block of ResNet."""
|
|
width: int
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
x = StdConv(self.width, (7, 7), (2, 2), padding=[(3, 3), (3, 3)],
|
|
use_bias=False, name='conv_root')(x)
|
|
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)])
|
|
return x
|
|
|
|
|
|
class ResidualUnit(nn.Module):
|
|
"""Bottleneck ResNet block."""
|
|
nmid: Optional[int] = None
|
|
strides: Sequence[int] = (1, 1)
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
nmid = self.nmid or x.shape[-1] // 4
|
|
nout = nmid * 4
|
|
conv = functools.partial(StdConv, use_bias=False)
|
|
|
|
residual = x
|
|
x = GroupNorm(name='gn1')(x)
|
|
x = nn.relu(x)
|
|
|
|
if x.shape[-1] != nout or self.strides != (1, 1):
|
|
residual = conv(nout, (1, 1), self.strides, name='conv_proj')(x)
|
|
|
|
x = conv(nmid, (1, 1), name='conv1')(x)
|
|
x = GroupNorm(name='gn2')(x)
|
|
x = nn.relu(x)
|
|
x = conv(nmid, (3, 3), self.strides, padding=[(1, 1), (1, 1)],
|
|
name='conv2')(x)
|
|
x = GroupNorm(name='gn3')(x)
|
|
x = nn.relu(x)
|
|
x = conv(nout, (1, 1), name='conv3')(x)
|
|
|
|
return x + residual
|
|
|
|
|
|
class ResNetStage(nn.Module):
|
|
"""A stage (sequence of same-resolution blocks)."""
|
|
block_size: int
|
|
nmid: Optional[int] = None
|
|
first_stride: Sequence[int] = (1, 1)
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
out = {}
|
|
x = out['unit01'] = ResidualUnit(
|
|
self.nmid, strides=self.first_stride, name='unit01')(x)
|
|
for i in range(1, self.block_size):
|
|
x = out[f'unit{i+1:02d}'] = ResidualUnit(
|
|
self.nmid, name=f'unit{i+1:02d}')(x)
|
|
return x, out
|
|
|
|
|
|
class Model(nn.Module):
|
|
"""ResNetV2."""
|
|
num_classes: Optional[int] = None
|
|
width: int = 1
|
|
depth: Union[int, Sequence[int]] = 50
|
|
head_zeroinit: bool = True
|
|
|
|
@nn.compact
|
|
def __call__(self, image, *, train=False):
|
|
blocks = bit.get_block_desc(self.depth)
|
|
width = int(64 * self.width)
|
|
out = {}
|
|
|
|
x = out['stem'] = RootBlock(width=width, name='root_block')(image)
|
|
|
|
|
|
x, out['stage1'] = ResNetStage(blocks[0], nmid=width, name='block1')(x)
|
|
for i, block_size in enumerate(blocks[1:], 1):
|
|
x, out[f'stage{i + 1}'] = ResNetStage(
|
|
block_size, width * 2 ** i,
|
|
first_stride=(2, 2), name=f'block{i + 1}')(x)
|
|
|
|
|
|
x = out['norm_pre_head'] = GroupNorm(name='norm-pre-head')(x)
|
|
x = out['pre_logits_2d'] = nn.relu(x)
|
|
x = out['pre_logits'] = jnp.mean(x, axis=(1, 2))
|
|
|
|
|
|
if self.num_classes:
|
|
kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {}
|
|
head = nn.Dense(self.num_classes, name='head', **kw)
|
|
out['logits_2d'] = head(out['pre_logits_2d'])
|
|
x = out['logits'] = head(out['pre_logits'])
|
|
|
|
return x, out
|
|
|
|
|
|
def load(init_params, init_file, model_cfg, dont_load=()):
|
|
"""Loads the TF-dumped NumPy or big_vision checkpoint.
|
|
|
|
Args:
|
|
init_params: random init params from which the new head is taken.
|
|
init_file: comes from `config.model_init`, can either be an absolute
|
|
path (ie starts with /) to the checkpoint, or a string like
|
|
"L-imagenet2012" describing one of the variants from the paper.
|
|
model_cfg: the model configuration.
|
|
dont_load: list of param names to be reset to init.
|
|
|
|
Returns:
|
|
The loaded parameters.
|
|
"""
|
|
|
|
|
|
vanity = {
|
|
'FunMatch-224px-i1k82.8': 'gs://bit_models/distill/R50x1_224.npz',
|
|
'FunMatch-160px-i1k80.5': 'gs://bit_models/distill/R50x1_160.npz',
|
|
}
|
|
if init_file[0] in ('L', 'M', 'S'):
|
|
|
|
|
|
|
|
|
|
|
|
if '-' in init_file:
|
|
up, down = init_file[0], init_file[1:]
|
|
else:
|
|
up, down = init_file, ''
|
|
down = {'-imagenet2012': '-ILSVRC2012'}.get(down, down)
|
|
fname = f'BiT-{up}-R{model_cfg.depth}x{model_cfg.width}{down}.npz'
|
|
fname = f'gs://bit_models/{fname}'
|
|
else:
|
|
fname = vanity.get(init_file, init_file)
|
|
|
|
params = u.load_params(fname)
|
|
params = maybe_convert_big_transfer_format(params)
|
|
return common.merge_params(params, init_params, dont_load)
|
|
|
|
|
|
def maybe_convert_big_transfer_format(params_tf):
|
|
"""If the checkpoint comes from legacy codebase, convert it."""
|
|
|
|
|
|
if 'resnet' not in params_tf:
|
|
return params_tf
|
|
|
|
|
|
params_tf = dict(u.tree_flatten_with_names(params_tf)[0])
|
|
|
|
|
|
for k in list(params_tf):
|
|
k2 = re.sub('/standardized_conv2d_\\d+/', '/standardized_conv2d/', k)
|
|
if k2 != k:
|
|
params_tf[k2] = params_tf[k]
|
|
del params_tf[k]
|
|
|
|
params = {
|
|
'root_block': {'conv_root': {'kernel': params_tf[
|
|
'resnet/root_block/standardized_conv2d/kernel']}},
|
|
'norm-pre-head': {
|
|
'bias': params_tf['resnet/group_norm/beta'][None, None, None],
|
|
'scale': params_tf['resnet/group_norm/gamma'][None, None, None],
|
|
},
|
|
'head': {
|
|
'kernel': params_tf['resnet/head/conv2d/kernel'][0, 0],
|
|
'bias': params_tf['resnet/head/conv2d/bias'],
|
|
}
|
|
}
|
|
|
|
for block in ('block1', 'block2', 'block3', 'block4'):
|
|
params[block] = {}
|
|
units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys()
|
|
if p.find(block) >= 0])
|
|
for unit in units:
|
|
params[block][unit] = {}
|
|
for i, group in enumerate('abc', 1):
|
|
params[block][unit][f'conv{i}'] = {
|
|
'kernel': params_tf[f'resnet/{block}/{unit}/{group}/standardized_conv2d/kernel']
|
|
}
|
|
params[block][unit][f'gn{i}'] = {
|
|
'bias': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/beta'][None, None, None],
|
|
'scale': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/gamma'][None, None, None],
|
|
}
|
|
|
|
projs = [p for p in params_tf.keys()
|
|
if p.find(f'{block}/{unit}/a/proj') >= 0]
|
|
assert len(projs) <= 1
|
|
if projs:
|
|
params[block][unit]['conv_proj'] = {
|
|
'kernel': params_tf[projs[0]]
|
|
}
|
|
|
|
return params
|
|
|