Update modeling_bailing_moe.py
Browse files- modeling_bailing_moe.py +127 -22
modeling_bailing_moe.py
CHANGED
|
@@ -207,6 +207,90 @@ class BailingMoeDynamicNTKScalingRotaryEmbedding(BailingMoeRotaryEmbedding):
|
|
| 207 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 208 |
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 211 |
def rotate_half(x):
|
| 212 |
"""Rotates half the hidden dims of the input."""
|
|
@@ -278,7 +362,7 @@ class BailingMoeGate(nn.Module):
|
|
| 278 |
|
| 279 |
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 280 |
|
| 281 |
-
def forward(self, hidden_states):
|
| 282 |
bsz, seq_len, h = hidden_states.shape
|
| 283 |
# compute gating score
|
| 284 |
hidden_states = hidden_states.view(-1, h)
|
|
@@ -286,7 +370,7 @@ class BailingMoeGate(nn.Module):
|
|
| 286 |
scores = logits.softmax(dim=-1, dtype=torch.float32)
|
| 287 |
|
| 288 |
# select top-k experts
|
| 289 |
-
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=
|
| 290 |
|
| 291 |
# norm gate to sum 1
|
| 292 |
if self.top_k > 1 and self.norm_topk_prob:
|
|
@@ -305,7 +389,7 @@ class BailingMoeSparseMoeBlock(nn.Module):
|
|
| 305 |
super().__init__()
|
| 306 |
self.config = config
|
| 307 |
self.num_experts_per_tok = config.num_experts_per_tok
|
| 308 |
-
self.
|
| 309 |
self.gate = BailingMoeGate(config)
|
| 310 |
if config.num_shared_experts is not None:
|
| 311 |
self.shared_experts = BailingMoeMLP(
|
|
@@ -313,7 +397,7 @@ class BailingMoeSparseMoeBlock(nn.Module):
|
|
| 313 |
)
|
| 314 |
|
| 315 |
def _setup_experts(self):
|
| 316 |
-
|
| 317 |
[
|
| 318 |
BailingMoeMLP(config=self.config, intermediate_size=self.config.moe_intermediate_size)
|
| 319 |
for _ in range(self.config.num_experts)
|
|
@@ -443,6 +527,25 @@ class BailingMoeAttention(nn.Module):
|
|
| 443 |
scaling_factor=scaling_factor,
|
| 444 |
base=self.rope_theta,
|
| 445 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
else:
|
| 447 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 448 |
|
|
@@ -1258,6 +1361,24 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
| 1258 |
def get_decoder(self):
|
| 1259 |
return self.model
|
| 1260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1261 |
@add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
|
| 1262 |
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1263 |
def forward(
|
|
@@ -1325,22 +1446,7 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
| 1325 |
|
| 1326 |
hidden_states = outputs[0]
|
| 1327 |
|
| 1328 |
-
|
| 1329 |
-
if self.training:
|
| 1330 |
-
norm_weight = (
|
| 1331 |
-
self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
|
| 1332 |
-
)
|
| 1333 |
-
logits = F.linear(hidden_states, norm_weight, None)
|
| 1334 |
-
else:
|
| 1335 |
-
self.lm_head.weight.data = (
|
| 1336 |
-
self.lm_head.weight.data.float()
|
| 1337 |
-
/ (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
|
| 1338 |
-
).to(hidden_states.dtype)
|
| 1339 |
-
logits = F.linear(hidden_states, self.lm_head.weight.data, None)
|
| 1340 |
-
self.norm_head = False
|
| 1341 |
-
else:
|
| 1342 |
-
logits = self.lm_head(hidden_states)
|
| 1343 |
-
|
| 1344 |
logits = logits.float()
|
| 1345 |
|
| 1346 |
loss = None
|
|
@@ -1392,8 +1498,7 @@ class BailingMoeForCausalLM(BailingMoePreTrainedModel):
|
|
| 1392 |
|
| 1393 |
# Keep only the unprocessed tokens:
|
| 1394 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 1395 |
-
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
| 1396 |
-
# input)
|
| 1397 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 1398 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1399 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
|
|
|
| 207 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 208 |
|
| 209 |
|
| 210 |
+
# Inverse dim formula to find dim based on number of rotations
|
| 211 |
+
def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
| 212 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Find dim range bounds based on rotations
|
| 216 |
+
def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
| 217 |
+
low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
| 218 |
+
high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
| 219 |
+
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def yarn_get_mscale(scale=1, mscale=1):
|
| 223 |
+
if scale <= 1:
|
| 224 |
+
return 1.0
|
| 225 |
+
return 0.1 * mscale * math.log(scale) + 1.0
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def yarn_linear_ramp_mask(min, max, dim):
|
| 229 |
+
if min == max:
|
| 230 |
+
max += 0.001 # Prevent singularity
|
| 231 |
+
|
| 232 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
| 233 |
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
| 234 |
+
return ramp_func
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class BailingMoeYarnRotaryEmbedding(BailingMoeRotaryEmbedding):
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
dim,
|
| 242 |
+
max_position_embeddings=2048,
|
| 243 |
+
base=10000,
|
| 244 |
+
device=None,
|
| 245 |
+
scaling_factor=1.0,
|
| 246 |
+
original_max_position_embeddings=4096,
|
| 247 |
+
beta_fast=32,
|
| 248 |
+
beta_slow=1,
|
| 249 |
+
mscale=1,
|
| 250 |
+
mscale_all_dim=0,
|
| 251 |
+
):
|
| 252 |
+
self.scaling_factor = scaling_factor
|
| 253 |
+
self.original_max_position_embeddings = original_max_position_embeddings
|
| 254 |
+
self.beta_fast = beta_fast
|
| 255 |
+
self.beta_slow = beta_slow
|
| 256 |
+
self.mscale = mscale
|
| 257 |
+
self.mscale_all_dim = mscale_all_dim
|
| 258 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 259 |
+
|
| 260 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 261 |
+
self.max_seq_len_cached = seq_len
|
| 262 |
+
dim = self.dim
|
| 263 |
+
|
| 264 |
+
freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
| 265 |
+
freq_inter = 1.0 / (
|
| 266 |
+
self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
low, high = yarn_find_correction_range(
|
| 270 |
+
self.beta_fast,
|
| 271 |
+
self.beta_slow,
|
| 272 |
+
dim,
|
| 273 |
+
self.base,
|
| 274 |
+
self.original_max_position_embeddings,
|
| 275 |
+
)
|
| 276 |
+
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32)
|
| 277 |
+
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
| 278 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 279 |
+
|
| 280 |
+
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
| 281 |
+
|
| 282 |
+
freqs = torch.outer(t, inv_freq)
|
| 283 |
+
|
| 284 |
+
_mscale = float(
|
| 285 |
+
yarn_get_mscale(self.scaling_factor, self.mscale)
|
| 286 |
+
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 290 |
+
self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False)
|
| 291 |
+
self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 295 |
def rotate_half(x):
|
| 296 |
"""Rotates half the hidden dims of the input."""
|
|
|
|
| 362 |
|
| 363 |
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 364 |
|
| 365 |
+
def forward(self, hidden_states, sort=False):
|
| 366 |
bsz, seq_len, h = hidden_states.shape
|
| 367 |
# compute gating score
|
| 368 |
hidden_states = hidden_states.view(-1, h)
|
|
|
|
| 370 |
scores = logits.softmax(dim=-1, dtype=torch.float32)
|
| 371 |
|
| 372 |
# select top-k experts
|
| 373 |
+
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=sort)
|
| 374 |
|
| 375 |
# norm gate to sum 1
|
| 376 |
if self.top_k > 1 and self.norm_topk_prob:
|
|
|
|
| 389 |
super().__init__()
|
| 390 |
self.config = config
|
| 391 |
self.num_experts_per_tok = config.num_experts_per_tok
|
| 392 |
+
self._setup_experts()
|
| 393 |
self.gate = BailingMoeGate(config)
|
| 394 |
if config.num_shared_experts is not None:
|
| 395 |
self.shared_experts = BailingMoeMLP(
|
|
|
|
| 397 |
)
|
| 398 |
|
| 399 |
def _setup_experts(self):
|
| 400 |
+
self.experts = nn.ModuleList(
|
| 401 |
[
|
| 402 |
BailingMoeMLP(config=self.config, intermediate_size=self.config.moe_intermediate_size)
|
| 403 |
for _ in range(self.config.num_experts)
|
|
|
|
| 527 |
scaling_factor=scaling_factor,
|
| 528 |
base=self.rope_theta,
|
| 529 |
)
|
| 530 |
+
elif scaling_type == "yarn":
|
| 531 |
+
kwargs = {
|
| 532 |
+
key: self.config.rope_scaling[key]
|
| 533 |
+
for key in [
|
| 534 |
+
"original_max_position_embeddings",
|
| 535 |
+
"beta_fast",
|
| 536 |
+
"beta_slow",
|
| 537 |
+
"mscale",
|
| 538 |
+
"mscale_all_dim",
|
| 539 |
+
]
|
| 540 |
+
if key in self.config.rope_scaling
|
| 541 |
+
}
|
| 542 |
+
self.rotary_emb = BailingMoeYarnRotaryEmbedding(
|
| 543 |
+
self.head_dim,
|
| 544 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 545 |
+
scaling_factor=scaling_factor,
|
| 546 |
+
base=self.rope_theta,
|
| 547 |
+
**kwargs,
|
| 548 |
+
)
|
| 549 |
else:
|
| 550 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 551 |
|
|
|
|
| 1361 |
def get_decoder(self):
|
| 1362 |
return self.model
|
| 1363 |
|
| 1364 |
+
def compute_logit(self, hidden_states):
|
| 1365 |
+
if self.norm_head:
|
| 1366 |
+
if self.training:
|
| 1367 |
+
norm_weight = (
|
| 1368 |
+
self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
|
| 1369 |
+
)
|
| 1370 |
+
logits = F.linear(hidden_states, norm_weight, None)
|
| 1371 |
+
else:
|
| 1372 |
+
self.lm_head.weight.data = (
|
| 1373 |
+
self.lm_head.weight.data.float()
|
| 1374 |
+
/ (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
|
| 1375 |
+
).to(hidden_states.dtype)
|
| 1376 |
+
logits = F.linear(hidden_states, self.lm_head.weight.data, None)
|
| 1377 |
+
self.norm_head = False
|
| 1378 |
+
else:
|
| 1379 |
+
logits = self.lm_head(hidden_states)
|
| 1380 |
+
return logits
|
| 1381 |
+
|
| 1382 |
@add_start_docstrings_to_model_forward(BAILINGMOE_INPUTS_DOCSTRING)
|
| 1383 |
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
| 1384 |
def forward(
|
|
|
|
| 1446 |
|
| 1447 |
hidden_states = outputs[0]
|
| 1448 |
|
| 1449 |
+
logits = self.compute_logit(hidden_states=hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1450 |
logits = logits.float()
|
| 1451 |
|
| 1452 |
loss = None
|
|
|
|
| 1498 |
|
| 1499 |
# Keep only the unprocessed tokens:
|
| 1500 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 1501 |
+
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as input)
|
|
|
|
| 1502 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 1503 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1504 |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|