File size: 8,370 Bytes
3f7c489 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import torch
class TestTimeAugmentation:
"""Test-Time Augmentation for image restoration models"""
def __init__(self, model, dino_net, device, use_flip=True, use_rot=True, use_multi_scale=False, scales=None):
"""
Args:
model: The model to apply TTA to
dino_net: DINO feature extractor
device: Device to run inference on
use_flip: Whether to use horizontal and vertical flips
use_rot: Whether to use 90-degree rotations
use_multi_scale: Whether to use multi-scale testing
scales: List of scales to use for multi-scale testing, e.g. [0.8, 1.0, 1.2]
"""
self.model = model
self.dino_net = dino_net
self.device = device
self.use_flip = use_flip
self.use_rot = use_rot
self.use_multi_scale = use_multi_scale
self.scales = scales or [1.0]
def _apply_augmentation(self, image, point, normal, aug_type):
"""Apply single augmentation to input images
Args:
image: Input RGB image
point: Point map
normal: Normal map
aug_type: Augmentation type string (e.g., 'original', 'h_flip', etc.)
Returns:
Augmented versions of image, point map and normal map
"""
if aug_type == 'original':
return image, point, normal
elif aug_type == 'h_flip':
# Horizontal flip
img_aug = torch.flip(image, dims=[3])
point_aug = torch.flip(point, dims=[3])
normal_aug = torch.flip(normal, dims=[3])
# For normal map, x direction needs to be flipped
normal_aug[:, 0, :, :] = -normal_aug[:, 0, :, :]
return img_aug, point_aug, normal_aug
elif aug_type == 'v_flip':
# Vertical flip
img_aug = torch.flip(image, dims=[2])
point_aug = torch.flip(point, dims=[2])
normal_aug = torch.flip(normal, dims=[2])
# For normal map, y direction needs to be flipped
normal_aug[:, 1, :, :] = -normal_aug[:, 1, :, :]
return img_aug, point_aug, normal_aug
elif aug_type == 'rot90':
# 90-degree rotation
img_aug = torch.rot90(image, k=1, dims=[2, 3])
point_aug = torch.rot90(point, k=1, dims=[2, 3])
normal_aug = torch.rot90(normal, k=1, dims=[2, 3])
# Swap x and y channels in normal map and negate x
normal_x = -normal_aug[:, 1, :, :].clone()
normal_y = normal_aug[:, 0, :, :].clone()
normal_aug[:, 0, :, :] = normal_x
normal_aug[:, 1, :, :] = normal_y
return img_aug, point_aug, normal_aug
elif aug_type == 'rot180':
# 180-degree rotation
img_aug = torch.rot90(image, k=2, dims=[2, 3])
point_aug = torch.rot90(point, k=2, dims=[2, 3])
normal_aug = torch.rot90(normal, k=2, dims=[2, 3])
# For normal map, both x and y directions need to be flipped
normal_aug[:, 0, :, :] = -normal_aug[:, 0, :, :]
normal_aug[:, 1, :, :] = -normal_aug[:, 1, :, :]
return img_aug, point_aug, normal_aug
elif aug_type == 'rot270':
# 270-degree rotation
img_aug = torch.rot90(image, k=3, dims=[2, 3])
point_aug = torch.rot90(point, k=3, dims=[2, 3])
normal_aug = torch.rot90(normal, k=3, dims=[2, 3])
# Swap x and y channels in normal map and negate y
normal_x = normal_aug[:, 1, :, :].clone()
normal_y = -normal_aug[:, 0, :, :].clone()
normal_aug[:, 0, :, :] = normal_x
normal_aug[:, 1, :, :] = normal_y
return img_aug, point_aug, normal_aug
else:
raise ValueError(f"Unknown augmentation type: {aug_type}")
def _reverse_augmentation(self, result, aug_type):
"""Reverse the augmentation on the result
Args:
result: Model output to reverse augmentation on
aug_type: Augmentation type string
Returns:
De-augmented result
"""
if aug_type == 'original':
return result
elif aug_type == 'h_flip':
return torch.flip(result, dims=[3])
elif aug_type == 'v_flip':
return torch.flip(result, dims=[2])
elif aug_type == 'rot90':
return torch.rot90(result, k=3, dims=[2, 3])
elif aug_type == 'rot180':
return torch.rot90(result, k=2, dims=[2, 3])
elif aug_type == 'rot270':
return torch.rot90(result, k=1, dims=[2, 3])
else:
raise ValueError(f"Unknown augmentation type: {aug_type}")
def __call__(self, sliding_window, input_img, point, normal):
"""
Apply TTA to the model and return ensemble result
Args:
sliding_window: SlidingWindowInference class instance
input_img: Input RGB image [B, C, H, W]
point: Point map [B, C, H, W]
normal: Normal map [B, C, H, W]
Returns:
Ensemble result with TTA [B, C, H, W]
"""
# Define all augmentations to use
augmentations = ['original']
if self.use_flip:
augmentations.extend(['h_flip', 'v_flip'])
if self.use_rot:
augmentations.extend(['rot90', 'rot180', 'rot270'])
# Initialize the result tensor
ensemble_result = torch.zeros_like(input_img)
ensemble_weight = 0.0
# For each scale and augmentation
for scale in self.scales:
scale_weight = 1.0
if scale != 1.0:
# Resize inputs for multi-scale testing
h, w = input_img.shape[2], input_img.shape[3]
new_h, new_w = int(h * scale), int(w * scale)
# Resize all inputs
resize_fn = torch.nn.functional.interpolate
input_img_scaled = resize_fn(input_img, size=(new_h, new_w), mode='bilinear', align_corners=False)
point_scaled = resize_fn(point, size=(new_h, new_w), mode='bilinear', align_corners=False)
normal_scaled = resize_fn(normal, size=(new_h, new_w), mode='bilinear', align_corners=False)
# Normalize normal vectors after resizing
normal_norm = torch.sqrt(torch.sum(normal_scaled**2, dim=1, keepdim=True) + 1e-6)
normal_scaled = normal_scaled / normal_norm
else:
input_img_scaled = input_img
point_scaled = point
normal_scaled = normal
# Apply each augmentation
for aug_type in augmentations:
# Apply augmentation
img_aug, point_aug, normal_aug = self._apply_augmentation(
input_img_scaled, point_scaled, normal_scaled, aug_type
)
# Run model inference with sliding window
with torch.cuda.amp.autocast():
result_aug = sliding_window(
model=self.model,
input_=img_aug,
point=point_aug,
normal=normal_aug,
dino_net=self.dino_net,
device=self.device
)
# Reverse augmentation on the result
result_aug = self._reverse_augmentation(result_aug, aug_type)
# Resize back to original size if using multi-scale
if scale != 1.0:
result_aug = resize_fn(result_aug, size=(h, w), mode='bilinear', align_corners=False)
# Add to ensemble
ensemble_result += result_aug * scale_weight
ensemble_weight += scale_weight
# Average results
ensemble_result = ensemble_result / ensemble_weight
return ensemble_result |