|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Train loop for training the stage-II model."""
|
|
|
|
import functools
|
|
import importlib
|
|
import multiprocessing.pool
|
|
import os
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
from absl import logging
|
|
from big_vision import input_pipeline
|
|
import big_vision.datasets.core as ds_core
|
|
import big_vision.evaluators.common as eval_common
|
|
import big_vision.models.proj.uvim.decode as decode
|
|
import big_vision.optax as bv_optax
|
|
import big_vision.pp.builder as pp_builder
|
|
import big_vision.utils as u
|
|
from clu import parameter_overview
|
|
import flax
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from ml_collections import config_flags
|
|
import numpy as np
|
|
import optax
|
|
|
|
import tensorflow.io.gfile as gfile
|
|
|
|
|
|
config_flags.DEFINE_config_file(
|
|
"config", None, "Training configuration.", lock_config=True)
|
|
|
|
flags.DEFINE_string("workdir", default=None, help="Work unit directory.")
|
|
flags.DEFINE_boolean("cleanup", default=False,
|
|
help="Delete workdir (only) after successful completion.")
|
|
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
FLAGS = flags.FLAGS
|
|
ONE_HOT_AXIS = -2
|
|
partial = functools.partial
|
|
|
|
|
|
def get_model(config):
|
|
mod = importlib.import_module(f"big_vision.models.{config.model_name}")
|
|
model = mod.Model(**config.model)
|
|
return model, mod
|
|
|
|
|
|
def setup_task(config):
|
|
"""Get functions and params to encode and decode labels as token sequences."""
|
|
config = config.oracle
|
|
|
|
|
|
task_module = importlib.import_module(f"big_vision.trainers.{config.task}")
|
|
input_fn = partial(task_module.input_pp, config=config)
|
|
predict_outputs_fn = partial(task_module.predict_outputs, config=config)
|
|
|
|
oracle, mod = get_model(config)
|
|
if config.get("model_init", None):
|
|
params, state = mod.load(None, config.model_init)
|
|
params = {"params": params, "state": state}
|
|
else:
|
|
params = {}
|
|
|
|
def encode_labels(params, batch):
|
|
inputs = input_fn(batch)
|
|
code = oracle.apply(params, **inputs, method=oracle.encode)[1]["code"]
|
|
return code + 1
|
|
|
|
def decode_labels(params, code, batch, **kwargs):
|
|
code = code - 1
|
|
inputs = input_fn(batch)
|
|
inputs["x"] = code
|
|
logits, _ = oracle.apply(
|
|
params, **inputs, discrete_input=True, **kwargs, method=oracle.decode)
|
|
return logits
|
|
|
|
return encode_labels, decode_labels, predict_outputs_fn, params
|
|
|
|
|
|
def main(argv):
|
|
del argv
|
|
|
|
config = FLAGS.config
|
|
workdir = FLAGS.workdir
|
|
logging.info("\u001b[33mHello from process %i holding %i/%i devices and "
|
|
"writing to workdir %s.\u001b[0m", jax.process_index(),
|
|
jax.local_device_count(), jax.device_count(), workdir)
|
|
|
|
save_ckpt_path = None
|
|
if workdir:
|
|
gfile.makedirs(workdir)
|
|
save_ckpt_path = os.path.join(workdir, "checkpoint.npz")
|
|
|
|
|
|
pool = multiprocessing.pool.ThreadPool()
|
|
|
|
|
|
for m in config.get("pp_modules",
|
|
["ops_general", "ops_image", "proj.uvim.pp_ops"]):
|
|
importlib.import_module(f"big_vision.pp.{m}")
|
|
|
|
|
|
|
|
|
|
|
|
rng = jax.random.PRNGKey(config.get("seed", 0))
|
|
|
|
|
|
|
|
xid, wid = -1, -1
|
|
fillin = lambda s: s
|
|
def info(s, *a):
|
|
logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
|
|
def write_note(note):
|
|
if jax.process_index() == 0:
|
|
info("%s", note)
|
|
|
|
write_note("Initializing...")
|
|
|
|
batch_size = config.input.batch_size
|
|
if batch_size % jax.device_count() != 0:
|
|
raise ValueError(f"Batch size ({batch_size}) must "
|
|
f"be divisible by device number ({jax.device_count()})")
|
|
info("Global batch size %d on %d hosts results in %d local batch size. With "
|
|
"%d dev per host (%d dev total), that's a %d per-device batch size.",
|
|
batch_size, jax.process_count(), batch_size // jax.process_count(),
|
|
jax.local_device_count(), jax.device_count(),
|
|
batch_size // jax.device_count())
|
|
|
|
|
|
mw = u.BigVisionMetricWriter(xid, wid, workdir, config)
|
|
chrono = u.Chrono()
|
|
|
|
write_note("Initializing train dataset...")
|
|
train_data = ds_core.get(**config.input.data)
|
|
train_ds = input_pipeline.make_for_train(
|
|
data=train_data.get_tfdata(ordered=False),
|
|
batch_size=batch_size,
|
|
preprocess_fn=pp_builder.get_preprocess_fn(config.input.get("pp")),
|
|
shuffle_buffer_size=config.input.get("shuffle_buffer_size"),
|
|
cache_raw=config.input.get("cache_raw", False),
|
|
filter_fn=config.input.get("filter_fn"),
|
|
)
|
|
|
|
|
|
n_prefetch = config.get("prefetch_to_device", 1)
|
|
train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch)
|
|
ntrain_img = train_data.total_examples
|
|
|
|
def get_steps(name, default=ValueError):
|
|
return u.steps(name, config, ntrain_img, batch_size, default)
|
|
total_steps = get_steps("total")
|
|
|
|
info("Running for %d steps, that means %f epochs",
|
|
total_steps, total_steps * batch_size / ntrain_img)
|
|
|
|
write_note(f"Initializing {config.model_name} model...")
|
|
model, model_mod = get_model(config)
|
|
|
|
encode_labels, decode_labels, predict_outputs_fn, task_params = (
|
|
setup_task(config))
|
|
|
|
|
|
|
|
|
|
@partial(jax.jit, backend="cpu")
|
|
def init(rng):
|
|
batch = jax.tree_map(
|
|
lambda x: jnp.zeros(x.shape, x.dtype.as_numpy_dtype),
|
|
train_ds.element_spec)
|
|
images = batch["image"]
|
|
labels = encode_labels(task_params, batch)
|
|
variables = model.init(rng, images, labels)
|
|
params = flax.core.unfreeze(variables["params"])
|
|
return params
|
|
|
|
rng, init_rng = jax.random.split(rng)
|
|
params_cpu = init(init_rng)
|
|
|
|
if jax.process_index() == 0:
|
|
num_params = sum(p.size for p in jax.tree_leaves(params_cpu))
|
|
parameter_overview.log_parameter_overview(params_cpu, msg="init params")
|
|
mw.measure("num_params", num_params)
|
|
|
|
write_note(f"Initializing {config.optax_name} optimizer...")
|
|
tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict(
|
|
total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img))
|
|
|
|
|
|
opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu)
|
|
sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns]
|
|
|
|
@partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1))
|
|
def update_fn(params, opt, batch, update_rng, task_params):
|
|
"""Update step."""
|
|
images = batch["image"]
|
|
labels = encode_labels(task_params, batch)
|
|
|
|
measurements = {}
|
|
|
|
rng, new_rng = jax.random.split(update_rng)
|
|
|
|
rng_local = jax.random.fold_in(rng, jax.lax.axis_index("batch"))
|
|
|
|
def loss_fn(params, images, labels):
|
|
logits = model.apply({"params": params}, images, labels, train=True,
|
|
rngs={"dropout": rng_local})
|
|
loss = u.weighted_softmax_xent(
|
|
logits=logits, labels=labels,
|
|
reduction=True, normalize=True)
|
|
return loss
|
|
|
|
l, grads = jax.value_and_grad(loss_fn)(params, images, labels)
|
|
l, grads = jax.lax.pmean((l, grads), axis_name="batch")
|
|
updates, opt = tx.update(grads, opt, params)
|
|
params = optax.apply_updates(params, updates)
|
|
|
|
gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.))
|
|
measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs]))
|
|
ps = jax.tree_leaves(params)
|
|
measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps]))
|
|
us = jax.tree_leaves(updates)
|
|
measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us]))
|
|
|
|
return params, opt, l, new_rng, measurements
|
|
|
|
|
|
def validation_fn(params, batch):
|
|
"""Compute per-example metrics."""
|
|
params, task_params = params["params"], params["task_params"]
|
|
images = batch["image"]
|
|
labels = encode_labels(task_params, batch)
|
|
logits = model.apply({"params": params}, images, labels, train=False)
|
|
loss = u.weighted_softmax_xent(
|
|
logits=logits, labels=labels,
|
|
reduction=False, normalize=True)
|
|
losses = {"loss": loss}
|
|
return jax.tree_map(
|
|
lambda x: jnp.mean(x, axis=tuple(range(1, x.ndim))),
|
|
losses)
|
|
|
|
def predict_fn(params, batch, seed=0, temperature=1e-7, **extra):
|
|
params, task_params = params["params"], params["task_params"]
|
|
|
|
|
|
if "image/id" in batch:
|
|
key = batch["image/id"]
|
|
else:
|
|
key = batch["image"].sum(axis=[1, 2, 3]).astype(jnp.int32)
|
|
local_rng = jax.lax.scan(
|
|
lambda k, x: (jax.random.fold_in(k, x), None),
|
|
jax.random.PRNGKey(seed),
|
|
key,
|
|
)[0]
|
|
|
|
images = batch["image"]
|
|
batch_size = images.shape[0]
|
|
prompts = jnp.zeros((batch_size, config.model.seq_len), dtype=jnp.int32)
|
|
seqs, _, _ = decode.temperature_sampling(
|
|
params={"params": params}, model=model, seed=local_rng,
|
|
inputs=images, prompts=prompts,
|
|
num_samples=1, eos_token=-1, prefill=False,
|
|
temperature=temperature)
|
|
seqs = jnp.squeeze(seqs, 1)
|
|
logits = decode_labels(task_params, seqs, batch)
|
|
return predict_outputs_fn(logits, **extra)
|
|
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def evaluators():
|
|
return eval_common.from_config(
|
|
config, {"predict": predict_fn, "validation": validation_fn},
|
|
lambda s: write_note(f"Initializing evaluator: {s}...\n{chrono.note}")
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resume_ckpt_path = None
|
|
if save_ckpt_path and gfile.exists(save_ckpt_path):
|
|
resume_ckpt_path = save_ckpt_path
|
|
elif config.get("resume"):
|
|
resume_ckpt_path = fillin(config.resume)
|
|
if resume_ckpt_path:
|
|
write_note("Resume training from checkpoint...")
|
|
checkpoint = {
|
|
"params": params_cpu,
|
|
"opt": opt_cpu,
|
|
"chrono": chrono.save(),
|
|
}
|
|
checkpoint_tree = jax.tree_structure(checkpoint)
|
|
loaded = u.load_checkpoint(checkpoint_tree, resume_ckpt_path)
|
|
|
|
checkpoint = jax.tree_map(u.recover_dtype, loaded)
|
|
params_cpu, opt_cpu = checkpoint["params"], checkpoint["opt"]
|
|
chrono.load(checkpoint["chrono"])
|
|
elif config.get("model_init"):
|
|
write_note(f"Initialize model from {config.model_init}...")
|
|
params_cpu = model_mod.load(
|
|
params_cpu, config.model_init, config.model,
|
|
**config.get("model_load", {}))
|
|
if jax.process_index() == 0:
|
|
parameter_overview.log_parameter_overview(
|
|
params_cpu, msg="restored params")
|
|
|
|
write_note("Kicking off misc stuff...")
|
|
first_step = bv_optax.get_count(opt_cpu)
|
|
chrono.inform(first_step, total_steps, batch_size, ntrain_img / batch_size)
|
|
prof = None
|
|
|
|
write_note(f"Replicating...\n{chrono.note}")
|
|
params_repl = flax.jax_utils.replicate(params_cpu)
|
|
opt_repl = flax.jax_utils.replicate(opt_cpu)
|
|
task_params = flax.jax_utils.replicate(task_params)
|
|
update_rngs = flax.jax_utils.replicate(rng)
|
|
|
|
ckpt_writer = None
|
|
|
|
write_note(f"First step compilations...\n{chrono.note}")
|
|
error = None
|
|
|
|
|
|
|
|
for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter):
|
|
mw.step_start(step)
|
|
|
|
with jax.profiler.StepTraceAnnotation("train_step", step_num=step):
|
|
params_repl, opt_repl, loss_value, update_rngs, measurements = (
|
|
update_fn(
|
|
params_repl,
|
|
opt_repl,
|
|
batch,
|
|
update_rng=update_rngs,
|
|
task_params=task_params))
|
|
|
|
|
|
if jax.process_index() == 0:
|
|
prof = u.startstop_prof(prof, step, first_step, get_steps("log_training"))
|
|
|
|
|
|
if (u.itstime(step, get_steps("log_training"), total_steps, host=0)
|
|
or chrono.warmup and jax.process_index() == 0):
|
|
for i, sched_fn_cpu in enumerate(sched_fns_cpu):
|
|
mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1))
|
|
l = mw.measure("training_loss", loss_value[0])
|
|
for name, value in measurements.items():
|
|
mw.measure(name, value[0])
|
|
chrono.tick(step, mw.measure, write_note)
|
|
if not np.isfinite(l):
|
|
error = (f"The loss became nan or inf somewhere within steps "
|
|
f"[{step - get_steps('log_training')}, {step}]")
|
|
break
|
|
|
|
|
|
if (save_ckpt_path and
|
|
(u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or
|
|
u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))):
|
|
chrono.pause(wait_for=(params_repl, opt_repl))
|
|
u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1))
|
|
|
|
|
|
|
|
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
|
|
params_cpu = jax.tree_map(lambda x: np.array(x[0]), params_repl)
|
|
|
|
|
|
copy_step = None
|
|
if u.itstime(step, get_steps("keep_ckpt", None), total_steps):
|
|
copy_step = step
|
|
|
|
ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": chrono.save()}
|
|
ckpt_writer = pool.apply_async(
|
|
u.save_checkpoint, (ckpt, save_ckpt_path, copy_step))
|
|
chrono.resume()
|
|
|
|
for (name, evaluator, log_steps, prefix) in evaluators():
|
|
if u.itstime(step, log_steps, total_steps, first=log_steps < total_steps,
|
|
last=False):
|
|
chrono.pause(wait_for=(params_repl, task_params))
|
|
write_note(f"{name} evaluation...\n{chrono.note}")
|
|
for key, value in evaluator.run(
|
|
{"params": params_repl, "task_params": task_params}):
|
|
mw.measure(f"{prefix}{key}", value)
|
|
chrono.resume()
|
|
mw.step_end()
|
|
|
|
|
|
|
|
if jax.process_index() == 0 and prof is not None:
|
|
u.startstop_prof(prof)
|
|
|
|
|
|
for (name, evaluator, _, prefix) in evaluators():
|
|
write_note(f"{name} evaluation...\n{chrono.note}")
|
|
for key, value in evaluator.run(
|
|
{"params": params_repl, "task_params": task_params}):
|
|
mw.measure(f"{prefix}{key}", value)
|
|
|
|
|
|
if not error:
|
|
write_note(f"Done!\n{chrono.note}")
|
|
else:
|
|
write_note(f"Failed!\n{error}\n{chrono.note}")
|
|
|
|
pool.close()
|
|
pool.join()
|
|
mw.close()
|
|
|
|
|
|
u.sync()
|
|
|
|
|
|
if error is not None:
|
|
raise RuntimeError(error)
|
|
|
|
u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(main)
|
|
|