Upload model
Browse files- config.json +475 -0
- configuration_magi.py +38 -0
- modelling_magi.py +486 -0
- processing_magi.py +274 -0
- pytorch_model.bin +3 -0
- utils.py +391 -0
    	
        config.json
    ADDED
    
    | @@ -0,0 +1,475 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_name_or_path": "to_push",
         | 
| 3 | 
            +
              "architectures": [
         | 
| 4 | 
            +
                "MagiModel"
         | 
| 5 | 
            +
              ],
         | 
| 6 | 
            +
              "auto_map": {
         | 
| 7 | 
            +
                "AutoConfig": "configuration_magi.MagiConfig",
         | 
| 8 | 
            +
                "AutoModel": "modelling_magi.MagiModel"
         | 
| 9 | 
            +
              },
         | 
| 10 | 
            +
              "crop_embedding_image_preprocessing_config": {
         | 
| 11 | 
            +
                "_processor_class": null,
         | 
| 12 | 
            +
                "do_normalize": true,
         | 
| 13 | 
            +
                "do_rescale": true,
         | 
| 14 | 
            +
                "do_resize": true,
         | 
| 15 | 
            +
                "image_mean": [
         | 
| 16 | 
            +
                  0.485,
         | 
| 17 | 
            +
                  0.456,
         | 
| 18 | 
            +
                  0.406
         | 
| 19 | 
            +
                ],
         | 
| 20 | 
            +
                "image_processor_type": "ViTImageProcessor",
         | 
| 21 | 
            +
                "image_std": [
         | 
| 22 | 
            +
                  0.229,
         | 
| 23 | 
            +
                  0.224,
         | 
| 24 | 
            +
                  0.225
         | 
| 25 | 
            +
                ],
         | 
| 26 | 
            +
                "resample": 2,
         | 
| 27 | 
            +
                "rescale_factor": 0.00392156862745098,
         | 
| 28 | 
            +
                "size": {
         | 
| 29 | 
            +
                  "height": 224,
         | 
| 30 | 
            +
                  "width": 224
         | 
| 31 | 
            +
                }
         | 
| 32 | 
            +
              },
         | 
| 33 | 
            +
              "crop_embedding_model_config": {
         | 
| 34 | 
            +
                "_name_or_path": "facebook/vit-mae-base",
         | 
| 35 | 
            +
                "add_cross_attention": false,
         | 
| 36 | 
            +
                "architectures": [
         | 
| 37 | 
            +
                  "ViTMAEForPreTraining"
         | 
| 38 | 
            +
                ],
         | 
| 39 | 
            +
                "attention_probs_dropout_prob": 0.0,
         | 
| 40 | 
            +
                "bad_words_ids": null,
         | 
| 41 | 
            +
                "begin_suppress_tokens": null,
         | 
| 42 | 
            +
                "bos_token_id": null,
         | 
| 43 | 
            +
                "chunk_size_feed_forward": 0,
         | 
| 44 | 
            +
                "cross_attention_hidden_size": null,
         | 
| 45 | 
            +
                "decoder_hidden_size": 512,
         | 
| 46 | 
            +
                "decoder_intermediate_size": 2048,
         | 
| 47 | 
            +
                "decoder_num_attention_heads": 16,
         | 
| 48 | 
            +
                "decoder_num_hidden_layers": 8,
         | 
| 49 | 
            +
                "decoder_start_token_id": null,
         | 
| 50 | 
            +
                "diversity_penalty": 0.0,
         | 
| 51 | 
            +
                "do_sample": false,
         | 
| 52 | 
            +
                "early_stopping": false,
         | 
| 53 | 
            +
                "encoder_no_repeat_ngram_size": 0,
         | 
| 54 | 
            +
                "eos_token_id": null,
         | 
| 55 | 
            +
                "exponential_decay_length_penalty": null,
         | 
| 56 | 
            +
                "finetuning_task": null,
         | 
| 57 | 
            +
                "forced_bos_token_id": null,
         | 
| 58 | 
            +
                "forced_eos_token_id": null,
         | 
| 59 | 
            +
                "hidden_act": "gelu",
         | 
| 60 | 
            +
                "hidden_dropout_prob": 0.0,
         | 
| 61 | 
            +
                "hidden_size": 768,
         | 
| 62 | 
            +
                "id2label": {
         | 
| 63 | 
            +
                  "0": "LABEL_0",
         | 
| 64 | 
            +
                  "1": "LABEL_1"
         | 
| 65 | 
            +
                },
         | 
| 66 | 
            +
                "image_size": 224,
         | 
| 67 | 
            +
                "initializer_range": 0.02,
         | 
| 68 | 
            +
                "intermediate_size": 3072,
         | 
| 69 | 
            +
                "is_decoder": false,
         | 
| 70 | 
            +
                "is_encoder_decoder": false,
         | 
| 71 | 
            +
                "label2id": {
         | 
| 72 | 
            +
                  "LABEL_0": 0,
         | 
| 73 | 
            +
                  "LABEL_1": 1
         | 
| 74 | 
            +
                },
         | 
| 75 | 
            +
                "layer_norm_eps": 1e-12,
         | 
| 76 | 
            +
                "length_penalty": 1.0,
         | 
| 77 | 
            +
                "mask_ratio": 0.75,
         | 
| 78 | 
            +
                "max_length": 20,
         | 
| 79 | 
            +
                "min_length": 0,
         | 
| 80 | 
            +
                "model_type": "",
         | 
| 81 | 
            +
                "no_repeat_ngram_size": 0,
         | 
| 82 | 
            +
                "norm_pix_loss": false,
         | 
| 83 | 
            +
                "num_attention_heads": 12,
         | 
| 84 | 
            +
                "num_beam_groups": 1,
         | 
| 85 | 
            +
                "num_beams": 1,
         | 
| 86 | 
            +
                "num_channels": 3,
         | 
| 87 | 
            +
                "num_hidden_layers": 12,
         | 
| 88 | 
            +
                "num_return_sequences": 1,
         | 
| 89 | 
            +
                "output_attentions": false,
         | 
| 90 | 
            +
                "output_hidden_states": false,
         | 
| 91 | 
            +
                "output_scores": false,
         | 
| 92 | 
            +
                "pad_token_id": null,
         | 
| 93 | 
            +
                "patch_size": 16,
         | 
| 94 | 
            +
                "prefix": null,
         | 
| 95 | 
            +
                "problem_type": null,
         | 
| 96 | 
            +
                "pruned_heads": {},
         | 
| 97 | 
            +
                "qkv_bias": true,
         | 
| 98 | 
            +
                "remove_invalid_values": false,
         | 
| 99 | 
            +
                "repetition_penalty": 1.0,
         | 
| 100 | 
            +
                "return_dict": true,
         | 
| 101 | 
            +
                "return_dict_in_generate": false,
         | 
| 102 | 
            +
                "sep_token_id": null,
         | 
| 103 | 
            +
                "suppress_tokens": null,
         | 
| 104 | 
            +
                "task_specific_params": null,
         | 
| 105 | 
            +
                "temperature": 1.0,
         | 
| 106 | 
            +
                "tf_legacy_loss": false,
         | 
| 107 | 
            +
                "tie_encoder_decoder": false,
         | 
| 108 | 
            +
                "tie_word_embeddings": true,
         | 
| 109 | 
            +
                "tokenizer_class": null,
         | 
| 110 | 
            +
                "top_k": 50,
         | 
| 111 | 
            +
                "top_p": 1.0,
         | 
| 112 | 
            +
                "torch_dtype": "float32",
         | 
| 113 | 
            +
                "torchscript": false,
         | 
| 114 | 
            +
                "typical_p": 1.0,
         | 
| 115 | 
            +
                "use_bfloat16": false
         | 
| 116 | 
            +
              },
         | 
| 117 | 
            +
              "detection_image_preprocessing_config": {
         | 
| 118 | 
            +
                "_processor_class": null,
         | 
| 119 | 
            +
                "do_normalize": true,
         | 
| 120 | 
            +
                "do_pad": true,
         | 
| 121 | 
            +
                "do_rescale": true,
         | 
| 122 | 
            +
                "do_resize": true,
         | 
| 123 | 
            +
                "format": "coco_detection",
         | 
| 124 | 
            +
                "image_mean": [
         | 
| 125 | 
            +
                  0.485,
         | 
| 126 | 
            +
                  0.456,
         | 
| 127 | 
            +
                  0.406
         | 
| 128 | 
            +
                ],
         | 
| 129 | 
            +
                "image_processor_type": "ConditionalDetrImageProcessor",
         | 
| 130 | 
            +
                "image_std": [
         | 
| 131 | 
            +
                  0.229,
         | 
| 132 | 
            +
                  0.224,
         | 
| 133 | 
            +
                  0.225
         | 
| 134 | 
            +
                ],
         | 
| 135 | 
            +
                "resample": 2,
         | 
| 136 | 
            +
                "rescale_factor": 0.00392156862745098,
         | 
| 137 | 
            +
                "size": {
         | 
| 138 | 
            +
                  "longest_edge": 1333,
         | 
| 139 | 
            +
                  "shortest_edge": 800
         | 
| 140 | 
            +
                }
         | 
| 141 | 
            +
              },
         | 
| 142 | 
            +
              "detection_model_config": {
         | 
| 143 | 
            +
                "_name_or_path": "microsoft/conditional-detr-resnet-50",
         | 
| 144 | 
            +
                "activation_dropout": 0.0,
         | 
| 145 | 
            +
                "activation_function": "relu",
         | 
| 146 | 
            +
                "add_cross_attention": false,
         | 
| 147 | 
            +
                "architectures": [
         | 
| 148 | 
            +
                  "ConditionalDETRForObjectDetection"
         | 
| 149 | 
            +
                ],
         | 
| 150 | 
            +
                "attention_dropout": 0.0,
         | 
| 151 | 
            +
                "auxiliary_loss": false,
         | 
| 152 | 
            +
                "backbone": "resnet50",
         | 
| 153 | 
            +
                "backbone_config": null,
         | 
| 154 | 
            +
                "bad_words_ids": null,
         | 
| 155 | 
            +
                "bbox_cost": 5,
         | 
| 156 | 
            +
                "bbox_loss_coefficient": 5,
         | 
| 157 | 
            +
                "begin_suppress_tokens": null,
         | 
| 158 | 
            +
                "bos_token_id": null,
         | 
| 159 | 
            +
                "chunk_size_feed_forward": 0,
         | 
| 160 | 
            +
                "class_cost": 2,
         | 
| 161 | 
            +
                "cls_loss_coefficient": 2,
         | 
| 162 | 
            +
                "cross_attention_hidden_size": null,
         | 
| 163 | 
            +
                "d_model": 256,
         | 
| 164 | 
            +
                "decoder_attention_heads": 8,
         | 
| 165 | 
            +
                "decoder_ffn_dim": 2048,
         | 
| 166 | 
            +
                "decoder_layerdrop": 0.0,
         | 
| 167 | 
            +
                "decoder_layers": 6,
         | 
| 168 | 
            +
                "decoder_start_token_id": null,
         | 
| 169 | 
            +
                "dice_loss_coefficient": 1,
         | 
| 170 | 
            +
                "dilation": false,
         | 
| 171 | 
            +
                "diversity_penalty": 0.0,
         | 
| 172 | 
            +
                "do_sample": false,
         | 
| 173 | 
            +
                "dropout": 0.1,
         | 
| 174 | 
            +
                "early_stopping": false,
         | 
| 175 | 
            +
                "encoder_attention_heads": 8,
         | 
| 176 | 
            +
                "encoder_ffn_dim": 2048,
         | 
| 177 | 
            +
                "encoder_layerdrop": 0.0,
         | 
| 178 | 
            +
                "encoder_layers": 6,
         | 
| 179 | 
            +
                "encoder_no_repeat_ngram_size": 0,
         | 
| 180 | 
            +
                "eos_token_id": null,
         | 
| 181 | 
            +
                "exponential_decay_length_penalty": null,
         | 
| 182 | 
            +
                "finetuning_task": null,
         | 
| 183 | 
            +
                "focal_alpha": 0.25,
         | 
| 184 | 
            +
                "forced_bos_token_id": null,
         | 
| 185 | 
            +
                "forced_eos_token_id": null,
         | 
| 186 | 
            +
                "giou_cost": 2,
         | 
| 187 | 
            +
                "giou_loss_coefficient": 2,
         | 
| 188 | 
            +
                "id2label": {
         | 
| 189 | 
            +
                  "0": "LABEL_0",
         | 
| 190 | 
            +
                  "1": "LABEL_1",
         | 
| 191 | 
            +
                  "2": "LABEL_2"
         | 
| 192 | 
            +
                },
         | 
| 193 | 
            +
                "init_std": 0.02,
         | 
| 194 | 
            +
                "init_xavier_std": 1.0,
         | 
| 195 | 
            +
                "is_decoder": false,
         | 
| 196 | 
            +
                "is_encoder_decoder": true,
         | 
| 197 | 
            +
                "label2id": {
         | 
| 198 | 
            +
                  "LABEL_0": 0,
         | 
| 199 | 
            +
                  "LABEL_1": 1,
         | 
| 200 | 
            +
                  "LABEL_2": 2
         | 
| 201 | 
            +
                },
         | 
| 202 | 
            +
                "length_penalty": 1.0,
         | 
| 203 | 
            +
                "mask_loss_coefficient": 1,
         | 
| 204 | 
            +
                "max_length": 20,
         | 
| 205 | 
            +
                "max_position_embeddings": 1024,
         | 
| 206 | 
            +
                "min_length": 0,
         | 
| 207 | 
            +
                "model_type": "",
         | 
| 208 | 
            +
                "no_repeat_ngram_size": 0,
         | 
| 209 | 
            +
                "num_beam_groups": 1,
         | 
| 210 | 
            +
                "num_beams": 1,
         | 
| 211 | 
            +
                "num_channels": 3,
         | 
| 212 | 
            +
                "num_hidden_layers": 6,
         | 
| 213 | 
            +
                "num_queries": 305,
         | 
| 214 | 
            +
                "num_return_sequences": 1,
         | 
| 215 | 
            +
                "output_attentions": false,
         | 
| 216 | 
            +
                "output_hidden_states": false,
         | 
| 217 | 
            +
                "output_scores": false,
         | 
| 218 | 
            +
                "pad_token_id": null,
         | 
| 219 | 
            +
                "position_embedding_type": "sine",
         | 
| 220 | 
            +
                "prefix": null,
         | 
| 221 | 
            +
                "problem_type": null,
         | 
| 222 | 
            +
                "pruned_heads": {},
         | 
| 223 | 
            +
                "remove_invalid_values": false,
         | 
| 224 | 
            +
                "repetition_penalty": 1.0,
         | 
| 225 | 
            +
                "return_dict": true,
         | 
| 226 | 
            +
                "return_dict_in_generate": false,
         | 
| 227 | 
            +
                "scale_embedding": false,
         | 
| 228 | 
            +
                "sep_token_id": null,
         | 
| 229 | 
            +
                "suppress_tokens": null,
         | 
| 230 | 
            +
                "task_specific_params": null,
         | 
| 231 | 
            +
                "temperature": 1.0,
         | 
| 232 | 
            +
                "tf_legacy_loss": false,
         | 
| 233 | 
            +
                "tie_encoder_decoder": false,
         | 
| 234 | 
            +
                "tie_word_embeddings": true,
         | 
| 235 | 
            +
                "tokenizer_class": null,
         | 
| 236 | 
            +
                "top_k": 50,
         | 
| 237 | 
            +
                "top_p": 1.0,
         | 
| 238 | 
            +
                "torch_dtype": "float32",
         | 
| 239 | 
            +
                "torchscript": false,
         | 
| 240 | 
            +
                "typical_p": 1.0,
         | 
| 241 | 
            +
                "use_bfloat16": false,
         | 
| 242 | 
            +
                "use_pretrained_backbone": true,
         | 
| 243 | 
            +
                "use_timm_backbone": true
         | 
| 244 | 
            +
              },
         | 
