|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for the classfication task."""
|
|
|
|
|
|
import functools
|
|
|
|
from big_vision.evaluators import common
|
|
import big_vision.utils as u
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
|
|
|
|
|
|
API = 'jit'
|
|
|
|
|
|
|
|
|
|
@functools.cache
|
|
def get_eval_fn(predict_fn, loss_name):
|
|
"""Produces eval function, also applies pmap."""
|
|
@jax.jit
|
|
def _eval_fn(train_state, batch, labels, mask):
|
|
logits, *_ = predict_fn(train_state, batch)
|
|
|
|
|
|
mask *= labels.max(axis=1)
|
|
|
|
loss = getattr(u, loss_name)(
|
|
logits=logits, labels=labels, reduction=False)
|
|
loss = jnp.sum(loss * mask)
|
|
|
|
top1_idx = jnp.argmax(logits, axis=1)
|
|
|
|
top1_correct = jnp.take_along_axis(
|
|
labels, top1_idx[:, None], axis=1)[:, 0]
|
|
ncorrect = jnp.sum(top1_correct * mask)
|
|
nseen = jnp.sum(mask)
|
|
return ncorrect, loss, nseen
|
|
return _eval_fn
|
|
|
|
|
|
class Evaluator:
|
|
"""Classification evaluator."""
|
|
|
|
def __init__(self, predict_fn, loss_name, label_key='labels', **kw):
|
|
self.get_data_iter, self.steps = common.eval_input_pipeline(**kw)
|
|
self.eval_fn = get_eval_fn(predict_fn, loss_name)
|
|
self.label_key = label_key
|
|
|
|
def run(self, train_state):
|
|
"""Computes all metrics."""
|
|
ncorrect, loss, nseen = 0, 0, 0
|
|
for _, batch in zip(range(self.steps), self.get_data_iter()):
|
|
labels, mask = batch.pop(self.label_key), batch.pop('_mask')
|
|
batch_ncorrect, batch_losses, batch_nseen = jax.device_get(
|
|
self.eval_fn(train_state, batch, labels, mask))
|
|
ncorrect += batch_ncorrect
|
|
loss += batch_losses
|
|
nseen += batch_nseen
|
|
yield ('prec@1', ncorrect / nseen)
|
|
yield ('loss', loss / nseen)
|
|
|