|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Model definitions for CapPa (https://arxiv.org/abs/2306.07915).
|
|
|
|
Used abbreviations for dimension annotations:
|
|
B: batch size.
|
|
H: image height.
|
|
W: image width.
|
|
P: number of patches (PH/PW: number of patches in height/width dimensions).
|
|
E: embedding size.
|
|
L: sequence length of text tokens.
|
|
V: vocab size.
|
|
"""
|
|
|
|
from collections.abc import Sequence
|
|
|
|
from big_vision import utils
|
|
from big_vision.models import common
|
|
from big_vision.models import vit
|
|
import flax
|
|
import flax.linen as nn
|
|
from flax.linen import partitioning
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
|
|
def shift_right(x, axis=1, constant_values=0):
|
|
"""Shift to the right on given axis with padding value 0."""
|
|
pad_widths = [(0, 0)] * len(x.shape)
|
|
pad_widths[axis] = (1, 0)
|
|
padded = jnp.pad(x, pad_widths, constant_values=constant_values)
|
|
|
|
|
|
return padded[tuple(slice(-1 if i == axis else None) for i in range(x.ndim))]
|
|
|
|
|
|
class MlpBlock(nn.Module):
|
|
"""Transformer MLP / feed-forward block with option to deactivate bias."""
|
|
mlp_dim: int | None = None
|
|
dropout: float = 0.0
|
|
use_bias: bool = True
|
|
|
|
@nn.compact
|
|
def __call__(self, x, deterministic=True):
|
|
"""Applies Transformer MlpBlock module."""
|
|
inits = dict(
|
|
kernel_init=nn.initializers.xavier_uniform(),
|
|
bias_init=nn.initializers.normal(stddev=1e-6),
|
|
)
|
|
|
|
n, l, d = x.shape
|
|
x = nn.Dense(self.mlp_dim or 4 * d, use_bias=self.use_bias, **inits)(x)
|
|
x = nn.gelu(x)
|
|
x = nn.Dropout(rate=self.dropout)(x, deterministic)
|
|
x = nn.Dense(d, use_bias=self.use_bias, **inits)(x)
|
|
return x
|
|
|
|
|
|
class EncoderDecoderBlock(nn.Module):
|
|
"""Transformer encoder-decoder layer."""
|
|
mlp_dim: int
|
|
num_heads: int
|
|
dropout_rate: float = 0.
|
|
decode: bool = False
|
|
use_bias: bool = True
|
|
|
|
@nn.compact
|
|
def __call__(self, targets, encoded, decoder_mask=None, deterministic=True):
|
|
"""Applies EncoderDecoder1DBlock module.
|
|
|
|
Args:
|
|
targets: target text embeddings [B, L, E].
|
|
encoded: encoded image patches from encoder [B, P, E].
|
|
decoder_mask: decoder self-attention mask.
|
|
deterministic: bool, deterministic or not (to apply dropout).
|
|
|
|
Returns:
|
|
output after transformer encoder-decoder block [B, L, E].
|
|
"""
|
|
def wlc(f):
|
|
dim_names = ("act_batch", "act_len", "act_emb")
|
|
return nn.with_logical_constraint(f, dim_names)
|
|
|
|
|
|
x = wlc(nn.LayerNorm(name="LayerNorm1", use_bias=self.use_bias)(targets))
|
|
x = wlc(nn.SelfAttention(
|
|
num_heads=self.num_heads, use_bias=False, broadcast_dropout=False,
|
|
dropout_rate=self.dropout_rate, decode=self.decode, name="SelfAttn")(
|
|
x, decoder_mask, deterministic=deterministic))
|
|
x = wlc(nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic))
|
|
x = wlc(x + targets)
|
|
|
|
if encoded is not None:
|
|
|
|
y = wlc(nn.LayerNorm(name="LayerNorm2", use_bias=self.use_bias)(x))
|
|
y = wlc(nn.MultiHeadDotProductAttention(
|
|
num_heads=self.num_heads, use_bias=False, broadcast_dropout=False,
|
|
dropout_rate=self.dropout_rate, name="CrossAttn")(
|
|
y, encoded, deterministic=deterministic))
|
|
y = wlc(
|
|
nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic))
|
|
y = wlc(y + x)
|
|
else:
|
|
y = x
|
|
|
|
|
|
z = wlc(nn.LayerNorm(name="LayerNorm3", use_bias=self.use_bias)(y))
|
|
z = wlc(MlpBlock(
|
|
mlp_dim=self.mlp_dim, dropout=self.dropout_rate, use_bias=self.use_bias,
|
|
name="MLP")(z, deterministic=deterministic))
|
|
|
|
return wlc(y + z), None
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
"""Transformer decoder with parallel prediction."""
|
|
emb_dim: int
|
|
mlp_dim: int
|
|
num_heads: int
|
|
num_layers: int
|
|
dropout_rate: float = 0.
|
|
output_vocab_size: int = 32_000
|
|
|
|
|
|
masked_pred_prob: float = 0.
|
|
masking_ratio: float = 0.
|
|
|
|
|
|
use_bias: bool = True
|
|
|
|
scan: bool = False
|
|
remat_policy: str = "nothing_saveable"
|
|
|
|
@nn.compact
|
|
def __call__(self,
|
|
encoded,
|
|
targets,
|
|
pos_emb,
|
|
decoder_mask=None,
|
|
decode=False,
|
|
deterministic=True,
|
|
max_decode_length=None):
|
|
"""Applies Transformer model on the inputs.
|
|
|
|
Args:
|
|
encoded: encoded image patches from encoder [B, P, E].
|
|
targets: target text tokens [B, L].
|
|
pos_emb: positional embeddings.
|
|
decoder_mask: decoder self-attention mask.
|
|
decode: bool, whether to perform fast autoregressive decoding with cache.
|
|
deterministic: bool, deterministic or not (to apply dropout).
|
|
max_decode_length: optional max length for positional embeddings.
|
|
|
|
Returns:
|
|
output of a transformer decoder [B, L, V].
|
|
"""
|
|
y = targets.astype("int32")
|
|
if not decode:
|
|
if self.masked_pred_prob > 0.0 and not deterministic:
|
|
|
|
|
|
def _add_random_masks(a):
|
|
|
|
n_masked = int(self.masking_ratio * a.shape[1])
|
|
mask_locations = jnp.zeros(a.shape[:2], dtype=jnp.int32)
|
|
mask_locations = mask_locations.at[:, :n_masked].set(1)
|
|
mask_locations = jax.random.permutation(
|
|
self.make_rng("dropout"), mask_locations, axis=1, independent=True
|
|
)
|
|
|
|
a_masked = jnp.where(mask_locations, self.output_vocab_size, a)
|
|
return a_masked
|
|
|
|
def where(mask, x, y):
|
|
mask = mask.reshape((-1,) + (1,) * (x.ndim - 1))
|
|
return jnp.where(mask, x, y)
|
|
|
|
do_masked_pred = (
|
|
jax.random.uniform(self.make_rng("dropout"), (len(y),))
|
|
< self.masked_pred_prob
|
|
)
|
|
y = where(do_masked_pred, _add_random_masks(y), shift_right(y))
|
|
decoder_mask = where(
|
|
do_masked_pred, jnp.ones_like(decoder_mask), decoder_mask
|
|
)
|
|
|
|
else:
|
|
y = shift_right(y)
|
|
|
|
embed = nn.Embed(
|
|
self.output_vocab_size + (1 if self.masked_pred_prob > 0.0 else 0),
|
|
self.emb_dim,
|
|
name="EmbedTargets",
|
|
embedding_init=nn.initializers.normal(stddev=1.0),
|
|
)
|
|
y = embed(y)
|
|
|
|
y = common.AddPositionEmbs(
|
|
decode=decode, name="PosEmbedTargets")(y, pos_emb)
|
|
|
|
|
|
|
|
|
|
if self.scan:
|
|
|
|
|
|
|
|
|
|
enc_dec_block_remat = nn.remat(
|
|
EncoderDecoderBlock,
|
|
prevent_cse=False,
|
|
static_argnums=(-1,),
|
|
policy=getattr(jax.checkpoint_policies, self.remat_policy, None))
|
|
|
|
initializing = self.is_mutable_collection("params")
|
|
param_scan_axis = 1
|
|
params_spec = (param_scan_axis if initializing
|
|
else partitioning.ScanIn(param_scan_axis))
|
|
dec_scanned = nn.scan(enc_dec_block_remat,
|
|
variable_axes={
|
|
"params": params_spec,
|
|
"cache": 0,
|
|
},
|
|
split_rngs={"params": True, "dropout": True},
|
|
in_axes=nn.broadcast,
|
|
length=self.num_layers)
|
|
|
|
y, _ = dec_scanned(num_heads=self.num_heads, mlp_dim=self.mlp_dim,
|
|
dropout_rate=self.dropout_rate, decode=decode,
|
|
use_bias=self.use_bias, name="EncDecBlock")(
|
|
y, encoded, decoder_mask, deterministic)
|
|
else:
|
|
for lyr in range(self.num_layers):
|
|
y, _ = EncoderDecoderBlock(
|
|
num_heads=self.num_heads, mlp_dim=self.mlp_dim,
|
|
dropout_rate=self.dropout_rate, decode=decode,
|
|
use_bias=self.use_bias, name=f"EncDecBlock{lyr}")(
|
|
y, encoded, decoder_mask=decoder_mask,
|
|
deterministic=deterministic)
|
|
|
|
y = nn.LayerNorm(name="LayerNorm")(y)
|
|
|
|
logits = nn.Dense(
|
|
self.output_vocab_size,
|
|
kernel_init=nn.initializers.zeros,
|
|
name="LogitsDense",
|
|
)(y)
|
|
return logits
|
|
|
|
|
|
class Model(nn.Module):
|
|
"""Transformer Model for sequence to sequence translation."""
|
|
|
|
num_heads: int = 8
|
|
num_layers: int = 6
|
|
mlp_dim: int = 2048
|
|
emb_dim: int = 512
|
|
enc_dropout_rate: float = 0.
|
|
vocab_size: int = 32_000
|
|
seq_len: int = 256
|
|
|
|
|
|
patches: Sequence[int] = (16, 16)
|
|
input_seq_len: int = 768
|
|
posemb_type: str = "learn"
|
|
patch_dropout: float = 0.
|
|
|
|
|
|
decoder_num_heads: int = 0
|
|
decoder_num_layers: int = 0
|
|
decoder_mlp_dim: int = 0
|
|
decoder_emb_dim: int = 0
|
|
dec_dropout_rate: float = 0.
|
|
|
|
masked_pred_prob: float = 0.
|
|
|
|
masking_ratio: float = 0.
|
|
|
|
decoder_bias: bool = True
|
|
|
|
scan: bool = False
|
|
remat_policy: str = "nothing_saveable"
|
|
|
|
def setup(self):
|
|
|
|
self.encoder = vit.Model(
|
|
patch_size=self.patches,
|
|
width=self.emb_dim,
|
|
depth=self.num_layers,
|
|
mlp_dim=self.mlp_dim,
|
|
num_heads=self.num_heads,
|
|
dropout=self.enc_dropout_rate,
|
|
posemb=self.posemb_type,
|
|
scan=self.scan,
|
|
remat_policy=self.remat_policy,
|
|
)
|
|
|
|
self.pos_emb_for_decoder = vit.get_posemb(
|
|
self,
|
|
self.posemb_type,
|
|
(1, self.seq_len),
|
|
self.decoder_emb_dim or self.emb_dim,
|
|
"pos_embedding_decoder",
|
|
)
|
|
self.decoder = Decoder(
|
|
num_layers=self.decoder_num_layers or self.num_layers,
|
|
mlp_dim=self.decoder_mlp_dim or self.mlp_dim,
|
|
num_heads=self.decoder_num_heads or self.num_heads,
|
|
dropout_rate=self.dec_dropout_rate,
|
|
emb_dim=self.decoder_emb_dim or self.emb_dim,
|
|
output_vocab_size=self.vocab_size,
|
|
masked_pred_prob=self.masked_pred_prob,
|
|
masking_ratio=self.masking_ratio,
|
|
use_bias=self.decoder_bias,
|
|
scan=self.scan,
|
|
remat_policy=self.remat_policy,
|
|
)
|
|
|
|
def encode(self, image, train=False, return_enc_features=False):
|
|
"""Encodes input image or embeddings."""
|
|
|
|
_, out = self.encoder(image, train=train)
|
|
encoded = out["encoded"]
|
|
|
|
|
|
if return_enc_features:
|
|
return encoded, out
|
|
|
|
return encoded
|
|
|
|
def decode(self, encoded, targets, decode=False, train=False,
|
|
max_decode_length=None):
|
|
"""Applies Transformer decoder-branch on encoded-input and target.
|
|
|
|
Args:
|
|
encoded: encoded image patches from encoder [B, P, E].
|
|
targets: target text tokens [B, L].
|
|
decode: whether to prepare and use an autoregressive cache.
|
|
train: whether it is training.
|
|
max_decode_length: optional max length for positional embeddings.
|
|
|
|
Returns:
|
|
logits array from transformer decoder [B, L, V].
|
|
"""
|
|
decoder_mask = None if decode else nn.make_causal_mask(targets)
|
|
logits = self.decoder(
|
|
encoded,
|
|
targets,
|
|
pos_emb=self.pos_emb_for_decoder,
|
|
decoder_mask=decoder_mask,
|
|
decode=decode,
|
|
deterministic=not train,
|
|
max_decode_length=max_decode_length)
|
|
return logits
|
|
|
|
def __call__(self, image, text, *, decode=False,
|
|
train=False, return_enc_features=False):
|
|
"""Applies Transformer model on the inputs.
|
|
|
|
Args:
|
|
image: batch of images [B, H, W, 3].
|
|
text: batch of tokenized texts [B, L].
|
|
decode: whether to prepare and use an autoregressive cache.
|
|
train: whether it is training.
|
|
return_enc_features: whether to return the encoder features.
|
|
|
|
Returns:
|
|
logits array from full transformer [B, L, V].
|
|
"""
|
|
if return_enc_features:
|
|
encoded, out = self.encode(image, train=train, return_enc_features=True)
|
|
return encoded, out
|
|
|
|
encoded = self.encode(image, train=train)
|
|
|
|
decoded = self.decode(encoded, text, decode=decode, train=train)
|
|
return decoded
|
|
|
|
|
|
def load(init_params, init_files, model_params=None,
|
|
dont_load=("head/kernel", "head/bias", "cls")):
|
|
"""Loads params from init checkpoint and merges into init_params."""
|
|
|
|
if isinstance(init_files, str):
|
|
|
|
ckpt_params = utils.load_params(init_files)
|
|
ckpt_params = flax.training.checkpoints.convert_pre_linen(ckpt_params)
|
|
ckpt_params = common.merge_params(ckpt_params, init_params, dont_load)
|
|
|
|
|
|
if (model_params.get("scan") and
|
|
"encoderblock" not in ckpt_params["encoder"]["Transformer"]):
|
|
raise NotImplementedError("Loading a non-scan checkpoint into a "
|
|
"scan model is not supported yet!")
|
|
if (not model_params.get("scan")
|
|
and "encoderblock" in ckpt_params["encoder"]["Transformer"]):
|
|
assert "decoder.*" in dont_load or "decoder/.*" in dont_load, (
|
|
"Converting scan decoder to a non-scan one is not supported yet!")
|
|
ckpt_params["encoder"] = utils.jit_cpu()(
|
|
vit.scan_to_pyloop)(ckpt_params["encoder"])
|
|
|
|
else:
|
|
assert set(init_files) == {"encoder"}, "Only encoder init supported"
|
|
enc_init = init_files["encoder"]
|
|
ckpt_params = flax.core.freeze(init_params).unfreeze()
|
|
vit_params = ckpt_params["encoder"]
|
|
encoder_params = vit.load(
|
|
vit_params, enc_init, model_cfg={},
|
|
dont_load=dont_load)
|
|
ckpt_params["encoder"] = encoder_params
|
|
|
|
ckpt_params["encoder"]["pos_embedding"] = vit.resample_posemb(
|
|
old=ckpt_params["encoder"]["pos_embedding"],
|
|
new=init_params["encoder"]["pos_embedding"])
|
|
|
|
return ckpt_params
|
|
|