Update modeling_GOT.py
Browse files- modeling_GOT.py +16 -15
modeling_GOT.py
CHANGED
|
@@ -249,7 +249,7 @@ class GOTQwenModel(Qwen2Model):
|
|
| 249 |
image_patches_features = []
|
| 250 |
for image_patch in image_patches:
|
| 251 |
image_p = torch.stack([image_patch])
|
| 252 |
-
|
| 253 |
with torch.set_grad_enabled(False):
|
| 254 |
cnn_feature_p = vision_tower_high(image_p)
|
| 255 |
cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
|
|
@@ -257,7 +257,6 @@ class GOTQwenModel(Qwen2Model):
|
|
| 257 |
image_patches_features.append(image_feature_p)
|
| 258 |
image_feature = torch.cat(image_patches_features, dim=1)
|
| 259 |
image_features.append(image_feature)
|
| 260 |
-
exit()
|
| 261 |
|
| 262 |
|
| 263 |
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
|
@@ -485,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 485 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 486 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 487 |
|
| 488 |
-
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None):
|
| 489 |
|
| 490 |
self.disable_torch_init()
|
| 491 |
|
|
@@ -549,7 +548,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 549 |
conv.append_message(conv.roles[1], None)
|
| 550 |
prompt = conv.get_prompt()
|
| 551 |
|
| 552 |
-
|
|
|
|
| 553 |
|
| 554 |
inputs = tokenizer([prompt])
|
| 555 |
|
|
@@ -570,7 +570,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 570 |
do_sample=False,
|
| 571 |
num_beams = 1,
|
| 572 |
no_repeat_ngram_size = 20,
|
| 573 |
-
streamer=streamer,
|
| 574 |
max_new_tokens=4096,
|
| 575 |
stopping_criteria=[stopping_criteria]
|
| 576 |
)
|
|
@@ -715,7 +715,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 715 |
return processed_images
|
| 716 |
|
| 717 |
|
| 718 |
-
def chat_plus(self, tokenizer,
|
| 719 |
# Model
|
| 720 |
self.disable_torch_init()
|
| 721 |
multi_page=False
|
|
@@ -730,8 +730,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 730 |
|
| 731 |
image_list = []
|
| 732 |
|
| 733 |
-
if len(image_file_list)>1:
|
| 734 |
-
|
| 735 |
|
| 736 |
if multi_page:
|
| 737 |
qs = 'OCR with format across multi pages: '
|
|
@@ -739,19 +739,19 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 739 |
import glob
|
| 740 |
# from natsort import natsorted
|
| 741 |
# patches = glob.glob(image_file + '/*png')
|
| 742 |
-
patches =
|
| 743 |
# patches = natsorted(patches)
|
| 744 |
sub_images = []
|
| 745 |
for sub_image in patches:
|
| 746 |
sub_images.append(self.load_image(sub_image))
|
| 747 |
|
| 748 |
ll = len(patches)
|
| 749 |
-
print(patches)
|
| 750 |
-
print("len ll: ", ll)
|
| 751 |
|
| 752 |
else:
|
| 753 |
qs = 'OCR with format upon the patch reference: '
|
| 754 |
-
img = self.load_image(
|
| 755 |
sub_images = self.dynamic_preprocess(img)
|
| 756 |
ll = len(sub_images)
|
| 757 |
|
|
@@ -762,7 +762,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 762 |
|
| 763 |
image_list = torch.stack(image_list)
|
| 764 |
|
| 765 |
-
print('====new images batch size======: ',image_list.shape)
|
| 766 |
|
| 767 |
|
| 768 |
if use_im_start_end:
|
|
@@ -788,7 +788,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 788 |
conv.append_message(conv.roles[1], None)
|
| 789 |
prompt = conv.get_prompt()
|
| 790 |
|
| 791 |
-
|
|
|
|
| 792 |
|
| 793 |
inputs = tokenizer([prompt])
|
| 794 |
|
|
@@ -807,7 +808,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 807 |
do_sample=False,
|
| 808 |
num_beams = 1,
|
| 809 |
# no_repeat_ngram_size = 20,
|
| 810 |
-
streamer=streamer,
|
| 811 |
max_new_tokens=4096,
|
| 812 |
stopping_criteria=[stopping_criteria]
|
| 813 |
)
|
|
|
|
| 249 |
image_patches_features = []
|
| 250 |
for image_patch in image_patches:
|
| 251 |
image_p = torch.stack([image_patch])
|
| 252 |
+
|
| 253 |
with torch.set_grad_enabled(False):
|
| 254 |
cnn_feature_p = vision_tower_high(image_p)
|
| 255 |
cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
|
|
|
|
| 257 |
image_patches_features.append(image_feature_p)
|
| 258 |
image_feature = torch.cat(image_patches_features, dim=1)
|
| 259 |
image_features.append(image_feature)
|
|
|
|
| 260 |
|
| 261 |
|
| 262 |
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
|
|
|
| 484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 486 |
|
| 487 |
+
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False):
|
| 488 |
|
| 489 |
self.disable_torch_init()
|
| 490 |
|
|
|
|
| 548 |
conv.append_message(conv.roles[1], None)
|
| 549 |
prompt = conv.get_prompt()
|
| 550 |
|
| 551 |
+
if print_prompt:
|
| 552 |
+
print(prompt)
|
| 553 |
|
| 554 |
inputs = tokenizer([prompt])
|
| 555 |
|
|
|
|
| 570 |
do_sample=False,
|
| 571 |
num_beams = 1,
|
| 572 |
no_repeat_ngram_size = 20,
|
| 573 |
+
# streamer=streamer,
|
| 574 |
max_new_tokens=4096,
|
| 575 |
stopping_criteria=[stopping_criteria]
|
| 576 |
)
|
|
|
|
| 715 |
return processed_images
|
| 716 |
|
| 717 |
|
| 718 |
+
def chat_plus(self, tokenizer, image_file, render=False, save_render_file=None, print_prompt=False):
|
| 719 |
# Model
|
| 720 |
self.disable_torch_init()
|
| 721 |
multi_page=False
|
|
|
|
| 730 |
|
| 731 |
image_list = []
|
| 732 |
|
| 733 |
+
# if len(image_file_list)>1:
|
| 734 |
+
# multi_page = True
|
| 735 |
|
| 736 |
if multi_page:
|
| 737 |
qs = 'OCR with format across multi pages: '
|
|
|
|
| 739 |
import glob
|
| 740 |
# from natsort import natsorted
|
| 741 |
# patches = glob.glob(image_file + '/*png')
|
| 742 |
+
patches = image_file
|
| 743 |
# patches = natsorted(patches)
|
| 744 |
sub_images = []
|
| 745 |
for sub_image in patches:
|
| 746 |
sub_images.append(self.load_image(sub_image))
|
| 747 |
|
| 748 |
ll = len(patches)
|
| 749 |
+
# print(patches)
|
| 750 |
+
# print("len ll: ", ll)
|
| 751 |
|
| 752 |
else:
|
| 753 |
qs = 'OCR with format upon the patch reference: '
|
| 754 |
+
img = self.load_image(image_file)
|
| 755 |
sub_images = self.dynamic_preprocess(img)
|
| 756 |
ll = len(sub_images)
|
| 757 |
|
|
|
|
| 762 |
|
| 763 |
image_list = torch.stack(image_list)
|
| 764 |
|
| 765 |
+
print('====new images batch size======: \n',image_list.shape)
|
| 766 |
|
| 767 |
|
| 768 |
if use_im_start_end:
|
|
|
|
| 788 |
conv.append_message(conv.roles[1], None)
|
| 789 |
prompt = conv.get_prompt()
|
| 790 |
|
| 791 |
+
if print_prompt:
|
| 792 |
+
print(prompt)
|
| 793 |
|
| 794 |
inputs = tokenizer([prompt])
|
| 795 |
|
|
|
|
| 808 |
do_sample=False,
|
| 809 |
num_beams = 1,
|
| 810 |
# no_repeat_ngram_size = 20,
|
| 811 |
+
# streamer=streamer,
|
| 812 |
max_new_tokens=4096,
|
| 813 |
stopping_criteria=[stopping_criteria]
|
| 814 |
)
|