Handle truncated image boundaries in `_convert` to avoid tensor size mismatch
#54
by
maikezu
- opened
- processing_minicpmo.py +12 -5
processing_minicpmo.py
CHANGED
@@ -269,7 +269,7 @@ class MiniCPMOProcessor(ProcessorMixin):
|
|
269 |
image_start_idx += 1
|
270 |
image_end_idx = torch.where(end_cond)[0]
|
271 |
|
272 |
-
valid_image_nums =
|
273 |
|
274 |
image_bounds = torch.hstack(
|
275 |
[
|
@@ -278,16 +278,23 @@ class MiniCPMOProcessor(ProcessorMixin):
|
|
278 |
]
|
279 |
)
|
280 |
|
|
|
281 |
## audio bound
|
282 |
audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
|
283 |
audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
|
284 |
-
|
285 |
-
audio_bounds = torch.hstack([
|
|
|
|
|
|
|
286 |
|
287 |
spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
|
288 |
spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
|
289 |
-
|
290 |
-
spk_bounds = torch.hstack([
|
|
|
|
|
|
|
291 |
|
292 |
return input_ids, image_bounds, audio_bounds, spk_bounds
|
293 |
|
|
|
269 |
image_start_idx += 1
|
270 |
image_end_idx = torch.where(end_cond)[0]
|
271 |
|
272 |
+
valid_image_nums = min(len(image_start_idx), len(image_end_idx))
|
273 |
|
274 |
image_bounds = torch.hstack(
|
275 |
[
|
|
|
278 |
]
|
279 |
)
|
280 |
|
281 |
+
|
282 |
## audio bound
|
283 |
audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
|
284 |
audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
|
285 |
+
valid_audio_nums = min(len(audio_start_idx), len(audio_end_idx))
|
286 |
+
audio_bounds = torch.hstack([
|
287 |
+
(audio_start_idx[:valid_audio_nums] + 1).unsqueeze(-1),
|
288 |
+
audio_end_idx[:valid_audio_nums].unsqueeze(-1)
|
289 |
+
])
|
290 |
|
291 |
spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
|
292 |
spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
|
293 |
+
valid_spk_nums = min(len(spk_start_idx), len(spk_end_idx))
|
294 |
+
spk_bounds = torch.hstack([
|
295 |
+
(spk_start_idx[:valid_spk_nums] + 1).unsqueeze(-1),
|
296 |
+
spk_end_idx[:valid_spk_nums].unsqueeze(-1)
|
297 |
+
])
|
298 |
|
299 |
return input_ids, image_bounds, audio_bounds, spk_bounds
|
300 |
|