Update modeling_quasrav4.py
Browse files- modeling_quasrav4.py +19 -19
modeling_quasrav4.py
CHANGED
@@ -8,7 +8,7 @@ from transformers import PreTrainedModel
|
|
8 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
9 |
from transformers.utils import logging
|
10 |
|
11 |
-
from .configuration_quasrav4 import
|
12 |
|
13 |
logger = logging.get_logger(__name__)
|
14 |
|
@@ -41,7 +41,7 @@ class RotaryPositionEmbedding(nn.Module):
|
|
41 |
return self.apply_rotary_pos_emb(x, cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2))
|
42 |
|
43 |
class KernelFunction(nn.Module):
|
44 |
-
def __init__(self, config:
|
45 |
super().__init__()
|
46 |
self.kernel_type = config.kernel_type
|
47 |
self.epsilon = config.kernel_epsilon
|
@@ -59,7 +59,7 @@ class KernelFunction(nn.Module):
|
|
59 |
raise ValueError(f"Unknown kernel type: {self.kernel_type}")
|
60 |
|
61 |
class GatedFeedForward(nn.Module):
|
62 |
-
def __init__(self, config:
|
63 |
super().__init__()
|
64 |
self.hidden_size = config.hidden_size
|
65 |
self.intermediate_size = config.intermediate_size
|
@@ -80,7 +80,7 @@ class GatedFeedForward(nn.Module):
|
|
80 |
return hidden_states + residual
|
81 |
|
82 |
class LinearAttention(nn.Module):
|
83 |
-
def __init__(self, config:
|
84 |
super().__init__()
|
85 |
self.hidden_size = config.hidden_size
|
86 |
self.num_heads = config.num_attention_heads
|
@@ -129,8 +129,8 @@ class LinearAttention(nn.Module):
|
|
129 |
|
130 |
# --- Main Model Components ---
|
131 |
|
132 |
-
class
|
133 |
-
def __init__(self, config:
|
134 |
super().__init__()
|
135 |
self.embed_dim = config.hidden_size
|
136 |
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
@@ -145,8 +145,8 @@ class QuasraV4Layer(nn.Module):
|
|
145 |
hidden_states = self.ffn(hidden_states)
|
146 |
return (hidden_states,)
|
147 |
|
148 |
-
class
|
149 |
-
def __init__(self, config:
|
150 |
super().__init__()
|
151 |
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id if hasattr(config, 'pad_token_id') else 0)
|
152 |
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
@@ -166,11 +166,11 @@ class QuasraV4Embeddings(nn.Module):
|
|
166 |
embeddings = self.dropout(embeddings)
|
167 |
return embeddings
|
168 |
|
169 |
-
class
|
170 |
-
config_class =
|
171 |
-
base_model_prefix = "
|
172 |
supports_gradient_checkpointing = True
|
173 |
-
_no_split_modules = ["
|
174 |
|
175 |
def _init_weights(self, module):
|
176 |
std = self.config.initializer_range
|
@@ -183,12 +183,12 @@ class QuasraV4PreTrainedModel(PreTrainedModel):
|
|
183 |
if module.padding_idx is not None:
|
184 |
module.weight.data[module.padding_idx].zero_()
|
185 |
|
186 |
-
class
|
187 |
-
def __init__(self, config:
|
188 |
super().__init__(config)
|
189 |
self.config = config
|
190 |
-
self.embeddings =
|
191 |
-
self.layers = nn.ModuleList([
|
192 |
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
193 |
self.gradient_checkpointing = False
|
194 |
self.post_init()
|
@@ -213,12 +213,12 @@ class QuasraV4Model(QuasraV4PreTrainedModel):
|
|
213 |
return (hidden_states,)
|
214 |
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
|
215 |
|
216 |
-
class
|
217 |
_auto_class = "AutoModelForCausalLM"
|
218 |
|
219 |
-
def __init__(self, config:
|
220 |
super().__init__(config)
|
221 |
-
self.model =
|
222 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
223 |
self.post_init()
|
224 |
|
|
|
8 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
9 |
from transformers.utils import logging
|
10 |
|
11 |
+
from .configuration_quasrav4 import InfinityFormerConfig
|
12 |
|
13 |
logger = logging.get_logger(__name__)
|
14 |
|
|
|
41 |
return self.apply_rotary_pos_emb(x, cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2))
|
42 |
|
43 |
class KernelFunction(nn.Module):
|
44 |
+
def __init__(self, config: InfinityFormerConfig):
|
45 |
super().__init__()
|
46 |
self.kernel_type = config.kernel_type
|
47 |
self.epsilon = config.kernel_epsilon
|
|
|
59 |
raise ValueError(f"Unknown kernel type: {self.kernel_type}")
|
60 |
|
61 |
class GatedFeedForward(nn.Module):
|
62 |
+
def __init__(self, config: InfinityFormerConfig):
|
63 |
super().__init__()
|
64 |
self.hidden_size = config.hidden_size
|
65 |
self.intermediate_size = config.intermediate_size
|
|
|
80 |
return hidden_states + residual
|
81 |
|
82 |
class LinearAttention(nn.Module):
|
83 |
+
def __init__(self, config: InfinityFormerConfig, layer_idx: int = 0):
|
84 |
super().__init__()
|
85 |
self.hidden_size = config.hidden_size
|
86 |
self.num_heads = config.num_attention_heads
|
|
|
129 |
|
130 |
# --- Main Model Components ---
|
131 |
|
132 |
+
class InfinityFormerLayer(nn.Module):
|
133 |
+
def __init__(self, config: InfinityFormerConfig, layer_idx: int):
|
134 |
super().__init__()
|
135 |
self.embed_dim = config.hidden_size
|
136 |
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
|
145 |
hidden_states = self.ffn(hidden_states)
|
146 |
return (hidden_states,)
|
147 |
|
148 |
+
class InfinityFormerEmbeddings(nn.Module):
|
149 |
+
def __init__(self, config: InfinityFormerConfig):
|
150 |
super().__init__()
|
151 |
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id if hasattr(config, 'pad_token_id') else 0)
|
152 |
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
|
166 |
embeddings = self.dropout(embeddings)
|
167 |
return embeddings
|
168 |
|
169 |
+
class InfinityFormerPreTrainedModel(PreTrainedModel):
|
170 |
+
config_class = InfinityFormerConfig
|
171 |
+
base_model_prefix = "infinity_former"
|
172 |
supports_gradient_checkpointing = True
|
173 |
+
_no_split_modules = ["InfinityFormerLayer"]
|
174 |
|
175 |
def _init_weights(self, module):
|
176 |
std = self.config.initializer_range
|
|
|
183 |
if module.padding_idx is not None:
|
184 |
module.weight.data[module.padding_idx].zero_()
|
185 |
|
186 |
+
class InfinityFormerModel(InfinityFormerPreTrainedModel):
|
187 |
+
def __init__(self, config: InfinityFormerConfig):
|
188 |
super().__init__(config)
|
189 |
self.config = config
|
190 |
+
self.embeddings = InfinityFormerEmbeddings(config)
|
191 |
+
self.layers = nn.ModuleList([InfinityFormerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
192 |
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
193 |
self.gradient_checkpointing = False
|
194 |
self.post_init()
|
|
|
213 |
return (hidden_states,)
|
214 |
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None)
|
215 |
|
216 |
+
class InfinityFormerForCausalLM(InfinityFormerPreTrainedModel):
|
217 |
_auto_class = "AutoModelForCausalLM"
|
218 |
|
219 |
+
def __init__(self, config: InfinityFormerConfig):
|
220 |
super().__init__(config)
|
221 |
+
self.model = InfinityFormerModel(config)
|
222 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
223 |
self.post_init()
|
224 |
|