lakshmi97 commited on
Commit
3bda35a
·
1 Parent(s): 73158cd

Added cached RoPE embedding

Browse files
Files changed (1) hide show
  1. modeling_phi4mm.py +21 -22
modeling_phi4mm.py CHANGED
@@ -944,38 +944,37 @@ class Phi4MMLongRoPEScaledRotaryEmbedding(Phi4MMRotaryEmbedding):
944
  self.short_factor = config.rope_scaling["short_factor"]
945
  self.long_factor = config.rope_scaling["long_factor"]
946
  self.original_max_position_embeddings = config.original_max_position_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
 
 
948
  @torch.no_grad()
949
  def forward(self, x, position_ids, seq_len=None):
950
- seq_len = seq_len or torch.max(position_ids) + 1
951
  if seq_len > self.original_max_position_embeddings:
952
- ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
953
  else:
954
- ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
955
-
956
- inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
957
- self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
958
-
959
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
960
  position_ids_expanded = position_ids[:, None, :].float()
961
 
962
- # Force float32 since bfloat16 loses precision on long contexts
963
- # See https://github.com/huggingface/transformers/pull/29285
964
- device_type = x.device.type
965
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
966
  with torch.autocast(device_type=device_type, enabled=False):
967
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
968
  emb = torch.cat((freqs, freqs), dim=-1)
969
-
970
- scale = self.max_position_embeddings / self.original_max_position_embeddings
971
- if scale <= 1.0:
972
- scaling_factor = 1.0
973
- else:
974
- scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
975
-
976
- cos = emb.cos() * scaling_factor
977
- sin = emb.sin() * scaling_factor
978
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
979
 
980
 
981
  # Copied from transformers.models.llama.modeling_llama.rotate_half
 
944
  self.short_factor = config.rope_scaling["short_factor"]
945
  self.long_factor = config.rope_scaling["long_factor"]
946
  self.original_max_position_embeddings = config.original_max_position_embeddings
947
+ self.long_inv_freq_expanded=self.seq_freq(self.long_factor)
948
+ self.short_inv_freq_expanded=self.seq_freq(self.short_factor)
949
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
950
+ if scale > 1.0:
951
+ self.scaling_factor=math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
952
+ else:
953
+ self.scaling_factor=1.0
954
+
955
+
956
+ def seq_freq(self, factor):
957
+ ext_factors = torch.tensor(factor, dtype=torch.float32, device='cuda')
958
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device='cuda').float() / self.dim
959
+ inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
960
+ inv_freq_expanded = inv_freq[None, :, None].float()
961
+ return inv_freq_expanded
962
 
963
+ ########## INIT FUNCTION COMPUTES VARIABLES #####################
964
  @torch.no_grad()
965
  def forward(self, x, position_ids, seq_len=None):
966
+ device_type = x.device.type
967
  if seq_len > self.original_max_position_embeddings:
968
+ inv_freq_expanded = self.long_inv_freq_expanded.expand(position_ids.shape[0], -1, 1).to(device_type)
969
  else:
970
+ inv_freq_expanded = self.short_inv_freq_expanded.expand(position_ids.shape[0], -1, 1).to(device_type)
 
 
 
 
 
971
  position_ids_expanded = position_ids[:, None, :].float()
972
 
 
 
 
 
973
  with torch.autocast(device_type=device_type, enabled=False):
974
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
975
  emb = torch.cat((freqs, freqs), dim=-1)
976
+ cos = emb.cos() * self.scaling_factor
977
+ sin = emb.sin() * self.scaling_factor
 
 
 
 
 
 
 
 
978
 
979
 
980
  # Copied from transformers.models.llama.modeling_llama.rotate_half