File size: 5,131 Bytes
a585153 6756875 2b84d84 a585153 2b84d84 a585153 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import torch
from ._ops import ops
from .grouped_gemm import backend as gg_backend
from .grouped_gemm import ops as gg_ops
from ._layers.arguments import Arguments
from ._layers.dmoe import ParallelDroplessMLP, dMoE
from ._layers.glu import SparseGLU
from ._layers.mlp import MLP, SparseMLP
from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
from . import layers
# This section contains the direct kernel exports (not inlcuded in the original code)
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
"""
Compute exclusive cumulative sum along the specified dimension.
Args:
x: Input tensor
dim: Dimension along which to compute cumsum
out: Output tensor (modified in-place)
Returns:
The output tensor
"""
result = ops.exclusive_cumsum(x, dim)
out.copy_(result)
return out
def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
"""
Compute inclusive cumulative sum along the specified dimension.
Args:
x: Input tensor
dim: Dimension along which to compute cumsum
out: Output tensor (modified in-place)
Returns:
The output tensor
"""
result = ops.inclusive_cumsum(x, dim)
out.copy_(result)
return out
def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
"""
Compute histogram of input tensor values.
Args:
x: Input tensor
num_bins: Number of histogram bins
Returns:
Histogram tensor with counts for each bin
"""
return ops.histogram(x, num_bins)
def indices(
padded_bins: torch.Tensor,
block_size: int,
output_block_rows: int,
output_block_columns: int,
) -> torch.Tensor:
"""
Construct indices from padded bins for sparse operations.
Args:
padded_bins: Tensor containing bin boundaries
block_size: Size of each block
output_block_rows: Number of rows in output blocks
output_block_columns: Number of columns in output blocks
Returns:
Tensor containing constructed indices
"""
return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
def replicate_forward(
x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
) -> torch.Tensor:
"""
Forward pass of replicate operation - replicate values according to bin sizes.
Args:
x: Input tensor with values to replicate
bins: Tensor containing bin sizes
out: Output tensor (modified in-place)
Returns:
The output tensor
"""
return ops.replicate_forward(x, bins, out)
def replicate_backward(
grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
) -> torch.Tensor:
"""
Backward pass of replicate operation - reduce gradients back to bins.
Args:
grad: Gradient tensor to reduce
bins: Tensor containing bin sizes
out: Output tensor (modified in-place)
Returns:
The output tensor
"""
return ops.replicate_backward(grad, bins, out)
def sort(
x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
) -> torch.Tensor:
"""
Radix sort with index tracking.
Args:
x: Input tensor to sort
end_bit: Number of bits to consider in sorting
x_out: Output tensor for sorted values
iota_out: Output tensor for sorted indices
Returns:
The sorted values tensor
"""
return ops.sort(x, end_bit, x_out, iota_out)
# Convenience functions for common use cases
def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
"""
Compute cumulative sum with automatic output allocation.
Args:
x: Input tensor
dim: Dimension along which to compute cumsum (default: last dimension)
exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
Returns:
New tensor containing the cumulative sum
"""
out = torch.empty_like(x)
if exclusive:
return exclusive_cumsum(x, dim, out)
else:
return inclusive_cumsum(x, dim, out)
def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
"""
Sort tensor and return both sorted values and indices.
Args:
x: Input tensor to sort
end_bit: Number of bits to consider in sorting
Returns:
Tuple of (sorted_values, sorted_indices)
"""
x_out = torch.empty_like(x)
iota_out = torch.empty_like(x)
sort(x, end_bit, x_out, iota_out)
return x_out, iota_out
# Export public API
__all__ = [
"MyReplacementLayer",
# Direct kernel exports
"exclusive_cumsum",
"inclusive_cumsum",
"histogram",
"indices",
"replicate_forward",
"replicate_backward",
"sort",
"cumsum",
"argsort",
# Original exports
"Arguments",
"ParallelDroplessMLP",
"dMoE",
"SparseGLU",
"MLP",
"SparseMLP",
"MoE",
"ParallelMLP",
"get_load_balancing_loss",
]
|