JonasGeiping commited on
Commit
bb6621b
·
verified ·
1 Parent(s): 06ac94c

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +14 -13
raven_modeling_minimal.py CHANGED
@@ -23,6 +23,16 @@ import torch.nn.functional as F
23
  from transformers import GenerationConfig
24
 
25
 
 
 
 
 
 
 
 
 
 
 
26
  class RavenPreTrainedModel(PreTrainedModel):
27
  config_class = RavenConfig
28
  base_model_prefix = "model"
@@ -37,18 +47,9 @@ class RavenPreTrainedModel(PreTrainedModel):
37
  _supports_static_cache = True
38
  _tp_plan = {}
39
 
40
- @cache
41
- def _init_func(self, dim, num_layers):
42
- return {
43
- "std": math.sqrt(2 / (5 * dim)),
44
- "out_proj": math.sqrt(2 / (5 * dim)) / math.sqrt(2 * num_layers),
45
- "embedding": math.sqrt(2 / (5 * dim)),
46
- "embed_scale": math.sqrt(dim),
47
- }
48
-
49
  @property
50
  def emb_scale(self):
51
- return self._init_func(self.config.n_embd, self.config.effective_expected_depth)["embed_scale"]
52
 
53
  def _normal_(self, tensor, std):
54
  return torch.nn.init.trunc_normal_(tensor, mean=0.0, std=std, a=-3 * std, b=3 * std)
@@ -86,7 +87,7 @@ class RavenPreTrainedModel(PreTrainedModel):
86
 
87
  @torch.no_grad()
88
  def _init_weights(self, module):
89
- _init_values = self._init_func(self.config.n_embd, self.config.effective_expected_depth)
90
  name = self._full_name_of_module_lookup[id(module)]
91
  if isinstance(module, RMSNorm):
92
  torch.nn.init.ones_(module.weight)
@@ -703,14 +704,14 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
703
  loss = torch.nn.functional.cross_entropy(
704
  logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
705
  )
706
- log_ppl = loss.clone().detach().exp()
707
  else:
708
  logits = self.lm_head(x).float()
709
  loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
710
 
711
  return CausalLMOutputRecurrentLatents(
712
  loss=loss,
713
- log_ppl=log_ppl,
714
  logits=logits if output_details["return_logits"] else None,
715
  past_key_values=past_key_values,
716
  hidden_states=x if output_details["return_head"] else None,
 
23
  from transformers import GenerationConfig
24
 
25
 
26
+ @cache
27
+ def _init_func(dim, num_layers) -> dict[str, float]:
28
+ return {
29
+ "std": math.sqrt(2 / (5 * dim)),
30
+ "out_proj": math.sqrt(2 / (5 * dim)) / math.sqrt(2 * num_layers),
31
+ "embedding": math.sqrt(2 / (5 * dim)),
32
+ "embed_scale": math.sqrt(dim),
33
+ }
34
+
35
+
36
  class RavenPreTrainedModel(PreTrainedModel):
37
  config_class = RavenConfig
38
  base_model_prefix = "model"
 
47
  _supports_static_cache = True
48
  _tp_plan = {}
49
 
 
 
 
 
 
 
 
 
 
50
  @property
51
  def emb_scale(self):
52
+ return _init_func(self.config.n_embd, self.config.effective_expected_depth)["embed_scale"]
53
 
54
  def _normal_(self, tensor, std):
55
  return torch.nn.init.trunc_normal_(tensor, mean=0.0, std=std, a=-3 * std, b=3 * std)
 
87
 
88
  @torch.no_grad()
89
  def _init_weights(self, module):
90
+ _init_values = _init_func(self.config.n_embd, self.config.effective_expected_depth)
91
  name = self._full_name_of_module_lookup[id(module)]
92
  if isinstance(module, RMSNorm):
93
  torch.nn.init.ones_(module.weight)
 
704
  loss = torch.nn.functional.cross_entropy(
705
  logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100
706
  )
707
+ log_ppl = loss.clone().detach()
708
  else:
709
  logits = self.lm_head(x).float()
710
  loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
711
 
712
  return CausalLMOutputRecurrentLatents(
713
  loss=loss,
714
+ log_ppl=log_ppl, # this value is returned only for compatibility reasons. For this model loss=log-ppl
715
  logits=logits if output_details["return_logits"] else None,
716
  past_key_values=past_key_values,
717
  hidden_states=x if output_details["return_head"] else None,