| | """ PyTorch Wav2Vec2-Ebranchformer model.""" |
| |
|
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.utils.checkpoint |
| | from torch import nn |
| | from transformers.activations import ACT2FN |
| | from transformers.models.wav2vec2.modeling_wav2vec2 import ( |
| | Wav2Vec2Config, |
| | Wav2Vec2ForCTC, |
| | Wav2Vec2ForPreTraining, |
| | ) |
| | from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
| | Wav2Vec2ConformerConfig, |
| | Wav2Vec2ConformerEncoder, |
| | ) |
| | from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
| | Wav2Vec2ConformerFeedForward as Wav2Vec2EBranchformerFeedForward, |
| | ) |
| | from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
| | Wav2Vec2ConformerModel, |
| | ) |
| | from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
| | Wav2Vec2ConformerSelfAttention as Wav2Vec2EBranchformerSelfAttention, |
| | ) |
| | from transformers.utils import logging |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class Wav2Vec2EBranchformerConfig(Wav2Vec2ConformerConfig, Wav2Vec2Config): |
| | """Config for EBranhformer model extending conformer.""" |
| |
|
| | model_type = "wav2vec2-ebranchformer" |
| |
|
| | def __init__( |
| | self, |
| | ebranchformer_conv_dropout=0.1, |
| | csgu_activation="identity", |
| | csgu_kernel_size=31, |
| | csgu_use_linear_after_conv=False, |
| | merge_conv_kernel=31, |
| | use_macaron_ff=True, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | |
| | self.csgu_kernel_size = csgu_kernel_size |
| | self.csgu_activation = csgu_activation |
| | self.csgu_conv_dropout = ebranchformer_conv_dropout |
| | self.csgu_use_linear_after_conv = csgu_use_linear_after_conv |
| | self.merge_conv_kernel = merge_conv_kernel |
| | self.use_macaron_ff = use_macaron_ff |
| |
|
| |
|
| | class ConvolutionalSpatialGatingUnit(torch.nn.Module): |
| | """Convolutional Spatial Gating Unit (CSGU).""" |
| |
|
| | def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| | super().__init__() |
| |
|
| | n_channels = config.intermediate_size // 2 |
| | self.norm = torch.nn.LayerNorm(n_channels) |
| | self.conv = torch.nn.Conv1d( |
| | n_channels, |
| | n_channels, |
| | config.csgu_kernel_size, |
| | 1, |
| | (config.csgu_kernel_size - 1) // 2, |
| | groups=n_channels, |
| | ) |
| | if config.csgu_use_linear_after_conv: |
| | self.linear = torch.nn.Linear(n_channels, n_channels) |
| | else: |
| | self.linear = None |
| |
|
| | if config.csgu_activation == "identity": |
| | self.act = torch.nn.Identity() |
| | else: |
| | self.act = ACT2FN[config.csgu_activation] |
| |
|
| | self.dropout = torch.nn.Dropout(config.csgu_conv_dropout) |
| |
|
| | def forward(self, hidden_states: torch.FloatTensor): |
| | """Forward method |
| | |
| | Args: |
| | hidden_states (torch.Tensor): (N, T, D) |
| | |
| | Returns: |
| | out (torch.Tensor): (N, T, D/2) |
| | """ |
| |
|
| | x_r, x_g = hidden_states.chunk(2, dim=-1) |
| |
|
| | x_g = self.norm(x_g) |
| | x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) |
| | if self.linear is not None: |
| | x_g = self.linear(x_g) |
| |
|
| | x_g = self.act(x_g) |
| | hidden_states = x_r * x_g |
| | hidden_states = self.dropout(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class ConvolutionalGatingMLP(torch.nn.Module): |
| | """Convolutional Gating MLP (cgMLP).""" |
| |
|
| | def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| | super().__init__() |
| | self.channel_proj1 = torch.nn.Sequential( |
| | torch.nn.Linear(config.hidden_size, config.intermediate_size), torch.nn.GELU() |
| | ) |
| | self.csgu = ConvolutionalSpatialGatingUnit(config) |
| | self.channel_proj2 = torch.nn.Linear(config.intermediate_size // 2, config.hidden_size) |
| |
|
| | def forward(self, hidden_states: torch.FloatTensor): |
| | hidden_states = self.channel_proj1(hidden_states) |
| | hidden_states = self.csgu(hidden_states) |
| | hidden_states = self.channel_proj2(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class Wav2Vec2EBranchformerEncoderLayer(nn.Module): |
| | def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| | super().__init__() |
| | embed_dim = config.hidden_size |
| | dropout = config.attention_dropout |
| |
|
| | |
| | if config.use_macaron_ff: |
| | self.ff1 = nn.Sequential(nn.LayerNorm(embed_dim), Wav2Vec2EBranchformerFeedForward(config)) |
| |
|
| | |
| | self.self_attn_layer_norm = nn.LayerNorm(embed_dim) |
| | self.self_attn_dropout = torch.nn.Dropout(dropout) |
| | self.self_attn = Wav2Vec2EBranchformerSelfAttention(config) |
| |
|
| | |
| | self.cgMLP = ConvolutionalGatingMLP(config) |
| | self.cgMLP_layer_norm = nn.LayerNorm(config.hidden_size) |
| | self.cgMLP_dropout = torch.nn.Dropout(dropout) |
| |
|
| | |
| | self.final_dropout = torch.nn.Dropout(dropout) |
| | self.merge_proj = torch.nn.Linear(embed_dim + embed_dim, embed_dim) |
| | self.depthwise_conv_fusion = torch.nn.Conv1d( |
| | embed_dim + embed_dim, |
| | embed_dim + embed_dim, |
| | kernel_size=config.merge_conv_kernel, |
| | stride=1, |
| | padding=(config.merge_conv_kernel - 1) // 2, |
| | groups=embed_dim + embed_dim, |
| | bias=True, |
| | ) |
| | self.final_layer_norm = nn.LayerNorm(embed_dim) |
| |
|
| | |
| | if config.use_macaron_ff: |
| | self.ff2 = nn.Sequential(nn.LayerNorm(embed_dim), Wav2Vec2EBranchformerFeedForward(config)) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | relative_position_embeddings: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | ): |
| | |
| | if self.ff1: |
| | residual = hidden_states |
| | hidden_states = residual + 0.5 * self.ff1(hidden_states) |
| |
|
| | |
| | residual = hidden_states |
| | global_branch = hidden_states |
| | local_branch = hidden_states |
| |
|
| | |
| | global_branch = self.self_attn_layer_norm(global_branch) |
| | global_branch, attn_weigts = self.self_attn( |
| | hidden_states=global_branch, |
| | attention_mask=attention_mask, |
| | relative_position_embeddings=relative_position_embeddings, |
| | output_attentions=output_attentions, |
| | ) |
| | global_branch = self.self_attn_dropout(global_branch) |
| |
|
| | |
| | local_branch = self.cgMLP_layer_norm(local_branch) |
| | local_branch = self.cgMLP(local_branch) |
| |
|
| | |
| | |
| | hidden_states = torch.cat([global_branch, local_branch], dim=-1) |
| | merge_residual = hidden_states |
| | |
| | hidden_states = merge_residual + self.depthwise_conv_fusion(hidden_states.transpose(1, 2)).transpose(1, 2) |
| | |
| | hidden_states = self.final_dropout(self.merge_proj(hidden_states)) |
| |
|
| | |
| | hidden_states = residual + hidden_states |
| |
|
| | |
| | if self.ff2: |
| | residual = hidden_states |
| | hidden_states = residual + 0.5 * self.ff2(hidden_states) |
| |
|
| | |
| | hidden_states = self.final_layer_norm(hidden_states) |
| | return hidden_states, attn_weigts |
| |
|
| |
|
| | class Wav2Vec2EBranchformerEncoder(Wav2Vec2ConformerEncoder): |
| | def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| | super().__init__(config) |
| | self.layers = nn.ModuleList( |
| | [Wav2Vec2EBranchformerEncoderLayer(config) for _ in range(config.num_hidden_layers)] |
| | ) |
| | self.pos_conv_embed = None |
| |
|
| |
|
| | class Wav2Vec2EBranchformerModel(Wav2Vec2ConformerModel): |
| | def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| | super().__init__(config) |
| | self.encoder = Wav2Vec2EBranchformerEncoder(config) |
| |
|
| | |
| | self.post_init() |
| |
|
| |
|
| | class Wav2Vec2EBranchformerForPreTraining(Wav2Vec2ForPreTraining): |
| | config_class = Wav2Vec2EBranchformerConfig |
| | base_model_prefix = "wav2vec2" |
| |
|
| | def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| | super().__init__(config) |
| | self.wav2vec2 = Wav2Vec2EBranchformerModel(config) |
| | self.post_init() |
| |
|
| |
|
| | class Wav2Vec2EBranchformerForCTC(Wav2Vec2ForCTC): |
| | config_class = Wav2Vec2EBranchformerConfig |
| | base_model_prefix = "wav2vec2" |
| |
|
| | def __init__(self, config: Wav2Vec2EBranchformerConfig): |
| | super().__init__(config) |
| | self.wav2vec2 = Wav2Vec2EBranchformerModel(config) |
| | self.post_init() |
| |
|