From 7966e3ca2af4baee452c8c40e8f3fedad9a6ccd8 Mon Sep 17 00:00:00 2001 From: fighting-ye <1138455646@qq.com> Date: Wed, 28 Aug 2024 14:55:14 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbert=20gpu=E6=80=A7=E8=83=BD?= =?UTF-8?q?=E8=A3=82=E5=8C=96=EF=BC=8C=E9=80=82=E9=85=8D=E8=9E=8D=E5=90=88?= =?UTF-8?q?=E5=BC=80=E5=85=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- official/nlp/Bert/run_pretrain.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/official/nlp/Bert/run_pretrain.py b/official/nlp/Bert/run_pretrain.py index 53d927dc1..9478b5a35 100644 --- a/official/nlp/Bert/run_pretrain.py +++ b/official/nlp/Bert/run_pretrain.py @@ -221,8 +221,13 @@ def run_pretrain(): ckpt_save_dir = os.path.join(cfg.save_checkpoint_path, 'ckpt_' + str(get_rank())) mindspore.reset_auto_parallel_context() - mindspore.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, - device_num=device_num) + if cfg.device_target == 'GPU': + comm_fusion_dict = {"allreduce": {"mode": "auto", "config": None}} + mindspore.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, comm_fusion=comm_fusion_dict, + gradients_mean=True, device_num=device_num) + else: + mindspore.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, + device_num=device_num) _set_bert_all_reduce_split() _check_compute_type(cfg) -- Gitee