Upload 6 files
Browse files- utils/common.py +159 -0
 - utils/cond_fn.py +98 -0
 - utils/face_restoration_helper.py +517 -0
 - utils/helpers.py +216 -0
 - utils/inference.py +320 -0
 - utils/sampler.py +341 -0
 
    	
        utils/common.py
    ADDED
    
    | 
         @@ -0,0 +1,159 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Mapping, Any, Tuple, Callable
         
     | 
| 2 | 
         
            +
            import importlib
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            from urllib.parse import urlparse
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from torch import Tensor
         
     | 
| 8 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from torch.hub import download_url_to_file, get_dir
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def get_obj_from_str(string: str, reload: bool=False) -> Any:
         
     | 
| 15 | 
         
            +
                module, cls = string.rsplit(".", 1)
         
     | 
| 16 | 
         
            +
                if reload:
         
     | 
| 17 | 
         
            +
                    module_imp = importlib.import_module(module)
         
     | 
| 18 | 
         
            +
                    importlib.reload(module_imp)
         
     | 
| 19 | 
         
            +
                return getattr(importlib.import_module(module, package=None), cls)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def instantiate_from_config(config: Mapping[str, Any]) -> Any:
         
     | 
| 23 | 
         
            +
                if not "target" in config:
         
     | 
| 24 | 
         
            +
                    raise KeyError("Expected key `target` to instantiate.")
         
     | 
| 25 | 
         
            +
                return get_obj_from_str(config["target"])(**config.get("params", dict()))
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def wavelet_blur(image: Tensor, radius: int):
         
     | 
| 29 | 
         
            +
                """
         
     | 
| 30 | 
         
            +
                Apply wavelet blur to the input tensor.
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
                # input shape: (1, 3, H, W)
         
     | 
| 33 | 
         
            +
                # convolution kernel
         
     | 
| 34 | 
         
            +
                kernel_vals = [
         
     | 
| 35 | 
         
            +
                    [0.0625, 0.125, 0.0625],
         
     | 
| 36 | 
         
            +
                    [0.125, 0.25, 0.125],
         
     | 
| 37 | 
         
            +
                    [0.0625, 0.125, 0.0625],
         
     | 
| 38 | 
         
            +
                ]
         
     | 
| 39 | 
         
            +
                kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
         
     | 
| 40 | 
         
            +
                # add channel dimensions to the kernel to make it a 4D tensor
         
     | 
| 41 | 
         
            +
                kernel = kernel[None, None]
         
     | 
| 42 | 
         
            +
                # repeat the kernel across all input channels
         
     | 
| 43 | 
         
            +
                kernel = kernel.repeat(3, 1, 1, 1)
         
     | 
| 44 | 
         
            +
                image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
         
     | 
| 45 | 
         
            +
                # apply convolution
         
     | 
| 46 | 
         
            +
                output = F.conv2d(image, kernel, groups=3, dilation=radius)
         
     | 
| 47 | 
         
            +
                return output
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def wavelet_decomposition(image: Tensor, levels=5):
         
     | 
| 51 | 
         
            +
                """
         
     | 
| 52 | 
         
            +
                Apply wavelet decomposition to the input tensor.
         
     | 
| 53 | 
         
            +
                This function only returns the low frequency & the high frequency.
         
     | 
| 54 | 
         
            +
                """
         
     | 
| 55 | 
         
            +
                high_freq = torch.zeros_like(image)
         
     | 
| 56 | 
         
            +
                for i in range(levels):
         
     | 
| 57 | 
         
            +
                    radius = 2 ** i
         
     | 
| 58 | 
         
            +
                    low_freq = wavelet_blur(image, radius)
         
     | 
| 59 | 
         
            +
                    high_freq += (image - low_freq)
         
     | 
| 60 | 
         
            +
                    image = low_freq
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                return high_freq, low_freq
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
         
     | 
| 66 | 
         
            +
                """
         
     | 
| 67 | 
         
            +
                Apply wavelet decomposition, so that the content will have the same color as the style.
         
     | 
| 68 | 
         
            +
                """
         
     | 
| 69 | 
         
            +
                # calculate the wavelet decomposition of the content feature
         
     | 
| 70 | 
         
            +
                content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
         
     | 
| 71 | 
         
            +
                del content_low_freq
         
     | 
| 72 | 
         
            +
                # calculate the wavelet decomposition of the style feature
         
     | 
| 73 | 
         
            +
                style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
         
     | 
| 74 | 
         
            +
                del style_high_freq
         
     | 
| 75 | 
         
            +
                # reconstruct the content feature with the style's high frequency
         
     | 
| 76 | 
         
            +
                return content_high_freq + style_low_freq
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
         
     | 
| 80 | 
         
            +
            def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
         
     | 
| 81 | 
         
            +
                """Load file form http url, will download models if necessary.
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                Args:
         
     | 
| 86 | 
         
            +
                    url (str): URL to be downloaded.
         
     | 
| 87 | 
         
            +
                    model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
         
     | 
| 88 | 
         
            +
                        Default: None.
         
     | 
| 89 | 
         
            +
                    progress (bool): Whether to show the download progress. Default: True.
         
     | 
| 90 | 
         
            +
                    file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                Returns:
         
     | 
| 93 | 
         
            +
                    str: The path to the downloaded file.
         
     | 
| 94 | 
         
            +
                """
         
     | 
| 95 | 
         
            +
                if model_dir is None:  # use the pytorch hub_dir
         
     | 
| 96 | 
         
            +
                    hub_dir = get_dir()
         
     | 
| 97 | 
         
            +
                    model_dir = os.path.join(hub_dir, 'checkpoints')
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                os.makedirs(model_dir, exist_ok=True)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                parts = urlparse(url)
         
     | 
| 102 | 
         
            +
                filename = os.path.basename(parts.path)
         
     | 
| 103 | 
         
            +
                if file_name is not None:
         
     | 
| 104 | 
         
            +
                    filename = file_name
         
     | 
| 105 | 
         
            +
                cached_file = os.path.abspath(os.path.join(model_dir, filename))
         
     | 
| 106 | 
         
            +
                if not os.path.exists(cached_file):
         
     | 
| 107 | 
         
            +
                    print(f'Downloading: "{url}" to {cached_file}\n')
         
     | 
| 108 | 
         
            +
                    download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
         
     | 
| 109 | 
         
            +
                return cached_file
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]:
         
     | 
| 113 | 
         
            +
                hi_list = list(range(0, h - tile_size + 1, tile_stride))
         
     | 
| 114 | 
         
            +
                if (h - tile_size) % tile_stride != 0:
         
     | 
| 115 | 
         
            +
                    hi_list.append(h - tile_size)
         
     | 
| 116 | 
         
            +
                
         
     | 
| 117 | 
         
            +
                wi_list = list(range(0, w - tile_size + 1, tile_stride))
         
     | 
| 118 | 
         
            +
                if (w - tile_size) % tile_stride != 0:
         
     | 
| 119 | 
         
            +
                    wi_list.append(w - tile_size)
         
     | 
| 120 | 
         
            +
                
         
     | 
| 121 | 
         
            +
                coords = []
         
     | 
| 122 | 
         
            +
                for hi in hi_list:
         
     | 
| 123 | 
         
            +
                    for wi in wi_list:
         
     | 
| 124 | 
         
            +
                        coords.append((hi, hi + tile_size, wi, wi + tile_size))
         
     | 
| 125 | 
         
            +
                return coords
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            # https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503
         
     | 
| 129 | 
         
            +
            def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray:
         
     | 
| 130 | 
         
            +
                """Generates a gaussian mask of weights for tile contributions"""
         
     | 
| 131 | 
         
            +
                latent_width = tile_width
         
     | 
| 132 | 
         
            +
                latent_height = tile_height
         
     | 
| 133 | 
         
            +
                var = 0.01
         
     | 
| 134 | 
         
            +
                midpoint = (latent_width - 1) / 2  # -1 because index goes from 0 to latent_width - 1
         
     | 
| 135 | 
         
            +
                x_probs = [
         
     | 
| 136 | 
         
            +
                    np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var)
         
     | 
| 137 | 
         
            +
                    for x in range(latent_width)]
         
     | 
| 138 | 
         
            +
                midpoint = latent_height / 2
         
     | 
| 139 | 
         
            +
                y_probs = [
         
     | 
| 140 | 
         
            +
                    np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var)
         
     | 
| 141 | 
         
            +
                    for y in range(latent_height)]
         
     | 
| 142 | 
         
            +
                weights = np.outer(y_probs, x_probs)
         
     | 
| 143 | 
         
            +
                return weights
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False))
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
            def count_vram_usage(func: Callable) -> Callable:
         
     | 
| 149 | 
         
            +
                if not COUNT_VRAM:
         
     | 
| 150 | 
         
            +
                    return func
         
     | 
| 151 | 
         
            +
                
         
     | 
| 152 | 
         
            +
                def wrapper(*args, **kwargs):
         
     | 
| 153 | 
         
            +
                    peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3)
         
     | 
| 154 | 
         
            +
                    ret = func(*args, **kwargs)
         
     | 
| 155 | 
         
            +
                    torch.cuda.synchronize()
         
     | 
| 156 | 
         
            +
                    peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3)
         
     | 
| 157 | 
         
            +
                    print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB")
         
     | 
| 158 | 
         
            +
                    return ret
         
     | 
| 159 | 
         
            +
                return wrapper
         
     | 
    	
        utils/cond_fn.py
    ADDED
    
    | 
         @@ -0,0 +1,98 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import overload, Tuple
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class Guidance:
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
                def __init__(self, scale: float, t_start: int, t_stop: int, space: str, repeat: int) -> "Guidance":
         
     | 
| 9 | 
         
            +
                    """
         
     | 
| 10 | 
         
            +
                    Initialize restoration guidance.
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                    Args:
         
     | 
| 13 | 
         
            +
                        scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale, 
         
     | 
| 14 | 
         
            +
                            the closer the final result will be to the output of the first stage model.
         
     | 
| 15 | 
         
            +
                        t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling 
         
     | 
