Replace max_batch_size with batch_size for HybridCache (#3)
Browse files- Replace max_batch_size with batch_size for HybridCache (b7d1cc83f04ada2af8339e9e599693a9c494a4d5)
Co-authored-by: Peter Baylies <[email protected]>
- modeling_ovis.py +3 -3
modeling_ovis.py
CHANGED
|
@@ -552,14 +552,14 @@ class Ovis(OvisPreTrainedModel):
|
|
| 552 |
self.get_text_tokenizer().save_pretrained(save_directory)
|
| 553 |
self.get_visual_tokenizer().get_image_processor().save_pretrained(save_directory)
|
| 554 |
|
| 555 |
-
def _get_hybrid_cache_for_llm(self,
|
| 556 |
cache_cls = HybridCache
|
| 557 |
llm = self.get_llm()
|
| 558 |
|
| 559 |
need_new_cache = (
|
| 560 |
not hasattr(llm, "_cache")
|
| 561 |
or (not isinstance(llm._cache, cache_cls))
|
| 562 |
-
or llm._cache.
|
| 563 |
or llm._cache.max_cache_len < max_cache_len
|
| 564 |
)
|
| 565 |
|
|
@@ -570,7 +570,7 @@ class Ovis(OvisPreTrainedModel):
|
|
| 570 |
cache_dtype = llm.dtype
|
| 571 |
llm._cache = cache_cls(
|
| 572 |
config=llm.config,
|
| 573 |
-
|
| 574 |
max_cache_len=max_cache_len,
|
| 575 |
device=llm.device,
|
| 576 |
dtype=cache_dtype,
|
|
|
|
| 552 |
self.get_text_tokenizer().save_pretrained(save_directory)
|
| 553 |
self.get_visual_tokenizer().get_image_processor().save_pretrained(save_directory)
|
| 554 |
|
| 555 |
+
def _get_hybrid_cache_for_llm(self, batch_size: int, max_cache_len: int):
|
| 556 |
cache_cls = HybridCache
|
| 557 |
llm = self.get_llm()
|
| 558 |
|
| 559 |
need_new_cache = (
|
| 560 |
not hasattr(llm, "_cache")
|
| 561 |
or (not isinstance(llm._cache, cache_cls))
|
| 562 |
+
or llm._cache.batch_size != batch_size
|
| 563 |
or llm._cache.max_cache_len < max_cache_len
|
| 564 |
)
|
| 565 |
|
|
|
|
| 570 |
cache_dtype = llm.dtype
|
| 571 |
llm._cache = cache_cls(
|
| 572 |
config=llm.config,
|
| 573 |
+
batch_size=batch_size,
|
| 574 |
max_cache_len=max_cache_len,
|
| 575 |
device=llm.device,
|
| 576 |
dtype=cache_dtype,
|