|
#pragma once |
|
|
|
#include "cuda_utils.h" |
|
#include "cutlass/cutlass.h" |
|
#include "cutlass/numeric_types.h" |
|
|
|
#include "cute/tensor.hpp" |
|
#include "cutlass/tensor_ref.h" |
|
#include "cutlass/gemm/dispatch_policy.hpp" |
|
#include "cutlass/gemm/collective/collective_builder.hpp" |
|
#include "cutlass/gemm/device/gemm_universal_adapter.h" |
|
#include "cutlass/gemm/kernel/gemm_universal.hpp" |
|
#include "cutlass/gemm/kernel/tile_scheduler_params.h" |
|
#include "cutlass/epilogue/dispatch_policy.hpp" |
|
#include "cutlass/epilogue/collective/collective_builder.hpp" |
|
|
|
#include "cutlass_extensions/gemm/dispatch_policy.hpp" |
|
#include "cutlass_extensions/gemm/collective/collective_builder.hpp" |
|
|
|
#include "cutlass_gemm_caller.cuh" |
|
|
|
namespace vllm { |
|
|
|
using namespace cute; |
|
|
|
|
|
template <class OutType, int ScaleGranularityM, |
|
int ScaleGranularityN, int ScaleGranularityK, |
|
class MmaTileShape, class ClusterShape, |
|
class EpilogueScheduler, class MainloopScheduler, |
|
bool swap_ab_ = false> |
|
struct cutlass_3x_gemm_fp8_blockwise { |
|
static constexpr bool swap_ab = swap_ab_; |
|
using ElementAB = cutlass::float_e4m3_t; |
|
|
|
using ElementA = ElementAB; |
|
using LayoutA = cutlass::layout::RowMajor; |
|
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type; |
|
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; |
|
|
|
using ElementB = ElementAB; |
|
using LayoutB = cutlass::layout::ColumnMajor; |
|
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type; |
|
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; |
|
|
|
using ElementD = OutType; |
|
using LayoutD = cutlass::layout::RowMajor; |
|
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type; |
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; |
|
|
|
using ElementC = void; |
|
using LayoutC = LayoutD; |
|
using LayoutC_Transpose = LayoutD_Transpose; |
|
static constexpr int AlignmentC = AlignmentD; |
|
|
|
using ElementAccumulator = float; |
|
using ElementCompute = float; |
|
using ElementBlockScale = float; |
|
|
|
using ScaleConfig = conditional_t<swap_ab, |
|
cutlass::detail::Sm100BlockwiseScaleConfig< |
|
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, |
|
cute::UMMA::Major::K, cute::UMMA::Major::MN>, |
|
cutlass::detail::Sm100BlockwiseScaleConfig< |
|
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, |
|
cute::UMMA::Major::MN, cute::UMMA::Major::K>>; |
|
|
|
|
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); |
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); |
|
|
|
using ArchTag = cutlass::arch::Sm100; |
|
using OperatorClass = cutlass::arch::OpClassTensorOp; |
|
|
|
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; |
|
using ElementScalar = float; |
|
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>; |
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< |
|
ArchTag, |
|
OperatorClass, |
|
MmaTileShape, |
|
ClusterShape, |
|
cutlass::epilogue::collective::EpilogueTileAuto, |
|
ElementAccumulator, |
|
ElementCompute, |
|
ElementC, |
|
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>, |
|
AlignmentC, |
|
ElementD, |
|
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>, |
|
AlignmentD, |
|
EpilogueScheduler, |
|
DefaultOperation |
|
>::CollectiveOp; |
|
|
|
using StageCountType = cutlass::gemm::collective::StageCountAuto; |
|
using CollectiveMainloop = conditional_t<swap_ab, |
|
typename cutlass::gemm::collective::CollectiveBuilder< |
|
ArchTag, |
|
OperatorClass, |
|
ElementB, |
|
cute::tuple<LayoutB_Transpose, LayoutSFA>, |
|
AlignmentB, |
|
ElementA, |
|
cute::tuple<LayoutA_Transpose, LayoutSFB>, |
|
AlignmentA, |
|
ElementAccumulator, |
|
MmaTileShape, |
|
ClusterShape, |
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, |
|
MainloopScheduler |
|
>::CollectiveOp, |
|
typename cutlass::gemm::collective::CollectiveBuilder< |
|
ArchTag, |
|
OperatorClass, |
|
ElementA, |
|
cute::tuple<LayoutA, LayoutSFA>, |
|
AlignmentA, |
|
ElementB, |
|
cute::tuple<LayoutB, LayoutSFB>, |
|
AlignmentB, |
|
ElementAccumulator, |
|
MmaTileShape, |
|
ClusterShape, |
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, |
|
MainloopScheduler |
|
>::CollectiveOp>; |
|
|
|
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal< |
|
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>; |
|
|
|
struct GemmKernel : public KernelType {}; |
|
}; |
|
|
|
template <typename Gemm> |
|
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, |
|
torch::Tensor const& b, |
|
torch::Tensor const& a_scales, |
|
torch::Tensor const& b_scales) { |
|
static constexpr bool swap_ab = Gemm::swap_ab; |
|
using GemmKernel = typename Gemm::GemmKernel; |
|
using StrideA = typename Gemm::GemmKernel::StrideA; |
|
using StrideB = typename Gemm::GemmKernel::StrideB; |
|
using StrideD = typename Gemm::GemmKernel::StrideD; |
|
using StrideC = typename Gemm::GemmKernel::StrideC; |
|
using LayoutSFA = typename Gemm::LayoutSFA; |
|
using LayoutSFB = typename Gemm::LayoutSFB; |
|
using ScaleConfig = typename Gemm::ScaleConfig; |
|
|
|
using ElementAB = typename Gemm::ElementAB; |
|
using ElementD = typename Gemm::ElementD; |
|
|
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1); |
|
|
|
StrideA a_stride; |
|
StrideB b_stride; |
|
StrideC c_stride; |
|
a_stride = |
|
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); |
|
b_stride = |
|
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); |
|
c_stride = |
|
cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); |
|
|
|
LayoutSFA layout_SFA = swap_ab ? |
|
ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) : |
|
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); |
|
LayoutSFB layout_SFB = swap_ab ? |
|
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) : |
|
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); |
|
|
|
auto a_ptr = static_cast<ElementAB*>(a.data_ptr()); |
|
auto b_ptr = static_cast<ElementAB*>(b.data_ptr()); |
|
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr()); |
|
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr()); |
|
|
|
auto mainloop_args = [&](){ |
|
|
|
if (swap_ab) { |
|
return typename GemmKernel::MainloopArguments{ |
|
b_ptr, b_stride, a_ptr, a_stride, |
|
b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB |
|
}; |
|
} |
|
else { |
|
return typename GemmKernel::MainloopArguments{ |
|
a_ptr, a_stride, b_ptr, b_stride, |
|
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB |
|
}; |
|
} |
|
}(); |
|
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); |
|
|
|
auto c_ptr = static_cast<ElementD*>(out.data_ptr()); |
|
typename GemmKernel::EpilogueArguments epilogue_args{ |
|
{}, c_ptr, c_stride, c_ptr, c_stride}; |
|
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args, |
|
epilogue_args); |
|
} |
|
|
|
template <typename OutType> |
|
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, |
|
torch::Tensor const& a, |
|
torch::Tensor const& b, |
|
torch::Tensor const& a_scales, |
|
torch::Tensor const& b_scales) { |
|
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; |
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); |
|
|
|
constexpr int TILE_K = 128; |
|
|
|
bool swap_ab = (m < 16) || (m % 4 != 0); |
|
bool use_tma_epilogue = (m * n) % 4 == 0; |
|
if (!swap_ab) { |
|
constexpr int TILE_N = 128; |
|
int tile_m = 256; |
|
if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) { |
|
tile_m = 64; |
|
} |
|
else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) { |
|
tile_m = 128; |
|
} |
|
if (tile_m == 64) { |
|
if (use_tma_epilogue) { |
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
|
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>, |
|
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, |
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( |
|
out, a, b, a_scales, b_scales); |
|
} else { |
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
|
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>, |
|
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, |
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( |
|
out, a, b, a_scales, b_scales); |
|
} |
|
} else if (tile_m == 128) { |
|
if (use_tma_epilogue) { |
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
|
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>, |
|
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, |
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( |
|
out, a, b, a_scales, b_scales); |
|
} else { |
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
|
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>, |
|
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, |
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( |
|
out, a, b, a_scales, b_scales); |
|
} |
|
} else { |
|
if (use_tma_epilogue) { |
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
|
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>, |
|
Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, |
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( |
|
out, a, b, a_scales, b_scales); |
|
} else { |
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
|
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>, |
|
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm, |
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( |
|
out, a, b, a_scales, b_scales); |
|
} |
|
} |
|
} else { |
|
|
|
constexpr int TILE_M = 128; |
|
constexpr int TILE_N = 16; |
|
|
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
|
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>, |
|
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, |
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>( |
|
out, a, b, a_scales, b_scales); |
|
} |
|
} |
|
|
|
} |
|
|