|  | #pragma once | 
					
						
						|  |  | 
					
						
						|  | #include "scaled_mm_c2x.cuh" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | namespace vllm { | 
					
						
						|  |  | 
					
						
						|  | template <typename InType, typename OutType, | 
					
						
						|  | template <typename, typename> typename Epilogue> | 
					
						
						|  | struct sm89_int8_fallback_gemm { | 
					
						
						|  |  | 
					
						
						|  | static_assert(std::is_same<InType, int8_t>()); | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<32, 64, 128>; | 
					
						
						|  | using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>; | 
					
						
						|  | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; | 
					
						
						|  | static int32_t const MainLoopStages = 5; | 
					
						
						|  |  | 
					
						
						|  | using Cutlass2xGemm = | 
					
						
						|  | cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType, | 
					
						
						|  | Epilogue, TileShape, WarpShape, InstructionShape, 5>; | 
					
						
						|  | }; | 
					
						
						|  |  | 
					
						
						|  | struct sm89_int8_config_default { | 
					
						
						|  |  | 
					
						
						|  | using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; | 
					
						
						|  | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; | 
					
						
						|  |  | 
					
						
						|  | template <typename InType, typename OutType, | 
					
						
						|  | template <typename, typename> typename Epilogue, | 
					
						
						|  | typename... EpilogueArgs> | 
					
						
						|  | static void dispatch(torch::Tensor& out, torch::Tensor const& a, | 
					
						
						|  | torch::Tensor const& b, EpilogueArgs&&... args) { | 
					
						
						|  | static_assert(std::is_same<InType, int8_t>()); | 
					
						
						|  | TORCH_CHECK(a.dtype() == torch::kInt8); | 
					
						
						|  |  | 
					
						
						|  | using FallbackGemm = | 
					
						
						|  | typename sm89_int8_fallback_gemm<InType, OutType, | 
					
						
						|  | Epilogue>::Cutlass2xGemm; | 
					
						
						|  |  | 
					
						
						|  | uint32_t const n = out.size(1); | 
					
						
						|  | uint32_t const np2 = next_pow_2(n); | 
					
						
						|  |  | 
					
						
						|  | if (np2 <= 4096) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<128, 128, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 5>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else if (np2 <= 8192) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<256, 128, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 3>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else if (np2 <= 16384) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<128, 128, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 5>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<256, 128, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 3>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | }; | 
					
						
						|  |  | 
					
						
						|  | struct sm89_int8_config_M256 { | 
					
						
						|  |  | 
					
						
						|  | using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; | 
					
						
						|  | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; | 
					
						
						|  |  | 
					
						
						|  | template <typename InType, typename OutType, | 
					
						
						|  | template <typename, typename> typename Epilogue, | 
					
						
						|  | typename... EpilogueArgs> | 
					
						
						|  | static void dispatch(torch::Tensor& out, torch::Tensor const& a, | 
					
						
						|  | torch::Tensor const& b, EpilogueArgs&&... args) { | 
					
						
						|  | static_assert(std::is_same<InType, int8_t>()); | 
					
						
						|  | TORCH_CHECK(a.dtype() == torch::kInt8); | 
					
						
						|  |  | 
					
						
						|  | using FallbackGemm = | 
					
						
						|  | typename sm89_int8_fallback_gemm<InType, OutType, | 
					
						
						|  | Epilogue>::Cutlass2xGemm; | 
					
						
						|  |  | 
					
						
						|  | uint32_t const n = out.size(1); | 
					
						
						|  | uint32_t const np2 = next_pow_2(n); | 
					
						
						|  |  | 
					
						
						|  | if (np2 <= 4096) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<64, 128, 128>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 3>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else if (np2 <= 8192) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<128, 128, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 5>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else if (np2 <= 16384) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<256, 128, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 3>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<128, 128, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 5>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | }; | 
					
						
						|  |  | 
					
						
						|  | struct sm89_int8_config_M128 { | 
					
						
						|  |  | 
					
						
						|  | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; | 
					
						
						|  |  | 
					
						
						|  | template <typename InType, typename OutType, | 
					
						
						|  | template <typename, typename> typename Epilogue, | 
					
						
						|  | typename... EpilogueArgs> | 
					
						
						|  | static void dispatch(torch::Tensor& out, torch::Tensor const& a, | 
					
						
						|  | torch::Tensor const& b, EpilogueArgs&&... args) { | 
					
						
						|  | static_assert(std::is_same<InType, int8_t>()); | 
					
						
						|  | TORCH_CHECK(a.dtype() == torch::kInt8); | 
					
						
						|  |  | 
					
						
						|  | using FallbackGemm = | 
					
						
						|  | typename sm89_int8_fallback_gemm<InType, OutType, | 
					
						
						|  | Epilogue>::Cutlass2xGemm; | 
					
						
						|  |  | 
					
						
						|  | uint32_t const n = out.size(1); | 
					
						
						|  | uint32_t const np2 = next_pow_2(n); | 
					
						
						|  |  | 
					
						
						|  | if (np2 <= 8192) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<64, 128, 128>; | 
					
						
						|  | using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 3>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else if (np2 <= 16384) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<128, 128, 64>; | 
					
						
						|  | using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 5>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<64, 64, 128>; | 
					
						
						|  | using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 5>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | }; | 
					
						
						|  |  | 
					
						
						|  | struct sm89_int8_config_M64 { | 
					
						
						|  |  | 
					
						
						|  | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; | 
					
						
						|  |  | 
					
						
						|  | template <typename InType, typename OutType, | 
					
						
						|  | template <typename, typename> typename Epilogue, | 
					
						
						|  | typename... EpilogueArgs> | 
					
						
						|  | static void dispatch(torch::Tensor& out, torch::Tensor const& a, | 
					
						
						|  | torch::Tensor const& b, EpilogueArgs&&... args) { | 
					
						
						|  | static_assert(std::is_same<InType, int8_t>()); | 
					
						
						|  | TORCH_CHECK(a.dtype() == torch::kInt8); | 
					
						
						|  |  | 
					
						
						|  | using FallbackGemm = | 
					
						
						|  | typename sm89_int8_fallback_gemm<InType, OutType, | 
					
						
						|  | Epilogue>::Cutlass2xGemm; | 
					
						
						|  |  | 
					
						
						|  | uint32_t const n = out.size(1); | 
					
						
						|  | uint32_t const np2 = next_pow_2(n); | 
					
						
						|  |  | 
					
						
						|  | if (np2 <= 8192) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<64, 64, 128>; | 
					
						
						|  | using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 5>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<64, 128, 128>; | 
					
						
						|  | using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 3>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | }; | 
					
						
						|  |  | 
					
						
						|  | struct sm89_int8_config_M32 { | 
					
						
						|  |  | 
					
						
						|  | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; | 
					
						
						|  |  | 
					
						
						|  | template <typename InType, typename OutType, | 
					
						
						|  | template <typename, typename> typename Epilogue, | 
					
						
						|  | typename... EpilogueArgs> | 
					
						
						|  | static void dispatch(torch::Tensor& out, torch::Tensor const& a, | 
					
						
						|  | torch::Tensor const& b, EpilogueArgs&&... args) { | 
					
						
						|  | static_assert(std::is_same<InType, int8_t>()); | 
					
						
						|  | TORCH_CHECK(a.dtype() == torch::kInt8); | 
					
						
						|  |  | 
					
						
						|  | using FallbackGemm = | 
					
						
						|  | typename sm89_int8_fallback_gemm<InType, OutType, | 
					
						
						|  | Epilogue>::Cutlass2xGemm; | 
					
						
						|  |  | 
					
						
						|  | uint32_t const n = out.size(1); | 
					
						
						|  | uint32_t const np2 = next_pow_2(n); | 
					
						
						|  |  | 
					
						
						|  | if (np2 <= 8192) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<32, 64, 128>; | 
					
						
						|  | using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 5>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<32, 128, 128>; | 
					
						
						|  | using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 4>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | }; | 
					
						
						|  |  | 
					
						
						|  | struct sm89_int8_config_M16 { | 
					
						
						|  |  | 
					
						
						|  | using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; | 
					
						
						|  | using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; | 
					
						
						|  |  | 
					
						
						|  | template <typename InType, typename OutType, | 
					
						
						|  | template <typename, typename> typename Epilogue, | 
					
						
						|  | typename... EpilogueArgs> | 
					
						
						|  | static void dispatch(torch::Tensor& out, torch::Tensor const& a, | 
					
						
						|  | torch::Tensor const& b, EpilogueArgs&&... args) { | 
					
						
						|  | static_assert(std::is_same<InType, int8_t>()); | 
					
						
						|  | TORCH_CHECK(a.dtype() == torch::kInt8); | 
					
						
						|  |  | 
					
						
						|  | using FallbackGemm = | 
					
						
						|  | typename sm89_int8_fallback_gemm<InType, OutType, | 
					
						
						|  | Epilogue>::Cutlass2xGemm; | 
					
						
						|  |  | 
					
						
						|  | uint32_t const n = out.size(1); | 
					
						
						|  | uint32_t const np2 = next_pow_2(n); | 
					
						
						|  |  | 
					
						
						|  | if (np2 <= 8192) { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<16, 64, 128>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 5>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else { | 
					
						
						|  | using TileShape = cutlass::gemm::GemmShape<16, 128, 128>; | 
					
						
						|  |  | 
					
						
						|  | return vllm::fallback_cutlass_gemm_caller< | 
					
						
						|  | vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90, | 
					
						
						|  | InType, OutType, Epilogue, TileShape, WarpShape, | 
					
						
						|  | InstructionShape, 4>, | 
					
						
						|  | FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | }; | 
					
						
						|  |  | 
					
						
						|  | template <typename InType, typename OutType, | 
					
						
						|  | template <typename, typename> typename Epilogue, | 
					
						
						|  | typename... EpilogueArgs> | 
					
						
						|  | inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out, | 
					
						
						|  | torch::Tensor const& a, | 
					
						
						|  | torch::Tensor const& b, | 
					
						
						|  | EpilogueArgs&&... args) { | 
					
						
						|  | static_assert(std::is_same<InType, int8_t>()); | 
					
						
						|  | TORCH_CHECK(a.dtype() == torch::kInt8); | 
					
						
						|  | TORCH_CHECK(b.dtype() == torch::kInt8); | 
					
						
						|  |  | 
					
						
						|  | uint32_t const m = a.size(0); | 
					
						
						|  | uint32_t const mp2 = | 
					
						
						|  | std::max(static_cast<uint32_t>(16), next_pow_2(m)); | 
					
						
						|  |  | 
					
						
						|  | if (mp2 <= 16) { | 
					
						
						|  |  | 
					
						
						|  | return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>( | 
					
						
						|  | out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else if (mp2 <= 32) { | 
					
						
						|  |  | 
					
						
						|  | return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>( | 
					
						
						|  | out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else if (mp2 <= 64) { | 
					
						
						|  |  | 
					
						
						|  | return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>( | 
					
						
						|  | out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else if (mp2 <= 128) { | 
					
						
						|  |  | 
					
						
						|  | return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>( | 
					
						
						|  | out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else if (mp2 <= 256) { | 
					
						
						|  |  | 
					
						
						|  | return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>( | 
					
						
						|  | out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } else { | 
					
						
						|  |  | 
					
						
						|  | return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>( | 
					
						
						|  | out, a, b, std::forward<EpilogueArgs>(args)...); | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | } | 
					
						
						|  |  |