diff --git a/ccsrc/ops/ascendc/add_rms_norm.cc b/ccsrc/ops/ascendc/add_rms_norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9f47a260b9e89de49dc51d88587a19a49e4e607 --- /dev/null +++ b/ccsrc/ops/ascendc/add_rms_norm.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include "ascendc_kernel_mod.h" +#include "ms_extension/api.h" +#include +#include +#include + +namespace mindspore { +namespace ops { +class OPS_API AddRmsNormCustomOpFuncImpl : public OpFuncImpl { +public: + ShapeArray InferShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + auto &x1 = input_infos[kInputIndex0]; + auto &x2 = input_infos[kInputIndex1]; + auto &gamma = input_infos[kInputIndex2]; + const auto &x1_shape = x1->GetShape(); + const auto &x2_shape = x2->GetShape(); + const auto &gamma_shape = gamma->GetShape(); + auto gamma_rank = gamma_shape.size(); + + if (x1->IsDynamicRank() && x2->IsDynamicRank() && gamma->IsDynamicRank()) { + auto out_shape = ShapeVector{abstract::Shape::kShapeRankAny}; + return {out_shape, out_shape, out_shape}; + } + + if (!(x1->IsDynamic() || x2->IsDynamic())) { + if (x1_shape != x2_shape) { + MS_EXCEPTION(ValueError) << "For AddRmsNorm, shape of x1: " << x1_shape + << " are not consistent with the shape x2: " << x2_shape << " ."; + } + } + auto out_shape = x1_shape; + auto out_rank = out_shape.size(); + auto rstd_shape = out_shape; + if (gamma->IsDynamicRank()) { + if (!IsDynamicRank(out_shape)) { + rstd_shape = ShapeVector(out_rank, abstract::TensorShape::kShapeDimAny); + } else { + rstd_shape = ShapeVector{abstract::TensorShape::kShapeRankAny}; + } + } else if (!IsDynamicRank(out_shape)) { + if (gamma_rank > out_rank) { + MS_LOG(EXCEPTION) << "For AddRmsNorm, The [gamma] rank can not be bigger than the rank of " + "other two inputs. but got gamma_rank: " + << gamma_rank << ", out_rank: " << out_rank; + } + for (auto dim = out_rank - gamma_rank; dim < out_rank; dim++) { + int64_t x_dim = out_shape[dim]; + int64_t gamma_dim = gamma_shape[dim - out_rank + gamma_rank]; + if (x_dim != gamma_dim && (x_dim != abstract::TensorShape::kShapeDimAny && + gamma_dim != abstract::TensorShape::kShapeDimAny)) { + MS_LOG(EXCEPTION) << "For AddRmsNorm, Each dimension of [gamma] must be aligned to the " + "corresponding dimension of other two inputs. But got: gamma_dim: " + << gamma_dim << ", x_dim: " << x_dim; + } + rstd_shape[dim] = 1; + } + } + return {out_shape, rstd_shape, out_shape}; + } + + std::vector InferType(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + auto x_dtype = input_infos[0]->GetType(); + return {x_dtype, TypeId::kNumberTypeFloat, x_dtype}; + } + + bool GeneralInferRegistered() const override { return true; } +}; +} // namespace ops +} // namespace mindspore + +namespace ms_custom_ops { +class AddRmsNormCustomAscend : public AscendCKernelMod { +public: + AddRmsNormCustomAscend() : AscendCKernelMod(std::move("aclnnAddRmsNormCustom")) {} + ~AddRmsNormCustomAscend() = default; + + bool Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + MS_EXCEPTION_IF_NULL(stream_ptr); + epsilon_ = static_cast(inputs[3]->GetValueWithCheck()); + RunOp(stream_ptr, workspace, inputs[0], inputs[1], inputs[2], epsilon_, outputs[0], outputs[1], + outputs[2]); + return true; + } + + void GetWorkSpaceInfo(const std::vector &inputs, + const std::vector &outputs) override { + GetWorkspaceForResize(inputs[0], inputs[1], inputs[2], epsilon_, outputs[0], outputs[1], + outputs[2]); + } + +private: + DEFINE_GET_WORKSPACE_FOR_RESIZE(); + double epsilon_{1e-6f}; // Default epsilon value, can be overridden by input tensor +}; +} // namespace ms_custom_ops + +MS_CUSTOM_OPS_REGISTER(add_rms_norm, AddRmsNormCustomOpFuncImpl, AddRmsNormCustomAscend); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +#include "ascendc_pyboost_runner.h" + +namespace ms_custom_ops { +using namespace mindspore; +using namespace mindspore::device::ascend; + +std::vector custom_add_rms_norm(const ms::Tensor &x1, const ms::Tensor &x2, + const ms::Tensor &gamma, float epsilon) { + auto x1_shape = x1.shape(); + auto gamma_shape = gamma.shape(); + auto rstd_shape = x1_shape; + size_t x1_rank = x1_shape.size(); + size_t gamma_rank = gamma_shape.size(); + for (size_t i = x1_rank - gamma_rank; i < x1_rank; ++i) { + rstd_shape[i] = 1; + } + + auto out_y = ms::Tensor(x1.data_type(), x1_shape); + auto out_rstd = ms::Tensor(TypeId::kNumberTypeFloat32, rstd_shape); + auto out_x = ms::Tensor(x1.data_type(), x1_shape); + auto runner = std::make_shared("AddRmsNorm"); + runner->SetLaunchFunc( + LAUNCH_ASCENDC_FUNC(aclnnAddRmsNormCustom, x1, x2, gamma, epsilon, out_y, out_rstd, out_x)); + runner->Run({x1, x2, gamma}, {out_y, out_rstd, out_x}); + return {out_y, out_rstd, out_x}; +} + +auto pyboost_add_rms_norm(const ms::Tensor &x1, const ms::Tensor &x2, const ms::Tensor &gamma, + float epsilon) { + return ms::pynative::PyboostRunner::Call<3>(custom_add_rms_norm, x1, x2, gamma, epsilon); +} +} // namespace ms_custom_ops + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("add_rms_norm", &ms_custom_ops::pyboost_add_rms_norm, "add_rms_norm", pybind11::arg("x1"), + pybind11::arg("x2"), pybind11::arg("gamma"), pybind11::arg("epsilon") = 1e-6f); +} diff --git a/ccsrc/ops/ascendc/kernel_impl/op_host/add_rms_norm_custom.cpp b/ccsrc/ops/ascendc/kernel_impl/op_host/add_rms_norm_custom.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c8657e157ebb0a09b9afe9cd8228b637002fe9c --- /dev/null +++ b/ccsrc/ops/ascendc/kernel_impl/op_host/add_rms_norm_custom.cpp @@ -0,0 +1,175 @@ +/** + * @file add_custom.cpp + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ +#include "add_rms_norm_custom_tiling.h" +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" + + +namespace optiling { +constexpr uint32_t kDtypeKeyFp16 = 1; +constexpr uint32_t kDtypeKeyFp32 = 2; +constexpr uint32_t kDtypeKeyBf16 = 3; +constexpr uint32_t kUbFactorB16 = 12288; +constexpr uint32_t kUbFactorB32 = 10240; +constexpr uint32_t kUbFactorB16Cutd = 12096; +constexpr uint32_t kUbFactorB32Cutd = 9696; +constexpr uint32_t kBlockAlignNum = 16; +constexpr size_t kWorkspaceSize = 16 * 1024 * 1024 + 256; + +inline int64_t CeilDiv(const int64_t dividend, const int64_t divisor) { + if (divisor == 0) { + return 0; + } + return (dividend + divisor - 1) / divisor; +} + +static ge::graphStatus TilingFunc(gert::TilingContext *context) { + AddRmsNormTilingData tiling; + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + auto block_dims = ascendcPlatform.GetCoreNumAiv(); + const float *eps = context->GetAttrs()->GetAttrPointer(0); + + uint32_t row_factor = 64; + int64_t num_col = 1; + int64_t num_row = 1; + + auto gamma_shape = context->GetInputShape(2)->GetOriginShape(); + auto gamma_dim = gamma_shape.GetDimNum(); + for (size_t idx = 0; idx < gamma_dim; idx++) { + num_col = num_col * gamma_shape.GetDim(idx); + } + float avg_factor = (num_col == 0) ? 0 : 1.0 / num_col; + + auto x1_shape = context->GetInputShape(0)->GetOriginShape(); + auto x_dim = x1_shape.GetDimNum(); + for (size_t idx = 0; idx < x_dim - gamma_dim; idx++) { + num_row = num_row * x1_shape.GetDim(idx); + } + + uint32_t block_factor = 1; + uint32_t tile_num = CeilDiv(num_row, block_dims * block_factor); + block_factor *= tile_num; + uint32_t use_core_num = CeilDiv(num_row, block_factor); + + uint32_t dtype_key; + uint32_t ub_factor = kUbFactorB16; + bool is_cast_gamma = false; + ge::DataType x1_dtype = context->GetInputDesc(0)->GetDataType(); + ge::DataType gamma_dtype = context->GetInputDesc(2)->GetDataType(); + if (x1_dtype == ge::DataType::DT_FLOAT16) { + dtype_key = kDtypeKeyFp16; + if (gamma_dtype == ge::DataType::DT_FLOAT) { + is_cast_gamma = true; + ub_factor = kUbFactorB32; + } + } else if (x1_dtype == ge::DataType::DT_FLOAT) { + dtype_key = kDtypeKeyFp32; + ub_factor = kUbFactorB32; + } else if (x1_dtype == ge::DataType::DT_BF16) { + dtype_key = kDtypeKeyBf16; + if (gamma_dtype == ge::DataType::DT_FLOAT) { + is_cast_gamma = true; + ub_factor = kUbFactorB32; + } + } + + uint32_t split_d = num_col > ub_factor ? 1 : 0; + if (split_d == 1) { + ub_factor = ((x1_dtype == ge::DataType::DT_FLOAT) || is_cast_gamma) ? kUbFactorB32Cutd + : kUbFactorB16Cutd; + uint32_t col_tile_num = CeilDiv(num_col, ub_factor); + ub_factor = CeilDiv(num_col, col_tile_num * kBlockAlignNum) * kBlockAlignNum; + } + + uint32_t tiling_key = dtype_key * 10 + split_d; + if (is_cast_gamma) { + tiling_key = tiling_key + 100; + } + + tiling.set_num_col(num_col); + tiling.set_num_row(num_row); + tiling.set_epsilon(*eps); + tiling.set_block_factor(block_factor); + tiling.set_row_factor(row_factor); + tiling.set_ub_factor(ub_factor); + tiling.set_avg_factor(avg_factor); + + context->SetBlockDim(use_core_num); + context->SetTilingKey(tiling_key); + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), + context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + currentWorkspace[0] = kWorkspaceSize; + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext *context) { + const gert::Shape *x1_shape = context->GetInputShape(0); + gert::Shape *y_shape = context->GetOutputShape(0); + gert::Shape *x_shape = context->GetOutputShape(2); + *y_shape = *x1_shape; + *x_shape = *x1_shape; + return GRAPH_SUCCESS; +} +static graphStatus InferDataType(gert::InferDataTypeContext *context) { + const auto inputDataType = context->GetInputDataType(0); + context->SetOutputDataType(0, inputDataType); + context->SetOutputDataType(1, ge::DT_FLOAT); + context->SetOutputDataType(2, inputDataType); + return ge::GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class AddRmsNormCustom : public OpDef { +public: + explicit AddRmsNormCustom(const char *name) : OpDef(name) { + this->Input("x1") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("x2") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("gamma") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("y") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("rstd") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("x") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("epsilon").Float(); + + this->SetInferShape(ge::InferShape).SetInferDataType(ge::InferDataType); + this->AICore().SetTiling(optiling::TilingFunc).AddConfig("ascend910b"); + } +}; +OP_ADD(AddRmsNormCustom); +} // namespace ops diff --git a/ccsrc/ops/ascendc/kernel_impl/op_host/add_rms_norm_custom_tiling.h b/ccsrc/ops/ascendc/kernel_impl/op_host/add_rms_norm_custom_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..b2280d1625479cad888437828e2e0a137c7ec90e --- /dev/null +++ b/ccsrc/ops/ascendc/kernel_impl/op_host/add_rms_norm_custom_tiling.h @@ -0,0 +1,27 @@ +/** + * @file add_custom_tiling.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ +#ifndef ADD_RMS_NORM_CUSTOM_TILING_H +#define ADD_RMS_NORM_CUSTOM_TILING_H +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(AddRmsNormTilingData) + TILING_DATA_FIELD_DEF(uint32_t, num_row); + TILING_DATA_FIELD_DEF(uint32_t, num_col); + TILING_DATA_FIELD_DEF(uint32_t, block_factor); + TILING_DATA_FIELD_DEF(uint32_t, row_factor); + TILING_DATA_FIELD_DEF(uint32_t, ub_factor); + TILING_DATA_FIELD_DEF(float, epsilon); + TILING_DATA_FIELD_DEF(float, avg_factor); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(AddRmsNormCustom, AddRmsNormTilingData) +} +#endif // ADD_RMS_NORM_CUSTOM_TILING_H diff --git a/ccsrc/ops/ascendc/kernel_impl/op_kernel/add_rms_norm_custom.cpp b/ccsrc/ops/ascendc/kernel_impl/op_kernel/add_rms_norm_custom.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a689fa84ba49fde0478ed3558b688f2c581c5b61 --- /dev/null +++ b/ccsrc/ops/ascendc/kernel_impl/op_kernel/add_rms_norm_custom.cpp @@ -0,0 +1,1030 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file add_rms_norm.cpp + * \brief + */ + +#include "kernel_operator.h" + +using namespace AscendC; + +#ifdef __CCE_KT_TEST__ +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif + +#if __CCE_AICORE__ != 220 +#define bfloat16_t int16_t +#endif +constexpr int32_t BUFFER_NUM = 1; // tensor num for each queue +constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float); +constexpr int32_t NUM_PER_BLK_FP32 = 8; +constexpr float MINUS_HALF = -0.5; +constexpr float ZERO = 0; +constexpr float ONE = 1; + +template __aicore__ inline T CeilDiv(T x, T y) { return y == 0 ? x : (x + y - 1) / y; } + +template struct integral_constant { static constexpr Tp value = v; }; +using true_type = integral_constant; +using false_type = integral_constant; +template struct is_same : public false_type {}; +template struct is_same : public true_type {}; + +__aicore__ inline void ReduceSumFP32(const LocalTensor &dst_local, + const LocalTensor &src_local, + const LocalTensor &work_local, int32_t count) { + // count need smaller than 255 repeat + if (g_coreType == AIV) { + uint64_t mask = NUM_PER_REP_FP32; + int32_t repeatTimes = count / NUM_PER_REP_FP32; + int32_t tailCount = count % NUM_PER_REP_FP32; + int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32; + BinaryRepeatParams repeatParams; + repeatParams.src0RepStride = ONE_REPEAT_BYTE_SIZE / ONE_BLK_SIZE; + repeatParams.src0BlkStride = 1; + repeatParams.src1RepStride = 0; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = 0; + repeatParams.dstBlkStride = 1; + Duplicate(work_local, ZERO, NUM_PER_REP_FP32); + pipe_barrier(PIPE_V); + if (likely(repeatTimes > 0)) { + Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams); + pipe_barrier(PIPE_V); + } + if (unlikely(tailCount != 0)) { + Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams); + pipe_barrier(PIPE_V); + } + AscendCUtils::SetMask(NUM_PER_REP_FP32); + vcadd((__ubuf__ float *)dst_local.GetPhyAddr(), (__ubuf__ float *)work_local.GetPhyAddr(), 1, 0, + 1, 0, false); + pipe_barrier(PIPE_V); + } +} + +__aicore__ inline void ReduceSumCustom(const LocalTensor &dst_local, + const LocalTensor &src_local, + const LocalTensor &work_local, int32_t count) { +#if __CCE_AICORE__ == 220 + ReduceSumFP32(dst_local, src_local, work_local, count); +#else + ReduceSum(dst_local, src_local, dst_local, count); +#endif +} + +__aicore__ inline void ReduceSumFP32ToBlock(const LocalTensor &dst_local, + const LocalTensor &src_local, + const LocalTensor &work_local, int32_t count) { + // count need smaller than 255 repeat + uint64_t mask = NUM_PER_REP_FP32; + int32_t repeatTimes = count / NUM_PER_REP_FP32; + int32_t tailCount = count % NUM_PER_REP_FP32; + int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32; + BinaryRepeatParams repeatParams; + repeatParams.src0RepStride = ONE_REPEAT_BYTE_SIZE / ONE_BLK_SIZE; + repeatParams.src0BlkStride = 1; + repeatParams.src1RepStride = 0; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = 0; + repeatParams.dstBlkStride = 1; + Duplicate(work_local, ZERO, NUM_PER_REP_FP32); + pipe_barrier(PIPE_V); + if (likely(repeatTimes > 0)) { + Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams); + pipe_barrier(PIPE_V); + } + if (unlikely(tailCount != 0)) { + Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams); + pipe_barrier(PIPE_V); + } + BlockReduceSum(dst_local, work_local, 1, mask, 1, 1, DEFAULT_REPEAT_STRIDE); + pipe_barrier(PIPE_V); +} + +__aicore__ inline void BlockReduceSumFP32(const LocalTensor &dst_local, + const LocalTensor &src_local, int32_t count) { + // count need multiple of 8 + int32_t repeatTimes = count / NUM_PER_REP_FP32; + int32_t tailCount = count % NUM_PER_REP_FP32; + int32_t dstAddr = repeatTimes * 8; + int32_t srcAddr = repeatTimes * NUM_PER_REP_FP32; + if (likely(repeatTimes > 0)) { + BlockReduceSum(dst_local, src_local, repeatTimes, NUM_PER_REP_FP32, 1, 1, + DEFAULT_REPEAT_STRIDE); + pipe_barrier(PIPE_V); + } + if (tailCount != 0) { + BlockReduceSum(dst_local[dstAddr], src_local[srcAddr], 1, tailCount, 1, 1, + DEFAULT_REPEAT_STRIDE); + pipe_barrier(PIPE_V); + } +} + +template +__aicore__ inline void DataCopyCustom(const U &dstTensor, const R &srcTensor, + const uint32_t count) { +#if __CCE_AICORE__ == 220 + DataCopyParams copyParams; + copyParams.blockLen = count * sizeof(T); + copyParams.blockCount = 1; + if constexpr (is_same>::value) { + DataCopyPadParams padParams; + DataCopyPad(dstTensor, srcTensor, copyParams, padParams); + } else { + DataCopyPad(dstTensor, srcTensor, copyParams); + } +#else + // only support count greater than 32byte + int32_t numPerBlock = ONE_BLK_SIZE / sizeof(T); + if (count % numPerBlock == 0) { + DataCopy(dstTensor, srcTensor, count); + } else { + if constexpr (is_same>::value) { + int32_t num = AlignUp(count, numPerBlock); + DataCopy(dstTensor, srcTensor, num); + } else { + int32_t num = count / numPerBlock * numPerBlock; + DataCopy(dstTensor, srcTensor, num); + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + for (int32_t i = 0; i < numPerBlock; i++) { + T tensorValue = srcTensor.GetValue(count - numPerBlock + i); + srcTensor.SetValue(i, tensorValue); + } + set_flag(PIPE_S, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID0); + DataCopy(dstTensor[count - numPerBlock], srcTensor, numPerBlock); + } + } +#endif +} + +template class KernelAddRmsNorm { +public: + __aicore__ inline KernelAddRmsNorm() {} + __aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR y, GM_ADDR rstd, + GM_ADDR x, uint32_t numRow, uint32_t numCol, uint32_t blockFactor, + uint32_t rowFactor, uint32_t ubFactor, float epsilon, + bool is_cast_gamma = false) { + ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!"); + this->numRow = numRow; + this->numCol = numCol; + this->blockFactor = blockFactor; + this->rowFactor = rowFactor; + this->ubFactor = ubFactor; + this->epsilon = epsilon; + this->avgFactor = (float)1.0 / numCol; + this->is_cast_gamma = is_cast_gamma; + + if (GetBlockIdx() < GetBlockNum() - 1) { + this->rowWork = blockFactor; + } else if (GetBlockIdx() == GetBlockNum() - 1) { + this->rowWork = numRow - (GetBlockNum() - 1) * blockFactor; + } else { + } + // get start index for current core, core parallel + x1Gm.SetGlobalBuffer((__gm__ T *)x1 + GetBlockIdx() * blockFactor * numCol, rowWork * numCol); + x2Gm.SetGlobalBuffer((__gm__ T *)x2 + GetBlockIdx() * blockFactor * numCol, rowWork * numCol); + if (is_cast_gamma) { + gammaGmFp32.SetGlobalBuffer((__gm__ float *)gamma, numCol); + } else { + gammaGm.SetGlobalBuffer((__gm__ T *)gamma, numCol); + } + yGm.SetGlobalBuffer((__gm__ T *)y + GetBlockIdx() * blockFactor * numCol, rowWork * numCol); + rstdGm.SetGlobalBuffer((__gm__ float *)rstd + GetBlockIdx() * blockFactor, blockFactor); + xGm.SetGlobalBuffer((__gm__ T *)x + GetBlockIdx() * blockFactor * numCol, rowWork * numCol); + + // pipe alloc memory to queue, the unit is Bytes + pipe.InitBuffer(inQueueX, BUFFER_NUM, ubFactor * sizeof(T)); + if (is_cast_gamma) { + pipe.InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(float)); + } else { + pipe.InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T)); + } + pipe.InitBuffer(outQueueY, BUFFER_NUM, ubFactor * sizeof(T)); + pipe.InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float)); + + if constexpr (is_same::value || is_same::value) { + pipe.InitBuffer(xFp32Buf, ubFactor * sizeof(float)); + } + pipe.InitBuffer(sqxBuf, ubFactor * sizeof(float)); + pipe.InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float)); + } + + __aicore__ inline void Process() { + CopyInGamma(); + uint32_t i_o_max = CeilDiv(rowWork, rowFactor); + uint32_t row_tail = rowWork - (i_o_max - 1) * rowFactor; + if (is_cast_gamma) { + LocalTensor gammaLocal = inQueueGamma.DeQue(); + // SubProcess(0, rowFactor, gammaLocal); + for (uint32_t i_o = 0; i_o < i_o_max - 1; i_o++) { + SubProcessFp32(i_o, rowFactor, gammaLocal); + } + SubProcessFp32(i_o_max - 1, row_tail, gammaLocal); + inQueueGamma.FreeTensor(gammaLocal); + } else { + LocalTensor gammaLocal = inQueueGamma.DeQue(); + // SubProcess(0, rowFactor, gammaLocal); + for (uint32_t i_o = 0; i_o < i_o_max - 1; i_o++) { + SubProcess(i_o, rowFactor, gammaLocal); + } + SubProcess(i_o_max - 1, row_tail, gammaLocal); + inQueueGamma.FreeTensor(gammaLocal); + } + } + + __aicore__ inline void SubProcess(uint32_t i_o, uint32_t calc_row_num, + LocalTensor &gammaLocal) { + LocalTensor rstdLocal = outQueueRstd.AllocTensor(); + for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) { + uint32_t gm_bias = (i_o * rowFactor + i_i) * numCol; + CopyIn(gm_bias); + Compute(i_i, gammaLocal, rstdLocal); + CopyOutY(gm_bias); + } + outQueueRstd.EnQue(rstdLocal); + CopyOutRstd(i_o, calc_row_num); + } + + __aicore__ inline void SubProcessFp32(uint32_t i_o, uint32_t calc_row_num, + LocalTensor &gammaLocal) { + LocalTensor rstdLocal = outQueueRstd.AllocTensor(); + for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) { + uint32_t gm_bias = (i_o * rowFactor + i_i) * numCol; + CopyIn(gm_bias); + ComputeFp32(i_i, gammaLocal, rstdLocal); + CopyOutY(gm_bias); + } + outQueueRstd.EnQue(rstdLocal); + CopyOutRstd(i_o, calc_row_num); + } + +private: + __aicore__ inline void CopyIn(uint32_t gm_bias) { + LocalTensor x1Local_in = inQueueX.AllocTensor(); + LocalTensor x2Local = sqxBuf.Get(); + LocalTensor xLocal = outQueueY.AllocTensor(); + + if constexpr (is_same::value || is_same::value) { + x2Local = x2Local[ubFactor]; + } + + DataCopyCustom(x1Local_in, x1Gm[gm_bias], numCol); + DataCopyCustom(x2Local, x2Gm[gm_bias], numCol); + inQueueX.EnQue(x1Local_in); + auto x1Local = inQueueX.DeQue(); + + if constexpr (is_same::value) { + LocalTensor x1_fp32 = xFp32Buf.Get(); + Add(xLocal, x1Local, x2Local, numCol); + pipe_barrier(PIPE_V); + Cast(x1_fp32, xLocal, RoundMode::CAST_NONE, numCol); + pipe_barrier(PIPE_V); + } else if constexpr (is_same::value) { + LocalTensor x1_fp32 = xFp32Buf.Get(); + LocalTensor x2_fp32 = sqxBuf.Get(); + Cast(x1_fp32, x1Local, RoundMode::CAST_NONE, numCol); + Cast(x2_fp32, x2Local, RoundMode::CAST_NONE, numCol); + pipe_barrier(PIPE_V); + Add(x1_fp32, x1_fp32, x2_fp32, numCol); + pipe_barrier(PIPE_V); + Cast(xLocal, x1_fp32, RoundMode::CAST_RINT, numCol); + pipe_barrier(PIPE_V); + + // cast for precision issue + Cast(x1_fp32, xLocal, RoundMode::CAST_NONE, numCol); + pipe_barrier(PIPE_V); + } else { + Add(x1Local, x1Local, x2Local, numCol); + pipe_barrier(PIPE_V); + Adds(xLocal, x1Local, (float)0, numCol); + } + inQueueX.FreeTensor(x1Local); + + // CopyOut x1 + x2 + outQueueY.EnQue(xLocal); + auto x_out = outQueueY.DeQue(); + DataCopyCustom(xGm[gm_bias], x_out, numCol); + outQueueY.FreeTensor(x_out); + } + + __aicore__ inline void CopyInGamma() { + if (is_cast_gamma) { + LocalTensor gammaLocal = inQueueGamma.AllocTensor(); + DataCopyCustom(gammaLocal, gammaGmFp32, numCol); + inQueueGamma.EnQue(gammaLocal); + } else { + LocalTensor gammaLocal = inQueueGamma.AllocTensor(); + DataCopyCustom(gammaLocal, gammaGm, numCol); + inQueueGamma.EnQue(gammaLocal); + } + } + + __aicore__ inline void Compute(uint32_t inner_progress, LocalTensor gammaLocal, + LocalTensor rstdLocal) { + LocalTensor xLocal = inQueueX.AllocTensor(); + LocalTensor sqx = sqxBuf.Get(); + LocalTensor reduce_buf_local = reduceFp32Buf.Get(); + Mul(sqx, xLocal, xLocal, numCol); + pipe_barrier(PIPE_V); + + Muls(sqx, sqx, avgFactor, numCol); + pipe_barrier(PIPE_V); + + ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol); + pipe_barrier(PIPE_V); + Adds(sqx, sqx, epsilon, 1); + pipe_barrier(PIPE_V); + + Sqrt(sqx, sqx, 1); + Duplicate(reduce_buf_local, ONE, 1); + pipe_barrier(PIPE_V); + Div(sqx, reduce_buf_local, sqx, 1); + pipe_barrier(PIPE_V); + event_t event_v_s = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + set_flag(PIPE_V, PIPE_S, event_v_s); + wait_flag(PIPE_V, PIPE_S, event_v_s); + float rstd_value = sqx.GetValue(0); + event_t event_s_v = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + set_flag(PIPE_S, PIPE_V, event_s_v); + wait_flag(PIPE_S, PIPE_V, event_s_v); + rstdLocal.SetValue(inner_progress, rstd_value); + pipe_barrier(PIPE_V); + LocalTensor yLocal = outQueueY.AllocTensor(); + Muls(yLocal, xLocal, rstd_value, numCol); + inQueueX.FreeTensor(xLocal); + pipe_barrier(PIPE_V); + Mul(yLocal, gammaLocal, yLocal, numCol); + pipe_barrier(PIPE_V); + outQueueY.EnQue(yLocal); + } + + __aicore__ inline void Compute(uint32_t inner_progress, LocalTensor gammaLocal, + LocalTensor rstdLocal) { + LocalTensor x_fp32 = xFp32Buf.Get(); + LocalTensor sqx = sqxBuf.Get(); + LocalTensor reduce_buf_local = reduceFp32Buf.Get(); + + Mul(sqx, x_fp32, x_fp32, numCol); + pipe_barrier(PIPE_V); + + Muls(sqx, sqx, avgFactor, numCol); + pipe_barrier(PIPE_V); + ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol); + pipe_barrier(PIPE_V); + + Adds(sqx, sqx, epsilon, 1); + pipe_barrier(PIPE_V); + + Sqrt(sqx, sqx, 1); + Duplicate(reduce_buf_local, ONE, 1); + pipe_barrier(PIPE_V); + Div(sqx, reduce_buf_local, sqx, 1); + pipe_barrier(PIPE_V); + event_t event_v_s = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + set_flag(PIPE_V, PIPE_S, event_v_s); + wait_flag(PIPE_V, PIPE_S, event_v_s); + float rstd_value = sqx.GetValue(0); + event_t event_s_v = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + set_flag(PIPE_S, PIPE_V, event_s_v); + wait_flag(PIPE_S, PIPE_V, event_s_v); + rstdLocal.SetValue(inner_progress, rstd_value); + pipe_barrier(PIPE_V); + Muls(x_fp32, x_fp32, rstd_value, numCol); + pipe_barrier(PIPE_V); + LocalTensor yLocal = outQueueY.AllocTensor(); + Cast(yLocal, x_fp32, RoundMode::CAST_RINT, numCol); + pipe_barrier(PIPE_V); + Cast(x_fp32, yLocal, RoundMode::CAST_NONE, numCol); + pipe_barrier(PIPE_V); + Cast(sqx, gammaLocal, RoundMode::CAST_NONE, numCol); // gamma_fp32 reuse sqx + pipe_barrier(PIPE_V); + Mul(x_fp32, x_fp32, sqx, numCol); + pipe_barrier(PIPE_V); + Cast(yLocal, x_fp32, RoundMode::CAST_RINT, numCol); + pipe_barrier(PIPE_V); + + event_t event_v_mte = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2)); + set_flag(PIPE_V, PIPE_MTE2, event_v_mte); + wait_flag(PIPE_V, PIPE_MTE2, event_v_mte); + + outQueueY.EnQue(yLocal); + } + + __aicore__ inline void Compute(uint32_t inner_progress, LocalTensor gammaLocal, + LocalTensor rstdLocal) { + LocalTensor x_fp32 = xFp32Buf.Get(); + LocalTensor sqx = sqxBuf.Get(); + LocalTensor reduce_buf_local = reduceFp32Buf.Get(); + + Mul(sqx, x_fp32, x_fp32, numCol); + pipe_barrier(PIPE_V); + + Muls(sqx, sqx, avgFactor, numCol); + pipe_barrier(PIPE_V); + + ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol); + pipe_barrier(PIPE_V); + + Adds(sqx, sqx, epsilon, 1); + pipe_barrier(PIPE_V); + + Sqrt(sqx, sqx, 1); + Duplicate(reduce_buf_local, ONE, 1); + pipe_barrier(PIPE_V); + Div(sqx, reduce_buf_local, sqx, 1); + pipe_barrier(PIPE_V); + event_t event_v_s = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + set_flag(PIPE_V, PIPE_S, event_v_s); + wait_flag(PIPE_V, PIPE_S, event_v_s); + float rstd_value = sqx.GetValue(0); + event_t event_s_v = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + set_flag(PIPE_S, PIPE_V, event_s_v); + wait_flag(PIPE_S, PIPE_V, event_s_v); + rstdLocal.SetValue(inner_progress, rstd_value); + pipe_barrier(PIPE_V); + Muls(x_fp32, x_fp32, rstd_value, numCol); + pipe_barrier(PIPE_V); + LocalTensor yLocal = outQueueY.AllocTensor(); + Cast(yLocal, x_fp32, RoundMode::CAST_NONE, numCol); + + event_t event_v_mte = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2)); + set_flag(PIPE_V, PIPE_MTE2, event_v_mte); + wait_flag(PIPE_V, PIPE_MTE2, event_v_mte); + + pipe_barrier(PIPE_V); + Mul(yLocal, gammaLocal, yLocal, numCol); + pipe_barrier(PIPE_V); + outQueueY.EnQue(yLocal); + } + + __aicore__ inline void ComputeFp32(uint32_t inner_progress, LocalTensor gammaLocal, + LocalTensor rstdLocal) { + LocalTensor x_fp32 = xFp32Buf.Get(); + LocalTensor sqx = sqxBuf.Get(); + LocalTensor reduce_buf_local = reduceFp32Buf.Get(); + + Mul(sqx, x_fp32, x_fp32, numCol); + pipe_barrier(PIPE_V); + + Muls(sqx, sqx, avgFactor, numCol); + pipe_barrier(PIPE_V); + + ReduceSumCustom(sqx, sqx, reduce_buf_local, numCol); + pipe_barrier(PIPE_V); + + Adds(sqx, sqx, epsilon, 1); + pipe_barrier(PIPE_V); + + Sqrt(sqx, sqx, 1); + Duplicate(reduce_buf_local, ONE, 1); + pipe_barrier(PIPE_V); + Div(sqx, reduce_buf_local, sqx, 1); + pipe_barrier(PIPE_V); + + event_t event_v_s = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + set_flag(PIPE_V, PIPE_S, event_v_s); + wait_flag(PIPE_V, PIPE_S, event_v_s); + float rstd_value = sqx.GetValue(0); + event_t event_s_v = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + set_flag(PIPE_S, PIPE_V, event_s_v); + wait_flag(PIPE_S, PIPE_V, event_s_v); + rstdLocal.SetValue(inner_progress, rstd_value); + pipe_barrier(PIPE_V); + Muls(x_fp32, x_fp32, rstd_value, numCol); + pipe_barrier(PIPE_V); + Mul(x_fp32, x_fp32, gammaLocal, numCol); + pipe_barrier(PIPE_V); + if (is_same::value) { + LocalTensor yLocal = outQueueY.AllocTensor(); + + Cast(yLocal, x_fp32, RoundMode::CAST_NONE, numCol); + + event_t event_v_mte = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2)); + set_flag(PIPE_V, PIPE_MTE2, event_v_mte); + wait_flag(PIPE_V, PIPE_MTE2, event_v_mte); + pipe_barrier(PIPE_V); + + outQueueY.EnQue(yLocal); + } else { + LocalTensor yLocal = outQueueY.AllocTensor(); + + Cast(yLocal, x_fp32, RoundMode::CAST_RINT, numCol); + pipe_barrier(PIPE_V); + + event_t event_v_mte = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE2)); + set_flag(PIPE_V, PIPE_MTE2, event_v_mte); + wait_flag(PIPE_V, PIPE_MTE2, event_v_mte); + pipe_barrier(PIPE_V); + + outQueueY.EnQue(yLocal); + } + } + + __aicore__ inline void CopyOutY(uint32_t progress) { + LocalTensor yLocal = outQueueY.DeQue(); + DataCopyCustom(yGm[progress], yLocal, numCol); + outQueueY.FreeTensor(yLocal); + } + + __aicore__ inline void CopyOutRstd(uint32_t outer_progress, uint32_t num) { + LocalTensor rstdLocal = outQueueRstd.DeQue(); + // #if __CCE_AICORE__ == 220 + // DataCopyCustom(rstdGm[outer_progress * rowFactor], rstdLocal, num); + // #endif + outQueueRstd.FreeTensor(rstdLocal); + } + +private: + TPipe pipe; + // create queues for input, in this case depth is equal to buffer num + TQue inQueueX, inQueueGamma; + // create queues for output, in this case depth is equal to buffer num + TQue outQueueY, outQueueRstd; + + TBuf xFp32Buf; + TBuf sqxBuf; + TBuf reduceFp32Buf; + GlobalTensor x1Gm; + GlobalTensor x2Gm; + GlobalTensor gammaGm; + GlobalTensor gammaGmFp32; + GlobalTensor yGm; + GlobalTensor rstdGm; + GlobalTensor xGm; + + uint32_t numRow; + uint32_t numCol; + uint32_t blockFactor; // number of calculations rows on each core + uint32_t rowFactor; + uint32_t ubFactor; + float epsilon; + float avgFactor; + bool is_cast_gamma; + + uint32_t rowWork = 1; +}; + +template class KernelAddRmsNormSplitD { +public: + __aicore__ inline KernelAddRmsNormSplitD() {} + __aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, GM_ADDR y, GM_ADDR rstd, + GM_ADDR x, GM_ADDR workspace, uint32_t numRow, uint32_t numCol, + uint32_t blockFactor, uint32_t rowFactor, uint32_t ubFactor, + float epsilon, bool is_cast_gamma = false) { + ASSERT(GetBlockNum() != 0 && "Block dim can not be zero!"); + this->numRow = numRow; + this->numCol = numCol; + this->blockFactor = blockFactor; + this->rowFactor = rowFactor; + this->ubFactor = ubFactor; + this->epsilon = epsilon; + this->avgFactor = (float)1.0 / numCol; + this->is_cast_gamma = is_cast_gamma; + + if (GetBlockIdx() < GetBlockNum() - 1) { + this->rowWork = blockFactor; + } else if (GetBlockIdx() == GetBlockNum() - 1) { + this->rowWork = numRow - (GetBlockNum() - 1) * blockFactor; + } else { + } + // get start index for current core, core parallel + x1Gm.SetGlobalBuffer((__gm__ T *)x1 + GetBlockIdx() * blockFactor * numCol, rowWork * numCol); + x2Gm.SetGlobalBuffer((__gm__ T *)x2 + GetBlockIdx() * blockFactor * numCol, rowWork * numCol); + if (is_cast_gamma) { + gammaGmFp32.SetGlobalBuffer((__gm__ float *)gamma, numCol); + } else { + gammaGm.SetGlobalBuffer((__gm__ T *)gamma, numCol); + } + yGm.SetGlobalBuffer((__gm__ T *)y + GetBlockIdx() * blockFactor * numCol, rowWork * numCol); + rstdGm.SetGlobalBuffer((__gm__ float *)rstd + GetBlockIdx() * blockFactor, blockFactor); + xGm.SetGlobalBuffer((__gm__ T *)x + GetBlockIdx() * blockFactor * numCol, rowWork * numCol); + + // pipe alloc memory to queue, the unit is Bytes. + // We need 2 buffers here for both x1 and x2. + pipe.InitBuffer(inQueueX, BUFFER_NUM, 2 * ubFactor * sizeof(T)); + if (is_cast_gamma) { + pipe.InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(float)); + } else { + pipe.InitBuffer(inQueueGamma, BUFFER_NUM, ubFactor * sizeof(T)); + } + pipe.InitBuffer(outQueueY, BUFFER_NUM, ubFactor * sizeof(T)); + pipe.InitBuffer(outQueueRstd, BUFFER_NUM, rowFactor * sizeof(float)); + + if constexpr (is_same::value || is_same::value) { + pipe.InitBuffer(xFp32Buf, ubFactor * sizeof(float)); + } + pipe.InitBuffer(sqxBuf, ubFactor * sizeof(float)); + pipe.InitBuffer(sumBuf, rowFactor * NUM_PER_BLK_FP32 * sizeof(float)); + pipe.InitBuffer(reduceFp32Buf, NUM_PER_REP_FP32 * sizeof(float)); + } + + __aicore__ inline void Process() { + uint32_t i_o_max = CeilDiv(rowWork, rowFactor); + uint32_t row_tail = rowWork - (i_o_max - 1) * rowFactor; + uint32_t j_max = CeilDiv(numCol, ubFactor); + uint32_t col_tail = numCol - (j_max - 1) * ubFactor; + for (uint32_t i_o = 0; i_o < i_o_max - 1; i_o++) { + SubProcess(i_o, rowFactor, j_max, col_tail); + } + SubProcess(i_o_max - 1, row_tail, j_max, col_tail); + } + + __aicore__ inline void SubProcess(uint32_t i_o, uint32_t calc_row_num, uint32_t j_max, + uint32_t col_tail) { + LocalTensor sumLocal = sumBuf.Get(); + + LocalTensor rstdLocal = outQueueRstd.AllocTensor(); + Duplicate(rstdLocal, (float)0.0, calc_row_num); + pipe_barrier(PIPE_V); + for (uint32_t j = 0; j < j_max - 1; j++) { + ComputeFormer(i_o, calc_row_num, j, rstdLocal, sumLocal, ubFactor); + } + // do tail + ComputeFormer(i_o, calc_row_num, j_max - 1, rstdLocal, sumLocal, col_tail); + ComputeRstd(rstdLocal, calc_row_num); + + for (uint32_t j = 0; j < j_max - 1; j++) { + ComputeLatter(i_o, calc_row_num, j, rstdLocal, ubFactor); + } + ComputeLatter(i_o, calc_row_num, j_max - 1, rstdLocal, col_tail); + outQueueRstd.EnQue(rstdLocal); + CopyOutRstd(i_o, calc_row_num); + } + +private: + __aicore__ inline void CopyInAndAdd(uint32_t i_idx, uint32_t j_idx, uint32_t num) { + LocalTensor x1x2_in = inQueueX.AllocTensor(); + LocalTensor x1_in = x1x2_in[0]; + LocalTensor x2_in = x1x2_in[ubFactor]; + DataCopyCustom(x1_in, x1Gm[i_idx * numCol + j_idx * ubFactor], num); + DataCopyCustom(x2_in, x2Gm[i_idx * numCol + j_idx * ubFactor], num); + inQueueX.EnQue(x1x2_in); + LocalTensor x1x2Local = inQueueX.DeQue(); + + auto x1Local = x1x2Local[0]; + auto x2Local = x1x2Local[ubFactor]; + + LocalTensor xLocal = outQueueY.AllocTensor(); + + if constexpr (is_same::value) { + LocalTensor x1_fp32 = xFp32Buf.Get(); + + Add(xLocal, x1Local, x2Local, num); + pipe_barrier(PIPE_V); + Cast(x1_fp32, xLocal, RoundMode::CAST_NONE, num); + pipe_barrier(PIPE_V); + // x1+x2 saved in x1_fp32 + } else if constexpr (is_same::value) { + LocalTensor x1_fp32 = xFp32Buf.Get(); + LocalTensor x2_fp32 = x1x2Local.template ReinterpretCast(); + + Cast(x1_fp32, x1Local, RoundMode::CAST_NONE, num); + pipe_barrier(PIPE_V); + Cast(x2_fp32, x2Local, RoundMode::CAST_NONE, num); + pipe_barrier(PIPE_V); + + Add(x1_fp32, x1_fp32, x2_fp32, num); + pipe_barrier(PIPE_V); + Cast(xLocal, x1_fp32, RoundMode::CAST_RINT, num); + pipe_barrier(PIPE_V); + + // cast for precision issue + Cast(x1_fp32, xLocal, RoundMode::CAST_NONE, num); + pipe_barrier(PIPE_V); + // x1+x2 saved in x1_fp32 + } else { + Add(x1Local, x1Local, x2Local, num); + pipe_barrier(PIPE_V); + Adds(xLocal, x1Local, (float)0.0, num); + // x1+x2 saved in inQueueX + } + inQueueX.FreeTensor(x1x2Local); + + // copy out to workspace && x_out + outQueueY.EnQue(xLocal); + auto x_out = outQueueY.DeQue(); + DataCopyCustom(xGm[i_idx * numCol + j_idx * ubFactor], x_out, num); + outQueueY.FreeTensor(x_out); + } + + __aicore__ inline void ComputeFormer(uint32_t i_o_idx, uint32_t calc_row_num, uint32_t j_idx, + LocalTensor &rstdLocal, LocalTensor &sumLocal, + uint32_t num) { + for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) { + CopyInAndAdd(i_o_idx * rowFactor + i_i, j_idx, num); + ComputeSum(i_i, sumLocal, num); + } + BlockReduceSumFP32(sumLocal, sumLocal, calc_row_num * NUM_PER_BLK_FP32); + Add(rstdLocal, rstdLocal, sumLocal, calc_row_num); + pipe_barrier(PIPE_V); + } + + __aicore__ inline void ComputeSum(uint32_t i_i_idx, LocalTensor &sumLocal, uint32_t num) { + LocalTensor sqx = sqxBuf.Get(); + LocalTensor reduce_buf_local = reduceFp32Buf.Get(); + if constexpr (is_same::value || is_same::value) { + LocalTensor x_fp32 = xFp32Buf.Get(); + pipe_barrier(PIPE_V); + Mul(sqx, x_fp32, x_fp32, num); + } else { + LocalTensor xLocal = inQueueX.AllocTensor(); + pipe_barrier(PIPE_V); + Mul(sqx, xLocal, xLocal, num); + inQueueX.FreeTensor(xLocal); + } + pipe_barrier(PIPE_V); + Muls(sqx, sqx, avgFactor, num); + pipe_barrier(PIPE_V); + // 8 means 8 fp32 pre block + ReduceSumFP32ToBlock(sumLocal[i_i_idx * 8], sqx, reduce_buf_local, num); + } + + __aicore__ inline void ComputeRstd(LocalTensor rstdLocal, uint32_t num) { + LocalTensor reduce_buf_local = reduceFp32Buf.Get(); + Adds(rstdLocal, rstdLocal, epsilon, num); + pipe_barrier(PIPE_V); + Sqrt(rstdLocal, rstdLocal, num); + Duplicate(reduce_buf_local, ONE, num); + pipe_barrier(PIPE_V); + Div(rstdLocal, reduce_buf_local, rstdLocal, num); + pipe_barrier(PIPE_V); + } + + __aicore__ inline void ComputeLatter(uint32_t i_o_idx, uint32_t calc_row_num, uint32_t j_idx, + LocalTensor &rstdLocal, uint32_t num) { + CopyInGamma(j_idx, num); + if (is_cast_gamma) { + LocalTensor gammaLocal = inQueueGamma.DeQue(); + for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) { + CopyInAndAdd(i_o_idx * rowFactor + i_i, j_idx, num); + ComputeYFp32(i_i, gammaLocal, rstdLocal, num); + CopyOutY(i_o_idx * rowFactor + i_i, j_idx, num); + } + inQueueGamma.FreeTensor(gammaLocal); + } else { + LocalTensor gammaLocal = inQueueGamma.DeQue(); + for (uint32_t i_i = 0; i_i < calc_row_num; i_i++) { + CopyInAndAdd(i_o_idx * rowFactor + i_i, j_idx, num); + ComputeY(i_i, gammaLocal, rstdLocal, num); + CopyOutY(i_o_idx * rowFactor + i_i, j_idx, num); + } + inQueueGamma.FreeTensor(gammaLocal); + } + } + + __aicore__ inline void CopyInGamma(uint32_t j_idx, uint32_t num) { + if (is_cast_gamma) { + LocalTensor gammaLocal = inQueueGamma.AllocTensor(); + DataCopyCustom(gammaLocal, gammaGmFp32[j_idx * ubFactor], num); + inQueueGamma.EnQue(gammaLocal); + } else { + LocalTensor gammaLocal = inQueueGamma.AllocTensor(); + DataCopyCustom(gammaLocal, gammaGm[j_idx * ubFactor], num); + inQueueGamma.EnQue(gammaLocal); + } + } + + __aicore__ inline void ComputeY(uint32_t i_i_idx, LocalTensor &gammaLocal, + LocalTensor &rstdLocal, uint32_t num) { + LocalTensor x_fp32 = xFp32Buf.Get(); + LocalTensor sqx = sqxBuf.Get(); + event_t event_v_s = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + set_flag(PIPE_V, PIPE_S, event_v_s); + wait_flag(PIPE_V, PIPE_S, event_v_s); + float rstd_value = rstdLocal.GetValue(i_i_idx); + event_t event_s_v = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + set_flag(PIPE_S, PIPE_V, event_s_v); + wait_flag(PIPE_S, PIPE_V, event_s_v); + pipe_barrier(PIPE_V); + Muls(x_fp32, x_fp32, rstd_value, num); + pipe_barrier(PIPE_V); + LocalTensor yLocal = outQueueY.AllocTensor(); + Cast(yLocal, x_fp32, RoundMode::CAST_NONE, num); + pipe_barrier(PIPE_V); + Mul(yLocal, gammaLocal, yLocal, num); + pipe_barrier(PIPE_V); + outQueueY.EnQue(yLocal); + } + + __aicore__ inline void ComputeY(uint32_t i_i_idx, LocalTensor &gammaLocal, + LocalTensor &rstdLocal, uint32_t num) { + LocalTensor xLocal = inQueueX.AllocTensor(); // inQueueX.DeQue(); + LocalTensor sqx = sqxBuf.Get(); + event_t event_v_s = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + set_flag(PIPE_V, PIPE_S, event_v_s); + wait_flag(PIPE_V, PIPE_S, event_v_s); + float rstd_value = rstdLocal.GetValue(i_i_idx); + event_t event_s_v = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + set_flag(PIPE_S, PIPE_V, event_s_v); + wait_flag(PIPE_S, PIPE_V, event_s_v); + LocalTensor yLocal = outQueueY.AllocTensor(); + Muls(yLocal, xLocal, rstd_value, num); + inQueueX.FreeTensor(xLocal); + pipe_barrier(PIPE_V); + Mul(yLocal, gammaLocal, yLocal, num); + pipe_barrier(PIPE_V); + outQueueY.EnQue(yLocal); + } + + __aicore__ inline void ComputeY(uint32_t i_i_idx, LocalTensor &gammaLocal, + LocalTensor &rstdLocal, uint32_t num) { + LocalTensor x_fp32 = xFp32Buf.Get(); + LocalTensor sqx = sqxBuf.Get(); + event_t event_v_s = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + set_flag(PIPE_V, PIPE_S, event_v_s); + wait_flag(PIPE_V, PIPE_S, event_v_s); + float rstd_value = rstdLocal.GetValue(i_i_idx); + event_t event_s_v = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + set_flag(PIPE_S, PIPE_V, event_s_v); + wait_flag(PIPE_S, PIPE_V, event_s_v); + pipe_barrier(PIPE_V); + Muls(x_fp32, x_fp32, rstd_value, num); + pipe_barrier(PIPE_V); + LocalTensor yLocal = outQueueY.AllocTensor(); + Cast(yLocal, x_fp32, RoundMode::CAST_RINT, num); + pipe_barrier(PIPE_V); + Cast(x_fp32, yLocal, RoundMode::CAST_NONE, num); + pipe_barrier(PIPE_V); + Cast(sqx, gammaLocal, RoundMode::CAST_NONE, num); + pipe_barrier(PIPE_V); + Mul(x_fp32, x_fp32, sqx, num); + pipe_barrier(PIPE_V); + Cast(yLocal, x_fp32, RoundMode::CAST_RINT, num); + pipe_barrier(PIPE_V); + outQueueY.EnQue(yLocal); + } + + __aicore__ inline void ComputeYFp32(uint32_t i_i_idx, LocalTensor &gammaLocal, + LocalTensor &rstdLocal, uint32_t num) { + LocalTensor x_fp32 = xFp32Buf.Get(); + LocalTensor sqx = sqxBuf.Get(); + event_t event_v_s = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_S)); + set_flag(PIPE_V, PIPE_S, event_v_s); + wait_flag(PIPE_V, PIPE_S, event_v_s); + float rstd_value = rstdLocal.GetValue(i_i_idx); + event_t event_s_v = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + set_flag(PIPE_S, PIPE_V, event_s_v); + wait_flag(PIPE_S, PIPE_V, event_s_v); + pipe_barrier(PIPE_V); + Muls(x_fp32, x_fp32, rstd_value, num); + pipe_barrier(PIPE_V); + Mul(x_fp32, gammaLocal, x_fp32, num); + pipe_barrier(PIPE_V); + if (is_same::value) { + LocalTensor yLocal = outQueueY.AllocTensor(); + Cast(yLocal, x_fp32, RoundMode::CAST_NONE, num); + pipe_barrier(PIPE_V); + outQueueY.EnQue(yLocal); + } else { + LocalTensor yLocal = outQueueY.AllocTensor(); + Cast(yLocal, x_fp32, RoundMode::CAST_RINT, num); + pipe_barrier(PIPE_V); + outQueueY.EnQue(yLocal); + } + } + + __aicore__ inline void CopyOutY(uint32_t i_idx, uint32_t j_idx, uint32_t num) { + LocalTensor yLocal = outQueueY.DeQue(); + pipe_barrier(PIPE_ALL); + DataCopyCustom(yGm[i_idx * numCol + j_idx * ubFactor], yLocal, num); + pipe_barrier(PIPE_ALL); + outQueueY.FreeTensor(yLocal); + } + + __aicore__ inline void CopyOutRstd(uint32_t i_o_idx, uint32_t num) { + LocalTensor rstdLocal = outQueueRstd.DeQue(); +#if __CCE_AICORE__ == 220 + DataCopyCustom(rstdGm[i_o_idx * rowFactor], rstdLocal, num); +#endif + outQueueRstd.FreeTensor(rstdLocal); + } + +private: + TPipe pipe; + // create queues for input, in this case depth is equal to buffer num + TQue inQueueX, inQueueGamma; + // create queues for output, in this case depth is equal to buffer num + TQue outQueueY, outQueueRstd; + TBuf xFp32Buf; + TBuf sqxBuf; + TBuf sumBuf; + TBuf reduceFp32Buf; + + GlobalTensor x1Gm; + GlobalTensor x2Gm; + GlobalTensor gammaGm; + GlobalTensor gammaGmFp32; + GlobalTensor yGm; + GlobalTensor rstdGm; + GlobalTensor xGm; + + uint32_t numRow; + uint32_t numCol; + uint32_t blockFactor; // number of calculations rows on each core + uint32_t rowFactor; + uint32_t ubFactor; + float epsilon; + float avgFactor; + bool is_cast_gamma; + uint32_t rowWork = 1; + + int tempbufNum; +}; + +inline __aicore__ int32_t AlignDiv32(int32_t n) { return ((n + 31) & ~31) / 32; } + +extern "C" __global__ __aicore__ void add_rms_norm_custom(GM_ADDR x1, GM_ADDR x2, GM_ADDR gamma, + GM_ADDR y, GM_ADDR rstd, GM_ADDR x, + GM_ADDR workspace, GM_ADDR tiling) { + GET_TILING_DATA(tilingData, tiling); + GM_ADDR usrWorkspace = AscendC::GetUserWorkspace(workspace); + if (TILING_KEY_IS(10)) { + KernelAddRmsNorm op; + op.Init(x1, x2, gamma, y, rstd, x, tilingData.num_row, tilingData.num_col, + tilingData.block_factor, tilingData.row_factor, tilingData.ub_factor, + tilingData.epsilon); + op.Process(); + } else if (TILING_KEY_IS(20)) { + KernelAddRmsNorm op; + op.Init(x1, x2, gamma, y, rstd, x, tilingData.num_row, tilingData.num_col, + tilingData.block_factor, tilingData.row_factor, tilingData.ub_factor, + tilingData.epsilon); + op.Process(); + } else if (TILING_KEY_IS(30)) { + KernelAddRmsNorm op; + op.Init(x1, x2, gamma, y, rstd, x, tilingData.num_row, tilingData.num_col, + tilingData.block_factor, tilingData.row_factor, tilingData.ub_factor, + tilingData.epsilon); + op.Process(); + } else if (TILING_KEY_IS(11)) { + KernelAddRmsNormSplitD op; + op.Init(x1, x2, gamma, y, rstd, x, usrWorkspace, tilingData.num_row, tilingData.num_col, + tilingData.block_factor, tilingData.row_factor, tilingData.ub_factor, + tilingData.epsilon); + op.Process(); + } else if (TILING_KEY_IS(21)) { + KernelAddRmsNormSplitD op; + op.Init(x1, x2, gamma, y, rstd, x, usrWorkspace, tilingData.num_row, tilingData.num_col, + tilingData.block_factor, tilingData.row_factor, tilingData.ub_factor, + tilingData.epsilon); + op.Process(); + } else if (TILING_KEY_IS(31)) { + KernelAddRmsNormSplitD op; + op.Init(x1, x2, gamma, y, rstd, x, usrWorkspace, tilingData.num_row, tilingData.num_col, + tilingData.block_factor, tilingData.row_factor, tilingData.ub_factor, + tilingData.epsilon); + op.Process(); + } + + if (TILING_KEY_IS(110)) { + KernelAddRmsNorm op; + op.Init(x1, x2, gamma, y, rstd, x, tilingData.num_row, tilingData.num_col, + tilingData.block_factor, tilingData.row_factor, tilingData.ub_factor, + tilingData.epsilon, true); + op.Process(); + } else if (TILING_KEY_IS(130)) { + KernelAddRmsNorm op; + op.Init(x1, x2, gamma, y, rstd, x, tilingData.num_row, tilingData.num_col, + tilingData.block_factor, tilingData.row_factor, tilingData.ub_factor, + tilingData.epsilon, true); + op.Process(); + } else if (TILING_KEY_IS(111)) { + KernelAddRmsNormSplitD op; + op.Init(x1, x2, gamma, y, rstd, x, usrWorkspace, tilingData.num_row, tilingData.num_col, + tilingData.block_factor, tilingData.row_factor, tilingData.ub_factor, + tilingData.epsilon, true); + op.Process(); + } else if (TILING_KEY_IS(131)) { + KernelAddRmsNormSplitD op; + op.Init(x1, x2, gamma, y, rstd, x, usrWorkspace, tilingData.num_row, tilingData.num_col, + tilingData.block_factor, tilingData.row_factor, tilingData.ub_factor, + tilingData.epsilon, true); + op.Process(); + } +} + +void add_rms_norm_custom_do(uint32_t blockDim, void *l2ctrl, void *stream, uint8_t *x1, uint8_t *x2, + uint8_t *gamma, uint8_t *y, uint8_t *rstd, uint8_t *x, + uint8_t *workspace, uint8_t *tiling) { + add_rms_norm_custom<<>>(x1, x2, gamma, y, rstd, x, workspace, tiling); +} diff --git a/tests/st/test_add_rms_norm.py b/tests/st/test_add_rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..366477e558a2cd346832f86f26b4520248645217 --- /dev/null +++ b/tests/st/test_add_rms_norm.py @@ -0,0 +1,62 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" tests_custom_pyboost_ascend """ + +import numpy as np +import mindspore as ms +from mindspore import Tensor, context +import pytest +import ms_custom_ops + +@ms.jit(jit_level="O0", infer_boost="on") +def add_rms_norm(x1, x2, gamma, epsilon=1e-6): + return ms.ops.add_rms_norm(x1, x2, gamma, epsilon) + +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('dtype', [ms.float16, ms.float32, ms.bfloat16]) +@pytest.mark.parametrize('shape', [(1, 1024, 1024)]) +def test_custom_add_rms_norm(exec_mode, dtype, shape): + ms.set_device("Ascend") + + def add_rms_norm_custom(x1, x2, gamma, epsilon=1e-6): + return ms_custom_ops.add_rms_norm(x1, x2, gamma, epsilon) + + if exec_mode == context.GRAPH_MODE: + add_rms_norm_custom = ms.jit(add_rms_norm_custom, jit_level="O0", infer_boost="on") + + x1 = Tensor(np.random.rand(*shape), dtype) + x2 = Tensor(np.random.rand(*shape), dtype) + gamma = Tensor(np.random.rand(*shape), dtype) + eps = 1e-6 + out = add_rms_norm_custom(x1, x2, gamma, eps) + expect = add_rms_norm(x1, x2, gamma, eps) + np.testing.assert_allclose( + out[0].astype(ms.float32).asnumpy(), + expect[0].astype(ms.float32).asnumpy(), + rtol=1e-3, + atol=1e-3, + ) + np.testing.assert_allclose( + out[1].astype(ms.float32).asnumpy(), + expect[1].astype(ms.float32).asnumpy(), + rtol=1e-3, + atol=1e-3, + ) + np.testing.assert_allclose( + out[2].astype(ms.float32).asnumpy(), + expect[2].astype(ms.float32).asnumpy(), + rtol=1e-3, + atol=1e-3, + ) diff --git a/yaml/ascendc/add_rms_norm_op.yaml b/yaml/ascendc/add_rms_norm_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7fed2ee545f89a97312b8b589c2097c1ae40fcc9 --- /dev/null +++ b/yaml/ascendc/add_rms_norm_op.yaml @@ -0,0 +1,19 @@ +#operator add_rms_norm +add_rms_norm: + args: + x1: + dtype: tensor + x2: + dtype: tensor + gamma: + dtype: tensor + epsilon: + dtype: float + default: 1e-6 + returns: + y: + dtype: tensor + rstd: + dtype: tensor + x: + dtype: tensor diff --git a/yaml/doc/add_rms_norm_doc.yaml b/yaml/doc/add_rms_norm_doc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2a5e8fb752f22358a444a41381b34a423716a8f4 --- /dev/null +++ b/yaml/doc/add_rms_norm_doc.yaml @@ -0,0 +1,50 @@ +add_rms_norm: + description: | + The AddRmsNorm is a fusion operator that fusing RmsNorm and its preceding Add operator, reducing the time for + moving data in and out. + It computes the following expression: + + .. math:: + \begin{array}{ll} \\ + x_i = x1_i + x2_i \\ + y_i=RmsNorm(x_i)=\frac{x_i}{\sqrt{\frac{1}{n}\sum_{i=1}^{n}{ x_i^2}+\varepsilon}}\gamma_i + \end{array} + + .. warning:: + This is an experimental API that is subject to change or deletion. This API is only supported in Atlas A2 + training series for now. + + Args: + x1 (Tensor): Input data of AddRmsNorm. Support data type: float16, float32, bfloat16. + x2 (Tensor): Input data of AddRmsNorm. Support data type: float16, float32, bfloat16. + gamma (Tensor): Learnable parameter :math:`\gamma` . Support data type: float16, float32, bfloat16. + epsilon (float, optional): A float number ranged in (0, 1] to prevent division by 0. Default value is `1e-6`. + + Returns: + - Tensor, denotes the normalized result, has the same type and shape as `x1`. + - Tensor, with the float data type, denotes the reciprocal of the input standard deviation, used by gradient + calculation. + - Tensor, the sum of `x1` and `x2`. + + Raises: + TypeError: If data type of `x1` or `x2` is not one of the following: float16, float32, bfloat16. + TypeError: If data type of `gamma` is not one of the following: float16, float32, bfloat16. + ValueError: If `epsilon` is not a float between 0 and 1. + ValueError: If the rank of `gamma` is greater than the rank of `x1` or `x2`. + RuntimeError: If the shapes of `x1` and `x2` are not same. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore + >>> import numpy as np + >>> from mindspore import Tensor + >>> import ms_custom_ops + >>> x1 = Tensor(np.array([[0.5, 1.0, 1.5], [0.5, 1.0, 1.5]]), mindspore.float32) + >>> x2 = Tensor(np.array([[0.5, 1.0, 1.5], [0.5, 1.0, 1.5]]), mindspore.float32) + >>> gamma = Tensor(np.ones([3]), mindspore.float32) + >>> y, _, _ = ms_custom_ops.add_rms_norm(x1, x2, gamma) + >>> print(y) + [[0.46290997 0.92581993 1.3887299] + [0.46290997 0.92581993 1.3887299]]