Does it support num_beams>1 ?

#9
by Yehor - opened

Seems like it crashes when I try to use num_beams=10:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipython-input-12-626539799.py in <cell line: 0>()
     12   ).to(model.device)
     13 
---> 14   output = model.generate(
     15       input_ids,
     16       # do_sample=True,

4 frames
/usr/local/lib/python3.11/dist-packages/transformers/models/lfm2/modeling_lfm2.py in reorder_cache(self, beam_idx)
    204         for layer_idx in range(len(self.key_cache)):
    205             device = self.key_cache[layer_idx].device
--> 206             self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
    207             device = self.value_cache[layer_idx].device
    208             self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

IndexError: index out of range in self

Sign up or log in to comment