File size: 8,908 Bytes
fa1a600 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BiT models as in the paper (ResNet V2) w/ loading of public weights.
See reproduction proof: http://(internal link)/qY70qs6j944
"""
import functools
import re
from typing import Optional, Sequence, Union
from big_vision import utils as u
from big_vision.models import bit
from big_vision.models import common
import flax.linen as nn
import jax.numpy as jnp
def standardize(x, axis, eps):
x = x - jnp.mean(x, axis=axis, keepdims=True)
x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps)
return x
# Defined our own, because we compute normalizing variance slightly differently,
# which does affect performance when loading pre-trained weights!
class GroupNorm(nn.Module):
"""Group normalization (arxiv.org/abs/1803.08494)."""
ngroups: int = 32
@nn.compact
def __call__(self, x):
input_shape = x.shape
group_shape = x.shape[:-1] + (self.ngroups, x.shape[-1] // self.ngroups)
x = x.reshape(group_shape)
# Standardize along spatial and group dimensions
x = standardize(x, axis=[1, 2, 4], eps=1e-5)
x = x.reshape(input_shape)
bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]])
x = x * self.param('scale', nn.initializers.ones, bias_scale_shape)
x = x + self.param('bias', nn.initializers.zeros, bias_scale_shape)
return x
class StdConv(nn.Conv):
def param(self, name, *a, **kw):
param = super().param(name, *a, **kw)
if name == 'kernel':
param = standardize(param, axis=[0, 1, 2], eps=1e-10)
return param
class RootBlock(nn.Module):
"""Root block of ResNet."""
width: int
@nn.compact
def __call__(self, x):
x = StdConv(self.width, (7, 7), (2, 2), padding=[(3, 3), (3, 3)],
use_bias=False, name='conv_root')(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)])
return x
class ResidualUnit(nn.Module):
"""Bottleneck ResNet block."""
nmid: Optional[int] = None
strides: Sequence[int] = (1, 1)
@nn.compact
def __call__(self, x):
nmid = self.nmid or x.shape[-1] // 4
nout = nmid * 4
conv = functools.partial(StdConv, use_bias=False)
residual = x
x = GroupNorm(name='gn1')(x)
x = nn.relu(x)
if x.shape[-1] != nout or self.strides != (1, 1):
residual = conv(nout, (1, 1), self.strides, name='conv_proj')(x)
x = conv(nmid, (1, 1), name='conv1')(x)
x = GroupNorm(name='gn2')(x)
x = nn.relu(x)
x = conv(nmid, (3, 3), self.strides, padding=[(1, 1), (1, 1)],
name='conv2')(x)
x = GroupNorm(name='gn3')(x)
x = nn.relu(x)
x = conv(nout, (1, 1), name='conv3')(x)
return x + residual
class ResNetStage(nn.Module):
"""A stage (sequence of same-resolution blocks)."""
block_size: int
nmid: Optional[int] = None
first_stride: Sequence[int] = (1, 1)
@nn.compact
def __call__(self, x):
out = {}
x = out['unit01'] = ResidualUnit(
self.nmid, strides=self.first_stride, name='unit01')(x)
for i in range(1, self.block_size):
x = out[f'unit{i+1:02d}'] = ResidualUnit(
self.nmid, name=f'unit{i+1:02d}')(x)
return x, out
class Model(nn.Module):
"""ResNetV2."""
num_classes: Optional[int] = None
width: int = 1
depth: Union[int, Sequence[int]] = 50 # 50/101/152, or list of block depths.
head_zeroinit: bool = True
@nn.compact
def __call__(self, image, *, train=False):
blocks = bit.get_block_desc(self.depth)
width = int(64 * self.width)
out = {}
x = out['stem'] = RootBlock(width=width, name='root_block')(image)
# Blocks
x, out['stage1'] = ResNetStage(blocks[0], nmid=width, name='block1')(x)
for i, block_size in enumerate(blocks[1:], 1):
x, out[f'stage{i + 1}'] = ResNetStage(
block_size, width * 2 ** i,
first_stride=(2, 2), name=f'block{i + 1}')(x)
# Pre-head
x = out['norm_pre_head'] = GroupNorm(name='norm-pre-head')(x)
x = out['pre_logits_2d'] = nn.relu(x)
x = out['pre_logits'] = jnp.mean(x, axis=(1, 2))
# Head
if self.num_classes:
kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {}
head = nn.Dense(self.num_classes, name='head', **kw)
out['logits_2d'] = head(out['pre_logits_2d'])
x = out['logits'] = head(out['pre_logits'])
return x, out
def load(init_params, init_file, model_cfg, dont_load=()):
"""Loads the TF-dumped NumPy or big_vision checkpoint.
Args:
init_params: random init params from which the new head is taken.
init_file: comes from `config.model_init`, can either be an absolute
path (ie starts with /) to the checkpoint, or a string like
"L-imagenet2012" describing one of the variants from the paper.
model_cfg: the model configuration.
dont_load: list of param names to be reset to init.
Returns:
The loaded parameters.
"""
# Support for vanity model names from the paper.
vanity = {
'FunMatch-224px-i1k82.8': 'gs://bit_models/distill/R50x1_224.npz',
'FunMatch-160px-i1k80.5': 'gs://bit_models/distill/R50x1_160.npz',
}
if init_file[0] in ('L', 'M', 'S'): # The models from the original paper.
# Supported names are of the following type:
# - 'M' or 'S': the original "upstream" model without fine-tuning.
# - 'M-ILSVRC2012': i21k model fine-tuned on i1k.
# - 'M-run0-caltech101': i21k model fine-tuned on VTAB's caltech101.
# each VTAB fine-tuning was run 3x, so there's run0, run1, run2.
if '-' in init_file:
up, down = init_file[0], init_file[1:]
else:
up, down = init_file, ''
down = {'-imagenet2012': '-ILSVRC2012'}.get(down, down) # normalize
fname = f'BiT-{up}-R{model_cfg.depth}x{model_cfg.width}{down}.npz'
fname = f'gs://bit_models/{fname}'
else:
fname = vanity.get(init_file, init_file)
params = u.load_params(fname)
params = maybe_convert_big_transfer_format(params)
return common.merge_params(params, init_params, dont_load)
def maybe_convert_big_transfer_format(params_tf):
"""If the checkpoint comes from legacy codebase, convert it."""
# Only do anything at all if we recognize the format.
if 'resnet' not in params_tf:
return params_tf
# For ease of processing and backwards compatibility, flatten again:
params_tf = dict(u.tree_flatten_with_names(params_tf)[0])
# Works around some files containing weird naming of variables:
for k in list(params_tf):
k2 = re.sub('/standardized_conv2d_\\d+/', '/standardized_conv2d/', k)
if k2 != k:
params_tf[k2] = params_tf[k]
del params_tf[k]
params = {
'root_block': {'conv_root': {'kernel': params_tf[
'resnet/root_block/standardized_conv2d/kernel']}},
'norm-pre-head': {
'bias': params_tf['resnet/group_norm/beta'][None, None, None],
'scale': params_tf['resnet/group_norm/gamma'][None, None, None],
},
'head': {
'kernel': params_tf['resnet/head/conv2d/kernel'][0, 0],
'bias': params_tf['resnet/head/conv2d/bias'],
}
}
for block in ('block1', 'block2', 'block3', 'block4'):
params[block] = {}
units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys()
if p.find(block) >= 0])
for unit in units:
params[block][unit] = {}
for i, group in enumerate('abc', 1):
params[block][unit][f'conv{i}'] = {
'kernel': params_tf[f'resnet/{block}/{unit}/{group}/standardized_conv2d/kernel'] # pylint: disable=line-too-long
}
params[block][unit][f'gn{i}'] = {
'bias': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/beta'][None, None, None], # pylint: disable=line-too-long
'scale': params_tf[f'resnet/{block}/{unit}/{group}/group_norm/gamma'][None, None, None], # pylint: disable=line-too-long
}
projs = [p for p in params_tf.keys()
if p.find(f'{block}/{unit}/a/proj') >= 0]
assert len(projs) <= 1
if projs:
params[block][unit]['conv_proj'] = {
'kernel': params_tf[projs[0]]
}
return params
|