drbh
commited on
Commit
·
b08f6c9
1
Parent(s):
5268e56
feat: adjust layer params
Browse files- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py +14 -17
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py +14 -17
- build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py +14 -17
- build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py +14 -17
- build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py +14 -17
- build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py +14 -17
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py +14 -17
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py +14 -17
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py +14 -17
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py
CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
-
moe_top_k = getattr(self, "
|
687 |
-
moe_num_experts = getattr(self, "
|
688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
689 |
-
alpha = getattr(self, "alpha", 1.
|
690 |
-
moe_capacity_factor = getattr(self, "
|
691 |
-
moe_jitter_eps = getattr(self, "
|
692 |
-
moe_normalize_expert_weights = getattr(
|
693 |
-
self, "moe_normalize_expert_weights", None
|
694 |
-
)
|
695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
696 |
-
|
697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
698 |
-
expert_parallel_group =
|
699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
700 |
-
|
701 |
-
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
-
output, expert_weights_out, _ = moe_forward(
|
706 |
x=x,
|
707 |
router_weight=self.router.weight,
|
708 |
moe_top_k=moe_top_k,
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
+
|
695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
+
|
699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
701 |
+
|
702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
703 |
x=x,
|
704 |
router_weight=self.router.weight,
|
705 |
moe_top_k=moe_top_k,
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py
CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
-
moe_top_k = getattr(self, "
|
687 |
-
moe_num_experts = getattr(self, "
|
688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
689 |
-
alpha = getattr(self, "alpha", 1.
|
690 |
-
moe_capacity_factor = getattr(self, "
|
691 |
-
moe_jitter_eps = getattr(self, "
|
692 |
-
moe_normalize_expert_weights = getattr(
|
693 |
-
self, "moe_normalize_expert_weights", None
|
694 |
-
)
|
695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
696 |
-
|
697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
698 |
-
expert_parallel_group =
|
699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
700 |
-
|
701 |
-
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
-
output, expert_weights_out, _ = moe_forward(
|
706 |
x=x,
|
707 |
router_weight=self.router.weight,
|
708 |
moe_top_k=moe_top_k,
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
+
|
695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
+
|
699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
701 |
+
|
702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
703 |
x=x,
|
704 |
router_weight=self.router.weight,
|
705 |
moe_top_k=moe_top_k,
|
build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py
CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
-
moe_top_k = getattr(self, "
|
687 |
-
moe_num_experts = getattr(self, "
|
688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
689 |
-
alpha = getattr(self, "alpha", 1.
|
690 |
-
moe_capacity_factor = getattr(self, "
|
691 |
-
moe_jitter_eps = getattr(self, "
|
692 |
-
moe_normalize_expert_weights = getattr(
|
693 |
-
self, "moe_normalize_expert_weights", None
|
694 |
-
)
|
695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
696 |
-
|
697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
698 |
-
expert_parallel_group =
|
699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
700 |
-
|
701 |
-
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
-
output, expert_weights_out, _ = moe_forward(
|
706 |
x=x,
|
707 |
router_weight=self.router.weight,
|
708 |
moe_top_k=moe_top_k,
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
+
|
695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
+
|
699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
701 |
+
|
702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
703 |
x=x,
|
704 |
router_weight=self.router.weight,
|
705 |
moe_top_k=moe_top_k,
|
build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py
CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
-
moe_top_k = getattr(self, "
|
687 |
-
moe_num_experts = getattr(self, "
|
688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
689 |
-
alpha = getattr(self, "alpha", 1.
|
690 |
-
moe_capacity_factor = getattr(self, "
|
691 |
-
moe_jitter_eps = getattr(self, "
|
692 |
-
moe_normalize_expert_weights = getattr(
|
693 |
-
self, "moe_normalize_expert_weights", None
|
694 |
-
)
|
695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
696 |
-
|
697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
698 |
-
expert_parallel_group =
|
699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
700 |
-
|
701 |
-
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
-
output, expert_weights_out, _ = moe_forward(
|
706 |
x=x,
|
707 |
router_weight=self.router.weight,
|
708 |
moe_top_k=moe_top_k,
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
+
|
695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
+
|
699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
701 |
+
|
702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
703 |
x=x,
|
704 |
router_weight=self.router.weight,
|
705 |
moe_top_k=moe_top_k,
|
build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py
CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
-
moe_top_k = getattr(self, "
|
687 |
-
moe_num_experts = getattr(self, "
|
688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
689 |
-
alpha = getattr(self, "alpha", 1.
|
690 |
-
moe_capacity_factor = getattr(self, "
|
691 |
-
moe_jitter_eps = getattr(self, "
|
692 |
-
moe_normalize_expert_weights = getattr(
|
693 |
-
self, "moe_normalize_expert_weights", None
|
694 |
-
)
|
695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
696 |
-
|
697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
698 |
-
expert_parallel_group =
|
699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
700 |
-
|
701 |
-
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
-
output, expert_weights_out, _ = moe_forward(
|
706 |
x=x,
|
707 |
router_weight=self.router.weight,
|
708 |
moe_top_k=moe_top_k,
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
+
|
695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
+
|
699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
701 |
+
|
702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
703 |
x=x,
|
704 |
router_weight=self.router.weight,
|
705 |
moe_top_k=moe_top_k,
|
build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py
CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
-
moe_top_k = getattr(self, "
|
687 |
-
moe_num_experts = getattr(self, "
|
688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
689 |
-
alpha = getattr(self, "alpha", 1.
|
690 |
-
moe_capacity_factor = getattr(self, "
|
691 |
-
moe_jitter_eps = getattr(self, "
|
692 |
-
moe_normalize_expert_weights = getattr(
|
693 |
-
self, "moe_normalize_expert_weights", None
|
694 |
-
)
|
695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
696 |
-
|
697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
698 |
-
expert_parallel_group =
|
699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
700 |
-
|
701 |
-
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
-
output, expert_weights_out, _ = moe_forward(
|
706 |
x=x,
|
707 |
router_weight=self.router.weight,
|
708 |
moe_top_k=moe_top_k,
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
+
|
695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
+
|
699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
701 |
+
|
702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
703 |
x=x,
|
704 |
router_weight=self.router.weight,
|
705 |
moe_top_k=moe_top_k,
|
build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py
CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
-
moe_top_k = getattr(self, "
|
687 |
-
moe_num_experts = getattr(self, "
|
688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
689 |
-
alpha = getattr(self, "alpha", 1.
|
690 |
-
moe_capacity_factor = getattr(self, "
|
691 |
-
moe_jitter_eps = getattr(self, "
|
692 |
-
moe_normalize_expert_weights = getattr(
|
693 |
-
self, "moe_normalize_expert_weights", None
|
694 |
-
)
|
695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
696 |
-
|
697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
698 |
-
expert_parallel_group =
|
699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
700 |
-
|
701 |
-
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
-
output, expert_weights_out, _ = moe_forward(
|
706 |
x=x,
|
707 |
router_weight=self.router.weight,
|
708 |
moe_top_k=moe_top_k,
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
+
|
695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
+
|
699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
701 |
+
|
702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
703 |
x=x,
|
704 |
router_weight=self.router.weight,
|
705 |
moe_top_k=moe_top_k,
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py
CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
-
moe_top_k = getattr(self, "
|
687 |
-
moe_num_experts = getattr(self, "
|
688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
689 |
-
alpha = getattr(self, "alpha", 1.
|
690 |
-
moe_capacity_factor = getattr(self, "
|
691 |
-
moe_jitter_eps = getattr(self, "
|
692 |
-
moe_normalize_expert_weights = getattr(
|
693 |
-
self, "moe_normalize_expert_weights", None
|
694 |
-
)
|
695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
696 |
-
|
697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
698 |
-
expert_parallel_group =
|
699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
700 |
-
|
701 |
-
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
-
output, expert_weights_out, _ = moe_forward(
|
706 |
x=x,
|
707 |
router_weight=self.router.weight,
|
708 |
moe_top_k=moe_top_k,
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
+
|
695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
+
|
699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
701 |
+
|
702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
703 |
x=x,
|
704 |
router_weight=self.router.weight,
|
705 |
moe_top_k=moe_top_k,
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py
CHANGED
@@ -683,26 +683,23 @@ def moe_forward(
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
-
moe_top_k = getattr(self, "
|
687 |
-
moe_num_experts = getattr(self, "
|
688 |
-
gradient_scale = getattr(self, "gradient_scale", None)
|
689 |
-
alpha = getattr(self, "alpha", 1.
|
690 |
-
moe_capacity_factor = getattr(self, "
|
691 |
-
moe_jitter_eps = getattr(self, "
|
692 |
-
moe_normalize_expert_weights = getattr(
|
693 |
-
self, "moe_normalize_expert_weights", None
|
694 |
-
)
|
695 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
696 |
-
|
697 |
has_parallel = hasattr(self, "expert_parallel_group")
|
698 |
-
expert_parallel_group =
|
699 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
700 |
-
|
701 |
-
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
-
output, expert_weights_out, _ = moe_forward(
|
706 |
x=x,
|
707 |
router_weight=self.router.weight,
|
708 |
moe_top_k=moe_top_k,
|
|
|
683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
684 |
|
685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
686 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
687 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
688 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
689 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
690 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
691 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
692 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
|
|
|
|
693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
694 |
+
|
695 |
has_parallel = hasattr(self, "expert_parallel_group")
|
696 |
+
expert_parallel_group = torch.distributed.group.WORLD
|
697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
698 |
+
|
699 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
700 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
701 |
+
|
702 |
+
output, expert_weights_out, *_ = moe_forward(
|
|
|
703 |
x=x,
|
704 |
router_weight=self.router.weight,
|
705 |
moe_top_k=moe_top_k,
|