|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""gemma reimplementation for big_vision.
|
|
|
|
We follow this einsum axis naming convention:
|
|
B: batch
|
|
T: query length
|
|
S: k/v length
|
|
N: num query heads
|
|
K: num k/v heads
|
|
G: num query heads per k/v head
|
|
H: head dim
|
|
D: d_model ("features")
|
|
|
|
Example Colab using the models via the PaliGemma decoding logic:
|
|
(internal link)
|
|
|
|
Doc locating the variable initializers in the original code and validating them:
|
|
(internal link)
|
|
"""
|
|
|
|
|
|
from big_vision.models import common
|
|
import big_vision.utils as u
|
|
import einops
|
|
import flax
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import ml_collections
|
|
import numpy as np
|
|
import orbax.checkpoint
|
|
|
|
|
|
def get_config(variant):
|
|
"""Returns config for specified gemma variant."""
|
|
if variant == "gemma_2b":
|
|
return ml_collections.ConfigDict(
|
|
dict(
|
|
variant=variant,
|
|
width=2048,
|
|
depth=18,
|
|
mlp_dim=16_384,
|
|
num_heads=8,
|
|
num_kv_heads=1,
|
|
head_dim=256,
|
|
norm_eps=1e-6,
|
|
vocab_size=256_128,
|
|
scan=True,
|
|
remat_policy="nothing_saveable",
|
|
)
|
|
)
|
|
if variant == "gemma_7b":
|
|
return ml_collections.ConfigDict(
|
|
dict(
|
|
variant=variant,
|
|
width=3072,
|
|
depth=28,
|
|
mlp_dim=24_576,
|
|
num_heads=16,
|
|
num_kv_heads=16,
|
|
head_dim=256,
|
|
norm_eps=1e-6,
|
|
vocab_size=256_128,
|
|
scan=True,
|
|
remat_policy="nothing_saveable",
|
|
)
|
|
)
|
|
raise ValueError(f"Unknown variant: {variant}")
|
|
|
|
|
|
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
|
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
|
freq_exponents = (2. / x.shape[-1]) * jnp.arange(x.shape[-1] // 2)
|
|
timescale = (max_wavelength ** freq_exponents)
|
|
radians = positions[..., None] / timescale[None, None, :]
|
|
radians = radians[..., None, :]
|
|
|
|
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
|
x1, x2 = jnp.split(x, 2, axis=-1)
|
|
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
|
return res
|
|
|
|
|
|
def _update_kv_cache(module, k, v, cache_size, cache_dtype):
|
|
"""Updates KV cache and returns its current contents."""
|
|
initialized = module.has_variable("cache", "idx")
|
|
batch_size, update_len, num_heads, head_dim = k.shape
|
|
cache_dtype = cache_dtype or k.dtype
|
|
|
|
|
|
|
|
|
|
|
|
idx = module.variable("cache", "idx", jnp.zeros, (batch_size,), jnp.int32)
|
|
|
|
kv_shape = (batch_size, cache_size, num_heads, head_dim)
|
|
k_cache = module.variable(
|
|
"cache", "k_cache", jnp.zeros, kv_shape, cache_dtype)
|
|
v_cache = module.variable(
|
|
"cache", "v_cache", jnp.zeros, kv_shape, cache_dtype)
|
|
|
|
if initialized:
|
|
assert update_len == 1, update_len
|
|
|
|
indices = (0, idx.value[0], 0, 0)
|
|
k_cache.value = jax.lax.dynamic_update_slice(
|
|
k_cache.value, k.astype(cache_dtype), indices)
|
|
v_cache.value = jax.lax.dynamic_update_slice(
|
|
v_cache.value, v.astype(cache_dtype), indices)
|
|
idx.value = idx.value + 1
|
|
else:
|
|
prefill_len = k.shape[1]
|
|
pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))
|
|
k_cache.value = jnp.pad(k.astype(cache_dtype), pad_width)
|
|
v_cache.value = jnp.pad(v.astype(cache_dtype), pad_width)
|
|
idx.value = idx.value + prefill_len
|
|
|
|
return k_cache.value.astype(k.dtype), v_cache.value.astype(v.dtype)
|
|
|
|
|
|
def trunc_norm_init(in_axis, out_axis, batch_axis):
|
|
return nn.initializers.variance_scaling(
|
|
1.0, "fan_in", "truncated_normal",
|
|
in_axis=in_axis, out_axis=out_axis, batch_axis=batch_axis)
|
|
|
|
|
|
class Einsum(nn.Module):
|
|
shape: tuple[int, ...]
|
|
w_init: nn.initializers.Initializer = nn.initializers.zeros_init()
|
|
|
|
@nn.compact
|
|
def __call__(self, eqn, x):
|
|
w = self.param("w", self.w_init, self.shape)
|
|
return jnp.einsum(eqn, x, w)
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
|
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
|
|
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))
|
|
normed_inputs = normed_inputs * (1 + scale)
|
|
return normed_inputs
|
|
|
|
|
|
class Embedder(nn.Module):
|
|
"""Embedder module."""
|
|
|
|
vocab_size: int
|
|
embed_dim: int
|
|
|
|
def setup(self):
|
|
self.input_embedding_table = self.param(
|
|
"input_embedding",
|
|
nn.initializers.variance_scaling(
|
|
scale=1.0, mode="fan_in", distribution="normal",
|
|
in_axis=1, out_axis=0,),
|
|
(self.vocab_size, self.embed_dim),
|
|
)
|
|
|
|
def encode(self, x):
|
|
x = self.input_embedding_table[(x,)]
|
|
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
|
return x
|
|
|
|
def decode(self, x):
|
|
return jnp.dot(x, self.input_embedding_table.T)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""Attention module."""
|
|
|
|
num_heads: int
|
|
num_kv_heads: int
|
|
features: int
|
|
head_dim: int
|
|
|
|
cache_dtype: str | None = None
|
|
|
|
def setup(self):
|
|
if self.num_kv_heads == self.num_heads:
|
|
self.qkv_einsum = Einsum(
|
|
shape=(3, self.num_heads, self.features, self.head_dim),
|
|
w_init=trunc_norm_init(
|
|
in_axis=(2,), out_axis=(0, 1, 3), batch_axis=()),
|
|
)
|
|
else:
|
|
|
|
self.q_einsum = Einsum(
|
|
shape=(self.num_heads, self.features, self.head_dim),
|
|
w_init=trunc_norm_init(in_axis=(1,), out_axis=(0, 2), batch_axis=()),
|
|
)
|
|
self.kv_einsum = Einsum(
|
|
shape=(2, self.num_kv_heads, self.features, self.head_dim),
|
|
w_init=trunc_norm_init(
|
|
in_axis=(2,), out_axis=(0, 1, 3), batch_axis=()),
|
|
)
|
|
self.attn_vec_einsum = Einsum(
|
|
shape=(self.num_heads, self.head_dim, self.features),
|
|
w_init=trunc_norm_init(in_axis=(0, 1), out_axis=(2,), batch_axis=()),
|
|
)
|
|
|
|
@nn.compact
|
|
def __call__(self, x, positions, attn_mask, decode, deterministic=True):
|
|
if self.num_kv_heads == self.num_heads:
|
|
q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x)
|
|
else:
|
|
q = self.q_einsum("BTD,NDH->BTNH", x)
|
|
k, v = self.kv_einsum("BSD,2KDH->2BSKH", x)
|
|
|
|
q = _apply_rope(q, positions=positions)
|
|
q *= self.head_dim**-0.5
|
|
|
|
k = _apply_rope(k, positions=positions)
|
|
if decode:
|
|
k, v = _update_kv_cache(self, k, v,
|
|
cache_size=attn_mask.shape[-1],
|
|
cache_dtype=self.cache_dtype)
|
|
|
|
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads)
|
|
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k)
|
|
logits = logits.astype(jnp.float32)
|
|
|
|
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
|
raise ValueError(
|
|
f"Attention mask with shape {attn_mask.shape} but shapes for q and k "
|
|
f"are: {q.shape} and {k.shape}"
|
|
)
|
|
|
|
|
|
big_neg = -2.3819763e38
|
|
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
|
|
|
probs = jax.nn.softmax(masked_logits, axis=-1).astype(k.dtype)
|
|
|
|
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
|
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
|
attn_output = self.attn_vec_einsum("BTNH,NHD->BTD", encoded)
|
|
|
|
return attn_output
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
"""Feed forward module."""
|
|
|
|
features: int
|
|
hidden_dim: int
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
w_gating = self.param(
|
|
"gating_einsum",
|
|
trunc_norm_init(in_axis=(1,), out_axis=(0, 2), batch_axis=()),
|
|
((2, self.features, self.hidden_dim)),
|
|
)
|
|
ff_gate = jnp.dot(x, w_gating[0])
|
|
gate_value = nn.gelu(ff_gate)
|
|
|
|
ff1 = jnp.dot(x, w_gating[1])
|
|
activations = gate_value * ff1
|
|
|
|
w_linear = self.param(
|
|
"linear",
|
|
trunc_norm_init(in_axis=(0,), out_axis=(1,), batch_axis=()),
|
|
(self.hidden_dim, self.features),
|
|
)
|
|
outputs = jnp.dot(activations, w_linear)
|
|
|
|
return outputs
|
|
|
|
|
|
class Block(nn.Module):
|
|
"""Transformer block."""
|
|
|
|
num_heads: int
|
|
num_kv_heads: int
|
|
embed_dim: int
|
|
head_dim: int
|
|
hidden_dim: int
|
|
|
|
dropout: float = 0.0
|
|
dropout_bdims: tuple[int, ...] = ()
|
|
cache_dtype: str | None = None
|
|
|
|
def setup(self):
|
|
self.pre_attention_norm = RMSNorm()
|
|
self.attn = Attention(
|
|
num_heads=self.num_heads,
|
|
num_kv_heads=self.num_kv_heads,
|
|
features=self.embed_dim,
|
|
head_dim=self.head_dim,
|
|
cache_dtype=self.cache_dtype,
|
|
)
|
|
self.pre_ffw_norm = RMSNorm()
|
|
self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim)
|
|
if self.dropout:
|
|
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
|
|
else:
|
|
self.drop = lambda x, _: x
|
|
|
|
def __call__(self, x, unused_scan_arg, positions, attn_mask,
|
|
decode, deterministic=True):
|
|
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
|
|
inputs_normalized = self.pre_attention_norm(x)
|
|
attn_output = self.attn(inputs_normalized, positions, attn_mask,
|
|
decode, deterministic)
|
|
attn_output = self.drop(attn_output, deterministic)
|
|
attn_output += x
|
|
residual = attn_output
|
|
attn_output = self.pre_ffw_norm(attn_output)
|
|
outputs = self.mlp(attn_output)
|
|
outputs = self.drop(outputs, deterministic)
|
|
outputs = residual + outputs
|
|
return outputs, unused_scan_arg
|
|
|
|
|
|
class Model(nn.Module):
|
|
"""gemma model."""
|
|
|
|
variant: str
|
|
|
|
width: int
|
|
depth: int
|
|
mlp_dim: int
|
|
num_heads: int
|
|
num_kv_heads: int
|
|
head_dim: int
|
|
norm_eps: float
|
|
vocab_size: int
|
|
|
|
dropout: float = 0.0
|
|
dropout_bdims: tuple[int, ...] = ()
|
|
cache_dtype: str | None = None
|
|
|
|
|
|
|
|
embed_dtype: str = "float32"
|
|
|
|
scan: bool = False
|
|
remat_policy: str = "none"
|
|
|
|
@nn.compact
|
|
def __call__(
|
|
self, tokens, *,
|
|
embedded_prefix=None,
|
|
embed_only=False,
|
|
pre_logits=None,
|
|
positions=None, mask=None,
|
|
decode=False, deterministic=True,
|
|
):
|
|
"""Embed only, or complete forward pass.
|
|
|
|
Args:
|
|
tokens: Embedded, then and appended to `embedded_prefix`. Can be None.
|
|
embedded_prefix: Optional prefix that is already embedded.
|
|
embed_only: Whether to compute embeddings only.
|
|
pre_logits: If present computes logits from pre_logits and returns.
|
|
positions: Optional `[B, T]` allows to specify the absolute position of
|
|
the tokens.
|
|
mask: Optional attention mask `[B, T, S]`.
|
|
decode: Whether to use kv-cache. Caller must pass masks and positions.
|
|
deterministic: Forwarded to all dropout layers.
|
|
|
|
Returns:
|
|
If `embed_only=False`, then `(logits, out)` will be returned.
|
|
If `embed_only=True`, then the embeddings will be returned.
|
|
"""
|
|
out = {}
|
|
|
|
embedder = Embedder(
|
|
vocab_size=self.vocab_size,
|
|
embed_dim=self.width,
|
|
name="embedder")
|
|
|
|
if pre_logits is not None:
|
|
x = out["pre_logits"] = pre_logits
|
|
logits = out["logits"] = embedder.decode(x)
|
|
return logits, out
|
|
|
|
x = []
|
|
if embedded_prefix is not None:
|
|
x.append(embedded_prefix)
|
|
if tokens is not None:
|
|
x.append(embedder.encode(tokens))
|
|
|
|
x = jnp.concatenate(x, axis=-2)
|
|
x = x.astype(self.embed_dtype)
|
|
batch_size, seq_len, width = x.shape
|
|
|
|
if embed_only:
|
|
return x
|
|
|
|
if decode:
|
|
assert positions is not None and mask is not None, (
|
|
"Must explicitly pass positions and mask for decoding.")
|
|
|
|
if positions is None:
|
|
positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]
|
|
assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)
|
|
|
|
if mask is None:
|
|
mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len]))
|
|
if mask.ndim == 3:
|
|
mask = mask[:, None, :, :]
|
|
cache_size = max(seq_len, mask.shape[-1])
|
|
assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape
|
|
|
|
if self.remat_policy == "none":
|
|
block_cls = Block
|
|
else:
|
|
block_cls = nn.remat(
|
|
Block,
|
|
prevent_cse=not self.scan,
|
|
static_argnums=(5, 6),
|
|
policy=getattr(jax.checkpoint_policies, self.remat_policy),
|
|
)
|
|
|
|
block_kw = dict(
|
|
num_heads=self.num_heads,
|
|
head_dim=self.head_dim,
|
|
num_kv_heads=self.num_kv_heads,
|
|
embed_dim=width,
|
|
hidden_dim=self.mlp_dim,
|
|
dropout=self.dropout,
|
|
dropout_bdims=self.dropout_bdims,
|
|
cache_dtype=self.cache_dtype,
|
|
)
|
|
layers = self.scope.push("layers")
|
|
if self.scan:
|
|
blocks = [nn.scan(
|
|
block_cls,
|
|
|
|
variable_axes={"params": 0, "cache": 1},
|
|
split_rngs={"params": True, "dropout": True},
|
|
in_axes=nn.broadcast,
|
|
length=self.depth,
|
|
)(
|
|
parent=layers, **block_kw
|
|
)]
|
|
else:
|
|
blocks = [
|
|
block_cls(
|
|
parent=layers.push(str(layer)),
|
|
**block_kw,
|
|
)
|
|
for layer in range(self.depth)
|
|
]
|
|
unused_scan_arg = ()
|
|
for block in blocks:
|
|
x, unused_scan_arg = block(
|
|
x, unused_scan_arg, positions, mask, decode, deterministic)
|
|
|
|
assert x.dtype == jnp.dtype(self.embed_dtype)
|
|
out["encoded"] = x
|
|
|
|
x = RMSNorm(name="final_norm")(x)
|
|
out["pre_logits"] = x
|
|
|
|
x = embedder.decode(x)
|
|
out["logits"] = x
|
|
|
|
return x, out
|
|
|
|
|
|
_ORBAX_INITS = {}
|
|
_BV_INITS = {}
|
|
|
|
|
|
def _load_orbax(path):
|
|
"""Loads and coverts Orbax gemma checkpoint."""
|
|
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
|
params = checkpointer.restore(path)
|
|
params = flax.traverse_util.unflatten_dict(params, sep="/")["transformer"]
|
|
n = sum(1 for k in params if k.startswith("layer_"))
|
|
params["layers"] = jax.tree.map(
|
|
lambda *xs: np.stack(xs), *(params.pop(f"layer_{i}") for i in range(n))
|
|
)
|
|
mlp = params["layers"]["mlp"]
|
|
mlp["gating_einsum"] = mlp["gating_einsum"].pop("w")
|
|
mlp["linear"] = mlp["linear"].pop("w")
|
|
return params
|
|
|
|
|
|
def _del_pad_rows(params):
|
|
"""Some checkpoints have 128 unused padding tokens."""
|
|
emb = params["embedder"]["input_embedding"]
|
|
assert emb.shape[0] == 256_128
|
|
params["embedder"]["input_embedding"] = np.asarray(emb)[:256_000]
|
|
return params
|
|
|
|
|
|
def load(init_params, init_file, model_cfg=None, dont_load=()):
|
|
"""Loads existing weights."""
|
|
model_cfg = model_cfg or {}
|
|
variant = model_cfg.get("variant", "gemma_2b")
|
|
init_variant = f"{init_file} {variant}"
|
|
if init_variant in _ORBAX_INITS:
|
|
params = _del_pad_rows(_load_orbax(_ORBAX_INITS[init_variant]))
|
|
elif init_variant in _BV_INITS:
|
|
params = _del_pad_rows(u.load_params(_BV_INITS[init_variant]))
|
|
else:
|
|
params = u.load_params(init_file)
|
|
|
|
def extend_rows(emb1, target_rows):
|
|
if (missing_rows := target_rows - emb1.shape[0]) == 0:
|
|
return emb1
|
|
assert missing_rows > 0, "You're asking to shrink vocab?!"
|
|
new_rows = np.random.randn(missing_rows, emb1.shape[1])
|
|
new_rows = (new_rows * 0.02).astype(emb1.dtype)
|
|
return np.r_[np.asarray(emb1), new_rows]
|
|
|
|
if "vocab_size" in model_cfg:
|
|
params["embedder"]["input_embedding"] = extend_rows(
|
|
params["embedder"]["input_embedding"],
|
|
model_cfg["vocab_size"],
|
|
)
|
|
|
|
return common.merge_params(params, init_params, dont_load)
|
|
|