import torch import os from collections import OrderedDict def freeze(model): for p in model.parameters(): p.requires_grad=False def unfreeze(model): for p in model.parameters(): p.requires_grad=True def is_frozen(model): x = [p.requires_grad for p in model.parameters()] return not all(x) def save_checkpoint(model_dir, state, session): epoch = state['epoch'] model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) torch.save(state, model_out_path) def load_checkpoint(model, weights, strict=True): checkpoint = torch.load(weights, map_location=torch.device('cpu')) try: state_dict = checkpoint["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k] = v model.load_state_dict(new_state_dict, strict=strict) except: state_dict = checkpoint["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] if 'module.' in k else k new_state_dict[name] = v model.load_state_dict(new_state_dict, strict=strict) def load_checkpoint_multigpu(model, weights): checkpoint = torch.load(weights) state_dict = checkpoint["state_dict"] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] new_state_dict[name] = v model.load_state_dict(new_state_dict) def load_start_epoch(weights): checkpoint = torch.load(weights, map_location=torch.device('cpu')) epoch = checkpoint["epoch"] return epoch def load_optim(optimizer, weights): checkpoint = torch.load(weights, map_location=torch.device('cpu')) optimizer.load_state_dict(checkpoint['optimizer']) for p in optimizer.param_groups: lr = p['lr'] return lr def get_arch(opt): from model import ShadowFormer, DenseSR arch = opt.arch print('You choose '+arch+'...') if arch == 'ShadowFormer': model_restoration = ShadowFormer(img_size=opt.train_ps,embed_dim=opt.embed_dim, win_size=opt.win_size,token_projection=opt.token_projection, token_mlp=opt.token_mlp) elif arch == 'DenseSR': model_restoration = DenseSR(img_size=opt.train_ps,embed_dim=opt.embed_dim, win_size=opt.win_size,token_projection=opt.token_projection, token_mlp=opt.token_mlp) else: raise Exception("Arch error!") return model_restoration def window_partition(x, win_size): B, C, H, W = x.shape x = x.permute(0,2,3,1) x = x.reshape(B, H // win_size, win_size, W // win_size, win_size, C) x = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, win_size, win_size, C) return x.permute(0,3,1,2) # def distributed_concat(var, num_total): # var_list = [torch.zeros(1, dtype=var.dtype).cuda() for _ in range(torch.distributed.get_world_size())] # torch.distributed.all_gather(var_list, var) # # truncate the dummy elements added by SequentialDistributedSampler # return var_list[:num_total] def distributed_concat(var, num_total): # 確保 var 是一個 1D tensor (shape: [1]) var = var.view(1) if var.dim() == 0 else var var_list = [torch.zeros_like(var).cuda() for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(var_list, var) # truncate the dummy elements added by SequentialDistributedSampler return var_list[:num_total]