import torch import torch.distributed as dist from typing import Optional, Any from . import _layers from . import ops # Set the expert model parallel attributes on a tensor def set_expert_model_parallel_attributes( tensor: torch.Tensor, is_parallel: bool, ): assert not hasattr(tensor, "expert_model_parallel") setattr(tensor, "expert_model_parallel", is_parallel) # Get the expert model parallel attributes from a tensor def expert_sharding_degree( world_size: int, moe_num_experts: int, ) -> int: esd = min(world_size, moe_num_experts) if (moe_num_experts % esd) != 0: raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") return esd # Calculate the hidden sharding degree based on world size and expert sharding degree def hidden_sharding_degree( world_size: int, moe_num_experts: int, ffn_hidden_size: int, ) -> int: esd = expert_sharding_degree(world_size, moe_num_experts) hsd = world_size // esd if (ffn_hidden_size % hsd) != 0: raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") if (esd * hsd) != world_size: raise ValueError( f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." ) return hsd # Calculate the number of experts per rank based on world size and expert sharding degree def experts_per_rank( moe_num_experts: int, world_size: int, ) -> int: return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) # Calculate the number of features per rank based on ffn hidden size and hidden sharding degree def features_per_rank( ffn_hidden_size: int, world_size: int, moe_num_experts: int ) -> int: return ffn_hidden_size // hidden_sharding_degree( world_size, moe_num_experts, ffn_hidden_size ) # Apply jitter to the input tensor def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: low = 1.0 - moe_jitter_eps high = 1.0 + moe_jitter_eps noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) return x * (low + noise * (high - low)) # Compute the top-k scores from the logits def compute_top_k(scores: torch.Tensor, moe_top_k: int): if moe_top_k == 1: return scores.max(dim=-1, keepdim=True) return torch.topk(scores, moe_top_k, dim=-1) # Route tokens to experts and compute expert weights and indices def route_tokens( x: torch.Tensor, router_weight: torch.Tensor, moe_top_k: int, moe_num_experts: int, moe_jitter_eps: float = None, moe_normalize_expert_weights: int = None, uniform_expert_assignment: bool = False, training: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if training and moe_jitter_eps is not None: x = apply_jitter(x, moe_jitter_eps) x_flat = x.view(-1, x.shape[-1]) logits = torch.nn.functional.linear(x_flat, router_weight) expert_weights, expert_indices = compute_top_k(logits, moe_top_k) expert_weights = expert_weights.softmax(dim=-1) if moe_normalize_expert_weights is not None: expert_weights = expert_weights / torch.norm( expert_weights, p=moe_normalize_expert_weights, dim=-1, keepdim=True, ) if uniform_expert_assignment: expert_indices = _layers.router._uniform_expert_assignment( expert_indices, moe_num_experts, ) return logits, expert_weights, expert_indices # Scale the gradient of the weights def scale_grad( w: torch.Tensor, gradient_scale: Optional[float] = None, ) -> torch.Tensor: if gradient_scale is None: return w return _layers.mlp.scale_gradient(w, gradient_scale) # Forward pass for the MLP layer def mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale=None, alpha: float = 1.702): # Scale weights w1 = scale_grad(w1, gradient_scale) w2 = scale_grad(w2, gradient_scale) w1_bias = scale_grad(w1_bias, gradient_scale) w2_bias = scale_grad(w2_bias, gradient_scale) # Resolve dtensors w1 = _layers.mlp.resolve_dtensor(w1) w2 = _layers.mlp.resolve_dtensor(w2) w1_bias = _layers.mlp.resolve_dtensor(w1_bias) w2_bias = _layers.mlp.resolve_dtensor(w2_bias) # Forward pass gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] gate, up = gate_up.chunk(2, dim=-1) glu = gate * torch.sigmoid(gate * alpha) x = (up + 1) * glu return torch.bmm(x, w2) + w2_bias[..., None, :] ## START: Load Balancing Loss (unused at the moment) # Global variable to store load balancing loss _LOAD_BALANCING_LOSS = [] def save_load_balancing_loss(loss): global _LOAD_BALANCING_LOSS _LOAD_BALANCING_LOSS.append(loss) def get_load_balancing_loss(): global _LOAD_BALANCING_LOSS return _LOAD_BALANCING_LOSS def clear_load_balancing_loss(): global _LOAD_BALANCING_LOSS _LOAD_BALANCING_LOSS.clear() def batched_load_balancing_loss(args): if args.moe_loss_weight == 0: return 0.0 tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size if args.num_layers_per_virtual_pipeline_stage is not None: num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage if len(tokens_per_expert) != num_layers_per_pipeline_stage: raise ValueError( f"Expected {num_layers_per_pipeline_stage} token_per_experts " f"but found {len(tokens_per_expert)}.\nnum_layers = " f"{args.num_layers}\npipeline_model_parallel_size = " f"{args.pipeline_model_parallel_size}\n" "num_layers_per_virtual_pipeline_stage" f" = {args.num_layers_per_virtual_pipeline_stage}", ) if len(expert_scores) != num_layers_per_pipeline_stage: raise ValueError( f"Expected {num_layers_per_pipeline_stage} expert_scores " f"but found {len(tokens_per_expert)}.\nnum_layers = " f"{args.num_layers}\npipeline_model_parallel_size = " f"{args.pipeline_model_parallel_size}\n" "num_layers_per_virtual_pipeline_stage" f" = {args.num_layers_per_virtual_pipeline_stage}", ) # Verify the shape of the tokens_per_expert and expert_scores tensors. assert all( (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) ) tokens = expert_scores[0].shape[0] assert all( ( ( x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens ) for x in expert_scores ) ) # Concatenate the contributions of each layer and convert to # the correct types and formats for the dot product. expert_scores = torch.cat(expert_scores, dim=1) if args.moe_lbl_in_fp32: expert_scores = expert_scores.float() if tokens != 0: expert_scores = expert_scores.mean(dim=0) else: expert_scores = expert_scores.sum(dim=0) tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) expected_values = num_layers_per_pipeline_stage * args.moe_num_experts assert tokens_per_expert.numel() == expected_values assert expert_scores.numel() == expected_values # Calculate the total scale across all factors. # # loss_weight * num_experts / (num_layers * tokens * top_k) scale_numerator = args.moe_num_experts * args.moe_loss_weight scale_denominator = args.num_layers * tokens * args.moe_top_k scale = scale_numerator / scale_denominator return scale * torch.dot(tokens_per_expert, expert_scores) ## END Load Balancing Loss # Calculate the expert capacity based on tokens, top_k, number of experts, # expert parallel group, capacity factor, and whether expert model parallelism is used. def expert_capacity( tokens: int, top_k: int, num_experts: int, expert_parallel_group: int, moe_capacity_factor: float, moe_expert_model_parallelism: bool, ) -> int: world_size = ( dist.get_world_size(expert_parallel_group) if moe_expert_model_parallelism else 1 ) tokens_per_expert = top_k * tokens * world_size / num_experts return int(moe_capacity_factor * tokens_per_expert) def load_balancing_loss( tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor, top_k: int, num_experts: int, ): assert len(expert_scores.size()) == 2 tokens, num_experts = expert_scores.size() assert num_experts == num_experts assert len(tokens_per_expert.size()) == 1 (num_experts,) = tokens_per_expert.size() assert num_experts == num_experts scale = num_experts / (tokens * top_k) return scale * torch.dot( tokens_per_expert.to(expert_scores.dtype), expert_scores.mean(dim=0), ) def indices_and_bins( top_expert: torch.Tensor, sort_end_bit: int, num_experts: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: top_expert = top_expert.int() # Ensure contiguous memory layout top_expert = top_expert.contiguous() # Ensure CUB knows which device to use with torch.cuda.device(top_expert.device): output = ops.sort(top_expert, sort_end_bit) bin_ids, indices = output tokens_per_expert = ops.histogram(top_expert, num_experts) bins = ops.inclusive_cumsum(tokens_per_expert, 0) bins = bins.view(1) if not len(bins.size()) else bins return indices, bin_ids, bins, tokens_per_expert def expert_capacity_fn( tokens: int, top_k: int, num_experts: int, expert_parallel_group: torch.distributed.ProcessGroup, moe_capacity_factor: float = 1.0, moe_expert_model_parallelism: bool = False, ) -> int: world_size = ( dist.get_world_size(expert_parallel_group) if moe_expert_model_parallelism else 1 ) tokens_per_expert = top_k * tokens * world_size / num_experts return int(moe_capacity_factor * tokens_per_expert) def permute_and_compute( x, tokens_per_expert, indices, bin_ids, expert_weights, bins, expert_capacity, top_k, w1, w2, w1_bias, w2_bias, gradient_scale, alpha, ): """Permute tokens and compute expert outputs.""" # Route tokens to experts x = x.view(-1, x.shape[-1]) # Ensure CUB knows which device to use with torch.cuda.device(x.device): x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) # Expert computation x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) # Ensure CUB knows which device to use with torch.cuda.device(x.device): # Route tokens back out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) return out def forward_once( x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w1_bias: torch.Tensor, w2_bias: torch.Tensor, gradient_scale: Optional[float] = None, alpha: float = 1.702, sort_end_bit: int = 0, top_k: int = 4, num_experts: int = 128, expert_parallel_group: int = None, moe_capacity_factor: float = 1.0, moe_expert_model_parallelism: bool = False, ): # x: [sl, bs, hs] # expert_weights: [sl * bs, top-k] # top_experts: [sl * bs, top-k] expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): indices, bin_ids, bins, tokens_per_expert = indices_and_bins( top_experts, sort_end_bit, num_experts ) # Calculate expert capacity sl, bs, _ = x.size() expert_capacity = expert_capacity_fn( sl * bs, top_k, num_experts, expert_parallel_group, moe_capacity_factor, moe_expert_model_parallelism, ) if expert_capacity == 0: expert_capacity = torch.max(tokens_per_expert).item() x = permute_and_compute( x, tokens_per_expert, indices, bin_ids, expert_weights, bins, expert_capacity, top_k, w1, w2, w1_bias, w2_bias, gradient_scale, alpha, ) return x, tokens_per_expert # TODO: replace with functional logic once aligned with ref def parallel_forward_once( x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, w1_bias: torch.Tensor, w2_bias: torch.Tensor, gradient_scale: Optional[float] = None, alpha: float = 1.702, sort_end_bit: int = 0, top_k: int = 4, num_experts: int = 128, expert_parallel_group: torch.distributed.ProcessGroup = None, moe_capacity_factor: float = 1.0, moe_expert_model_parallelism: bool = True, hidden_size: int = 1152, ): pass class MyReplacementLayer(torch.nn.Module): # def __init__(self): # super().__init__() def forward( # self, x: torch.Tensor, router_weight: torch.Tensor, moe_top_k: int, moe_num_experts: int, moe_jitter_eps: float = None, moe_normalize_expert_weights: int = None, uniform_expert_assignment: bool = False, training: bool = False, # w1: torch.Tensor = None, w2: torch.Tensor = None, w1_bias: torch.Tensor = None, w2_bias: torch.Tensor = None, gradient_scale: Optional[float] = None, alpha: float = 1.702, sort_end_bit: int = 0, expert_parallel_group: torch.distributed.ProcessGroup = None, moe_capacity_factor: float = 1.0, moe_expert_model_parallelism: bool = False, forward_fn: Any = None, hidden_size: int = None, # Required for parallel forward ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Route tokens to experts logits, expert_weights, expert_indices = route_tokens( x, router_weight, moe_top_k, moe_num_experts, moe_jitter_eps, moe_normalize_expert_weights, uniform_expert_assignment, training, ) # Create router scores for output router_scores = ( torch.zeros_like(logits) .scatter_(1, expert_indices, expert_weights) .transpose(0, 1) ) in_shape = x.size() # Prepare forward function arguments forward_args = { "x": x, "expert_weights": expert_weights, "top_experts": expert_indices, "w1": w1, "w2": w2, "w1_bias": w1_bias, "w2_bias": w2_bias, "gradient_scale": gradient_scale, "alpha": alpha, "sort_end_bit": sort_end_bit, "top_k": moe_top_k, "num_experts": moe_num_experts, "expert_parallel_group": expert_parallel_group, "moe_capacity_factor": moe_capacity_factor, "moe_expert_model_parallelism": moe_expert_model_parallelism, } # Add hidden_size for parallel forward if moe_expert_model_parallelism and hidden_size is not None: forward_args["hidden_size"] = hidden_size elif moe_expert_model_parallelism and hidden_size is None: # Infer hidden_size from input shape forward_args["hidden_size"] = x.shape[-1] # Compute expert outputs x, tokens_per_expert = forward_fn(**forward_args) # Save load balancing loss if needed moe_loss_weight = 0.0 # Can be made configurable if training and moe_loss_weight > 0: save_load_balancing_loss((tokens_per_expert, logits)) # Restore original shape x = x.view(in_shape) return x, expert_weights, router_scores class MegaBlocksMoeMLP(torch.nn.Module): def forward( self, x: torch.Tensor, ) -> torch.Tensor: router_weight = self.router.weight moe_top_k = 4 moe_num_experts = 128 w1 = self.experts.gate_up_proj.data w2 = self.experts.down_proj.data w1_bias = self.experts.gate_up_proj_bias.data w2_bias = self.experts.down_proj_bias.data expert_parallel_group = None sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) hidden_size = self.experts.hidden_size output, expert_weights_out, router_scores = MyReplacementLayer.forward( x=x, router_weight=router_weight, moe_top_k=moe_top_k, moe_num_experts=moe_num_experts, moe_jitter_eps=None, moe_normalize_expert_weights=None, uniform_expert_assignment=False, training=False, w1=w1, w2=w2, w1_bias=w1_bias, w2_bias=w2_bias, gradient_scale=None, alpha=1.702, sort_end_bit=sort_end_bit, expert_parallel_group=expert_parallel_group, moe_capacity_factor=1.0, moe_expert_model_parallelism=False, forward_fn=forward_once, hidden_size=hidden_size, ) return output, expert_weights_out