| 16 | 
         
            +
                            process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`.
         
     | 
| 17 | 
         
            +
                        space (str): The data space for computing loss function (rgb or latent).
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                    Our restoration guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior).
         
     | 
| 20 | 
         
            +
                    Thanks for their work!
         
     | 
| 21 | 
         
            +
                    """
         
     | 
| 22 | 
         
            +
                    self.scale = scale * 3000
         
     | 
| 23 | 
         
            +
                    self.t_start = t_start
         
     | 
| 24 | 
         
            +
                    self.t_stop = t_stop
         
     | 
| 25 | 
         
            +
                    self.target = None
         
     | 
| 26 | 
         
            +
                    self.space = space
         
     | 
| 27 | 
         
            +
                    self.repeat = repeat
         
     | 
| 28 | 
         
            +
                
         
     | 
| 29 | 
         
            +
                def load_target(self, target: torch.Tensor) -> None:
         
     | 
| 30 | 
         
            +
                    self.target = target
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def __call__(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
         
     | 
| 33 | 
         
            +
                    # avoid propagating gradient out of this scope
         
     | 
| 34 | 
         
            +
                    pred_x0 = pred_x0.detach().clone()
         
     | 
| 35 | 
         
            +
                    target_x0 = target_x0.detach().clone()
         
     | 
| 36 | 
         
            +
                    return self._forward(target_x0, pred_x0, t)
         
     | 
| 37 | 
         
            +
                
         
     | 
| 38 | 
         
            +
                @overload
         
     | 
| 39 | 
         
            +
                def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
         
     | 
| 40 | 
         
            +
                    ...
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            class MSEGuidance(Guidance):
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
         
     | 
| 46 | 
         
            +
                    # inputs: [-1, 1], nchw, rgb
         
     | 
| 47 | 
         
            +
                    with torch.enable_grad():
         
     | 
| 48 | 
         
            +
                        pred_x0.requires_grad_(True)
         
     | 
| 49 | 
         
            +
                        loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
         
     | 
| 50 | 
         
            +
                    scale = self.scale
         
     | 
| 51 | 
         
            +
                    g = -torch.autograd.grad(loss, pred_x0)[0] * scale
         
     | 
| 52 | 
         
            +
                    return g, loss.item()
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            class WeightedMSEGuidance(Guidance):
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def _get_weight(self, target: torch.Tensor) -> torch.Tensor:
         
     | 
| 58 | 
         
            +
                    # convert RGB to G
         
     | 
| 59 | 
         
            +
                    rgb_to_gray_kernel = torch.tensor([0.2989, 0.5870, 0.1140]).view(1, 3, 1, 1)
         
     | 
| 60 | 
         
            +
                    target = torch.sum(target * rgb_to_gray_kernel.to(target.device), dim=1, keepdim=True)
         
     | 
| 61 | 
         
            +
                    # initialize sobel kernel in x and y axis
         
     | 
| 62 | 
         
            +
                    G_x = [
         
     | 
| 63 | 
         
            +
                        [1, 0, -1],
         
     | 
| 64 | 
         
            +
                        [2, 0, -2],
         
     | 
| 65 | 
         
            +
                        [1, 0, -1]
         
     | 
| 66 | 
         
            +
                    ]
         
     | 
| 67 | 
         
            +
                    G_y = [
         
     | 
| 68 | 
         
            +
                        [1, 2, 1],
         
     | 
| 69 | 
         
            +
                        [0, 0, 0],
         
     | 
| 70 | 
         
            +
                        [-1, -2, -1]
         
     | 
| 71 | 
         
            +
                    ]
         
     | 
| 72 | 
         
            +
                    G_x = torch.tensor(G_x, dtype=target.dtype, device=target.device)[None]
         
     | 
| 73 | 
         
            +
                    G_y = torch.tensor(G_y, dtype=target.dtype, device=target.device)[None]
         
     | 
| 74 | 
         
            +
                    G = torch.stack((G_x, G_y))
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    target = F.pad(target, (1, 1, 1, 1), mode='replicate') # padding = 1
         
     | 
| 77 | 
         
            +
                    grad = F.conv2d(target, G, stride=1)
         
     | 
| 78 | 
         
            +
                    mag = grad.pow(2).sum(dim=1, keepdim=True).sqrt()
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    n, c, h, w = mag.size()
         
     | 
| 81 | 
         
            +
                    block_size = 2
         
     | 
| 82 | 
         
            +
                    blocks = mag.view(n, c, h // block_size, block_size, w // block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous()
         
     | 
| 83 | 
         
            +
                    block_mean = blocks.sum(dim=(-2, -1), keepdim=True).tanh().repeat(1, 1, 1, 1, block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous()
         
     | 
| 84 | 
         
            +
                    block_mean = block_mean.view(n, c, h, w)
         
     | 
| 85 | 
         
            +
                    weight_map = 1 - block_mean
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    return weight_map
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
         
     | 
| 90 | 
         
            +
                    # inputs: [-1, 1], nchw, rgb
         
     | 
| 91 | 
         
            +
                    with torch.no_grad():
         
     | 
| 92 | 
         
            +
                        w = self._get_weight((target_x0 + 1) / 2)
         
     | 
| 93 | 
         
            +
                    with torch.enable_grad():
         
     | 
| 94 | 
         
            +
                        pred_x0.requires_grad_(True)
         
     | 
| 95 | 
         
            +
                        loss = ((pred_x0 - target_x0).pow(2) * w).mean((1, 2, 3)).sum()
         
     | 
| 96 | 
         
            +
                    scale = self.scale
         
     | 
| 97 | 
         
            +
                    g = -torch.autograd.grad(loss, pred_x0)[0] * scale
         
     | 
| 98 | 
         
            +
                    return g, loss.item()
         
     | 
    	
        utils/face_restoration_helper.py
    ADDED
    
    | 
         @@ -0,0 +1,517 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import cv2
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            from torchvision.transforms.functional import normalize
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from facexlib.detection import init_detection_model
         
     | 
| 8 | 
         
            +
            from facexlib.parsing import init_parsing_model
         
     | 
| 9 | 
         
            +
            from facexlib.utils.misc import img2tensor, imwrite
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from utils.common import load_file_from_url
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def get_largest_face(det_faces, h, w):
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                def get_location(val, length):
         
     | 
| 16 | 
         
            +
                    if val < 0:
         
     | 
| 17 | 
         
            +
                        return 0
         
     | 
| 18 | 
         
            +
                    elif val > length:
         
     | 
| 19 | 
         
            +
                        return length
         
     | 
| 20 | 
         
            +
                    else:
         
     | 
| 21 | 
         
            +
                        return val
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                face_areas = []
         
     | 
| 24 | 
         
            +
                for det_face in det_faces:
         
     | 
| 25 | 
         
            +
                    left = get_location(det_face[0], w)
         
     | 
| 26 | 
         
            +
                    right = get_location(det_face[2], w)
         
     | 
| 27 | 
         
            +
                    top = get_location(det_face[1], h)
         
     | 
| 28 | 
         
            +
                    bottom = get_location(det_face[3], h)
         
     | 
| 29 | 
         
            +
                    face_area = (right - left) * (bottom - top)
         
     | 
| 30 | 
         
            +
                    face_areas.append(face_area)
         
     | 
| 31 | 
         
            +
                largest_idx = face_areas.index(max(face_areas))
         
     | 
| 32 | 
         
            +
                return det_faces[largest_idx], largest_idx
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def get_center_face(det_faces, h=0, w=0, center=None):
         
     | 
| 36 | 
         
            +
                if center is not None:
         
     | 
| 37 | 
         
            +
                    center = np.array(center)
         
     | 
| 38 | 
         
            +
                else:
         
     | 
| 39 | 
         
            +
                    center = np.array([w / 2, h / 2])
         
     | 
| 40 | 
         
            +
                center_dist = []
         
     | 
| 41 | 
         
            +
                for det_face in det_faces:
         
     | 
| 42 | 
         
            +
                    face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
         
     | 
| 43 | 
         
            +
                    dist = np.linalg.norm(face_center - center)
         
     | 
| 44 | 
         
            +
                    center_dist.append(dist)
         
     | 
| 45 | 
         
            +
                center_idx = center_dist.index(min(center_dist))
         
     | 
| 46 | 
         
            +
                return det_faces[center_idx], center_idx
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            class FaceRestoreHelper(object):
         
     | 
| 50 | 
         
            +
                """Helper for the face restoration pipeline (base class)."""
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def __init__(self,
         
     | 
| 53 | 
         
            +
                             upscale_factor,
         
     | 
| 54 | 
         
            +
                             face_size=512,
         
     | 
| 55 | 
         
            +
                             crop_ratio=(1, 1),
         
     | 
| 56 | 
         
            +
                             det_model='retinaface_resnet50',
         
     | 
| 57 | 
         
            +
                             save_ext='png',
         
     | 
| 58 | 
         
            +
                             template_3points=False,
         
     | 
| 59 | 
         
            +
                             pad_blur=False,
         
     | 
| 60 | 
         
            +
                             use_parse=False,
         
     | 
| 61 | 
         
            +
                             device=None):
         
     | 
| 62 | 
         
            +
                    self.template_3points = template_3points  # improve robustness
         
     | 
| 63 | 
         
            +
                    self.upscale_factor = int(upscale_factor)
         
     | 
| 64 | 
         
            +
                    # the cropped face ratio based on the square face
         
     | 
| 65 | 
         
            +
                    self.crop_ratio = crop_ratio  # (h, w)
         
     | 
| 66 | 
         
            +
                    assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
         
     | 
| 67 | 
         
            +
                    self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
         
     | 
| 68 | 
         
            +
                    self.det_model = det_model
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    if self.det_model == 'dlib':
         
     | 
| 71 | 
         
            +
                        # standard 5 landmarks for FFHQ faces with 1024 x 1024
         
     | 
| 72 | 
         
            +
                        self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
         
     | 
| 73 | 
         
            +
                                                    [337.91089109, 488.38613861], [437.95049505, 493.51485149],
         
     | 
| 74 | 
         
            +
                                                    [513.58415842, 678.5049505]])
         
     | 
| 75 | 
         
            +
                        self.face_template = self.face_template / (1024 // face_size)
         
     | 
| 76 | 
         
            +
                    elif self.template_3points:
         
     | 
| 77 | 
         
            +
                        self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
         
     | 
| 78 | 
         
            +
                    else:
         
     | 
| 79 | 
         
            +
                        # standard 5 landmarks for FFHQ faces with 512 x 512 
         
     | 
| 80 | 
         
            +
                        # facexlib
         
     | 
| 81 | 
         
            +
                        self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
         
     | 
| 82 | 
         
            +
                                                       [201.26117, 371.41043], [313.08905, 371.15118]])
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                        # dlib: left_eye: 36:41  right_eye: 42:47  nose: 30,32,33,34  left mouth corner: 48  right mouth corner: 54
         
     | 
| 85 | 
         
            +
                        # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
         
     | 
| 86 | 
         
            +
                        #                                 [198.22603, 372.82502], [313.91018, 372.75659]])
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    self.face_template = self.face_template * (face_size / 512.0)
         
     | 
| 89 | 
         
            +
                    if self.crop_ratio[0] > 1:
         
     | 
| 90 | 
         
            +
                        self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
         
     | 
| 91 | 
         
            +
                    if self.crop_ratio[1] > 1:
         
     | 
| 92 | 
         
            +
                        self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
         
     | 
| 93 | 
         
            +
                    self.save_ext = save_ext
         
     | 
| 94 | 
         
            +
                    self.pad_blur = pad_blur
         
     | 
| 95 | 
         
            +
                    if self.pad_blur is True:
         
     | 
| 96 | 
         
            +
                        self.template_3points = False
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    self.all_landmarks_5 = []
         
     | 
| 99 | 
         
            +
                    self.det_faces = []
         
     | 
| 100 | 
         
            +
                    self.affine_matrices = []
         
     | 
| 101 | 
         
            +
                    self.inverse_affine_matrices = []
         
     | 
| 102 | 
         
            +
                    self.cropped_faces = []
         
     | 
| 103 | 
         
            +
                    self.restored_faces = []
         
     | 
| 104 | 
         
            +
                    self.pad_input_imgs = []
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    if device is None:
         
     | 
| 107 | 
         
            +
                        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         
     | 
| 108 | 
         
            +
                        # self.device = get_device()
         
     | 
| 109 | 
         
            +
                    else:
         
     | 
| 110 | 
         
            +
                        self.device = device
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    # init face detection model
         
     | 
| 113 | 
         
            +
                    self.face_detector = init_detection_model(det_model, half=False, device=self.device)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    # init face parsing model
         
     | 
| 116 | 
         
            +
                    self.use_parse = use_parse
         
     | 
| 117 | 
         
            +
                    self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                def set_upscale_factor(self, upscale_factor):
         
     | 
| 120 | 
         
            +
                    self.upscale_factor = upscale_factor
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                def read_image(self, img):
         
     | 
| 123 | 
         
            +
                    """img can be image path or cv2 loaded image."""
         
     | 
| 124 | 
         
            +
                    # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
         
     | 
| 125 | 
         
            +
                    if isinstance(img, str):
         
     | 
| 126 | 
         
            +
                        img = cv2.imread(img)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    if np.max(img) > 256:  # 16-bit image
         
     | 
| 129 | 
         
            +
                        img = img / 65535 * 255
         
     | 
| 130 | 
         
            +
                    if len(img.shape) == 2:  # gray image
         
     | 
| 131 | 
         
            +
                        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
         
     | 
| 132 | 
         
            +
                    elif img.shape[2] == 4:  # BGRA image with alpha channel
         
     | 
| 133 | 
         
            +
                        img = img[:, :, 0:3]
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    self.input_img = img
         
     | 
| 136 | 
         
            +
                    # self.is_gray = is_gray(img, threshold=10)
         
     | 
| 137 | 
         
            +
                    # if self.is_gray:
         
     | 
| 138 | 
         
            +
                    #     print('Grayscale input: True')
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    if min(self.input_img.shape[:2])<512:
         
     | 
| 141 | 
         
            +
                        f = 512.0/min(self.input_img.shape[:2])
         
     | 
| 142 | 
         
            +
                        self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                def init_dlib(self, detection_path, landmark5_path):
         
     | 
| 145 | 
         
            +
                    """Initialize the dlib detectors and predictors."""
         
     | 
| 146 | 
         
            +
                    try:
         
     | 
| 147 | 
         
            +
                        import dlib
         
     | 
| 148 | 
         
            +
                    except ImportError:
         
     | 
| 149 | 
         
            +
                        print('Please install dlib by running:' 'conda install -c conda-forge dlib')
         
     | 
| 150 | 
         
            +
                    detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
         
     | 
| 151 | 
         
            +
                    landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
         
     | 
| 152 | 
         
            +
                    face_detector = dlib.cnn_face_detection_model_v1(detection_path)
         
     | 
| 153 | 
         
            +
                    shape_predictor_5 = dlib.shape_predictor(landmark5_path)
         
     | 
| 154 | 
         
            +
                    return face_detector, shape_predictor_5
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                def get_face_landmarks_5_dlib(self,
         
     | 
| 157 | 
         
            +
                                            only_keep_largest=False,
         
     | 
| 158 | 
         
            +
                                            scale=1):
         
     | 
| 159 | 
         
            +
                    det_faces = self.face_detector(self.input_img, scale)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    if len(det_faces) == 0:
         
     | 
| 162 | 
         
            +
                        print('No face detected. Try to increase upsample_num_times.')
         
     | 
| 163 | 
         
            +
                        return 0
         
     | 
| 164 | 
         
            +
                    else:
         
     | 
| 165 | 
         
            +
                        if only_keep_largest:
         
     | 
| 166 | 
         
            +
                            print('Detect several faces and only keep the largest.')
         
     | 
| 167 | 
         
            +
                            face_areas = []
         
     | 
| 168 | 
         
            +
                            for i in range(len(det_faces)):
         
     | 
| 169 | 
         
            +
                                face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
         
     | 
| 170 | 
         
            +
                                    det_faces[i].rect.bottom() - det_faces[i].rect.top())
         
     | 
| 171 | 
         
            +
                                face_areas.append(face_area)
         
     | 
| 172 | 
         
            +
                            largest_idx = face_areas.index(max(face_areas))
         
     | 
| 173 | 
         
            +
                            self.det_faces = [det_faces[largest_idx]]
         
     | 
| 174 | 
         
            +
                        else:
         
     | 
| 175 | 
         
            +
                            self.det_faces = det_faces
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    if len(self.det_faces) == 0:
         
     | 
| 178 | 
         
            +
                        return 0
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                    for face in self.det_faces:
         
     | 
| 181 | 
         
            +
                        shape = self.shape_predictor_5(self.input_img, face.rect)
         
     | 
| 182 | 
         
            +
                        landmark = np.array([[part.x, part.y] for part in shape.parts()])
         
     | 
| 183 | 
         
            +
                        self.all_landmarks_5.append(landmark)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    return len(self.all_landmarks_5)
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                def get_face_landmarks_5(self,
         
     | 
| 189 | 
         
            +
                                         only_keep_largest=False,
         
     | 
| 190 | 
         
            +
                                         only_center_face=False,
         
     | 
| 191 | 
         
            +
                                         resize=None,
         
     | 
| 192 | 
         
            +
                                         blur_ratio=0.01,
         
     | 
| 193 | 
         
            +
                                         eye_dist_threshold=None):
         
     | 
| 194 | 
         
            +
                    if self.det_model == 'dlib':
         
     | 
| 195 | 
         
            +
                        return self.get_face_landmarks_5_dlib(only_keep_largest)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                    if resize is None:
         
     | 
| 198 | 
         
            +
                        scale = 1
         
     | 
| 199 | 
         
            +
                        input_img = self.input_img
         
     | 
| 200 | 
         
            +
                    else:
         
     | 
| 201 | 
         
            +
                        h, w = self.input_img.shape[0:2]
         
     | 
| 202 | 
         
            +
                        scale = resize / min(h, w)
         
     | 
| 203 | 
         
            +
                        scale = max(1, scale) # always scale up
         
     | 
| 204 | 
         
            +
                        h, w = int(h * scale), int(w * scale)
         
     | 
| 205 | 
         
            +
                        interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
         
     | 
| 206 | 
         
            +
                        input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    with torch.no_grad():
         
     | 
| 209 | 
         
            +
                        bboxes = self.face_detector.detect_faces(input_img)
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                    if bboxes is None or bboxes.shape[0] == 0:
         
     | 
| 212 | 
         
            +
                        return 0
         
     | 
| 213 | 
         
            +
                    else:
         
     | 
| 214 | 
         
            +
                        bboxes = bboxes / scale
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    for bbox in bboxes:
         
     | 
| 217 | 
         
            +
                        # remove faces with too small eye distance: side faces or too small faces
         
     | 
| 218 | 
         
            +
                        eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
         
     | 
| 219 | 
         
            +
                        if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
         
     | 
| 220 | 
         
            +
                            continue
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                        if self.template_3points:
         
     | 
| 223 | 
         
            +
                            landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
         
     | 
| 224 | 
         
            +
                        else:
         
     | 
| 225 | 
         
            +
                            landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
         
     | 
| 226 | 
         
            +
                        self.all_landmarks_5.append(landmark)
         
     | 
| 227 | 
         
            +
                        self.det_faces.append(bbox[0:5])
         
     | 
| 228 | 
         
            +
                        
         
     | 
| 229 | 
         
            +
                    if len(self.det_faces) == 0:
         
     | 
| 230 | 
         
            +
                        return 0
         
     | 
| 231 | 
         
            +
                    if only_keep_largest:
         
     | 
| 232 | 
         
            +
                        h, w, _ = self.input_img.shape
         
     | 
| 233 | 
         
            +
                        self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
         
     | 
| 234 | 
         
            +
                        self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
         
     | 
| 235 | 
         
            +
                    elif only_center_face:
         
     | 
| 236 | 
         
            +
                        h, w, _ = self.input_img.shape
         
     | 
| 237 | 
         
            +
                        self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
         
     | 
| 238 | 
         
            +
                        self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    # pad blurry images
         
     | 
| 241 | 
         
            +
                    if self.pad_blur:
         
     | 
| 242 | 
         
            +
                        self.pad_input_imgs = []
         
     | 
| 243 | 
         
            +
                        for landmarks in self.all_landmarks_5:
         
     | 
| 244 | 
         
            +
                            # get landmarks
         
     | 
| 245 | 
         
            +
                            eye_left = landmarks[0, :]
         
     | 
| 246 | 
         
            +
                            eye_right = landmarks[1, :]
         
     | 
| 247 | 
         
            +
                            eye_avg = (eye_left + eye_right) * 0.5
         
     | 
| 248 | 
         
            +
                            mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
         
     | 
| 249 | 
         
            +
                            eye_to_eye = eye_right - eye_left
         
     | 
| 250 | 
         
            +
                            eye_to_mouth = mouth_avg - eye_avg
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                            # Get the oriented crop rectangle
         
     | 
| 253 | 
         
            +
                            # x: half width of the oriented crop rectangle
         
     | 
| 254 | 
         
            +
                            x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
         
     | 
| 255 | 
         
            +
                            #  - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
         
     | 
| 256 | 
         
            +
                            # norm with the hypotenuse: get the direction
         
     | 
| 257 | 
         
            +
                            x /= np.hypot(*x)  # get the hypotenuse of a right triangle
         
     | 
| 258 | 
         
            +
                            rect_scale = 1.5
         
     | 
| 259 | 
         
            +
                            x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
         
     | 
| 260 | 
         
            +
                            # y: half height of the oriented crop rectangle
         
     | 
| 261 | 
         
            +
                            y = np.flipud(x) * [-1, 1]
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                            # c: center
         
     | 
| 264 | 
         
            +
                            c = eye_avg + eye_to_mouth * 0.1
         
     | 
| 265 | 
         
            +
                            # quad: (left_top, left_bottom, right_bottom, right_top)
         
     | 
| 266 | 
         
            +
                            quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
         
     | 
| 267 | 
         
            +
                            # qsize: side length of the square
         
     | 
| 268 | 
         
            +
                            qsize = np.hypot(*x) * 2
         
     | 
| 269 | 
         
            +
                            border = max(int(np.rint(qsize * 0.1)), 3)
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                            # get pad
         
     | 
| 272 | 
         
            +
                            # pad: (width_left, height_top, width_right, height_bottom)
         
     | 
| 273 | 
         
            +
                            pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
         
     | 
| 274 | 
         
            +
                                   int(np.ceil(max(quad[:, 1]))))
         
     | 
| 275 | 
         
            +
                            pad = [
         
     | 
| 276 | 
         
            +
                                max(-pad[0] + border, 1),
         
     | 
| 277 | 
         
            +
                                max(-pad[1] + border, 1),
         
     | 
| 278 | 
         
            +
                                max(pad[2] - self.input_img.shape[0] + border, 1),
         
     | 
| 279 | 
         
            +
                                max(pad[3] - self.input_img.shape[1] + border, 1)
         
     | 
| 280 | 
         
            +
                            ]
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                            if max(pad) > 1:
         
     | 
| 283 | 
         
            +
                                # pad image
         
     | 
| 284 | 
         
            +
                                pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
         
     | 
| 285 | 
         
            +
                                # modify landmark coords
         
     | 
| 286 | 
         
            +
                                landmarks[:, 0] += pad[0]
         
     | 
| 287 | 
         
            +
                                landmarks[:, 1] += pad[1]
         
     | 
| 288 | 
         
            +
                                # blur pad images
         
     | 
| 289 | 
         
            +
                                h, w, _ = pad_img.shape
         
     | 
| 290 | 
         
            +
                                y, x, _ = np.ogrid[:h, :w, :1]
         
     | 
| 291 | 
         
            +
                                mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
         
     | 
| 292 | 
         
            +
                                                                   np.float32(w - 1 - x) / pad[2]),
         
     | 
| 293 | 
         
            +
                                                  1.0 - np.minimum(np.float32(y) / pad[1],
         
     | 
| 294 | 
         
            +
                                                                   np.float32(h - 1 - y) / pad[3]))
         
     | 
| 295 | 
         
            +
                                blur = int(qsize * blur_ratio)
         
     | 
| 296 | 
         
            +
                                if blur % 2 == 0:
         
     | 
| 297 | 
         
            +
                                    blur += 1
         
     | 
| 298 | 
         
            +
                                blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
         
     | 
| 299 | 
         
            +
                                # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                                pad_img = pad_img.astype('float32')
         
     | 
| 302 | 
         
            +
                                pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
         
     | 
| 303 | 
         
            +
                                pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
         
     | 
| 304 | 
         
            +
                                pad_img = np.clip(pad_img, 0, 255)  # float32, [0, 255]
         
     | 
| 305 | 
         
            +
                                self.pad_input_imgs.append(pad_img)
         
     | 
| 306 | 
         
            +
                            else:
         
     | 
| 307 | 
         
            +
                                self.pad_input_imgs.append(np.copy(self.input_img))
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                    return len(self.all_landmarks_5)
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
         
     | 
| 312 | 
         
            +
                    """Align and warp faces with face template.
         
     | 
| 313 | 
         
            +
                    """
         
     | 
| 314 | 
         
            +
                    if self.pad_blur:
         
     | 
| 315 | 
         
            +
                        assert len(self.pad_input_imgs) == len(
         
     | 
| 316 | 
         
            +
                            self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
         
     | 
| 317 | 
         
            +
                    for idx, landmark in enumerate(self.all_landmarks_5):
         
     | 
| 318 | 
         
            +
                        # use 5 landmarks to get affine matrix
         
     | 
| 319 | 
         
            +
                        # use cv2.LMEDS method for the equivalence to skimage transform
         
     | 
| 320 | 
         
            +
                        # ref: https://blog.csdn.net/yichxi/article/details/115827338
         
     | 
| 321 | 
         
            +
                        affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
         
     | 
| 322 | 
         
            +
                        self.affine_matrices.append(affine_matrix)
         
     | 
| 323 | 
         
            +
                        # warp and crop faces
         
     | 
| 324 | 
         
            +
                        if border_mode == 'constant':
         
     | 
| 325 | 
         
            +
                            border_mode = cv2.BORDER_CONSTANT
         
     | 
| 326 | 
         
            +
                        elif border_mode == 'reflect101':
         
     | 
| 327 | 
         
            +
                            border_mode = cv2.BORDER_REFLECT101
         
     | 
| 328 | 
         
            +
                        elif border_mode == 'reflect':
         
     | 
| 329 | 
         
            +
                            border_mode = cv2.BORDER_REFLECT
         
     | 
| 330 | 
         
            +
                        if self.pad_blur:
         
     | 
| 331 | 
         
            +
                            input_img = self.pad_input_imgs[idx]
         
     | 
| 332 | 
         
            +
                        else:
         
     | 
| 333 | 
         
            +
                            input_img = self.input_img
         
     | 
| 334 | 
         
            +
                        cropped_face = cv2.warpAffine(
         
     | 
| 335 | 
         
            +
                            input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132))  # gray
         
     | 
| 336 | 
         
            +
                        self.cropped_faces.append(cropped_face)
         
     | 
| 337 | 
         
            +
                        # save the cropped face
         
     | 
| 338 | 
         
            +
                        if save_cropped_path is not None:
         
     | 
| 339 | 
         
            +
                            path = os.path.splitext(save_cropped_path)[0]
         
     | 
| 340 | 
         
            +
                            save_path = f'{path}_{idx:02d}.{self.save_ext}'
         
     | 
| 341 | 
         
            +
                            imwrite(cropped_face, save_path)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                def get_inverse_affine(self, save_inverse_affine_path=None):
         
     | 
| 344 | 
         
            +
                    """Get inverse affine matrix."""
         
     | 
| 345 | 
         
            +
                    for idx, affine_matrix in enumerate(self.affine_matrices):
         
     | 
| 346 | 
         
            +
                        inverse_affine = cv2.invertAffineTransform(affine_matrix)
         
     | 
| 347 | 
         
            +
                        inverse_affine *= self.upscale_factor
         
     | 
| 348 | 
         
            +
                        self.inverse_affine_matrices.append(inverse_affine)
         
     | 
| 349 | 
         
            +
                        # save inverse affine matrices
         
     | 
| 350 | 
         
            +
                        if save_inverse_affine_path is not None:
         
     | 
| 351 | 
         
            +
                            path, _ = os.path.splitext(save_inverse_affine_path)
         
     | 
| 352 | 
         
            +
                            save_path = f'{path}_{idx:02d}.pth'
         
     | 
| 353 | 
         
            +
                            torch.save(inverse_affine, save_path)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                def add_restored_face(self, restored_face, input_face=None):
         
     | 
| 357 | 
         
            +
                    # if self.is_gray:
         
     | 
| 358 | 
         
            +
                    #     restored_face = bgr2gray(restored_face) # convert img into grayscale
         
     | 
| 359 | 
         
            +
                    #     if input_face is not None:
         
     | 
| 360 | 
         
            +
                    #         restored_face = adain_npy(restored_face, input_face) # transfer the color
         
     | 
| 361 | 
         
            +
                    self.restored_faces.append(restored_face)
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
         
     | 
| 365 | 
         
            +
                    h, w, _ = self.input_img.shape
         
     | 
| 366 | 
         
            +
                    h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                    if upsample_img is None:
         
     | 
| 369 | 
         
            +
                        # simply resize the background
         
     | 
| 370 | 
         
            +
                        # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
         
     | 
| 371 | 
         
            +
                        upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
         
     | 
| 372 | 
         
            +
                    else:
         
     | 
| 373 | 
         
            +
                        upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                    assert len(self.restored_faces) == len(
         
     | 
| 376 | 
         
            +
                        self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
         
     | 
| 377 | 
         
            +
                    
         
     | 
| 378 | 
         
            +
                    inv_mask_borders = []
         
     | 
| 379 | 
         
            +
                    for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
         
     | 
| 380 | 
         
            +
                        if face_upsampler is not None:
         
     | 
| 381 | 
         
            +
                            restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
         
     | 
| 382 | 
         
            +
                            inverse_affine /= self.upscale_factor
         
     | 
| 383 | 
         
            +
                            inverse_affine[:, 2] *= self.upscale_factor
         
     | 
| 384 | 
         
            +
                            face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
         
     | 
| 385 | 
         
            +
                        else:
         
     | 
| 386 | 
         
            +
                            # Add an offset to inverse affine matrix, for more precise back alignment
         
     | 
| 387 | 
         
            +
                            if self.upscale_factor > 1:
         
     | 
| 388 | 
         
            +
                                extra_offset = 0.5 * self.upscale_factor
         
     | 
| 389 | 
         
            +
                            else:
         
     | 
| 390 | 
         
            +
                                extra_offset = 0
         
     | 
| 391 | 
         
            +
                            inverse_affine[:, 2] += extra_offset
         
     | 
| 392 | 
         
            +
                            face_size = self.face_size
         
     | 
| 393 | 
         
            +
                        inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                        # if draw_box or not self.use_parse:  # use square parse maps
         
     | 
| 396 | 
         
            +
                        #     mask = np.ones(face_size, dtype=np.float32)
         
     | 
| 397 | 
         
            +
                        #     inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
         
     | 
| 398 | 
         
            +
                        #     # remove the black borders
         
     | 
| 399 | 
         
            +
                        #     inv_mask_erosion = cv2.erode(
         
     | 
| 400 | 
         
            +
                        #         inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
         
     | 
| 401 | 
         
            +
                        #     pasted_face = inv_mask_erosion[:, :, None] * inv_restored
         
     | 
| 402 | 
         
            +
                        #     total_face_area = np.sum(inv_mask_erosion)  # // 3
         
     | 
| 403 | 
         
            +
                        #     # add border
         
     | 
| 404 | 
         
            +
                        #     if draw_box:
         
     | 
| 405 | 
         
            +
                        #         h, w = face_size
         
     | 
| 406 | 
         
            +
                        #         mask_border = np.ones((h, w, 3), dtype=np.float32)
         
     | 
| 407 | 
         
            +
                        #         border = int(1400/np.sqrt(total_face_area))
         
     | 
| 408 | 
         
            +
                        #         mask_border[border:h-border, border:w-border,:] = 0
         
     | 
| 409 | 
         
            +
                        #         inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
         
     | 
| 410 | 
         
            +
                        #         inv_mask_borders.append(inv_mask_border)
         
     | 
| 411 | 
         
            +
                        #     if not self.use_parse:
         
     | 
| 412 | 
         
            +
                        #         # compute the fusion edge based on the area of face
         
     | 
| 413 | 
         
            +
                        #         w_edge = int(total_face_area**0.5) // 20
         
     | 
| 414 | 
         
            +
                        #         erosion_radius = w_edge * 2
         
     | 
| 415 | 
         
            +
                        #         inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
         
     | 
| 416 | 
         
            +
                        #         blur_size = w_edge * 2
         
     | 
| 417 | 
         
            +
                        #         inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
         
     | 
| 418 | 
         
            +
                        #         if len(upsample_img.shape) == 2:  # upsample_img is gray image
         
     | 
| 419 | 
         
            +
                        #             upsample_img = upsample_img[:, :, None]
         
     | 
| 420 | 
         
            +
                        #         inv_soft_mask = inv_soft_mask[:, :, None]
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
                        # always use square mask
         
     | 
| 423 | 
         
            +
                        mask = np.ones(face_size, dtype=np.float32)
         
     | 
| 424 | 
         
            +
                        inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
         
     | 
| 425 | 
         
            +
                        # remove the black borders
         
     | 
| 426 | 
         
            +
                        inv_mask_erosion = cv2.erode(
         
     | 
| 427 | 
         
            +
                            inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
         
     | 
| 428 | 
         
            +
                        pasted_face = inv_mask_erosion[:, :, None] * inv_restored
         
     | 
| 429 | 
         
            +
                        total_face_area = np.sum(inv_mask_erosion)  # // 3
         
     | 
| 430 | 
         
            +
                        # add border
         
     | 
| 431 | 
         
            +
                        if draw_box:
         
     | 
| 432 | 
         
            +
                            h, w = face_size
         
     | 
| 433 | 
         
            +
                            mask_border = np.ones((h, w, 3), dtype=np.float32)
         
     | 
| 434 | 
         
            +
                            border = int(1400/np.sqrt(total_face_area))
         
     | 
| 435 | 
         
            +
                            mask_border[border:h-border, border:w-border,:] = 0
         
     | 
| 436 | 
         
            +
                            inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
         
     | 
| 437 | 
         
            +
                            inv_mask_borders.append(inv_mask_border)
         
     | 
| 438 | 
         
            +
                        # compute the fusion edge based on the area of face
         
     | 
| 439 | 
         
            +
                        w_edge = int(total_face_area**0.5) // 20
         
     | 
| 440 | 
         
            +
                        erosion_radius = w_edge * 2
         
     | 
| 441 | 
         
            +
                        inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
         
     | 
| 442 | 
         
            +
                        blur_size = w_edge * 2
         
     | 
| 443 | 
         
            +
                        inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
         
     | 
| 444 | 
         
            +
                        if len(upsample_img.shape) == 2:  # upsample_img is gray image
         
     | 
| 445 | 
         
            +
                            upsample_img = upsample_img[:, :, None]
         
     | 
| 446 | 
         
            +
                        inv_soft_mask = inv_soft_mask[:, :, None]
         
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
                        # parse mask
         
     | 
| 449 | 
         
            +
                        if self.use_parse:
         
     | 
| 450 | 
         
            +
                            # inference
         
     | 
| 451 | 
         
            +
                            face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
         
     | 
| 452 | 
         
            +
                            face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
         
     | 
| 453 | 
         
            +
                            normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
         
     | 
| 454 | 
         
            +
                            face_input = torch.unsqueeze(face_input, 0).to(self.device)
         
     | 
| 455 | 
         
            +
                            with torch.no_grad():
         
     | 
| 456 | 
         
            +
                                out = self.face_parse(face_input)[0]
         
     | 
| 457 | 
         
            +
                            out = out.argmax(dim=1).squeeze().cpu().numpy()
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
                            parse_mask = np.zeros(out.shape)
         
     | 
| 460 | 
         
            +
                            MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
         
     | 
| 461 | 
         
            +
                            for idx, color in enumerate(MASK_COLORMAP):
         
     | 
| 462 | 
         
            +
                                parse_mask[out == idx] = color
         
     | 
| 463 | 
         
            +
                            #  blur the mask
         
     | 
| 464 | 
         
            +
                            parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
         
     | 
| 465 | 
         
            +
                            parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
         
     | 
| 466 | 
         
            +
                            # remove the black borders
         
     | 
| 467 | 
         
            +
                            thres = 10
         
     | 
| 468 | 
         
            +
                            parse_mask[:thres, :] = 0
         
     | 
| 469 | 
         
            +
                            parse_mask[-thres:, :] = 0
         
     | 
| 470 | 
         
            +
                            parse_mask[:, :thres] = 0
         
     | 
| 471 | 
         
            +
                            parse_mask[:, -thres:] = 0
         
     | 
| 472 | 
         
            +
                            parse_mask = parse_mask / 255.
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                            parse_mask = cv2.resize(parse_mask, face_size)
         
     | 
| 475 | 
         
            +
                            parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
         
     | 
| 476 | 
         
            +
                            inv_soft_parse_mask = parse_mask[:, :, None]
         
     | 
| 477 | 
         
            +
                            # pasted_face = inv_restored
         
     | 
| 478 | 
         
            +
                            fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
         
     | 
| 479 | 
         
            +
                            inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
         
     | 
| 480 | 
         
            +
             
     | 
| 481 | 
         
            +
                        if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4:  # alpha channel
         
     | 
| 482 | 
         
            +
                            alpha = upsample_img[:, :, 3:]
         
     | 
| 483 | 
         
            +
                            upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
         
     | 
| 484 | 
         
            +
                            upsample_img = np.concatenate((upsample_img, alpha), axis=2)
         
     | 
| 485 | 
         
            +
                        else:
         
     | 
| 486 | 
         
            +
                            upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                    if np.max(upsample_img) > 256:  # 16-bit image
         
     | 
| 489 | 
         
            +
                        upsample_img = upsample_img.astype(np.uint16)
         
     | 
| 490 | 
         
            +
                    else:
         
     | 
| 491 | 
         
            +
                        upsample_img = upsample_img.astype(np.uint8)
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                    # draw bounding box
         
     | 
| 494 | 
         
            +
                    if draw_box:
         
     | 
| 495 | 
         
            +
                        # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
         
     | 
| 496 | 
         
            +
                        img_color = np.ones([*upsample_img.shape], dtype=np.float32)
         
     | 
| 497 | 
         
            +
                        img_color[:,:,0] = 0
         
     | 
| 498 | 
         
            +
                        img_color[:,:,1] = 255
         
     | 
| 499 | 
         
            +
                        img_color[:,:,2] = 0
         
     | 
| 500 | 
         
            +
                        for inv_mask_border in inv_mask_borders:
         
     | 
| 501 | 
         
            +
                            upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
         
     | 
| 502 | 
         
            +
                            # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
            +
                    if save_path is not None:
         
     | 
| 505 | 
         
            +
                        path = os.path.splitext(save_path)[0]
         
     | 
| 506 | 
         
            +
                        save_path = f'{path}.{self.save_ext}'
         
     | 
| 507 | 
         
            +
                        imwrite(upsample_img, save_path)
         
     | 
| 508 | 
         
            +
                    return upsample_img
         
     | 
| 509 | 
         
            +
             
     | 
| 510 | 
         
            +
                def clean_all(self):
         
     | 
| 511 | 
         
            +
                    self.all_landmarks_5 = []
         
     | 
| 512 | 
         
            +
                    self.restored_faces = []
         
     | 
| 513 | 
         
            +
                    self.affine_matrices = []
         
     | 
| 514 | 
         
            +
                    self.cropped_faces = []
         
     | 
| 515 | 
         
            +
                    self.inverse_affine_matrices = []
         
     | 
| 516 | 
         
            +
                    self.det_faces = []
         
     | 
| 517 | 
         
            +
                    self.pad_input_imgs = []
         
     | 
    	
        utils/helpers.py
    ADDED
    
    | 
         @@ -0,0 +1,216 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import overload, Tuple, Optional
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from torch import nn
         
     | 
| 5 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            from PIL import Image
         
     | 
| 8 | 
         
            +
            from einops import rearrange
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from model.cldm import ControlLDM
         
     | 
| 11 | 
         
            +
            from model.gaussian_diffusion import Diffusion
         
     | 
| 12 | 
         
            +
            from model.bsrnet import RRDBNet
         
     | 
| 13 | 
         
            +
            from model.swinir import SwinIR
         
     | 
| 14 | 
         
            +
            from model.scunet import SCUNet
         
     | 
| 15 | 
         
            +
            from utils.sampler import SpacedSampler
         
     | 
| 16 | 
         
            +
            from utils.cond_fn import Guidance
         
     | 
| 17 | 
         
            +
            from utils.common import wavelet_decomposition, wavelet_reconstruction, count_vram_usage
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray:
         
     | 
| 21 | 
         
            +
                pil = Image.fromarray(img)
         
     | 
| 22 | 
         
            +
                res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC)
         
     | 
| 23 | 
         
            +
                return np.array(res)
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            def resize_short_edge_to(imgs: torch.Tensor, size: int) -> torch.Tensor:
         
     | 
| 27 | 
         
            +
                _, _, h, w = imgs.size()
         
     | 
| 28 | 
         
            +
                if h == w:
         
     | 
| 29 | 
         
            +
                    new_h, new_w = size, size
         
     | 
| 30 | 
         
            +
                elif h < w:
         
     | 
| 31 | 
         
            +
                    new_h, new_w = size, int(w * (size / h))
         
     | 
| 32 | 
         
            +
                else:
         
     | 
| 33 | 
         
            +
                    new_h, new_w = int(h * (size / w)), size
         
     | 
| 34 | 
         
            +
                return F.interpolate(imgs, size=(new_h, new_w), mode="bicubic", antialias=True)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor:
         
     | 
| 38 | 
         
            +
                _, _, h, w = imgs.size()
         
     | 
| 39 | 
         
            +
                if h % multiple == 0 and w % multiple == 0:
         
     | 
| 40 | 
         
            +
                    return imgs.clone()
         
     | 
| 41 | 
         
            +
                # get_pad = lambda x: (x // multiple + 1) * multiple - x
         
     | 
| 42 | 
         
            +
                get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x
         
     | 
| 43 | 
         
            +
                ph, pw = get_pad(h), get_pad(w)
         
     | 
| 44 | 
         
            +
                return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            class Pipeline:
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def __init__(self, stage1_model: nn.Module, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
         
     | 
| 50 | 
         
            +
                    self.stage1_model = stage1_model
         
     | 
| 51 | 
         
            +
                    self.cldm = cldm
         
     | 
| 52 | 
         
            +
                    self.diffusion = diffusion
         
     | 
| 53 | 
         
            +
                    self.cond_fn = cond_fn
         
     | 
| 54 | 
         
            +
                    self.device = device
         
     | 
| 55 | 
         
            +
                    self.final_size: Tuple[int] = None
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def set_final_size(self, lq: torch.Tensor) -> None:
         
     | 
| 58 | 
         
            +
                    h, w = lq.shape[2:]
         
     | 
| 59 | 
         
            +
                    self.final_size = (h, w)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                @overload
         
     | 
| 62 | 
         
            +
                def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
         
     | 
| 63 | 
         
            +
                    ...
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                @count_vram_usage
         
     | 
| 66 | 
         
            +
                def run_stage2(
         
     | 
| 67 | 
         
            +
                    self,
         
     | 
| 68 | 
         
            +
                    clean: torch.Tensor,
         
     | 
| 69 | 
         
            +
                    steps: int,
         
     | 
| 70 | 
         
            +
                    strength: float,
         
     | 
| 71 | 
         
            +
                    tiled: bool,
         
     | 
| 72 | 
         
            +
                    tile_size: int,
         
     | 
| 73 | 
         
            +
                    tile_stride: int,
         
     | 
| 74 | 
         
            +
                    pos_prompt: str,
         
     | 
| 75 | 
         
            +
                    neg_prompt: str,
         
     | 
| 76 | 
         
            +
                    cfg_scale: float,
         
     | 
| 77 | 
         
            +
                    better_start: float
         
     | 
| 78 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 79 | 
         
            +
                    ### preprocess
         
     | 
| 80 | 
         
            +
                    bs, _, ori_h, ori_w = clean.shape
         
     | 
| 81 | 
         
            +
                    # pad: ensure that height & width are multiples of 64
         
     | 
| 82 | 
         
            +
                    pad_clean = pad_to_multiples_of(clean, multiple=64)
         
     | 
| 83 | 
         
            +
                    h, w = pad_clean.shape[2:]
         
     | 
| 84 | 
         
            +
                    # prepare conditon
         
     | 
| 85 | 
         
            +
                    if not tiled:
         
     | 
| 86 | 
         
            +
                        cond = self.cldm.prepare_condition(pad_clean, [pos_prompt] * bs)
         
     | 
| 87 | 
         
            +
                        uncond = self.cldm.prepare_condition(pad_clean, [neg_prompt] * bs)
         
     | 
| 88 | 
         
            +
                    else:
         
     | 
| 89 | 
         
            +
                        cond = self.cldm.prepare_condition_tiled(pad_clean, [pos_prompt] * bs, tile_size, tile_stride)
         
     | 
| 90 | 
         
            +
                        uncond = self.cldm.prepare_condition_tiled(pad_clean, [neg_prompt] * bs, tile_size, tile_stride)
         
     | 
| 91 | 
         
            +
                    if self.cond_fn:
         
     | 
| 92 | 
         
            +
                        self.cond_fn.load_target(pad_clean * 2 - 1)
         
     | 
| 93 | 
         
            +
                    old_control_scales = self.cldm.control_scales
         
     | 
| 94 | 
         
            +
                    self.cldm.control_scales = [strength] * 13
         
     | 
| 95 | 
         
            +
                    if better_start:
         
     | 
| 96 | 
         
            +
                        # using noised low frequency part of condition as a better start point of 
         
     | 
| 97 | 
         
            +
                        # reverse sampling, which can prevent our model from generating noise in 
         
     | 
| 98 | 
         
            +
                        # image background.
         
     | 
| 99 | 
         
            +
                        _, low_freq = wavelet_decomposition(pad_clean)
         
     | 
| 100 | 
         
            +
                        if not tiled:
         
     | 
| 101 | 
         
            +
                            x_0 = self.cldm.vae_encode(low_freq)
         
     | 
| 102 | 
         
            +
                        else:
         
     | 
| 103 | 
         
            +
                            x_0 = self.cldm.vae_encode_tiled(low_freq, tile_size, tile_stride)
         
     | 
| 104 | 
         
            +
                        x_T = self.diffusion.q_sample(
         
     | 
| 105 | 
         
            +
                            x_0,
         
     | 
| 106 | 
         
            +
                            torch.full((bs, ), self.diffusion.num_timesteps - 1, dtype=torch.long, device=self.device),
         
     | 
| 107 | 
         
            +
                            torch.randn(x_0.shape, dtype=torch.float32, device=self.device)
         
     | 
| 108 | 
         
            +
                        )
         
     | 
| 109 | 
         
            +
                        # print(f"diffusion sqrt_alphas_cumprod: {self.diffusion.sqrt_alphas_cumprod[-1]}")
         
     | 
| 110 | 
         
            +
                    else:
         
     | 
| 111 | 
         
            +
                        x_T = torch.randn((bs, 4, h // 8, w // 8), dtype=torch.float32, device=self.device)
         
     | 
| 112 | 
         
            +
                    ### run sampler
         
     | 
| 113 | 
         
            +
                    sampler = SpacedSampler(self.diffusion.betas)
         
     | 
| 114 | 
         
            +
                    z = sampler.sample(
         
     | 
| 115 | 
         
            +
                        model=self.cldm, device=self.device, steps=steps, batch_size=bs, x_size=(4, h // 8, w // 8),
         
     | 
| 116 | 
         
            +
                        cond=cond, uncond=uncond, cfg_scale=cfg_scale, x_T=x_T, progress=True,
         
     | 
| 117 | 
         
            +
                        progress_leave=True, cond_fn=self.cond_fn, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
         
     | 
| 118 | 
         
            +
                    )
         
     | 
| 119 | 
         
            +
                    if not tiled:
         
     | 
| 120 | 
         
            +
                        x = self.cldm.vae_decode(z)
         
     | 
| 121 | 
         
            +
                    else:
         
     | 
| 122 | 
         
            +
                        x = self.cldm.vae_decode_tiled(z, tile_size // 8, tile_stride // 8)
         
     | 
| 123 | 
         
            +
                    ### postprocess
         
     | 
| 124 | 
         
            +
                    self.cldm.control_scales = old_control_scales
         
     | 
| 125 | 
         
            +
                    sample = x[:, :, :ori_h, :ori_w]
         
     | 
| 126 | 
         
            +
                    return sample
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                @torch.no_grad()
         
     | 
| 129 | 
         
            +
                def run(
         
     | 
| 130 | 
         
            +
                    self,
         
     | 
| 131 | 
         
            +
                    lq: np.ndarray,
         
     | 
| 132 | 
         
            +
                    steps: int,
         
     | 
| 133 | 
         
            +
                    strength: float,
         
     | 
| 134 | 
         
            +
                    tiled: bool,
         
     | 
| 135 | 
         
            +
                    tile_size: int,
         
     | 
| 136 | 
         
            +
                    tile_stride: int,
         
     | 
| 137 | 
         
            +
                    pos_prompt: str,
         
     | 
| 138 | 
         
            +
                    neg_prompt: str,
         
     | 
| 139 | 
         
            +
                    cfg_scale: float,
         
     | 
| 140 | 
         
            +
                    better_start: bool
         
     | 
| 141 | 
         
            +
                ) -> np.ndarray:
         
     | 
| 142 | 
         
            +
                    # image to tensor
         
     | 
| 143 | 
         
            +
                    lq = torch.tensor((lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device)
         
     | 
| 144 | 
         
            +
                    lq = rearrange(lq, "n h w c -> n c h w").contiguous()
         
     | 
| 145 | 
         
            +
                    # set pipeline output size
         
     | 
| 146 | 
         
            +
                    self.set_final_size(lq)
         
     | 
| 147 | 
         
            +
                    clean = self.run_stage1(lq)
         
     | 
| 148 | 
         
            +
                    sample = self.run_stage2(
         
     | 
| 149 | 
         
            +
                        clean, steps, strength, tiled, tile_size, tile_stride,
         
     | 
| 150 | 
         
            +
                        pos_prompt, neg_prompt, cfg_scale, better_start
         
     | 
| 151 | 
         
            +
                    )
         
     | 
| 152 | 
         
            +
                    # colorfix (borrowed from StableSR, thanks for their work)
         
     | 
| 153 | 
         
            +
                    sample = (sample + 1) / 2
         
     | 
| 154 | 
         
            +
                    sample = wavelet_reconstruction(sample, clean)
         
     | 
| 155 | 
         
            +
                    # resize to desired output size
         
     | 
| 156 | 
         
            +
                    sample = F.interpolate(sample, size=self.final_size, mode="bicubic", antialias=True)
         
     | 
| 157 | 
         
            +
                    # tensor to image
         
     | 
| 158 | 
         
            +
                    sample = rearrange(sample * 255., "n c h w -> n h w c")
         
     | 
| 159 | 
         
            +
                    sample = sample.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy()
         
     | 
| 160 | 
         
            +
                    return sample
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
            class BSRNetPipeline(Pipeline):
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                def __init__(self, bsrnet: RRDBNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str, upscale: float) -> None:
         
     | 
| 166 | 
         
            +
                    super().__init__(bsrnet, cldm, diffusion, cond_fn, device)
         
     | 
| 167 | 
         
            +
                    self.upscale = upscale
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                def set_final_size(self, lq: torch.Tensor) -> None:
         
     | 
| 170 | 
         
            +
                    h, w = lq.shape[2:]
         
     | 
| 171 | 
         
            +
                    self.final_size = (int(h * self.upscale), int(w * self.upscale))
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                @count_vram_usage
         
     | 
| 174 | 
         
            +
                def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
         
     | 
| 175 | 
         
            +
                    # NOTE: upscale is always set to 4 in our experiments
         
     | 
| 176 | 
         
            +
                    clean = self.stage1_model(lq)
         
     | 
| 177 | 
         
            +
                    # if self.final_size[0] < 512 and self.final_size[1] < 512:
         
     | 
| 178 | 
         
            +
                    if min(self.final_size) < 512:
         
     | 
| 179 | 
         
            +
                        clean = resize_short_edge_to(clean, size=512)
         
     | 
| 180 | 
         
            +
                    else:
         
     | 
| 181 | 
         
            +
                        clean = F.interpolate(clean, size=self.final_size, mode="bicubic", antialias=True)
         
     | 
| 182 | 
         
            +
                    return clean
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
            class SwinIRPipeline(Pipeline):
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                def __init__(self, swinir: SwinIR, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
         
     | 
| 188 | 
         
            +
                    super().__init__(swinir, cldm, diffusion, cond_fn, device)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                @count_vram_usage
         
     | 
| 191 | 
         
            +
                def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
         
     | 
| 192 | 
         
            +
                    # NOTE: lq size is always equal to 512 in our experiments
         
     | 
| 193 | 
         
            +
                    # resize: ensure the input lq size is as least 512, since SwinIR is trained on 512 resolution
         
     | 
| 194 | 
         
            +
                    if min(lq.shape[2:]) < 512:
         
     | 
| 195 | 
         
            +
                        lq = resize_short_edge_to(lq, size=512)
         
     | 
| 196 | 
         
            +
                    ori_h, ori_w = lq.shape[2:]
         
     | 
| 197 | 
         
            +
                    # pad: ensure that height & width are multiples of 64
         
     | 
| 198 | 
         
            +
                    pad_lq = pad_to_multiples_of(lq, multiple=64)
         
     | 
| 199 | 
         
            +
                    # run
         
     | 
| 200 | 
         
            +
                    clean = self.stage1_model(pad_lq)
         
     | 
| 201 | 
         
            +
                    # remove padding
         
     | 
| 202 | 
         
            +
                    clean = clean[:, :, :ori_h, :ori_w]
         
     | 
| 203 | 
         
            +
                    return clean
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
            class SCUNetPipeline(Pipeline):
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                def __init__(self, scunet: SCUNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
         
     | 
| 209 | 
         
            +
                    super().__init__(scunet, cldm, diffusion, cond_fn, device)
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                @count_vram_usage
         
     | 
| 212 | 
         
            +
                def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
         
     | 
| 213 | 
         
            +
                    clean = self.stage1_model(lq)
         
     | 
| 214 | 
         
            +
                    if min(clean.shape[2:]) < 512:
         
     | 
| 215 | 
         
            +
                        clean = resize_short_edge_to(clean, size=512)
         
     | 
| 216 | 
         
            +
                    return clean
         
     | 
    	
        utils/inference.py
    ADDED
    
    | 
         @@ -0,0 +1,320 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            from typing import overload, Generator, Dict
         
     | 
| 3 | 
         
            +
            from argparse import Namespace
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from PIL import Image
         
     | 
| 8 | 
         
            +
            from omegaconf import OmegaConf
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from model.cldm import ControlLDM
         
     | 
| 11 | 
         
            +
            from model.gaussian_diffusion import Diffusion
         
     | 
| 12 | 
         
            +
            from model.bsrnet import RRDBNet
         
     | 
| 13 | 
         
            +
            from model.scunet import SCUNet
         
     | 
| 14 | 
         
            +
            from model.swinir import SwinIR
         
     | 
| 15 | 
         
            +
            from utils.common import instantiate_from_config, load_file_from_url, count_vram_usage
         
     | 
| 16 | 
         
            +
            from utils.face_restoration_helper import FaceRestoreHelper
         
     | 
| 17 | 
         
            +
            from utils.helpers import (
         
     | 
| 18 | 
         
            +
                Pipeline,
         
     | 
| 19 | 
         
            +
                BSRNetPipeline, SwinIRPipeline, SCUNetPipeline,
         
     | 
| 20 | 
         
            +
                bicubic_resize
         
     | 
| 21 | 
         
            +
            )
         
     | 
| 22 | 
         
            +
            from utils.cond_fn import MSEGuidance, WeightedMSEGuidance
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            MODELS = {
         
     | 
| 26 | 
         
            +
                ### stage_1 model weights
         
     | 
| 27 | 
         
            +
                "bsrnet": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRNet.pth",
         
     | 
| 28 | 
         
            +
                # the following checkpoint is up-to-date, but we use the old version in our paper
         
     | 
| 29 | 
         
            +
                # "swinir_face": "https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth",
         
     | 
| 30 | 
         
            +
                "swinir_face": "https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt",
         
     | 
| 31 | 
         
            +
                "scunet_psnr": "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth",
         
     | 
| 32 | 
         
            +
                "swinir_general": "https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt",
         
     | 
| 33 | 
         
            +
                ### stage_2 model weights
         
     | 
| 34 | 
         
            +
                "sd_v21": "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt",
         
     | 
| 35 | 
         
            +
                "v1_face": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_face.pth",
         
     | 
| 36 | 
         
            +
                "v1_general": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_general.pth",
         
     | 
| 37 | 
         
            +
                "v2": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v2.pth"
         
     | 
| 38 | 
         
            +
            }
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def load_model_from_url(url: str) -> Dict[str, torch.Tensor]:
         
     | 
| 42 | 
         
            +
                sd_path = load_file_from_url(url, model_dir="weights")
         
     | 
| 43 | 
         
            +
                sd = torch.load(sd_path, map_location="cpu")
         
     | 
| 44 | 
         
            +
                if "state_dict" in sd:
         
     | 
| 45 | 
         
            +
                    sd = sd["state_dict"]
         
     | 
| 46 | 
         
            +
                if list(sd.keys())[0].startswith("module"):
         
     | 
| 47 | 
         
            +
                    sd = {k[len("module."):]: v for k, v in sd.items()}
         
     | 
| 48 | 
         
            +
                return sd
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            class InferenceLoop:
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def __init__(self, args: Namespace) -> "InferenceLoop":
         
     | 
| 54 | 
         
            +
                    self.args = args
         
     | 
| 55 | 
         
            +
                    self.loop_ctx = {}
         
     | 
| 56 | 
         
            +
                    self.pipeline: Pipeline = None
         
     | 
| 57 | 
         
            +
                    self.init_stage1_model()
         
     | 
| 58 | 
         
            +
                    self.init_stage2_model()
         
     | 
| 59 | 
         
            +
                    self.init_cond_fn()
         
     | 
| 60 | 
         
            +
                    self.init_pipeline()
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                @overload
         
     | 
| 63 | 
         
            +
                def init_stage1_model(self) -> None:
         
     | 
| 64 | 
         
            +
                    ...
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                @count_vram_usage
         
     | 
| 67 | 
         
            +
                def init_stage2_model(self) -> None:
         
     | 
| 68 | 
         
            +
                    ### load uent, vae, clip
         
     | 
| 69 | 
         
            +
                    self.cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml"))
         
     | 
| 70 | 
         
            +
                    sd = load_model_from_url(MODELS["sd_v21"])
         
     | 
| 71 | 
         
            +
                    unused = self.cldm.load_pretrained_sd(sd)
         
     | 
| 72 | 
         
            +
                    print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
         
     | 
| 73 | 
         
            +
                    ### load controlnet
         
     | 
| 74 | 
         
            +
                    if self.args.version == "v1":
         
     | 
| 75 | 
         
            +
                        if self.args.task == "fr":
         
     | 
| 76 | 
         
            +
                            control_sd = load_model_from_url(MODELS["v1_face"])
         
     | 
| 77 | 
         
            +
                        elif self.args.task == "sr":
         
     | 
| 78 | 
         
            +
                            control_sd = load_model_from_url(MODELS["v1_general"])
         
     | 
| 79 | 
         
            +
                        else:
         
     | 
| 80 | 
         
            +
                            raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'")
         
     | 
| 81 | 
         
            +
                    else:
         
     | 
| 82 | 
         
            +
                        control_sd = load_model_from_url(MODELS["v2"])
         
     | 
| 83 | 
         
            +
                    self.cldm.load_controlnet_from_ckpt(control_sd)
         
     | 
| 84 | 
         
            +
                    print(f"strictly load controlnet weight")
         
     | 
| 85 | 
         
            +
                    self.cldm.eval().to(self.args.device)
         
     | 
| 86 | 
         
            +
                    ### load diffusion
         
     | 
| 87 | 
         
            +
                    self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load("configs/inference/diffusion.yaml"))
         
     | 
| 88 | 
         
            +
                    self.diffusion.to(self.args.device)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def init_cond_fn(self) -> None:
         
     | 
| 91 | 
         
            +
                    if not self.args.guidance:
         
     | 
| 92 | 
         
            +
                        self.cond_fn = None
         
     | 
| 93 | 
         
            +
                        return
         
     | 
| 94 | 
         
            +
                    if self.args.g_loss == "mse":
         
     | 
| 95 | 
         
            +
                        cond_fn_cls = MSEGuidance
         
     | 
| 96 | 
         
            +
                    elif self.args.g_loss == "w_mse":
         
     | 
| 97 | 
         
            +
                        cond_fn_cls = WeightedMSEGuidance
         
     | 
| 98 | 
         
            +
                    else:
         
     | 
| 99 | 
         
            +
                        raise ValueError(self.args.g_loss)
         
     | 
| 100 | 
         
            +
                    self.cond_fn = cond_fn_cls(
         
     | 
| 101 | 
         
            +
                        scale=self.args.g_scale, t_start=self.args.g_start, t_stop=self.args.g_stop,
         
     | 
| 102 | 
         
            +
                        space=self.args.g_space, repeat=self.args.g_repeat
         
     | 
| 103 | 
         
            +
                    )
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                @overload
         
     | 
| 106 | 
         
            +
                def init_pipeline(self) -> None:
         
     | 
| 107 | 
         
            +
                    ...
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                def setup(self) -> None:
         
     | 
| 110 | 
         
            +
                    self.output_dir = self.args.output
         
     | 
| 111 | 
         
            +
                    os.makedirs(self.output_dir, exist_ok=True)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                def lq_loader(self) -> Generator[np.ndarray, None, None]:
         
     | 
| 114 | 
         
            +
                    img_exts = [".png", ".jpg", ".jpeg"]
         
     | 
| 115 | 
         
            +
                    if os.path.isdir(self.args.input):
         
     | 
| 116 | 
         
            +
                        file_names = sorted([
         
     | 
| 117 | 
         
            +
                            file_name for file_name in os.listdir(self.args.input) if os.path.splitext(file_name)[-1] in img_exts
         
     | 
| 118 | 
         
            +
                        ])
         
     | 
| 119 | 
         
            +
                        file_paths = [os.path.join(self.args.input, file_name) for file_name in file_names]
         
     | 
| 120 | 
         
            +
                    else:
         
     | 
| 121 | 
         
            +
                        assert os.path.splitext(self.args.input)[-1] in img_exts
         
     | 
| 122 | 
         
            +
                        file_paths = [self.args.input]
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    def _loader() -> Generator[np.ndarray, None, None]:
         
     | 
| 125 | 
         
            +
                        for file_path in file_paths:
         
     | 
| 126 | 
         
            +
                            ### load lq
         
     | 
| 127 | 
         
            +
                            lq = np.array(Image.open(file_path).convert("RGB"))
         
     | 
| 128 | 
         
            +
                            print(f"load lq: {file_path}")
         
     | 
| 129 | 
         
            +
                            ### set context for saving results
         
     | 
| 130 | 
         
            +
                            self.loop_ctx["file_stem"] = os.path.splitext(os.path.basename(file_path))[0]
         
     | 
| 131 | 
         
            +
                            for i in range(self.args.n_samples):
         
     | 
| 132 | 
         
            +
                                self.loop_ctx["repeat_idx"] = i
         
     | 
| 133 | 
         
            +
                                yield lq
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    return _loader
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
         
     | 
| 138 | 
         
            +
                    return lq
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                @torch.no_grad()
         
     | 
| 141 | 
         
            +
                def run(self) -> None:
         
     | 
| 142 | 
         
            +
                    self.setup()
         
     | 
| 143 | 
         
            +
                    # We don't support batch processing since input images may have different size
         
     | 
| 144 | 
         
            +
                    loader = self.lq_loader()
         
     | 
| 145 | 
         
            +
                    for lq in loader():
         
     | 
| 146 | 
         
            +
                        lq = self.after_load_lq(lq)
         
     | 
| 147 | 
         
            +
                        sample = self.pipeline.run(
         
     | 
| 148 | 
         
            +
                            lq[None], self.args.steps, 1.0, self.args.tiled,
         
     | 
| 149 | 
         
            +
                            self.args.tile_size, self.args.tile_stride,
         
     | 
| 150 | 
         
            +
                            self.args.pos_prompt, self.args.neg_prompt, self.args.cfg_scale,
         
     | 
| 151 | 
         
            +
                            self.args.better_start
         
     | 
| 152 | 
         
            +
                        )[0]
         
     | 
| 153 | 
         
            +
                        self.save(sample)
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                def save(self, sample: np.ndarray) -> None:
         
     | 
| 156 | 
         
            +
                    file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"]
         
     | 
| 157 | 
         
            +
                    file_name = f"{file_stem}_{repeat_idx}.png" if self.args.n_samples > 1 else f"{file_stem}.png"
         
     | 
| 158 | 
         
            +
                    save_path = os.path.join(self.args.output, file_name)
         
     | 
| 159 | 
         
            +
                    Image.fromarray(sample).save(save_path)
         
     | 
| 160 | 
         
            +
                    print(f"save result to {save_path}")
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
            class BSRInferenceLoop(InferenceLoop):
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                @count_vram_usage
         
     | 
| 166 | 
         
            +
                def init_stage1_model(self) -> None:
         
     | 
| 167 | 
         
            +
                    self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml"))
         
     | 
| 168 | 
         
            +
                    sd = load_model_from_url(MODELS["bsrnet"])
         
     | 
| 169 | 
         
            +
                    self.bsrnet.load_state_dict(sd, strict=True)
         
     | 
| 170 | 
         
            +
                    self.bsrnet.eval().to(self.args.device)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                def init_pipeline(self) -> None:
         
     | 
| 173 | 
         
            +
                    self.pipeline = BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            class BFRInferenceLoop(InferenceLoop):
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                @count_vram_usage
         
     | 
| 179 | 
         
            +
                def init_stage1_model(self) -> None:
         
     | 
| 180 | 
         
            +
                    self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
         
     | 
| 181 | 
         
            +
                    sd = load_model_from_url(MODELS["swinir_face"])
         
     | 
| 182 | 
         
            +
                    self.swinir_face.load_state_dict(sd, strict=True)
         
     | 
| 183 | 
         
            +
                    self.swinir_face.eval().to(self.args.device)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                def init_pipeline(self) -> None:
         
     | 
| 186 | 
         
            +
                    self.pipeline = SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
         
     | 
| 189 | 
         
            +
                    # For BFR task, super resolution is achieved by directly upscaling lq
         
     | 
| 190 | 
         
            +
                    return bicubic_resize(lq, self.args.upscale)
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
            class BIDInferenceLoop(InferenceLoop):
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                @count_vram_usage
         
     | 
| 196 | 
         
            +
                def init_stage1_model(self) -> None:
         
     | 
| 197 | 
         
            +
                    self.scunet_psnr: SCUNet = instantiate_from_config(OmegaConf.load("configs/inference/scunet.yaml"))
         
     | 
| 198 | 
         
            +
                    sd = load_model_from_url(MODELS["scunet_psnr"])
         
     | 
| 199 | 
         
            +
                    self.scunet_psnr.load_state_dict(sd, strict=True)
         
     | 
| 200 | 
         
            +
                    self.scunet_psnr.eval().to(self.args.device)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                def init_pipeline(self) -> None:
         
     | 
| 203 | 
         
            +
                    self.pipeline = SCUNetPipeline(self.scunet_psnr, self.cldm, self.diffusion, self.cond_fn, self.args.device)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
         
     | 
| 206 | 
         
            +
                    # For BID task, super resolution is achieved by directly upscaling lq
         
     | 
| 207 | 
         
            +
                    return bicubic_resize(lq, self.args.upscale)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
            class V1InferenceLoop(InferenceLoop):
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                @count_vram_usage
         
     | 
| 213 | 
         
            +
                def init_stage1_model(self) -> None:
         
     | 
| 214 | 
         
            +
                    self.swinir: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
         
     | 
| 215 | 
         
            +
                    if self.args.task == "fr":
         
     | 
| 216 | 
         
            +
                        sd = load_model_from_url(MODELS["swinir_face"])
         
     | 
| 217 | 
         
            +
                    elif self.args.task == "sr":
         
     | 
| 218 | 
         
            +
                        sd = load_model_from_url(MODELS["swinir_general"])
         
     | 
| 219 | 
         
            +
                    else:
         
     | 
| 220 | 
         
            +
                        raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'")
         
     | 
| 221 | 
         
            +
                    self.swinir.load_state_dict(sd, strict=True)
         
     | 
| 222 | 
         
            +
                    self.swinir.eval().to(self.args.device)
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                def init_pipeline(self) -> None:
         
     | 
| 225 | 
         
            +
                    self.pipeline = SwinIRPipeline(self.swinir, self.cldm, self.diffusion, self.cond_fn, self.args.device)
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
         
     | 
| 228 | 
         
            +
                    # For BFR task, super resolution is achieved by directly upscaling lq
         
     | 
| 229 | 
         
            +
                    return bicubic_resize(lq, self.args.upscale)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
            class UnAlignedBFRInferenceLoop(InferenceLoop):
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                @count_vram_usage
         
     | 
| 235 | 
         
            +
                def init_stage1_model(self) -> None:
         
     | 
| 236 | 
         
            +
                    self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml"))
         
     | 
| 237 | 
         
            +
                    sd = load_model_from_url(MODELS["bsrnet"])
         
     | 
| 238 | 
         
            +
                    self.bsrnet.load_state_dict(sd, strict=True)
         
     | 
| 239 | 
         
            +
                    self.bsrnet.eval().to(self.args.device)
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
         
     | 
| 242 | 
         
            +
                    sd = load_model_from_url(MODELS["swinir_face"])
         
     | 
| 243 | 
         
            +
                    self.swinir_face.load_state_dict(sd, strict=True)
         
     | 
| 244 | 
         
            +
                    self.swinir_face.eval().to(self.args.device)
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                def init_pipeline(self) -> None:
         
     | 
| 247 | 
         
            +
                    self.pipes = {
         
     | 
| 248 | 
         
            +
                        "bg": BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale),
         
     | 
| 249 | 
         
            +
                        "face": SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device)
         
     | 
| 250 | 
         
            +
                    }
         
     | 
| 251 | 
         
            +
                    self.pipeline = self.pipes["face"]
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                def setup(self) -> None:
         
     | 
| 254 | 
         
            +
                    super().setup()
         
     | 
| 255 | 
         
            +
                    self.cropped_face_dir = os.path.join(self.args.output, "cropped_faces")
         
     | 
| 256 | 
         
            +
                    os.makedirs(self.cropped_face_dir, exist_ok=True)
         
     | 
| 257 | 
         
            +
                    self.restored_face_dir = os.path.join(self.args.output, "restored_faces")
         
     | 
| 258 | 
         
            +
                    os.makedirs(self.restored_face_dir, exist_ok=True)
         
     | 
| 259 | 
         
            +
                    self.restored_bg_dir = os.path.join(self.args.output, "restored_backgrounds")
         
     | 
| 260 | 
         
            +
                    os.makedirs(self.restored_bg_dir, exist_ok=True)
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                def lq_loader(self) -> Generator[np.ndarray, None, None]:
         
     | 
| 263 | 
         
            +
                    base_loader = super().lq_loader()
         
     | 
| 264 | 
         
            +
                    self.face_helper = FaceRestoreHelper(
         
     | 
| 265 | 
         
            +
                        device=self.args.device,
         
     | 
| 266 | 
         
            +
                        upscale_factor=1,
         
     | 
| 267 | 
         
            +
                        face_size=512,
         
     | 
| 268 | 
         
            +
                        use_parse=True,
         
     | 
| 269 | 
         
            +
                        det_model="retinaface_resnet50"
         
     | 
| 270 | 
         
            +
                    )
         
     | 
| 271 | 
         
            +
                    
         
     | 
| 272 | 
         
            +
                    def _loader() -> Generator[np.ndarray, None, None]:
         
     | 
| 273 | 
         
            +
                        for lq in base_loader():
         
     | 
| 274 | 
         
            +
                            ### set input image
         
     | 
| 275 | 
         
            +
                            self.face_helper.clean_all()
         
     | 
| 276 | 
         
            +
                            upscaled_bg = bicubic_resize(lq, self.args.upscale)
         
     | 
| 277 | 
         
            +
                            self.face_helper.read_image(upscaled_bg)
         
     | 
| 278 | 
         
            +
                            ### get face landmarks for each face
         
     | 
| 279 | 
         
            +
                            self.face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
         
     | 
| 280 | 
         
            +
                            self.face_helper.align_warp_face()
         
     | 
| 281 | 
         
            +
                            print(f"detect {len(self.face_helper.cropped_faces)} faces")
         
     | 
| 282 | 
         
            +
                            ### restore each face (has been upscaeled)
         
     | 
| 283 | 
         
            +
                            for i, lq_face in enumerate(self.face_helper.cropped_faces):
         
     | 
| 284 | 
         
            +
                                self.loop_ctx["is_face"] = True
         
     | 
| 285 | 
         
            +
                                self.loop_ctx["face_idx"] = i
         
     | 
| 286 | 
         
            +
                                self.loop_ctx["cropped_face"] = lq_face
         
     | 
| 287 | 
         
            +
                                yield lq_face
         
     | 
| 288 | 
         
            +
                            ### restore background (hasn't been upscaled)
         
     | 
| 289 | 
         
            +
                            self.loop_ctx["is_face"] = False
         
     | 
| 290 | 
         
            +
                            yield lq
         
     | 
| 291 | 
         
            +
                    
         
     | 
| 292 | 
         
            +
                    return _loader
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
         
     | 
| 295 | 
         
            +
                    if self.loop_ctx["is_face"]:
         
     | 
| 296 | 
         
            +
                        self.pipeline = self.pipes["face"]
         
     | 
| 297 | 
         
            +
                    else:
         
     | 
| 298 | 
         
            +
                        self.pipeline = self.pipes["bg"]
         
     | 
| 299 | 
         
            +
                    return lq
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                def save(self, sample: np.ndarray) -> None:
         
     | 
| 302 | 
         
            +
                    file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"]
         
     | 
| 303 | 
         
            +
                    if self.loop_ctx["is_face"]:
         
     | 
| 304 | 
         
            +
                        face_idx = self.loop_ctx["face_idx"]
         
     | 
| 305 | 
         
            +
                        file_name = f"{file_stem}_{repeat_idx}_face_{face_idx}.png"
         
     | 
| 306 | 
         
            +
                        Image.fromarray(sample).save(os.path.join(self.restored_face_dir, file_name))
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                        cropped_face = self.loop_ctx["cropped_face"]
         
     | 
| 309 | 
         
            +
                        Image.fromarray(cropped_face).save(os.path.join(self.cropped_face_dir, file_name))
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                        self.face_helper.add_restored_face(sample)
         
     | 
| 312 | 
         
            +
                    else:
         
     | 
| 313 | 
         
            +
                        self.face_helper.get_inverse_affine()
         
     | 
| 314 | 
         
            +
                        # paste each restored face to the input image
         
     | 
| 315 | 
         
            +
                        restored_img = self.face_helper.paste_faces_to_input_image(
         
     | 
| 316 | 
         
            +
                            upsample_img=sample
         
     | 
| 317 | 
         
            +
                        )
         
     | 
| 318 | 
         
            +
                        file_name = f"{file_stem}_{repeat_idx}.png"
         
     | 
| 319 | 
         
            +
                        Image.fromarray(sample).save(os.path.join(self.restored_bg_dir, file_name))
         
     | 
| 320 | 
         
            +
                        Image.fromarray(restored_img).save(os.path.join(self.output_dir, file_name))
         
     | 
    	
        utils/sampler.py
    ADDED
    
    | 
         @@ -0,0 +1,341 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional, Tuple, Dict
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from torch import nn
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            from tqdm import tqdm
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from model.gaussian_diffusion import extract_into_tensor
         
     | 
| 9 | 
         
            +
            from model.cldm import ControlLDM
         
     | 
| 10 | 
         
            +
            from utils.cond_fn import Guidance
         
     | 
| 11 | 
         
            +
            from utils.common import sliding_windows, gaussian_weights
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
         
     | 
| 15 | 
         
            +
            def space_timesteps(num_timesteps, section_counts):
         
     | 
| 16 | 
         
            +
                """
         
     | 
| 17 | 
         
            +
                Create a list of timesteps to use from an original diffusion process,
         
     | 
| 18 | 
         
            +
                given the number of timesteps we want to take from equally-sized portions
         
     | 
| 19 | 
         
            +
                of the original process.
         
     | 
| 20 | 
         
            +
                For example, if there's 300 timesteps and the section counts are [10,15,20]
         
     | 
| 21 | 
         
            +
                then the first 100 timesteps are strided to be 10 timesteps, the second 100
         
     | 
| 22 | 
         
            +
                are strided to be 15 timesteps, and the final 100 are strided to be 20.
         
     | 
| 23 | 
         
            +
                If the stride is a string starting with "ddim", then the fixed striding
         
     | 
| 24 | 
         
            +
                from the DDIM paper is used, and only one section is allowed.
         
     | 
| 25 | 
         
            +
                :param num_timesteps: the number of diffusion steps in the original
         
     | 
| 26 | 
         
            +
                                      process to divide up.
         
     | 
| 27 | 
         
            +
                :param section_counts: either a list of numbers, or a string containing
         
     | 
| 28 | 
         
            +
                                       comma-separated numbers, indicating the step count
         
     | 
| 29 | 
         
            +
                                       per section. As a special case, use "ddimN" where N
         
     | 
| 30 | 
         
            +
                                       is a number of steps to use the striding from the
         
     | 
| 31 | 
         
            +
                                       DDIM paper.
         
     | 
| 32 | 
         
            +
                :return: a set of diffusion steps from the original process to use.
         
     | 
| 33 | 
         
            +
                """
         
     | 
| 34 | 
         
            +
                if isinstance(section_counts, str):
         
     | 
| 35 | 
         
            +
                    if section_counts.startswith("ddim"):
         
     | 
| 36 | 
         
            +
                        desired_count = int(section_counts[len("ddim") :])
         
     | 
| 37 | 
         
            +
                        for i in range(1, num_timesteps):
         
     | 
| 38 | 
         
            +
                            if len(range(0, num_timesteps, i)) == desired_count:
         
     | 
| 39 | 
         
            +
                                return set(range(0, num_timesteps, i))
         
     | 
| 40 | 
         
            +
                        raise ValueError(
         
     | 
| 41 | 
         
            +
                            f"cannot create exactly {num_timesteps} steps with an integer stride"
         
     | 
| 42 | 
         
            +
                        )
         
     | 
| 43 | 
         
            +
                    section_counts = [int(x) for x in section_counts.split(",")]
         
     | 
| 44 | 
         
            +
                size_per = num_timesteps // len(section_counts)
         
     | 
| 45 | 
         
            +
                extra = num_timesteps % len(section_counts)
         
     | 
| 46 | 
         
            +
                start_idx = 0
         
     | 
| 47 | 
         
            +
                all_steps = []
         
     | 
| 48 | 
         
            +
                for i, section_count in enumerate(section_counts):
         
     | 
| 49 | 
         
            +
                    size = size_per + (1 if i < extra else 0)
         
     | 
| 50 | 
         
            +
                    if size < section_count:
         
     | 
| 51 | 
         
            +
                        raise ValueError(
         
     | 
| 52 | 
         
            +
                            f"cannot divide section of {size} steps into {section_count}"
         
     | 
| 53 | 
         
            +
                        )
         
     | 
| 54 | 
         
            +
                    if section_count <= 1:
         
     | 
| 55 | 
         
            +
                        frac_stride = 1
         
     | 
| 56 | 
         
            +
                    else:
         
     | 
| 57 | 
         
            +
                        frac_stride = (size - 1) / (section_count - 1)
         
     | 
| 58 | 
         
            +
                    cur_idx = 0.0
         
     | 
| 59 | 
         
            +
                    taken_steps = []
         
     | 
| 60 | 
         
            +
                    for _ in range(section_count):
         
     | 
| 61 | 
         
            +
                        taken_steps.append(start_idx + round(cur_idx))
         
     | 
| 62 | 
         
            +
                        cur_idx += frac_stride
         
     | 
| 63 | 
         
            +
                    all_steps += taken_steps
         
     | 
| 64 | 
         
            +
                    start_idx += size
         
     | 
| 65 | 
         
            +
                return set(all_steps)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            class SpacedSampler(nn.Module):
         
     | 
| 69 | 
         
            +
                """
         
     | 
| 70 | 
         
            +
                Implementation for spaced sampling schedule proposed in IDDPM. This class is designed
         
     | 
| 71 | 
         
            +
                for sampling ControlLDM.
         
     | 
| 72 | 
         
            +
                
         
     | 
| 73 | 
         
            +
                https://arxiv.org/pdf/2102.09672.pdf
         
     | 
| 74 | 
         
            +
                """
         
     | 
| 75 | 
         
            +
                
         
     | 
| 76 | 
         
            +
                def __init__(self, betas: np.ndarray) -> "SpacedSampler":
         
     | 
| 77 | 
         
            +
                    super().__init__()
         
     | 
| 78 | 
         
            +
                    self.num_timesteps = len(betas)
         
     | 
| 79 | 
         
            +
                    self.original_betas = betas
         
     | 
| 80 | 
         
            +
                    self.original_alphas_cumprod = np.cumprod(1.0 - betas, axis=0)
         
     | 
| 81 | 
         
            +
                    self.context = {}
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                def register(self, name: str, value: np.ndarray) -> None:
         
     | 
| 84 | 
         
            +
                    self.register_buffer(name, torch.tensor(value, dtype=torch.float32))
         
     | 
| 85 | 
         
            +
                
         
     | 
| 86 | 
         
            +
                def make_schedule(self, num_steps: int) -> None:
         
     | 
| 87 | 
         
            +
                    # calcualte betas for spaced sampling
         
     | 
| 88 | 
         
            +
                    # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
         
     | 
| 89 | 
         
            +
                    used_timesteps = space_timesteps(self.num_timesteps, str(num_steps))
         
     | 
| 90 | 
         
            +
                    betas = []
         
     | 
| 91 | 
         
            +
                    last_alpha_cumprod = 1.0
         
     | 
| 92 | 
         
            +
                    for i, alpha_cumprod in enumerate(self.original_alphas_cumprod):
         
     | 
| 93 | 
         
            +
                        if i in used_timesteps:
         
     | 
| 94 | 
         
            +
                            # marginal distribution is the same as q(x_{S_t}|x_0)
         
     | 
| 95 | 
         
            +
                            betas.append(1 - alpha_cumprod / last_alpha_cumprod)
         
     | 
| 96 | 
         
            +
                            last_alpha_cumprod = alpha_cumprod
         
     | 
| 97 | 
         
            +
                    assert len(betas) == num_steps
         
     | 
| 98 | 
         
            +
                    self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...]
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    betas = np.array(betas, dtype=np.float64)
         
     | 
| 101 | 
         
            +
                    alphas = 1.0 - betas
         
     | 
| 102 | 
         
            +
                    alphas_cumprod = np.cumprod(alphas, axis=0)
         
     | 
| 103 | 
         
            +
                    # print(f"sampler sqrt_alphas_cumprod: {np.sqrt(alphas_cumprod)[-1]}")
         
     | 
| 104 | 
         
            +
                    alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
         
     | 
| 105 | 
         
            +
                    sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
         
     | 
| 106 | 
         
            +
                    sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
         
     | 
| 107 | 
         
            +
                    # calculations for posterior q(x_{t-1} | x_t, x_0)
         
     | 
| 108 | 
         
            +
                    posterior_variance = (
         
     | 
| 109 | 
         
            +
                        betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
         
     | 
| 110 | 
         
            +
                    )
         
     | 
| 111 | 
         
            +
                    # log calculation clipped because the posterior variance is 0 at the
         
     | 
| 112 | 
         
            +
                    # beginning of the diffusion chain.
         
     | 
| 113 | 
         
            +
                    posterior_log_variance_clipped = np.log(
         
     | 
| 114 | 
         
            +
                        np.append(posterior_variance[1], posterior_variance[1:])
         
     | 
| 115 | 
         
            +
                    )
         
     | 
| 116 | 
         
            +
                    posterior_mean_coef1 = (
         
     | 
| 117 | 
         
            +
                        betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
         
     | 
| 118 | 
         
            +
                    )
         
     | 
| 119 | 
         
            +
                    posterior_mean_coef2 = (
         
     | 
| 120 | 
         
            +
                        (1.0 - alphas_cumprod_prev)
         
     | 
| 121 | 
         
            +
                        * np.sqrt(alphas)
         
     | 
| 122 | 
         
            +
                        / (1.0 - alphas_cumprod)
         
     | 
| 123 | 
         
            +
                    )
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod)
         
     | 
| 126 | 
         
            +
                    self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod)
         
     | 
| 127 | 
         
            +
                    self.register("posterior_variance", posterior_variance)
         
     | 
| 128 | 
         
            +
                    self.register("posterior_log_variance_clipped", posterior_log_variance_clipped)
         
     | 
| 129 | 
         
            +
                    self.register("posterior_mean_coef1", posterior_mean_coef1)
         
     | 
| 130 | 
         
            +
                    self.register("posterior_mean_coef2", posterior_mean_coef2)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                def q_posterior_mean_variance(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor]:
         
     | 
| 133 | 
         
            +
                    """
         
     | 
| 134 | 
         
            +
                    Implement the posterior distribution q(x_{t-1}|x_t, x_0).
         
     | 
| 135 | 
         
            +
                    
         
     | 
| 136 | 
         
            +
                    Args:
         
     | 
| 137 | 
         
            +
                        x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`.
         
     | 