| 245 | 
            +
              "disable_crop_embeddings": false,
         | 
| 246 | 
            +
              "disable_detections": false,
         | 
| 247 | 
            +
              "disable_ocr": false,
         | 
| 248 | 
            +
              "model_type": "magi",
         | 
| 249 | 
            +
              "ocr_model_config": {
         | 
| 250 | 
            +
                "_name_or_path": "/work/rs/logs/manga_ocr/nt8rn2ul/",
         | 
| 251 | 
            +
                "add_cross_attention": false,
         | 
| 252 | 
            +
                "architectures": [
         | 
| 253 | 
            +
                  "VisionEncoderDecoderModel"
         | 
| 254 | 
            +
                ],
         | 
| 255 | 
            +
                "bad_words_ids": null,
         | 
| 256 | 
            +
                "begin_suppress_tokens": null,
         | 
| 257 | 
            +
                "bos_token_id": null,
         | 
| 258 | 
            +
                "chunk_size_feed_forward": 0,
         | 
| 259 | 
            +
                "cross_attention_hidden_size": null,
         | 
| 260 | 
            +
                "decoder": {
         | 
| 261 | 
            +
                  "_name_or_path": "",
         | 
| 262 | 
            +
                  "activation_dropout": 0.0,
         | 
| 263 | 
            +
                  "activation_function": "gelu",
         | 
| 264 | 
            +
                  "add_cross_attention": true,
         | 
| 265 | 
            +
                  "architectures": null,
         | 
| 266 | 
            +
                  "attention_dropout": 0.0,
         | 
| 267 | 
            +
                  "bad_words_ids": null,
         | 
| 268 | 
            +
                  "begin_suppress_tokens": null,
         | 
| 269 | 
            +
                  "bos_token_id": 0,
         | 
| 270 | 
            +
                  "chunk_size_feed_forward": 0,
         | 
| 271 | 
            +
                  "classifier_dropout": 0.0,
         | 
| 272 | 
            +
                  "cross_attention_hidden_size": 768,
         | 
| 273 | 
            +
                  "d_model": 1024,
         | 
| 274 | 
            +
                  "decoder_attention_heads": 16,
         | 
| 275 | 
            +
                  "decoder_ffn_dim": 4096,
         | 
| 276 | 
            +
                  "decoder_layerdrop": 0.0,
         | 
| 277 | 
            +
                  "decoder_layers": 12,
         | 
| 278 | 
            +
                  "decoder_start_token_id": 2,
         | 
| 279 | 
            +
                  "diversity_penalty": 0.0,
         | 
| 280 | 
            +
                  "do_sample": false,
         | 
| 281 | 
            +
                  "dropout": 0.1,
         | 
| 282 | 
            +
                  "early_stopping": false,
         | 
| 283 | 
            +
                  "encoder_no_repeat_ngram_size": 0,
         | 
| 284 | 
            +
                  "eos_token_id": 2,
         | 
| 285 | 
            +
                  "exponential_decay_length_penalty": null,
         | 
| 286 | 
            +
                  "finetuning_task": null,
         | 
| 287 | 
            +
                  "forced_bos_token_id": null,
         | 
| 288 | 
            +
                  "forced_eos_token_id": null,
         | 
| 289 | 
            +
                  "id2label": {
         | 
| 290 | 
            +
                    "0": "LABEL_0",
         | 
| 291 | 
            +
                    "1": "LABEL_1"
         | 
| 292 | 
            +
                  },
         | 
| 293 | 
            +
                  "init_std": 0.02,
         | 
| 294 | 
            +
                  "is_decoder": true,
         | 
| 295 | 
            +
                  "is_encoder_decoder": false,
         | 
| 296 | 
            +
                  "label2id": {
         | 
| 297 | 
            +
                    "LABEL_0": 0,
         | 
| 298 | 
            +
                    "LABEL_1": 1
         | 
| 299 | 
            +
                  },
         | 
| 300 | 
            +
                  "layernorm_embedding": true,
         | 
| 301 | 
            +
                  "length_penalty": 1.0,
         | 
| 302 | 
            +
                  "max_length": 20,
         | 
| 303 | 
            +
                  "max_position_embeddings": 512,
         | 
| 304 | 
            +
                  "min_length": 0,
         | 
| 305 | 
            +
                  "model_type": "trocr",
         | 
| 306 | 
            +
                  "no_repeat_ngram_size": 0,
         | 
| 307 | 
            +
                  "num_beam_groups": 1,
         | 
| 308 | 
            +
                  "num_beams": 1,
         | 
| 309 | 
            +
                  "num_return_sequences": 1,
         | 
| 310 | 
            +
                  "output_attentions": false,
         | 
| 311 | 
            +
                  "output_hidden_states": false,
         | 
| 312 | 
            +
                  "output_scores": false,
         | 
| 313 | 
            +
                  "pad_token_id": 1,
         | 
| 314 | 
            +
                  "prefix": null,
         | 
| 315 | 
            +
                  "problem_type": null,
         | 
| 316 | 
            +
                  "pruned_heads": {},
         | 
| 317 | 
            +
                  "remove_invalid_values": false,
         | 
| 318 | 
            +
                  "repetition_penalty": 1.0,
         | 
| 319 | 
            +
                  "return_dict": true,
         | 
| 320 | 
            +
                  "return_dict_in_generate": false,
         | 
| 321 | 
            +
                  "scale_embedding": false,
         | 
| 322 | 
            +
                  "sep_token_id": null,
         | 
| 323 | 
            +
                  "suppress_tokens": null,
         | 
| 324 | 
            +
                  "task_specific_params": null,
         | 
| 325 | 
            +
                  "temperature": 1.0,
         | 
| 326 | 
            +
                  "tf_legacy_loss": false,
         | 
| 327 | 
            +
                  "tie_encoder_decoder": false,
         | 
| 328 | 
            +
                  "tie_word_embeddings": true,
         | 
| 329 | 
            +
                  "tokenizer_class": null,
         | 
| 330 | 
            +
                  "top_k": 50,
         | 
| 331 | 
            +
                  "top_p": 1.0,
         | 
| 332 | 
            +
                  "torch_dtype": null,
         | 
| 333 | 
            +
                  "torchscript": false,
         | 
| 334 | 
            +
                  "typical_p": 1.0,
         | 
| 335 | 
            +
                  "use_bfloat16": false,
         | 
| 336 | 
            +
                  "use_cache": false,
         | 
| 337 | 
            +
                  "use_learned_position_embeddings": true,
         | 
| 338 | 
            +
                  "vocab_size": 50265
         | 
| 339 | 
            +
                },
         | 
| 340 | 
            +
                "decoder_start_token_id": 0,
         | 
| 341 | 
            +
                "diversity_penalty": 0.0,
         | 
| 342 | 
            +
                "do_sample": false,
         | 
| 343 | 
            +
                "early_stopping": true,
         | 
| 344 | 
            +
                "encoder": {
         | 
| 345 | 
            +
                  "_name_or_path": "",
         | 
| 346 | 
            +
                  "add_cross_attention": false,
         | 
| 347 | 
            +
                  "architectures": null,
         | 
| 348 | 
            +
                  "attention_probs_dropout_prob": 0.0,
         | 
| 349 | 
            +
                  "bad_words_ids": null,
         | 
| 350 | 
            +
                  "begin_suppress_tokens": null,
         | 
| 351 | 
            +
                  "bos_token_id": null,
         | 
| 352 | 
            +
                  "chunk_size_feed_forward": 0,
         | 
| 353 | 
            +
                  "cross_attention_hidden_size": null,
         | 
| 354 | 
            +
                  "decoder_start_token_id": null,
         | 
| 355 | 
            +
                  "diversity_penalty": 0.0,
         | 
| 356 | 
            +
                  "do_sample": false,
         | 
| 357 | 
            +
                  "early_stopping": false,
         | 
| 358 | 
            +
                  "encoder_no_repeat_ngram_size": 0,
         | 
| 359 | 
            +
                  "encoder_stride": 16,
         | 
| 360 | 
            +
                  "eos_token_id": null,
         | 
| 361 | 
            +
                  "exponential_decay_length_penalty": null,
         | 
| 362 | 
            +
                  "finetuning_task": null,
         | 
| 363 | 
            +
                  "forced_bos_token_id": null,
         | 
| 364 | 
            +
                  "forced_eos_token_id": null,
         | 
| 365 | 
            +
                  "hidden_act": "gelu",
         | 
| 366 | 
            +
                  "hidden_dropout_prob": 0.0,
         | 
| 367 | 
            +
                  "hidden_size": 768,
         | 
| 368 | 
            +
                  "id2label": {
         | 
| 369 | 
            +
                    "0": "LABEL_0",
         | 
| 370 | 
            +
                    "1": "LABEL_1"
         | 
| 371 | 
            +
                  },
         | 
| 372 | 
            +
                  "image_size": 384,
         | 
| 373 | 
            +
                  "initializer_range": 0.02,
         | 
| 374 | 
            +
                  "intermediate_size": 3072,
         | 
| 375 | 
            +
                  "is_decoder": false,
         | 
| 376 | 
            +
                  "is_encoder_decoder": false,
         | 
| 377 | 
            +
                  "label2id": {
         | 
| 378 | 
            +
                    "LABEL_0": 0,
         | 
| 379 | 
            +
                    "LABEL_1": 1
         | 
| 380 | 
            +
                  },
         | 
| 381 | 
            +
                  "layer_norm_eps": 1e-12,
         | 
| 382 | 
            +
                  "length_penalty": 1.0,
         | 
| 383 | 
            +
                  "max_length": 20,
         | 
| 384 | 
            +
                  "min_length": 0,
         | 
| 385 | 
            +
                  "model_type": "vit",
         | 
| 386 | 
            +
                  "no_repeat_ngram_size": 0,
         | 
| 387 | 
            +
                  "num_attention_heads": 12,
         | 
| 388 | 
            +
                  "num_beam_groups": 1,
         | 
| 389 | 
            +
                  "num_beams": 1,
         | 
| 390 | 
            +
                  "num_channels": 3,
         | 
| 391 | 
            +
                  "num_hidden_layers": 12,
         | 
| 392 | 
            +
                  "num_return_sequences": 1,
         | 
| 393 | 
            +
                  "output_attentions": false,
         | 
| 394 | 
            +
                  "output_hidden_states": false,
         | 
| 395 | 
            +
                  "output_scores": false,
         | 
| 396 | 
            +
                  "pad_token_id": null,
         | 
| 397 | 
            +
                  "patch_size": 16,
         | 
| 398 | 
            +
                  "prefix": null,
         | 
| 399 | 
            +
                  "problem_type": null,
         | 
| 400 | 
            +
                  "pruned_heads": {},
         | 
| 401 | 
            +
                  "qkv_bias": false,
         | 
| 402 | 
            +
                  "remove_invalid_values": false,
         | 
| 403 | 
            +
                  "repetition_penalty": 1.0,
         | 
| 404 | 
            +
                  "return_dict": true,
         | 
| 405 | 
            +
                  "return_dict_in_generate": false,
         | 
| 406 | 
            +
                  "sep_token_id": null,
         | 
| 407 | 
            +
                  "suppress_tokens": null,
         | 
| 408 | 
            +
                  "task_specific_params": null,
         | 
| 409 | 
            +
                  "temperature": 1.0,
         | 
| 410 | 
            +
                  "tf_legacy_loss": false,
         | 
| 411 | 
            +
                  "tie_encoder_decoder": false,
         | 
| 412 | 
            +
                  "tie_word_embeddings": true,
         | 
| 413 | 
            +
                  "tokenizer_class": null,
         | 
| 414 | 
            +
                  "top_k": 50,
         | 
| 415 | 
            +
                  "top_p": 1.0,
         | 
| 416 | 
            +
                  "torch_dtype": null,
         | 
| 417 | 
            +
                  "torchscript": false,
         | 
| 418 | 
            +
                  "typical_p": 1.0,
         | 
| 419 | 
            +
                  "use_bfloat16": false
         | 
| 420 | 
            +
                },
         | 
| 421 | 
            +
                "encoder_no_repeat_ngram_size": 0,
         | 
| 422 | 
            +
                "eos_token_id": 2,
         | 
| 423 | 
            +
                "exponential_decay_length_penalty": null,
         | 
| 424 | 
            +
                "finetuning_task": null,
         | 
| 425 | 
            +
                "forced_bos_token_id": null,
         | 
| 426 | 
            +
                "forced_eos_token_id": null,
         | 
| 427 | 
            +
                "id2label": {
         | 
| 428 | 
            +
                  "0": "LABEL_0",
         | 
| 429 | 
            +
                  "1": "LABEL_1"
         | 
| 430 | 
            +
                },
         | 
| 431 | 
            +
                "is_decoder": false,
         | 
| 432 | 
            +
                "is_encoder_decoder": true,
         | 
| 433 | 
            +
                "label2id": {
         | 
| 434 | 
            +
                  "LABEL_0": 0,
         | 
| 435 | 
            +
                  "LABEL_1": 1
         | 
| 436 | 
            +
                },
         | 
| 437 | 
            +
                "length_penalty": 2.0,
         | 
| 438 | 
            +
                "max_length": 300,
         | 
| 439 | 
            +
                "min_length": 0,
         | 
| 440 | 
            +
                "model_type": "vision-encoder-decoder",
         | 
| 441 | 
            +
                "no_repeat_ngram_size": 3,
         | 
| 442 | 
            +
                "num_beam_groups": 1,
         | 
| 443 | 
            +
                "num_beams": 4,
         | 
| 444 | 
            +
                "num_return_sequences": 1,
         | 
| 445 | 
            +
                "output_attentions": false,
         | 
| 446 | 
            +
                "output_hidden_states": false,
         | 
| 447 | 
            +
                "output_scores": false,
         | 
| 448 | 
            +
                "pad_token_id": 1,
         | 
| 449 | 
            +
                "prefix": null,
         | 
| 450 | 
            +
                "problem_type": null,
         | 
| 451 | 
            +
                "pruned_heads": {},
         | 
| 452 | 
            +
                "remove_invalid_values": false,
         | 
| 453 | 
            +
                "repetition_penalty": 1.0,
         | 
| 454 | 
            +
                "return_dict": true,
         | 
| 455 | 
            +
                "return_dict_in_generate": false,
         | 
| 456 | 
            +
                "sep_token_id": null,
         | 
| 457 | 
            +
                "suppress_tokens": null,
         | 
| 458 | 
            +
                "task_specific_params": null,
         | 
| 459 | 
            +
                "temperature": 1.0,
         | 
| 460 | 
            +
                "tf_legacy_loss": false,
         | 
| 461 | 
            +
                "tie_encoder_decoder": false,
         | 
| 462 | 
            +
                "tie_word_embeddings": false,
         | 
| 463 | 
            +
                "tokenizer_class": null,
         | 
| 464 | 
            +
                "top_k": 50,
         | 
| 465 | 
            +
                "top_p": 1.0,
         | 
| 466 | 
            +
                "torch_dtype": "float32",
         | 
| 467 | 
            +
                "torchscript": false,
         | 
| 468 | 
            +
                "typical_p": 1.0,
         | 
| 469 | 
            +
                "use_bfloat16": false,
         | 
| 470 | 
            +
                "vocab_size": 50265
         | 
| 471 | 
            +
              },
         | 
