drbh commited on
Commit
9a1816c
·
1 Parent(s): b08f6c9

fix: adjust layer params in source

Browse files
Files changed (1) hide show
  1. torch-ext/megablocks/layers.py +15 -18
torch-ext/megablocks/layers.py CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
686
- moe_top_k = getattr(self, "moe_top_k", 4)
687
- moe_num_experts = getattr(self, "moe_num_experts", 128)
688
- gradient_scale = getattr(self, "gradient_scale", None)
689
- alpha = getattr(self, "alpha", 1.702)
690
- moe_capacity_factor = getattr(self, "moe_capacity_factor", 1.0)
691
- moe_jitter_eps = getattr(self, "moe_jitter_eps", None)
692
- moe_normalize_expert_weights = getattr(
693
- self, "moe_normalize_expert_weights", None
694
- )
695
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
696
-
697
  has_parallel = hasattr(self, "expert_parallel_group")
698
- expert_parallel_group = getattr(self, "expert_parallel_group", None)
699
  forward_fn = parallel_forward_once if has_parallel else forward_once
700
- sort_end_bit = max(
701
- int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
702
- )
703
- mlp_impl = getattr(self, "mlp_impl", "grouped") # or sparse
704
-
705
- output, expert_weights_out, _ = moe_forward(
706
  x=x,
707
  router_weight=self.router.weight,
708
  moe_top_k=moe_top_k,
@@ -725,4 +722,4 @@ class MegaBlocksMoeMLP(torch.nn.Module):
725
  hidden_size=self.experts.hidden_size,
726
  mlp_impl=mlp_impl,
727
  )
728
- return output, expert_weights_out
 
683
  class MegaBlocksMoeMLP(torch.nn.Module):
684
 
685
  def forward(self, x: torch.Tensor) -> torch.Tensor:
686
+ moe_top_k = getattr(self.router, "top_k", 4)
687
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
688
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
689
+ alpha = getattr(self.experts, "alpha", 1.0)
690
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
691
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
692
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
 
 
693
  uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
694
+
695
  has_parallel = hasattr(self, "expert_parallel_group")
696
+ expert_parallel_group = torch.distributed.group.WORLD
697
  forward_fn = parallel_forward_once if has_parallel else forward_once
698
+
699
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
700
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
701
+
702
+ output, expert_weights_out, *_ = moe_forward(
 
703
  x=x,
704
  router_weight=self.router.weight,
705
  moe_top_k=moe_top_k,
 
722
  hidden_size=self.experts.hidden_size,
723
  mlp_impl=mlp_impl,
724
  )
725
+ return output, expert_weights_out