From 10379c03ffa7adf931233ab3949ba91680c69342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E6=98=8A=E8=BE=B0?= Date: Tue, 5 Aug 2025 21:56:27 +0800 Subject: [PATCH 1/2] custom ops support paged_cache_load --- .../ms_kernels_internal/paged_cache_load.cc | 232 ++++++++++++++++++ .../paged_cache_load_op.yaml | 40 +++ 2 files changed, 272 insertions(+) create mode 100644 ccsrc/ops/ms_kernels_internal/paged_cache_load.cc create mode 100644 yaml/ms_kernels_internal/paged_cache_load_op.yaml diff --git a/ccsrc/ops/ms_kernels_internal/paged_cache_load.cc b/ccsrc/ops/ms_kernels_internal/paged_cache_load.cc new file mode 100644 index 0000000..5b48d67 --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/paged_cache_load.cc @@ -0,0 +1,232 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include "internal_kernel_mod.h" +#include "mindspore/ops/ops_utils/op_utils.h" +#include "ops/ops_func_impl/op_func_impl.h" + +namespace ms_custom_ops { +constexpr size_t kInputKeyCacheIndex = 0; +constexpr size_t kInputValueCacheIndex = 1; +constexpr size_t kInputBlockTableIndex = 2; +constexpr size_t kInputSeqLensIndex = 3; +constexpr size_t kInputKeyIndex = 4; +constexpr size_t kInputValueIndex = 5; +constexpr size_t kInputSeqStartsIndex = 6; +constexpr size_t kInputParamKvCacheCfgIndex = 7; +constexpr size_t kInputParamIsSeqLensCumsumTypeIndex = 8; +constexpr size_t kInputParamHasSeqStartsIndex = 9; +constexpr size_t kOutputKeyOutIndex = 0; +constexpr size_t kOutputValueOutIndex = 1; +class OPS_API CustomPagedCacheLoadOpFuncImpl : public OpFuncImpl { +public: + ShapeArray InferShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + return {input_infos[kInputKeyIndex]->GetShape(), input_infos[kInputValueIndex]->GetShape()}; + } + std::vector + InferType(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + return {{input_infos[kInputKeyIndex]->GetType(), input_infos[kInputValueIndex]->GetType()}}; + } + + bool GeneralInferRegistered() const override { return true; } +}; +} // namespace ms_custom_ops + +namespace ms_custom_ops { +class CustomPagedCacheLoad : public InternalKernelMod { +public: + CustomPagedCacheLoad() : InternalKernelMod(), skip_execution_(false) {} + ~CustomPagedCacheLoad() = default; + + void InitKernelInputsOutputsIndex() override { + kernel_inputs_index_ = {kInputKeyCacheIndex, kInputValueCacheIndex, kInputBlockTableIndex, kInputSeqLensIndex, + kInputKeyIndex, kInputValueIndex, kInputSeqStartsIndex}; + kernel_outputs_index_ = {kOutputKeyOutIndex, kOutputValueOutIndex}; + } + + int Resize(const std::vector &inputs, + const std::vector &outputs) override { + // Check if any input has shape containing 0 + for (const auto &input : inputs) { + if (input == nullptr) + continue; + auto shape = input->GetShapeVector(); + for (const auto &dim : shape) { + if (dim == 0) { + MS_LOG(INFO) << "paged_cache_load: Skipping execution due to zero " + "dimension in input shape: " + << shape; + skip_execution_ = true; + return KernelMod::Resize(inputs, outputs); // Skip execution + } + } + } + + skip_execution_ = false; + // Call base class implementation + return InternalKernelMod::Resize(inputs, outputs); + } + + bool Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, + void *stream_ptr) override { + // Skip execution if flag is set + if (skip_execution_) { + return true; // Skip execution, return success + } + + // Call base class implementation + return InternalKernelMod::Launch(inputs, workspace, outputs, stream_ptr); + } + +protected: + internal::InternalOpPtr + CreateKernel(const internal::InputsImmutableInfoList &inputs, + const internal::OutputsImmutableInfoList &outputs, + const std::vector &ms_inputs, + const std::vector &ms_outputs) override { + internal::PagedCacheLoadParam param; + auto kv_cache_cfg_type = ms_inputs.at(kInputParamKvCacheCfgIndex); + auto is_seq_lens_cumsum_type = ms_inputs.at(kInputParamIsSeqLensCumsumTypeIndex); + auto has_seq_starts = ms_inputs.at(kInputParamHasSeqStartsIndex); + param.kv_cache_cfg_type = kv_cache_cfg_type->GetValue().value(); + param.is_seq_lens_cumsum_type = is_seq_lens_cumsum_type->GetValue().value(); + param.has_seq_starts = has_seq_starts->GetValue().value(); + return internal::CreatePagedCacheLoadOp(inputs, outputs, param, internal::kInternalPagedCacheLoadOpName); + } + +private: + bool skip_execution_; // Flag to skip execution when shape contains 0 +}; +} // namespace ms_custom_ops +REG_GRAPH_MODE_OP(paged_cache_load, ms_custom_ops::CustomPagedCacheLoadOpFuncImpl, + ms_custom_ops::CustomPagedCacheLoad); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +#include "internal_pyboost_runner.h" + +using namespace ms_custom_ops; +namespace ms::pynative { +class PagedCacheLoadRunner : public InternalPyboostRunner { +public: + using InternalPyboostRunner::InternalPyboostRunner; + void SetKvCacheCfg(const int32_t &kv_cache_cfg) { this->kv_cache_cfg_ = kv_cache_cfg; } + void SetIsSeqLensCumsumType(const bool &is_seq_lens_cumsum_type) { + this->is_seq_lens_cumsum_type_ = is_seq_lens_cumsum_type; + } + void SetHasSeqStarts(const bool &has_seq_starts) { this->has_seq_starts_ = has_seq_starts; } + internal::PagedCacheLoadParam param_; +protected: + internal::InternalOpPtr + CreateKernel(const internal::InputsImmutableInfoList &inputs, + const internal::OutputsImmutableInfoList &outputs) override { + return internal::CreatePagedCacheLoadOp( + inputs, outputs, param_, internal::kInternalPagedCacheLoadOpName); + } + +private: + int32_t kv_cache_cfg_{0}; + bool is_seq_lens_cumsum_type_{false}; + bool has_seq_starts_{false}; +}; +MS_KERNELS_INTERNAL_NAME_REG(PagedCacheLoad, + internal::kInternalPagedCacheLoadOpName); +} // namespace ms::pynative + +namespace ms_custom_ops { +// Helper function to convert optional tensor to tensor or empty tensor +ms::Tensor GetTensorOrEmpty(const std::optional &opt_tensor); + +// infer shape and type func +// ms::Tensor GenResultTensor(const ms::Tensor &key) { +// return ms::Tensor(key.data_type(), key.shape()); +// } + +std::vector npu_paged_cache_load(const ms::Tensor &key_cache, + const ms::Tensor &value_cache, + const ms::Tensor &block_table, + const ms::Tensor &seq_lens, + const ms::Tensor &key, + const ms::Tensor &value, + const std::optional &seq_starts, + std::optional kv_cache_cfg, + std::optional is_seq_lens_cumsum_type, + std::optional has_seq_starts) { + auto op_name = "PagedCacheLoad"; + auto runner = std::make_shared(op_name); + MS_EXCEPTION_IF_NULL(runner); + + // Set head_num if provided + if (kv_cache_cfg.has_value()) { + runner->SetKvCacheCfg(static_cast(kv_cache_cfg.value())); + } + if (is_seq_lens_cumsum_type.has_value()) { + runner->SetIsSeqLensCumsumType(is_seq_lens_cumsum_type.value()); + } + if (has_seq_starts.has_value()) { + runner->SetHasSeqStarts(has_seq_starts.value()); + } + runner->param_.kv_cache_cfg_type = static_cast(kv_cache_cfg.value()); + runner->param_.is_seq_lens_cumsum_type = is_seq_lens_cumsum_type.value(); + runner->param_.has_seq_starts = has_seq_starts.value(); + + // Setup the runner with all parameters (including hash calculation) + runner->Setup(op_name, key_cache, value_cache, block_table, seq_lens, key, value, seq_starts, kv_cache_cfg, + is_seq_lens_cumsum_type, has_seq_starts); + std::vector inputs = {key_cache, value_cache, block_table, seq_lens, key, value, + GetTensorOrEmpty(seq_starts)}; + std::vector outputs = {key, value}; + runner->GetOrCreateKernel(inputs, outputs); + runner->Run(inputs, outputs); + return outputs; +} +} // namespace ms_custom_ops + +auto pyboost_paged_cache_load(const ms::Tensor &key_cache, + const ms::Tensor &value_cache, + const ms::Tensor &block_table, + const ms::Tensor &seq_lens, + const ms::Tensor &key, + const ms::Tensor &value, + const std::optional &seq_starts, + std::optional kv_cache_cfg, + std::optional is_seq_lens_cumsum_type, + std::optional has_seq_starts) { + return ms::pynative::PyboostRunner::Call<2>( + ms_custom_ops::npu_paged_cache_load, key_cache, value_cache, block_table, seq_lens, key, value, seq_starts, + kv_cache_cfg, is_seq_lens_cumsum_type, has_seq_starts); +} + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("paged_cache_load", &pyboost_paged_cache_load, "Paged Cache Load", + pybind11::arg("key_cache"), pybind11::arg("value_cache"), + pybind11::arg("block_table"), pybind11::arg("seq_lens"), + pybind11::arg("key"), pybind11::arg("value"), + pybind11::arg("seq_starts") = std::nullopt, + pybind11::arg("kv_cache_cfg") = std::nullopt, + pybind11::arg("is_seq_lens_cumsum_type") = std::nullopt, + pybind11::arg("has_seq_starts") = std::nullopt); +} diff --git a/yaml/ms_kernels_internal/paged_cache_load_op.yaml b/yaml/ms_kernels_internal/paged_cache_load_op.yaml new file mode 100644 index 0000000..309fdc7 --- /dev/null +++ b/yaml/ms_kernels_internal/paged_cache_load_op.yaml @@ -0,0 +1,40 @@ +#operator paged_cache_load +paged_cache_load: + args: + key_cache: + dtype: tensor + value_cache: + dtype: tensor + block_tables: + dtype: tensor + seq_lens: + dtype: tensor + key: + dtype: tensor + value: + dtype: tensor + seq_starts: + dtype: tensor + default: None + kv_cache_cfg: + dtype: int + default: 0 + is_seq_lens_cumsum_type: + dtype: bool + default: false + has_seq_starts: + dtype: bool + default: false + args_signature: + rw_write: key, value + labels: + side_effect_mem: True + returns: + key_out: + dtype: tensor + inplace: key + value_out: + dtype: tensor + inplace: value + class: + name: PagedCacheLoad -- Gitee From 0b14e5ea736aa2f32b46274636c42f684b0b951e Mon Sep 17 00:00:00 2001 From: tianxiaodong3 Date: Mon, 4 Aug 2025 09:54:57 +0800 Subject: [PATCH 2/2] support custom ringmla op --- .../ms_kernels_internal/ring_mla/ring_mla.cc | 288 +++++++++++ .../ms_kernels_internal/ring_mla/ring_mla.h | 128 +++++ .../ring_mla/ring_mla_runner.cc | 206 ++++++++ .../ring_mla/ring_mla_runner.h | 61 +++ tests/st/test_ms_ring_mla.py | 470 ++++++++++++++++++ yaml/ms_kernels_internal/ring_mla_op.yaml | 72 +++ 6 files changed, 1225 insertions(+) create mode 100644 ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla.cc create mode 100644 ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla.h create mode 100644 ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla_runner.cc create mode 100644 ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla_runner.h create mode 100644 tests/st/test_ms_ring_mla.py create mode 100644 yaml/ms_kernels_internal/ring_mla_op.yaml diff --git a/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla.cc b/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla.cc new file mode 100644 index 0000000..60bba55 --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla.cc @@ -0,0 +1,288 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ring_mla.h" + +namespace ms_custom_ops { + +void CustomRingMLAOpFuncImpl::CheckInputShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const { + // Helper lambda for shape check + auto check_shape_rank = [](const std::vector &shape, size_t expected_rank, const std::string &name) { + MS_CHECK_VALUE(shape.size() == expected_rank, + CheckAndConvertUtils::FormatCommMsg("For RingMLA The rank of " + name + " must be ", expected_rank, + ", but got shape: ", shape)); + }; + + auto check_head_dim = [](const std::vector &shape, int64_t expected, const std::string &name) { + MS_CHECK_VALUE(shape.back() == expected, + CheckAndConvertUtils::FormatCommMsg("For RingMLA The headDim of " + name + " must be ", expected, + ", but got shape: ", shape)); + }; + + // query + if (!input_infos[kQueryIdx]->IsDynamic()) { + const auto &query_shape = input_infos[kQueryIdx]->GetShape(); + check_shape_rank(query_shape, QKV_SHAPE_RANK, "query"); + check_head_dim(query_shape, QK_SPLIT1_HEAD_DIM, "query"); + } + + // query_rope + if (!input_infos[kQueryRopeIdx]->IsDynamic()) { + const auto &query_rope_shape = input_infos[kQueryRopeIdx]->GetShape(); + check_shape_rank(query_rope_shape, QKV_SHAPE_RANK, "query_rope"); + check_head_dim(query_rope_shape, QK_SPLIT2_HEAD_DIM, "query_rope"); + } + + // key + if (!input_infos[kKeyIdx]->IsDynamic()) { + const auto &key_shape = input_infos[kKeyIdx]->GetShape(); + check_shape_rank(key_shape, QKV_SHAPE_RANK, "key"); + check_head_dim(key_shape, QK_SPLIT1_HEAD_DIM, "key"); + } + + // key_rope + if (!input_infos[kKeyRopeIdx]->IsDynamic()) { + const auto &key_rope_shape = input_infos[kKeyRopeIdx]->GetShape(); + check_shape_rank(key_rope_shape, QKV_SHAPE_RANK, "key_rope"); + check_head_dim(key_rope_shape, QK_SPLIT2_HEAD_DIM, "key_rope"); + } + + // value + if (!input_infos[kValueIdx]->IsDynamic()) { + const auto &value_shape = input_infos[kValueIdx]->GetShape(); + check_shape_rank(value_shape, QKV_SHAPE_RANK, "value"); + check_head_dim(value_shape, QK_SPLIT1_HEAD_DIM, "value"); + } + + if (is_input_softmax_lse_) { + if (!input_infos[kOPrevIdx]->IsDynamic()) { + const auto &prev_out_shape = input_infos[kOPrevIdx]->GetShape(); + check_shape_rank(prev_out_shape, QKV_SHAPE_RANK, "prev_out"); + check_head_dim(prev_out_shape, QK_SPLIT1_HEAD_DIM, "prev_out"); + } + + if (!input_infos[kLsePrevIdx]->IsDynamic()) { + const auto &prev_lse_shape = input_infos[kLsePrevIdx]->GetShape(); + check_shape_rank(prev_lse_shape, LSE_SHAPE_RANK, "prev_lse"); + } + } +} + +ShapeArray CustomRingMLAOpFuncImpl::InferShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const { + auto calc_type = static_cast( + input_infos[kCalcTypeIdx]->GetScalarValueWithCheck()); + is_input_softmax_lse_ = (calc_type == internal::RingMLAParam::CalcType::CALC_TYPE_DEFAULT); + (void)CheckInputShape(primitive, input_infos); + const auto &query_shape = input_infos[kQueryIdx]->GetShape(); + const auto &value_shape = input_infos[kValueIdx]->GetShape(); + ShapeVector attn_out_shape = query_shape; + attn_out_shape[QKV_HEAD_DIM_IDX] = value_shape[QKV_HEAD_DIM_IDX]; + + ShapeVector lse_out_shape; + if (is_input_softmax_lse_) { + lse_out_shape = input_infos[kLsePrevIdx]->GetShape(); + return {attn_out_shape, lse_out_shape}; + } + lse_out_shape = query_shape; + lse_out_shape[LSE_N_TOKENS_IDX] = query_shape[QKV_N_TOKENS_IDX]; + lse_out_shape[LSE_HEAD_NUM_IDX] = query_shape[QKV_HEAD_NUM_IDX]; + lse_out_shape.resize(LSE_SHAPE_RANK); + return {attn_out_shape, lse_out_shape}; +} + +std::vector CustomRingMLAOpFuncImpl::InferType(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const { + auto query_type = input_infos[kQueryIdx]->GetType(); + return {query_type, TypeId::kNumberTypeFloat32}; +} + +bool CustomRingMLA::RingMLAParamCheck(const internal::RingMLAParam &op_param) { + if (op_param.calcType != internal::RingMLAParam::CalcType::CALC_TYPE_DEFAULT && + op_param.calcType != internal::RingMLAParam::CalcType::CALC_TYPE_FISRT_RING) { + MS_LOG(ERROR) << "Ring MLA expects calcType to be one of CALC_TYPE_DEFAULT, CALC_TYPE_FISRT_RING. " + << "But got param.calcType = " << op_param.calcType; + return false; + } + if (op_param.headNum <= 0) { + MS_LOG(ERROR) << "Ring MLA expects headNum to be greater than zero, But got param.headNum = " << op_param.headNum; + return false; + } + if (op_param.kvHeadNum < 0) { + MS_LOG(ERROR) << "Ring MLA expects kvHeadNum to be no less than zero, " + << "But got param.kvHeadNum = " << op_param.kvHeadNum; + return false; + } + if (op_param.kvHeadNum > 0 && op_param.headNum % op_param.kvHeadNum != 0) { + MS_LOG(ERROR) << "Ring MLA expects headNum to be divisible by kvHeadNum, " + << "But got param.headNum = " << op_param.headNum + << ", param.kvHeadNum = " << op_param.kvHeadNum; + return false; + } + if (op_param.headNum < op_param.kvHeadNum) { + MS_LOG(ERROR) << "Ring MLA expects headNum >= kvHeadNum, " + << "But got param.headNum = " << op_param.headNum + << ", param.kvHeadNum = " << op_param.kvHeadNum; + return false; + } + if (op_param.maskType != internal::RingMLAParam::MaskType::NO_MASK && + op_param.maskType != internal::RingMLAParam::MaskType::MASK_TYPE_TRIU) { + MS_LOG(ERROR) << "Ring MLA expects maskType as one of NO_MASK, MASK_TYPE_TRIU, " + << "But got param.maskType = " << op_param.maskType; + return false; + } + if (op_param.inputLayout != internal::RingMLAParam::InputLayout::TYPE_BSND) { + MS_LOG(ERROR) << "Ring MLA only supports inputLayout as TYPE_BSND, " + << "But got param.inputLayout = " << op_param.inputLayout; + return false; + } + if (op_param.kernelType != internal::RingMLAParam::KernelType::KERNELTYPE_HIGH_PRECISION) { + MS_LOG(ERROR) << "Ring MLA only supports kernelType as KERNELTYPE_HIGH_PRECISION, " + << "But got param.kernelType = " << op_param.kernelType; + return false; + } + return true; +} + +// Helper to extract a vector from a KernelTensor, supporting int32 and int64 +static void ExtractSeqLenVector(KernelTensor *const seq_len_tensor, std::vector *out_vec) { + MS_EXCEPTION_IF_NULL(seq_len_tensor); + out_vec->clear(); + TypeId dtype = seq_len_tensor->dtype_id(); + if (dtype == kNumberTypeInt64) { + const auto &vec64 = seq_len_tensor->GetValueWithCheck>(); + out_vec->assign(vec64.begin(), vec64.end()); + } else if (dtype == kNumberTypeInt32) { + *out_vec = seq_len_tensor->GetValueWithCheck>(); + } else { + MS_LOG(EXCEPTION) << "actual_seq_lengths data type must be Int32 or Int64, but got " + << TypeIdToString(dtype); + } +} + +// Returns true if the new sequence length vector is different from the old one +static bool NeedUpdateSeqLen(const std::vector &old_seq_len, const std::vector &new_seq_len) { + if (old_seq_len.size() != new_seq_len.size()) { + return true; + } + for (size_t i = 0; i < new_seq_len.size(); ++i) { + if (old_seq_len[i] != new_seq_len[i]) { + return true; + } + } + return false; +} + +// Updates seq_len from the input tensor if needed, returns true if update is needed +static bool GetSeqLenFromInputAndCheckUpdate(const std::string &kernel_name, const std::string &tensor_name, + KernelTensor *const seq_len_tensor, std::vector *seq_len) { + MS_EXCEPTION_IF_NULL(seq_len_tensor); + + // If the tensor is not None, extract and compare + if (seq_len_tensor->type_id() != kMetaTypeNone) { + std::vector new_seq_len; + ExtractSeqLenVector(seq_len_tensor, &new_seq_len); + + bool need_update = NeedUpdateSeqLen(*seq_len, new_seq_len); + if (need_update) { + *seq_len = std::move(new_seq_len); + } + + MS_LOG(INFO) << "For op '" << kernel_name << "', set param seq_len with tensor_input '" << tensor_name << "' as " + << (*seq_len); + return need_update; + } + + // If tensor is None, handle accordingly + MS_LOG(INFO) << "For op '" << kernel_name << "', param seq_len must be set, but none of '" + << tensor_name << "' is found in tensor_input"; + if (seq_len->empty()) { + // No previous value, nothing to update + return false; + } + // Previous value exists, but now input is None: clear and signal update + seq_len->clear(); + return true; +} + +internal::InternalOpPtr CustomRingMLA::CreateKernel(const internal::InputsImmutableInfoList &inputs_ii, + const internal::OutputsImmutableInfoList &outputs_ii, + const std::vector &ms_inputs, + const std::vector &ms_outputs) { + // Extract and set all required parameters from ms_inputs + param_.headNum = static_cast(ms_inputs[kHeadNumIdx]->GetValueWithCheck()); + param_.qkScale = ms_inputs[kQkScaleIdx]->GetValueWithCheck(); + param_.kvHeadNum = static_cast(ms_inputs[kKvHeadNumIdx]->GetValueWithCheck()); + param_.maskType = static_cast( + ms_inputs[kMaskTypeIdx]->GetValueWithCheck()); + param_.calcType = static_cast( + ms_inputs[kCalcTypeIdx]->GetValueWithCheck()); + + // Update sequence lengths from input tensors + (void)GetSeqLenFromInputAndCheckUpdate(kernel_name_, "q_seq_lens", ms_inputs[kQSeqLenIdx], ¶m_.qSeqLen); + (void)GetSeqLenFromInputAndCheckUpdate(kernel_name_, "batch_valid_length", + ms_inputs[kKVSeqLenIdx], ¶m_.kvSeqLen); + + MS_CHECK_VALUE(RingMLAParamCheck(param_), + CheckAndConvertUtils::FormatCommMsg("For RingMLA The param is invalid, please check the input " + "parameters, kernel_name: ", kernel_name_)); + + created_flag_ = true; + return internal::CreateRingMLAOp(inputs_ii, outputs_ii, param_, internal::kInternalRingMLAOpName); +} + +bool CustomRingMLA::UpdateParam(const std::vector &inputs, + const std::vector &outputs) { + if (created_flag_) { + // Sequence lengths already initialized in CreateKernel, skip update + created_flag_ = false; + return true; + } + + // Check if either q_seq_len or kv_seq_len needs update + bool q_need_update = GetSeqLenFromInputAndCheckUpdate(kernel_name_, "q_seq_lens", + inputs[kQSeqLenIdx], ¶m_.qSeqLen); + bool kv_need_update = GetSeqLenFromInputAndCheckUpdate(kernel_name_, "batch_valid_length", + inputs[kKVSeqLenIdx], ¶m_.kvSeqLen); + + if (q_need_update || kv_need_update) { + auto ret = internal_op_->UpdateParam(¶m_); + if (ret != internal::kInternalOk) { + MS_LOG(ERROR) << "CustomRingMLA UpdateParam failed, kernel_name: " << kernel_name_; + return false; + } + return true; + } + + return true; +} + +uint64_t CustomRingMLA::GenerateTilingKey(const std::vector &inputs) { + // User defined CacheKey, the inputs should include all the factors which will affect tiling result. + return InternalTilingCache::GenerateKey(kernel_name_, inputs, param_.qSeqLen, param_.kvSeqLen); +} + +void CustomRingMLA::InitKernelInputsOutputsIndex() { + kernel_inputs_index_ = {kQueryIdx, kQueryRopeIdx, kKeyIdx, kKeyRopeIdx, kValueIdx, kMaskIdx, kAlibiCoeffIdx, + kDeqQKIdx, kOffsetQKIdx, kDeqPVIdx, kOffsetPVIdx, kQuantPIdx, kLogNIdx, + kOPrevIdx, kLsePrevIdx}; + kernel_outputs_index_ = {kAttentionOutIdx, kSoftmaxLseOutIdx}; +} + +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(ring_mla, ms_custom_ops::CustomRingMLAOpFuncImpl, ms_custom_ops::CustomRingMLA); diff --git a/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla.h b/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla.h new file mode 100644 index 0000000..159b4c7 --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla.h @@ -0,0 +1,128 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_KERNELS_INTERNAL_RING_MLA_H_ +#define MS_KERNELS_INTERNAL_RING_MLA_H_ + +#include +#include +#include +#include +#include "internal_kernel_mod.h" +#include "ir/tensor.h" +#include "kernel/ascend/acl_ir/acl_convert.h" +#include "mindspore/ops/ops_utils/op_utils.h" +#include "ms_extension/api.h" +#include "ops/base_operator.h" +#include "ops/ops_func_impl/op_func_impl.h" +#include "ops/ops_func_impl/simple_infer.h" +#include "runtime/device/kernel_runtime.h" +#include "utils/check_convert_utils.h" + +namespace { +// shape rank +constexpr auto QKV_SHAPE_RANK = 3; // [sum(seqlen), headNum, headSize] +constexpr auto LSE_SHAPE_RANK = 2; // [headNum, qNTokens] +// query, key, value dim index +constexpr auto QKV_N_TOKENS_IDX = 0; +constexpr auto QKV_HEAD_NUM_IDX = 1; +constexpr auto QKV_HEAD_DIM_IDX = 2; +constexpr auto QK_SPLIT1_HEAD_DIM = 128; +constexpr auto QK_SPLIT2_HEAD_DIM = 64; +// lse dim index +constexpr auto LSE_N_TOKENS_IDX = 1; +constexpr auto LSE_HEAD_NUM_IDX = 0; +// seqlen, mask index +constexpr auto SEQLEN_BATCH_IDX = 0; + +enum RingMLAInputIndex : int { + kQueryIdx = 0, + kQueryRopeIdx, + kKeyIdx, + kKeyRopeIdx, + kValueIdx, + kMaskIdx, + kAlibiCoeffIdx, + kDeqQKIdx, + kOffsetQKIdx, + kDeqPVIdx, + kOffsetPVIdx, + kQuantPIdx, + kLogNIdx, + kOPrevIdx, + kLsePrevIdx, + kQSeqLenIdx, + kKVSeqLenIdx, + kHeadNumIdx, + kQkScaleIdx, + kKvHeadNumIdx, + kMaskTypeIdx, + kCalcTypeIdx, + kRingMLAInputNums +}; + +enum RingMLAOutputIndex : int { + kAttentionOutIdx = 0, + kSoftmaxLseOutIdx, + kRingMLAOutputNums +}; +} // namespace + +namespace ms_custom_ops { + +class OPS_API CustomRingMLAOpFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override; + std::vector InferType(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override; + bool GeneralInferRegistered() const override { return true; } + std::set GetValueDependArgIndices() const override { + return {kQSeqLenIdx, kKVSeqLenIdx, kHeadNumIdx, kQkScaleIdx, kKvHeadNumIdx, kMaskTypeIdx, kCalcTypeIdx}; + }; + + protected: + void CheckInputShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const; + + private: + mutable bool is_input_softmax_lse_{false}; +}; + +class CustomRingMLA : public InternalKernelMod { + public: + CustomRingMLA() = default; + ~CustomRingMLA() override = default; + void InitKernelInputsOutputsIndex() override; + + protected: + internal::InternalOpPtr CreateKernel( + const internal::InputsImmutableInfoList &inputs, + const internal::OutputsImmutableInfoList &outputs, + const std::vector &ms_inputs, + const std::vector &ms_outputs) override; + bool UpdateParam(const std::vector &inputs, + const std::vector &outputs) override; + uint64_t GenerateTilingKey(const std::vector &inputs) override; + + private: + bool RingMLAParamCheck(const internal::RingMLAParam &op_param); + bool created_flag_{false}; + internal::RingMLAParam param_; +}; + +} // namespace ms_custom_ops + +#endif // MS_KERNELS_INTERNAL_RING_MLA_H_ diff --git a/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla_runner.cc b/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla_runner.cc new file mode 100644 index 0000000..c5f2075 --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla_runner.cc @@ -0,0 +1,206 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ring_mla_runner.h" + +using namespace ms_custom_ops; +namespace ms::pynative { + +namespace { + +inline bool GetSeqLenFromInputTensor(const ms::Tensor &input_tensor, std::vector *seq_len) { + if (seq_len == nullptr) { + MS_LOG(EXCEPTION) << "For GetSeqLenFromInputTensor, the seq_len ptr is nullptr."; + } + auto input_tensor_ptr = input_tensor.tensor(); + auto input_tensor_value = static_cast(input_tensor_ptr->data_c()); + if (input_tensor_value == nullptr) { + MS_LOG(EXCEPTION) << "For GetSeqLenFromInputTensor, the input_tensor_value is nullptr."; + } + auto input_tensor_value_num = input_tensor.numel(); + seq_len->clear(); + for (size_t i = 0; i < input_tensor_value_num; ++i) { + seq_len->emplace_back(input_tensor_value[i]); + } + return true; +} + +} // namespace + +void RingMLARunner::SetSeqLen(const std::optional &q_seq_lens, + const std::optional &context_lens) { + if (!q_seq_lens.has_value() || !context_lens.has_value()) { + MS_LOG(EXCEPTION) << "For RingMLARunner, the q_seq_lens and context_lens must not be None."; + return; + } + (void)GetSeqLenFromInputTensor(q_seq_lens.value(), ¶m_.qSeqLen); + (void)GetSeqLenFromInputTensor(context_lens.value(), ¶m_.kvSeqLen); +} + +void RingMLARunner::SetRingMLAParam(int64_t head_num, float scale_value, + int64_t kv_head_num, int64_t mask_type, int64_t calc_type) { + param_.headNum = static_cast(head_num); + param_.qkScale = scale_value; + param_.kvHeadNum = static_cast(kv_head_num); + param_.maskType = static_cast(mask_type); + param_.calcType = static_cast(calc_type); +} + +bool RingMLARunner::UpdateParam() { + if (created_flag_) { + created_flag_ = false; + return true; + } + if (internal_op_ == nullptr) { + MS_LOG(ERROR) << "RingMLARunner UpdateParam failed, internal_op_ is nullptr."; + return false; + } + auto ret = internal_op_->UpdateParam(¶m_); + if (ret != internal::kInternalOk) { + MS_LOG(ERROR) << "RingMLARunner UpdateParam failed."; + return false; + } + return true; +} + +internal::InternalOpPtr RingMLARunner::CreateKernel( + const internal::InputsImmutableInfoList &inputs, + const internal::OutputsImmutableInfoList &outputs) { + created_flag_ = true; + return internal::CreateRingMLAOp(inputs, outputs, param_, internal::kInternalRingMLAOpName); +} + +MS_KERNELS_INTERNAL_NAME_REG(RingMLA, internal::kInternalRingMLAOpName); + +} // namespace ms::pynative + +namespace ms_custom_ops { + +namespace { + +ms::Tensor ToTensorOrEmpty(const std::optional &opt_tensor) { + return opt_tensor.has_value() ? opt_tensor.value() : ms::Tensor(); +} + +ms::Tensor GenAttnOutTensor(const ms::Tensor &query) { + return ms::Tensor(query.data_type(), query.shape()); +} + +ms::Tensor GenLseOutTensor(const ms::Tensor &query, const std::optional &lse_prev, + const int64_t &calc_type) { + using CalcType = internal::RingMLAParam::CalcType; + bool is_ring = static_cast(calc_type) == CalcType::CALC_TYPE_DEFAULT; + if (is_ring && lse_prev.has_value()) { + return ms::Tensor(lse_prev.value().data_type(), lse_prev.value().shape()); + } + + constexpr size_t QKV_N_TOKENS_IDX = 0; + constexpr size_t QKV_HEAD_NUM_IDX = 1; + constexpr size_t LSE_N_TOKENS_IDX = 1; + constexpr size_t LSE_HEAD_NUM_IDX = 0; + constexpr size_t LSE_SHAPE_RANK = 2; // [headNum, qNTokens] + + auto query_shape = query.shape(); + auto lse_out_shape = query_shape; + lse_out_shape[LSE_N_TOKENS_IDX] = query_shape[QKV_N_TOKENS_IDX]; + lse_out_shape[LSE_HEAD_NUM_IDX] = query_shape[QKV_HEAD_NUM_IDX]; + lse_out_shape.resize(LSE_SHAPE_RANK); + return ms::Tensor(TypeId::kNumberTypeFloat32, lse_out_shape); +} + +} // namespace + +std::vector npu_ring_mla( + const ms::Tensor &query, const ms::Tensor &query_rope, const ms::Tensor &key, + const ms::Tensor &key_rope, const ms::Tensor &value, const std::optional &mask, + const std::optional &alibi_coeff, const std::optional &deq_scale_qk, + const std::optional &deq_offset_qk, const std::optional &deq_scale_pv, + const std::optional &deq_offset_pv, const std::optional &quant_p, + const std::optional &log_n, const std::optional &o_prev, + const std::optional &lse_prev, const std::optional &q_seq_lens, + const std::optional &context_lens, const int64_t &head_num, const float &scale_value, + const int64_t &kv_head_num, const int64_t &mask_type, const int64_t &calc_type) { + const std::string op_name = "RingMLA"; + auto runner = std::make_shared(op_name); + MS_EXCEPTION_IF_NULL(runner); + + runner->SetRingMLAParam(head_num, scale_value, kv_head_num, mask_type, calc_type); + runner->SetSeqLen(q_seq_lens, context_lens); + + // Setup the runner with all parameters (including hash calculation) + runner->Setup(op_name, query, query_rope, key, key_rope, value, mask, alibi_coeff, deq_scale_qk, deq_offset_qk, + deq_scale_pv, deq_offset_pv, quant_p, log_n, o_prev, lse_prev, q_seq_lens, context_lens, + head_num, scale_value, kv_head_num, mask_type, calc_type); + + auto attn_out = GenAttnOutTensor(query); + auto lse_out = GenLseOutTensor(query, lse_prev, calc_type); + + std::vector inputs = { + query, query_rope, key, key_rope, value, + ToTensorOrEmpty(mask), ToTensorOrEmpty(alibi_coeff), + ToTensorOrEmpty(deq_scale_qk), ToTensorOrEmpty(deq_offset_qk), + ToTensorOrEmpty(deq_scale_pv), ToTensorOrEmpty(deq_offset_pv), + ToTensorOrEmpty(quant_p), ToTensorOrEmpty(log_n), + ToTensorOrEmpty(o_prev), ToTensorOrEmpty(lse_prev) + }; + std::vector outputs = {attn_out, lse_out}; + runner->GetOrCreateKernel(inputs, outputs); + runner->Run(inputs, outputs); + return outputs; +} + +} // namespace ms_custom_ops + +auto pyboost_ring_mla(const ms::Tensor &query, const ms::Tensor &query_rope, const ms::Tensor &key, + const ms::Tensor &key_rope, const ms::Tensor &value, const std::optional &mask, + const std::optional &alibi_coeff, const std::optional &deq_scale_qk, + const std::optional &deq_offset_qk, const std::optional &deq_scale_pv, + const std::optional &deq_offset_pv, const std::optional &quant_p, + const std::optional &log_n, const std::optional &o_prev, + const std::optional &lse_prev, const ms::Tensor &q_seq_lens, + const ms::Tensor &context_lens, const int64_t &head_num, const float &scale_value, + const int64_t &kv_head_num, const int64_t &mask_type, const int64_t &calc_type) { + return ms::pynative::PyboostRunner::Call<2>( + ms_custom_ops::npu_ring_mla, query, query_rope, key, key_rope, value, mask, alibi_coeff, deq_scale_qk, + deq_offset_qk, deq_scale_pv, deq_offset_pv, quant_p, log_n, o_prev, lse_prev, q_seq_lens, context_lens, + head_num, scale_value, kv_head_num, mask_type, calc_type); +} + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("ring_mla", &pyboost_ring_mla, "Ring MLA", + pybind11::arg("query"), + pybind11::arg("query_rope"), + pybind11::arg("key"), + pybind11::arg("key_rope"), + pybind11::arg("value"), + pybind11::arg("mask") = std::nullopt, + pybind11::arg("alibi_coeff") = std::nullopt, + pybind11::arg("deq_scale_qk") = std::nullopt, + pybind11::arg("deq_offset_qk") = std::nullopt, + pybind11::arg("deq_scale_pv") = std::nullopt, + pybind11::arg("deq_offset_pv") = std::nullopt, + pybind11::arg("quant_p") = std::nullopt, + pybind11::arg("log_n") = std::nullopt, + pybind11::arg("o_prev") = std::nullopt, + pybind11::arg("lse_prev") = std::nullopt, + pybind11::arg("q_seq_lens"), + pybind11::arg("context_lens"), + pybind11::arg("head_num"), + pybind11::arg("scale_value"), + pybind11::arg("kv_head_num"), + pybind11::arg("mask_type"), + pybind11::arg("calc_type")); +} diff --git a/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla_runner.h b/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla_runner.h new file mode 100644 index 0000000..3536e99 --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/ring_mla/ring_mla_runner.h @@ -0,0 +1,61 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_KERNELS_INTERNAL_RING_MLA_RUNNER_H_ +#define MS_KERNELS_INTERNAL_RING_MLA_RUNNER_H_ + +#include +#include +#include +#include +#include + +#include "internal_kernel_mod.h" +#include "ir/tensor.h" +#include "kernel/ascend/acl_ir/acl_convert.h" +#include "mindspore/ops/ops_utils/op_utils.h" +#include "ms_extension/api.h" +#include "ops/base_operator.h" +#include "ops/ops_func_impl/op_func_impl.h" +#include "ops/ops_func_impl/simple_infer.h" +#include "runtime/device/kernel_runtime.h" +#include "utils/check_convert_utils.h" +#include "internal_pyboost_runner.h" + +using namespace ms_custom_ops; +namespace ms::pynative { + +class RingMLARunner : public InternalPyboostRunner { + public: + using InternalPyboostRunner::InternalPyboostRunner; + void SetSeqLen(const std::optional &q_seq_lens, + const std::optional &context_lens); + void SetRingMLAParam(int64_t head_num, float scale_value, + int64_t kv_head_num, int64_t mask_type, int64_t calc_type); + + protected: + bool UpdateParam() override; + internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, + const internal::OutputsImmutableInfoList &outputs) override; + + private: + bool created_flag_{false}; + internal::RingMLAParam param_; +}; + +} // namespace ms::pynative + +#endif // MS_KERNELS_INTERNAL_RING_MLA_RUNNER_H_ diff --git a/tests/st/test_ms_ring_mla.py b/tests/st/test_ms_ring_mla.py new file mode 100644 index 0000000..8ce4405 --- /dev/null +++ b/tests/st/test_ms_ring_mla.py @@ -0,0 +1,470 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for ms_custom_ops.ring_mla using numpy golden reference.""" + +from typing import List, Optional, Tuple +import math + +import numpy as np +import pytest + +import mindspore as ms +from mindspore import Tensor, context, ops, nn +from mindspore.common.np_dtype import bfloat16 as np_bfloat16 + +import ms_custom_ops + + +class TestConfig: + def __init__(self, device_target: str = "Ascend", mode: int = context.GRAPH_MODE): + self.device_target = device_target + self.mode = mode + + def apply(self): + context.set_context(device_target=self.device_target, mode=self.mode) + + +def _make_triu_mask(size: int, dtype: np.dtype, batch: Optional[int] = None) -> np.ndarray: + # Follow coefficient semantics similar to provided torch test code + if dtype == np.float16: + # mask values directly used + base = -10000.0 + mask = np.triu(np.ones((size, size), dtype=np.float32) * base, 1) + else: + # bf16 and others: use a very negative number + base = 1 + mask = np.triu(np.ones((size, size), dtype=np.float32), 1) * base + if batch is not None: + mask = np.broadcast_to(mask, (batch, size, size)).copy() + return mask.astype(np.float32) + + +def _reconstruct_full(q_base: np.ndarray, q_rope: np.ndarray) -> np.ndarray: + # q_base: [q_ntokens, heads, d_base], q_rope: [q_ntokens, heads, d_rope] + return np.concatenate([q_base, q_rope], axis=-1) + + +def _expand_kv_to_heads(k_or_v: np.ndarray, heads: int, kv_heads: int) -> np.ndarray: + # k_or_v: [kv_ntokens, kv_heads, dim] + if heads == kv_heads: + return k_or_v + group_num = heads // kv_heads + # Repeat along kv_head dim to match total heads + return np.repeat(k_or_v, repeats=group_num, axis=1) + + +def _golden_attention( + q_base: np.ndarray, + q_rope: np.ndarray, + k_base: np.ndarray, + k_rope: np.ndarray, + v: np.ndarray, + mask: Optional[np.ndarray], + q_seq_lens: List[int], + kv_seq_lens: List[int], + heads: int, + kv_heads: int, + scale: float, + out_dim: int, + out_dtype: np.dtype, +) -> Tuple[np.ndarray, np.ndarray]: + """Compute golden attention and lse without ring update. + + Returns: + out: [q_ntokens, heads, out_dim] in out_dtype + lse: [heads, q_ntokens] in float32 + """ + q = _reconstruct_full(q_base, q_rope) # [q_ntokens, heads, d] + k = _reconstruct_full(k_base, k_rope) # [kv_ntokens, kv_heads, d] + v_dim = v.shape[-1] + assert out_dim == v_dim + + # Expand K/V from kv_heads to heads by repeating per-group + k_exp = _expand_kv_to_heads(k, heads, kv_heads) # [kv_ntokens, heads, d] + v_exp = _expand_kv_to_heads(v, heads, kv_heads) # [kv_ntokens, heads, out_dim] + + q_ntokens = q.shape[0] + kv_ntokens = k.shape[0] + assert sum(q_seq_lens) == q_ntokens + assert sum(kv_seq_lens) == kv_ntokens + + # Offsets per batch + out = np.zeros((q_ntokens, heads, out_dim), dtype=np.float32) + lse = np.zeros((heads, q_ntokens), dtype=np.float32) + + q_offset = 0 + kv_offset = 0 + batch = len(q_seq_lens) + for b in range(batch): + q_len = q_seq_lens[b] + kv_len = kv_seq_lens[b] + + if q_len == 0: + continue + + q_slice = q[q_offset : q_offset + q_len] # [q_len, heads, d] + if kv_len == 0: + # Output zeros, lse zeros + q_offset += q_len + continue + + k_slice = k_exp[kv_offset : kv_offset + kv_len] # [kv_len, heads, d] + v_slice = v_exp[kv_offset : kv_offset + kv_len] # [kv_len, heads, out_dim] + + # Compute per-head attention + # logits[i, h, j] = dot(q_slice[i,h,:], k_slice[j,h,:]) * scale + # We'll compute as batch matmul per head using einsum + # q_slice: [q_len, heads, d], k_slice: [kv_len, heads, d] + logits = np.einsum("qhd,khd->qhk", q_slice.astype(np.float32), k_slice.astype(np.float32)) * scale + + # Apply mask if provided + if mask is not None: + if mask.ndim == 2: + mask_slice = mask[:q_len, :kv_len] + elif mask.ndim == 3: + mask_slice = mask[b, :q_len, :kv_len] + elif mask.ndim == 4: + # [batch, heads, q, kv] + mask_slice = mask[b, :, :q_len, :kv_len] # [heads, q, kv] + # transpose to [q, heads, kv] + mask_slice = np.transpose(mask_slice, (1, 0, 2)) + else: + raise ValueError("Unsupported mask ndim") + if mask.ndim < 4: + # broadcast to [q, heads, kv] by expanding head axis + mask_slice = np.broadcast_to(mask_slice[:, None, :], logits.shape).copy() + logits = logits + mask_slice.astype(np.float32) + + # Softmax per head and query across kv axis + m = np.max(logits, axis=2, keepdims=True) + exp_logits = np.exp((logits - m).astype(np.float32)) + denom = np.sum(exp_logits, axis=2, keepdims=True) + p = exp_logits / np.maximum(denom, 1e-38) + + # Output: [q_len, heads, out_dim] + o = np.einsum("qhk,khd->qhd", p.astype(np.float32), v_slice.astype(np.float32)) + + # LSE: [heads, q_len] + lse_b = (np.log(np.maximum(denom.squeeze(-1), 1e-38)) + m.squeeze(-1)).transpose(1, 0) + + out[q_offset : q_offset + q_len] = o + lse[:, q_offset : q_offset + q_len] = lse_b + + q_offset += q_len + kv_offset += kv_len + + return out.astype(out_dtype), lse.astype(np.float32) + + +def _golden_ring_update( + out_cur: np.ndarray, # [q_ntokens, heads, out_dim] + lse_cur: np.ndarray, # [heads, q_ntokens] + o_prev: np.ndarray, # [q_ntokens, heads, out_dim] + lse_prev: np.ndarray, # [heads, q_ntokens] +) -> Tuple[np.ndarray, np.ndarray]: + # Combine according to: new_o = (o_cur * exp(lse_cur) + o_prev * exp(lse_prev)) / (exp(lse_cur)+exp(lse_prev)) + exp_new = np.exp(lse_cur.astype(np.float32)) + exp_old = np.exp(lse_prev.astype(np.float32)) + + # Align shapes + exp_new_e = np.transpose(exp_new, (1, 0))[:, :, None] # [q, h, 1] + exp_old_e = np.transpose(exp_old, (1, 0))[:, :, None] # [q, h, 1] + + num = out_cur.astype(np.float32) * exp_new_e + o_prev.astype(np.float32) * exp_old_e + den = exp_new_e + exp_old_e + out_combined = num / np.maximum(den, 1e-38) + + lse_combined = np.log(np.maximum(exp_new + exp_old, 1e-38)) # [heads, q] + return out_combined, lse_combined + + +def _ms_tensor(x: np.ndarray) -> Tensor: + if x.dtype == np_bfloat16: + # MindSpore expects float32 array then cast by dtype + return Tensor(x.astype(np.float32)).astype(ms.bfloat16) + return Tensor(x) + + +def _init_prev_tensors(rng: np.random.Generator, q_ntokens: int, heads: int, dv: int, + dtype: np.dtype, is_ring: int) -> Tuple[np.ndarray, np.ndarray]: + if is_ring == 1: + o_prev = rng.uniform(-1.0, 1.0, size=(q_ntokens, heads, dv)).astype(dtype) + lse_prev = (rng.random((heads, q_ntokens)) * 10.0).astype(np.float32) + else: + o_prev = np.zeros((q_ntokens, heads, dv), dtype=dtype) + lse_prev = np.zeros((heads, q_ntokens), dtype=np.float32) + return o_prev, lse_prev + + +class RingMLANet(nn.Cell): + """Thin wrapper to call ms_custom_ops.ring_mla with fixed attributes.""" + + def __init__(self, head_num: int, scale_value: float, kv_head_num: int, mask_type: int, calc_type: int): + super().__init__() + self.head_num = head_num + self.scale_value = scale_value + self.kv_head_num = kv_head_num + self.mask_type = mask_type + self.calc_type = calc_type + # determine execution mode once during initialization + self._is_pynative = (context.get_context("mode") == context.PYNATIVE_MODE) + + def construct(self, q_nope, q_rope, key, k_rope, value, mask, alibi_coeff, + deq_scale_qk, deq_offset_qk, deq_scale_pv, deq_offset_pv, quant_p, log_n, o_prev, lse_prev, + q_seq_lens, context_lens): + if self._is_pynative: + q_lens_cpu = q_seq_lens.move_to("CPU") + kv_lens_cpu = context_lens.move_to("CPU") + else: + q_lens_cpu = ops.move_to(q_seq_lens, "CPU") + kv_lens_cpu = ops.move_to(context_lens, "CPU") + return ms_custom_ops.ring_mla( + q_nope, q_rope, key, k_rope, value, mask, alibi_coeff, + deq_scale_qk, deq_offset_qk, deq_scale_pv, deq_offset_pv, quant_p, log_n, o_prev, lse_prev, + q_lens_cpu, kv_lens_cpu, + self.head_num, self.scale_value, self.kv_head_num, self.mask_type, self.calc_type) + + +class RingMLATestCase: + """A comprehensive test case for ring multi-head latent attention (MLA) operations. + + This class encapsulates all the necessary components for testing ring MLA functionality, + including input generation, mask creation, golden reference computation, and comparison + with MindSpore implementation. It supports various configurations such as different + data types (fp16, bf16), mask types (none, triu), and sequence lengths for both + queries and key-values. + """ + + def __init__( + self, + *, + heads: int, + kv_heads: int, + dim_qk: int, + dim_v: int, + q_seq_lens: List[int], + kv_seq_lens: List[int], + np_dtype: np.dtype, + mask_type: int, # 0: no mask, 1: triu + is_ring: int, + rng_seed: int, + mask_size: Optional[int] = None, + ): + self.heads = heads + self.kv_heads = kv_heads + self.dim_qk = dim_qk + self.dim_v = dim_v + self.q_seq_lens = q_seq_lens + self.kv_seq_lens = kv_seq_lens + self.np_dtype = np_dtype + self.mask_type = mask_type + self.is_ring = is_ring + self.rng = np.random.default_rng(rng_seed) + self.q_ntokens = int(sum(q_seq_lens)) + self.kv_ntokens = int(sum(kv_seq_lens)) + self.d_base = 128 + self.d_rope = dim_qk - self.d_base + self.scale = 1.0 / math.sqrt(float(dim_qk)) + self.max_seq = max(max(q_seq_lens), max(kv_seq_lens)) + self.mask_size = mask_size if mask_size is not None else self.max_seq + + def build_inputs(self): + q_full = self.rng.uniform(-1.0, 1.0, size=(self.q_ntokens, self.heads, self.dim_qk)).astype(self.np_dtype) + k_full = self.rng.uniform(-1.0, 1.0, size=(self.kv_ntokens, self.kv_heads, self.dim_qk)).astype(self.np_dtype) + v = self.rng.uniform(-1.0, 1.0, size=(self.kv_ntokens, self.kv_heads, self.dim_v)).astype(self.np_dtype) + q_base, q_rope = q_full[..., : self.d_base], q_full[..., self.d_base :] + k_base, k_rope = k_full[..., : self.d_base], k_full[..., self.d_base :] + return q_base, q_rope, k_base, k_rope, v + + def build_masks(self, batch: Optional[int] = None): + if self.mask_type == 0: + return None, None + assert self.mask_size == 512 + # fp16: both op and golden use the same values + if self.np_dtype == np.float16: + mask = _make_triu_mask(self.mask_size, np.float16, batch) + return mask.astype(np.float16), mask.astype(np.float32) + # bf16: op uses structural bf16 mask, golden uses -3e38 fp32 + base = np.triu(np.ones((self.mask_size, self.mask_size), dtype=np.float32), 1) + if batch is not None: + base = np.broadcast_to(base, (batch, self.mask_size, self.mask_size)).copy() + mask_op = base.astype(np_bfloat16) + mask_golden = base * -3e38 + return mask_op, mask_golden + + def run(self, run_mode: int, dynamic: bool = False): + q_base, q_rope, k_base, k_rope, v = self.build_inputs() + assert len(self.q_seq_lens) == len(self.kv_seq_lens) + batch = len(self.q_seq_lens) + mask_op, mask_golden = self.build_masks(batch=batch) + + # Golden + out_dtype = np.float16 if self.np_dtype == np.float16 else np_bfloat16 + cur_out, cur_lse = _golden_attention( + q_base, q_rope, k_base, k_rope, v, + mask_golden if mask_golden is not None else None, + self.q_seq_lens, self.kv_seq_lens, + self.heads, self.kv_heads, self.scale, self.dim_v, out_dtype, + ) + o_prev, lse_prev = _init_prev_tensors(self.rng, self.q_ntokens, self.heads, self.dim_v, self.np_dtype, is_ring=self.is_ring) + if self.is_ring == 1: + golden_out, golden_lse = _golden_ring_update(cur_out.astype(np.float32), cur_lse, o_prev.astype(np.float32), lse_prev) + else: + golden_out, golden_lse = cur_out, cur_lse + + # Net + calc_type = 0 if self.is_ring == 1 else 1 + net = RingMLANet(self.heads, self.scale, self.kv_heads, self.mask_type, calc_type) + + # Optionally enable dynamic shape by setting input placeholders + if dynamic: + ms_dtype = ms.float16 if self.np_dtype == np.float16 else ms.bfloat16 + # query no rope / rope + q_nope_dyn = Tensor(shape=[None, self.heads, self.d_base], dtype=ms_dtype) + q_rope_dyn = Tensor(shape=[None, self.heads, self.d_rope], dtype=ms_dtype) + # key / rope / value + k_nope_dyn = Tensor(shape=[None, self.kv_heads, self.d_base], dtype=ms_dtype) + k_rope_dyn = Tensor(shape=[None, self.kv_heads, self.d_rope], dtype=ms_dtype) + v_dyn = Tensor(shape=[None, self.kv_heads, self.dim_v], dtype=ms_dtype) + # mask (optional) + if self.mask_type == 0: + mask_dyn = None + else: + mask_dtype = ms.float16 if self.np_dtype == np.float16 else ms.bfloat16 + mask_dyn = Tensor(shape=[None, self.mask_size, self.mask_size], dtype=mask_dtype) + # optional tensors left as None + alibi_dyn = None + deq_scale_qk_dyn = None + deq_offset_qk_dyn = None + deq_scale_pv_dyn = None + deq_offset_pv_dyn = None + quant_p_dyn = None + log_n_dyn = None + # previous outputs and lse + o_prev_dyn = Tensor(shape=[None, self.heads, self.dim_v], dtype=ms_dtype) + lse_prev_dyn = Tensor(shape=[self.heads, None], dtype=ms.float32) + # sequence length tensors + q_lens_dyn = Tensor(shape=[None], dtype=ms.int32) + kv_lens_dyn = Tensor(shape=[None], dtype=ms.int32) + + net.set_inputs( + q_nope_dyn, q_rope_dyn, + k_nope_dyn, k_rope_dyn, + v_dyn, mask_dyn, + alibi_dyn, deq_scale_qk_dyn, deq_offset_qk_dyn, deq_scale_pv_dyn, deq_offset_pv_dyn, quant_p_dyn, log_n_dyn, + o_prev_dyn, lse_prev_dyn, + q_lens_dyn, kv_lens_dyn, + ) + out, lse = net( + _ms_tensor(q_base), _ms_tensor(q_rope), + _ms_tensor(k_base), _ms_tensor(k_rope), + _ms_tensor(v), _ms_tensor(mask_op) if mask_op is not None else None, + None, None, None, None, None, None, None, + _ms_tensor(o_prev), _ms_tensor(lse_prev), + _ms_tensor(np.array(self.q_seq_lens, dtype=np.int32)), + _ms_tensor(np.array(self.kv_seq_lens, dtype=np.int32)), + ) + + # Compare + out_np = (out.float().asnumpy() if self.np_dtype == np_bfloat16 else out.asnumpy()).astype(np.float32) + lse_np = lse.asnumpy().astype(np.float32) + tol = (1e-2, 1e-2) if self.np_dtype == np_bfloat16 else (1e-3, 1e-3) + assert np.allclose(out_np, golden_out.astype(np.float32), rtol=tol[0], atol=tol[1]) + assert np.allclose(lse_np, golden_lse.astype(np.float32), rtol=1e-3, atol=1e-3) + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('is_ring', [0, 1]) +@pytest.mark.parametrize('dynamic', [True, False]) +def test_ring_mla_fp16_no_mask(run_mode, is_ring, dynamic): + cfg = TestConfig(device_target="Ascend", mode=run_mode) + cfg.apply() + case = RingMLATestCase( + heads=16, kv_heads=16, dim_qk=192, dim_v=128, + q_seq_lens=[100, 100], kv_seq_lens=[100, 100], np_dtype=np.float16, + mask_type=0, is_ring=is_ring, rng_seed=2025 + is_ring, + ) + case.run(run_mode, dynamic=dynamic) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('is_ring', [0, 1]) +@pytest.mark.parametrize('dynamic', [True, False]) +def test_ring_mla_fp16_mask(run_mode, is_ring, dynamic): + cfg = TestConfig(device_target="Ascend", mode=run_mode) + cfg.apply() + case = RingMLATestCase( + heads=16, kv_heads=16, dim_qk=192, dim_v=128, + q_seq_lens=[150, 50], kv_seq_lens=[200, 200], np_dtype=np.float16, + mask_type=1, is_ring=is_ring, rng_seed=2026 + is_ring, mask_size=512, + ) + case.run(run_mode, dynamic=dynamic) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('is_ring', [0, 1]) +@pytest.mark.parametrize('dynamic', [True, False]) +def test_ring_mla_bf16_no_mask(run_mode, is_ring, dynamic): + cfg = TestConfig(device_target="Ascend", mode=run_mode) + cfg.apply() + case = RingMLATestCase( + heads=16, kv_heads=16, dim_qk=192, dim_v=128, + q_seq_lens=[128, 128], kv_seq_lens=[128, 128], np_dtype=np_bfloat16, + mask_type=0, is_ring=is_ring, rng_seed=2027 + is_ring, + ) + case.run(run_mode, dynamic=dynamic) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('is_ring', [0, 1]) +@pytest.mark.parametrize('dynamic', [True, False]) +def test_ring_mla_bf16_mask(run_mode, is_ring, dynamic): + cfg = TestConfig(device_target="Ascend", mode=run_mode) + cfg.apply() + case = RingMLATestCase( + heads=16, kv_heads=16, dim_qk=192, dim_v=128, + q_seq_lens=[120, 72], kv_seq_lens=[192, 192], np_dtype=np_bfloat16, + mask_type=1, is_ring=is_ring, rng_seed=2028 + is_ring, mask_size=512, + ) + case.run(run_mode, dynamic=dynamic) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('run_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('is_ring', [0, 1]) +@pytest.mark.parametrize('dynamic', [True, False]) +def test_ring_mla_bf16_mask_diff_qkv_lens(run_mode, is_ring, dynamic): + cfg = TestConfig(device_target="Ascend", mode=run_mode) + cfg.apply() + case = RingMLATestCase( + heads=16, kv_heads=16, dim_qk=192, dim_v=128, + q_seq_lens=[64, 128, 32, 1, 100], kv_seq_lens=[200, 180, 50, 10, 128], np_dtype=np_bfloat16, + mask_type=1, is_ring=is_ring, rng_seed=2029 + is_ring, mask_size=512, + ) + case.run(run_mode, dynamic=dynamic) + diff --git a/yaml/ms_kernels_internal/ring_mla_op.yaml b/yaml/ms_kernels_internal/ring_mla_op.yaml new file mode 100644 index 0000000..d982e33 --- /dev/null +++ b/yaml/ms_kernels_internal/ring_mla_op.yaml @@ -0,0 +1,72 @@ +#operator ring_mla +ring_mla: + args: + query: + dtype: tensor + query_rope: + dtype: tensor + key: + dtype: tensor + key_rope: + dtype: tensor + value: + dtype: tensor + mask: + dtype: tensor + default: None + alibi_coeff: + dtype: tensor + default: None + deq_scale_qk: + dtype: tensor + default: None + deq_offset_qk: + dtype: tensor + default: None + deq_scale_pv: + dtype: tensor + default: None + deq_offset_pv: + dtype: tensor + default: None + quant_p: + dtype: tensor + default: None + log_n: + dtype: tensor + default: None + o_prev: + dtype: tensor + default: None + lse_prev: + dtype: tensor + default: None + q_seq_lens: + dtype: tensor + default: None + context_lens: + dtype: tensor + default: None + head_num: + dtype: int + default: 0 + scale_value: + dtype: float + default: 1.0 + kv_head_num: + dtype: int + default: 0 + mask_type: + dtype: int + default: 0 + calc_type: + dtype: int + default: 0 + returns: + attention_out: + dtype: tensor + lse: + dtype: tensor + dispatch: + enable: True + InternalOpAscend: AutoGen -- Gitee