maikezu commited on
Commit
3074dc3
·
verified ·
1 Parent(s): b503c5c

Fix truncated sequences in _convert for audio and speaker spans

Browse files

Extended the previous fix for _convert to also handle truncated audio and speaker spans.

Uses min(len(start), len(end)) for audio_bounds and spk_bounds to avoid runtime errors when max_inp_length truncates sequences.

Files changed (1) hide show
  1. processing_minicpmo.py +11 -4
processing_minicpmo.py CHANGED
@@ -278,16 +278,23 @@ class MiniCPMOProcessor(ProcessorMixin):
278
  ]
279
  )
280
 
 
281
  ## audio bound
282
  audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
283
  audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
284
- assert len(audio_start_idx) == len(audio_end_idx)
285
- audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
 
 
 
286
 
287
  spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
288
  spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
289
- assert len(spk_start_idx) == len(spk_end_idx)
290
- spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
 
 
 
291
 
292
  return input_ids, image_bounds, audio_bounds, spk_bounds
293
 
 
278
  ]
279
  )
280
 
281
+
282
  ## audio bound
283
  audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
284
  audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
285
+ valid_audio_nums = min(len(audio_start_idx), len(audio_end_idx))
286
+ audio_bounds = torch.hstack([
287
+ (audio_start_idx[:valid_audio_nums] + 1).unsqueeze(-1),
288
+ audio_end_idx[:valid_audio_nums].unsqueeze(-1)
289
+ ])
290
 
291
  spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
292
  spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
293
+ valid_spk_nums = min(len(spk_start_idx), len(spk_end_idx))
294
+ spk_bounds = torch.hstack([
295
+ (spk_start_idx[:valid_spk_nums] + 1).unsqueeze(-1),
296
+ spk_end_idx[:valid_spk_nums].unsqueeze(-1)
297
+ ])
298
 
299
  return input_ids, image_bounds, audio_bounds, spk_bounds
300