Commit
·
1dc29e9
0
Parent(s):
Import EETQ kernels
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build.toml +85 -0
- cutlass_extensions/include/cutlass_extensions/arch/mma.h +46 -0
- cutlass_extensions/include/cutlass_extensions/compute_occupancy.h +51 -0
- cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h +48 -0
- cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h +148 -0
- cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +390 -0
- cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +285 -0
- cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h +82 -0
- cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h +58 -0
- cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +123 -0
- cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +492 -0
- cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h +447 -0
- cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +89 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h +106 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +346 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +315 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h +426 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +527 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h +236 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +599 -0
- cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +385 -0
- cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +127 -0
- cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +313 -0
- cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +469 -0
- cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +429 -0
- cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h +61 -0
- cutlass_kernels/cutlass_heuristic.cu +208 -0
- cutlass_kernels/cutlass_heuristic.h +39 -0
- cutlass_kernels/cutlass_preprocessors.cc +703 -0
- cutlass_kernels/cutlass_preprocessors.h +33 -0
- cutlass_kernels/fpA_intB_gemm.cu +99 -0
- cutlass_kernels/fpA_intB_gemm.h +36 -0
- cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +118 -0
- cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +858 -0
- cutlass_kernels/fpA_intB_gemm_wrapper.cu +201 -0
- cutlass_kernels/fpA_intB_gemm_wrapper.h +23 -0
- torch-ext/quantization_eetq/__init__.py +3 -0
- torch-ext/quantization_eetq/custom_ops.py +36 -0
- torch-ext/registration.h +27 -0
- torch-ext/torch_binding.cpp +19 -0
- torch-ext/torch_binding.h +25 -0
- utils/activation_types.h +40 -0
- utils/cuda_utils.cc +55 -0
- utils/cuda_utils.h +76 -0
- utils/logger.cc +59 -0
- utils/logger.h +121 -0
- utils/string_utils.h +54 -0
- utils/torch_utils.h +65 -0
- weightOnlyBatchedGemv/common.h +107 -0
- weightOnlyBatchedGemv/enabled.h +105 -0
build.toml
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
version = "0.0.1"
|
| 3 |
+
|
| 4 |
+
[torch]
|
| 5 |
+
name = "quantization_eetq"
|
| 6 |
+
src = [
|
| 7 |
+
"torch-ext/registration.h",
|
| 8 |
+
"torch-ext/torch_binding.cpp",
|
| 9 |
+
"torch-ext/torch_binding.h"
|
| 10 |
+
]
|
| 11 |
+
pyroot = "torch-ext"
|
| 12 |
+
|
| 13 |
+
[kernel.cutlass_kernels]
|
| 14 |
+
capabilities = [ "7.0", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
| 15 |
+
src = [
|
| 16 |
+
"cutlass_extensions/include/cutlass_extensions/arch/mma.h",
|
| 17 |
+
"cutlass_extensions/include/cutlass_extensions/compute_occupancy.h",
|
| 18 |
+
"cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h",
|
| 19 |
+
"cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h",
|
| 20 |
+
"cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h",
|
| 21 |
+
"cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h",
|
| 22 |
+
"cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h",
|
| 23 |
+
"cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h",
|
| 24 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h",
|
| 25 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h",
|
| 26 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h",
|
| 27 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
|
| 28 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h",
|
| 29 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h",
|
| 30 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h",
|
| 31 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h",
|
| 32 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h",
|
| 33 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h",
|
| 34 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h",
|
| 35 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h",
|
| 36 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h",
|
| 37 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h",
|
| 38 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h",
|
| 39 |
+
"cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
|
| 40 |
+
"cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h",
|
| 41 |
+
"cutlass_kernels/cutlass_heuristic.cu",
|
| 42 |
+
"cutlass_kernels/cutlass_heuristic.h",
|
| 43 |
+
"cutlass_kernels/cutlass_preprocessors.cc",
|
| 44 |
+
"cutlass_kernels/cutlass_preprocessors.h",
|
| 45 |
+
"cutlass_kernels/fpA_intB_gemm.cu",
|
| 46 |
+
"cutlass_kernels/fpA_intB_gemm.h",
|
| 47 |
+
"cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h",
|
| 48 |
+
"cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h",
|
| 49 |
+
"cutlass_kernels/fpA_intB_gemm_wrapper.cu",
|
| 50 |
+
"cutlass_kernels/fpA_intB_gemm_wrapper.h",
|
| 51 |
+
"weightOnlyBatchedGemv/common.h",
|
| 52 |
+
"weightOnlyBatchedGemv/enabled.h",
|
| 53 |
+
"utils/activation_types.h",
|
| 54 |
+
"utils/cuda_utils.h",
|
| 55 |
+
"utils/logger.cc",
|
| 56 |
+
"utils/logger.h",
|
| 57 |
+
"utils/string_utils.h",
|
| 58 |
+
"utils/torch_utils.h",
|
| 59 |
+
]
|
| 60 |
+
depends = [ "cutlass_2_10", "torch" ]
|
| 61 |
+
include = [ ".", "utils", "cutlass_extensions/include" ]
|
| 62 |
+
|
| 63 |
+
[kernel.weight_only_batched_gemv]
|
| 64 |
+
capabilities = [ "7.0", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
| 65 |
+
src = [
|
| 66 |
+
"cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h",
|
| 67 |
+
"cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h",
|
| 68 |
+
"weightOnlyBatchedGemv/common.h",
|
| 69 |
+
"weightOnlyBatchedGemv/enabled.h",
|
| 70 |
+
"weightOnlyBatchedGemv/kernel.h",
|
| 71 |
+
"weightOnlyBatchedGemv/kernelLauncher.cu",
|
| 72 |
+
"weightOnlyBatchedGemv/kernelLauncher.h",
|
| 73 |
+
"weightOnlyBatchedGemv/utility.h",
|
| 74 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int4b.cu",
|
| 75 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs1Int8b.cu",
|
| 76 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int4b.cu",
|
| 77 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs2Int8b.cu",
|
| 78 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int4b.cu",
|
| 79 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs3Int8b.cu",
|
| 80 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int4b.cu",
|
| 81 |
+
"weightOnlyBatchedGemv/weightOnlyBatchedGemvBs4Int8b.cu",
|
| 82 |
+
]
|
| 83 |
+
depends = [ "cutlass_2_10", "torch" ]
|
| 84 |
+
include = [ "cutlass_extensions/include" ]
|
| 85 |
+
|
cutlass_extensions/include/cutlass_extensions/arch/mma.h
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates exposing architecture support for multiply-add operations
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 38 |
+
|
| 39 |
+
namespace cutlass {
|
| 40 |
+
namespace arch {
|
| 41 |
+
|
| 42 |
+
// Tag which triggers MMA which will trigger
|
| 43 |
+
struct OpMultiplyAddDequantizeInterleavedBToA;
|
| 44 |
+
|
| 45 |
+
} // namespace arch
|
| 46 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/compute_occupancy.h
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
#pragma once
|
| 17 |
+
|
| 18 |
+
#include <cuda_runtime_api.h>
|
| 19 |
+
|
| 20 |
+
#include "cutlass/device_kernel.h"
|
| 21 |
+
#include "utils/cuda_utils.h"
|
| 22 |
+
|
| 23 |
+
namespace fastertransformer {
|
| 24 |
+
|
| 25 |
+
template<typename GemmKernel>
|
| 26 |
+
inline int compute_occupancy_for_kernel()
|
| 27 |
+
{
|
| 28 |
+
|
| 29 |
+
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
| 30 |
+
|
| 31 |
+
if (smem_size > (48 << 10)) {
|
| 32 |
+
cudaError_t status =
|
| 33 |
+
cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
| 34 |
+
if (status == cudaError::cudaErrorInvalidValue) {
|
| 35 |
+
// Clear the error bit since we can ignore this.
|
| 36 |
+
// This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an
|
| 37 |
+
// occupancy of 0. This will cause the heuristic to ignore this configuration.
|
| 38 |
+
status = cudaGetLastError();
|
| 39 |
+
return 0;
|
| 40 |
+
}
|
| 41 |
+
check_cuda_error(status);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
int max_active_blocks = -1;
|
| 45 |
+
check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
| 46 |
+
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
|
| 47 |
+
|
| 48 |
+
return max_active_blocks;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
} // namespace fastertransformer
|
cutlass_extensions/include/cutlass_extensions/epilogue/epilogue_quant_helper.h
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 35 |
+
|
| 36 |
+
namespace cutlass {
|
| 37 |
+
namespace epilogue {
|
| 38 |
+
|
| 39 |
+
// define scaling mode
|
| 40 |
+
enum class QuantMode {
|
| 41 |
+
PerTensorQuant,
|
| 42 |
+
PerTokenQuant,
|
| 43 |
+
PerChannelQuant,
|
| 44 |
+
PerTokenChannelQuant
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
} // namespace epilogue
|
| 48 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/epilogue/thread/ft_fused_activations.h
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/array.h"
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/epilogue/thread/activation.h"
|
| 40 |
+
#include "cutlass/epilogue/thread/scale_type.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/half.h"
|
| 43 |
+
#include "cutlass/numeric_conversion.h"
|
| 44 |
+
#include "cutlass/numeric_types.h"
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace epilogue {
|
| 50 |
+
namespace thread {
|
| 51 |
+
|
| 52 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
__forceinline__ __device__ float copysignf_pos(float a, float b)
|
| 55 |
+
{
|
| 56 |
+
float r;
|
| 57 |
+
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
|
| 58 |
+
return r;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
__forceinline__ __device__ float tanh_opt(float x)
|
| 62 |
+
{
|
| 63 |
+
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
|
| 64 |
+
const float exp_val = -1.f * fabs(2 * x);
|
| 65 |
+
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
| 66 |
+
#else
|
| 67 |
+
return fast_tanh(x);
|
| 68 |
+
#endif
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 72 |
+
|
| 73 |
+
// DdK: GELU_taylor ir incomplete in 2.10. Vendored fixes here.
|
| 74 |
+
|
| 75 |
+
// GELU operator implemented using the Taylor series approximation
|
| 76 |
+
template <typename T>
|
| 77 |
+
struct GELU_taylor_fixed {
|
| 78 |
+
static const bool kIsHeavy=true;
|
| 79 |
+
CUTLASS_HOST_DEVICE
|
| 80 |
+
T operator()(T const &z) const {
|
| 81 |
+
|
| 82 |
+
T k0 = T(0.7978845608028654);
|
| 83 |
+
T k1 = T(0.044715);
|
| 84 |
+
|
| 85 |
+
return T(cutlass::constants::half<T>() * z *
|
| 86 |
+
(cutlass::constants::one<T>() + fast_tanh(k0 * z * (cutlass::constants::one<T>() + k1 * z * z))));
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
using Params = LinearCombinationGenericParams<T>;
|
| 90 |
+
|
| 91 |
+
CUTLASS_HOST_DEVICE
|
| 92 |
+
T operator()(T const &scalar, Params const ¶ms_) const {
|
| 93 |
+
return this->operator()(scalar);
|
| 94 |
+
}
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
template<>
|
| 98 |
+
struct GELU_taylor_fixed<float> {
|
| 99 |
+
static const bool kIsHeavy = true;
|
| 100 |
+
CUTLASS_DEVICE
|
| 101 |
+
float operator()(float const& z) const
|
| 102 |
+
{
|
| 103 |
+
|
| 104 |
+
float k0 = float(0.7978845608028654);
|
| 105 |
+
float k1 = float(0.044715);
|
| 106 |
+
|
| 107 |
+
return float(
|
| 108 |
+
cutlass::constants::half<float>() * z
|
| 109 |
+
* (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
using Params = LinearCombinationGenericParams<float>;
|
| 113 |
+
|
| 114 |
+
CUTLASS_DEVICE
|
| 115 |
+
float operator()(float const& scalar, Params const& params_) const
|
| 116 |
+
{
|
| 117 |
+
return this->operator()(scalar);
|
| 118 |
+
}
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
template <typename T, int N>
|
| 122 |
+
struct GELU_taylor_fixed<Array<T, N> > {
|
| 123 |
+
static const bool kIsHeavy=true;
|
| 124 |
+
CUTLASS_HOST_DEVICE
|
| 125 |
+
Array<T, N> operator()(Array<T, N> const &rhs) const {
|
| 126 |
+
Array<T, N> y;
|
| 127 |
+
GELU_taylor<T> gelu_op;
|
| 128 |
+
|
| 129 |
+
CUTLASS_PRAGMA_UNROLL
|
| 130 |
+
for (int i = 0; i < N; ++i) {
|
| 131 |
+
y[i] = gelu_op(rhs[i]);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
return y;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
using Params = LinearCombinationGenericParams<T>;
|
| 138 |
+
CUTLASS_HOST_DEVICE
|
| 139 |
+
Array<T, N> operator()(Array<T, N> const &rhs, Params const ¶ms_) const {
|
| 140 |
+
return this->operator()(rhs);
|
| 141 |
+
}
|
| 142 |
+
};
|
| 143 |
+
|
| 144 |
+
} // namespace thread
|
| 145 |
+
} // namespace epilogue
|
| 146 |
+
} // namespace cutlass
|
| 147 |
+
|
| 148 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
|
| 33 |
+
|
| 34 |
+
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
|
| 35 |
+
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
#include "../epilogue_quant_helper.h"
|
| 43 |
+
#include "cutlass/arch/memory.h"
|
| 44 |
+
#include "cutlass/arch/memory_sm75.h"
|
| 45 |
+
#include "cutlass/cutlass.h"
|
| 46 |
+
#include "cutlass/fast_math.h"
|
| 47 |
+
#include "cutlass/numeric_conversion.h"
|
| 48 |
+
|
| 49 |
+
namespace cutlass {
|
| 50 |
+
namespace epilogue {
|
| 51 |
+
namespace threadblock {
|
| 52 |
+
|
| 53 |
+
template<typename ThreadblockShape_,
|
| 54 |
+
int ThreadCount,
|
| 55 |
+
typename ScaleTileIterator_,
|
| 56 |
+
typename OutputTileIterator_,
|
| 57 |
+
typename ElementAccumulator_,
|
| 58 |
+
typename ElementCompute_,
|
| 59 |
+
typename ElementwiseFunctor_,
|
| 60 |
+
bool UseMasking_ = false>
|
| 61 |
+
class EpilogueVisitorPerRowPerCol {
|
| 62 |
+
public:
|
| 63 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 64 |
+
static int const kThreadCount = ThreadCount;
|
| 65 |
+
|
| 66 |
+
using ScaleTileIterator = ScaleTileIterator_;
|
| 67 |
+
using OutputTileIterator = OutputTileIterator_;
|
| 68 |
+
using ElementwiseFunctor = ElementwiseFunctor_;
|
| 69 |
+
|
| 70 |
+
static int const kIterations = OutputTileIterator::kIterations;
|
| 71 |
+
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
| 72 |
+
|
| 73 |
+
using ElementOutput = typename OutputTileIterator::Element;
|
| 74 |
+
using LayoutOutput = cutlass::layout::RowMajor;
|
| 75 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 76 |
+
|
| 77 |
+
using AlphaScaleElementType = typename ScaleTileIterator::Element;
|
| 78 |
+
|
| 79 |
+
using ElementCompute = ElementCompute_;
|
| 80 |
+
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
| 81 |
+
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
|
| 82 |
+
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
| 83 |
+
|
| 84 |
+
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
| 85 |
+
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
| 86 |
+
|
| 87 |
+
/// Argument structure
|
| 88 |
+
struct Arguments {
|
| 89 |
+
|
| 90 |
+
typename ElementwiseFunctor::Params elementwise;
|
| 91 |
+
int64_t batch_stride_alpha;
|
| 92 |
+
int64_t batch_stride_C;
|
| 93 |
+
int64_t batch_stride_D;
|
| 94 |
+
|
| 95 |
+
//
|
| 96 |
+
// Methods
|
| 97 |
+
//
|
| 98 |
+
Arguments(): batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
|
| 99 |
+
|
| 100 |
+
Arguments(typename ElementwiseFunctor::Params elementwise_):
|
| 101 |
+
elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0)
|
| 102 |
+
{
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
Arguments(typename ElementwiseFunctor::Params elementwise_,
|
| 106 |
+
int64_t batch_stride_alpha_,
|
| 107 |
+
int64_t batch_stride_C_,
|
| 108 |
+
int64_t batch_stride_D_):
|
| 109 |
+
elementwise(elementwise_),
|
| 110 |
+
batch_stride_alpha(batch_stride_alpha_),
|
| 111 |
+
batch_stride_C(batch_stride_C_),
|
| 112 |
+
batch_stride_D(batch_stride_D_)
|
| 113 |
+
{
|
| 114 |
+
}
|
| 115 |
+
};
|
| 116 |
+
|
| 117 |
+
struct Params {
|
| 118 |
+
|
| 119 |
+
typename ElementwiseFunctor::Params elementwise;
|
| 120 |
+
int64_t batch_stride_alpha;
|
| 121 |
+
int64_t batch_stride_C;
|
| 122 |
+
int64_t batch_stride_D;
|
| 123 |
+
//
|
| 124 |
+
// Methods
|
| 125 |
+
//
|
| 126 |
+
CUTLASS_HOST_DEVICE
|
| 127 |
+
Params() {}
|
| 128 |
+
|
| 129 |
+
CUTLASS_HOST_DEVICE
|
| 130 |
+
Params(Arguments const& args):
|
| 131 |
+
elementwise(args.elementwise),
|
| 132 |
+
batch_stride_alpha(args.batch_stride_alpha),
|
| 133 |
+
batch_stride_C(args.batch_stride_C),
|
| 134 |
+
batch_stride_D(args.batch_stride_D)
|
| 135 |
+
{
|
| 136 |
+
}
|
| 137 |
+
};
|
| 138 |
+
|
| 139 |
+
/// Shared storage
|
| 140 |
+
struct SharedStorage {};
|
| 141 |
+
|
| 142 |
+
private:
|
| 143 |
+
Params const& params_;
|
| 144 |
+
SharedStorage& shared_storage_;
|
| 145 |
+
MatrixCoord extent_;
|
| 146 |
+
MatrixCoord extent_real_;
|
| 147 |
+
ElementwiseFunctor elementwise_;
|
| 148 |
+
|
| 149 |
+
const bool per_token_quant_;
|
| 150 |
+
const bool per_channel_quant_;
|
| 151 |
+
|
| 152 |
+
AlphaScaleElementType* ptr_alpha_row_;
|
| 153 |
+
AlphaScaleElementType* ptr_alpha_col_;
|
| 154 |
+
ScaleTileIterator iterator_alpha_col_;
|
| 155 |
+
OutputTileIterator iterator_C_;
|
| 156 |
+
OutputTileIterator iterator_D_;
|
| 157 |
+
|
| 158 |
+
AlphaScaleElementType element_alpha_row_ = 1.0f;
|
| 159 |
+
AlphaScaleElementType element_alpha_col_ = 1.0f;
|
| 160 |
+
typename ScaleTileIterator::Fragment fragment_alpha_col_;
|
| 161 |
+
typename OutputTileIterator::Fragment fragment_C_;
|
| 162 |
+
typename OutputTileIterator::Fragment fragment_D_;
|
| 163 |
+
|
| 164 |
+
ElementAccumulator beta_;
|
| 165 |
+
|
| 166 |
+
int column_offset_;
|
| 167 |
+
|
| 168 |
+
MatrixCoord thread_offset_;
|
| 169 |
+
|
| 170 |
+
public:
|
| 171 |
+
CUTLASS_DEVICE
|
| 172 |
+
EpilogueVisitorPerRowPerCol(Params const& params,
|
| 173 |
+
SharedStorage& shared_storage,
|
| 174 |
+
cutlass::MatrixCoord const& problem_size,
|
| 175 |
+
int thread_idx,
|
| 176 |
+
int warp_idx,
|
| 177 |
+
int lane_idx,
|
| 178 |
+
typename ScaleTileIterator::Params params_alpha_col,
|
| 179 |
+
typename OutputTileIterator::Params params_C,
|
| 180 |
+
typename OutputTileIterator::Params params_D,
|
| 181 |
+
QuantMode quant_mode,
|
| 182 |
+
AlphaScaleElementType* ptr_alpha_row,
|
| 183 |
+
AlphaScaleElementType* ptr_alpha_col,
|
| 184 |
+
typename OutputTileIterator::Element* ptr_C,
|
| 185 |
+
typename OutputTileIterator::Element* ptr_D,
|
| 186 |
+
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
|
| 187 |
+
int column_offset = 0,
|
| 188 |
+
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)):
|
| 189 |
+
params_(params),
|
| 190 |
+
shared_storage_(shared_storage),
|
| 191 |
+
extent_(problem_size),
|
| 192 |
+
elementwise_(params.elementwise),
|
| 193 |
+
per_token_quant_(quant_mode == QuantMode::PerTokenQuant || quant_mode == QuantMode::PerTokenChannelQuant),
|
| 194 |
+
per_channel_quant_(quant_mode == QuantMode::PerChannelQuant || quant_mode == QuantMode::PerTokenChannelQuant),
|
| 195 |
+
ptr_alpha_row_(ptr_alpha_row),
|
| 196 |
+
ptr_alpha_col_(ptr_alpha_col),
|
| 197 |
+
iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
|
| 198 |
+
iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
|
| 199 |
+
iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
|
| 200 |
+
extent_real_(problem_size_real)
|
| 201 |
+
{
|
| 202 |
+
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
| 203 |
+
|
| 204 |
+
if (beta_ == ElementAccumulator()) {
|
| 205 |
+
iterator_C_.clear_mask();
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
/// Helper to indicate split-K behavior
|
| 210 |
+
CUTLASS_DEVICE
|
| 211 |
+
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
| 212 |
+
int split_k_slices)
|
| 213 |
+
{ ///< Total number of split-K slices
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
/// Called to set the batch index
|
| 217 |
+
CUTLASS_DEVICE
|
| 218 |
+
void set_batch_index(int batch_idx)
|
| 219 |
+
{
|
| 220 |
+
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
|
| 221 |
+
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
| 222 |
+
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
/// Called at the start of the epilogue just before iterating over accumulator slices
|
| 226 |
+
CUTLASS_DEVICE
|
| 227 |
+
void begin_epilogue()
|
| 228 |
+
{
|
| 229 |
+
if (per_channel_quant_) {
|
| 230 |
+
iterator_alpha_col_.load(fragment_alpha_col_);
|
| 231 |
+
}
|
| 232 |
+
else if (ptr_alpha_col_ != nullptr) {
|
| 233 |
+
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
| 234 |
+
element_alpha_col_, ptr_alpha_col_, true);
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
if (!per_token_quant_ && ptr_alpha_row_ != nullptr) {
|
| 238 |
+
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
| 239 |
+
element_alpha_row_, ptr_alpha_row_, true);
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
/// Called at the start of one step before starting accumulator exchange
|
| 244 |
+
CUTLASS_DEVICE
|
| 245 |
+
void begin_step(int step_idx)
|
| 246 |
+
{
|
| 247 |
+
fragment_D_.clear();
|
| 248 |
+
fragment_C_.clear();
|
| 249 |
+
|
| 250 |
+
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
| 251 |
+
iterator_C_.load(fragment_C_);
|
| 252 |
+
++iterator_C_;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
// load alpha_row in begin_step only when per token(row) scaling is used
|
| 256 |
+
if (per_token_quant_) {
|
| 257 |
+
int thread_offset_row =
|
| 258 |
+
iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(0).row();
|
| 259 |
+
|
| 260 |
+
// element_alpha_row_ = ptr_alpha_row_[thread_offset_row];
|
| 261 |
+
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
| 262 |
+
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
/// Called at the start of a row
|
| 267 |
+
CUTLASS_DEVICE
|
| 268 |
+
void begin_row(int row_idx)
|
| 269 |
+
{
|
| 270 |
+
// Clear accumulators for max and sum when starting a whole row
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
/// Called after accumulators have been exchanged for each accumulator vector
|
| 274 |
+
CUTLASS_DEVICE
|
| 275 |
+
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
|
| 276 |
+
{
|
| 277 |
+
|
| 278 |
+
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
|
| 279 |
+
|
| 280 |
+
ComputeFragment result = source_converter(accum);
|
| 281 |
+
if (per_channel_quant_) {
|
| 282 |
+
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[frag_idx];
|
| 283 |
+
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
|
| 284 |
+
}
|
| 285 |
+
else {
|
| 286 |
+
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
/* printf("%d %e\n", accum[0], result[0]); */
|
| 290 |
+
/* scale_accumulator_(result, alpha_row_vector[0]); //TODO(mseznec) */
|
| 291 |
+
|
| 292 |
+
/* if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { */
|
| 293 |
+
/* result = source_converter(elementwise_(result)); */
|
| 294 |
+
/* } else { */
|
| 295 |
+
/* result = source_converter(elementwise_(result, source_vector)); */
|
| 296 |
+
/* } */
|
| 297 |
+
|
| 298 |
+
/* // Convert to the output */
|
| 299 |
+
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
|
| 300 |
+
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
|
| 301 |
+
output = output_converter(result);
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
/// Called at the end of a row
|
| 305 |
+
CUTLASS_DEVICE
|
| 306 |
+
void end_row(int row_idx)
|
| 307 |
+
{
|
| 308 |
+
|
| 309 |
+
/* using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>; */
|
| 310 |
+
/* using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>; */
|
| 311 |
+
|
| 312 |
+
/* ConvertSumOutput convert_sum_output; */
|
| 313 |
+
/* ConvertNormOutput convert_norm_output; */
|
| 314 |
+
|
| 315 |
+
/* // Compute accumulate sum only in the last step */
|
| 316 |
+
/* accum_sum_ = warp_reduce_sum_(accum_sum_); */
|
| 317 |
+
|
| 318 |
+
/* bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0); */
|
| 319 |
+
/* bool row_guard = thread_offset_.row() < extent_.row(); */
|
| 320 |
+
/* bool is_write_thread = row_guard && is_first_thread_in_tile; */
|
| 321 |
+
|
| 322 |
+
/* int block_batch = blockIdx.z; */
|
| 323 |
+
|
| 324 |
+
/* ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch *
|
| 325 |
+
* params_.batch_stride_Max; */
|
| 326 |
+
/* ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch *
|
| 327 |
+
* params_.batch_stride_Sum; */
|
| 328 |
+
|
| 329 |
+
/* arch::global_store<ElementNorm, sizeof(ElementNorm)>( */
|
| 330 |
+
/* convert_norm_output(accum_max_), */
|
| 331 |
+
/* (void *)curr_ptr_max, */
|
| 332 |
+
/* is_write_thread); */
|
| 333 |
+
|
| 334 |
+
/* arch::global_store<ElementSum, sizeof(ElementSum)>( */
|
| 335 |
+
/* convert_sum_output(accum_sum_), */
|
| 336 |
+
/* (void *)curr_ptr_sum, */
|
| 337 |
+
/* is_write_thread); */
|
| 338 |
+
|
| 339 |
+
/* // Clear accumulators for max and sum when finishing a whole row */
|
| 340 |
+
/* clear_accum_(); */
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
/// Called after all accumulator elements have been visited
|
| 344 |
+
CUTLASS_DEVICE
|
| 345 |
+
void end_step(int step_idx)
|
| 346 |
+
{
|
| 347 |
+
|
| 348 |
+
iterator_D_.store(fragment_D_);
|
| 349 |
+
++iterator_D_;
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
/// Called after all steps have been completed
|
| 353 |
+
CUTLASS_DEVICE
|
| 354 |
+
void end_epilogue() {}
|
| 355 |
+
|
| 356 |
+
private:
|
| 357 |
+
CUTLASS_DEVICE
|
| 358 |
+
ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum,
|
| 359 |
+
ComputeFragment const& scale_col,
|
| 360 |
+
AlphaScaleElementType const& scale_row)
|
| 361 |
+
{
|
| 362 |
+
|
| 363 |
+
ComputeFragment result;
|
| 364 |
+
CUTLASS_PRAGMA_UNROLL
|
| 365 |
+
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
| 366 |
+
result[i] = accum[i] * (scale_col[i] * scale_row);
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
return result;
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
CUTLASS_DEVICE
|
| 373 |
+
ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum,
|
| 374 |
+
AlphaScaleElementType const& scale_col,
|
| 375 |
+
AlphaScaleElementType const& scale_row)
|
| 376 |
+
{
|
| 377 |
+
|
| 378 |
+
ComputeFragment result;
|
| 379 |
+
CUTLASS_PRAGMA_UNROLL
|
| 380 |
+
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
| 381 |
+
result[i] = accum[i] * (scale_col * scale_row);
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
return result;
|
| 385 |
+
}
|
| 386 |
+
};
|
| 387 |
+
|
| 388 |
+
} // namespace threadblock
|
| 389 |
+
} // namespace epilogue
|
| 390 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
| 33 |
+
|
| 34 |
+
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
| 35 |
+
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
| 36 |
+
|
| 37 |
+
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
|
| 38 |
+
|
| 39 |
+
*/
|
| 40 |
+
|
| 41 |
+
#pragma once
|
| 42 |
+
|
| 43 |
+
#include "cutlass/array.h"
|
| 44 |
+
#include "cutlass/cutlass.h"
|
| 45 |
+
#include "cutlass/numeric_types.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/platform/platform.h"
|
| 48 |
+
|
| 49 |
+
#include "cutlass/gemm/gemm.h"
|
| 50 |
+
|
| 51 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 52 |
+
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
| 53 |
+
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
| 54 |
+
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
|
| 55 |
+
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
| 56 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 57 |
+
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
|
| 58 |
+
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
| 59 |
+
|
| 60 |
+
#include "cutlass/epilogue/thread/conversion_op.h"
|
| 61 |
+
#include "cutlass/epilogue/thread/reduction_op.h"
|
| 62 |
+
|
| 63 |
+
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
| 64 |
+
|
| 65 |
+
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
| 66 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
| 67 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
| 68 |
+
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
| 69 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
| 70 |
+
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
| 71 |
+
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
| 72 |
+
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
| 73 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
| 74 |
+
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
| 75 |
+
|
| 76 |
+
#include "cutlass/epilogue/threadblock/epilogue.h"
|
| 77 |
+
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
| 78 |
+
|
| 79 |
+
#include "cutlass/layout/permute.h"
|
| 80 |
+
|
| 81 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 82 |
+
|
| 83 |
+
namespace cutlass {
|
| 84 |
+
namespace epilogue {
|
| 85 |
+
namespace threadblock {
|
| 86 |
+
|
| 87 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 88 |
+
|
| 89 |
+
namespace detail {
|
| 90 |
+
|
| 91 |
+
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
| 92 |
+
template<typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
|
| 93 |
+
struct DefaultIteratorsTensorOp<cutlass::half_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape, ThreadMap> {
|
| 94 |
+
|
| 95 |
+
using WarpTileIterator =
|
| 96 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, InstructionShape, int32_t, layout::RowMajor>;
|
| 97 |
+
|
| 98 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
|
| 99 |
+
|
| 100 |
+
static int const kFragmentsPerIteration = 1;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
| 104 |
+
template<typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
|
| 105 |
+
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
|
| 106 |
+
int32_t,
|
| 107 |
+
8,
|
| 108 |
+
ThreadblockShape,
|
| 109 |
+
WarpShape,
|
| 110 |
+
InstructionShape,
|
| 111 |
+
ThreadMap> {
|
| 112 |
+
|
| 113 |
+
using WarpTileIterator =
|
| 114 |
+
cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape, InstructionShape, int32_t, layout::RowMajor>;
|
| 115 |
+
|
| 116 |
+
using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
|
| 117 |
+
|
| 118 |
+
static int const kFragmentsPerIteration = 1;
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 122 |
+
|
| 123 |
+
} // namespace detail
|
| 124 |
+
|
| 125 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 126 |
+
|
| 127 |
+
/// Tile iterator used to load output tile from shared memory in epilogue.
|
| 128 |
+
///
|
| 129 |
+
/// Satisfies: ReadableTileIterator
|
| 130 |
+
///
|
| 131 |
+
template<typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
|
| 132 |
+
>
|
| 133 |
+
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
|
| 134 |
+
public:
|
| 135 |
+
using ThreadMap = ThreadMap_;
|
| 136 |
+
using Shape = typename ThreadMap::Shape;
|
| 137 |
+
|
| 138 |
+
using Element = int32_t;
|
| 139 |
+
|
| 140 |
+
using Layout = layout::RowMajor;
|
| 141 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 142 |
+
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
| 143 |
+
|
| 144 |
+
using Index = typename Layout::Index;
|
| 145 |
+
using LongIndex = typename Layout::LongIndex;
|
| 146 |
+
using TensorCoord = MatrixCoord;
|
| 147 |
+
|
| 148 |
+
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
| 149 |
+
|
| 150 |
+
static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
|
| 151 |
+
|
| 152 |
+
static int const kThreads = ThreadMap::kThreads;
|
| 153 |
+
|
| 154 |
+
/// Fragment object
|
| 155 |
+
using Fragment = Array<Element,
|
| 156 |
+
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
|
| 157 |
+
* ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
|
| 158 |
+
|
| 159 |
+
/// Memory access size
|
| 160 |
+
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
|
| 161 |
+
|
| 162 |
+
/// Vector type used for SMEM loads
|
| 163 |
+
using LoadType = AlignedArray<Element,
|
| 164 |
+
const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
|
| 165 |
+
const_min(16, kAlignment)>;
|
| 166 |
+
|
| 167 |
+
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
|
| 168 |
+
|
| 169 |
+
private:
|
| 170 |
+
//
|
| 171 |
+
// Data members
|
| 172 |
+
//
|
| 173 |
+
|
| 174 |
+
/// Byte-level pointer
|
| 175 |
+
LoadType const* pointers_[kLoadsPerAccess];
|
| 176 |
+
|
| 177 |
+
/// Stride along adjacent rows in units of LoadType
|
| 178 |
+
int stride_;
|
| 179 |
+
|
| 180 |
+
public:
|
| 181 |
+
//
|
| 182 |
+
// Methods
|
| 183 |
+
//
|
| 184 |
+
|
| 185 |
+
/// Constructor
|
| 186 |
+
CUTLASS_DEVICE
|
| 187 |
+
SharedLoadIteratorMixed(TensorRef ref, int thread_idx): stride_((ref.stride(0) / LoadType::kElements))
|
| 188 |
+
{
|
| 189 |
+
|
| 190 |
+
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
|
| 191 |
+
|
| 192 |
+
// Initialize pointers
|
| 193 |
+
CUTLASS_PRAGMA_UNROLL
|
| 194 |
+
for (int i = 0; i < kLoadsPerAccess; ++i) {
|
| 195 |
+
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
|
| 196 |
+
|
| 197 |
+
int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
|
| 198 |
+
int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
|
| 199 |
+
|
| 200 |
+
col_idx += (bank_offset + i) % kLoadsPerAccess;
|
| 201 |
+
|
| 202 |
+
pointers_[i] += thread_offset.row() * stride_ + col_idx;
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
/// Adds a pointer offset in units of Element
|
| 207 |
+
CUTLASS_HOST_DEVICE
|
| 208 |
+
void add_pointer_offset(LongIndex pointer_offset)
|
| 209 |
+
{
|
| 210 |
+
CUTLASS_PRAGMA_UNROLL
|
| 211 |
+
for (int i = 0; i < kLoadsPerAccess; ++i) {
|
| 212 |
+
pointers_[i] += pointer_offset / LoadType::kElements;
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
CUTLASS_DEVICE
|
| 217 |
+
void add_tile_offset(TensorCoord const& offset)
|
| 218 |
+
{
|
| 219 |
+
CUTLASS_PRAGMA_UNROLL
|
| 220 |
+
for (int i = 0; i < kLoadsPerAccess; ++i) {
|
| 221 |
+
pointers_[i] +=
|
| 222 |
+
offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
/// Loads a fragment from memory
|
| 227 |
+
CUTLASS_DEVICE
|
| 228 |
+
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
|
| 229 |
+
{
|
| 230 |
+
|
| 231 |
+
CUTLASS_PRAGMA_UNROLL
|
| 232 |
+
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
|
| 233 |
+
|
| 234 |
+
CUTLASS_PRAGMA_UNROLL
|
| 235 |
+
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
| 236 |
+
|
| 237 |
+
CUTLASS_PRAGMA_UNROLL
|
| 238 |
+
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
| 239 |
+
|
| 240 |
+
int row_ptr_offset =
|
| 241 |
+
row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_
|
| 242 |
+
+ cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements;
|
| 243 |
+
|
| 244 |
+
int frag_row_idx =
|
| 245 |
+
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
| 246 |
+
|
| 247 |
+
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
|
| 248 |
+
|
| 249 |
+
CUTLASS_PRAGMA_UNROLL
|
| 250 |
+
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
| 251 |
+
|
| 252 |
+
int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
|
| 253 |
+
|
| 254 |
+
CUTLASS_PRAGMA_UNROLL
|
| 255 |
+
for (int v = 0; v < kLoadsPerAccess; ++v) {
|
| 256 |
+
|
| 257 |
+
int vector_idx =
|
| 258 |
+
(column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
|
| 259 |
+
|
| 260 |
+
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
|
| 261 |
+
|
| 262 |
+
frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
/// Loads a fragment
|
| 271 |
+
CUTLASS_DEVICE
|
| 272 |
+
void load(Fragment& frag) const
|
| 273 |
+
{
|
| 274 |
+
|
| 275 |
+
load_with_pointer_offset(frag, 0);
|
| 276 |
+
}
|
| 277 |
+
};
|
| 278 |
+
|
| 279 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 280 |
+
|
| 281 |
+
} // namespace threadblock
|
| 282 |
+
} // namespace epilogue
|
| 283 |
+
} // namespace cutlass
|
| 284 |
+
|
| 285 |
+
////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @file epilogue_helpers.h
|
| 3 |
+
*
|
| 4 |
+
* This file includes types for the epilogues. The empty structs exist so we can signal to template
|
| 5 |
+
* code the type of epilogue we want to run, and let the underlying code specify the details such as
|
| 6 |
+
* element types, accumulator type and elements per vector access.
|
| 7 |
+
*
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "cutlass/epilogue/thread/linear_combination.h"
|
| 13 |
+
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
| 14 |
+
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
| 15 |
+
#include "cutlass/epilogue/thread/linear_combination_silu.h"
|
| 16 |
+
#include "cutlass_extensions/epilogue/thread/ft_fused_activations.h"
|
| 17 |
+
|
| 18 |
+
namespace fastertransformer {
|
| 19 |
+
|
| 20 |
+
struct EpilogueOpBiasSilu {};
|
| 21 |
+
|
| 22 |
+
struct EpilogueOpBiasReLU {};
|
| 23 |
+
|
| 24 |
+
struct EpilogueOpBiasFtGelu {};
|
| 25 |
+
|
| 26 |
+
struct EpilogueOpBias {};
|
| 27 |
+
|
| 28 |
+
struct EpilogueOpNoBias {};
|
| 29 |
+
|
| 30 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
|
| 31 |
+
struct Epilogue {
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
| 35 |
+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu> {
|
| 36 |
+
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType,
|
| 37 |
+
ElementsPerVectorAccess,
|
| 38 |
+
ElementAccumulator,
|
| 39 |
+
ElementAccumulator,
|
| 40 |
+
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
| 44 |
+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU> {
|
| 45 |
+
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType,
|
| 46 |
+
ElementsPerVectorAccess,
|
| 47 |
+
ElementAccumulator,
|
| 48 |
+
ElementAccumulator,
|
| 49 |
+
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
| 53 |
+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu> {
|
| 54 |
+
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor_fixed,
|
| 55 |
+
ElementType,
|
| 56 |
+
ElementsPerVectorAccess,
|
| 57 |
+
ElementAccumulator,
|
| 58 |
+
ElementAccumulator,
|
| 59 |
+
cutlass::epilogue::thread::ScaleType::NoBetaScaling,
|
| 60 |
+
cutlass::FloatRoundStyle::round_to_nearest,
|
| 61 |
+
true>;
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
| 65 |
+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias> {
|
| 66 |
+
using Op = cutlass::epilogue::thread::LinearCombination<ElementType,
|
| 67 |
+
ElementsPerVectorAccess,
|
| 68 |
+
ElementAccumulator,
|
| 69 |
+
ElementAccumulator,
|
| 70 |
+
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template<typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
| 74 |
+
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBias> {
|
| 75 |
+
using Op = cutlass::epilogue::thread::LinearCombination<ElementType,
|
| 76 |
+
ElementsPerVectorAccess,
|
| 77 |
+
ElementAccumulator,
|
| 78 |
+
ElementAccumulator,
|
| 79 |
+
cutlass::epilogue::thread::ScaleType::Default>;
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
} // namespace fastertransformer
|
cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
|
| 19 |
+
namespace fastertransformer {
|
| 20 |
+
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
|
| 21 |
+
// in the kernel layout details when doing weight only quantization.
|
| 22 |
+
enum class CutlassTileConfig {
|
| 23 |
+
// Signals that we should run heuristics do choose a config
|
| 24 |
+
Undefined,
|
| 25 |
+
|
| 26 |
+
// Signals that we should run heuristics do choose a config
|
| 27 |
+
ChooseWithHeuristic,
|
| 28 |
+
|
| 29 |
+
// SiMT config
|
| 30 |
+
CtaShape128x128x8_WarpShape64x64x8,
|
| 31 |
+
|
| 32 |
+
// TensorCore configs CTA_N = 128, CTA_K = 64
|
| 33 |
+
// Warp configs for M=32
|
| 34 |
+
CtaShape32x128x64_WarpShape32x32x64,
|
| 35 |
+
|
| 36 |
+
// Warp configs for M=64
|
| 37 |
+
CtaShape64x128x64_WarpShape32x64x64,
|
| 38 |
+
CtaShape64x128x64_WarpShape64x32x64,
|
| 39 |
+
|
| 40 |
+
// Warp configs for M=128
|
| 41 |
+
CtaShape128x128x64_WarpShape64x32x64,
|
| 42 |
+
CtaShape128x128x64_WarpShape128x32x64
|
| 43 |
+
};
|
| 44 |
+
|
| 45 |
+
enum class SplitKStyle {
|
| 46 |
+
NO_SPLIT_K,
|
| 47 |
+
SPLIT_K_SERIAL,
|
| 48 |
+
// SPLIT_K_PARALLEL // Not supported yet
|
| 49 |
+
};
|
| 50 |
+
|
| 51 |
+
struct CutlassGemmConfig {
|
| 52 |
+
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
|
| 53 |
+
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
|
| 54 |
+
int split_k_factor = -1;
|
| 55 |
+
int stages = -1;
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
} // namespace fastertransformer
|
cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "cutlass/arch/arch.h"
|
| 4 |
+
#include "cutlass/arch/mma.h"
|
| 5 |
+
#include "cutlass/bfloat16.h"
|
| 6 |
+
#include "cutlass/cutlass.h"
|
| 7 |
+
#include "cutlass/gemm/gemm.h"
|
| 8 |
+
#include "cutlass/layout/matrix.h"
|
| 9 |
+
|
| 10 |
+
#include "cutlass_extensions/arch/mma.h"
|
| 11 |
+
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
| 12 |
+
|
| 13 |
+
namespace cutlass {
|
| 14 |
+
namespace gemm {
|
| 15 |
+
namespace kernel {
|
| 16 |
+
|
| 17 |
+
template<typename TypeA, typename TypeB, typename arch, typename Enable = void>
|
| 18 |
+
struct MixedGemmArchTraits {
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
template<typename arch>
|
| 22 |
+
struct MixedGemmArchTraits<float, float, arch> {
|
| 23 |
+
static constexpr int Stages = 2;
|
| 24 |
+
using OperatorClass = cutlass::arch::OpClassSimt;
|
| 25 |
+
using AccType = float;
|
| 26 |
+
using LayoutB = cutlass::layout::RowMajor;
|
| 27 |
+
|
| 28 |
+
static constexpr int ElementsPerAccessA = 1;
|
| 29 |
+
static constexpr int ElementsPerAccessB = 1;
|
| 30 |
+
static constexpr int ElementsPerAccessC = 1;
|
| 31 |
+
static constexpr int ThreadblockK = 8;
|
| 32 |
+
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
| 33 |
+
|
| 34 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
// ========================= Volta Traits ===========================
|
| 38 |
+
// Volta will always dequantize after the global memory load.
|
| 39 |
+
// This will instantiate any HMMA tensorcore kernels for Volta.
|
| 40 |
+
// Note that volta does not have native bfloat support so weights and activations will be casted to fp16
|
| 41 |
+
// and compute will happen in fp16 then will be converted for bf16 output.
|
| 42 |
+
template<typename TypeA, typename TypeB>
|
| 43 |
+
struct MixedGemmArchTraits<
|
| 44 |
+
TypeA,
|
| 45 |
+
TypeB,
|
| 46 |
+
cutlass::arch::Sm70,
|
| 47 |
+
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
| 48 |
+
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
|
| 49 |
+
private:
|
| 50 |
+
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm70>;
|
| 51 |
+
|
| 52 |
+
public:
|
| 53 |
+
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
| 54 |
+
|
| 55 |
+
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
| 56 |
+
using AccType = float;
|
| 57 |
+
using LayoutB = typename LayoutDetails::Layout;
|
| 58 |
+
|
| 59 |
+
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
| 60 |
+
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
| 61 |
+
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
| 62 |
+
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
| 63 |
+
|
| 64 |
+
using Operator = typename LayoutDetails::Operator;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
// ======================= Turing Traits ==============================
|
| 68 |
+
// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
|
| 69 |
+
// and compute will happen in fp16 then will be converted for bf16 output.
|
| 70 |
+
template<typename TypeA, typename TypeB>
|
| 71 |
+
struct MixedGemmArchTraits<
|
| 72 |
+
TypeA,
|
| 73 |
+
TypeB,
|
| 74 |
+
cutlass::arch::Sm75,
|
| 75 |
+
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
| 76 |
+
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
|
| 77 |
+
private:
|
| 78 |
+
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm75>;
|
| 79 |
+
|
| 80 |
+
public:
|
| 81 |
+
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
| 82 |
+
|
| 83 |
+
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
| 84 |
+
using AccType = float;
|
| 85 |
+
using LayoutB = typename LayoutDetails::Layout;
|
| 86 |
+
|
| 87 |
+
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
| 88 |
+
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
| 89 |
+
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
| 90 |
+
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
| 91 |
+
|
| 92 |
+
using Operator = typename LayoutDetails::Operator;
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
// ======================= Ampere Traits ==============================
|
| 96 |
+
template<typename TypeA, typename TypeB>
|
| 97 |
+
struct MixedGemmArchTraits<
|
| 98 |
+
TypeA,
|
| 99 |
+
TypeB,
|
| 100 |
+
cutlass::arch::Sm80,
|
| 101 |
+
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
| 102 |
+
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
|
| 103 |
+
private:
|
| 104 |
+
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm80>;
|
| 105 |
+
|
| 106 |
+
public:
|
| 107 |
+
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
| 108 |
+
|
| 109 |
+
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
| 110 |
+
using AccType = float;
|
| 111 |
+
using LayoutB = typename LayoutDetails::Layout;
|
| 112 |
+
|
| 113 |
+
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
| 114 |
+
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
| 115 |
+
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
| 116 |
+
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
| 117 |
+
|
| 118 |
+
using Operator = typename LayoutDetails::Operator;
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
} // namespace kernel
|
| 122 |
+
} // namespace gemm
|
| 123 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
|
| 40 |
+
#include "cutlass/arch/arch.h"
|
| 41 |
+
#include "cutlass/gemm/gemm.h"
|
| 42 |
+
#include "cutlass/matrix_coord.h"
|
| 43 |
+
#include "cutlass/semaphore.h"
|
| 44 |
+
|
| 45 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
|
| 47 |
+
namespace cutlass {
|
| 48 |
+
namespace gemm {
|
| 49 |
+
namespace kernel {
|
| 50 |
+
|
| 51 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
template<typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
| 54 |
+
typename Epilogue_, ///! Epilogue
|
| 55 |
+
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
| 56 |
+
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
|
| 57 |
+
/// arch.
|
| 58 |
+
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
| 59 |
+
>
|
| 60 |
+
struct GemmFpAIntB {
|
| 61 |
+
|
| 62 |
+
using Mma = Mma_;
|
| 63 |
+
using Epilogue = Epilogue_;
|
| 64 |
+
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
| 65 |
+
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
| 66 |
+
static bool const kSplitKSerial = SplitKSerial;
|
| 67 |
+
|
| 68 |
+
using ElementA = typename Mma::IteratorA::Element;
|
| 69 |
+
using LayoutA = typename Mma::IteratorA::Layout;
|
| 70 |
+
using ElementB = typename Mma::IteratorB::Element;
|
| 71 |
+
using LayoutB = typename Mma::IteratorB::Element;
|
| 72 |
+
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
| 73 |
+
using LayoutC = typename Mma::LayoutC;
|
| 74 |
+
using ElementScale = ElementC;
|
| 75 |
+
|
| 76 |
+
static ComplexTransform const kTransformA = Mma::kTransformA;
|
| 77 |
+
static ComplexTransform const kTransformB = Mma::kTransformA;
|
| 78 |
+
|
| 79 |
+
// Type definitions about the mainloop.
|
| 80 |
+
using Operator = typename Mma::Operator;
|
| 81 |
+
using OperatorClass = typename Mma::Operator::OperatorClass;
|
| 82 |
+
using ThreadblockShape = typename Mma::Shape;
|
| 83 |
+
using WarpShape = typename Mma::Operator::Shape;
|
| 84 |
+
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
| 85 |
+
using ArchTag = typename Mma::ArchTag;
|
| 86 |
+
|
| 87 |
+
static int const kStages = Mma::kStages;
|
| 88 |
+
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
| 89 |
+
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
| 90 |
+
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
| 91 |
+
|
| 92 |
+
/// Warp count (concept: GemmShape)
|
| 93 |
+
using WarpCount = typename Mma::WarpCount;
|
| 94 |
+
static int const kThreadCount = 32 * WarpCount::kCount;
|
| 95 |
+
|
| 96 |
+
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
| 97 |
+
|
| 98 |
+
/// Parameters structure
|
| 99 |
+
struct Arguments {
|
| 100 |
+
GemmUniversalMode mode = GemmUniversalMode::kGemm;
|
| 101 |
+
|
| 102 |
+
cutlass::gemm::GemmCoord problem_size;
|
| 103 |
+
typename Mma::IteratorA::TensorRef ref_A;
|
| 104 |
+
typename Mma::IteratorB::TensorRef ref_B;
|
| 105 |
+
typename Mma::IteratorScale::TensorRef ref_scale;
|
| 106 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
| 107 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
| 108 |
+
|
| 109 |
+
// Control serial split-k
|
| 110 |
+
int batch_count;
|
| 111 |
+
|
| 112 |
+
typename EpilogueOutputOp::Params output_op;
|
| 113 |
+
|
| 114 |
+
// For gather+scatter operations
|
| 115 |
+
int const* gather_A_indices;
|
| 116 |
+
int const* gather_B_indices;
|
| 117 |
+
int const* scatter_D_indices;
|
| 118 |
+
|
| 119 |
+
// Included so we can use Gemm Universal
|
| 120 |
+
int batch_stride_D = 0;
|
| 121 |
+
|
| 122 |
+
//
|
| 123 |
+
// Methods
|
| 124 |
+
//
|
| 125 |
+
|
| 126 |
+
CUTLASS_HOST_DEVICE
|
| 127 |
+
Arguments() {}
|
| 128 |
+
|
| 129 |
+
CUTLASS_HOST_DEVICE
|
| 130 |
+
Arguments(cutlass::gemm::GemmCoord const& problem_size,
|
| 131 |
+
typename Mma::IteratorA::TensorRef ref_A,
|
| 132 |
+
typename Mma::IteratorB::TensorRef ref_B,
|
| 133 |
+
typename Mma::IteratorScale::TensorRef ref_scale,
|
| 134 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
| 135 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D,
|
| 136 |
+
int serial_split_k_factor,
|
| 137 |
+
typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
|
| 138 |
+
int const* gather_A_indices = nullptr,
|
| 139 |
+
int const* gather_B_indices = nullptr,
|
| 140 |
+
int const* scatter_D_indices = nullptr):
|
| 141 |
+
problem_size(problem_size),
|
| 142 |
+
ref_A(ref_A),
|
| 143 |
+
ref_B(ref_B),
|
| 144 |
+
ref_scale(ref_scale),
|
| 145 |
+
ref_C(ref_C),
|
| 146 |
+
ref_D(ref_D),
|
| 147 |
+
batch_count(serial_split_k_factor),
|
| 148 |
+
output_op(output_op),
|
| 149 |
+
gather_A_indices(gather_A_indices),
|
| 150 |
+
gather_B_indices(gather_B_indices),
|
| 151 |
+
scatter_D_indices(scatter_D_indices)
|
| 152 |
+
{
|
| 153 |
+
}
|
| 154 |
+
};
|
| 155 |
+
|
| 156 |
+
/// Parameters structure
|
| 157 |
+
struct Params {
|
| 158 |
+
cutlass::gemm::GemmCoord problem_size;
|
| 159 |
+
cutlass::gemm::GemmCoord grid_tiled_shape;
|
| 160 |
+
int swizzle_log_tile;
|
| 161 |
+
typename Mma::IteratorA::Params params_A;
|
| 162 |
+
typename Mma::IteratorA::TensorRef ref_A;
|
| 163 |
+
typename Mma::IteratorB::Params params_B;
|
| 164 |
+
typename Mma::IteratorB::TensorRef ref_B;
|
| 165 |
+
typename Mma::IteratorScale::Params params_scale;
|
| 166 |
+
typename Mma::IteratorScale::TensorRef ref_scale;
|
| 167 |
+
typename Epilogue::OutputTileIterator::Params params_C;
|
| 168 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
| 169 |
+
typename Epilogue::OutputTileIterator::Params params_D;
|
| 170 |
+
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
| 171 |
+
typename EpilogueOutputOp::Params output_op;
|
| 172 |
+
int* semaphore;
|
| 173 |
+
int gemm_k_size;
|
| 174 |
+
// For gather+scatter operations
|
| 175 |
+
int const* gather_A_indices;
|
| 176 |
+
int const* gather_B_indices;
|
| 177 |
+
int const* scatter_D_indices;
|
| 178 |
+
|
| 179 |
+
//
|
| 180 |
+
// Methods
|
| 181 |
+
//
|
| 182 |
+
|
| 183 |
+
CUTLASS_HOST_DEVICE
|
| 184 |
+
Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {}
|
| 185 |
+
|
| 186 |
+
CUTLASS_HOST_DEVICE
|
| 187 |
+
Params(Arguments const& args,
|
| 188 |
+
cutlass::gemm::GemmCoord const& grid_tiled_shape,
|
| 189 |
+
const int gemm_k_size,
|
| 190 |
+
void* workspace = nullptr):
|
| 191 |
+
problem_size(args.problem_size),
|
| 192 |
+
grid_tiled_shape(grid_tiled_shape),
|
| 193 |
+
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
| 194 |
+
params_A(args.ref_A.layout()),
|
| 195 |
+
ref_A(args.ref_A),
|
| 196 |
+
params_B(args.ref_B.layout()),
|
| 197 |
+
ref_B(args.ref_B),
|
| 198 |
+
params_scale(args.ref_scale.layout()),
|
| 199 |
+
ref_scale(args.ref_scale),
|
| 200 |
+
params_C(args.ref_C.layout()),
|
| 201 |
+
ref_C(args.ref_C),
|
| 202 |
+
params_D(args.ref_D.layout()),
|
| 203 |
+
ref_D(args.ref_D),
|
| 204 |
+
output_op(args.output_op),
|
| 205 |
+
semaphore(static_cast<int*>(workspace)),
|
| 206 |
+
gemm_k_size(gemm_k_size),
|
| 207 |
+
gather_A_indices(args.gather_A_indices),
|
| 208 |
+
gather_B_indices(args.gather_B_indices),
|
| 209 |
+
scatter_D_indices(args.scatter_D_indices)
|
| 210 |
+
{
|
| 211 |
+
}
|
| 212 |
+
};
|
| 213 |
+
|
| 214 |
+
/// Shared memory storage structure
|
| 215 |
+
union SharedStorage {
|
| 216 |
+
typename Mma::SharedStorage main_loop;
|
| 217 |
+
typename Epilogue::SharedStorage epilogue;
|
| 218 |
+
};
|
| 219 |
+
|
| 220 |
+
//
|
| 221 |
+
// Methods
|
| 222 |
+
//
|
| 223 |
+
|
| 224 |
+
CUTLASS_HOST_DEVICE
|
| 225 |
+
GemmFpAIntB() {}
|
| 226 |
+
|
| 227 |
+
/// Determines whether kernel satisfies alignment
|
| 228 |
+
CUTLASS_HOST_DEVICE
|
| 229 |
+
static Status can_implement(Arguments const& args)
|
| 230 |
+
{
|
| 231 |
+
|
| 232 |
+
static int const kAlignmentA =
|
| 233 |
+
(platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ?
|
| 234 |
+
32 :
|
| 235 |
+
(platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value) ?
|
| 236 |
+
64 :
|
| 237 |
+
Mma::IteratorA::AccessType::kElements;
|
| 238 |
+
static int const kAlignmentB =
|
| 239 |
+
(platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ?
|
| 240 |
+
32 :
|
| 241 |
+
(platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value) ?
|
| 242 |
+
64 :
|
| 243 |
+
Mma::IteratorB::AccessType::kElements;
|
| 244 |
+
|
| 245 |
+
static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
|
| 246 |
+
|
| 247 |
+
static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
| 248 |
+
layout::ColumnMajorInterleaved<32>>::value) ?
|
| 249 |
+
32 :
|
| 250 |
+
(platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
| 251 |
+
layout::ColumnMajorInterleaved<64>>::value) ?
|
| 252 |
+
64 :
|
| 253 |
+
Epilogue::OutputTileIterator::kElementsPerAccess;
|
| 254 |
+
|
| 255 |
+
if (!TensorRef_aligned(args.ref_A, kAlignmentA)) {
|
| 256 |
+
return Status::kErrorMisalignedOperand;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
if (!TensorRef_aligned(args.ref_B, kAlignmentB)) {
|
| 260 |
+
return Status::kErrorMisalignedOperand;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) {
|
| 264 |
+
return Status::kErrorMisalignedOperand;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
if (!TensorRef_aligned(args.ref_C, kAlignmentC)) {
|
| 268 |
+
return Status::kErrorMisalignedOperand;
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
if (!TensorRef_aligned(args.ref_D, kAlignmentC)) {
|
| 272 |
+
return Status::kErrorMisalignedOperand;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
return Status::kSuccess;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
| 279 |
+
{
|
| 280 |
+
|
| 281 |
+
return 0;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// The dummy template parameter is not used and exists so that we can compile this code using
|
| 285 |
+
// a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
|
| 286 |
+
// a namespace
|
| 287 |
+
template<bool B, typename dummy = void>
|
| 288 |
+
struct KernelRunner {
|
| 289 |
+
CUTLASS_DEVICE
|
| 290 |
+
static void run_kernel(Params const& params, SharedStorage& shared_storage)
|
| 291 |
+
{
|
| 292 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 293 |
+
}
|
| 294 |
+
};
|
| 295 |
+
|
| 296 |
+
template<typename dummy>
|
| 297 |
+
struct KernelRunner<true, dummy> {
|
| 298 |
+
CUTLASS_DEVICE
|
| 299 |
+
static void run_kernel(Params const& params, SharedStorage& shared_storage)
|
| 300 |
+
{
|
| 301 |
+
using LayoutB = typename Mma::IteratorB::Layout;
|
| 302 |
+
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
| 303 |
+
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
| 304 |
+
"B must be row major/col major OR col major interleaved.");
|
| 305 |
+
|
| 306 |
+
// Compute threadblock location
|
| 307 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 308 |
+
|
| 309 |
+
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
| 310 |
+
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
| 311 |
+
|
| 312 |
+
// Early exit if CTA is out of range
|
| 313 |
+
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
| 314 |
+
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
| 315 |
+
|
| 316 |
+
return;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
// Compute initial location in logical coordinates
|
| 320 |
+
cutlass::MatrixCoord tb_offset_A{
|
| 321 |
+
threadblock_tile_offset.m() * Mma::Shape::kM,
|
| 322 |
+
threadblock_tile_offset.k() * params.gemm_k_size,
|
| 323 |
+
};
|
| 324 |
+
|
| 325 |
+
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
|
| 326 |
+
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
|
| 327 |
+
|
| 328 |
+
cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * Mma::Shape::kN};
|
| 329 |
+
|
| 330 |
+
// Problem size is a function of threadblock index in the K dimension
|
| 331 |
+
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
| 332 |
+
|
| 333 |
+
// Compute threadblock-scoped matrix multiply-add
|
| 334 |
+
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
| 335 |
+
|
| 336 |
+
// Compute position within threadblock
|
| 337 |
+
int thread_idx = threadIdx.x;
|
| 338 |
+
|
| 339 |
+
// Construct iterators to A and B operands
|
| 340 |
+
typename Mma::IteratorA iterator_A(params.params_A,
|
| 341 |
+
params.ref_A.data(),
|
| 342 |
+
{params.problem_size.m(), problem_size_k},
|
| 343 |
+
thread_idx,
|
| 344 |
+
tb_offset_A,
|
| 345 |
+
params.gather_A_indices);
|
| 346 |
+
|
| 347 |
+
typename Mma::IteratorB iterator_B(params.params_B,
|
| 348 |
+
params.ref_B.data(),
|
| 349 |
+
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
|
| 350 |
+
thread_idx,
|
| 351 |
+
tb_offset_B,
|
| 352 |
+
params.gather_B_indices);
|
| 353 |
+
|
| 354 |
+
typename Mma::IteratorScale iterator_scale(params.params_scale,
|
| 355 |
+
params.ref_scale.data(),
|
| 356 |
+
{1, params.problem_size.n()},
|
| 357 |
+
thread_idx,
|
| 358 |
+
tb_offset_scale);
|
| 359 |
+
|
| 360 |
+
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
| 361 |
+
// is compiled as warp-uniform.
|
| 362 |
+
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 363 |
+
int lane_idx = threadIdx.x % 32;
|
| 364 |
+
|
| 365 |
+
//
|
| 366 |
+
// Main loop
|
| 367 |
+
//
|
| 368 |
+
// Construct thread-scoped matrix multiply
|
| 369 |
+
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
| 370 |
+
|
| 371 |
+
typename Mma::FragmentC accumulators;
|
| 372 |
+
|
| 373 |
+
accumulators.clear();
|
| 374 |
+
|
| 375 |
+
if (!kSplitKSerial || gemm_k_iterations > 0) {
|
| 376 |
+
// Compute threadblock-scoped matrix multiply-add
|
| 377 |
+
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
//
|
| 381 |
+
// Epilogue
|
| 382 |
+
//
|
| 383 |
+
|
| 384 |
+
EpilogueOutputOp output_op(params.output_op);
|
| 385 |
+
|
| 386 |
+
//
|
| 387 |
+
// Masked tile iterators constructed from members
|
| 388 |
+
//
|
| 389 |
+
|
| 390 |
+
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
| 391 |
+
|
| 392 |
+
// assume identity swizzle
|
| 393 |
+
MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,
|
| 394 |
+
threadblock_tile_offset.n() * Mma::Shape::kN);
|
| 395 |
+
|
| 396 |
+
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
| 397 |
+
|
| 398 |
+
// Construct the semaphore.
|
| 399 |
+
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
| 400 |
+
|
| 401 |
+
// If performing a reduction via split-K, fetch the initial synchronization
|
| 402 |
+
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
| 403 |
+
|
| 404 |
+
// Fetch the synchronization lock initially but do not block.
|
| 405 |
+
semaphore.fetch();
|
| 406 |
+
|
| 407 |
+
// Indicate which position in a serial reduction the output operator is currently updating
|
| 408 |
+
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
// Tile iterator loading from source tensor.
|
| 412 |
+
typename Epilogue::OutputTileIterator iterator_C(params.params_C,
|
| 413 |
+
params.ref_C.data(),
|
| 414 |
+
params.problem_size.mn(),
|
| 415 |
+
thread_idx,
|
| 416 |
+
threadblock_offset,
|
| 417 |
+
params.scatter_D_indices);
|
| 418 |
+
|
| 419 |
+
// Tile iterator writing to destination tensor.
|
| 420 |
+
typename Epilogue::OutputTileIterator iterator_D(params.params_D,
|
| 421 |
+
params.ref_D.data(),
|
| 422 |
+
params.problem_size.mn(),
|
| 423 |
+
thread_idx,
|
| 424 |
+
threadblock_offset,
|
| 425 |
+
params.scatter_D_indices);
|
| 426 |
+
|
| 427 |
+
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
| 428 |
+
|
| 429 |
+
// Wait on the semaphore - this latency may have been covered by iterator construction
|
| 430 |
+
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
| 431 |
+
|
| 432 |
+
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
| 433 |
+
if (threadblock_tile_offset.k()) {
|
| 434 |
+
iterator_C = iterator_D;
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
semaphore.wait(threadblock_tile_offset.k());
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
// Execute the epilogue operator to update the destination tensor.
|
| 441 |
+
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
| 442 |
+
|
| 443 |
+
//
|
| 444 |
+
// Release the semaphore
|
| 445 |
+
//
|
| 446 |
+
|
| 447 |
+
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
|
| 448 |
+
|
| 449 |
+
int lock = 0;
|
| 450 |
+
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
|
| 451 |
+
|
| 452 |
+
// The final threadblock resets the semaphore for subsequent grids.
|
| 453 |
+
lock = 0;
|
| 454 |
+
}
|
| 455 |
+
else {
|
| 456 |
+
// Otherwise, the semaphore is incremented
|
| 457 |
+
lock = threadblock_tile_offset.k() + 1;
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
semaphore.release(lock);
|
| 461 |
+
}
|
| 462 |
+
}
|
| 463 |
+
};
|
| 464 |
+
|
| 465 |
+
/*
|
| 466 |
+
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
| 467 |
+
to the ArchTag of the cutlass kernel operator.
|
| 468 |
+
*/
|
| 469 |
+
/// Executes one GEMM
|
| 470 |
+
CUTLASS_DEVICE
|
| 471 |
+
void operator()(Params const& params, SharedStorage& shared_storage)
|
| 472 |
+
{
|
| 473 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
|
| 474 |
+
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
|
| 475 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
| 476 |
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
| 477 |
+
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
|
| 478 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
| 479 |
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
| 480 |
+
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
|
| 481 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
| 482 |
+
#else
|
| 483 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 484 |
+
#endif
|
| 485 |
+
}
|
| 486 |
+
};
|
| 487 |
+
|
| 488 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 489 |
+
|
| 490 |
+
} // namespace kernel
|
| 491 |
+
} // namespace gemm
|
| 492 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
| 3 |
+
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice,
|
| 9 |
+
*this list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
| 22 |
+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
| 23 |
+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
| 24 |
+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
| 25 |
+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
| 26 |
+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
| 27 |
+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
| 28 |
+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
| 29 |
+
*POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
*
|
| 31 |
+
**************************************************************************************************/
|
| 32 |
+
|
| 33 |
+
/*! \file
|
| 34 |
+
\brief Template for a pipelined GEMM kernel. Does not compute batching or
|
| 35 |
+
support split-K.
|
| 36 |
+
*/
|
| 37 |
+
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
|
| 42 |
+
#include "cutlass/arch/arch.h"
|
| 43 |
+
#include "cutlass/gemm/gemm.h"
|
| 44 |
+
#include "cutlass/matrix_coord.h"
|
| 45 |
+
#include "cutlass/semaphore.h"
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
namespace cutlass {
|
| 50 |
+
namespace gemm {
|
| 51 |
+
namespace kernel {
|
| 52 |
+
|
| 53 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
| 56 |
+
typename Epilogue_, ///! Epilogue
|
| 57 |
+
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
| 58 |
+
typename KernelArch ///! The Architecture this kernel is compiled for.
|
| 59 |
+
/// Used since SIMT kernels lose top-level arch.
|
| 60 |
+
//////
|
| 61 |
+
>
|
| 62 |
+
struct GemmFpAIntBWithBroadcast {
|
| 63 |
+
|
| 64 |
+
using Mma = Mma_;
|
| 65 |
+
using Epilogue = Epilogue_;
|
| 66 |
+
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
| 67 |
+
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
| 68 |
+
|
| 69 |
+
using ElementA = typename Mma::IteratorA::Element;
|
| 70 |
+
using LayoutA = typename Mma::IteratorA::Layout;
|
| 71 |
+
using ElementB = typename Mma::IteratorB::Element;
|
| 72 |
+
using LayoutB = typename Mma::IteratorB::Element;
|
| 73 |
+
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
| 74 |
+
using LayoutC = typename Mma::LayoutC;
|
| 75 |
+
using ElementScale = ElementC;
|
| 76 |
+
|
| 77 |
+
static ComplexTransform const kTransformA = Mma::kTransformA;
|
| 78 |
+
static ComplexTransform const kTransformB = Mma::kTransformA;
|
| 79 |
+
|
| 80 |
+
// Type definitions about the mainloop.
|
| 81 |
+
using Operator = typename Mma::Operator;
|
| 82 |
+
using OperatorClass = typename Mma::Operator::OperatorClass;
|
| 83 |
+
using ThreadblockShape = typename Mma::Shape;
|
| 84 |
+
using WarpShape = typename Mma::Operator::Shape;
|
| 85 |
+
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
| 86 |
+
using ArchTag = typename Mma::ArchTag;
|
| 87 |
+
|
| 88 |
+
static int const kStages = Mma::kStages;
|
| 89 |
+
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
| 90 |
+
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
| 91 |
+
static int const kAlignmentC =
|
| 92 |
+
Epilogue::OutputTileIterator::kElementsPerAccess;
|
| 93 |
+
|
| 94 |
+
/// Warp count (concept: GemmShape)
|
| 95 |
+
using WarpCount = typename Mma::WarpCount;
|
| 96 |
+
static int const kThreadCount = 32 * WarpCount::kCount;
|
| 97 |
+
|
| 98 |
+
static constexpr int kInterleave =
|
| 99 |
+
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
| 100 |
+
|
| 101 |
+
/// Parameters structure
|
| 102 |
+
struct Arguments {
|
| 103 |
+
GemmUniversalMode mode = GemmUniversalMode::kGemm;
|
| 104 |
+
|
| 105 |
+
cutlass::gemm::GemmCoord problem_size;
|
| 106 |
+
int batch_count;
|
| 107 |
+
typename EpilogueOutputOp::Params epilogue;
|
| 108 |
+
|
| 109 |
+
void const *ptr_A;
|
| 110 |
+
void const *ptr_B;
|
| 111 |
+
void const *ptr_scales;
|
| 112 |
+
void const *ptr_C;
|
| 113 |
+
void *ptr_D;
|
| 114 |
+
|
| 115 |
+
void const *ptr_Vector;
|
| 116 |
+
void const *ptr_Tensor;
|
| 117 |
+
|
| 118 |
+
int64_t batch_stride_A;
|
| 119 |
+
int64_t batch_stride_B;
|
| 120 |
+
int64_t batch_stride_C;
|
| 121 |
+
int64_t batch_stride_D;
|
| 122 |
+
int64_t batch_stride_Vector;
|
| 123 |
+
int64_t batch_stride_Tensor;
|
| 124 |
+
|
| 125 |
+
int lda, ldb, ldc, ldd, ldr, ldt;
|
| 126 |
+
|
| 127 |
+
typename EpilogueOutputOp::Params output_op;
|
| 128 |
+
|
| 129 |
+
// For gather+scatter operations
|
| 130 |
+
int const *gather_A_indices;
|
| 131 |
+
int const *gather_B_indices;
|
| 132 |
+
int const *scatter_D_indices;
|
| 133 |
+
|
| 134 |
+
CUTLASS_HOST_DEVICE
|
| 135 |
+
Arguments() {}
|
| 136 |
+
|
| 137 |
+
CUTLASS_HOST_DEVICE
|
| 138 |
+
Arguments(cutlass::gemm::GemmCoord const &problem_size, int batch_count,
|
| 139 |
+
typename EpilogueOutputOp::Params epilogue, void const *ptr_A,
|
| 140 |
+
void const *ptr_B, void const *ptr_scales, void const *ptr_C,
|
| 141 |
+
void *ptr_D, const void *ptr_Vector, const void *ptr_Tensor,
|
| 142 |
+
int64_t batch_stride_A, int64_t batch_stride_B,
|
| 143 |
+
int64_t batch_stride_C, int64_t batch_stride_D,
|
| 144 |
+
int64_t batch_stride_Vector, int64_t batch_stride_Tensor,
|
| 145 |
+
int lda, int ldb, int ldc, int ldd, int ldr, int ldt,
|
| 146 |
+
typename EpilogueOutputOp::Params output_op =
|
| 147 |
+
typename EpilogueOutputOp::Params())
|
| 148 |
+
: problem_size(problem_size), batch_count(batch_count),
|
| 149 |
+
epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B),
|
| 150 |
+
ptr_scales(ptr_scales), ptr_C(ptr_C), ptr_D(ptr_D),
|
| 151 |
+
ptr_Vector(ptr_Vector), ptr_Tensor(ptr_Tensor),
|
| 152 |
+
batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B),
|
| 153 |
+
batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D),
|
| 154 |
+
batch_stride_Vector(batch_stride_Vector),
|
| 155 |
+
batch_stride_Tensor(batch_stride_Tensor), lda(lda), ldb(ldb),
|
| 156 |
+
ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt), output_op(output_op),
|
| 157 |
+
gather_A_indices(nullptr), gather_B_indices(nullptr),
|
| 158 |
+
scatter_D_indices(nullptr) {}
|
| 159 |
+
};
|
| 160 |
+
|
| 161 |
+
/// Parameters structure
|
| 162 |
+
struct Params {
|
| 163 |
+
cutlass::gemm::GemmCoord problem_size;
|
| 164 |
+
cutlass::gemm::GemmCoord grid_tiled_shape;
|
| 165 |
+
int swizzle_log_tile;
|
| 166 |
+
|
| 167 |
+
typename Mma::IteratorA::Params params_A;
|
| 168 |
+
typename Mma::IteratorB::Params params_B;
|
| 169 |
+
typename Mma::IteratorScale::Params params_scale;
|
| 170 |
+
typename Epilogue::OutputTileIterator::Params params_C;
|
| 171 |
+
typename Epilogue::OutputTileIterator::Params params_D;
|
| 172 |
+
typename Epilogue::TensorTileIterator::Params params_Tensor;
|
| 173 |
+
|
| 174 |
+
typename EpilogueOutputOp::Params output_op;
|
| 175 |
+
|
| 176 |
+
// GemmUniversalMode mode; todo
|
| 177 |
+
int batch_count;
|
| 178 |
+
int gemm_k_size;
|
| 179 |
+
void *ptr_A;
|
| 180 |
+
void *ptr_B;
|
| 181 |
+
void *ptr_C;
|
| 182 |
+
void *ptr_scales;
|
| 183 |
+
void *ptr_D;
|
| 184 |
+
|
| 185 |
+
void *ptr_Vector;
|
| 186 |
+
typename LayoutC::Stride::Index ldr;
|
| 187 |
+
|
| 188 |
+
void *ptr_Tensor;
|
| 189 |
+
|
| 190 |
+
int64_t batch_stride_A;
|
| 191 |
+
int64_t batch_stride_B;
|
| 192 |
+
int64_t batch_stride_C;
|
| 193 |
+
int64_t batch_stride_D;
|
| 194 |
+
int64_t batch_stride_Vector;
|
| 195 |
+
int64_t batch_stride_Tensor;
|
| 196 |
+
|
| 197 |
+
// For gather+scatter operations
|
| 198 |
+
int const *gather_A_indices;
|
| 199 |
+
int const *gather_B_indices;
|
| 200 |
+
int const *scatter_D_indices;
|
| 201 |
+
|
| 202 |
+
//
|
| 203 |
+
// Methods
|
| 204 |
+
//
|
| 205 |
+
|
| 206 |
+
CUTLASS_HOST_DEVICE
|
| 207 |
+
Params() : swizzle_log_tile(0), gemm_k_size(0) {}
|
| 208 |
+
|
| 209 |
+
CUTLASS_HOST_DEVICE
|
| 210 |
+
Params(Arguments const &args,
|
| 211 |
+
cutlass::gemm::GemmCoord const &grid_tiled_shape,
|
| 212 |
+
const int gemm_k_size, void *workspace = nullptr)
|
| 213 |
+
: problem_size(args.problem_size), grid_tiled_shape(grid_tiled_shape),
|
| 214 |
+
swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)),
|
| 215 |
+
params_A(args.lda), params_B(args.ldb), params_C(args.ldc),
|
| 216 |
+
params_D(args.ldd), params_Tensor(args.ldt), output_op(args.epilogue),
|
| 217 |
+
batch_count(args.batch_count), gemm_k_size(gemm_k_size),
|
| 218 |
+
ptr_A(const_cast<void *>(args.ptr_A)),
|
| 219 |
+
ptr_B(const_cast<void *>(args.ptr_B)),
|
| 220 |
+
ptr_scales(const_cast<void *>(args.ptr_scales)),
|
| 221 |
+
ptr_C(const_cast<void *>(args.ptr_C)), ptr_D(args.ptr_D),
|
| 222 |
+
ptr_Vector(const_cast<void *>(args.ptr_Vector)), ldr(args.ldr),
|
| 223 |
+
ptr_Tensor(const_cast<void *>(args.ptr_Tensor)), batch_stride_A(args.batch_stride_A),
|
| 224 |
+
batch_stride_B(args.batch_stride_B),
|
| 225 |
+
batch_stride_C(args.batch_stride_C),
|
| 226 |
+
batch_stride_D(args.batch_stride_D),
|
| 227 |
+
batch_stride_Vector(args.batch_stride_Vector),
|
| 228 |
+
batch_stride_Tensor(args.batch_stride_Tensor),
|
| 229 |
+
gather_A_indices(args.gather_A_indices),
|
| 230 |
+
gather_B_indices(args.gather_B_indices),
|
| 231 |
+
scatter_D_indices(args.scatter_D_indices) {}
|
| 232 |
+
};
|
| 233 |
+
|
| 234 |
+
/// Shared memory storage structure
|
| 235 |
+
union SharedStorage {
|
| 236 |
+
typename Mma::SharedStorage main_loop;
|
| 237 |
+
typename Epilogue::SharedStorage epilogue;
|
| 238 |
+
};
|
| 239 |
+
|
| 240 |
+
//
|
| 241 |
+
// Methods
|
| 242 |
+
//
|
| 243 |
+
|
| 244 |
+
CUTLASS_HOST_DEVICE
|
| 245 |
+
GemmFpAIntBWithBroadcast() {}
|
| 246 |
+
|
| 247 |
+
CUTLASS_HOST_DEVICE
|
| 248 |
+
static Status can_implement(Arguments const &args) {
|
| 249 |
+
// todo
|
| 250 |
+
return Status::kSuccess;
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
static size_t
|
| 254 |
+
get_extra_workspace_size(Arguments const &args,
|
| 255 |
+
cutlass::gemm::GemmCoord const &grid_tiled_shape) {
|
| 256 |
+
|
| 257 |
+
return 0;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
// The dummy template parameter is not used and exists so that we can compile
|
| 261 |
+
// this code using a standard earlier than C++17. Prior to C++17, fully
|
| 262 |
+
// specialized templates HAD to exists in a namespace
|
| 263 |
+
template <bool B, typename dummy = void> struct KernelRunner {
|
| 264 |
+
CUTLASS_DEVICE
|
| 265 |
+
static void run_kernel(Params const ¶ms,
|
| 266 |
+
SharedStorage &shared_storage) {
|
| 267 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 268 |
+
}
|
| 269 |
+
};
|
| 270 |
+
|
| 271 |
+
template <typename dummy> struct KernelRunner<true, dummy> {
|
| 272 |
+
CUTLASS_DEVICE
|
| 273 |
+
static void run_kernel(Params const ¶ms,
|
| 274 |
+
SharedStorage &shared_storage) {
|
| 275 |
+
using LayoutB = typename Mma::IteratorB::Layout;
|
| 276 |
+
static_assert(
|
| 277 |
+
platform::is_same<LayoutB, layout::RowMajor>::value &&
|
| 278 |
+
kInterleave == 1 ||
|
| 279 |
+
platform::is_same<LayoutB, layout::ColumnMajor>::value &&
|
| 280 |
+
kInterleave >= 1,
|
| 281 |
+
"B must be row major/col major OR col major interleaved.");
|
| 282 |
+
|
| 283 |
+
// Compute threadblock location
|
| 284 |
+
ThreadblockSwizzle threadblock_swizzle;
|
| 285 |
+
|
| 286 |
+
cutlass::gemm::GemmCoord threadblock_tile_offset =
|
| 287 |
+
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
| 288 |
+
|
| 289 |
+
// Early exit if CTA is out of range
|
| 290 |
+
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
|
| 291 |
+
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
|
| 292 |
+
|
| 293 |
+
return;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
// Compute initial location in logical coordinates
|
| 297 |
+
cutlass::MatrixCoord tb_offset_A{
|
| 298 |
+
threadblock_tile_offset.m() * Mma::Shape::kM,
|
| 299 |
+
threadblock_tile_offset.k() * params.gemm_k_size,
|
| 300 |
+
};
|
| 301 |
+
|
| 302 |
+
cutlass::MatrixCoord tb_offset_B{
|
| 303 |
+
threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
|
| 304 |
+
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
|
| 305 |
+
|
| 306 |
+
cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() *
|
| 307 |
+
Mma::Shape::kN};
|
| 308 |
+
|
| 309 |
+
// Problem size is a function of threadblock index in the K dimension
|
| 310 |
+
int problem_size_k =
|
| 311 |
+
min(params.problem_size.k(),
|
| 312 |
+
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
| 313 |
+
|
| 314 |
+
// Compute threadblock-scoped matrix multiply-add
|
| 315 |
+
int gemm_k_iterations =
|
| 316 |
+
(problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) /
|
| 317 |
+
Mma::Shape::kK;
|
| 318 |
+
|
| 319 |
+
// Compute position within threadblock
|
| 320 |
+
int thread_idx = threadIdx.x;
|
| 321 |
+
|
| 322 |
+
// Construct iterators to A and B operands
|
| 323 |
+
typename Mma::IteratorA iterator_A(
|
| 324 |
+
params.params_A, static_cast<ElementA *>(params.ptr_A),
|
| 325 |
+
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A,
|
| 326 |
+
params.gather_A_indices);
|
| 327 |
+
|
| 328 |
+
typename Mma::IteratorB iterator_B(
|
| 329 |
+
params.params_B, static_cast<ElementB *>(params.ptr_B),
|
| 330 |
+
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
|
| 331 |
+
thread_idx, tb_offset_B, params.gather_B_indices);
|
| 332 |
+
|
| 333 |
+
typename Mma::IteratorScale iterator_scale(
|
| 334 |
+
params.params_scale, static_cast<ElementScale *>(params.ptr_scales),
|
| 335 |
+
{1, params.problem_size.n()}, thread_idx, tb_offset_scale);
|
| 336 |
+
|
| 337 |
+
// Broadcast the warp_id computed by lane 0 to ensure dependent code is
|
| 338 |
+
// compiled as warp-uniform.
|
| 339 |
+
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
| 340 |
+
int lane_idx = threadIdx.x % 32;
|
| 341 |
+
|
| 342 |
+
//
|
| 343 |
+
// Main loop
|
| 344 |
+
//
|
| 345 |
+
// Construct thread-scoped matrix multiply
|
| 346 |
+
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
| 347 |
+
|
| 348 |
+
typename Mma::FragmentC accumulators;
|
| 349 |
+
|
| 350 |
+
accumulators.clear();
|
| 351 |
+
|
| 352 |
+
if (gemm_k_iterations > 0) {
|
| 353 |
+
// Compute threadblock-scoped matrix multiply-add
|
| 354 |
+
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B,
|
| 355 |
+
iterator_scale, accumulators);
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
//
|
| 359 |
+
// Epilogue
|
| 360 |
+
//
|
| 361 |
+
|
| 362 |
+
EpilogueOutputOp output_op(params.output_op);
|
| 363 |
+
|
| 364 |
+
//
|
| 365 |
+
// Masked tile iterators constructed from members
|
| 366 |
+
//
|
| 367 |
+
|
| 368 |
+
threadblock_tile_offset =
|
| 369 |
+
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
| 370 |
+
|
| 371 |
+
// assume identity swizzle
|
| 372 |
+
MatrixCoord threadblock_offset(
|
| 373 |
+
threadblock_tile_offset.m() * Mma::Shape::kM,
|
| 374 |
+
threadblock_tile_offset.n() * Mma::Shape::kN);
|
| 375 |
+
|
| 376 |
+
int block_idx = threadblock_tile_offset.m() +
|
| 377 |
+
threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
| 378 |
+
|
| 379 |
+
ElementC *ptr_C = static_cast<ElementC *>(params.ptr_C);
|
| 380 |
+
ElementC *ptr_D = static_cast<ElementC *>(params.ptr_D);
|
| 381 |
+
|
| 382 |
+
// Tile iterator loading from source tensor.
|
| 383 |
+
typename Epilogue::OutputTileIterator iterator_C(
|
| 384 |
+
params.params_C, ptr_C, params.problem_size.mn(),
|
| 385 |
+
thread_idx, threadblock_offset, params.scatter_D_indices);
|
| 386 |
+
|
| 387 |
+
// Tile iterator writing to destination tensor.
|
| 388 |
+
typename Epilogue::OutputTileIterator iterator_D(
|
| 389 |
+
params.params_D, ptr_D, params.problem_size.mn(),
|
| 390 |
+
thread_idx, threadblock_offset, params.scatter_D_indices);
|
| 391 |
+
|
| 392 |
+
typename Epilogue::ElementTensor *ptr_Tensor =
|
| 393 |
+
static_cast<typename Epilogue::ElementTensor *>(params.ptr_Tensor);
|
| 394 |
+
|
| 395 |
+
// Define the reduction output pointer and move to the appropriate place
|
| 396 |
+
typename Epilogue::ElementVector *ptr_Vector =
|
| 397 |
+
static_cast<typename Epilogue::ElementVector *>(params.ptr_Vector);
|
| 398 |
+
|
| 399 |
+
typename Epilogue::TensorTileIterator tensor_iterator(
|
| 400 |
+
params.params_Tensor,
|
| 401 |
+
// Only the final block outputs Tensor
|
| 402 |
+
ptr_Tensor, params.problem_size.mn(), thread_idx, threadblock_offset);
|
| 403 |
+
|
| 404 |
+
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx,
|
| 405 |
+
lane_idx);
|
| 406 |
+
|
| 407 |
+
if (ptr_Vector) {
|
| 408 |
+
ptr_Vector += threadblock_offset.column() +
|
| 409 |
+
threadblock_tile_offset.m() * params.ldr;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
epilogue(output_op, ptr_Vector, iterator_D, accumulators, iterator_C,
|
| 413 |
+
tensor_iterator, params.problem_size.mn(), threadblock_offset);
|
| 414 |
+
}
|
| 415 |
+
};
|
| 416 |
+
|
| 417 |
+
/*
|
| 418 |
+
To improve compilation speed, we do not compile the device operator if the
|
| 419 |
+
CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel
|
| 420 |
+
operator.
|
| 421 |
+
*/
|
| 422 |
+
/// Executes one GEMM
|
| 423 |
+
CUTLASS_DEVICE
|
| 424 |
+
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
| 425 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
|
| 426 |
+
static constexpr bool compile_needed =
|
| 427 |
+
platform::is_same<KernelArch, arch::Sm70>::value;
|
| 428 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
| 429 |
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
| 430 |
+
static constexpr bool compile_needed =
|
| 431 |
+
platform::is_same<KernelArch, arch::Sm75>::value;
|
| 432 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
| 433 |
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
| 434 |
+
static constexpr bool compile_needed =
|
| 435 |
+
platform::is_same<KernelArch, arch::Sm80>::value;
|
| 436 |
+
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
| 437 |
+
#else
|
| 438 |
+
CUTLASS_NOT_IMPLEMENTED();
|
| 439 |
+
#endif
|
| 440 |
+
}
|
| 441 |
+
};
|
| 442 |
+
|
| 443 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 444 |
+
|
| 445 |
+
} // namespace kernel
|
| 446 |
+
} // namespace gemm
|
| 447 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
|
| 3 |
+
quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
|
| 4 |
+
to be consumed by CUTLASS.
|
| 5 |
+
|
| 6 |
+
Note that for int4, ThreadBlockK MUST be 64.
|
| 7 |
+
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#pragma once
|
| 11 |
+
|
| 12 |
+
#include "cutlass/layout/matrix.h"
|
| 13 |
+
#include "cutlass/numeric_types.h"
|
| 14 |
+
|
| 15 |
+
#include "cutlass/arch/arch.h"
|
| 16 |
+
#include "cutlass/arch/mma.h"
|
| 17 |
+
#include "cutlass/platform/platform.h"
|
| 18 |
+
|
| 19 |
+
#include "cutlass_extensions/arch/mma.h"
|
| 20 |
+
#include "cutlass_extensions/tile_interleaved_layout.h"
|
| 21 |
+
|
| 22 |
+
namespace cutlass {
|
| 23 |
+
namespace gemm {
|
| 24 |
+
namespace kernel {
|
| 25 |
+
|
| 26 |
+
template<typename TypeB, typename Arch, typename Enable = void>
|
| 27 |
+
struct LayoutDetailsB {
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
// Volta specialiations. Volta will dequantize before STS, so we need a different operator
|
| 31 |
+
template<typename TypeB>
|
| 32 |
+
struct LayoutDetailsB<TypeB, arch::Sm70> {
|
| 33 |
+
static constexpr int ThreadblockK = 64;
|
| 34 |
+
using Layout = layout::RowMajor;
|
| 35 |
+
static constexpr int ElementsPerAccess = 8;
|
| 36 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
|
| 40 |
+
// TODO - Switch this to column major for weights since gemms should be more performant.
|
| 41 |
+
template<typename Arch>
|
| 42 |
+
struct LayoutDetailsB<half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
|
| 43 |
+
static constexpr int ThreadblockK = 64;
|
| 44 |
+
using Layout = layout::RowMajor;
|
| 45 |
+
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
| 46 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
template<typename Arch>
|
| 50 |
+
struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
|
| 51 |
+
static constexpr int ThreadblockK = 64;
|
| 52 |
+
using Layout = layout::RowMajor;
|
| 53 |
+
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
|
| 54 |
+
using Operator = cutlass::arch::OpMultiplyAdd;
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
|
| 58 |
+
// which signals that we want to dequantize after loading from smem.
|
| 59 |
+
template<typename Arch>
|
| 60 |
+
struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
|
| 61 |
+
static constexpr int ThreadblockK = 64;
|
| 62 |
+
|
| 63 |
+
private:
|
| 64 |
+
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
| 65 |
+
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
| 66 |
+
|
| 67 |
+
public:
|
| 68 |
+
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
| 69 |
+
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
|
| 70 |
+
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template<typename Arch>
|
| 74 |
+
struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
|
| 75 |
+
static constexpr int ThreadblockK = 64;
|
| 76 |
+
|
| 77 |
+
private:
|
| 78 |
+
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
|
| 79 |
+
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
| 80 |
+
|
| 81 |
+
public:
|
| 82 |
+
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
| 83 |
+
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
|
| 84 |
+
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
} // namespace kernel
|
| 88 |
+
} // namespace gemm
|
| 89 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "cutlass_extensions/arch/mma.h"
|
| 4 |
+
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
| 5 |
+
|
| 6 |
+
namespace cutlass {
|
| 7 |
+
namespace gemm {
|
| 8 |
+
namespace threadblock {
|
| 9 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 10 |
+
|
| 11 |
+
// We need to distinguish here, since we want volta support. It is too much effort
|
| 12 |
+
// to write shared memory iterators that are probably needed for volta to function
|
| 13 |
+
// properly. As a result, we allow converters both after the LDG (for volta) and after
|
| 14 |
+
// the LDS for Turing+.
|
| 15 |
+
template<
|
| 16 |
+
/// Iterator for B matrix in global memory
|
| 17 |
+
typename IteratorB,
|
| 18 |
+
/// Warp level Mma
|
| 19 |
+
typename MmaOperator,
|
| 20 |
+
/// Math operation perform by warp level operator
|
| 21 |
+
typename MathOperator>
|
| 22 |
+
struct SetConverters {
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
// Dequantize after LDG, so set transforms accordingly
|
| 26 |
+
template<
|
| 27 |
+
/// Iterator for B matrix in global memory
|
| 28 |
+
typename IteratorB,
|
| 29 |
+
/// Mma Policy
|
| 30 |
+
typename MmaOperator>
|
| 31 |
+
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd> {
|
| 32 |
+
using TransformAfterLDG =
|
| 33 |
+
FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
| 34 |
+
typename IteratorB::Element,
|
| 35 |
+
IteratorB::Fragment::kElements>;
|
| 36 |
+
|
| 37 |
+
using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
| 38 |
+
typename MmaOperator::ArchMmaOperator::ElementB,
|
| 39 |
+
MmaOperator::FragmentB::kElements>;
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
// Dequantize after LDS, so set transforms accordingly
|
| 43 |
+
|
| 44 |
+
template<
|
| 45 |
+
/// Iterator for B matrix in global memory
|
| 46 |
+
typename IteratorB,
|
| 47 |
+
/// Mma Policy
|
| 48 |
+
typename MmaOperator>
|
| 49 |
+
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA> {
|
| 50 |
+
using TransformAfterLDG =
|
| 51 |
+
NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element, IteratorB::Fragment::kElements>;
|
| 52 |
+
|
| 53 |
+
using TransformAfterLDS =
|
| 54 |
+
FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
| 55 |
+
typename TransformAfterLDG::result_type::Element,
|
| 56 |
+
MmaOperator::FragmentB::kElements>;
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
|
| 61 |
+
template<
|
| 62 |
+
/// Element type for A matrix operand
|
| 63 |
+
typename ElementA_,
|
| 64 |
+
/// Layout type for A matrix operand
|
| 65 |
+
typename LayoutA_,
|
| 66 |
+
/// Access granularity of A matrix in units of elements
|
| 67 |
+
int kAlignmentA,
|
| 68 |
+
/// Element type for B matrix operand
|
| 69 |
+
typename ElementB_,
|
| 70 |
+
/// Layout type for B matrix operand
|
| 71 |
+
typename LayoutB_,
|
| 72 |
+
/// Access granularity of B matrix in units of elements
|
| 73 |
+
int kAlignmentB,
|
| 74 |
+
/// Element type for the input scale
|
| 75 |
+
typename ElementScale_,
|
| 76 |
+
/// Layout for the scale operand
|
| 77 |
+
typename LayoutScale_,
|
| 78 |
+
/// Access granularity of Scales in unit of elements
|
| 79 |
+
int kAlignmentScale,
|
| 80 |
+
/// Element type for internal accumulation
|
| 81 |
+
typename ElementAccumulator_,
|
| 82 |
+
/// Layout type for C and D matrix operands
|
| 83 |
+
typename LayoutC_,
|
| 84 |
+
/// Operator class tag
|
| 85 |
+
typename OperatorClass_,
|
| 86 |
+
/// Tag indicating architecture to tune for
|
| 87 |
+
typename ArchTag_,
|
| 88 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 89 |
+
typename ThreadblockShape_,
|
| 90 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 91 |
+
typename WarpShape_,
|
| 92 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 93 |
+
typename InstructionShape_,
|
| 94 |
+
/// Number of stages used in the pipelined mainloop
|
| 95 |
+
int Stages,
|
| 96 |
+
/// Operation performed by GEMM
|
| 97 |
+
typename Operator_,
|
| 98 |
+
/// Use zfill or predicate for out-of-bound cp.async
|
| 99 |
+
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
| 100 |
+
///
|
| 101 |
+
typename Enable = void>
|
| 102 |
+
struct DqMma;
|
| 103 |
+
|
| 104 |
+
} // namespace threadblock
|
| 105 |
+
} // namespace gemm
|
| 106 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "cutlass/gemm/threadblock/default_mma.h"
|
| 4 |
+
#include "cutlass_extensions/arch/mma.h"
|
| 5 |
+
|
| 6 |
+
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
|
| 7 |
+
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
|
| 8 |
+
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
| 9 |
+
#include "cutlass_extensions/tile_interleaved_layout.h"
|
| 10 |
+
|
| 11 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
| 12 |
+
|
| 13 |
+
namespace cutlass {
|
| 14 |
+
namespace gemm {
|
| 15 |
+
namespace threadblock {
|
| 16 |
+
|
| 17 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 18 |
+
|
| 19 |
+
template<
|
| 20 |
+
/// Type for elementA
|
| 21 |
+
typename ElementA,
|
| 22 |
+
/// Layout type for A matrix operand
|
| 23 |
+
typename LayoutA,
|
| 24 |
+
/// Access granularity of A matrix in units of elements
|
| 25 |
+
int kAlignmentA,
|
| 26 |
+
/// Type for element B
|
| 27 |
+
typename ElementB,
|
| 28 |
+
/// Layout type for B matrix operand
|
| 29 |
+
typename LayoutB,
|
| 30 |
+
/// Access granularity of B matrix in units of elements
|
| 31 |
+
int kAlignmentB,
|
| 32 |
+
/// Element type for the input scale
|
| 33 |
+
typename ElementScale,
|
| 34 |
+
/// Layout for the scale operand
|
| 35 |
+
typename LayoutScale,
|
| 36 |
+
/// Access granularity of Scales in unit of elements
|
| 37 |
+
int kAlignmentScale,
|
| 38 |
+
/// Element type for internal accumulation
|
| 39 |
+
typename ElementAccumulator,
|
| 40 |
+
/// Operator class tag
|
| 41 |
+
typename OperatorClass,
|
| 42 |
+
/// Tag indicating architecture to tune for
|
| 43 |
+
typename ArchTag,
|
| 44 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 45 |
+
typename ThreadblockShape,
|
| 46 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 47 |
+
typename WarpShape,
|
| 48 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 49 |
+
typename InstructionShape,
|
| 50 |
+
/// Stages in GEMM
|
| 51 |
+
int kStages,
|
| 52 |
+
///
|
| 53 |
+
typename Operator,
|
| 54 |
+
///
|
| 55 |
+
SharedMemoryClearOption SharedMemoryClear>
|
| 56 |
+
struct DqMma<ElementA,
|
| 57 |
+
LayoutA,
|
| 58 |
+
kAlignmentA,
|
| 59 |
+
ElementB,
|
| 60 |
+
LayoutB,
|
| 61 |
+
kAlignmentB,
|
| 62 |
+
ElementScale,
|
| 63 |
+
LayoutScale,
|
| 64 |
+
kAlignmentScale,
|
| 65 |
+
ElementAccumulator,
|
| 66 |
+
layout::RowMajor,
|
| 67 |
+
OperatorClass,
|
| 68 |
+
ArchTag,
|
| 69 |
+
ThreadblockShape,
|
| 70 |
+
WarpShape,
|
| 71 |
+
InstructionShape,
|
| 72 |
+
kStages,
|
| 73 |
+
Operator,
|
| 74 |
+
SharedMemoryClear,
|
| 75 |
+
typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> {
|
| 76 |
+
|
| 77 |
+
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
| 78 |
+
"Element A must be fp16 or bf16");
|
| 79 |
+
|
| 80 |
+
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
| 81 |
+
"Mma multistage must dequantize after ldsm");
|
| 82 |
+
|
| 83 |
+
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
| 84 |
+
"Element B must be uint8 or uint4");
|
| 85 |
+
|
| 86 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128) ?
|
| 87 |
+
cutlass::arch::CacheOperation::Global :
|
| 88 |
+
cutlass::arch::CacheOperation::Always;
|
| 89 |
+
|
| 90 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128) ?
|
| 91 |
+
cutlass::arch::CacheOperation::Global :
|
| 92 |
+
cutlass::arch::CacheOperation::Always;
|
| 93 |
+
|
| 94 |
+
// Define the MmaCore components
|
| 95 |
+
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
| 96 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
| 97 |
+
WarpShape,
|
| 98 |
+
InstructionShape,
|
| 99 |
+
ElementA,
|
| 100 |
+
LayoutA,
|
| 101 |
+
ElementB,
|
| 102 |
+
LayoutB,
|
| 103 |
+
ElementAccumulator,
|
| 104 |
+
layout::RowMajor,
|
| 105 |
+
OperatorClass,
|
| 106 |
+
std::max(kStages, 3),
|
| 107 |
+
Operator,
|
| 108 |
+
false,
|
| 109 |
+
CacheOpA,
|
| 110 |
+
CacheOpB>;
|
| 111 |
+
|
| 112 |
+
// Define iterators over tiles from the A operand
|
| 113 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 114 |
+
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
| 115 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
| 116 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 117 |
+
ElementA,
|
| 118 |
+
LayoutA,
|
| 119 |
+
1,
|
| 120 |
+
ThreadMapA,
|
| 121 |
+
AccessTypeA>;
|
| 122 |
+
|
| 123 |
+
// Define iterators over tiles from the B operand
|
| 124 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 125 |
+
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
| 126 |
+
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
| 127 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 128 |
+
ElementB,
|
| 129 |
+
LayoutB,
|
| 130 |
+
0,
|
| 131 |
+
ThreadMapB,
|
| 132 |
+
AccessTypeB>;
|
| 133 |
+
|
| 134 |
+
// ThreadMap for scale iterator
|
| 135 |
+
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
| 136 |
+
using IteratorScaleThreadMap =
|
| 137 |
+
transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
| 138 |
+
MmaCore::Shape::kN / kAlignmentScale,
|
| 139 |
+
kAlignmentScale>;
|
| 140 |
+
|
| 141 |
+
// Define iterators over tiles from the scale operand
|
| 142 |
+
using IteratorScale =
|
| 143 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
| 144 |
+
ElementScale,
|
| 145 |
+
LayoutScale,
|
| 146 |
+
0,
|
| 147 |
+
IteratorScaleThreadMap,
|
| 148 |
+
kAlignmentScale>;
|
| 149 |
+
|
| 150 |
+
using SmemIteratorScale = IteratorScale;
|
| 151 |
+
|
| 152 |
+
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA,
|
| 153 |
+
ElementB,
|
| 154 |
+
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
| 155 |
+
|
| 156 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 157 |
+
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape,
|
| 158 |
+
IteratorA,
|
| 159 |
+
typename MmaCore::SmemIteratorA,
|
| 160 |
+
MmaCore::kCacheOpA,
|
| 161 |
+
IteratorB,
|
| 162 |
+
typename MmaCore::SmemIteratorB,
|
| 163 |
+
MmaCore::kCacheOpB,
|
| 164 |
+
IteratorScale,
|
| 165 |
+
SmemIteratorScale,
|
| 166 |
+
ElementAccumulator,
|
| 167 |
+
layout::RowMajor,
|
| 168 |
+
typename MmaCore::MmaPolicy,
|
| 169 |
+
kStages,
|
| 170 |
+
Converter,
|
| 171 |
+
SharedMemoryClear>;
|
| 172 |
+
};
|
| 173 |
+
|
| 174 |
+
template<
|
| 175 |
+
/// Type for element A
|
| 176 |
+
typename ElementA,
|
| 177 |
+
/// Layout type for A matrix operand
|
| 178 |
+
typename LayoutA,
|
| 179 |
+
/// Access granularity of A matrix in units of elements
|
| 180 |
+
int kAlignmentA,
|
| 181 |
+
/// Type for element B
|
| 182 |
+
typename ElementB,
|
| 183 |
+
/// Access granularity of B matrix in units of elements
|
| 184 |
+
int kAlignmentB,
|
| 185 |
+
/// Element type for the input scale
|
| 186 |
+
typename ElementScale,
|
| 187 |
+
/// Layout for the scale operand
|
| 188 |
+
typename LayoutScale,
|
| 189 |
+
/// Access granularity of Scales in unit of elements
|
| 190 |
+
int kAlignmentScale,
|
| 191 |
+
/// Element type for internal accumulation
|
| 192 |
+
typename ElementAccumulator,
|
| 193 |
+
/// Operator class tag
|
| 194 |
+
typename OperatorClass,
|
| 195 |
+
/// Tag indicating architecture to tune for
|
| 196 |
+
typename ArchTag,
|
| 197 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 198 |
+
typename ThreadblockShape,
|
| 199 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 200 |
+
typename WarpShape,
|
| 201 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 202 |
+
typename InstructionShape,
|
| 203 |
+
/// Stages in GEMM
|
| 204 |
+
int kStages,
|
| 205 |
+
///
|
| 206 |
+
typename Operator,
|
| 207 |
+
///
|
| 208 |
+
SharedMemoryClearOption SharedMemoryClear,
|
| 209 |
+
///
|
| 210 |
+
int RowsPerTile,
|
| 211 |
+
///
|
| 212 |
+
int ColumnsInterleaved>
|
| 213 |
+
struct DqMma<ElementA,
|
| 214 |
+
LayoutA,
|
| 215 |
+
kAlignmentA,
|
| 216 |
+
ElementB,
|
| 217 |
+
layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
|
| 218 |
+
kAlignmentB,
|
| 219 |
+
ElementScale,
|
| 220 |
+
LayoutScale,
|
| 221 |
+
kAlignmentScale,
|
| 222 |
+
ElementAccumulator,
|
| 223 |
+
layout::RowMajor,
|
| 224 |
+
OperatorClass,
|
| 225 |
+
ArchTag,
|
| 226 |
+
ThreadblockShape,
|
| 227 |
+
WarpShape,
|
| 228 |
+
InstructionShape,
|
| 229 |
+
kStages,
|
| 230 |
+
Operator,
|
| 231 |
+
SharedMemoryClear,
|
| 232 |
+
typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> {
|
| 233 |
+
|
| 234 |
+
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
| 235 |
+
"Element A must be fp16 or bf16");
|
| 236 |
+
|
| 237 |
+
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
| 238 |
+
"Mma multistage must dequantize after ldsm");
|
| 239 |
+
|
| 240 |
+
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
| 241 |
+
"Element B must be uint8 or uint4");
|
| 242 |
+
|
| 243 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128) ?
|
| 244 |
+
cutlass::arch::CacheOperation::Global :
|
| 245 |
+
cutlass::arch::CacheOperation::Always;
|
| 246 |
+
|
| 247 |
+
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128) ?
|
| 248 |
+
cutlass::arch::CacheOperation::Global :
|
| 249 |
+
cutlass::arch::CacheOperation::Always;
|
| 250 |
+
|
| 251 |
+
// Define the MmaCore components
|
| 252 |
+
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
| 253 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
| 254 |
+
WarpShape,
|
| 255 |
+
InstructionShape,
|
| 256 |
+
ElementA,
|
| 257 |
+
LayoutA,
|
| 258 |
+
ElementB,
|
| 259 |
+
layout::ColumnMajor,
|
| 260 |
+
ElementAccumulator,
|
| 261 |
+
layout::RowMajor,
|
| 262 |
+
OperatorClass,
|
| 263 |
+
std::max(kStages, 3),
|
| 264 |
+
Operator,
|
| 265 |
+
false,
|
| 266 |
+
CacheOpA,
|
| 267 |
+
CacheOpB>;
|
| 268 |
+
|
| 269 |
+
// Define iterators over tiles from the A operand
|
| 270 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 271 |
+
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
| 272 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
| 273 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 274 |
+
ElementA,
|
| 275 |
+
LayoutA,
|
| 276 |
+
1,
|
| 277 |
+
ThreadMapA,
|
| 278 |
+
AccessTypeA>;
|
| 279 |
+
|
| 280 |
+
private:
|
| 281 |
+
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
| 282 |
+
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
| 283 |
+
|
| 284 |
+
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
| 285 |
+
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
| 286 |
+
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
| 287 |
+
|
| 288 |
+
using GmemIteratorShape =
|
| 289 |
+
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
| 290 |
+
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
| 291 |
+
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>,
|
| 292 |
+
OriginalThreadMap::kThreads,
|
| 293 |
+
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
| 294 |
+
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
| 295 |
+
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
| 296 |
+
|
| 297 |
+
public:
|
| 298 |
+
// Define iterators over tiles from the B operand
|
| 299 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 300 |
+
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
| 301 |
+
using IteratorB = cutlass::transform::threadblock::
|
| 302 |
+
PredicatedTileAccessIterator<GmemIteratorShape, ElementB, layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
|
| 303 |
+
|
| 304 |
+
// ThreadMap for scale iterator
|
| 305 |
+
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
| 306 |
+
using IteratorScaleThreadMap =
|
| 307 |
+
transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
| 308 |
+
MmaCore::Shape::kN / kAlignmentScale,
|
| 309 |
+
kAlignmentScale>;
|
| 310 |
+
|
| 311 |
+
// Define iterators over tiles from the scale operand
|
| 312 |
+
using IteratorScale =
|
| 313 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
| 314 |
+
ElementScale,
|
| 315 |
+
LayoutScale,
|
| 316 |
+
0,
|
| 317 |
+
IteratorScaleThreadMap,
|
| 318 |
+
kAlignmentScale>;
|
| 319 |
+
|
| 320 |
+
using SmemIteratorScale = IteratorScale;
|
| 321 |
+
|
| 322 |
+
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementA,
|
| 323 |
+
ElementB,
|
| 324 |
+
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
| 325 |
+
|
| 326 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 327 |
+
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape,
|
| 328 |
+
IteratorA,
|
| 329 |
+
typename MmaCore::SmemIteratorA,
|
| 330 |
+
MmaCore::kCacheOpA,
|
| 331 |
+
IteratorB,
|
| 332 |
+
typename MmaCore::SmemIteratorB,
|
| 333 |
+
MmaCore::kCacheOpB,
|
| 334 |
+
IteratorScale,
|
| 335 |
+
SmemIteratorScale,
|
| 336 |
+
ElementAccumulator,
|
| 337 |
+
layout::RowMajor,
|
| 338 |
+
typename MmaCore::MmaPolicy,
|
| 339 |
+
kStages,
|
| 340 |
+
Converter,
|
| 341 |
+
SharedMemoryClear>;
|
| 342 |
+
};
|
| 343 |
+
|
| 344 |
+
} // namespace threadblock
|
| 345 |
+
} // namespace gemm
|
| 346 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "cutlass/gemm/threadblock/default_mma.h"
|
| 4 |
+
#include "cutlass_extensions/arch/mma.h"
|
| 5 |
+
|
| 6 |
+
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
|
| 7 |
+
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
|
| 8 |
+
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
| 9 |
+
#include "cutlass_extensions/tile_interleaved_layout.h"
|
| 10 |
+
|
| 11 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
| 12 |
+
|
| 13 |
+
namespace cutlass {
|
| 14 |
+
namespace gemm {
|
| 15 |
+
namespace threadblock {
|
| 16 |
+
|
| 17 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 18 |
+
|
| 19 |
+
template<
|
| 20 |
+
/// Type for element A
|
| 21 |
+
typename ElementA,
|
| 22 |
+
/// Layout type for A matrix operand
|
| 23 |
+
typename LayoutA,
|
| 24 |
+
/// Access granularity of A matrix in units of elements
|
| 25 |
+
int kAlignmentA,
|
| 26 |
+
/// Type for element B
|
| 27 |
+
typename ElementB,
|
| 28 |
+
/// Layout type for B matrix operand
|
| 29 |
+
typename LayoutB,
|
| 30 |
+
/// Access granularity of B matrix in units of elements
|
| 31 |
+
int kAlignmentB,
|
| 32 |
+
/// Element type for the input scale
|
| 33 |
+
typename ElementScale,
|
| 34 |
+
/// Layout for the scale operand
|
| 35 |
+
typename LayoutScale,
|
| 36 |
+
/// Access granularity of Scales in unit of elements
|
| 37 |
+
int kAlignmentScale,
|
| 38 |
+
/// Element type for internal accumulation
|
| 39 |
+
typename ElementAccumulator,
|
| 40 |
+
/// Operator class tag
|
| 41 |
+
typename OperatorClass,
|
| 42 |
+
/// Tag indicating architecture to tune for
|
| 43 |
+
typename ArchTag,
|
| 44 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 45 |
+
typename ThreadblockShape,
|
| 46 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 47 |
+
typename WarpShape,
|
| 48 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 49 |
+
typename InstructionShape,
|
| 50 |
+
/// Operation performed by GEMM
|
| 51 |
+
typename Operator>
|
| 52 |
+
struct DqMma<ElementA,
|
| 53 |
+
LayoutA,
|
| 54 |
+
kAlignmentA,
|
| 55 |
+
ElementB,
|
| 56 |
+
LayoutB,
|
| 57 |
+
kAlignmentB,
|
| 58 |
+
ElementScale,
|
| 59 |
+
LayoutScale,
|
| 60 |
+
kAlignmentScale,
|
| 61 |
+
ElementAccumulator,
|
| 62 |
+
layout::RowMajor,
|
| 63 |
+
OperatorClass,
|
| 64 |
+
ArchTag,
|
| 65 |
+
ThreadblockShape,
|
| 66 |
+
WarpShape,
|
| 67 |
+
InstructionShape,
|
| 68 |
+
2,
|
| 69 |
+
Operator,
|
| 70 |
+
SharedMemoryClearOption::kNone,
|
| 71 |
+
typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
|
| 72 |
+
|
| 73 |
+
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
| 74 |
+
"Element A must be fp16 or bf16");
|
| 75 |
+
|
| 76 |
+
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
| 77 |
+
"Element B must be uint8 or uint4");
|
| 78 |
+
|
| 79 |
+
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
| 80 |
+
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
| 81 |
+
using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
|
| 82 |
+
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
| 83 |
+
|
| 84 |
+
// Define the MmaCore components
|
| 85 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
| 86 |
+
WarpShape,
|
| 87 |
+
InstructionShape,
|
| 88 |
+
MmaCoreElementA,
|
| 89 |
+
LayoutA,
|
| 90 |
+
MmaCoreElementB,
|
| 91 |
+
LayoutB,
|
| 92 |
+
ElementAccumulator,
|
| 93 |
+
layout::RowMajor,
|
| 94 |
+
OperatorClass,
|
| 95 |
+
2,
|
| 96 |
+
Operator>;
|
| 97 |
+
|
| 98 |
+
// Define iterators over tiles from the A operand
|
| 99 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
| 100 |
+
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
|
| 101 |
+
ElementA,
|
| 102 |
+
LayoutA,
|
| 103 |
+
1,
|
| 104 |
+
typename MmaCore::IteratorThreadMapA,
|
| 105 |
+
kAlignmentA>;
|
| 106 |
+
|
| 107 |
+
// Define iterators over tiles from the B operand
|
| 108 |
+
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
| 109 |
+
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
|
| 110 |
+
ElementB,
|
| 111 |
+
LayoutB,
|
| 112 |
+
0,
|
| 113 |
+
typename MmaCore::IteratorThreadMapB,
|
| 114 |
+
kAlignmentB>;
|
| 115 |
+
|
| 116 |
+
// ThreadMap for scale iterator
|
| 117 |
+
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
| 118 |
+
using IteratorScaleThreadMap =
|
| 119 |
+
transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
| 120 |
+
MmaCore::Shape::kN / kAlignmentScale,
|
| 121 |
+
kAlignmentScale>;
|
| 122 |
+
|
| 123 |
+
// Define iterators over tiles from the scale operand
|
| 124 |
+
using IteratorScale =
|
| 125 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
| 126 |
+
ElementScale,
|
| 127 |
+
LayoutScale,
|
| 128 |
+
0,
|
| 129 |
+
IteratorScaleThreadMap,
|
| 130 |
+
kAlignmentScale>;
|
| 131 |
+
|
| 132 |
+
using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
|
| 133 |
+
using SmemIteratorScale =
|
| 134 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
| 135 |
+
SmemScaleType,
|
| 136 |
+
LayoutScale,
|
| 137 |
+
0,
|
| 138 |
+
IteratorScaleThreadMap,
|
| 139 |
+
kAlignmentScale>;
|
| 140 |
+
|
| 141 |
+
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
| 142 |
+
|
| 143 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 144 |
+
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
|
| 145 |
+
IteratorA,
|
| 146 |
+
typename MmaCore::SmemIteratorA,
|
| 147 |
+
IteratorB,
|
| 148 |
+
typename MmaCore::SmemIteratorB,
|
| 149 |
+
IteratorScale,
|
| 150 |
+
SmemIteratorScale,
|
| 151 |
+
ElementAccumulator,
|
| 152 |
+
layout::RowMajor,
|
| 153 |
+
typename MmaCore::MmaPolicy,
|
| 154 |
+
typename Converters::TransformAfterLDG,
|
| 155 |
+
typename Converters::TransformAfterLDS>;
|
| 156 |
+
};
|
| 157 |
+
|
| 158 |
+
// Specialization to handle column major interleave B
|
| 159 |
+
template<
|
| 160 |
+
/// Type for element A
|
| 161 |
+
typename ElementA,
|
| 162 |
+
/// Layout type for A matrix operand
|
| 163 |
+
typename LayoutA,
|
| 164 |
+
/// Access granularity of A matrix in units of elements
|
| 165 |
+
int kAlignmentA,
|
| 166 |
+
/// Type for element B
|
| 167 |
+
typename ElementB,
|
| 168 |
+
/// Access granularity of B matrix in units of elements
|
| 169 |
+
int kAlignmentB,
|
| 170 |
+
/// Element type for the input scale
|
| 171 |
+
typename ElementScale,
|
| 172 |
+
/// Layout for the scale operand
|
| 173 |
+
typename LayoutScale,
|
| 174 |
+
/// Access granularity of Scales in unit of elements
|
| 175 |
+
int kAlignmentScale,
|
| 176 |
+
/// Element type for internal accumulation
|
| 177 |
+
typename ElementAccumulator,
|
| 178 |
+
/// Operator class tag
|
| 179 |
+
typename OperatorClass,
|
| 180 |
+
/// Tag indicating architecture to tune for
|
| 181 |
+
typename ArchTag,
|
| 182 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 183 |
+
typename ThreadblockShape,
|
| 184 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 185 |
+
typename WarpShape,
|
| 186 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 187 |
+
typename InstructionShape,
|
| 188 |
+
/// Operation performed by GEMM
|
| 189 |
+
typename Operator,
|
| 190 |
+
///
|
| 191 |
+
int RowsPerTile,
|
| 192 |
+
///
|
| 193 |
+
int ColumnsInterleaved>
|
| 194 |
+
struct DqMma<ElementA,
|
| 195 |
+
LayoutA,
|
| 196 |
+
kAlignmentA,
|
| 197 |
+
ElementB,
|
| 198 |
+
layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
|
| 199 |
+
kAlignmentB,
|
| 200 |
+
ElementScale,
|
| 201 |
+
LayoutScale,
|
| 202 |
+
kAlignmentScale,
|
| 203 |
+
ElementAccumulator,
|
| 204 |
+
layout::RowMajor,
|
| 205 |
+
OperatorClass,
|
| 206 |
+
ArchTag,
|
| 207 |
+
ThreadblockShape,
|
| 208 |
+
WarpShape,
|
| 209 |
+
InstructionShape,
|
| 210 |
+
2,
|
| 211 |
+
Operator,
|
| 212 |
+
SharedMemoryClearOption::kNone,
|
| 213 |
+
typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
|
| 214 |
+
|
| 215 |
+
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
| 216 |
+
"Element A must be fp16 or bf16");
|
| 217 |
+
|
| 218 |
+
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
| 219 |
+
"Element B must be uint8 or uint4");
|
| 220 |
+
|
| 221 |
+
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
| 222 |
+
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
| 223 |
+
using MmaCoreElementA = typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
|
| 224 |
+
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
| 225 |
+
|
| 226 |
+
// Define the MmaCore components
|
| 227 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
| 228 |
+
WarpShape,
|
| 229 |
+
InstructionShape,
|
| 230 |
+
MmaCoreElementA,
|
| 231 |
+
LayoutA,
|
| 232 |
+
MmaCoreElementB,
|
| 233 |
+
layout::ColumnMajor,
|
| 234 |
+
ElementAccumulator,
|
| 235 |
+
layout::RowMajor,
|
| 236 |
+
OperatorClass,
|
| 237 |
+
2,
|
| 238 |
+
Operator>;
|
| 239 |
+
|
| 240 |
+
// Define iterators over tiles from the A operand
|
| 241 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
| 242 |
+
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
|
| 243 |
+
ElementA,
|
| 244 |
+
LayoutA,
|
| 245 |
+
1,
|
| 246 |
+
typename MmaCore::IteratorThreadMapA,
|
| 247 |
+
kAlignmentA>;
|
| 248 |
+
|
| 249 |
+
private:
|
| 250 |
+
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
| 251 |
+
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
| 252 |
+
|
| 253 |
+
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
| 254 |
+
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
| 255 |
+
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
| 256 |
+
|
| 257 |
+
using GmemIteratorShape =
|
| 258 |
+
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
| 259 |
+
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
| 260 |
+
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>,
|
| 261 |
+
OriginalThreadMap::kThreads,
|
| 262 |
+
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
| 263 |
+
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
| 264 |
+
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
| 265 |
+
|
| 266 |
+
public:
|
| 267 |
+
// Define iterators over tiles from the B operand
|
| 268 |
+
using IteratorB = cutlass::transform::threadblock::
|
| 269 |
+
PredicatedTileIterator<GmemIteratorShape, ElementB, layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
|
| 270 |
+
|
| 271 |
+
// ThreadMap for scale iterator
|
| 272 |
+
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
| 273 |
+
using IteratorScaleThreadMap =
|
| 274 |
+
transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
| 275 |
+
MmaCore::Shape::kN / kAlignmentScale,
|
| 276 |
+
kAlignmentScale>;
|
| 277 |
+
|
| 278 |
+
// Define iterators over tiles from the scale operand
|
| 279 |
+
using IteratorScale =
|
| 280 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
| 281 |
+
ElementScale,
|
| 282 |
+
LayoutScale,
|
| 283 |
+
0,
|
| 284 |
+
IteratorScaleThreadMap,
|
| 285 |
+
kAlignmentScale>;
|
| 286 |
+
|
| 287 |
+
using SmemScaleType = typename platform::conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
|
| 288 |
+
using SmemIteratorScale =
|
| 289 |
+
cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaCore::Shape::kN>,
|
| 290 |
+
SmemScaleType,
|
| 291 |
+
LayoutScale,
|
| 292 |
+
0,
|
| 293 |
+
IteratorScaleThreadMap,
|
| 294 |
+
kAlignmentScale>;
|
| 295 |
+
|
| 296 |
+
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
| 297 |
+
|
| 298 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 299 |
+
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape,
|
| 300 |
+
IteratorA,
|
| 301 |
+
typename MmaCore::SmemIteratorA,
|
| 302 |
+
IteratorB,
|
| 303 |
+
typename MmaCore::SmemIteratorB,
|
| 304 |
+
IteratorScale,
|
| 305 |
+
SmemIteratorScale,
|
| 306 |
+
ElementAccumulator,
|
| 307 |
+
layout::RowMajor,
|
| 308 |
+
typename MmaCore::MmaPolicy,
|
| 309 |
+
typename Converters::TransformAfterLDG,
|
| 310 |
+
typename Converters::TransformAfterLDS>;
|
| 311 |
+
};
|
| 312 |
+
|
| 313 |
+
} // namespace threadblock
|
| 314 |
+
} // namespace gemm
|
| 315 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
| 4 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
| 5 |
+
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
|
| 6 |
+
|
| 7 |
+
namespace cutlass {
|
| 8 |
+
namespace gemm {
|
| 9 |
+
namespace threadblock {
|
| 10 |
+
|
| 11 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 12 |
+
|
| 13 |
+
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight
|
| 14 |
+
template<
|
| 15 |
+
/// Layout type for A matrix operand
|
| 16 |
+
typename LayoutA,
|
| 17 |
+
/// Access granularity of A matrix in units of elements
|
| 18 |
+
int kAlignmentA,
|
| 19 |
+
/// Layout type for B matrix operand
|
| 20 |
+
typename LayoutB,
|
| 21 |
+
/// Access granularity of B matrix in units of elements
|
| 22 |
+
int kAlignmentB,
|
| 23 |
+
/// Element type for internal accumulation
|
| 24 |
+
typename ElementAccumulator,
|
| 25 |
+
/// Tag indicating architecture to tune for
|
| 26 |
+
typename ArchTag,
|
| 27 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 28 |
+
typename ThreadblockShape,
|
| 29 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 30 |
+
typename WarpShape,
|
| 31 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 32 |
+
typename InstructionShape,
|
| 33 |
+
/// Operation performed by GEMM
|
| 34 |
+
typename Operator>
|
| 35 |
+
struct DefaultMma<cutlass::half_t,
|
| 36 |
+
LayoutA,
|
| 37 |
+
kAlignmentA,
|
| 38 |
+
uint8_t,
|
| 39 |
+
LayoutB,
|
| 40 |
+
kAlignmentB,
|
| 41 |
+
ElementAccumulator,
|
| 42 |
+
layout::RowMajor,
|
| 43 |
+
arch::OpClassTensorOp,
|
| 44 |
+
ArchTag,
|
| 45 |
+
ThreadblockShape,
|
| 46 |
+
WarpShape,
|
| 47 |
+
InstructionShape,
|
| 48 |
+
2,
|
| 49 |
+
Operator> {
|
| 50 |
+
|
| 51 |
+
private:
|
| 52 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
| 53 |
+
|
| 54 |
+
using Mma = DqMma<half_t,
|
| 55 |
+
LayoutA,
|
| 56 |
+
kAlignmentA,
|
| 57 |
+
uint8_t,
|
| 58 |
+
LayoutB,
|
| 59 |
+
kAlignmentB,
|
| 60 |
+
half_t,
|
| 61 |
+
layout::RowMajor,
|
| 62 |
+
kAlignmentScale,
|
| 63 |
+
ElementAccumulator,
|
| 64 |
+
layout::RowMajor,
|
| 65 |
+
arch::OpClassTensorOp,
|
| 66 |
+
ArchTag,
|
| 67 |
+
ThreadblockShape,
|
| 68 |
+
WarpShape,
|
| 69 |
+
InstructionShape,
|
| 70 |
+
2,
|
| 71 |
+
Operator>;
|
| 72 |
+
|
| 73 |
+
public:
|
| 74 |
+
// Define the MmaCore components
|
| 75 |
+
using MmaCore = typename Mma::MmaCore;
|
| 76 |
+
|
| 77 |
+
// Define iterators over tiles from the A operand
|
| 78 |
+
using IteratorA = typename Mma::IteratorA;
|
| 79 |
+
|
| 80 |
+
// Define iterators over tiles from the B operand
|
| 81 |
+
using IteratorB = typename Mma::IteratorB;
|
| 82 |
+
|
| 83 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 84 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 88 |
+
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
| 89 |
+
template<
|
| 90 |
+
/// Layout type for A matrix operand
|
| 91 |
+
typename LayoutA,
|
| 92 |
+
/// Access granularity of A matrix in units of elements
|
| 93 |
+
int kAlignmentA,
|
| 94 |
+
/// Layout type for B matrix operand
|
| 95 |
+
typename LayoutB,
|
| 96 |
+
/// Access granularity of B matrix in units of elements
|
| 97 |
+
int kAlignmentB,
|
| 98 |
+
/// Element type for internal accumulation
|
| 99 |
+
typename ElementAccumulator,
|
| 100 |
+
/// Tag indicating architecture to tune for
|
| 101 |
+
typename ArchTag,
|
| 102 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 103 |
+
typename ThreadblockShape,
|
| 104 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 105 |
+
typename WarpShape,
|
| 106 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 107 |
+
typename InstructionShape,
|
| 108 |
+
/// Operation performed by GEMM
|
| 109 |
+
typename Operator>
|
| 110 |
+
struct DefaultMma<cutlass::half_t,
|
| 111 |
+
LayoutA,
|
| 112 |
+
kAlignmentA,
|
| 113 |
+
uint4b_t,
|
| 114 |
+
LayoutB,
|
| 115 |
+
kAlignmentB,
|
| 116 |
+
ElementAccumulator,
|
| 117 |
+
layout::RowMajor,
|
| 118 |
+
arch::OpClassTensorOp,
|
| 119 |
+
ArchTag,
|
| 120 |
+
ThreadblockShape,
|
| 121 |
+
WarpShape,
|
| 122 |
+
InstructionShape,
|
| 123 |
+
2,
|
| 124 |
+
Operator> {
|
| 125 |
+
|
| 126 |
+
private:
|
| 127 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
| 128 |
+
|
| 129 |
+
using Mma = DqMma<half_t,
|
| 130 |
+
LayoutA,
|
| 131 |
+
kAlignmentA,
|
| 132 |
+
uint4b_t,
|
| 133 |
+
LayoutB,
|
| 134 |
+
kAlignmentB,
|
| 135 |
+
half_t,
|
| 136 |
+
layout::RowMajor,
|
| 137 |
+
kAlignmentScale,
|
| 138 |
+
ElementAccumulator,
|
| 139 |
+
layout::RowMajor,
|
| 140 |
+
arch::OpClassTensorOp,
|
| 141 |
+
ArchTag,
|
| 142 |
+
ThreadblockShape,
|
| 143 |
+
WarpShape,
|
| 144 |
+
InstructionShape,
|
| 145 |
+
2,
|
| 146 |
+
Operator>;
|
| 147 |
+
|
| 148 |
+
public:
|
| 149 |
+
// Define the MmaCore components
|
| 150 |
+
using MmaCore = typename Mma::MmaCore;
|
| 151 |
+
|
| 152 |
+
// Define iterators over tiles from the A operand
|
| 153 |
+
using IteratorA = typename Mma::IteratorA;
|
| 154 |
+
|
| 155 |
+
// Define iterators over tiles from the B operand
|
| 156 |
+
using IteratorB = typename Mma::IteratorB;
|
| 157 |
+
|
| 158 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 159 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
| 160 |
+
};
|
| 161 |
+
|
| 162 |
+
template<
|
| 163 |
+
/// Layout type for A matrix operand
|
| 164 |
+
typename LayoutA,
|
| 165 |
+
/// Access granularity of A matrix in units of elements
|
| 166 |
+
int kAlignmentA,
|
| 167 |
+
/// Layout type for B matrix operand
|
| 168 |
+
typename LayoutB,
|
| 169 |
+
/// Access granularity of B matrix in units of elements
|
| 170 |
+
int kAlignmentB,
|
| 171 |
+
/// Element type for internal accumulation
|
| 172 |
+
typename ElementAccumulator,
|
| 173 |
+
/// Tag indicating architecture to tune for
|
| 174 |
+
typename ArchTag,
|
| 175 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 176 |
+
typename ThreadblockShape,
|
| 177 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 178 |
+
typename WarpShape,
|
| 179 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 180 |
+
typename InstructionShape,
|
| 181 |
+
/// Operation performed by GEMM
|
| 182 |
+
typename Operator,
|
| 183 |
+
///
|
| 184 |
+
int kStages,
|
| 185 |
+
/// Shared memory clear option
|
| 186 |
+
SharedMemoryClearOption SharedMemoryClear>
|
| 187 |
+
struct DefaultMma<cutlass::half_t,
|
| 188 |
+
LayoutA,
|
| 189 |
+
kAlignmentA,
|
| 190 |
+
uint8_t,
|
| 191 |
+
LayoutB,
|
| 192 |
+
kAlignmentB,
|
| 193 |
+
ElementAccumulator,
|
| 194 |
+
layout::RowMajor,
|
| 195 |
+
arch::OpClassTensorOp,
|
| 196 |
+
ArchTag,
|
| 197 |
+
ThreadblockShape,
|
| 198 |
+
WarpShape,
|
| 199 |
+
InstructionShape,
|
| 200 |
+
kStages,
|
| 201 |
+
Operator,
|
| 202 |
+
false,
|
| 203 |
+
SharedMemoryClear> {
|
| 204 |
+
|
| 205 |
+
private:
|
| 206 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
| 207 |
+
|
| 208 |
+
using Mma = DqMma<half_t,
|
| 209 |
+
LayoutA,
|
| 210 |
+
kAlignmentA,
|
| 211 |
+
uint8_t,
|
| 212 |
+
LayoutB,
|
| 213 |
+
kAlignmentB,
|
| 214 |
+
half_t,
|
| 215 |
+
layout::RowMajor,
|
| 216 |
+
kAlignmentScale,
|
| 217 |
+
ElementAccumulator,
|
| 218 |
+
layout::RowMajor,
|
| 219 |
+
arch::OpClassTensorOp,
|
| 220 |
+
ArchTag,
|
| 221 |
+
ThreadblockShape,
|
| 222 |
+
WarpShape,
|
| 223 |
+
InstructionShape,
|
| 224 |
+
kStages,
|
| 225 |
+
Operator,
|
| 226 |
+
SharedMemoryClear>;
|
| 227 |
+
|
| 228 |
+
public:
|
| 229 |
+
// Define the MmaCore components
|
| 230 |
+
using MmaCore = typename Mma::MmaCore;
|
| 231 |
+
|
| 232 |
+
// Define iterators over tiles from the A operand
|
| 233 |
+
using IteratorA = typename Mma::IteratorA;
|
| 234 |
+
|
| 235 |
+
// Define iterators over tiles from the B operand
|
| 236 |
+
using IteratorB = typename Mma::IteratorB;
|
| 237 |
+
|
| 238 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 239 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
| 240 |
+
};
|
| 241 |
+
|
| 242 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 243 |
+
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
| 244 |
+
template<
|
| 245 |
+
/// Layout type for A matrix operand
|
| 246 |
+
typename LayoutA,
|
| 247 |
+
/// Access granularity of A matrix in units of elements
|
| 248 |
+
int kAlignmentA,
|
| 249 |
+
/// Layout type for B matrix operand
|
| 250 |
+
typename LayoutB,
|
| 251 |
+
/// Access granularity of B matrix in units of elements
|
| 252 |
+
int kAlignmentB,
|
| 253 |
+
/// Element type for internal accumulation
|
| 254 |
+
typename ElementAccumulator,
|
| 255 |
+
/// Tag indicating architecture to tune for
|
| 256 |
+
typename ArchTag,
|
| 257 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 258 |
+
typename ThreadblockShape,
|
| 259 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 260 |
+
typename WarpShape,
|
| 261 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 262 |
+
typename InstructionShape,
|
| 263 |
+
/// Operation performed by GEMM
|
| 264 |
+
typename Operator,
|
| 265 |
+
///
|
| 266 |
+
int kStages,
|
| 267 |
+
/// Shared memory clear option
|
| 268 |
+
SharedMemoryClearOption SharedMemoryClear>
|
| 269 |
+
struct DefaultMma<cutlass::half_t,
|
| 270 |
+
LayoutA,
|
| 271 |
+
kAlignmentA,
|
| 272 |
+
uint4b_t,
|
| 273 |
+
LayoutB,
|
| 274 |
+
kAlignmentB,
|
| 275 |
+
ElementAccumulator,
|
| 276 |
+
layout::RowMajor,
|
| 277 |
+
arch::OpClassTensorOp,
|
| 278 |
+
ArchTag,
|
| 279 |
+
ThreadblockShape,
|
| 280 |
+
WarpShape,
|
| 281 |
+
InstructionShape,
|
| 282 |
+
kStages,
|
| 283 |
+
Operator,
|
| 284 |
+
false,
|
| 285 |
+
SharedMemoryClear> {
|
| 286 |
+
|
| 287 |
+
private:
|
| 288 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
| 289 |
+
|
| 290 |
+
using Mma = DqMma<half_t,
|
| 291 |
+
LayoutA,
|
| 292 |
+
kAlignmentA,
|
| 293 |
+
uint4b_t,
|
| 294 |
+
LayoutB,
|
| 295 |
+
kAlignmentB,
|
| 296 |
+
half_t,
|
| 297 |
+
layout::RowMajor,
|
| 298 |
+
kAlignmentScale,
|
| 299 |
+
ElementAccumulator,
|
| 300 |
+
layout::RowMajor,
|
| 301 |
+
arch::OpClassTensorOp,
|
| 302 |
+
ArchTag,
|
| 303 |
+
ThreadblockShape,
|
| 304 |
+
WarpShape,
|
| 305 |
+
InstructionShape,
|
| 306 |
+
kStages,
|
| 307 |
+
Operator,
|
| 308 |
+
SharedMemoryClear>;
|
| 309 |
+
|
| 310 |
+
public:
|
| 311 |
+
// Define the MmaCore components
|
| 312 |
+
using MmaCore = typename Mma::MmaCore;
|
| 313 |
+
|
| 314 |
+
// Define iterators over tiles from the A operand
|
| 315 |
+
using IteratorA = typename Mma::IteratorA;
|
| 316 |
+
|
| 317 |
+
// Define iterators over tiles from the B operand
|
| 318 |
+
using IteratorB = typename Mma::IteratorB;
|
| 319 |
+
|
| 320 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 321 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
| 322 |
+
};
|
| 323 |
+
|
| 324 |
+
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
| 325 |
+
// large tile when not enough shared mem is present to do 3+ stage
|
| 326 |
+
template<
|
| 327 |
+
/// Layout type for A matrix operand
|
| 328 |
+
typename LayoutA,
|
| 329 |
+
/// Access granularity of A matrix in units of elements
|
| 330 |
+
int kAlignmentA,
|
| 331 |
+
/// Layout type for B matrix operand
|
| 332 |
+
typename LayoutB,
|
| 333 |
+
/// Access granularity of B matrix in units of elements
|
| 334 |
+
int kAlignmentB,
|
| 335 |
+
/// Element type for internal accumulation
|
| 336 |
+
typename ElementAccumulator,
|
| 337 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 338 |
+
typename ThreadblockShape,
|
| 339 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 340 |
+
typename WarpShape,
|
| 341 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 342 |
+
typename InstructionShape,
|
| 343 |
+
/// Operation performed by GEMM
|
| 344 |
+
typename Operator,
|
| 345 |
+
/// Use zfill or predicate for out-of-bound cp.async
|
| 346 |
+
SharedMemoryClearOption SharedMemoryClear,
|
| 347 |
+
/// Gather operand A by using an index array
|
| 348 |
+
bool GatherA,
|
| 349 |
+
/// Gather operand B by using an index array
|
| 350 |
+
bool GatherB>
|
| 351 |
+
struct DefaultMma<half_t,
|
| 352 |
+
LayoutA,
|
| 353 |
+
kAlignmentA,
|
| 354 |
+
half_t,
|
| 355 |
+
LayoutB,
|
| 356 |
+
kAlignmentB,
|
| 357 |
+
ElementAccumulator,
|
| 358 |
+
layout::RowMajor,
|
| 359 |
+
arch::OpClassTensorOp,
|
| 360 |
+
arch::Sm80,
|
| 361 |
+
ThreadblockShape,
|
| 362 |
+
WarpShape,
|
| 363 |
+
InstructionShape,
|
| 364 |
+
2,
|
| 365 |
+
Operator,
|
| 366 |
+
false,
|
| 367 |
+
SharedMemoryClear,
|
| 368 |
+
GatherA,
|
| 369 |
+
GatherB> {
|
| 370 |
+
|
| 371 |
+
// Define the MmaCore components
|
| 372 |
+
// 3 is used on purpose here to trigger components for mma multistage
|
| 373 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
| 374 |
+
WarpShape,
|
| 375 |
+
InstructionShape,
|
| 376 |
+
half_t,
|
| 377 |
+
LayoutA,
|
| 378 |
+
half_t,
|
| 379 |
+
LayoutB,
|
| 380 |
+
ElementAccumulator,
|
| 381 |
+
layout::RowMajor,
|
| 382 |
+
arch::OpClassTensorOp,
|
| 383 |
+
3,
|
| 384 |
+
Operator>;
|
| 385 |
+
|
| 386 |
+
// Define iterators over tiles from the A operand
|
| 387 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 388 |
+
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
| 389 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
| 390 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 391 |
+
half_t,
|
| 392 |
+
LayoutA,
|
| 393 |
+
1,
|
| 394 |
+
ThreadMapA,
|
| 395 |
+
AccessTypeA,
|
| 396 |
+
GatherA>;
|
| 397 |
+
|
| 398 |
+
// Define iterators over tiles from the B operand
|
| 399 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 400 |
+
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
| 401 |
+
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
| 402 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 403 |
+
half_t,
|
| 404 |
+
LayoutB,
|
| 405 |
+
0,
|
| 406 |
+
ThreadMapB,
|
| 407 |
+
AccessTypeB,
|
| 408 |
+
GatherB>;
|
| 409 |
+
|
| 410 |
+
// Define the threadblock-scoped multistage matrix multiply
|
| 411 |
+
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
|
| 412 |
+
IteratorA,
|
| 413 |
+
typename MmaCore::SmemIteratorA,
|
| 414 |
+
MmaCore::kCacheOpA,
|
| 415 |
+
IteratorB,
|
| 416 |
+
typename MmaCore::SmemIteratorB,
|
| 417 |
+
MmaCore::kCacheOpB,
|
| 418 |
+
ElementAccumulator,
|
| 419 |
+
layout::RowMajor,
|
| 420 |
+
typename MmaCore::MmaPolicy,
|
| 421 |
+
2>;
|
| 422 |
+
};
|
| 423 |
+
|
| 424 |
+
} // namespace threadblock
|
| 425 |
+
} // namespace gemm
|
| 426 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "cutlass/gemm/threadblock/default_mma.h"
|
| 4 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
| 5 |
+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
| 6 |
+
|
| 7 |
+
namespace cutlass {
|
| 8 |
+
namespace gemm {
|
| 9 |
+
namespace threadblock {
|
| 10 |
+
|
| 11 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 12 |
+
|
| 13 |
+
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
|
| 14 |
+
template<
|
| 15 |
+
/// Layout type for A matrix operand
|
| 16 |
+
typename LayoutA,
|
| 17 |
+
/// Access granularity of A matrix in units of elements
|
| 18 |
+
int kAlignmentA,
|
| 19 |
+
/// Layout type for B matrix operand
|
| 20 |
+
typename LayoutB,
|
| 21 |
+
/// Access granularity of B matrix in units of elements
|
| 22 |
+
int kAlignmentB,
|
| 23 |
+
/// Element type for internal accumulation
|
| 24 |
+
typename ElementAccumulator,
|
| 25 |
+
/// Tag indicating architecture to tune for
|
| 26 |
+
typename ArchTag,
|
| 27 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 28 |
+
typename ThreadblockShape,
|
| 29 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 30 |
+
typename WarpShape,
|
| 31 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 32 |
+
typename InstructionShape,
|
| 33 |
+
/// Operation performed by GEMM
|
| 34 |
+
typename Operator,
|
| 35 |
+
/// Use zfill or predicate for out-of-bound cp.async
|
| 36 |
+
SharedMemoryClearOption SharedMemoryClear,
|
| 37 |
+
/// Gather operand A by using an index array
|
| 38 |
+
bool GatherA,
|
| 39 |
+
/// Gather operand B by using an index array
|
| 40 |
+
bool GatherB>
|
| 41 |
+
struct DefaultMma<bfloat16_t,
|
| 42 |
+
LayoutA,
|
| 43 |
+
kAlignmentA,
|
| 44 |
+
bfloat16_t,
|
| 45 |
+
LayoutB,
|
| 46 |
+
kAlignmentB,
|
| 47 |
+
ElementAccumulator,
|
| 48 |
+
layout::RowMajor,
|
| 49 |
+
arch::OpClassTensorOp,
|
| 50 |
+
ArchTag,
|
| 51 |
+
ThreadblockShape,
|
| 52 |
+
WarpShape,
|
| 53 |
+
InstructionShape,
|
| 54 |
+
2,
|
| 55 |
+
Operator,
|
| 56 |
+
false,
|
| 57 |
+
SharedMemoryClear,
|
| 58 |
+
GatherA,
|
| 59 |
+
GatherB> {
|
| 60 |
+
|
| 61 |
+
private:
|
| 62 |
+
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
|
| 63 |
+
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
| 64 |
+
using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
| 65 |
+
using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
| 66 |
+
|
| 67 |
+
public:
|
| 68 |
+
// Define the MmaCore components
|
| 69 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
| 70 |
+
WarpShape,
|
| 71 |
+
InstructionShape,
|
| 72 |
+
MmaElementA,
|
| 73 |
+
LayoutA,
|
| 74 |
+
MmaElementB,
|
| 75 |
+
LayoutB,
|
| 76 |
+
ElementAccumulator,
|
| 77 |
+
layout::RowMajor,
|
| 78 |
+
arch::OpClassTensorOp,
|
| 79 |
+
2,
|
| 80 |
+
Operator>;
|
| 81 |
+
|
| 82 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
| 83 |
+
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
|
| 84 |
+
bfloat16_t,
|
| 85 |
+
LayoutA,
|
| 86 |
+
1,
|
| 87 |
+
typename MmaCore::IteratorThreadMapA,
|
| 88 |
+
kAlignmentA,
|
| 89 |
+
GatherA>;
|
| 90 |
+
|
| 91 |
+
// Define iterators over tiles from the B operand
|
| 92 |
+
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
| 93 |
+
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
|
| 94 |
+
bfloat16_t,
|
| 95 |
+
LayoutB,
|
| 96 |
+
0,
|
| 97 |
+
typename MmaCore::IteratorThreadMapB,
|
| 98 |
+
kAlignmentB,
|
| 99 |
+
GatherB>;
|
| 100 |
+
|
| 101 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 102 |
+
using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape,
|
| 103 |
+
IteratorA,
|
| 104 |
+
typename MmaCore::SmemIteratorA,
|
| 105 |
+
IteratorB,
|
| 106 |
+
typename MmaCore::SmemIteratorB,
|
| 107 |
+
ElementAccumulator,
|
| 108 |
+
layout::RowMajor,
|
| 109 |
+
typename MmaCore::MmaPolicy>;
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
| 113 |
+
// large tile when not enough shared mem is present to do 3+ stage
|
| 114 |
+
template<
|
| 115 |
+
/// Layout type for A matrix operand
|
| 116 |
+
typename LayoutA,
|
| 117 |
+
/// Access granularity of A matrix in units of elements
|
| 118 |
+
int kAlignmentA,
|
| 119 |
+
/// Layout type for B matrix operand
|
| 120 |
+
typename LayoutB,
|
| 121 |
+
/// Access granularity of B matrix in units of elements
|
| 122 |
+
int kAlignmentB,
|
| 123 |
+
/// Element type for internal accumulation
|
| 124 |
+
typename ElementAccumulator,
|
| 125 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 126 |
+
typename ThreadblockShape,
|
| 127 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 128 |
+
typename WarpShape,
|
| 129 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 130 |
+
typename InstructionShape,
|
| 131 |
+
/// Operation performed by GEMM
|
| 132 |
+
typename Operator,
|
| 133 |
+
/// Use zfill or predicate for out-of-bound cp.async
|
| 134 |
+
SharedMemoryClearOption SharedMemoryClear,
|
| 135 |
+
/// Gather operand A by using an index array
|
| 136 |
+
bool GatherA,
|
| 137 |
+
/// Gather operand B by using an index array
|
| 138 |
+
bool GatherB>
|
| 139 |
+
struct DefaultMma<bfloat16_t,
|
| 140 |
+
LayoutA,
|
| 141 |
+
kAlignmentA,
|
| 142 |
+
bfloat16_t,
|
| 143 |
+
LayoutB,
|
| 144 |
+
kAlignmentB,
|
| 145 |
+
ElementAccumulator,
|
| 146 |
+
layout::RowMajor,
|
| 147 |
+
arch::OpClassTensorOp,
|
| 148 |
+
arch::Sm80,
|
| 149 |
+
ThreadblockShape,
|
| 150 |
+
WarpShape,
|
| 151 |
+
InstructionShape,
|
| 152 |
+
2,
|
| 153 |
+
Operator,
|
| 154 |
+
false,
|
| 155 |
+
SharedMemoryClear,
|
| 156 |
+
GatherA,
|
| 157 |
+
GatherB> {
|
| 158 |
+
|
| 159 |
+
// Define the MmaCore components
|
| 160 |
+
// 3 is used on purpose here to trigger components for mma multistage
|
| 161 |
+
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
|
| 162 |
+
WarpShape,
|
| 163 |
+
InstructionShape,
|
| 164 |
+
bfloat16_t,
|
| 165 |
+
LayoutA,
|
| 166 |
+
bfloat16_t,
|
| 167 |
+
LayoutB,
|
| 168 |
+
ElementAccumulator,
|
| 169 |
+
layout::RowMajor,
|
| 170 |
+
arch::OpClassTensorOp,
|
| 171 |
+
3,
|
| 172 |
+
Operator>;
|
| 173 |
+
|
| 174 |
+
// Define iterators over tiles from the A operand
|
| 175 |
+
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
| 176 |
+
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
| 177 |
+
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
| 178 |
+
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
|
| 179 |
+
bfloat16_t,
|
| 180 |
+
LayoutA,
|
| 181 |
+
1,
|
| 182 |
+
ThreadMapA,
|
| 183 |
+
AccessTypeA,
|
| 184 |
+
GatherA>;
|
| 185 |
+
|
| 186 |
+
// Define iterators over tiles from the B operand
|
| 187 |
+
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
| 188 |
+
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
| 189 |
+
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
| 190 |
+
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
|
| 191 |
+
bfloat16_t,
|
| 192 |
+
LayoutB,
|
| 193 |
+
0,
|
| 194 |
+
ThreadMapB,
|
| 195 |
+
AccessTypeB,
|
| 196 |
+
GatherB>;
|
| 197 |
+
|
| 198 |
+
// Define the threadblock-scoped multistage matrix multiply
|
| 199 |
+
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
|
| 200 |
+
IteratorA,
|
| 201 |
+
typename MmaCore::SmemIteratorA,
|
| 202 |
+
MmaCore::kCacheOpA,
|
| 203 |
+
IteratorB,
|
| 204 |
+
typename MmaCore::SmemIteratorB,
|
| 205 |
+
MmaCore::kCacheOpB,
|
| 206 |
+
ElementAccumulator,
|
| 207 |
+
layout::RowMajor,
|
| 208 |
+
typename MmaCore::MmaPolicy,
|
| 209 |
+
2>;
|
| 210 |
+
};
|
| 211 |
+
|
| 212 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 213 |
+
|
| 214 |
+
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
|
| 215 |
+
template<
|
| 216 |
+
/// Layout type for A matrix operand
|
| 217 |
+
typename LayoutA,
|
| 218 |
+
/// Access granularity of A matrix in units of elements
|
| 219 |
+
int kAlignmentA,
|
| 220 |
+
/// Layout type for B matrix operand
|
| 221 |
+
typename LayoutB,
|
| 222 |
+
/// Access granularity of B matrix in units of elements
|
| 223 |
+
int kAlignmentB,
|
| 224 |
+
/// Element type for internal accumulation
|
| 225 |
+
typename ElementAccumulator,
|
| 226 |
+
/// Tag indicating architecture to tune for
|
| 227 |
+
typename ArchTag,
|
| 228 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 229 |
+
typename ThreadblockShape,
|
| 230 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 231 |
+
typename WarpShape,
|
| 232 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 233 |
+
typename InstructionShape,
|
| 234 |
+
/// Operation performed by GEMM
|
| 235 |
+
typename Operator>
|
| 236 |
+
struct DefaultMma<cutlass::bfloat16_t,
|
| 237 |
+
LayoutA,
|
| 238 |
+
kAlignmentA,
|
| 239 |
+
uint8_t,
|
| 240 |
+
LayoutB,
|
| 241 |
+
kAlignmentB,
|
| 242 |
+
ElementAccumulator,
|
| 243 |
+
layout::RowMajor,
|
| 244 |
+
arch::OpClassTensorOp,
|
| 245 |
+
ArchTag,
|
| 246 |
+
ThreadblockShape,
|
| 247 |
+
WarpShape,
|
| 248 |
+
InstructionShape,
|
| 249 |
+
2,
|
| 250 |
+
Operator> {
|
| 251 |
+
|
| 252 |
+
private:
|
| 253 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
| 254 |
+
|
| 255 |
+
using Mma = DqMma<bfloat16_t,
|
| 256 |
+
LayoutA,
|
| 257 |
+
kAlignmentA,
|
| 258 |
+
uint8_t,
|
| 259 |
+
LayoutB,
|
| 260 |
+
kAlignmentB,
|
| 261 |
+
bfloat16_t,
|
| 262 |
+
layout::RowMajor,
|
| 263 |
+
kAlignmentScale,
|
| 264 |
+
ElementAccumulator,
|
| 265 |
+
layout::RowMajor,
|
| 266 |
+
arch::OpClassTensorOp,
|
| 267 |
+
ArchTag,
|
| 268 |
+
ThreadblockShape,
|
| 269 |
+
WarpShape,
|
| 270 |
+
InstructionShape,
|
| 271 |
+
2,
|
| 272 |
+
Operator>;
|
| 273 |
+
|
| 274 |
+
public:
|
| 275 |
+
// Define the MmaCore components
|
| 276 |
+
using MmaCore = typename Mma::MmaCore;
|
| 277 |
+
|
| 278 |
+
// Define iterators over tiles from the A operand
|
| 279 |
+
using IteratorA = typename Mma::IteratorA;
|
| 280 |
+
|
| 281 |
+
// Define iterators over tiles from the B operand
|
| 282 |
+
using IteratorB = typename Mma::IteratorB;
|
| 283 |
+
|
| 284 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 285 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
| 286 |
+
};
|
| 287 |
+
|
| 288 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 289 |
+
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
|
| 290 |
+
template<
|
| 291 |
+
/// Layout type for A matrix operand
|
| 292 |
+
typename LayoutA,
|
| 293 |
+
/// Access granularity of A matrix in units of elements
|
| 294 |
+
int kAlignmentA,
|
| 295 |
+
/// Layout type for B matrix operand
|
| 296 |
+
typename LayoutB,
|
| 297 |
+
/// Access granularity of B matrix in units of elements
|
| 298 |
+
int kAlignmentB,
|
| 299 |
+
/// Element type for internal accumulation
|
| 300 |
+
typename ElementAccumulator,
|
| 301 |
+
/// Tag indicating architecture to tune for
|
| 302 |
+
typename ArchTag,
|
| 303 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 304 |
+
typename ThreadblockShape,
|
| 305 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 306 |
+
typename WarpShape,
|
| 307 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 308 |
+
typename InstructionShape,
|
| 309 |
+
/// Operation performed by GEMM
|
| 310 |
+
typename Operator>
|
| 311 |
+
struct DefaultMma<cutlass::bfloat16_t,
|
| 312 |
+
LayoutA,
|
| 313 |
+
kAlignmentA,
|
| 314 |
+
uint4b_t,
|
| 315 |
+
LayoutB,
|
| 316 |
+
kAlignmentB,
|
| 317 |
+
ElementAccumulator,
|
| 318 |
+
layout::RowMajor,
|
| 319 |
+
arch::OpClassTensorOp,
|
| 320 |
+
ArchTag,
|
| 321 |
+
ThreadblockShape,
|
| 322 |
+
WarpShape,
|
| 323 |
+
InstructionShape,
|
| 324 |
+
2,
|
| 325 |
+
Operator> {
|
| 326 |
+
|
| 327 |
+
private:
|
| 328 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
| 329 |
+
|
| 330 |
+
using Mma = DqMma<bfloat16_t,
|
| 331 |
+
LayoutA,
|
| 332 |
+
kAlignmentA,
|
| 333 |
+
uint4b_t,
|
| 334 |
+
LayoutB,
|
| 335 |
+
kAlignmentB,
|
| 336 |
+
bfloat16_t,
|
| 337 |
+
layout::RowMajor,
|
| 338 |
+
kAlignmentScale,
|
| 339 |
+
ElementAccumulator,
|
| 340 |
+
layout::RowMajor,
|
| 341 |
+
arch::OpClassTensorOp,
|
| 342 |
+
ArchTag,
|
| 343 |
+
ThreadblockShape,
|
| 344 |
+
WarpShape,
|
| 345 |
+
InstructionShape,
|
| 346 |
+
2,
|
| 347 |
+
Operator>;
|
| 348 |
+
|
| 349 |
+
public:
|
| 350 |
+
// Define the MmaCore components
|
| 351 |
+
using MmaCore = typename Mma::MmaCore;
|
| 352 |
+
|
| 353 |
+
// Define iterators over tiles from the A operand
|
| 354 |
+
using IteratorA = typename Mma::IteratorA;
|
| 355 |
+
|
| 356 |
+
// Define iterators over tiles from the B operand
|
| 357 |
+
using IteratorB = typename Mma::IteratorB;
|
| 358 |
+
|
| 359 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 360 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
| 361 |
+
};
|
| 362 |
+
|
| 363 |
+
template<
|
| 364 |
+
/// Layout type for A matrix operand
|
| 365 |
+
typename LayoutA,
|
| 366 |
+
/// Access granularity of A matrix in units of elements
|
| 367 |
+
int kAlignmentA,
|
| 368 |
+
/// Layout type for B matrix operand
|
| 369 |
+
typename LayoutB,
|
| 370 |
+
/// Access granularity of B matrix in units of elements
|
| 371 |
+
int kAlignmentB,
|
| 372 |
+
/// Element type for internal accumulation
|
| 373 |
+
typename ElementAccumulator,
|
| 374 |
+
/// Tag indicating architecture to tune for
|
| 375 |
+
typename ArchTag,
|
| 376 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 377 |
+
typename ThreadblockShape,
|
| 378 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 379 |
+
typename WarpShape,
|
| 380 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 381 |
+
typename InstructionShape,
|
| 382 |
+
/// Operation performed by GEMM
|
| 383 |
+
typename Operator,
|
| 384 |
+
///
|
| 385 |
+
int kStages,
|
| 386 |
+
/// Shared memory clear option
|
| 387 |
+
SharedMemoryClearOption SharedMemoryClear>
|
| 388 |
+
struct DefaultMma<cutlass::bfloat16_t,
|
| 389 |
+
LayoutA,
|
| 390 |
+
kAlignmentA,
|
| 391 |
+
uint8_t,
|
| 392 |
+
LayoutB,
|
| 393 |
+
kAlignmentB,
|
| 394 |
+
ElementAccumulator,
|
| 395 |
+
layout::RowMajor,
|
| 396 |
+
arch::OpClassTensorOp,
|
| 397 |
+
ArchTag,
|
| 398 |
+
ThreadblockShape,
|
| 399 |
+
WarpShape,
|
| 400 |
+
InstructionShape,
|
| 401 |
+
kStages,
|
| 402 |
+
Operator,
|
| 403 |
+
false,
|
| 404 |
+
SharedMemoryClear> {
|
| 405 |
+
|
| 406 |
+
private:
|
| 407 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
| 408 |
+
|
| 409 |
+
using Mma = DqMma<bfloat16_t,
|
| 410 |
+
LayoutA,
|
| 411 |
+
kAlignmentA,
|
| 412 |
+
uint8_t,
|
| 413 |
+
LayoutB,
|
| 414 |
+
kAlignmentB,
|
| 415 |
+
bfloat16_t,
|
| 416 |
+
layout::RowMajor,
|
| 417 |
+
kAlignmentScale,
|
| 418 |
+
ElementAccumulator,
|
| 419 |
+
layout::RowMajor,
|
| 420 |
+
arch::OpClassTensorOp,
|
| 421 |
+
ArchTag,
|
| 422 |
+
ThreadblockShape,
|
| 423 |
+
WarpShape,
|
| 424 |
+
InstructionShape,
|
| 425 |
+
kStages,
|
| 426 |
+
Operator,
|
| 427 |
+
SharedMemoryClear>;
|
| 428 |
+
|
| 429 |
+
public:
|
| 430 |
+
// Define the MmaCore components
|
| 431 |
+
using MmaCore = typename Mma::MmaCore;
|
| 432 |
+
|
| 433 |
+
// Define iterators over tiles from the A operand
|
| 434 |
+
using IteratorA = typename Mma::IteratorA;
|
| 435 |
+
|
| 436 |
+
// Define iterators over tiles from the B operand
|
| 437 |
+
using IteratorB = typename Mma::IteratorB;
|
| 438 |
+
|
| 439 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 440 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
| 441 |
+
};
|
| 442 |
+
|
| 443 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 444 |
+
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
| 445 |
+
template<
|
| 446 |
+
/// Layout type for A matrix operand
|
| 447 |
+
typename LayoutA,
|
| 448 |
+
/// Access granularity of A matrix in units of elements
|
| 449 |
+
int kAlignmentA,
|
| 450 |
+
/// Layout type for B matrix operand
|
| 451 |
+
typename LayoutB,
|
| 452 |
+
/// Access granularity of B matrix in units of elements
|
| 453 |
+
int kAlignmentB,
|
| 454 |
+
/// Element type for internal accumulation
|
| 455 |
+
typename ElementAccumulator,
|
| 456 |
+
/// Tag indicating architecture to tune for
|
| 457 |
+
typename ArchTag,
|
| 458 |
+
/// Threadblock-level tile size (concept: GemmShape)
|
| 459 |
+
typename ThreadblockShape,
|
| 460 |
+
/// Warp-level tile size (concept: GemmShape)
|
| 461 |
+
typename WarpShape,
|
| 462 |
+
/// Instruction-level tile size (concept: GemmShape)
|
| 463 |
+
typename InstructionShape,
|
| 464 |
+
/// Operation performed by GEMM
|
| 465 |
+
typename Operator,
|
| 466 |
+
///
|
| 467 |
+
int kStages,
|
| 468 |
+
/// Shared memory clear option
|
| 469 |
+
SharedMemoryClearOption SharedMemoryClear>
|
| 470 |
+
struct DefaultMma<cutlass::bfloat16_t,
|
| 471 |
+
LayoutA,
|
| 472 |
+
kAlignmentA,
|
| 473 |
+
uint4b_t,
|
| 474 |
+
LayoutB,
|
| 475 |
+
kAlignmentB,
|
| 476 |
+
ElementAccumulator,
|
| 477 |
+
layout::RowMajor,
|
| 478 |
+
arch::OpClassTensorOp,
|
| 479 |
+
ArchTag,
|
| 480 |
+
ThreadblockShape,
|
| 481 |
+
WarpShape,
|
| 482 |
+
InstructionShape,
|
| 483 |
+
kStages,
|
| 484 |
+
Operator,
|
| 485 |
+
false,
|
| 486 |
+
SharedMemoryClear> {
|
| 487 |
+
|
| 488 |
+
private:
|
| 489 |
+
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
| 490 |
+
|
| 491 |
+
using Mma = DqMma<bfloat16_t,
|
| 492 |
+
LayoutA,
|
| 493 |
+
kAlignmentA,
|
| 494 |
+
uint4b_t,
|
| 495 |
+
LayoutB,
|
| 496 |
+
kAlignmentB,
|
| 497 |
+
bfloat16_t,
|
| 498 |
+
layout::RowMajor,
|
| 499 |
+
kAlignmentScale,
|
| 500 |
+
ElementAccumulator,
|
| 501 |
+
layout::RowMajor,
|
| 502 |
+
arch::OpClassTensorOp,
|
| 503 |
+
ArchTag,
|
| 504 |
+
ThreadblockShape,
|
| 505 |
+
WarpShape,
|
| 506 |
+
InstructionShape,
|
| 507 |
+
kStages,
|
| 508 |
+
Operator,
|
| 509 |
+
SharedMemoryClear>;
|
| 510 |
+
|
| 511 |
+
public:
|
| 512 |
+
// Define the MmaCore components
|
| 513 |
+
using MmaCore = typename Mma::MmaCore;
|
| 514 |
+
|
| 515 |
+
// Define iterators over tiles from the A operand
|
| 516 |
+
using IteratorA = typename Mma::IteratorA;
|
| 517 |
+
|
| 518 |
+
// Define iterators over tiles from the B operand
|
| 519 |
+
using IteratorB = typename Mma::IteratorB;
|
| 520 |
+
|
| 521 |
+
// Define the threadblock-scoped pipelined matrix multiply
|
| 522 |
+
using ThreadblockMma = typename Mma::ThreadblockMma;
|
| 523 |
+
};
|
| 524 |
+
|
| 525 |
+
} // namespace threadblock
|
| 526 |
+
} // namespace gemm
|
| 527 |
+
} // namespace cutlass
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/aligned_buffer.h"
|
| 38 |
+
#include "cutlass/arch/memory.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/gemm/gemm.h"
|
| 42 |
+
#include "cutlass/gemm/threadblock/mma_base.h"
|
| 43 |
+
#include "cutlass/matrix_shape.h"
|
| 44 |
+
#include "cutlass/numeric_types.h"
|
| 45 |
+
|
| 46 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace gemm {
|
| 50 |
+
namespace threadblock {
|
| 51 |
+
|
| 52 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
|
| 54 |
+
// correct warp level mma. On volta, all data is stored to shared memory as FP16.
|
| 55 |
+
template<typename WarpMma, int kExpansionFactor = 1>
|
| 56 |
+
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
|
| 57 |
+
typename WarpMma::FragmentC& D,
|
| 58 |
+
typename WarpMma::FragmentA const& A,
|
| 59 |
+
typename WarpMma::FragmentB const& B,
|
| 60 |
+
typename WarpMma::FragmentC const& C,
|
| 61 |
+
const int warp_tileB_k_offset)
|
| 62 |
+
{
|
| 63 |
+
warp_mma(D, A, B, C);
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
template<typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
|
| 67 |
+
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
|
| 68 |
+
typename WarpMma::FragmentC& D,
|
| 69 |
+
typename WarpMma::TransformedFragmentA const& A,
|
| 70 |
+
typename WarpMma::TransformedFragmentB const& B,
|
| 71 |
+
typename WarpMma::FragmentC const& C,
|
| 72 |
+
const int warp_tileB_k_offset)
|
| 73 |
+
{
|
| 74 |
+
warp_mma(D, A, B, C, warp_tileB_k_offset);
|
| 75 |
+
}
|
| 76 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 77 |
+
|
| 78 |
+
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
| 79 |
+
/// instructions.
|
| 80 |
+
template<
|
| 81 |
+
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 82 |
+
typename Shape_,
|
| 83 |
+
/// Policy describing tuning details (concept: MmaPolicy)
|
| 84 |
+
typename Policy_,
|
| 85 |
+
/// The type of the scales
|
| 86 |
+
typename ElementScale_,
|
| 87 |
+
/// Number of stages,
|
| 88 |
+
int Stages,
|
| 89 |
+
/// Used for partial specialization
|
| 90 |
+
typename Enable = bool>
|
| 91 |
+
class DqMmaBase {
|
| 92 |
+
public:
|
| 93 |
+
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 94 |
+
using Shape = Shape_;
|
| 95 |
+
|
| 96 |
+
///< Policy describing tuning details
|
| 97 |
+
using Policy = Policy_;
|
| 98 |
+
|
| 99 |
+
///< Type of the scale to be loaded
|
| 100 |
+
using ElementScale = ElementScale_;
|
| 101 |
+
|
| 102 |
+
//
|
| 103 |
+
// Dependent types
|
| 104 |
+
//
|
| 105 |
+
|
| 106 |
+
/// Warp-level Mma
|
| 107 |
+
using Operator = typename Policy::Operator;
|
| 108 |
+
|
| 109 |
+
/// Shape describing the overall GEMM computed from shared memory
|
| 110 |
+
/// by each warp.
|
| 111 |
+
using WarpGemm = typename Policy::Operator::Shape;
|
| 112 |
+
|
| 113 |
+
/// Shape describing the number of warps filling the CTA
|
| 114 |
+
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
| 115 |
+
|
| 116 |
+
/// Number of warp-level GEMM oeprations
|
| 117 |
+
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
| 118 |
+
|
| 119 |
+
static constexpr int kNumKIterationsPerWarpBLoad =
|
| 120 |
+
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
| 121 |
+
|
| 122 |
+
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
|
| 123 |
+
static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
|
| 124 |
+
|
| 125 |
+
/// Number of stages
|
| 126 |
+
static int const kStages = Stages;
|
| 127 |
+
|
| 128 |
+
/// Tensor reference to the A operand
|
| 129 |
+
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
| 130 |
+
|
| 131 |
+
/// Tensor reference to the B operand
|
| 132 |
+
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
| 133 |
+
|
| 134 |
+
//
|
| 135 |
+
// Nested structs
|
| 136 |
+
//
|
| 137 |
+
|
| 138 |
+
/// Shared storage object needed by threadblock-scoped GEMM
|
| 139 |
+
class SharedStorage {
|
| 140 |
+
public:
|
| 141 |
+
//
|
| 142 |
+
// Type definitions
|
| 143 |
+
//
|
| 144 |
+
|
| 145 |
+
/// Shape of the A matrix operand in shared memory
|
| 146 |
+
using ShapeA =
|
| 147 |
+
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
| 148 |
+
|
| 149 |
+
/// Shape of the B matrix operand in shared memory
|
| 150 |
+
using ShapeB =
|
| 151 |
+
MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
|
| 152 |
+
|
| 153 |
+
public:
|
| 154 |
+
//
|
| 155 |
+
// Data members
|
| 156 |
+
//
|
| 157 |
+
|
| 158 |
+
/// Buffer for A operand
|
| 159 |
+
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
| 160 |
+
|
| 161 |
+
/// Buffer for B operand
|
| 162 |
+
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
| 163 |
+
|
| 164 |
+
/// Buffer to hold scales for threadblock
|
| 165 |
+
AlignedBuffer<ElementScale, Shape::kN> operand_scale;
|
| 166 |
+
|
| 167 |
+
public:
|
| 168 |
+
//
|
| 169 |
+
// Methods
|
| 170 |
+
//
|
| 171 |
+
|
| 172 |
+
/// Returns a layout object for the A matrix
|
| 173 |
+
CUTLASS_DEVICE
|
| 174 |
+
static typename Operator::LayoutA LayoutA()
|
| 175 |
+
{
|
| 176 |
+
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
/// Returns a layout object for the B matrix
|
| 180 |
+
CUTLASS_HOST_DEVICE
|
| 181 |
+
static typename Operator::LayoutB LayoutB()
|
| 182 |
+
{
|
| 183 |
+
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
/// Returns a TensorRef to the A operand
|
| 187 |
+
CUTLASS_HOST_DEVICE
|
| 188 |
+
TensorRefA operand_A_ref()
|
| 189 |
+
{
|
| 190 |
+
return TensorRefA{operand_A.data(), LayoutA()};
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
/// Returns a TensorRef to the B operand
|
| 194 |
+
CUTLASS_HOST_DEVICE
|
| 195 |
+
TensorRefB operand_B_ref()
|
| 196 |
+
{
|
| 197 |
+
return TensorRefB{operand_B.data(), LayoutB()};
|
| 198 |
+
}
|
| 199 |
+
};
|
| 200 |
+
|
| 201 |
+
protected:
|
| 202 |
+
//
|
| 203 |
+
// Data members
|
| 204 |
+
//
|
| 205 |
+
|
| 206 |
+
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
| 207 |
+
typename Operator::IteratorA warp_tile_iterator_A_;
|
| 208 |
+
|
| 209 |
+
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
| 210 |
+
typename Operator::IteratorB warp_tile_iterator_B_;
|
| 211 |
+
|
| 212 |
+
public:
|
| 213 |
+
/// Construct from tensor references
|
| 214 |
+
CUTLASS_DEVICE
|
| 215 |
+
DqMmaBase(
|
| 216 |
+
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
| 217 |
+
SharedStorage& shared_storage,
|
| 218 |
+
///< ID within the threadblock
|
| 219 |
+
int thread_idx,
|
| 220 |
+
///< ID of warp
|
| 221 |
+
int warp_idx,
|
| 222 |
+
///< ID of each thread within a warp
|
| 223 |
+
int lane_idx):
|
| 224 |
+
warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
| 225 |
+
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
|
| 226 |
+
{
|
| 227 |
+
}
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 231 |
+
|
| 232 |
+
} // namespace threadblock
|
| 233 |
+
} // namespace gemm
|
| 234 |
+
} // namespace cutlass
|
| 235 |
+
|
| 236 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/aligned_buffer.h"
|
| 38 |
+
#include "cutlass/arch/memory.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/cutlass.h"
|
| 41 |
+
#include "cutlass/gemm/gemm.h"
|
| 42 |
+
#include "cutlass/matrix_shape.h"
|
| 43 |
+
#include "cutlass/numeric_types.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
| 46 |
+
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
| 47 |
+
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
| 48 |
+
|
| 49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
namespace cutlass {
|
| 52 |
+
namespace gemm {
|
| 53 |
+
namespace threadblock {
|
| 54 |
+
|
| 55 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
| 58 |
+
/// instructions.
|
| 59 |
+
template<
|
| 60 |
+
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 61 |
+
typename Shape_,
|
| 62 |
+
/// Iterates over tiles of A operand in global memory
|
| 63 |
+
// (concept: ReadableTileIterator | ForwardTileIterator |
|
| 64 |
+
// MaskedTileIterator)
|
| 65 |
+
typename IteratorA_,
|
| 66 |
+
/// Iterates over tiles of A operand in shared memory
|
| 67 |
+
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
| 68 |
+
typename SmemIteratorA_,
|
| 69 |
+
/// Cache operation for operand A
|
| 70 |
+
cutlass::arch::CacheOperation::Kind CacheOpA,
|
| 71 |
+
/// Iterates over tiles of B operand in global memory
|
| 72 |
+
// (concept: ReadableTileIterator | ForwardTileIterator |
|
| 73 |
+
// MaskedTileIterator)
|
| 74 |
+
typename IteratorB_,
|
| 75 |
+
/// Iterates over tiles of B operand in shared memory
|
| 76 |
+
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
| 77 |
+
typename SmemIteratorB_,
|
| 78 |
+
/// Cache operation for operand B
|
| 79 |
+
cutlass::arch::CacheOperation::Kind CacheOpB,
|
| 80 |
+
/// Data type for the scales
|
| 81 |
+
typename IteratorScale_,
|
| 82 |
+
/// Iterators over scales in shared memory
|
| 83 |
+
typename SmemIteratorScale_,
|
| 84 |
+
/// Data type of accumulator matrix
|
| 85 |
+
typename ElementC_,
|
| 86 |
+
/// Data type of accumulator matrix
|
| 87 |
+
typename LayoutC_,
|
| 88 |
+
/// Policy describing tuning details (concept: MmaPolicy)
|
| 89 |
+
typename Policy_,
|
| 90 |
+
/// Number of stages,
|
| 91 |
+
int Stages,
|
| 92 |
+
/// Converter for B matrix applited immediately after the LDS
|
| 93 |
+
typename TransformBAfterLDS_,
|
| 94 |
+
/// Use zfill or predicate for out-of-bound cp.async
|
| 95 |
+
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
| 96 |
+
/// Used for partial specialization
|
| 97 |
+
typename Enable = bool>
|
| 98 |
+
class DqMmaMultistage: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages> {
|
| 99 |
+
public:
|
| 100 |
+
///< Base class
|
| 101 |
+
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages>;
|
| 102 |
+
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 103 |
+
using Shape = Shape_;
|
| 104 |
+
///< Iterates over tiles of A operand in global memory
|
| 105 |
+
using IteratorA = IteratorA_;
|
| 106 |
+
///< Iterates over tiles of B operand in global memory
|
| 107 |
+
using IteratorB = IteratorB_;
|
| 108 |
+
///< Data type of accumulator matrix
|
| 109 |
+
using ElementC = ElementC_;
|
| 110 |
+
///< Layout of accumulator matrix
|
| 111 |
+
using LayoutC = LayoutC_;
|
| 112 |
+
///< Policy describing tuning details
|
| 113 |
+
using Policy = Policy_;
|
| 114 |
+
|
| 115 |
+
using IteratorScale = IteratorScale_;
|
| 116 |
+
using ElementScale = typename IteratorScale::Element;
|
| 117 |
+
using LayoutScale = typename IteratorScale::Layout;
|
| 118 |
+
|
| 119 |
+
using SmemIteratorA = SmemIteratorA_;
|
| 120 |
+
using SmemIteratorB = SmemIteratorB_;
|
| 121 |
+
using SmemIteratorScale = SmemIteratorScale_;
|
| 122 |
+
|
| 123 |
+
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
| 124 |
+
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
| 125 |
+
|
| 126 |
+
using TransformBAfterLDS = TransformBAfterLDS_;
|
| 127 |
+
|
| 128 |
+
//
|
| 129 |
+
// Dependent types
|
| 130 |
+
//
|
| 131 |
+
|
| 132 |
+
/// Fragment of operand Scale loaded from global memory;
|
| 133 |
+
using FragmentScale = typename IteratorScale::Fragment;
|
| 134 |
+
|
| 135 |
+
/// Fragment of accumulator tile
|
| 136 |
+
using FragmentC = typename Policy::Operator::FragmentC;
|
| 137 |
+
|
| 138 |
+
/// Warp-level Mma
|
| 139 |
+
using Operator = typename Policy::Operator;
|
| 140 |
+
|
| 141 |
+
/// Minimum architecture is Sm80 to support cp.async
|
| 142 |
+
using ArchTag = arch::Sm80;
|
| 143 |
+
|
| 144 |
+
using Dequantizer =
|
| 145 |
+
warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale, LayoutScale, 32>;
|
| 146 |
+
|
| 147 |
+
/// Complex transform on A operand
|
| 148 |
+
static ComplexTransform const kTransformA = Operator::kTransformA;
|
| 149 |
+
|
| 150 |
+
/// Complex transform on B operand
|
| 151 |
+
static ComplexTransform const kTransformB = Operator::kTransformB;
|
| 152 |
+
|
| 153 |
+
/// Internal structure exposed for introspection.
|
| 154 |
+
struct Detail {
|
| 155 |
+
|
| 156 |
+
static_assert(Base::kWarpGemmIterations > 1,
|
| 157 |
+
"The pipelined structure requires at least two warp-level "
|
| 158 |
+
"GEMM operations.");
|
| 159 |
+
|
| 160 |
+
/// Number of cp.async instructions to load one stage of operand A
|
| 161 |
+
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
|
| 162 |
+
|
| 163 |
+
/// Number of cp.async instructions to load one stage of operand B
|
| 164 |
+
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
|
| 165 |
+
|
| 166 |
+
/// Number of stages
|
| 167 |
+
static int const kStages = Stages;
|
| 168 |
+
|
| 169 |
+
/// Number of cp.async instructions to load on group of operand A
|
| 170 |
+
static int const kAccessesPerGroupA =
|
| 171 |
+
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
| 172 |
+
|
| 173 |
+
/// Number of cp.async instructions to load on group of operand B
|
| 174 |
+
static int const kAccessesPerGroupB =
|
| 175 |
+
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
| 176 |
+
};
|
| 177 |
+
|
| 178 |
+
private:
|
| 179 |
+
using WarpFragmentA = typename Operator::FragmentA;
|
| 180 |
+
using WarpFragmentB = typename Operator::FragmentB;
|
| 181 |
+
Dequantizer warp_dequantizer_;
|
| 182 |
+
|
| 183 |
+
using ElementB = typename IteratorB::Element;
|
| 184 |
+
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
|
| 185 |
+
|
| 186 |
+
static constexpr bool RequiresTileInterleave =
|
| 187 |
+
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
| 188 |
+
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
| 189 |
+
"Layout K must match threadblockK");
|
| 190 |
+
|
| 191 |
+
private:
|
| 192 |
+
//
|
| 193 |
+
// Data members
|
| 194 |
+
//
|
| 195 |
+
|
| 196 |
+
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
| 197 |
+
SmemIteratorA smem_iterator_A_;
|
| 198 |
+
|
| 199 |
+
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
| 200 |
+
SmemIteratorB smem_iterator_B_;
|
| 201 |
+
|
| 202 |
+
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
| 203 |
+
SmemIteratorScale smem_iterator_scale_;
|
| 204 |
+
|
| 205 |
+
public:
|
| 206 |
+
/// Construct from tensor references
|
| 207 |
+
CUTLASS_DEVICE
|
| 208 |
+
DqMmaMultistage(
|
| 209 |
+
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
| 210 |
+
typename Base::SharedStorage& shared_storage,
|
| 211 |
+
///< ID within the threadblock
|
| 212 |
+
int thread_idx,
|
| 213 |
+
///< ID of warp
|
| 214 |
+
int warp_idx,
|
| 215 |
+
///< ID of each thread within a warp
|
| 216 |
+
int lane_idx):
|
| 217 |
+
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
| 218 |
+
warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
| 219 |
+
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
|
| 220 |
+
lane_idx),
|
| 221 |
+
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
| 222 |
+
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
| 223 |
+
smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
| 224 |
+
{
|
| 225 |
+
// Compute warp location within threadblock tile by mapping the warp_id to
|
| 226 |
+
// three coordinates:
|
| 227 |
+
// _m: the warp's position within the threadblock along the M dimension
|
| 228 |
+
// _n: the warp's position within the threadblock along the N dimension
|
| 229 |
+
// _k: the warp's position within the threadblock along the K dimension
|
| 230 |
+
|
| 231 |
+
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
| 232 |
+
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
| 233 |
+
|
| 234 |
+
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
| 235 |
+
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
| 236 |
+
|
| 237 |
+
// Add per-warp offsets in units of warp-level tiles
|
| 238 |
+
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
| 239 |
+
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
CUTLASS_DEVICE
|
| 243 |
+
void
|
| 244 |
+
copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
|
| 245 |
+
{
|
| 246 |
+
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
| 247 |
+
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
| 248 |
+
|
| 249 |
+
// Async Copy for operand A
|
| 250 |
+
CUTLASS_PRAGMA_UNROLL
|
| 251 |
+
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
| 252 |
+
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
| 253 |
+
typename IteratorA::AccessType* dst_ptr =
|
| 254 |
+
reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
| 255 |
+
|
| 256 |
+
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
| 257 |
+
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
| 258 |
+
|
| 259 |
+
CUTLASS_PRAGMA_UNROLL
|
| 260 |
+
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
| 261 |
+
auto gmem_ptr = iterator_A.get();
|
| 262 |
+
|
| 263 |
+
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
| 264 |
+
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
| 265 |
+
}
|
| 266 |
+
else {
|
| 267 |
+
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
++iterator_A;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
++this->smem_iterator_A_;
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
|
| 278 |
+
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
| 279 |
+
|
| 280 |
+
// Async Copy for operand B
|
| 281 |
+
CUTLASS_PRAGMA_UNROLL
|
| 282 |
+
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
| 283 |
+
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
| 284 |
+
typename IteratorB::AccessType* dst_ptr =
|
| 285 |
+
reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
| 286 |
+
|
| 287 |
+
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
| 288 |
+
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
| 289 |
+
|
| 290 |
+
CUTLASS_PRAGMA_UNROLL
|
| 291 |
+
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
| 292 |
+
auto gmem_ptr = iterator_B.get();
|
| 293 |
+
|
| 294 |
+
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
| 295 |
+
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
| 296 |
+
}
|
| 297 |
+
else {
|
| 298 |
+
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
++iterator_B;
|
| 302 |
+
}
|
| 303 |
+
++this->smem_iterator_B_;
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
/// Perform a threadblock-scoped matrix multiply-accumulate
|
| 309 |
+
CUTLASS_DEVICE
|
| 310 |
+
void operator()(
|
| 311 |
+
///< problem size of GEMM
|
| 312 |
+
int gemm_k_iterations,
|
| 313 |
+
///< destination accumulator tile
|
| 314 |
+
FragmentC& accum,
|
| 315 |
+
///< iterator over A operand in global memory
|
| 316 |
+
IteratorA iterator_A,
|
| 317 |
+
///< iterator over B operand in global memory
|
| 318 |
+
IteratorB iterator_B,
|
| 319 |
+
///< iterator over scale operand in global memory
|
| 320 |
+
IteratorScale iterator_scale,
|
| 321 |
+
///< initial value of accumulator
|
| 322 |
+
FragmentC const& src_accum)
|
| 323 |
+
{
|
| 324 |
+
|
| 325 |
+
//
|
| 326 |
+
// Prologue
|
| 327 |
+
//
|
| 328 |
+
|
| 329 |
+
TransformBAfterLDS lds_converter;
|
| 330 |
+
|
| 331 |
+
// NOTE - switch to ldg.sts
|
| 332 |
+
// Issue this first, so cp.async.commit_group will commit this load as well.
|
| 333 |
+
// Note: we do not commit here and this load will commit in the same group as
|
| 334 |
+
// the first load of A.
|
| 335 |
+
FragmentScale tb_frag_scales;
|
| 336 |
+
tb_frag_scales.clear();
|
| 337 |
+
iterator_scale.load(tb_frag_scales);
|
| 338 |
+
this->smem_iterator_scale_.store(tb_frag_scales);
|
| 339 |
+
|
| 340 |
+
// Issue several complete stages
|
| 341 |
+
CUTLASS_PRAGMA_UNROLL
|
| 342 |
+
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
|
| 343 |
+
|
| 344 |
+
iterator_A.clear_mask(gemm_k_iterations == 0);
|
| 345 |
+
iterator_B.clear_mask(gemm_k_iterations == 0);
|
| 346 |
+
|
| 347 |
+
iterator_A.set_iteration_index(0);
|
| 348 |
+
this->smem_iterator_A_.set_iteration_index(0);
|
| 349 |
+
|
| 350 |
+
// Async Copy for operand A
|
| 351 |
+
CUTLASS_PRAGMA_UNROLL
|
| 352 |
+
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
| 353 |
+
typename IteratorA::AccessType* dst_ptr =
|
| 354 |
+
reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
| 355 |
+
|
| 356 |
+
CUTLASS_PRAGMA_UNROLL
|
| 357 |
+
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
| 358 |
+
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
| 359 |
+
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector
|
| 360 |
+
/ 8;
|
| 361 |
+
|
| 362 |
+
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
| 363 |
+
|
| 364 |
+
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
| 365 |
+
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
| 366 |
+
|
| 367 |
+
++iterator_A;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
++this->smem_iterator_A_;
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
iterator_B.set_iteration_index(0);
|
| 374 |
+
this->smem_iterator_B_.set_iteration_index(0);
|
| 375 |
+
|
| 376 |
+
// Async Copy for operand B
|
| 377 |
+
CUTLASS_PRAGMA_UNROLL
|
| 378 |
+
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
| 379 |
+
typename IteratorB::AccessType* dst_ptr =
|
| 380 |
+
reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
| 381 |
+
|
| 382 |
+
CUTLASS_PRAGMA_UNROLL
|
| 383 |
+
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
| 384 |
+
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
| 385 |
+
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector
|
| 386 |
+
/ 8;
|
| 387 |
+
|
| 388 |
+
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
| 389 |
+
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
| 390 |
+
|
| 391 |
+
++iterator_B;
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
++this->smem_iterator_B_;
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
// Move to the next stage
|
| 398 |
+
iterator_A.add_tile_offset({0, 1});
|
| 399 |
+
iterator_B.add_tile_offset({1, 0});
|
| 400 |
+
|
| 401 |
+
this->smem_iterator_A_.add_tile_offset({0, 1});
|
| 402 |
+
this->smem_iterator_B_.add_tile_offset({1, 0});
|
| 403 |
+
|
| 404 |
+
// Defines the boundary of a stage of cp.async.
|
| 405 |
+
cutlass::arch::cp_async_fence();
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
// Perform accumulation in the 'd' output operand
|
| 409 |
+
accum = src_accum;
|
| 410 |
+
|
| 411 |
+
//
|
| 412 |
+
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
|
| 413 |
+
// so that all accumulator elements outside the GEMM footprint are zero.
|
| 414 |
+
//
|
| 415 |
+
|
| 416 |
+
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
|
| 417 |
+
|
| 418 |
+
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
| 419 |
+
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
| 420 |
+
|
| 421 |
+
typename IteratorA::AccessType zero_A;
|
| 422 |
+
zero_A.clear();
|
| 423 |
+
|
| 424 |
+
last_smem_iterator_A.set_iteration_index(0);
|
| 425 |
+
|
| 426 |
+
// Async Copy for operand A
|
| 427 |
+
CUTLASS_PRAGMA_UNROLL
|
| 428 |
+
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
| 429 |
+
|
| 430 |
+
typename IteratorA::AccessType* dst_ptr =
|
| 431 |
+
reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
|
| 432 |
+
|
| 433 |
+
*dst_ptr = zero_A;
|
| 434 |
+
|
| 435 |
+
++last_smem_iterator_A;
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
| 439 |
+
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
| 440 |
+
typename IteratorB::AccessType zero_B;
|
| 441 |
+
|
| 442 |
+
zero_B.clear();
|
| 443 |
+
last_smem_iterator_B.set_iteration_index(0);
|
| 444 |
+
|
| 445 |
+
// Async Copy for operand B
|
| 446 |
+
CUTLASS_PRAGMA_UNROLL
|
| 447 |
+
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
| 448 |
+
|
| 449 |
+
typename IteratorB::AccessType* dst_ptr =
|
| 450 |
+
reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
|
| 451 |
+
|
| 452 |
+
*dst_ptr = zero_B;
|
| 453 |
+
|
| 454 |
+
++last_smem_iterator_B;
|
| 455 |
+
}
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
// Waits until kStages-2 stages have committed.
|
| 459 |
+
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
| 460 |
+
__syncthreads();
|
| 461 |
+
|
| 462 |
+
// Pair of fragments used to overlap shared memory loads and math
|
| 463 |
+
// instructions
|
| 464 |
+
WarpFragmentA warp_frag_A[2];
|
| 465 |
+
WarpFragmentB warp_frag_B[2];
|
| 466 |
+
typename Dequantizer::FragmentScale warp_frag_scales;
|
| 467 |
+
|
| 468 |
+
Operator warp_mma;
|
| 469 |
+
|
| 470 |
+
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
| 471 |
+
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
| 472 |
+
|
| 473 |
+
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
| 474 |
+
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
| 475 |
+
warp_dequantizer_.load(warp_frag_scales);
|
| 476 |
+
|
| 477 |
+
++this->warp_tile_iterator_A_;
|
| 478 |
+
++this->warp_tile_iterator_B_;
|
| 479 |
+
|
| 480 |
+
iterator_A.clear_mask(gemm_k_iterations == 0);
|
| 481 |
+
iterator_B.clear_mask(gemm_k_iterations == 0);
|
| 482 |
+
|
| 483 |
+
int smem_write_stage_idx = Base::kStages - 1;
|
| 484 |
+
int smem_read_stage_idx = 0;
|
| 485 |
+
|
| 486 |
+
//
|
| 487 |
+
// Mainloop
|
| 488 |
+
//
|
| 489 |
+
|
| 490 |
+
CUTLASS_GEMM_LOOP
|
| 491 |
+
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
| 492 |
+
//
|
| 493 |
+
// Loop over GEMM K dimension
|
| 494 |
+
//
|
| 495 |
+
|
| 496 |
+
// Computes a warp-level GEMM on data held in shared memory
|
| 497 |
+
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
| 498 |
+
CUTLASS_PRAGMA_UNROLL
|
| 499 |
+
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
| 500 |
+
|
| 501 |
+
// Load warp-level tiles from shared memory, wrapping to k offset if
|
| 502 |
+
// this is the last group as the case may be.
|
| 503 |
+
|
| 504 |
+
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
| 505 |
+
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
| 506 |
+
++this->warp_tile_iterator_A_;
|
| 507 |
+
|
| 508 |
+
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
| 509 |
+
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
| 510 |
+
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) {
|
| 511 |
+
this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1)
|
| 512 |
+
% Base::kWarpGemmIterationsForB);
|
| 513 |
+
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
| 514 |
+
++this->warp_tile_iterator_B_;
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
typename TransformBAfterLDS::result_type converted_frag_B =
|
| 518 |
+
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
| 519 |
+
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
| 520 |
+
|
| 521 |
+
run_warp_mma(
|
| 522 |
+
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
| 523 |
+
|
| 524 |
+
// Issue global->shared copies for the this stage
|
| 525 |
+
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
|
| 526 |
+
int group_start_iteration_A, group_start_iteration_B;
|
| 527 |
+
|
| 528 |
+
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
| 529 |
+
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
| 530 |
+
|
| 531 |
+
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
| 535 |
+
int group_start_iteration_A, group_start_iteration_B;
|
| 536 |
+
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
| 537 |
+
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
| 538 |
+
|
| 539 |
+
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
| 540 |
+
|
| 541 |
+
// Inserts a memory fence between stages of cp.async instructions.
|
| 542 |
+
cutlass::arch::cp_async_fence();
|
| 543 |
+
|
| 544 |
+
// Waits until kStages-2 stages have committed.
|
| 545 |
+
arch::cp_async_wait<Base::kStages - 2>();
|
| 546 |
+
__syncthreads();
|
| 547 |
+
|
| 548 |
+
// Move to the next stage
|
| 549 |
+
iterator_A.add_tile_offset({0, 1});
|
| 550 |
+
iterator_B.add_tile_offset({1, 0});
|
| 551 |
+
|
| 552 |
+
this->smem_iterator_A_.add_tile_offset({0, 1});
|
| 553 |
+
this->smem_iterator_B_.add_tile_offset({1, 0});
|
| 554 |
+
|
| 555 |
+
// Add negative offsets to return iterators to the 'start' of the
|
| 556 |
+
// circular buffer in shared memory
|
| 557 |
+
if (smem_write_stage_idx == (Base::kStages - 1)) {
|
| 558 |
+
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
| 559 |
+
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
| 560 |
+
smem_write_stage_idx = 0;
|
| 561 |
+
}
|
| 562 |
+
else {
|
| 563 |
+
++smem_write_stage_idx;
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
if (smem_read_stage_idx == (Base::kStages - 1)) {
|
| 567 |
+
this->warp_tile_iterator_A_.add_tile_offset(
|
| 568 |
+
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
| 569 |
+
this->warp_tile_iterator_B_.add_tile_offset(
|
| 570 |
+
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
| 571 |
+
smem_read_stage_idx = 0;
|
| 572 |
+
}
|
| 573 |
+
else {
|
| 574 |
+
++smem_read_stage_idx;
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
--gemm_k_iterations;
|
| 578 |
+
iterator_A.clear_mask(gemm_k_iterations == 0);
|
| 579 |
+
iterator_B.clear_mask(gemm_k_iterations == 0);
|
| 580 |
+
}
|
| 581 |
+
}
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
| 585 |
+
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
| 586 |
+
cutlass::arch::cp_async_fence();
|
| 587 |
+
cutlass::arch::cp_async_wait<0>();
|
| 588 |
+
__syncthreads();
|
| 589 |
+
}
|
| 590 |
+
}
|
| 591 |
+
};
|
| 592 |
+
|
| 593 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 594 |
+
|
| 595 |
+
} // namespace threadblock
|
| 596 |
+
} // namespace gemm
|
| 597 |
+
} // namespace cutlass
|
| 598 |
+
|
| 599 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/aligned_buffer.h"
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/numeric_conversion.h"
|
| 41 |
+
|
| 42 |
+
#include "cutlass/matrix_shape.h"
|
| 43 |
+
#include "cutlass/numeric_types.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/gemm/gemm.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
| 48 |
+
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
| 49 |
+
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
| 50 |
+
|
| 51 |
+
#include "cutlass_extensions/ft_gemm_configs.h"
|
| 52 |
+
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
| 53 |
+
|
| 54 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
namespace cutlass {
|
| 57 |
+
namespace gemm {
|
| 58 |
+
namespace threadblock {
|
| 59 |
+
|
| 60 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 61 |
+
|
| 62 |
+
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
| 63 |
+
template<
|
| 64 |
+
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 65 |
+
typename Shape_,
|
| 66 |
+
/// Iterates over tiles of A operand in global memory
|
| 67 |
+
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
| 68 |
+
typename IteratorA_,
|
| 69 |
+
/// Iterates over tiles of A operand in shared memory
|
| 70 |
+
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
| 71 |
+
typename SmemIteratorA_,
|
| 72 |
+
/// Iterates over tiles of B operand in global memory
|
| 73 |
+
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
| 74 |
+
typename IteratorB_,
|
| 75 |
+
/// Iterates over tiles of B operand in shared memory
|
| 76 |
+
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
| 77 |
+
typename SmemIteratorB_,
|
| 78 |
+
/// Data type for the scales
|
| 79 |
+
typename IteratorScale_,
|
| 80 |
+
/// Iterators over scales in shared memory
|
| 81 |
+
typename SmemIteratorScale_,
|
| 82 |
+
/// Data type of accumulator matrix
|
| 83 |
+
typename ElementC_,
|
| 84 |
+
/// Data type of accumulator matrix
|
| 85 |
+
typename LayoutC_,
|
| 86 |
+
/// Policy describing tuning details (concept: MmaPolicy)
|
| 87 |
+
typename Policy_,
|
| 88 |
+
/// Converter for B matrix applied immediately after the LDG (before STS)
|
| 89 |
+
typename TransformBAfterLDG_,
|
| 90 |
+
/// Converter for B matrix applited immediately after the LDS
|
| 91 |
+
typename TransformBAfterLDS_,
|
| 92 |
+
/// Used for partial specialization
|
| 93 |
+
typename Enable = bool>
|
| 94 |
+
class DqMmaPipelined: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2> {
|
| 95 |
+
public:
|
| 96 |
+
///< Base class
|
| 97 |
+
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2>;
|
| 98 |
+
|
| 99 |
+
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 100 |
+
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
| 101 |
+
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
| 102 |
+
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
| 103 |
+
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
| 104 |
+
using Policy = Policy_; ///< Policy describing tuning details
|
| 105 |
+
|
| 106 |
+
using IteratorScale = IteratorScale_;
|
| 107 |
+
using ElementScale = typename IteratorScale::Element;
|
| 108 |
+
using LayoutScale = typename IteratorScale::Layout;
|
| 109 |
+
|
| 110 |
+
using SmemIteratorA = SmemIteratorA_;
|
| 111 |
+
using SmemIteratorB = SmemIteratorB_;
|
| 112 |
+
using SmemIteratorScale = SmemIteratorScale_;
|
| 113 |
+
|
| 114 |
+
using TransformBAfterLDG = TransformBAfterLDG_;
|
| 115 |
+
using TransformBAfterLDS = TransformBAfterLDS_;
|
| 116 |
+
|
| 117 |
+
//
|
| 118 |
+
// Dependent types
|
| 119 |
+
//
|
| 120 |
+
|
| 121 |
+
/// Fragment of operand A loaded from global memory
|
| 122 |
+
using FragmentA = typename IteratorA::Fragment;
|
| 123 |
+
|
| 124 |
+
/// Fragment of operand B loaded from global memory
|
| 125 |
+
using FragmentB = typename IteratorB::Fragment;
|
| 126 |
+
|
| 127 |
+
/// Fragment of operand Scale loaded from global memory;
|
| 128 |
+
using FragmentScale = typename IteratorScale::Fragment;
|
| 129 |
+
|
| 130 |
+
/// Fragment of accumulator tile
|
| 131 |
+
using FragmentC = typename Policy::Operator::FragmentC;
|
| 132 |
+
|
| 133 |
+
/// Warp-level Mma
|
| 134 |
+
using Operator = typename Policy::Operator;
|
| 135 |
+
|
| 136 |
+
/// Obtain the arch tag from the warp-level operator
|
| 137 |
+
using ArchTag = typename Policy::Operator::ArchTag;
|
| 138 |
+
|
| 139 |
+
using Dequantizer = warp::MmaTensorOpDequantizer<Operator,
|
| 140 |
+
typename Base::WarpGemm,
|
| 141 |
+
Operand::kB,
|
| 142 |
+
typename SmemIteratorScale::Fragment::Element,
|
| 143 |
+
LayoutScale,
|
| 144 |
+
32>;
|
| 145 |
+
|
| 146 |
+
/// Complex transform on A operand
|
| 147 |
+
static ComplexTransform const kTransformA = Operator::kTransformA;
|
| 148 |
+
|
| 149 |
+
/// Complex transform on B operand
|
| 150 |
+
static ComplexTransform const kTransformB = Operator::kTransformB;
|
| 151 |
+
|
| 152 |
+
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
| 153 |
+
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
| 154 |
+
|
| 155 |
+
private:
|
| 156 |
+
using WarpFragmentA = typename Operator::FragmentA;
|
| 157 |
+
using WarpFragmentB = typename Operator::FragmentB;
|
| 158 |
+
Dequantizer warp_dequantizer_;
|
| 159 |
+
|
| 160 |
+
using ElementB = typename IteratorB::Element;
|
| 161 |
+
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
|
| 162 |
+
|
| 163 |
+
static constexpr bool RequiresTileInterleave =
|
| 164 |
+
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
| 165 |
+
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
| 166 |
+
"Layout K must match threadblockK");
|
| 167 |
+
|
| 168 |
+
protected:
|
| 169 |
+
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
| 170 |
+
SmemIteratorA smem_iterator_A_;
|
| 171 |
+
|
| 172 |
+
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
| 173 |
+
SmemIteratorB smem_iterator_B_;
|
| 174 |
+
|
| 175 |
+
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
| 176 |
+
SmemIteratorScale smem_iterator_scale_;
|
| 177 |
+
|
| 178 |
+
public:
|
| 179 |
+
/// Construct from tensor references
|
| 180 |
+
CUTLASS_DEVICE
|
| 181 |
+
DqMmaPipelined(typename Base::SharedStorage&
|
| 182 |
+
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
| 183 |
+
int thread_idx, ///< ID within the threadblock
|
| 184 |
+
int warp_idx, ///< ID of warp
|
| 185 |
+
int lane_idx ///< ID of each thread within a warp
|
| 186 |
+
):
|
| 187 |
+
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
| 188 |
+
warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
| 189 |
+
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
|
| 190 |
+
lane_idx),
|
| 191 |
+
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
| 192 |
+
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
| 193 |
+
smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
| 194 |
+
{
|
| 195 |
+
|
| 196 |
+
// Compute warp location within threadblock tile by mapping the warp_id to
|
| 197 |
+
// three coordinates:
|
| 198 |
+
// _m: the warp's position within the threadblock along the M dimension
|
| 199 |
+
// _n: the warp's position within the threadblock along the N dimension
|
| 200 |
+
// _k: the warp's position within the threadblock along the K dimension
|
| 201 |
+
|
| 202 |
+
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
| 203 |
+
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
| 204 |
+
|
| 205 |
+
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
| 206 |
+
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
| 207 |
+
|
| 208 |
+
// Add per-warp offsets in units of warp-level tiles
|
| 209 |
+
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
| 210 |
+
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
/// Perform a threadblock-scoped matrix multiply-accumulate
|
| 214 |
+
CUTLASS_DEVICE
|
| 215 |
+
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
| 216 |
+
FragmentC& accum, ///< destination accumulator tile
|
| 217 |
+
IteratorA iterator_A, ///< iterator over A operand in global memory
|
| 218 |
+
IteratorB iterator_B, ///< iterator over B operand in global memory
|
| 219 |
+
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
| 220 |
+
FragmentC const& src_accum)
|
| 221 |
+
{ ///< source accumulator tile
|
| 222 |
+
|
| 223 |
+
//
|
| 224 |
+
// Prologue
|
| 225 |
+
//
|
| 226 |
+
TransformBAfterLDG ldg_converter;
|
| 227 |
+
TransformBAfterLDS lds_converter;
|
| 228 |
+
|
| 229 |
+
using TransformA =
|
| 230 |
+
NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
| 231 |
+
|
| 232 |
+
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
|
| 233 |
+
typename FragmentScale::Element,
|
| 234 |
+
FragmentScale::kElements>;
|
| 235 |
+
|
| 236 |
+
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
| 237 |
+
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
| 238 |
+
TransformA transformA;
|
| 239 |
+
TransformScale transformScale;
|
| 240 |
+
|
| 241 |
+
// Perform accumulation in the 'd' output operand
|
| 242 |
+
accum = src_accum;
|
| 243 |
+
|
| 244 |
+
FragmentA tb_frag_A;
|
| 245 |
+
FragmentB tb_frag_B;
|
| 246 |
+
FragmentScale tb_frag_scales;
|
| 247 |
+
|
| 248 |
+
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
| 249 |
+
WarpFragmentScale warp_frag_scales;
|
| 250 |
+
|
| 251 |
+
tb_frag_A.clear();
|
| 252 |
+
tb_frag_B.clear();
|
| 253 |
+
tb_frag_scales.clear();
|
| 254 |
+
|
| 255 |
+
// The last kblock is loaded in the prolog
|
| 256 |
+
iterator_A.load(tb_frag_A);
|
| 257 |
+
iterator_B.load(tb_frag_B);
|
| 258 |
+
iterator_scale.load(tb_frag_scales);
|
| 259 |
+
|
| 260 |
+
++iterator_A;
|
| 261 |
+
++iterator_B;
|
| 262 |
+
|
| 263 |
+
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
| 264 |
+
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
| 265 |
+
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
|
| 266 |
+
|
| 267 |
+
++this->smem_iterator_A_;
|
| 268 |
+
++this->smem_iterator_B_;
|
| 269 |
+
|
| 270 |
+
__syncthreads();
|
| 271 |
+
|
| 272 |
+
warp_dequantizer_.load(warp_frag_scales);
|
| 273 |
+
|
| 274 |
+
// Pair of fragments used to overlap shared memory loads and math instructions
|
| 275 |
+
WarpFragmentA warp_frag_A[2];
|
| 276 |
+
WarpFragmentB warp_frag_B[2];
|
| 277 |
+
|
| 278 |
+
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
| 279 |
+
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
| 280 |
+
|
| 281 |
+
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
| 282 |
+
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
| 283 |
+
|
| 284 |
+
++this->warp_tile_iterator_A_;
|
| 285 |
+
++this->warp_tile_iterator_B_;
|
| 286 |
+
|
| 287 |
+
Operator warp_mma;
|
| 288 |
+
|
| 289 |
+
int smem_write_stage_idx = 1;
|
| 290 |
+
|
| 291 |
+
// Avoid reading out of bounds
|
| 292 |
+
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
| 293 |
+
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
| 294 |
+
|
| 295 |
+
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
| 296 |
+
// shared memory loads (which have the tighest latency requirement).
|
| 297 |
+
|
| 298 |
+
//
|
| 299 |
+
// Mainloop
|
| 300 |
+
//
|
| 301 |
+
|
| 302 |
+
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
| 303 |
+
CUTLASS_GEMM_LOOP
|
| 304 |
+
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
|
| 305 |
+
//
|
| 306 |
+
// Loop over GEMM K dimension
|
| 307 |
+
//
|
| 308 |
+
|
| 309 |
+
CUTLASS_PRAGMA_UNROLL
|
| 310 |
+
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
| 311 |
+
|
| 312 |
+
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
| 313 |
+
// as the case may be.
|
| 314 |
+
|
| 315 |
+
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
| 316 |
+
|
| 317 |
+
// Write fragments to shared memory
|
| 318 |
+
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
| 319 |
+
|
| 320 |
+
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
| 321 |
+
|
| 322 |
+
__syncthreads();
|
| 323 |
+
|
| 324 |
+
++this->smem_iterator_A_;
|
| 325 |
+
++this->smem_iterator_B_;
|
| 326 |
+
|
| 327 |
+
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
| 328 |
+
if (smem_write_stage_idx == 1) {
|
| 329 |
+
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
| 330 |
+
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
| 331 |
+
}
|
| 332 |
+
else {
|
| 333 |
+
this->warp_tile_iterator_A_.add_tile_offset(
|
| 334 |
+
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
| 335 |
+
this->warp_tile_iterator_B_.add_tile_offset(
|
| 336 |
+
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
smem_write_stage_idx ^= 1;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
| 343 |
+
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
| 344 |
+
++this->warp_tile_iterator_A_;
|
| 345 |
+
|
| 346 |
+
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
| 347 |
+
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
| 348 |
+
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
| 349 |
+
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) {
|
| 350 |
+
this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1)
|
| 351 |
+
% Base::kWarpGemmIterationsForB);
|
| 352 |
+
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
| 353 |
+
++this->warp_tile_iterator_B_;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
if (warp_mma_k == 0) {
|
| 357 |
+
|
| 358 |
+
iterator_A.load(tb_frag_A);
|
| 359 |
+
iterator_B.load(tb_frag_B);
|
| 360 |
+
|
| 361 |
+
++iterator_A;
|
| 362 |
+
++iterator_B;
|
| 363 |
+
|
| 364 |
+
// Avoid reading out of bounds if this was the last loop iteration
|
| 365 |
+
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
| 366 |
+
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
typename TransformBAfterLDS::result_type converted_frag_B =
|
| 370 |
+
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
| 371 |
+
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
| 372 |
+
run_warp_mma(
|
| 373 |
+
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
};
|
| 378 |
+
|
| 379 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 380 |
+
|
| 381 |
+
} // namespace threadblock
|
| 382 |
+
} // namespace gemm
|
| 383 |
+
} // namespace cutlass
|
| 384 |
+
|
| 385 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
| 39 |
+
#include "cutlass/gemm/warp/mma_tensor_op.h"
|
| 40 |
+
|
| 41 |
+
#include "cutlass_extensions/arch/mma.h"
|
| 42 |
+
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
| 43 |
+
|
| 44 |
+
namespace cutlass {
|
| 45 |
+
namespace gemm {
|
| 46 |
+
namespace warp {
|
| 47 |
+
|
| 48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
/// Partial specialization for m-by-n-by-kgroup
|
| 51 |
+
template<
|
| 52 |
+
/// Shape of one matrix production operation (concept: GemmShape)
|
| 53 |
+
typename WarpShape_,
|
| 54 |
+
/// Shape of one matrix production operation (concept: GemmShape)
|
| 55 |
+
typename InstructionShape_,
|
| 56 |
+
/// Data type of A elements,
|
| 57 |
+
typename ElementA,
|
| 58 |
+
/// Layout of A matrix (concept: MatrixLayout)
|
| 59 |
+
typename LayoutA,
|
| 60 |
+
/// Data type of B elements
|
| 61 |
+
typename ElementB,
|
| 62 |
+
/// Layout of B matrix (concept: MatrixLayout)
|
| 63 |
+
typename LayoutB,
|
| 64 |
+
/// Element type of C matrix
|
| 65 |
+
typename ElementC,
|
| 66 |
+
/// Layout of C matrix (concept: MatrixLayout)
|
| 67 |
+
typename LayoutC,
|
| 68 |
+
/// Number of partitions along K dimension
|
| 69 |
+
int PartitionsK,
|
| 70 |
+
/// Store the accumulators in row major or column major. Row major is used
|
| 71 |
+
/// when output layout is interleaved.
|
| 72 |
+
bool AccumulatorsInRowMajor>
|
| 73 |
+
struct DefaultMmaTensorOp<WarpShape_,
|
| 74 |
+
InstructionShape_,
|
| 75 |
+
ElementA,
|
| 76 |
+
LayoutA,
|
| 77 |
+
ElementB,
|
| 78 |
+
LayoutB,
|
| 79 |
+
ElementC,
|
| 80 |
+
LayoutC,
|
| 81 |
+
arch::OpMultiplyAddDequantizeInterleavedBToA,
|
| 82 |
+
PartitionsK,
|
| 83 |
+
AccumulatorsInRowMajor> {
|
| 84 |
+
|
| 85 |
+
private:
|
| 86 |
+
// Shape for computing the FP16s
|
| 87 |
+
using ComputeInstructionShape = InstructionShape_;
|
| 88 |
+
|
| 89 |
+
// Chosen so we get K=16 for int8 and K=32 for int4.
|
| 90 |
+
static constexpr int LoadInstructionK = 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value;
|
| 91 |
+
|
| 92 |
+
// Shape for loading the narrow data type from shared memory
|
| 93 |
+
using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
|
| 94 |
+
|
| 95 |
+
public:
|
| 96 |
+
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<InstructionShape_,
|
| 97 |
+
32,
|
| 98 |
+
ElementA,
|
| 99 |
+
cutlass::layout::RowMajor,
|
| 100 |
+
ElementA,
|
| 101 |
+
cutlass::layout::ColumnMajor,
|
| 102 |
+
ElementC,
|
| 103 |
+
cutlass::layout::RowMajor,
|
| 104 |
+
arch::OpMultiplyAdd>,
|
| 105 |
+
cutlass::MatrixShape<1, 1>>;
|
| 106 |
+
|
| 107 |
+
// Define the warp-level tensor op
|
| 108 |
+
using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_,
|
| 109 |
+
ElementA,
|
| 110 |
+
LayoutA,
|
| 111 |
+
ElementB,
|
| 112 |
+
LayoutB,
|
| 113 |
+
ElementC,
|
| 114 |
+
LayoutC,
|
| 115 |
+
Policy,
|
| 116 |
+
LoadInstructionShape,
|
| 117 |
+
PartitionsK,
|
| 118 |
+
AccumulatorsInRowMajor>;
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 122 |
+
|
| 123 |
+
} // namespace warp
|
| 124 |
+
} // namespace gemm
|
| 125 |
+
} // namespace cutlass
|
| 126 |
+
|
| 127 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
|
| 33 |
+
Tensor Cores.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/platform/platform.h"
|
| 41 |
+
|
| 42 |
+
#include "cutlass/matrix_shape.h"
|
| 43 |
+
#include "cutlass/numeric_conversion.h"
|
| 44 |
+
#include "cutlass/numeric_types.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/arch/memory_sm75.h"
|
| 47 |
+
#include "cutlass/arch/mma_sm75.h"
|
| 48 |
+
#include "cutlass/arch/mma_sm80.h"
|
| 49 |
+
|
| 50 |
+
#include "cutlass/gemm/gemm.h"
|
| 51 |
+
#include "cutlass/gemm/warp/mma.h"
|
| 52 |
+
|
| 53 |
+
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
|
| 54 |
+
|
| 55 |
+
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
|
| 56 |
+
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
|
| 57 |
+
|
| 58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
namespace cutlass {
|
| 61 |
+
namespace gemm {
|
| 62 |
+
namespace warp {
|
| 63 |
+
|
| 64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
| 66 |
+
template<
|
| 67 |
+
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
| 68 |
+
typename Shape_,
|
| 69 |
+
/// Data type of A elements
|
| 70 |
+
typename ElementA_,
|
| 71 |
+
/// Layout of A matrix (concept: MatrixLayout)
|
| 72 |
+
typename LayoutA_,
|
| 73 |
+
/// Data type of B elements
|
| 74 |
+
typename ElementB_,
|
| 75 |
+
/// Layout of B matrix (concept: MatrixLayout)
|
| 76 |
+
typename LayoutB_,
|
| 77 |
+
/// Element type of C matrix
|
| 78 |
+
typename ElementC_,
|
| 79 |
+
/// Layout of C matrix (concept: MatrixLayout)
|
| 80 |
+
typename LayoutC_,
|
| 81 |
+
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
| 82 |
+
typename Policy_,
|
| 83 |
+
/// Instruction shape to override shared memory iterators with
|
| 84 |
+
typename SharedMemoryInstructionShape_,
|
| 85 |
+
/// Number of partitions along K dimension
|
| 86 |
+
int PartitionsK_ = 1,
|
| 87 |
+
/// Store the accumulators in row major or column major. Row major is used
|
| 88 |
+
/// when output layout is interleaved.
|
| 89 |
+
bool AccumulatorsInRowMajor = false,
|
| 90 |
+
/// Used for partial specialization
|
| 91 |
+
typename Enable = bool>
|
| 92 |
+
class MmaTensorOpComputeBWithF16 {
|
| 93 |
+
public:
|
| 94 |
+
/// Shape of warp-level matrix operation (concept: GemmShape)
|
| 95 |
+
using Shape = Shape_;
|
| 96 |
+
|
| 97 |
+
/// Data type of multiplicand A
|
| 98 |
+
using ElementA = ElementA_;
|
| 99 |
+
|
| 100 |
+
/// Layout of multiplicand A
|
| 101 |
+
using LayoutA = LayoutA_;
|
| 102 |
+
|
| 103 |
+
/// Data type of multiplicand B
|
| 104 |
+
using ElementB = ElementB_;
|
| 105 |
+
|
| 106 |
+
/// Layout of multiplicand B
|
| 107 |
+
using LayoutB = LayoutB_;
|
| 108 |
+
|
| 109 |
+
/// Data type of accumulator matrix C
|
| 110 |
+
using ElementC = ElementC_;
|
| 111 |
+
|
| 112 |
+
/// Layout of accumulator matrix C
|
| 113 |
+
using LayoutC = LayoutC_;
|
| 114 |
+
|
| 115 |
+
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
| 116 |
+
using Policy = Policy_;
|
| 117 |
+
|
| 118 |
+
/// Underlying matrix multiply operator (concept: arch::Mma)
|
| 119 |
+
using ArchMmaOperator = typename Policy::Operator;
|
| 120 |
+
|
| 121 |
+
/// Indicates math operator
|
| 122 |
+
using MathOperator = typename ArchMmaOperator::Operator;
|
| 123 |
+
|
| 124 |
+
/// Architecture tag from underlying instruction
|
| 125 |
+
using ArchTag = typename ArchMmaOperator::ArchTag;
|
| 126 |
+
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
|
| 127 |
+
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|
| 128 |
+
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
|
| 129 |
+
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
|
| 130 |
+
&& ArchTag::kMinComputeCapability >= 80),
|
| 131 |
+
"MmaTensorOpCvtBToA only supports underlying HMMA");
|
| 132 |
+
|
| 133 |
+
static_assert(platform::is_same<ElementA, half_t>::value
|
| 134 |
+
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
|
| 135 |
+
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
|
| 136 |
+
|
| 137 |
+
/// Indicates class of matrix operator
|
| 138 |
+
using OperatorClass = arch::OpClassTensorOp;
|
| 139 |
+
|
| 140 |
+
/// Shape of underlying instruction
|
| 141 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
| 142 |
+
|
| 143 |
+
/// Instruction shape to override shared memory iterators with
|
| 144 |
+
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
|
| 145 |
+
|
| 146 |
+
static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM,
|
| 147 |
+
"M dimension of compute instruction must match load");
|
| 148 |
+
static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN,
|
| 149 |
+
"N dimension of compute instruction must match load");
|
| 150 |
+
|
| 151 |
+
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
|
| 152 |
+
|
| 153 |
+
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
|
| 154 |
+
|
| 155 |
+
/// Complex transform on A operand
|
| 156 |
+
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
| 157 |
+
|
| 158 |
+
/// Complex transform on B operand
|
| 159 |
+
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
| 160 |
+
|
| 161 |
+
/// Number of threads participating in warp-level matrix product
|
| 162 |
+
static int const kThreadCount = 32;
|
| 163 |
+
|
| 164 |
+
/// Number of partitions along K dimension
|
| 165 |
+
static int const kPartitionsK = PartitionsK_;
|
| 166 |
+
|
| 167 |
+
public:
|
| 168 |
+
/// Iterates over the A operand in memory
|
| 169 |
+
using IteratorA = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>,
|
| 170 |
+
Operand::kA,
|
| 171 |
+
ElementA,
|
| 172 |
+
LayoutA,
|
| 173 |
+
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
| 174 |
+
Policy::OpDelta::kRow,
|
| 175 |
+
kThreadCount,
|
| 176 |
+
kPartitionsK>;
|
| 177 |
+
|
| 178 |
+
/// Storage for A tile
|
| 179 |
+
using FragmentA = typename IteratorA::Fragment;
|
| 180 |
+
|
| 181 |
+
/// Storage for transformed A tile
|
| 182 |
+
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
|
| 183 |
+
|
| 184 |
+
/// Iterates over the B operand in memory
|
| 185 |
+
using IteratorB =
|
| 186 |
+
MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>,
|
| 187 |
+
Operand::kB,
|
| 188 |
+
ElementB,
|
| 189 |
+
LayoutB,
|
| 190 |
+
MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>,
|
| 191 |
+
Policy::OpDelta::kRow,
|
| 192 |
+
kThreadCount,
|
| 193 |
+
kPartitionsK>;
|
| 194 |
+
|
| 195 |
+
/// Storage for B tile
|
| 196 |
+
using FragmentB = typename IteratorB::Fragment;
|
| 197 |
+
|
| 198 |
+
/// Storage for transformed B tile
|
| 199 |
+
using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
|
| 200 |
+
|
| 201 |
+
/// Iterates over the C operand in memory
|
| 202 |
+
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>,
|
| 203 |
+
ElementC,
|
| 204 |
+
LayoutC,
|
| 205 |
+
typename ArchMmaOperator::Shape,
|
| 206 |
+
typename Policy::OpDelta>;
|
| 207 |
+
|
| 208 |
+
/// Storage for C tile
|
| 209 |
+
using FragmentC = typename IteratorC::Fragment;
|
| 210 |
+
|
| 211 |
+
/// Number of mma operations performed
|
| 212 |
+
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
|
| 213 |
+
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
|
| 214 |
+
|
| 215 |
+
public:
|
| 216 |
+
/// Underlying matrix multiply operator (concept: arch::Mma)
|
| 217 |
+
ArchMmaOperator mma;
|
| 218 |
+
|
| 219 |
+
public:
|
| 220 |
+
//
|
| 221 |
+
// Methods
|
| 222 |
+
//
|
| 223 |
+
|
| 224 |
+
/// Ctor
|
| 225 |
+
CUTLASS_DEVICE
|
| 226 |
+
MmaTensorOpComputeBWithF16() {}
|
| 227 |
+
|
| 228 |
+
/// Performs a warp-level matrix multiply-accumulate operation
|
| 229 |
+
CUTLASS_DEVICE
|
| 230 |
+
void operator()(FragmentC& D,
|
| 231 |
+
TransformedFragmentA const& A,
|
| 232 |
+
TransformedFragmentB const& B,
|
| 233 |
+
FragmentC const& C,
|
| 234 |
+
const int warp_tileB_k_offset) const
|
| 235 |
+
{
|
| 236 |
+
|
| 237 |
+
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
| 238 |
+
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
| 239 |
+
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
| 240 |
+
|
| 241 |
+
static_assert(
|
| 242 |
+
TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
|
| 243 |
+
"Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B");
|
| 244 |
+
|
| 245 |
+
D = C;
|
| 246 |
+
|
| 247 |
+
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
|
| 248 |
+
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
|
| 249 |
+
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
|
| 250 |
+
|
| 251 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
| 252 |
+
// Serpentine visitation order maximizing reuse of Rb
|
| 253 |
+
CUTLASS_PRAGMA_UNROLL
|
| 254 |
+
for (int n = 0; n < MmaIterations::kColumn; ++n) {
|
| 255 |
+
|
| 256 |
+
CUTLASS_PRAGMA_UNROLL
|
| 257 |
+
for (int m = 0; m < MmaIterations::kRow; ++m) {
|
| 258 |
+
|
| 259 |
+
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
|
| 260 |
+
|
| 261 |
+
int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
|
| 262 |
+
if (AccumulatorsInRowMajor) { // matrix B is reordered
|
| 263 |
+
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn],
|
| 264 |
+
ptr_A[m_serpentine],
|
| 265 |
+
ptr_B[n_offsetB],
|
| 266 |
+
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
|
| 267 |
+
}
|
| 268 |
+
else {
|
| 269 |
+
mma(ptr_D[m_serpentine + n * MmaIterations::kRow],
|
| 270 |
+
ptr_A[m_serpentine],
|
| 271 |
+
ptr_B[n_offsetB],
|
| 272 |
+
ptr_D[m_serpentine + n * MmaIterations::kRow]);
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
| 277 |
+
// Serpentine visitation order maximizing reuse of Ra
|
| 278 |
+
CUTLASS_PRAGMA_UNROLL
|
| 279 |
+
for (int m = 0; m < MmaIterations::kRow; ++m) {
|
| 280 |
+
|
| 281 |
+
CUTLASS_PRAGMA_UNROLL
|
| 282 |
+
for (int n = 0; n < MmaIterations::kColumn; ++n) {
|
| 283 |
+
|
| 284 |
+
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
|
| 285 |
+
|
| 286 |
+
int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
|
| 287 |
+
if (AccumulatorsInRowMajor) { // matrix B is reordered
|
| 288 |
+
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn],
|
| 289 |
+
ptr_A[m],
|
| 290 |
+
ptr_B[n_serpentine_offsetB],
|
| 291 |
+
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
|
| 292 |
+
}
|
| 293 |
+
else {
|
| 294 |
+
mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
|
| 295 |
+
ptr_A[m],
|
| 296 |
+
ptr_B[n_serpentine_offsetB],
|
| 297 |
+
ptr_D[m + n_serpentine * MmaIterations::kRow]);
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
#else
|
| 302 |
+
assert(0);
|
| 303 |
+
#endif
|
| 304 |
+
}
|
| 305 |
+
};
|
| 306 |
+
|
| 307 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 308 |
+
|
| 309 |
+
} // namespace warp
|
| 310 |
+
} // namespace gemm
|
| 311 |
+
} // namespace cutlass
|
| 312 |
+
|
| 313 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/matrix_shape.h"
|
| 41 |
+
#include "cutlass/numeric_types.h"
|
| 42 |
+
#include "cutlass/tensor_ref.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/arch/arch.h"
|
| 45 |
+
#include "cutlass/arch/memory_sm75.h"
|
| 46 |
+
#include "cutlass/gemm/gemm.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/layout/matrix.h"
|
| 49 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 50 |
+
#include "cutlass/layout/tensor.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/functional.h"
|
| 53 |
+
#include "cutlass/platform/platform.h"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
namespace cutlass {
|
| 59 |
+
namespace gemm {
|
| 60 |
+
namespace warp {
|
| 61 |
+
|
| 62 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 63 |
+
|
| 64 |
+
template<
|
| 65 |
+
/// Matrix multiply operator
|
| 66 |
+
typename MmaOperator_,
|
| 67 |
+
/// Size of the matrix to load (concept: MatrixShape)
|
| 68 |
+
typename Shape_,
|
| 69 |
+
/// Operand identity
|
| 70 |
+
Operand Operand,
|
| 71 |
+
/// Data type of Scale elements
|
| 72 |
+
typename Element_,
|
| 73 |
+
/// Layout of operand
|
| 74 |
+
typename Layout_,
|
| 75 |
+
/// Number of threads participating in one matrix operation
|
| 76 |
+
int Threads,
|
| 77 |
+
///
|
| 78 |
+
typename Enable = void>
|
| 79 |
+
class MmaTensorOpDequantizer;
|
| 80 |
+
|
| 81 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 82 |
+
// Bfloat specialization for Ampere
|
| 83 |
+
template<
|
| 84 |
+
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
| 85 |
+
typename MmaOperator_,
|
| 86 |
+
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
| 87 |
+
typename Shape_>
|
| 88 |
+
class MmaTensorOpDequantizer<
|
| 89 |
+
MmaOperator_,
|
| 90 |
+
Shape_,
|
| 91 |
+
Operand::kB,
|
| 92 |
+
bfloat16_t,
|
| 93 |
+
layout::RowMajor,
|
| 94 |
+
32,
|
| 95 |
+
typename platform::enable_if<
|
| 96 |
+
MmaOperator_::ArchTag::kMinComputeCapability >= 80
|
| 97 |
+
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
|
| 98 |
+
|
| 99 |
+
public:
|
| 100 |
+
/// Mma Operator
|
| 101 |
+
using MmaOperator = MmaOperator_;
|
| 102 |
+
|
| 103 |
+
// The architecture specific mma ooperator being used
|
| 104 |
+
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
| 105 |
+
|
| 106 |
+
// Mma Instruction Shape
|
| 107 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
| 108 |
+
|
| 109 |
+
// This is the ratio of the load instruction vs the compute instruction.
|
| 110 |
+
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
| 111 |
+
|
| 112 |
+
/// Type of the scales
|
| 113 |
+
using ElementScale = bfloat16_t;
|
| 114 |
+
|
| 115 |
+
/// Fragment to hold B data before Mma
|
| 116 |
+
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
| 117 |
+
|
| 118 |
+
// Fragment to hold scale data to apply to B before mma
|
| 119 |
+
// We need 1 fp16 per matrix iteration in the N dimension
|
| 120 |
+
static constexpr int kColsPerMmaPerThread = 1;
|
| 121 |
+
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
| 122 |
+
|
| 123 |
+
/// Warp mma shape
|
| 124 |
+
using Shape = Shape_;
|
| 125 |
+
|
| 126 |
+
/// Layout of the scales in shared memory
|
| 127 |
+
using Layout = layout::RowMajor;
|
| 128 |
+
|
| 129 |
+
/// TensorRef type for loading element from a tensor
|
| 130 |
+
using TensorRef = TensorRef<ElementScale, Layout>;
|
| 131 |
+
|
| 132 |
+
CUTLASS_DEVICE
|
| 133 |
+
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
| 134 |
+
{
|
| 135 |
+
const int warp_offset = warp_idx_n * Shape::kN;
|
| 136 |
+
const int quad = lane_idx / 4;
|
| 137 |
+
const int thread_offset = warp_offset + quad;
|
| 138 |
+
pointer_ = smem_scales.data() + thread_offset;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
CUTLASS_DEVICE
|
| 142 |
+
void load(FragmentScale& scale_frag)
|
| 143 |
+
{
|
| 144 |
+
|
| 145 |
+
CUTLASS_PRAGMA_UNROLL
|
| 146 |
+
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
|
| 147 |
+
scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
CUTLASS_DEVICE
|
| 152 |
+
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
| 153 |
+
{
|
| 154 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
| 155 |
+
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
| 156 |
+
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
| 157 |
+
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
| 158 |
+
== FragmentDequantizedOperand::kElements,
|
| 159 |
+
"");
|
| 160 |
+
|
| 161 |
+
const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);
|
| 162 |
+
|
| 163 |
+
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
| 164 |
+
CUTLASS_PRAGMA_UNROLL
|
| 165 |
+
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
|
| 166 |
+
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
|
| 167 |
+
|
| 168 |
+
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
|
| 169 |
+
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
|
| 170 |
+
CUTLASS_PRAGMA_UNROLL
|
| 171 |
+
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) {
|
| 172 |
+
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
#else
|
| 176 |
+
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
|
| 177 |
+
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
|
| 178 |
+
// numerous conversion instructions in GEMM main loop.
|
| 179 |
+
arch::device_breakpoint();
|
| 180 |
+
#endif
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
private:
|
| 184 |
+
ElementScale const* pointer_;
|
| 185 |
+
};
|
| 186 |
+
|
| 187 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 188 |
+
|
| 189 |
+
// Specialization for Turing & Ampere
|
| 190 |
+
template<
|
| 191 |
+
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
| 192 |
+
typename MmaOperator_,
|
| 193 |
+
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
| 194 |
+
typename Shape_>
|
| 195 |
+
class MmaTensorOpDequantizer<
|
| 196 |
+
MmaOperator_,
|
| 197 |
+
Shape_,
|
| 198 |
+
Operand::kB,
|
| 199 |
+
half_t,
|
| 200 |
+
layout::RowMajor,
|
| 201 |
+
32,
|
| 202 |
+
typename platform::enable_if<
|
| 203 |
+
MmaOperator_::ArchTag::kMinComputeCapability >= 75
|
| 204 |
+
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
|
| 205 |
+
|
| 206 |
+
public:
|
| 207 |
+
/// Mma Operator
|
| 208 |
+
using MmaOperator = MmaOperator_;
|
| 209 |
+
|
| 210 |
+
// The architecture specific mma ooperator being used
|
| 211 |
+
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
| 212 |
+
|
| 213 |
+
// Mma Instruction Shape
|
| 214 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
| 215 |
+
|
| 216 |
+
// This is the ratio of the load instruction vs the compute instruction.
|
| 217 |
+
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
| 218 |
+
|
| 219 |
+
/// Type of the scales
|
| 220 |
+
using ElementScale = half_t;
|
| 221 |
+
|
| 222 |
+
/// Fragment to hold B data before Mma
|
| 223 |
+
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
| 224 |
+
|
| 225 |
+
// Fragment to hold scale data to apply to B before mma
|
| 226 |
+
// We need 1 fp16 per matrix iteration in the N dimension
|
| 227 |
+
static constexpr int kColsPerMmaPerThread = 1;
|
| 228 |
+
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
| 229 |
+
|
| 230 |
+
/// Warp mma shape
|
| 231 |
+
using Shape = Shape_;
|
| 232 |
+
|
| 233 |
+
/// Layout of the scales in shared memory
|
| 234 |
+
using Layout = layout::RowMajor;
|
| 235 |
+
|
| 236 |
+
/// TensorRef type for loading element from a tensor
|
| 237 |
+
using TensorRef = TensorRef<ElementScale, Layout>;
|
| 238 |
+
|
| 239 |
+
CUTLASS_DEVICE
|
| 240 |
+
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
| 241 |
+
{
|
| 242 |
+
const int warp_offset = warp_idx_n * Shape::kN;
|
| 243 |
+
const int quad = lane_idx / 4;
|
| 244 |
+
const int thread_offset = warp_offset + quad;
|
| 245 |
+
pointer_ = smem_scales.data() + thread_offset;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
CUTLASS_DEVICE
|
| 249 |
+
void load(FragmentScale& scale_frag)
|
| 250 |
+
{
|
| 251 |
+
|
| 252 |
+
CUTLASS_PRAGMA_UNROLL
|
| 253 |
+
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
|
| 254 |
+
scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
CUTLASS_DEVICE
|
| 259 |
+
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
| 260 |
+
{
|
| 261 |
+
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
| 262 |
+
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
| 263 |
+
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
| 264 |
+
== FragmentDequantizedOperand::kElements,
|
| 265 |
+
"");
|
| 266 |
+
|
| 267 |
+
multiplies<ExpandedMmaOperandB> mul_op;
|
| 268 |
+
|
| 269 |
+
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
| 270 |
+
CUTLASS_PRAGMA_UNROLL
|
| 271 |
+
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
|
| 272 |
+
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
private:
|
| 277 |
+
ElementScale const* pointer_;
|
| 278 |
+
};
|
| 279 |
+
|
| 280 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 281 |
+
|
| 282 |
+
// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm
|
| 283 |
+
template<
|
| 284 |
+
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
| 285 |
+
typename MmaOperator_,
|
| 286 |
+
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
| 287 |
+
typename Shape_>
|
| 288 |
+
class MmaTensorOpDequantizer<
|
| 289 |
+
MmaOperator_,
|
| 290 |
+
Shape_,
|
| 291 |
+
Operand::kB,
|
| 292 |
+
half_t,
|
| 293 |
+
layout::RowMajor,
|
| 294 |
+
32,
|
| 295 |
+
typename platform::enable_if<
|
| 296 |
+
platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
|
| 297 |
+
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::RowMajor>::value>::type> {
|
| 298 |
+
|
| 299 |
+
public:
|
| 300 |
+
static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
|
| 301 |
+
|
| 302 |
+
/// Mma Operator
|
| 303 |
+
using MmaOperator = MmaOperator_;
|
| 304 |
+
|
| 305 |
+
// The architecture specific mma ooperator being used
|
| 306 |
+
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
| 307 |
+
|
| 308 |
+
// Mma Instruction Shape
|
| 309 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
| 310 |
+
|
| 311 |
+
/// Type of the scales
|
| 312 |
+
using ElementScale = half_t;
|
| 313 |
+
|
| 314 |
+
/// Fragment to hold B data before Mma
|
| 315 |
+
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
| 316 |
+
|
| 317 |
+
/// Warp mma shape
|
| 318 |
+
using Shape = Shape_;
|
| 319 |
+
|
| 320 |
+
// Fragment to hold scale data to apply to B before mma
|
| 321 |
+
// Each 32x32x4 matmul uses 8 elements from B.
|
| 322 |
+
static constexpr int ColsPerMmaTile = 32;
|
| 323 |
+
static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
|
| 324 |
+
using FragmentScale = Array<ElementScale, TileNIterations * 8>;
|
| 325 |
+
using AccessType = Array<ElementScale, 8>;
|
| 326 |
+
|
| 327 |
+
/// Layout of the scales in shared memory
|
| 328 |
+
using Layout = layout::RowMajor;
|
| 329 |
+
|
| 330 |
+
/// TensorRef type for loading element from a tensor
|
| 331 |
+
using TensorRef = TensorRef<ElementScale, Layout>;
|
| 332 |
+
|
| 333 |
+
CUTLASS_DEVICE
|
| 334 |
+
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
| 335 |
+
{
|
| 336 |
+
const int warp_offset = warp_idx_n * Shape::kN;
|
| 337 |
+
const int base_col = lane_idx & 0xF8;
|
| 338 |
+
const int thread_offset = warp_offset + base_col;
|
| 339 |
+
pointer_ = smem_scales.data() + thread_offset;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
CUTLASS_DEVICE
|
| 343 |
+
void load(FragmentScale& scale_frag)
|
| 344 |
+
{
|
| 345 |
+
AccessType* scale_frag_ptr = reinterpret_cast<AccessType*>(&scale_frag);
|
| 346 |
+
|
| 347 |
+
CUTLASS_PRAGMA_UNROLL
|
| 348 |
+
for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
|
| 349 |
+
// We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
|
| 350 |
+
scale_frag_ptr[tile_iter] = *reinterpret_cast<AccessType const*>(pointer_ + ColsPerMmaTile * tile_iter);
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
CUTLASS_DEVICE
|
| 355 |
+
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
| 356 |
+
{
|
| 357 |
+
static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
|
| 358 |
+
|
| 359 |
+
multiplies<FragmentDequantizedOperand> mul_op;
|
| 360 |
+
operand_frag = mul_op(operand_frag, scale_frag);
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
private:
|
| 364 |
+
ElementScale const* pointer_;
|
| 365 |
+
};
|
| 366 |
+
|
| 367 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 368 |
+
|
| 369 |
+
// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm
|
| 370 |
+
template<
|
| 371 |
+
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
| 372 |
+
typename MmaOperator_,
|
| 373 |
+
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
| 374 |
+
typename Shape_>
|
| 375 |
+
class MmaTensorOpDequantizer<
|
| 376 |
+
MmaOperator_,
|
| 377 |
+
Shape_,
|
| 378 |
+
Operand::kB,
|
| 379 |
+
half_t,
|
| 380 |
+
layout::RowMajor,
|
| 381 |
+
32,
|
| 382 |
+
typename platform::enable_if<
|
| 383 |
+
platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
|
| 384 |
+
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
|
| 385 |
+
|
| 386 |
+
public:
|
| 387 |
+
static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
|
| 388 |
+
|
| 389 |
+
/// Mma Operator
|
| 390 |
+
using MmaOperator = MmaOperator_;
|
| 391 |
+
|
| 392 |
+
// The architecture specific mma ooperator being used
|
| 393 |
+
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
| 394 |
+
|
| 395 |
+
// Mma Instruction Shape
|
| 396 |
+
using InstructionShape = typename ArchMmaOperator::Shape;
|
| 397 |
+
|
| 398 |
+
/// Type of the scales
|
| 399 |
+
using ElementScale = half_t;
|
| 400 |
+
|
| 401 |
+
/// Fragment to hold B data before Mma
|
| 402 |
+
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
| 403 |
+
|
| 404 |
+
/// Warp mma shape
|
| 405 |
+
using Shape = Shape_;
|
| 406 |
+
|
| 407 |
+
// Fragment to hold scale data to apply to B before mma
|
| 408 |
+
// Each 32x32x4 matmul uses 8 elements from B.
|
| 409 |
+
static constexpr int ColsPerMmaTile = 32;
|
| 410 |
+
static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
|
| 411 |
+
using FragmentScale = Array<ElementScale, TileNIterations * 2>;
|
| 412 |
+
|
| 413 |
+
/// Layout of the scales in shared memory
|
| 414 |
+
using Layout = layout::RowMajor;
|
| 415 |
+
|
| 416 |
+
/// TensorRef type for loading element from a tensor
|
| 417 |
+
using TensorRef = TensorRef<ElementScale, Layout>;
|
| 418 |
+
|
| 419 |
+
CUTLASS_DEVICE
|
| 420 |
+
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
| 421 |
+
{
|
| 422 |
+
const int warp_offset = warp_idx_n * Shape::kN;
|
| 423 |
+
const int base_col = lane_idx & 0xF8 + lane_idx % 4;
|
| 424 |
+
const int thread_offset = warp_offset + base_col;
|
| 425 |
+
pointer_ = smem_scales.data() + thread_offset;
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
CUTLASS_DEVICE
|
| 429 |
+
void load(FragmentScale& scale_frag)
|
| 430 |
+
{
|
| 431 |
+
CUTLASS_PRAGMA_UNROLL
|
| 432 |
+
for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
|
| 433 |
+
// We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
|
| 434 |
+
// For col major B, each thread will jump 4 cols to get its next value inside
|
| 435 |
+
// of the super mma.
|
| 436 |
+
CUTLASS_PRAGMA_UNROLL
|
| 437 |
+
for (int mma_iter = 0; mma_iter < 2; ++mma_iter) {
|
| 438 |
+
scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter];
|
| 439 |
+
}
|
| 440 |
+
}
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
CUTLASS_DEVICE
|
| 444 |
+
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
| 445 |
+
{
|
| 446 |
+
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
| 447 |
+
static constexpr int total_n_mmas = 2 * TileNIterations;
|
| 448 |
+
static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, "");
|
| 449 |
+
|
| 450 |
+
multiplies<MmaOperandB> mul_op;
|
| 451 |
+
|
| 452 |
+
MmaOperandB* operand_frag_ptr = reinterpret_cast<MmaOperandB*>(&operand_frag);
|
| 453 |
+
CUTLASS_PRAGMA_UNROLL
|
| 454 |
+
for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) {
|
| 455 |
+
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
| 456 |
+
}
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
private:
|
| 460 |
+
ElementScale const* pointer_;
|
| 461 |
+
};
|
| 462 |
+
|
| 463 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 464 |
+
|
| 465 |
+
} // namespace warp
|
| 466 |
+
} // namespace gemm
|
| 467 |
+
} // namespace cutlass
|
| 468 |
+
|
| 469 |
+
////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*!
|
| 32 |
+
\file
|
| 33 |
+
\brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/arch/arch.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/half.h"
|
| 41 |
+
#include "cutlass/numeric_types.h"
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
|
| 45 |
+
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
|
| 46 |
+
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
|
| 47 |
+
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
|
| 48 |
+
// This converter will uninterleave the data and subtract the bias while converting to the result type.
|
| 49 |
+
template<typename T, typename S, int N>
|
| 50 |
+
struct FastInterleavedAndBiasedNumericArrayConverter {
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
template<>
|
| 54 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4> {
|
| 55 |
+
using result_type = Array<half_t, 4>;
|
| 56 |
+
using source_type = Array<uint8_t, 4>;
|
| 57 |
+
|
| 58 |
+
CUTLASS_DEVICE
|
| 59 |
+
static result_type convert(source_type const& source)
|
| 60 |
+
{
|
| 61 |
+
result_type result;
|
| 62 |
+
|
| 63 |
+
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
| 64 |
+
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
| 65 |
+
|
| 66 |
+
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
| 67 |
+
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
| 68 |
+
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
| 69 |
+
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
|
| 70 |
+
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
|
| 71 |
+
|
| 72 |
+
// Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
|
| 73 |
+
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
| 74 |
+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
| 75 |
+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
| 76 |
+
|
| 77 |
+
return result;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
CUTLASS_DEVICE
|
| 81 |
+
result_type operator()(source_type const& s)
|
| 82 |
+
{
|
| 83 |
+
return convert(s);
|
| 84 |
+
}
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
template<int N>
|
| 88 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N> {
|
| 89 |
+
static constexpr int VEC_WIDTH = 4;
|
| 90 |
+
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
| 91 |
+
|
| 92 |
+
using result_type = Array<half_t, N>;
|
| 93 |
+
using source_type = Array<uint8_t, N>;
|
| 94 |
+
|
| 95 |
+
CUTLASS_DEVICE
|
| 96 |
+
static result_type convert(source_type const& source)
|
| 97 |
+
{
|
| 98 |
+
using scalar_result_type = typename result_type::Element;
|
| 99 |
+
using scalar_source_type = typename source_type::Element;
|
| 100 |
+
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
| 101 |
+
convert_vector_;
|
| 102 |
+
|
| 103 |
+
result_type result;
|
| 104 |
+
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
| 105 |
+
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
| 106 |
+
|
| 107 |
+
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
| 108 |
+
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
| 109 |
+
|
| 110 |
+
CUTLASS_PRAGMA_UNROLL
|
| 111 |
+
for (int i = 0; i < N / VEC_WIDTH; ++i) {
|
| 112 |
+
result_ptr[i] = convert_vector_(source_ptr[i]);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
return result;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
CUTLASS_DEVICE
|
| 119 |
+
result_type operator()(source_type const& s)
|
| 120 |
+
{
|
| 121 |
+
return convert(s);
|
| 122 |
+
}
|
| 123 |
+
};
|
| 124 |
+
|
| 125 |
+
template<>
|
| 126 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4> {
|
| 127 |
+
using result_type = Array<bfloat16_t, 4>;
|
| 128 |
+
using source_type = Array<uint8_t, 4>;
|
| 129 |
+
|
| 130 |
+
CUTLASS_DEVICE
|
| 131 |
+
static result_type convert(source_type const& source)
|
| 132 |
+
{
|
| 133 |
+
result_type result;
|
| 134 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
| 135 |
+
|
| 136 |
+
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
|
| 137 |
+
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
| 138 |
+
|
| 139 |
+
static constexpr uint32_t fp32_base = 0x4B000000;
|
| 140 |
+
float fp32_intermediates[4];
|
| 141 |
+
|
| 142 |
+
// Construct FP32s, bfloat does not have enough mantissa for IADD trick
|
| 143 |
+
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
| 144 |
+
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
|
| 145 |
+
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
|
| 146 |
+
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
|
| 147 |
+
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
|
| 148 |
+
|
| 149 |
+
// Subtract out fp32_base + 128 to make the unsigned integer signed.
|
| 150 |
+
CUTLASS_PRAGMA_UNROLL
|
| 151 |
+
for (int ii = 0; ii < 4; ++ii) {
|
| 152 |
+
fp32_intermediates[ii] -= 8388736.f;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
// Truncate the fp32 representation and pack up as bfloat16s.
|
| 156 |
+
CUTLASS_PRAGMA_UNROLL
|
| 157 |
+
for (int ii = 0; ii < 2; ++ii) {
|
| 158 |
+
bf16_result_ptr[ii] =
|
| 159 |
+
__byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
|
| 160 |
+
}
|
| 161 |
+
#else
|
| 162 |
+
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
|
| 163 |
+
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
|
| 164 |
+
result.clear(); // Suppress compiler warning
|
| 165 |
+
arch::device_breakpoint();
|
| 166 |
+
#endif
|
| 167 |
+
return result;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
CUTLASS_DEVICE
|
| 171 |
+
result_type operator()(source_type const& s)
|
| 172 |
+
{
|
| 173 |
+
return convert(s);
|
| 174 |
+
}
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
template<int N>
|
| 178 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N> {
|
| 179 |
+
static constexpr int VEC_WIDTH = 4;
|
| 180 |
+
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
| 181 |
+
|
| 182 |
+
using result_type = Array<bfloat16_t, N>;
|
| 183 |
+
using source_type = Array<uint8_t, N>;
|
| 184 |
+
|
| 185 |
+
CUTLASS_DEVICE
|
| 186 |
+
static result_type convert(source_type const& source)
|
| 187 |
+
{
|
| 188 |
+
using scalar_result_type = typename result_type::Element;
|
| 189 |
+
using scalar_source_type = typename source_type::Element;
|
| 190 |
+
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
| 191 |
+
convert_vector_;
|
| 192 |
+
|
| 193 |
+
result_type result;
|
| 194 |
+
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
| 195 |
+
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
| 196 |
+
|
| 197 |
+
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
| 198 |
+
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
| 199 |
+
|
| 200 |
+
CUTLASS_PRAGMA_UNROLL
|
| 201 |
+
for (int i = 0; i < N / VEC_WIDTH; ++i) {
|
| 202 |
+
result_ptr[i] = convert_vector_(source_ptr[i]);
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
return result;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
CUTLASS_DEVICE
|
| 209 |
+
result_type operator()(source_type const& s)
|
| 210 |
+
{
|
| 211 |
+
return convert(s);
|
| 212 |
+
}
|
| 213 |
+
};
|
| 214 |
+
|
| 215 |
+
template<>
|
| 216 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8> {
|
| 217 |
+
using result_type = Array<half_t, 8>;
|
| 218 |
+
using source_type = Array<uint4b_t, 8>;
|
| 219 |
+
|
| 220 |
+
CUTLASS_DEVICE
|
| 221 |
+
static result_type convert(source_type const& source)
|
| 222 |
+
{
|
| 223 |
+
result_type result;
|
| 224 |
+
|
| 225 |
+
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
| 226 |
+
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
| 227 |
+
|
| 228 |
+
// First, we extract the i4s and construct an intermediate fp16 number.
|
| 229 |
+
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
| 230 |
+
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
| 231 |
+
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
| 232 |
+
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
| 233 |
+
|
| 234 |
+
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
| 235 |
+
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
| 236 |
+
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
| 237 |
+
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
| 238 |
+
|
| 239 |
+
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
| 240 |
+
// immediately before required.
|
| 241 |
+
const uint32_t top_i4s = i4s >> 8;
|
| 242 |
+
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
| 243 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
| 244 |
+
: "=r"(h[0])
|
| 245 |
+
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
| 246 |
+
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
| 247 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
| 248 |
+
: "=r"(h[1])
|
| 249 |
+
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
| 250 |
+
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
| 251 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
| 252 |
+
: "=r"(h[2])
|
| 253 |
+
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
| 254 |
+
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
| 255 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
| 256 |
+
: "=r"(h[3])
|
| 257 |
+
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
| 258 |
+
|
| 259 |
+
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
| 260 |
+
// half2 ctor. In this case, I chose performance reliability over code readability.
|
| 261 |
+
|
| 262 |
+
// This is the half2 {1032, 1032} represented as an integer.
|
| 263 |
+
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
| 264 |
+
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
| 265 |
+
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
| 266 |
+
// This is the half2 {-72, -72} represented as an integer.
|
| 267 |
+
static constexpr uint32_t NEG_72 = 0xd480d480;
|
| 268 |
+
|
| 269 |
+
// Finally, we construct the output numbers.
|
| 270 |
+
// Convert elt_01
|
| 271 |
+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
| 272 |
+
// Convert elt_23
|
| 273 |
+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
|
| 274 |
+
// Convert elt_45
|
| 275 |
+
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
| 276 |
+
// Convert elt_67
|
| 277 |
+
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
|
| 278 |
+
|
| 279 |
+
return result;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
CUTLASS_DEVICE
|
| 283 |
+
result_type operator()(source_type const& s)
|
| 284 |
+
{
|
| 285 |
+
return convert(s);
|
| 286 |
+
}
|
| 287 |
+
};
|
| 288 |
+
|
| 289 |
+
template<int N>
|
| 290 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N> {
|
| 291 |
+
static constexpr int VEC_WIDTH = 8;
|
| 292 |
+
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
|
| 293 |
+
|
| 294 |
+
using result_type = Array<half_t, N>;
|
| 295 |
+
using source_type = Array<uint4b_t, N>;
|
| 296 |
+
|
| 297 |
+
CUTLASS_DEVICE
|
| 298 |
+
static result_type convert(source_type const& source)
|
| 299 |
+
{
|
| 300 |
+
using scalar_result_type = typename result_type::Element;
|
| 301 |
+
using scalar_source_type = typename source_type::Element;
|
| 302 |
+
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
| 303 |
+
convert_vector_;
|
| 304 |
+
|
| 305 |
+
result_type result;
|
| 306 |
+
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
| 307 |
+
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
| 308 |
+
|
| 309 |
+
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
| 310 |
+
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
| 311 |
+
|
| 312 |
+
CUTLASS_PRAGMA_UNROLL
|
| 313 |
+
for (int i = 0; i < N / VEC_WIDTH; ++i) {
|
| 314 |
+
result_ptr[i] = convert_vector_(source_ptr[i]);
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
return result;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
CUTLASS_DEVICE
|
| 321 |
+
result_type operator()(source_type const& s)
|
| 322 |
+
{
|
| 323 |
+
return convert(s);
|
| 324 |
+
}
|
| 325 |
+
};
|
| 326 |
+
|
| 327 |
+
template<>
|
| 328 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8> {
|
| 329 |
+
using result_type = Array<bfloat16_t, 8>;
|
| 330 |
+
using source_type = Array<uint4b_t, 8>;
|
| 331 |
+
|
| 332 |
+
CUTLASS_DEVICE
|
| 333 |
+
static result_type convert(source_type const& source)
|
| 334 |
+
{
|
| 335 |
+
result_type result;
|
| 336 |
+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
| 337 |
+
|
| 338 |
+
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
| 339 |
+
uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
|
| 340 |
+
|
| 341 |
+
// First, we extract the i4s and construct an intermediate fp16 number.
|
| 342 |
+
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
| 343 |
+
static constexpr uint32_t MASK = 0x000f000f;
|
| 344 |
+
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
|
| 345 |
+
|
| 346 |
+
// We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
|
| 347 |
+
// No shift needed for first item.
|
| 348 |
+
uint32_t i4s = source_i4s;
|
| 349 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
| 350 |
+
: "=r"(h[0])
|
| 351 |
+
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
| 352 |
+
CUTLASS_PRAGMA_UNROLL
|
| 353 |
+
for (int ii = 1; ii < result_type::kElements / 2; ++ii) {
|
| 354 |
+
i4s >>= sizeof_bits<typename source_type::Element>::value;
|
| 355 |
+
// (i4s & 0x000f000f) | 0x43004300
|
| 356 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
| 357 |
+
: "=r"(h[ii])
|
| 358 |
+
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
// This is the BF16 {-136, -136} represented as an integer.
|
| 362 |
+
static constexpr uint32_t BF16_BIAS = 0xC308C308;
|
| 363 |
+
static constexpr uint32_t BF16_ONE = 0x3F803F80;
|
| 364 |
+
|
| 365 |
+
// Finally, we construct the output numbers.
|
| 366 |
+
CUTLASS_PRAGMA_UNROLL
|
| 367 |
+
for (int ii = 0; ii < result_type::kElements / 2; ++ii) {
|
| 368 |
+
// Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
|
| 369 |
+
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
|
| 370 |
+
}
|
| 371 |
+
#else
|
| 372 |
+
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
|
| 373 |
+
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
|
| 374 |
+
arch::device_breakpoint();
|
| 375 |
+
result.clear(); // Suppress compiler warning.
|
| 376 |
+
#endif
|
| 377 |
+
return result;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
CUTLASS_DEVICE
|
| 381 |
+
result_type operator()(source_type const& s)
|
| 382 |
+
{
|
| 383 |
+
return convert(s);
|
| 384 |
+
}
|
| 385 |
+
};
|
| 386 |
+
|
| 387 |
+
template<int N>
|
| 388 |
+
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N> {
|
| 389 |
+
static constexpr int VEC_WIDTH = 8;
|
| 390 |
+
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
|
| 391 |
+
|
| 392 |
+
using result_type = Array<bfloat16_t, N>;
|
| 393 |
+
using source_type = Array<uint4b_t, N>;
|
| 394 |
+
|
| 395 |
+
CUTLASS_DEVICE
|
| 396 |
+
static result_type convert(source_type const& source)
|
| 397 |
+
{
|
| 398 |
+
using scalar_result_type = typename result_type::Element;
|
| 399 |
+
using scalar_source_type = typename source_type::Element;
|
| 400 |
+
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
| 401 |
+
convert_vector_;
|
| 402 |
+
|
| 403 |
+
result_type result;
|
| 404 |
+
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
| 405 |
+
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
| 406 |
+
|
| 407 |
+
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
| 408 |
+
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
| 409 |
+
|
| 410 |
+
CUTLASS_PRAGMA_UNROLL
|
| 411 |
+
for (int i = 0; i < N / VEC_WIDTH; ++i) {
|
| 412 |
+
result_ptr[i] = convert_vector_(source_ptr[i]);
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
return result;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
CUTLASS_DEVICE
|
| 419 |
+
result_type operator()(source_type const& s)
|
| 420 |
+
{
|
| 421 |
+
return convert(s);
|
| 422 |
+
}
|
| 423 |
+
};
|
| 424 |
+
|
| 425 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 426 |
+
|
| 427 |
+
} // namespace cutlass
|
| 428 |
+
|
| 429 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Defines new layouts needed for MoE
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include "cutlass/fast_math.h"
|
| 38 |
+
#include "cutlass/matrix_coord.h"
|
| 39 |
+
#include "cutlass/pitch_linear_coord.h"
|
| 40 |
+
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
namespace layout {
|
| 43 |
+
|
| 44 |
+
template<int RowsPerTile, int ColumnsInterleaved>
|
| 45 |
+
class ColumnMajorTileInterleave {
|
| 46 |
+
static constexpr int kRowsPerTile = RowsPerTile;
|
| 47 |
+
static constexpr int kColumnsInterleaved = ColumnsInterleaved;
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
template<class T>
|
| 51 |
+
struct IsColumnMajorTileInterleave {
|
| 52 |
+
static constexpr bool value = false;
|
| 53 |
+
};
|
| 54 |
+
|
| 55 |
+
template<int U, int V>
|
| 56 |
+
struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>> {
|
| 57 |
+
static constexpr bool value = true;
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
} // namespace layout
|
| 61 |
+
} // namespace cutlass
|
cutlass_kernels/cutlass_heuristic.cu
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#include "cutlass_heuristic.h"
|
| 18 |
+
#include "cutlass/gemm/gemm.h"
|
| 19 |
+
#include <cuda_runtime_api.h>
|
| 20 |
+
|
| 21 |
+
#include <vector>
|
| 22 |
+
#include <stdexcept>
|
| 23 |
+
|
| 24 |
+
namespace fastertransformer {
|
| 25 |
+
|
| 26 |
+
struct TileShape {
|
| 27 |
+
int m;
|
| 28 |
+
int n;
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
|
| 32 |
+
{
|
| 33 |
+
switch (tile_config) {
|
| 34 |
+
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
| 35 |
+
return TileShape{32, 128};
|
| 36 |
+
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
|
| 37 |
+
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
| 38 |
+
return TileShape{64, 128};
|
| 39 |
+
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
|
| 40 |
+
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
|
| 41 |
+
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
| 42 |
+
return TileShape{128, 128};
|
| 43 |
+
default:
|
| 44 |
+
throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config");
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
bool is_valid_split_k_factor(const int64_t m,
|
| 49 |
+
const int64_t n,
|
| 50 |
+
const int64_t k,
|
| 51 |
+
const TileShape tile_shape,
|
| 52 |
+
const int split_k_factor,
|
| 53 |
+
const size_t workspace_bytes,
|
| 54 |
+
const bool is_weight_only)
|
| 55 |
+
{
|
| 56 |
+
|
| 57 |
+
// All tile sizes have a k_tile of 64.
|
| 58 |
+
static constexpr int k_tile = 64;
|
| 59 |
+
|
| 60 |
+
// For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k
|
| 61 |
+
if (is_weight_only) {
|
| 62 |
+
if ((k % k_tile) != 0) {
|
| 63 |
+
return false;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
if ((k % split_k_factor) != 0) {
|
| 67 |
+
return false;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
const int k_elements_per_split = k / split_k_factor;
|
| 71 |
+
if ((k_elements_per_split % k_tile) != 0) {
|
| 72 |
+
return false;
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// Check that the workspace has sufficient space for this split-k factor
|
| 77 |
+
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
| 78 |
+
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
| 79 |
+
const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
|
| 80 |
+
|
| 81 |
+
if (required_ws_bytes > workspace_bytes) {
|
| 82 |
+
return false;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
return true;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
std::vector<CutlassTileConfig> get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only)
|
| 89 |
+
{
|
| 90 |
+
|
| 91 |
+
std::vector<CutlassTileConfig> simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
|
| 92 |
+
|
| 93 |
+
std::vector<CutlassTileConfig> square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
|
| 94 |
+
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
|
| 95 |
+
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64};
|
| 96 |
+
|
| 97 |
+
std::vector<CutlassTileConfig> quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
|
| 98 |
+
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
|
| 99 |
+
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
|
| 100 |
+
|
| 101 |
+
const std::vector<CutlassTileConfig> allowed_configs = is_weight_only ? quant_B_configs : square_configs;
|
| 102 |
+
return simt_configs_only ? simt_configs : allowed_configs;
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only)
|
| 106 |
+
{
|
| 107 |
+
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(is_weight_only, simt_configs_only);
|
| 108 |
+
|
| 109 |
+
std::vector<CutlassGemmConfig> candidate_configs;
|
| 110 |
+
const int min_stages = 2;
|
| 111 |
+
const int max_stages = sm >= 80 ? 4 : 2;
|
| 112 |
+
|
| 113 |
+
for (const auto& tile_config : tiles) {
|
| 114 |
+
for (int stages = min_stages; stages <= max_stages; ++stages) {
|
| 115 |
+
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
|
| 116 |
+
candidate_configs.push_back(config);
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
return candidate_configs;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
|
| 124 |
+
const std::vector<int>& occupancies,
|
| 125 |
+
const int64_t m,
|
| 126 |
+
const int64_t n,
|
| 127 |
+
const int64_t k,
|
| 128 |
+
const int64_t num_experts,
|
| 129 |
+
const int split_k_limit,
|
| 130 |
+
const size_t workspace_bytes,
|
| 131 |
+
const int multi_processor_count,
|
| 132 |
+
const int is_weight_only)
|
| 133 |
+
{
|
| 134 |
+
|
| 135 |
+
if (occupancies.size() != candidate_configs.size()) {
|
| 136 |
+
throw std::runtime_error("[FT Error][estimate_best_config_from_occupancies] occpancies and "
|
| 137 |
+
"candidate configs vectors must have equal length.");
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
CutlassGemmConfig best_config;
|
| 141 |
+
// Score will be [0, 1]. The objective is to minimize this score.
|
| 142 |
+
// It represents the fraction of SM resources unused in the last wave.
|
| 143 |
+
float config_score = 1.0f;
|
| 144 |
+
int config_waves = INT_MAX;
|
| 145 |
+
int current_m_tile = 0;
|
| 146 |
+
|
| 147 |
+
const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
|
| 148 |
+
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
|
| 149 |
+
CutlassGemmConfig candidate_config = candidate_configs[ii];
|
| 150 |
+
TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
|
| 151 |
+
int occupancy = occupancies[ii];
|
| 152 |
+
|
| 153 |
+
if (occupancy == 0) {
|
| 154 |
+
continue;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
// Keep small tile sizes when possible.
|
| 158 |
+
if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile
|
| 159 |
+
&& current_m_tile < tile_shape.m) {
|
| 160 |
+
continue;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
| 164 |
+
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
| 165 |
+
|
| 166 |
+
for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) {
|
| 167 |
+
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) {
|
| 168 |
+
const int ctas_per_wave = occupancy * multi_processor_count;
|
| 169 |
+
const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
|
| 170 |
+
|
| 171 |
+
const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
|
| 172 |
+
const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
|
| 173 |
+
const float current_score = float(num_waves_total) - num_waves_fractional;
|
| 174 |
+
|
| 175 |
+
const float score_slack = 0.1f;
|
| 176 |
+
if (current_score < config_score
|
| 177 |
+
|| ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) {
|
| 178 |
+
config_score = current_score;
|
| 179 |
+
config_waves = num_waves_total;
|
| 180 |
+
SplitKStyle split_style =
|
| 181 |
+
split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
|
| 182 |
+
best_config = CutlassGemmConfig{
|
| 183 |
+
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
|
| 184 |
+
current_m_tile = tile_shape.m;
|
| 185 |
+
}
|
| 186 |
+
else if (current_score == config_score
|
| 187 |
+
&& (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor
|
| 188 |
+
|| current_m_tile < tile_shape.m)) {
|
| 189 |
+
// Prefer deeper pipeline or smaller split-k
|
| 190 |
+
SplitKStyle split_style =
|
| 191 |
+
split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
|
| 192 |
+
best_config = CutlassGemmConfig{
|
| 193 |
+
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
|
| 194 |
+
current_m_tile = tile_shape.m;
|
| 195 |
+
config_waves = num_waves_total;
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) {
|
| 202 |
+
throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config.");
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
return best_config;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
} // namespace fastertransformer
|
cutlass_kernels/cutlass_heuristic.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
|
| 19 |
+
#include <vector>
|
| 20 |
+
#include <cstddef>
|
| 21 |
+
#include <cstdint>
|
| 22 |
+
#include "cutlass_extensions/ft_gemm_configs.h"
|
| 23 |
+
|
| 24 |
+
namespace fastertransformer {
|
| 25 |
+
|
| 26 |
+
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only);
|
| 27 |
+
|
| 28 |
+
CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
|
| 29 |
+
const std::vector<int>& occupancies,
|
| 30 |
+
const int64_t m,
|
| 31 |
+
const int64_t n,
|
| 32 |
+
const int64_t k,
|
| 33 |
+
const int64_t num_experts,
|
| 34 |
+
const int split_k_limit,
|
| 35 |
+
const size_t workspace_bytes,
|
| 36 |
+
const int multi_processor_count,
|
| 37 |
+
const int is_weight_only);
|
| 38 |
+
|
| 39 |
+
} // namespace fastertransformer
|
cutlass_kernels/cutlass_preprocessors.cc
ADDED
|
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
#include "cutlass_preprocessors.h"
|
| 17 |
+
#include "cuda_utils.h"
|
| 18 |
+
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
| 19 |
+
|
| 20 |
+
#include <vector>
|
| 21 |
+
|
| 22 |
+
namespace fastertransformer {
|
| 23 |
+
|
| 24 |
+
int get_bits_in_quant_type(QuantType quant_type) {
|
| 25 |
+
switch (quant_type) {
|
| 26 |
+
case QuantType::INT8_WEIGHT_ONLY:
|
| 27 |
+
return 8;
|
| 28 |
+
case QuantType::PACKED_INT4_WEIGHT_ONLY:
|
| 29 |
+
return 4;
|
| 30 |
+
default:
|
| 31 |
+
return -1;
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
struct LayoutDetails {
|
| 36 |
+
enum class Layout {
|
| 37 |
+
UNKNOWN,
|
| 38 |
+
ROW_MAJOR,
|
| 39 |
+
COLUMN_MAJOR
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
Layout layoutB = Layout::UNKNOWN;
|
| 43 |
+
int rows_per_column_tile = 1;
|
| 44 |
+
int columns_interleaved = 1;
|
| 45 |
+
|
| 46 |
+
bool uses_imma_ldsm = false;
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
template<typename Layout>
|
| 50 |
+
struct getLayoutDetails {
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
template<>
|
| 54 |
+
struct getLayoutDetails<cutlass::layout::RowMajor> {
|
| 55 |
+
LayoutDetails operator()()
|
| 56 |
+
{
|
| 57 |
+
LayoutDetails layout_details;
|
| 58 |
+
layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR;
|
| 59 |
+
return layout_details;
|
| 60 |
+
}
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
template<>
|
| 64 |
+
struct getLayoutDetails<cutlass::layout::ColumnMajor> {
|
| 65 |
+
LayoutDetails operator()()
|
| 66 |
+
{
|
| 67 |
+
LayoutDetails layout_details;
|
| 68 |
+
layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR;
|
| 69 |
+
return layout_details;
|
| 70 |
+
}
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
template<int RowsPerTile, int ColumnsInterleaved>
|
| 74 |
+
struct getLayoutDetails<cutlass::layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>> {
|
| 75 |
+
LayoutDetails operator()()
|
| 76 |
+
{
|
| 77 |
+
LayoutDetails layout_details;
|
| 78 |
+
layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR;
|
| 79 |
+
layout_details.rows_per_column_tile = RowsPerTile;
|
| 80 |
+
layout_details.columns_interleaved = ColumnsInterleaved;
|
| 81 |
+
return layout_details;
|
| 82 |
+
}
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
template<typename cutlassArch, typename TypeB>
|
| 86 |
+
LayoutDetails getLayoutDetailsForArchAndQuantType()
|
| 87 |
+
{
|
| 88 |
+
|
| 89 |
+
using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB<TypeB, cutlassArch>;
|
| 90 |
+
using LayoutB = typename CompileTraits::Layout;
|
| 91 |
+
using MmaOperator = typename CompileTraits::Operator;
|
| 92 |
+
LayoutDetails details = getLayoutDetails<LayoutB>()();
|
| 93 |
+
details.uses_imma_ldsm = std::is_same<MmaOperator, cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA>::value;
|
| 94 |
+
return details;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
template<typename cutlassArch>
|
| 98 |
+
LayoutDetails getLayoutDetailsForArch(QuantType quant_type)
|
| 99 |
+
{
|
| 100 |
+
LayoutDetails details;
|
| 101 |
+
if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
|
| 102 |
+
details = getLayoutDetailsForArchAndQuantType<cutlassArch, uint8_t>();
|
| 103 |
+
}
|
| 104 |
+
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
|
| 105 |
+
details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::uint4b_t>();
|
| 106 |
+
}
|
| 107 |
+
else {
|
| 108 |
+
FT_CHECK_WITH_INFO(false, "Unsupported quantization type");
|
| 109 |
+
}
|
| 110 |
+
return details;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
|
| 114 |
+
{
|
| 115 |
+
if (arch >= 70 && arch < 75) {
|
| 116 |
+
return getLayoutDetailsForArch<cutlass::arch::Sm70>(quant_type);
|
| 117 |
+
}
|
| 118 |
+
else if (arch >= 75 && arch < 80) {
|
| 119 |
+
return getLayoutDetailsForArch<cutlass::arch::Sm75>(quant_type);
|
| 120 |
+
}
|
| 121 |
+
else if (arch >= 80 && arch < 90) {
|
| 122 |
+
return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type);
|
| 123 |
+
}
|
| 124 |
+
else {
|
| 125 |
+
FT_CHECK_WITH_INFO(false, "Unsupported Arch");
|
| 126 |
+
return LayoutDetails();
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
// Permutes the rows of B for Turing and Ampere. Throws an error for other
|
| 131 |
+
// architectures. The data is permuted such that: For int8, each group of 16
|
| 132 |
+
// rows is permuted using the map below:
|
| 133 |
+
// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
|
| 134 |
+
// For int4, each group of 32 rows is permuted using the map below:
|
| 135 |
+
// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22
|
| 136 |
+
// 23 30 31
|
| 137 |
+
void permute_B_rows_for_mixed_gemm(int8_t *permuted_quantized_tensor,
|
| 138 |
+
const int8_t *quantized_tensor,
|
| 139 |
+
const std::vector<size_t> &shape,
|
| 140 |
+
QuantType quant_type,
|
| 141 |
+
const int64_t arch_version) {
|
| 142 |
+
const size_t num_rows = shape[0];
|
| 143 |
+
const size_t num_cols = shape[1];
|
| 144 |
+
|
| 145 |
+
const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
| 146 |
+
const int K = 16 / BITS_PER_ELT;
|
| 147 |
+
const int ELTS_PER_REG = 32 / BITS_PER_ELT;
|
| 148 |
+
|
| 149 |
+
const uint32_t *input_byte_ptr =
|
| 150 |
+
reinterpret_cast<const uint32_t *>(quantized_tensor);
|
| 151 |
+
uint32_t *output_byte_ptr =
|
| 152 |
+
reinterpret_cast<uint32_t *>(permuted_quantized_tensor);
|
| 153 |
+
|
| 154 |
+
int MMA_SHAPE_N = 8;
|
| 155 |
+
int B_ROWS_PER_MMA = 8 * K;
|
| 156 |
+
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
| 157 |
+
|
| 158 |
+
const int num_vec_cols = num_cols / elts_in_int32;
|
| 159 |
+
|
| 160 |
+
FT_CHECK_WITH_INFO(arch_version >= 75,
|
| 161 |
+
"Unsupported Arch. Pre-volta not supported. Column "
|
| 162 |
+
"interleave not needed on Volta.");
|
| 163 |
+
|
| 164 |
+
FT_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0,
|
| 165 |
+
fmtstr("Invalid shape for quantized tensor. Number of "
|
| 166 |
+
"rows of quantized matrix must be a multiple of %d",
|
| 167 |
+
B_ROWS_PER_MMA));
|
| 168 |
+
|
| 169 |
+
FT_CHECK_WITH_INFO(
|
| 170 |
+
num_cols % MMA_SHAPE_N == 0,
|
| 171 |
+
fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number "
|
| 172 |
+
"of cols must be a multiple of %d.",
|
| 173 |
+
MMA_SHAPE_N));
|
| 174 |
+
|
| 175 |
+
// The code is written as below so it works for both int8
|
| 176 |
+
// and packed int4.
|
| 177 |
+
for (size_t base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) {
|
| 178 |
+
for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) {
|
| 179 |
+
|
| 180 |
+
for (int write_col = 0; write_col < num_vec_cols; ++write_col) {
|
| 181 |
+
const int write_row = base_row + tile_row;
|
| 182 |
+
const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) +
|
| 183 |
+
tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
|
| 184 |
+
const int read_row = base_row + tile_read_row;
|
| 185 |
+
const int read_col = write_col;
|
| 186 |
+
|
| 187 |
+
const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col;
|
| 188 |
+
const int64_t write_offset =
|
| 189 |
+
int64_t(write_row) * num_vec_cols + write_col;
|
| 190 |
+
|
| 191 |
+
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
// We need to use this transpose to correctly handle packed int4 and int8 data
|
| 198 |
+
// The reason this code is relatively complex is that the "trivial" loops took a
|
| 199 |
+
// substantial amount of time to transpose leading to long preprocessing times.
|
| 200 |
+
// This seemed to be a big issue for relatively large models.
|
| 201 |
+
template <QuantType quant_type>
|
| 202 |
+
void subbyte_transpose_impl(int8_t *transposed_quantized_tensor,
|
| 203 |
+
const int8_t *quantized_tensor,
|
| 204 |
+
const std::vector<size_t> &shape) {
|
| 205 |
+
const int bits_per_elt = get_bits_in_quant_type(quant_type);
|
| 206 |
+
const size_t num_rows = shape[0];
|
| 207 |
+
const size_t num_cols = shape[1];
|
| 208 |
+
|
| 209 |
+
const size_t col_bytes = num_cols * bits_per_elt / 8;
|
| 210 |
+
const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
|
| 211 |
+
|
| 212 |
+
const uint8_t *input_byte_ptr =
|
| 213 |
+
reinterpret_cast<const uint8_t *>(quantized_tensor);
|
| 214 |
+
uint8_t *output_byte_ptr =
|
| 215 |
+
reinterpret_cast<uint8_t *>(transposed_quantized_tensor);
|
| 216 |
+
|
| 217 |
+
static constexpr int ELTS_PER_BYTE =
|
| 218 |
+
quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2;
|
| 219 |
+
|
| 220 |
+
static constexpr int M_TILE_L1 = 64;
|
| 221 |
+
static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
|
| 222 |
+
uint8_t cache_buf[M_TILE_L1][N_TILE_L1];
|
| 223 |
+
|
| 224 |
+
static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1);
|
| 225 |
+
|
| 226 |
+
// We assume the dims are a multiple of vector width. Our kernels only handle
|
| 227 |
+
// dims which are multiples of 64 for weight-only quantization. As a result,
|
| 228 |
+
// this seemed like a reasonable tradeoff because it allows GCC to emit vector
|
| 229 |
+
// instructions.
|
| 230 |
+
FT_CHECK_WITH_INFO(
|
| 231 |
+
!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH),
|
| 232 |
+
fmtstr("Number of bytes for rows and cols must be a multiple of %d. "
|
| 233 |
+
"However, num_rows_bytes = %ld and num_col_bytes = %d.",
|
| 234 |
+
VECTOR_WIDTH, col_bytes_trans, col_bytes));
|
| 235 |
+
|
| 236 |
+
for (size_t row_tile_start = 0; row_tile_start < num_rows;
|
| 237 |
+
row_tile_start += M_TILE_L1) {
|
| 238 |
+
for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes;
|
| 239 |
+
col_tile_start_byte += N_TILE_L1) {
|
| 240 |
+
|
| 241 |
+
const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
|
| 242 |
+
const int col_limit =
|
| 243 |
+
std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
|
| 244 |
+
|
| 245 |
+
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
| 246 |
+
const int row = row_tile_start + ii;
|
| 247 |
+
|
| 248 |
+
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
|
| 249 |
+
const int col = col_tile_start_byte + jj;
|
| 250 |
+
|
| 251 |
+
const size_t logical_src_offset = row * col_bytes + col;
|
| 252 |
+
|
| 253 |
+
if (row < row_limit && col < col_limit) {
|
| 254 |
+
for (int v = 0; v < VECTOR_WIDTH; ++v) {
|
| 255 |
+
cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v];
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
|
| 262 |
+
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
| 263 |
+
for (int jj = ii + 1; jj < N_TILE_L1; ++jj) {
|
| 264 |
+
std::swap(cache_buf[ii][jj], cache_buf[jj][ii]);
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
} else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
|
| 268 |
+
|
| 269 |
+
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
| 270 |
+
// Using M_TILE_L1 here is deliberate since we assume that the cache
|
| 271 |
+
// tile is square in the number of elements (not necessarily the
|
| 272 |
+
// number of bytes).
|
| 273 |
+
for (int jj = ii + 1; jj < M_TILE_L1; ++jj) {
|
| 274 |
+
const int ii_byte = ii / ELTS_PER_BYTE;
|
| 275 |
+
const int ii_bit_offset = ii % ELTS_PER_BYTE;
|
| 276 |
+
|
| 277 |
+
const int jj_byte = jj / ELTS_PER_BYTE;
|
| 278 |
+
const int jj_bit_offset = jj % ELTS_PER_BYTE;
|
| 279 |
+
|
| 280 |
+
uint8_t src_elt =
|
| 281 |
+
0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
|
| 282 |
+
uint8_t tgt_elt =
|
| 283 |
+
0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
|
| 284 |
+
|
| 285 |
+
cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset));
|
| 286 |
+
cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset));
|
| 287 |
+
|
| 288 |
+
cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset));
|
| 289 |
+
cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset));
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
} else {
|
| 293 |
+
FT_CHECK_WITH_INFO(false, "Unsupported quantization type.");
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
|
| 297 |
+
const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
|
| 298 |
+
|
| 299 |
+
const int row_limit_trans =
|
| 300 |
+
std::min(row_tile_start_trans + M_TILE_L1, num_cols);
|
| 301 |
+
const int col_limit_trans =
|
| 302 |
+
std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
|
| 303 |
+
|
| 304 |
+
for (int ii = 0; ii < M_TILE_L1; ++ii) {
|
| 305 |
+
const int row = row_tile_start_trans + ii;
|
| 306 |
+
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
|
| 307 |
+
const int col = col_tile_start_byte_trans + jj;
|
| 308 |
+
|
| 309 |
+
const size_t logical_tgt_offset = row * col_bytes_trans + col;
|
| 310 |
+
|
| 311 |
+
if (row < row_limit_trans && col < col_limit_trans) {
|
| 312 |
+
for (int v = 0; v < VECTOR_WIDTH; ++v) {
|
| 313 |
+
output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v];
|
| 314 |
+
}
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
void subbyte_transpose(int8_t *transposed_quantized_tensor,
|
| 323 |
+
const int8_t *quantized_tensor,
|
| 324 |
+
const std::vector<size_t> &shape, QuantType quant_type) {
|
| 325 |
+
|
| 326 |
+
if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
|
| 327 |
+
subbyte_transpose_impl<QuantType::INT8_WEIGHT_ONLY>(
|
| 328 |
+
transposed_quantized_tensor, quantized_tensor, shape);
|
| 329 |
+
} else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
|
| 330 |
+
subbyte_transpose_impl<QuantType::PACKED_INT4_WEIGHT_ONLY>(
|
| 331 |
+
transposed_quantized_tensor, quantized_tensor, shape);
|
| 332 |
+
} else {
|
| 333 |
+
FT_CHECK_WITH_INFO(false, "Invalid quant_tye");
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
void add_bias_and_interleave_int8s_inplace(int8_t *int8_tensor,
|
| 338 |
+
const size_t num_elts) {
|
| 339 |
+
for (size_t ii = 0; ii < num_elts; ++ii) {
|
| 340 |
+
int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128);
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
// Step 2 will transform the layout of a 32-bit register in CUDA in order to
|
| 344 |
+
// match the int4 layout. This has no performance benefit and is purely so
|
| 345 |
+
// that int4 and int8 have the same layout. Pictorially, this does the
|
| 346 |
+
// following: bit 32 0
|
| 347 |
+
// [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits)
|
| 348 |
+
//
|
| 349 |
+
// And it will rearrange the output 32 bit register to be the following:
|
| 350 |
+
// bit 32 0
|
| 351 |
+
// [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
|
| 352 |
+
|
| 353 |
+
FT_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a "
|
| 354 |
+
"multiple of 4 for register relayout");
|
| 355 |
+
for (size_t base = 0; base < num_elts; base += 4) {
|
| 356 |
+
std::swap(int8_tensor[base + 1], int8_tensor[base + 2]);
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
void add_bias_and_interleave_int4s_inplace(int8_t *packed_int4_tensor,
|
| 361 |
+
const size_t num_elts) {
|
| 362 |
+
const size_t num_bytes = num_elts / 2;
|
| 363 |
+
|
| 364 |
+
// Step 1 will be to transform all the int4s to unsigned in order to make the
|
| 365 |
+
// dequantize take as little instructions as possible in the CUDA code.
|
| 366 |
+
for (size_t ii = 0; ii < num_bytes; ++ii) {
|
| 367 |
+
int8_t transformed_packed_int4s = 0;
|
| 368 |
+
int8_t transformed_first_elt =
|
| 369 |
+
(int8_t(packed_int4_tensor[ii] << 4) >> 4) +
|
| 370 |
+
8; // The double shift here is to ensure sign extension
|
| 371 |
+
int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8;
|
| 372 |
+
|
| 373 |
+
FT_CHECK_WITH_INFO(transformed_first_elt >= 0 &&
|
| 374 |
+
transformed_first_elt <= 15,
|
| 375 |
+
"Illegal result for int4 transform (first elt)");
|
| 376 |
+
FT_CHECK_WITH_INFO(transformed_second_elt >= 0 &&
|
| 377 |
+
transformed_second_elt <= 15,
|
| 378 |
+
"Illegal result for int4 transform (second elt)");
|
| 379 |
+
|
| 380 |
+
// We don't need to mask in these ops since everything should be in the
|
| 381 |
+
// range 0-15
|
| 382 |
+
transformed_packed_int4s |= transformed_first_elt;
|
| 383 |
+
transformed_packed_int4s |= (transformed_second_elt << 4);
|
| 384 |
+
packed_int4_tensor[ii] = transformed_packed_int4s;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
// Step 2 will transform the layout of a 32-bit register in CUDA in order to
|
| 388 |
+
// minimize the number of shift & logical instructions That are needed to
|
| 389 |
+
// extract the int4s in the GEMM main loop. Pictorially, the loop below will
|
| 390 |
+
// do the following: Take as input a 32 bit register with layout: bit 32 0
|
| 391 |
+
// [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt
|
| 392 |
+
// occupies 4 bits)
|
| 393 |
+
//
|
| 394 |
+
// And it will rearrange the output 32 bit register to be the following:
|
| 395 |
+
// bit 32 0
|
| 396 |
+
// [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt
|
| 397 |
+
// occupies 4 bits)
|
| 398 |
+
|
| 399 |
+
FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a "
|
| 400 |
+
"multiple of 8 for register relayout");
|
| 401 |
+
const size_t num_registers = num_bytes / 4;
|
| 402 |
+
|
| 403 |
+
uint32_t *register_ptr = reinterpret_cast<uint32_t *>(packed_int4_tensor);
|
| 404 |
+
for (size_t ii = 0; ii < num_registers; ++ii) {
|
| 405 |
+
const uint32_t current_register = register_ptr[ii];
|
| 406 |
+
uint32_t transformed_register = 0;
|
| 407 |
+
|
| 408 |
+
for (int dest_idx = 0; dest_idx < 8; ++dest_idx) {
|
| 409 |
+
const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
|
| 410 |
+
const int src_shift = 4 * src_idx;
|
| 411 |
+
const int dest_shift = 4 * dest_idx;
|
| 412 |
+
|
| 413 |
+
const uint32_t src_bits = (current_register >> src_shift) & 0xF;
|
| 414 |
+
transformed_register |= (src_bits << dest_shift);
|
| 415 |
+
}
|
| 416 |
+
register_ptr[ii] = transformed_register;
|
| 417 |
+
}
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
void add_bias_and_interleave_quantized_tensor_inplace(int8_t *tensor,
|
| 421 |
+
const size_t num_elts,
|
| 422 |
+
QuantType quant_type) {
|
| 423 |
+
if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
|
| 424 |
+
add_bias_and_interleave_int8s_inplace(tensor, num_elts);
|
| 425 |
+
} else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
|
| 426 |
+
add_bias_and_interleave_int4s_inplace(tensor, num_elts);
|
| 427 |
+
} else {
|
| 428 |
+
FT_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving.");
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
void interleave_column_major_tensor(int8_t *interleaved_quantized_tensor,
|
| 433 |
+
const int8_t *quantized_tensor,
|
| 434 |
+
const std::vector<size_t> &shape,
|
| 435 |
+
QuantType quant_type,
|
| 436 |
+
LayoutDetails details) {
|
| 437 |
+
// We only want to run this step for weight only quant.
|
| 438 |
+
FT_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY ||
|
| 439 |
+
quant_type == QuantType::INT8_WEIGHT_ONLY);
|
| 440 |
+
FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D");
|
| 441 |
+
|
| 442 |
+
const size_t num_rows = shape[0];
|
| 443 |
+
const size_t num_cols = shape[1];
|
| 444 |
+
|
| 445 |
+
const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
| 446 |
+
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
| 447 |
+
|
| 448 |
+
const int rows_per_tile = details.rows_per_column_tile;
|
| 449 |
+
|
| 450 |
+
FT_CHECK_WITH_INFO(!(num_rows % elts_in_int32),
|
| 451 |
+
fmtstr("The number of rows must be a multiple of %d but "
|
| 452 |
+
"the number of rows is %d.",
|
| 453 |
+
elts_in_int32, num_rows));
|
| 454 |
+
|
| 455 |
+
FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile),
|
| 456 |
+
fmtstr("The number of columns must be a multiple of %d "
|
| 457 |
+
"but the number of columns is %ld",
|
| 458 |
+
rows_per_tile, num_cols));
|
| 459 |
+
|
| 460 |
+
const uint32_t *input_byte_ptr =
|
| 461 |
+
reinterpret_cast<const uint32_t *>(quantized_tensor);
|
| 462 |
+
uint32_t *output_byte_ptr =
|
| 463 |
+
reinterpret_cast<uint32_t *>(interleaved_quantized_tensor);
|
| 464 |
+
|
| 465 |
+
FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile),
|
| 466 |
+
fmtstr("The number of columns must be a multiple of %d "
|
| 467 |
+
"but the number of columns is %d.",
|
| 468 |
+
rows_per_tile, num_cols));
|
| 469 |
+
|
| 470 |
+
const int num_vec_rows = num_rows / elts_in_int32;
|
| 471 |
+
const int vec_rows_per_tile = rows_per_tile / elts_in_int32;
|
| 472 |
+
const int interleave = details.columns_interleaved;
|
| 473 |
+
|
| 474 |
+
for (size_t read_col = 0; read_col < num_cols; ++read_col) {
|
| 475 |
+
const auto write_col = read_col / interleave;
|
| 476 |
+
for (int base_vec_row = 0; base_vec_row < num_vec_rows;
|
| 477 |
+
base_vec_row += vec_rows_per_tile) {
|
| 478 |
+
for (int vec_read_row = base_vec_row;
|
| 479 |
+
vec_read_row <
|
| 480 |
+
std::min(num_vec_rows, base_vec_row + vec_rows_per_tile);
|
| 481 |
+
++vec_read_row) {
|
| 482 |
+
const int64_t vec_write_row =
|
| 483 |
+
interleave * base_vec_row +
|
| 484 |
+
vec_rows_per_tile * (read_col % interleave) +
|
| 485 |
+
vec_read_row % vec_rows_per_tile;
|
| 486 |
+
|
| 487 |
+
const int64_t read_offset =
|
| 488 |
+
int64_t(read_col) * num_vec_rows + vec_read_row;
|
| 489 |
+
const int64_t write_offset =
|
| 490 |
+
int64_t(write_col) * num_vec_rows * interleave + vec_write_row;
|
| 491 |
+
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
|
| 492 |
+
}
|
| 493 |
+
}
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
void preprocess_weights_for_mixed_gemm(int8_t *preprocessed_quantized_weight,
|
| 498 |
+
const int8_t *row_major_quantized_weight,
|
| 499 |
+
const std::vector<size_t> &shape,
|
| 500 |
+
QuantType quant_type, int arch) {
|
| 501 |
+
LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch);
|
| 502 |
+
|
| 503 |
+
FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D");
|
| 504 |
+
|
| 505 |
+
size_t num_elts = 1;
|
| 506 |
+
for (const auto &dim : shape) {
|
| 507 |
+
num_elts *= dim;
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8;
|
| 511 |
+
|
| 512 |
+
std::vector<int8_t> src_buf(num_bytes);
|
| 513 |
+
std::vector<int8_t> dst_buf(num_bytes);
|
| 514 |
+
std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin());
|
| 515 |
+
|
| 516 |
+
// Works on row major data, so issue this permutation first.
|
| 517 |
+
if (details.uses_imma_ldsm) {
|
| 518 |
+
permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch);
|
| 519 |
+
src_buf.swap(dst_buf);
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) {
|
| 523 |
+
subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type);
|
| 524 |
+
src_buf.swap(dst_buf);
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
if (details.columns_interleaved > 1) {
|
| 528 |
+
interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details);
|
| 529 |
+
src_buf.swap(dst_buf);
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
|
| 533 |
+
std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight);
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
void preprocess_weights(int8_t *preprocessed_quantized_weight,
|
| 537 |
+
const int8_t *row_major_quantized_weight, size_t rows,
|
| 538 |
+
size_t cols, bool is_int4, int arch) {
|
| 539 |
+
QuantType qtype = is_int4 ? QuantType::PACKED_INT4_WEIGHT_ONLY
|
| 540 |
+
: QuantType::INT8_WEIGHT_ONLY;
|
| 541 |
+
preprocess_weights_for_mixed_gemm(preprocessed_quantized_weight,
|
| 542 |
+
row_major_quantized_weight, {rows, cols},
|
| 543 |
+
qtype, arch);
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
/*
|
| 547 |
+
Arguments:
|
| 548 |
+
input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16.
|
| 549 |
+
|
| 550 |
+
quant_type - the type of the output quantization weight.
|
| 551 |
+
|
| 552 |
+
This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the
|
| 553 |
+
zero-point is zero and will automatically construct the scales.
|
| 554 |
+
|
| 555 |
+
It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is
|
| 556 |
+
viewed as a stack of matrices and a scale is produced for each column of every matrix.
|
| 557 |
+
|
| 558 |
+
Outputs
|
| 559 |
+
processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM
|
| 560 |
+
unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking.
|
| 561 |
+
scale_ptr - scales for the quantized weight.
|
| 562 |
+
|
| 563 |
+
Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data
|
| 564 |
+
layout may not make sense if printed.
|
| 565 |
+
|
| 566 |
+
Shapes:
|
| 567 |
+
quant_type == int8:
|
| 568 |
+
If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n]
|
| 569 |
+
If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n]
|
| 570 |
+
quant_type == int4:
|
| 571 |
+
If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n]
|
| 572 |
+
If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape
|
| 573 |
+
[b,n]
|
| 574 |
+
|
| 575 |
+
The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the
|
| 576 |
+
reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind
|
| 577 |
+
of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors
|
| 578 |
+
must have a dimension of 1, which breaks the semantics we need for batched weights.
|
| 579 |
+
*/
|
| 580 |
+
|
| 581 |
+
template<typename ComputeType, typename WeightType>
|
| 582 |
+
void symmetric_quantize(int8_t* processed_quantized_weight,
|
| 583 |
+
int8_t* unprocessed_quantized_weight,
|
| 584 |
+
ComputeType* scale_ptr,
|
| 585 |
+
const WeightType* input_weight_ptr,
|
| 586 |
+
const std::vector<size_t>& shape,
|
| 587 |
+
QuantType quant_type)
|
| 588 |
+
{
|
| 589 |
+
|
| 590 |
+
FT_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL");
|
| 591 |
+
FT_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL");
|
| 592 |
+
FT_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL");
|
| 593 |
+
|
| 594 |
+
FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
| 595 |
+
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
|
| 596 |
+
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
| 597 |
+
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
| 598 |
+
|
| 599 |
+
const int bits_in_type = get_bits_in_quant_type(quant_type);
|
| 600 |
+
const int bytes_per_out_col = num_cols * bits_in_type / 8;
|
| 601 |
+
|
| 602 |
+
std::vector<int8_t> weight_buf;
|
| 603 |
+
if (unprocessed_quantized_weight == nullptr) {
|
| 604 |
+
weight_buf.resize(num_experts * num_rows * num_cols);
|
| 605 |
+
unprocessed_quantized_weight = weight_buf.data();
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
const int input_mat_size = num_rows * num_cols;
|
| 609 |
+
const int quantized_mat_size = num_rows * bytes_per_out_col;
|
| 610 |
+
const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1));
|
| 611 |
+
|
| 612 |
+
std::vector<float> per_col_max(num_cols);
|
| 613 |
+
|
| 614 |
+
for (int expert = 0; expert < num_experts; ++expert) {
|
| 615 |
+
const WeightType* current_weight = input_weight_ptr + expert * input_mat_size;
|
| 616 |
+
int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size;
|
| 617 |
+
|
| 618 |
+
// First we find the per column max for this expert weight.
|
| 619 |
+
for (int jj = 0; jj < num_cols; ++jj) {
|
| 620 |
+
per_col_max[jj] = 0.f;
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
for (int ii = 0; ii < num_rows; ++ii) {
|
| 624 |
+
const WeightType* current_weight_row = current_weight + ii * num_cols;
|
| 625 |
+
for (int jj = 0; jj < num_cols; ++jj) {
|
| 626 |
+
per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj])));
|
| 627 |
+
}
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
// Then, we construct the scales
|
| 631 |
+
ComputeType* current_scales = scale_ptr + expert * num_cols;
|
| 632 |
+
for (int jj = 0; jj < num_cols; ++jj) {
|
| 633 |
+
per_col_max[jj] *= quant_range_scale;
|
| 634 |
+
current_scales[jj] = ComputeType(per_col_max[jj]);
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
// Finally, construct the weights.
|
| 638 |
+
for (int ii = 0; ii < num_rows; ++ii) {
|
| 639 |
+
int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col;
|
| 640 |
+
const WeightType* current_weight_row = current_weight + ii * num_cols;
|
| 641 |
+
for (int jj = 0; jj < bytes_per_out_col; ++jj) {
|
| 642 |
+
|
| 643 |
+
if (quant_type == QuantType::INT8_WEIGHT_ONLY) {
|
| 644 |
+
const float col_scale = per_col_max[jj];
|
| 645 |
+
const float weight_elt = float(current_weight_row[jj]);
|
| 646 |
+
const float scaled_weight = round(weight_elt / col_scale);
|
| 647 |
+
const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
|
| 648 |
+
current_quantized_weight_row[jj] = clipped_weight;
|
| 649 |
+
}
|
| 650 |
+
else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) {
|
| 651 |
+
|
| 652 |
+
// We will pack two int4 elements per iteration of the inner loop.
|
| 653 |
+
int8_t packed_int4s = 0;
|
| 654 |
+
for (int packed_idx = 0; packed_idx < 2; ++packed_idx) {
|
| 655 |
+
const int input_idx = 2 * jj + packed_idx;
|
| 656 |
+
if (input_idx < num_cols) {
|
| 657 |
+
const float col_scale = per_col_max[input_idx];
|
| 658 |
+
const float weight_elt = float(current_weight_row[input_idx]);
|
| 659 |
+
const float scaled_weight = round(weight_elt / col_scale);
|
| 660 |
+
int int_weight = int(scaled_weight);
|
| 661 |
+
const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));
|
| 662 |
+
|
| 663 |
+
// Kill the sign extension bits (hence 0x0F mask) then shift to upper bits
|
| 664 |
+
// if packing the second int4 and or the bits into the final result.
|
| 665 |
+
packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx));
|
| 666 |
+
}
|
| 667 |
+
}
|
| 668 |
+
current_quantized_weight_row[jj] = packed_int4s;
|
| 669 |
+
}
|
| 670 |
+
else {
|
| 671 |
+
FT_CHECK_WITH_INFO(false, "Unsupported quantization type");
|
| 672 |
+
}
|
| 673 |
+
}
|
| 674 |
+
}
|
| 675 |
+
}
|
| 676 |
+
const int arch = fastertransformer::getSMVersion();
|
| 677 |
+
preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, arch);
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
template void
|
| 681 |
+
symmetric_quantize<half, float>(int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
|
| 682 |
+
|
| 683 |
+
template void
|
| 684 |
+
symmetric_quantize<half, half>(int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
template<typename ComputeType, typename WeightType>
|
| 688 |
+
void symmetric_quantize(int8_t* processed_quantized_weight,
|
| 689 |
+
ComputeType* scale_ptr,
|
| 690 |
+
const WeightType* input_weight_ptr,
|
| 691 |
+
const std::vector<size_t>& shape,
|
| 692 |
+
QuantType quant_type)
|
| 693 |
+
{
|
| 694 |
+
symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type);
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
template void symmetric_quantize<float, float>(int8_t*, float*, const float*, const std::vector<size_t>&, QuantType);
|
| 698 |
+
|
| 699 |
+
template void symmetric_quantize<half, float>(int8_t*, half*, const float*, const std::vector<size_t>&, QuantType);
|
| 700 |
+
|
| 701 |
+
template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType);
|
| 702 |
+
|
| 703 |
+
} // namespace fastertransformer
|
cutlass_kernels/cutlass_preprocessors.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
| 3 |
+
|
| 4 |
+
#include <cstddef>
|
| 5 |
+
#include <cstdint>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
namespace fastertransformer {
|
| 9 |
+
|
| 10 |
+
enum class QuantType { INT8_WEIGHT_ONLY, PACKED_INT4_WEIGHT_ONLY };
|
| 11 |
+
|
| 12 |
+
int get_bits_in_quant_type(QuantType quant_type);
|
| 13 |
+
|
| 14 |
+
void preprocess_weights(int8_t *preprocessed_quantized_weight,
|
| 15 |
+
const int8_t *row_major_quantized_weight, size_t rows,
|
| 16 |
+
size_t cols, bool is_int4, int arch);
|
| 17 |
+
|
| 18 |
+
template<typename ComputeType, typename WeightType>
|
| 19 |
+
void symmetric_quantize(int8_t* processed_quantized_weight,
|
| 20 |
+
ComputeType* scale_ptr,
|
| 21 |
+
const WeightType* input_weight_ptr,
|
| 22 |
+
const std::vector<size_t>& shape,
|
| 23 |
+
QuantType quant_type);
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
template<typename ComputeType, typename WeightType>
|
| 27 |
+
void symmetric_quantize(int8_t* processed_quantized_weight,
|
| 28 |
+
int8_t* unprocessed_quantized_weight,
|
| 29 |
+
ComputeType* scale_ptr,
|
| 30 |
+
const WeightType* input_weight_ptr,
|
| 31 |
+
const std::vector<size_t>& shape,
|
| 32 |
+
QuantType quant_type);
|
| 33 |
+
} // namespace fastertransformer
|
cutlass_kernels/fpA_intB_gemm.cu
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "fpA_intB_gemm.h"
|
| 2 |
+
#include "fpA_intB_gemm/fpA_intB_gemm_template.h"
|
| 3 |
+
|
| 4 |
+
namespace fastertransformer
|
| 5 |
+
{
|
| 6 |
+
|
| 7 |
+
ActivationType get_activation(const std::string &activation_name)
|
| 8 |
+
{
|
| 9 |
+
if (activation_name == "identity")
|
| 10 |
+
return ActivationType::Identity;
|
| 11 |
+
if (activation_name == "relu")
|
| 12 |
+
return ActivationType::Relu;
|
| 13 |
+
if (activation_name == "silu")
|
| 14 |
+
return ActivationType::Silu;
|
| 15 |
+
if (activation_name == "gelu")
|
| 16 |
+
return ActivationType::Gelu;
|
| 17 |
+
// todo: more
|
| 18 |
+
return ActivationType::InvalidType;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
void gemm_fp16_int(const half *A,
|
| 22 |
+
const uint8_t *B,
|
| 23 |
+
const half *weight_scales,
|
| 24 |
+
half *C,
|
| 25 |
+
int m, int n, int k,
|
| 26 |
+
char *workspace_ptr,
|
| 27 |
+
size_t workspace_bytes,
|
| 28 |
+
cudaStream_t stream)
|
| 29 |
+
{
|
| 30 |
+
CutlassFpAIntBGemmRunner<half, uint8_t> runner;
|
| 31 |
+
runner.gemm(A, B, weight_scales,
|
| 32 |
+
C, m, n, k, workspace_ptr, workspace_bytes, stream);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
template <typename WeightType>
|
| 36 |
+
void gemm_fp16_int_bias_act(const half *A,
|
| 37 |
+
const WeightType *B,
|
| 38 |
+
const half *weight_scales,
|
| 39 |
+
const half *bias,
|
| 40 |
+
half *C,
|
| 41 |
+
std::optional<std::string> activation,
|
| 42 |
+
int m, int n, int k, int bias_stride, char *workspace_ptr,
|
| 43 |
+
size_t workspace_bytes, cudaStream_t stream)
|
| 44 |
+
{
|
| 45 |
+
CutlassFpAIntBGemmRunner<half, WeightType> runner;
|
| 46 |
+
|
| 47 |
+
if (!activation && bias == nullptr)
|
| 48 |
+
{
|
| 49 |
+
runner.gemm(A, B, weight_scales,
|
| 50 |
+
C, m, n, k, workspace_ptr, workspace_bytes, stream);
|
| 51 |
+
}
|
| 52 |
+
else if (!activation)
|
| 53 |
+
{
|
| 54 |
+
runner.gemm_bias_act(A, B, weight_scales, bias,
|
| 55 |
+
C, m, n, k, bias_stride, ActivationType::Identity, workspace_ptr, workspace_bytes, stream);
|
| 56 |
+
}
|
| 57 |
+
else
|
| 58 |
+
{
|
| 59 |
+
runner.gemm_bias_act(A, B, weight_scales, bias,
|
| 60 |
+
C, m, n, k, bias_stride, get_activation(*activation), workspace_ptr, workspace_bytes, stream);
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
template <typename WeightType>
|
| 65 |
+
void gemm_fp16_int_bias_act_residual(
|
| 66 |
+
const half *A, const WeightType *B, const half *weight_scales,
|
| 67 |
+
const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
|
| 68 |
+
const std::string &unary_op, int m, int n,
|
| 69 |
+
int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream)
|
| 70 |
+
{
|
| 71 |
+
CutlassFpAIntBGemmRunner<half, WeightType> runner;
|
| 72 |
+
|
| 73 |
+
runner.gemm_bias_act_residual(A, B, weight_scales, bias, residual,
|
| 74 |
+
C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
template void gemm_fp16_int_bias_act<uint4b_t>(const half *A, const uint4b_t *B,
|
| 78 |
+
const half *weight_scales, const half *bias,
|
| 79 |
+
half *C, std::optional<std::string> activation, int m,
|
| 80 |
+
int n, int k, int bias_stride, char *workspace_ptr,
|
| 81 |
+
size_t workspace_bytes, cudaStream_t stream);
|
| 82 |
+
|
| 83 |
+
template void gemm_fp16_int_bias_act_residual<uint4b_t>(
|
| 84 |
+
const half *A, const uint4b_t *B, const half *weight_scales,
|
| 85 |
+
const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
|
| 86 |
+
const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
|
| 87 |
+
|
| 88 |
+
template void gemm_fp16_int_bias_act<uint8_t>(const half *A, const uint8_t *B,
|
| 89 |
+
const half *weight_scales, const half *bias,
|
| 90 |
+
half *C, std::optional<std::string> activation, int m,
|
| 91 |
+
int n, int k, int bias_stride, char *workspace_ptr,
|
| 92 |
+
size_t workspace_bytes, cudaStream_t stream);
|
| 93 |
+
|
| 94 |
+
template void gemm_fp16_int_bias_act_residual<uint8_t>(
|
| 95 |
+
const half *A, const uint8_t *B, const half *weight_scales,
|
| 96 |
+
const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
|
| 97 |
+
const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
|
| 98 |
+
|
| 99 |
+
} // namespace fastertransformer
|
cutlass_kernels/fpA_intB_gemm.h
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <string>
|
| 4 |
+
#include <optional>
|
| 5 |
+
|
| 6 |
+
#include <cuda_runtime.h>
|
| 7 |
+
#include "cutlass/numeric_types.h"
|
| 8 |
+
#include "cutlass/half.h"
|
| 9 |
+
#include "cutlass/integer_subbyte.h"
|
| 10 |
+
|
| 11 |
+
namespace fastertransformer {
|
| 12 |
+
|
| 13 |
+
using half = cutlass::half_t;
|
| 14 |
+
using uint4b_t = cutlass::uint4b_t;
|
| 15 |
+
|
| 16 |
+
// TODO: Support more general bias shape
|
| 17 |
+
|
| 18 |
+
// base gemm
|
| 19 |
+
void gemm_fp16_int(const half *A, const uint8_t * B, const half *weight_scales,
|
| 20 |
+
half *C, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
|
| 21 |
+
|
| 22 |
+
template <typename WeightType>
|
| 23 |
+
void gemm_fp16_int_bias_act(const half *A, const WeightType *B,
|
| 24 |
+
const half *weight_scales, const half *bias,
|
| 25 |
+
half *C, std::optional<std::string> activation, int m,
|
| 26 |
+
int n, int k, int bias_stride, char *workspace_ptr,
|
| 27 |
+
size_t workspace_bytes, cudaStream_t stream);
|
| 28 |
+
|
| 29 |
+
template <typename WeightType>
|
| 30 |
+
void gemm_fp16_int_bias_act_residual(
|
| 31 |
+
const half *A, const WeightType *B, const half *weight_scales,
|
| 32 |
+
const half *bias, const half *residual, half *C, const std::string& activation, const std::string& binary_op,
|
| 33 |
+
const std::string& unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
} // namespace fastertransformer
|
cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
|
| 19 |
+
#include "cutlass_extensions/include/cutlass_extensions/ft_gemm_configs.h"
|
| 20 |
+
#include "utils/activation_types.h"
|
| 21 |
+
#include <cuda_runtime_api.h>
|
| 22 |
+
|
| 23 |
+
namespace fastertransformer {
|
| 24 |
+
|
| 25 |
+
/*
|
| 26 |
+
This runner only supports:
|
| 27 |
+
T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t}
|
| 28 |
+
|
| 29 |
+
Activations, biases, scales and outputs are all assumed to be row-major.
|
| 30 |
+
|
| 31 |
+
However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.
|
| 32 |
+
In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor
|
| 33 |
+
will instantiate the layout and preprocess based on the instantiation, so layout changes should only require
|
| 34 |
+
modifications to mix_gemm_B_layout.h.
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
template<typename T, typename WeightType>
|
| 38 |
+
class CutlassFpAIntBGemmRunner {
|
| 39 |
+
public:
|
| 40 |
+
CutlassFpAIntBGemmRunner();
|
| 41 |
+
~CutlassFpAIntBGemmRunner();
|
| 42 |
+
|
| 43 |
+
void gemm(const T* A,
|
| 44 |
+
const WeightType* B,
|
| 45 |
+
const T* weight_scales,
|
| 46 |
+
T* C,
|
| 47 |
+
int m,
|
| 48 |
+
int n,
|
| 49 |
+
int k,
|
| 50 |
+
char* workspace_ptr,
|
| 51 |
+
const size_t workspace_bytes,
|
| 52 |
+
cudaStream_t stream);
|
| 53 |
+
|
| 54 |
+
void gemm_bias_act(const T* A,
|
| 55 |
+
const WeightType* B,
|
| 56 |
+
const T* weight_scales,
|
| 57 |
+
const T* biases,
|
| 58 |
+
T* C,
|
| 59 |
+
int m,
|
| 60 |
+
int n,
|
| 61 |
+
int k,
|
| 62 |
+
int bias_stride,
|
| 63 |
+
ActivationType activation_type,
|
| 64 |
+
char* workspace_ptr,
|
| 65 |
+
const size_t workspace_bytes,
|
| 66 |
+
cudaStream_t stream);
|
| 67 |
+
|
| 68 |
+
void gemm_bias_act_residual(const T *A, const WeightType *B,
|
| 69 |
+
const T *weight_scales, const T *biases,
|
| 70 |
+
const T *residual, T *C, int m, int n, int k,
|
| 71 |
+
const std::string& activation, const std::string& binary_op,
|
| 72 |
+
const std::string& unary_op,
|
| 73 |
+
char *workspace_ptr,
|
| 74 |
+
const size_t workspace_bytes,
|
| 75 |
+
cudaStream_t stream);
|
| 76 |
+
|
| 77 |
+
// Returns desired workspace size in bytes.
|
| 78 |
+
int getWorkspaceSize(const int m, const int n, const int k);
|
| 79 |
+
|
| 80 |
+
private:
|
| 81 |
+
template<typename EpilogueTag>
|
| 82 |
+
void dispatch_to_arch(const T* A,
|
| 83 |
+
const WeightType* B,
|
| 84 |
+
const T* weight_scales,
|
| 85 |
+
const T* biases,
|
| 86 |
+
T* C,
|
| 87 |
+
int m,
|
| 88 |
+
int n,
|
| 89 |
+
int k,
|
| 90 |
+
int bias_stride,
|
| 91 |
+
CutlassGemmConfig gemm_config,
|
| 92 |
+
char* workspace_ptr,
|
| 93 |
+
const size_t workspace_bytes,
|
| 94 |
+
cudaStream_t stream,
|
| 95 |
+
int* occupancy = nullptr);
|
| 96 |
+
|
| 97 |
+
template<typename EpilogueTag>
|
| 98 |
+
void run_gemm(const T* A,
|
| 99 |
+
const WeightType* B,
|
| 100 |
+
const T* weight_scales,
|
| 101 |
+
const T* biases,
|
| 102 |
+
T* C,
|
| 103 |
+
int m,
|
| 104 |
+
int n,
|
| 105 |
+
int k,
|
| 106 |
+
int bias_stride,
|
| 107 |
+
char* workspace_ptr,
|
| 108 |
+
const size_t workspace_bytes,
|
| 109 |
+
cudaStream_t stream);
|
| 110 |
+
|
| 111 |
+
private:
|
| 112 |
+
static constexpr int split_k_limit = 7;
|
| 113 |
+
|
| 114 |
+
int sm_;
|
| 115 |
+
int multi_processor_count_;
|
| 116 |
+
};
|
| 117 |
+
|
| 118 |
+
} // namespace fastertransformer
|
cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h
ADDED
|
@@ -0,0 +1,858 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#pragma GCC diagnostic push
|
| 18 |
+
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
| 19 |
+
|
| 20 |
+
#include "cutlass/gemm/device/gemm_universal_base.h"
|
| 21 |
+
#include "cutlass/gemm/kernel/default_gemm.h"
|
| 22 |
+
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
|
| 23 |
+
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
|
| 24 |
+
#include "cutlass_extensions/compute_occupancy.h"
|
| 25 |
+
|
| 26 |
+
#include "cutlass_extensions/epilogue_helpers.h"
|
| 27 |
+
#include "cutlass_extensions/ft_gemm_configs.h"
|
| 28 |
+
#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
|
| 29 |
+
#include "cutlass_extensions/gemm/kernel/fpA_intB_gemm.h"
|
| 30 |
+
#include "cutlass_extensions/gemm/kernel/fpA_intB_gemm_with_broadcast.h"
|
| 31 |
+
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
| 32 |
+
|
| 33 |
+
#pragma GCC diagnostic pop
|
| 34 |
+
|
| 35 |
+
#include "../cutlass_heuristic.h"
|
| 36 |
+
#include "fpA_intB_gemm.h"
|
| 37 |
+
#include "cuda_utils.h"
|
| 38 |
+
|
| 39 |
+
namespace fastertransformer {
|
| 40 |
+
|
| 41 |
+
template <typename T,
|
| 42 |
+
typename WeightType,
|
| 43 |
+
typename arch,
|
| 44 |
+
typename EpilogueTag,
|
| 45 |
+
typename ThreadblockShape,
|
| 46 |
+
typename WarpShape,
|
| 47 |
+
int Stages>
|
| 48 |
+
void generic_mixed_gemm_kernelLauncher(const T *A,
|
| 49 |
+
const WeightType *B,
|
| 50 |
+
const T *weight_scales,
|
| 51 |
+
const T *biases,
|
| 52 |
+
T *C,
|
| 53 |
+
int m,
|
| 54 |
+
int n,
|
| 55 |
+
int k,
|
| 56 |
+
int bias_stride,
|
| 57 |
+
CutlassGemmConfig gemm_config,
|
| 58 |
+
char *workspace,
|
| 59 |
+
size_t workspace_bytes,
|
| 60 |
+
cudaStream_t stream,
|
| 61 |
+
int *occupancy = nullptr)
|
| 62 |
+
{
|
| 63 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 64 |
+
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
|
| 65 |
+
"Specialized for half, float");
|
| 66 |
+
|
| 67 |
+
static_assert(cutlass::platform::is_same<T, WeightType>::value || cutlass::platform::is_same<WeightType, uint8_t>::value || cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
|
| 68 |
+
"");
|
| 69 |
+
|
| 70 |
+
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
|
| 71 |
+
using ElementType_ =
|
| 72 |
+
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
|
| 73 |
+
using ElementType = ElementType_;
|
| 74 |
+
|
| 75 |
+
using CutlassWeightType_ = typename cutlass::platform::
|
| 76 |
+
conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t, WeightType>::type;
|
| 77 |
+
using CutlassWeightType = CutlassWeightType_;
|
| 78 |
+
|
| 79 |
+
// We need separate config for each architecture since we will target different tensorcore instructions. For float,
|
| 80 |
+
// we do not target TCs.
|
| 81 |
+
using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
|
| 82 |
+
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
|
| 83 |
+
|
| 84 |
+
using EpilogueOp =
|
| 85 |
+
typename Epilogue<ElementType, MixedGemmArchTraits::ElementsPerAccessC, ElementAccumulator, EpilogueTag>::Op;
|
| 86 |
+
|
| 87 |
+
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
|
| 88 |
+
ElementType,
|
| 89 |
+
cutlass::layout::RowMajor,
|
| 90 |
+
MixedGemmArchTraits::ElementsPerAccessA,
|
| 91 |
+
CutlassWeightType,
|
| 92 |
+
typename MixedGemmArchTraits::LayoutB,
|
| 93 |
+
MixedGemmArchTraits::ElementsPerAccessB,
|
| 94 |
+
ElementType,
|
| 95 |
+
cutlass::layout::RowMajor,
|
| 96 |
+
ElementAccumulator,
|
| 97 |
+
cutlass::arch::OpClassTensorOp,
|
| 98 |
+
arch,
|
| 99 |
+
ThreadblockShape,
|
| 100 |
+
WarpShape,
|
| 101 |
+
typename MixedGemmArchTraits::InstructionShape,
|
| 102 |
+
EpilogueOp,
|
| 103 |
+
typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
| 104 |
+
Stages,
|
| 105 |
+
true,
|
| 106 |
+
typename MixedGemmArchTraits::Operator>::GemmKernel;
|
| 107 |
+
|
| 108 |
+
using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB<typename GemmKernel_::Mma,
|
| 109 |
+
typename GemmKernel_::Epilogue,
|
| 110 |
+
typename GemmKernel_::ThreadblockSwizzle,
|
| 111 |
+
arch, // Ensure top level arch is used for dispatch
|
| 112 |
+
GemmKernel_::kSplitKSerial>;
|
| 113 |
+
|
| 114 |
+
if (occupancy != nullptr)
|
| 115 |
+
{
|
| 116 |
+
*occupancy = compute_occupancy_for_kernel<GemmKernel>();
|
| 117 |
+
return;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
|
| 121 |
+
|
| 122 |
+
const int ldb =
|
| 123 |
+
cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value ? n : k * GemmKernel::kInterleave;
|
| 124 |
+
|
| 125 |
+
typename Gemm::Arguments args({m, n, k},
|
| 126 |
+
{reinterpret_cast<ElementType *>(const_cast<T *>(A)), k},
|
| 127 |
+
{reinterpret_cast<CutlassWeightType *>(const_cast<WeightType *>(B)), ldb},
|
| 128 |
+
{reinterpret_cast<ElementType *>(const_cast<T *>(weight_scales)), 0},
|
| 129 |
+
// TODO: Support more general bias shape
|
| 130 |
+
{reinterpret_cast<ElementType *>(const_cast<T *>(biases)), bias_stride},
|
| 131 |
+
{reinterpret_cast<ElementType *>(C), n},
|
| 132 |
+
gemm_config.split_k_factor,
|
| 133 |
+
{ElementAccumulator(1.f), ElementAccumulator(0.f)});
|
| 134 |
+
|
| 135 |
+
// This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of
|
| 136 |
+
// threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the
|
| 137 |
+
// interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write
|
| 138 |
+
// our own predicated iterator in order to relax this limitation.
|
| 139 |
+
if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK)))
|
| 140 |
+
{
|
| 141 |
+
throw std::runtime_error("Temp assertion: k must be multiple of threadblockK");
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
Gemm gemm;
|
| 145 |
+
if (gemm.get_workspace_size(args) > workspace_bytes)
|
| 146 |
+
{
|
| 147 |
+
FT_LOG_WARNING(
|
| 148 |
+
"Requested split-k but workspace size insufficient. Falling back to non-split-k implementation.");
|
| 149 |
+
// If requested split-k factor will require more workspace bytes, revert to standard gemm.
|
| 150 |
+
args.batch_count = 1;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
auto can_implement = gemm.can_implement(args);
|
| 154 |
+
if (can_implement != cutlass::Status::kSuccess)
|
| 155 |
+
{
|
| 156 |
+
std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement));
|
| 157 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
auto init_status = gemm.initialize(args, workspace, stream);
|
| 161 |
+
if (init_status != cutlass::Status::kSuccess)
|
| 162 |
+
{
|
| 163 |
+
std::string err_msg =
|
| 164 |
+
"Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status));
|
| 165 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
auto run_status = gemm.run(stream);
|
| 169 |
+
if (run_status != cutlass::Status::kSuccess)
|
| 170 |
+
{
|
| 171 |
+
std::string err_msg =
|
| 172 |
+
"Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
|
| 173 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
template<typename T,
|
| 178 |
+
typename WeightType,
|
| 179 |
+
typename arch,
|
| 180 |
+
typename EpilogueTag,
|
| 181 |
+
typename ThreadblockShape,
|
| 182 |
+
typename WarpShape,
|
| 183 |
+
int Stages,
|
| 184 |
+
typename Enable = void>
|
| 185 |
+
struct dispatch_stages {
|
| 186 |
+
static void dispatch(const T *A,
|
| 187 |
+
const WeightType *B,
|
| 188 |
+
const T *weight_scales,
|
| 189 |
+
const T *biases,
|
| 190 |
+
T *C,
|
| 191 |
+
int m,
|
| 192 |
+
int n,
|
| 193 |
+
int k,
|
| 194 |
+
int bias_stride,
|
| 195 |
+
CutlassGemmConfig gemm_config,
|
| 196 |
+
char *workspace,
|
| 197 |
+
size_t workspace_bytes,
|
| 198 |
+
cudaStream_t stream,
|
| 199 |
+
int *occupancy = nullptr)
|
| 200 |
+
{
|
| 201 |
+
|
| 202 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 203 |
+
std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages);
|
| 204 |
+
throw std::runtime_error("[FT Error][dispatch_stages::dispatch] " + err_msg);
|
| 205 |
+
}
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
template<typename T,
|
| 209 |
+
typename WeightType,
|
| 210 |
+
typename arch,
|
| 211 |
+
typename EpilogueTag,
|
| 212 |
+
typename ThreadblockShape,
|
| 213 |
+
typename WarpShape>
|
| 214 |
+
struct dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2> {
|
| 215 |
+
static void dispatch(const T *A,
|
| 216 |
+
const WeightType *B,
|
| 217 |
+
const T *weight_scales,
|
| 218 |
+
const T *biases,
|
| 219 |
+
T *C,
|
| 220 |
+
int m,
|
| 221 |
+
int n,
|
| 222 |
+
int k,
|
| 223 |
+
int bias_stride,
|
| 224 |
+
CutlassGemmConfig gemm_config,
|
| 225 |
+
char *workspace,
|
| 226 |
+
size_t workspace_bytes,
|
| 227 |
+
cudaStream_t stream,
|
| 228 |
+
int *occupancy = nullptr)
|
| 229 |
+
{
|
| 230 |
+
|
| 231 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 232 |
+
generic_mixed_gemm_kernelLauncher<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>(
|
| 233 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
| 234 |
+
}
|
| 235 |
+
};
|
| 236 |
+
|
| 237 |
+
template<typename T,
|
| 238 |
+
typename WeightType,
|
| 239 |
+
typename EpilogueTag,
|
| 240 |
+
typename ThreadblockShape,
|
| 241 |
+
typename WarpShape,
|
| 242 |
+
int Stages>
|
| 243 |
+
struct dispatch_stages<T,
|
| 244 |
+
WeightType,
|
| 245 |
+
cutlass::arch::Sm80,
|
| 246 |
+
EpilogueTag,
|
| 247 |
+
ThreadblockShape,
|
| 248 |
+
WarpShape,
|
| 249 |
+
Stages,
|
| 250 |
+
typename std::enable_if<(Stages > 2)>::type> {
|
| 251 |
+
static void dispatch(const T *A,
|
| 252 |
+
const WeightType *B,
|
| 253 |
+
const T *weight_scales,
|
| 254 |
+
const T *biases,
|
| 255 |
+
T *C,
|
| 256 |
+
int m,
|
| 257 |
+
int n,
|
| 258 |
+
int k,
|
| 259 |
+
int bias_stride,
|
| 260 |
+
CutlassGemmConfig gemm_config,
|
| 261 |
+
char *workspace,
|
| 262 |
+
size_t workspace_bytes,
|
| 263 |
+
cudaStream_t stream,
|
| 264 |
+
int *occupancy = nullptr)
|
| 265 |
+
{
|
| 266 |
+
|
| 267 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 268 |
+
generic_mixed_gemm_kernelLauncher<T,
|
| 269 |
+
WeightType,
|
| 270 |
+
cutlass::arch::Sm80,
|
| 271 |
+
EpilogueTag,
|
| 272 |
+
ThreadblockShape,
|
| 273 |
+
WarpShape,
|
| 274 |
+
Stages>(
|
| 275 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
| 276 |
+
}
|
| 277 |
+
};
|
| 278 |
+
|
| 279 |
+
template <typename T,
|
| 280 |
+
typename WeightType,
|
| 281 |
+
typename arch,
|
| 282 |
+
typename EpilogueTag,
|
| 283 |
+
typename ThreadblockShape,
|
| 284 |
+
typename WarpShape>
|
| 285 |
+
void dispatch_gemm_config(const T *A,
|
| 286 |
+
const WeightType *B,
|
| 287 |
+
const T *weight_scales,
|
| 288 |
+
const T *biases,
|
| 289 |
+
T *C,
|
| 290 |
+
int m,
|
| 291 |
+
int n,
|
| 292 |
+
int k,
|
| 293 |
+
int bias_stride,
|
| 294 |
+
CutlassGemmConfig gemm_config,
|
| 295 |
+
char *workspace,
|
| 296 |
+
size_t workspace_bytes,
|
| 297 |
+
cudaStream_t stream,
|
| 298 |
+
int *occupancy = nullptr)
|
| 299 |
+
{
|
| 300 |
+
|
| 301 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 302 |
+
switch (gemm_config.stages) {
|
| 303 |
+
case 2:
|
| 304 |
+
using DispatcherStages2 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 2>;
|
| 305 |
+
DispatcherStages2::dispatch(
|
| 306 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
| 307 |
+
break;
|
| 308 |
+
case 3:
|
| 309 |
+
using DispatcherStages3 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 3>;
|
| 310 |
+
DispatcherStages3::dispatch(
|
| 311 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
| 312 |
+
break;
|
| 313 |
+
case 4:
|
| 314 |
+
using DispatcherStages4 = dispatch_stages<T, WeightType, arch, EpilogueTag, ThreadblockShape, WarpShape, 4>;
|
| 315 |
+
DispatcherStages4::dispatch(
|
| 316 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
| 317 |
+
break;
|
| 318 |
+
default:
|
| 319 |
+
std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages);
|
| 320 |
+
throw std::runtime_error("[FT Error][dispatch_gemm_config] " + err_msg);
|
| 321 |
+
break;
|
| 322 |
+
}
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
template <typename T, typename WeightType, typename arch, typename EpilogueTag>
|
| 326 |
+
void dispatch_gemm_to_cutlass(const T *A,
|
| 327 |
+
const WeightType *B,
|
| 328 |
+
const T *weight_scales,
|
| 329 |
+
const T *biases,
|
| 330 |
+
T *C,
|
| 331 |
+
int m,
|
| 332 |
+
int n,
|
| 333 |
+
int k,
|
| 334 |
+
int bias_stride,
|
| 335 |
+
char *workspace,
|
| 336 |
+
size_t workspace_bytes,
|
| 337 |
+
CutlassGemmConfig gemm_config,
|
| 338 |
+
cudaStream_t stream,
|
| 339 |
+
int *occupancy = nullptr)
|
| 340 |
+
{
|
| 341 |
+
|
| 342 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 343 |
+
|
| 344 |
+
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
|
| 345 |
+
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best
|
| 346 |
+
// for mixed type gemms.
|
| 347 |
+
switch (gemm_config.tile_config) {
|
| 348 |
+
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
| 349 |
+
dispatch_gemm_config<T,
|
| 350 |
+
WeightType,
|
| 351 |
+
arch,
|
| 352 |
+
EpilogueTag,
|
| 353 |
+
cutlass::gemm::GemmShape<32, 128, 64>,
|
| 354 |
+
cutlass::gemm::GemmShape<32, 32, 64>>(
|
| 355 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
| 356 |
+
break;
|
| 357 |
+
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
| 358 |
+
dispatch_gemm_config<T,
|
| 359 |
+
WeightType,
|
| 360 |
+
arch,
|
| 361 |
+
EpilogueTag,
|
| 362 |
+
cutlass::gemm::GemmShape<64, 128, 64>,
|
| 363 |
+
cutlass::gemm::GemmShape<64, 32, 64>>(
|
| 364 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
| 365 |
+
break;
|
| 366 |
+
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
| 367 |
+
dispatch_gemm_config<T,
|
| 368 |
+
WeightType,
|
| 369 |
+
arch,
|
| 370 |
+
EpilogueTag,
|
| 371 |
+
cutlass::gemm::GemmShape<128, 128, 64>,
|
| 372 |
+
cutlass::gemm::GemmShape<128, 32, 64>>(
|
| 373 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
| 374 |
+
break;
|
| 375 |
+
case CutlassTileConfig::Undefined:
|
| 376 |
+
throw std::runtime_error("[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined.");
|
| 377 |
+
break;
|
| 378 |
+
case CutlassTileConfig::ChooseWithHeuristic:
|
| 379 |
+
throw std::runtime_error(
|
| 380 |
+
"[FT Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by heuristic.");
|
| 381 |
+
break;
|
| 382 |
+
default:
|
| 383 |
+
throw std::runtime_error(
|
| 384 |
+
"[FT Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM.");
|
| 385 |
+
break;
|
| 386 |
+
}
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
template<typename T, typename WeightType>
|
| 390 |
+
CutlassFpAIntBGemmRunner<T, WeightType>::CutlassFpAIntBGemmRunner()
|
| 391 |
+
{
|
| 392 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 393 |
+
int device{-1};
|
| 394 |
+
check_cuda_error(cudaGetDevice(&device));
|
| 395 |
+
sm_ = getSMVersion();
|
| 396 |
+
check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
template<typename T, typename WeightType>
|
| 400 |
+
CutlassFpAIntBGemmRunner<T, WeightType>::~CutlassFpAIntBGemmRunner()
|
| 401 |
+
{
|
| 402 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
template<typename T, typename WeightType>
|
| 406 |
+
template<typename EpilogueTag>
|
| 407 |
+
void CutlassFpAIntBGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(const T* A,
|
| 408 |
+
const WeightType* B,
|
| 409 |
+
const T* weight_scales,
|
| 410 |
+
const T* biases,
|
| 411 |
+
T* C,
|
| 412 |
+
int m,
|
| 413 |
+
int n,
|
| 414 |
+
int k,
|
| 415 |
+
int bias_stride,
|
| 416 |
+
CutlassGemmConfig gemm_config,
|
| 417 |
+
char* workspace_ptr,
|
| 418 |
+
const size_t workspace_bytes,
|
| 419 |
+
cudaStream_t stream,
|
| 420 |
+
int* occupancy)
|
| 421 |
+
{
|
| 422 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 423 |
+
if (sm_ >= 70 && sm_ < 75) {
|
| 424 |
+
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm70, EpilogueTag>(
|
| 425 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
| 426 |
+
} else if (sm_ >= 75 && sm_ < 80) {
|
| 427 |
+
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm75, EpilogueTag>(
|
| 428 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
| 429 |
+
} else if (sm_ >= 80 && sm_ < 90) {
|
| 430 |
+
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(
|
| 431 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
| 432 |
+
}
|
| 433 |
+
else {
|
| 434 |
+
throw std::runtime_error(
|
| 435 |
+
"[FT Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type GEMM");
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
template<typename T, typename WeightType>
|
| 440 |
+
template<typename EpilogueTag>
|
| 441 |
+
void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(const T* A,
|
| 442 |
+
const WeightType* B,
|
| 443 |
+
const T* weight_scales,
|
| 444 |
+
const T* biases,
|
| 445 |
+
T* C,
|
| 446 |
+
int m,
|
| 447 |
+
int n,
|
| 448 |
+
int k,
|
| 449 |
+
int bias_stride,
|
| 450 |
+
char* workspace_ptr,
|
| 451 |
+
const size_t workspace_bytes,
|
| 452 |
+
cudaStream_t stream)
|
| 453 |
+
{
|
| 454 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 455 |
+
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
|
| 456 |
+
std::vector<CutlassGemmConfig> candidate_configs = get_candidate_configs(sm_, is_weight_only, false);
|
| 457 |
+
std::vector<int> occupancies(candidate_configs.size());
|
| 458 |
+
|
| 459 |
+
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
|
| 460 |
+
dispatch_to_arch<EpilogueTag>(A,
|
| 461 |
+
B,
|
| 462 |
+
weight_scales,
|
| 463 |
+
biases,
|
| 464 |
+
C,
|
| 465 |
+
m,
|
| 466 |
+
n,
|
| 467 |
+
k,
|
| 468 |
+
bias_stride,
|
| 469 |
+
candidate_configs[ii],
|
| 470 |
+
workspace_ptr,
|
| 471 |
+
workspace_bytes,
|
| 472 |
+
stream,
|
| 473 |
+
&occupancies[ii]);
|
| 474 |
+
}
|
| 475 |
+
// Standard GEMM, so 1 "expert". We use the same function for MoE and regular FFN.
|
| 476 |
+
static constexpr int num_experts = 1;
|
| 477 |
+
CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(candidate_configs,
|
| 478 |
+
occupancies,
|
| 479 |
+
m,
|
| 480 |
+
n,
|
| 481 |
+
k,
|
| 482 |
+
num_experts,
|
| 483 |
+
split_k_limit,
|
| 484 |
+
workspace_bytes,
|
| 485 |
+
multi_processor_count_,
|
| 486 |
+
is_weight_only);
|
| 487 |
+
|
| 488 |
+
dispatch_to_arch<EpilogueTag>(
|
| 489 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, chosen_config, workspace_ptr, workspace_bytes, stream);
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
template <typename T, typename WeightType>
|
| 493 |
+
void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act(const T *A,
|
| 494 |
+
const WeightType *B,
|
| 495 |
+
const T *weight_scales,
|
| 496 |
+
const T *biases,
|
| 497 |
+
T *C,
|
| 498 |
+
int m,
|
| 499 |
+
int n,
|
| 500 |
+
int k,
|
| 501 |
+
int bias_stride,
|
| 502 |
+
ActivationType activation_type,
|
| 503 |
+
char *workspace_ptr,
|
| 504 |
+
const size_t workspace_bytes,
|
| 505 |
+
cudaStream_t stream)
|
| 506 |
+
{
|
| 507 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 508 |
+
|
| 509 |
+
switch (activation_type) {
|
| 510 |
+
case ActivationType::Relu:
|
| 511 |
+
run_gemm<EpilogueOpBiasReLU>(
|
| 512 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
|
| 513 |
+
break;
|
| 514 |
+
case ActivationType::Gelu:
|
| 515 |
+
run_gemm<EpilogueOpBiasFtGelu>(
|
| 516 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
|
| 517 |
+
break;
|
| 518 |
+
case ActivationType::Silu:
|
| 519 |
+
run_gemm<EpilogueOpBiasSilu>(
|
| 520 |
+
A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
|
| 521 |
+
break;
|
| 522 |
+
case ActivationType::Identity:
|
| 523 |
+
run_gemm<EpilogueOpBias>(A, B, weight_scales, biases, C, m, n, k, bias_stride, workspace_ptr, workspace_bytes, stream);
|
| 524 |
+
break;
|
| 525 |
+
case ActivationType::InvalidType:
|
| 526 |
+
FT_CHECK_WITH_INFO(false, "Activation type for fpA_intB must be valid.");
|
| 527 |
+
break;
|
| 528 |
+
default: {
|
| 529 |
+
if (isGatedActivation(activation_type)) {
|
| 530 |
+
FT_CHECK_WITH_INFO(false, "Fused gated activations not supported");
|
| 531 |
+
}
|
| 532 |
+
else {
|
| 533 |
+
FT_CHECK_WITH_INFO(false, "Invalid activation type.");
|
| 534 |
+
}
|
| 535 |
+
}
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
template<typename T, typename WeightType>
|
| 540 |
+
void CutlassFpAIntBGemmRunner<T, WeightType>::gemm(const T* A,
|
| 541 |
+
const WeightType* B,
|
| 542 |
+
const T* weight_scales,
|
| 543 |
+
T* C,
|
| 544 |
+
int m,
|
| 545 |
+
int n,
|
| 546 |
+
int k,
|
| 547 |
+
char* workspace_ptr,
|
| 548 |
+
const size_t workspace_bytes,
|
| 549 |
+
cudaStream_t stream)
|
| 550 |
+
{
|
| 551 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 552 |
+
run_gemm<EpilogueOpNoBias>(A, B, weight_scales, nullptr, C, m, n, k, 0, workspace_ptr, workspace_bytes, stream);
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
template <typename T, typename WeightType, typename Arch,
|
| 556 |
+
typename ThreadblockShape, typename WarpShape, typename EpilogueOp,
|
| 557 |
+
int stages>
|
| 558 |
+
void dispatch_gemm_residual(const T *A, const WeightType *B,
|
| 559 |
+
const T *weight_scales, const T *biases,
|
| 560 |
+
const T *residual, T *C, int m, int n, int k,
|
| 561 |
+
char *workspace_ptr, const size_t workspace_bytes,
|
| 562 |
+
cudaStream_t stream) {
|
| 563 |
+
using ElementType = typename cutlass::platform::conditional<
|
| 564 |
+
cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
|
| 565 |
+
using ElementOutput = ElementType;
|
| 566 |
+
|
| 567 |
+
using MixedGemmArchTraits =
|
| 568 |
+
cutlass::gemm::kernel::MixedGemmArchTraits<ElementType, WeightType, Arch>;
|
| 569 |
+
using ElementAccumulator = typename EpilogueOp::ElementAccumulator;
|
| 570 |
+
|
| 571 |
+
using Swizzle =
|
| 572 |
+
typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
|
| 573 |
+
using InstructionShape = typename MixedGemmArchTraits::InstructionShape;
|
| 574 |
+
|
| 575 |
+
using Epilogue = typename cutlass::gemm::kernel::DefaultGemmWithBroadcast<
|
| 576 |
+
ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone,
|
| 577 |
+
MixedGemmArchTraits::ElementsPerAccessA, WeightType,
|
| 578 |
+
typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone,
|
| 579 |
+
MixedGemmArchTraits::ElementsPerAccessB, ElementType,
|
| 580 |
+
cutlass::layout::RowMajor, ElementAccumulator,
|
| 581 |
+
cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape,
|
| 582 |
+
InstructionShape, EpilogueOp, Swizzle, stages,
|
| 583 |
+
typename MixedGemmArchTraits::Operator>::Epilogue;
|
| 584 |
+
|
| 585 |
+
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
|
| 586 |
+
ElementType, cutlass::layout::RowMajor,
|
| 587 |
+
MixedGemmArchTraits::ElementsPerAccessA, WeightType,
|
| 588 |
+
typename MixedGemmArchTraits::LayoutB,
|
| 589 |
+
MixedGemmArchTraits::ElementsPerAccessB, ElementType,
|
| 590 |
+
cutlass::layout::RowMajor, ElementAccumulator,
|
| 591 |
+
cutlass::arch::OpClassTensorOp, Arch, ThreadblockShape, WarpShape,
|
| 592 |
+
InstructionShape, EpilogueOp, Swizzle, stages, true,
|
| 593 |
+
typename MixedGemmArchTraits::Operator>::GemmKernel;
|
| 594 |
+
|
| 595 |
+
using GemmKernel = cutlass::gemm::kernel::GemmFpAIntBWithBroadcast<
|
| 596 |
+
typename GemmKernel_::Mma, Epilogue,
|
| 597 |
+
typename GemmKernel_::ThreadblockSwizzle, Arch>;
|
| 598 |
+
|
| 599 |
+
using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
|
| 600 |
+
|
| 601 |
+
// TODO: Support batch
|
| 602 |
+
const int batch_count = 1;
|
| 603 |
+
const auto lda = k;
|
| 604 |
+
const int ldb =
|
| 605 |
+
cutlass::platform::is_same<cutlass::layout::RowMajor,
|
| 606 |
+
typename MixedGemmArchTraits::LayoutB>::value
|
| 607 |
+
? n
|
| 608 |
+
: k * GemmKernel::kInterleave;
|
| 609 |
+
const int ldc = n;
|
| 610 |
+
|
| 611 |
+
typename Gemm::Arguments args(
|
| 612 |
+
{m, n, k}, batch_count,
|
| 613 |
+
{ElementAccumulator(1.f), ElementAccumulator(1.f)}, A, B, weight_scales,
|
| 614 |
+
residual, C, biases, nullptr, 0, 0, 0, 0, 0, 0, lda, ldb, ldc, ldc, 0, 0);
|
| 615 |
+
|
| 616 |
+
if (GemmKernel::kInterleave > 1 &&
|
| 617 |
+
((k % MixedGemmArchTraits::ThreadblockK) ||
|
| 618 |
+
(k % MixedGemmArchTraits::ThreadblockK))) {
|
| 619 |
+
throw std::runtime_error(
|
| 620 |
+
"Temp assertion: k must be multiple of threadblockK");
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
Gemm gemm;
|
| 624 |
+
auto can_implement = gemm.can_implement(args);
|
| 625 |
+
if (can_implement != cutlass::Status::kSuccess) {
|
| 626 |
+
std::string err_msg =
|
| 627 |
+
"fpA_intB cutlass kernel will fail for params. Error: " +
|
| 628 |
+
std::string(cutlassGetStatusString(can_implement));
|
| 629 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
auto init_status = gemm.initialize(args, workspace_ptr, stream);
|
| 633 |
+
if (init_status != cutlass::Status::kSuccess) {
|
| 634 |
+
std::string err_msg =
|
| 635 |
+
"Failed to initialize cutlass fpA_intB gemm. Error: " +
|
| 636 |
+
std::string(cutlassGetStatusString(init_status));
|
| 637 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
auto run_status = gemm.run(stream);
|
| 641 |
+
if (run_status != cutlass::Status::kSuccess) {
|
| 642 |
+
std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " +
|
| 643 |
+
std::string(cutlassGetStatusString(run_status));
|
| 644 |
+
throw std::runtime_error("[FT Error][fpA_intB Runner] " + err_msg);
|
| 645 |
+
}
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
template <typename T, typename WeightType, typename Arch, typename EpilogueOp,
|
| 649 |
+
int stages>
|
| 650 |
+
void dispatch_gemm_residual(CutlassTileConfig tile_config, const T *A,
|
| 651 |
+
const WeightType *B, const T *weight_scales,
|
| 652 |
+
const T *biases, const T *residual, T *C, int m,
|
| 653 |
+
int n, int k, char *workspace_ptr,
|
| 654 |
+
const size_t workspace_bytes, cudaStream_t stream) {
|
| 655 |
+
if (tile_config == CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64) {
|
| 656 |
+
dispatch_gemm_residual<
|
| 657 |
+
T, WeightType, Arch, cutlass::gemm::GemmShape<32, 128, 64>,
|
| 658 |
+
cutlass::gemm::GemmShape<32, 32, 64>, EpilogueOp, stages>(
|
| 659 |
+
A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
|
| 660 |
+
workspace_bytes, stream);
|
| 661 |
+
} else if (tile_config ==
|
| 662 |
+
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64) {
|
| 663 |
+
dispatch_gemm_residual<
|
| 664 |
+
T, WeightType, Arch, cutlass::gemm::GemmShape<64, 128, 64>,
|
| 665 |
+
cutlass::gemm::GemmShape<64, 32, 64>, EpilogueOp, stages>(
|
| 666 |
+
A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
|
| 667 |
+
workspace_bytes, stream);
|
| 668 |
+
} else { // CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
| 669 |
+
dispatch_gemm_residual<
|
| 670 |
+
T, WeightType, Arch, cutlass::gemm::GemmShape<128, 128, 64>,
|
| 671 |
+
cutlass::gemm::GemmShape<128, 32, 64>, EpilogueOp, stages>(
|
| 672 |
+
A, B, weight_scales, biases, residual, C, m, n, k, workspace_ptr,
|
| 673 |
+
workspace_bytes, stream);
|
| 674 |
+
}
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
template <typename T, typename WeightType, typename Arch, typename EpilogueOp>
|
| 678 |
+
void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
|
| 679 |
+
const WeightType *B, const T *weight_scales,
|
| 680 |
+
const T *biases, const T *residual, T *C, int m,
|
| 681 |
+
int n, int k, char *workspace_ptr,
|
| 682 |
+
const size_t workspace_bytes, cudaStream_t stream) {
|
| 683 |
+
if constexpr (std::is_same<Arch, cutlass::arch::Sm75>::value) {
|
| 684 |
+
dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm75, EpilogueOp, 2>(
|
| 685 |
+
config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
| 686 |
+
workspace_ptr, workspace_bytes, stream);
|
| 687 |
+
} else if constexpr (std::is_same<Arch, cutlass::arch::Sm70>::value) {
|
| 688 |
+
dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm70, EpilogueOp, 2>(
|
| 689 |
+
config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
| 690 |
+
workspace_ptr, workspace_bytes, stream);
|
| 691 |
+
} else {
|
| 692 |
+
if (config.stages == 3) {
|
| 693 |
+
dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 3>(
|
| 694 |
+
config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
| 695 |
+
workspace_ptr, workspace_bytes, stream);
|
| 696 |
+
} else if (config.stages == 4) {
|
| 697 |
+
dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 4>(
|
| 698 |
+
config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
| 699 |
+
workspace_ptr, workspace_bytes, stream);
|
| 700 |
+
} else { // 2
|
| 701 |
+
dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp, 2>(
|
| 702 |
+
config.tile_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
| 703 |
+
workspace_ptr, workspace_bytes, stream);
|
| 704 |
+
}
|
| 705 |
+
}
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
template <typename T, typename WeightType, typename Arch,
|
| 709 |
+
template <typename T_> class ActivationOp,
|
| 710 |
+
template <typename T_> class BinaryOp>
|
| 711 |
+
inline void
|
| 712 |
+
dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
|
| 713 |
+
const WeightType *B, const T *weight_scales,
|
| 714 |
+
const T *biases, const T *residual, T *C, int m, int n,
|
| 715 |
+
int k, const std::string &unary_op, char *workspace_ptr,
|
| 716 |
+
const size_t workspace_bytes, cudaStream_t stream) {
|
| 717 |
+
using ElementOutput = T;
|
| 718 |
+
using MixedGemmArchTraits =
|
| 719 |
+
cutlass::gemm::kernel::MixedGemmArchTraits<T, WeightType, Arch>;
|
| 720 |
+
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
|
| 721 |
+
|
| 722 |
+
if (unary_op == "identity") {
|
| 723 |
+
using EpilogueOp =
|
| 724 |
+
cutlass::epilogue::thread::LinearCombinationResidualBlock<
|
| 725 |
+
ElementOutput, ElementAccumulator, ElementAccumulator,
|
| 726 |
+
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
| 727 |
+
ActivationOp, BinaryOp, cutlass::epilogue::thread::Identity>;
|
| 728 |
+
dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp>(
|
| 729 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k,
|
| 730 |
+
workspace_ptr, workspace_bytes, stream);
|
| 731 |
+
} else if (unary_op == "relu") {
|
| 732 |
+
using EpilogueOp =
|
| 733 |
+
cutlass::epilogue::thread::LinearCombinationResidualBlock<
|
| 734 |
+
ElementOutput, ElementAccumulator, ElementAccumulator,
|
| 735 |
+
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
| 736 |
+
ActivationOp, BinaryOp, cutlass::epilogue::thread::ReLu>;
|
| 737 |
+
dispatch_gemm_residual<T, WeightType, Arch, EpilogueOp>(
|
| 738 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k,
|
| 739 |
+
workspace_ptr, workspace_bytes, stream);
|
| 740 |
+
} else {
|
| 741 |
+
throw std::runtime_error(
|
| 742 |
+
"[FT Error][Unsupported unary op after residual block] " + unary_op);
|
| 743 |
+
}
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
template <typename T, typename WeightType, typename Arch,
|
| 747 |
+
template <typename T_> class ActivationOp>
|
| 748 |
+
void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
|
| 749 |
+
const WeightType *B, const T *weight_scales,
|
| 750 |
+
const T *biases, const T *residual, T *C, int m,
|
| 751 |
+
int n, int k, const std::string &binary_op,
|
| 752 |
+
const std::string &unary_op, char *workspace_ptr,
|
| 753 |
+
const size_t workspace_bytes, cudaStream_t stream) {
|
| 754 |
+
if (binary_op == "plus") {
|
| 755 |
+
dispatch_gemm_residual<T, WeightType, Arch, ActivationOp, cutlass::plus>(
|
| 756 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op,
|
| 757 |
+
workspace_ptr, workspace_bytes, stream);
|
| 758 |
+
} else if (binary_op == "multiply") {
|
| 759 |
+
dispatch_gemm_residual<T, WeightType, Arch, ActivationOp,
|
| 760 |
+
cutlass::multiplies>(
|
| 761 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, unary_op,
|
| 762 |
+
workspace_ptr, workspace_bytes, stream);
|
| 763 |
+
} else {
|
| 764 |
+
throw std::runtime_error(
|
| 765 |
+
"[FT Error][Unsupported binary op for residual block] " + binary_op);
|
| 766 |
+
}
|
| 767 |
+
}
|
| 768 |
+
|
| 769 |
+
template <typename T, typename WeightType, typename Arch>
|
| 770 |
+
void dispatch_gemm_residual(CutlassGemmConfig config, const T *A,
|
| 771 |
+
const WeightType *B, const T *weight_scales,
|
| 772 |
+
const T *biases, const T *residual, T *C, int m,
|
| 773 |
+
int n, int k, const std::string &activation,
|
| 774 |
+
const std::string &binary_op,
|
| 775 |
+
const std::string &unary_op, char *workspace_ptr,
|
| 776 |
+
const size_t workspace_bytes, cudaStream_t stream) {
|
| 777 |
+
if (activation == "identity") {
|
| 778 |
+
dispatch_gemm_residual<T, WeightType, Arch,
|
| 779 |
+
cutlass::epilogue::thread::Identity>(
|
| 780 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
|
| 781 |
+
unary_op, workspace_ptr, workspace_bytes, stream);
|
| 782 |
+
} else if ("silu") {
|
| 783 |
+
dispatch_gemm_residual<T, WeightType, Arch,
|
| 784 |
+
cutlass::epilogue::thread::SiLu>(
|
| 785 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
|
| 786 |
+
unary_op, workspace_ptr, workspace_bytes, stream);
|
| 787 |
+
} else if ("relu") {
|
| 788 |
+
dispatch_gemm_residual<T, WeightType, Arch,
|
| 789 |
+
cutlass::epilogue::thread::ReLu>(
|
| 790 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
|
| 791 |
+
unary_op, workspace_ptr, workspace_bytes, stream);
|
| 792 |
+
} else if ("gelu") {
|
| 793 |
+
dispatch_gemm_residual<T, WeightType, Arch,
|
| 794 |
+
cutlass::epilogue::thread::GELU>(
|
| 795 |
+
config, A, B, weight_scales, biases, residual, C, m, n, k, binary_op,
|
| 796 |
+
unary_op, workspace_ptr, workspace_bytes, stream);
|
| 797 |
+
} else {
|
| 798 |
+
throw std::runtime_error(
|
| 799 |
+
"[FT Error][Unsupported activation before residual binary op] " +
|
| 800 |
+
activation);
|
| 801 |
+
}
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
template <typename T, typename WeightType>
|
| 805 |
+
void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act_residual(
|
| 806 |
+
const T *A, const WeightType *B, const T *weight_scales, const T *biases,
|
| 807 |
+
const T *residual, T *C, int m, int n, int k, const std::string &activation,
|
| 808 |
+
const std::string &binary_op, const std::string &unary_op,
|
| 809 |
+
char *workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) {
|
| 810 |
+
|
| 811 |
+
std::vector<CutlassGemmConfig> candidate_configs =
|
| 812 |
+
get_candidate_configs(sm_, true, false);
|
| 813 |
+
std::vector<int> occupancies(candidate_configs.size());
|
| 814 |
+
|
| 815 |
+
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
|
| 816 |
+
dispatch_to_arch<EpilogueOpNoBias>(
|
| 817 |
+
A, B, weight_scales, biases, C, m, n, k, 0, candidate_configs[ii],
|
| 818 |
+
workspace_ptr, workspace_bytes, stream, &occupancies[ii]);
|
| 819 |
+
}
|
| 820 |
+
|
| 821 |
+
CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(
|
| 822 |
+
candidate_configs, occupancies, m, n, k, 1, split_k_limit,
|
| 823 |
+
workspace_bytes, multi_processor_count_, true);
|
| 824 |
+
|
| 825 |
+
if (sm_ >= 80 && sm_ < 90) {
|
| 826 |
+
dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm80>(
|
| 827 |
+
chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
| 828 |
+
activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
|
| 829 |
+
stream);
|
| 830 |
+
} else if (sm_ >= 75 && sm_ < 80) {
|
| 831 |
+
dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm75>(
|
| 832 |
+
chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
| 833 |
+
activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
|
| 834 |
+
stream);
|
| 835 |
+
} else if (sm_ == 70) {
|
| 836 |
+
dispatch_gemm_residual<T, WeightType, cutlass::arch::Sm70>(
|
| 837 |
+
chosen_config, A, B, weight_scales, biases, residual, C, m, n, k,
|
| 838 |
+
activation, binary_op, unary_op, workspace_ptr, workspace_bytes,
|
| 839 |
+
stream);
|
| 840 |
+
} else {
|
| 841 |
+
throw std::runtime_error("[FT Error][Unsupported SM] " + sm_);
|
| 842 |
+
}
|
| 843 |
+
}
|
| 844 |
+
|
| 845 |
+
template<typename T, typename WeightType>
|
| 846 |
+
int CutlassFpAIntBGemmRunner<T, WeightType>::getWorkspaceSize(const int m, const int n, const int k)
|
| 847 |
+
{
|
| 848 |
+
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
|
| 849 |
+
// TODO(masahi): Shouldn't it be 0?
|
| 850 |
+
|
| 851 |
+
// These are the min tile sizes for each config, which would launch the maximum number of blocks
|
| 852 |
+
const int max_grid_m = (m + 31) / 32;
|
| 853 |
+
const int max_grid_n = (n + 127) / 128;
|
| 854 |
+
// We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
|
| 855 |
+
return max_grid_m * max_grid_n * split_k_limit * 4;
|
| 856 |
+
}
|
| 857 |
+
|
| 858 |
+
} // namespace fastertransformer
|
cutlass_kernels/fpA_intB_gemm_wrapper.cu
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/all.h>
|
| 2 |
+
#include "cub/cub.cuh"
|
| 3 |
+
#include <cuda_runtime.h>
|
| 4 |
+
#include <cuda_fp16.h>
|
| 5 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 6 |
+
#include "fpA_intB_gemm_wrapper.h"
|
| 7 |
+
#include "fpA_intB_gemm.h"
|
| 8 |
+
#include "cutlass_preprocessors.h"
|
| 9 |
+
#include "cuda_utils.h"
|
| 10 |
+
#include "weightOnlyBatchedGemv/enabled.h"
|
| 11 |
+
#include "weightOnlyBatchedGemv/kernelLauncher.h"
|
| 12 |
+
#include "torch_utils.h"
|
| 13 |
+
|
| 14 |
+
#include <vector>
|
| 15 |
+
|
| 16 |
+
namespace ft = fastertransformer;
|
| 17 |
+
|
| 18 |
+
int getWorkspaceSize(const int m, const int n, const int k)
|
| 19 |
+
{
|
| 20 |
+
// These are the min tile sizes for each config, which would launch the maximum number of blocks
|
| 21 |
+
const int max_grid_m = (m + 31) / 32;
|
| 22 |
+
const int max_grid_n = (n + 127) / 128;
|
| 23 |
+
const int split_k_limit = 7;
|
| 24 |
+
// We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
|
| 25 |
+
return max_grid_m * max_grid_n * split_k_limit * 4;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
std::vector<torch::Tensor>
|
| 29 |
+
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
|
| 30 |
+
at::ScalarType quant_type,
|
| 31 |
+
bool return_unprocessed_quantized_tensor)
|
| 32 |
+
{
|
| 33 |
+
CHECK_CPU(weight);
|
| 34 |
+
CHECK_CONTIGUOUS(weight);
|
| 35 |
+
TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
|
| 36 |
+
TORCH_CHECK(weight.dim() == 2 || weight.dim() == 3, "Invalid dim. The dim of weight should be 2 or 3");
|
| 37 |
+
|
| 38 |
+
auto _st = weight.scalar_type();
|
| 39 |
+
TORCH_CHECK(_st == torch::kFloat32 || _st == torch::kFloat16, "Invalid datatype. Weight must be FP16 or FP32");
|
| 40 |
+
TORCH_CHECK(quant_type == torch::kInt8 || quant_type == at::ScalarType::QUInt4x2, "Must be int4 or int8 quantization");
|
| 41 |
+
ft::QuantType ft_quant_type = ft::get_ft_quant_type(quant_type);
|
| 42 |
+
|
| 43 |
+
const size_t num_experts = weight.dim() == 2 ? 1 : weight.size(0);
|
| 44 |
+
const size_t num_rows = weight.size(-2);
|
| 45 |
+
const size_t num_cols = weight.size(-1);
|
| 46 |
+
|
| 47 |
+
const size_t bits_in_type = ft::get_bits_in_quant_type(ft_quant_type);
|
| 48 |
+
const size_t bytes_per_out_col = num_cols * bits_in_type / 8;
|
| 49 |
+
|
| 50 |
+
const size_t input_mat_size = num_rows * num_cols;
|
| 51 |
+
const size_t quantized_mat_size = num_rows * bytes_per_out_col;
|
| 52 |
+
|
| 53 |
+
std::vector<long int> quantized_weight_shape;
|
| 54 |
+
std::vector<long int> scale_shape;
|
| 55 |
+
if (weight.dim() == 2) {
|
| 56 |
+
quantized_weight_shape = {long(num_rows), long(bytes_per_out_col)};
|
| 57 |
+
scale_shape = {long(num_cols)};
|
| 58 |
+
}
|
| 59 |
+
else if (weight.dim() == 3) {
|
| 60 |
+
quantized_weight_shape = {long(num_experts), long(num_rows), long(bytes_per_out_col)};
|
| 61 |
+
scale_shape = {long(num_experts), long(num_cols)};
|
| 62 |
+
}
|
| 63 |
+
else {
|
| 64 |
+
TORCH_CHECK(false, "Invalid weight dimension. Weight must have dim 2 or 3");
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
torch::Tensor unprocessed_quantized_weight =
|
| 68 |
+
torch::empty(quantized_weight_shape, torch::dtype(torch::kInt8).device(torch::kCPU).requires_grad(false));
|
| 69 |
+
|
| 70 |
+
torch::Tensor processed_quantized_weight = torch::empty_like(unprocessed_quantized_weight);
|
| 71 |
+
|
| 72 |
+
torch::Tensor scales = torch::empty(scale_shape, torch::dtype(weight.dtype()).device(torch::kCPU).requires_grad(false));
|
| 73 |
+
|
| 74 |
+
int8_t *unprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(unprocessed_quantized_weight.data_ptr());
|
| 75 |
+
int8_t *processed_quantized_weight_ptr = reinterpret_cast<int8_t *>(processed_quantized_weight.data_ptr());
|
| 76 |
+
|
| 77 |
+
if (weight.scalar_type() == at::ScalarType::Float)
|
| 78 |
+
{
|
| 79 |
+
ft::symmetric_quantize<float, float>(processed_quantized_weight_ptr,
|
| 80 |
+
unprocessed_quantized_weight_ptr,
|
| 81 |
+
reinterpret_cast<float *>(scales.data_ptr()),
|
| 82 |
+
reinterpret_cast<const float *>(weight.data_ptr()),
|
| 83 |
+
{num_rows, num_cols},
|
| 84 |
+
ft_quant_type);
|
| 85 |
+
}
|
| 86 |
+
else if (weight.scalar_type() == at::ScalarType::Half)
|
| 87 |
+
{
|
| 88 |
+
ft::symmetric_quantize<half, half>(processed_quantized_weight_ptr,
|
| 89 |
+
unprocessed_quantized_weight_ptr,
|
| 90 |
+
reinterpret_cast<half *>(scales.data_ptr()),
|
| 91 |
+
reinterpret_cast<const half *>(weight.data_ptr()),
|
| 92 |
+
{num_rows, num_cols},
|
| 93 |
+
ft_quant_type);
|
| 94 |
+
}
|
| 95 |
+
else
|
| 96 |
+
{
|
| 97 |
+
TORCH_CHECK(false, "Invalid data type. Weight must be FP32/FP16");
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
if (return_unprocessed_quantized_tensor)
|
| 101 |
+
{
|
| 102 |
+
return std::vector<torch::Tensor>{unprocessed_quantized_weight, processed_quantized_weight, scales};
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
return std::vector<torch::Tensor>{processed_quantized_weight, scales};
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
torch::Tensor preprocess_weights_cuda(torch::Tensor const &origin_weight,
|
| 109 |
+
bool is_int4)
|
| 110 |
+
{
|
| 111 |
+
// guarantee the weight is cpu tensor
|
| 112 |
+
CHECK_CPU(origin_weight);
|
| 113 |
+
|
| 114 |
+
torch::Tensor preprocessed_quantized_weight = torch::empty_like(origin_weight);
|
| 115 |
+
int8_t *preprocessed_quantized_weight_ptr = reinterpret_cast<int8_t *>(preprocessed_quantized_weight.data_ptr());
|
| 116 |
+
const int8_t *row_major_quantized_weight_ptr = reinterpret_cast<const int8_t *>(origin_weight.data_ptr());
|
| 117 |
+
size_t rows = origin_weight.size(-2);
|
| 118 |
+
size_t cols = origin_weight.size(-1);
|
| 119 |
+
int arch = ft::getSMVersion();
|
| 120 |
+
ft::preprocess_weights(preprocessed_quantized_weight_ptr,
|
| 121 |
+
row_major_quantized_weight_ptr,
|
| 122 |
+
rows,
|
| 123 |
+
cols,
|
| 124 |
+
is_int4,
|
| 125 |
+
arch);
|
| 126 |
+
return preprocessed_quantized_weight;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
|
| 130 |
+
torch::Tensor const &weight,
|
| 131 |
+
torch::Tensor const &scale)
|
| 132 |
+
{
|
| 133 |
+
c10::cuda::CUDAGuard device_guard(input.device());
|
| 134 |
+
// TORCH_CHECK(input.dim() == 3 || input.dim() == 2, "Invalid input dim: ", input.dim());
|
| 135 |
+
const int m = input.dim() == 2 ? input.size(0) : input.size(0) * input.size(1);
|
| 136 |
+
const int k = input.size(-1);
|
| 137 |
+
const int n = weight.size(-1);
|
| 138 |
+
auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
| 139 |
+
torch::Tensor output = input.dim() == 2 ? torch::empty({m, n}, options) : torch::empty({input.size(0), input.size(1), n}, options);
|
| 140 |
+
const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
|
| 141 |
+
const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
|
| 142 |
+
const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
|
| 143 |
+
ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
|
| 144 |
+
// const int max_size = std::max(n, k);
|
| 145 |
+
// size_t workspace_size = getWorkspaceSize(m, max_size, max_size);
|
| 146 |
+
// void *ptr = nullptr;
|
| 147 |
+
// char *workspace_ptr = workspace_size > 0 ? (char *)cudaMalloc((void **)&ptr, workspace_size) : nullptr;
|
| 148 |
+
const bool use_cuda_kernel = m <= SMALL_M_FAST_PATH;
|
| 149 |
+
// const bool use_cuda_kernel = false;
|
| 150 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 151 |
+
|
| 152 |
+
if(use_cuda_kernel){
|
| 153 |
+
tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16;
|
| 154 |
+
tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b;
|
| 155 |
+
tensorrt_llm::kernels::WeightOnlyParams params{weight_ptr, reinterpret_cast<const uint8_t *>(scale.data_ptr()), nullptr,
|
| 156 |
+
reinterpret_cast<half *>(input.data_ptr()), nullptr, nullptr, reinterpret_cast<half *>(output.data_ptr()), m, n, k, 0, weight_only_quant_type,
|
| 157 |
+
tensorrt_llm::kernels::WeightOnlyType::PerChannel,
|
| 158 |
+
tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
|
| 159 |
+
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
|
| 160 |
+
}
|
| 161 |
+
else
|
| 162 |
+
ft::gemm_fp16_int(
|
| 163 |
+
input_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
scale_ptr,
|
| 166 |
+
output_ptr,
|
| 167 |
+
m, n, k,
|
| 168 |
+
nullptr,
|
| 169 |
+
0,
|
| 170 |
+
stream);
|
| 171 |
+
return output;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
|
| 176 |
+
torch::Tensor const &weight,
|
| 177 |
+
torch::Tensor const &scale,
|
| 178 |
+
torch::Tensor &output,
|
| 179 |
+
const int64_t m,
|
| 180 |
+
const int64_t n,
|
| 181 |
+
const int64_t k)
|
| 182 |
+
{
|
| 183 |
+
c10::cuda::CUDAGuard device_guard(input.device());
|
| 184 |
+
|
| 185 |
+
const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
|
| 186 |
+
const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
|
| 187 |
+
const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
|
| 188 |
+
ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
|
| 189 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 190 |
+
|
| 191 |
+
ft::gemm_fp16_int(
|
| 192 |
+
input_ptr,
|
| 193 |
+
weight_ptr,
|
| 194 |
+
scale_ptr,
|
| 195 |
+
output_ptr,
|
| 196 |
+
m, n, k,
|
| 197 |
+
nullptr,
|
| 198 |
+
0,
|
| 199 |
+
stream);
|
| 200 |
+
return output;
|
| 201 |
+
}
|
cutlass_kernels/fpA_intB_gemm_wrapper.h
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/all.h>
|
| 2 |
+
#include <vector>
|
| 3 |
+
|
| 4 |
+
#define SMALL_M_FAST_PATH 4
|
| 5 |
+
std::vector<torch::Tensor>
|
| 6 |
+
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
|
| 7 |
+
at::ScalarType quant_type,
|
| 8 |
+
bool return_unprocessed_quantized_tensor);
|
| 9 |
+
|
| 10 |
+
torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
|
| 11 |
+
bool is_int4);
|
| 12 |
+
|
| 13 |
+
torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
|
| 14 |
+
torch::Tensor const &weight,
|
| 15 |
+
torch::Tensor const &scale);
|
| 16 |
+
|
| 17 |
+
torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
|
| 18 |
+
torch::Tensor const &weight,
|
| 19 |
+
torch::Tensor const &scale,
|
| 20 |
+
torch::Tensor &output,
|
| 21 |
+
const int64_t m,
|
| 22 |
+
const int64_t n,
|
| 23 |
+
const int64_t k);
|
torch-ext/quantization_eetq/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .custom_ops import w8_a16_gemm, w8_a16_gemm_, preprocess_weights, quant_weights
|
| 2 |
+
|
| 3 |
+
__all__ = ["w8_a16_gemm", "w8_a16_gemm_", "preprocess_weights", "quant_weights"]
|
torch-ext/quantization_eetq/custom_ops.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from ._ops import ops
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def w8_a16_gemm(
|
| 8 |
+
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
| 9 |
+
) -> torch.Tensor:
|
| 10 |
+
return ops.w8_a16_gemm(input, weight, scale)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def w8_a16_gemm_(
|
| 14 |
+
input: torch.Tensor,
|
| 15 |
+
weight: torch.Tensor,
|
| 16 |
+
scale: torch.Tensor,
|
| 17 |
+
output: torch.Tensor,
|
| 18 |
+
m: int,
|
| 19 |
+
n: int,
|
| 20 |
+
k: int,
|
| 21 |
+
) -> torch.Tensor:
|
| 22 |
+
return ops.w8_a16_gemm_(input, weight, scale, output, m, n, k)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def preprocess_weights(origin_weight: torch.Tensor, is_int4: bool) -> torch.Tensor:
|
| 26 |
+
return ops.preprocess_weights(origin_weight, is_int4)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def quant_weights(
|
| 30 |
+
origin_weight: torch.Tensor,
|
| 31 |
+
quant_type: torch.dtype,
|
| 32 |
+
return_unprocessed_quantized_tensor: bool,
|
| 33 |
+
) -> List[torch.Tensor]:
|
| 34 |
+
return ops.quant_weights(
|
| 35 |
+
origin_weight, quant_type, return_unprocessed_quantized_tensor
|
| 36 |
+
)
|
torch-ext/registration.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <Python.h>
|
| 4 |
+
|
| 5 |
+
#define _CONCAT(A, B) A##B
|
| 6 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
| 7 |
+
|
| 8 |
+
#define _STRINGIFY(A) #A
|
| 9 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
| 10 |
+
|
| 11 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
| 12 |
+
// could be a macro instead of a literal token.
|
| 13 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
| 14 |
+
|
| 15 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
| 16 |
+
// could be a macro instead of a literal token.
|
| 17 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
| 18 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
| 19 |
+
|
| 20 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
| 21 |
+
// via python's import statement.
|
| 22 |
+
#define REGISTER_EXTENSION(NAME) \
|
| 23 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
| 24 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
| 25 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
| 26 |
+
return PyModule_Create(&module); \
|
| 27 |
+
}
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
#include "torch_binding.h"
|
| 5 |
+
|
| 6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
ops.def("w8_a16_gemm(Tensor input, Tensor weight, Tensor scale) -> Tensor");
|
| 8 |
+
ops.impl("w8_a16_gemm", torch::kCUDA, &w8_a16_gemm_forward_cuda);
|
| 9 |
+
ops.def("w8_a16_gemm_(Tensor input, Tensor weight, Tensor scale, Tensor! output,"
|
| 10 |
+
"int m, int n, int k) -> Tensor");
|
| 11 |
+
ops.impl("w8_a16_gemm_", torch::kCUDA, &w8_a16_gemm_forward_cuda_);
|
| 12 |
+
ops.def("preprocess_weights(Tensor origin_weight, bool is_int4) -> Tensor");
|
| 13 |
+
ops.impl("preprocess_weights", torch::kCUDA, &preprocess_weights_cuda);
|
| 14 |
+
ops.def("quant_weights(Tensor origin_weight, ScalarType quant_type,"
|
| 15 |
+
"bool return_unprocessed_quantized_tensor) -> Tensor[]");
|
| 16 |
+
ops.impl("quant_weights", torch::kCUDA, &symmetric_quantize_last_axis_of_tensor);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <vector>
|
| 4 |
+
|
| 5 |
+
#include <torch/torch.h>
|
| 6 |
+
|
| 7 |
+
std::vector<torch::Tensor>
|
| 8 |
+
symmetric_quantize_last_axis_of_tensor(torch::Tensor const &weight,
|
| 9 |
+
at::ScalarType quant_type,
|
| 10 |
+
bool return_unprocessed_quantized_tensor);
|
| 11 |
+
|
| 12 |
+
torch::Tensor preprocess_weights_cuda(torch::Tensor const &ori_weight,
|
| 13 |
+
bool is_int4);
|
| 14 |
+
|
| 15 |
+
torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor const &input,
|
| 16 |
+
torch::Tensor const&weight,
|
| 17 |
+
torch::Tensor const &scale);
|
| 18 |
+
|
| 19 |
+
torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor const &input,
|
| 20 |
+
torch::Tensor const &weight,
|
| 21 |
+
torch::Tensor const &scale,
|
| 22 |
+
torch::Tensor &output,
|
| 23 |
+
const int64_t m,
|
| 24 |
+
const int64_t n,
|
| 25 |
+
const int64_t k);
|
utils/activation_types.h
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
|
| 19 |
+
#include "cuda_utils.h"
|
| 20 |
+
|
| 21 |
+
namespace fastertransformer {
|
| 22 |
+
|
| 23 |
+
enum class ActivationType {
|
| 24 |
+
Gelu,
|
| 25 |
+
Relu,
|
| 26 |
+
Silu,
|
| 27 |
+
GeGLU,
|
| 28 |
+
ReGLU,
|
| 29 |
+
SiGLU,
|
| 30 |
+
Identity,
|
| 31 |
+
InvalidType
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
inline bool isGatedActivation(ActivationType activaiton_type)
|
| 35 |
+
{
|
| 36 |
+
return activaiton_type == ActivationType::GeGLU || activaiton_type == ActivationType::ReGLU
|
| 37 |
+
|| activaiton_type == ActivationType::SiGLU;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
} // namespace fastertransformer
|
utils/cuda_utils.cc
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#include "cuda_utils.h"
|
| 18 |
+
|
| 19 |
+
namespace fastertransformer {
|
| 20 |
+
|
| 21 |
+
/* ***************************** common utils ****************************** */
|
| 22 |
+
|
| 23 |
+
cudaError_t getSetDevice(int i_device, int* o_device)
|
| 24 |
+
{
|
| 25 |
+
int current_dev_id = 0;
|
| 26 |
+
cudaError_t err = cudaSuccess;
|
| 27 |
+
|
| 28 |
+
if (o_device != NULL) {
|
| 29 |
+
err = cudaGetDevice(¤t_dev_id);
|
| 30 |
+
if (err != cudaSuccess) {
|
| 31 |
+
return err;
|
| 32 |
+
}
|
| 33 |
+
if (current_dev_id == i_device) {
|
| 34 |
+
*o_device = i_device;
|
| 35 |
+
}
|
| 36 |
+
else {
|
| 37 |
+
err = cudaSetDevice(i_device);
|
| 38 |
+
if (err != cudaSuccess) {
|
| 39 |
+
return err;
|
| 40 |
+
}
|
| 41 |
+
*o_device = current_dev_id;
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
else {
|
| 45 |
+
err = cudaSetDevice(i_device);
|
| 46 |
+
if (err != cudaSuccess) {
|
| 47 |
+
return err;
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
return cudaSuccess;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
/* ************************** end of common utils ************************** */
|
| 55 |
+
} // namespace fastertransformer
|
utils/cuda_utils.h
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
|
| 19 |
+
#include "logger.h"
|
| 20 |
+
|
| 21 |
+
#include <cuda_runtime.h>
|
| 22 |
+
#include <fstream>
|
| 23 |
+
#include <iostream>
|
| 24 |
+
#include <string>
|
| 25 |
+
#include <vector>
|
| 26 |
+
|
| 27 |
+
namespace fastertransformer {
|
| 28 |
+
/* **************************** debug tools ********************************* */
|
| 29 |
+
template<typename T>
|
| 30 |
+
void check(T result, char const* const func, const char* const file, int const line)
|
| 31 |
+
{
|
| 32 |
+
if (result) {
|
| 33 |
+
throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + ("<unknown>") + " "
|
| 34 |
+
+ file + ":" + std::to_string(line) + " \n");
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
|
| 39 |
+
|
| 40 |
+
[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "")
|
| 41 |
+
{
|
| 42 |
+
throw std::runtime_error(std::string("[FT][ERROR] ") + info + " Assertion fail: " + file + ":"
|
| 43 |
+
+ std::to_string(line) + " \n");
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
inline void myAssert(bool result, const char* const file, int const line, std::string const& info = "")
|
| 47 |
+
{
|
| 48 |
+
if (!result) {
|
| 49 |
+
throwRuntimeError(file, line, info);
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
#define FT_CHECK(val) myAssert(val, __FILE__, __LINE__)
|
| 54 |
+
#define FT_CHECK_WITH_INFO(val, info) \
|
| 55 |
+
do { \
|
| 56 |
+
bool is_valid_val = (val); \
|
| 57 |
+
if (!is_valid_val) { \
|
| 58 |
+
fastertransformer::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \
|
| 59 |
+
} \
|
| 60 |
+
} while (0)
|
| 61 |
+
|
| 62 |
+
/* ***************************** common utils ****************************** */
|
| 63 |
+
inline int getSMVersion()
|
| 64 |
+
{
|
| 65 |
+
int device{-1};
|
| 66 |
+
check_cuda_error(cudaGetDevice(&device));
|
| 67 |
+
int sm_major = 0;
|
| 68 |
+
int sm_minor = 0;
|
| 69 |
+
check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
|
| 70 |
+
check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
|
| 71 |
+
return sm_major * 10 + sm_minor;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
cudaError_t getSetDevice(int i_device, int* o_device = NULL);
|
| 75 |
+
/* ************************** end of common utils ************************** */
|
| 76 |
+
} // namespace fastertransformer
|
utils/logger.cc
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#include "logger.h"
|
| 18 |
+
#include <cuda_runtime.h>
|
| 19 |
+
|
| 20 |
+
namespace fastertransformer {
|
| 21 |
+
|
| 22 |
+
Logger::Logger()
|
| 23 |
+
{
|
| 24 |
+
char* is_first_rank_only_char = std::getenv("FT_LOG_FIRST_RANK_ONLY");
|
| 25 |
+
bool is_first_rank_only =
|
| 26 |
+
(is_first_rank_only_char != nullptr && std::string(is_first_rank_only_char) == "ON") ? true : false;
|
| 27 |
+
|
| 28 |
+
int device_id;
|
| 29 |
+
cudaGetDevice(&device_id);
|
| 30 |
+
|
| 31 |
+
char* level_name = std::getenv("FT_LOG_LEVEL");
|
| 32 |
+
if (level_name != nullptr) {
|
| 33 |
+
std::map<std::string, Level> name_to_level = {
|
| 34 |
+
{"TRACE", TRACE},
|
| 35 |
+
{"DEBUG", DEBUG},
|
| 36 |
+
{"INFO", INFO},
|
| 37 |
+
{"WARNING", WARNING},
|
| 38 |
+
{"ERROR", ERROR},
|
| 39 |
+
};
|
| 40 |
+
auto level = name_to_level.find(level_name);
|
| 41 |
+
// If FT_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR
|
| 42 |
+
if (is_first_rank_only && device_id != 0) {
|
| 43 |
+
level = name_to_level.find("ERROR");
|
| 44 |
+
}
|
| 45 |
+
if (level != name_to_level.end()) {
|
| 46 |
+
setLevel(level->second);
|
| 47 |
+
}
|
| 48 |
+
else {
|
| 49 |
+
fprintf(stderr,
|
| 50 |
+
"[FT][WARNING] Invalid logger level FT_LOG_LEVEL=%s. "
|
| 51 |
+
"Ignore the environment variable and use a default "
|
| 52 |
+
"logging level.\n",
|
| 53 |
+
level_name);
|
| 54 |
+
level_name = nullptr;
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
} // namespace fastertransformer
|
utils/logger.h
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
|
| 19 |
+
#include <cstdlib>
|
| 20 |
+
#include <map>
|
| 21 |
+
#include <string>
|
| 22 |
+
|
| 23 |
+
#include "string_utils.h"
|
| 24 |
+
|
| 25 |
+
namespace fastertransformer {
|
| 26 |
+
|
| 27 |
+
class Logger {
|
| 28 |
+
|
| 29 |
+
public:
|
| 30 |
+
enum Level {
|
| 31 |
+
TRACE = 0,
|
| 32 |
+
DEBUG = 10,
|
| 33 |
+
INFO = 20,
|
| 34 |
+
WARNING = 30,
|
| 35 |
+
ERROR = 40
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
static Logger& getLogger()
|
| 39 |
+
{
|
| 40 |
+
thread_local Logger instance;
|
| 41 |
+
return instance;
|
| 42 |
+
}
|
| 43 |
+
Logger(Logger const&) = delete;
|
| 44 |
+
void operator=(Logger const&) = delete;
|
| 45 |
+
|
| 46 |
+
template<typename... Args>
|
| 47 |
+
void log(const Level level, const std::string format, const Args&... args)
|
| 48 |
+
{
|
| 49 |
+
if (level_ <= level) {
|
| 50 |
+
std::string fmt = getPrefix(level) + format + "\n";
|
| 51 |
+
FILE* out = level_ < WARNING ? stdout : stderr;
|
| 52 |
+
std::string logstr = fmtstr(fmt, args...);
|
| 53 |
+
fprintf(out, "%s", logstr.c_str());
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
template<typename... Args>
|
| 58 |
+
void log(const Level level, const int rank, const std::string format, const Args&... args)
|
| 59 |
+
{
|
| 60 |
+
if (level_ <= level) {
|
| 61 |
+
std::string fmt = getPrefix(level, rank) + format + "\n";
|
| 62 |
+
FILE* out = level_ < WARNING ? stdout : stderr;
|
| 63 |
+
std::string logstr = fmtstr(fmt, args...);
|
| 64 |
+
fprintf(out, "%s", logstr.c_str());
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
void setLevel(const Level level)
|
| 69 |
+
{
|
| 70 |
+
level_ = level;
|
| 71 |
+
log(INFO, "Set logger level by %s", getLevelName(level).c_str());
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
int getLevel() const
|
| 75 |
+
{
|
| 76 |
+
return level_;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
private:
|
| 80 |
+
const std::string PREFIX = "[FT]";
|
| 81 |
+
const std::map<const Level, const std::string> level_name_ = {
|
| 82 |
+
{TRACE, "TRACE"}, {DEBUG, "DEBUG"}, {INFO, "INFO"}, {WARNING, "WARNING"}, {ERROR, "ERROR"}};
|
| 83 |
+
|
| 84 |
+
#ifndef NDEBUG
|
| 85 |
+
const Level DEFAULT_LOG_LEVEL = DEBUG;
|
| 86 |
+
#else
|
| 87 |
+
const Level DEFAULT_LOG_LEVEL = INFO;
|
| 88 |
+
#endif
|
| 89 |
+
Level level_ = DEFAULT_LOG_LEVEL;
|
| 90 |
+
|
| 91 |
+
Logger();
|
| 92 |
+
|
| 93 |
+
inline const std::string getLevelName(const Level level)
|
| 94 |
+
{
|
| 95 |
+
return level_name_.at(level);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
inline const std::string getPrefix(const Level level)
|
| 99 |
+
{
|
| 100 |
+
return PREFIX + "[" + getLevelName(level) + "] ";
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
inline const std::string getPrefix(const Level level, const int rank)
|
| 104 |
+
{
|
| 105 |
+
return PREFIX + "[" + getLevelName(level) + "][" + std::to_string(rank) + "] ";
|
| 106 |
+
}
|
| 107 |
+
};
|
| 108 |
+
|
| 109 |
+
#define FT_LOG(level, ...) \
|
| 110 |
+
do { \
|
| 111 |
+
if (fastertransformer::Logger::getLogger().getLevel() <= level) { \
|
| 112 |
+
fastertransformer::Logger::getLogger().log(level, __VA_ARGS__); \
|
| 113 |
+
} \
|
| 114 |
+
} while (0)
|
| 115 |
+
|
| 116 |
+
#define FT_LOG_TRACE(...) FT_LOG(fastertransformer::Logger::TRACE, __VA_ARGS__)
|
| 117 |
+
#define FT_LOG_DEBUG(...) FT_LOG(fastertransformer::Logger::DEBUG, __VA_ARGS__)
|
| 118 |
+
#define FT_LOG_INFO(...) FT_LOG(fastertransformer::Logger::INFO, __VA_ARGS__)
|
| 119 |
+
#define FT_LOG_WARNING(...) FT_LOG(fastertransformer::Logger::WARNING, __VA_ARGS__)
|
| 120 |
+
#define FT_LOG_ERROR(...) FT_LOG(fastertransformer::Logger::ERROR, __VA_ARGS__)
|
| 121 |
+
} // namespace fastertransformer
|
utils/string_utils.h
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
|
| 19 |
+
#include <memory> // std::make_unique
|
| 20 |
+
#include <sstream> // std::stringstream
|
| 21 |
+
#include <string>
|
| 22 |
+
#include <vector>
|
| 23 |
+
|
| 24 |
+
namespace fastertransformer {
|
| 25 |
+
|
| 26 |
+
template<typename... Args>
|
| 27 |
+
inline std::string fmtstr(const std::string& format, Args... args)
|
| 28 |
+
{
|
| 29 |
+
// This function came from a code snippet in stackoverflow under cc-by-1.0
|
| 30 |
+
// https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf
|
| 31 |
+
|
| 32 |
+
// Disable format-security warning in this function.
|
| 33 |
+
#if defined(_MSC_VER) // for visual studio
|
| 34 |
+
#pragma warning(push)
|
| 35 |
+
#pragma warning(warning(disable : 4996))
|
| 36 |
+
#elif defined(__GNUC__) || defined(__clang__) // for gcc or clang
|
| 37 |
+
#pragma GCC diagnostic push
|
| 38 |
+
#pragma GCC diagnostic ignored "-Wformat-security"
|
| 39 |
+
#endif
|
| 40 |
+
int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0'
|
| 41 |
+
if (size_s <= 0) {
|
| 42 |
+
throw std::runtime_error("Error during formatting.");
|
| 43 |
+
}
|
| 44 |
+
auto size = static_cast<size_t>(size_s);
|
| 45 |
+
auto buf = std::make_unique<char[]>(size);
|
| 46 |
+
std::snprintf(buf.get(), size, format.c_str(), args...);
|
| 47 |
+
#if defined(_MSC_VER)
|
| 48 |
+
#pragma warning(pop)
|
| 49 |
+
#elif defined(__GNUC__) || defined(__clang__)
|
| 50 |
+
#pragma GCC diagnostic pop
|
| 51 |
+
#endif
|
| 52 |
+
return std::string(buf.get(), buf.get() + size - 1); // We don't want the '\0' inside
|
| 53 |
+
}
|
| 54 |
+
} // namespace fastertransformer
|
utils/torch_utils.h
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include "torch/csrc/cuda/Stream.h"
|
| 3 |
+
#include "torch/all.h"
|
| 4 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 5 |
+
#include <cstdio>
|
| 6 |
+
#include <cuda_fp16.h>
|
| 7 |
+
#include <cuda_runtime.h>
|
| 8 |
+
#include <iostream>
|
| 9 |
+
#include <nvToolsExt.h>
|
| 10 |
+
#include <torch/custom_class.h>
|
| 11 |
+
#include <torch/script.h>
|
| 12 |
+
#include <vector>
|
| 13 |
+
|
| 14 |
+
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
| 15 |
+
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
| 16 |
+
#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small")
|
| 17 |
+
#define CHECK_TYPE(x, st) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type: " #x)
|
| 18 |
+
#define CHECK_TH_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
| 19 |
+
#define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
|
| 20 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 21 |
+
#define CHECK_INPUT(x, st) \
|
| 22 |
+
CHECK_TH_CUDA(x); \
|
| 23 |
+
CHECK_CONTIGUOUS(x); \
|
| 24 |
+
CHECK_TYPE(x, st)
|
| 25 |
+
#define CHECK_CPU_INPUT(x, st) \
|
| 26 |
+
CHECK_CPU(x); \
|
| 27 |
+
CHECK_CONTIGUOUS(x); \
|
| 28 |
+
CHECK_TYPE(x, st)
|
| 29 |
+
#define CHECK_OPTIONAL_INPUT(x, st) \
|
| 30 |
+
if (x.has_value()) { \
|
| 31 |
+
CHECK_INPUT(x.value(), st); \
|
| 32 |
+
}
|
| 33 |
+
#define CHECK_OPTIONAL_CPU_INPUT(x, st) \
|
| 34 |
+
if (x.has_value()) { \
|
| 35 |
+
CHECK_CPU_INPUT(x.value(), st); \
|
| 36 |
+
}
|
| 37 |
+
#define PRINT_TENSOR(x) std::cout << #x << ":\n" << x << std::endl
|
| 38 |
+
#define PRINT_TENSOR_SIZE(x) std::cout << "size of " << #x << ": " << x.sizes() << std::endl
|
| 39 |
+
|
| 40 |
+
namespace fastertransformer {
|
| 41 |
+
|
| 42 |
+
template<typename T>
|
| 43 |
+
inline T* get_ptr(torch::Tensor& t)
|
| 44 |
+
{
|
| 45 |
+
return reinterpret_cast<T*>(t.data_ptr());
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
std::vector<size_t> convert_shape(torch::Tensor tensor);
|
| 49 |
+
|
| 50 |
+
size_t sizeBytes(torch::Tensor tensor);
|
| 51 |
+
|
| 52 |
+
QuantType get_ft_quant_type(torch::ScalarType quant_type)
|
| 53 |
+
{
|
| 54 |
+
if (quant_type == torch::kInt8) {
|
| 55 |
+
return QuantType::INT8_WEIGHT_ONLY;
|
| 56 |
+
}
|
| 57 |
+
else if (quant_type == at::ScalarType::QUInt4x2) {
|
| 58 |
+
return QuantType::PACKED_INT4_WEIGHT_ONLY;
|
| 59 |
+
}
|
| 60 |
+
else {
|
| 61 |
+
TORCH_CHECK(false, "Invalid quantization type");
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
} // namespace fastertransformer
|
weightOnlyBatchedGemv/common.h
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
#include <cassert>
|
| 19 |
+
#include <cmath>
|
| 20 |
+
#include <cstdint>
|
| 21 |
+
#include <cuda_fp16.h>
|
| 22 |
+
#if defined(ENABLE_BF16)
|
| 23 |
+
#include <cuda_bf16.h>
|
| 24 |
+
#endif
|
| 25 |
+
#include <cuda_runtime.h>
|
| 26 |
+
#include <cuda_runtime_api.h>
|
| 27 |
+
#include <iostream>
|
| 28 |
+
|
| 29 |
+
namespace tensorrt_llm
|
| 30 |
+
{
|
| 31 |
+
namespace kernels
|
| 32 |
+
{
|
| 33 |
+
enum class WeightOnlyQuantType
|
| 34 |
+
{
|
| 35 |
+
Int4b,
|
| 36 |
+
Int8b
|
| 37 |
+
};
|
| 38 |
+
enum class WeightOnlyType
|
| 39 |
+
{
|
| 40 |
+
PerChannel,
|
| 41 |
+
GroupWise
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
struct WeightOnlyPerChannel;
|
| 45 |
+
template <int GS>
|
| 46 |
+
struct WeightOnlyGroupWise;
|
| 47 |
+
|
| 48 |
+
enum class WeightOnlyActivationFunctionType
|
| 49 |
+
{
|
| 50 |
+
Gelu,
|
| 51 |
+
Relu,
|
| 52 |
+
Identity,
|
| 53 |
+
InvalidType
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
enum class WeightOnlyActivationType
|
| 57 |
+
{
|
| 58 |
+
FP16,
|
| 59 |
+
BF16
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
struct WeightOnlyParams
|
| 63 |
+
{
|
| 64 |
+
// ActType is fp16 or bf16
|
| 65 |
+
using ActType = void;
|
| 66 |
+
using WeiType = uint8_t;
|
| 67 |
+
|
| 68 |
+
const uint8_t* qweight;
|
| 69 |
+
const ActType* scales;
|
| 70 |
+
const ActType* zeros;
|
| 71 |
+
const ActType* in;
|
| 72 |
+
const ActType* act_scale;
|
| 73 |
+
const ActType* bias;
|
| 74 |
+
ActType* out;
|
| 75 |
+
const int m;
|
| 76 |
+
const int n;
|
| 77 |
+
const int k;
|
| 78 |
+
const int group_size;
|
| 79 |
+
WeightOnlyQuantType quant_type;
|
| 80 |
+
WeightOnlyType weight_only_type;
|
| 81 |
+
WeightOnlyActivationFunctionType act_func_type;
|
| 82 |
+
WeightOnlyActivationType act_type;
|
| 83 |
+
|
| 84 |
+
WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in,
|
| 85 |
+
const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k,
|
| 86 |
+
const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type,
|
| 87 |
+
const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type)
|
| 88 |
+
: qweight(_qweight)
|
| 89 |
+
, scales(_scales)
|
| 90 |
+
, zeros(_zeros)
|
| 91 |
+
, in(_in)
|
| 92 |
+
, act_scale(_act_scale)
|
| 93 |
+
, bias(_bias)
|
| 94 |
+
, out(_out)
|
| 95 |
+
, m(_m)
|
| 96 |
+
, n(_n)
|
| 97 |
+
, k(_k)
|
| 98 |
+
, group_size(_group_size)
|
| 99 |
+
, quant_type(_quant_type)
|
| 100 |
+
, weight_only_type(_weight_only_type)
|
| 101 |
+
, act_func_type(_act_func_type)
|
| 102 |
+
, act_type(_act_type)
|
| 103 |
+
{
|
| 104 |
+
}
|
| 105 |
+
};
|
| 106 |
+
} // namespace kernels
|
| 107 |
+
} // namespace tensorrt_llm
|
weightOnlyBatchedGemv/enabled.h
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#pragma once
|
| 18 |
+
#include "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
| 19 |
+
#include "common.h"
|
| 20 |
+
#include <cuda_runtime.h>
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
inline int getSMVersion()
|
| 24 |
+
{
|
| 25 |
+
int device{-1};
|
| 26 |
+
cudaGetDevice(&device);
|
| 27 |
+
int sm_major = 0;
|
| 28 |
+
int sm_minor = 0;
|
| 29 |
+
cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device);
|
| 30 |
+
cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device);
|
| 31 |
+
return sm_major * 10 + sm_minor;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
namespace tensorrt_llm
|
| 35 |
+
{
|
| 36 |
+
namespace kernels
|
| 37 |
+
{
|
| 38 |
+
template <typename TypeB, typename Layout>
|
| 39 |
+
struct SupportedLayout
|
| 40 |
+
{
|
| 41 |
+
static constexpr bool value = false;
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
template <>
|
| 45 |
+
struct SupportedLayout<uint8_t, cutlass::layout::ColumnMajorTileInterleave<64, 2>>
|
| 46 |
+
{
|
| 47 |
+
static constexpr bool value = true;
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
template <>
|
| 51 |
+
struct SupportedLayout<cutlass::uint4b_t, cutlass::layout::ColumnMajorTileInterleave<64, 4>>
|
| 52 |
+
{
|
| 53 |
+
static constexpr bool value = true;
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
template <typename TypeB, typename Arch>
|
| 57 |
+
bool isEnabled()
|
| 58 |
+
{
|
| 59 |
+
using Layout = typename cutlass::gemm::kernel::LayoutDetailsB<TypeB, Arch>::Layout;
|
| 60 |
+
return SupportedLayout<TypeB, Layout>::value;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
template <typename TypeB>
|
| 64 |
+
bool isEnabledForArch(int arch)
|
| 65 |
+
{
|
| 66 |
+
if (arch >= 70 && arch < 75)
|
| 67 |
+
{
|
| 68 |
+
return isEnabled<TypeB, cutlass::arch::Sm70>();
|
| 69 |
+
}
|
| 70 |
+
else if (arch >= 75 && arch < 80)
|
| 71 |
+
{
|
| 72 |
+
return isEnabled<TypeB, cutlass::arch::Sm75>();
|
| 73 |
+
}
|
| 74 |
+
else if (arch >= 80 && arch <= 90)
|
| 75 |
+
{
|
| 76 |
+
return isEnabled<TypeB, cutlass::arch::Sm80>();
|
| 77 |
+
}
|
| 78 |
+
else
|
| 79 |
+
{
|
| 80 |
+
// TLLM_CHECK_WITH_INFO(false, "Unsupported Arch");
|
| 81 |
+
assert(0);
|
| 82 |
+
return false;
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
inline bool isWeightOnlyBatchedGemvEnabled(WeightOnlyQuantType qtype)
|
| 87 |
+
{
|
| 88 |
+
const int arch = getSMVersion();
|
| 89 |
+
if (qtype == WeightOnlyQuantType::Int4b)
|
| 90 |
+
{
|
| 91 |
+
return isEnabledForArch<cutlass::uint4b_t>(arch);
|
| 92 |
+
}
|
| 93 |
+
else if (qtype == WeightOnlyQuantType::Int8b)
|
| 94 |
+
{
|
| 95 |
+
return isEnabledForArch<uint8_t>(arch);
|
| 96 |
+
}
|
| 97 |
+
else
|
| 98 |
+
{
|
| 99 |
+
assert(0);
|
| 100 |
+
// TLLM_CHECK_WITH_INFO(false, "Unsupported WeightOnlyQuantType");
|
| 101 |
+
return false;
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
} // namespace kernels
|
| 105 |
+
} // namespace tensorrt_llm
|