File size: 1,775 Bytes
3f7c489 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import os
import torchvision.transforms as transforms
class Augment_RGB_torch:
### rotate and flip
def __init__(self, rotate=0):
self.rotate = rotate
pass
def transform0(self, torch_tensor):
return torch_tensor
def transform1(self, torch_tensor):
H, W = torch_tensor.shape[1], torch_tensor.shape[2]
train_transform = transforms.Compose([
transforms.RandomRotation((self.rotate,self.rotate), interpolation=transforms.InterpolationMode.BILINEAR, expand=False),
transforms.Resize((int(H * 1.3), int(W * 1.3)), antialias=True),
# CenterCrop,if the size is larger than the original size, the excess will be filled with black
transforms.CenterCrop([H, W])
])
return train_transform(torch_tensor)
def transform2(self, torch_tensor):
torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2])
return torch_tensor
def transform3(self, torch_tensor):
torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2])
return torch_tensor
def transform4(self, torch_tensor):
torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2])
return torch_tensor
def transform5(self, torch_tensor):
torch_tensor = torch_tensor.flip(-2)
return torch_tensor
def transform6(self, torch_tensor):
torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2)
return torch_tensor
def transform7(self, torch_tensor):
torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2)
return torch_tensor
def transform8(self, torch_tensor):
torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2)
return torch_tensor
|