Update modeling_GOT.py
Browse files- modeling_GOT.py +13 -14
modeling_GOT.py
CHANGED
|
@@ -575,17 +575,16 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 575 |
stopping_criteria=[stopping_criteria]
|
| 576 |
)
|
| 577 |
|
| 578 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
if render:
|
| 580 |
print('==============rendering===============')
|
| 581 |
from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
|
| 582 |
|
| 583 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 584 |
-
|
| 585 |
-
if outputs.endswith(stop_str):
|
| 586 |
-
outputs = outputs[:-len(stop_str)]
|
| 587 |
-
outputs = outputs.strip()
|
| 588 |
-
|
| 589 |
if '**kern' in outputs:
|
| 590 |
import verovio
|
| 591 |
from cairosvg import svg2png
|
|
@@ -813,16 +812,16 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 813 |
max_new_tokens=4096,
|
| 814 |
stopping_criteria=[stopping_criteria]
|
| 815 |
)
|
| 816 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 817 |
if render:
|
| 818 |
print('==============rendering===============')
|
| 819 |
from .render_tools import content_mmd_to_html
|
| 820 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 821 |
-
|
| 822 |
-
if outputs.endswith(stop_str):
|
| 823 |
-
outputs = outputs[:-len(stop_str)]
|
| 824 |
-
outputs = outputs.strip()
|
| 825 |
-
|
| 826 |
html_path_2 = save_render_file
|
| 827 |
right_num = outputs.count('\\right')
|
| 828 |
left_num = outputs.count('\left')
|
|
|
|
| 575 |
stopping_criteria=[stopping_criteria]
|
| 576 |
)
|
| 577 |
|
| 578 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 579 |
+
|
| 580 |
+
if outputs.endswith(stop_str):
|
| 581 |
+
outputs = outputs[:-len(stop_str)]
|
| 582 |
+
outputs = outputs.strip()
|
| 583 |
+
|
| 584 |
if render:
|
| 585 |
print('==============rendering===============')
|
| 586 |
from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
|
| 587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
if '**kern' in outputs:
|
| 589 |
import verovio
|
| 590 |
from cairosvg import svg2png
|
|
|
|
| 812 |
max_new_tokens=4096,
|
| 813 |
stopping_criteria=[stopping_criteria]
|
| 814 |
)
|
| 815 |
+
|
| 816 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 817 |
+
|
| 818 |
+
if outputs.endswith(stop_str):
|
| 819 |
+
outputs = outputs[:-len(stop_str)]
|
| 820 |
+
outputs = outputs.strip()
|
| 821 |
+
|
| 822 |
if render:
|
| 823 |
print('==============rendering===============')
|
| 824 |
from .render_tools import content_mmd_to_html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
html_path_2 = save_render_file
|
| 826 |
right_num = outputs.count('\\right')
|
| 827 |
left_num = outputs.count('\left')
|