|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Training loop for GIVT-style autoregressive and masked models."""
|
|
|
|
|
|
import functools
|
|
import importlib
|
|
import multiprocessing.pool
|
|
import os
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
from absl import logging
|
|
import big_vision.evaluators.common as eval_common
|
|
import big_vision.input_pipeline as input_pipeline
|
|
from big_vision.models.proj.givt import parallel_decode
|
|
import big_vision.models.proj.givt.decode as softar_decode
|
|
import big_vision.optax as bv_optax
|
|
import big_vision.sharding as bv_sharding
|
|
import big_vision.trainers.proj.givt.utils as trainer_utils
|
|
from big_vision.trainers.proj.uvim import panoptic_task
|
|
import big_vision.utils as u
|
|
from clu import parameter_overview
|
|
import flax
|
|
import jax
|
|
from jax.experimental import mesh_utils
|
|
from jax.experimental import multihost_utils
|
|
from jax.experimental.array_serialization import serialization as array_serial
|
|
import jax.numpy as jnp
|
|
from ml_collections import config_flags
|
|
import numpy as np
|
|
import optax
|
|
import tensorflow as tf
|
|
|
|
from tensorflow.io import gfile
|
|
|
|
|
|
|
|
|
|
config_flags.DEFINE_config_file(
|
|
"config", None, "Training configuration.", lock_config=True)
|
|
|
|
flags.DEFINE_string("workdir", default=None, help="Work unit directory.")
|
|
flags.DEFINE_boolean("cleanup", default=False,
|
|
help="Delete workdir (only) after successful completion.")
|
|
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
|
|
|
|
jax.config.update("jax_transfer_guard", "disallow")
|
|
|
|
jax.config.update("jax_threefry_partitionable", True)
|
|
|
|
|
|
NamedSharding = jax.sharding.NamedSharding
|
|
P = jax.sharding.PartitionSpec
|
|
|
|
|
|
def main(argv):
|
|
del argv
|
|
|
|
jax.distributed.initialize()
|
|
|
|
|
|
tf.config.set_visible_devices([], "GPU")
|
|
|
|
config = flags.FLAGS.config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
workdir = flags.FLAGS.workdir
|
|
logging.info(
|
|
f"\u001b[33mHello from process {jax.process_index()} holding "
|
|
f"{jax.local_device_count()}/{jax.device_count()} devices and "
|
|
f"writing to workdir {workdir}.\u001b[0m")
|
|
|
|
save_ckpt_path = None
|
|
if workdir:
|
|
gfile.makedirs(workdir)
|
|
save_ckpt_path = os.path.join(workdir, "checkpoint.bv")
|
|
|
|
|
|
pool = multiprocessing.pool.ThreadPool()
|
|
|
|
|
|
for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text",
|
|
"proj.uvim.pp_ops", "proj.givt.pp_ops"]):
|
|
importlib.import_module(f"big_vision.pp.{m}")
|
|
|
|
|
|
xid, wid = -1, -1
|
|
def info(s, *a):
|
|
logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
|
|
def write_note(note):
|
|
if jax.process_index() == 0:
|
|
info("%s", note)
|
|
|
|
mw = u.BigVisionMetricWriter(xid, wid, workdir, config)
|
|
|
|
|
|
u.chrono.inform(measure=mw.measure, write_note=write_note)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config_mesh = config.get("mesh", [("data", jax.device_count())])
|
|
|
|
|
|
sharding_rules = config.get("sharding_rules", [("act_batch", "data")])
|
|
|
|
mesh_axes, mesh_size = tuple(zip(*config_mesh))
|
|
|
|
|
|
mesh_size = np.array(jax.devices()).reshape(mesh_size).shape
|
|
|
|
device_mesh = mesh_utils.create_device_mesh(mesh_size)
|
|
|
|
|
|
|
|
|
|
|
|
devices_flat = device_mesh.flatten()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
write_note("Initializing train dataset...")
|
|
batch_size = config.input.batch_size
|
|
if batch_size % jax.device_count() != 0:
|
|
raise ValueError(f"Batch size ({batch_size}) must "
|
|
f"be divisible by device number ({jax.device_count()})")
|
|
info("Global batch size %d on %d hosts results in %d local batch size. With "
|
|
"%d dev per host (%d dev total), that's a %d per-device batch size.",
|
|
batch_size, jax.process_count(), batch_size // jax.process_count(),
|
|
jax.local_device_count(), jax.device_count(),
|
|
batch_size // jax.device_count())
|
|
|
|
train_ds, ntrain_img = input_pipeline.training(config.input)
|
|
|
|
total_steps = u.steps("total", config, ntrain_img, batch_size)
|
|
def get_steps(name, default=ValueError, cfg=config):
|
|
return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default)
|
|
|
|
u.chrono.inform(total_steps=total_steps, global_bs=batch_size,
|
|
steps_per_epoch=ntrain_img / batch_size)
|
|
|
|
info("Running for %d steps, that means %f epochs",
|
|
total_steps, total_steps * batch_size / ntrain_img)
|
|
|
|
|
|
n_prefetch = config.get("prefetch_to_device", 1)
|
|
train_iter = input_pipeline.start_global(train_ds, devices_flat, n_prefetch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
write_note(f"Creating {config.vae.model_name} model...")
|
|
vae_mod = importlib.import_module(
|
|
f"big_vision.models.{config.vae.model_name}")
|
|
vae = vae_mod.Model(**config.vae.get("model", {}))
|
|
|
|
write_note(f"Creating {config.model_name} model...")
|
|
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}")
|
|
model_config = config.get("model", {})
|
|
model = model_mod.Model(**model_config)
|
|
|
|
if config.get("adaptor_name"):
|
|
write_note(f"Creating {config.adaptor_name} model...")
|
|
adaptor_mod = importlib.import_module(
|
|
f"big_vision.models.{config.adaptor_name}")
|
|
adaptor = adaptor_mod.Model(num_channels=model_config.out_dim,
|
|
**config.adaptor.model)
|
|
else:
|
|
adaptor = None
|
|
|
|
def init(rng):
|
|
def _get_dummy_input(input_name, dtype=jnp.int64):
|
|
if input_name in train_ds.element_spec:
|
|
return jnp.zeros(train_ds.element_spec[input_name].shape, dtype=dtype)
|
|
return None
|
|
|
|
dummy_img = _get_dummy_input("image", dtype=jnp.float32)
|
|
dummy_labels = _get_dummy_input("labels")
|
|
dummy_cond_img = _get_dummy_input("cond_image", dtype=jnp.float32)
|
|
local_batch_size = dummy_img.shape[0]
|
|
|
|
code_shape = (
|
|
local_batch_size, model_config.seq_len, model_config.out_dim)
|
|
dummy_code = jnp.zeros(code_shape, jnp.float32)
|
|
|
|
input_mask = model.get_input_mask_training(
|
|
jax.random.PRNGKey(0), (local_batch_size, model_config.seq_len)
|
|
)
|
|
params = model.init(rng, dummy_code, dummy_labels, image=dummy_cond_img,
|
|
input_mask=input_mask)["params"]
|
|
|
|
if adaptor is not None:
|
|
_, rng_adaptor = jax.random.split(rng)
|
|
adaptor_variables = adaptor.init(rng_adaptor, dummy_code)
|
|
params_adaptor = flax.core.unfreeze(adaptor_variables["params"])
|
|
params["params_adaptor"] = params_adaptor
|
|
|
|
return params
|
|
|
|
rng = jax.random.PRNGKey(u.put_cpu(config.get("seed", 0)))
|
|
|
|
write_note("Inferring parameter shapes...")
|
|
rng, rng_init = jax.random.split(rng)
|
|
params_shape = jax.eval_shape(init, rng_init)
|
|
|
|
write_note("Inferring optimizer state shapes...")
|
|
tx, sched_fns = bv_optax.make(config, params_shape, sched_kw=dict(
|
|
total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img))
|
|
opt_shape = jax.eval_shape(tx.init, params_shape)
|
|
|
|
sched_fns_cpu = [u.jit_cpu()(sched_fn) for sched_fn in sched_fns]
|
|
|
|
|
|
|
|
assert "model_init" in config.vae
|
|
params_vae = vae_mod.load(None, config.vae.model_init,
|
|
**config.vae.get("model_load", {}))
|
|
|
|
def vae_encode(images, rng=None, reparametrize=True):
|
|
mu, logvar = vae.apply({"params": params_vae}, images, method=vae.encode)
|
|
if reparametrize:
|
|
assert rng is not None and "dropout" in rng
|
|
return vae.apply({"params": params_vae}, mu, logvar,
|
|
method=vae.reparametrize, rngs=rng)
|
|
return mu
|
|
|
|
if jax.process_index() == 0:
|
|
num_params = sum(np.prod(p.shape) for p in jax.tree_leaves(params_shape))
|
|
mw.measure("num_params", num_params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
write_note("Creating device mesh...")
|
|
mesh = jax.sharding.Mesh(device_mesh, mesh_axes)
|
|
repl_sharding = jax.sharding.NamedSharding(mesh, P())
|
|
|
|
write_note("Inferring shardings...")
|
|
train_state_shape = {"params": params_shape, "opt": opt_shape}
|
|
|
|
strategy = config.get("sharding_strategy", [(".*", "replicate")])
|
|
train_state_sharding = bv_sharding.infer_sharding(
|
|
train_state_shape, strategy=strategy, mesh=mesh)
|
|
|
|
write_note("Transferring train_state to devices...")
|
|
|
|
rng_init = u.reshard(rng_init, repl_sharding)
|
|
|
|
|
|
params = jax.jit(init, out_shardings=train_state_sharding["params"])(rng_init)
|
|
opt = jax.jit(tx.init, out_shardings=train_state_sharding["opt"])(params)
|
|
|
|
rng, rng_loop = jax.random.split(rng, 2)
|
|
rng_loop = u.reshard(rng_loop, repl_sharding)
|
|
del rng
|
|
|
|
|
|
|
|
train_state = {"params": params, "opt": opt}
|
|
del params, opt
|
|
|
|
write_note("Logging parameter overview...")
|
|
parameter_overview.log_parameter_overview(
|
|
train_state["params"], msg="Init params",
|
|
include_stats="global", jax_logging_process=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def loss_fn(params, images, labels, cond_images, rng):
|
|
rng, rng_dropout = jax.random.split(rng, 2)
|
|
rng, rng_mask = jax.random.split(rng, 2)
|
|
_, rng_droplabels = jax.random.split(rng, 2)
|
|
|
|
rng_dropout = {"dropout": rng_dropout}
|
|
|
|
sequence = vae_encode(images, rng_dropout)
|
|
if adaptor is not None:
|
|
|
|
sequence = adaptor.apply({"params": params["params_adaptor"]},
|
|
sequence, method=adaptor.forward)
|
|
|
|
b, s, _ = sequence.shape
|
|
|
|
input_mask = model.get_input_mask_training(rng_mask, (b, s))
|
|
drop_labels = model.get_drop_labels(rng_droplabels, batch_size=b)
|
|
|
|
_, pdf = model.apply(
|
|
{"params": params}, sequence, labels,
|
|
image=cond_images,
|
|
train=True,
|
|
input_mask=input_mask,
|
|
drop_labels=drop_labels,
|
|
rngs=rng_dropout)
|
|
|
|
|
|
nll = -pdf.log_prob(sequence)
|
|
metrics = {"nll": nll}
|
|
if input_mask is not None:
|
|
metrics["fraction_masked_out"] = input_mask.astype(jnp.float32).mean(
|
|
axis=1
|
|
)
|
|
if nll.ndim == 3:
|
|
input_mask = input_mask[:, :, None]
|
|
|
|
|
|
|
|
nll = jnp.where(input_mask, nll, 0.0)
|
|
|
|
|
|
|
|
loss = nll.mean(where=input_mask)
|
|
else:
|
|
loss = nll.mean()
|
|
|
|
return loss, metrics
|
|
|
|
@functools.partial(
|
|
jax.jit,
|
|
donate_argnums=(0,),
|
|
out_shardings=(train_state_sharding, repl_sharding))
|
|
def update_fn(train_state, rng, batch):
|
|
"""Update step."""
|
|
|
|
images = batch["image"]
|
|
labels, cond_images = batch.get("labels"), batch.get("cond_image")
|
|
|
|
step_count = bv_optax.get_count(train_state["opt"], jittable=True)
|
|
rng = jax.random.fold_in(rng, step_count)
|
|
|
|
measurements = {}
|
|
|
|
|
|
_, rng_model = jax.random.split(rng, 2)
|
|
params, opt = train_state["params"], train_state["opt"]
|
|
|
|
(loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(
|
|
params, images, labels, cond_images, rng_model)
|
|
updates, opt = tx.update(grads, opt, params)
|
|
params = optax.apply_updates(params, updates)
|
|
train_state = {"params": params, "opt": opt}
|
|
|
|
measurements["training_loss"] = loss
|
|
gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.))
|
|
measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs]))
|
|
ps = jax.tree_leaves(params)
|
|
measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps]))
|
|
us = jax.tree_leaves(updates)
|
|
measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us]))
|
|
|
|
if adaptor is not None:
|
|
ps_a = jax.tree_leaves(params["params_adaptor"])
|
|
measurements["l2_params_adaptor"] = jnp.sqrt(sum([jnp.vdot(p, p)
|
|
for p in ps_a]))
|
|
|
|
measurements.update({f"train/{k}": v.mean() for k, v in metrics.items()})
|
|
|
|
return train_state, measurements
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validation_fn(train_state, batch, seed=0):
|
|
params = train_state["params"]
|
|
|
|
local_rng = trainer_utils.get_local_rng(seed, batch)
|
|
|
|
_, aux = loss_fn(
|
|
params, batch["image"], batch.get("labels"),
|
|
batch.get("cond_image"), local_rng)
|
|
return {
|
|
key: jnp.mean(value, axis=tuple(range(1, value.ndim)))
|
|
for key, value in aux.items()
|
|
}
|
|
|
|
def predict_fn_teacher_forcing(train_state, batch, seed=0):
|
|
params = train_state["params"]
|
|
image, labels = batch["image"], batch.get("labels")
|
|
|
|
local_rng = trainer_utils.get_local_rng(seed, batch)
|
|
|
|
rng_dropout = {"dropout": local_rng}
|
|
sequence = vae_encode(image, rng_dropout)
|
|
|
|
if adaptor is not None:
|
|
|
|
sequence = adaptor.apply({"params": params["params_adaptor"]},
|
|
sequence, method=adaptor.forward)
|
|
|
|
b, s, _ = sequence.shape
|
|
|
|
|
|
input_mask = model.get_input_mask_teacher_forced((b, s))
|
|
|
|
_, pdf = model.apply(
|
|
{"params": params}, sequence, labels,
|
|
train=True, input_mask=input_mask, rngs=rng_dropout)
|
|
|
|
rng_sample, _ = jax.random.split(local_rng, 2)
|
|
sampled_sequence = pdf.sample(seed=rng_sample)
|
|
|
|
if adaptor is not None:
|
|
|
|
sampled_sequence = adaptor.apply({"params": params["params_adaptor"]},
|
|
sampled_sequence, method=adaptor.inverse)
|
|
logits = vae.apply(
|
|
{"params": params_vae}, sampled_sequence, method=vae.decode)
|
|
|
|
return {"logits": logits}
|
|
|
|
def predict_fn_rep(train_state, image, seed=0):
|
|
assert model.style == "ar"
|
|
assert model.drop_labels_probability == 1.0
|
|
params = train_state["params"]
|
|
|
|
local_rng = trainer_utils.get_local_rng(seed, batch)
|
|
|
|
rng_dropout = {"dropout": local_rng}
|
|
sequence = vae_encode(image, rng_dropout)
|
|
placeholder_labels = jnp.zeros((sequence.shape[0],), dtype=jnp.int32)
|
|
|
|
return model.apply({"params": params}, sequence, labels=placeholder_labels,
|
|
return_reps=True, method=model.decode)
|
|
|
|
def predict_fn_sampling(train_state, batch, seed=0):
|
|
params = train_state["params"]
|
|
labels = batch.get("labels")
|
|
|
|
local_rng = trainer_utils.get_local_rng(seed, batch)
|
|
code_logprobs = None
|
|
|
|
if model.style == "ar":
|
|
if labels is None:
|
|
|
|
if "image" in batch:
|
|
sampling_batch_size = batch["image"].shape[0]
|
|
elif "cond_image" in batch:
|
|
sampling_batch_size = batch["cond_image"].shape[0]
|
|
else:
|
|
sampling_batch_size = config.get("sampling_batch_size", 4)
|
|
else:
|
|
sampling_batch_size = None
|
|
sampled_codes, code_logprobs = softar_decode.generate(
|
|
params={"params": params},
|
|
seed=local_rng,
|
|
model=model,
|
|
seq_len=config.model.seq_len,
|
|
feature_dim=config.model.out_dim,
|
|
labels=labels,
|
|
cond_image=batch.get("cond_image"),
|
|
batch_size=sampling_batch_size,
|
|
config=config.get("ar_generation_config"),
|
|
)
|
|
elif model.style == "masked":
|
|
assert "cond_image" not in batch
|
|
sampled_codes = parallel_decode.decode_masked(
|
|
rng=local_rng,
|
|
labels=labels,
|
|
seq_len=config.model.seq_len,
|
|
feature_dim=config.model.out_dim,
|
|
model=model,
|
|
variables={"params": params},
|
|
config=parallel_decode.MaskedGenerationConfig(
|
|
**config.get("masked_generation_config", {})
|
|
),
|
|
).current_inputs_q
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if adaptor is not None:
|
|
|
|
sampled_codes = adaptor.apply({"params": params["params_adaptor"]},
|
|
sampled_codes, method=adaptor.inverse)
|
|
|
|
sampled_images = vae.apply(
|
|
{"params": params_vae}, sampled_codes, method=vae.decode)
|
|
|
|
sampling_results = {"logits": sampled_images}
|
|
if code_logprobs is not None:
|
|
sampling_results["logprobs"] = code_logprobs
|
|
|
|
return sampling_results
|
|
|
|
def predict_fn_sampling_panoptic(
|
|
train_state, batch, seed=0, min_fraction=0.0):
|
|
logits = predict_fn_sampling(train_state, batch, seed)["logits"]
|
|
return panoptic_task.panoptic_predictions_from_logits(
|
|
logits["semantics"], logits["instances"], min_fraction=min_fraction)
|
|
|
|
def predict_fn_sampling_depth(train_state, batch, seed=0):
|
|
depth = predict_fn_sampling(train_state, batch, seed)["logits"]["depth"]
|
|
depth = trainer_utils.unbin_depth(
|
|
depth, min_depth=config.min_depth, max_depth=config.max_depth,
|
|
num_bins=config.vae.model.inout_specs["depth"][1])
|
|
return {"depth": depth}
|
|
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def evaluators():
|
|
return eval_common.from_config(
|
|
config,
|
|
{
|
|
"validation": validation_fn,
|
|
"sample_teacher_forced": predict_fn_teacher_forcing,
|
|
"sample": predict_fn_sampling,
|
|
"sample_panoptic": predict_fn_sampling_panoptic,
|
|
"sample_depth": predict_fn_sampling_depth,
|
|
"representation": predict_fn_rep,
|
|
},
|
|
lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"),
|
|
lambda key, cfg: get_steps(key, default=None, cfg=cfg),
|
|
devices_flat,
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resume_ckpt_path = None
|
|
if save_ckpt_path and gfile.exists(f"{save_ckpt_path}-LAST"):
|
|
resume_ckpt_path = save_ckpt_path
|
|
elif config.get("resume"):
|
|
resume_ckpt_path = fillin(config.resume)
|
|
|
|
ckpt_mngr = None
|
|
if save_ckpt_path or resume_ckpt_path:
|
|
ckpt_mngr = array_serial.GlobalAsyncCheckpointManager()
|
|
|
|
if resume_ckpt_path:
|
|
write_note(f"Resuming training from checkpoint {resume_ckpt_path}...")
|
|
jax.tree_map(lambda x: x.delete(), train_state)
|
|
del train_state
|
|
shardings = {
|
|
**train_state_sharding,
|
|
"chrono": jax.tree_map(lambda _: repl_sharding,
|
|
u.chrono.save()),
|
|
}
|
|
loaded = u.load_checkpoint_ts(
|
|
resume_ckpt_path, tree=shardings, shardings=shardings)
|
|
train_state = {key: loaded[key] for key in train_state_sharding.keys()}
|
|
|
|
u.chrono.load(jax.device_get(loaded["chrono"]))
|
|
del loaded
|
|
elif config.get("model_init"):
|
|
write_note(f"Initialize model from {config.model_init}...")
|
|
train_state["params"] = model_mod.load(
|
|
train_state["params"], config.model_init, config.get("model"),
|
|
**config.get("model_load", {}))
|
|
|
|
|
|
train_state["params"] = u.reshard(
|
|
train_state["params"], train_state_sharding["params"])
|
|
|
|
parameter_overview.log_parameter_overview(
|
|
train_state["params"], msg="restored params",
|
|
include_stats="global", jax_logging_process=0)
|
|
|
|
|
|
write_note("Inferring the first step number...")
|
|
first_step_device = bv_optax.get_count(train_state["opt"], jittable=True)
|
|
first_step = int(jax.device_get(first_step_device))
|
|
u.chrono.inform(first_step=first_step)
|
|
|
|
|
|
|
|
|
|
if first_step in (total_steps, 0):
|
|
write_note("Running initial or final evals...")
|
|
mw.step_start(first_step)
|
|
for (name, evaluator, _, prefix) in evaluators():
|
|
if config.evals[name].get("skip_first") and first_step != total_steps:
|
|
continue
|
|
write_note(f"{name} evaluation...\n{u.chrono.note}")
|
|
with u.chrono.log_timing(f"z/secs/eval/{name}"):
|
|
with mesh, flax.linen.logical_axis_rules(sharding_rules):
|
|
for key, value in evaluator.run(train_state):
|
|
mw.measure(f"{prefix}{key}", value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prof = None
|
|
|
|
write_note("Starting training loop, compiling the first step...")
|
|
for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter):
|
|
|
|
if config.get("eval_only", False):
|
|
break
|
|
mw.step_start(step)
|
|
|
|
with jax.profiler.StepTraceAnnotation("train_step", step_num=step):
|
|
with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1):
|
|
with mesh, flax.linen.logical_axis_rules(sharding_rules):
|
|
train_state, measurements = update_fn(train_state, rng_loop, batch)
|
|
|
|
|
|
if jax.process_index() == 0:
|
|
prof = u.startstop_prof(prof, step, first_step, get_steps("log_training"))
|
|
|
|
|
|
if (u.itstime(step, get_steps("log_training"), total_steps, host=0)
|
|
or u.chrono.warmup and jax.process_index() == 0):
|
|
for i, sched_fn_cpu in enumerate(sched_fns_cpu):
|
|
mw.measure(f"global_schedule{i if i else ''}",
|
|
sched_fn_cpu(u.put_cpu(step - 1)))
|
|
measurements = jax.device_get(measurements)
|
|
for name, value in measurements.items():
|
|
mw.measure(name, value)
|
|
u.chrono.tick(step)
|
|
if not np.isfinite(measurements["training_loss"]):
|
|
raise RuntimeError(f"The loss became nan or inf somewhere within steps "
|
|
f"[{step - get_steps('log_training')}, {step}]")
|
|
|
|
|
|
keep_ckpt_steps = get_steps("keep_ckpt", None) or total_steps
|
|
if save_ckpt_path and (
|
|
(keep := u.itstime(step, keep_ckpt_steps, total_steps, first=False))
|
|
or u.itstime(step, get_steps("ckpt", None), total_steps, first=True)
|
|
):
|
|
u.chrono.pause(wait_for=train_state)
|
|
|
|
|
|
ckpt = {**train_state}
|
|
|
|
|
|
|
|
with jax.transfer_guard("allow"):
|
|
chrono_ckpt = multihost_utils.broadcast_one_to_all(u.chrono.save())
|
|
chrono_shardings = jax.tree_map(lambda _: repl_sharding, chrono_ckpt)
|
|
ckpt = ckpt | {"chrono": u.reshard(chrono_ckpt, chrono_shardings)}
|
|
|
|
u.save_checkpoint_ts(ckpt_mngr, ckpt, save_ckpt_path, step, keep)
|
|
u.chrono.resume()
|
|
|
|
for (name, evaluator, log_steps, prefix) in evaluators():
|
|
if u.itstime(step, log_steps, total_steps, first=False, last=True):
|
|
u.chrono.pause(wait_for=train_state)
|
|
u.chrono.tick(step)
|
|
write_note(f"{name} evaluation...\n{u.chrono.note}")
|
|
with u.chrono.log_timing(f"z/secs/eval/{name}"):
|
|
with mesh, flax.linen.logical_axis_rules(sharding_rules):
|
|
for key, value in evaluator.run(train_state):
|
|
mw.measure(f"{prefix}{key}", jax.device_get(value))
|
|
u.chrono.resume()
|
|
mw.step_end()
|
|
|
|
|
|
if jax.process_index() == 0 and prof is not None:
|
|
u.startstop_prof(prof)
|
|
|
|
|
|
write_note(f"Done!\n{u.chrono.note}")
|
|
|
|
pool.close()
|
|
pool.join()
|
|
mw.close()
|
|
|
|
if ckpt_mngr:
|
|
ckpt_mngr.wait_until_finished()
|
|
|
|
|
|
u.sync()
|
|
|
|
u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(main)
|
|
|