Does it support num_beams>1 ?
#9
by
Yehor
- opened
Seems like it crashes when I try to use num_beams=10:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
/tmp/ipython-input-12-626539799.py in <cell line: 0>()
12 ).to(model.device)
13
---> 14 output = model.generate(
15 input_ids,
16 # do_sample=True,
4 frames
/usr/local/lib/python3.11/dist-packages/transformers/models/lfm2/modeling_lfm2.py in reorder_cache(self, beam_idx)
204 for layer_idx in range(len(self.key_cache)):
205 device = self.key_cache[layer_idx].device
--> 206 self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
207 device = self.value_cache[layer_idx].device
208 self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
IndexError: index out of range in self