marriola commited on
Commit
53f8c8f
·
verified ·
1 Parent(s): 38d71e1

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +2 -2
modeling_bd3lm.py CHANGED
@@ -599,10 +599,10 @@ class BD3LM(transformers.PreTrainedModel):
599
  'sampling_eps_max',
600
  torch.tensor(config.sampling_eps_max))
601
 
602
- def reset_kv_cache(self):
603
  for block in self.backbone.blocks:
604
  block.kv_cache = torch.zeros(
605
- self.config.loader.eval_batch_size,
606
  self.n,
607
  self.config.model.hidden_size * 3,
608
  device='cuda',
 
599
  'sampling_eps_max',
600
  torch.tensor(config.sampling_eps_max))
601
 
602
+ def reset_kv_cache(self, eval_batch_size=1):
603
  for block in self.backbone.blocks:
604
  block.kv_cache = torch.zeros(
605
+ eval_batch_size,
606
  self.n,
607
  self.config.model.hidden_size * 3,
608
  device='cuda',