zhoukz commited on
Commit
c9ab6b8
·
1 Parent(s): dea8813

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -51,7 +51,7 @@ base_model:
51
 
52
  >>> import torch
53
  >>> with torch.no_grad():
54
- ... model_inputs = processor(text=text, audio=audio)
55
  ... generation = model.generate(**model_inputs)
56
  ... output = processor.batch_decode(generation, skip_special_tokens=True)
57
 
 
51
 
52
  >>> import torch
53
  >>> with torch.no_grad():
54
+ ... model_inputs = processor(text=text, audio=audio, sampling_rate=sr))
55
  ... generation = model.generate(**model_inputs)
56
  ... output = processor.batch_decode(generation, skip_special_tokens=True)
57
 
config.json CHANGED
@@ -37,15 +37,10 @@
37
  "AutoConfig": "configuration_midashenglm.MiAudioLLMHFConfig",
38
  "AutoModelForCausalLM": "modeling_midashenglm.DashengQwen25OmniModelInstruct"
39
  },
40
- "freeze": null,
41
- "gradient_checkpoint_decoder": false,
42
- "lora": null,
43
- "model": "DashengQwen25OmniModelInstruct",
44
  "model_type": "miaudiollm",
45
  "resize_tokenizer": false,
46
  "subsample_factor": 5,
47
- "text_model_config": {
48
- "_attn_implementation_autoset": true,
49
  "attention_dropout": 0.0,
50
  "hidden_act": "silu",
51
  "hidden_size": 2048,
 
37
  "AutoConfig": "configuration_midashenglm.MiAudioLLMHFConfig",
38
  "AutoModelForCausalLM": "modeling_midashenglm.DashengQwen25OmniModelInstruct"
39
  },
 
 
 
 
40
  "model_type": "miaudiollm",
41
  "resize_tokenizer": false,
42
  "subsample_factor": 5,
