quant_weights is a CPU function
Browse files
torch-ext/torch_binding.cpp
CHANGED
|
@@ -13,7 +13,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 13 |
ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda);
|
| 14 |
ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type,"
|
| 15 |
"bool return_unprocessed_quantized_tensor) -> Tensor[]");
|
| 16 |
-
ops.impl("quant_weights", torch::
|
| 17 |
}
|
| 18 |
|
| 19 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
|
| 13 |
ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda);
|
| 14 |
ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type,"
|
| 15 |
"bool return_unprocessed_quantized_tensor) -> Tensor[]");
|
| 16 |
+
ops.impl("quant_weights", torch::kCPU, &symmetric_quantize_last_axis_of_tensor);
|
| 17 |
}
|
| 18 |
|
| 19 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|