Commit
·
98dd963
1
Parent(s):
6fcf8b4
Add streaming support for text generation
Browse files- Implemented streaming functionality for real-time text output.
- Added `_decode_stream` method to handle text streaming.
- Updated `chat` method to support streaming mode.
- Adjusted code to process and yield text in chunks for better responsiveness.
This update enhances the user experience by allowing incremental text generation and display.
- modeling_minicpm.py +64 -8
modeling_minicpm.py
CHANGED
|
@@ -22,12 +22,14 @@ import math
|
|
| 22 |
import warnings
|
| 23 |
from typing import List, Optional, Tuple, Union, Dict
|
| 24 |
|
|
|
|
| 25 |
import torch
|
| 26 |
import torch.nn.functional as F
|
| 27 |
import torch.utils.checkpoint
|
| 28 |
from torch import nn
|
| 29 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 30 |
|
|
|
|
| 31 |
from transformers.activations import ACT2FN
|
| 32 |
from transformers.cache_utils import Cache, DynamicCache
|
| 33 |
from transformers.modeling_attn_mask_utils import (
|
|
@@ -1248,6 +1250,9 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
|
|
| 1248 |
self.vocab_size = config.vocab_size
|
| 1249 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1250 |
|
|
|
|
|
|
|
|
|
|
| 1251 |
# Initialize weights and apply final processing
|
| 1252 |
self.post_init()
|
| 1253 |
|
|
@@ -1426,11 +1431,52 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
|
|
| 1426 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1427 |
)
|
| 1428 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1429 |
|
| 1430 |
@torch.inference_mode()
|
| 1431 |
-
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
| 1432 |
-
|
| 1433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1434 |
if history is None:
|
| 1435 |
history = []
|
| 1436 |
if logits_processor:
|
|
@@ -1443,12 +1489,22 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
|
|
| 1443 |
history.append({"role": role, "content": query})
|
| 1444 |
history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
|
| 1445 |
inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
|
| 1446 |
-
outputs = self.generate(**inputs, **gen_kwargs)
|
| 1447 |
-
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
| 1448 |
-
response = tokenizer.decode(outputs)
|
| 1449 |
-
history.append({"role": "assistant", "content": response})
|
| 1450 |
-
return response, history
|
| 1451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1452 |
|
| 1453 |
@add_start_docstrings(
|
| 1454 |
"""
|
|
|
|
| 22 |
import warnings
|
| 23 |
from typing import List, Optional, Tuple, Union, Dict
|
| 24 |
|
| 25 |
+
from threading import Thread
|
| 26 |
import torch
|
| 27 |
import torch.nn.functional as F
|
| 28 |
import torch.utils.checkpoint
|
| 29 |
from torch import nn
|
| 30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 31 |
|
| 32 |
+
from transformers import TextIteratorStreamer
|
| 33 |
from transformers.activations import ACT2FN
|
| 34 |
from transformers.cache_utils import Cache, DynamicCache
|
| 35 |
from transformers.modeling_attn_mask_utils import (
|
|
|
|
| 1250 |
self.vocab_size = config.vocab_size
|
| 1251 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1252 |
|
| 1253 |
+
# List of terminator tokens used to indicate the end of a sequence or conversation.
|
| 1254 |
+
self.terminators = ['</s>', '<|im_end|>']
|
| 1255 |
+
|
| 1256 |
# Initialize weights and apply final processing
|
| 1257 |
self.post_init()
|
| 1258 |
|
|
|
|
| 1431 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 1432 |
)
|
| 1433 |
return reordered_past
|
| 1434 |
+
|
| 1435 |
+
# Internal function to handle streaming of generated text using TextIteratorStreamer.
|
| 1436 |
+
def _decode_stream(self, input_ids, tokenizer, **kwargs):
|
| 1437 |
+
# Convert terminators to token IDs
|
| 1438 |
+
terminators_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
| 1439 |
+
# Initialize TextIteratorStreamer for handling streaming output
|
| 1440 |
+
streamer = TextIteratorStreamer(tokenizer=tokenizer,skip_prompt=True, skip_special_tokens=True)
|
| 1441 |
+
# Set up generation parameters, including input IDs, eos token IDs, and streamer
|
| 1442 |
+
generation_kwargs = {
|
| 1443 |
+
'input_ids': input_ids,
|
| 1444 |
+
'eos_token_id': terminators_ids,
|
| 1445 |
+
'streamer': streamer
|
| 1446 |
+
}
|
| 1447 |
+
generation_kwargs.update(kwargs)
|
| 1448 |
+
# Run the generation task in a separate thread to enable streaming output
|
| 1449 |
+
thread = Thread(target=self.generate, kwargs=generation_kwargs)
|
| 1450 |
+
thread.start()
|
| 1451 |
+
# Return the streamer instance for later access to streamed text
|
| 1452 |
+
return streamer
|
| 1453 |
+
|
| 1454 |
|
| 1455 |
@torch.inference_mode()
|
| 1456 |
+
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", max_length: int = 4096, num_beams=1,
|
| 1457 |
+
do_sample=True, logits_processor=None, stream=False, top_p=0.8, temperature=0.3, **kwargs):
|
| 1458 |
+
"""
|
| 1459 |
+
Main function for handling dialogue generation based on the input query and history.
|
| 1460 |
+
|
| 1461 |
+
Parameters:
|
| 1462 |
+
- tokenizer: Tokenizer instance used for encoding and decoding.
|
| 1463 |
+
- query: The user input query string.
|
| 1464 |
+
- history: Dialogue history, a list of dictionaries where each dictionary contains role and content.
|
| 1465 |
+
- role: The current role, default is "user".
|
| 1466 |
+
- max_length: Maximum length of the generated text.
|
| 1467 |
+
- num_beams: Number of beams for beam search.
|
| 1468 |
+
- do_sample: Whether to use sampling for generation.
|
| 1469 |
+
- logits_processor: Function for processing logits (if any).
|
| 1470 |
+
- stream: Whether to use streaming output.
|
| 1471 |
+
- top_p: Nucleus sampling parameter.
|
| 1472 |
+
- temperature: Temperature parameter for generation.
|
| 1473 |
+
- **kwargs: Additional arguments for generation.
|
| 1474 |
+
|
| 1475 |
+
Returns:
|
| 1476 |
+
- If stream is True, returns a generator function to get the generated text incrementally.
|
| 1477 |
+
- If stream is False, returns the complete generated response string.
|
| 1478 |
+
"""
|
| 1479 |
+
|
| 1480 |
if history is None:
|
| 1481 |
history = []
|
| 1482 |
if logits_processor:
|
|
|
|
| 1489 |
history.append({"role": role, "content": query})
|
| 1490 |
history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
|
| 1491 |
inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1492 |
|
| 1493 |
+
if stream:
|
| 1494 |
+
res = self._decode_stream(inputs["input_ids"], tokenizer, **gen_kwargs)
|
| 1495 |
+
def stream_gen():
|
| 1496 |
+
for text in res:
|
| 1497 |
+
# Remove terminators from the text
|
| 1498 |
+
for term in self.terminators:
|
| 1499 |
+
text = text.replace(term, '')
|
| 1500 |
+
yield text
|
| 1501 |
+
return stream_gen()
|
| 1502 |
+
|
| 1503 |
+
else:
|
| 1504 |
+
outputs = self.generate(**inputs, **gen_kwargs)
|
| 1505 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
| 1506 |
+
response = tokenizer.decode(outputs)
|
| 1507 |
+
return response
|
| 1508 |
|
| 1509 |
@add_start_docstrings(
|
| 1510 |
"""
|