|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for optax."""
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
from big_vision import optax as bv_optax
|
|
import chex
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import ml_collections
|
|
import numpy as np
|
|
import optax
|
|
|
|
|
|
class OptaxTest(parameterized.TestCase):
|
|
|
|
def test_get_count(self):
|
|
params = jax.tree.map(jnp.array, {"a": 1.})
|
|
tx = optax.masked(
|
|
optax.scale_by_schedule(lambda step: step),
|
|
{"a": True},
|
|
)
|
|
opt_state = tx.init(params)
|
|
self.assertEqual(bv_optax.get_count(opt_state), 0)
|
|
_, opt_state = tx.update(params, opt_state)
|
|
self.assertEqual(bv_optax.get_count(opt_state), 1)
|
|
|
|
def test_split_frozen(self):
|
|
params = jax.tree.map(jnp.array, {
|
|
"Dense_0": {"kernel": 1., "bias": 2.},
|
|
})
|
|
sched1 = dict(decay_type="cosine")
|
|
sched2 = dict(decay_type="linear")
|
|
schedule = [
|
|
(".*/kernel", sched1),
|
|
(".*/bias", sched2),
|
|
]
|
|
masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule")
|
|
frozen_mask, masks, scheds = bv_optax._split_frozen(masks, scheds)
|
|
chex.assert_trees_all_equal(
|
|
frozen_mask,
|
|
{"Dense_0": {"kernel": False, "bias": False}},
|
|
)
|
|
chex.assert_trees_all_equal(
|
|
masks,
|
|
(
|
|
{"Dense_0": {"kernel": True, "bias": False}},
|
|
{"Dense_0": {"kernel": False, "bias": True}},
|
|
),
|
|
)
|
|
self.assertEqual(scheds, (sched1, sched2))
|
|
|
|
schedule = [
|
|
(".*/bias", None),
|
|
("Dense_0/.*", sched1),
|
|
(".*", None),
|
|
]
|
|
masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule")
|
|
frozen_mask, masks, scheds = bv_optax._split_frozen(masks, scheds)
|
|
chex.assert_trees_all_equal(
|
|
frozen_mask,
|
|
{"Dense_0": {"kernel": False, "bias": True}},
|
|
)
|
|
chex.assert_trees_all_equal(
|
|
masks,
|
|
({"Dense_0": {"kernel": True, "bias": False}},),
|
|
)
|
|
self.assertEqual(scheds, (sched1,))
|
|
|
|
schedule = [
|
|
(".*/kernel", None),
|
|
]
|
|
masks, scheds = bv_optax._make_mask_trees(params, schedule, log="schedule")
|
|
with self.assertRaisesRegex(AssertionError, "All params must be covered"):
|
|
_ = bv_optax._split_frozen(masks, scheds)
|
|
|
|
def test_replace_frozen(self):
|
|
params = jax.tree.map(jnp.array, {
|
|
"Dense_0": {"kernel": 1., "bias": 2.},
|
|
})
|
|
schedule = [
|
|
(".*/kernel", {}),
|
|
(".*", None),
|
|
]
|
|
chex.assert_trees_all_equal(
|
|
bv_optax.replace_frozen(schedule, params, 0.),
|
|
{"Dense_0": {"kernel": 1., "bias": 0.}},
|
|
)
|
|
|
|
def test_make_simple(self):
|
|
params = jax.tree.map(jnp.array, {
|
|
"Dense_0": {"kernel": 1., "bias": 2.},
|
|
})
|
|
|
|
config = ml_collections.ConfigDict()
|
|
config.lr = 0.01
|
|
config.schedule = dict(decay_type="linear")
|
|
config.optax_name = "scale"
|
|
config.optax = ml_collections.ConfigDict()
|
|
g_scale = 0.5
|
|
config.optax.step_size = g_scale
|
|
|
|
total_steps = 10
|
|
sched_kw = dict(global_batch_size=1, total_steps=total_steps)
|
|
tx, (schedule_fn,) = bv_optax.make(config, params, sched_kw=sched_kw)
|
|
opt_state = tx.init(params)
|
|
grads = jax.tree.map(jnp.ones_like, params)
|
|
for step in range(total_steps):
|
|
updates, opt_state = tx.update(grads, opt_state)
|
|
self.assertEqual(bv_optax.get_count(opt_state), step + 1)
|
|
sched = schedule_fn(step)
|
|
np.testing.assert_almost_equal(
|
|
sched, 1.0 / total_steps * (total_steps - step))
|
|
make_tx = lambda sched: lambda g: -sched * config.lr * g_scale * g
|
|
chex.assert_trees_all_close(updates, jax.tree.map(make_tx(sched), grads))
|
|
|
|
def test_make_wd(self):
|
|
params = jax.tree.map(jnp.array, {
|
|
"Dense_0": {"kernel": 1., "bias": 2., "other": 3.},
|
|
})
|
|
wds = jax.tree.map(jnp.array, {
|
|
"Dense_0": {"kernel": 2e-3, "bias": 5e-4, "other": 0.},
|
|
})
|
|
|
|
config = ml_collections.ConfigDict()
|
|
config.lr = 0.01
|
|
config.wd = 1e-3
|
|
config.wd_mults = [
|
|
(".*/kernel", 2.0),
|
|
(".*/bias", 0.5),
|
|
]
|
|
config.schedule = dict(decay_type="linear")
|
|
config.optax_name = "scale"
|
|
config.optax = ml_collections.ConfigDict()
|
|
g_scale = 0.5
|
|
config.optax.step_size = g_scale
|
|
|
|
total_steps = 10
|
|
sched_kw = dict(global_batch_size=1, total_steps=total_steps)
|
|
tx, (sched_fn,) = bv_optax.make(config, params, sched_kw=sched_kw)
|
|
opt_state = tx.init(params)
|
|
grads = jax.tree.map(jnp.ones_like, params)
|
|
for step in range(total_steps):
|
|
updates, opt_state = tx.update(grads, opt_state, params)
|
|
self.assertEqual(bv_optax.get_count(opt_state), step + 1)
|
|
sched = sched_fn(step)
|
|
np.testing.assert_almost_equal(
|
|
sched, 1.0 / total_steps * (total_steps - step))
|
|
|
|
def make_tx(sched):
|
|
def inner(p, g, wd):
|
|
return -sched * (config.lr * g_scale * g + p * wd)
|
|
return inner
|
|
|
|
chex.assert_trees_all_close(
|
|
updates, jax.tree.map(make_tx(sched), params, grads, wds))
|
|
|
|
def test_make_clip_norm(self):
|
|
params = jax.tree.map(jnp.array, {
|
|
"Dense_0": {"kernel": 1., "bias": 2., "other": 3.},
|
|
})
|
|
|
|
config = ml_collections.ConfigDict()
|
|
config.lr = 0.01
|
|
config.schedule = dict(decay_type="linear")
|
|
config.optax_name = "scale"
|
|
config.grad_clip_norm = 1.0
|
|
config.optax = ml_collections.ConfigDict()
|
|
g_scale = 0.5
|
|
config.optax.step_size = g_scale
|
|
|
|
total_steps = 10
|
|
sched_kw = dict(global_batch_size=1, total_steps=total_steps)
|
|
tx, (sched_fn,) = bv_optax.make(config, params, sched_kw=sched_kw)
|
|
opt_state = tx.init(params)
|
|
|
|
grads = jax.tree.map(jnp.ones_like, params)
|
|
gflat = jax.tree.leaves(grads)
|
|
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in gflat]))
|
|
grad_clip_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
|
|
grads_scaled = jax.tree.map(lambda p: grad_clip_factor * p, grads)
|
|
|
|
for step in range(total_steps):
|
|
updates, opt_state = tx.update(grads, opt_state)
|
|
self.assertEqual(bv_optax.get_count(opt_state), step + 1)
|
|
sched = sched_fn(step)
|
|
np.testing.assert_almost_equal(
|
|
sched, 1.0 / total_steps * (total_steps - step))
|
|
make_tx = lambda sched: lambda g: -sched * config.lr * g_scale * g
|
|
chex.assert_trees_all_close(updates,
|
|
jax.tree.map(make_tx(sched), grads_scaled))
|
|
|
|
def test_make_multi(self):
|
|
params = jax.tree.map(
|
|
jnp.array, {
|
|
"Dense_0": {"kernel": 1.0, "bias": 2.0, "other": 3.0},
|
|
"Dense_1": {"kernel": 4.0, "bias": 5.0, "other": 6.0},
|
|
"Dense_2": {"kernel": 7.0, "bias": 8.0, "other": 9.0},
|
|
"Dense_3": {"kernel": 10., "bias": 11., "other": 12.},
|
|
})
|
|
|
|
|
|
lrb = 0.01
|
|
lr1 = 2.0
|
|
lr2 = 0.5
|
|
lr_mults = {
|
|
"Dense_0": {"kernel": lr1, "bias": lr1, "other": lr1},
|
|
"Dense_1": {"kernel": lr2, "bias": lr2, "other": lr2},
|
|
"Dense_2": {"kernel": 1.0, "bias": 1.0, "other": 1.0},
|
|
"Dense_3": {"kernel": 1.0, "bias": 1.0, "other": 1.0},
|
|
}
|
|
wdb = 1e-3
|
|
wd1 = 10.0
|
|
wd2 = 0.1
|
|
wds = jax.tree.map(
|
|
jnp.array, {
|
|
"Dense_0": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.},
|
|
"Dense_1": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.},
|
|
"Dense_2": {"kernel": wd1 * wdb, "bias": wd2 * wdb, "other": 0.},
|
|
"Dense_3": {"kernel": 0.0 * wdb, "bias": 0.0 * wdb, "other": 0.},
|
|
})
|
|
|
|
config = ml_collections.ConfigDict()
|
|
config.lr = lrb
|
|
config.lr_mults = [
|
|
("Dense_0/.*", lr1),
|
|
("Dense_1/.*", lr2),
|
|
]
|
|
config.wd = wdb
|
|
config.wd_mults = [
|
|
(".*/kernel", wd1),
|
|
(".*/bias", wd2),
|
|
]
|
|
mult1 = 1.0
|
|
mult2 = 0.1
|
|
config.schedule = [
|
|
("Dense_0/.*", dict(decay_type="linear", mult=mult1, linear_end=mult1)),
|
|
("Dense_[12]/.*", dict(decay_type="linear", mult=mult2)),
|
|
(".*", None),
|
|
]
|
|
config.optax_name = "scale"
|
|
config.grad_clip_norm = 1.0
|
|
config.optax = ml_collections.ConfigDict()
|
|
g_scale = 0.5
|
|
config.optax.step_size = g_scale
|
|
|
|
total_steps = 10
|
|
sched_kw = dict(global_batch_size=1, total_steps=total_steps)
|
|
tx, (sched_fn1,
|
|
sched_fn2) = bv_optax.make(config, params, sched_kw=sched_kw)
|
|
opt_state = tx.init(params)
|
|
|
|
|
|
frozen_fn = lambda _: jnp.array(0.)
|
|
sched_fns = {
|
|
"Dense_0": {"kernel": sched_fn1, "bias": sched_fn1, "other": sched_fn1},
|
|
"Dense_1": {"kernel": sched_fn2, "bias": sched_fn2, "other": sched_fn2},
|
|
"Dense_2": {"kernel": sched_fn2, "bias": sched_fn2, "other": sched_fn2},
|
|
"Dense_3": {"kernel": frozen_fn, "bias": frozen_fn, "other": frozen_fn},
|
|
}
|
|
|
|
grads = jax.tree.map(jnp.ones_like, params)
|
|
gflat, _ = jax.tree.flatten(
|
|
|
|
jax.tree.map(lambda g, sched_fn: {frozen_fn: 0}.get(sched_fn, g),
|
|
grads, sched_fns))
|
|
l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in gflat]))
|
|
grad_clip_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
|
|
grads_scaled = jax.tree.map(lambda p: grad_clip_factor * p, grads)
|
|
|
|
def make_tx(step):
|
|
def get_update(p, g, wd, sched_fn, lr_mult):
|
|
return -sched_fn(step) * (lrb * lr_mult * g_scale * g + p * wd)
|
|
return get_update
|
|
|
|
for step in range(total_steps):
|
|
updates, opt_state = tx.update(grads, opt_state, params)
|
|
self.assertEqual(bv_optax.get_count(opt_state), step + 1)
|
|
sched1, sched2 = sched_fn1(step), sched_fn2(step)
|
|
np.testing.assert_almost_equal(sched1, mult1)
|
|
np.testing.assert_almost_equal(sched2,
|
|
mult2 * (total_steps - step) / total_steps)
|
|
chex.assert_trees_all_close(
|
|
updates,
|
|
jax.tree.map(
|
|
make_tx(step), params, grads_scaled, wds, sched_fns, lr_mults))
|
|
|
|
def test_frozen_no_state(self):
|
|
params = {"small": jnp.zeros([1]), "large": jnp.zeros([1000])}
|
|
config = ml_collections.ConfigDict()
|
|
config.lr = 0.01
|
|
config.schedule = [
|
|
("small", dict(decay_type="cosine")),
|
|
("large", None),
|
|
]
|
|
config.optax_name = "scale_by_adam"
|
|
|
|
sched_kw = dict(global_batch_size=1, total_steps=1)
|
|
tx, _ = bv_optax.make(config, params, sched_kw=sched_kw)
|
|
|
|
opt_state = tx.init(params)
|
|
adam_state = bv_optax.find_states(opt_state, optax.ScaleByAdamState)
|
|
nbytes = sum(
|
|
jax.tree.flatten(jax.tree.map(lambda x: x.nbytes, adam_state))[0])
|
|
self.assertLess(nbytes, 1_000)
|
|
|
|
def test_adafactor(self):
|
|
params = {"Dense_0": {"kernel": jnp.zeros([1024, 1024])}}
|
|
|
|
config = ml_collections.ConfigDict()
|
|
config.optax_name = "big_vision.scale_by_adafactor"
|
|
config.lr = 0.01
|
|
config.schedule = dict(decay_type="linear")
|
|
sched_kw = dict(global_batch_size=1, total_steps=1)
|
|
|
|
tx, _ = bv_optax.make(config, params, sched_kw=sched_kw)
|
|
|
|
opt_state = tx.init(params)
|
|
adafactor_state = bv_optax.find_states(opt_state, optax.FactoredState)
|
|
n_state_params = sum(
|
|
jax.tree.flatten(
|
|
jax.tree.map(lambda x: np.prod(
|
|
x.shape if hasattr(x, "shape") else 0), adafactor_state))[0])
|
|
self.assertEqual(n_state_params, 2 * 1024 + 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main()
|
|
|