File size: 5,131 Bytes
a585153
 
 
 
 
 
 
6756875
 
 
 
2b84d84
 
 
 
 
 
 
a585153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b84d84
a585153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0

import torch

from ._ops import ops

from .grouped_gemm import backend as gg_backend
from .grouped_gemm import ops as gg_ops


from ._layers.arguments import Arguments
from ._layers.dmoe import ParallelDroplessMLP, dMoE
from ._layers.glu import SparseGLU
from ._layers.mlp import MLP, SparseMLP
from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss

from . import layers

# This section contains the direct kernel exports (not inlcuded in the original code)
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
    """
    Compute exclusive cumulative sum along the specified dimension.

    Args:
        x: Input tensor
        dim: Dimension along which to compute cumsum
        out: Output tensor (modified in-place)

    Returns:
        The output tensor
    """
    result = ops.exclusive_cumsum(x, dim)
    out.copy_(result)
    return out


def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
    """
    Compute inclusive cumulative sum along the specified dimension.

    Args:
        x: Input tensor
        dim: Dimension along which to compute cumsum
        out: Output tensor (modified in-place)

    Returns:
        The output tensor
    """
    result = ops.inclusive_cumsum(x, dim)
    out.copy_(result)
    return out


def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
    """
    Compute histogram of input tensor values.

    Args:
        x: Input tensor
        num_bins: Number of histogram bins

    Returns:
        Histogram tensor with counts for each bin
    """
    return ops.histogram(x, num_bins)


def indices(
    padded_bins: torch.Tensor,
    block_size: int,
    output_block_rows: int,
    output_block_columns: int,
) -> torch.Tensor:
    """
    Construct indices from padded bins for sparse operations.

    Args:
        padded_bins: Tensor containing bin boundaries
        block_size: Size of each block
        output_block_rows: Number of rows in output blocks
        output_block_columns: Number of columns in output blocks

    Returns:
        Tensor containing constructed indices
    """
    return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)


def replicate_forward(
    x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
) -> torch.Tensor:
    """
    Forward pass of replicate operation - replicate values according to bin sizes.

    Args:
        x: Input tensor with values to replicate
        bins: Tensor containing bin sizes
        out: Output tensor (modified in-place)

    Returns:
        The output tensor
    """
    return ops.replicate_forward(x, bins, out)


def replicate_backward(
    grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
) -> torch.Tensor:
    """
    Backward pass of replicate operation - reduce gradients back to bins.

    Args:
        grad: Gradient tensor to reduce
        bins: Tensor containing bin sizes
        out: Output tensor (modified in-place)

    Returns:
        The output tensor
    """
    return ops.replicate_backward(grad, bins, out)


def sort(
    x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
) -> torch.Tensor:
    """
    Radix sort with index tracking.

    Args:
        x: Input tensor to sort
        end_bit: Number of bits to consider in sorting
        x_out: Output tensor for sorted values
        iota_out: Output tensor for sorted indices

    Returns:
        The sorted values tensor
    """
    return ops.sort(x, end_bit, x_out, iota_out)


# Convenience functions for common use cases
def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
    """
    Compute cumulative sum with automatic output allocation.

    Args:
        x: Input tensor
        dim: Dimension along which to compute cumsum (default: last dimension)
        exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum

    Returns:
        New tensor containing the cumulative sum
    """
    out = torch.empty_like(x)
    if exclusive:
        return exclusive_cumsum(x, dim, out)
    else:
        return inclusive_cumsum(x, dim, out)


def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Sort tensor and return both sorted values and indices.

    Args:
        x: Input tensor to sort
        end_bit: Number of bits to consider in sorting

    Returns:
        Tuple of (sorted_values, sorted_indices)
    """
    x_out = torch.empty_like(x)
    iota_out = torch.empty_like(x)
    sort(x, end_bit, x_out, iota_out)
    return x_out, iota_out


# Export public API
__all__ = [
    "MyReplacementLayer",
    # Direct kernel exports
    "exclusive_cumsum",
    "inclusive_cumsum",
    "histogram",
    "indices",
    "replicate_forward",
    "replicate_backward",
    "sort",
    "cumsum",
    "argsort",
    # Original exports
    "Arguments",
    "ParallelDroplessMLP",
    "dMoE",
    "SparseGLU",
    "MLP",
    "SparseMLP",
    "MoE",
    "ParallelMLP",
    "get_load_balancing_loss",
]