|
|
|
|
|
|
|
import torch |
|
|
|
from ._ops import ops |
|
|
|
from .grouped_gemm import backend as gg_backend |
|
from .grouped_gemm import ops as gg_ops |
|
|
|
|
|
from ._layers.arguments import Arguments |
|
from ._layers.dmoe import ParallelDroplessMLP, dMoE |
|
from ._layers.glu import SparseGLU |
|
from ._layers.mlp import MLP, SparseMLP |
|
from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss |
|
|
|
from . import layers |
|
|
|
|
|
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Compute exclusive cumulative sum along the specified dimension. |
|
|
|
Args: |
|
x: Input tensor |
|
dim: Dimension along which to compute cumsum |
|
out: Output tensor (modified in-place) |
|
|
|
Returns: |
|
The output tensor |
|
""" |
|
result = ops.exclusive_cumsum(x, dim) |
|
out.copy_(result) |
|
return out |
|
|
|
|
|
def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Compute inclusive cumulative sum along the specified dimension. |
|
|
|
Args: |
|
x: Input tensor |
|
dim: Dimension along which to compute cumsum |
|
out: Output tensor (modified in-place) |
|
|
|
Returns: |
|
The output tensor |
|
""" |
|
result = ops.inclusive_cumsum(x, dim) |
|
out.copy_(result) |
|
return out |
|
|
|
|
|
def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: |
|
""" |
|
Compute histogram of input tensor values. |
|
|
|
Args: |
|
x: Input tensor |
|
num_bins: Number of histogram bins |
|
|
|
Returns: |
|
Histogram tensor with counts for each bin |
|
""" |
|
return ops.histogram(x, num_bins) |
|
|
|
|
|
def indices( |
|
padded_bins: torch.Tensor, |
|
block_size: int, |
|
output_block_rows: int, |
|
output_block_columns: int, |
|
) -> torch.Tensor: |
|
""" |
|
Construct indices from padded bins for sparse operations. |
|
|
|
Args: |
|
padded_bins: Tensor containing bin boundaries |
|
block_size: Size of each block |
|
output_block_rows: Number of rows in output blocks |
|
output_block_columns: Number of columns in output blocks |
|
|
|
Returns: |
|
Tensor containing constructed indices |
|
""" |
|
return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) |
|
|
|
|
|
def replicate_forward( |
|
x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
Forward pass of replicate operation - replicate values according to bin sizes. |
|
|
|
Args: |
|
x: Input tensor with values to replicate |
|
bins: Tensor containing bin sizes |
|
out: Output tensor (modified in-place) |
|
|
|
Returns: |
|
The output tensor |
|
""" |
|
return ops.replicate_forward(x, bins, out) |
|
|
|
|
|
def replicate_backward( |
|
grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
Backward pass of replicate operation - reduce gradients back to bins. |
|
|
|
Args: |
|
grad: Gradient tensor to reduce |
|
bins: Tensor containing bin sizes |
|
out: Output tensor (modified in-place) |
|
|
|
Returns: |
|
The output tensor |
|
""" |
|
return ops.replicate_backward(grad, bins, out) |
|
|
|
|
|
def sort( |
|
x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor |
|
) -> torch.Tensor: |
|
""" |
|
Radix sort with index tracking. |
|
|
|
Args: |
|
x: Input tensor to sort |
|
end_bit: Number of bits to consider in sorting |
|
x_out: Output tensor for sorted values |
|
iota_out: Output tensor for sorted indices |
|
|
|
Returns: |
|
The sorted values tensor |
|
""" |
|
return ops.sort(x, end_bit, x_out, iota_out) |
|
|
|
|
|
|
|
def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: |
|
""" |
|
Compute cumulative sum with automatic output allocation. |
|
|
|
Args: |
|
x: Input tensor |
|
dim: Dimension along which to compute cumsum (default: last dimension) |
|
exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum |
|
|
|
Returns: |
|
New tensor containing the cumulative sum |
|
""" |
|
out = torch.empty_like(x) |
|
if exclusive: |
|
return exclusive_cumsum(x, dim, out) |
|
else: |
|
return inclusive_cumsum(x, dim, out) |
|
|
|
|
|
def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Sort tensor and return both sorted values and indices. |
|
|
|
Args: |
|
x: Input tensor to sort |
|
end_bit: Number of bits to consider in sorting |
|
|
|
Returns: |
|
Tuple of (sorted_values, sorted_indices) |
|
""" |
|
x_out = torch.empty_like(x) |
|
iota_out = torch.empty_like(x) |
|
sort(x, end_bit, x_out, iota_out) |
|
return x_out, iota_out |
|
|
|
|
|
|
|
__all__ = [ |
|
"MyReplacementLayer", |
|
|
|
"exclusive_cumsum", |
|
"inclusive_cumsum", |
|
"histogram", |
|
"indices", |
|
"replicate_forward", |
|
"replicate_backward", |
|
"sort", |
|
"cumsum", |
|
"argsort", |
|
|
|
"Arguments", |
|
"ParallelDroplessMLP", |
|
"dMoE", |
|
"SparseGLU", |
|
"MLP", |
|
"SparseMLP", |
|
"MoE", |
|
"ParallelMLP", |
|
"get_load_balancing_loss", |
|
] |
|
|