drbh
commited on
Commit
·
9a1816c
1
Parent(s):
b08f6c9
fix: adjust layer params in source
Browse files- torch-ext/megablocks/layers.py +15 -18
torch-ext/megablocks/layers.py
CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
-
moe_top_k = getattr(self, "
|
687 |
-
moe_num_experts = getattr(self, "
|
688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
689 |
-
alpha = getattr(self, "alpha", 1.
|
690 |
-
moe_capacity_factor = getattr(self, "
|
691 |
-
moe_jitter_eps = getattr(self, "
|
692 |
-
moe_normalize_expert_weights = getattr(
|
693 |
-
self, "moe_normalize_expert_weights", None
|
694 |
-
)
|
695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
696 |
-
|
697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
698 |
-
expert_parallel_group =
|
699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
700 |
-
|
701 |
-
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
-
output, expert_weights_out, _ = moe_forward(
|
706 |
x=x,
|
707 |
router_weight=self.router.weight,
|
708 |
moe_top_k=moe_top_k,
|
@@ -725,4 +722,4 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
725 |
hidden_size=self.experts.hidden_size,
|
726 |
mlp_impl=mlp_impl,
|
727 |
)
|
728 |
-
return output, expert_weights_out
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
+
|
695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
+
|
699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
701 |
+
|
702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
703 |
x=x,
|
704 |
router_weight=self.router.weight,
|
705 |
moe_top_k=moe_top_k,
|
|
|
722 |
hidden_size=self.experts.hidden_size,
|
723 |
mlp_impl=mlp_impl,
|
724 |
)
|
725 |
+
return output, expert_weights_out
|