|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Gradient transformations and other optax utilities."""
|
|
|
|
import operator
|
|
import big_vision.utils as u
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import optax
|
|
|
|
|
|
def find_states(opt_state, cls):
|
|
leaves = jax.tree.leaves(
|
|
opt_state, is_leaf=lambda node: isinstance(node, cls))
|
|
return [leaf for leaf in leaves if isinstance(leaf, cls)]
|
|
|
|
|
|
def get_count(opt_state, jittable=False):
|
|
"""Returns `ScaleByScheduleState.count` from `opt_state` as an integer."""
|
|
counts = [
|
|
state.count
|
|
for state in find_states(opt_state, optax.ScaleByScheduleState)
|
|
]
|
|
if jittable:
|
|
return counts[0]
|
|
else:
|
|
counts = {int(c) for c in counts}
|
|
assert len(counts) == 1, f"Expected exactly 1 ScaleByScheduleState:{counts}"
|
|
return next(iter(counts))
|
|
|
|
|
|
def replace_frozen(schedule, pytree, replacement, log=None):
|
|
"""Replaces values matching frozen params in `pytree` with `replacement`."""
|
|
if not isinstance(schedule, (list, tuple)):
|
|
return pytree
|
|
masks, scheds = _make_mask_trees(pytree, schedule, log=log)
|
|
frozen_mask, _, _ = _split_frozen(masks, scheds)
|
|
return jax.tree.map(
|
|
lambda v, f: replacement if f else v, pytree, frozen_mask)
|
|
|
|
|
|
def clip_by_per_example_global_norm(
|
|
max_norm: float,
|
|
) -> optax.GradientTransformation:
|
|
"""Clips the norm of per-example gradients."""
|
|
|
|
def init_fn(params):
|
|
del params
|
|
return optax.EmptyState()
|
|
|
|
def update_fn(updates, state, params=None):
|
|
del params
|
|
grads_flat, grads_treedef = jax.tree_util.tree_flatten(updates)
|
|
clipped, _ = optax.per_example_global_norm_clip(grads_flat, max_norm)
|
|
return jax.tree_util.tree_unflatten(grads_treedef, clipped), state
|
|
|
|
return optax.GradientTransformation(init_fn, update_fn)
|
|
|
|
|
|
def make(config, params, *, sched_kw):
|
|
"""Returns gradient transform and learning rate functions."""
|
|
|
|
|
|
schedule = config.get("schedule", {})
|
|
if not isinstance(schedule, (tuple, list)):
|
|
schedule = [(".*", schedule)]
|
|
masks, scheds = _make_mask_trees(params, schedule, "config.schedule")
|
|
frozen_mask, masks, scheds = _split_frozen(masks, scheds)
|
|
not_frozen_mask = jax.tree.map(operator.not_, frozen_mask)
|
|
def create_schedule(mult=1.0, **kw):
|
|
assert "base" not in kw, kw
|
|
return u.create_learning_rate_schedule(base=mult, **kw)
|
|
schedule_fns = [create_schedule(**sched_kw, **sched) for sched in scheds]
|
|
schedule_txs = [
|
|
optax.masked(optax.scale_by_schedule(schedule_fn), mask)
|
|
for schedule_fn, mask in zip(schedule_fns, masks)
|
|
] + [
|
|
|
|
|
|
|
|
optax.masked(optax.set_to_zero(), frozen_mask)
|
|
]
|
|
|
|
|
|
if clip_norm := config.get("grad_clip_norm"):
|
|
if config.get("grad_clip_per_example"):
|
|
clip_tx = clip_by_per_example_global_norm(clip_norm)
|
|
else:
|
|
clip_tx = optax.clip_by_global_norm(clip_norm)
|
|
grad_clip_norm_tx = optax.masked(clip_tx, not_frozen_mask)
|
|
else:
|
|
grad_clip_norm_tx = optax.identity()
|
|
|
|
|
|
tx_func = operator.attrgetter(config.optax_name)(optax)
|
|
opt_txs = [optax.masked(tx_func(**config.get("optax", {})), not_frozen_mask)]
|
|
assert "optim" not in config, "Deprecated option, use config.optax."
|
|
|
|
|
|
lr_mult_txs = [optax.scale(config.lr)]
|
|
if config.get("lr_mults"):
|
|
masks, mults = _make_mask_trees(params, config.lr_mults, "config.lr_mults")
|
|
assert all(mult > 0 for mult in mults), (
|
|
f"Use schedule=None for parameter freezing instead of lr_mults={mults}")
|
|
lr_mult_txs += [
|
|
optax.masked(optax.scale(mult), mask)
|
|
for mult, mask in zip(mults, masks)
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
assert "weight_decay" not in config, "Deprecated option. Use wd and schedule."
|
|
assert config.get("weight_decay_decouple", True), (
|
|
"Coupled weight decay not supported anymore.")
|
|
if config.get("wd"):
|
|
wd_mults = config.get("wd_mults", [(".*/kernel$", 1.0)])
|
|
masks, mults = _make_mask_trees(params, wd_mults, "config.wd_mults")
|
|
weight_decay_txs = [
|
|
optax.add_decayed_weights(config.wd * mult, mask)
|
|
for mult, mask in zip(mults, masks)
|
|
]
|
|
else:
|
|
weight_decay_txs = []
|
|
|
|
|
|
return optax.chain(
|
|
grad_clip_norm_tx,
|
|
*opt_txs,
|
|
*lr_mult_txs,
|
|
*weight_decay_txs,
|
|
*schedule_txs,
|
|
optax.scale(-1.0)), schedule_fns
|
|
|
|
|
|
def _make_mask_trees(params, patterns_values, log):
|
|
patterns, values = zip(*patterns_values)
|
|
masks = u.make_mask_trees(params, patterns, log=log)
|
|
return masks, values
|
|
|
|
|
|
def _split_frozen(masks, scheds):
|
|
"""Computes `frozen_mask` and updates `masks` and `scheds`."""
|
|
|
|
all_false = jax.tree.map(lambda *bools: not any(bools), *masks)
|
|
not_covered = [k for k, v in u.tree_flatten_with_names(all_false)[0] if v]
|
|
assert not not_covered, (
|
|
f"All params must be covered (use `None` for freezing): {not_covered}")
|
|
frozen_masks = [
|
|
mask for mask, sched in zip(masks, scheds) if sched is None]
|
|
frozen_mask = jax.tree.map(
|
|
lambda *bools: any(bools), *frozen_masks,
|
|
all_false)
|
|
masks, scheds = zip(*(
|
|
(mask, sched) for mask, sched in zip(masks, scheds) if sched is not None))
|
|
return frozen_mask, masks, scheds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optax.big_vision = type("", (), {})()
|
|
|
|
|
|
def scale_by_adafactor(min_dim_size_to_factor=32,
|
|
decay_rate=0.8, decay_offset=0,
|
|
beta2_cap=0.999,
|
|
clipping_threshold=None,
|
|
momentum=0.9, dtype_momentum=jnp.bfloat16,
|
|
eps=1e-30):
|
|
"""The BigVision variant of Adafactor optimizer."""
|
|
|
|
def _decay_rate_pow(i, exponent):
|
|
"""Second-order moment decay schedule."""
|
|
t = jnp.array(i, jnp.float32) + 1.0
|
|
return jnp.minimum(beta2_cap, 1.0 - t**(-exponent))
|
|
|
|
scale_by_rms = optax.scale_by_factored_rms(
|
|
factored=True,
|
|
decay_rate=decay_rate,
|
|
step_offset=decay_offset,
|
|
min_dim_size_to_factor=min_dim_size_to_factor,
|
|
epsilon=eps,
|
|
decay_rate_fn=_decay_rate_pow)
|
|
|
|
clip = (optax.clip_by_block_rms(clipping_threshold) if clipping_threshold
|
|
else optax.identity())
|
|
|
|
mom = (optax.ema(momentum, debias=False, accumulator_dtype=dtype_momentum)
|
|
if momentum else optax.identity())
|
|
|
|
return optax.chain(scale_by_rms, clip, mom)
|
|
|
|
optax.big_vision.scale_by_adafactor = scale_by_adafactor
|
|
|
|
|
|
|
|
def momentum_hp(momentum=0.9, dtype=jnp.bfloat16, nesterov=False):
|
|
"""SGD-Momentum with half-precision accumulator."""
|
|
return optax.trace(decay=momentum, accumulator_dtype=dtype, nesterov=nesterov)
|
|
|
|
optax.big_vision.momentum_hp = momentum_hp
|
|
optax.big_vision.sgd = optax.identity
|
|
|