| 138 | 
         
            +
                        x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`.
         
     | 
| 139 | 
         
            +
                        t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get 
         
     | 
| 140 | 
         
            +
                            parameters for each timestep.
         
     | 
| 141 | 
         
            +
                    
         
     | 
| 142 | 
         
            +
                    Returns:
         
     | 
| 143 | 
         
            +
                        posterior_mean (torch.Tensor): Mean of the posterior distribution.
         
     | 
| 144 | 
         
            +
                        posterior_variance (torch.Tensor): Variance of the posterior distribution.
         
     | 
| 145 | 
         
            +
                        posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution.
         
     | 
| 146 | 
         
            +
                    """
         
     | 
| 147 | 
         
            +
                    posterior_mean = (
         
     | 
| 148 | 
         
            +
                        extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
         
     | 
| 149 | 
         
            +
                        + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
         
     | 
| 150 | 
         
            +
                    )
         
     | 
| 151 | 
         
            +
                    posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
         
     | 
| 152 | 
         
            +
                    posterior_log_variance_clipped = extract_into_tensor(
         
     | 
| 153 | 
         
            +
                        self.posterior_log_variance_clipped, t, x_t.shape
         
     | 
| 154 | 
         
            +
                    )
         
     | 
| 155 | 
         
            +
                    return posterior_mean, posterior_variance, posterior_log_variance_clipped
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                def _predict_xstart_from_eps(self, x_t: torch.Tensor, t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor:
         
     | 
| 158 | 
         
            +
                    return (
         
     | 
| 159 | 
         
            +
                        extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
         
     | 
| 160 | 
         
            +
                        - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
         
     | 
| 161 | 
         
            +
                    )
         
     | 
| 162 | 
         
            +
                
         
     | 
| 163 | 
         
            +
                def apply_cond_fn(
         
     | 
| 164 | 
         
            +
                    self,
         
     | 
| 165 | 
         
            +
                    model: ControlLDM,
         
     | 
| 166 | 
         
            +
                    pred_x0: torch.Tensor,
         
     | 
| 167 | 
         
            +
                    t: torch.Tensor,
         
     | 
| 168 | 
         
            +
                    index: torch.Tensor,
         
     | 
| 169 | 
         
            +
                    cond_fn: Guidance
         
     | 
| 170 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 171 | 
         
            +
                    t_now = int(t[0].item()) + 1
         
     | 
| 172 | 
         
            +
                    if not (cond_fn.t_stop < t_now and t_now < cond_fn.t_start):
         
     | 
| 173 | 
         
            +
                        # stop guidance
         
     | 
| 174 | 
         
            +
                        self.context["g_apply"] = False
         
     | 
| 175 | 
         
            +
                        return pred_x0
         
     | 
| 176 | 
         
            +
                    grad_rescale = 1 / extract_into_tensor(self.posterior_mean_coef1, index, pred_x0.shape)
         
     | 
| 177 | 
         
            +
                    # apply guidance for multiple times
         
     | 
| 178 | 
         
            +
                    loss_vals = []
         
     | 
| 179 | 
         
            +
                    for _ in range(cond_fn.repeat):
         
     | 
| 180 | 
         
            +
                        # set target and pred for gradient computation
         
     | 
| 181 | 
         
            +
                        target, pred = None, None
         
     | 
| 182 | 
         
            +
                        if cond_fn.space == "latent":
         
     | 
| 183 | 
         
            +
                            target = model.vae_encode(cond_fn.target)
         
     | 
| 184 | 
         
            +
                            pred = pred_x0
         
     | 
| 185 | 
         
            +
                        elif cond_fn.space == "rgb":
         
     | 
| 186 | 
         
            +
                            # We need to backward gradient to x0 in latent space, so it's required
         
     | 
| 187 | 
         
            +
                            # to trace the computation graph while decoding the latent.
         
     | 
| 188 | 
         
            +
                            with torch.enable_grad():
         
     | 
| 189 | 
         
            +
                                target = cond_fn.target
         
     | 
| 190 | 
         
            +
                                pred_x0_rg = pred_x0.detach().clone().requires_grad_(True)
         
     | 
| 191 | 
         
            +
                                pred = model.vae_decode(pred_x0_rg)
         
     | 
| 192 | 
         
            +
                                assert pred.requires_grad
         
     | 
| 193 | 
         
            +
                        else:
         
     | 
| 194 | 
         
            +
                            raise NotImplementedError(cond_fn.space)
         
     | 
| 195 | 
         
            +
                        # compute gradient
         
     | 
| 196 | 
         
            +
                        delta_pred, loss_val = cond_fn(target, pred, t_now)
         
     | 
| 197 | 
         
            +
                        loss_vals.append(loss_val)
         
     | 
| 198 | 
         
            +
                        # update pred_x0 w.r.t gradient
         
     | 
| 199 | 
         
            +
                        if cond_fn.space == "latent":
         
     | 
| 200 | 
         
            +
                            delta_pred_x0 = delta_pred
         
     | 
| 201 | 
         
            +
                            pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale
         
     | 
| 202 | 
         
            +
                        elif cond_fn.space == "rgb":
         
     | 
| 203 | 
         
            +
                            pred.backward(delta_pred)
         
     | 
| 204 | 
         
            +
                            delta_pred_x0 = pred_x0_rg.grad
         
     | 
| 205 | 
         
            +
                            pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale
         
     | 
| 206 | 
         
            +
                        else:
         
     | 
| 207 | 
         
            +
                            raise NotImplementedError(cond_fn.space)
         
     | 
| 208 | 
         
            +
                    self.context["g_apply"] = True
         
     | 
| 209 | 
         
            +
                    self.context["g_loss"] = float(np.mean(loss_vals))
         
     | 
| 210 | 
         
            +
                    return pred_x0
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                def predict_noise(
         
     | 
| 213 | 
         
            +
                    self,
         
     | 
| 214 | 
         
            +
                    model: ControlLDM,
         
     | 
| 215 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 216 | 
         
            +
                    t: torch.Tensor,
         
     | 
| 217 | 
         
            +
                    cond: Dict[str, torch.Tensor],
         
     | 
| 218 | 
         
            +
                    uncond: Optional[Dict[str, torch.Tensor]],
         
     | 
| 219 | 
         
            +
                    cfg_scale: float
         
     | 
| 220 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 221 | 
         
            +
                    if uncond is None or cfg_scale == 1.:
         
     | 
| 222 | 
         
            +
                        model_output = model(x, t, cond)
         
     | 
| 223 | 
         
            +
                    else:
         
     | 
| 224 | 
         
            +
                        # apply classifier-free guidance
         
     | 
| 225 | 
         
            +
                        model_cond = model(x, t, cond)
         
     | 
| 226 | 
         
            +
                        model_uncond = model(x, t, uncond)
         
     | 
| 227 | 
         
            +
                        model_output = model_uncond + cfg_scale * (model_cond - model_uncond)
         
     | 
| 228 | 
         
            +
                    return model_output
         
     | 
| 229 | 
         
            +
                
         
     | 
| 230 | 
         
            +
                @torch.no_grad()
         
     | 
| 231 | 
         
            +
                def predict_noise_tiled(
         
     | 
| 232 | 
         
            +
                    self,
         
     | 
| 233 | 
         
            +
                    model: ControlLDM,
         
     | 
| 234 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 235 | 
         
            +
                    t: torch.Tensor,
         
     | 
| 236 | 
         
            +
                    cond: Dict[str, torch.Tensor],
         
     | 
| 237 | 
         
            +
                    uncond: Optional[Dict[str, torch.Tensor]],
         
     | 
| 238 | 
         
            +
                    cfg_scale: float,
         
     | 
| 239 | 
         
            +
                    tile_size: int,
         
     | 
| 240 | 
         
            +
                    tile_stride: int
         
     | 
| 241 | 
         
            +
                ):
         
     | 
| 242 | 
         
            +
                    _, _, h, w = x.shape
         
     | 
| 243 | 
         
            +
                    tiles = tqdm(sliding_windows(h, w, tile_size // 8, tile_stride // 8), unit="tile", leave=False)
         
     | 
| 244 | 
         
            +
                    eps = torch.zeros_like(x)
         
     | 
| 245 | 
         
            +
                    count = torch.zeros_like(x, dtype=torch.float32)
         
     | 
| 246 | 
         
            +
                    weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None]
         
     | 
| 247 | 
         
            +
                    weights = torch.tensor(weights, dtype=torch.float32, device=x.device)
         
     | 
| 248 | 
         
            +
                    for hi, hi_end, wi, wi_end in tiles:
         
     | 
| 249 | 
         
            +
                        tiles.set_description(f"Process tile ({hi} {hi_end}), ({wi} {wi_end})")
         
     | 
| 250 | 
         
            +
                        tile_x = x[:, :, hi:hi_end, wi:wi_end]
         
     | 
| 251 | 
         
            +
                        tile_cond = {
         
     | 
| 252 | 
         
            +
                            "c_img": cond["c_img"][:, :, hi:hi_end, wi:wi_end],
         
     | 
| 253 | 
         
            +
                            "c_txt": cond["c_txt"]
         
     | 
| 254 | 
         
            +
                        }
         
     | 
| 255 | 
         
            +
                        if uncond:
         
     | 
| 256 | 
         
            +
                            tile_uncond = {
         
     | 
| 257 | 
         
            +
                                "c_img": uncond["c_img"][:, :, hi:hi_end, wi:wi_end],
         
     | 
| 258 | 
         
            +
                                "c_txt": uncond["c_txt"]
         
     | 
| 259 | 
         
            +
                            }
         
     | 
| 260 | 
         
            +
                        tile_eps = self.predict_noise(model, tile_x, t, tile_cond, tile_uncond, cfg_scale)
         
     | 
| 261 | 
         
            +
                        # accumulate noise
         
     | 
| 262 | 
         
            +
                        eps[:, :, hi:hi_end, wi:wi_end] += tile_eps * weights
         
     | 
| 263 | 
         
            +
                        count[:, :, hi:hi_end, wi:wi_end] += weights
         
     | 
| 264 | 
         
            +
                    # average on noise (score)
         
     | 
| 265 | 
         
            +
                    eps.div_(count)
         
     | 
| 266 | 
         
            +
                    return eps
         
     | 
| 267 | 
         
            +
                
         
     | 
| 268 | 
         
            +
                @torch.no_grad()
         
     | 
| 269 | 
         
            +
                def p_sample(
         
     | 
| 270 | 
         
            +
                    self,
         
     | 
| 271 | 
         
            +
                    model: ControlLDM,
         
     | 
| 272 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 273 | 
         
            +
                    t: torch.Tensor,
         
     | 
| 274 | 
         
            +
                    index: torch.Tensor,
         
     | 
| 275 | 
         
            +
                    cond: Dict[str, torch.Tensor],
         
     | 
| 276 | 
         
            +
                    uncond: Optional[Dict[str, torch.Tensor]],
         
     | 
| 277 | 
         
            +
                    cfg_scale: float,
         
     | 
| 278 | 
         
            +
                    cond_fn: Optional[Guidance],
         
     | 
| 279 | 
         
            +
                    tiled: bool,
         
     | 
| 280 | 
         
            +
                    tile_size: int,
         
     | 
| 281 | 
         
            +
                    tile_stride: int
         
     | 
| 282 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 283 | 
         
            +
                    if tiled:
         
     | 
| 284 | 
         
            +
                        eps = self.predict_noise_tiled(model, x, t, cond, uncond, cfg_scale, tile_size, tile_stride)
         
     | 
| 285 | 
         
            +
                    else:
         
     | 
| 286 | 
         
            +
                        eps = self.predict_noise(model, x, t, cond, uncond, cfg_scale)
         
     | 
| 287 | 
         
            +
                    pred_x0 = self._predict_xstart_from_eps(x, index, eps)
         
     | 
| 288 | 
         
            +
                    if cond_fn:
         
     | 
| 289 | 
         
            +
                        assert not tiled, f"tiled sampling currently doesn't support guidance"
         
     | 
| 290 | 
         
            +
                        pred_x0 = self.apply_cond_fn(model, pred_x0, t, index, cond_fn)
         
     | 
| 291 | 
         
            +
                    model_mean, model_variance, _ = self.q_posterior_mean_variance(pred_x0, x, index)
         
     | 
| 292 | 
         
            +
                    noise = torch.randn_like(x)
         
     | 
| 293 | 
         
            +
                    nonzero_mask = (
         
     | 
| 294 | 
         
            +
                        (index != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
         
     | 
| 295 | 
         
            +
                    )
         
     | 
| 296 | 
         
            +
                    x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise
         
     | 
| 297 | 
         
            +
                    return x_prev
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                @torch.no_grad()
         
     | 
| 300 | 
         
            +
                def sample(
         
     | 
| 301 | 
         
            +
                    self,
         
     | 
| 302 | 
         
            +
                    model: ControlLDM,
         
     | 
| 303 | 
         
            +
                    device: str,
         
     | 
| 304 | 
         
            +
                    steps: int,
         
     | 
| 305 | 
         
            +
                    batch_size: int,
         
     | 
| 306 | 
         
            +
                    x_size: Tuple[int],
         
     | 
| 307 | 
         
            +
                    cond: Dict[str, torch.Tensor],
         
     | 
| 308 | 
         
            +
                    uncond: Dict[str, torch.Tensor],
         
     | 
| 309 | 
         
            +
                    cfg_scale: float,
         
     | 
| 310 | 
         
            +
                    cond_fn: Optional[Guidance]=None,
         
     | 
| 311 | 
         
            +
                    tiled: bool=False,
         
     | 
| 312 | 
         
            +
                    tile_size: int=-1,
         
     | 
| 313 | 
         
            +
                    tile_stride: int=-1,
         
     | 
| 314 | 
         
            +
                    x_T: Optional[torch.Tensor]=None,
         
     | 
| 315 | 
         
            +
                    progress: bool=True,
         
     | 
| 316 | 
         
            +
                    progress_leave: bool=True,
         
     | 
| 317 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 318 | 
         
            +
                    self.make_schedule(steps)
         
     | 
| 319 | 
         
            +
                    self.to(device)
         
     | 
| 320 | 
         
            +
                    if x_T is None:
         
     | 
| 321 | 
         
            +
                        # TODO: not convert to float32, may trigger an error
         
     | 
| 322 | 
         
            +
                        img = torch.randn((batch_size, *x_size), device=device)
         
     | 
| 323 | 
         
            +
                    else:
         
     | 
| 324 | 
         
            +
                        img = x_T
         
     | 
| 325 | 
         
            +
                    timesteps = np.flip(self.timesteps) # [1000, 950, 900, ...]
         
     | 
| 326 | 
         
            +
                    total_steps = len(self.timesteps)
         
     | 
| 327 | 
         
            +
                    iterator = tqdm(timesteps, total=total_steps, leave=progress_leave, disable=not progress)
         
     | 
| 328 | 
         
            +
                    for i, step in enumerate(iterator):
         
     | 
| 329 | 
         
            +
                        ts = torch.full((batch_size,), step, device=device, dtype=torch.long)
         
     | 
| 330 | 
         
            +
                        index = torch.full_like(ts, fill_value=total_steps - i - 1)
         
     | 
| 331 | 
         
            +
                        img = self.p_sample(
         
     | 
| 332 | 
         
            +
                            model, img, ts, index, cond, uncond, cfg_scale, cond_fn,
         
     | 
| 333 | 
         
            +
                            tiled, tile_size, tile_stride
         
     | 
| 334 | 
         
            +
                        )
         
     | 
| 335 | 
         
            +
                        if cond_fn and self.context["g_apply"]:
         
     | 
| 336 | 
         
            +
                            loss_val = self.context["g_loss"]
         
     | 
| 337 | 
         
            +
                            desc = f"Spaced Sampler With Guidance, Loss: {loss_val:.6f}"
         
     | 
| 338 | 
         
            +
                        else:
         
     | 
| 339 | 
         
            +
                            desc = "Spaced Sampler"
         
     | 
| 340 | 
         
            +
                        iterator.set_description(desc)
         
     | 
| 341 | 
         
            +
                    return img
         
     |