|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for the FlexiViT model."""
|
|
|
|
from absl.testing import absltest
|
|
from big_vision.models.proj.flexi import vit
|
|
import jax
|
|
from jax import config
|
|
from jax import numpy as jnp
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
config.update("jax_enable_x64", True)
|
|
|
|
|
|
class PatchEmbTest(absltest.TestCase):
|
|
|
|
def _test_patch_emb_resize(self, old_shape, new_shape, n_patches=100):
|
|
|
|
|
|
|
|
|
|
patch_shape = old_shape[:-2]
|
|
resized_patch_shape = new_shape[:-2]
|
|
patches = np.random.randn(n_patches, *old_shape[:-1])
|
|
w_emb = jnp.asarray(np.random.randn(*old_shape))
|
|
|
|
old_embeddings = jax.lax.conv_general_dilated(
|
|
patches, w_emb, window_strides=patch_shape, padding="VALID",
|
|
dimension_numbers=("NHWC", "HWIO", "NHWC"), precision="highest")
|
|
|
|
patch_resized = tf.image.resize(
|
|
tf.constant(patches), resized_patch_shape, method="bilinear").numpy()
|
|
patch_resized = jnp.asarray(patch_resized).astype(jnp.float64)
|
|
w_emb_resampled = vit.resample_patchemb(w_emb, resized_patch_shape)
|
|
self.assertEqual(w_emb_resampled.shape, new_shape)
|
|
|
|
new_embeddings = jax.lax.conv_general_dilated(
|
|
patch_resized, w_emb_resampled, window_strides=resized_patch_shape,
|
|
padding="VALID", dimension_numbers=("NHWC", "HWIO", "NHWC"),
|
|
precision="highest")
|
|
|
|
self.assertEqual(old_embeddings.shape, new_embeddings.shape)
|
|
np.testing.assert_allclose(
|
|
old_embeddings, new_embeddings, rtol=1e-1, atol=1e-4)
|
|
|
|
def test_resize_square(self):
|
|
out_channels = 256
|
|
patch_sizes = [48, 40, 30, 24, 20, 16, 15, 12, 10, 8, 6, 5]
|
|
for s in patch_sizes:
|
|
old_shape = (s, s, 3, out_channels)
|
|
for t in patch_sizes:
|
|
new_shape = (t, t, 3, out_channels)
|
|
if s <= t:
|
|
self._test_patch_emb_resize(old_shape, new_shape)
|
|
|
|
def test_resize_rectangular(self):
|
|
out_channels = 256
|
|
old_shape = (8, 10, 3, out_channels)
|
|
new_shape = (10, 12, 3, out_channels)
|
|
self._test_patch_emb_resize(old_shape, new_shape)
|
|
|
|
old_shape = (8, 6, 3, out_channels)
|
|
new_shape = (9, 15, 3, out_channels)
|
|
self._test_patch_emb_resize(old_shape, new_shape)
|
|
|
|
old_shape = (8, 6, 3, out_channels)
|
|
new_shape = (15, 9, 3, out_channels)
|
|
self._test_patch_emb_resize(old_shape, new_shape)
|
|
|
|
def test_input_channels(self):
|
|
out_channels = 256
|
|
for c in [1, 3, 10]:
|
|
old_shape = (8, 10, c, out_channels)
|
|
new_shape = (10, 12, c, out_channels)
|
|
self._test_patch_emb_resize(old_shape, new_shape)
|
|
|
|
def _test_works(self, old_shape, new_shape):
|
|
old = jnp.asarray(np.random.randn(*old_shape))
|
|
resampled = vit.resample_patchemb(old, new_shape[:2])
|
|
self.assertEqual(resampled.shape, new_shape)
|
|
self.assertEqual(resampled.dtype, old.dtype)
|
|
|
|
def test_downsampling(self):
|
|
|
|
|
|
|
|
out_channels = 256
|
|
for t in [4, 5, 6, 7]:
|
|
for c in [1, 3, 5]:
|
|
old_shape = (8, 8, c, out_channels)
|
|
new_shape = (t, t, c, out_channels)
|
|
self._test_works(old_shape, new_shape)
|
|
|
|
def _test_raises(self, old_shape, new_shape):
|
|
old = jnp.asarray(np.random.randn(*old_shape))
|
|
with self.assertRaises(AssertionError):
|
|
vit.resample_patchemb(old, new_shape)
|
|
|
|
def test_raises_incorrect_dims(self):
|
|
old_shape = (8, 10, 3, 256)
|
|
new_shape = (10, 12, 1, 256)
|
|
self._test_raises(old_shape, new_shape)
|
|
|
|
old_shape = (8, 10, 1, 256)
|
|
new_shape = (10, 12, 3, 256)
|
|
self._test_raises(old_shape, new_shape)
|
|
|
|
old_shape = (8, 10, 3, 128)
|
|
new_shape = (10, 12, 3, 256)
|
|
self._test_raises(old_shape, new_shape)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main()
|
|
|