|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""TensorFlow Datasets as data source for big_vision."""
|
|
import functools
|
|
|
|
import big_vision.datasets.core as ds_core
|
|
import jax
|
|
import numpy as np
|
|
import overrides
|
|
import tensorflow as tf
|
|
import tensorflow_datasets as tfds
|
|
|
|
|
|
class DataSource(ds_core.DataSource):
|
|
"""Use TFDS as a data source."""
|
|
|
|
def __init__(self, name, split, data_dir=None, skip_decode=("image",)):
|
|
self.builder = _get_builder(name, data_dir)
|
|
self.split = split
|
|
|
|
process_splits = tfds.even_splits(split, jax.process_count())
|
|
self.process_split = process_splits[jax.process_index()]
|
|
self.skip_decode = skip_decode
|
|
|
|
@overrides.overrides
|
|
def get_tfdata(
|
|
self, ordered=False, *, process_split=True, allow_cache=True, **kw):
|
|
|
|
|
|
|
|
return (_cached_get_dataset if allow_cache else _get_dataset)(
|
|
self.builder, self.skip_decode,
|
|
split=self.process_split if process_split else self.split,
|
|
shuffle_files=not ordered,
|
|
**kw)
|
|
|
|
@property
|
|
@overrides.overrides
|
|
def total_examples(self):
|
|
return self.builder.info.splits[self.split].num_examples
|
|
|
|
@overrides.overrides
|
|
def num_examples_per_process(self):
|
|
splits = tfds.even_splits(self.split, jax.process_count())
|
|
return [self.builder.info.splits[s].num_examples for s in splits]
|
|
|
|
|
|
@functools.cache
|
|
def _get_builder(dataset, data_dir):
|
|
if dataset == "from_data_dir":
|
|
return tfds.builder_from_directory(data_dir)
|
|
else:
|
|
return tfds.builder(dataset, data_dir=data_dir, try_gcs=True)
|
|
|
|
|
|
|
|
|
|
def _get_dataset(builder, skip_decode, **kw):
|
|
"""Returns a tf.data to be used."""
|
|
rckw = {k: kw.pop(k) for k in ("shuffle_seed",) if k in kw}
|
|
ds = builder.as_dataset(
|
|
read_config=tfds.ReadConfig(
|
|
skip_prefetch=True,
|
|
try_autocache=False,
|
|
add_tfds_id=True,
|
|
**rckw,
|
|
),
|
|
decoders={
|
|
f: tfds.decode.SkipDecoding()
|
|
for f in skip_decode if f in builder.info.features
|
|
},
|
|
**kw)
|
|
|
|
def _hash_tfds_id(example):
|
|
id_ = tf.strings.to_hash_bucket_strong(
|
|
example["tfds_id"],
|
|
np.iinfo(np.uint32).max,
|
|
[3714561454027272724, 8800639020734831960])
|
|
example["_id"] = tf.bitcast(id_, tf.int32)[0]
|
|
return example
|
|
|
|
return ds.map(_hash_tfds_id)
|
|
_cached_get_dataset = functools.cache(_get_dataset)
|
|
|