Mayank Mishra
commited on
Commit
·
6bb0180
1
Parent(s):
448e236
update script
Browse files- modeling_granite.py +12 -10
modeling_granite.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
|
|
| 1 |
import numbers
|
|
|
|
| 2 |
from enum import Enum
|
| 3 |
from typing import Optional, Tuple, Union
|
| 4 |
|
|
@@ -846,7 +848,7 @@ class GranitePreTrainedModel(PreTrainedModel):
|
|
| 846 |
self.initializer_range = config.initializer_range
|
| 847 |
|
| 848 |
def _init_weights(self, module: nn.Module) -> None:
|
| 849 |
-
if isinstance(module, (nn.LayerNorm, RMSNorm, RoPE)):
|
| 850 |
module.reset_parameters()
|
| 851 |
elif isinstance(module, nn.Linear):
|
| 852 |
nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
|
|
@@ -1104,15 +1106,15 @@ class GraniteModel(GranitePreTrainedModel):
|
|
| 1104 |
|
| 1105 |
def _prepare_a_bunch_of_stuff(
|
| 1106 |
self,
|
| 1107 |
-
input_ids: torch.Tensor
|
| 1108 |
-
past_key_values: DynamicCache
|
| 1109 |
-
attention_mask: torch.Tensor
|
| 1110 |
-
token_type_ids: torch.Tensor
|
| 1111 |
-
position_ids: torch.Tensor
|
| 1112 |
-
inputs_embeds: torch.Tensor
|
| 1113 |
-
use_cache: bool
|
| 1114 |
-
output_hidden_states: bool
|
| 1115 |
-
return_dict: bool
|
| 1116 |
) -> Tuple[
|
| 1117 |
bool,
|
| 1118 |
bool,
|
|
|
|
| 1 |
+
import math
|
| 2 |
import numbers
|
| 3 |
+
import warnings
|
| 4 |
from enum import Enum
|
| 5 |
from typing import Optional, Tuple, Union
|
| 6 |
|
|
|
|
| 848 |
self.initializer_range = config.initializer_range
|
| 849 |
|
| 850 |
def _init_weights(self, module: nn.Module) -> None:
|
| 851 |
+
if isinstance(module, (nn.LayerNorm, RMSNorm, Alibi, RoPE)):
|
| 852 |
module.reset_parameters()
|
| 853 |
elif isinstance(module, nn.Linear):
|
| 854 |
nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
|
|
|
|
| 1106 |
|
| 1107 |
def _prepare_a_bunch_of_stuff(
|
| 1108 |
self,
|
| 1109 |
+
input_ids: torch.Tensor,
|
| 1110 |
+
past_key_values: DynamicCache,
|
| 1111 |
+
attention_mask: torch.Tensor,
|
| 1112 |
+
token_type_ids: torch.Tensor,
|
| 1113 |
+
position_ids: torch.Tensor,
|
| 1114 |
+
inputs_embeds: torch.Tensor,
|
| 1115 |
+
use_cache: bool,
|
| 1116 |
+
output_hidden_states: bool,
|
| 1117 |
+
return_dict: bool,
|
| 1118 |
) -> Tuple[
|
| 1119 |
bool,
|
| 1120 |
bool,
|