support stream and sysyem prompt
Browse files- README.md +19 -2
- modeling_minicpmv.py +44 -4
README.md
CHANGED
|
@@ -377,10 +377,27 @@ res = model.chat(
|
|
| 377 |
image=image,
|
| 378 |
msgs=msgs,
|
| 379 |
tokenizer=tokenizer,
|
| 380 |
-
sampling=True,
|
| 381 |
-
temperature=0.7
|
|
|
|
| 382 |
)
|
| 383 |
print(res)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
```
|
| 385 |
|
| 386 |
Please look at [GitHub](https://github.com/OpenBMB/MiniCPM-V) for more detail about usage.
|
|
|
|
| 377 |
image=image,
|
| 378 |
msgs=msgs,
|
| 379 |
tokenizer=tokenizer,
|
| 380 |
+
sampling=True, # if sampling=False, beam_search will be used by default
|
| 381 |
+
temperature=0.7,
|
| 382 |
+
# system_prompt='' # pass system_prompt if needed
|
| 383 |
)
|
| 384 |
print(res)
|
| 385 |
+
|
| 386 |
+
## if you want to use streaming, please make sure sampling=True and stream=True
|
| 387 |
+
## the model.chat will return a generator
|
| 388 |
+
res = model.chat(
|
| 389 |
+
image=image,
|
| 390 |
+
msgs=msgs,
|
| 391 |
+
tokenizer=tokenizer,
|
| 392 |
+
sampling=True,
|
| 393 |
+
temperature=0.7,
|
| 394 |
+
stream=True
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
generated_text = ""
|
| 398 |
+
for new_text in res:
|
| 399 |
+
generated_text += new_text
|
| 400 |
+
print(new_text, flush=True, end='')
|
| 401 |
```
|
| 402 |
|
| 403 |
Please look at [GitHub](https://github.com/OpenBMB/MiniCPM-V) for more detail about usage.
|
modeling_minicpmv.py
CHANGED
|
@@ -3,10 +3,11 @@ from typing import List, Optional
|
|
| 3 |
import json
|
| 4 |
import torch
|
| 5 |
import torchvision
|
|
|
|
| 6 |
from copy import deepcopy
|
| 7 |
from PIL import Image
|
| 8 |
from torchvision import transforms
|
| 9 |
-
from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast
|
| 10 |
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
|
| 11 |
|
| 12 |
from .configuration_minicpm import MiniCPMVConfig
|
|
@@ -218,6 +219,25 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 218 |
**kwargs
|
| 219 |
)
|
| 220 |
return self._decode_text(output, tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
def _decode_text(self, result_ids, tokenizer):
|
| 223 |
result_text = []
|
|
@@ -294,6 +314,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 294 |
max_inp_length: Optional[int] = None,
|
| 295 |
vision_hidden_states=None,
|
| 296 |
return_vision_hidden_states=False,
|
|
|
|
| 297 |
**kwargs
|
| 298 |
):
|
| 299 |
|
|
@@ -326,7 +347,10 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 326 |
vision_hidden_states,
|
| 327 |
) = self.get_vllm_embedding(model_inputs)
|
| 328 |
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
if return_vision_hidden_states:
|
| 332 |
return result, vision_hidden_states
|
|
@@ -342,6 +366,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 342 |
max_new_tokens=1024,
|
| 343 |
sampling=True,
|
| 344 |
max_inp_length=2048,
|
|
|
|
|
|
|
| 345 |
**kwargs
|
| 346 |
):
|
| 347 |
if isinstance(msgs, str):
|
|
@@ -349,6 +375,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 349 |
|
| 350 |
copy_msgs = deepcopy(msgs)
|
| 351 |
assert len(copy_msgs) > 0, 'msgs is empty'
|
|
|
|
| 352 |
|
| 353 |
if image is not None and isinstance(copy_msgs[0]['content'], str):
|
| 354 |
copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
|
|
@@ -393,6 +420,10 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 393 |
if tgt_sizes:
|
| 394 |
tgt_sizes = torch.vstack(tgt_sizes)
|
| 395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
|
| 397 |
|
| 398 |
if sampling:
|
|
@@ -423,11 +454,20 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 423 |
max_new_tokens=max_new_tokens,
|
| 424 |
vision_hidden_states=vision_hidden_states,
|
| 425 |
return_vision_hidden_states=True,
|
|
|
|
| 426 |
**generation_config
|
| 427 |
)
|
| 428 |
-
answer = res[0]
|
| 429 |
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
|
| 433 |
class PreTrainedTokenizerFastWrapper(PreTrainedTokenizerFast):
|
|
|
|
| 3 |
import json
|
| 4 |
import torch
|
| 5 |
import torchvision
|
| 6 |
+
from threading import Thread
|
| 7 |
from copy import deepcopy
|
| 8 |
from PIL import Image
|
| 9 |
from torchvision import transforms
|
| 10 |
+
from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast, TextIteratorStreamer
|
| 11 |
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
|
| 12 |
|
| 13 |
from .configuration_minicpm import MiniCPMVConfig
|
|
|
|
| 219 |
**kwargs
|
| 220 |
)
|
| 221 |
return self._decode_text(output, tokenizer)
|
| 222 |
+
|
| 223 |
+
def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
|
| 224 |
+
terminators = [
|
| 225 |
+
tokenizer.eos_token_id,
|
| 226 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
| 227 |
+
]
|
| 228 |
+
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
| 229 |
+
generation_kwargs = {
|
| 230 |
+
'inputs_embeds': inputs_embeds,
|
| 231 |
+
'pad_token_id': 0,
|
| 232 |
+
'eos_token_id': terminators,
|
| 233 |
+
'streamer': streamer
|
| 234 |
+
}
|
| 235 |
+
generation_kwargs.update(kwargs)
|
| 236 |
+
|
| 237 |
+
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
|
| 238 |
+
thread.start()
|
| 239 |
+
|
| 240 |
+
return streamer
|
| 241 |
|
| 242 |
def _decode_text(self, result_ids, tokenizer):
|
| 243 |
result_text = []
|
|
|
|
| 314 |
max_inp_length: Optional[int] = None,
|
| 315 |
vision_hidden_states=None,
|
| 316 |
return_vision_hidden_states=False,
|
| 317 |
+
stream=False,
|
| 318 |
**kwargs
|
| 319 |
):
|
| 320 |
|
|
|
|
| 347 |
vision_hidden_states,
|
| 348 |
) = self.get_vllm_embedding(model_inputs)
|
| 349 |
|
| 350 |
+
if stream:
|
| 351 |
+
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
|
| 352 |
+
else:
|
| 353 |
+
result = self._decode(model_inputs["inputs_embeds"], tokenizer, **kwargs)
|
| 354 |
|
| 355 |
if return_vision_hidden_states:
|
| 356 |
return result, vision_hidden_states
|
|
|
|
| 366 |
max_new_tokens=1024,
|
| 367 |
sampling=True,
|
| 368 |
max_inp_length=2048,
|
| 369 |
+
system_prompt='',
|
| 370 |
+
stream=False,
|
| 371 |
**kwargs
|
| 372 |
):
|
| 373 |
if isinstance(msgs, str):
|
|
|
|
| 375 |
|
| 376 |
copy_msgs = deepcopy(msgs)
|
| 377 |
assert len(copy_msgs) > 0, 'msgs is empty'
|
| 378 |
+
assert sampling or not stream, 'if use stream mode, make sure sampling=True'
|
| 379 |
|
| 380 |
if image is not None and isinstance(copy_msgs[0]['content'], str):
|
| 381 |
copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
|
|
|
|
| 420 |
if tgt_sizes:
|
| 421 |
tgt_sizes = torch.vstack(tgt_sizes)
|
| 422 |
|
| 423 |
+
if system_prompt:
|
| 424 |
+
sys_msg = {'role': 'system', 'content': system_prompt}
|
| 425 |
+
copy_msgs = [sys_msg] + copy_msgs
|
| 426 |
+
|
| 427 |
input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
|
| 428 |
|
| 429 |
if sampling:
|
|
|
|
| 454 |
max_new_tokens=max_new_tokens,
|
| 455 |
vision_hidden_states=vision_hidden_states,
|
| 456 |
return_vision_hidden_states=True,
|
| 457 |
+
stream=stream,
|
| 458 |
**generation_config
|
| 459 |
)
|
|
|
|
| 460 |
|
| 461 |
+
if stream:
|
| 462 |
+
def stream_gen():
|
| 463 |
+
for text in res:
|
| 464 |
+
text = text.replace(tokenizer.eot_token, '').replace(tokenizer.eos_token, '')
|
| 465 |
+
yield text
|
| 466 |
+
return stream_gen()
|
| 467 |
+
|
| 468 |
+
else:
|
| 469 |
+
answer = res[0]
|
| 470 |
+
return answer
|
| 471 |
|
| 472 |
|
| 473 |
class PreTrainedTokenizerFastWrapper(PreTrainedTokenizerFast):
|