|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Core data functions, dispatch calls to the requested dataset."""
|
|
import importlib
|
|
|
|
|
|
|
|
|
|
class DataSource:
|
|
"""The API that any data source should implement."""
|
|
|
|
def get_tfdata(self, ordered, *, process_split=True, allow_cache=True):
|
|
"""Creates this data object as a tf.data.Dataset.
|
|
|
|
This will be called separately in each process, and it is up to the dataset
|
|
implementation to shard it accordingly if desired!
|
|
|
|
Args:
|
|
ordered: if True, the dataset should use deterministic ordering, if False
|
|
it may have undefined ordering. Think of True == val, False == train.
|
|
process_split: if False then every process receives the entire dataset
|
|
(e.g. for evaluators running in a single process).
|
|
allow_cache: whether to allow caching the opened data or not.
|
|
|
|
Returns:
|
|
A tf.data.Dataset object.
|
|
|
|
Raises:
|
|
RuntimeError: if not implemented by the dataset, but called.
|
|
"""
|
|
raise RuntimeError("not implemented for {self.__class__.__name__}")
|
|
|
|
@property
|
|
def total_examples(self):
|
|
"""Returns number of examples in the dataset, regardless of sharding."""
|
|
raise RuntimeError("not implemented for {self.__class__.__name__}")
|
|
|
|
def num_examples_per_process(self):
|
|
"""Returns a list of the numer of examples for each process.
|
|
|
|
This is only needed for datasets that should go through make_for_inference.
|
|
|
|
Returns:
|
|
Returns a list of the numer of examples for each process.
|
|
|
|
Ideally, this would always be `[total() / nprocess] * nprocess`, but in
|
|
reality we can almost never perfectly shard a dataset across arbitrary
|
|
number of processes.
|
|
|
|
One alternative option that can work in some cases is to not even shard
|
|
the dataset and thus return `[num_examples()] * nprocess.
|
|
|
|
Raises:
|
|
RuntimeError: if not implemented by the dataset, but called.
|
|
"""
|
|
raise RuntimeError("not implemented for {self.__class__.__name__}")
|
|
|
|
|
|
def get(name, **kw):
|
|
if name.startswith("bv:"):
|
|
mod = importlib.import_module(f"big_vision.datasets.{name[3:]}")
|
|
return mod.DataSource(**kw)
|
|
else:
|
|
mod = importlib.import_module("big_vision.datasets.tfds")
|
|
return mod.DataSource(name, **kw)
|
|
|