File size: 3,568 Bytes
3f7c489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e26714
3f7c489
 
 
 
 
 
 
9e26714
 
3f7c489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import os
from collections import OrderedDict

def freeze(model):
    for p in model.parameters():
        p.requires_grad=False

def unfreeze(model):
    for p in model.parameters():
        p.requires_grad=True

def is_frozen(model):
    x = [p.requires_grad for p in model.parameters()]
    return not all(x)

def save_checkpoint(model_dir, state, session):
    epoch = state['epoch']
    model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
    torch.save(state, model_out_path)

def load_checkpoint(model, weights, strict=True):
    checkpoint = torch.load(weights, map_location=torch.device('cpu'))
    try:
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            new_state_dict[k] = v
        model.load_state_dict(new_state_dict, strict=strict)
    except:
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] if 'module.' in k else k
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict, strict=strict)

def load_checkpoint_multigpu(model, weights):
    checkpoint = torch.load(weights)
    state_dict = checkpoint["state_dict"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] 
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)

def load_start_epoch(weights):
    checkpoint = torch.load(weights,  map_location=torch.device('cpu'))
    epoch = checkpoint["epoch"]
    return epoch

def load_optim(optimizer, weights):
    checkpoint = torch.load(weights,  map_location=torch.device('cpu'))
    optimizer.load_state_dict(checkpoint['optimizer'])
    for p in optimizer.param_groups: lr = p['lr']
    return lr

def get_arch(opt):
    from model import ShadowFormer, DenseSR
    arch = opt.arch

    print('You choose '+arch+'...')
    if arch == 'ShadowFormer':
        model_restoration = ShadowFormer(img_size=opt.train_ps,embed_dim=opt.embed_dim,
                                        win_size=opt.win_size,token_projection=opt.token_projection,
                                        token_mlp=opt.token_mlp)
    elif arch == 'DenseSR':
        model_restoration = DenseSR(img_size=opt.train_ps,embed_dim=opt.embed_dim,
                                        win_size=opt.win_size,token_projection=opt.token_projection,
                                        token_mlp=opt.token_mlp)
    else:
        raise Exception("Arch error!")

    return model_restoration


def window_partition(x, win_size):
    B, C, H, W = x.shape
    x = x.permute(0,2,3,1)
    x = x.reshape(B, H // win_size, win_size, W // win_size, win_size, C)
    x = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, win_size, win_size, C)
    return x.permute(0,3,1,2)

# def distributed_concat(var, num_total):
#     var_list = [torch.zeros(1, dtype=var.dtype).cuda() for _ in range(torch.distributed.get_world_size())]
#     torch.distributed.all_gather(var_list, var)
#     # truncate the dummy elements added by SequentialDistributedSampler
#     return var_list[:num_total]

def distributed_concat(var, num_total):
    # 確保 var 是一個 1D tensor (shape: [1])
    var = var.view(1) if var.dim() == 0 else var

    var_list = [torch.zeros_like(var).cuda() for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(var_list, var)

    # truncate the dummy elements added by SequentialDistributedSampler
    return var_list[:num_total]