drbh
feat: bump builds
2b84d84
raw
history blame
5.13 kB
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import torch
from ._ops import ops
from .grouped_gemm import backend as gg_backend
from .grouped_gemm import ops as gg_ops
from ._layers.arguments import Arguments
from ._layers.dmoe import ParallelDroplessMLP, dMoE
from ._layers.glu import SparseGLU
from ._layers.mlp import MLP, SparseMLP
from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
from . import layers
# This section contains the direct kernel exports (not inlcuded in the original code)
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
"""
Compute exclusive cumulative sum along the specified dimension.
Args:
x: Input tensor
dim: Dimension along which to compute cumsum
out: Output tensor (modified in-place)
Returns:
The output tensor
"""
result = ops.exclusive_cumsum(x, dim)
out.copy_(result)
return out
def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
"""
Compute inclusive cumulative sum along the specified dimension.
Args:
x: Input tensor
dim: Dimension along which to compute cumsum
out: Output tensor (modified in-place)
Returns:
The output tensor
"""
result = ops.inclusive_cumsum(x, dim)
out.copy_(result)
return out
def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
"""
Compute histogram of input tensor values.
Args:
x: Input tensor
num_bins: Number of histogram bins
Returns:
Histogram tensor with counts for each bin
"""
return ops.histogram(x, num_bins)
def indices(
padded_bins: torch.Tensor,
block_size: int,
output_block_rows: int,
output_block_columns: int,
) -> torch.Tensor:
"""
Construct indices from padded bins for sparse operations.
Args:
padded_bins: Tensor containing bin boundaries
block_size: Size of each block
output_block_rows: Number of rows in output blocks
output_block_columns: Number of columns in output blocks
Returns:
Tensor containing constructed indices
"""
return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
def replicate_forward(
x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
) -> torch.Tensor:
"""
Forward pass of replicate operation - replicate values according to bin sizes.
Args:
x: Input tensor with values to replicate
bins: Tensor containing bin sizes
out: Output tensor (modified in-place)
Returns:
The output tensor
"""
return ops.replicate_forward(x, bins, out)
def replicate_backward(
grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
) -> torch.Tensor:
"""
Backward pass of replicate operation - reduce gradients back to bins.
Args:
grad: Gradient tensor to reduce
bins: Tensor containing bin sizes
out: Output tensor (modified in-place)
Returns:
The output tensor
"""
return ops.replicate_backward(grad, bins, out)
def sort(
x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
) -> torch.Tensor:
"""
Radix sort with index tracking.
Args:
x: Input tensor to sort
end_bit: Number of bits to consider in sorting
x_out: Output tensor for sorted values
iota_out: Output tensor for sorted indices
Returns:
The sorted values tensor
"""
return ops.sort(x, end_bit, x_out, iota_out)
# Convenience functions for common use cases
def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
"""
Compute cumulative sum with automatic output allocation.
Args:
x: Input tensor
dim: Dimension along which to compute cumsum (default: last dimension)
exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
Returns:
New tensor containing the cumulative sum
"""
out = torch.empty_like(x)
if exclusive:
return exclusive_cumsum(x, dim, out)
else:
return inclusive_cumsum(x, dim, out)
def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
"""
Sort tensor and return both sorted values and indices.
Args:
x: Input tensor to sort
end_bit: Number of bits to consider in sorting
Returns:
Tuple of (sorted_values, sorted_indices)
"""
x_out = torch.empty_like(x)
iota_out = torch.empty_like(x)
sort(x, end_bit, x_out, iota_out)
return x_out, iota_out
# Export public API
__all__ = [
"MyReplacementLayer",
# Direct kernel exports
"exclusive_cumsum",
"inclusive_cumsum",
"histogram",
"indices",
"replicate_forward",
"replicate_backward",
"sort",
"cumsum",
"argsort",
# Original exports
"Arguments",
"ParallelDroplessMLP",
"dMoE",
"SparseGLU",
"MLP",
"SparseMLP",
"MoE",
"ParallelMLP",
"get_load_balancing_loss",
]