|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utils for few-shot evaluation."""
|
|
|
|
|
|
import functools
|
|
|
|
import big_vision.datasets.core as ds_core
|
|
import big_vision.input_pipeline as input_pipeline
|
|
import big_vision.pp.builder as pp_builder
|
|
import big_vision.utils as u
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax.sharding import NamedSharding as Sharding
|
|
from jax.sharding import PartitionSpec as P
|
|
import numpy as np
|
|
|
|
BIAS_CONSTANT = 100.0
|
|
|
|
|
|
|
|
API = "jit"
|
|
|
|
|
|
|
|
@u.jit_cpu(static_argnums=(2,))
|
|
def _precompute_cache(x, y, num_classes):
|
|
"""Cache quantities to speed-up the computation of L2-regularized least-sq."""
|
|
|
|
mean = jnp.mean(x, axis=0, keepdims=True)
|
|
std = jnp.std(x, axis=0, keepdims=True) + 1e-5
|
|
x = (x - mean) / std
|
|
|
|
|
|
x = jnp.pad(x, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT)
|
|
|
|
|
|
y = 2.0 * jax.nn.one_hot(y, num_classes) - 1.0
|
|
|
|
num_points, dim = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_points >= dim:
|
|
eigs, q = jnp.linalg.eigh(x.T @ x)
|
|
rhs = q.T @ (x.T @ y)
|
|
lhs = q
|
|
else:
|
|
eigs, q = jnp.linalg.eigh(x @ x.T)
|
|
rhs = q.T @ y
|
|
lhs = x.T @ q
|
|
|
|
cache = {
|
|
"eigs": eigs,
|
|
"rhs": rhs,
|
|
"lhs": lhs,
|
|
"mean": mean,
|
|
"std": std
|
|
}
|
|
return cache
|
|
|
|
|
|
@u.jit_cpu()
|
|
def _eig_fewshot_acc_fn(cache, x_test, y_test, l2_reg):
|
|
"""Computes (x,y) linear regression accuracy on (x_test, y_test)."""
|
|
|
|
x_test = (x_test - cache["mean"]) / cache["std"]
|
|
x_test = jnp.pad(x_test, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT)
|
|
|
|
rhs = cache["rhs"]
|
|
lhs = cache["lhs"]
|
|
eigs = cache["eigs"]
|
|
|
|
|
|
scaling = 1.0 / (eigs + l2_reg * jnp.ones_like(eigs))
|
|
scaling = scaling.reshape((1, -1))
|
|
w = (lhs * scaling) @ rhs
|
|
|
|
preds = jnp.argmax(x_test @ w, axis=1)
|
|
return jnp.mean(preds == y_test)
|
|
|
|
|
|
class Evaluator:
|
|
"""Class for few-shot evaluation."""
|
|
|
|
def __init__(self, predict_fn, batch_size,
|
|
datasets, shots, l2_reg,
|
|
pp_train, pp_eval, display_first,
|
|
representation_layer=None, num_seeds=3,
|
|
label_key="label", mask_key="_mask", data_dir=None, *,
|
|
devices):
|
|
self.datasets = datasets
|
|
self.shots = shots
|
|
self.l2_reg = l2_reg
|
|
self.batch_size = batch_size
|
|
self.pp_tr = pp_train
|
|
self.pp_te = pp_eval
|
|
self.display_first = display_first
|
|
self._datasets = {}
|
|
self._repr = {}
|
|
self.num_seeds = num_seeds
|
|
self.label_key = label_key
|
|
self.mask_key = mask_key
|
|
self.data_dir = data_dir
|
|
self.devices = devices
|
|
self.mesh = jax.sharding.Mesh(devices, ("devices",))
|
|
self.repr_fn = self.get_representation_fn(
|
|
predict_fn, representation_layer)
|
|
|
|
def get_representation_fn(self, predict_fn, representation_layer):
|
|
|
|
@functools.partial(jax.jit, out_shardings=Sharding(self.mesh, P()))
|
|
def _repr_fn(train_state, batch, labels, mask):
|
|
zimg, *_, out = predict_fn(train_state, batch)
|
|
if representation_layer is not None:
|
|
rep = u.tree_get(out, representation_layer)
|
|
else:
|
|
rep = zimg
|
|
return rep, labels, mask
|
|
return _repr_fn
|
|
|
|
|
|
def _get_dataset(self, dataset, train_split, test_split):
|
|
"""Lazy-loads given dataset."""
|
|
key = (dataset, train_split, test_split)
|
|
try:
|
|
return self._datasets[key]
|
|
except KeyError:
|
|
|
|
train_data = ds_core.get(
|
|
name=dataset, split=train_split, data_dir=self.data_dir
|
|
)
|
|
test_data = ds_core.get(
|
|
name=dataset, split=test_split, data_dir=self.data_dir
|
|
)
|
|
train_ds, batches_tr = input_pipeline.make_for_inference(
|
|
train_data.get_tfdata(ordered=True),
|
|
num_ex_per_process=train_data.num_examples_per_process(),
|
|
batch_size=self.batch_size,
|
|
preprocess_fn=pp_builder.get_preprocess_fn(self.pp_tr))
|
|
test_ds, batches_te = input_pipeline.make_for_inference(
|
|
test_data.get_tfdata(ordered=True),
|
|
num_ex_per_process=test_data.num_examples_per_process(),
|
|
batch_size=self.batch_size,
|
|
preprocess_fn=pp_builder.get_preprocess_fn(self.pp_te))
|
|
|
|
num_classes = train_data.builder.info.features[self.label_key].num_classes
|
|
return self._datasets.setdefault(
|
|
key, (train_ds, batches_tr, test_ds, batches_te, num_classes))
|
|
|
|
def _get_repr(self, params, data, steps):
|
|
"""Compute representation for the whole dataset."""
|
|
pre_logits_list = []
|
|
labels_list = []
|
|
for batch, _ in zip(
|
|
input_pipeline.start_global(data, self.devices, 0), range(steps)):
|
|
labels, mask = batch.pop(self.label_key), batch.pop(self.mask_key)
|
|
pre_logits, labels, mask = jax.device_get(self.repr_fn(
|
|
params, batch, labels, mask))
|
|
mask = mask.astype(bool)
|
|
pre_logits_list.append(pre_logits[mask])
|
|
labels_list.append(labels[mask])
|
|
pre_logits = np.concatenate(pre_logits_list, axis=0)
|
|
labels = np.concatenate(labels_list, axis=0)
|
|
|
|
return pre_logits, labels
|
|
|
|
def compute_fewshot_metrics(self, train_state, seed,
|
|
dataset, train_split, test_split):
|
|
"""Compute few-shot metrics on one dataset."""
|
|
if dataset in self._repr:
|
|
repr_train, labels_train, repr_test, labels_test, num_classes = (
|
|
self._repr[dataset])
|
|
else:
|
|
train_ds, steps_tr, test_ds, steps_te, num_classes = self._get_dataset(
|
|
dataset, train_split, test_split)
|
|
repr_train, labels_train = self._get_repr(train_state, train_ds, steps_tr)
|
|
repr_test, labels_test = self._get_repr(train_state, test_ds, steps_te)
|
|
self._repr[dataset] = (repr_train, labels_train,
|
|
repr_test, labels_test,
|
|
num_classes)
|
|
|
|
|
|
rng = np.random.default_rng(seed)
|
|
class_indices = [rng.permutation(np.where(labels_train == cls_i)[0])
|
|
for cls_i in range(num_classes)]
|
|
|
|
results = {}
|
|
for shots in self.shots:
|
|
all_idx = [indices[:shots] for indices in class_indices]
|
|
all_idx = np.concatenate(all_idx, axis=0)
|
|
x = u.put_cpu(repr_train[all_idx])
|
|
y = u.put_cpu(labels_train[all_idx])
|
|
repr_test, labels_test = u.put_cpu((repr_test, labels_test))
|
|
|
|
|
|
|
|
cache = _precompute_cache(x, y, num_classes)
|
|
acc = _eig_fewshot_acc_fn(
|
|
cache, repr_test, labels_test, u.put_cpu(self.l2_reg))
|
|
results[shots] = jax.device_get(acc)
|
|
|
|
return results
|
|
|
|
def run(self, train_state):
|
|
"""New API executed in terms of old API."""
|
|
self._repr = {}
|
|
for seed in range(self.num_seeds):
|
|
for name, dataset_args in self.datasets.items():
|
|
result = self.compute_fewshot_metrics(train_state, seed, *dataset_args)
|
|
for shots, v in result.items():
|
|
prefix = "a/" if (name, shots) in self.display_first else "z/"
|
|
suffix = f"-seed-{seed}"
|
|
yield f"{prefix}{name}_{shots}shot{suffix}", v
|
|
|