| 472 | 
            +
              "ocr_pretrained_processor_path": "microsoft/trocr-base-printed",
         | 
| 473 | 
            +
              "torch_dtype": "float32",
         | 
| 474 | 
            +
              "transformers_version": "4.34.0.dev0"
         | 
| 475 | 
            +
            }
         | 
    	
        configuration_magi.py
    ADDED
    
    | @@ -0,0 +1,38 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import PretrainedConfig, VisionEncoderDecoderConfig
         | 
| 2 | 
            +
            from typing import List
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class MagiConfig(PretrainedConfig):
         | 
| 6 | 
            +
                model_type = "magi"
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                def __init__(
         | 
| 9 | 
            +
                    self,
         | 
| 10 | 
            +
                    disable_ocr: bool = False,
         | 
| 11 | 
            +
                    disable_crop_embeddings: bool = False,
         | 
| 12 | 
            +
                    disable_detections: bool = False,
         | 
| 13 | 
            +
                    detection_model_config: dict = None,
         | 
| 14 | 
            +
                    ocr_model_config: dict = None,
         | 
| 15 | 
            +
                    crop_embedding_model_config: dict = None,
         | 
| 16 | 
            +
                    detection_image_preprocessing_config: dict = None,
         | 
| 17 | 
            +
                    ocr_pretrained_processor_path: str = None,
         | 
| 18 | 
            +
                    crop_embedding_image_preprocessing_config: dict = None,
         | 
| 19 | 
            +
                    **kwargs,
         | 
| 20 | 
            +
                ):
         | 
| 21 | 
            +
                    self.disable_ocr = disable_ocr
         | 
| 22 | 
            +
                    self.disable_crop_embeddings = disable_crop_embeddings
         | 
| 23 | 
            +
                    self.disable_detections = disable_detections
         | 
| 24 | 
            +
                    
         | 
| 25 | 
            +
                    self.detection_model_config = None
         | 
| 26 | 
            +
                    self.ocr_model_config = None
         | 
| 27 | 
            +
                    self.crop_embedding_model_config = None
         | 
| 28 | 
            +
                    if detection_model_config is not None:
         | 
| 29 | 
            +
                        self.detection_model_config = PretrainedConfig.from_dict(detection_model_config)
         | 
| 30 | 
            +
                    if ocr_model_config is not None:
         | 
| 31 | 
            +
                        self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(ocr_model_config)
         | 
| 32 | 
            +
                    if crop_embedding_model_config is not None:
         | 
| 33 | 
            +
                        self.crop_embedding_model_config = PretrainedConfig.from_dict(crop_embedding_model_config)
         | 
| 34 | 
            +
                    
         | 
| 35 | 
            +
                    self.detection_image_preprocessing_config = detection_image_preprocessing_config
         | 
| 36 | 
            +
                    self.ocr_pretrained_processor_path = ocr_pretrained_processor_path
         | 
| 37 | 
            +
                    self.crop_embedding_image_preprocessing_config = crop_embedding_image_preprocessing_config
         | 
| 38 | 
            +
                    super().__init__(**kwargs)
         | 
    	
        modelling_magi.py
    ADDED
    
    | @@ -0,0 +1,486 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel
         | 
| 2 | 
            +
            from transformers.models.conditional_detr.modeling_conditional_detr import (
         | 
| 3 | 
            +
                ConditionalDetrMLPPredictionHead, 
         | 
| 4 | 
            +
                ConditionalDetrModelOutput,
         | 
| 5 | 
            +
                ConditionalDetrHungarianMatcher,
         | 
| 6 | 
            +
                inverse_sigmoid,
         | 
| 7 | 
            +
            )
         | 
| 8 | 
            +
            from .configuration_magi import MagiConfig
         | 
| 9 | 
            +
            from .processing_magi import MagiProcessor
         | 
| 10 | 
            +
            from torch import nn
         | 
| 11 | 
            +
            from typing import Optional, List
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from einops import rearrange, repeat, einsum
         | 
| 14 | 
            +
            from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            class MagiModel(PreTrainedModel):
         | 
| 17 | 
            +
                config_class = MagiConfig
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def __init__(self, config):
         | 
| 20 | 
            +
                    super().__init__(config)
         | 
| 21 | 
            +
                    self.config = config
         | 
| 22 | 
            +
                    self.processor = MagiProcessor(config)
         | 
| 23 | 
            +
                    if not config.disable_ocr:
         | 
| 24 | 
            +
                        self.ocr_model = VisionEncoderDecoderModel(config.ocr_model_config)
         | 
| 25 | 
            +
                    if not config.disable_crop_embeddings:
         | 
| 26 | 
            +
                        self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config)
         | 
| 27 | 
            +
                    if not config.disable_detections:
         | 
| 28 | 
            +
                        self.num_non_obj_tokens = 5
         | 
| 29 | 
            +
                        self.detection_transformer = ConditionalDetrModel(config.detection_model_config)
         | 
| 30 | 
            +
                        self.bbox_predictor = ConditionalDetrMLPPredictionHead(
         | 
| 31 | 
            +
                            input_dim=config.detection_model_config.d_model,
         | 
| 32 | 
            +
                            hidden_dim=config.detection_model_config.d_model,
         | 
| 33 | 
            +
                            output_dim=4, num_layers=3
         | 
| 34 | 
            +
                        )
         | 
| 35 | 
            +
                        self.is_this_text_a_dialogue = ConditionalDetrMLPPredictionHead(
         | 
| 36 | 
            +
                            input_dim=config.detection_model_config.d_model,
         | 
| 37 | 
            +
                            hidden_dim=config.detection_model_config.d_model,
         | 
| 38 | 
            +
                            output_dim=1,
         | 
| 39 | 
            +
                            num_layers=3
         | 
| 40 | 
            +
                        )
         | 
| 41 | 
            +
                        self.character_character_matching_head = ConditionalDetrMLPPredictionHead(
         | 
| 42 | 
            +
                            input_dim = 3 * config.detection_model_config.d_model + (2 * config.crop_embedding_model_config.hidden_size if not config.disable_crop_embeddings else 0),
         | 
| 43 | 
            +
                            hidden_dim=config.detection_model_config.d_model,
         | 
| 44 | 
            +
                            output_dim=1, num_layers=3
         | 
| 45 | 
            +
                        )
         | 
| 46 | 
            +
                        self.text_character_matching_head = ConditionalDetrMLPPredictionHead(
         | 
| 47 | 
            +
                            input_dim = 3 * config.detection_model_config.d_model,
         | 
| 48 | 
            +
                            hidden_dim=config.detection_model_config.d_model,
         | 
| 49 | 
            +
                            output_dim=1, num_layers=3
         | 
| 50 | 
            +
                        )
         | 
| 51 | 
            +
                        self.class_labels_classifier = nn.Linear(
         | 
| 52 | 
            +
                            config.detection_model_config.d_model, config.detection_model_config.num_labels
         | 
| 53 | 
            +
                        )
         | 
| 54 | 
            +
                        self.matcher = ConditionalDetrHungarianMatcher(
         | 
| 55 | 
            +
                            class_cost=config.detection_model_config.class_cost,
         | 
| 56 | 
            +
                            bbox_cost=config.detection_model_config.bbox_cost,
         | 
| 57 | 
            +
                            giou_cost=config.detection_model_config.giou_cost
         | 
| 58 | 
            +
                        )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def move_to_device(self, input):
         | 
| 61 | 
            +
                    return move_to_device(input, self.device)
         | 
| 62 | 
            +
                
         | 
| 63 | 
            +
                def predict_detections_and_associations(
         | 
| 64 | 
            +
                        self,
         | 
| 65 | 
            +
                        images,
         | 
| 66 | 
            +
                        move_to_device_fn=None,
         | 
| 67 | 
            +
                        character_detection_threshold=0.3,
         | 
| 68 | 
            +
                        panel_detection_threshold=0.2,
         | 
| 69 | 
            +
                        text_detection_threshold=0.25,
         | 
| 70 | 
            +
                        character_character_matching_threshold=0.7,
         | 
| 71 | 
            +
                        text_character_matching_threshold=0.4,
         | 
| 72 | 
            +
                    ):
         | 
| 73 | 
            +
                    assert not self.config.disable_detections
         | 
| 74 | 
            +
                    move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
         | 
| 75 | 
            +
                    
         | 
| 76 | 
            +
                    inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images)
         | 
| 77 | 
            +
                    inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
         | 
| 80 | 
            +
                    predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # create callback fn
         | 
| 83 | 
            +
                    def get_character_character_matching_scores(batch_character_indices, batch_bboxes):
         | 
| 84 | 
            +
                        predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
         | 
| 85 | 
            +
                        predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
         | 
| 86 | 
            +
                        crop_bboxes = [batch_bboxes[i][batch_character_indices[i]] for i in range(len(batch_character_indices))]
         | 
| 87 | 
            +
                        crop_embeddings_for_batch = self.predict_crop_embeddings(images, crop_bboxes, move_to_device_fn)
         | 
| 88 | 
            +
                        character_obj_tokens_for_batch = []
         | 
| 89 | 
            +
                        c2c_tokens_for_batch = []
         | 
| 90 | 
            +
                        for predicted_obj_tokens, predicted_c2c_tokens, character_indices in zip(predicted_obj_tokens_for_batch, predicted_c2c_tokens_for_batch, batch_character_indices):
         | 
| 91 | 
            +
                            character_obj_tokens_for_batch.append(predicted_obj_tokens[character_indices])
         | 
| 92 | 
            +
                            c2c_tokens_for_batch.append(predicted_c2c_tokens)
         | 
| 93 | 
            +
                        return self._get_character_character_affinity_matrices(
         | 
| 94 | 
            +
                            character_obj_tokens_for_batch=character_obj_tokens_for_batch,
         | 
| 95 | 
            +
                            crop_embeddings_for_batch=crop_embeddings_for_batch,
         | 
| 96 | 
            +
                            c2c_tokens_for_batch=c2c_tokens_for_batch,
         | 
| 97 | 
            +
                            apply_sigmoid=True,
         | 
| 98 | 
            +
                        )
         | 
| 99 | 
            +
                    
         | 
| 100 | 
            +
                    # create callback fn
         | 
| 101 | 
            +
                    def get_text_character_matching_scores(batch_text_indices, batch_character_indices):
         | 
| 102 | 
            +
                        predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
         | 
| 103 | 
            +
                        predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
         | 
| 104 | 
            +
                        text_obj_tokens_for_batch = []
         | 
| 105 | 
            +
                        character_obj_tokens_for_batch = []
         | 
| 106 | 
            +
                        t2c_tokens_for_batch = []
         | 
| 107 | 
            +
                        for predicted_obj_tokens, predicted_t2c_tokens, text_indices, character_indices in zip(predicted_obj_tokens_for_batch, predicted_t2c_tokens_for_batch, batch_text_indices, batch_character_indices):
         | 
| 108 | 
            +
                            text_obj_tokens_for_batch.append(predicted_obj_tokens[text_indices])
         | 
| 109 | 
            +
                            character_obj_tokens_for_batch.append(predicted_obj_tokens[character_indices])
         | 
| 110 | 
            +
                            t2c_tokens_for_batch.append(predicted_t2c_tokens)
         | 
| 111 | 
            +
                        return self._get_text_character_affinity_matrices(
         | 
| 112 | 
            +
                            character_obj_tokens_for_batch=character_obj_tokens_for_batch,
         | 
| 113 | 
            +
                            text_obj_tokens_for_this_batch=text_obj_tokens_for_batch,
         | 
| 114 | 
            +
                            t2c_tokens_for_batch=t2c_tokens_for_batch,
         | 
| 115 | 
            +
                            apply_sigmoid=True,
         | 
| 116 | 
            +
                        )
         | 
| 117 | 
            +
                    
         | 
| 118 | 
            +
                    # create callback fn
         | 
| 119 | 
            +
                    def get_dialog_confidence_scores(batch_text_indices):
         | 
| 120 | 
            +
                        predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
         | 
| 121 | 
            +
                        dialog_confidence = []
         | 
| 122 | 
            +
                        for predicted_obj_tokens, text_indices in zip(predicted_obj_tokens_for_batch, batch_text_indices):
         | 
