add confidence scores for `od` and `description_with_bboxes` tasks (#25)
Browse files- add confidence scores for `od` and `description_with_bboxes` tasks (15bd11834572457e1e8e8fec7f25c83283902db2)
Co-authored-by: Fatih C. Akyon <[email protected]>
- processing_florence2.py +83 -23
processing_florence2.py
CHANGED
|
@@ -20,6 +20,7 @@ import re
|
|
| 20 |
import logging
|
| 21 |
from typing import List, Optional, Union
|
| 22 |
import numpy as np
|
|
|
|
| 23 |
|
| 24 |
import torch
|
| 25 |
|
|
@@ -32,6 +33,7 @@ from transformers.tokenization_utils_base import (
|
|
| 32 |
TextInput,
|
| 33 |
TruncationStrategy,
|
| 34 |
)
|
|
|
|
| 35 |
from transformers.utils import TensorType
|
| 36 |
|
| 37 |
|
|
@@ -304,7 +306,7 @@ class Florence2Processor(ProcessorMixin):
|
|
| 304 |
image_processor_input_names = self.image_processor.model_input_names
|
| 305 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 306 |
|
| 307 |
-
def post_process_generation(self, text, task, image_size):
|
| 308 |
"""
|
| 309 |
Post-process the output of the model to each of the task outputs.
|
| 310 |
|
|
@@ -317,6 +319,8 @@ class Florence2Processor(ProcessorMixin):
|
|
| 317 |
task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
|
| 318 |
task_answer = self.post_processor(
|
| 319 |
text=text,
|
|
|
|
|
|
|
| 320 |
image_size=image_size,
|
| 321 |
parse_tasks=task_answer_post_processing_type,
|
| 322 |
)[task_answer_post_processing_type]
|
|
@@ -330,6 +334,9 @@ class Florence2Processor(ProcessorMixin):
|
|
| 330 |
bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
|
| 331 |
labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
|
| 332 |
final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
|
|
|
|
|
|
|
|
|
|
| 333 |
elif task_answer_post_processing_type in ['ocr']:
|
| 334 |
bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
|
| 335 |
labels = [str(_od_instance['text']) for _od_instance in task_answer]
|
|
@@ -591,7 +598,8 @@ class Florence2PostProcesser(object):
|
|
| 591 |
'PARSE_TASKS': [
|
| 592 |
{
|
| 593 |
'TASK_NAME': 'od',
|
| 594 |
-
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
|
|
|
|
| 595 |
},
|
| 596 |
{
|
| 597 |
'TASK_NAME': 'ocr',
|
|
@@ -607,6 +615,7 @@ class Florence2PostProcesser(object):
|
|
| 607 |
},
|
| 608 |
{
|
| 609 |
'TASK_NAME': 'description_with_bboxes',
|
|
|
|
| 610 |
},
|
| 611 |
{
|
| 612 |
'TASK_NAME': 'description_with_polygons',
|
|
@@ -648,9 +657,6 @@ class Florence2PostProcesser(object):
|
|
| 648 |
token_ids, skip_special_tokens=False)
|
| 649 |
assert len(filtered_tokens) == len(token_ids)
|
| 650 |
|
| 651 |
-
# To avoid mixing byte-level and unicode for byte-level BPT
|
| 652 |
-
# we need to build string separately for added tokens and byte-level tokens
|
| 653 |
-
# cf. https://github.com/huggingface/transformers/issues/1133
|
| 654 |
sub_texts = []
|
| 655 |
for token in filtered_tokens:
|
| 656 |
if token in self.all_special_tokens:
|
|
@@ -658,10 +664,6 @@ class Florence2PostProcesser(object):
|
|
| 658 |
else:
|
| 659 |
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
|
| 660 |
sub_text = tokenizer.convert_tokens_to_string([token])
|
| 661 |
-
elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
|
| 662 |
-
# Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
|
| 663 |
-
# Note: Do not strip sub_text as it may have functional whitespace
|
| 664 |
-
sub_text = token.replace('▁', ' ')
|
| 665 |
else:
|
| 666 |
raise ValueError(f'type {type(tokenizer)} not supported')
|
| 667 |
sub_texts.append(sub_text)
|
|
@@ -673,13 +675,6 @@ class Florence2PostProcesser(object):
|
|
| 673 |
text += sub_text
|
| 674 |
spans.append(span)
|
| 675 |
|
| 676 |
-
# Text format:
|
| 677 |
-
# 1. T5Tokenizer/T5TokenizerFast:
|
| 678 |
-
# "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
|
| 679 |
-
# Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
|
| 680 |
-
# 2. BartTokenizer (need to double check):
|
| 681 |
-
# "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
|
| 682 |
-
# Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
|
| 683 |
return text, spans
|
| 684 |
|
| 685 |
def parse_od_from_text_and_spans(
|
|
@@ -714,7 +709,7 @@ class Florence2PostProcesser(object):
|
|
| 714 |
return instances
|
| 715 |
|
| 716 |
def parse_ocr_from_text_and_spans(self,
|
| 717 |
-
|
| 718 |
pattern,
|
| 719 |
image_size,
|
| 720 |
area_threshold=-1.0,
|
|
@@ -818,9 +813,26 @@ class Florence2PostProcesser(object):
|
|
| 818 |
|
| 819 |
return instances
|
| 820 |
|
| 821 |
-
def parse_description_with_bboxes_from_text_and_spans(
|
| 822 |
-
|
| 823 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 824 |
|
| 825 |
text = text.replace('<s>', '')
|
| 826 |
text = text.replace('</s>', '')
|
|
@@ -842,13 +854,16 @@ class Florence2PostProcesser(object):
|
|
| 842 |
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
| 843 |
|
| 844 |
if phrase_text_strip == '' and not allow_empty_phrase:
|
|
|
|
| 845 |
continue
|
| 846 |
|
| 847 |
# parse phrase, get string
|
| 848 |
phrase = re.search(pattern, phrase_text_strip)
|
| 849 |
if phrase is None:
|
|
|
|
| 850 |
continue
|
| 851 |
|
|
|
|
| 852 |
phrase = phrase.group()
|
| 853 |
# remove leading and trailing spaces
|
| 854 |
phrase = phrase.strip()
|
|
@@ -856,6 +871,7 @@ class Florence2PostProcesser(object):
|
|
| 856 |
# parse bboxes by box_pattern
|
| 857 |
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
| 858 |
if len(bboxes_parsed) == 0:
|
|
|
|
| 859 |
continue
|
| 860 |
|
| 861 |
# a list of list
|
|
@@ -866,14 +882,42 @@ class Florence2PostProcesser(object):
|
|
| 866 |
size=image_size
|
| 867 |
).tolist()
|
| 868 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
| 870 |
-
for _bboxes in bboxes:
|
| 871 |
# Prepare instance.
|
| 872 |
instance = {}
|
| 873 |
instance['bbox'] = _bboxes
|
| 874 |
# exclude non-ascii characters
|
| 875 |
instance['cat_name'] = phrase
|
|
|
|
|
|
|
| 876 |
instances.append(instance)
|
|
|
|
|
|
|
| 877 |
|
| 878 |
return instances
|
| 879 |
|
|
@@ -991,6 +1035,8 @@ class Florence2PostProcesser(object):
|
|
| 991 |
def __call__(
|
| 992 |
self,
|
| 993 |
text=None,
|
|
|
|
|
|
|
| 994 |
image_size=None,
|
| 995 |
parse_tasks=None,
|
| 996 |
):
|
|
@@ -999,7 +1045,6 @@ class Florence2PostProcesser(object):
|
|
| 999 |
text: model outputs
|
| 1000 |
image_size: (width, height)
|
| 1001 |
parse_tasks: a list of tasks to parse, if None, parse all tasks.
|
| 1002 |
-
|
| 1003 |
"""
|
| 1004 |
if parse_tasks is not None:
|
| 1005 |
if isinstance(parse_tasks, str):
|
|
@@ -1008,7 +1053,18 @@ class Florence2PostProcesser(object):
|
|
| 1008 |
assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
|
| 1009 |
|
| 1010 |
# sequence or text should be provided
|
| 1011 |
-
assert text is not None, 'text should be provided'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
|
| 1013 |
parsed_dict = {
|
| 1014 |
'text': text
|
|
@@ -1019,6 +1075,7 @@ class Florence2PostProcesser(object):
|
|
| 1019 |
continue
|
| 1020 |
|
| 1021 |
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
|
|
|
|
| 1022 |
|
| 1023 |
if task == 'ocr':
|
| 1024 |
instances = self.parse_ocr_from_text_and_spans(
|
|
@@ -1040,6 +1097,9 @@ class Florence2PostProcesser(object):
|
|
| 1040 |
elif task == 'description_with_bboxes':
|
| 1041 |
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
| 1042 |
text,
|
|
|
|
|
|
|
|
|
|
| 1043 |
pattern=pattern,
|
| 1044 |
image_size=image_size,
|
| 1045 |
)
|
|
|
|
| 20 |
import logging
|
| 21 |
from typing import List, Optional, Union
|
| 22 |
import numpy as np
|
| 23 |
+
import math
|
| 24 |
|
| 25 |
import torch
|
| 26 |
|
|
|
|
| 33 |
TextInput,
|
| 34 |
TruncationStrategy,
|
| 35 |
)
|
| 36 |
+
from transformers import BartTokenizer, BartTokenizerFast
|
| 37 |
from transformers.utils import TensorType
|
| 38 |
|
| 39 |
|
|
|
|
| 306 |
image_processor_input_names = self.image_processor.model_input_names
|
| 307 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 308 |
|
| 309 |
+
def post_process_generation(self, text=None, sequence=None, transition_beam_score=None, task=None, image_size=None):
|
| 310 |
"""
|
| 311 |
Post-process the output of the model to each of the task outputs.
|
| 312 |
|
|
|
|
| 319 |
task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
|
| 320 |
task_answer = self.post_processor(
|
| 321 |
text=text,
|
| 322 |
+
sequence=sequence,
|
| 323 |
+
transition_beam_score=transition_beam_score,
|
| 324 |
image_size=image_size,
|
| 325 |
parse_tasks=task_answer_post_processing_type,
|
| 326 |
)[task_answer_post_processing_type]
|
|
|
|
| 334 |
bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
|
| 335 |
labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
|
| 336 |
final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
|
| 337 |
+
if len(od_instances) and 'score' in od_instances[0]:
|
| 338 |
+
scores_od = [_od_instance['score'] for _od_instance in od_instances]
|
| 339 |
+
final_answer['scores'] = scores_od
|
| 340 |
elif task_answer_post_processing_type in ['ocr']:
|
| 341 |
bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
|
| 342 |
labels = [str(_od_instance['text']) for _od_instance in task_answer]
|
|
|
|
| 598 |
'PARSE_TASKS': [
|
| 599 |
{
|
| 600 |
'TASK_NAME': 'od',
|
| 601 |
+
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>',
|
| 602 |
+
'SCORE_MODE': 'avg_loc_scores'
|
| 603 |
},
|
| 604 |
{
|
| 605 |
'TASK_NAME': 'ocr',
|
|
|
|
| 615 |
},
|
| 616 |
{
|
| 617 |
'TASK_NAME': 'description_with_bboxes',
|
| 618 |
+
'SCORE_MODE': 'avg_loc_scores'
|
| 619 |
},
|
| 620 |
{
|
| 621 |
'TASK_NAME': 'description_with_polygons',
|
|
|
|
| 657 |
token_ids, skip_special_tokens=False)
|
| 658 |
assert len(filtered_tokens) == len(token_ids)
|
| 659 |
|
|
|
|
|
|
|
|
|
|
| 660 |
sub_texts = []
|
| 661 |
for token in filtered_tokens:
|
| 662 |
if token in self.all_special_tokens:
|
|
|
|
| 664 |
else:
|
| 665 |
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
|
| 666 |
sub_text = tokenizer.convert_tokens_to_string([token])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 667 |
else:
|
| 668 |
raise ValueError(f'type {type(tokenizer)} not supported')
|
| 669 |
sub_texts.append(sub_text)
|
|
|
|
| 675 |
text += sub_text
|
| 676 |
spans.append(span)
|
| 677 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
return text, spans
|
| 679 |
|
| 680 |
def parse_od_from_text_and_spans(
|
|
|
|
| 709 |
return instances
|
| 710 |
|
| 711 |
def parse_ocr_from_text_and_spans(self,
|
| 712 |
+
text,
|
| 713 |
pattern,
|
| 714 |
image_size,
|
| 715 |
area_threshold=-1.0,
|
|
|
|
| 813 |
|
| 814 |
return instances
|
| 815 |
|
| 816 |
+
def parse_description_with_bboxes_from_text_and_spans(
|
| 817 |
+
self,
|
| 818 |
+
text,
|
| 819 |
+
spans=None,
|
| 820 |
+
scores=None,
|
| 821 |
+
score_mode=None,
|
| 822 |
+
pattern=None,
|
| 823 |
+
image_size=None,
|
| 824 |
+
allow_empty_phrase=False
|
| 825 |
+
):
|
| 826 |
+
def find_matched_token_indices(cur_span, token_spans):
|
| 827 |
+
inds = []
|
| 828 |
+
for i, token_span in enumerate(token_spans):
|
| 829 |
+
if not (token_span[1] <= cur_span[0] or token_span[0] >= cur_span[1]):
|
| 830 |
+
inds.append(i)
|
| 831 |
+
return inds
|
| 832 |
+
|
| 833 |
+
cur_span = 0
|
| 834 |
+
if text.startswith('<s>'):
|
| 835 |
+
cur_span += 3
|
| 836 |
|
| 837 |
text = text.replace('<s>', '')
|
| 838 |
text = text.replace('</s>', '')
|
|
|
|
| 854 |
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
| 855 |
|
| 856 |
if phrase_text_strip == '' and not allow_empty_phrase:
|
| 857 |
+
cur_span += len(pharse_text)
|
| 858 |
continue
|
| 859 |
|
| 860 |
# parse phrase, get string
|
| 861 |
phrase = re.search(pattern, phrase_text_strip)
|
| 862 |
if phrase is None:
|
| 863 |
+
cur_span += len(pharse_text)
|
| 864 |
continue
|
| 865 |
|
| 866 |
+
phrase_span = phrase.span()
|
| 867 |
phrase = phrase.group()
|
| 868 |
# remove leading and trailing spaces
|
| 869 |
phrase = phrase.strip()
|
|
|
|
| 871 |
# parse bboxes by box_pattern
|
| 872 |
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
| 873 |
if len(bboxes_parsed) == 0:
|
| 874 |
+
cur_span += len(pharse_text)
|
| 875 |
continue
|
| 876 |
|
| 877 |
# a list of list
|
|
|
|
| 882 |
size=image_size
|
| 883 |
).tolist()
|
| 884 |
|
| 885 |
+
if score_mode == 'avg_loc_scores':
|
| 886 |
+
if spans is None or scores is None:
|
| 887 |
+
all_scores = None
|
| 888 |
+
else:
|
| 889 |
+
bbox_end_spans = [_bboxes_parsed.span(0) for _bboxes_parsed in bboxes_parsed]
|
| 890 |
+
all_scores = []
|
| 891 |
+
for _spans in bbox_end_spans:
|
| 892 |
+
token_inds = find_matched_token_indices((_spans[0] + cur_span, _spans[1]+ cur_span), spans)
|
| 893 |
+
loc_scores = [scores[token_i] for token_i in token_inds]
|
| 894 |
+
score = sum(loc_scores) / len(loc_scores)
|
| 895 |
+
all_scores.append(score)
|
| 896 |
+
elif score_mode == 'avg_cat_name_scores':
|
| 897 |
+
if spans is None or scores is None:
|
| 898 |
+
all_scores = None
|
| 899 |
+
else:
|
| 900 |
+
cat_name_token_inds = find_matched_token_indices((phrase_span[0] + cur_span, phrase_span[1]+cur_span), spans)
|
| 901 |
+
cat_name_scores = [scores[token_i] for token_i in cat_name_token_inds]
|
| 902 |
+
score = sum(cat_name_scores) / len(cat_name_scores)
|
| 903 |
+
all_scores = [score] * len(bboxes)
|
| 904 |
+
elif score_mode is None:
|
| 905 |
+
all_scores = None
|
| 906 |
+
else:
|
| 907 |
+
raise ValueError('Unknown score mode: {}'.format(score_mode))
|
| 908 |
+
|
| 909 |
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
| 910 |
+
for _idx, _bboxes in enumerate(bboxes):
|
| 911 |
# Prepare instance.
|
| 912 |
instance = {}
|
| 913 |
instance['bbox'] = _bboxes
|
| 914 |
# exclude non-ascii characters
|
| 915 |
instance['cat_name'] = phrase
|
| 916 |
+
if all_scores is not None:
|
| 917 |
+
instance['score'] = math.exp(all_scores[_idx])
|
| 918 |
instances.append(instance)
|
| 919 |
+
|
| 920 |
+
cur_span += len(pharse_text)
|
| 921 |
|
| 922 |
return instances
|
| 923 |
|
|
|
|
| 1035 |
def __call__(
|
| 1036 |
self,
|
| 1037 |
text=None,
|
| 1038 |
+
sequence=None,
|
| 1039 |
+
transition_beam_score=None,
|
| 1040 |
image_size=None,
|
| 1041 |
parse_tasks=None,
|
| 1042 |
):
|
|
|
|
| 1045 |
text: model outputs
|
| 1046 |
image_size: (width, height)
|
| 1047 |
parse_tasks: a list of tasks to parse, if None, parse all tasks.
|
|
|
|
| 1048 |
"""
|
| 1049 |
if parse_tasks is not None:
|
| 1050 |
if isinstance(parse_tasks, str):
|
|
|
|
| 1053 |
assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
|
| 1054 |
|
| 1055 |
# sequence or text should be provided
|
| 1056 |
+
assert sequence is not None or text is not None, 'sequence or text should be provided'
|
| 1057 |
+
assert sequence is None or text is None, 'only one of sequence and text should be provided'
|
| 1058 |
+
|
| 1059 |
+
if sequence is not None:
|
| 1060 |
+
sequence = sequence.tolist()[1:]
|
| 1061 |
+
text, spans = self.decode_with_spans(self.tokenizer, sequence)
|
| 1062 |
+
if transition_beam_score is not None:
|
| 1063 |
+
transition_beam_score = transition_beam_score.tolist()
|
| 1064 |
+
assert len(sequence) == len(transition_beam_score)
|
| 1065 |
+
else:
|
| 1066 |
+
spans = None
|
| 1067 |
+
transition_beam_score = None
|
| 1068 |
|
| 1069 |
parsed_dict = {
|
| 1070 |
'text': text
|
|
|
|
| 1075 |
continue
|
| 1076 |
|
| 1077 |
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
|
| 1078 |
+
score_mode = self.parse_tasks_configs[task].get('SCORE_MODE', None)
|
| 1079 |
|
| 1080 |
if task == 'ocr':
|
| 1081 |
instances = self.parse_ocr_from_text_and_spans(
|
|
|
|
| 1097 |
elif task == 'description_with_bboxes':
|
| 1098 |
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
| 1099 |
text,
|
| 1100 |
+
spans=spans,
|
| 1101 |
+
scores=transition_beam_score,
|
| 1102 |
+
score_mode=score_mode,
|
| 1103 |
pattern=pattern,
|
| 1104 |
image_size=image_size,
|
| 1105 |
)
|