diff --git a/src/cam/comm_operator/pybind/functions.h b/src/cam/comm_operator/pybind/functions.h index 4d5b4be7ca85b586422b851744925becb84264b5..e0855dfdf96d5fe6865303b3c79ab57d13b7d2dc 100644 --- a/src/cam/comm_operator/pybind/functions.h +++ b/src/cam/comm_operator/pybind/functions.h @@ -34,7 +34,7 @@ at::Tensor fused_deep_moe_impl_autograd( int64_t quantMode, \ int64_t globalBs); -std::tuple +std::vector moe_dispatch_normal_impl_autograd( const at::Tensor &x, \ const at::Tensor &topkIdx, \ diff --git a/src/cam/comm_operator/pybind/setup.py b/src/cam/comm_operator/pybind/setup.py index 1da4931fb92b540d01a16edab739bd957008b52d..3a5f3831417ef137dad2896eb59e5005037868c1 100644 --- a/src/cam/comm_operator/pybind/setup.py +++ b/src/cam/comm_operator/pybind/setup.py @@ -53,7 +53,7 @@ print(compile_args) exts = [] ext1 = NpuExtension( - name="cam_ge_op_lib", + name="umdk_cam_op_lib", include_dirs=[ os.path.join(torch_npu_path, "include"), os.path.join(torch_npu_path, "include/third_party/acl/inc/acl/"), @@ -92,9 +92,9 @@ exts.append(ext1) BdistWheelBuild.dependencies = ["libc10.so", "libtorch.so", "libtorch_cpu.so", "libtorch_python.so", "libtorch_npu.so"] setup( - name="cam_ge_operator", + name="umdk_cam_op_lib", version=env_version, - keywords="cam_ge_op_lib", + keywords='umdk_cam_op_lib', ext_modules=exts, packages=find_packages(), cmdclass={ diff --git a/src/cam/examples/fused_deep_moe_sample.py b/src/cam/examples/fused_deep_moe_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..669fd1e9415677da829c3213fc87aee33ff7b2e7 --- /dev/null +++ b/src/cam/examples/fused_deep_moe_sample.py @@ -0,0 +1,449 @@ +# +# SPDX-License-Identifier: MIT +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +# Description: Example for fused_deep_moe operator. +# This sample gives an example for using FusedDeepMoe operator and "small" operators, +# where "small" means the operators taking the same effect with FusedDeepMoe Operator but +# using small operator combination (dispatch + gmm1 + swiglu + gmm2 + combine). +# Create: 2025-12-11 +# Note: +# History: 2025-12-11 create example file +# + +import gc +import os +import sys +from pathlib import Path + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch_npu +import torchair + +# 导入CAM算子库。导入前请保证算子库已经正确安装。 +import umdk_cam_op_lib + +torch_npu.npu.config.allow_internal_format = True +LOG_NAME = "fused_deep_moe_sample_logs" + + +def redirect_output(log_file_path): + log_path = Path(LOG_NAME) / log_file_path + log_path.parent.mkdir(parents=True, exist_ok=True) + f = open(LOG_NAME + "/" + log_file_path, "w") + os.dup2(f.fileno(), sys.stdout.fileno()) + os.dup2(f.fileno(), sys.stderr.fileno()) + return f + + +def permute_weight(w: torch.Tensor, tile_n): + *dims, n = w.shape + order = list(range(len(dims))) + [-2, -3, -1] + return w.reshape(*dims, 2, n // tile_n, + tile_n // 2).permute(order).reshape(*dims, + n).contiguous() + + +def from_inclusive_prefix_sum(pref): + if isinstance(pref, torch.Tensor): + if pref.numel() == 0: + return pref + return torch.cat([pref[:1], pref[1:] - pref[:-1]]) + + if not pref: + return [] + out = [pref[0]] + for i in range(1, len(pref)): + out.append(pref[i] - pref[i - 1]) + return out + + +def output_to_file(rank_id): + return False + + +class DecodeMoeOps(torch.nn.Module): + + def __init__(self, + gmm1_weight, + gmm1_weight_scale, + gmm2_weight, + gmm2_weight_scale, + ep_hcomm_info, + batch_size, + token_hidden_size, + moe_intermediate_size, + ep_world_size, + moe_expert_num, + global_rank_id, + shared_expert_rank_num=0): + super().__init__() + self.ep_hcomm_info = ep_hcomm_info + self.batch_size = batch_size + self.token_hidden_size = token_hidden_size + self.moe_intermediate_size = moe_intermediate_size + self.ep_world_size = ep_world_size + self.moe_expert_num = moe_expert_num + self.global_rank_id = global_rank_id + self.shared_expert_rank_num = shared_expert_rank_num + is_shared_expert = global_rank_id < shared_expert_rank_num + moe_expert_num_per_rank = moe_expert_num // (ep_world_size - + shared_expert_rank_num) + self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank + self.ep_recv_count_size = self.local_expert_num * ep_world_size + self.gmm1_weight = torch.empty([ + self.local_expert_num, self.token_hidden_size, + self.moe_intermediate_size * 2 + ]) + self.gmm1_weight_scale = torch.empty( + [self.local_expert_num, self.moe_intermediate_size * 2]) + self.gmm2_weight = torch.empty([ + self.local_expert_num, self.moe_intermediate_size, + self.token_hidden_size + ]) + self.gmm2_weight_scale = torch.empty( + [self.local_expert_num, self.token_hidden_size]) + self._process_weights_after_loading(gmm1_weight, gmm1_weight_scale, + gmm2_weight, gmm2_weight_scale) + + def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, + gmm2_weight, gmm2_weight_scale): + raise NotImplementedError("To be implemented in subclass") + + def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): + raise NotImplementedError("To be implemented in subclass") + + def forward(self, x, expert_ids, smooth_scales, expert_scales): + return self._apply_ops(x, expert_ids, smooth_scales, expert_scales) + + +class SmallOps(DecodeMoeOps): + + def __init__(self, + gmm1_weight, + gmm1_weight_scale, + gmm2_weight, + gmm2_weight_scale, + ep_hcomm_info, + batch_size, + token_hidden_size, + moe_intermediate_size, + ep_world_size, + moe_expert_num, + global_rank_id, + shared_expert_rank_num=0): + super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, + gmm2_weight_scale, ep_hcomm_info, batch_size, + token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, global_rank_id, + shared_expert_rank_num) + self.tp_hcomm_info = "" + + def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, + gmm2_weight, gmm2_weight_scale): + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, + torch_npu.Format.FRACTAL_NZ) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, + torch_npu.Format.FRACTAL_NZ) + self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False) + self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, + requires_grad=False) + self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False) + self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, + requires_grad=False) + + def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): + outputs = torch_npu.npu_moe_distribute_dispatch_v2( + x=x, + expert_ids=expert_ids, + expert_scales=expert_scales, + group_ep=self.ep_hcomm_info, + ep_world_size=self.ep_world_size, + ep_rank_id=self.global_rank_id, + moe_expert_num=self.moe_expert_num, + group_tp=self.tp_hcomm_info, + tp_world_size=1, + tp_rank_id=0, + expert_shard_type=0, + shared_expert_num=1, + shared_expert_rank_num=self.shared_expert_rank_num, + quant_mode=2, + global_bs=self.batch_size * self.ep_world_size, + expert_token_nums_type=1, # 0代表前缀和,1代表各自数量 + ) + expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs + output_dtype = x.dtype + + y1_int32 = torch_npu.npu_grouped_matmul( + x=[expand_x], + weight=[self.gmm1_weight], + split_item=3, + group_list_type=1, # 默认为0,代表前缀和形式 + group_type=0, # 0代表m轴分组 + group_list=expert_token_nums, + output_dtype=torch.int32)[0] + y1, y1_scale = torch_npu.npu_dequant_swiglu_quant( + x=y1_int32, + weight_scale=self.gmm1_weight_scale.to(torch.float32), + activation_scale=dynamic_scales, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=expert_token_nums, + activate_left=True, + quant_mode=1, + ) + y2 = torch_npu.npu_grouped_matmul(x=[y1], + weight=[self.gmm2_weight], + scale=[self.gmm2_weight_scale], + per_token_scale=[y1_scale], + split_item=2, + group_list_type=1, + group_type=0, + group_list=expert_token_nums, + output_dtype=output_dtype)[0] + combine_output = torch_npu.npu_moe_distribute_combine_v2( + expand_x=y2, + expert_ids=expert_ids, + assist_info_for_combine=assist_info_for_combine, + ep_send_counts=ep_send_counts, + expert_scales=expert_scales, + group_ep=self.ep_hcomm_info, + ep_world_size=self.ep_world_size, + ep_rank_id=self.global_rank_id, + moe_expert_num=self.moe_expert_num, + tp_send_counts=tp_send_counts, + expand_scales=expand_scales, + group_tp=self.tp_hcomm_info, + tp_world_size=1, + tp_rank_id=0, + expert_shard_type=0, + shared_expert_num=1, + shared_expert_rank_num=self.shared_expert_rank_num, + global_bs=self.batch_size * self.ep_world_size) + return (combine_output, ep_send_counts[:self.ep_recv_count_size]) + + +class FusionOp(DecodeMoeOps): + + def __init__(self, + gmm1_weight, + gmm1_weight_scale, + gmm2_weight, + gmm2_weight_scale, + ep_hcomm_info, + batch_size, + token_hidden_size, + moe_intermediate_size, + ep_world_size, + moe_expert_num, + global_rank_id, + shared_expert_rank_num=0): + super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, + gmm2_weight_scale, ep_hcomm_info, batch_size, + token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, global_rank_id, + shared_expert_rank_num) + + def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, + gmm2_weight, gmm2_weight_scale): + gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\ + .view(self.local_expert_num, 2, self.moe_intermediate_size // 64, 64, self.token_hidden_size)\ + .transpose(1,2).contiguous()\ + .view(self.local_expert_num, self.moe_intermediate_size * 2, self.token_hidden_size)\ + .transpose(1,2).contiguous() + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, + torch_npu.Format.ND) + gmm1_weight.add_(0) + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, + torch_npu.Format.FRACTAL_NZ) + gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128) + gmm2_weight = torch_npu.npu_format_cast( + gmm2_weight.transpose(1, 2).contiguous(), + torch_npu.Format.FRACTAL_NZ) + + gmm1_weight_scale = gmm1_weight_scale.float() + gmm2_weight_scale = gmm2_weight_scale.float() + + self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False) + self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, + requires_grad=False) + self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False) + self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, + requires_grad=False) + + def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): + output = torch.ops.umdk_cam_op_lib.fused_deep_moe( + x=x, + expertIds=expert_ids, + gmm1PermutedWeight=self.gmm1_weight, + gmm1PermutedWeightScale=self.gmm1_weight_scale, + gmm2Weight=self.gmm2_weight, + gmm2WeightScale=self.gmm2_weight_scale, + expertSmoothScalesOptional=smooth_scales, + expertScalesOptional=expert_scales, + groupEp=self.ep_hcomm_info, + epRankSize=self.ep_world_size, + epRankId=self.global_rank_id, + moeExpertNum=self.moe_expert_num, + sharedExpertNum=1, + sharedExpertRankNum=self.shared_expert_rank_num, + quantMode=0, + globalBs=self.batch_size * self.ep_world_size) + return output + + +def generate_datas(batch_size, + token_hidden_size, + moe_intermediate_size, + ep_world_size, + moe_expert_num, + global_rank_id, + shared_expert_rank_num=0, + top_k=8, + test_bfloat16=True, + enable_dynamic_bs=False): + is_shared_expert = global_rank_id < shared_expert_rank_num + moe_expert_num_per_rank = moe_expert_num // (ep_world_size - + shared_expert_rank_num) + actual_bs = int( + torch.randint(1, batch_size, [1]).item( + ) if enable_dynamic_bs else batch_size) + local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank + gmm1_input_dim = token_hidden_size + gmm1_output_dim = moe_intermediate_size * 2 + gmm2_input_dim = moe_intermediate_size + gmm2_output_dim = token_hidden_size + x = torch.rand([actual_bs, token_hidden_size]) * 10 - 5 + expert_ids = torch.arange( + global_rank_id * batch_size * top_k, + global_rank_id * batch_size * top_k + actual_bs * top_k).to( + torch.int32).view(actual_bs, top_k) + expert_ids = expert_ids % moe_expert_num + if is_shared_expert: + gmm1_weight = torch.ones([ + local_expert_num, gmm1_input_dim, gmm1_output_dim + ]).to(torch.int8) * 4 + gmm2_weight = torch.ones([ + local_expert_num, gmm2_input_dim, gmm2_output_dim + ]).to(torch.int8) * 4 + gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1 + gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1 + gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim + ]) * 0.0015 + gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim + ]) * 0.0015 + else: + gmm1_weight = torch.randint( + -16, 16, + [local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8) + gmm2_weight = torch.randint( + -16, 16, + [local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8) + gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim + ]) * 0.003 + 0.0015 + gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim + ]) * 0.003 + 0.0015 + expert_scales = torch.rand(actual_bs, top_k) + if test_bfloat16: + x = x.bfloat16() + gmm1_weight_scale = gmm1_weight_scale.bfloat16() + gmm2_weight_scale = gmm2_weight_scale.bfloat16() + else: + x = x.half() + smooth_sales = None + return (x, expert_ids, smooth_sales, expert_scales), \ + (gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \ + actual_bs + + +def run_once(local_rank_id, + batch_size, + token_hidden_size, + moe_intermediate_size, + ep_world_size, + moe_expert_num, + shared_expert_rank_num=0, + top_k=8, + test_bfloat16=True, + enable_dynamic_bs=False): + # 配置日志输出文件名 + log_file = redirect_output(f"local_rank_{local_rank_id}.log" + ) if output_to_file(local_rank_id) else None + # 使用A3 单机16DIE进行测试 + global_rank_id = local_rank_id + device_id = local_rank_id % 16 + torch_npu.npu.set_device(device_id) + + # 初始化分布式环境 + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" # 端口号随意 + dist.init_process_group(backend="hccl", + rank=local_rank_id, + world_size=ep_world_size) + ep_ranks_list = list(np.arange(0, ep_world_size)) + ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list) + ep_group_small = dist.new_group(backend="hccl", ranks=ep_ranks_list) + + ep_hcomm_info_fused = ep_group._get_backend( + torch.device("npu")).get_hccl_comm_name(local_rank_id) + ep_hcomm_info_small = ep_group_small._get_backend( + torch.device("npu")).get_hccl_comm_name(local_rank_id) + torch_npu.npu.synchronize(device_id) + + # 构建必要参数和权重数据 + parameter = (batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, global_rank_id, + shared_expert_rank_num) + input_datas, weight_datas, actual_bs = generate_datas( + *parameter, top_k, test_bfloat16, enable_dynamic_bs) + input_datas = [ + data.npu() if data is not None else None for data in input_datas + ] + weight_datas = [ + data.npu() if data is not None else None for data in weight_datas + ] + small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, + *parameter).npu() # type: ignore + fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, + *parameter).npu() # type: ignore + small_op_token_output, small_op_count_output = small_ops(*input_datas) + fused_op_token_output, fused_op_count_output = fused_ops(*input_datas) + torch_npu.npu.synchronize(device_id) + + # 处理资源销毁 + dist.destroy_process_group() + if log_file is not None: + log_file.close() + small_op_count_output = from_inclusive_prefix_sum(small_op_count_output) + torch.testing.assert_close(small_op_token_output.cpu(), + fused_op_token_output.cpu(), + atol=2.0, + rtol=0.02) + torch.testing.assert_close(small_op_count_output.cpu(), + fused_op_count_output.cpu()) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + + +@torch.inference_mode() +def test(): + batch_size = 64 + token_hidden_size = 7168 + moe_intermediate_size = 2048 + ep_world_size = 16 + moe_expert_num = 64 + shared_expert_rank_num = 0 + top_k = 8 + test_bfloat16 = True + enable_dynamic_bs = False + args = (batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, shared_expert_rank_num, top_k, + test_bfloat16, enable_dynamic_bs) + mp.spawn(run_once, args=args, nprocs=ep_world_size, join=True) + +if __name__ == "__main__": + test() \ No newline at end of file