| 123 | 
            +
                            confidence = self.is_this_text_a_dialogue(predicted_obj_tokens[text_indices]).sigmoid()
         | 
| 124 | 
            +
                            dialog_confidence.append(rearrange(confidence, "i 1 -> i"))
         | 
| 125 | 
            +
                        return dialog_confidence
         | 
| 126 | 
            +
                    
         | 
| 127 | 
            +
                    return self.processor.postprocess_detections_and_associations(
         | 
| 128 | 
            +
                        predicted_bboxes=predicted_bboxes,
         | 
| 129 | 
            +
                        predicted_class_scores=predicted_class_scores,
         | 
| 130 | 
            +
                        original_image_sizes=torch.stack([torch.tensor(img.shape[:2]) for img in images], dim=0).to(predicted_bboxes.device),
         | 
| 131 | 
            +
                        get_character_character_matching_scores=get_character_character_matching_scores,
         | 
| 132 | 
            +
                        get_text_character_matching_scores=get_text_character_matching_scores,
         | 
| 133 | 
            +
                        get_dialog_confidence_scores=get_dialog_confidence_scores,
         | 
| 134 | 
            +
                        character_detection_threshold=character_detection_threshold,
         | 
| 135 | 
            +
                        panel_detection_threshold=panel_detection_threshold,
         | 
| 136 | 
            +
                        text_detection_threshold=text_detection_threshold,
         | 
| 137 | 
            +
                        character_character_matching_threshold=character_character_matching_threshold,
         | 
| 138 | 
            +
                        text_character_matching_threshold=text_character_matching_threshold,
         | 
| 139 | 
            +
                    )
         | 
| 140 | 
            +
                
         | 
| 141 | 
            +
                def predict_crop_embeddings(self, images, crop_bboxes, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
         | 
| 142 | 
            +
                    if self.config.disable_crop_embeddings:
         | 
| 143 | 
            +
                        return None
         | 
| 144 | 
            +
                    
         | 
| 145 | 
            +
                    assert isinstance(crop_bboxes, List), "please provide a list of bboxes for each image to get embeddings for"
         | 
| 146 | 
            +
                    
         | 
| 147 | 
            +
                    move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
         | 
| 148 | 
            +
                    
         | 
| 149 | 
            +
                    # temporarily change the mask ratio from default to the one specified
         | 
| 150 | 
            +
                    old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio
         | 
| 151 | 
            +
                    self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    crops_per_image = []
         | 
| 154 | 
            +
                    num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
         | 
| 155 | 
            +
                    for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
         | 
| 156 | 
            +
                        crops = self.processor.crop_image(image, bboxes)
         | 
| 157 | 
            +
                        assert len(crops) == num_crops
         | 
| 158 | 
            +
                        crops_per_image.extend(crops)
         | 
| 159 | 
            +
                    
         | 
| 160 | 
            +
                    if len(crops_per_image) == 0:
         | 
| 161 | 
            +
                        return [[] for _ in crop_bboxes]
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    crops_per_image = self.processor.preprocess_inputs_for_crop_embeddings(crops_per_image)
         | 
| 164 | 
            +
                    crops_per_image = move_to_device_fn(crops_per_image)
         | 
| 165 | 
            +
                    
         | 
| 166 | 
            +
                    # process the crops in batches to avoid OOM
         | 
| 167 | 
            +
                    embeddings = []
         | 
| 168 | 
            +
                    for i in range(0, len(crops_per_image), batch_size):
         | 
| 169 | 
            +
                        crops = crops_per_image[i:i+batch_size]
         | 
| 170 | 
            +
                        embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0]
         | 
| 171 | 
            +
                        embeddings.append(embeddings_per_batch)
         | 
| 172 | 
            +
                    embeddings = torch.cat(embeddings, dim=0)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    crop_embeddings_for_batch = []
         | 
| 175 | 
            +
                    for num_crops in num_crops_per_batch:
         | 
| 176 | 
            +
                        crop_embeddings_for_batch.append(embeddings[:num_crops])
         | 
| 177 | 
            +
                        embeddings = embeddings[num_crops:]
         | 
| 178 | 
            +
                    
         | 
| 179 | 
            +
                    # restore the mask ratio to the default
         | 
| 180 | 
            +
                    self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    return crop_embeddings_for_batch
         | 
| 183 | 
            +
                
         | 
| 184 | 
            +
                def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32):
         | 
| 185 | 
            +
                    assert not self.config.disable_ocr
         | 
| 186 | 
            +
                    move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    crops_per_image = []
         | 
| 189 | 
            +
                    num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
         | 
| 190 | 
            +
                    for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
         | 
| 191 | 
            +
                        crops = self.processor.crop_image(image, bboxes)
         | 
| 192 | 
            +
                        assert len(crops) == num_crops
         | 
| 193 | 
            +
                        crops_per_image.extend(crops)
         | 
| 194 | 
            +
                    
         | 
| 195 | 
            +
                    if len(crops_per_image) == 0:
         | 
| 196 | 
            +
                        return [[] for _ in crop_bboxes]
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    crops_per_image = self.processor.preprocess_inputs_for_ocr(crops_per_image)
         | 
| 199 | 
            +
                    crops_per_image = move_to_device_fn(crops_per_image)
         | 
| 200 | 
            +
                    
         | 
| 201 | 
            +
                    # process the crops in batches to avoid OOM
         | 
| 202 | 
            +
                    all_generated_texts = []
         | 
| 203 | 
            +
                    if use_tqdm:
         | 
| 204 | 
            +
                        from tqdm import tqdm
         | 
| 205 | 
            +
                        pbar = tqdm(range(0, len(crops_per_image), batch_size))
         | 
| 206 | 
            +
                    else:
         | 
| 207 | 
            +
                        pbar = range(0, len(crops_per_image), batch_size)
         | 
| 208 | 
            +
                    for i in pbar:
         | 
| 209 | 
            +
                        crops = crops_per_image[i:i+batch_size]
         | 
| 210 | 
            +
                        generated_ids = self.ocr_model.generate(crops)
         | 
| 211 | 
            +
                        generated_texts = self.processor.postprocess_ocr_tokens(generated_ids)
         | 
| 212 | 
            +
                        all_generated_texts.extend(generated_texts)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    texts_for_images = []
         | 
| 215 | 
            +
                    for num_crops in num_crops_per_batch:
         | 
| 216 | 
            +
                        texts_for_images.append([x.replace("\n", "") for x in all_generated_texts[:num_crops]])
         | 
| 217 | 
            +
                        all_generated_texts = all_generated_texts[num_crops:]
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    return texts_for_images
         | 
| 220 | 
            +
                
         | 
| 221 | 
            +
                def visualise_single_image_prediction(
         | 
| 222 | 
            +
                        self, image_as_np_array, predictions, filename=None
         | 
| 223 | 
            +
                ):
         | 
| 224 | 
            +
                    return visualise_single_image_prediction(image_as_np_array, predictions, filename)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                def generate_transcript_for_single_image(
         | 
| 227 | 
            +
                        self, predictions, ocr_results, filename=None
         | 
| 228 | 
            +
                ):
         | 
| 229 | 
            +
                    character_clusters = predictions["character_cluster_labels"]
         | 
| 230 | 
            +
                    text_to_character = predictions["text_character_associations"]
         | 
| 231 | 
            +
                    text_to_character = {k: v for k, v in text_to_character}
         | 
| 232 | 
            +
                    transript = " ### Transcript ###\n"
         | 
| 233 | 
            +
                    for index, text in enumerate(ocr_results):
         | 
| 234 | 
            +
                        if index in text_to_character:
         | 
| 235 | 
            +
                            speaker = character_clusters[text_to_character[index]]
         | 
| 236 | 
            +
                            speaker = f"<{speaker}>"
         | 
| 237 | 
            +
                        else:
         | 
| 238 | 
            +
                            speaker = "<?>"
         | 
| 239 | 
            +
                        transript += f"{speaker}: {text}\n"
         | 
| 240 | 
            +
                    if filename is not None:
         | 
| 241 | 
            +
                        with open(filename, "w") as file:
         | 
| 242 | 
            +
                            file.write(transript)
         | 
| 243 | 
            +
                    return transript
         | 
| 244 | 
            +
                
         | 
| 245 | 
            +
                def get_text_character_affinity_matrices_given_annotations(
         | 
| 246 | 
            +
                        self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
         | 
| 247 | 
            +
                ):
         | 
| 248 | 
            +
                    assert not self.config.disable_detections
         | 
| 249 | 
            +
                    move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
         | 
| 252 | 
            +
                    inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
         | 
| 253 | 
            +
                    processed_targets = inputs_to_detection_transformer.pop("labels")
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
         | 
| 256 | 
            +
                    predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
         | 
| 257 | 
            +
                    predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
         | 
| 260 | 
            +
                    matching_dict = {
         | 
| 261 | 
            +
                        "logits": predicted_class_scores,
         | 
| 262 | 
            +
                        "pred_boxes": predicted_bboxes,
         | 
| 263 | 
            +
                    }
         | 
| 264 | 
            +
                    indices = self.matcher(matching_dict, processed_targets)
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    matched_char_obj_tokens_for_batch = []
         | 
| 267 | 
            +
                    matched_text_obj_tokens_for_batch = []
         | 
| 268 | 
            +
                    t2c_tokens_for_batch = []
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    text_bboxes_for_batch = []
         | 
| 271 | 
            +
                    character_bboxes_for_batch = []
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    for j, (pred_idx, tgt_idx) in enumerate(indices):
         | 
| 274 | 
            +
                        target_idx_to_pred_idx = {tgt.item(): pred.item() for pred, tgt in zip(pred_idx, tgt_idx)}
         | 
| 275 | 
            +
                        targets_for_this_image = processed_targets[j]
         | 
| 276 | 
            +
                        indices_of_text_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 1]
         | 
| 277 | 
            +
                        indices_of_char_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 0]
         | 
| 278 | 
            +
                        predicted_text_indices = [target_idx_to_pred_idx[i] for i in indices_of_text_boxes_in_annotation]
         | 
| 279 | 
            +
                        predicted_char_indices = [target_idx_to_pred_idx[i] for i in indices_of_char_boxes_in_annotation]
         | 
| 280 | 
            +
                        
         | 
| 281 | 
            +
                        text_bboxes_for_batch.append(
         | 
| 282 | 
            +
                            [annotations[j]["bboxes_as_x1y1x2y2"][k] for k in indices_of_text_boxes_in_annotation]
         | 
| 283 | 
            +
                        )
         | 
| 284 | 
            +
                        character_bboxes_for_batch.append(
         | 
| 285 | 
            +
                            [annotations[j]["bboxes_as_x1y1x2y2"][k] for k in indices_of_char_boxes_in_annotation]
         | 
| 286 | 
            +
                        )
         | 
| 287 | 
            +
                        
         | 
| 288 | 
            +
                        matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
         | 
| 289 | 
            +
                        matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
         | 
| 290 | 
            +
                        t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
         | 
| 291 | 
            +
                    
         | 
| 292 | 
            +
                    text_character_affinity_matrices = self._get_text_character_affinity_matrices(
         | 
| 293 | 
            +
                        character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
         | 
| 294 | 
            +
                        text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
         | 
| 295 | 
            +
                        t2c_tokens_for_batch=t2c_tokens_for_batch,
         | 
| 296 | 
            +
                        apply_sigmoid=apply_sigmoid,
         | 
| 297 | 
            +
                    )
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    return {
         | 
| 300 | 
            +
                        "text_character_affinity_matrices": text_character_affinity_matrices,
         | 
| 301 | 
            +
                        "text_bboxes_for_batch": text_bboxes_for_batch,
         | 
| 302 | 
            +
                        "character_bboxes_for_batch": character_bboxes_for_batch,
         | 
| 303 | 
            +
                    }
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                def get_obj_embeddings_corresponding_to_given_annotations(
         | 
| 306 | 
            +
                        self, images, annotations, move_to_device_fn=None
         | 
| 307 | 
            +
                ):
         | 
| 308 | 
            +
                    assert not self.config.disable_detections
         | 
| 309 | 
            +
                    move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(images, annotations)
         | 
| 312 | 
            +
                    inputs_to_detection_transformer = move_to_device_fn(inputs_to_detection_transformer)
         | 
| 313 | 
            +
                    processed_targets = inputs_to_detection_transformer.pop("labels")
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    detection_transformer_output = self._get_detection_transformer_output(**inputs_to_detection_transformer)
         | 
| 316 | 
            +
                    predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(detection_transformer_output)
         | 
| 317 | 
            +
                    predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(detection_transformer_output)
         | 
| 318 | 
            +
                    predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(detection_transformer_output)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(detection_transformer_output)
         | 
| 321 | 
            +
                    matching_dict = {
         | 
| 322 | 
            +
                        "logits": predicted_class_scores,
         | 
| 323 | 
            +
                        "pred_boxes": predicted_bboxes,
         | 
| 324 | 
            +
                    }
         | 
| 325 | 
            +
                    indices = self.matcher(matching_dict, processed_targets)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    matched_char_obj_tokens_for_batch = []
         | 
| 328 | 
            +
                    matched_text_obj_tokens_for_batch = []
         | 
| 329 | 
            +
                    matched_panel_obj_tokens_for_batch = []
         | 
| 330 | 
            +
                    t2c_tokens_for_batch = []
         | 
| 331 | 
            +
                    c2c_tokens_for_batch = []
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    for j, (pred_idx, tgt_idx) in enumerate(indices):
         | 
| 334 | 
            +
                        target_idx_to_pred_idx = {tgt.item(): pred.item() for pred, tgt in zip(pred_idx, tgt_idx)}
         | 
| 335 | 
            +
                        targets_for_this_image = processed_targets[j]
         | 
| 336 | 
            +
                        indices_of_char_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 0]
         | 
| 337 | 
            +
                        indices_of_text_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 1]
         | 
| 338 | 
            +
                        indices_of_panel_boxes_in_annotation = [i for i, label in enumerate(targets_for_this_image["class_labels"]) if label == 2]
         | 
| 339 | 
            +
                        predicted_text_indices = [target_idx_to_pred_idx[i] for i in indices_of_text_boxes_in_annotation]
         | 
| 340 | 
            +
                        predicted_char_indices = [target_idx_to_pred_idx[i] for i in indices_of_char_boxes_in_annotation]
         | 
| 341 | 
            +
                        predicted_panel_indices = [target_idx_to_pred_idx[i] for i in indices_of_panel_boxes_in_annotation]
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                        matched_char_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_char_indices])
         | 
| 344 | 
            +
                        matched_text_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_text_indices])
         | 
