Add support for ROCm
Browse files- build.toml +6 -0
- flake.lock +12 -12
- flake.nix +1 -1
- torch-ext/torch_binding.cpp +11 -0
build.toml
CHANGED
|
@@ -46,8 +46,12 @@ include = [ "." ]
|
|
| 46 |
depends = [ "cutlass_3_6", "torch" ]
|
| 47 |
|
| 48 |
[kernel.fp8_common]
|
|
|
|
| 49 |
cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
|
|
|
| 50 |
src = [
|
|
|
|
|
|
|
| 51 |
"fp8/common.cu",
|
| 52 |
"fp8/common.cuh",
|
| 53 |
"dispatch_utils.h",
|
|
@@ -66,7 +70,9 @@ src = [
|
|
| 66 |
depends = [ "torch" ]
|
| 67 |
|
| 68 |
[kernel.int8_common]
|
|
|
|
| 69 |
cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
|
|
|
| 70 |
src = [
|
| 71 |
"compressed_tensors/int8_quant_kernels.cu",
|
| 72 |
"dispatch_utils.h"
|
|
|
|
| 46 |
depends = [ "cutlass_3_6", "torch" ]
|
| 47 |
|
| 48 |
[kernel.fp8_common]
|
| 49 |
+
language = "cuda-hipify"
|
| 50 |
cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
| 51 |
+
rocm-archs = [ "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101" ]
|
| 52 |
src = [
|
| 53 |
+
"fp8/amd/hip_float8.h",
|
| 54 |
+
"fp8/amd/hip_float8_impl.h",
|
| 55 |
"fp8/common.cu",
|
| 56 |
"fp8/common.cuh",
|
| 57 |
"dispatch_utils.h",
|
|
|
|
| 70 |
depends = [ "torch" ]
|
| 71 |
|
| 72 |
[kernel.int8_common]
|
| 73 |
+
language = "cuda-hipify"
|
| 74 |
cuda-capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
| 75 |
+
rocm-archs = [ "gfx906", "gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100", "gfx1101" ]
|
| 76 |
src = [
|
| 77 |
"compressed_tensors/int8_quant_kernels.cu",
|
| 78 |
"dispatch_utils.h"
|
flake.lock
CHANGED
|
@@ -41,17 +41,17 @@
|
|
| 41 |
"rocm-nix": "rocm-nix"
|
| 42 |
},
|
| 43 |
"locked": {
|
| 44 |
-
"lastModified":
|
| 45 |
-
"narHash": "sha256-
|
| 46 |
-
"
|
| 47 |
-
"
|
| 48 |
-
"
|
| 49 |
-
"type": "
|
| 50 |
-
"url": "ssh://[email protected]/huggingface/kernel-builder"
|
| 51 |
},
|
| 52 |
"original": {
|
| 53 |
-
"
|
| 54 |
-
"
|
|
|
|
| 55 |
}
|
| 56 |
},
|
| 57 |
"nixpkgs": {
|
|
@@ -78,11 +78,11 @@
|
|
| 78 |
]
|
| 79 |
},
|
| 80 |
"locked": {
|
| 81 |
-
"lastModified":
|
| 82 |
-
"narHash": "sha256-
|
| 83 |
"owner": "huggingface",
|
| 84 |
"repo": "rocm-nix",
|
| 85 |
-
"rev": "
|
| 86 |
"type": "github"
|
| 87 |
},
|
| 88 |
"original": {
|
|
|
|
| 41 |
"rocm-nix": "rocm-nix"
|
| 42 |
},
|
| 43 |
"locked": {
|
| 44 |
+
"lastModified": 1743416390,
|
| 45 |
+
"narHash": "sha256-Krrrq9asF2d5SVWGJQIhQA8UxVcTpiCor8hQU4G5J38=",
|
| 46 |
+
"owner": "huggingface",
|
| 47 |
+
"repo": "kernel-builder",
|
| 48 |
+
"rev": "e57cbde93f29032d32bbab8e32a1c86def6e9365",
|
| 49 |
+
"type": "github"
|
|
|
|
| 50 |
},
|
| 51 |
"original": {
|
| 52 |
+
"owner": "huggingface",
|
| 53 |
+
"repo": "kernel-builder",
|
| 54 |
+
"type": "github"
|
| 55 |
}
|
| 56 |
},
|
| 57 |
"nixpkgs": {
|
|
|
|
| 78 |
]
|
| 79 |
},
|
| 80 |
"locked": {
|
| 81 |
+
"lastModified": 1743085847,
|
| 82 |
+
"narHash": "sha256-uWG29p+nhZmGRV1LffWwRGjwtPIXeu1F0YTQbXgB+GU=",
|
| 83 |
"owner": "huggingface",
|
| 84 |
"repo": "rocm-nix",
|
| 85 |
+
"rev": "245cdc9bfb4bfafa818711c5f5e0b889afe1ba39",
|
| 86 |
"type": "github"
|
| 87 |
},
|
| 88 |
"original": {
|
flake.nix
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
description = "Flake for quantization kernels";
|
| 3 |
|
| 4 |
inputs = {
|
| 5 |
-
kernel-builder.url = "
|
| 6 |
};
|
| 7 |
|
| 8 |
outputs =
|
|
|
|
| 2 |
description = "Flake for quantization kernels";
|
| 3 |
|
| 4 |
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
};
|
| 7 |
|
| 8 |
outputs =
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -4,6 +4,8 @@
|
|
| 4 |
#include "torch_binding.h"
|
| 5 |
|
| 6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
|
|
|
|
|
| 7 |
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
| 8 |
// quantization, as well as bias
|
| 9 |
ops.def(
|
|
@@ -26,6 +28,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 26 |
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
| 27 |
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
| 28 |
|
|
|
|
|
|
|
| 29 |
// Compute FP8 quantized tensor for given scaling factor.
|
| 30 |
ops.def(
|
| 31 |
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
|
|
@@ -60,6 +64,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 60 |
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
| 61 |
&dynamic_scaled_int8_quant);
|
| 62 |
|
|
|
|
|
|
|
| 63 |
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
| 64 |
ops.def(
|
| 65 |
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
|
@@ -103,8 +109,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 103 |
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
| 104 |
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
| 105 |
"SymInt size_k) -> Tensor");
|
|
|
|
| 106 |
}
|
| 107 |
|
|
|
|
|
|
|
| 108 |
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
|
| 109 |
ops.impl("awq_marlin_repack", &awq_marlin_repack);
|
| 110 |
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
|
@@ -120,4 +129,6 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, ops) {
|
|
| 120 |
ops.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
| 121 |
}
|
| 122 |
|
|
|
|
|
|
|
| 123 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
|
| 4 |
#include "torch_binding.h"
|
| 5 |
|
| 6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
#ifndef USE_ROCM
|
| 8 |
+
|
| 9 |
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
| 10 |
// quantization, as well as bias
|
| 11 |
ops.def(
|
|
|
|
| 28 |
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
| 29 |
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
| 30 |
|
| 31 |
+
#endif
|
| 32 |
+
|
| 33 |
// Compute FP8 quantized tensor for given scaling factor.
|
| 34 |
ops.def(
|
| 35 |
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
|
|
|
|
| 64 |
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
| 65 |
&dynamic_scaled_int8_quant);
|
| 66 |
|
| 67 |
+
#ifndef USE_ROCM
|
| 68 |
+
|
| 69 |
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
| 70 |
ops.def(
|
| 71 |
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
|
|
|
| 109 |
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
| 110 |
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
| 111 |
"SymInt size_k) -> Tensor");
|
| 112 |
+
#endif
|
| 113 |
}
|
| 114 |
|
| 115 |
+
#ifndef USE_ROCM
|
| 116 |
+
|
| 117 |
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
|
| 118 |
ops.impl("awq_marlin_repack", &awq_marlin_repack);
|
| 119 |
ops.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
|
|
|
|
| 129 |
ops.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
| 130 |
}
|
| 131 |
|
| 132 |
+
#endif
|
| 133 |
+
|
| 134 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|