|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for computing mean of per-example metrics.
|
|
|
|
This evaluator can be used in two ways:
|
|
1. Create a new evaluator with reduced boilerplate by inheriting from it.
|
|
2. For quick prototyping, use this with predict_fns which return the metrics.
|
|
"""
|
|
from functools import partial
|
|
from typing import Mapping
|
|
|
|
from big_vision.evaluators import common
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
API = 'jit'
|
|
|
|
|
|
|
|
@partial(jax.jit, static_argnums=0)
|
|
def _run_predict_fn(predict_fn, train_state, batch):
|
|
"""Sum per-example metrics weighted by `_mask`."""
|
|
mask = batch['_mask']
|
|
metrics = predict_fn(train_state, batch)
|
|
|
|
assert isinstance(metrics, Mapping), 'predict_fn must return a dict'
|
|
for y in jax.tree.leaves(metrics):
|
|
if y.shape != mask.shape:
|
|
raise ValueError(
|
|
f'Expected per-example metrics of shape {mask.shape} found '
|
|
f'{jax.tree.map(lambda x: x.shape, metrics)}.')
|
|
metrics = {**metrics, '_mask': mask}
|
|
return jax.tree.map(lambda x: jnp.sum(jnp.where(mask, x, 0)), metrics)
|
|
|
|
|
|
class Evaluator:
|
|
"""Report the mean of per-example metrics computed by predict_fn.
|
|
|
|
`predict_fn(params, batch)` must return a dict from metric name to
|
|
per-example metrics of shape [batch_size].
|
|
"""
|
|
|
|
def __init__(self, predict_fn, **kw):
|
|
self.get_data_iter, self.steps = common.eval_input_pipeline(**kw)
|
|
self.predict_fn = partial(_run_predict_fn, predict_fn)
|
|
|
|
def run(self, train_state):
|
|
"""Computes all metrics."""
|
|
metrics = []
|
|
|
|
|
|
for _, batch in zip(range(self.steps), self.get_data_iter()):
|
|
batch_metrics = self.predict_fn(train_state, batch)
|
|
metrics.append(batch_metrics)
|
|
|
|
|
|
metrics = jax.device_get(metrics)
|
|
|
|
|
|
metrics_sum = jax.tree.map(lambda *x: np.sum(x), *metrics)
|
|
mask_sum = metrics_sum.pop('_mask')
|
|
for key, value_sum in metrics_sum.items():
|
|
yield (key, value_sum / mask_sum)
|
|
|