Added cached RoPE embedding
Browse files- modeling_phi4mm.py +21 -22
modeling_phi4mm.py
CHANGED
@@ -944,38 +944,37 @@ class Phi4MMLongRoPEScaledRotaryEmbedding(Phi4MMRotaryEmbedding):
|
|
944 |
self.short_factor = config.rope_scaling["short_factor"]
|
945 |
self.long_factor = config.rope_scaling["long_factor"]
|
946 |
self.original_max_position_embeddings = config.original_max_position_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
|
|
|
948 |
@torch.no_grad()
|
949 |
def forward(self, x, position_ids, seq_len=None):
|
950 |
-
|
951 |
if seq_len > self.original_max_position_embeddings:
|
952 |
-
|
953 |
else:
|
954 |
-
|
955 |
-
|
956 |
-
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
|
957 |
-
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
|
958 |
-
|
959 |
-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
960 |
position_ids_expanded = position_ids[:, None, :].float()
|
961 |
|
962 |
-
# Force float32 since bfloat16 loses precision on long contexts
|
963 |
-
# See https://github.com/huggingface/transformers/pull/29285
|
964 |
-
device_type = x.device.type
|
965 |
-
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
966 |
with torch.autocast(device_type=device_type, enabled=False):
|
967 |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
968 |
emb = torch.cat((freqs, freqs), dim=-1)
|
969 |
-
|
970 |
-
|
971 |
-
if scale <= 1.0:
|
972 |
-
scaling_factor = 1.0
|
973 |
-
else:
|
974 |
-
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
975 |
-
|
976 |
-
cos = emb.cos() * scaling_factor
|
977 |
-
sin = emb.sin() * scaling_factor
|
978 |
-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
979 |
|
980 |
|
981 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
|
|
944 |
self.short_factor = config.rope_scaling["short_factor"]
|
945 |
self.long_factor = config.rope_scaling["long_factor"]
|
946 |
self.original_max_position_embeddings = config.original_max_position_embeddings
|
947 |
+
self.long_inv_freq_expanded=self.seq_freq(self.long_factor)
|
948 |
+
self.short_inv_freq_expanded=self.seq_freq(self.short_factor)
|
949 |
+
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
950 |
+
if scale > 1.0:
|
951 |
+
self.scaling_factor=math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
952 |
+
else:
|
953 |
+
self.scaling_factor=1.0
|
954 |
+
|
955 |
+
|
956 |
+
def seq_freq(self, factor):
|
957 |
+
ext_factors = torch.tensor(factor, dtype=torch.float32, device='cuda')
|
958 |
+
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device='cuda').float() / self.dim
|
959 |
+
inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
|
960 |
+
inv_freq_expanded = inv_freq[None, :, None].float()
|
961 |
+
return inv_freq_expanded
|
962 |
|
963 |
+
########## INIT FUNCTION COMPUTES VARIABLES #####################
|
964 |
@torch.no_grad()
|
965 |
def forward(self, x, position_ids, seq_len=None):
|
966 |
+
device_type = x.device.type
|
967 |
if seq_len > self.original_max_position_embeddings:
|
968 |
+
inv_freq_expanded = self.long_inv_freq_expanded.expand(position_ids.shape[0], -1, 1).to(device_type)
|
969 |
else:
|
970 |
+
inv_freq_expanded = self.short_inv_freq_expanded.expand(position_ids.shape[0], -1, 1).to(device_type)
|
|
|
|
|
|
|
|
|
|
|
971 |
position_ids_expanded = position_ids[:, None, :].float()
|
972 |
|
|
|
|
|
|
|
|
|
973 |
with torch.autocast(device_type=device_type, enabled=False):
|
974 |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
975 |
emb = torch.cat((freqs, freqs), dim=-1)
|
976 |
+
cos = emb.cos() * self.scaling_factor
|
977 |
+
sin = emb.sin() * self.scaling_factor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
978 |
|
979 |
|
980 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|