From 13362eef30663bc0171ff860bf9030b4bf1a7b88 Mon Sep 17 00:00:00 2001 From: l00486551 Date: Thu, 16 Sep 2021 10:40:40 +0800 Subject: [PATCH] add fat_deepffm GPU support --- research/recommend/Fat-DeepFFM/README.md | 106 +++++++++++++----- .../scripts/run_distribute_train_gpu.sh | 60 ++++++++++ research/recommend/Fat-DeepFFM/train.py | 38 ++++--- 3 files changed, 163 insertions(+), 41 deletions(-) create mode 100644 research/recommend/Fat-DeepFFM/scripts/run_distribute_train_gpu.sh diff --git a/research/recommend/Fat-DeepFFM/README.md b/research/recommend/Fat-DeepFFM/README.md index 95f456819..8cbd4e440 100644 --- a/research/recommend/Fat-DeepFFM/README.md +++ b/research/recommend/Fat-DeepFFM/README.md @@ -97,11 +97,36 @@ Fat - DeepFFM consists of three parts. The FFM component is a factorization mach [hccl tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). +- running on GPU + + ```shell + # run training example + python train.py \ + --dataset_path='data/mindrecord' \ + --ckpt_path='./checkpoint/Fat-DeepFFM' \ + --eval_file_name='./auc.log' \ + --loss_file_name='./loss.log' \ + --device_target='GPU' \ + --do_eval=True > output.log 2>&1 & + + # run distributed training example + bash scripts/run_distribute_train_gpu.sh 8 /dataset_path + + # run evaluation example + python eval.py \ + --dataset_path='dataset/mindrecord' \ + --ckpt_path='./checkpoint/Fat-DeepFFM.ckpt'\ + --device_target = 'GPU'\ + --device_id=0 > eval_output.log 2>&1 & + OR + bash scripts/run_eval.sh 0 GPU /dataset_path /ckpt_path + ``` + # [Script Description](#contents) ## [Script and Sample Code](#contents) -```path +```bash . └─Fat-deepffm ├─README.md @@ -109,6 +134,7 @@ Fat - DeepFFM consists of three parts. The FFM component is a factorization mach ├─scripts ├─run_alone_train.sh # launch standalone training(1p) in Ascend ├─run_distribute_train.sh # launch distributed training(8p) in Ascend + ├─run_distribute_train_gpu.sh # launch distributed training(8p) in GPU └─run_eval.sh # launch evaluating in Ascend ├─src ├─config.py # parameter configuration @@ -189,6 +215,24 @@ Parameters for both training and evaluation can be set in config.py The model checkpoint will be saved in the current directory. +- running on GPU + + ```shell + bash scripts/run_alone_train.sh [DATASET_PATH] [DEVICE_TARGET] [DO_EVAL] + ``` + + After training, you'll get some checkpoint files under `./checkpoint` folder by default. The loss value are saved in loss.log file. + + ```log + 2021-06-19 21:59:10 epoch: 1 step: 5166, loss is 0.46262410283088684 + 2021-06-19 22:12:13 epoch: 2 step: 5166, loss is 0.4792023301124573 + 2021-06-19 22:21:03 epoch: 3 step: 5166, loss is 0.4666571617126465 + 2021-06-19 22:29:54 epoch: 4 step: 5166, loss is 0.44029417634010315 + ... + ``` + + The model checkpoint will be saved in the current directory. + ### Distributed Training - running on Ascend @@ -199,6 +243,12 @@ Parameters for both training and evaluation can be set in config.py The above shell script will run distribute training in the background. You can view the results through the file `log[X]/output.log`. The loss value are saved in loss.log file. +- running on GPU + + ```shell + bash scripts/run_distribute_train_gpu.sh 8 /dataset_path + ``` + ## [Evaluation Process](#contents) ### Evaluation @@ -261,37 +311,37 @@ Inference result is saved in current path, you can find result like this in acc. ### Training Performance -| Parameters | Ascend | -| -------------------------- | ----------------------------------------------------------- | -| Model Version | Fat-DeepFFM | -| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | -| uploaded Date | 09/15/2020 (month/day/year) | -| MindSpore Version | 1.2.0 | -| Dataset | Criteo | -| Training Parameters | epoch=30, batch_size=1000, lr=1e-4 | -| Optimizer | Adam | -| Loss Function | Sigmoid Cross Entropy With Logits | -| outputs | AUC | -| Loss | 0.45 | -| Speed | 1pc: 8.16 ms/step; | -| Total time | 1pc: 4 hours; | -| Parameters (M) | 560.34 | -| Checkpoint for Fine tuning | 87.65M (.ckpt file) | +| Parameters | Ascend | GPU +| -------------------------- | ----------------------------------------------------------- | ------------------------------------------- +| Model Version | Fat-DeepFFM | Fat-DeepFFM +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G +| uploaded Date | 09/15/2020 (month/day/year) | 09/04/2021 (month/day/year) +| MindSpore Version | 1.2.0 | 1.3.0 +| Dataset | Criteo | Criteo +| Training Parameters | epoch=30, batch_size=1000, lr=1e-4 | epoch=30, batch_size=1000, lr=1e-4 +| Optimizer | Adam | Adam +| Loss Function | Sigmoid Cross Entropy With Logits | Sigmoid Cross Entropy With Logits +| outputs | AUC | AUC +| Loss | 0.45 | 0.43 +| Speed | 1pc: 8.16 ms/step; | N/A +| Total time | 1pc: 4 hours; | 1pc: 10 hours; +| Parameters (M) | 560.34 | 560.34 +| Checkpoint for Fine tuning | 87.65M (.ckpt file) | 89.35M (.ckpt file) | Scripts | [deepfm script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/Fat-DeepFFM) | ### Inference Performance -| Parameters | Ascend | -| ------------------- | --------------------------- | -| Model Version | DeepFM | -| Resource | Ascend 910; OS Euler2.8 | -| Uploaded Date | 06/20/2021 (month/day/year) | -| MindSpore Version | 1.2.0 | -| Dataset | Criteo | -| batch_size | 1000 | -| outputs | AUC | -| AUC | 1pc: 80.90%; | -| Model for inference | 87.65M (.ckpt file) | +| Parameters | Ascend | GPU +| ------------------- | --------------------------- | --------------------------- +| Model Version | DeepFM | DeepFM +| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G; OS Euler3.10 +| Uploaded Date | 06/20/2021 (month/day/year) | 09/04/2021 (month/day/year) +| MindSpore Version | 1.2.0 | 1.3.0 +| Dataset | Criteo | Criteo +| batch_size | 1000 | 1000 +| outputs | AUC | AUC +| AUC | 1pc: 80.90%; | 1pc: 79.54%; +| Model for inference | 87.65M (.ckpt file) | 89.35M (.ckpt file) # [Description of Random Situation](#contents) diff --git a/research/recommend/Fat-DeepFFM/scripts/run_distribute_train_gpu.sh b/research/recommend/Fat-DeepFFM/scripts/run_distribute_train_gpu.sh new file mode 100644 index 000000000..426cbada2 --- /dev/null +++ b/research/recommend/Fat-DeepFFM/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "Please run the script as: " +echo "bash scripts/run_distribute_train_gpu.sh DEVICE_NUM DATASET_PATH" +echo "for example: sh scripts/run_distribute_train_gpu.sh 8 /dataset_path" +echo "After running the script, the network runs in the background, The log will be generated in log/output.log" + +if [ $# != 2 ]; then + echo "Usage: bash scripts/run_distribute_train_gpu.sh [DEVICE_NUM] [DATASET_PATH]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +dataset_path=$(get_real_path $2) +echo $dataset_path + +if [ ! -d $dataset_path ] +then + echo "error: dataset_path=$dataset_path is not a directory." +exit 1 +fi + +export RANK_SIZE=$1 +export DATA_URL=$2 + +rm -rf log +mkdir ./log +mkdir ./log/ckpt +cp *.py ./log +cp -r ./src ./log +cd ./log || exit +env > env.log +mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \ + python -u train.py \ + --dataset_path=$DATA_URL \ + --ckpt_path="./log/ckpt" \ + --eval_file_name='auc.log' \ + --loss_file_name='loss.log' \ + --device_target='GPU' \ + --do_eval=True > output.log 2>&1 & \ No newline at end of file diff --git a/research/recommend/Fat-DeepFFM/train.py b/research/recommend/Fat-DeepFFM/train.py index 4c437a467..5cca2cd5e 100644 --- a/research/recommend/Fat-DeepFFM/train.py +++ b/research/recommend/Fat-DeepFFM/train.py @@ -24,7 +24,7 @@ from src.fat_deepffm import ModelBuilder from src.metrics import AUCMetric from mindspore import context, Model from mindspore.common import set_seed -from mindspore.communication.management import init, get_rank +from mindspore.communication.management import init, get_rank, get_group_size from mindspore.context import ParallelMode from mindspore.train.callback import CheckpointConfig, ModelCheckpoint @@ -48,17 +48,29 @@ set_seed(1) if __name__ == '__main__': model_config = ModelConfig() if rank_size > 1: - device_id = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, - device_id=device_id) - context.reset_auto_parallel_context() - context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True, - all_reduce_fusion_config=[9, 11]) - init() - rank_id = get_rank() + if args.device_target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, + device_id=device_id) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True, + all_reduce_fusion_config=[9, 11]) + init() + rank_id = get_rank() + elif args.device_target == "GPU": + init() + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=get_group_size(), + parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + rank_id = get_rank() + else: + raise Exception("Unsupported device_target", args.device_target) else: - device_id = int(os.getenv('DEVICE_ID')) + if args.device_target == "Ascend": + device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=device_id) rank_size = None @@ -72,12 +84,12 @@ if __name__ == '__main__': time_callback = TimeMonitor(data_size=ds_train.get_dataset_size()) loss_callback = LossCallback(args.loss_file_name) cb = [loss_callback, time_callback] - if rank_size == 1 or device_id == 0: + if rank_size == 1: config_ck = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size() * model_config.epoch_size, keep_checkpoint_max=model_config.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_path, config=config_ck) cb += [ckpoint_cb] - if args.do_eval and device_id == 0: + if args.do_eval: ds_test = get_mindrecord_dataset(args.dataset_path, train_mode=False) eval_callback = AUCCallBack(model, ds_test, eval_file_path=args.eval_file_name) cb.append(eval_callback) -- Gitee