From 5200f5b8a9ac2c68ad9370a567dba6f9eae35279 Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Wed, 5 Nov 2025 16:55:42 +0800 Subject: [PATCH 1/3] quant batch matmul internal --- .../quant_batch_matmul_internal.cc | 333 ++++++++++ .../quant_batch_matmul_internal.md | 166 +++++ .../quant_batch_matmul_internal_op.yaml | 42 ++ tests/st/test_quant_batch_matmul_internal.py | 584 ++++++++++++++++++ 4 files changed, 1125 insertions(+) create mode 100644 ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc create mode 100644 ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.md create mode 100644 ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml create mode 100644 tests/st/test_quant_batch_matmul_internal.py diff --git a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc new file mode 100644 index 0000000..512c717 --- /dev/null +++ b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc @@ -0,0 +1,333 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ============================================================================= +// GRAPH MODE IMPLEMENTATION +// ============================================================================= + +#include + +#include "ops/framework/ms_kernels_internal/graphmode/internal_kernel_mod.h" +#include "ops/framework/utils.h" + +namespace ms_custom_ops { + +constexpr size_t kQbmmMatSize = 2; + +enum class QuantBatchMatmulInternalInputIndex : size_t { + kInputX1Index = 0, + kInputX2Index, + kInputScaleIndex, + kInputOffsetIndex, + kInputBiasIndex, + kInputPertokenScaleIndex, + kInputTransposeX1Index, + kInputTransposeX2Index, + kInputX1FormatIndex, + kInputX2FormatIndex, + kInputOutputFormatIndex, + kInputOutputDtypeIndex, + kInputsNum, +}; + +enum class QuantBatchMatmulInternalOutputIndex : size_t { + kOutputIndex = 0, +}; + +ShapeVector BatchMatMulMakeShape(const ShapeVector x1_shape, const ShapeVector x2_shape, + bool transpose_x1, bool transpose_x2, size_t offset) { + if (x1_shape.size() < kQbmmMatSize || x2_shape.size() < kQbmmMatSize) { + MS_LOG(EXCEPTION) << "For 'QuantBatchMatmulInternal', the dimension of 'x1' and 'x2' " + << "should be at least 2, but got " << x1_shape << " and " << x2_shape; + } + ShapeVector out_shape; + ShapeVector long_shape = x1_shape.size() > x2_shape.size() ? x1_shape : x2_shape; + ShapeVector short_shape = x1_shape.size() > x2_shape.size() ? x2_shape : x1_shape; + size_t size_diff = long_shape.size() - short_shape.size(); + for (size_t i = 0; i < long_shape.size() - offset; i++) { + if (long_shape[i] < 0) { + out_shape.push_back(abstract::Shape::kShapeDimAny); + } else if (i >= size_diff) { + out_shape.push_back(long_shape[i] > short_shape[i - size_diff] + ? long_shape[i] + : short_shape[i - size_diff]); + } else { + out_shape.push_back(long_shape[i]); + } + } + size_t x1_offset = x1_shape.size() - offset; + size_t x2_offset = x2_shape.size() - offset; + out_shape.push_back(x1_shape[x1_offset + (transpose_x1 ? 1 : 0)]); + out_shape.push_back(x2_shape[x2_offset + (transpose_x2 ? 0 : 1)]); + return out_shape; +} + +inline internal_v2::InternalOpPtr CreateQuantBatchMatmulInternalOpWithParam( + const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs, const bool &transpose_x1, + const bool &transpose_x2, const DataFormat &x1_format, const DataFormat &x2_format, + const DataFormat &output_format, const bool &with_pertoken_scale, const bool &with_bias) { + internal_v2::MatmulParam param; + param.transpose_a = transpose_x1; + param.transpose_b = transpose_x2; + param.with_pertoken_scale = with_pertoken_scale; + param.with_bias = with_bias; + param.enable_shuffle = false; // the real definition is in the internal + param.enable_dequant = true; + + // Map format to internal_v2 enum and set appropriate format + auto inputs_clone = inputs; + auto outputs_clone = outputs; + + inputs_clone[static_cast(QuantBatchMatmulInternalInputIndex::kInputX1Index)].SetFormat( + internal_v2::TensorFormat::kFormatND); + inputs_clone[static_cast(QuantBatchMatmulInternalInputIndex::kInputX2Index)].SetFormat( + internal_v2::TensorFormat::kFormatND); + outputs_clone[static_cast(QuantBatchMatmulInternalOutputIndex::kOutputIndex)] + .SetFormat(internal_v2::TensorFormat::kFormatND); + if (x1_format == DataFormat::FRACTAL_NZ) { + inputs_clone[static_cast(QuantBatchMatmulInternalInputIndex::kInputX1Index)] + .SetFormat(internal_v2::TensorFormat::kFormatFRACTAL_NZ); + } + if (x2_format == DataFormat::FRACTAL_NZ) { + inputs_clone[static_cast(QuantBatchMatmulInternalInputIndex::kInputX2Index)] + .SetFormat(internal_v2::TensorFormat::kFormatFRACTAL_NZ); + } + if (output_format == DataFormat::FRACTAL_NZ) { + outputs_clone[static_cast(QuantBatchMatmulInternalOutputIndex::kOutputIndex)] + .SetFormat(internal_v2::TensorFormat::kFormatFRACTAL_NZ); + } + + return internal_v2::CreateMatmulOp(inputs_clone, outputs_clone, param, + internal_v2::kInternalMatMulOpName); +} + +class OPS_API QuantBatchMatmulInternalOpFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + auto x1_shape = input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputX1Index)]->GetShape(); + auto x2_shape = input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputX2Index)]->GetShape(); + if (IsDynamicRank(x1_shape) || IsDynamicRank(x2_shape)) { + return {ShapeVector({abstract::Shape::kShapeRankAny})}; + } + bool transpose_x1 = input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputTransposeX1Index)]->GetScalarValueWithCheck(); + bool transpose_x2 = input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputTransposeX2Index)]->GetScalarValueWithCheck(); + ShapeVector out_shape = + BatchMatMulMakeShape(x1_shape, x2_shape, transpose_x1, transpose_x2, kQbmmMatSize); + return {out_shape}; + } + + std::vector InferType(const PrimitivePtr &primitive, + const InferInfoPtrList &input_infos) const override { + TypeId output_type = TypeId::kNumberTypeFloat16; + if (!input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputOutputDtypeIndex)]->IsNone()) { + auto dtype_ptr = input_infos[static_cast( + QuantBatchMatmulInternalInputIndex::kInputOutputDtypeIndex)]->GetScalarValueWithCheck(); + output_type = static_cast(dtype_ptr); + } + return {output_type}; + } + + bool GeneralInferRegistered() const override { return true; } +}; + +class QuantBatchMatmulInternal : public InternalKernelMod { + public: + QuantBatchMatmulInternal() : InternalKernelMod() {} + ~QuantBatchMatmulInternal() = default; + + void InitKernelInputsOutputsIndex() override { + kernel_inputs_index_ = { + static_cast(QuantBatchMatmulInternalInputIndex::kInputX1Index), + static_cast(QuantBatchMatmulInternalInputIndex::kInputX2Index), + static_cast(QuantBatchMatmulInternalInputIndex::kInputBiasIndex), + static_cast(QuantBatchMatmulInternalInputIndex::kInputScaleIndex), + static_cast(QuantBatchMatmulInternalInputIndex::kInputPertokenScaleIndex)}; + kernel_outputs_index_ = { + static_cast(QuantBatchMatmulInternalOutputIndex::kOutputIndex)}; + } + + bool Init(const std::vector &inputs, + const std::vector &outputs) override { + bool result = InternalKernelMod::Init(inputs, outputs); + + auto output_format = inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputOutputFormatIndex)); + auto output_format_val = + static_cast(output_format->GetValueWithCheck()); + + if (output_format_val == DataFormat::FRACTAL_NZ) { + ClearNzOutputIndices(); + AddNzOutputIndex( + static_cast(QuantBatchMatmulInternalOutputIndex::kOutputIndex)); + } + + return result; + } + + protected: + internal_v2::InternalOpPtr CreateKernel( + const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs, + const std::vector &ms_inputs, + const std::vector &ms_outputs) override { + auto transpose_x1 = ms_inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputTransposeX1Index))->GetValueWithCheck(); + auto transpose_x2 = ms_inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputTransposeX2Index))->GetValueWithCheck(); + this->x1_format_ = static_cast(ms_inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputX1FormatIndex))->GetValueWithCheck()); + this->x2_format_ = static_cast(ms_inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputX2FormatIndex))->GetValueWithCheck()); + this->output_format_ = static_cast(ms_inputs.at(static_cast( + QuantBatchMatmulInternalInputIndex::kInputOutputFormatIndex))->GetValueWithCheck()); + bool with_pertoken_scale = !(ms_inputs.at(static_cast(QuantBatchMatmulInternalInputIndex::kInputPertokenScaleIndex))->GetType()->isa()); + bool with_bias = !(ms_inputs.at(static_cast(QuantBatchMatmulInternalInputIndex::kInputBiasIndex))->GetType()->isa()); + + return CreateQuantBatchMatmulInternalOpWithParam(inputs, outputs, transpose_x1, transpose_x2, + this->x1_format_, this->x2_format_, + this->output_format_, with_pertoken_scale, with_bias); + } + + uint64_t GenerateTilingKey(const std::vector &inputs) override { + return InternalTilingCache::GenerateKey(kernel_name_, inputs, this->x1_format_, + this->x2_format_, this->output_format_); + } + + private: + DataFormat x1_format_{DataFormat::ND}; + DataFormat x2_format_{DataFormat::ND}; + DataFormat output_format_{DataFormat::ND}; +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(quant_batch_matmul_internal, + ms_custom_ops::QuantBatchMatmulInternalOpFuncImpl, + ms_custom_ops::QuantBatchMatmulInternal); + +// ============================================================================= +// PYBOOST MODE IMPLEMENTATION +// ============================================================================= + +#include "ops/framework/ms_kernels_internal/pyboost/internal_pyboost_runner.h" + +namespace ms_custom_ops { +class QuantBatchMatmulInternalRunner : public InternalPyboostRunner { + public: + using InternalPyboostRunner::InternalPyboostRunner; + + void SetTransposeX1(const bool &transpose_x1) { this->transpose_x1_ = transpose_x1; } + void SetTransposeX2(const bool &transpose_x2) { this->transpose_x2_ = transpose_x2; } + void SetX1Format(const DataFormat &x1_format) { this->x1_format_ = x1_format; } + void SetX2Format(const DataFormat &x2_format) { this->x2_format_ = x2_format; } + void SetOutputFormat(const DataFormat &output_format) { + this->output_format_ = output_format; + } + void SetWithPertokenScale(const bool &with_pertoken_scale) { this->with_pertoken_scale_ = with_pertoken_scale; } + void SetWithBias(const bool &with_bias) { this->with_bias_ = with_bias; } + + protected: + internal_v2::InternalOpPtr CreateKernel( + const internal_v2::InputsImmutableInfoList &inputs, + const internal_v2::OutputsImmutableInfoList &outputs) override { + return CreateQuantBatchMatmulInternalOpWithParam(inputs, outputs, this->transpose_x1_, + this->transpose_x2_, this->x1_format_, + this->x2_format_, this->output_format_, + this->with_pertoken_scale_, this->with_bias_); + } + + private: + bool transpose_x1_{false}; + bool transpose_x2_{false}; + DataFormat x1_format_{DataFormat::ND}; + DataFormat x2_format_{DataFormat::ND}; + DataFormat output_format_{DataFormat::ND}; +}; + +ms::Tensor npu_quant_batch_matmul_internal( + const ms::Tensor &x1, const ms::Tensor &x2, const ms::Tensor &scale, + const std::optional &offset, const std::optional &bias, + const std::optional &pertoken_scale, std::optional transpose_x1, + std::optional transpose_x2, std::optional x1_format, + std::optional x2_format, std::optional output_format, + std::optional output_dtype) { + auto op_name = "QuantBatchMatmulInternal"; + auto runner = std::make_shared(op_name); + MS_EXCEPTION_IF_NULL(runner); + + runner->SetTransposeX1(transpose_x1.value_or(false)); + runner->SetTransposeX2(transpose_x2.value_or(false)); + runner->SetX1Format(static_cast(x1_format.value_or(0))); + runner->SetX2Format(static_cast(x2_format.value_or(0))); + runner->SetOutputFormat(static_cast(output_format.value_or(0))); + runner->SetWithPertokenScale(pertoken_scale.has_value()); + runner->SetWithBias(bias.has_value()); + + // Setup the runner with all parameters (including hash calculation) + runner->Setup(op_name, x1, x2, scale, GetTensorOrEmpty(offset), GetTensorOrEmpty(bias), + GetTensorOrEmpty(pertoken_scale), transpose_x1.value_or(false), + transpose_x2.value_or(false), x1_format.value_or(0), x2_format.value_or(0), + output_format.value_or(0), output_dtype.value_or(0)); + + // Infer output shape and type + auto transpose_x1_val = transpose_x1.value_or(false); + auto transpose_x2_val = transpose_x2.value_or(false); + auto output_shape = BatchMatMulMakeShape(x1.shape(), x2.shape(), transpose_x1_val, + transpose_x2_val, kQbmmMatSize); + if (output_format.has_value() && + static_cast(output_format.value()) == DataFormat::FRACTAL_NZ) { + CheckShapeHWAlignment(output_shape, x1.data_type()); + } + TypeId out_dtype = TypeId::kNumberTypeFloat16; + if (output_dtype.has_value()) { + out_dtype = static_cast(output_dtype.value()); + } + std::vector inputs = {x1, x2, GetTensorOrEmpty(bias), scale, GetTensorOrEmpty(pertoken_scale)}; + std::vector outputs = {ms::Tensor(out_dtype, output_shape)}; + runner->GetOrCreateKernel(inputs, outputs); + runner->Run(inputs, outputs); + return outputs[0]; +} +} // namespace ms_custom_ops + +auto pyboost_quant_batch_matmul_internal( + const ms::Tensor &x1, const ms::Tensor &x2, const ms::Tensor &scale, + const std::optional &offset, const std::optional &bias, + const std::optional &pertoken_scale, std::optional transpose_x1, + std::optional transpose_x2, std::optional x1_format, + std::optional x2_format, std::optional output_format, + std::optional output_dtype) { + return ms::pynative::PyboostRunner::Call<1>( + ms_custom_ops::npu_quant_batch_matmul_internal, x1, x2, scale, offset, bias, + pertoken_scale, transpose_x1, transpose_x2, x1_format, x2_format, output_format, + output_dtype); +} + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("quant_batch_matmul_internal", &pyboost_quant_batch_matmul_internal, "QuantBatchMatmulInternal", + pybind11::arg("x1"), pybind11::arg("x2"), pybind11::arg("scale"), + pybind11::arg("offset") = pybind11::none(), pybind11::arg("bias") = pybind11::none(), + pybind11::arg("pertoken_scale") = pybind11::none(), + pybind11::arg("transpose_x1") = false, pybind11::arg("transpose_x2") = false, + pybind11::arg("x1_format") = 0, pybind11::arg("x2_format") = 0, + pybind11::arg("output_format") = 0, pybind11::arg("output_dtype") = 0); +} diff --git a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.md b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.md new file mode 100644 index 0000000..257a022 --- /dev/null +++ b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.md @@ -0,0 +1,166 @@ +# quant_batch_matmul_internal 算子 + +## 描述 + +quant_batch_matmul_internal 算子用于执行量化批处理矩阵乘法操作。该算子支持 int8 量化输入,输出 bfloat16 或 float16 精度的结果,提供高效的量化矩阵计算能力。算子支持输入矩阵的转置操作,并可指定输入和输出的数据格式。 + +## 输入参数 + +| Name | DType | Shape | Optional | Inplace | Format | Description | +|----------------|--------------------|-----------|----------|---------|---------------|------------------------------------------------------| +| x1 | Tensor(int8) | 2维及以上 | No | No | ND/FRACTAL_NZ | 第一个量化输入矩阵(int8 类型) | +| x2 | Tensor(int8) | 2维及以上 | No | No | ND/FRACTAL_NZ | 第二个量化输入矩阵(int8 类型) | +| scale | Tensor(float16/bf16) | 标量或1维 | No | No | - | 量化缩放因子 | +| offset | Tensor(int32) | 标量 | Yes | No | - | 量化偏移量,默认为 None | +| bias | Tensor(float16/bf16) | 1维 | Yes | No | - | 偏置项,默认为 None | +| pertoken_scale | Tensor(float16/bf16) | 1维 | Yes | No | - | 每 token 缩放因子,默认为 None | +| transpose_x1 | bool | - | Yes | - | - | 是否对 x1 进行转置,默认为 False | +| transpose_x2 | bool | - | Yes | - | - | 是否对 x2 进行转置,默认为 False | +| x1_format | int | - | Yes | - | - | x1 的数据格式:0 表示 ND,1 表示 FRACTAL_NZ,默认为 0 | +| x2_format | int | - | Yes | - | - | x2 的数据格式:0 表示 ND,1 表示 FRACTAL_NZ,默认为 0 | +| output_format | int | - | Yes | - | - | 输出的数据格式:0 表示 ND,1 表示 FRACTAL_NZ,默认为 0 | +| output_dtype | TypeId | - | Yes | - | - | 输出数据类型,支持 float16 或 bfloat16,默认为 float16 | + +## 输出参数 + +| Name | DType | Shape | Description | +|------|--------------------|------------------------|------------------| +| y | Tensor(float16/bf16) | 符合矩阵乘法规则的形状 | 量化矩阵乘法的计算结果 | + +## 支持产品 + +- Atlas 800I A2 推理产品、Atlas 推理系列产品 + +## 约束说明 + +1. **数据类型约束**: + - x1 和 x2 必须为 int8 类型 + - scale、bias、pertoken_scale 支持 float16 或 bfloat16 类型 + - offset 为 int32 类型 + - 输出 y 支持 float16 或 bfloat16 类型(通过 output_dtype 指定) +2. Atlas 800I A2 推理产品仅支持 ND 格式 +3. Atlas 推理系列产品支持 ND 和 FRACTAL_NZ 数据格式,推荐组合如下: + + | x1 | x2 | y | 推荐场景 | + |------------|------------|------------|-------------------------------------------------| + | FRACTAL_NZ | FRACTAL_NZ | ND | prefill 阶段,x1 第 0 维度较大的情况 | + | ND | FRACTAL_NZ | ND | decode 阶段,x1 第 0 维度较小的情况 | + | FRACTAL_NZ | FRACTAL_NZ | FRACTAL_NZ | 通用情况 | + + 当输入为 FRACTAL_NZ 格式时,需要对该输入先调用 `trans_data` 算子进行格式转换;当输出为 FRACTAL_NZ 格式时,需要对输出调用 `trans_data` 算子进行格式转换。 + + **注意**:Pynative Mode 下仅支持 x1 和 y 为 ND 格式的情况。 + +## 使用示例 + +### 基本使用示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_context(device_target="Ascend") + +@ms.jit +def quant_matmul_func(x1, x2, scale, transpose_x1=False, transpose_x2=False, + x1_format=0, x2_format=0, output_format=0, output_dtype=ms.float16): + return ms_custom_ops.quant_batch_matmul_internal( + x1, x2, scale, transpose_x1=transpose_x1, transpose_x2=transpose_x2, + x1_format=x1_format, x2_format=x2_format, output_format=output_format, + output_dtype=output_dtype) + +# 示例1:基础量化矩阵乘法 +batch = 2 +m = 128 +k = 256 +n = 128 +x1 = np.random.randint(-128, 127, (batch, m, k)).astype(np.int8) +x2 = np.random.randint(-128, 127, (batch, k, n)).astype(np.int8) +scale = np.random.randn(1).astype(np.float16) + +ms_x1 = ms.Tensor(x1) +ms_x2 = ms.Tensor(x2) +ms_scale = ms.Tensor(scale) +output = quant_matmul_func(ms_x1, ms_x2, ms_scale) +print("Output shape:", output.shape) +print("Output dtype:", output.dtype) +``` + +### 使用 bfloat16 输出和偏置的示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_context(device_target="Ascend") + +@ms.jit +def quant_matmul_with_bias_func(x1, x2, scale, bias): + return ms_custom_ops.quant_batch_matmul_internal( + x1, x2, scale, bias=bias, output_dtype=ms.bfloat16) + +# x1 形状: (batch, m, k), int8 类型 +# x2 形状: (batch, k, n), int8 类型 +# scale 形状: (1,), float16 类型 +# bias 形状: (n,), float16 类型 +batch = 2 +m = 128 +k = 256 +n = 128 +x1 = np.random.randint(-128, 127, (batch, m, k)).astype(np.int8) +x2 = np.random.randint(-128, 127, (batch, k, n)).astype(np.int8) +scale = np.random.randn(1).astype(np.float16) +bias = np.random.randn(n).astype(np.float16) + +ms_x1 = ms.Tensor(x1) +ms_x2 = ms.Tensor(x2) +ms_scale = ms.Tensor(scale) +ms_bias = ms.Tensor(bias) +output = quant_matmul_with_bias_func(ms_x1, ms_x2, ms_scale, ms_bias) +print("Output shape:", output.shape) +print("Output dtype:", output.dtype) +``` + +### 使用 FRACTAL_NZ 格式的示例 + +```python +import mindspore as ms +import numpy as np +import ms_custom_ops + +ms.set_context(device_target="Ascend") + +@ms.jit +def quant_matmul_with_nz_func(x1, x2, scale): + # 将 ND 格式转换为 FRACTAL_NZ 格式 + x1_nz = ms_custom_ops.trans_data(x1, transdata_type=1) # ND_TO_FRACTAL_NZ + x2_nz = ms_custom_ops.trans_data(x2, transdata_type=1) # ND_TO_FRACTAL_NZ + + # 执行量化矩阵乘法,指定输入和输出格式为 FRACTAL_NZ + out_nz = ms_custom_ops.quant_batch_matmul_internal( + x1_nz, x2_nz, scale, transpose_x1=False, transpose_x2=False, + x1_format=1, x2_format=1, output_format=1) + + # 将输出从 FRACTAL_NZ 格式转换回 ND 格式 + out = ms_custom_ops.trans_data(out_nz, transdata_type=0) # FRACTAL_NZ_TO_ND + return out + +# x1 形状: (batch, m, k), int8 类型 +# x2 形状: (batch, k, n), int8 类型 +# 输出形状: (batch, m, n) +batch = 2 +m = 128 +k = 256 +n = 128 +x1 = np.random.randint(-128, 127, (batch, m, k)).astype(np.int8) +x2 = np.random.randint(-128, 127, (batch, k, n)).astype(np.int8) +scale = np.random.randn(1).astype(np.float16) + +ms_x1 = ms.Tensor(x1) +ms_x2 = ms.Tensor(x2) +ms_scale = ms.Tensor(scale) +output = quant_matmul_with_nz_func(ms_x1, ms_x2, ms_scale) +print("Output shape:", output.shape) +``` diff --git a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml new file mode 100644 index 0000000..8eb2bf8 --- /dev/null +++ b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml @@ -0,0 +1,42 @@ +#operator quant_batch_matmul_internal +quant_batch_matmul_internal: + args: + x1: + dtype: tensor + x2: + dtype: tensor + scale: + dtype: tensor + offset: + dtype: tensor + default: None + bias: + dtype: tensor + default: None + pertoken_scale: + dtype: tensor + default: None + transpose_x1: + dtype: bool + default: false + transpose_x2: + dtype: bool + default: false + x1_format: + dtype: int + default: 0 # 0: ND, 1: FRACTAL_NZ + x2_format: + dtype: int + default: 0 # 0: ND, 1: FRACTAL_NZ + output_format: + dtype: int + default: 0 # 0: ND, 1: FRACTAL_NZ + output_dtype: + dtype: TypeId + default: mstype.float16 + arg_handler: dtype_to_type_id + args_signature: + dtype_group: (x1, x2) + returns: + y: + dtype: tensor \ No newline at end of file diff --git a/tests/st/test_quant_batch_matmul_internal.py b/tests/st/test_quant_batch_matmul_internal.py new file mode 100644 index 0000000..9ba57fc --- /dev/null +++ b/tests/st/test_quant_batch_matmul_internal.py @@ -0,0 +1,584 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import sys +import logging +import numpy as np +import pytest +import mindspore as ms +from mindspore import Profiler +from mindspore import context +from mindspore.common.np_dtype import bfloat16 +from mindspore.common.api import jit +from mindspore._c_expression import MSContext +from functools import wraps +from st_utils import custom_compare +import ms_custom_ops + +np.set_printoptions(precision=2, suppress=True, linewidth=200) + +def jit_for_graph_mode(fn): + """ + A decorator that conditionally applies jit to a function at runtime based on the context mode. + """ + jitted_fn = jit(fn) + @wraps(fn) + def wrapper(*args, **kwargs): + if context.get_context("mode") == context.GRAPH_MODE: + return jitted_fn(*args, **kwargs) + return fn(*args, **kwargs) + return wrapper + +ND = 0 +FRACTAL_NZ = 1 + + +def np_quant_batch_matmul_compute(x1, x2, scale, offset=None, bias=None, pertoken_scale=None, + transpose_x1=False, transpose_x2=False): + """Compute quant_batch_matmul using numpy for comparison.""" + # Handle transpose + if transpose_x1: + if len(x1.shape) == 2: + x1 = x1.T + else: + x1 = np.transpose(x1, tuple(range(len(x1.shape) - 2)) + (-1, -2)) + if transpose_x2: + if len(x2.shape) == 2: + x2 = x2.T + else: + x2 = np.transpose(x2, tuple(range(len(x2.shape) - 2)) + (-1, -2)) + + # Matrix multiplication + result = np.matmul(x1.astype(np.float32), x2.astype(np.float32)).astype(np.int32) + + # Add offset if provided + if offset is not None: + result = result + offset.astype(np.int32) + + # Add bias if provided + if bias is not None: + bias_fp32 = bias.astype(np.float32) + if len(result.shape) == 2: + result = result.astype(np.float32) + bias_fp32.reshape(-1) + else: + result = result.astype(np.float32) + bias_fp32 + + # Apply scale + scale_fp32 = scale.astype(np.float32) + if len(scale_fp32.shape) == 0 or scale_fp32.shape[0] == 1: + result = result.astype(np.float32) * scale_fp32.item() + else: + if len(result.shape) == 2: + result = result.astype(np.float32) * scale_fp32.reshape(-1) + else: + result = result.astype(np.float32) * scale_fp32.reshape(-1) + + # Apply pertoken_scale if provided + if pertoken_scale is not None: + pertoken_scale_fp32 = pertoken_scale.astype(np.float32) + if len(result.shape) == 2: + result = result * pertoken_scale_fp32.reshape(-1, 1) + else: + result = result * pertoken_scale_fp32.reshape(-1, 1, 1) + + return result.astype(np.float32) + + +class QuantBatchMatmulInternalCustom(ms.nn.Cell): + def __init__(self, weight, scale, bias, pertoken_scale, ta, tb, + x1_format=ND, x2_format=ND, output_format=ND, output_dtype=ms.float16): + super().__init__() + self.weight = ms.Parameter(weight, requires_grad=False) + self.weight.set_data(weight) + self.scale = ms.Parameter(scale, requires_grad=False) + if bias is not None: + self.bias = ms.Parameter(bias, requires_grad=False) + else: + self.bias = None + if pertoken_scale is not None: + self.pertoken_scale = ms.Parameter(pertoken_scale, requires_grad=False) + else: + self.pertoken_scale = None + self.trans_x1 = ta + self.trans_x2 = tb + self.x1_format = x1_format + self.x2_format = x2_format + self.output_format = output_format + self.output_dtype = output_dtype + + @jit_for_graph_mode + def construct(self, i0): + if self.x1_format == FRACTAL_NZ: + i0 = ms_custom_ops.trans_data(i0, transdata_type=1) # ND_TO_FRACTAL_NZ + + output = ms_custom_ops.quant_batch_matmul_internal( + i0, self.weight, self.scale, + offset=None, # offset is optional + bias=self.bias, + pertoken_scale=self.pertoken_scale, + transpose_x1=self.trans_x1, + transpose_x2=self.trans_x2, + x1_format=self.x1_format, + x2_format=self.x2_format, + output_format=self.output_format, + output_dtype=self.output_dtype) + + if self.output_format == FRACTAL_NZ: + output = ms_custom_ops.trans_data(output, transdata_type=0, out_crops=None) # FRACTAL_NZ_TO_ND + return output + + +def quant_batch_matmul_internal(m, k, n, trans_x1=False, trans_x2=False, + with_bias=False, with_pertoken_scale=False, + mstype=ms.float16, profiling=False, + x1_format=ND, x2_format=ND, output_format=ND, + output_dtype=None): + os.environ['USE_LLM_CUSTOM_MATMUL'] = "off" + os.environ['INTERNAL_PRINT_TILING'] = "on" + + if output_dtype is None: + output_dtype = mstype + + if ms.float16 == mstype: + np_type = np.float16 + elif ms.float32 == mstype: + np_type = np.float32 + elif ms.bfloat16 == mstype: + np_type = bfloat16 + else: + np_type = np.float16 + + # Generate random int8 inputs + np.random.seed(0) + if trans_x1: + i0_host = np.random.randint(-128, 127, size=[k, m]).astype(np.int8) + else: + i0_host = np.random.randint(-128, 127, size=[m, k]).astype(np.int8) + + if trans_x2: + i1_host = np.random.randint(-128, 127, size=[n, k]).astype(np.int8) + else: + i1_host = np.random.randint(-128, 127, size=[k, n]).astype(np.int8) + + # Generate scale (float16/bf16) + np.random.seed(0) + if output_dtype == ms.bfloat16: + scale_host = np.random.randn(n).astype(np.float32) + else: + scale_host = np.random.randn(n).astype(np.float16) + + # Generate bias if needed + bias_host = None + if with_bias: + np.random.seed(0) + if output_dtype == ms.bfloat16: + bias_host = np.random.randn(n).astype(np.float32) + else: + bias_host = np.random.randn(n).astype(np.float16) + + # Generate pertoken_scale if needed + pertoken_scale_host = None + if with_pertoken_scale: + np.random.seed(0) + if output_dtype == ms.bfloat16: + pertoken_scale_host = np.random.randn(m).astype(np.float32) + else: + pertoken_scale_host = np.random.randn(m).astype(np.float16) + + # Compute expected result + i0_host_fp32 = i0_host.astype(np.float32) + i1_host_fp32 = i1_host.astype(np.float32) + expect = np_quant_batch_matmul_compute(i0_host_fp32, i1_host_fp32, scale_host.astype(np.float32), + bias=bias_host.astype(np.float32) if bias_host is not None else None, + pertoken_scale=pertoken_scale_host.astype(np.float32) if pertoken_scale_host is not None else None, + transpose_x1=trans_x1, transpose_x2=trans_x2) + print("numpy compute done") + + input1 = ms.Tensor(i0_host, ms.int8) + input2 = ms.Tensor(i1_host, ms.int8) + scale_tensor = ms.Tensor(scale_host, output_dtype) + + bias_tensor = None + if with_bias: + bias_tensor = ms.Tensor(bias_host, output_dtype) + + pertoken_scale_tensor = None + if with_pertoken_scale: + pertoken_scale_tensor = ms.Tensor(pertoken_scale_host, output_dtype) + + # Handle format conversion for x2 + if x2_format == FRACTAL_NZ or MSContext.get_instance().get_ascend_soc_version() == "ascend310p": + input2 = ms_custom_ops.trans_data(input2, transdata_type=1) # ND_TO_FRACTAL_NZ + x2_format = FRACTAL_NZ + else: + x2_format = ND + + net = QuantBatchMatmulInternalCustom(input2, scale_tensor, bias_tensor, pertoken_scale_tensor, + trans_x1, trans_x2, x1_format, x2_format, output_format, output_dtype) + + if profiling: + for i in range(50): + output = net(input1) + return + + output = net(input1) + output_fp32 = output.astype(ms.float32) + output_np = output_fp32.asnumpy() + + # Use custom_compare for better accuracy handling + res = custom_compare(expect, output_np, output_dtype) + assert res, "quant_batch_matmul_internal compare fail." + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('trans_x2', [False, True]) +@pytest.mark.parametrize('mstype', [ms.float16]) +@pytest.mark.parametrize('x1_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_1024_1024_1024_nz_input_fp16(exec_mode, trans_x2, mstype, x1_format, output_format, request): + """ + Feature: test quant_batch_matmul_internal operator in graph and pynative mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + quant_batch_matmul_internal(1024, 1024, 1024, trans_x1=False, trans_x2=trans_x2, + mstype=mstype, x1_format=x1_format, output_format=output_format) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('trans_x2', [False, True]) +@pytest.mark.parametrize('mstype', [ms.bfloat16]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_1024_1024_1024_input_bfp16(exec_mode, trans_x2, mstype, request): + """ + Feature: test quant_batch_matmul_internal operator in graph mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + if "platform_ascend310p" in request.config.getoption("-m") and mstype is ms.bfloat16: + pytest.skip("Skipping ms.bfloat16 for 310p mark") + quant_batch_matmul_internal(1024, 1024, 1024, trans_x1=False, trans_x2=trans_x2, mstype=mstype, + output_dtype=ms.bfloat16) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('trans_x2', [False, True]) +@pytest.mark.parametrize('mstype', [ms.float16]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_2048_2048_2048_nd_input_fp16(exec_mode, trans_x2, mstype, request): + """ + Feature: test quant_batch_matmul_internal operator in graph mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + quant_batch_matmul_internal(2048, 2048, 2048, trans_x1=False, trans_x2=trans_x2, mstype=mstype, + output_dtype=ms.float16) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('trans_x2', [False, True]) +@pytest.mark.parametrize('mstype', [ms.bfloat16]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_2048_2048_2048_nd_input_bfp16(exec_mode, trans_x2, mstype, request): + """ + Feature: test quant_batch_matmul_internal operator in graph mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + if "platform_ascend310p" in request.config.getoption("-m") and mstype is ms.bfloat16: + pytest.skip("Skipping ms.bfloat16 for 310p mark") + quant_batch_matmul_internal(2048, 2048, 2048, trans_x1=False, trans_x2=trans_x2, mstype=mstype, + output_dtype=ms.bfloat16) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('trans_x2', [True]) +@pytest.mark.parametrize('mstype', [ms.float16]) +@pytest.mark.parametrize('x1_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_1024_1234_1234_input_unaligned_k_n_fp16(exec_mode, trans_x2, mstype, x1_format, output_format, request): + """ + Feature: test quant_batch_matmul_internal operator in graph and pynative mode + Description: Test that unaligned large/edge n dimension raise exception + Expectation: CheckDimensionAlignment validation rejects unaligned inputs + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + with pytest.raises(RuntimeError, match="dimension must be aligned"): + quant_batch_matmul_internal(1024, 1234, 1234, trans_x1=False, trans_x2=trans_x2, + mstype=mstype, x1_format=x1_format, output_format=output_format, + output_dtype=ms.float16) + logging.info( + "Unaligned dimension correctly rejected: shape=%s", + (1234, 1234) + ) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('trans_x2', [False, True]) +@pytest.mark.parametrize('mstype', [ms.bfloat16]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_1024_1234_1234_nz_input_unaligned_k_n_bfp16(exec_mode, trans_x2, mstype, request): + """ + Feature: test quant_batch_matmul_internal operator in graph mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + if "platform_ascend310p" in request.config.getoption("-m") and mstype is ms.bfloat16: + pytest.skip("Skipping ms.bfloat16 for 310p mark") + quant_batch_matmul_internal(1024, 1234, 1234, trans_x1=False, trans_x2=trans_x2, mstype=mstype, + output_dtype=ms.bfloat16) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('trans_x2', [False, True]) +@pytest.mark.parametrize('mstype', [ms.float16]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_1024_2048_2234_nd_input_unaligned_n_fp16(exec_mode, trans_x2, mstype, request): + """ + Feature: test quant_batch_matmul_internal operator in graph mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + quant_batch_matmul_internal(1024, 2048, 2234, trans_x1=False, trans_x2=trans_x2, mstype=mstype, + output_dtype=ms.float16) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('trans_x2', [False, True]) +@pytest.mark.parametrize('mstype', [ms.bfloat16]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_1024_2048_2234_nd_input_unaligned_n_bfp16(exec_mode, trans_x2, mstype, request): + """ + Feature: test quant_batch_matmul_internal operator in graph mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + if "platform_ascend310p" in request.config.getoption("-m") and mstype is ms.bfloat16: + pytest.skip("Skipping ms.bfloat16 for 310p mark") + quant_batch_matmul_internal(1024, 2048, 2234, trans_x1=False, trans_x2=trans_x2, mstype=mstype, + output_dtype=ms.bfloat16) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 10, 20]) +@pytest.mark.parametrize('x1_format', [ND]) +@pytest.mark.parametrize('output_format', [ND]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_m_4096_4096_False_True_float16_nd_input(exec_mode, m, x1_format, output_format): + """ + Feature: test quant_batch_matmul_internal operator in graph and pynative mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + quant_batch_matmul_internal(m, 4096, 4096, trans_x1=False, trans_x2=True, + mstype=ms.float16, x1_format=x1_format, output_format=output_format, + output_dtype=ms.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE]) +@pytest.mark.parametrize('m', [1, 10, 20]) +@pytest.mark.parametrize('x1_format', [FRACTAL_NZ]) +@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_m_4096_4096_False_True_float16_nz_input(exec_mode, m, x1_format, output_format): + """ + Feature: quant_batch_matmul_internal operator input_x1 dimension alignment validation + Description: Test that unaligned large/edge m dimension raise exception + Expectation: CheckDimensionAlignment validation rejects unaligned inputs + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + quant_batch_matmul_internal(m, 4096, 4096, trans_x1=False, trans_x2=True, + mstype=ms.float16, x1_format=x1_format, output_format=output_format, + output_dtype=ms.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 10, 20]) +@pytest.mark.parametrize('x1_format', [FRACTAL_NZ]) +@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_m_4096_4096_False_True_float16_nz_input_pynative(exec_mode, m, x1_format, output_format): + """ + Feature: quant_batch_matmul_internal operator input_x1 dimension alignment validation + Description: Test that unaligned large/edge m dimension raise exception + Expectation: CheckDimensionAlignment validation rejects unaligned inputs + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + with pytest.raises(RuntimeError, match="dimension must be aligned"): + quant_batch_matmul_internal(m, 4096, 4096, trans_x1=False, trans_x2=True, + mstype=ms.float16, x1_format=x1_format, output_format=output_format, + output_dtype=ms.float16) + logging.info( + "Unaligned dimension correctly rejected: shape=%s", + (m, 4096) + ) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 256, 1024]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_m_4096_4096_False_True_float16_nz_input_unaligned_k_n(exec_mode, m): + """ + Feature: test quant_batch_matmul_internal operator in graph mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + quant_batch_matmul_internal(m, 1234, 1234, trans_x1=False, trans_x2=True, mstype=ms.float16, + output_dtype=ms.float16) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 32, 1024]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_m_2048_2234_False_True_float16_nd_input_unaligned_n(exec_mode, m): + """ + Feature: test quant_batch_matmul_internal operator in graph mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + quant_batch_matmul_internal(m, 2048, 2234, trans_x1=False, trans_x2=True, mstype=ms.float16, + output_dtype=ms.float16) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('with_bias', [False, True]) +@pytest.mark.parametrize('with_pertoken_scale', [False, True]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_with_bias_pertoken(exec_mode, with_bias, with_pertoken_scale): + """ + Feature: test quant_batch_matmul_internal operator with bias and pertoken_scale + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + quant_batch_matmul_internal(1024, 2048, 2048, trans_x1=False, trans_x2=True, + with_bias=with_bias, with_pertoken_scale=with_pertoken_scale, + mstype=ms.float16, output_dtype=ms.float16) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.env_onecard +def test_quant_batch_matmul_internal_in_real_shape_increment(exec_mode): + """ + Feature: test quant_batch_matmul_internal operator in graph mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + prof_flag = False + profiler = Profiler(start_profile=False, output_path="profiler") + profiler.start() + quant_batch_matmul_internal(16, 2752, 4096, trans_x1=False, trans_x2=True, + mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + quant_batch_matmul_internal(16, 32, 4096, trans_x1=False, trans_x2=True, + mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + quant_batch_matmul_internal(16, 4096, 32, trans_x1=False, trans_x2=True, + mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + quant_batch_matmul_internal(16, 4096, 8256, trans_x1=False, trans_x2=True, + mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + profiler.stop() + profiler.analyse() + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [514, 1024]) +@pytest.mark.parametrize("prof_flag_str", [0, 1]) +def test_quant_batch_matmul_internal_in_real_shape_prefill(exec_mode, m, prof_flag_str): + """ + Feature: test quant_batch_matmul_internal operator in graph mode + Description: test quant_batch_matmul_internal. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + prof_flag = bool(int(prof_flag_str)) + profiler = Profiler(start_profile=False, output_path="profiler") + profiler.start() + quant_batch_matmul_internal(m, 2752, 4096, trans_x1=False, trans_x2=True, + mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + quant_batch_matmul_internal(m, 32, 4096, trans_x1=False, trans_x2=True, + mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + quant_batch_matmul_internal(m, 4096, 32, trans_x1=False, trans_x2=True, + mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + quant_batch_matmul_internal(m, 4096, 8256, trans_x1=False, trans_x2=True, + mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + quant_batch_matmul_internal(m, 5504, 4096, trans_x1=False, trans_x2=True, + mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + profiler.stop() + profiler.analyse() -- Gitee From 6b90a9060d3ede3b323d5f37865d1daef9695e1d Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Wed, 12 Nov 2025 16:38:42 +0800 Subject: [PATCH 2/3] fix bug --- .../quant_batch_matmul_internal.cc | 12 +- tests/st/test_quant_batch_matmul_internal.py | 734 +++++++++--------- 2 files changed, 359 insertions(+), 387 deletions(-) diff --git a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc index 512c717..3b7464f 100644 --- a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc +++ b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc @@ -47,8 +47,8 @@ enum class QuantBatchMatmulInternalOutputIndex : size_t { kOutputIndex = 0, }; -ShapeVector BatchMatMulMakeShape(const ShapeVector x1_shape, const ShapeVector x2_shape, - bool transpose_x1, bool transpose_x2, size_t offset) { +ShapeVector BatchMatMulInternalMakeShape(const ShapeVector x1_shape, const ShapeVector x2_shape, + bool transpose_x1, bool transpose_x2, size_t offset) { if (x1_shape.size() < kQbmmMatSize || x2_shape.size() < kQbmmMatSize) { MS_LOG(EXCEPTION) << "For 'QuantBatchMatmulInternal', the dimension of 'x1' and 'x2' " << "should be at least 2, but got " << x1_shape << " and " << x2_shape; @@ -131,7 +131,7 @@ class OPS_API QuantBatchMatmulInternalOpFuncImpl : public OpFuncImpl { bool transpose_x2 = input_infos[static_cast( QuantBatchMatmulInternalInputIndex::kInputTransposeX2Index)]->GetScalarValueWithCheck(); ShapeVector out_shape = - BatchMatMulMakeShape(x1_shape, x2_shape, transpose_x1, transpose_x2, kQbmmMatSize); + BatchMatMulInternalMakeShape(x1_shape, x2_shape, transpose_x1, transpose_x2, kQbmmMatSize); return {out_shape}; } @@ -258,6 +258,8 @@ class QuantBatchMatmulInternalRunner : public InternalPyboostRunner { private: bool transpose_x1_{false}; bool transpose_x2_{false}; + bool with_pertoken_scale_{false}; + bool with_bias_{false}; DataFormat x1_format_{DataFormat::ND}; DataFormat x2_format_{DataFormat::ND}; DataFormat output_format_{DataFormat::ND}; @@ -291,8 +293,8 @@ ms::Tensor npu_quant_batch_matmul_internal( // Infer output shape and type auto transpose_x1_val = transpose_x1.value_or(false); auto transpose_x2_val = transpose_x2.value_or(false); - auto output_shape = BatchMatMulMakeShape(x1.shape(), x2.shape(), transpose_x1_val, - transpose_x2_val, kQbmmMatSize); + auto output_shape = BatchMatMulInternalMakeShape(x1.shape(), x2.shape(), transpose_x1_val, + transpose_x2_val, kQbmmMatSize); if (output_format.has_value() && static_cast(output_format.value()) == DataFormat::FRACTAL_NZ) { CheckShapeHWAlignment(output_shape, x1.data_type()); diff --git a/tests/st/test_quant_batch_matmul_internal.py b/tests/st/test_quant_batch_matmul_internal.py index 9ba57fc..b81ac58 100644 --- a/tests/st/test_quant_batch_matmul_internal.py +++ b/tests/st/test_quant_batch_matmul_internal.py @@ -14,571 +14,541 @@ # ============================================================================ import os -import sys -import logging import numpy as np import pytest -import mindspore as ms from mindspore import Profiler + +import mindspore as ms from mindspore import context -from mindspore.common.np_dtype import bfloat16 -from mindspore.common.api import jit from mindspore._c_expression import MSContext -from functools import wraps -from st_utils import custom_compare import ms_custom_ops np.set_printoptions(precision=2, suppress=True, linewidth=200) -def jit_for_graph_mode(fn): - """ - A decorator that conditionally applies jit to a function at runtime based on the context mode. - """ - jitted_fn = jit(fn) - @wraps(fn) - def wrapper(*args, **kwargs): - if context.get_context("mode") == context.GRAPH_MODE: - return jitted_fn(*args, **kwargs) - return fn(*args, **kwargs) - return wrapper - +# Format constants ND = 0 FRACTAL_NZ = 1 -def np_quant_batch_matmul_compute(x1, x2, scale, offset=None, bias=None, pertoken_scale=None, - transpose_x1=False, transpose_x2=False): - """Compute quant_batch_matmul using numpy for comparison.""" - # Handle transpose - if transpose_x1: - if len(x1.shape) == 2: - x1 = x1.T - else: - x1 = np.transpose(x1, tuple(range(len(x1.shape) - 2)) + (-1, -2)) - if transpose_x2: - if len(x2.shape) == 2: - x2 = x2.T - else: - x2 = np.transpose(x2, tuple(range(len(x2.shape) - 2)) + (-1, -2)) - - # Matrix multiplication - result = np.matmul(x1.astype(np.float32), x2.astype(np.float32)).astype(np.int32) - - # Add offset if provided - if offset is not None: - result = result + offset.astype(np.int32) - - # Add bias if provided +def process_deq_scale(deq_scale) -> np.ndarray: + new_deq_scale = np.frombuffer(deq_scale.tobytes(), dtype=np.uint32) + return new_deq_scale.astype(np.int64) + + +def np_qbmm_compute(a, b, tmp_scale, bias=None, tmp_pertoken_scale=None): + c = np.dot(a.astype(np.float32), b.astype(np.float32)).astype(np.int32) if bias is not None: - bias_fp32 = bias.astype(np.float32) - if len(result.shape) == 2: - result = result.astype(np.float32) + bias_fp32.reshape(-1) - else: - result = result.astype(np.float32) + bias_fp32 - - # Apply scale - scale_fp32 = scale.astype(np.float32) - if len(scale_fp32.shape) == 0 or scale_fp32.shape[0] == 1: - result = result.astype(np.float32) * scale_fp32.item() - else: - if len(result.shape) == 2: - result = result.astype(np.float32) * scale_fp32.reshape(-1) - else: - result = result.astype(np.float32) * scale_fp32.reshape(-1) - - # Apply pertoken_scale if provided - if pertoken_scale is not None: - pertoken_scale_fp32 = pertoken_scale.astype(np.float32) - if len(result.shape) == 2: - result = result * pertoken_scale_fp32.reshape(-1, 1) - else: - result = result * pertoken_scale_fp32.reshape(-1, 1, 1) - - return result.astype(np.float32) - - -class QuantBatchMatmulInternalCustom(ms.nn.Cell): - def __init__(self, weight, scale, bias, pertoken_scale, ta, tb, - x1_format=ND, x2_format=ND, output_format=ND, output_dtype=ms.float16): + c = c + bias + c = c.astype(np.float32) * tmp_scale + if tmp_pertoken_scale is not None: + pertoken_scale = tmp_pertoken_scale[:, np.newaxis] + c = c * pertoken_scale + c = c.astype(np.float16) + return c + + +class Qbmm(ms.nn.Cell): + def __init__(self, weight, scale, bias, pertoken_scale, trans_a, trans_b, dst_dtype, + x1_format=0, x2_format=0, output_format=0): super().__init__() self.weight = ms.Parameter(weight, requires_grad=False) - self.weight.set_data(weight) self.scale = ms.Parameter(scale, requires_grad=False) - if bias is not None: - self.bias = ms.Parameter(bias, requires_grad=False) - else: - self.bias = None - if pertoken_scale is not None: - self.pertoken_scale = ms.Parameter(pertoken_scale, requires_grad=False) - else: - self.pertoken_scale = None - self.trans_x1 = ta - self.trans_x2 = tb + self.bias = bias + self.pertoken_scale = pertoken_scale + self.trans_a = trans_a + self.trans_b = trans_b + self.dst_dtype = dst_dtype self.x1_format = x1_format self.x2_format = x2_format self.output_format = output_format - self.output_dtype = output_dtype - @jit_for_graph_mode - def construct(self, i0): + def construct(self, x): if self.x1_format == FRACTAL_NZ: - i0 = ms_custom_ops.trans_data(i0, transdata_type=1) # ND_TO_FRACTAL_NZ - + x = ms_custom_ops.trans_data(x, transdata_type=1) # ND_TO_FRACTAL_NZ + output = ms_custom_ops.quant_batch_matmul_internal( - i0, self.weight, self.scale, - offset=None, # offset is optional - bias=self.bias, - pertoken_scale=self.pertoken_scale, - transpose_x1=self.trans_x1, - transpose_x2=self.trans_x2, - x1_format=self.x1_format, - x2_format=self.x2_format, - output_format=self.output_format, - output_dtype=self.output_dtype) - + x, self.weight, self.scale, None, self.bias, self.pertoken_scale, + transpose_x1=self.trans_a, transpose_x2=self.trans_b, + x1_format=self.x1_format, x2_format=self.x2_format, + output_format=self.output_format, output_dtype=self.dst_dtype + ) + if self.output_format == FRACTAL_NZ: - output = ms_custom_ops.trans_data(output, transdata_type=0, out_crops=None) # FRACTAL_NZ_TO_ND + output = ms_custom_ops.trans_data(output, transdata_type=0) # FRACTAL_NZ_TO_ND return output -def quant_batch_matmul_internal(m, k, n, trans_x1=False, trans_x2=False, - with_bias=False, with_pertoken_scale=False, - mstype=ms.float16, profiling=False, - x1_format=ND, x2_format=ND, output_format=ND, - output_dtype=None): +def qbmm(m, k, n, batch_m=0, trans_a=False, trans_b=False, dst_dtype=ms.float16, scale_dtype=ms.int64, + bias_none=False, pertoken_scale_none=True, profiling=False, is_dyn=False, x1_format=ND, x2_format=ND, output_format=ND): os.environ['USE_LLM_CUSTOM_MATMUL'] = "off" os.environ['INTERNAL_PRINT_TILING'] = "on" - if output_dtype is None: - output_dtype = mstype - - if ms.float16 == mstype: - np_type = np.float16 - elif ms.float32 == mstype: - np_type = np.float32 - elif ms.bfloat16 == mstype: - np_type = bfloat16 - else: - np_type = np.float16 - - # Generate random int8 inputs - np.random.seed(0) - if trans_x1: - i0_host = np.random.randint(-128, 127, size=[k, m]).astype(np.int8) - else: - i0_host = np.random.randint(-128, 127, size=[m, k]).astype(np.int8) - - if trans_x2: - i1_host = np.random.randint(-128, 127, size=[n, k]).astype(np.int8) - else: - i1_host = np.random.randint(-128, 127, size=[k, n]).astype(np.int8) - - # Generate scale (float16/bf16) - np.random.seed(0) - if output_dtype == ms.bfloat16: - scale_host = np.random.randn(n).astype(np.float32) - else: - scale_host = np.random.randn(n).astype(np.float16) - - # Generate bias if needed - bias_host = None - if with_bias: - np.random.seed(0) - if output_dtype == ms.bfloat16: - bias_host = np.random.randn(n).astype(np.float32) - else: - bias_host = np.random.randn(n).astype(np.float16) - - # Generate pertoken_scale if needed - pertoken_scale_host = None - if with_pertoken_scale: - np.random.seed(0) - if output_dtype == ms.bfloat16: - pertoken_scale_host = np.random.randn(m).astype(np.float32) - else: - pertoken_scale_host = np.random.randn(m).astype(np.float16) - - # Compute expected result - i0_host_fp32 = i0_host.astype(np.float32) - i1_host_fp32 = i1_host.astype(np.float32) - expect = np_quant_batch_matmul_compute(i0_host_fp32, i1_host_fp32, scale_host.astype(np.float32), - bias=bias_host.astype(np.float32) if bias_host is not None else None, - pertoken_scale=pertoken_scale_host.astype(np.float32) if pertoken_scale_host is not None else None, - transpose_x1=trans_x1, transpose_x2=trans_x2) - print("numpy compute done") - - input1 = ms.Tensor(i0_host, ms.int8) - input2 = ms.Tensor(i1_host, ms.int8) - scale_tensor = ms.Tensor(scale_host, output_dtype) - - bias_tensor = None - if with_bias: - bias_tensor = ms.Tensor(bias_host, output_dtype) - - pertoken_scale_tensor = None - if with_pertoken_scale: - pertoken_scale_tensor = ms.Tensor(pertoken_scale_host, output_dtype) - - # Handle format conversion for x2 + a_shape = (m, k) + # 在多卡并行的场景里,要控制随机种子以防多卡numpy生成不一致 + seed = 0 + np.random.seed(seed) + a = np.random.uniform(-20, 20, size=a_shape).astype(np.int8) + np.random.seed(seed) + b = np.random.uniform(-20, 20, size=(k, n)).astype(np.int8) + np.random.seed(seed) + bias = np.random.randint(-10, 10, (n)).astype(np.int32) + np.random.seed(seed) + tmp_scale = np.random.rand(n).astype(np.float32) / 1000 + scale = process_deq_scale(tmp_scale) + bias_fp16 = bias * tmp_scale + if bias_none: + bias = None + + np.random.seed(seed) + tmp_pertoken_scale = np.random.rand(m).astype(np.float32) + pertoken_scale_ms = ms.Tensor(tmp_pertoken_scale, ms.float32) + if pertoken_scale_none: + tmp_pertoken_scale = None + pertoken_scale_ms = None + + # When the output is bf16, currently only support scale of float32 type. + if scale_dtype == ms.float32 or dst_dtype == ms.bfloat16: + scale_ms = ms.Tensor(tmp_scale, ms.float32) + elif scale_dtype == ms.int64: + scale_ms = ms.Tensor(scale, ms.int64) + + expect_np = np_qbmm_compute(a, b, tmp_scale, bias, tmp_pertoken_scale) + + if trans_a: + a = np.transpose(a, (1, 0)) + if trans_b: + b = np.transpose(b, (1, 0)) + + a_ms = ms.Tensor(a, ms.int8) + b_ms = ms.Tensor(b, ms.int8) + if x2_format == FRACTAL_NZ or MSContext.get_instance().get_ascend_soc_version() == "ascend310p": - input2 = ms_custom_ops.trans_data(input2, transdata_type=1) # ND_TO_FRACTAL_NZ - x2_format = FRACTAL_NZ - else: - x2_format = ND + b_ms = ms_custom_ops.trans_data(b_ms, transdata_type=1) # ND_TO_FRACTAL_NZ + if x2_format == ND: + x2_format = FRACTAL_NZ + + net = None + bias_ms = ms.Tensor(bias, ms.int32) + if bias_none: + bias_ms = None + net = Qbmm(b_ms, scale_ms, bias_ms, pertoken_scale_ms, trans_a, trans_b, dst_dtype, + x1_format=x1_format, x2_format=x2_format, output_format=output_format) - net = QuantBatchMatmulInternalCustom(input2, scale_tensor, bias_tensor, pertoken_scale_tensor, - trans_x1, trans_x2, x1_format, x2_format, output_format, output_dtype) + if is_dyn: + input_dyn = ms.Tensor(shape=(None), dtype=ms.int8) + net.set_inputs(input_dyn) if profiling: - for i in range(50): - output = net(input1) + for _ in range(50): + output = net(a_ms) return - output = net(input1) - output_fp32 = output.astype(ms.float32) - output_np = output_fp32.asnumpy() - - # Use custom_compare for better accuracy handling - res = custom_compare(expect, output_np, output_dtype) - assert res, "quant_batch_matmul_internal compare fail." + output = net(a_ms) + output_np = output.asnumpy() + acc = 0.01 + res = np.allclose(output_np, expect_np, acc, acc) + assert res, "qbmm compare fail." @pytest.mark.level0 +@pytest.mark.platform_ascend910b @pytest.mark.platform_ascend310p @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('trans_x2', [False, True]) -@pytest.mark.parametrize('mstype', [ms.float16]) -@pytest.mark.parametrize('x1_format', [ND, FRACTAL_NZ]) -@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('m', [1, 256, 1024]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_1024_1024_1024_nz_input_fp16(exec_mode, trans_x2, mstype, x1_format, output_format, request): +def test_qbmm_add_m_4096_4096_false_true_nz_input(m, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph and pynative mode - Description: test quant_batch_matmul_internal. + Feature: test qbmm operator in graph mode + Description: test qbmm. Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - quant_batch_matmul_internal(1024, 1024, 1024, trans_x1=False, trans_x2=trans_x2, - mstype=mstype, x1_format=x1_format, output_format=output_format) - + qbmm(m, 4096, 4096, trans_a=False, trans_b=True) @pytest.mark.level0 +@pytest.mark.platform_ascend310p @pytest.mark.platform_ascend910b @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('trans_x2', [False, True]) -@pytest.mark.parametrize('mstype', [ms.bfloat16]) +@pytest.mark.parametrize('m', [1, 32, 256, 512, 1024, 4096]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_1024_1024_1024_input_bfp16(exec_mode, trans_x2, mstype, request): +def test_qbmm_add_m_4096_4096_false_true_nd_input(m, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph mode - Description: test quant_batch_matmul_internal. + Feature: test qbmm operator in graph mode + Description: test qbmm. Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - if "platform_ascend310p" in request.config.getoption("-m") and mstype is ms.bfloat16: - pytest.skip("Skipping ms.bfloat16 for 310p mark") - quant_batch_matmul_internal(1024, 1024, 1024, trans_x1=False, trans_x2=trans_x2, mstype=mstype, - output_dtype=ms.bfloat16) + qbmm(m, 2048, 2048, trans_a=False, trans_b=True) -@pytest.mark.level0 +@pytest.mark.level2 @pytest.mark.platform_ascend910b @pytest.mark.platform_ascend310p @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('trans_x2', [False, True]) -@pytest.mark.parametrize('mstype', [ms.float16]) +@pytest.mark.parametrize('input_shape', [(128, 2560, 5120), + (16, 11264, 6912), (16, 6912, 11264)]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_2048_2048_2048_nd_input_fp16(exec_mode, trans_x2, mstype, request): +def test_qbmm_add_false_true_nd_input(input_shape, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph mode - Description: test quant_batch_matmul_internal. + Feature: testqbmm operator in graph mode + Description: testqbmm. Expectation: the result is correct """ + m, k, n = input_shape ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - quant_batch_matmul_internal(2048, 2048, 2048, trans_x1=False, trans_x2=trans_x2, mstype=mstype, - output_dtype=ms.float16) - + qbmm(m, k, n, trans_a=False, trans_b=True) @pytest.mark.level2 +@pytest.mark.platform_ascend310p @pytest.mark.platform_ascend910b @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('trans_x2', [False, True]) -@pytest.mark.parametrize('mstype', [ms.bfloat16]) +@pytest.mark.parametrize('input_shape', [(128, 5120, 10240), + (1024, 5632, 3456), (1024, 3456, 11264)]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_2048_2048_2048_nd_input_bfp16(exec_mode, trans_x2, mstype, request): +def test_qbmm_add_false_true_nz_input(input_shape, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph mode - Description: test quant_batch_matmul_internal. + Feature: testqbmm operator in graph mode + Description: testqbmm. Expectation: the result is correct """ + m, k, n = input_shape ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - if "platform_ascend310p" in request.config.getoption("-m") and mstype is ms.bfloat16: - pytest.skip("Skipping ms.bfloat16 for 310p mark") - quant_batch_matmul_internal(2048, 2048, 2048, trans_x1=False, trans_x2=trans_x2, mstype=mstype, - output_dtype=ms.bfloat16) - + qbmm(m, k, n, trans_a=False, trans_b=True) -@pytest.mark.level0 +@pytest.mark.level2 @pytest.mark.platform_ascend310p +@pytest.mark.platform_ascend910b @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('trans_x2', [True]) -@pytest.mark.parametrize('mstype', [ms.float16]) -@pytest.mark.parametrize('x1_format', [ND, FRACTAL_NZ]) -@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('input_shape', [(128, 1234, 2234), + (1024, 2234, 1234), (1024, 2234, 5234)]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_1024_1234_1234_input_unaligned_k_n_fp16(exec_mode, trans_x2, mstype, x1_format, output_format, request): +def test_qbmm_add_false_true_nz_input_unaligned_k_n(input_shape, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph and pynative mode - Description: Test that unaligned large/edge n dimension raise exception - Expectation: CheckDimensionAlignment validation rejects unaligned inputs + Feature: testqbmm operator in graph mode + Description: testqbmm. + Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - with pytest.raises(RuntimeError, match="dimension must be aligned"): - quant_batch_matmul_internal(1024, 1234, 1234, trans_x1=False, trans_x2=trans_x2, - mstype=mstype, x1_format=x1_format, output_format=output_format, - output_dtype=ms.float16) - logging.info( - "Unaligned dimension correctly rejected: shape=%s", - (1234, 1234) - ) - + m, k, n = input_shape + qbmm(m, k, n, trans_a=False, trans_b=True) @pytest.mark.level2 +@pytest.mark.platform_ascend310p @pytest.mark.platform_ascend910b @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('trans_x2', [False, True]) -@pytest.mark.parametrize('mstype', [ms.bfloat16]) +@pytest.mark.parametrize('input_shape', [(128, 1024, 2234), + (1024, 2048, 1234), (1024, 2048, 5234)]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_1024_1234_1234_nz_input_unaligned_k_n_bfp16(exec_mode, trans_x2, mstype, request): +def test_qbmm_add_false_true_nd_input_unaligned_n(input_shape, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph mode - Description: test quant_batch_matmul_internal. + Feature: testqbmm operator in graph mode + Description: testqbmm. Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - if "platform_ascend310p" in request.config.getoption("-m") and mstype is ms.bfloat16: - pytest.skip("Skipping ms.bfloat16 for 310p mark") - quant_batch_matmul_internal(1024, 1234, 1234, trans_x1=False, trans_x2=trans_x2, mstype=mstype, - output_dtype=ms.bfloat16) - + m, k, n = input_shape + qbmm(m, k, n, trans_a=False, trans_b=True) @pytest.mark.level2 @pytest.mark.platform_ascend910b @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('trans_x2', [False, True]) -@pytest.mark.parametrize('mstype', [ms.float16]) +@pytest.mark.parametrize('batch_size', [0, 3, 500, 2000]) +@pytest.mark.parametrize('is_bias_none', [True, False]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_1024_2048_2234_nd_input_unaligned_n_fp16(exec_mode, trans_x2, mstype, request): +def test_qbmm_16_250_10_false_true_nz_input_unaligned_k_n_batch_size(batch_size, is_bias_none, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph mode - Description: test quant_batch_matmul_internal. + Feature: testqbmm operator in graph mode + Description: testqbmm. Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - quant_batch_matmul_internal(1024, 2048, 2234, trans_x1=False, trans_x2=trans_x2, mstype=mstype, - output_dtype=ms.float16) - + qbmm(16, 250, 10, trans_a=False, trans_b=True, batch_m=batch_size, bias_none=is_bias_none) @pytest.mark.level2 @pytest.mark.platform_ascend910b @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('trans_x2', [False, True]) -@pytest.mark.parametrize('mstype', [ms.bfloat16]) +@pytest.mark.parametrize('batch_size', [0, 3, 500, 2000]) +@pytest.mark.parametrize('is_bias_none', [True, False]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_1024_2048_2234_nd_input_unaligned_n_bfp16(exec_mode, trans_x2, mstype, request): +def test_qbmm_16_256_10_false_true_nd_input_unaligned_n_batch_size(batch_size, is_bias_none, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph mode - Description: test quant_batch_matmul_internal. + Feature: testqbmm operator in graph mode + Description: testqbmm. Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - if "platform_ascend310p" in request.config.getoption("-m") and mstype is ms.bfloat16: - pytest.skip("Skipping ms.bfloat16 for 310p mark") - quant_batch_matmul_internal(1024, 2048, 2234, trans_x1=False, trans_x2=trans_x2, mstype=mstype, - output_dtype=ms.bfloat16) + qbmm(16, 256, 10, trans_a=False, trans_b=True, batch_m=batch_size, bias_none=is_bias_none) - -@pytest.mark.level0 +@pytest.mark.level2 @pytest.mark.platform_ascend310p +@pytest.mark.platform_ascend910b @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('m', [1, 10, 20]) -@pytest.mark.parametrize('x1_format', [ND]) -@pytest.mark.parametrize('output_format', [ND]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_m_4096_4096_False_True_float16_nd_input(exec_mode, m, x1_format, output_format): +def input_matmul_add_32_32_32_false_true_nd_input(exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph and pynative mode - Description: test quant_batch_matmul_internal. + Feature: test qbmm operator in graph mode + Description: test qbmm. Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - quant_batch_matmul_internal(m, 4096, 4096, trans_x1=False, trans_x2=True, - mstype=ms.float16, x1_format=x1_format, output_format=output_format, - output_dtype=ms.float16) - + qbmm(16, 32, 64, trans_a=False, trans_b=True) -@pytest.mark.level0 +@pytest.mark.level2 +@pytest.mark.platform_ascend910b @pytest.mark.platform_ascend310p -@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE]) -@pytest.mark.parametrize('m', [1, 10, 20]) -@pytest.mark.parametrize('x1_format', [FRACTAL_NZ]) -@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_m_4096_4096_False_True_float16_nz_input(exec_mode, m, x1_format, output_format): +def input_matmul_add_32_32_32_false_true_nz_input(exec_mode): """ - Feature: quant_batch_matmul_internal operator input_x1 dimension alignment validation - Description: Test that unaligned large/edge m dimension raise exception - Expectation: CheckDimensionAlignment validation rejects unaligned inputs + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - quant_batch_matmul_internal(m, 4096, 4096, trans_x1=False, trans_x2=True, - mstype=ms.float16, x1_format=x1_format, output_format=output_format, - output_dtype=ms.float16) - + qbmm(16, 64, 128, trans_a=False, trans_b=True) -@pytest.mark.level0 +@pytest.mark.level2 +@pytest.mark.platform_ascend910b @pytest.mark.platform_ascend310p -@pytest.mark.parametrize("exec_mode", [context.PYNATIVE_MODE]) -@pytest.mark.parametrize('m', [1, 10, 20]) -@pytest.mark.parametrize('x1_format', [FRACTAL_NZ]) -@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_m_4096_4096_False_True_float16_nz_input_pynative(exec_mode, m, x1_format, output_format): +def test_qbmm_16_32_64_false_true_bias_none_nd_input(exec_mode): """ - Feature: quant_batch_matmul_internal operator input_x1 dimension alignment validation - Description: Test that unaligned large/edge m dimension raise exception - Expectation: CheckDimensionAlignment validation rejects unaligned inputs + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - with pytest.raises(RuntimeError, match="dimension must be aligned"): - quant_batch_matmul_internal(m, 4096, 4096, trans_x1=False, trans_x2=True, - mstype=ms.float16, x1_format=x1_format, output_format=output_format, - output_dtype=ms.float16) - logging.info( - "Unaligned dimension correctly rejected: shape=%s", - (m, 4096) - ) + qbmm(16, 32, 64, trans_a=False, trans_b=True, bias_none=True) +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.env_onecard +def test_qbmm_16_32_64_false_true_bias_none_nz_input(exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(16, 64, 128, trans_a=False, trans_b=True, bias_none=True) @pytest.mark.level2 @pytest.mark.platform_ascend910b @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('m', [1, 256, 1024]) +@pytest.mark.parametrize('is_bias_none', [True, False]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_m_4096_4096_False_True_float16_nz_input_unaligned_k_n(exec_mode, m): +def test_qbmm_16_16_32_64_false_true_nd_input(is_bias_none, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph mode - Description: test quant_batch_matmul_internal. + Feature: test qbmm operator in graph mode + Description: test qbmm. Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - quant_batch_matmul_internal(m, 1234, 1234, trans_x1=False, trans_x2=True, mstype=ms.float16, - output_dtype=ms.float16) + qbmm(m=16, k=32, n=64, batch_m=3, trans_a=False, trans_b=True, + bias_none=is_bias_none) +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('is_bias_none', [True, False]) +@pytest.mark.env_onecard +def test_qbmm_16_16_32_64_false_true_nz_input(is_bias_none, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m=16, k=64, n=128, batch_m=3, trans_a=False, + trans_b=True, bias_none=is_bias_none) @pytest.mark.level2 @pytest.mark.platform_ascend910b @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('m', [1, 32, 1024]) +@pytest.mark.parametrize('m', [1, 512, 1024]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_m_2048_2234_False_True_float16_nd_input_unaligned_n(exec_mode, m): +def test_qbmm_m_4096_4096_false_true_bf16(m, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph mode - Description: test quant_batch_matmul_internal. + Feature: test qbmm operator in graph mode + Description: test qbmm. Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - quant_batch_matmul_internal(m, 2048, 2234, trans_x1=False, trans_x2=True, mstype=ms.float16, - output_dtype=ms.float16) + qbmm(m, 4096, 4096, trans_a=False, trans_b=True, dst_dtype=ms.bfloat16) + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 31, 32, 512, 520, 1024]) +@pytest.mark.parametrize('k', [32, 1024]) +@pytest.mark.parametrize('batch_size', [2, 3]) +@pytest.mark.env_onecard +def test_qbmm_with_batch_prefill(m, k, batch_size, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, k, 1024, batch_m=batch_size, trans_a=False, trans_b=True) @pytest.mark.level2 @pytest.mark.platform_ascend910b @pytest.mark.platform_ascend310p @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('with_bias', [False, True]) -@pytest.mark.parametrize('with_pertoken_scale', [False, True]) +@pytest.mark.parametrize('m', [1, 31, 32, 64, 65]) +@pytest.mark.parametrize('k', [32, 64, 128]) +@pytest.mark.parametrize('batch_size', [2, 3]) +@pytest.mark.env_onecard +def test_qbmm_with_batch_increment_310p(m, k, batch_size, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, k, 128, batch_m=batch_size, trans_a=False, trans_b=True) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [48, 128, 1024]) +@pytest.mark.parametrize('x1_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_with_bias_pertoken(exec_mode, with_bias, with_pertoken_scale): +def test_qbmm_with_fp32_scale_ds(m, x1_format, output_format, exec_mode): """ - Feature: test quant_batch_matmul_internal operator with bias and pertoken_scale - Description: test quant_batch_matmul_internal. + Feature: test qbmm operator in graph mode + Description: test qbmm. Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - quant_batch_matmul_internal(1024, 2048, 2048, trans_x1=False, trans_x2=True, - with_bias=with_bias, with_pertoken_scale=with_pertoken_scale, - mstype=ms.float16, output_dtype=ms.float16) + qbmm(m, 7168, 1536, trans_b=True, scale_dtype=ms.float32, x1_format=x1_format, output_format=output_format) + qbmm(m, 1536, 24576, trans_b=True, scale_dtype=ms.float32, x1_format=x1_format, output_format=output_format) + qbmm(m, 7168, 576, trans_b=True, scale_dtype=ms.float32, x1_format=x1_format, output_format=output_format) + qbmm(m, 16384, 7168, trans_b=True, scale_dtype=ms.float32, x1_format=x1_format, output_format=output_format) @pytest.mark.level2 -@pytest.mark.platform_ascend910b @pytest.mark.platform_ascend310p @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [48, 128, 1024]) +@pytest.mark.parametrize('x1_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) @pytest.mark.env_onecard -def test_quant_batch_matmul_internal_in_real_shape_increment(exec_mode): +def test_qbmm_with_pertoken_ds(m, x1_format, output_format, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph mode - Description: test quant_batch_matmul_internal. + Feature: test qbmm operator in graph mode + Description: test qbmm. Expectation: the result is correct """ ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - prof_flag = False + qbmm(m, 7168, 36864, trans_b=True, bias_none=True, scale_dtype=ms.float32, + pertoken_scale_none=False, x1_format=x1_format, output_format=output_format) + qbmm(m, 7168, 18432, trans_b=True, bias_none=True, scale_dtype=ms.float32, + pertoken_scale_none=False, x1_format=x1_format, output_format=output_format) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE]) +@pytest.mark.parametrize('m', [1, 48, 128]) +@pytest.mark.parametrize('bias_none', [False, True]) +@pytest.mark.parametrize('x1_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('output_format', [ND, FRACTAL_NZ]) +@pytest.mark.parametrize('pertoken_scale_none', [False, True]) +@pytest.mark.env_onecard +def test_qbmm_with_scale(m, bias_none, x1_format, output_format, pertoken_scale_none, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + qbmm(m, 256, 256, trans_b=True, bias_none=bias_none, scale_dtype=ms.float32, + pertoken_scale_none=pertoken_scale_none, x1_format=x1_format, output_format=output_format) + qbmm(m, 1024, 1024, trans_b=True, bias_none=bias_none, scale_dtype=ms.float32, + pertoken_scale_none=pertoken_scale_none, x1_format=x1_format, output_format=output_format) + + +@pytest.mark.level2 +@pytest.mark.platform_ascend910b +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 16, 48, 128]) +@pytest.mark.env_onecard +def test_qbmm_increment_prof(m, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + prof_flag = True + ms.set_context(device_target="Ascend", mode=exec_mode, jit_config={"jit_level": "O0", "infer_boost": "on"}) profiler = Profiler(start_profile=False, output_path="profiler") profiler.start() - quant_batch_matmul_internal(16, 2752, 4096, trans_x1=False, trans_x2=True, - mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) - quant_batch_matmul_internal(16, 32, 4096, trans_x1=False, trans_x2=True, - mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) - quant_batch_matmul_internal(16, 4096, 32, trans_x1=False, trans_x2=True, - mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) - quant_batch_matmul_internal(16, 4096, 8256, trans_x1=False, trans_x2=True, - mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + qbmm(m, 4096, 3072, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=ND, x2_format=ND, output_format=ND) + qbmm(m, 2048, 4096, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=ND, x2_format=ND, output_format=ND) + qbmm(m, 4096, 16512, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=ND, x2_format=ND, output_format=ND) profiler.stop() profiler.analyse() - @pytest.mark.level2 @pytest.mark.platform_ascend910b @pytest.mark.platform_ascend310p @pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) @pytest.mark.parametrize('m', [514, 1024]) -@pytest.mark.parametrize("prof_flag_str", [0, 1]) -def test_quant_batch_matmul_internal_in_real_shape_prefill(exec_mode, m, prof_flag_str): +@pytest.mark.env_onecard +def test_qbmm_prefill_prof(m, exec_mode): """ - Feature: test quant_batch_matmul_internal operator in graph mode - Description: test quant_batch_matmul_internal. + Feature: test qbmm operator in graph mode + Description: test qbmm. Expectation: the result is correct """ - ms.set_context(device_target="Ascend", mode=exec_mode) - ms.set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - prof_flag = bool(int(prof_flag_str)) + prof_flag = True + ms.set_context(device_target="Ascend", mode=exec_mode, jit_config={"jit_level": "O0", "infer_boost": "on"}) profiler = Profiler(start_profile=False, output_path="profiler") profiler.start() - quant_batch_matmul_internal(m, 2752, 4096, trans_x1=False, trans_x2=True, - mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) - quant_batch_matmul_internal(m, 32, 4096, trans_x1=False, trans_x2=True, - mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) - quant_batch_matmul_internal(m, 4096, 32, trans_x1=False, trans_x2=True, - mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) - quant_batch_matmul_internal(m, 4096, 8256, trans_x1=False, trans_x2=True, - mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) - quant_batch_matmul_internal(m, 5504, 4096, trans_x1=False, trans_x2=True, - mstype=ms.float16, profiling=prof_flag, output_dtype=ms.float16) + qbmm(m, 4096, 3072, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=FRACTAL_NZ, x2_format=FRACTAL_NZ, output_format=FRACTAL_NZ) + qbmm(m, 2048, 4096, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=FRACTAL_NZ, x2_format=FRACTAL_NZ, output_format=FRACTAL_NZ) + qbmm(m, 2048, 11008, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=FRACTAL_NZ, x2_format=FRACTAL_NZ, output_format=FRACTAL_NZ) + qbmm(m, 4096, 16512, trans_a=False, trans_b=True, profiling=prof_flag, x1_format=FRACTAL_NZ, x2_format=FRACTAL_NZ, output_format=FRACTAL_NZ) profiler.stop() profiler.analyse() + +@pytest.mark.level0 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 4096]) +@pytest.mark.env_onecard +def test_qbmm_fastgelu(m, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode) + prof_flag = False + qbmm(m, 4096, 11008, trans_a=False, trans_b=True, profiling=prof_flag) + +@pytest.mark.level1 +@pytest.mark.platform_ascend310p +@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('m', [1, 16, 1024]) +@pytest.mark.parametrize('k', [1024, 2048]) +@pytest.mark.parametrize('n', [1024, 2048]) +@pytest.mark.parametrize('is_dyn', [False, True]) +@pytest.mark.env_onecard +def test_qbmm_fastgelu_dyn(m, k, n, is_dyn, exec_mode): + """ + Feature: test qbmm operator in graph mode + Description: test qbmm. + Expectation: the result is correct + """ + ms.set_context(device_target="Ascend", mode=exec_mode, jit_config={"jit_level": "O0", "infer_boost": "on"}) + prof_flag = False + qbmm(m, k, n, trans_a=False, trans_b=True, profiling=prof_flag, is_dyn=is_dyn) -- Gitee From 85e89e672093361b84792edd8c1b0270166b3382 Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Fri, 14 Nov 2025 11:36:44 +0800 Subject: [PATCH 3/3] fix bugg --- .../quant_batch_matmul_internal.cc | 4 ++-- .../quant_batch_matmul_internal_op.yaml | 2 -- tests/st/test_quant_batch_matmul_internal.py | 15 --------------- 3 files changed, 2 insertions(+), 19 deletions(-) diff --git a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc index 3b7464f..63713de 100644 --- a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc +++ b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal.cc @@ -327,8 +327,8 @@ auto pyboost_quant_batch_matmul_internal( MS_CUSTOM_OPS_EXTENSION_MODULE(m) { m.def("quant_batch_matmul_internal", &pyboost_quant_batch_matmul_internal, "QuantBatchMatmulInternal", pybind11::arg("x1"), pybind11::arg("x2"), pybind11::arg("scale"), - pybind11::arg("offset") = pybind11::none(), pybind11::arg("bias") = pybind11::none(), - pybind11::arg("pertoken_scale") = pybind11::none(), + pybind11::arg("offset") = std::nullopt, pybind11::arg("bias") = std::nullopt, + pybind11::arg("pertoken_scale") = std::nullopt, pybind11::arg("transpose_x1") = false, pybind11::arg("transpose_x2") = false, pybind11::arg("x1_format") = 0, pybind11::arg("x2_format") = 0, pybind11::arg("output_format") = 0, pybind11::arg("output_dtype") = 0); diff --git a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml index 8eb2bf8..00ed986 100644 --- a/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml +++ b/ops/c_api/quant_batch_matmul_internal/quant_batch_matmul_internal_op.yaml @@ -35,8 +35,6 @@ quant_batch_matmul_internal: dtype: TypeId default: mstype.float16 arg_handler: dtype_to_type_id - args_signature: - dtype_group: (x1, x2) returns: y: dtype: tensor \ No newline at end of file diff --git a/tests/st/test_quant_batch_matmul_internal.py b/tests/st/test_quant_batch_matmul_internal.py index b81ac58..615bad4 100644 --- a/tests/st/test_quant_batch_matmul_internal.py +++ b/tests/st/test_quant_batch_matmul_internal.py @@ -150,21 +150,6 @@ def qbmm(m, k, n, batch_m=0, trans_a=False, trans_b=False, dst_dtype=ms.float16, assert res, "qbmm compare fail." -@pytest.mark.level0 -@pytest.mark.platform_ascend910b -@pytest.mark.platform_ascend310p -@pytest.mark.parametrize("exec_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('m', [1, 256, 1024]) -@pytest.mark.env_onecard -def test_qbmm_add_m_4096_4096_false_true_nz_input(m, exec_mode): - """ - Feature: test qbmm operator in graph mode - Description: test qbmm. - Expectation: the result is correct - """ - ms.set_context(device_target="Ascend", mode=exec_mode) - qbmm(m, 4096, 4096, trans_a=False, trans_b=True) - @pytest.mark.level0 @pytest.mark.platform_ascend310p @pytest.mark.platform_ascend910b -- Gitee