File size: 1,775 Bytes
3f7c489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import os
import torchvision.transforms as transforms


class Augment_RGB_torch:
### rotate and flip
    def __init__(self, rotate=0):
        self.rotate = rotate
        pass
    def transform0(self, torch_tensor):
        return torch_tensor  

    def transform1(self, torch_tensor):
        H, W = torch_tensor.shape[1], torch_tensor.shape[2]
        train_transform = transforms.Compose([
        transforms.RandomRotation((self.rotate,self.rotate), interpolation=transforms.InterpolationMode.BILINEAR, expand=False),
        transforms.Resize((int(H * 1.3), int(W * 1.3)), antialias=True),
        # CenterCrop,if the size is larger than the original size, the excess will be filled with black
        transforms.CenterCrop([H, W])
        ])
        return train_transform(torch_tensor)

    def transform2(self, torch_tensor):
        torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2])
        return torch_tensor
    def transform3(self, torch_tensor):
        torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2])
        return torch_tensor
    def transform4(self, torch_tensor):
        torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2])
        return torch_tensor
    def transform5(self, torch_tensor):
        torch_tensor = torch_tensor.flip(-2)
        return torch_tensor
    def transform6(self, torch_tensor):
        torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2)
        return torch_tensor
    def transform7(self, torch_tensor):
        torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2)
        return torch_tensor
    def transform8(self, torch_tensor):
        torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2)
        return torch_tensor