Update raven_modeling_minimal.py
Browse files- raven_modeling_minimal.py +14 -13
raven_modeling_minimal.py
CHANGED
@@ -23,6 +23,16 @@ import torch.nn.functional as F
|
|
23 |
from transformers import GenerationConfig
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
class RavenPreTrainedModel(PreTrainedModel):
|
27 |
config_class = RavenConfig
|
28 |
base_model_prefix = "model"
|
@@ -37,18 +47,9 @@ class RavenPreTrainedModel(PreTrainedModel):
|
|
37 |
_supports_static_cache = True
|
38 |
_tp_plan = {}
|
39 |
|
40 |
-
@cache
|
41 |
-
def _init_func(self, dim, num_layers):
|
42 |
-
return {
|
43 |
-
"std": math.sqrt(2 / (5 * dim)),
|
44 |
-
"out_proj": math.sqrt(2 / (5 * dim)) / math.sqrt(2 * num_layers),
|
45 |
-
"embedding": math.sqrt(2 / (5 * dim)),
|
46 |
-
"embed_scale": math.sqrt(dim),
|
47 |
-
}
|
48 |
-
|
49 |
@property
|
50 |
def emb_scale(self):
|
51 |
-
return
|
52 |
|
53 |
def _normal_(self, tensor, std):
|
54 |
return torch.nn.init.trunc_normal_(tensor, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
@@ -86,7 +87,7 @@ class RavenPreTrainedModel(PreTrainedModel):
|
|
86 |
|
87 |
@torch.no_grad()
|
88 |
def _init_weights(self, module):
|
89 |
-
_init_values =
|
90 |
name = self._full_name_of_module_lookup[id(module)]
|
91 |
if isinstance(module, RMSNorm):
|
92 |
torch.nn.init.ones_(module.weight)
|
@@ -703,14 +704,14 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
703 |
loss = torch.nn.functional.cross_entropy(
|
704 |
logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
|
705 |
)
|
706 |
-
log_ppl = loss.clone().detach()
|
707 |
else:
|
708 |
logits = self.lm_head(x).float()
|
709 |
loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
|
710 |
|
711 |
return CausalLMOutputRecurrentLatents(
|
712 |
loss=loss,
|
713 |
-
log_ppl=log_ppl,
|
714 |
logits=logits if output_details["return_logits"] else None,
|
715 |
past_key_values=past_key_values,
|
716 |
hidden_states=x if output_details["return_head"] else None,
|
|
|
23 |
from transformers import GenerationConfig
|
24 |
|
25 |
|
26 |
+
@cache
|
27 |
+
def _init_func(dim, num_layers) -> dict[str, float]:
|
28 |
+
return {
|
29 |
+
"std": math.sqrt(2 / (5 * dim)),
|
30 |
+
"out_proj": math.sqrt(2 / (5 * dim)) / math.sqrt(2 * num_layers),
|
31 |
+
"embedding": math.sqrt(2 / (5 * dim)),
|
32 |
+
"embed_scale": math.sqrt(dim),
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
class RavenPreTrainedModel(PreTrainedModel):
|
37 |
config_class = RavenConfig
|
38 |
base_model_prefix = "model"
|
|
|
47 |
_supports_static_cache = True
|
48 |
_tp_plan = {}
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
@property
|
51 |
def emb_scale(self):
|
52 |
+
return _init_func(self.config.n_embd, self.config.effective_expected_depth)["embed_scale"]
|
53 |
|
54 |
def _normal_(self, tensor, std):
|
55 |
return torch.nn.init.trunc_normal_(tensor, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
|
|
87 |
|
88 |
@torch.no_grad()
|
89 |
def _init_weights(self, module):
|
90 |
+
_init_values = _init_func(self.config.n_embd, self.config.effective_expected_depth)
|
91 |
name = self._full_name_of_module_lookup[id(module)]
|
92 |
if isinstance(module, RMSNorm):
|
93 |
torch.nn.init.ones_(module.weight)
|
|
|
704 |
loss = torch.nn.functional.cross_entropy(
|
705 |
logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
|
706 |
)
|
707 |
+
log_ppl = loss.clone().detach()
|
708 |
else:
|
709 |
logits = self.lm_head(x).float()
|
710 |
loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
|
711 |
|
712 |
return CausalLMOutputRecurrentLatents(
|
713 |
loss=loss,
|
714 |
+
log_ppl=log_ppl, # this value is returned only for compatibility reasons. For this model loss=log-ppl
|
715 |
logits=logits if output_details["return_logits"] else None,
|
716 |
past_key_values=past_key_values,
|
717 |
hidden_states=x if output_details["return_head"] else None,
|