drbh
commited on
Commit
·
59bdff8
1
Parent(s):
f475609
fix: adjust sig types
Browse files- flash_mla/flash_mla_api.cu +5 -19
- torch-ext/torch_binding.cpp +1 -1
- torch-ext/torch_binding.h +3 -11
flash_mla/flash_mla_api.cu
CHANGED
|
@@ -53,40 +53,26 @@ get_mla_metadata(
|
|
| 53 |
return {tile_scheduler_metadata, num_splits};
|
| 54 |
}
|
| 55 |
|
|
|
|
|
|
|
| 56 |
std::vector<at::Tensor>
|
| 57 |
mha_fwd_kvcache_mla(
|
| 58 |
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
| 59 |
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
| 60 |
-
|
| 61 |
-
// TODO: fix for optional
|
| 62 |
-
// std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
| 63 |
-
const at::Tensor &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
| 64 |
-
|
| 65 |
const int64_t head_size_v,
|
| 66 |
const at::Tensor &seqlens_k, // batch_size
|
| 67 |
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
| 68 |
-
|
| 69 |
-
// TODO: should be float
|
| 70 |
const double softmax_scale,
|
| 71 |
-
|
| 72 |
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
| 73 |
const at::Tensor &num_splits // batch_size + 1
|
| 74 |
-
|
| 75 |
-
// TODO: remove this once determined why build is adding this parameter
|
| 76 |
-
// const int64_t unknown_param
|
| 77 |
) {
|
| 78 |
auto dprops = at::cuda::getCurrentDeviceProperties();
|
| 79 |
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
| 80 |
TORCH_CHECK(is_sm90);
|
| 81 |
|
| 82 |
-
|
| 83 |
-
bool is_causal = is_causal_;
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
// TODO: fix for optional
|
| 87 |
-
// at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
| 88 |
-
at::Tensor vcache = vcache_;
|
| 89 |
-
|
| 90 |
auto q_dtype = q.dtype();
|
| 91 |
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
| 92 |
|
|
|
|
| 53 |
return {tile_scheduler_metadata, num_splits};
|
| 54 |
}
|
| 55 |
|
| 56 |
+
// note doubles and longs are used in place of floats and ints
|
| 57 |
+
// https://github.com/pytorch/pytorch/blob/338ed67a1e7aa98dd849f297533c5a71bea4b661/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h#L211
|
| 58 |
std::vector<at::Tensor>
|
| 59 |
mha_fwd_kvcache_mla(
|
| 60 |
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
| 61 |
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
|
| 62 |
+
const c10::optional<torch::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
const int64_t head_size_v,
|
| 64 |
const at::Tensor &seqlens_k, // batch_size
|
| 65 |
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
|
|
|
|
|
|
|
| 66 |
const double softmax_scale,
|
| 67 |
+
bool is_causal,
|
| 68 |
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
| 69 |
const at::Tensor &num_splits // batch_size + 1
|
|
|
|
|
|
|
|
|
|
| 70 |
) {
|
| 71 |
auto dprops = at::cuda::getCurrentDeviceProperties();
|
| 72 |
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
| 73 |
TORCH_CHECK(is_sm90);
|
| 74 |
|
| 75 |
+
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
auto q_dtype = q.dtype();
|
| 77 |
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
|
| 78 |
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -8,7 +8,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 8 |
ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
|
| 9 |
|
| 10 |
// TOOD: remove last unknown_param when resolved
|
| 11 |
-
ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor
|
| 12 |
ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
|
| 13 |
}
|
| 14 |
|
|
|
|
| 8 |
ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata);
|
| 9 |
|
| 10 |
// TOOD: remove last unknown_param when resolved
|
| 11 |
+
ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor? vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits) -> Tensor[]");
|
| 12 |
ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla);
|
| 13 |
}
|
| 14 |
|
torch-ext/torch_binding.h
CHANGED
|
@@ -13,21 +13,13 @@ std::vector<torch::Tensor>
|
|
| 13 |
mha_fwd_kvcache_mla(
|
| 14 |
torch::Tensor &q,
|
| 15 |
const torch::Tensor &kcache,
|
| 16 |
-
|
| 17 |
-
// TODO: fix for optional
|
| 18 |
-
// std::optional<torch::Tensor> &vcache_,
|
| 19 |
-
|
| 20 |
-
const torch::Tensor &vcache_,
|
| 21 |
const int64_t head_size_v,
|
| 22 |
const torch::Tensor &seqlens_k,
|
| 23 |
const torch::Tensor &block_table,
|
| 24 |
-
|
| 25 |
// TODO:should be float
|
| 26 |
-
const
|
| 27 |
-
|
| 28 |
-
// TODO: fix for mutable bool
|
| 29 |
-
const bool is_causal_,
|
| 30 |
-
|
| 31 |
const torch::Tensor &tile_scheduler_metadata,
|
| 32 |
const torch::Tensor &num_splits
|
| 33 |
);
|
|
|
|
| 13 |
mha_fwd_kvcache_mla(
|
| 14 |
torch::Tensor &q,
|
| 15 |
const torch::Tensor &kcache,
|
| 16 |
+
const c10::optional<torch::Tensor> &vcache_,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
const int64_t head_size_v,
|
| 18 |
const torch::Tensor &seqlens_k,
|
| 19 |
const torch::Tensor &block_table,
|
|
|
|
| 20 |
// TODO:should be float
|
| 21 |
+
const torch::kFloat softmax_scale,
|
| 22 |
+
bool is_causal,
|
|
|
|
|
|
|
|
|
|
| 23 |
const torch::Tensor &tile_scheduler_metadata,
|
| 24 |
const torch::Tensor &num_splits
|
| 25 |
);
|