43
+ "text_config": {
 
44
  "attention_dropout": 0.0,
45
  "hidden_act": "silu",
46
  "hidden_size": 2048,
configuration_midashenglm.py CHANGED
@@ -1,5 +1,5 @@
1
  from ast import Dict
2
- from typing import Literal, Tuple, Union
3
 
4
  from transformers import PretrainedConfig
5
  from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
@@ -66,22 +66,16 @@ class MiAudioLLMHFConfig(PretrainedConfig):
66
 
67
  def __init__(
68
  self,
69
- model: str = "DashengQwen2ModelInstruct",
70
  audio_encoder_config: Dict = {},
71
- freeze: Literal["audio", "text"] | str | None = None,
72
- lora: Literal["encoder", "decoder"] | None = None,
73
  subsample_factor: int = 5,
74
- text_model_config: Dict = None,
75
  **kwargs,
76
  ):
77
- self.model = model
78
  self.audio_encoder_config = DashengConfig(**audio_encoder_config)
79
- self.freeze = freeze
80
- self.lora = lora
81
  self.subsample_factor = subsample_factor
82
- self.text_model_config = (
83
- Qwen2_5OmniTextConfig(**text_model_config)
84
- if text_model_config
85
  else Qwen2_5OmniTextConfig()
86
  )
87
  super().__init__(**kwargs)
 
1
  from ast import Dict
2
+ from typing import Tuple, Union
3
 
4
  from transformers import PretrainedConfig
5
  from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
 
66
 
67
  def __init__(
68
  self,
 
69
  audio_encoder_config: Dict = {},
 
 
70
  subsample_factor: int = 5,
71
+ text_config: Dict = None,
72
  **kwargs,
73
  ):
 
74
  self.audio_encoder_config = DashengConfig(**audio_encoder_config)
 
 
75
  self.subsample_factor = subsample_factor
76
+ self.text_config = (
77
+ Qwen2_5OmniTextConfig(**text_config)
78
+ if text_config
79
  else Qwen2_5OmniTextConfig()
80
  )
81
  super().__init__(**kwargs)
model.safetensors.index.json CHANGED
@@ -390,20 +390,20 @@
390
  "audio_encoder.freq_pos_embed": "model-00001-of-00002.safetensors",
391
  "audio_encoder.front_end.0.mel_scale.fb": "model-00001-of-00002.safetensors",
392
  "audio_encoder.front_end.0.spectrogram.window": "model-00001-of-00002.safetensors",
393
- "audio_encoder.init_bn.1.bias": "model-00001-of-00002.safetensors",
394
- "audio_encoder.init_bn.1.num_batches_tracked": "model-00001-of-00002.safetensors",
395
- "audio_encoder.init_bn.1.running_mean": "model-00001-of-00002.safetensors",
396
- "audio_encoder.init_bn.1.running_var": "model-00001-of-00002.safetensors",
397
- "audio_encoder.init_bn.1.weight": "model-00001-of-00002.safetensors",
398
  "audio_encoder.norm.bias": "model-00001-of-00002.safetensors",
399
  "audio_encoder.norm.weight": "model-00001-of-00002.safetensors",
400
  "audio_encoder.patch_embed.proj.bias": "model-00001-of-00002.safetensors",
401
  "audio_encoder.patch_embed.proj.weight": "model-00001-of-00002.safetensors",
402
  "audio_encoder.time_pos_embed": "model-00001-of-00002.safetensors",
403
- "audio_projector.net.0.bias": "model-00002-of-00002.safetensors",
404
- "audio_projector.net.0.weight": "model-00002-of-00002.safetensors",
405
- "audio_projector.net.2.bias": "model-00002-of-00002.safetensors",
406
- "audio_projector.net.2.weight": "model-00002-of-00002.safetensors",
407
  "decoder.lm_head.weight": "model-00002-of-00002.safetensors",
408
  "decoder.model.embed_tokens.weight": "model-00001-of-00002.safetensors",
409
  "decoder.model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
@@ -442,11 +442,11 @@
442
  "decoder.model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
443
  "decoder.model.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
444
  "decoder.model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
445
- "decoder.model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
446
- "decoder.model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
447
  "decoder.model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
448
- "decoder.model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
449
- "decoder.model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
450
  "decoder.model.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
451
  "decoder.model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
452
  "decoder.model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
 
390
  "audio_encoder.freq_pos_embed": "model-00001-of-00002.safetensors",
391
  "audio_encoder.front_end.0.mel_scale.fb": "model-00001-of-00002.safetensors",
392
  "audio_encoder.front_end.0.spectrogram.window": "model-00001-of-00002.safetensors",
393
+ "audio_encoder.init_bn.bias": "model-00001-of-00002.safetensors",
394
+ "audio_encoder.init_bn.num_batches_tracked": "model-00001-of-00002.safetensors",
395
+ "audio_encoder.init_bn.running_mean": "model-00001-of-00002.safetensors",
396
+ "audio_encoder.init_bn.running_var": "model-00001-of-00002.safetensors",
397
+ "audio_encoder.init_bn.weight": "model-00001-of-00002.safetensors",
398
  "audio_encoder.norm.bias": "model-00001-of-00002.safetensors",
399
  "audio_encoder.norm.weight": "model-00001-of-00002.safetensors",
400
  "audio_encoder.patch_embed.proj.bias": "model-00001-of-00002.safetensors",
401
  "audio_encoder.patch_embed.proj.weight": "model-00001-of-00002.safetensors",
402
  "audio_encoder.time_pos_embed": "model-00001-of-00002.safetensors",
403
+ "audio_projector.net.0.bias": "model-00001-of-00002.safetensors",
404
+ "audio_projector.net.0.weight": "model-00001-of-00002.safetensors",
405
+ "audio_projector.net.2.bias": "model-00001-of-00002.safetensors",
406
+ "audio_projector.net.2.weight": "model-00001-of-00002.safetensors",
407
  "decoder.lm_head.weight": "model-00002-of-00002.safetensors",
408
  "decoder.model.embed_tokens.weight": "model-00001-of-00002.safetensors",
409
  "decoder.model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
442
  "decoder.model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
443
  "decoder.model.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
444
  "decoder.model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
445
+ "decoder.model.layers.11.input_layernorm.weight": "model-00002-of-00002.safetensors",
446
+ "decoder.model.layers.11.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
447
  "decoder.model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
448
+ "decoder.model.layers.11.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
449
+ "decoder.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
450
  "decoder.model.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
451
  "decoder.model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
452
  "decoder.model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
modeling_midashenglm.py CHANGED
@@ -249,21 +249,12 @@ class Block(nn.Module):
249
  return x
250
 
251
 
252
- # TODO
253
- class RearranceReplace(nn.Module):
254
- def forward(self, x: torch.Tensor) -> torch.Tensor:
255
- # rearrange(x, "b c f t -> b f c t")
256
- # or
257
- # rearrange(x, "b f c t -> b c f t")
258
- return torch.permute(x, (0, 2, 1, 3))
259
 
 
 
260
 
261
- class AudioTransformer(nn.Module):
262
- def __init__(
263
- self,
264
- config: DashengConfig,
265
- ):
266
- super().__init__()
267
  self.target_length = config.target_length
268
  self.embed_dim = config.embed_dim
269
  self.hop_length = config.hop_length
@@ -282,13 +273,7 @@ class AudioTransformer(nn.Module):
282
  audio_transforms.AmplitudeToDB(top_db=120),
283
  )
284
 
