Update modeling_GOT.py
Browse files- modeling_GOT.py +50 -26
modeling_GOT.py
CHANGED
|
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 486 |
|
| 487 |
-
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False):
|
| 488 |
|
| 489 |
self.disable_torch_init()
|
| 490 |
|
|
@@ -565,18 +565,30 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 565 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 567 |
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
|
| 581 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 582 |
|
|
@@ -716,7 +728,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 716 |
return processed_images
|
| 717 |
|
| 718 |
|
| 719 |
-
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False):
|
| 720 |
# Model
|
| 721 |
self.disable_torch_init()
|
| 722 |
multi_page=False
|
|
@@ -807,18 +819,30 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 807 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 808 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 809 |
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 822 |
|
| 823 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 824 |
|
|
|
|
| 484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 486 |
|
| 487 |
+
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
| 488 |
|
| 489 |
self.disable_torch_init()
|
| 490 |
|
|
|
|
| 565 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 567 |
|
| 568 |
+
if stream_flag:
|
| 569 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 570 |
+
output_ids = self.generate(
|
| 571 |
+
input_ids,
|
| 572 |
+
images=[image_tensor_1.unsqueeze(0).half().cuda()],
|
| 573 |
+
do_sample=False,
|
| 574 |
+
num_beams = 1,
|
| 575 |
+
no_repeat_ngram_size = 20,
|
| 576 |
+
streamer=streamer,
|
| 577 |
+
max_new_tokens=4096,
|
| 578 |
+
stopping_criteria=[stopping_criteria]
|
| 579 |
+
)
|
| 580 |
+
else:
|
| 581 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 582 |
+
output_ids = self.generate(
|
| 583 |
+
input_ids,
|
| 584 |
+
images=[image_tensor_1.unsqueeze(0).half().cuda()],
|
| 585 |
+
do_sample=False,
|
| 586 |
+
num_beams = 1,
|
| 587 |
+
no_repeat_ngram_size = 20,
|
| 588 |
+
# streamer=streamer,
|
| 589 |
+
max_new_tokens=4096,
|
| 590 |
+
stopping_criteria=[stopping_criteria]
|
| 591 |
+
)
|
| 592 |
|
| 593 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 594 |
|
|
|
|
| 728 |
return processed_images
|
| 729 |
|
| 730 |
|
| 731 |
+
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
| 732 |
# Model
|
| 733 |
self.disable_torch_init()
|
| 734 |
multi_page=False
|
|
|
|
| 819 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 820 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 821 |
|
| 822 |
+
if stream_flag:
|
| 823 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 824 |
+
output_ids = self.generate(
|
| 825 |
+
input_ids,
|
| 826 |
+
images=[image_list.half().cuda()],
|
| 827 |
+
do_sample=False,
|
| 828 |
+
num_beams = 1,
|
| 829 |
+
# no_repeat_ngram_size = 20,
|
| 830 |
+
streamer=streamer,
|
| 831 |
+
max_new_tokens=4096,
|
| 832 |
+
stopping_criteria=[stopping_criteria]
|
| 833 |
+
)
|
| 834 |
+
else:
|
| 835 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 836 |
+
output_ids = self.generate(
|
| 837 |
+
input_ids,
|
| 838 |
+
images=[image_list.half().cuda()],
|
| 839 |
+
do_sample=False,
|
| 840 |
+
num_beams = 1,
|
| 841 |
+
# no_repeat_ngram_size = 20,
|
| 842 |
+
# streamer=streamer,
|
| 843 |
+
max_new_tokens=4096,
|
| 844 |
+
stopping_criteria=[stopping_criteria]
|
| 845 |
+
)
|
| 846 |
|
| 847 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 848 |
|