|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Trains CLIP with Pixels Only (CLIPPO), https://arxiv.org/abs/2212.08045
|
|
|
|
IMPORTANT NOTE: This config uses coco_captions by default for demonstration
|
|
purposes since the TFDS catalog does not provide any large image/alt-text data
|
|
set; the training will not produce a model with useful accuracy. Please
|
|
replace the data set below (marked by a comment) with an appropriate image/
|
|
alt-text data set wrapped in TFDS (for example LAION-400M) and run the config
|
|
with the suffix `:test_with_coco=False` to train on your data set. Refer to
|
|
the following guide to build a TFDS wrapper for your favorite image/alt-text
|
|
data set:
|
|
https://www.tensorflow.org/datasets/add_dataset
|
|
|
|
Also note that evaluation on ImageNet requires manual TFDS setup, see
|
|
https://github.com/google-research/big_vision#preparing-tfds-data
|
|
|
|
|
|
Example training:
|
|
|
|
big_vision.trainers.proj.image_text.contrastive \
|
|
--config big_vision/configs/proj/clippo/train_clippo.py \
|
|
--workdir gs://[your_bucket]/big_vision/`date '+%Y-%m-%d_%H%M'`
|
|
|
|
"""
|
|
|
|
import big_vision.configs.common as bvcc
|
|
from big_vision.configs.common_fewshot import get_fewshot_lsr
|
|
from big_vision.configs.proj.image_text import common
|
|
from ml_collections import ConfigDict
|
|
|
|
|
|
def get_config(arg=None):
|
|
"""The base configuration."""
|
|
arg = bvcc.parse_arg(
|
|
arg, res=224, runlocal=False, variant='B/16',
|
|
test_with_coco=True, i1k_eval=False)
|
|
config = ConfigDict()
|
|
|
|
config.input = {}
|
|
if arg.test_with_coco:
|
|
|
|
config.input.data = dict(name='coco_captions', split='train')
|
|
val_data = dict(config.input.data)
|
|
val_data['split'] = 'val'
|
|
config.input.batch_size = 4000 if not arg.runlocal else 32
|
|
config.input.shuffle_buffer_size = 50_000 if not arg.runlocal else 50
|
|
config.total_steps = 400 if not arg.runlocal else 10
|
|
else:
|
|
|
|
config.input.data = None
|
|
val_data = None
|
|
assert config.input.data is not None and val_data is not None, (
|
|
config.input.data, val_data)
|
|
|
|
|
|
|
|
config.input.batch_size = 8 * 1024 if not arg.runlocal else 32
|
|
config.input.shuffle_buffer_size = 250_000 if not arg.runlocal else 50
|
|
config.total_steps = 100_000 if not arg.runlocal else 10
|
|
|
|
def tokenizer(inkey, outkey='labels'):
|
|
return (f'render_unifont('
|
|
f'inkey="{inkey}", '
|
|
f'outkey="{outkey}", '
|
|
f'image_size={arg.res}, '
|
|
f'lower=True, '
|
|
f'font_size=16, '
|
|
f'text_brightness=0, '
|
|
f'background_brightness=127)|'
|
|
f'value_range(-1, 1, inkey="{outkey}", outkey="{outkey}")')
|
|
|
|
pp_image = f'decode|resize({arg.res})|value_range(-1,1)'
|
|
if arg.test_with_coco:
|
|
|
|
pp_image_aug = (
|
|
f'decode|resize({arg.res})|flip_lr|randaug(2,10)|value_range(-1,1)')
|
|
config.input.pp = pp_eval = (
|
|
f'{pp_image_aug}|flatten|{tokenizer("captions/text")}|'
|
|
f'keep("image", "labels")')
|
|
else:
|
|
config.input.pp = pp_eval = (
|
|
f'{pp_image}|flatten|{tokenizer("text")}|keep("image", "labels")')
|
|
|
|
config.pp_modules = [
|
|
'ops_general', 'ops_image', 'ops_text', 'proj.clippo.pp_ops']
|
|
|
|
config.log_training_steps = 50
|
|
config.ckpt_steps = 1000
|
|
config.keep_ckpt_steps = 5000
|
|
|
|
config.loss_use_global_batch = True
|
|
|
|
|
|
config.model_name = 'proj.clippo.one_tower'
|
|
|
|
config.model = ConfigDict()
|
|
config.model.image_model = 'vit'
|
|
config.model.image = ConfigDict({
|
|
'variant': arg.variant,
|
|
'pool_type': 'map',
|
|
'head_zeroinit': False,
|
|
})
|
|
|
|
if arg.test_with_coco:
|
|
|
|
assert arg.variant == 'B/16', arg.variant
|
|
config.model_init = {'image': 'howto-i21k-B/16'}
|
|
config.model_load = {}
|
|
config.model_load['img_load_kw'] = {
|
|
'dont_load': ['^head/.*', '^MAPHead_0/.*', 'cls']}
|
|
|
|
config.model.temperature_init = 10.0
|
|
config.model.out_dim = 768
|
|
|
|
|
|
config.optax_name = 'big_vision.scale_by_adafactor'
|
|
config.grad_clip_norm = 1.0
|
|
|
|
if arg.test_with_coco:
|
|
|
|
config.lr = 0.0001
|
|
config.wd = 0.0003
|
|
config.schedule = dict(decay_type='rsqrt',
|
|
timescale=100,
|
|
warmup_steps=100 if not arg.runlocal else 5,
|
|
cooldown_steps=100 if not arg.runlocal else 5)
|
|
else:
|
|
config.lr = 0.001
|
|
config.wd = 0.0001
|
|
config.schedule = dict(decay_type='rsqrt',
|
|
timescale=10_000,
|
|
warmup_steps=10_000 if not arg.runlocal else 5,
|
|
cooldown_steps=10_000 if not arg.runlocal else 5)
|
|
|
|
|
|
eval_common = dict(
|
|
type='proj.image_text.contrastive',
|
|
use_global_batch=config.loss_use_global_batch,
|
|
log_steps=1000 if not arg.runlocal else 5,
|
|
)
|
|
config.evals = {}
|
|
sub = '[:4]' if arg.runlocal else ''
|
|
config.evals.val = {
|
|
**eval_common,
|
|
'data': val_data,
|
|
'pp_fn': pp_eval,
|
|
}
|
|
config.evals.coco = {
|
|
**eval_common,
|
|
'data': dict(name='coco_captions', split=f'val{sub}'),
|
|
'pp_fn': (
|
|
f'{pp_image}|flatten|{tokenizer("captions/text")}|'
|
|
f'keep("image", "labels")'),
|
|
}
|
|
|
|
if arg.i1k_eval:
|
|
|
|
|
|
config.evals.imagenet = {
|
|
**eval_common,
|
|
'data': dict(name='imagenet2012', split=f'validation{sub}'),
|
|
'pp_fn': (
|
|
f'{pp_image}|clip_i1k_label_names|'
|
|
f'{tokenizer("labels")}|keep("image", "labels")'),
|
|
}
|
|
config.evals.disclf = dict(
|
|
type='proj.image_text.discriminative_classifier',
|
|
pp_txt=tokenizer('texts', 'labels'),
|
|
prefix='z/0shot/',
|
|
log_steps=5_000 if not arg.runlocal else 5)
|
|
|
|
config.evals.retrieval_coco = common.get_coco(
|
|
pp_img=f'resize({arg.res})|value_range(-1, 1)',
|
|
pp_txt=tokenizer('texts'),
|
|
log_steps=5_000 if not arg.runlocal else 5,
|
|
)
|
|
|
|
|
|
config.evals.fewshot = get_fewshot_lsr()
|
|
config.evals.fewshot.log_steps = 5_000 if not arg.runlocal else 5
|
|
config.evals.fewshot.representation_layer = 'img/pre_logits'
|
|
|
|
config.seed = 0
|
|
|
|
return config
|
|
|