|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Packed Sequence Op."""
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Optional, List, Union
|
|
|
|
import tensorflow as tf
|
|
|
|
AUTOTUNE = tf.data.experimental.AUTOTUNE
|
|
|
|
|
|
def pack_dataset(dataset: tf.data.Dataset,
|
|
key2length: Union[int, Dict[str, int]],
|
|
keys: Optional[List[str]] = None) -> tf.data.Dataset:
|
|
"""Creates a 'packed' version of a dataset on-the-fly.
|
|
|
|
Adapted from the mesh-tf implementation.
|
|
This is meant to replace the irritation of having to create a separate
|
|
"packed" version of a dataset to train efficiently on TPU.
|
|
Each example in the output dataset represents several examples in the
|
|
input dataset.
|
|
For each key in the input dataset, two additional keys are created:
|
|
<key>_seg: an int32 tensor identifying the parts
|
|
representing the original example.
|
|
<key>_pos: an int32 tensor identifying the position within the original
|
|
example.
|
|
Example:
|
|
Two input examples get combined to form an output example.
|
|
The input examples are:
|
|
{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]}
|
|
{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]}
|
|
The output example is:
|
|
{
|
|
"inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
|
|
"inputs_seg": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
|
|
"inputs_pos": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
|
|
"targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
|
|
"targets_seg": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
|
|
"targets_pos": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
|
|
}
|
|
0 represents padding in both the inputs and the outputs.
|
|
Sequences in the incoming examples are truncated to length "length", and the
|
|
sequences in the output examples all have fixed (padded) length "length".
|
|
Args:
|
|
dataset: a tf.data.Dataset
|
|
key2length: an integer, or a dict from feature-key to integer
|
|
keys: a list of strings (e.g. ["inputs", "targets"])
|
|
Returns:
|
|
a tf.data.Dataset
|
|
"""
|
|
shapes = tf.nest.map_structure(lambda spec: spec.shape, dataset.element_spec)
|
|
if keys is None:
|
|
keys = list(shapes.keys())
|
|
for k in keys:
|
|
if k not in shapes:
|
|
raise ValueError(f"""Key {k} not found in dataset. Available keys are
|
|
{shapes.keys()}""")
|
|
if not shapes[k].is_compatible_with(tf.TensorShape([None])):
|
|
raise ValueError('Tensors to be packed must be one-dimensional.')
|
|
|
|
|
|
if isinstance(key2length, int):
|
|
key2length = {k: key2length for k in keys}
|
|
else:
|
|
key2length = dict(key2length)
|
|
for k in keys:
|
|
for suffix in ['_seg', '_pos']:
|
|
key2length[k + suffix] = key2length[k]
|
|
|
|
|
|
dataset = dataset.map(
|
|
lambda x: {k: x[k][:key2length[k]] for k in keys},
|
|
num_parallel_calls=AUTOTUNE)
|
|
|
|
|
|
batch_size = max(key2length.values())
|
|
dataset = dataset.padded_batch(
|
|
batch_size, padded_shapes={k: [-1] for k in keys})
|
|
dataset = _pack_with_tf_ops(dataset, keys, key2length)
|
|
|
|
|
|
def my_fn(x):
|
|
return {k: tf.reshape(v, [key2length[k]]) for k, v in x.items()}
|
|
|
|
return dataset.map(my_fn, num_parallel_calls=AUTOTUNE)
|
|
|
|
|
|
def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str],
|
|
key2length: Dict[str, int]) -> tf.data.Dataset:
|
|
"""Helper-function for packing a dataset which has already been batched.
|
|
Helper for pack_dataset() Uses tf.while_loop.
|
|
Args:
|
|
dataset: a dataset containing padded batches of examples.
|
|
keys: a list of strings
|
|
key2length: an dict from feature-key to integer
|
|
Returns:
|
|
a dataset.
|
|
"""
|
|
empty_example = {}
|
|
for k in keys:
|
|
empty_example[k] = tf.zeros([0], dtype=tf.int32)
|
|
empty_example[k + '_pos'] = tf.zeros([0], dtype=tf.int32)
|
|
keys_etc = empty_example.keys()
|
|
|
|
def write_packed_example(partial, outputs):
|
|
new_partial = empty_example.copy()
|
|
new_outputs = {}
|
|
for k in keys_etc:
|
|
new_outputs[k] = outputs[k].write(
|
|
outputs[k].size(),
|
|
tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]))
|
|
return new_partial, new_outputs
|
|
|
|
def map_fn(x):
|
|
"""Internal function to flat_map over.
|
|
Consumes a batch of input examples and produces a variable number of output
|
|
examples.
|
|
Args:
|
|
x: a single example
|
|
Returns:
|
|
a tf.data.Dataset
|
|
"""
|
|
partial = empty_example.copy()
|
|
i = tf.zeros([], dtype=tf.int32)
|
|
dynamic_batch_size = tf.shape(x[keys[0]])[0]
|
|
outputs = {}
|
|
for k in keys:
|
|
outputs[k] = tf.TensorArray(
|
|
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]])
|
|
outputs[k + '_pos'] = tf.TensorArray(
|
|
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]])
|
|
|
|
def body_fn(i, partial, outputs):
|
|
"""Body function for while_loop.
|
|
Args:
|
|
i: integer scalar
|
|
partial: dictionary of Tensor (partially-constructed example)
|
|
outputs: dictionary of TensorArray
|
|
Returns:
|
|
A triple containing the new values of the inputs.
|
|
"""
|
|
can_append = True
|
|
one_example = {}
|
|
for k in keys:
|
|
val = tf.cast(x[k][i], tf.int32)
|
|
val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
|
|
one_example[k] = val
|
|
for k in keys:
|
|
can_append = tf.logical_and(
|
|
can_append,
|
|
tf.less_equal(
|
|
tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]))
|
|
|
|
def false_fn():
|
|
return write_packed_example(partial, outputs)
|
|
|
|
def true_fn():
|
|
return partial, outputs
|
|
|
|
partial, outputs = tf.cond(can_append, true_fn, false_fn)
|
|
new_partial = {}
|
|
for k in keys:
|
|
new_seq = one_example[k][:key2length[k]]
|
|
new_seq_len = tf.size(new_seq)
|
|
new_partial[k] = tf.concat([partial[k], new_seq], 0)
|
|
new_partial[k + '_pos'] = tf.concat(
|
|
[partial[k + '_pos'],
|
|
tf.range(new_seq_len)], 0)
|
|
partial = new_partial
|
|
return i + 1, partial, outputs
|
|
|
|
|
|
i, partial, outputs = tf.while_loop(
|
|
cond=lambda *_: True,
|
|
body=body_fn,
|
|
loop_vars=(i, partial, outputs),
|
|
shape_invariants=(
|
|
tf.TensorShape([]),
|
|
{k: tf.TensorShape([None]) for k in keys_etc},
|
|
{k: tf.TensorShape(None) for k in keys_etc},
|
|
),
|
|
maximum_iterations=dynamic_batch_size)
|
|
_, outputs = write_packed_example(partial, outputs)
|
|
packed = {k: outputs[k].stack() for k in keys_etc}
|
|
for k in keys:
|
|
packed[k + '_seg'] = (
|
|
tf.cumsum(
|
|
tf.cast(tf.equal(packed[k + '_pos'], 0), tf.int32), axis=1) *
|
|
tf.cast(tf.not_equal(packed[k], 0), tf.int32))
|
|
return packed
|
|
|
|
dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE)
|
|
return dataset.unbatch()
|
|
|