duzx16
commited on
Commit
·
d2bbc82
1
Parent(s):
2449bdc
Fix Chinese punctuation
Browse files- modeling_chatglm.py +18 -4
modeling_chatglm.py
CHANGED
|
@@ -4,6 +4,7 @@ import math
|
|
| 4 |
import copy
|
| 5 |
import os
|
| 6 |
import warnings
|
|
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
import torch.utils.checkpoint
|
|
@@ -1085,6 +1086,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1085 |
for layer_past in past
|
| 1086 |
)
|
| 1087 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1088 |
@torch.no_grad()
|
| 1089 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
| 1090 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
@@ -1107,8 +1123,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1107 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
| 1108 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
| 1109 |
response = tokenizer.decode(outputs)
|
| 1110 |
-
response =
|
| 1111 |
-
response = response.replace("[[训练时间]]", "2023年")
|
| 1112 |
history = history + [(query, response)]
|
| 1113 |
return response, history
|
| 1114 |
|
|
@@ -1134,8 +1149,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1134 |
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
| 1135 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
| 1136 |
response = tokenizer.decode(outputs)
|
| 1137 |
-
response =
|
| 1138 |
-
response = response.replace("[[训练时间]]", "2023年")
|
| 1139 |
new_history = history + [(query, response)]
|
| 1140 |
yield response, new_history
|
| 1141 |
|
|
|
|
| 4 |
import copy
|
| 5 |
import os
|
| 6 |
import warnings
|
| 7 |
+
import re
|
| 8 |
|
| 9 |
import torch
|
| 10 |
import torch.utils.checkpoint
|
|
|
|
| 1086 |
for layer_past in past
|
| 1087 |
)
|
| 1088 |
|
| 1089 |
+
def process_response(self, response):
|
| 1090 |
+
response = response.strip()
|
| 1091 |
+
response = response.replace("[[训练时间]]", "2023年")
|
| 1092 |
+
punkts = [
|
| 1093 |
+
[",", ","],
|
| 1094 |
+
["!", "!"],
|
| 1095 |
+
[":", ":"],
|
| 1096 |
+
[";", ";"],
|
| 1097 |
+
["\?", "?"],
|
| 1098 |
+
]
|
| 1099 |
+
for item in punkts:
|
| 1100 |
+
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
| 1101 |
+
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
| 1102 |
+
return response
|
| 1103 |
+
|
| 1104 |
@torch.no_grad()
|
| 1105 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
| 1106 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
|
|
| 1123 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
| 1124 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
| 1125 |
response = tokenizer.decode(outputs)
|
| 1126 |
+
response = self.process_response(response)
|
|
|
|
| 1127 |
history = history + [(query, response)]
|
| 1128 |
return response, history
|
| 1129 |
|
|
|
|
| 1149 |
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
| 1150 |
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
| 1151 |
response = tokenizer.decode(outputs)
|
| 1152 |
+
response = self.process_response(response)
|
|
|
|
| 1153 |
new_history = history + [(query, response)]
|
| 1154 |
yield response, new_history
|
| 1155 |
|