|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""A config for transferring vit-augreg.
|
|
|
|
Best HP selected on (mini)val, expected test results (repeated 5 times):
|
|
|
|
ViT-Augreg-B/32:
|
|
Dataset, crop, learning rate, mean (%), range (%)
|
|
- ImageNet, inception_crop, 0.03, 83.27, [83.22...83.33]
|
|
- Cifar10, resmall_crop, 0.003, 98.55, [98.46...98.6]
|
|
- Cifar100, resmall_crop, 0.01, 91.35, [91.09...91.62]
|
|
- Pets, inception_crop, 0.003, 93.78, [93.62...94.00]
|
|
- Flowers, inception_crop, 0.003, 99.43, [99.42...99.45]
|
|
|
|
|
|
Command to run:
|
|
big_vision.train \
|
|
--config big_vision/configs/transfer.py:model=vit-i21k-augreg-b/32,dataset=cifar10,crop=resmall_crop \
|
|
--workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'` --config.lr=0.03
|
|
"""
|
|
|
|
import big_vision.configs.common as bvcc
|
|
import ml_collections as mlc
|
|
|
|
|
|
def _set_model(config, model):
|
|
"""Load pre-trained models: vit or bit."""
|
|
|
|
config.model_load = dict(dont_load=['head/kernel', 'head/bias'])
|
|
|
|
if model == 'vit-i21k-augreg-b/32':
|
|
|
|
config.model_name = 'vit'
|
|
config.model_init = 'howto-i21k-B/32'
|
|
config.model = dict(variant='B/32', pool_type='tok')
|
|
elif model == 'vit-i21k-augreg-l/16':
|
|
config.model_name = 'vit'
|
|
config.model_init = 'howto-i21k-L/16'
|
|
config.model = dict(variant='L/16', pool_type='tok')
|
|
elif model == 'vit-s16':
|
|
config.model_name = 'vit'
|
|
config.model_init = 'i1k-s16-300ep'
|
|
config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d',
|
|
rep_size=True)
|
|
elif model == 'bit-m-r50x1':
|
|
config.model_name = 'bit_paper'
|
|
config.model_init = 'M'
|
|
config.model = dict(depth=50, width=1)
|
|
else:
|
|
raise ValueError(f'Unknown model: {model}, please define customized model.')
|
|
|
|
|
|
def _set_dataset(config, dataset, crop='inception_crop', h_res=448, l_res=384):
|
|
if dataset == 'cifar10':
|
|
_set_task(config, 'cifar10', 'train[:98%]', 'train[98%:]', 'test', 10, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
|
|
elif dataset == 'cifar100':
|
|
_set_task(config, 'cifar100', 'train[:98%]', 'train[98%:]', 'test', 100, steps=10_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
|
|
elif dataset == 'imagenet2012':
|
|
_set_task(config, 'imagenet2012', 'train[:99%]', 'train[99%:]', 'validation', 1000, steps=20_000, warmup=500, crop=crop, h_res=h_res, l_res=l_res)
|
|
_set_imagenet_variants(config)
|
|
elif dataset == 'oxford_iiit_pet':
|
|
_set_task(config, 'oxford_iiit_pet', 'train[:90%]', 'train[90%:]', 'test', 37, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res)
|
|
elif dataset == 'oxford_flowers102':
|
|
_set_task(config, 'oxford_flowers102', 'train[:90%]', 'train[90%:]', 'test', 102, steps=500, warmup=100, crop=crop, h_res=h_res, l_res=l_res)
|
|
else:
|
|
raise ValueError(
|
|
f'Unknown dataset: {dataset}, please define customized dataset.')
|
|
|
|
|
|
def _set_task(config, dataset, train, val, test, n_cls,
|
|
steps=20_000, warmup=500, lbl='label', crop='resmall_crop',
|
|
flip=True, h_res=448, l_res=384):
|
|
"""Vision task with val and test splits."""
|
|
config.total_steps = steps
|
|
config.schedule = dict(
|
|
warmup_steps=warmup,
|
|
decay_type='cosine',
|
|
)
|
|
|
|
config.input.data = dict(name=dataset, split=train)
|
|
pp_common = (
|
|
'|value_range(-1, 1)|'
|
|
f'onehot({n_cls}, key="{lbl}", key_result="labels")|'
|
|
'keep("image", "labels")'
|
|
)
|
|
|
|
if crop == 'inception_crop':
|
|
pp_train = f'decode|inception_crop({l_res})'
|
|
elif crop == 'resmall_crop':
|
|
pp_train = f'decode|resize_small({h_res})|random_crop({l_res})'
|
|
elif crop == 'resize_crop':
|
|
pp_train = f'decode|resize({h_res})|random_crop({l_res})'
|
|
else:
|
|
raise ValueError(f'Unknown crop: {crop}. Must be one of: '
|
|
'inception_crop, resmall_crop, resize_crop')
|
|
if flip:
|
|
pp_train += '|flip_lr'
|
|
config.input.pp = pp_train + pp_common
|
|
|
|
pp = f'decode|resize_small({h_res})|central_crop({l_res})' + pp_common
|
|
config.num_classes = n_cls
|
|
|
|
def get_eval(split):
|
|
return dict(
|
|
type='classification',
|
|
data=dict(name=dataset, split=split),
|
|
loss_name='softmax_xent',
|
|
log_steps=100,
|
|
pp_fn=pp,
|
|
)
|
|
config.evals = dict(val=get_eval(val), test=get_eval(test))
|
|
|
|
|
|
def _set_imagenet_variants(config, h_res=448, l_res=384):
|
|
"""Evaluation tasks on ImageNet variants: v2 and real."""
|
|
pp = (f'decode|resize_small({h_res})|central_crop({l_res})'
|
|
'|value_range(-1, 1)|onehot(1000, key="{lbl}", key_result="labels")|'
|
|
'keep("image", "labels")'
|
|
)
|
|
|
|
|
|
config.evals.minival = config.evals.val
|
|
config.evals.val = config.evals.test
|
|
|
|
|
|
config.evals.real = dict(type='classification')
|
|
config.evals.real.data = dict(name='imagenet2012_real', split='validation')
|
|
config.evals.real.pp_fn = pp.format(lbl='real_label')
|
|
config.evals.real.loss_name = config.loss
|
|
config.evals.real.log_steps = 100
|
|
|
|
config.evals.v2 = dict(type='classification')
|
|
config.evals.v2.data = dict(name='imagenet_v2', split='test')
|
|
config.evals.v2.pp_fn = pp.format(lbl='label')
|
|
config.evals.v2.loss_name = config.loss
|
|
config.evals.v2.log_steps = 100
|
|
|
|
|
|
def get_config(arg=None):
|
|
"""Config for adaptation."""
|
|
arg = bvcc.parse_arg(arg, model='vit', dataset='cifar10', crop='resmall_crop',
|
|
h_res=448, l_res=384, batch_size=512, fsdp=False,
|
|
runlocal=False)
|
|
config = mlc.ConfigDict()
|
|
|
|
config.input = {}
|
|
config.input.batch_size = arg.batch_size if not arg.runlocal else 8
|
|
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 100
|
|
|
|
config.log_training_steps = 10
|
|
config.ckpt_steps = 1000
|
|
config.ckpt_timeout = 600
|
|
|
|
|
|
config.optax_name = 'big_vision.momentum_hp'
|
|
config.grad_clip_norm = 1.0
|
|
config.wd = None
|
|
config.loss = 'softmax_xent'
|
|
config.lr = 0.01
|
|
config.mixup = dict(p=0.0)
|
|
|
|
config.seed = 0
|
|
|
|
_set_dataset(config, arg.dataset, arg.crop, arg.h_res, arg.l_res)
|
|
|
|
_set_model(config, arg.model)
|
|
if arg.fsdp:
|
|
config.mesh = [('data', -1)]
|
|
config.sharding_strategy = [('.*', 'fsdp(axis="data")')]
|
|
config.sharding_rules = [('act_batch', ('data',))]
|
|
config.model.scan = True
|
|
|
|
return config |