From e9a1d7288018d9f605c5250bb010ee0984430229 Mon Sep 17 00:00:00 2001 From: lihui Date: Fri, 12 Sep 2025 08:35:34 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- codegen/templates/_op_plugin_docs.py | 43 +++++ op_plugin/config/op_plugin_functions.yaml | 8 + .../torch_npu_OpApi_schema_all.json | 3 + .../test_npu_rms_norm_quant.py | 170 ++++++++++++++++++ 4 files changed, 224 insertions(+) create mode 100644 test/test_custom_ops/test_npu_rms_norm_quant.py diff --git a/codegen/templates/_op_plugin_docs.py b/codegen/templates/_op_plugin_docs.py index 3adf4be27..7b8ee3afc 100644 --- a/codegen/templates/_op_plugin_docs.py +++ b/codegen/templates/_op_plugin_docs.py @@ -9273,6 +9273,49 @@ y, rstd = torch_npu.npu_gemma_rms_norm(input_x, input_gamma) """ ) +_add_torch_npu_docstr( + "npu_rms_norm_quant", + """ +接口原型 +npu_rms_norm_quant(Tensor x, Tensor gamma, Tensor beta, Tensor scale, Tensor offset, float epsilon=1e-06) -> Tensor + +功能描述 +RmsNormQuant算子将RmsNorm算子以及RmsNorm后的Quantize算子融合起来,减少搬入搬出的操作。 + +参数说明 +x: Device侧的Tensor类型,标准化输入张量。shape支持1-8维,数据类型支持FLOAT16、BFLOAT16,格式支持ND。不支持空Tensor。 +gamma: Device侧的Tensor类型,归一化权重张量。shape为1-2维,需与x最后一维一致,数据类型与x一致。格式支持ND。不支持空Tensor。 +beta: Device侧的Tensor类型,归一化偏置项。shape和数据类型与x一致。格式支持ND。不支持空Tensor。 +scale: Device侧的Tensor类型,量化过程中得到y进行的scale张量,shape为1,维度为1.格式支持ND。不支持空Tensor。 +offset: Device侧的Tensor类型,量化过程中得到y进行的offset张量.shape与scale保持一致,格式支持ND。不支持空Tensor。 +epsilon: double类型,防止除0错误,默认值为1e-6. + +输出说明 +y: Device侧的Tensor类型。数据类型支持INT8。shape、数据格式需要与入参x保持一致。支持非连续的Tensor,不支持空Tensor。 + +约束说明 +x、y的尾轴长度,以及gamma的尾轴长度必大于等于32Bytes. + +支持的型号 +Atlas A3训练系列产品/Atlas A3推理系列产品 +Atlas A2训练系列产品/Atlas 800I A2推理产品/A200I A2 Box异构组件 + +调用示例 +import torch +import torch_npu + +eps = 1e-6 +x = torch.randn(16, dtype = torch.float16).npu() +gamma = torch.randn(16, dtype = torch.float16).npu() +beta = torch.zeros(16, dtype = torch.float16).npu() +scale = torch.ones(1, dtype = torch.float16).npu() +offset = torch.zeros(1, dtype = torch.int8).npu() + +y = torch_npu.npu_rms_norm_quant(x, gamma, beta, scale, offset, eps) +_ = y.cpu().numpy() +""" +) + _add_torch_npu_docstr( "npu_add_rms_norm_cast", """ diff --git a/op_plugin/config/op_plugin_functions.yaml b/op_plugin/config/op_plugin_functions.yaml index 8603580f8..112a8f840 100644 --- a/op_plugin/config/op_plugin_functions.yaml +++ b/op_plugin/config/op_plugin_functions.yaml @@ -5568,6 +5568,14 @@ custom: dtype: x1 exec: aclnnAddRmsNorm + - func: npu_rms_norm_quant(Tensor x, Tensor gamma, Tensor beta, Tensor scale, Tensor offset, float epsilon=1e-06) -> Tensor + op_opi: all_version + gen_opapi: + y: + size: x + dtype: at::kChar + exec: aclnnRmsNormQuant, x, gamma, beta, scale, offset, epsilon, y + - func: npu_add_rms_norm_cast(Tensor x1, Tensor x2, Tensor gamma, float epsilon=1e-06) -> (Tensor, Tensor, Tensor, Tensor) op_api: all_version gen_opapi: diff --git a/test/core_tests/torch_npu_OpApi_schema_all.json b/test/core_tests/torch_npu_OpApi_schema_all.json index faf3ff9e7..2b5ab82b5 100644 --- a/test/core_tests/torch_npu_OpApi_schema_all.json +++ b/test/core_tests/torch_npu_OpApi_schema_all.json @@ -812,6 +812,9 @@ "func: npu_add_layer_norm_backward(Tensor? dy_opt, Tensor x1, Tensor x2, Tensor rstd, Tensor mean, Tensor gamma, Tensor? dsum_opt) -> (Tensor, Tensor, Tensor, Tensor)": { "version": ["all_version"] }, + "func: npu_rms_norm_quant(Tensor x, Tensor gamma, Tensor beta, Tensor scale, Tensor offset, float epsilon=1e-06) -> Tensor": { + "version": ["all_version"] + }, "func: npu_rms_norm(Tensor self, Tensor gamma, float epsilon=1e-06) -> (Tensor, Tensor)": { "version": ["all_version"] }, diff --git a/test/test_custom_ops/test_npu_rms_norm_quant.py b/test/test_custom_ops/test_npu_rms_norm_quant.py new file mode 100644 index 000000000..85dfd7742 --- /dev/null +++ b/test/test_custom_ops/test_npu_rms_norm_quant.py @@ -0,0 +1,170 @@ +import unittest +import math + +import numpy as np +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import SupportedDevices + + +class TestNPURmsNormQuant(TestCase): + + def compare(self, a, b, benchmark): + + diff_abs = torch.abs(a - b) + max_diff_abs, _ = torch.max(diff_abs, dim=0) + + if max_diff_abs.item() == 0: + return True + else: + rel_error = 0 + abs_error = 0 + for i in range(a.shape[0]): + yes_no = (a[i] == 0 and b[i].item() != 0) + no_yes = (a[i] != 0 and b[i].item() == 0) + if a[i].item() == 0 and b[i].item() == 0: + diff_rel_item = 0 + elif yes_no or no_yes: + diff_rel_item = 1 + elif a[i] != 0 and b[i].item() != 0: + diff_rel_item = diff_abs[i].item() / abs(a[i].item()) + + if abs(a[i].item()) < 1 and diff_abs[i].item() > benchmark: + abs_error += 1 + elif abs(a[i].item()) >= 1 and diff_rel_item > benchmark: + rel_error += 1 + if (rel_error + abs_error) > 10: + break + if (rel_error + abs_error) > 0: + return False + else: + return True + + def npu_rms_norm_quant_golden(self, x, gamma, beta, scale, + offset, epsilon=1e-06): + + x_fp32 = x.float() + input_gamma_fp32 = gamma.float() + input_beta_fp32 = beta.float() + tensor_scales = scale.float() + offset = offset.float() + ori_shape = x.shape + + len_shape_x = len(x_fp32.shape) + len_shape_gamma = len(gamma.shape) + axis = len_shape_x - len_shape_gamma + variance = torch.mean(torch.pow(x_fp32, 2), axis=-1, keepdims=True) + std = torch.sqrt(variance + epsilon) + rstd = 1 / std + result_mid = x_fp32 * rstd + y_array = result_mid * input_gamma_fp32 + input_beta_fp32 + y = y_array.type(torch.float32) + y1 = torch.quantize_per_tensor(y, tensor_scales, offset, torch.qint8) + y1_np = y1.int_repr().detach().clone().cpu().numpy() + return torch.tensor(y1_np).type(torch.float16).type(torch.int8).reshape(ori_shape) + + + def test_npu_rms_norm_quant(self): + torch.manual_seed(42) + np.random.seed(42) + shape_list = [ + [[16,], [16,]], + [[16,], [1, 16]], + [[1, 16], [16,]], + [[1, 16], [1, 16]], + [[1, 1, 16], [16,]], + [[1, 1, 16], [1, 16]], + [[1, 1, 1, 16], [16,]], + [[1, 1, 1, 16], [1, 16]], + [[1, 1, 1, 1, 16], [16,]], + [[1, 1, 1, 1, 16], [1, 16]], + [[1, 1, 1, 1, 1, 16], [16,]], + [[1, 1, 1, 1, 1, 16], [1, 16]], + [[1, 1, 1, 1, 1, 1, 16], [16,]], + [[1, 1, 1, 1, 1, 1, 16], [1, 16]], + [[1, 1, 1, 1, 1, 1, 1, 16], [16,]], + [[1, 1, 1, 1, 1, 1, 1, 16], [1, 16]], + ] + + benchmark_int8 = 1 + + for x_shape, quant_shape in shape_list: + D = x_shape[-1] + + x = torch.randn(x_shape, dtype=torch.float16) + + if quant_shape == [D,]: + gamma = torch.randn(D, dtype=torch.float16) + beta = torch.randn(D, dtype=torch.float16) + elif quant_shape == [1, D]: + gamma = torch.randn(1, D, dtype=torch.float16) + beta = torch.randn(1, D, dtype=torch.float16) + + scale = (torch.rand(1, dtype=torch.float16) * 0.8 + 0.2) + offset = torch.randint(-5, 6, (1,), dtype=torch.int8) + x_npu = x.npu() + gamma_npu = gamma.npu() + beta_npu = beta.npu() + scale_npu = scale.npu() + offset_npu = offset.npu() + + y_ref = self.npu_rms_norm_quant_golden(x, gamma, beta, scale, offset, epsilon=1e-6) + y_npu = torch_npu.npu_rms_norm_quant(x_npu, gamma_npu, beta_npu, scale_npu, offset_npu, epsilon=1e-6) + y_ref_flat = y_ref.reshape(1, y_ref.numel())[0].cpu() + y_npu_flat = y_npu.reshape(1, y_npu.numel())[0].cpu() + self.assertTrue(self.compare(y_ref_flat, y_npu_flat, benchmark_int8)) + + + def test_npu_rms_norm_quant_bf16(self): + shape_list = [ + [[16,], [16,]], + [[16,], [1, 16]], + [[1, 16], [16,]], + [[1, 16], [1, 16]], + [[1, 1, 16], [16,]], + [[1, 1, 16], [1, 16]], + [[1, 1, 1, 16], [16,]], + [[1, 1, 1, 16], [1, 16]], + [[1, 1, 1, 1, 16], [16,]], + [[1, 1, 1, 1, 16], [1, 16]], + [[1, 1, 1, 1, 1, 16], [16,]], + [[1, 1, 1, 1, 1, 16], [1, 16]], + [[1, 1, 1, 1, 1, 1, 16], [16,]], + [[1, 1, 1, 1, 1, 1, 16], [1, 16]], + [[1, 1, 1, 1, 1, 1, 1, 16], [16,]], + [[1, 1, 1, 1, 1, 1, 1, 16], [1, 16]], + ] + + benchmark_int8 = 1 + + for x_shape, quant_shape in shape_list: + D = x_shape[-1] + x = torch.randn(x_shape, dtype=torch.bfloat16) + if quant_shape == [D,]: + gamma = torch.randn(D, dtype=torch.bfloat16) + beta = torch.randn(D, dtype=torch.bfloat16) + elif quant_shape == [1, D]: + gamma = torch.randn(1, D, dtype=torch.bfloat16) + beta = torch.randn(1, D, dtype=torch.bfloat16) + + scale = (torch.rand(1, dtype=torch.bfloat16) * 0.8 + 0.2) # (0.2, 1.0] + offset = torch.randint(-5, 6, (1,), dtype=torch.int8) + + x_npu = x.npu() + gamma_npu = gamma.npu() + beta_npu = beta.npu() + scale_npu = scale.npu() + offset_npu = offset.npu() + + y_ref = self.npu_rms_norm_quant_golden(x, gamma, beta, scale, offset, epsilon=1e-6) + y_npu = torch_npu.npu_rms_norm_quant(x_npu, gamma_npu, beta_npu, scale_npu, offset_npu, epsilon=1e-6) + + y_ref_flat = y_ref.reshape(1, y_ref.numel())[0].cpu() + y_npu_flat = y_npu.reshape(1, y_npu.numel())[0].cpu() + + self.assertTrue(self.compare(y_ref_flat, y_npu_flat, benchmark_int8)) + +if __name__ == "__main__": + run_tests() -- Gitee