Add HungarianMatcher (#2)
Browse files- added conditional_detr hungarian matcher (6da1a70d484731a1f9182e8d786f2aa5efed4f05)
Co-authored-by: Emanuele Vivoli <[email protected]>
- conditional_detr_utils.py +179 -0
- modelling_magi.py +4 -3
conditional_detr_utils.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Microsoft Research Asia and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" PyTorch Conditional DETR model."""
|
| 16 |
+
|
| 17 |
+
from transformers.utils import (
|
| 18 |
+
is_scipy_available,
|
| 19 |
+
is_vision_available,
|
| 20 |
+
logging
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from torch import Tensor, nn
|
| 25 |
+
|
| 26 |
+
if is_scipy_available():
|
| 27 |
+
from scipy.optimize import linear_sum_assignment
|
| 28 |
+
|
| 29 |
+
if is_vision_available():
|
| 30 |
+
from transformers.image_transforms import center_to_corners_format
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
|
| 35 |
+
class ConditionalDetrHungarianMatcher(nn.Module):
|
| 36 |
+
"""
|
| 37 |
+
This class computes an assignment between the targets and the predictions of the network.
|
| 38 |
+
|
| 39 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
| 40 |
+
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
| 41 |
+
un-matched (and thus treated as non-objects).
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
class_cost:
|
| 45 |
+
The relative weight of the classification error in the matching cost.
|
| 46 |
+
bbox_cost:
|
| 47 |
+
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
| 48 |
+
giou_cost:
|
| 49 |
+
The relative weight of the giou loss of the bounding box in the matching cost.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
self.class_cost = class_cost
|
| 56 |
+
self.bbox_cost = bbox_cost
|
| 57 |
+
self.giou_cost = giou_cost
|
| 58 |
+
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
| 59 |
+
raise ValueError("All costs of the Matcher can't be 0")
|
| 60 |
+
|
| 61 |
+
@torch.no_grad()
|
| 62 |
+
def forward(self, outputs, targets):
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
outputs (`dict`):
|
| 66 |
+
A dictionary that contains at least these entries:
|
| 67 |
+
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
| 68 |
+
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
| 69 |
+
targets (`List[dict]`):
|
| 70 |
+
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
| 71 |
+
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
| 72 |
+
ground-truth
|
| 73 |
+
objects in the target) containing the class labels
|
| 74 |
+
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
| 78 |
+
- index_i is the indices of the selected predictions (in order)
|
| 79 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
| 80 |
+
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
| 81 |
+
"""
|
| 82 |
+
batch_size, num_queries = outputs["logits"].shape[:2]
|
| 83 |
+
|
| 84 |
+
# We flatten to compute the cost matrices in a batch
|
| 85 |
+
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
| 86 |
+
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
| 87 |
+
|
| 88 |
+
# Also concat the target labels and boxes
|
| 89 |
+
target_ids = torch.cat([v["class_labels"] for v in targets])
|
| 90 |
+
target_bbox = torch.cat([v["boxes"] for v in targets])
|
| 91 |
+
|
| 92 |
+
# Compute the classification cost.
|
| 93 |
+
alpha = 0.25
|
| 94 |
+
gamma = 2.0
|
| 95 |
+
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
| 96 |
+
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
| 97 |
+
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
|
| 98 |
+
|
| 99 |
+
# Compute the L1 cost between boxes
|
| 100 |
+
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
| 101 |
+
|
| 102 |
+
# Compute the giou cost between boxes
|
| 103 |
+
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
| 104 |
+
|
| 105 |
+
# Final cost matrix
|
| 106 |
+
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
| 107 |
+
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
| 108 |
+
|
| 109 |
+
sizes = [len(v["boxes"]) for v in targets]
|
| 110 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
| 111 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Copied from transformers.models.detr.modeling_detr._upcast
|
| 115 |
+
def _upcast(t: Tensor) -> Tensor:
|
| 116 |
+
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
| 117 |
+
if t.is_floating_point():
|
| 118 |
+
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
| 119 |
+
else:
|
| 120 |
+
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# Copied from transformers.models.detr.modeling_detr.box_area
|
| 124 |
+
def box_area(boxes: Tensor) -> Tensor:
|
| 125 |
+
"""
|
| 126 |
+
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
| 130 |
+
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
| 131 |
+
< x2` and `0 <= y1 < y2`.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
`torch.FloatTensor`: a tensor containing the area for each box.
|
| 135 |
+
"""
|
| 136 |
+
boxes = _upcast(boxes)
|
| 137 |
+
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# Copied from transformers.models.detr.modeling_detr.box_iou
|
| 141 |
+
def box_iou(boxes1, boxes2):
|
| 142 |
+
area1 = box_area(boxes1)
|
| 143 |
+
area2 = box_area(boxes2)
|
| 144 |
+
|
| 145 |
+
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
| 146 |
+
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
| 147 |
+
|
| 148 |
+
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
| 149 |
+
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
| 150 |
+
|
| 151 |
+
union = area1[:, None] + area2 - inter
|
| 152 |
+
|
| 153 |
+
iou = inter / union
|
| 154 |
+
return iou, union
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
| 158 |
+
def generalized_box_iou(boxes1, boxes2):
|
| 159 |
+
"""
|
| 160 |
+
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
| 164 |
+
"""
|
| 165 |
+
# degenerate boxes gives inf / nan results
|
| 166 |
+
# so do an early check
|
| 167 |
+
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
| 168 |
+
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
| 169 |
+
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
| 170 |
+
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
| 171 |
+
iou, union = box_iou(boxes1, boxes2)
|
| 172 |
+
|
| 173 |
+
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
| 174 |
+
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
| 175 |
+
|
| 176 |
+
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
| 177 |
+
area = width_height[:, :, 0] * width_height[:, :, 1]
|
| 178 |
+
|
| 179 |
+
return iou - (area - union) / area
|
modelling_magi.py
CHANGED
|
@@ -2,15 +2,15 @@ from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel
|
|
| 2 |
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
| 3 |
ConditionalDetrMLPPredictionHead,
|
| 4 |
ConditionalDetrModelOutput,
|
| 5 |
-
ConditionalDetrHungarianMatcher,
|
| 6 |
inverse_sigmoid,
|
| 7 |
)
|
|
|
|
| 8 |
from .configuration_magi import MagiConfig
|
| 9 |
from .processing_magi import MagiProcessor
|
| 10 |
from torch import nn
|
| 11 |
from typing import Optional, List
|
| 12 |
import torch
|
| 13 |
-
from einops import rearrange, repeat
|
| 14 |
from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
|
| 15 |
|
| 16 |
class MagiModel(PreTrainedModel):
|
|
@@ -498,4 +498,5 @@ class MagiModel(PreTrainedModel):
|
|
| 498 |
if apply_sigmoid:
|
| 499 |
text_character_affinities = text_character_affinities.sigmoid()
|
| 500 |
affinity_matrices.append(text_character_affinities)
|
| 501 |
-
return affinity_matrices
|
|
|
|
|
|
| 2 |
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
| 3 |
ConditionalDetrMLPPredictionHead,
|
| 4 |
ConditionalDetrModelOutput,
|
|
|
|
| 5 |
inverse_sigmoid,
|
| 6 |
)
|
| 7 |
+
from .conditional_detr_utils import ConditionalDetrHungarianMatcher
|
| 8 |
from .configuration_magi import MagiConfig
|
| 9 |
from .processing_magi import MagiProcessor
|
| 10 |
from torch import nn
|
| 11 |
from typing import Optional, List
|
| 12 |
import torch
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
|
| 15 |
|
| 16 |
class MagiModel(PreTrainedModel):
|
|
|
|
| 498 |
if apply_sigmoid:
|
| 499 |
text_character_affinities = text_character_affinities.sigmoid()
|
| 500 |
affinity_matrices.append(text_character_affinities)
|
| 501 |
+
return affinity_matrices
|
| 502 |
+
|