File size: 15,352 Bytes
			
			| 4d1b54e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 | /******************************************************************************
 * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
 ******************************************************************************/
#pragma once
#include <cute/tensor.hpp>
#include "cutlass/fast_math.h"  // For cutlass::FastDivmod
#include "utils.h"
namespace flash {
using namespace cute;
template <int kBlockM, int kHeadDim, int NumThreads, typename Element>
struct PackGQAManager {
    // We use CpAsync for Q, since TMA doesn't work there
    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
    static constexpr int kGmemElemsPerStore = kGmemElemsPerLoad;
    static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
    // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each
    // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.
    // In the case of PackGQA, this reduces the number of times we need to call divmod.
    static constexpr int kBytePerRow = kHeadDim * sizeof(Element);
    static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
    static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
    static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow");
    // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where
    // these threads share the same page table entry and share the work of computing pointers to paged K and paged V.
    static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp");
    using GmemCopyAtomCpAsync = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<uint128_t>, Element>;
    using GmemLayoutAtom = Layout<Shape <Int<NumThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
                                  Stride<Int<kGmemThreadsPerRow>, _1>>;
    using GmemTiledCopyQCpAsync = decltype(
        make_tiled_copy(GmemCopyAtomCpAsync{},
                        GmemLayoutAtom{},
                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per load
    // Was trying to have each WG loading Q to the rows in sQ that only that WG needs so that we only need
    // to sync within each WG, but didn't seem to be any faster.
    // using GmemLayoutAtomWG = Layout<Shape <Int<128 / kGmemThreadsPerRow>, Int<NumThreads / 128>, Int<kGmemThreadsPerRow> >,
    //     Stride<Int<kGmemThreadsPerRow>, _128, _1>>;
    // using GmemTiledCopyQCpAsyncWG = decltype(
    //     make_tiled_copy(GmemCopyAtomCpAsync{},
    //                     GmemLayoutAtomNew{},
    //                     Layout<Shape<_1, _1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per load
    using GmemTiledCopyO = decltype(
        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
                        GmemLayoutAtom{},
                        Layout<Shape<_1, Int<kGmemElemsPerStore>>>{}));  // Val layout, 8 or 16 vals per store
    template <int NumThreadsPerRow=kGmemThreadsPerRow, typename Engine, typename Layout, typename TensorC>
    CUTLASS_DEVICE
    static auto
    compute_ptr(Tensor<Engine, Layout> &tensor, TensorC const &tRows,
                cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const m_block) {
        // tensor of shape ((qhead_per_khead, seqlen_q))
        static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size(tRows)), NumThreadsPerRow);
        using TensorType = typename Engine::value_type;
        Tensor tPrPtr = make_tensor<TensorType const*>(Shape<Int<NumPtrPerThread>>{});
        #pragma unroll
        for (int i = 0; i < NumPtrPerThread; ++i) {
            int const row = i * NumThreads + get<0>(tRows(thread_idx % NumThreadsPerRow));
            int const idx = m_block * kBlockM + row;
            int m_idx, h_idx;
            m_idx = qhead_per_khead_divmod.divmod(h_idx, idx);
            tPrPtr[i] = &tensor(make_coord(make_coord(h_idx, m_idx)));
        }
        return tPrPtr;
    }
    template <typename TensormQ, typename TensorsQ>
    CUTLASS_DEVICE
    static void
    load_Q(TensormQ const &mQ,  // ((qhead_per_khead, seqlen_q), headdim)
           TensorsQ &sQ,  // (kBlockM, kHeadDim)
           cutlass::FastDivmod const &qhead_per_khead_divmod,
           int const thread_idx, int const seqlen_q, int const m_block
          )
    {
        GmemTiledCopyQCpAsync gmem_tiled_copy_Q_cp_async;
        // GmemTiledCopyQCpAsyncNew gmem_tiled_copy_Q_cp_async;
        auto gmem_thr_copy_Q_cp_async = gmem_tiled_copy_Q_cp_async.get_thread_slice(thread_idx);
        Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)
        Tensor tQcQ = gmem_thr_copy_Q_cp_async.partition_S(cQ);       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
        Tensor tQsQ = gmem_thr_copy_Q_cp_async.partition_D(sQ);       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
        // Tensor tQcQ_ = gmem_thr_copy_Q_cp_async.partition_S(cute::flat_divide(cQ, _64{}));       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
        // Tensor tQsQ_ = gmem_thr_copy_Q_cp_async.partition_D(cute::flat_divide(sQ, _64{}));       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
        // Tensor tQcQ = group_modes<1, rank(tQcQ_) - 1>(tQcQ_);
        // Tensor tQsQ = group_modes<1, rank(tQsQ_) - 1>(tQsQ_);
        Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
        #pragma unroll
        for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < size<1>(mQ); }
        // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for Q.
        // We split the work among threads loading the same row of Q, then __shfl_sync the pointers.
        Tensor mQ_0 = mQ(_, _0{});
        Tensor tQcQ_row = tQcQ(_0{}, _, _0{});
        Tensor tPrQPtr = compute_ptr(mQ_0, tQcQ_row, qhead_per_khead_divmod, thread_idx, m_block);
        int const qhead_per_khead = qhead_per_khead_divmod.divisor;
        #pragma unroll
        for (int m = 0; m < size<1>(tQsQ); ++m) {
            int idx = m_block * kBlockM + get<0>(tQcQ(_0{}, m, _0{}));
            Element const* q_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrQPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));
            if (idx < seqlen_q * qhead_per_khead) {
                // if (thread_idx == 0) { printf("m: %d, m_idx: %d, h_idx: %d, q_ptr = %p, q_ptr_og = %p\n", m, m_idx, h_idx, q_ptr, &mQ_copy(0, make_coord(h_idx, m_idx), 0));}
                Tensor mQ_cur = make_tensor(make_gmem_ptr(q_ptr), Shape<Int<kHeadDim>>{});
                Tensor mQ_cur_copy = cute::tiled_divide(mQ_cur, Shape<Int<kGmemElemsPerLoad>>{});
                #pragma unroll
                for (int k = 0; k < size<2>(tQsQ); ++k) {
                    int ki = get<1>(tQcQ(_0{}, _0{}, k)) / kGmemElemsPerLoad;
                    // the "tiled_copy.with(tQpQ(k))"" will fill in zero for columns where tQpQ(k) is false
                    // TODO: check this
                    cute::copy(gmem_tiled_copy_Q_cp_async.with(tQpQ(k)), mQ_cur_copy(_, ki), tQsQ(_, m, k));
                }
            } // Don't need to fill in 0s for sQ since we're not gonna write the output to gmem for those rows
        }
    };
    template <typename TensormLSE, typename TensorsLSE, typename TiledMma>
    CUTLASS_DEVICE
    static void
    store_LSE(TensormLSE &mLSE,  // ((qhead_per_khead, seqlen_q))
              TensorsLSE const &tLSErLSE,  // (kBlockM) split across threads according to tiled_mma
              TiledMma tiled_mma,
              cutlass::FastDivmod const &qhead_per_khead_divmod,
              int const thread_idx, int const seqlen_o, int const m_block
             )
    {
        Tensor caccO = cute::make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
        auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
        Tensor taccOcO = thread_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)
        Tensor taccOcO_row = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()))(_, _0{});
        CUTE_STATIC_ASSERT_V(size(tLSErLSE) == size(taccOcO_row));                     // MMA_M
        // If PackGQA, we split the work of compute divmod among threads in the same row
        static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{});
        static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0);
        static_assert(CUTE_STATIC_V(size(tLSErLSE)) <= kMmaThreadsPerRow);
        static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow);
        Tensor tPrLSEPtr = compute_ptr<kMmaThreadsPerRow>(mLSE, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block);
        static_assert(CUTE_STATIC_V(size(tPrLSEPtr)) == 1);
        int const qhead_per_khead = qhead_per_khead_divmod.divisor;
        #pragma unroll
        for (int mi = 0; mi < size(tLSErLSE); ++mi) {
            int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
            float* ptr_LSE_cur = reinterpret_cast<float*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrLSEPtr[0]), mi % kMmaThreadsPerRow, kMmaThreadsPerRow));
            if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o * qhead_per_khead) {
                *ptr_LSE_cur = tLSErLSE(mi);
            }
        }
    };
    template <typename TensormO, typename TensorrO>
    CUTLASS_DEVICE
    static void
    store_O(TensormO &mO,  // ((qhead_per_khead, seqlen_o), headdim)
            TensorrO const &tOrO,  // (kBlockM, kHeadDim) split across threads according to gmem_tiled_copy_O
            cutlass::FastDivmod const &qhead_per_khead_divmod,
            int const thread_idx, int const seqlen_o, int const m_block
          )
    {
        GmemTiledCopyO gmem_tiled_copy_O;
        auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
        Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)
        Tensor tOcO = gmem_thr_copy_O.partition_D(cO);       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
        Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
        #pragma unroll
        for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < size<1>(mO); }
        // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O.
        // We split the work among threads loading the same row of O, then __shfl_sync the pointers.
        Tensor mO_0 = mO(_, _0{});
        Tensor tOcO_row = tOcO(_0{}, _, _0{});
        Tensor tPrOPtr = compute_ptr(mO_0, tOcO_row, qhead_per_khead_divmod, thread_idx, m_block);
        int const qhead_per_khead = qhead_per_khead_divmod.divisor;
        #pragma unroll
        for (int m = 0; m < size<1>(tOrO); ++m) {
            int idx = m_block * kBlockM + get<0>(tOcO(_0{}, m, _0{}));
            Element* o_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrOPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));
            if (idx < seqlen_o * qhead_per_khead) {
                Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape<Int<kHeadDim>>{});
                Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape<Int<kGmemElemsPerStore>>{});
                #pragma unroll
                for (int k = 0; k < size<2>(tOrO); ++k) {
                    int ki = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerStore;
                    if (tOpO(k)) {
                        cute::copy(gmem_tiled_copy_O, tOrO(_, m, k), mO_cur_copy(_, ki));
                    }
                }
            }
        }
    };
    template <typename TensormO, typename TensorrO, typename TiledMma>
    CUTLASS_DEVICE
    static void
    store_O_direct(TensormO &mO,  // ((qhead_per_khead, seqlen_o), headdim)
                   TensorrO const &tOrO,  // (kBlockM, kHeadDim) split across threads according to tiled_mma
                   TiledMma tiled_mma,
                   cutlass::FastDivmod const &qhead_per_khead_divmod,
                   int const thread_idx, int const seqlen_o, int const m_block
                 )
    {
        static constexpr int kGmemElemsPerStoreDirect = 2;
        cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element> gmem_copy_direct;
        // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
        Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
        Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
        Tensor caccO = cute::make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
        auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
        Tensor taccOcO = thread_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)
        Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
        Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
        Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);
        // If PackGQA, we split the work of compute divmod among threads in the same row
        static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{});
        static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0);
        static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow);
        // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O.
        // We split the work among threads loading the same row of O, then __shfl_sync the pointers.
        Tensor mO_0 = mO(_, _0{});
        Tensor tPrOPtr = compute_ptr<kMmaThreadsPerRow>(mO_0, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block);
        static_assert(CUTE_STATIC_V(size(tPrOPtr)) == 1);
        int const qhead_per_khead = qhead_per_khead_divmod.divisor;
        #pragma unroll
        for (int m = 0; m < size<1>(tOrO_copy); ++m) {
            int row = m_block * kBlockM + get<0>(taccOcO_row(m));
            Element* o_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrOPtr[0]), m % kMmaThreadsPerRow, kMmaThreadsPerRow));
            if (row < seqlen_o * qhead_per_khead) {
                Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape<Int<kHeadDim>>{});
                Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape<Int<kGmemElemsPerStoreDirect>>{});
                #pragma unroll
                for (int k = 0; k < size<2>(tOrO_copy); ++k) {
                    int col = get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect));
                    if (col < size<1>(mO)) {
                        cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), mO_cur_copy(_, col / kGmemElemsPerStoreDirect));
                    }
                }
            }
        }
    };
};
} // namespace flash
 | 
