|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Generic tensor preprocessing ops.
|
|
|
|
All preprocessing ops should return a data processing functors. A data
|
|
is represented as a dictionary of (TF) tensors. The functors output a modified
|
|
dictionary.
|
|
"""
|
|
|
|
import collections
|
|
|
|
from big_vision.pp import utils
|
|
from big_vision.pp.registry import Registry
|
|
import big_vision.utils as bv_utils
|
|
import jax
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
|
|
@Registry.register("preprocess_ops.value_range")
|
|
@utils.InKeyOutKey()
|
|
def get_value_range(vmin=-1, vmax=1, in_min=0, in_max=255.0, clip_values=False):
|
|
"""Transforms a [in_min,in_max] image to [vmin,vmax] range.
|
|
|
|
Input ranges in_min/in_max can be equal-size lists to rescale the invidudal
|
|
channels independently.
|
|
|
|
Args:
|
|
vmin: A scalar. Output max value.
|
|
vmax: A scalar. Output min value.
|
|
in_min: A scalar or a list of input min values to scale. If a list, the
|
|
length should match to the number of channels in the image.
|
|
in_max: A scalar or a list of input max values to scale. If a list, the
|
|
length should match to the number of channels in the image.
|
|
clip_values: Whether to clip the output values to the provided ranges.
|
|
|
|
Returns:
|
|
A function to rescale the values.
|
|
"""
|
|
|
|
def _value_range(image):
|
|
"""Scales values in given range."""
|
|
in_min_t = tf.constant(in_min, tf.float32)
|
|
in_max_t = tf.constant(in_max, tf.float32)
|
|
image = tf.cast(image, tf.float32)
|
|
image = (image - in_min_t) / (in_max_t - in_min_t)
|
|
image = vmin + image * (vmax - vmin)
|
|
if clip_values:
|
|
image = tf.clip_by_value(image, vmin, vmax)
|
|
return image
|
|
|
|
return _value_range
|
|
|
|
|
|
@Registry.register("preprocess_ops.lookup")
|
|
@utils.InKeyOutKey()
|
|
def get_lookup(mapping, npzkey="fnames", sep=None):
|
|
"""Map string to number."""
|
|
|
|
|
|
|
|
|
|
|
|
if mapping.endswith(".npz"):
|
|
with tf.io.gfile.GFile(mapping, "rb") as f:
|
|
keys = np.array(np.load(f, allow_pickle=False)[npzkey])
|
|
vals = np.arange(len(keys))
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
with tf.io.gfile.GFile(mapping, "r") as f:
|
|
buf = f.read()
|
|
if sep is None:
|
|
keys = buf.splitlines()
|
|
vals = np.arange(len(keys))
|
|
else:
|
|
keys, vals = zip(*[l.split(sep) for l in buf.splitlines()])
|
|
vals = [int(v) for v in vals]
|
|
|
|
def _do_the_mapping(needle):
|
|
"""Map string to number."""
|
|
with tf.init_scope():
|
|
table = tf.lookup.StaticHashTable(
|
|
tf.lookup.KeyValueTensorInitializer(keys, vals), -1)
|
|
return table.lookup(needle)
|
|
|
|
return _do_the_mapping
|
|
|
|
|
|
@Registry.register("preprocess_ops.onehot")
|
|
def get_onehot(depth,
|
|
key="labels",
|
|
key_result=None,
|
|
multi=True,
|
|
on=1.0,
|
|
off=0.0):
|
|
"""One-hot encodes the input.
|
|
|
|
Args:
|
|
depth: Length of the one-hot vector (how many classes).
|
|
key: Key of the data to be one-hot encoded.
|
|
key_result: Key under which to store the result (same as `key` if None).
|
|
multi: If there are multiple labels, whether to merge them into the same
|
|
"multi-hot" vector (True) or keep them as an extra dimension (False).
|
|
on: Value to fill in for the positive label (default: 1).
|
|
off: Value to fill in for negative labels (default: 0).
|
|
|
|
Returns:
|
|
Data dictionary.
|
|
"""
|
|
|
|
def _onehot(data):
|
|
|
|
|
|
labels = data[key]
|
|
labels = tf.cast(labels, tf.int64)
|
|
if labels.shape.rank > 0 and multi:
|
|
x = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(labels)[0]), (depth,))
|
|
x = tf.clip_by_value(x, 0, 1) * (on - off) + off
|
|
else:
|
|
x = tf.one_hot(labels, depth, on_value=on, off_value=off)
|
|
data[key_result or key] = x
|
|
return data
|
|
|
|
return _onehot
|
|
|
|
|
|
@Registry.register("preprocess_ops.keep")
|
|
def get_keep(*keys):
|
|
"""Keeps only the given keys."""
|
|
|
|
def _keep(data):
|
|
return {k: v for k, v in data.items() if k in keys}
|
|
|
|
return _keep
|
|
|
|
|
|
@Registry.register("preprocess_ops.drop")
|
|
def get_drop(*keys):
|
|
"""Drops the given keys."""
|
|
|
|
def _drop(data):
|
|
return {k: v for k, v in data.items() if k not in keys}
|
|
|
|
return _drop
|
|
|
|
|
|
@Registry.register("preprocess_ops.copy")
|
|
def get_copy(inkey, outkey):
|
|
"""Copies value of `inkey` into `outkey`."""
|
|
|
|
def _copy(data):
|
|
|
|
|
|
|
|
|
|
data[outkey] = jax.tree.map(lambda x: x, data[inkey])
|
|
return data
|
|
|
|
return _copy
|
|
|
|
|
|
@Registry.register("preprocess_ops.squeeze_last_dim")
|
|
@utils.InKeyOutKey()
|
|
def get_squeeze_last_dim():
|
|
def _squeeze_last_dim(x):
|
|
return tf.squeeze(x, axis=-1)
|
|
return _squeeze_last_dim
|
|
|
|
|
|
@Registry.register("preprocess_ops.concat")
|
|
def get_concat(inkeys, outkey=None, axis=-1):
|
|
"""Concatenates elements along some axis."""
|
|
|
|
def _concat(data):
|
|
data[outkey or inkeys[0]] = tf.concat([data[k] for k in inkeys], axis)
|
|
return data
|
|
|
|
return _concat
|
|
|
|
|
|
@Registry.register("preprocess_ops.rag_tensor")
|
|
@utils.InKeyOutKey()
|
|
def get_rag_tensor():
|
|
"""Converts the specified feature to ragged tensor."""
|
|
|
|
def rag_tensor(raw_tensor):
|
|
|
|
return tf.RaggedTensor.from_tensor(raw_tensor[None])
|
|
|
|
return rag_tensor
|
|
|
|
|
|
@Registry.register("preprocess_ops.pad_to_shape")
|
|
@utils.InKeyOutKey()
|
|
def get_pad_to_shape(shape, pad_value=0, where="after"):
|
|
"""Pads tensor to specified `shape`."""
|
|
|
|
def _pads(cur, tgt):
|
|
if tgt is None:
|
|
return [0, 0]
|
|
diff = tgt - cur
|
|
return {
|
|
"before": [diff, 0],
|
|
"after": [0, diff],
|
|
"both": [diff // 2, diff - diff // 2],
|
|
}[where]
|
|
|
|
def _pad_to_shape(x):
|
|
assert len(x.shape.as_list()) == len(shape)
|
|
paddings = [_pads(tgt=shape[i], cur=tf.shape(x)[i])
|
|
for i in range(len(shape))]
|
|
constant_value = tf.constant(pad_value, x.dtype)
|
|
ret = tf.pad(x, paddings, constant_values=constant_value)
|
|
ret.set_shape(shape)
|
|
return ret
|
|
|
|
return _pad_to_shape
|
|
|
|
|
|
@Registry.register("preprocess_ops.flatten")
|
|
def get_flatten():
|
|
"""Flattens the keys of data with separator '/'."""
|
|
|
|
def flatten(data):
|
|
flat, _ = bv_utils.tree_flatten_with_names(data)
|
|
return dict(flat)
|
|
|
|
return flatten
|
|
|
|
|
|
@Registry.register("preprocess_ops.reshape")
|
|
@utils.InKeyOutKey()
|
|
def get_reshape(new_shape):
|
|
"""Reshapes tensor to a given new shape.
|
|
|
|
Args:
|
|
new_shape: new shape for the tensor.
|
|
|
|
Returns:
|
|
A function for reshaping a tensor.
|
|
|
|
"""
|
|
|
|
def _reshape(tensor):
|
|
"""Reshapes a tensor to a given shape."""
|
|
dtype = tensor.dtype
|
|
tensor = tf.reshape(tensor, new_shape)
|
|
return tf.cast(tensor, dtype)
|
|
|
|
return _reshape
|
|
|
|
|
|
@Registry.register("preprocess_ops.setdefault")
|
|
def get_setdefault(key, value):
|
|
"""If `key` is an empty tensor or missing, set it to `value`."""
|
|
def _setdefault(data):
|
|
x = data.get(key, tf.constant(value))
|
|
v = tf.constant(value, dtype=x.dtype)
|
|
v = tf.broadcast_to(v, [s or 1 for s in x.shape])
|
|
data[key] = tf.cond(tf.size(x) > 0, lambda: x, lambda: v)
|
|
return data
|
|
return _setdefault
|
|
|
|
|
|
@Registry.register("preprocess_ops.choice")
|
|
def get_choice(n="single", key=None, fewer_ok=False, inkey=None, outkey=None):
|
|
"""Chooses the same `n` random entries of all `keys`.
|
|
|
|
Args:
|
|
n: how many entries to randomly sample (without repeat). Possible values:
|
|
- int: that many entries (or fewer if there's fewer, see `fewer_ok`.)
|
|
- "single": The string "single" only chooses one and drop the leading dim.
|
|
- [min, max]: A pair means randomly take between min/max examples (incl.).
|
|
key: str or list of str: See Note.
|
|
fewer_ok: whether to fail when there's fewer than `n` elements to choose
|
|
from (and hence set static shape to `n`), or whether to allow it.
|
|
(and hence have unknown static shape).
|
|
inkey: str or list of str: See Note.
|
|
outkey: str or list of str: See Note.
|
|
|
|
Note:
|
|
If key/inkey/outkey is a list, then the same random entries are chosen for
|
|
all of the keys. Other than that, they function the same as InKeyOutKey.
|
|
|
|
The outkey can also contain the placeholder `{key}` that'll be .
|
|
|
|
Examples:
|
|
choice(key="alt_text/text")
|
|
choice(n=128, key=["patches", "positions"])
|
|
choice(inkey=["questions_i18n", "answers_i18n"], outkey=["q", "a"])
|
|
|
|
Returns:
|
|
The pp op.
|
|
"""
|
|
|
|
|
|
inkeys = utils.maybe_repeat(inkey or key, 1)
|
|
outkeys = utils.maybe_repeat(outkey or key, 1)
|
|
outkeys = [ok.format(key=ik) for ok, ik in zip(outkeys, inkeys)]
|
|
|
|
|
|
is_varlen = isinstance(n, (list, tuple))
|
|
min_n = n[0] if is_varlen else 1 if n == "single" else n
|
|
|
|
def _choice(data):
|
|
nitems = tf.shape(data[inkeys[0]])[0]
|
|
|
|
|
|
|
|
lengths = [tf.shape(data[k])[0] for k in inkeys]
|
|
checks = [tf.debugging.assert_equal(l, nitems) for l in lengths]
|
|
if not fewer_ok:
|
|
checks.append(tf.debugging.assert_greater_equal(nitems, min_n))
|
|
with tf.control_dependencies(checks):
|
|
nitems = tf.identity(nitems)
|
|
|
|
if n == "single":
|
|
index = tf.random.uniform([], 0, nitems, dtype=tf.int32)
|
|
else:
|
|
|
|
indices = tf.random.shuffle(tf.range(nitems))
|
|
end = n
|
|
if is_varlen:
|
|
end = tf.random.uniform([], n[0], n[1] + 1, dtype=tf.int32)
|
|
|
|
indices = tf.sort(indices[:end])
|
|
|
|
for ik, ok in zip(inkeys, outkeys):
|
|
if n == "single":
|
|
result = data[ik][index]
|
|
else:
|
|
result = tf.gather(data[ik], indices, axis=0)
|
|
if not is_varlen:
|
|
result = tf.ensure_shape(result, [n] + [None] * (result.ndim - 1))
|
|
data[ok] = result
|
|
|
|
return data
|
|
return _choice
|
|
|
|
|
|
def _shuffled_index(count, nitems, seed):
|
|
"""Returns index from a shuffled sequence (items only repeat after epoch)."""
|
|
nitems = tf.cast(nitems, count.dtype)
|
|
item_epoch, item_offset = (count // nitems, count % nitems)
|
|
shuffled_indices = tf.random.experimental.stateless_shuffle(
|
|
tf.range(nitems), seed=tf.random.fold_in(seed, item_epoch))
|
|
return shuffled_indices[item_offset]
|
|
|
|
|
|
@Registry.register("preprocess_ops.choice_no_replacement")
|
|
def get_choice_no_replacement(key=None, inkey=None, outkey=None):
|
|
"""Chooses the same random (no replacement) entry of all `keys`.
|
|
|
|
Note: Consider using this for iterating over small datasets with a small
|
|
number of epochs. It differs from `choice(n='single')` in that if an example,
|
|
as identified by its `_id` field, is seen N times then it will cycled through
|
|
all the inkeys values before repeating them. Additionally each repetition uses
|
|
a different order.
|
|
|
|
Caveats: requires dataset to provide a _id field and uses host RAM to keep a
|
|
counter how often each id is seen. It is also not robust to preemptions.
|
|
|
|
Args:
|
|
key: str or list of str: See Note.
|
|
inkey: str or list of str: See Note.
|
|
outkey: str or list of str: See Note.
|
|
|
|
Note:
|
|
If key/inkey/outkey is a list, then the same random entries are chosen for
|
|
all of the keys. Other than that, they function the same as InKeyOutKey.
|
|
|
|
The outkey can also contain the placeholder `{key}` that'll be replaced
|
|
by the inkey name.
|
|
|
|
Examples:
|
|
choice(key="alt_text/text")
|
|
choice(key=["patches", "positions"])
|
|
choice(inkey=["questions_i18n", "answers_i18n"], outkey=["q", "a"])
|
|
|
|
Returns:
|
|
The pp op.
|
|
"""
|
|
|
|
inkeys = utils.maybe_repeat(inkey or key, 1)
|
|
outkeys = utils.maybe_repeat(outkey or key, 1)
|
|
outkeys = [ok.format(key=ik) for ok, ik in zip(outkeys, inkeys)]
|
|
|
|
|
|
|
|
|
|
counter = collections.defaultdict(lambda: -1)
|
|
def _seen_count(example_id):
|
|
example_id = example_id.item()
|
|
counter[example_id] += 1
|
|
return counter[example_id]
|
|
|
|
|
|
|
|
|
|
|
|
seed = tf.random.uniform(
|
|
[2], minval=tf.int32.min, maxval=tf.int32.max, dtype=tf.int32)
|
|
|
|
def _choice(data):
|
|
nitems = tf.shape(data[inkeys[0]])[0]
|
|
|
|
|
|
checks = [
|
|
tf.debugging.assert_equal(tf.shape(data[k])[0], nitems)
|
|
for k in inkeys
|
|
]
|
|
with tf.control_dependencies(checks):
|
|
nitems = tf.identity(nitems)
|
|
|
|
|
|
|
|
|
|
|
|
count = tf.numpy_function(
|
|
_seen_count, (data["_id"],), Tout=tf.int64, stateful=True)
|
|
count = tf.cast(count, tf.int32)
|
|
nitems = tf.cast(nitems, tf.int32)
|
|
shuffle_epoch = count // nitems
|
|
shuffle_offset = count % nitems
|
|
|
|
example_seed = tf.random.fold_in(seed, data["_id"])
|
|
shuffle_seed = tf.random.fold_in(example_seed, shuffle_epoch)
|
|
shuffle = tf.random.experimental.stateless_shuffle(
|
|
tf.range(nitems), seed=shuffle_seed)
|
|
index = shuffle[shuffle_offset]
|
|
|
|
|
|
for ik, ok in zip(inkeys, outkeys):
|
|
data[ok] = data[ik][index]
|
|
return data
|
|
|
|
return _choice
|
|
|