Upload folder using huggingface_hub
Browse files- processing_maira2.py +23 -18
processing_maira2.py
CHANGED
|
@@ -3,17 +3,18 @@
|
|
| 3 |
|
| 4 |
|
| 5 |
import re
|
| 6 |
-
from typing import Any,
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
from PIL import Image
|
| 10 |
from transformers import BaseImageProcessor, LlavaProcessor, PreTrainedTokenizer
|
| 11 |
-
from transformers.models.llava.processing_llava import LlavaProcessorKwargs
|
| 12 |
from transformers.feature_extraction_utils import BatchFeature
|
| 13 |
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
| 14 |
-
from transformers.
|
|
|
|
| 15 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 16 |
|
|
|
|
| 17 |
# SingleChatMessageType: TypeAlias = dict[str, str | int | None]
|
| 18 |
# ChatMessageListType: TypeAlias = list[dict[str, str | list[SingleChatMessageType]]]
|
| 19 |
# BoxType: TypeAlias = tuple[float, float, float, float]
|
|
@@ -55,9 +56,9 @@ class Maira2Processor(LlavaProcessor):
|
|
| 55 |
self,
|
| 56 |
image_processor: BaseImageProcessor = None,
|
| 57 |
tokenizer: PreTrainedTokenizer = None,
|
| 58 |
-
patch_size
|
| 59 |
-
vision_feature_select_strategy
|
| 60 |
-
chat_template
|
| 61 |
image_token: str = "<image>",
|
| 62 |
phrase_start_token: str = "<obj>",
|
| 63 |
phrase_end_token: str = "</obj>",
|
|
@@ -301,12 +302,12 @@ class Maira2Processor(LlavaProcessor):
|
|
| 301 |
)
|
| 302 |
messages = [{"content": prompt, "role": "user"}]
|
| 303 |
if assistant_text is not None:
|
| 304 |
-
messages.append(
|
|
|
|
|
|
|
| 305 |
return messages
|
| 306 |
|
| 307 |
-
def _construct_chat_messages_phrase_grounding(
|
| 308 |
-
self, phrase: str, assistant_text: str = None
|
| 309 |
-
):
|
| 310 |
"""
|
| 311 |
This function constructs the chat messages for phrase grounding used in the phrase grounding task.
|
| 312 |
|
|
@@ -331,7 +332,9 @@ class Maira2Processor(LlavaProcessor):
|
|
| 331 |
]
|
| 332 |
messages = [{"content": prompt, "role": "user"}]
|
| 333 |
if assistant_text is not None:
|
| 334 |
-
messages.append(
|
|
|
|
|
|
|
| 335 |
return messages
|
| 336 |
|
| 337 |
def format_reporting_input(
|
|
@@ -388,7 +391,9 @@ class Maira2Processor(LlavaProcessor):
|
|
| 388 |
assistant_text=assistant_text,
|
| 389 |
)
|
| 390 |
add_generation_prompt = assistant_text is None
|
| 391 |
-
text = self.tokenizer.apply_chat_template(
|
|
|
|
|
|
|
| 392 |
return text, images
|
| 393 |
|
| 394 |
def format_phrase_grounding_input(
|
|
@@ -419,7 +424,9 @@ class Maira2Processor(LlavaProcessor):
|
|
| 419 |
)
|
| 420 |
messages = self._construct_chat_messages_phrase_grounding(phrase)
|
| 421 |
add_generation_prompt = assistant_text is None
|
| 422 |
-
text = self.tokenizer.apply_chat_template(
|
|
|
|
|
|
|
| 423 |
return text, images
|
| 424 |
|
| 425 |
def format_and_preprocess_reporting_input(
|
|
@@ -542,9 +549,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 542 |
assert len(text) == 0
|
| 543 |
return split_text
|
| 544 |
|
| 545 |
-
def convert_output_to_plaintext_or_grounded_sequence(
|
| 546 |
-
self, text: str
|
| 547 |
-
):
|
| 548 |
"""
|
| 549 |
This function converts the input text to a grounded sequence by extracting the grounded phrases and bounding
|
| 550 |
boxes from the text. If the text is plaintext without any grounded phrases, it returns the text as is.
|
|
@@ -725,6 +730,6 @@ class Maira2Processor(LlavaProcessor):
|
|
| 725 |
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
| 726 |
prompt_strings.append(sample)
|
| 727 |
|
| 728 |
-
output_kwargs.pop("return_mm_token_type_ids", None)
|
| 729 |
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
| 730 |
-
return BatchFeature(data={**text_inputs, **image_inputs})
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
import re
|
| 6 |
+
from typing import Any, List, Union
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
from PIL import Image
|
| 10 |
from transformers import BaseImageProcessor, LlavaProcessor, PreTrainedTokenizer
|
|
|
|
| 11 |
from transformers.feature_extraction_utils import BatchFeature
|
| 12 |
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
| 13 |
+
from transformers.models.llava.processing_llava import LlavaProcessorKwargs
|
| 14 |
+
from transformers.processing_utils import Unpack, _validate_images_text_input_order
|
| 15 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 16 |
|
| 17 |
+
|
| 18 |
# SingleChatMessageType: TypeAlias = dict[str, str | int | None]
|
| 19 |
# ChatMessageListType: TypeAlias = list[dict[str, str | list[SingleChatMessageType]]]
|
| 20 |
# BoxType: TypeAlias = tuple[float, float, float, float]
|
|
|
|
| 56 |
self,
|
| 57 |
image_processor: BaseImageProcessor = None,
|
| 58 |
tokenizer: PreTrainedTokenizer = None,
|
| 59 |
+
patch_size=None,
|
| 60 |
+
vision_feature_select_strategy=None,
|
| 61 |
+
chat_template=None,
|
| 62 |
image_token: str = "<image>",
|
| 63 |
phrase_start_token: str = "<obj>",
|
| 64 |
phrase_end_token: str = "</obj>",
|
|
|
|
| 302 |
)
|
| 303 |
messages = [{"content": prompt, "role": "user"}]
|
| 304 |
if assistant_text is not None:
|
| 305 |
+
messages.append(
|
| 306 |
+
{"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"}
|
| 307 |
+
)
|
| 308 |
return messages
|
| 309 |
|
| 310 |
+
def _construct_chat_messages_phrase_grounding(self, phrase: str, assistant_text: str = None):
|
|
|
|
|
|
|
| 311 |
"""
|
| 312 |
This function constructs the chat messages for phrase grounding used in the phrase grounding task.
|
| 313 |
|
|
|
|
| 332 |
]
|
| 333 |
messages = [{"content": prompt, "role": "user"}]
|
| 334 |
if assistant_text is not None:
|
| 335 |
+
messages.append(
|
| 336 |
+
{"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"}
|
| 337 |
+
)
|
| 338 |
return messages
|
| 339 |
|
| 340 |
def format_reporting_input(
|
|
|
|
| 391 |
assistant_text=assistant_text,
|
| 392 |
)
|
| 393 |
add_generation_prompt = assistant_text is None
|
| 394 |
+
text = self.tokenizer.apply_chat_template(
|
| 395 |
+
messages, add_generation_prompt=add_generation_prompt, tokenize=False
|
| 396 |
+
)
|
| 397 |
return text, images
|
| 398 |
|
| 399 |
def format_phrase_grounding_input(
|
|
|
|
| 424 |
)
|
| 425 |
messages = self._construct_chat_messages_phrase_grounding(phrase)
|
| 426 |
add_generation_prompt = assistant_text is None
|
| 427 |
+
text = self.tokenizer.apply_chat_template(
|
| 428 |
+
messages, add_generation_prompt=add_generation_prompt, tokenize=False
|
| 429 |
+
)
|
| 430 |
return text, images
|
| 431 |
|
| 432 |
def format_and_preprocess_reporting_input(
|
|
|
|
| 549 |
assert len(text) == 0
|
| 550 |
return split_text
|
| 551 |
|
| 552 |
+
def convert_output_to_plaintext_or_grounded_sequence(self, text: str):
|
|
|
|
|
|
|
| 553 |
"""
|
| 554 |
This function converts the input text to a grounded sequence by extracting the grounded phrases and bounding
|
| 555 |
boxes from the text. If the text is plaintext without any grounded phrases, it returns the text as is.
|
|
|
|
| 730 |
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
| 731 |
prompt_strings.append(sample)
|
| 732 |
|
| 733 |
+
output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
| 734 |
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
| 735 |
+
return BatchFeature(data={**text_inputs, **image_inputs})
|