Upload BD3LM
Browse files- modeling_bd3lm.py +2 -2
modeling_bd3lm.py
CHANGED
@@ -599,10 +599,10 @@ class BD3LM(transformers.PreTrainedModel):
|
|
599 |
'sampling_eps_max',
|
600 |
torch.tensor(config.sampling_eps_max))
|
601 |
|
602 |
-
def reset_kv_cache(self):
|
603 |
for block in self.backbone.blocks:
|
604 |
block.kv_cache = torch.zeros(
|
605 |
-
|
606 |
self.n,
|
607 |
self.config.model.hidden_size * 3,
|
608 |
device='cuda',
|
|
|
599 |
'sampling_eps_max',
|
600 |
torch.tensor(config.sampling_eps_max))
|
601 |
|
602 |
+
def reset_kv_cache(self, eval_batch_size=1):
|
603 |
for block in self.backbone.blocks:
|
604 |
block.kv_cache = torch.zeros(
|
605 |
+
eval_batch_size,
|
606 |
self.n,
|
607 |
self.config.model.hidden_size * 3,
|
608 |
device='cuda',
|