|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""AutoAugment and RandAugment policies for enhanced image preprocessing.
|
|
|
|
AutoAugment Reference: https://arxiv.org/abs/1805.09501
|
|
RandAugment Reference: https://arxiv.org/abs/1909.13719
|
|
|
|
This code is forked from
|
|
https://github.com/tensorflow/tpu/blob/11d0db15cf1c3667f6e36fecffa111399e008acd/models/official/efficientnet/autoaugment.py
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import dataclasses
|
|
import inspect
|
|
import math
|
|
import tensorflow.compat.v1 as tf
|
|
from tensorflow_addons import image as contrib_image
|
|
|
|
|
|
|
|
_MAX_LEVEL = 10.
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class HParams:
|
|
"""Parameters for AutoAugment and RandAugment."""
|
|
cutout_const: int
|
|
translate_const: int
|
|
|
|
|
|
def policy_v0():
|
|
"""Autoaugment policy that was used in AutoAugment Paper."""
|
|
|
|
|
|
|
|
policy = [
|
|
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
|
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
|
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
|
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
|
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
|
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
|
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
|
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
|
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
|
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
|
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
|
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
|
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
|
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
|
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
|
[('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
|
|
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
|
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
|
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
|
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
|
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
|
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
|
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
|
|
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
|
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
|
]
|
|
return policy
|
|
|
|
|
|
def policy_vtest():
|
|
"""Autoaugment test policy for debugging."""
|
|
|
|
|
|
|
|
policy = [
|
|
[('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)],
|
|
]
|
|
return policy
|
|
|
|
|
|
def blend(image1, image2, factor):
|
|
"""Blend image1 and image2 using 'factor'.
|
|
Factor can be above 0.0. A value of 0.0 means only image1 is used.
|
|
A value of 1.0 means only image2 is used. A value between 0.0 and
|
|
1.0 means we linearly interpolate the pixel values between the two
|
|
images. A value greater than 1.0 "extrapolates" the difference
|
|
between the two pixel values, and we clip the results to values
|
|
between 0 and 255.
|
|
Args:
|
|
image1: An image Tensor of type uint8.
|
|
image2: An image Tensor of type uint8.
|
|
factor: A floating point value above 0.0.
|
|
Returns:
|
|
A blended image Tensor of type uint8.
|
|
"""
|
|
if factor == 0.0:
|
|
return tf.convert_to_tensor(image1)
|
|
if factor == 1.0:
|
|
return tf.convert_to_tensor(image2)
|
|
|
|
image1 = tf.to_float(image1)
|
|
image2 = tf.to_float(image2)
|
|
|
|
difference = image2 - image1
|
|
scaled = factor * difference
|
|
|
|
|
|
temp = tf.to_float(image1) + scaled
|
|
|
|
|
|
if factor > 0.0 and factor < 1.0:
|
|
|
|
return tf.cast(temp, tf.uint8)
|
|
|
|
|
|
|
|
|
|
return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8)
|
|
|
|
|
|
def cutout(image, pad_size, replace=0):
|
|
"""Apply cutout (https://arxiv.org/abs/1708.04552) to image.
|
|
This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
|
|
a random location within `img`. The pixel values filled in will be of the
|
|
value `replace`. The located where the mask will be applied is randomly
|
|
chosen uniformly over the whole image.
|
|
Args:
|
|
image: An image Tensor of type uint8.
|
|
pad_size: Specifies how big the zero mask that will be generated is that
|
|
is applied to the image. The mask will be of size
|
|
(2*pad_size x 2*pad_size).
|
|
replace: What pixel value to fill in the image in the area that has
|
|
the cutout mask applied to it.
|
|
Returns:
|
|
An image Tensor that is of type uint8.
|
|
"""
|
|
image_height = tf.shape(image)[0]
|
|
image_width = tf.shape(image)[1]
|
|
|
|
|
|
cutout_center_height = tf.random_uniform(
|
|
shape=[], minval=0, maxval=image_height,
|
|
dtype=tf.int32)
|
|
|
|
cutout_center_width = tf.random_uniform(
|
|
shape=[], minval=0, maxval=image_width,
|
|
dtype=tf.int32)
|
|
|
|
lower_pad = tf.maximum(0, cutout_center_height - pad_size)
|
|
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
|
|
left_pad = tf.maximum(0, cutout_center_width - pad_size)
|
|
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
|
|
|
|
cutout_shape = [image_height - (lower_pad + upper_pad),
|
|
image_width - (left_pad + right_pad)]
|
|
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
|
|
mask = tf.pad(
|
|
tf.zeros(cutout_shape, dtype=image.dtype),
|
|
padding_dims, constant_values=1)
|
|
mask = tf.expand_dims(mask, -1)
|
|
mask = tf.tile(mask, [1, 1, 3])
|
|
image = tf.where(
|
|
tf.equal(mask, 0),
|
|
tf.ones_like(image, dtype=image.dtype) * replace,
|
|
image)
|
|
return image
|
|
|
|
|
|
def solarize(image, threshold=128):
|
|
|
|
|
|
|
|
return tf.where(image < threshold, image, 255 - image)
|
|
|
|
|
|
def solarize_add(image, addition=0, threshold=128):
|
|
|
|
|
|
|
|
|
|
added_image = tf.cast(image, tf.int64) + addition
|
|
added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
|
|
return tf.where(image < threshold, added_image, image)
|
|
|
|
|
|
def color(image, factor):
|
|
"""Equivalent of PIL Color."""
|
|
degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
|
|
return blend(degenerate, image, factor)
|
|
|
|
|
|
def contrast(image, factor):
|
|
"""Equivalent of PIL Contrast."""
|
|
degenerate = tf.image.rgb_to_grayscale(image)
|
|
|
|
degenerate = tf.cast(degenerate, tf.int32)
|
|
|
|
|
|
|
|
|
|
hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
|
|
mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
|
|
degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
|
|
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
|
|
degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
|
|
return blend(degenerate, image, factor)
|
|
|
|
|
|
def brightness(image, factor):
|
|
"""Equivalent of PIL Brightness."""
|
|
degenerate = tf.zeros_like(image)
|
|
return blend(degenerate, image, factor)
|
|
|
|
|
|
def posterize(image, bits):
|
|
"""Equivalent of PIL Posterize."""
|
|
shift = 8 - bits
|
|
return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)
|
|
|
|
|
|
def rotate(image, degrees, replace):
|
|
"""Rotates the image by degrees either clockwise or counterclockwise.
|
|
Args:
|
|
image: An image Tensor of type uint8.
|
|
degrees: Float, a scalar angle in degrees to rotate all images by. If
|
|
degrees is positive the image will be rotated clockwise otherwise it will
|
|
be rotated counterclockwise.
|
|
replace: A one or three value 1D tensor to fill empty pixels caused by
|
|
the rotate operation.
|
|
Returns:
|
|
The rotated version of image.
|
|
"""
|
|
|
|
degrees_to_radians = math.pi / 180.0
|
|
radians = degrees * degrees_to_radians
|
|
|
|
|
|
|
|
|
|
image = contrib_image.rotate(wrap(image), radians)
|
|
return unwrap(image, replace)
|
|
|
|
|
|
def translate_x(image, pixels, replace):
|
|
"""Equivalent of PIL Translate in X dimension."""
|
|
image = contrib_image.translate(wrap(image), [-pixels, 0])
|
|
return unwrap(image, replace)
|
|
|
|
|
|
def translate_y(image, pixels, replace):
|
|
"""Equivalent of PIL Translate in Y dimension."""
|
|
image = contrib_image.translate(wrap(image), [0, -pixels])
|
|
return unwrap(image, replace)
|
|
|
|
|
|
def shear_x(image, level, replace):
|
|
"""Equivalent of PIL Shearing in X dimension."""
|
|
|
|
|
|
|
|
|
|
image = contrib_image.transform(
|
|
wrap(image), [1., level, 0., 0., 1., 0., 0., 0.])
|
|
return unwrap(image, replace)
|
|
|
|
|
|
def shear_y(image, level, replace):
|
|
"""Equivalent of PIL Shearing in Y dimension."""
|
|
|
|
|
|
|
|
|
|
image = contrib_image.transform(
|
|
wrap(image), [1., 0., 0., level, 1., 0., 0., 0.])
|
|
return unwrap(image, replace)
|
|
|
|
|
|
def autocontrast(image):
|
|
"""Implements Autocontrast function from PIL using TF ops.
|
|
Args:
|
|
image: A 3D uint8 tensor.
|
|
Returns:
|
|
The image after it has had autocontrast applied to it and will be of type
|
|
uint8.
|
|
"""
|
|
|
|
def scale_channel(image):
|
|
"""Scale the 2D image using the autocontrast rule."""
|
|
|
|
|
|
|
|
lo = tf.to_float(tf.reduce_min(image))
|
|
hi = tf.to_float(tf.reduce_max(image))
|
|
|
|
|
|
def scale_values(im):
|
|
scale = 255.0 / (hi - lo)
|
|
offset = -lo * scale
|
|
im = tf.to_float(im) * scale + offset
|
|
im = tf.clip_by_value(im, 0.0, 255.0)
|
|
return tf.cast(im, tf.uint8)
|
|
|
|
result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
|
|
return result
|
|
|
|
|
|
|
|
s1 = scale_channel(image[:, :, 0])
|
|
s2 = scale_channel(image[:, :, 1])
|
|
s3 = scale_channel(image[:, :, 2])
|
|
image = tf.stack([s1, s2, s3], 2)
|
|
return image
|
|
|
|
|
|
def sharpness(image, factor):
|
|
"""Implements Sharpness function from PIL using TF ops."""
|
|
orig_image = image
|
|
image = tf.cast(image, tf.float32)
|
|
|
|
image = tf.expand_dims(image, 0)
|
|
|
|
kernel = tf.constant(
|
|
[[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32,
|
|
shape=[3, 3, 1, 1]) / 13.
|
|
|
|
kernel = tf.tile(kernel, [1, 1, 3, 1])
|
|
strides = [1, 1, 1, 1]
|
|
with tf.device('/cpu:0'):
|
|
|
|
|
|
degenerate = tf.nn.depthwise_conv2d(
|
|
image, kernel, strides, padding='VALID', rate=[1, 1])
|
|
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
|
|
degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
|
|
|
|
|
|
|
|
mask = tf.ones_like(degenerate)
|
|
padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
|
|
padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
|
|
result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
|
|
|
|
|
|
return blend(result, orig_image, factor)
|
|
|
|
|
|
def equalize(image):
|
|
"""Implements Equalize function from PIL using TF ops."""
|
|
def scale_channel(im, c):
|
|
"""Scale the data in the channel to implement equalize."""
|
|
im = tf.cast(im[:, :, c], tf.int32)
|
|
|
|
histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)
|
|
|
|
|
|
nonzero = tf.where(tf.not_equal(histo, 0))
|
|
nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
|
|
step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255
|
|
|
|
def build_lut(histo, step):
|
|
|
|
|
|
lut = (tf.cumsum(histo) + (step // 2)) // step
|
|
|
|
lut = tf.concat([[0], lut[:-1]], 0)
|
|
|
|
|
|
return tf.clip_by_value(lut, 0, 255)
|
|
|
|
|
|
|
|
result = tf.cond(tf.equal(step, 0),
|
|
lambda: im,
|
|
lambda: tf.gather(build_lut(histo, step), im))
|
|
|
|
return tf.cast(result, tf.uint8)
|
|
|
|
|
|
|
|
s1 = scale_channel(image, 0)
|
|
s2 = scale_channel(image, 1)
|
|
s3 = scale_channel(image, 2)
|
|
image = tf.stack([s1, s2, s3], 2)
|
|
return image
|
|
|
|
|
|
def invert(image):
|
|
"""Inverts the image pixels."""
|
|
image = tf.convert_to_tensor(image)
|
|
return 255 - image
|
|
|
|
|
|
def wrap(image):
|
|
"""Returns 'image' with an extra channel set to all 1s."""
|
|
shape = tf.shape(image)
|
|
extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype)
|
|
extended = tf.concat([image, extended_channel], 2)
|
|
return extended
|
|
|
|
|
|
def unwrap(image, replace):
|
|
"""Unwraps an image produced by wrap.
|
|
Where there is a 0 in the last channel for every spatial position,
|
|
the rest of the three channels in that spatial dimension are grayed
|
|
(set to 128). Operations like translate and shear on a wrapped
|
|
Tensor will leave 0s in empty locations. Some transformations look
|
|
at the intensity of values to do preprocessing, and we want these
|
|
empty pixels to assume the 'average' value, rather than pure black.
|
|
Args:
|
|
image: A 3D Image Tensor with 4 channels.
|
|
replace: A one or three value 1D tensor to fill empty pixels.
|
|
Returns:
|
|
image: A 3D image Tensor with 3 channels.
|
|
"""
|
|
image_shape = tf.shape(image)
|
|
|
|
flattened_image = tf.reshape(image, [-1, image_shape[2]])
|
|
|
|
|
|
alpha_channel = flattened_image[:, 3]
|
|
|
|
replace = tf.concat([replace, tf.ones([1], image.dtype)], 0)
|
|
|
|
|
|
flattened_image = tf.where(
|
|
tf.equal(alpha_channel, 0),
|
|
tf.ones_like(flattened_image, dtype=image.dtype) * replace,
|
|
flattened_image)
|
|
|
|
image = tf.reshape(flattened_image, image_shape)
|
|
image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3])
|
|
return image
|
|
|
|
|
|
NAME_TO_FUNC = {
|
|
'AutoContrast': autocontrast,
|
|
'Equalize': equalize,
|
|
'Invert': invert,
|
|
'Rotate': rotate,
|
|
'Posterize': posterize,
|
|
'Solarize': solarize,
|
|
'SolarizeAdd': solarize_add,
|
|
'Color': color,
|
|
'Contrast': contrast,
|
|
'Brightness': brightness,
|
|
'Sharpness': sharpness,
|
|
'ShearX': shear_x,
|
|
'ShearY': shear_y,
|
|
'TranslateX': translate_x,
|
|
'TranslateY': translate_y,
|
|
'Cutout': cutout,
|
|
}
|
|
|
|
|
|
def _randomly_negate_tensor(tensor):
|
|
"""With 50% prob turn the tensor negative."""
|
|
should_flip = tf.cast(tf.floor(tf.random_uniform([]) + 0.5), tf.bool)
|
|
final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
|
|
return final_tensor
|
|
|
|
|
|
def _rotate_level_to_arg(level):
|
|
level = (level/_MAX_LEVEL) * 30.
|
|
level = _randomly_negate_tensor(level)
|
|
return (level,)
|
|
|
|
|
|
def _shrink_level_to_arg(level):
|
|
"""Converts level to ratio by which we shrink the image content."""
|
|
if level == 0:
|
|
return (1.0,)
|
|
|
|
level = 2. / (_MAX_LEVEL / level) + 0.9
|
|
return (level,)
|
|
|
|
|
|
def _enhance_level_to_arg(level):
|
|
return ((level/_MAX_LEVEL) * 1.8 + 0.1,)
|
|
|
|
|
|
def _shear_level_to_arg(level):
|
|
level = (level/_MAX_LEVEL) * 0.3
|
|
|
|
level = _randomly_negate_tensor(level)
|
|
return (level,)
|
|
|
|
|
|
def _translate_level_to_arg(level, translate_const):
|
|
level = (level/_MAX_LEVEL) * float(translate_const)
|
|
|
|
level = _randomly_negate_tensor(level)
|
|
return (level,)
|
|
|
|
|
|
def level_to_arg(hparams):
|
|
return {
|
|
'AutoContrast': lambda level: (),
|
|
'Equalize': lambda level: (),
|
|
'Invert': lambda level: (),
|
|
'Rotate': _rotate_level_to_arg,
|
|
'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),),
|
|
'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),),
|
|
'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),),
|
|
'Color': _enhance_level_to_arg,
|
|
'Contrast': _enhance_level_to_arg,
|
|
'Brightness': _enhance_level_to_arg,
|
|
'Sharpness': _enhance_level_to_arg,
|
|
'ShearX': _shear_level_to_arg,
|
|
'ShearY': _shear_level_to_arg,
|
|
'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams.cutout_const),),
|
|
'TranslateX': lambda level: _translate_level_to_arg(
|
|
level, hparams.translate_const),
|
|
'TranslateY': lambda level: _translate_level_to_arg(
|
|
level, hparams.translate_const),
|
|
|
|
}
|
|
|
|
|
|
def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams):
|
|
"""Return the function that corresponds to `name` and update `level` param."""
|
|
func = NAME_TO_FUNC[name]
|
|
args = level_to_arg(augmentation_hparams)[name](level)
|
|
|
|
|
|
|
|
|
|
if 'prob' in inspect.getfullargspec(func).args:
|
|
args = tuple([prob] + list(args))
|
|
|
|
|
|
|
|
|
|
if 'replace' in inspect.getfullargspec(func).args:
|
|
|
|
assert 'replace' == inspect.getfullargspec(func).args[-1]
|
|
args = tuple(list(args) + [replace_value])
|
|
|
|
|
|
return (func, prob, args)
|
|
|
|
|
|
def _apply_func_with_prob(func, image, args, prob):
|
|
"""Apply `func` to image w/ `args` as input with probability `prob`."""
|
|
assert isinstance(args, tuple)
|
|
|
|
|
|
|
|
|
|
if 'prob' in inspect.getfullargspec(func).args:
|
|
prob = 1.0
|
|
|
|
|
|
|
|
should_apply_op = tf.cast(
|
|
tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool)
|
|
augmented_image = tf.cond(
|
|
should_apply_op,
|
|
lambda: func(image, *args),
|
|
lambda: image)
|
|
return augmented_image
|
|
|
|
|
|
def select_and_apply_random_policy(policies, image):
|
|
"""Select a random policy from `policies` and apply it to `image`."""
|
|
policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32)
|
|
|
|
|
|
for (i, policy) in enumerate(policies):
|
|
image = tf.cond(
|
|
tf.equal(i, policy_to_select),
|
|
lambda selected_policy=policy: selected_policy(image),
|
|
lambda: image)
|
|
return image
|
|
|
|
|
|
def build_and_apply_nas_policy(policies, image,
|
|
augmentation_hparams):
|
|
"""Build a policy from the given policies passed in and apply to image.
|
|
Args:
|
|
policies: list of lists of tuples in the form `(func, prob, level)`, `func`
|
|
is a string name of the augmentation function, `prob` is the probability
|
|
of applying the `func` operation, `level` is the input argument for
|
|
`func`.
|
|
image: tf.Tensor that the resulting policy will be applied to.
|
|
augmentation_hparams: Hparams associated with the NAS learned policy.
|
|
Returns:
|
|
A version of image that now has data augmentation applied to it based on
|
|
the `policies` pass into the function.
|
|
"""
|
|
replace_value = [128, 128, 128]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tf_policies = []
|
|
for policy in policies:
|
|
tf_policy = []
|
|
|
|
|
|
for policy_info in policy:
|
|
policy_info = list(policy_info) + [replace_value, augmentation_hparams]
|
|
|
|
tf_policy.append(_parse_policy_info(*policy_info))
|
|
|
|
|
|
def make_final_policy(tf_policy_):
|
|
def final_policy(image_):
|
|
for func, prob, args in tf_policy_:
|
|
image_ = _apply_func_with_prob(
|
|
func, image_, args, prob)
|
|
return image_
|
|
return final_policy
|
|
tf_policies.append(make_final_policy(tf_policy))
|
|
|
|
augmented_image = select_and_apply_random_policy(
|
|
tf_policies, image)
|
|
return augmented_image
|
|
|
|
|
|
def distort_image_with_autoaugment(image, augmentation_name):
|
|
"""Applies the AutoAugment policy to `image`.
|
|
AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
|
|
Args:
|
|
image: `Tensor` of shape [height, width, 3] representing an image.
|
|
augmentation_name: The name of the AutoAugment policy to use. The available
|
|
options are `v0` and `test`. `v0` is the policy used for
|
|
all of the results in the paper and was found to achieve the best results
|
|
on the COCO dataset. `v1`, `v2` and `v3` are additional good policies
|
|
found on the COCO dataset that have slight variation in what operations
|
|
were used during the search procedure along with how many operations are
|
|
applied in parallel to a single image (2 vs 3).
|
|
Returns:
|
|
A tuple containing the augmented versions of `image`.
|
|
"""
|
|
available_policies = {'v0': policy_v0,
|
|
'test': policy_vtest}
|
|
if augmentation_name not in available_policies:
|
|
raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name))
|
|
|
|
policy = available_policies[augmentation_name]()
|
|
|
|
augmentation_hparams = HParams(
|
|
cutout_const=100, translate_const=250)
|
|
|
|
return build_and_apply_nas_policy(policy, image, augmentation_hparams)
|
|
|
|
|
|
def distort_image_with_randaugment(image, num_layers, magnitude):
|
|
"""Applies the RandAugment policy to `image`.
|
|
RandAugment is from the paper https://arxiv.org/abs/1909.13719,
|
|
Args:
|
|
image: `Tensor` of shape [height, width, 3] representing an image.
|
|
num_layers: Integer, the number of augmentation transformations to apply
|
|
sequentially to an image. Represented as (N) in the paper. Usually best
|
|
values will be in the range [1, 3].
|
|
magnitude: Integer, shared magnitude across all augmentation operations.
|
|
Represented as (M) in the paper. Usually best values are in the range
|
|
[5, 30].
|
|
Returns:
|
|
The augmented version of `image`.
|
|
"""
|
|
replace_value = [128] * 3
|
|
tf.logging.info('Using RandAug.')
|
|
augmentation_hparams = HParams(
|
|
cutout_const=40, translate_const=100)
|
|
available_ops = [
|
|
'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize',
|
|
'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness',
|
|
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd']
|
|
|
|
for layer_num in range(num_layers):
|
|
op_to_select = tf.random_uniform(
|
|
[], maxval=len(available_ops), dtype=tf.int32)
|
|
random_magnitude = float(magnitude)
|
|
with tf.name_scope('randaug_layer_{}'.format(layer_num)):
|
|
for (i, op_name) in enumerate(available_ops):
|
|
prob = tf.random_uniform([], minval=0.2, maxval=0.8, dtype=tf.float32)
|
|
func, _, args = _parse_policy_info(op_name, prob, random_magnitude,
|
|
replace_value, augmentation_hparams)
|
|
image = tf.cond(
|
|
tf.equal(i, op_to_select),
|
|
lambda selected_func=func, selected_args=args: selected_func(
|
|
image, *selected_args),
|
|
|
|
lambda: image)
|
|
return image
|
|
|