|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Script that loads a model and only runs evaluators."""
|
|
|
|
from functools import partial
|
|
import importlib
|
|
|
|
import os
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
from absl import logging
|
|
import big_vision.evaluators.common as eval_common
|
|
import big_vision.utils as u
|
|
from clu import parameter_overview
|
|
import flax
|
|
import flax.jax_utils as flax_utils
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from ml_collections import config_flags
|
|
from tensorflow.io import gfile
|
|
|
|
|
|
config_flags.DEFINE_config_file(
|
|
"config", None, "Training configuration.", lock_config=True)
|
|
|
|
flags.DEFINE_string("workdir", default=None, help="Work unit directory.")
|
|
flags.DEFINE_boolean("cleanup", default=False,
|
|
help="Delete workdir (only) after successful completion.")
|
|
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
def main(argv):
|
|
del argv
|
|
|
|
config = flags.FLAGS.config
|
|
workdir = flags.FLAGS.workdir
|
|
logging.info("Workdir: %s", workdir)
|
|
|
|
|
|
for m in config.get("pp_modules", ["ops_general", "ops_image"]):
|
|
importlib.import_module(f"big_vision.pp.{m}")
|
|
|
|
|
|
|
|
xid, wid = -1, -1
|
|
def write_note(note):
|
|
if jax.process_index() == 0:
|
|
logging.info("NOTE: %s", note)
|
|
|
|
mw = u.BigVisionMetricWriter(xid, wid, workdir, config)
|
|
u.chrono.inform(measure=mw.measure, write_note=write_note)
|
|
|
|
write_note(f"Initializing {config.model_name} model...")
|
|
assert config.get("model.reinit") is None, (
|
|
"I don't think you want any part of the model to be re-initialized.")
|
|
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}")
|
|
model_kw = dict(config.get("model", {}))
|
|
if "num_classes" in config:
|
|
model_kw["num_classes"] = config.num_classes
|
|
model = model_mod.Model(**model_kw)
|
|
|
|
|
|
|
|
|
|
@partial(jax.jit, backend="cpu")
|
|
def init(rng):
|
|
input_shapes = config.get("init_shapes", [(1, 224, 224, 3)])
|
|
input_types = config.get("init_types", [jnp.float32] * len(input_shapes))
|
|
dummy_inputs = [jnp.zeros(s, t) for s, t in zip(input_shapes, input_types)]
|
|
things = flax.core.unfreeze(model.init(rng, *dummy_inputs))
|
|
return things.get("params", {})
|
|
|
|
with u.chrono.log_timing("z/secs/init"):
|
|
params_cpu = init(jax.random.PRNGKey(42))
|
|
if jax.process_index() == 0:
|
|
parameter_overview.log_parameter_overview(params_cpu, msg="init params")
|
|
num_params = sum(p.size for p in jax.tree.leaves(params_cpu))
|
|
mw.measure("num_params", num_params)
|
|
|
|
|
|
if config.get("model_init"):
|
|
write_note(f"Initialize model from {config.model_init}...")
|
|
params_cpu = model_mod.load(
|
|
params_cpu, config.model_init, config.get("model"),
|
|
**config.get("model_load", {}))
|
|
if jax.process_index() == 0:
|
|
parameter_overview.log_parameter_overview(params_cpu, msg="loaded params")
|
|
|
|
write_note("Replicating...")
|
|
params_repl = flax_utils.replicate(params_cpu)
|
|
|
|
def predict_fn(params, *a, **kw):
|
|
return model.apply({"params": params}, *a, **kw)
|
|
|
|
evaluators = eval_common.from_config(
|
|
config, {"predict": predict_fn, "model": model},
|
|
lambda s: write_note(f"Initializing evaluator: {s}..."),
|
|
lambda key, cfg: 1,
|
|
)
|
|
|
|
|
|
|
|
|
|
for s in range(config.get("eval_repeats", 1)):
|
|
mw.step_start(s)
|
|
for (name, evaluator, _, prefix) in evaluators:
|
|
write_note(f"{name} evaluation step {s}...")
|
|
with u.profile(name, noop=name in config.get("no_profile", [])):
|
|
with u.chrono.log_timing(f"z/secs/eval/{name}"):
|
|
for key, value in evaluator.run(params_repl):
|
|
mw.measure(f"{prefix}{key}", value)
|
|
u.sync()
|
|
u.chrono.flush_timings()
|
|
mw.step_end()
|
|
|
|
write_note("Done!")
|
|
mw.close()
|
|
|
|
|
|
u.sync()
|
|
|
|
if workdir and flags.FLAGS.cleanup and jax.process_index() == 0:
|
|
gfile.rmtree(workdir)
|
|
try:
|
|
gfile.remove(os.path.join(workdir, ".."))
|
|
except tf.errors.OpError:
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(main)
|
|
|