| 
							 | 
						import argparse | 
					
					
						
						| 
							 | 
						import functools | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						from typing import List | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import rasterio | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import yaml | 
					
					
						
						| 
							 | 
						from einops import rearrange | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from Prithvi import MaskedAutoencoderViT | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						NO_DATA = -9999 | 
					
					
						
						| 
							 | 
						NO_DATA_FLOAT = 0.0001 | 
					
					
						
						| 
							 | 
						PERCENTILES = (0.1, 99.9) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def process_channel_group(orig_img, new_img, channels, data_mean, data_std): | 
					
					
						
						| 
							 | 
						    """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the | 
					
					
						
						| 
							 | 
						        original range using *data_mean* and *data_std* and then lowest and highest percentiles are | 
					
					
						
						| 
							 | 
						        removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first. | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W). | 
					
					
						
						| 
							 | 
						        new_img: torch.Tensor representing image with shape = (bands, H, W). | 
					
					
						
						| 
							 | 
						        channels: list of indices representing RGB channels. | 
					
					
						
						| 
							 | 
						        data_mean: list of mean values for each band. | 
					
					
						
						| 
							 | 
						        data_std: list of std values for each band. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        torch.Tensor with shape (num_channels, height, width) for original image | 
					
					
						
						| 
							 | 
						        torch.Tensor with shape (num_channels, height, width) for the other image | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    stack_c = [], [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    for c in channels: | 
					
					
						
						| 
							 | 
						        orig_ch = orig_img[c, ...] | 
					
					
						
						| 
							 | 
						        valid_mask = torch.ones_like(orig_ch, dtype=torch.bool) | 
					
					
						
						| 
							 | 
						        valid_mask[orig_ch == 0.0001] = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        orig_ch = (orig_ch * data_std[c]) + data_mean[c] | 
					
					
						
						| 
							 | 
						        new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1) | 
					
					
						
						| 
							 | 
						        new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        orig_ch[~valid_mask] = 0 | 
					
					
						
						| 
							 | 
						        new_ch[~valid_mask] = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        stack_c[0].append(orig_ch) | 
					
					
						
						| 
							 | 
						        stack_c[1].append(new_ch) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    stack_orig = torch.stack(stack_c[0], dim=0) | 
					
					
						
						| 
							 | 
						    stack_rec = torch.stack(stack_c[1], dim=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return stack_orig, stack_rec | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def read_geotiff(file_path: str): | 
					
					
						
						| 
							 | 
						    """ Read all bands from *file_path* and returns image + meta info. | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        file_path: path to image file. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        np.ndarray with shape (bands, height, width) | 
					
					
						
						| 
							 | 
						        meta info dict | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    with rasterio.open(file_path) as src: | 
					
					
						
						| 
							 | 
						        img = src.read() | 
					
					
						
						| 
							 | 
						        meta = src.meta | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return img, meta | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def save_geotiff(image, output_path: str, meta: dict): | 
					
					
						
						| 
							 | 
						    """ Save multi-band image in Geotiff file. | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        image: np.ndarray with shape (bands, height, width) | 
					
					
						
						| 
							 | 
						        output_path: path where to save the image | 
					
					
						
						| 
							 | 
						        meta: dict with meta info. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    with rasterio.open(output_path, "w", **meta) as dest: | 
					
					
						
						| 
							 | 
						        for i in range(image.shape[0]): | 
					
					
						
						| 
							 | 
						            dest.write(image[i, :, :], i + 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def _convert_np_uint8(float_image: torch.Tensor): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    image = float_image.numpy() * 255.0 | 
					
					
						
						| 
							 | 
						    image = image.astype(dtype=np.uint8) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return image | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_example(file_paths: List[str], mean: List[float], std: List[float]): | 
					
					
						
						| 
							 | 
						    """ Build an input example by loading images in *file_paths*. | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        file_paths: list of file paths . | 
					
					
						
						| 
							 | 
						        mean: list containing mean values for each band in the images in *file_paths*. | 
					
					
						
						| 
							 | 
						        std: list containing std values for each band in the images in *file_paths*. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        np.array containing created example | 
					
					
						
						| 
							 | 
						        list of meta info for each image in *file_paths* | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    imgs = [] | 
					
					
						
						| 
							 | 
						    metas = [] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    for file in file_paths: | 
					
					
						
						| 
							 | 
						        img, meta = read_geotiff(file) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        img = np.moveaxis(img, 0, -1)    | 
					
					
						
						| 
							 | 
						        img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        imgs.append(img) | 
					
					
						
						| 
							 | 
						        metas.append(meta) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    imgs = np.stack(imgs, axis=0)     | 
					
					
						
						| 
							 | 
						    imgs = np.moveaxis(imgs, -1, 0).astype('float32')   | 
					
					
						
						| 
							 | 
						    imgs = np.expand_dims(imgs, axis=0)   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return imgs, metas | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device): | 
					
					
						
						| 
							 | 
						    """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible). | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        model: MAE model to run. | 
					
					
						
						| 
							 | 
						        input_data: torch.Tensor with shape (B, C, T, H, W). | 
					
					
						
						| 
							 | 
						        mask_ratio: mask ratio to use. | 
					
					
						
						| 
							 | 
						        device: device where model should run. | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        3 torch.Tensor with shape (B, C, T, H, W). | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    with torch.no_grad(): | 
					
					
						
						| 
							 | 
						        x = input_data.to(device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        _, pred, mask = model(x, mask_ratio) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu() | 
					
					
						
						| 
							 | 
						    pred_img = model.unpatchify(pred).detach().cpu() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    rec_img = input_data.clone() | 
					
					
						
						| 
							 | 
						    rec_img[mask_img == 1] = pred_img[mask_img == 1]   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    mask_img = (~(mask_img.to(torch.bool))).to(torch.float) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return rec_img, mask_img | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data): | 
					
					
						
						| 
							 | 
						    """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp. | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        input_img: input torch.Tensor with shape (C, T, H, W). | 
					
					
						
						| 
							 | 
						        rec_img: reconstructed torch.Tensor with shape (C, T, H, W). | 
					
					
						
						| 
							 | 
						        mask_img: mask torch.Tensor with shape (C, T, H, W). | 
					
					
						
						| 
							 | 
						        channels: list of indices representing RGB channels. | 
					
					
						
						| 
							 | 
						        mean: list of mean values for each band. | 
					
					
						
						| 
							 | 
						        std: list of std values for each band. | 
					
					
						
						| 
							 | 
						        output_dir: directory where to save outputs. | 
					
					
						
						| 
							 | 
						        meta_data: list of dicts with geotiff meta info. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    for t in range(input_img.shape[1]): | 
					
					
						
						| 
							 | 
						        rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :], | 
					
					
						
						| 
							 | 
						                                                   new_img=rec_img[:, t, :, :], | 
					
					
						
						| 
							 | 
						                                                   channels=channels, data_mean=mean, | 
					
					
						
						| 
							 | 
						                                                   data_std=std) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        rgb_mask = mask_img[channels, t, :, :] * rgb_orig | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        save_geotiff(image=_convert_np_uint8(rgb_orig), | 
					
					
						
						| 
							 | 
						                     output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"), | 
					
					
						
						| 
							 | 
						                     meta=meta_data[t]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        save_geotiff(image=_convert_np_uint8(rgb_pred), | 
					
					
						
						| 
							 | 
						                     output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"), | 
					
					
						
						| 
							 | 
						                     meta=meta_data[t]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        save_geotiff(image=_convert_np_uint8(rgb_mask), | 
					
					
						
						| 
							 | 
						                     output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"), | 
					
					
						
						| 
							 | 
						                     meta=meta_data[t]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir: str, mask_ratio: float): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    os.makedirs(output_dir, exist_ok=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    with open(yaml_file_path, 'r') as f: | 
					
					
						
						| 
							 | 
						        params = yaml.safe_load(f) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    num_frames = params['num_frames'] | 
					
					
						
						| 
							 | 
						    img_size = params['img_size'] | 
					
					
						
						| 
							 | 
						    bands = params['bands'] | 
					
					
						
						| 
							 | 
						    mean = params['data_mean'] | 
					
					
						
						| 
							 | 
						    std = params['data_std'] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    depth = params['depth'] | 
					
					
						
						| 
							 | 
						    patch_size = params['patch_size'] | 
					
					
						
						| 
							 | 
						    embed_dim = params['embed_dim'] | 
					
					
						
						| 
							 | 
						    num_heads = params['num_heads'] | 
					
					
						
						| 
							 | 
						    tubelet_size = params['tubelet_size'] | 
					
					
						
						| 
							 | 
						    decoder_embed_dim = params['decoder_embed_dim'] | 
					
					
						
						| 
							 | 
						    decoder_num_heads = params['decoder_num_heads'] | 
					
					
						
						| 
							 | 
						    decoder_depth = params['decoder_depth'] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    batch_size = params['batch_size'] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    assert len(data_files) == num_frames, "File list must be equal to expected number of frames." | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if torch.cuda.is_available(): | 
					
					
						
						| 
							 | 
						        device = torch.device('cuda') | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        device = torch.device('cpu') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print(f"Using {device} device.\n") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    input_data, meta_data = load_example(file_paths=data_files, mean=mean, std=std) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model = MaskedAutoencoderViT( | 
					
					
						
						| 
							 | 
						            img_size=img_size, | 
					
					
						
						| 
							 | 
						            patch_size=patch_size, | 
					
					
						
						| 
							 | 
						            num_frames=num_frames, | 
					
					
						
						| 
							 | 
						            tubelet_size=tubelet_size, | 
					
					
						
						| 
							 | 
						            in_chans=len(bands), | 
					
					
						
						| 
							 | 
						            embed_dim=embed_dim, | 
					
					
						
						| 
							 | 
						            depth=depth, | 
					
					
						
						| 
							 | 
						            num_heads=num_heads, | 
					
					
						
						| 
							 | 
						            decoder_embed_dim=decoder_embed_dim, | 
					
					
						
						| 
							 | 
						            decoder_depth=decoder_depth, | 
					
					
						
						| 
							 | 
						            decoder_num_heads=decoder_num_heads, | 
					
					
						
						| 
							 | 
						            mlp_ratio=4., | 
					
					
						
						| 
							 | 
						            norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6), | 
					
					
						
						| 
							 | 
						            norm_pix_loss=False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | 
					
					
						
						| 
							 | 
						    print(f"\n--> model has {total_params / 1e6} Million params.\n") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model.to(device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    state_dict = torch.load(checkpoint, map_location=device) | 
					
					
						
						| 
							 | 
						    model.load_state_dict(state_dict) | 
					
					
						
						| 
							 | 
						    print(f"Loaded checkpoint from {checkpoint}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    model.eval() | 
					
					
						
						| 
							 | 
						    channels = [bands.index(b) for b in ['B04', 'B03', 'B02']]   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    batch = torch.tensor(input_data, device='cpu') | 
					
					
						
						| 
							 | 
						    windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) | 
					
					
						
						| 
							 | 
						    h1, w1 = windows.shape[3:5] | 
					
					
						
						| 
							 | 
						    windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 | 
					
					
						
						| 
							 | 
						    windows = torch.tensor_split(windows, num_batches, dim=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    rec_imgs = [] | 
					
					
						
						| 
							 | 
						    mask_imgs = [] | 
					
					
						
						| 
							 | 
						    for x in windows: | 
					
					
						
						| 
							 | 
						        rec_img, mask_img = run_model(model, x, mask_ratio, device) | 
					
					
						
						| 
							 | 
						        rec_imgs.append(rec_img) | 
					
					
						
						| 
							 | 
						        mask_imgs.append(mask_img) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    rec_imgs = torch.concat(rec_imgs, dim=0) | 
					
					
						
						| 
							 | 
						    mask_imgs = torch.concat(mask_imgs, dim=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)', | 
					
					
						
						| 
							 | 
						                         h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1) | 
					
					
						
						| 
							 | 
						    mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)', | 
					
					
						
						| 
							 | 
						                          h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    h, w = rec_imgs.shape[-2:] | 
					
					
						
						| 
							 | 
						    rec_imgs_full = batch.clone() | 
					
					
						
						| 
							 | 
						    rec_imgs_full[..., :h, :w] = rec_imgs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    mask_imgs_full = torch.ones_like(batch) | 
					
					
						
						| 
							 | 
						    mask_imgs_full[..., :h, :w] = mask_imgs | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for d in meta_data: | 
					
					
						
						| 
							 | 
						        d.update(count=3, dtype='uint8', compress='lzw', nodata=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...], | 
					
					
						
						| 
							 | 
						                  channels, mean, std, output_dir, meta_data) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print("Done!") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    parser = argparse.ArgumentParser('MAE run inference', add_help=False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    parser.add_argument('--data_files', required=True, type=str, nargs='+', | 
					
					
						
						| 
							 | 
						                        help='Path to the data files. Assumes multi-band files.') | 
					
					
						
						| 
							 | 
						    parser.add_argument('--yaml_file_path', type=str, required=True, | 
					
					
						
						| 
							 | 
						                        help='Path to yaml file containing model training parameters.') | 
					
					
						
						| 
							 | 
						    parser.add_argument('--checkpoint', required=True, type=str, | 
					
					
						
						| 
							 | 
						                        help='Path to a checkpoint file to load from.') | 
					
					
						
						| 
							 | 
						    parser.add_argument('--output_dir', required=True, type=str, | 
					
					
						
						| 
							 | 
						                        help='Path to the directory where to save outputs.') | 
					
					
						
						| 
							 | 
						    parser.add_argument('--mask_ratio', default=None, type=float, | 
					
					
						
						| 
							 | 
						                        help='Masking ratio (percentage of removed patches). ' | 
					
					
						
						| 
							 | 
						                             'If None (default) use same value used for pretraining.') | 
					
					
						
						| 
							 | 
						    args = parser.parse_args() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    main(**vars(args)) |