danieldk HF Staff commited on
Commit
f701ae7
·
1 Parent(s): a1bca63

Add sources

Browse files
build.toml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "relu_metal"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h"
9
+ ]
10
+
11
+ [kernel.activation]
12
+ backend = "metal"
13
+ src = [
14
+ "relu/relu.mm",
15
+ ]
16
+ depends = [ "torch" ]
flake.lock ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1750234878,
77
+ "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1751910742,
102
+ "owner": "huggingface",
103
+ "repo": "kernel-builder",
104
+ "rev": "f1099723e3df41950b073051839bc2c5b088c380",
105
+ "type": "github"
106
+ },
107
+ "original": {
108
+ "owner": "huggingface",
109
+ "repo": "kernel-builder",
110
+ "type": "github"
111
+ }
112
+ },
113
+ "nixpkgs": {
114
+ "locked": {
115
+ "lastModified": 1747820358,
116
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
117
+ "owner": "danieldk",
118
+ "repo": "nixpkgs",
119
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
120
+ "type": "github"
121
+ },
122
+ "original": {
123
+ "owner": "danieldk",
124
+ "ref": "cudatoolkit-12.9-kernel-builder",
125
+ "repo": "nixpkgs",
126
+ "type": "github"
127
+ }
128
+ },
129
+ "root": {
130
+ "inputs": {
131
+ "kernel-builder": "kernel-builder"
132
+ }
133
+ },
134
+ "systems": {
135
+ "locked": {
136
+ "lastModified": 1681028828,
137
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
138
+ "owner": "nix-systems",
139
+ "repo": "default",
140
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
141
+ "type": "github"
142
+ },
143
+ "original": {
144
+ "owner": "nix-systems",
145
+ "repo": "default",
146
+ "type": "github"
147
+ }
148
+ },
149
+ "systems_2": {
150
+ "locked": {
151
+ "lastModified": 1681028828,
152
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
153
+ "owner": "nix-systems",
154
+ "repo": "default",
155
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
156
+ "type": "github"
157
+ },
158
+ "original": {
159
+ "owner": "nix-systems",
160
+ "repo": "default",
161
+ "type": "github"
162
+ }
163
+ }
164
+ },
165
+ "root": "root",
166
+ "version": 7
167
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Metal ReLU test kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
relu/relu.mm ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/torch.h>
2
+
3
+ #import <Foundation/Foundation.h>
4
+ #import <Metal/Metal.h>
5
+ #include <string>
6
+
7
+ char const *CUSTOM_KERNEL = R"(
8
+ #include <metal_stdlib>
9
+ using namespace metal;
10
+
11
+ kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]],
12
+ device float *outC [[buffer(1)]],
13
+ uint index [[thread_position_in_grid]]) {
14
+ // Explicitly write to output
15
+ outC[index] = max(0.0f, inA[index]);
16
+ }
17
+
18
+ kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]],
19
+ device half *outC [[buffer(1)]],
20
+ uint index [[thread_position_in_grid]]) {
21
+ // Explicitly write to output
22
+ outC[index] = max(static_cast<half>(0.0), inA[index]);
23
+ }
24
+ )";
25
+
26
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
27
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
28
+ }
29
+
30
+ torch::Tensor &dispatchReluKernel(torch::Tensor const &input, torch::Tensor &output) {
31
+ @autoreleasepool {
32
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
33
+ NSError *error = nil;
34
+
35
+ int numThreads = input.numel();
36
+
37
+ id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource:[NSString stringWithUTF8String:CUSTOM_KERNEL]
38
+ options:nil
39
+ error:&error];
40
+ TORCH_CHECK(customKernelLibrary, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String);
41
+
42
+ std::string kernel_name = std::string("relu_forward_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half");
43
+ id<MTLFunction> customReluFunction = [customKernelLibrary newFunctionWithName:[NSString stringWithUTF8String:kernel_name.c_str()]];
44
+ TORCH_CHECK(customReluFunction, "Failed to create function state object for ", kernel_name.c_str());
45
+
46
+ id<MTLComputePipelineState> reluPSO = [device newComputePipelineStateWithFunction:customReluFunction error:&error];
47
+ TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String);
48
+
49
+ id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
50
+ TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
51
+
52
+ dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
53
+
54
+ dispatch_sync(serialQueue, ^(){
55
+ id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
56
+ TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
57
+
58
+ [computeEncoder setComputePipelineState:reluPSO];
59
+ [computeEncoder setBuffer:getMTLBufferStorage(input) offset:input.storage_offset() * input.element_size() atIndex:0];
60
+ [computeEncoder setBuffer:getMTLBufferStorage(output) offset:output.storage_offset() * output.element_size() atIndex:1];
61
+
62
+ MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
63
+
64
+ NSUInteger threadGroupSize = reluPSO.maxTotalThreadsPerThreadgroup;
65
+ if (threadGroupSize > numThreads) {
66
+ threadGroupSize = numThreads;
67
+ }
68
+ MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1);
69
+
70
+ [computeEncoder dispatchThreads:gridSize
71
+ threadsPerThreadgroup:threadgroupSize];
72
+
73
+ [computeEncoder endEncoding];
74
+
75
+ torch::mps::commit();
76
+ });
77
+ }
78
+
79
+ return output;
80
+ }
81
+
82
+ torch::Tensor mps_relu(const torch::Tensor &input) {
83
+ TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
84
+ TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
85
+
86
+ TORCH_CHECK(input.scalar_type() == torch::kFloat ||
87
+ input.scalar_type() == torch::kHalf, "Unsupported data type: ", input.scalar_type());
88
+
89
+ torch::Tensor output = torch::empty_like(input);
90
+
91
+ return dispatchReluKernel(input, output);
92
+ }
torch-ext/registration.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Registration macros from vLLM:
2
+ // https://github.com/vllm-project/vllm/blob/main/csrc/core/registration.h
3
+
4
+ #pragma once
5
+
6
+ #include <Python.h>
7
+
8
+ #define _CONCAT(A, B) A##B
9
+ #define CONCAT(A, B) _CONCAT(A, B)
10
+
11
+ #define _STRINGIFY(A) #A
12
+ #define STRINGIFY(A) _STRINGIFY(A)
13
+
14
+ // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
15
+ // could be a macro instead of a literal token.
16
+ #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
17
+
18
+ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
19
+ // could be a macro instead of a literal token.
20
+ #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
21
+ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
22
+
23
+ // REGISTER_EXTENSION allows the shared library to be loaded and initialized
24
+ // via python's import statement.
25
+ #define REGISTER_EXTENSION(NAME) \
26
+ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
27
+ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
28
+ STRINGIFY(NAME), nullptr, 0, nullptr}; \
29
+ return PyModule_Create(&module); \
30
+ }
torch-ext/relu_metal/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ._ops import ops
4
+
5
+
6
+ def relu(input: torch.Tensor) -> torch.Tensor:
7
+ return ops.relu(input)
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("relu(Tensor input) -> Tensor");
8
+ ops.impl("relu", torch::kMPS, mps_relu);
9
+ }
10
+
11
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #include <torch/torch.h>
2
+
3
+ torch::Tensor mps_relu(const torch::Tensor &input);