| 345 | 
            +
                        matched_panel_obj_tokens_for_batch.append(predicted_obj_tokens_for_batch[j][predicted_panel_indices])
         | 
| 346 | 
            +
                        t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
         | 
| 347 | 
            +
                        c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    return {
         | 
| 350 | 
            +
                        "character": matched_char_obj_tokens_for_batch,
         | 
| 351 | 
            +
                        "text": matched_text_obj_tokens_for_batch,
         | 
| 352 | 
            +
                        "panel": matched_panel_obj_tokens_for_batch,
         | 
| 353 | 
            +
                        "t2c": t2c_tokens_for_batch,
         | 
| 354 | 
            +
                        "c2c": c2c_tokens_for_batch,
         | 
| 355 | 
            +
                    }
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                def sort_panels_and_text_bboxes_in_reading_order(
         | 
| 358 | 
            +
                    self,
         | 
| 359 | 
            +
                    batch_panel_bboxes,
         | 
| 360 | 
            +
                    batch_text_bboxes,
         | 
| 361 | 
            +
                ):
         | 
| 362 | 
            +
                    batch_sorted_panel_indices = []
         | 
| 363 | 
            +
                    batch_sorted_text_indices = []
         | 
| 364 | 
            +
                    for batch_index in range(len(batch_text_bboxes)):
         | 
| 365 | 
            +
                        panel_bboxes = batch_panel_bboxes[batch_index]
         | 
| 366 | 
            +
                        text_bboxes = batch_text_bboxes[batch_index]
         | 
| 367 | 
            +
                        sorted_panel_indices = sort_panels(panel_bboxes)
         | 
| 368 | 
            +
                        sorted_panels = [panel_bboxes[i] for i in sorted_panel_indices]
         | 
| 369 | 
            +
                        sorted_text_indices = sort_text_boxes_in_reading_order(text_bboxes, sorted_panels)
         | 
| 370 | 
            +
                        batch_sorted_panel_indices.append(sorted_panel_indices)
         | 
| 371 | 
            +
                        batch_sorted_text_indices.append(sorted_text_indices)
         | 
| 372 | 
            +
                    return batch_sorted_panel_indices, batch_sorted_text_indices
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                def _get_detection_transformer_output(
         | 
| 375 | 
            +
                        self, 
         | 
| 376 | 
            +
                        pixel_values: torch.FloatTensor,
         | 
| 377 | 
            +
                        pixel_mask: Optional[torch.LongTensor] = None
         | 
| 378 | 
            +
                ):
         | 
| 379 | 
            +
                    if self.config.disable_detections:
         | 
| 380 | 
            +
                        raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
         | 
| 381 | 
            +
                    return self.detection_transformer(
         | 
| 382 | 
            +
                        pixel_values=pixel_values,
         | 
| 383 | 
            +
                        pixel_mask=pixel_mask,
         | 
| 384 | 
            +
                        return_dict=True
         | 
| 385 | 
            +
                    )
         | 
| 386 | 
            +
                
         | 
| 387 | 
            +
                def _get_predicted_obj_tokens(
         | 
| 388 | 
            +
                        self,
         | 
| 389 | 
            +
                        detection_transformer_output: ConditionalDetrModelOutput
         | 
| 390 | 
            +
                ):
         | 
| 391 | 
            +
                    return detection_transformer_output.last_hidden_state[:, :-self.num_non_obj_tokens]
         | 
| 392 | 
            +
                
         | 
| 393 | 
            +
                def _get_predicted_c2c_tokens(
         | 
| 394 | 
            +
                        self,
         | 
| 395 | 
            +
                        detection_transformer_output: ConditionalDetrModelOutput
         | 
| 396 | 
            +
                ):
         | 
| 397 | 
            +
                    return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens]
         | 
| 398 | 
            +
                
         | 
| 399 | 
            +
                def _get_predicted_t2c_tokens(
         | 
| 400 | 
            +
                        self,
         | 
| 401 | 
            +
                        detection_transformer_output: ConditionalDetrModelOutput
         | 
| 402 | 
            +
                ):
         | 
| 403 | 
            +
                    return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens+1]
         | 
| 404 | 
            +
                
         | 
| 405 | 
            +
                def _get_predicted_bboxes_and_classes(
         | 
| 406 | 
            +
                        self,
         | 
| 407 | 
            +
                        detection_transformer_output: ConditionalDetrModelOutput,
         | 
| 408 | 
            +
                ):
         | 
| 409 | 
            +
                    if self.config.disable_detections:
         | 
| 410 | 
            +
                        raise ValueError("Detection model is disabled. Set disable_detections=False in the config.")
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    obj = self._get_predicted_obj_tokens(detection_transformer_output)
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                    predicted_class_scores = self.class_labels_classifier(obj)
         | 
| 415 | 
            +
                    reference = detection_transformer_output.reference_points[:-self.num_non_obj_tokens] 
         | 
| 416 | 
            +
                    reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
         | 
| 417 | 
            +
                    predicted_boxes = self.bbox_predictor(obj)
         | 
| 418 | 
            +
                    predicted_boxes[..., :2] += reference_before_sigmoid
         | 
| 419 | 
            +
                    predicted_boxes = predicted_boxes.sigmoid()
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    return predicted_class_scores, predicted_boxes
         | 
| 422 | 
            +
                
         | 
| 423 | 
            +
                def _get_character_character_affinity_matrices(
         | 
| 424 | 
            +
                        self,
         | 
| 425 | 
            +
                        character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
         | 
| 426 | 
            +
                        crop_embeddings_for_batch: List[torch.FloatTensor] = None,
         | 
| 427 | 
            +
                        c2c_tokens_for_batch: List[torch.FloatTensor] = None,
         | 
| 428 | 
            +
                        apply_sigmoid=True,
         | 
| 429 | 
            +
                ):
         | 
| 430 | 
            +
                    assert self.config.disable_detections or (character_obj_tokens_for_batch is not None and c2c_tokens_for_batch is not None)
         | 
| 431 | 
            +
                    assert self.config.disable_crop_embeddings or crop_embeddings_for_batch is not None
         | 
| 432 | 
            +
                    assert not self.config.disable_detections or not self.config.disable_crop_embeddings
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    if self.config.disable_detections:
         | 
| 435 | 
            +
                        affinity_matrices = []
         | 
| 436 | 
            +
                        for crop_embeddings in crop_embeddings_for_batch:
         | 
| 437 | 
            +
                            crop_embeddings = crop_embeddings / crop_embeddings.norm(dim=-1, keepdim=True)
         | 
| 438 | 
            +
                            affinity_matrix = einsum("i d, j d -> i j", affinity_matrix)
         | 
| 439 | 
            +
                            affinity_matrices.append(affinity_matrix)
         | 
| 440 | 
            +
                        return affinity_matrices
         | 
| 441 | 
            +
                    affinity_matrices = []
         | 
| 442 | 
            +
                    for batch_index, (character_obj_tokens, c2c) in enumerate(zip(character_obj_tokens_for_batch, c2c_tokens_for_batch)):
         | 
| 443 | 
            +
                        if character_obj_tokens.shape[0] == 0:
         | 
| 444 | 
            +
                            affinity_matrices.append(torch.zeros(0, 0).type_as(character_obj_tokens))
         | 
| 445 | 
            +
                            continue
         | 
| 446 | 
            +
                        if not self.config.disable_crop_embeddings:
         | 
| 447 | 
            +
                            crop_embeddings = crop_embeddings_for_batch[batch_index]
         | 
| 448 | 
            +
                            assert character_obj_tokens.shape[0] == crop_embeddings.shape[0]
         | 
| 449 | 
            +
                            character_obj_tokens = torch.cat([character_obj_tokens, crop_embeddings], dim=-1)
         | 
| 450 | 
            +
                        char_i = repeat(character_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
         | 
| 451 | 
            +
                        char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=character_obj_tokens.shape[0])
         | 
| 452 | 
            +
                        char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
         | 
| 453 | 
            +
                        c2c = repeat(c2c, "d -> repeat d", repeat = char_ij.shape[0])
         | 
| 454 | 
            +
                        char_ij_c2c = torch.cat([char_ij, c2c], dim=-1)
         | 
| 455 | 
            +
                        character_character_affinities = self.character_character_matching_head(char_ij_c2c)
         | 
