Upload folder using huggingface_hub
Browse files- processing_maira2.py +23 -18
processing_maira2.py
CHANGED
@@ -3,17 +3,18 @@
|
|
3 |
|
4 |
|
5 |
import re
|
6 |
-
from typing import Any,
|
7 |
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
10 |
from transformers import BaseImageProcessor, LlavaProcessor, PreTrainedTokenizer
|
11 |
-
from transformers.models.llava.processing_llava import LlavaProcessorKwargs
|
12 |
from transformers.feature_extraction_utils import BatchFeature
|
13 |
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
14 |
-
from transformers.
|
|
|
15 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
16 |
|
|
|
17 |
# SingleChatMessageType: TypeAlias = dict[str, str | int | None]
|
18 |
# ChatMessageListType: TypeAlias = list[dict[str, str | list[SingleChatMessageType]]]
|
19 |
# BoxType: TypeAlias = tuple[float, float, float, float]
|
@@ -55,9 +56,9 @@ class Maira2Processor(LlavaProcessor):
|
|
55 |
self,
|
56 |
image_processor: BaseImageProcessor = None,
|
57 |
tokenizer: PreTrainedTokenizer = None,
|
58 |
-
patch_size
|
59 |
-
vision_feature_select_strategy
|
60 |
-
chat_template
|
61 |
image_token: str = "<image>",
|
62 |
phrase_start_token: str = "<obj>",
|
63 |
phrase_end_token: str = "</obj>",
|
@@ -301,12 +302,12 @@ class Maira2Processor(LlavaProcessor):
|
|
301 |
)
|
302 |
messages = [{"content": prompt, "role": "user"}]
|
303 |
if assistant_text is not None:
|
304 |
-
messages.append(
|
|
|
|
|
305 |
return messages
|
306 |
|
307 |
-
def _construct_chat_messages_phrase_grounding(
|
308 |
-
self, phrase: str, assistant_text: str = None
|
309 |
-
):
|
310 |
"""
|
311 |
This function constructs the chat messages for phrase grounding used in the phrase grounding task.
|
312 |
|
@@ -331,7 +332,9 @@ class Maira2Processor(LlavaProcessor):
|
|
331 |
]
|
332 |
messages = [{"content": prompt, "role": "user"}]
|
333 |
if assistant_text is not None:
|
334 |
-
messages.append(
|
|
|
|
|
335 |
return messages
|
336 |
|
337 |
def format_reporting_input(
|
@@ -388,7 +391,9 @@ class Maira2Processor(LlavaProcessor):
|
|
388 |
assistant_text=assistant_text,
|
389 |
)
|
390 |
add_generation_prompt = assistant_text is None
|
391 |
-
text = self.tokenizer.apply_chat_template(
|
|
|
|
|
392 |
return text, images
|
393 |
|
394 |
def format_phrase_grounding_input(
|
@@ -419,7 +424,9 @@ class Maira2Processor(LlavaProcessor):
|
|
419 |
)
|
420 |
messages = self._construct_chat_messages_phrase_grounding(phrase)
|
421 |
add_generation_prompt = assistant_text is None
|
422 |
-
text = self.tokenizer.apply_chat_template(
|
|
|
|
|
423 |
return text, images
|
424 |
|
425 |
def format_and_preprocess_reporting_input(
|
@@ -542,9 +549,7 @@ class Maira2Processor(LlavaProcessor):
|
|
542 |
assert len(text) == 0
|
543 |
return split_text
|
544 |
|
545 |
-
def convert_output_to_plaintext_or_grounded_sequence(
|
546 |
-
self, text: str
|
547 |
-
):
|
548 |
"""
|
549 |
This function converts the input text to a grounded sequence by extracting the grounded phrases and bounding
|
550 |
boxes from the text. If the text is plaintext without any grounded phrases, it returns the text as is.
|
@@ -725,6 +730,6 @@ class Maira2Processor(LlavaProcessor):
|
|
725 |
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
726 |
prompt_strings.append(sample)
|
727 |
|
728 |
-
output_kwargs.pop("return_mm_token_type_ids", None)
|
729 |
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
730 |
-
return BatchFeature(data={**text_inputs, **image_inputs})
|
|
|
3 |
|
4 |
|
5 |
import re
|
6 |
+
from typing import Any, List, Union
|
7 |
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
10 |
from transformers import BaseImageProcessor, LlavaProcessor, PreTrainedTokenizer
|
|
|
11 |
from transformers.feature_extraction_utils import BatchFeature
|
12 |
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
13 |
+
from transformers.models.llava.processing_llava import LlavaProcessorKwargs
|
14 |
+
from transformers.processing_utils import Unpack, _validate_images_text_input_order
|
15 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
16 |
|
17 |
+
|
18 |
# SingleChatMessageType: TypeAlias = dict[str, str | int | None]
|
19 |
# ChatMessageListType: TypeAlias = list[dict[str, str | list[SingleChatMessageType]]]
|
20 |
# BoxType: TypeAlias = tuple[float, float, float, float]
|
|
|
56 |
self,
|
57 |
image_processor: BaseImageProcessor = None,
|
58 |
tokenizer: PreTrainedTokenizer = None,
|
59 |
+
patch_size=None,
|
60 |
+
vision_feature_select_strategy=None,
|
61 |
+
chat_template=None,
|
62 |
image_token: str = "<image>",
|
63 |
phrase_start_token: str = "<obj>",
|
64 |
phrase_end_token: str = "</obj>",
|
|
|
302 |
)
|
303 |
messages = [{"content": prompt, "role": "user"}]
|
304 |
if assistant_text is not None:
|
305 |
+
messages.append(
|
306 |
+
{"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"}
|
307 |
+
)
|
308 |
return messages
|
309 |
|
310 |
+
def _construct_chat_messages_phrase_grounding(self, phrase: str, assistant_text: str = None):
|
|
|
|
|
311 |
"""
|
312 |
This function constructs the chat messages for phrase grounding used in the phrase grounding task.
|
313 |
|
|
|
332 |
]
|
333 |
messages = [{"content": prompt, "role": "user"}]
|
334 |
if assistant_text is not None:
|
335 |
+
messages.append(
|
336 |
+
{"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"}
|
337 |
+
)
|
338 |
return messages
|
339 |
|
340 |
def format_reporting_input(
|
|
|
391 |
assistant_text=assistant_text,
|
392 |
)
|
393 |
add_generation_prompt = assistant_text is None
|
394 |
+
text = self.tokenizer.apply_chat_template(
|
395 |
+
messages, add_generation_prompt=add_generation_prompt, tokenize=False
|
396 |
+
)
|
397 |
return text, images
|
398 |
|
399 |
def format_phrase_grounding_input(
|
|
|
424 |
)
|
425 |
messages = self._construct_chat_messages_phrase_grounding(phrase)
|
426 |
add_generation_prompt = assistant_text is None
|
427 |
+
text = self.tokenizer.apply_chat_template(
|
428 |
+
messages, add_generation_prompt=add_generation_prompt, tokenize=False
|
429 |
+
)
|
430 |
return text, images
|
431 |
|
432 |
def format_and_preprocess_reporting_input(
|
|
|
549 |
assert len(text) == 0
|
550 |
return split_text
|
551 |
|
552 |
+
def convert_output_to_plaintext_or_grounded_sequence(self, text: str):
|
|
|
|
|
553 |
"""
|
554 |
This function converts the input text to a grounded sequence by extracting the grounded phrases and bounding
|
555 |
boxes from the text. If the text is plaintext without any grounded phrases, it returns the text as is.
|
|
|
730 |
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
731 |
prompt_strings.append(sample)
|
732 |
|
733 |
+
output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
734 |
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
735 |
+
return BatchFeature(data={**text_inputs, **image_inputs})
|