diff --git a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc new file mode 100644 index 0000000000000000000000000000000000000000..63713de70a41e473079ccd43f2039e413b2fad98 --- /dev/null +++ b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc @@ -0,0 +1,335 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include + +#include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" +#include "ops/framework/utils.h" + +namespace ms_custom_ops { + +constexpr size_t kQbmmMatSize = 2; + +enum class QuantBatchMatmulInternalInputIndex : size_t { + kInputX1Index = 0, + kInputX2Index, + kInputScaleIndex, + kInputOffsetIndex, + kInputBiasIndex, + kInputPertokenScaleIndex, + kInputTransposeX1Index, + kInputTransposeX2Index, + kInputX1FormatIndex, + kInputX2FormatIndex, + kInputOutputFormatIndex, + kInputOutputDtypeIndex, + kInputsNum, +}; + +enum class QuantBatchMatmulInternalOutputIndex : size_t { + kOutputIndex = 0, +}; + +ShapeVector BatchMatMulInternalMakeShape(const ShapeVector x1_shape, const ShapeVector x2_shape, + bool transpose_x1, bool transpose_x2, size_t offset) { + if (x1_shape.size() < kQbmmMatSize || x2_shape.size() < kQbmmMatSize) { + MS_LOG(EXCEPTION) << "For 'QuantBatchMatmulInternal', the dimension of 'x1' and 'x2' " + << "should be at least 2, but got " << x1_shape << " and " << x2_shape; + } + ShapeVector out_shape; + ShapeVector long_shape = x1_shape.size() > x2_shape.size() ? x1_shape : x2_shape; + ShapeVector short_shape = x1_shape.size() > x2_shape.size() ? x2_shape : x1_shape; + size_t size_diff = long_shape.size() - short_shape.size(); + for (size_t i = 0; i < long_shape.size() - offset; i++) { + if (long_shape[i] < 0) { + out_shape.push_back(abstract::Shape::kShapeDimAny); + } else if (i >= size_diff) { + out_shape.push_back(long_shape[i] > short_shape[i - size_diff] + ? long_shape[i] + : short_shape[i - size_diff]); + } else { + out_shape.push_back(long_shape[i]); + } + } + size_t x1_offset = x1_shape.size() - offset; + size_t x2_offset = x2_shape.size() - offset; + out_shape.push_back(x1_shape[x1_offset + (transpose_x1 ? 1 : 0)]); + out_shape.push_back(x2_shape[x2_offset + (transpose_x2 ? 0 : 1)]); + return out_shape; +} + +inline internal_v2::InternalOpPtr CreateQuantBatchMatmulInternalOpWithParam( + const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs, const bool &transpose_x1, + const bool &transpose_x2, const DataFormat &x1_format, const DataFormat &x2_format, + const DataFormat &output_format, const bool &with_pertoken_scale, const bool &with_bias) { + internal_v2::MatmulParam param; + param.transpose_a = transpose_x1; + param.transpose_b = transpose_x2; + param.with_pertoken_scale = with_pertoken_scale; + param.with_bias = with_bias; + param.enable_shuffle = false; // the real definition is in the internal + param.enable_dequant = true; + + // Map format to internal_v2 enum and set appropriate format + auto inputs_clone = inputs; + auto outputs_clone = outputs; + + inputs_clone[static_cast(QuantBatchMatmulInternalInputIndex::kInputX1Index)].SetFormat( + internal_v2::TensorFormat::kFormatND); + inputs_clone[static_cast(QuantBatchMatmulInternalInputIndex::kInputX2Index)].SetFormat( + internal_v2::TensorFormat::kFormatND); + outputs_clone[static_cast(QuantBatchMatmulInternalOutputIndex::kOutputIndex)] + .SetFormat(internal_v2::TensorFormat::kFormatND); + if (x1_format == DataFormat::FRACTAL_NZ) { + inputs_clone[static_cast(QuantBatchMatmulInternalInputIndex::kInputX1Index)] + .SetFormat(internal_v2::TensorFormat::kFormatFRACTAL_NZ); + } + if (x2_format == DataFormat::FRACTAL_NZ) { + inputs_clone[static_cast(QuantBatchMatmulInternalInputIndex::kInputX2Index)] + .SetFormat(internal_v2::TensorFormat::kFormatFRACTAL_NZ); + } + if (output_format == DataFormat::FRACTAL_NZ) { + outputs_clone[static_cast(QuantBatchMatmulInternalOutputIndex::kOutputIndex)] + .SetFormat(internal_v2::TensorFormat::kFormatFRACTAL_NZ); + } + + return internal_v2::CreateMatmulOp(inputs_clone, outputs_clone, param, + internal_v2::kInternalMatMulOpName); +} + +class OPS_API QuantBatchMatmulInternalOpFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + auto x1_shape = input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputX1Index)]->GetShape(); + auto x2_shape = input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputX2Index)]->GetShape(); + if (IsDynamicRank(x1_shape) || IsDynamicRank(x2_shape)) { + return {ShapeVector({abstract::Shape::kShapeRankAny})}; + } + bool transpose_x1 = input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputTransposeX1Index)]->GetScalarValueWithCheck(); + bool transpose_x2 = input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputTransposeX2Index)]->GetScalarValueWithCheck(); + ShapeVector out_shape = + BatchMatMulInternalMakeShape(x1_shape, x2_shape, transpose_x1, transpose_x2, kQbmmMatSize); + return {out_shape}; + } + + std::vector InferType(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + TypeId output_type = TypeId::kNumberTypeFloat16; + if (!input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputOutputDtypeIndex)]->IsNone()) { + auto dtype_ptr = input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputOutputDtypeIndex)]->GetScalarValueWithCheck(); + output_type = static_cast(dtype_ptr); + } + return {output_type}; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +class QuantBatchMatmulInternal : public InternalKernelMod { + public: + QuantBatchMatmulInternal() : InternalKernelMod() {} + ~QuantBatchMatmulInternal() = default; + + void InitKernelInputsOutputsIndex() override { + kernel_inputs_index_ = { + static_cast(QuantBatchMatmulInternalInputIndex::kInputX1Index), + static_cast(QuantBatchMatmulInternalInputIndex::kInputX2Index), + static_cast(QuantBatchMatmulInternalInputIndex::kInputBiasIndex), + static_cast(QuantBatchMatmulInternalInputIndex::kInputScaleIndex), + static_cast(QuantBatchMatmulInternalInputIndex::kInputPertokenScaleIndex)}; + kernel_outputs_index_ = { + static_cast(QuantBatchMatmulInternalOutputIndex::kOutputIndex)}; + } + + bool Init(const std::vector &inputs, + const std::vector &outputs) override { + bool result = InternalKernelMod::Init(inputs, outputs); + + auto output_format = inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputOutputFormatIndex)); + auto output_format_val = + static_cast(output_format->GetValueWithCheck()); + + if (output_format_val == DataFormat::FRACTAL_NZ) { + ClearNzOutputIndices(); + AddNzOutputIndex( + static_cast(QuantBatchMatmulInternalOutputIndex::kOutputIndex)); + } + + return result; + } + + protected: + internal_v2::InternalOpPtr CreateKernel( + const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs, + const std::vector &ms_inputs, + const std::vector &ms_outputs) override { + auto transpose_x1 = ms_inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputTransposeX1Index))->GetValueWithCheck(); + auto transpose_x2 = ms_inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputTransposeX2Index))->GetValueWithCheck(); + this->x1_format_ = static_cast(ms_inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputX1FormatIndex))->GetValueWithCheck()); + this->x2_format_ = static_cast(ms_inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputX2FormatIndex))->GetValueWithCheck()); + this->output_format_ = static_cast(ms_inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputOutputFormatIndex))->GetValueWithCheck()); + bool with_pertoken_scale = !(ms_inputs.at(static_cast(QuantBatchMatmulInternalInputIndex::kInputPertokenScaleIndex))->GetType()->isa()); + bool with_bias = !(ms_inputs.at(static_cast(QuantBatchMatmulInternalInputIndex::kInputBiasIndex))->GetType()->isa()); + + return CreateQuantBatchMatmulInternalOpWithParam(inputs, outputs, transpose_x1, transpose_x2, + this->x1_format_, this->x2_format_, + this->output_format_, with_pertoken_scale, with_bias); + } + + uint64_t GenerateTilingKey(const std::vector &inputs) override { + return InternalTilingCache::GenerateKey(kernel_name_, inputs, this->x1_format_, + this->x2_format_, this->output_format_); + } + + private: + DataFormat x1_format_{DataFormat::ND}; + DataFormat x2_format_{DataFormat::ND}; + DataFormat output_format_{DataFormat::ND}; +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(quant_batch_matmul_internal, + ms_custom_ops::QuantBatchMatmulInternalOpFuncImpl, + ms_custom_ops::QuantBatchMatmulInternal); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +#include "ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h" + +namespace ms_custom_ops { +class QuantBatchMatmulInternalRunner : public InternalPyboostRunner { + public: + using InternalPyboostRunner::InternalPyboostRunner; + + void SetTransposeX1(const bool &transpose_x1) { this->transpose_x1_ = transpose_x1; } + void SetTransposeX2(const bool &transpose_x2) { this->transpose_x2_ = transpose_x2; } + void SetX1Format(const DataFormat &x1_format) { this->x1_format_ = x1_format; } + void SetX2Format(const DataFormat &x2_format) { this->x2_format_ = x2_format; } + void SetOutputFormat(const DataFormat &output_format) { + this->output_format_ = output_format; + } + void SetWithPertokenScale(const bool &with_pertoken_scale) { this->with_pertoken_scale_ = with_pertoken_scale; } + void SetWithBias(const bool &with_bias) { this->with_bias_ = with_bias; } + + protected: + internal_v2::InternalOpPtr CreateKernel( + const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs) override { + return CreateQuantBatchMatmulInternalOpWithParam(inputs, outputs, this->transpose_x1_, + this->transpose_x2_, this->x1_format_, + this->x2_format_, this->output_format_, + this->with_pertoken_scale_, this->with_bias_); + } + + private: + bool transpose_x1_{false}; + bool transpose_x2_{false}; + bool with_pertoken_scale_{false}; + bool with_bias_{false}; + DataFormat x1_format_{DataFormat::ND}; + DataFormat x2_format_{DataFormat::ND}; + DataFormat output_format_{DataFormat::ND}; +}; + +ms::Tensor npu_quant_batch_matmul_internal( + const ms::Tensor &x1, const ms::Tensor &x2, const ms::Tensor &scale, + const std::optional &offset, const std::optional &bias, + const std::optional &pertoken_scale, std::optional transpose_x1, + std::optional transpose_x2, std::optional x1_format, + std::optional x2_format, std::optional output_format, + std::optional output_dtype) { + auto op_name = "QuantBatchMatmulInternal"; + auto runner = std::make_shared(op_name); + MS_EXCEPTION_IF_NULL(runner); + + runner->SetTransposeX1(transpose_x1.value_or(false)); + runner->SetTransposeX2(transpose_x2.value_or(false)); + runner->SetX1Format(static_cast(x1_format.value_or(0))); + runner->SetX2Format(static_cast(x2_format.value_or(0))); + runner->SetOutputFormat(static_cast(output_format.value_or(0))); + runner->SetWithPertokenScale(pertoken_scale.has_value()); + runner->SetWithBias(bias.has_value()); + + // Setup the runner with all parameters (including hash calculation) + runner->Setup(op_name, x1, x2, scale, GetTensorOrEmpty(offset), GetTensorOrEmpty(bias), + GetTensorOrEmpty(pertoken_scale), transpose_x1.value_or(false), + transpose_x2.value_or(false), x1_format.value_or(0), x2_format.value_or(0), + output_format.value_or(0), output_dtype.value_or(0)); + + // Infer output shape and type + auto transpose_x1_val = transpose_x1.value_or(false); + auto transpose_x2_val = transpose_x2.value_or(false); + auto output_shape = BatchMatMulInternalMakeShape(x1.shape(), x2.shape(), transpose_x1_val, + transpose_x2_val, kQbmmMatSize); + if (output_format.has_value() && + static_cast(output_format.value()) == DataFormat::FRACTAL_NZ) { + CheckShapeHWAlignment(output_shape, x1.data_type()); + } + TypeId out_dtype = TypeId::kNumberTypeFloat16; + if (output_dtype.has_value()) { + out_dtype = static_cast(output_dtype.value()); + } + std::vector inputs = {x1, x2, GetTensorOrEmpty(bias), scale, GetTensorOrEmpty(pertoken_scale)}; + std::vector outputs = {ms::Tensor(out_dtype, output_shape)}; + runner->GetOrCreateKernel(inputs, outputs); + runner->Run(inputs, outputs); + return outputs[0]; +} +} // namespace ms_custom_ops + +auto pyboost_quant_batch_matmul_internal( + const ms::Tensor &x1, const ms::Tensor &x2, const ms::Tensor &scale, + const std::optional &offset, const std::optional &bias, + const std::optional &pertoken_scale, std::optional transpose_x1, + std::optional transpose_x2, std::optional x1_format, + std::optional x2_format, std::optional output_format, + std::optional output_dtype) { + return ms::pynative::PyboostRunner::Call<1>( + ms_custom_ops::npu_quant_batch_matmul_internal, x1, x2, scale, offset, bias, + pertoken_scale, transpose_x1, transpose_x2, x1_format, x2_format, output_format, + output_dtype); +} + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("quant_batch_matmul_internal", &pyboost_quant_batch_matmul_internal, "QuantBatchMatmulInternal", + pybind11::arg("x1"), pybind11::arg("x2"), pybind11::arg("scale"), + pybind11::arg("offset") = std::nullopt, pybind11::arg("bias") = std::nullopt, + pybind11::arg("pertoken_scale") = std::nullopt, + pybind11::arg("transpose_x1") = false, pybind11::arg("transpose_x2") = false, + pybind11::arg("x1_format") = 0, pybind11::arg("x2_format") = 0, + pybind11::arg("output_format") = 0, pybind11::arg("output_dtype") = 0); +} diff --git a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.md b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.md new file mode 100644 index 0000000000000000000000000000000000000000..257a022612bd904c747039a6dc979fa04a966501 --- /dev/null +++ b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.md @@ -0,0 +1,166 @@ +# quant_batch_matmul_internal 算子 + +## 描述 + +quant_batch_matmul_internal 算子用于执行量化批处理矩阵乘法操作。该算子支持 int8 量化输入,输出 bfloat16 或 float16 精度的结果,提供高效的量化矩阵计算能力。算子支持输入矩阵的转置操作,并可指定输入和输出的数据格式。 + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|----------------|--------------------|-----------|----------|---------|---------------|------------------------------------------------------| +| x1 | Tensor(int8) | 2维及以上 | No | No | ND/FRACTAL_NZ | 第一个量化输入矩阵(int8 类型) | +| x2 | Tensor(int8) | 2维及以上 | No | No | ND/FRACTAL_NZ | 第二个量化输入矩阵(int8 类型) | +| scale | Tensor(float16/bf16) | 标量或1维 | No | No | - | 量化缩放因子 | +| offset | Tensor(int32) | 标量 | Yes | No | - | 量化偏移量,默认为 None | +| bias | Tensor(float16/bf16) | 1维 | Yes | No | - | 偏置项,默认为 None | +| pertoken_scale | Tensor(float16/bf16) | 1维 | Yes | No | - | 每 token 缩放因子,默认为 None | +| transpose_x1 | bool | - | Yes | - | - | 是否对 x1 进行转置,默认为 False | +| transpose_x2 | bool | - | Yes | - | - | 是否对 x2 进行转置,默认为 False | +| x1_format | int | - | Yes | - | - | x1 的数据格式:0 表示 ND,1 表示 FRACTAL_NZ,默认为 0 | +| x2_format | int | - | Yes | - | - | x2 的数据格式:0 表示 ND,1 表示 FRACTAL_NZ,默认为 0 | +| output_format | int | - | Yes | - | - | 输出的数据格式:0 表示 ND,1 表示 FRACTAL_NZ,默认为 0 | +| output_dtype | TypeId | - | Yes | - | - | 输出数据类型,支持 float16 或 bfloat16,默认为 float16 | + +## 输出参数 + +| Name | DType | Shape | Description | +|------|--------------------|------------------------|------------------| +| y | Tensor(float16/bf16) | 符合矩阵乘法规则的形状 | 量化矩阵乘法的计算结果 | + +## 支持产品 + +- Atlas 800I A2 推理产品、Atlas 推理系列产品 + +## 约束说明 + +1. **数据类型约束**: + - x1 和 x2 必须为 int8 类型 + - scale、bias、pertoken_scale 支持 float16 或 bfloat16 类型 + - offset 为 int32 类型 + - 输出 y 支持 float16 或 bfloat16 类型(通过 output_dtype 指定) +2. Atlas 800I A2 推理产品仅支持 ND 格式 +3. Atlas 推理系列产品支持 ND 和 FRACTAL_NZ 数据格式,推荐组合如下: + + | x1 | x2 | y | 推荐场景 | + |------------|------------|------------|-------------------------------------------------| + | FRACTAL_NZ | FRACTAL_NZ | ND | prefill 阶段,x1 第 0 维度较大的情况 | + | ND | FRACTAL_NZ | ND | decode 阶段,x1 第 0 维度较小的情况 | + | FRACTAL_NZ | FRACTAL_NZ | FRACTAL_NZ | 通用情况 | + + 当输入为 FRACTAL_NZ 格式时,需要对该输入先调用 `trans_data` 算子进行格式转换;当输出为 FRACTAL_NZ 格式时,需要对输出调用 `trans_data` 算子进行格式转换。 + + **注意**:Pynative Mode 下仅支持 x1 和 y 为 ND 格式的情况。 + +## 使用示例 + +### 基本使用示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_context(device_target="Ascend") + +@ms.jit +def quant_matmul_func(x1, x2, scale, transpose_x1=False, transpose_x2=False, + x1_format=0, x2_format=0, output_format=0, output_dtype=ms.float16): + return ms_custom_ops.quant_batch_matmul_internal( + x1, x2, scale, transpose_x1=transpose_x1, transpose_x2=transpose_x2, + x1_format=x1_format, x2_format=x2_format, output_format=output_format, + output_dtype=output_dtype) + +# 示例1:基础量化矩阵乘法 +batch = 2 +m = 128 +k = 256 +n = 128 +x1 = np.random.randint(-128, 127, (batch, m, k)).astype(np.int8) +x2 = np.random.randint(-128, 127, (batch, k, n)).astype(np.int8) +scale = np.random.randn(1).astype(np.float16) + +ms_x1 = ms.Tensor(x1) +ms_x2 = ms.Tensor(x2) +ms_scale = ms.Tensor(scale) +output = quant_matmul_func(ms_x1, ms_x2, ms_scale) +print("Output shape:", output.shape) +print("Output dtype:", output.dtype) +``` + +### 使用 bfloat16 输出和偏置的示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_context(device_target="Ascend") + +@ms.jit +def quant_matmul_with_bias_func(x1, x2, scale, bias): + return ms_custom_ops.quant_batch_matmul_internal( + x1, x2, scale, bias=bias, output_dtype=ms.bfloat16) + +# x1 形状: (batch, m, k), int8 类型 +# x2 形状: (batch, k, n), int8 类型 +# scale 形状: (1,), float16 类型 +# bias 形状: (n,), float16 类型 +batch = 2 +m = 128 +k = 256 +n = 128 +x1 = np.random.randint(-128, 127, (batch, m, k)).astype(np.int8) +x2 = np.random.randint(-128, 127, (batch, k, n)).astype(np.int8) +scale = np.random.randn(1).astype(np.float16) +bias = np.random.randn(n).astype(np.float16) + +ms_x1 = ms.Tensor(x1) +ms_x2 = ms.Tensor(x2) +ms_scale = ms.Tensor(scale) +ms_bias = ms.Tensor(bias) +output = quant_matmul_with_bias_func(ms_x1, ms_x2, ms_scale, ms_bias) +print("Output shape:", output.shape) +print("Output dtype:", output.dtype) +``` + +### 使用 FRACTAL_NZ 格式的示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_context(device_target="Ascend") + +@ms.jit +def quant_matmul_with_nz_func(x1, x2, scale): + # 将 ND 格式转换为 FRACTAL_NZ 格式 + x1_nz = ms_custom_ops.trans_data(x1, transdata_type=1) # ND_TO_FRACTAL_NZ + x2_nz = ms_custom_ops.trans_data(x2, transdata_type=1) # ND_TO_FRACTAL_NZ + + # 执行量化矩阵乘法,指定输入和输出格式为 FRACTAL_NZ + out_nz = ms_custom_ops.quant_batch_matmul_internal( + x1_nz, x2_nz, scale, transpose_x1=False, transpose_x2=False, + x1_format=1, x2_format=1, output_format=1) + + # 将输出从 FRACTAL_NZ 格式转换回 ND 格式 + out = ms_custom_ops.trans_data(out_nz, transdata_type=0) # FRACTAL_NZ_TO_ND + return out + +# x1 形状: (batch, m, k), int8 类型 +# x2 形状: (batch, k, n), int8 类型 +# 输出形状: (batch, m, n) +batch = 2 +m = 128 +k = 256 +n = 128 +x1 = np.random.randint(-128, 127, (batch, m, k)).astype(np.int8) +x2 = np.random.randint(-128, 127, (batch, k, n)).astype(np.int8) +scale = np.random.randn(1).astype(np.float16) + +ms_x1 = ms.Tensor(x1) +ms_x2 = ms.Tensor(x2) +ms_scale = ms.Tensor(scale) +output = quant_matmul_with_nz_func(ms_x1, ms_x2, ms_scale) +print("Output shape:", output.shape) +``` diff --git a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..00ed986e94e8d65d1ab7a4831f0ddebe836415d0 --- /dev/null +++ b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml @@ -0,0 +1,40 @@ +#operator quant_batch_matmul_internal +quant_batch_matmul_internal: + args: + x1: + dtype: tensor + x2: + dtype: tensor + scale: + dtype: tensor + offset: + dtype: tensor + default: None + bias: + dtype: tensor + default: None + pertoken_scale: + dtype: tensor + default: None + transpose_x1: + dtype: bool + default: false + transpose_x2: + dtype: bool + default: false + x1_format: + dtype: int + default: 0 # 0: ND, 1: FRACTAL_NZ + x2_format: + dtype: int + default: 0 # 0: ND, 1: FRACTAL_NZ + output_format: + dtype: int + default: 0 # 0: ND, 1: FRACTAL_NZ + output_dtype: + dtype: TypeId + default: mstype.float16 + arg_handler: dtype_to_type_id + returns: + y: + dtype: tensor \ No newline at end of file diff --git a/tests/st/test_quant_batch_matmul_internal.py b/tests/st/test_quant_batch_matmul_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..615bad41b7259681b0bce2968e44b6db0bc0d397 --- /dev/null +++ b/tests/st/test_quant_batch_matmul_internal.py @@ -0,0 +1,539 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import numpy as np +import pytest +from mindspore import Profiler + +import mindspore as ms +from mindspore import context +from mindspore._c_expression import MSContext +import ms_custom_ops + +np.set_printoptions(precision=2, suppress=True, linewidth=200) + +# Format constants +ND = 0 +FRACTAL_NZ = 1 + + +def process_deq_scale(deq_scale) -> np.ndarray: + new_deq_scale = np.frombuffer(deq_scale.tobytes(), dtype=np.uint32) + return new_deq_scale.astype(np.int64) + + +def np_qbmm_compute(a, b, tmp_scale, bias=None, tmp_pertoken_scale=None): + c = np.dot(a.astype(np.float32), b.astype(np.float32)).astype(np.int32) + if bias is not None: + c = c + bias + c = c.astype(np.float32) * tmp_scale + if tmp_pertoken_scale is not None: + pertoken_scale = tmp_pertoken_scale[:, np.newaxis] + c = c * pertoken_scale + c = c.astype(np.float16) + return c + + +class Qbmm(ms.nn.Cell): + def __init__(self, weight, scale, bias, pertoken_scale, trans_a, trans_b, dst_dtype, + x1_format=0, x2_format=0, output_format=0): + super().__init__() + self.weight = ms.Parameter(weight, requires_grad=False) + self.scale = ms.Parameter(scale, requires_grad=False) + self.bias = bias + self.pertoken_scale = pertoken_scale + self.trans_a = trans_a + self.trans_b = trans_b + self.dst_dtype = dst_dtype + self.x1_format = x1_format + self.x2_format = x2_format + self.output_format = output_format + + def construct(self, x): + if self.x1_format == FRACTAL_NZ: + x = ms_custom_ops.trans_data(x, transdata_type=1) # ND_TO_FRACTAL_NZ + + output = ms_custom_ops.quant_batch_matmul_internal( + x, self.weight, self.scale, None, self.bias, self.pertoken_scale, + transpose_x1=self.trans_a, transpose_x2=self.trans_b, + x1_format=self.x1_format, x2_format=self.x2_format, + output_format=self.output_format, output_dtype=self.dst_dtype + ) + + if self.output_format == FRACTAL_NZ: + output = ms_custom_ops.trans_data(output, transdata_type=0) # FRACTAL_NZ_TO_ND + return output + + +def qbmm(m, k, n, batch_m=0, trans_a=False, trans_b=False, dst_dtype=ms.float16, scale_dtype=ms.int64, + bias_none=False, pertoken_scale_none=True, profiling=False, is_dyn=False, x1_format=ND, x2_format=ND, output_format=ND): + os.environ['USE_LLM_CUSTOM_MATMUL'] = "off" + os.environ['INTERNAL_PRINT_TILING'] = "on" + + a_shape = (m, k) + # 在多卡并行的场景里,要控制随机种子以防多卡numpy生成不一致 + seed = 0 + np.random.seed(seed) + a = np.random.uniform(-20, 20, size=a_shape).astype(np.int8) + np.random.seed(seed) + b = np.random.uniform(-20, 20, size=(k, n)).astype(np.int8) + np.random.seed(seed) + bias = np.random.randint(-10, 10, (n)).astype(np.int32) + np.random.seed(seed) + tmp_scale = np.random.rand(n).astype(np.float32) / 1000 + scale = process_deq_scale(tmp_scale) + bias_fp16 = bias * tmp_scale + if bias_none: + bias = None + + np.random.seed(seed) + tmp_pertoken_scale = np.random.rand(m).astype(np.float32) + pertoken_scale_ms = ms.Tensor(tmp_pertoken_scale, ms.float32) + if pertoken_scale_none: + tmp_pertoken_scale = None + pertoken_scale_ms = None + + # When the output is bf16, currently only support scale of float32 type. + if scale_dtype == ms.float32 or dst_dtype == ms.bfloat16: + scale_ms = ms.Tensor(tmp_scale, ms.float32) + elif scale_dtype == ms.int64: + scale_ms = ms.Tensor(scale, ms.int64) + + expect_np = np_qbmm_compute(a, b, tmp_scale, bias, tmp_pertoken_scale) + + if trans_a: + a = np.transpose(a, (1, 0)) + if trans_b: + b = np.transpose(b, (1, 0)) + + a_ms = ms.Tensor(a, ms.int8) + b_ms = ms.Tensor(b, ms.int8) + + if x2_format == FRACTAL_NZ or MSContext.get_instance().get_ascend_soc_version() == "ascend310p": + b_ms = ms_custom_ops.trans_data(b_ms, transdata_type=1) # ND_TO_FRACTAL_NZ + if x2_format == ND: + x2_format = FRACTAL_NZ + + net = None + bias_ms = ms.Tensor(bias, ms.int32) + if bias_none: + bias_ms = None + net = Qbmm(b_ms, scale_ms, bias_ms, pertoken_scale_ms, trans_a, trans_b, dst_dtype, + x1_format=x1_format, x2_format=x2_format, output_format=output_format) + + if is_dyn: + input_dyn = ms.Tensor(shape=(None), dtype=ms.int8) + net.set_inputs(input_dyn) + + if profiling: + for _ in range(50): + output = net(a_ms) + return + + output = net(a_ms) + output_np = output.asnumpy() + acc = 0.01 + res = np.allclose(output_np, expect_np, acc, acc) + assert res, "qbmm compare fail." + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 32, 256, 512, 1024, 4096]) +@pytest.mark.env_onecard +def test_qbmm_add_m_4096_4096_false_true_nd_input(m, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, 2048, 2048, trans_a=False, trans_b=True) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('input_shape', [(128, 2560, 5120), + (16, 11264, 6912), (16, 6912, 11264)]) +@pytest.mark.env_onecard +def test_qbmm_add_false_true_nd_input(input_shape, exec_mode): + """ + Feature: testqbmm operator in graph mode + Description: testqbmm. + Expectation: the result is correct + """ + m, k, n = input_shape + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, k, n, trans_a=False, trans_b=True) + +@pytest.mark.level2 +@pytest.mark.platform_ascend310p +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('input_shape', [(128, 5120, 10240), + (1024, 5632, 3456), (1024, 3456, 11264)]) +@pytest.mark.env_onecard +def test_qbmm_add_false_true_nz_input(input_shape, exec_mode): + """ + Feature: testqbmm operator in graph mode + Description: testqbmm. + Expectation: the result is correct + """ + m, k, n = input_shape + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, k, n, trans_a=False, trans_b=True) + +@pytest.mark.level2 +@pytest.mark.platform_ascend310p +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('input_shape', [(128, 1234, 2234), + (1024, 2234, 1234), (1024, 2234, 5234)]) +@pytest.mark.env_onecard +def test_qbmm_add_false_true_nz_input_unaligned_k_n(input_shape, exec_mode): + """ + Feature: testqbmm operator in graph mode + Description: testqbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + m, k, n = input_shape + qbmm(m, k, n, trans_a=False, trans_b=True) + +@pytest.mark.level2 +@pytest.mark.platform_ascend310p +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('input_shape', [(128, 1024, 2234), + (1024, 2048, 1234), (1024, 2048, 5234)]) +@pytest.mark.env_onecard +def test_qbmm_add_false_true_nd_input_unaligned_n(input_shape, exec_mode): + """ + Feature: testqbmm operator in graph mode + Description: testqbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + m, k, n = input_shape + qbmm(m, k, n, trans_a=False, trans_b=True) + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('batch_size', [0, 3, 500, 2000]) +@pytest.mark.parametrize('is_bias_none', [True, False]) +@pytest.mark.env_onecard +def test_qbmm_16_250_10_false_true_nz_input_unaligned_k_n_batch_size(batch_size, is_bias_none, exec_mode): + """ + Feature: testqbmm operator in graph mode + Description: testqbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(16, 250, 10, trans_a=False, trans_b=True, batch_m=batch_size, bias_none=is_bias_none) + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('batch_size', [0, 3, 500, 2000]) +@pytest.mark.parametrize('is_bias_none', [True, False]) +@pytest.mark.env_onecard +def test_qbmm_16_256_10_false_true_nd_input_unaligned_n_batch_size(batch_size, is_bias_none, exec_mode): + """ + Feature: testqbmm operator in graph mode + Description: testqbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(16, 256, 10, trans_a=False, trans_b=True, batch_m=batch_size, bias_none=is_bias_none) + +@pytest.mark.level2 +@pytest.mark.platform_ascend310p +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.env_onecard +def input_matmul_add_32_32_32_false_true_nd_input(exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(16, 32, 64, trans_a=False, trans_b=True) + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.env_onecard +def input_matmul_add_32_32_32_false_true_nz_input(exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(16, 64, 128, trans_a=False, trans_b=True) + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.env_onecard +def test_qbmm_16_32_64_false_true_bias_none_nd_input(exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(16, 32, 64, trans_a=False, trans_b=True, bias_none=True) + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.env_onecard +def test_qbmm_16_32_64_false_true_bias_none_nz_input(exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(16, 64, 128, trans_a=False, trans_b=True, bias_none=True) + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('is_bias_none', [True, False]) +@pytest.mark.env_onecard +def test_qbmm_16_16_32_64_false_true_nd_input(is_bias_none, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m=16, k=32, n=64, batch_m=3, trans_a=False, trans_b=True, + bias_none=is_bias_none) + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('is_bias_none', [True, False]) +@pytest.mark.env_onecard +def test_qbmm_16_16_32_64_false_true_nz_input(is_bias_none, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m=16, k=64, n=128, batch_m=3, trans_a=False, + trans_b=True, bias_none=is_bias_none) + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 512, 1024]) +@pytest.mark.env_onecard +def test_qbmm_m_4096_4096_false_true_bf16(m, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, 4096, 4096, trans_a=False, trans_b=True, dst_dtype=ms.bfloat16) + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 31, 32, 512, 520, 1024]) +@pytest.mark.parametrize('k', [32, 1024]) +@pytest.mark.parametrize('batch_size', [2, 3]) +@pytest.mark.env_onecard +def test_qbmm_with_batch_prefill(m, k, batch_size, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, k, 1024, batch_m=batch_size, trans_a=False, trans_b=True) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 31, 32, 64, 65]) +@pytest.mark.parametrize('k', [32, 64, 128]) +@pytest.mark.parametrize('batch_size', [2, 3]) +@pytest.mark.env_onecard +def test_qbmm_with_batch_increment_310p(m, k, batch_size, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, k, 128, batch_m=batch_size, trans_a=False, trans_b=True) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [48, 128, 1024]) +@pytest.mark.parametrize('x1_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.env_onecard +def test_qbmm_with_fp32_scale_ds(m, x1_format, output_format, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, 7168, 1536, trans_b=True, scale_dtype=ms.float32, x1_format=x1_format, output_format=output_format) + qbmm(m, 1536, 24576, trans_b=True, scale_dtype=ms.float32, x1_format=x1_format, output_format=output_format) + qbmm(m, 7168, 576, trans_b=True, scale_dtype=ms.float32, x1_format=x1_format, output_format=output_format) + qbmm(m, 16384, 7168, trans_b=True, scale_dtype=ms.float32, x1_format=x1_format, output_format=output_format) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [48, 128, 1024]) +@pytest.mark.parametrize('x1_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.env_onecard +def test_qbmm_with_pertoken_ds(m, x1_format, output_format, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, 7168, 36864, trans_b=True, bias_none=True, scale_dtype=ms.float32, + pertoken_scale_none=False, x1_format=x1_format, output_format=output_format) + qbmm(m, 7168, 18432, trans_b=True, bias_none=True, scale_dtype=ms.float32, + pertoken_scale_none=False, x1_format=x1_format, output_format=output_format) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE]) +@pytest.mark.parametrize('m', [1, 48, 128]) +@pytest.mark.parametrize('bias_none', [False, True]) +@pytest.mark.parametrize('x1_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('pertoken_scale_none', [False, True]) +@pytest.mark.env_onecard +def test_qbmm_with_scale(m, bias_none, x1_format, output_format, pertoken_scale_none, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, 256, 256, trans_b=True, bias_none=bias_none, scale_dtype=ms.float32, + pertoken_scale_none=pertoken_scale_none, x1_format=x1_format, output_format=output_format) + qbmm(m, 1024, 1024, trans_b=True, bias_none=bias_none, scale_dtype=ms.float32, + pertoken_scale_none=pertoken_scale_none, x1_format=x1_format, output_format=output_format) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 16, 48, 128]) +@pytest.mark.env_onecard +def test_qbmm_increment_prof(m, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + prof_flag = True + ms.set_context(device_target="Ascend", mode=exec_mode, jit_config={"jit_level": "O0", "infer_boost": "on"}) + profiler = Profiler(start_profile=False, output_path="profiler") + profiler.start() + qbmm(m, 4096, 3072, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=ND, x2_format=ND, output_format=ND) + qbmm(m, 2048, 4096, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=ND, x2_format=ND, output_format=ND) + qbmm(m, 4096, 16512, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=ND, x2_format=ND, output_format=ND) + profiler.stop() + profiler.analyse() + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [514, 1024]) +@pytest.mark.env_onecard +def test_qbmm_prefill_prof(m, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + prof_flag = True + ms.set_context(device_target="Ascend", mode=exec_mode, jit_config={"jit_level": "O0", "infer_boost": "on"}) + profiler = Profiler(start_profile=False, output_path="profiler") + profiler.start() + qbmm(m, 4096, 3072, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=FRACTAL_NZ, x2_format=FRACTAL_NZ, output_format=FRACTAL_NZ) + qbmm(m, 2048, 4096, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=FRACTAL_NZ, x2_format=FRACTAL_NZ, output_format=FRACTAL_NZ) + qbmm(m, 2048, 11008, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=FRACTAL_NZ, x2_format=FRACTAL_NZ, output_format=FRACTAL_NZ) + qbmm(m, 4096, 16512, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=FRACTAL_NZ, x2_format=FRACTAL_NZ, output_format=FRACTAL_NZ) + profiler.stop() + profiler.analyse() + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 4096]) +@pytest.mark.env_onecard +def test_qbmm_fastgelu(m, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + prof_flag = False + qbmm(m, 4096, 11008, trans_a=False, trans_b=True, profiling=prof_flag) + +@pytest.mark.level1 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 16, 1024]) +@pytest.mark.parametrize('k', [1024, 2048]) +@pytest.mark.parametrize('n', [1024, 2048]) +@pytest.mark.parametrize('is_dyn', [False, True]) +@pytest.mark.env_onecard +def test_qbmm_fastgelu_dyn(m, k, n, is_dyn, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode, jit_config={"jit_level": "O0", "infer_boost": "on"}) + prof_flag = False + qbmm(m, k, n, trans_a=False, trans_b=True, profiling=prof_flag, is_dyn=is_dyn)