Handle truncated image boundaries in `_convert` to avoid tensor size mismatch

#54
by maikezu - opened
Files changed (1) hide show
  1. processing_minicpmo.py +12 -5
processing_minicpmo.py CHANGED
@@ -269,7 +269,7 @@ class MiniCPMOProcessor(ProcessorMixin):
269
  image_start_idx += 1
270
  image_end_idx = torch.where(end_cond)[0]
271
 
272
- valid_image_nums = max(len(image_start_idx), len(image_end_idx))
273
 
274
  image_bounds = torch.hstack(
275
  [
@@ -278,16 +278,23 @@ class MiniCPMOProcessor(ProcessorMixin):
278
  ]
279
  )
280
 
 
281
  ## audio bound
282
  audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
283
  audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
284
- assert len(audio_start_idx) == len(audio_end_idx)
285
- audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
 
 
 
286
 
287
  spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
288
  spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
289
- assert len(spk_start_idx) == len(spk_end_idx)
290
- spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
 
 
 
291
 
292
  return input_ids, image_bounds, audio_bounds, spk_bounds
293
 
 
269
  image_start_idx += 1
270
  image_end_idx = torch.where(end_cond)[0]
271
 
272
+ valid_image_nums = min(len(image_start_idx), len(image_end_idx))
273
 
274
  image_bounds = torch.hstack(
275
  [
 
278
  ]
279
  )
280
 
281
+
282
  ## audio bound
283
  audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
284
  audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
285
+ valid_audio_nums = min(len(audio_start_idx), len(audio_end_idx))
286
+ audio_bounds = torch.hstack([
287
+ (audio_start_idx[:valid_audio_nums] + 1).unsqueeze(-1),
288
+ audio_end_idx[:valid_audio_nums].unsqueeze(-1)
289
+ ])
290
 
291
  spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
292
  spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
293
+ valid_spk_nums = min(len(spk_start_idx), len(spk_end_idx))
294
+ spk_bounds = torch.hstack([
295
+ (spk_start_idx[:valid_spk_nums] + 1).unsqueeze(-1),
296
+ spk_end_idx[:valid_spk_nums].unsqueeze(-1)
297
+ ])
298
 
299
  return input_ids, image_bounds, audio_bounds, spk_bounds
300