| 456 | 
            +
                        character_character_affinities = rearrange(character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
         | 
| 457 | 
            +
                        if apply_sigmoid:
         | 
| 458 | 
            +
                            character_character_affinities = character_character_affinities.sigmoid()
         | 
| 459 | 
            +
                        affinity_matrices.append(character_character_affinities)
         | 
| 460 | 
            +
                    return affinity_matrices
         | 
| 461 | 
            +
                
         | 
| 462 | 
            +
                def _get_text_character_affinity_matrices(
         | 
| 463 | 
            +
                        self,
         | 
| 464 | 
            +
                        character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
         | 
| 465 | 
            +
                        text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
         | 
| 466 | 
            +
                        t2c_tokens_for_batch: List[torch.FloatTensor] = None,
         | 
| 467 | 
            +
                        apply_sigmoid=True,
         | 
| 468 | 
            +
                ):
         | 
| 469 | 
            +
                    assert not self.config.disable_detections
         | 
| 470 | 
            +
                    assert character_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None and t2c_tokens_for_batch is not None
         | 
| 471 | 
            +
                    affinity_matrices = []
         | 
| 472 | 
            +
                    for character_obj_tokens, text_obj_tokens, t2c in zip(character_obj_tokens_for_batch, text_obj_tokens_for_this_batch, t2c_tokens_for_batch):
         | 
| 473 | 
            +
                        if character_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
         | 
| 474 | 
            +
                            affinity_matrices.append(torch.zeros(text_obj_tokens.shape[0], character_obj_tokens.shape[0]).type_as(character_obj_tokens))
         | 
| 475 | 
            +
                            continue
         | 
| 476 | 
            +
                        text_i = repeat(text_obj_tokens, "i d -> i repeat d", repeat=character_obj_tokens.shape[0])
         | 
| 477 | 
            +
                        char_j = repeat(character_obj_tokens, "j d -> repeat j d", repeat=text_obj_tokens.shape[0])
         | 
| 478 | 
            +
                        text_char = rearrange([text_i, char_j], "two i j d -> (i j) (two d)")
         | 
| 479 | 
            +
                        t2c = repeat(t2c, "d -> repeat d", repeat = text_char.shape[0])
         | 
| 480 | 
            +
                        text_char_t2c = torch.cat([text_char, t2c], dim=-1)
         | 
| 481 | 
            +
                        text_character_affinities = self.text_character_matching_head(text_char_t2c)
         | 
| 482 | 
            +
                        text_character_affinities = rearrange(text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
         | 
| 483 | 
            +
                        if apply_sigmoid:
         | 
| 484 | 
            +
                            text_character_affinities = text_character_affinities.sigmoid()
         | 
| 485 | 
            +
                        affinity_matrices.append(text_character_affinities)
         | 
| 486 | 
            +
                    return affinity_matrices
         | 
    	
        processing_magi.py
    ADDED
    
    | @@ -0,0 +1,274 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor
         | 
| 2 | 
            +
            from transformers.image_transforms import center_to_corners_format
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from typing import List
         | 
| 5 | 
            +
            from shapely.geometry import box
         | 
| 6 | 
            +
            from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order, x1y1x2y2_to_xywh
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            class MagiProcessor():
         | 
| 10 | 
            +
                def __init__(self, config):
         | 
| 11 | 
            +
                    self.config = config
         | 
| 12 | 
            +
                    self.detection_image_preprocessor = None
         | 
| 13 | 
            +
                    self.ocr_preprocessor = None
         | 
| 14 | 
            +
                    self.crop_embedding_image_preprocessor = None
         | 
| 15 | 
            +
                    if not config.disable_detections:
         | 
| 16 | 
            +
                        assert config.detection_image_preprocessing_config is not None
         | 
| 17 | 
            +
                        self.detection_image_preprocessor =  ConditionalDetrImageProcessor.from_dict(config.detection_image_preprocessing_config)
         | 
| 18 | 
            +
                    if not config.disable_ocr:
         | 
| 19 | 
            +
                        assert config.ocr_pretrained_processor_path is not None
         | 
| 20 | 
            +
                        self.ocr_preprocessor = TrOCRProcessor.from_pretrained(config.ocr_pretrained_processor_path)
         | 
| 21 | 
            +
                    if not config.disable_crop_embeddings:
         | 
| 22 | 
            +
                        assert config.crop_embedding_image_preprocessing_config is not None
         | 
| 23 | 
            +
                        self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config)
         | 
| 24 | 
            +
                
         | 
| 25 | 
            +
                def preprocess_inputs_for_detection(self, images, annotations=None):
         | 
| 26 | 
            +
                    images = list(images)
         | 
| 27 | 
            +
                    assert isinstance(images[0], np.ndarray)
         | 
| 28 | 
            +
                    annotations = self._convert_annotations_to_coco_format(annotations)
         | 
| 29 | 
            +
                    inputs = self.detection_image_preprocessor(images, annotations=annotations, return_tensors="pt")
         | 
| 30 | 
            +
                    return inputs
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def preprocess_inputs_for_ocr(self, images):
         | 
| 33 | 
            +
                    images = list(images)
         | 
| 34 | 
            +
                    assert isinstance(images[0], np.ndarray)
         | 
| 35 | 
            +
                    return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
         | 
| 36 | 
            +
                
         | 
| 37 | 
            +
                def preprocess_inputs_for_crop_embeddings(self, images):
         | 
| 38 | 
            +
                    images = list(images)
         | 
| 39 | 
            +
                    assert isinstance(images[0], np.ndarray)
         | 
| 40 | 
            +
                    return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
         | 
| 41 | 
            +
                
         | 
| 42 | 
            +
                def postprocess_detections_and_associations(
         | 
| 43 | 
            +
                        self,
         | 
| 44 | 
            +
                        predicted_bboxes,
         | 
| 45 | 
            +
                        predicted_class_scores,
         | 
| 46 | 
            +
                        original_image_sizes,
         | 
| 47 | 
            +
                        get_character_character_matching_scores,
         | 
| 48 | 
            +
                        get_text_character_matching_scores,
         | 
| 49 | 
            +
                        get_dialog_confidence_scores,
         | 
| 50 | 
            +
                        character_detection_threshold=0.3,
         | 
| 51 | 
            +
                        panel_detection_threshold=0.2,
         | 
| 52 | 
            +
                        text_detection_threshold=0.25,
         | 
| 53 | 
            +
                        character_character_matching_threshold=0.7,
         | 
| 54 | 
            +
                        text_character_matching_threshold=0.4,
         | 
| 55 | 
            +
                    ):
         | 
| 56 | 
            +
                    assert self.config.disable_detections is False
         | 
| 57 | 
            +
                    batch_scores, batch_labels = predicted_class_scores.max(-1)
         | 
| 58 | 
            +
                    batch_scores = batch_scores.sigmoid()
         | 
| 59 | 
            +
                    batch_labels = batch_labels.long()
         | 
| 60 | 
            +
                    batch_bboxes = center_to_corners_format(predicted_bboxes)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # scale the bboxes back to the original image size
         | 
| 63 | 
            +
                    if isinstance(original_image_sizes, List):
         | 
| 64 | 
            +
                        img_h = torch.Tensor([i[0] for i in original_image_sizes])
         | 
| 65 | 
            +
                        img_w = torch.Tensor([i[1] for i in original_image_sizes])
         | 
| 66 | 
            +
                    else:
         | 
| 67 | 
            +
                        img_h, img_w = original_image_sizes.unbind(1)
         | 
| 68 | 
            +
                    scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(batch_bboxes.device)
         | 
| 69 | 
            +
                    batch_bboxes = batch_bboxes * scale_fct[:, None, :]
         | 
| 70 | 
            +
                    
         | 
| 71 | 
            +
                    batch_panel_indices = self._get_indices_of_panels_to_keep(batch_scores, batch_labels, batch_bboxes, panel_detection_threshold)
         | 
| 72 | 
            +
                    batch_character_indices = self._get_indices_of_characters_to_keep(batch_scores, batch_labels, batch_bboxes, character_detection_threshold)
         | 
| 73 | 
            +
                    batch_text_indices = self._get_indices_of_texts_to_keep(batch_scores, batch_labels, batch_bboxes, text_detection_threshold)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    batch_character_character_matching_scores = get_character_character_matching_scores(batch_character_indices, batch_bboxes)
         | 
| 76 | 
            +
                    batch_text_character_matching_scores = get_text_character_matching_scores(batch_text_indices, batch_character_indices)
         | 
| 77 | 
            +
                    batch_dialog_confidence_scores = get_dialog_confidence_scores(batch_text_indices)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    # sort panels and texts in the reading order
         | 
| 80 | 
            +
                    for batch_index in range(len(batch_scores)):
         | 
| 81 | 
            +
                        panel_bboxes = batch_bboxes[batch_index][batch_panel_indices[batch_index]]
         | 
| 82 | 
            +
                        panel_scores = batch_scores[batch_index][batch_panel_indices[batch_index]]
         | 
| 83 | 
            +
                        text_bboxes = batch_bboxes[batch_index][batch_text_indices[batch_index]]
         | 
| 84 | 
            +
                        text_scores = batch_scores[batch_index][batch_text_indices[batch_index]]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                        sorted_panel_indices = sort_panels(panel_bboxes)
         | 
| 87 | 
            +
                        batch_bboxes[batch_index][batch_panel_indices[batch_index]] = panel_bboxes[sorted_panel_indices]
         | 
| 88 | 
            +
                        batch_scores[batch_index][batch_panel_indices[batch_index]] = panel_scores[sorted_panel_indices]
         | 
| 89 | 
            +
                        sorted_panels = batch_bboxes[batch_index][batch_panel_indices[batch_index]]
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        sorted_text_indices = sort_text_boxes_in_reading_order(text_bboxes, sorted_panels)
         | 
| 92 | 
            +
                        batch_bboxes[batch_index][batch_text_indices[batch_index]] = text_bboxes[sorted_text_indices]
         | 
| 93 | 
            +
                        batch_scores[batch_index][batch_text_indices[batch_index]] = text_scores[sorted_text_indices]
         | 
| 94 | 
            +
                        batch_text_character_matching_scores[batch_index] = batch_text_character_matching_scores[batch_index][sorted_text_indices]
         | 
| 95 | 
            +
                        batch_dialog_confidence_scores[batch_index] = batch_dialog_confidence_scores[batch_index][sorted_text_indices]
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    results = []
         | 
| 98 | 
            +
                    for batch_index in range(len(batch_scores)):
         | 
| 99 | 
            +
                        panel_bboxes = batch_bboxes[batch_index][batch_panel_indices[batch_index]]
         | 
| 100 | 
            +
                        panel_scores = batch_scores[batch_index][batch_panel_indices[batch_index]]
         | 
| 101 | 
            +
                        text_bboxes = batch_bboxes[batch_index][batch_text_indices[batch_index]]
         | 
| 102 | 
            +
                        text_scores = batch_scores[batch_index][batch_text_indices[batch_index]]
         | 
| 103 | 
            +
                        character_bboxes = batch_bboxes[batch_index][batch_character_indices[batch_index]]
         | 
| 104 | 
            +
                        character_scores = batch_scores[batch_index][batch_character_indices[batch_index]]
         | 
| 105 | 
            +
                        char_i, char_j = torch.where(batch_character_character_matching_scores[batch_index] > character_character_matching_threshold)
         | 
| 106 | 
            +
                        character_character_associations = torch.stack([char_i, char_j], dim=1)
         | 
| 107 | 
            +
                        text_boxes_to_match = batch_dialog_confidence_scores[batch_index] > text_character_matching_threshold
         | 
| 108 | 
            +
                        if 0 in batch_text_character_matching_scores[batch_index].shape:
         | 
| 109 | 
            +
                            text_character_associations = torch.zeros((0, 2), dtype=torch.long)
         | 
| 110 | 
            +
                        else:
         | 
| 111 | 
            +
                            most_likely_speaker_for_each_text = torch.argmax(batch_text_character_matching_scores[batch_index], dim=1)[text_boxes_to_match]
         | 
| 112 | 
            +
                            text_indices = torch.arange(len(text_bboxes)).type_as(most_likely_speaker_for_each_text)[text_boxes_to_match]
         | 
| 113 | 
            +
                            text_character_associations = torch.stack([text_indices, most_likely_speaker_for_each_text], dim=1)
         | 
| 114 | 
            +
                        
         | 
| 115 | 
            +
                        character_ufds = UnionFind.from_adj_matrix(
         | 
| 116 | 
            +
                            batch_character_character_matching_scores[batch_index] > character_character_matching_threshold
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
                        results.append({
         | 
| 119 | 
            +
                            "panels": panel_bboxes.tolist(),
         | 
| 120 | 
            +
                            "panel_scores": panel_scores.tolist(),
         | 
| 121 | 
            +
                            "texts": text_bboxes.tolist(),
         | 
| 122 | 
            +
                            "text_scores": text_scores.tolist(),
         | 
| 123 | 
            +
                            "characters": character_bboxes.tolist(),
         | 
| 124 | 
            +
                            "character_scores": character_scores.tolist(),
         | 
| 125 | 
            +
                            "character_character_associations": character_character_associations.tolist(),
         | 
| 126 | 
            +
                            "text_character_associations": text_character_associations.tolist(),
         | 
| 127 | 
            +
                            "character_cluster_labels": character_ufds.get_labels_for_connected_components(),
         | 
| 128 | 
            +
                            "dialog_confidences": batch_dialog_confidence_scores[batch_index].tolist(),
         | 
| 129 | 
            +
                        })
         | 
| 130 | 
            +
                    return results
         | 
| 131 | 
            +
                
         | 
| 132 | 
            +
                def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
         | 
| 133 | 
            +
                    return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
         | 
| 134 | 
            +
                
         | 
| 135 | 
            +
                def crop_image(self, image, bboxes):
         | 
| 136 | 
            +
                    crops_for_image = []
         | 
| 137 | 
            +
                    for bbox in bboxes:
         | 
| 138 | 
            +
                        x1, y1, x2, y2 = bbox
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                        # fix the bounding box in case it is out of bounds or too small
         | 
| 141 | 
            +
                        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
         | 
| 142 | 
            +
                        x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2) # just incase
         | 
| 143 | 
            +
                        x1, y1 = max(0, x1), max(0, y1)
         | 
| 144 | 
            +
                        x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
         | 
| 145 | 
            +
                        x2, y2 = max(0, x2), max(0, y2)
         | 
| 146 | 
            +
                        x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
         | 
| 147 | 
            +
                        if x2 - x1 < 10:
         | 
| 148 | 
            +
                            if image.shape[1] - x1 > 10:
         | 
| 149 | 
            +
                                x2 = x1 + 10
         | 
| 150 | 
            +
                            else:
         | 
| 151 | 
            +
                                x1 = x2 - 10
         | 
| 152 | 
            +
                        if y2 - y1 < 10:
         | 
| 153 | 
            +
                            if image.shape[0] - y1 > 10:
         | 
| 154 | 
            +
                                y2 = y1 + 10
         | 
| 155 | 
            +
                            else:
         | 
| 156 | 
            +
                                y1 = y2 - 10
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                        crop = image[y1:y2, x1:x2]
         | 
| 159 | 
            +
                        crops_for_image.append(crop)
         | 
| 160 | 
            +
                    return crops_for_image
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
         | 
| 163 | 
            +
                    indices_of_characters_to_keep = []
         | 
| 164 | 
            +
                    for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
         | 
| 165 | 
            +
                        indices = torch.where((labels == 0) & (scores > character_detection_threshold))[0]
         | 
| 166 | 
            +
                        indices_of_characters_to_keep.append(indices)
         | 
| 167 | 
            +
                    return indices_of_characters_to_keep
         | 
| 168 | 
            +
                
         | 
| 169 | 
            +
                def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
         | 
| 170 | 
            +
                    indices_of_panels_to_keep = []
         | 
| 171 | 
            +
                    for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
         | 
| 172 | 
            +
                        indices = torch.where(labels == 2)[0]
         | 
| 173 | 
            +
                        bboxes = bboxes[indices]
         | 
| 174 | 
            +
                        scores = scores[indices]
         | 
| 175 | 
            +
                        labels = labels[indices]
         | 
| 176 | 
            +
                        if len(indices) == 0:
         | 
| 177 | 
            +
                            indices_of_panels_to_keep.append([])
         | 
| 178 | 
            +
                            continue
         | 
| 179 | 
            +
                        scores, labels, indices, bboxes  = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
         | 
| 180 | 
            +
                        panels_to_keep = []
         | 
| 181 | 
            +
                        union_of_panels_so_far = box(0, 0, 0, 0)
         | 
| 182 | 
            +
                        for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
         | 
| 183 | 
            +
                            panel_polygon = box(pb[0], pb[1], pb[2], pb[3])
         | 
| 184 | 
            +
                            if ps < panel_detection_threshold:
         | 
| 185 | 
            +
                                continue
         | 
| 186 | 
            +
                            if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
         | 
| 187 | 
            +
                                continue
         | 
| 188 | 
            +
                            panels_to_keep.append((ps, pl, pb, pi))
         | 
| 189 | 
            +
                            union_of_panels_so_far = union_of_panels_so_far.union(panel_polygon)
         | 
| 190 | 
            +
                        indices_of_panels_to_keep.append([p[3].item() for p in panels_to_keep])
         | 
| 191 | 
            +
                    return indices_of_panels_to_keep
         | 
| 192 | 
            +
                
         | 
| 193 | 
            +
                def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
         | 
| 194 | 
            +
                    indices_of_texts_to_keep = []
         | 
| 195 | 
            +
                    for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
         | 
| 196 | 
            +
                        indices = torch.where((labels == 1) & (scores > text_detection_threshold))[0]
         | 
| 197 | 
            +
                        bboxes = bboxes[indices]
         | 
| 198 | 
            +
                        scores = scores[indices]
         | 
| 199 | 
            +
                        labels = labels[indices]
         | 
| 200 | 
            +
                        if len(indices) == 0:
         | 
| 201 | 
            +
                            indices_of_texts_to_keep.append([])
         | 
| 202 | 
            +
                            continue
         | 
| 203 | 
            +
                        scores, labels, indices, bboxes  = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
         | 
| 204 | 
            +
                        texts_to_keep = []
         | 
| 205 | 
            +
                        texts_to_keep_as_shapely_objects = []
         | 
| 206 | 
            +
                        for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
         | 
| 207 | 
            +
                            text_polygon = box(tb[0], tb[1], tb[2], tb[3])
         | 
| 208 | 
            +
                            should_append = True
         | 
| 209 | 
            +
                            for t in texts_to_keep_as_shapely_objects:
         | 
| 210 | 
            +
                                if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
         | 
| 211 | 
            +
                                    should_append = False
         | 
| 212 | 
            +
                                    break
         | 
| 213 | 
            +
                            if should_append:
         | 
| 214 | 
            +
                                texts_to_keep.append((ts, tl, tb, ti))
         | 
| 215 | 
            +
                                texts_to_keep_as_shapely_objects.append(text_polygon)
         | 
| 216 | 
            +
                        indices_of_texts_to_keep.append([t[3].item() for t in texts_to_keep])
         | 
| 217 | 
            +
                    return indices_of_texts_to_keep
         | 
| 218 | 
            +
                    
         | 
| 219 | 
            +
                def _convert_annotations_to_coco_format(self, annotations):
         | 
| 220 | 
            +
                    if annotations is None:
         | 
| 221 | 
            +
                        return None
         | 
| 222 | 
            +
                    self._verify_annotations_are_in_correct_format(annotations)
         | 
| 223 | 
            +
                    coco_annotations = []
         | 
| 224 | 
            +
                    for annotation in annotations:
         | 
| 225 | 
            +
                        coco_annotation = {
         | 
| 226 | 
            +
                            "image_id": annotation["image_id"],
         | 
| 227 | 
            +
                            "annotations": [],
         | 
| 228 | 
            +
                        }
         | 
| 229 | 
            +
                        for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]):
         | 
| 230 | 
            +
                            coco_annotation["annotations"].append({
         | 
| 231 | 
            +
                                "bbox": x1y1x2y2_to_xywh(bbox),
         | 
| 232 | 
            +
                                "category_id": label,
         | 
| 233 | 
            +
                                "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
         | 
| 234 | 
            +
                            })
         | 
| 235 | 
            +
                        coco_annotations.append(coco_annotation)
         | 
| 236 | 
            +
                    return coco_annotations
         | 
| 237 | 
            +
                
         | 
| 238 | 
            +
                def _verify_annotations_are_in_correct_format(self, annotations):
         | 
