|
#pragma once |
|
|
|
#include "scaled_mm.cuh" |
|
#include "cutlass_gemm_caller.cuh" |
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace vllm { |
|
|
|
using c3x::cutlass_gemm_caller; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm100_fp8_config_default { |
|
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; |
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; |
|
using TileShape = Shape<_256, _128, _128>; |
|
using ClusterShape = Shape<_2, _2, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm100_fp8_config_M256 { |
|
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; |
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; |
|
using TileShape = Shape<_128, _128, _128>; |
|
using ClusterShape = Shape<_2, _1, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm100_fp8_config_M64 { |
|
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; |
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; |
|
using TileShape = Shape<_64, _64, _128>; |
|
using ClusterShape = Shape<_1, _1, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue> |
|
struct sm100_fp8_config_M16 { |
|
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
|
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; |
|
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; |
|
using TileShape = Shape<_64, _64, _128>; |
|
using ClusterShape = Shape<_1, _4, _1>; |
|
using Cutlass3xGemm = |
|
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, |
|
KernelSchedule, EpilogueSchedule>; |
|
}; |
|
|
|
template <typename InType, typename OutType, |
|
template <typename, typename, typename> typename Epilogue, |
|
typename... EpilogueArgs> |
|
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, |
|
torch::Tensor const& a, |
|
torch::Tensor const& b, |
|
EpilogueArgs&&... args) { |
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); |
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); |
|
|
|
using Cutlass3xGemmDefault = |
|
typename sm100_fp8_config_default<InType, OutType, |
|
Epilogue>::Cutlass3xGemm; |
|
using Cutlass3xGemmM16 = |
|
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm; |
|
using Cutlass3xGemmM64 = |
|
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm; |
|
using Cutlass3xGemmM256 = |
|
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm; |
|
|
|
uint32_t const m = a.size(0); |
|
uint32_t const mp2 = |
|
std::max(static_cast<uint32_t>(16), next_pow_2(m)); |
|
|
|
if (mp2 <= 16) { |
|
|
|
return cutlass_gemm_caller<Cutlass3xGemmM16>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} else if (mp2 <= 64) { |
|
|
|
return cutlass_gemm_caller<Cutlass3xGemmM64>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} else if (mp2 <= 256) { |
|
|
|
return cutlass_gemm_caller<Cutlass3xGemmM256>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} else { |
|
|
|
return cutlass_gemm_caller<Cutlass3xGemmDefault>( |
|
out, a, b, std::forward<EpilogueArgs>(args)...); |
|
} |
|
} |
|
|
|
template <template <typename, typename, typename> typename Epilogue, |
|
typename... EpilogueArgs> |
|
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out, |
|
torch::Tensor const& a, |
|
torch::Tensor const& b, |
|
EpilogueArgs&&... epilogue_args) { |
|
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); |
|
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); |
|
|
|
if (out.dtype() == torch::kBFloat16) { |
|
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t, |
|
cutlass::bfloat16_t, Epilogue>( |
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); |
|
} else { |
|
TORCH_CHECK(out.dtype() == torch::kFloat16); |
|
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t, |
|
cutlass::half_t, Epilogue>( |
|
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); |
|
} |
|
} |
|
|
|
} |