duzx16
commited on
Commit
·
4d0fc39
1
Parent(s):
658202d
Update decode method in tokenizer
Browse files- tokenization_chatglm.py +20 -7
tokenization_chatglm.py
CHANGED
|
@@ -31,6 +31,9 @@ class TextTokenizer:
|
|
| 31 |
def tokenize(self, text):
|
| 32 |
return self.sp.EncodeAsPieces(text)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
def convert_tokens_to_ids(self, tokens):
|
| 35 |
return [self.sp.PieceToId(token) for token in tokens]
|
| 36 |
|
|
@@ -111,16 +114,25 @@ class SPTokenizer:
|
|
| 111 |
tokens = [x + self.num_image_tokens for x in tmp]
|
| 112 |
return tokens if add_dummy_prefix else tokens[2:]
|
| 113 |
|
| 114 |
-
def
|
| 115 |
-
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
| 116 |
-
ids = [_id for _id in ids if _id >= 0]
|
| 117 |
-
text = self._get_text_tokenizer().decode(ids)
|
| 118 |
text = text.replace("<n>", "\n")
|
| 119 |
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
| 120 |
for i in range(2, self.max_blank_length + 1):
|
| 121 |
text = text.replace(self.get_blank_token(i), " " * i)
|
| 122 |
return text
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
def tokenize(
|
| 125 |
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
| 126 |
) -> List[str]:
|
|
@@ -256,11 +268,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 256 |
|
| 257 |
return seq
|
| 258 |
|
|
|
|
|
|
|
|
|
|
| 259 |
def _decode(
|
| 260 |
self,
|
| 261 |
token_ids: Union[int, List[int]],
|
| 262 |
-
skip_special_tokens: bool = False,
|
| 263 |
-
clean_up_tokenization_spaces: bool = True,
|
| 264 |
**kwargs
|
| 265 |
) -> str:
|
| 266 |
if isinstance(token_ids, int):
|
|
@@ -269,7 +282,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 269 |
return ""
|
| 270 |
if self.pad_token_id in token_ids: # remove pad
|
| 271 |
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
| 272 |
-
return
|
| 273 |
|
| 274 |
def _convert_token_to_id(self, token):
|
| 275 |
""" Converts a token (str) in an id using the vocab. """
|
|
|
|
| 31 |
def tokenize(self, text):
|
| 32 |
return self.sp.EncodeAsPieces(text)
|
| 33 |
|
| 34 |
+
def convert_tokens_to_string(self, tokens):
|
| 35 |
+
return self.sp.DecodePieces(tokens)
|
| 36 |
+
|
| 37 |
def convert_tokens_to_ids(self, tokens):
|
| 38 |
return [self.sp.PieceToId(token) for token in tokens]
|
| 39 |
|
|
|
|
| 114 |
tokens = [x + self.num_image_tokens for x in tmp]
|
| 115 |
return tokens if add_dummy_prefix else tokens[2:]
|
| 116 |
|
| 117 |
+
def postprocess(self, text):
|
|
|
|
|
|
|
|
|
|
| 118 |
text = text.replace("<n>", "\n")
|
| 119 |
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
| 120 |
for i in range(2, self.max_blank_length + 1):
|
| 121 |
text = text.replace(self.get_blank_token(i), " " * i)
|
| 122 |
return text
|
| 123 |
|
| 124 |
+
def decode(self, text_ids: List[int]) -> str:
|
| 125 |
+
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
| 126 |
+
ids = [_id for _id in ids if _id >= 0]
|
| 127 |
+
text = self._get_text_tokenizer().decode(ids)
|
| 128 |
+
text = self.postprocess(text)
|
| 129 |
+
return text
|
| 130 |
+
|
| 131 |
+
def decode_tokens(self, tokens: List[str]) -> str:
|
| 132 |
+
text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
|
| 133 |
+
text = self.postprocess(text)
|
| 134 |
+
return text
|
| 135 |
+
|
| 136 |
def tokenize(
|
| 137 |
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
| 138 |
) -> List[str]:
|
|
|
|
| 268 |
|
| 269 |
return seq
|
| 270 |
|
| 271 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 272 |
+
return self.sp_tokenizer.decode_tokens(tokens)
|
| 273 |
+
|
| 274 |
def _decode(
|
| 275 |
self,
|
| 276 |
token_ids: Union[int, List[int]],
|
|
|
|
|
|
|
| 277 |
**kwargs
|
| 278 |
) -> str:
|
| 279 |
if isinstance(token_ids, int):
|
|
|
|
| 282 |
return ""
|
| 283 |
if self.pad_token_id in token_ids: # remove pad
|
| 284 |
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
| 285 |
+
return super()._decode(token_ids, **kwargs)
|
| 286 |
|
| 287 |
def _convert_token_to_id(self, token):
|
| 288 |
""" Converts a token (str) in an id using the vocab. """
|