| 239 | 
            +
                    error_msg = """
         | 
| 240 | 
            +
                    Annotations must be in the following format:
         | 
| 241 | 
            +
                    [
         | 
| 242 | 
            +
                        {
         | 
| 243 | 
            +
                            "image_id": 0,
         | 
| 244 | 
            +
                            "bboxes_as_x1y1x2y2": [[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]],
         | 
| 245 | 
            +
                            "labels": [0, 1, 2],
         | 
| 246 | 
            +
                        },
         | 
| 247 | 
            +
                        ...
         | 
| 248 | 
            +
                    ]
         | 
| 249 | 
            +
                    Labels: 0 for characters, 1 for text, 2 for panels.
         | 
| 250 | 
            +
                    """
         | 
| 251 | 
            +
                    if annotations is None:
         | 
| 252 | 
            +
                        return
         | 
| 253 | 
            +
                    if not isinstance(annotations, List) and not isinstance(annotations, tuple):
         | 
| 254 | 
            +
                        raise ValueError(
         | 
| 255 | 
            +
                            f"{error_msg} Expected a List/Tuple, found {type(annotations)}."
         | 
| 256 | 
            +
                        )
         | 
| 257 | 
            +
                    if len(annotations) == 0:
         | 
| 258 | 
            +
                        return
         | 
| 259 | 
            +
                    if not isinstance(annotations[0], dict):
         | 
| 260 | 
            +
                        raise ValueError(
         | 
| 261 | 
            +
                            f"{error_msg} Expected a List[Dict], found {type(annotations[0])}."
         | 
| 262 | 
            +
                        )
         | 
| 263 | 
            +
                    if "image_id" not in annotations[0]:
         | 
| 264 | 
            +
                        raise ValueError(
         | 
| 265 | 
            +
                            f"{error_msg} Dict must contain 'image_id'."
         | 
| 266 | 
            +
                        )
         | 
| 267 | 
            +
                    if "bboxes_as_x1y1x2y2" not in annotations[0]:
         | 
| 268 | 
            +
                        raise ValueError(
         | 
| 269 | 
            +
                            f"{error_msg} Dict must contain 'bboxes_as_x1y1x2y2'."
         | 
| 270 | 
            +
                        )
         | 
| 271 | 
            +
                    if "labels" not in annotations[0]:
         | 
| 272 | 
            +
                        raise ValueError(
         | 
| 273 | 
            +
                            f"{error_msg} Dict must contain 'labels'."
         | 
| 274 | 
            +
                        )
         | 
    	
        pytorch_model.bin
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:219c2b80e741b1d02e92f22701a38358a5606d6460ad8b6335091e909b212011
         | 
| 3 | 
            +
            size 2063428286
         | 
    	
        utils.py
    ADDED
    
    | @@ -0,0 +1,391 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import matplotlib.pyplot as plt
         | 
| 5 | 
            +
            import matplotlib.patches as patches
         | 
| 6 | 
            +
            from shapely.geometry import Point, box
         | 
| 7 | 
            +
            import networkx as nx
         | 
| 8 | 
            +
            from copy import deepcopy
         | 
| 9 | 
            +
            from itertools import groupby
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            def move_to_device(inputs, device):
         | 
| 12 | 
            +
                if hasattr(inputs, "keys"):
         | 
| 13 | 
            +
                    return {k: move_to_device(v, device) for k, v in inputs.items()}
         | 
| 14 | 
            +
                elif isinstance(inputs, list):
         | 
| 15 | 
            +
                    return [move_to_device(v, device) for v in inputs]
         | 
| 16 | 
            +
                elif isinstance(inputs, tuple):
         | 
| 17 | 
            +
                    return tuple([move_to_device(v, device) for v in inputs])
         | 
| 18 | 
            +
                elif isinstance(inputs, np.ndarray):
         | 
| 19 | 
            +
                    return torch.from_numpy(inputs).to(device)
         | 
| 20 | 
            +
                else:
         | 
| 21 | 
            +
                    return inputs.to(device)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            class UnionFind:
         | 
| 24 | 
            +
                def __init__(self, n):
         | 
| 25 | 
            +
                    self.parent = list(range(n))
         | 
| 26 | 
            +
                    self.size = [1] * n
         | 
| 27 | 
            +
                    self.num_components = n
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @classmethod
         | 
| 30 | 
            +
                def from_adj_matrix(cls, adj_matrix):
         | 
| 31 | 
            +
                    ufds = cls(adj_matrix.shape[0])
         | 
| 32 | 
            +
                    for i in range(adj_matrix.shape[0]):
         | 
| 33 | 
            +
                        for j in range(adj_matrix.shape[1]):
         | 
| 34 | 
            +
                            if adj_matrix[i, j] > 0:
         | 
| 35 | 
            +
                                ufds.unite(i, j)
         | 
| 36 | 
            +
                    return ufds
         | 
| 37 | 
            +
                
         | 
| 38 | 
            +
                @classmethod
         | 
| 39 | 
            +
                def from_adj_list(cls, adj_list):
         | 
| 40 | 
            +
                    ufds = cls(len(adj_list))
         | 
| 41 | 
            +
                    for i in range(len(adj_list)):
         | 
| 42 | 
            +
                        for j in adj_list[i]:
         | 
| 43 | 
            +
                            ufds.unite(i, j)
         | 
| 44 | 
            +
                    return ufds
         | 
| 45 | 
            +
                
         | 
| 46 | 
            +
                @classmethod
         | 
| 47 | 
            +
                def from_edge_list(cls, edge_list, num_nodes):
         | 
| 48 | 
            +
                    ufds = cls(num_nodes)
         | 
| 49 | 
            +
                    for edge in edge_list:
         | 
| 50 | 
            +
                        ufds.unite(edge[0], edge[1])
         | 
| 51 | 
            +
                    return ufds
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def find(self, x):
         | 
| 54 | 
            +
                    if self.parent[x] == x:
         | 
| 55 | 
            +
                        return x
         | 
| 56 | 
            +
                    self.parent[x] = self.find(self.parent[x])
         | 
| 57 | 
            +
                    return self.parent[x]
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def unite(self, x, y):
         | 
| 60 | 
            +
                    x = self.find(x)
         | 
| 61 | 
            +
                    y = self.find(y)
         | 
| 62 | 
            +
                    if x != y:
         | 
| 63 | 
            +
                        if self.size[x] < self.size[y]:
         | 
| 64 | 
            +
                            x, y = y, x
         | 
| 65 | 
            +
                        self.parent[y] = x
         | 
| 66 | 
            +
                        self.size[x] += self.size[y]
         | 
| 67 | 
            +
                        self.num_components -= 1
         | 
| 68 | 
            +
                
         | 
| 69 | 
            +
                def get_components_of(self, x):
         | 
| 70 | 
            +
                    x = self.find(x)
         | 
| 71 | 
            +
                    return [i for i in range(len(self.parent)) if self.find(i) == x]
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                def are_connected(self, x, y):
         | 
| 74 | 
            +
                    return self.find(x) == self.find(y)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def get_size(self, x):
         | 
| 77 | 
            +
                    return self.size[self.find(x)]
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def get_num_components(self):
         | 
| 80 | 
            +
                    return self.num_components
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                def get_labels_for_connected_components(self):
         | 
| 83 | 
            +
                    map_parent_to_label = {}
         | 
| 84 | 
            +
                    labels = []
         | 
| 85 | 
            +
                    for i in range(len(self.parent)):
         | 
| 86 | 
            +
                        parent = self.find(i)
         | 
| 87 | 
            +
                        if parent not in map_parent_to_label:
         | 
| 88 | 
            +
                            map_parent_to_label[parent] = len(map_parent_to_label)
         | 
| 89 | 
            +
                        labels.append(map_parent_to_label[parent])
         | 
| 90 | 
            +
                    return labels
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            def visualise_single_image_prediction(image_as_np_array, predictions, filename):
         | 
| 93 | 
            +
                figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
         | 
| 94 | 
            +
                subplot.imshow(image_as_np_array)
         | 
| 95 | 
            +
                plot_bboxes(subplot, predictions["panels"], color="green")
         | 
| 96 | 
            +
                plot_bboxes(subplot, predictions["texts"], color="red", add_index=True)
         | 
| 97 | 
            +
                plot_bboxes(subplot, predictions["characters"], color="blue")
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                COLOURS = [
         | 
| 100 | 
            +
                    "#b7ff51", # green
         | 
| 101 | 
            +
                    "#f50a8f", # pink
         | 
| 102 | 
            +
                    "#4b13b6", # purple
         | 
| 103 | 
            +
                    "#ddaa34", # orange
         | 
| 104 | 
            +
                    "#bea2a2", # brown
         | 
| 105 | 
            +
                ]
         | 
| 106 | 
            +
                colour_index = 0
         | 
| 107 | 
            +
                character_cluster_labels = predictions["character_cluster_labels"]
         | 
