Commit
·
516b05d
1
Parent(s):
9b0fc76
support audio finetuning (#22)
Browse files- support audio finetuning (29a824ea848aa2a72b26706a1c641bc339a51869)
Co-authored-by: Zhangchi Feng <[email protected]>
- modeling_minicpmo.py +14 -1
modeling_minicpmo.py
CHANGED
|
@@ -466,7 +466,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 466 |
else:
|
| 467 |
return []
|
| 468 |
|
| 469 |
-
def get_audio_embedding(self, data, chunk_length=-1):
|
| 470 |
r"""
|
| 471 |
Extract full audio embeddings with optional chunk-based attention.
|
| 472 |
|
|
@@ -484,6 +484,8 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 484 |
Returns:
|
| 485 |
List[List[torch.Tensor]]: audio embeddings
|
| 486 |
"""
|
|
|
|
|
|
|
| 487 |
|
| 488 |
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
| 489 |
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
|
@@ -544,6 +546,17 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
| 544 |
idx += 1
|
| 545 |
final_audio_embeds.append(target_audio_embeds)
|
| 546 |
return final_audio_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
else:
|
| 548 |
return []
|
| 549 |
|
|
|
|
| 466 |
else:
|
| 467 |
return []
|
| 468 |
|
| 469 |
+
def get_audio_embedding(self, data, chunk_length=-1, dummy=True):
|
| 470 |
r"""
|
| 471 |
Extract full audio embeddings with optional chunk-based attention.
|
| 472 |
|
|
|
|
| 484 |
Returns:
|
| 485 |
List[List[torch.Tensor]]: audio embeddings
|
| 486 |
"""
|
| 487 |
+
dtype = self.apm.embed_positions.weight.dtype
|
| 488 |
+
device = self.apm.embed_positions.weight.device
|
| 489 |
|
| 490 |
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
|
| 491 |
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
|
|
|
|
| 546 |
idx += 1
|
| 547 |
final_audio_embeds.append(target_audio_embeds)
|
| 548 |
return final_audio_embeds
|
| 549 |
+
elif self.training and dummy:
|
| 550 |
+
dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
|
| 551 |
+
audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
|
| 552 |
+
|
| 553 |
+
audio_embeds = self.audio_projection_layer(audio_states)
|
| 554 |
+
|
| 555 |
+
audio_embeds = audio_embeds.transpose(1, 2)
|
| 556 |
+
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
| 557 |
+
audio_embeds = audio_embeds.transpose(1, 2)
|
| 558 |
+
return [audio_embeds]
|
| 559 |
+
|
| 560 |
else:
|
| 561 |
return []
|
| 562 |
|