285
- self.init_bn = nn.Sequential(
286
- # Rearrange("b c f t -> b f c t"),
287
- RearranceReplace(),
288
- nn.BatchNorm2d(config.n_mels, momentum=0.01),
289
- # Rearrange("b f c t -> b c f t"),
290
- RearranceReplace(),
291
- )
292
 
293
  self.patch_embed = AudioPatchEmbed(
294
  input_size=(config.n_mels, config.target_length),
@@ -327,6 +312,8 @@ class AudioTransformer(nn.Module):
327
  )
328
  self.norm = norm_layer(config.embed_dim)
329
 
 
 
330
  def forward_features(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
331
  t = x.shape[-1]
332
  x = x + self.time_pos_embed[:, :, :, :t]
@@ -357,7 +344,9 @@ class AudioTransformer(nn.Module):
357
  x = self.front_end(x)
358
  target_length_in_patches = self.target_length // 4
359
  x = x.unsqueeze(1)
 
360
  x = self.init_bn(x)
 
361
 
362
  x = self.patch_embed(x)
363
  t = x.shape[-1]
@@ -427,23 +416,21 @@ class DashengQwen25OmniModelInstructOutput(ModelOutput):
427
 
428
  class Decoder(PreTrainedModel, GenerationMixin):
429
  config_class = Qwen2_5OmniTextConfig
 
 
 
 
 
 
430
 
431
  def __init__(self, config: Qwen2_5OmniTextConfig):
432
  super().__init__(config)
433
- self.model = Qwen2_5OmniThinkerTextModel._from_config(
434
- config,
435
- attn_implementation="sdpa", # TODO
436
- )
437
  self.lm_head = nn.Linear(
438
  config.hidden_size,
439
  config.vocab_size,
440
  bias=False,
441
  )
442
- # TODO fix dtype
443
- self.lm_head.weight.data = self.lm_head.weight.data.to(
444
- self.model.embed_tokens.weight.dtype
445
- )
446
- # TODO tie weight?
447
  self.post_init()
448
 
449
  def forward(
@@ -481,30 +468,25 @@ class Decoder(PreTrainedModel, GenerationMixin):
481
 
482
  class DashengQwen25OmniModelInstruct(PreTrainedModel):
483
  config_class = MiAudioLLMHFConfig
 
 
 
 
 
 
484
 
485
  def __init__(self, config: MiAudioLLMHFConfig):
486
  super().__init__(config)
487
 
488
- freeze = config.freeze
489
- lora = config.lora
490
- subsample_factor = config.subsample_factor
491
-
492
- self.subsample_factor = subsample_factor
493
- self.lora = lora
494
- # Encoder part
495
- self.audio_encoder = AudioTransformer(config.audio_encoder_config)
496
- assert lora != "encoder"
497
-
498
- # decoder
499
- self.decoder = Decoder(config.text_model_config)
500
- assert lora != "decoder"
501
- assert freeze is None
502
-
503
- # audio projector
504
  self.audio_projector = AudioProjectorSubsample(
505
  self.audio_encoder.embed_dim,
506
- config.text_model_config.hidden_size,
507
- self.subsample_factor,
 
 
 
 
508
  )
509
 
510
  self.post_init()
 
249
  return x
250
 
251
 
252
+ class AudioTransformer(PreTrainedModel):
253
+ config_class = DashengConfig
 
 
 
 
 
254
 
255
+ def __init__(self, config: DashengConfig):
256
+ super().__init__(config)
257
 
 
 
 
 
 
 
258
  self.target_length = config.target_length
259
  self.embed_dim = config.embed_dim
260
  self.hop_length = config.hop_length
 
273
  audio_transforms.AmplitudeToDB(top_db=120),
274
  )
275
 
276
+ self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
 
 
 
 
 
 
277
 
278
  self.patch_embed = AudioPatchEmbed(
279
  input_size=(config.n_mels, config.target_length),
 
312
  )
313
  self.norm = norm_layer(config.embed_dim)
314
 
315
+ self.post_init()
316
+
317
  def forward_features(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
318
  t = x.shape[-1]
319
  x = x + self.time_pos_embed[:, :, :, :t]
 
344
  x = self.front_end(x)
345
  target_length_in_patches = self.target_length // 4
346
  x = x.unsqueeze(1)
347
+ x = torch.permute(x, (0, 2, 1, 3))
348
  x = self.init_bn(x)
349
+ x = torch.permute(x, (0, 2, 1, 3))
350
 
351
  x = self.patch_embed(x)
352
  t = x.shape[-1]
 
416
 
417
  class Decoder(PreTrainedModel, GenerationMixin):
418
  config_class = Qwen2_5OmniTextConfig
419
+ _supports_flash_attn_2 = Qwen2_5OmniThinkerTextModel._supports_flash_attn_2
420
+ _supports_sdpa = Qwen2_5OmniThinkerTextModel._supports_sdpa
421
+ _supports_flex_attn = Qwen2_5OmniThinkerTextModel._supports_flex_attn
422
+ _supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
423
+ _supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
424
+ _supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
425
 
426
  def __init__(self, config: Qwen2_5OmniTextConfig):
427
  super().__init__(config)
428
+ self.model = Qwen2_5OmniThinkerTextModel._from_config(config)
 
 
 
429
  self.lm_head = nn.Linear(
430
  config.hidden_size,
431
  config.vocab_size,
432
  bias=False,
433
  )
 
 
 
 
 
434
  self.post_init()
435
 
436
  def forward(
 
468
 
469
  class DashengQwen25OmniModelInstruct(PreTrainedModel):
470
  config_class = MiAudioLLMHFConfig
471
+ _supports_flash_attn_2 = Qwen2_5OmniThinkerTextModel._supports_flash_attn_2
472
+ _supports_sdpa = Qwen2_5OmniThinkerTextModel._supports_sdpa
473
+ _supports_flex_attn = Qwen2_5OmniThinkerTextModel._supports_flex_attn
474
+ _supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
475
+ _supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
476
+ _supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
477
 
478
  def __init__(self, config: MiAudioLLMHFConfig):
479
  super().__init__(config)
480
 
481
+ self.audio_encoder = AudioTransformer._from_config(config.audio_encoder_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
  self.audio_projector = AudioProjectorSubsample(
483
  self.audio_encoder.embed_dim,
484
+ config.text_config.hidden_size,
485
+ config.subsample_factor,
486
+ )
487
+ self.decoder = Decoder._from_config(
488
+ config.text_config,
489
+ attn_implementation=config._attn_implementation,
490
  )
491
 
492
  self.post_init()
processing_midashenglm.py CHANGED
@@ -55,32 +55,35 @@ class MiAudioLLMProcessor(ProcessorMixin):
55
  tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast | None = None,
56
  model_subsampling: int = 5,
57
  chat_template: str | None = None,
58
- # TODO 是否可以移除?
59
- audio_token: str = "<|AUDIO|>",
60
- audio_bos_token: str = "<|audio_bos|>",
61
- audio_eos_token: str = "<|audio_eos|>",
62
  ):
63
- if chat_template is None:
64
- chat_template = self.default_chat_template
65
  assert tokenizer is not None, "Tokenizer Needs to be passed"
66
- self.audio_token = (
67
- tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token
68
  )
69
- self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
70
- self.audio_bos_token = (
71
- tokenizer.audio_bos_token
72
- if hasattr(tokenizer, "audio_bos_token")
73
- else audio_bos_token
74
  )
75
- self.audio_eos_token = (
76
- tokenizer.audio_eos_token
77
- if hasattr(tokenizer, "audio_eos_token")
78
- else audio_eos_token
79
  )
 
 
 
 
 
 
 
 
80
  self.model_subsampling = model_subsampling
81
- # Fix Normalization
82
- if feature_extractor is not None and feature_extractor.do_normalize is True:
83
- feature_extractor.do_normalize = False
 
 
 
84
  super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
85
 
86
  def __call__(
 
55
  tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast | None = None,
56
  model_subsampling: int = 5,
57
  chat_template: str | None = None,
58
+ audio_token: str | None = None,
59
+ audio_bos_token: str | None = None,
60
+ audio_eos_token: str | None = None,
 
61
  ):
 
 
62
  assert tokenizer is not None, "Tokenizer Needs to be passed"
63
+ assert audio_token is not None or hasattr(tokenizer, "audio_token"), (
64
+ "Either `audio_token` must be provided or tokenizer must have `audio_token` attribute."
65
  )
66
+ assert audio_bos_token is not None or hasattr(tokenizer, "audio_bos_token"), (
67
+ "Either `audio_bos_token` must be provided or tokenizer must have `audio_bos_token` attribute."
 
 
 
68
  )
69
+ assert audio_eos_token is not None or hasattr(tokenizer, "audio_eos_token"), (
70
+ "Either `audio_eos_token` must be provided or tokenizer must have `audio_eos_token` attribute."
 
 
71
  )
72
+
73
+ if chat_template is None:
74
+ chat_template = self.default_chat_template
75
+
76
+ self.audio_token: str = audio_token or tokenizer.audio_token
77
+ self.audio_bos_token = audio_bos_token or tokenizer.audio_bos_token
78
+ self.audio_eos_token = audio_eos_token or tokenizer.audio_eos_token
79
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
80
  self.model_subsampling = model_subsampling
81
+
82
+ if feature_extractor is not None:
83
+ assert not feature_extractor.do_normalize, (
84
+ "This model does not use normalization. Please set `do_normalize=False` in the feature extractor."
85
+ )
86
+
87
  super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
88
 
89
  def __call__(