| 108 | 
            +
                unique_label_sorted_by_frequency = sorted(list(set(character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
         | 
| 109 | 
            +
                for label in unique_label_sorted_by_frequency:
         | 
| 110 | 
            +
                    root = None
         | 
| 111 | 
            +
                    others = []
         | 
| 112 | 
            +
                    for i in range(len(predictions["characters"])):
         | 
| 113 | 
            +
                        if character_cluster_labels[i] == label:
         | 
| 114 | 
            +
                            if root is None:
         | 
| 115 | 
            +
                                root = i
         | 
| 116 | 
            +
                            else:
         | 
| 117 | 
            +
                                others.append(i)
         | 
| 118 | 
            +
                    if colour_index >= len(COLOURS):
         | 
| 119 | 
            +
                        random_colour = COLOURS[0]
         | 
| 120 | 
            +
                        while random_colour in COLOURS:
         | 
| 121 | 
            +
                            random_colour = "#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
         | 
| 122 | 
            +
                    else:
         | 
| 123 | 
            +
                        random_colour = COLOURS[colour_index]
         | 
| 124 | 
            +
                        colour_index += 1
         | 
| 125 | 
            +
                    bbox_i = predictions["characters"][root]
         | 
| 126 | 
            +
                    x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
         | 
| 127 | 
            +
                    y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
         | 
| 128 | 
            +
                    subplot.plot([x1], [y1], color=random_colour, marker="o", markersize=5)
         | 
| 129 | 
            +
                    for j in others:
         | 
| 130 | 
            +
                        # draw line from centre of bbox i to centre of bbox j
         | 
| 131 | 
            +
                        bbox_j = predictions["characters"][j]
         | 
| 132 | 
            +
                        x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
         | 
| 133 | 
            +
                        y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
         | 
| 134 | 
            +
                        x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
         | 
| 135 | 
            +
                        y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
         | 
| 136 | 
            +
                        subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
         | 
| 137 | 
            +
                        subplot.plot([x2], [y2], color=random_colour, marker="o", markersize=5)
         | 
| 138 | 
            +
                
         | 
| 139 | 
            +
                for (i, j) in predictions["text_character_associations"]:
         | 
| 140 | 
            +
                    score = predictions["dialog_confidences"][i]
         | 
| 141 | 
            +
                    bbox_i = predictions["texts"][i]
         | 
| 142 | 
            +
                    bbox_j = predictions["characters"][j]
         | 
| 143 | 
            +
                    x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
         | 
| 144 | 
            +
                    y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
         | 
| 145 | 
            +
                    x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
         | 
| 146 | 
            +
                    y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
         | 
| 147 | 
            +
                    subplot.plot([x1, x2], [y1, y2], color="red", linewidth=2, linestyle="dashed", alpha=score)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                subplot.axis("off")
         | 
| 150 | 
            +
                if filename is not None:
         | 
| 151 | 
            +
                    plt.savefig(filename, bbox_inches="tight", pad_inches=0)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                figure.canvas.draw()
         | 
| 154 | 
            +
                image = np.array(figure.canvas.renderer._renderer)
         | 
| 155 | 
            +
                plt.close()
         | 
| 156 | 
            +
                return image
         | 
| 157 | 
            +
             | 
| 158 | 
            +
            def plot_bboxes(subplot, bboxes, color="red", add_index=False):
         | 
| 159 | 
            +
                for id, bbox in enumerate(bboxes):
         | 
| 160 | 
            +
                    w = bbox[2] - bbox[0]
         | 
| 161 | 
            +
                    h = bbox[3] - bbox[1]
         | 
| 162 | 
            +
                    rect = patches.Rectangle(
         | 
| 163 | 
            +
                        bbox[:2], w, h, linewidth=1, edgecolor=color, facecolor="none", linestyle="solid"
         | 
| 164 | 
            +
                    )
         | 
| 165 | 
            +
                    subplot.add_patch(rect)
         | 
| 166 | 
            +
                    if add_index:
         | 
| 167 | 
            +
                        cx, cy = bbox[0] + w / 2, bbox[1] + h / 2
         | 
| 168 | 
            +
                        subplot.text(cx, cy, str(id), color=color, fontsize=10, ha="center", va="center")
         | 
| 169 | 
            +
             | 
| 170 | 
            +
            def sort_panels(rects):
         | 
| 171 | 
            +
                before_rects = convert_to_list_of_lists(rects)
         | 
| 172 | 
            +
                # slightly erode all rectangles initially to account for imperfect detections
         | 
| 173 | 
            +
                rects = [erode_rectangle(rect, 0.05) for rect in before_rects]
         | 
| 174 | 
            +
                G = nx.DiGraph()
         | 
| 175 | 
            +
                G.add_nodes_from(range(len(rects)))
         | 
| 176 | 
            +
                for i in range(len(rects)):
         | 
| 177 | 
            +
                    for j in range(len(rects)):
         | 
| 178 | 
            +
                        if i == j:
         | 
| 179 | 
            +
                            continue
         | 
| 180 | 
            +
                        if is_there_a_directed_edge(i, j, rects):
         | 
| 181 | 
            +
                            G.add_edge(i, j, weight=get_distance(rects[i], rects[j]))
         | 
| 182 | 
            +
                        else:
         | 
| 183 | 
            +
                            G.add_edge(j, i, weight=get_distance(rects[i], rects[j]))
         | 
| 184 | 
            +
                while True:
         | 
| 185 | 
            +
                    cycles = sorted(nx.simple_cycles(G))
         | 
| 186 | 
            +
                    cycles = [cycle for cycle in cycles if len(cycle) > 1]
         | 
| 187 | 
            +
                    if len(cycles) == 0:
         | 
| 188 | 
            +
                        break
         | 
| 189 | 
            +
                    cycle = cycles[0]
         | 
| 190 | 
            +
                    edges = [e for e in zip(cycle, cycle[1:] + cycle[:1])]
         | 
| 191 | 
            +
                    max_cyclic_edge = max(edges, key=lambda x: G.edges[x]["weight"])
         | 
| 192 | 
            +
                    G.remove_edge(*max_cyclic_edge)
         | 
| 193 | 
            +
                return list(nx.topological_sort(G))
         | 
| 194 | 
            +
             | 
| 195 | 
            +
            def is_strictly_above(rectA, rectB):
         | 
| 196 | 
            +
                x1A, y1A, x2A, y2A = rectA
         | 
| 197 | 
            +
                x1B, y1B, x2B, y2B = rectB
         | 
| 198 | 
            +
                return y2A < y1B
         | 
| 199 | 
            +
             | 
| 200 | 
            +
            def is_strictly_below(rectA, rectB):
         | 
| 201 | 
            +
                x1A, y1A, x2A, y2A = rectA
         | 
| 202 | 
            +
                x1B, y1B, x2B, y2B = rectB
         | 
| 203 | 
            +
                return y2B < y1A
         | 
| 204 | 
            +
             | 
| 205 | 
            +
            def is_strictly_left_of(rectA, rectB):
         | 
| 206 | 
            +
                x1A, y1A, x2A, y2A = rectA
         | 
| 207 | 
            +
                x1B, y1B, x2B, y2B = rectB
         | 
| 208 | 
            +
                return x2A < x1B
         | 
| 209 | 
            +
             | 
| 210 | 
            +
            def is_strictly_right_of(rectA, rectB):
         | 
| 211 | 
            +
                x1A, y1A, x2A, y2A = rectA
         | 
| 212 | 
            +
                x1B, y1B, x2B, y2B = rectB
         | 
| 213 | 
            +
                return x2B < x1A
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            def intersects(rectA, rectB):
         | 
| 216 | 
            +
                return box(*rectA).intersects(box(*rectB))
         | 
| 217 | 
            +
             | 
| 218 | 
            +
            def is_there_a_directed_edge(a, b, rects):
         | 
| 219 | 
            +
                rectA = rects[a]
         | 
| 220 | 
            +
                rectB = rects[b]
         | 
| 221 | 
            +
                centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2, rectA[1] + (rectA[3] - rectA[1]) / 2]
         | 
| 222 | 
            +
                centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2, rectB[1] + (rectB[3] - rectB[1]) / 2]
         | 
| 223 | 
            +
                if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
         | 
| 224 | 
            +
                    return box(*rectA).area > (box(*rectB)).area
         | 
| 225 | 
            +
                copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
         | 
| 226 | 
            +
                copy_B = [rectB[0], rectB[1], rectB[2], rectB[3]]
         | 
| 227 | 
            +
                while True:
         | 
| 228 | 
            +
                    if is_strictly_above(copy_A, copy_B) and not is_strictly_left_of(copy_A, copy_B):
         | 
| 229 | 
            +
                        return 1
         | 
| 230 | 
            +
                    if is_strictly_above(copy_B, copy_A) and not is_strictly_left_of(copy_B, copy_A):
         | 
| 231 | 
            +
                        return 0
         | 
| 232 | 
            +
                    if is_strictly_right_of(copy_A, copy_B) and not is_strictly_below(copy_A, copy_B):
         | 
| 233 | 
            +
                        return 1
         | 
| 234 | 
            +
                    if is_strictly_right_of(copy_B, copy_A) and not is_strictly_below(copy_B, copy_A):
         | 
| 235 | 
            +
                        return 0
         | 
| 236 | 
            +
                    if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
         | 
| 237 | 
            +
                        return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
         | 
| 238 | 
            +
                    if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
         | 
| 239 | 
            +
                       return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
         | 
| 240 | 
            +
                    # otherwise they intersect
         | 
| 241 | 
            +
                    copy_A = erode_rectangle(copy_A, 0.05)
         | 
| 242 | 
            +
                    copy_B = erode_rectangle(copy_B, 0.05)
         | 
| 243 | 
            +
                
         | 
| 244 | 
            +
            def get_distance(rectA, rectB):
         | 
| 245 | 
            +
                return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
         | 
| 246 | 
            +
             | 
| 247 | 
            +
            def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
         | 
| 248 | 
            +
                rects = deepcopy(rects)
         | 
| 249 | 
            +
                while True:
         | 
| 250 | 
            +
                    xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
         | 
| 251 | 
            +
                    rect_index = [i for i in range(len(rects)) if intersects(rects[i], [xmin, ymin, xmax, ymax])]
         | 
| 252 | 
            +
                    rects_copy = [rect for rect in rects if intersects(rect, [xmin, ymin, xmax, ymax])]
         | 
| 253 | 
            +
                    
         | 
| 254 | 
            +
                    # try to split the panels using a "horizontal" lines
         | 
| 255 | 
            +
                    overlapping_y_ranges = merge_overlapping_ranges([(y1, y2) for x1, y1, x2, y2 in rects_copy])
         | 
| 256 | 
            +
                    panel_index_to_split = {}
         | 
| 257 | 
            +
                    for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
         | 
| 258 | 
            +
                        for i, index in enumerate(rect_index):
         | 
| 259 | 
            +
                            if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
         | 
| 260 | 
            +
                                panel_index_to_split[index] = split_index
         | 
| 261 | 
            +
                    
         | 
| 262 | 
            +
                    if panel_index_to_split[a] != panel_index_to_split[b]:
         | 
| 263 | 
            +
                        return panel_index_to_split[a] < panel_index_to_split[b]
         | 
| 264 | 
            +
                    
         | 
| 265 | 
            +
                    # try to split the panels using a "vertical" lines
         | 
| 266 | 
            +
                    overlapping_x_ranges = merge_overlapping_ranges([(x1, x2) for x1, y1, x2, y2 in rects_copy])
         | 
| 267 | 
            +
                    panel_index_to_split = {}
         | 
| 268 | 
            +
                    for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
         | 
| 269 | 
            +
                        for i, index in enumerate(rect_index):
         | 
| 270 | 
            +
                            if x1 <= rects_copy[i][0] <= rects_copy[i][2] <= x2:
         | 
| 271 | 
            +
                                panel_index_to_split[index] = split_index
         | 
| 272 | 
            +
                    if panel_index_to_split[a] != panel_index_to_split[b]:
         | 
| 273 | 
            +
                        return panel_index_to_split[a] < panel_index_to_split[b]
         | 
| 274 | 
            +
                    
         | 
| 275 | 
            +
                    # otherwise, erode the rectangles and try again
         | 
| 276 | 
            +
                    rects = [erode_rectangle(rect, 0.05) for rect in rects]
         | 
| 277 | 
            +
             | 
| 278 | 
            +
            def erode_rectangle(bbox, erosion_factor):
         | 
| 279 | 
            +
                x1, y1, x2, y2 = bbox
         | 
| 280 | 
            +
                w, h = x2 - x1, y2 - y1
         | 
| 281 | 
            +
                cx, cy = x1 + w / 2, y1 + h / 2
         | 
| 282 | 
            +
                if w < h:
         | 
| 283 | 
            +
                    aspect_ratio = w / h
         | 
| 284 | 
            +
                    erosion_factor_width = erosion_factor * aspect_ratio
         | 
| 285 | 
            +
                    erosion_factor_height = erosion_factor
         | 
| 286 | 
            +
                else:
         | 
| 287 | 
            +
                    aspect_ratio = h / w
         | 
| 288 | 
            +
                    erosion_factor_width = erosion_factor
         | 
| 289 | 
            +
                    erosion_factor_height = erosion_factor * aspect_ratio
         | 
| 290 | 
            +
                w = w - w * erosion_factor_width
         | 
| 291 | 
            +
                h = h - h * erosion_factor_height
         | 
| 292 | 
            +
                x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
         | 
| 293 | 
            +
                return [x1, y1, x2, y2]
         | 
| 294 | 
            +
             | 
| 295 | 
            +
            def merge_overlapping_ranges(ranges):
         | 
| 296 | 
            +
                """
         | 
| 297 | 
            +
                ranges: list of tuples (x1, x2)
         | 
| 298 | 
            +
                """
         | 
| 299 | 
            +
                if len(ranges) == 0:
         | 
| 300 | 
            +
                    return []
         | 
| 301 | 
            +
                ranges = sorted(ranges, key=lambda x: x[0])
         | 
| 302 | 
            +
                merged_ranges = []
         | 
| 303 | 
            +
                for i, r in enumerate(ranges):
         | 
| 304 | 
            +
                    if i == 0:
         | 
| 305 | 
            +
                        prev_x1, prev_x2 = r
         | 
| 306 | 
            +
                        continue
         | 
| 307 | 
            +
                    x1, x2 = r
         | 
| 308 | 
            +
                    if x1 > prev_x2:
         | 
| 309 | 
            +
                        merged_ranges.append((prev_x1, prev_x2))
         | 
| 310 | 
            +
                        prev_x1, prev_x2 = x1, x2
         | 
| 311 | 
            +
                    else:
         | 
| 312 | 
            +
                        prev_x2 = max(prev_x2, x2)
         | 
| 313 | 
            +
                merged_ranges.append((prev_x1, prev_x2))
         | 
| 314 | 
            +
                return merged_ranges
         | 
| 315 | 
            +
             | 
| 316 | 
            +
            def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
         | 
| 317 | 
            +
                text_bboxes = convert_to_list_of_lists(text_bboxes)
         | 
| 318 | 
            +
                sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                if len(text_bboxes) == 0:
         | 
| 321 | 
            +
                    return []
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                def indices_of_same_elements(nums):
         | 
| 324 | 
            +
                    groups = groupby(range(len(nums)), key=lambda i: nums[i])
         | 
| 325 | 
            +
                    return [list(indices) for _, indices in groups]
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                panel_id_for_text = get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes)
         | 
| 328 | 
            +
                indices_of_texts = list(range(len(text_bboxes)))
         | 
| 329 | 
            +
                indices_of_texts, panel_id_for_text = zip(*sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
         | 
| 330 | 
            +
                indices_of_texts = list(indices_of_texts)
         | 
| 331 | 
            +
                grouped_indices = indices_of_same_elements(panel_id_for_text)
         | 
| 332 | 
            +
                for group in grouped_indices:
         | 
| 333 | 
            +
                    subset_of_text_indices = [indices_of_texts[i] for i in group]
         | 
| 334 | 
            +
                    text_bboxes_of_subset = [text_bboxes[i] for i in subset_of_text_indices]
         | 
| 335 | 
            +
                    sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
         | 
| 336 | 
            +
                    indices_of_texts[group[0] : group[-1] + 1] = [subset_of_text_indices[i] for i in sorted_subset_indices]
         | 
| 337 | 
            +
                return indices_of_texts
         | 
| 338 | 
            +
             | 
| 339 | 
            +
            def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
         | 
| 340 | 
            +
                text_to_panel_mapping = []
         | 
| 341 | 
            +
                for text_bbox in text_bboxes:
         | 
| 342 | 
            +
                    shapely_text_polygon = box(*text_bbox)
         | 
| 343 | 
            +
                    all_intersections = []
         | 
| 344 | 
            +
                    all_distances = []
         | 
| 345 | 
            +
                    if len(sorted_panel_bboxes) == 0:
         | 
| 346 | 
            +
                        text_to_panel_mapping.append(-1)
         | 
| 347 | 
            +
                        continue
         | 
| 348 | 
            +
                    for j, annotation in enumerate(sorted_panel_bboxes):
         | 
| 349 | 
            +
                        shapely_annotation_polygon = box(*annotation)
         | 
| 350 | 
            +
                        if shapely_text_polygon.intersects(shapely_annotation_polygon):
         | 
| 351 | 
            +
                            all_intersections.append((shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
         | 
| 352 | 
            +
                        all_distances.append((shapely_text_polygon.distance(shapely_annotation_polygon), j))
         | 
| 353 | 
            +
                    if len(all_intersections) == 0:
         | 
| 354 | 
            +
                        text_to_panel_mapping.append(min(all_distances, key=lambda x: x[0])[1])
         | 
| 355 | 
            +
                    else:
         | 
| 356 | 
            +
                        text_to_panel_mapping.append(max(all_intersections, key=lambda x: x[0])[1])
         | 
| 357 | 
            +
                return text_to_panel_mapping
         | 
| 358 | 
            +
             | 
| 359 | 
            +
            def sort_texts_within_panel(rects):
         | 
| 360 | 
            +
                smallest_y = float("inf")
         | 
| 361 | 
            +
                greatest_x = float("-inf")
         | 
| 362 | 
            +
                for i, rect in enumerate(rects):
         | 
| 363 | 
            +
                    x1, y1, x2, y2 = rect
         | 
| 364 | 
            +
                    smallest_y = min(smallest_y, y1)
         | 
| 365 | 
            +
                    greatest_x = max(greatest_x, x2)
         | 
| 366 | 
            +
                
         | 
| 367 | 
            +
                reference_point = Point(greatest_x, smallest_y)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                polygons_and_index = []
         | 
| 370 | 
            +
                for i, rect in enumerate(rects):
         | 
| 371 | 
            +
                    x1, y1, x2, y2 = rect
         | 
| 372 | 
            +
                    polygons_and_index.append((box(x1,y1,x2,y2), i))
         | 
| 373 | 
            +
                # sort points by closest to reference point
         | 
| 374 | 
            +
                polygons_and_index = sorted(polygons_and_index, key=lambda x: reference_point.distance(x[0]))
         | 
| 375 | 
            +
                indices = [x[1] for x in polygons_and_index]
         | 
| 376 | 
            +
                return indices
         | 
| 377 | 
            +
             | 
| 378 | 
            +
            def x1y1wh_to_x1y1x2y2(bbox):
         | 
| 379 | 
            +
                x1, y1, w, h = bbox
         | 
| 380 | 
            +
                return [x1, y1, x1 + w, y1 + h]
         | 
| 381 | 
            +
             | 
| 382 | 
            +
            def x1y1x2y2_to_xywh(bbox):
         | 
| 383 | 
            +
                x1, y1, x2, y2 = bbox
         | 
| 384 | 
            +
                return [x1, y1, x2 - x1, y2 - y1]
         | 
| 385 | 
            +
             | 
| 386 | 
            +
            def convert_to_list_of_lists(rects):
         | 
| 387 | 
            +
                if isinstance(rects, torch.Tensor):
         | 
| 388 | 
            +
                    return rects.tolist()
         | 
| 389 | 
            +
                if isinstance(rects, np.ndarray):
         | 
| 390 | 
            +
                    return rects.tolist()
         | 
| 391 | 
            +
                return [[a, b, c, d] for a, b, c, d in rects]
         | 
