From 0b6f8e884acfd9c5ae485bed548309c848be82fd Mon Sep 17 00:00:00 2001 From: PaddlePaddle-Gardener Date: Wed, 12 Jan 2022 14:43:21 +0800 Subject: [PATCH] mirgate_38785 --- python/paddle/fluid/dygraph/amp/auto_cast.py | 157 +++++++++--------- python/paddle/fluid/dygraph/layers.py | 2 + .../test_imperative_auto_mixed_precision.py | 103 ++++++++---- 3 files changed, 153 insertions(+), 109 deletions(-) diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index ddde3e66c5..f09e210c3c 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -71,7 +71,9 @@ AMP_RELATED_FLAGS_SETTING = { } PURE_FP16_WHITE_LIST = {' '} -PURE_FP16_BLACK_LIST = {'lookup_table', 'lookup_table_v2'} +PURE_FP16_BLACK_LIST = { + 'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad' +} #NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list @@ -118,57 +120,23 @@ def _in_amp_guard(): return False -@dygraph_only -def pure_fp16_initialize(enable_pure_fp16, models, optimizers): - if not enable_pure_fp16: - return models, optimizers +def _in_pure_fp16_guard(): + tracer = _dygraph_tracer() + return tracer and tracer._amp_level == core.AmpLevel.O2 + +@dygraph_only +def pure_fp16_initialize(models): for idx in range(len(models)): for layer in models[idx].sublayers(include_self=True): layer._casted_by_pure_fp16 = True - if len(layer._sub_layers) is 0: - - if (layer._dtype is 'float16') or isinstance(layer, ( - paddle.nn.BatchNorm, paddle.nn.LayerNorm)): - continue - layer.to(dtype='float16') - - for idx_opt in range(len(optimizers)): - # update _param_groups - if getattr(optimizers[idx_opt], '_param_groups', None) and isinstance( - optimizers[idx_opt]._param_groups[0], dict): - for param_group in optimizers[idx_opt]._param_groups: - for i, param in enumerate(param_group['params']): - for idx_model in range(len(models)): - for layer in models[idx_model].sublayers( - include_self=True): - if id(param) in layer._parameters_transform_map: - param_group['params'][ - i] = layer._parameters_transform_map[id( - param)][0] - for param_group in optimizers[idx_opt]._parameter_list: - params = param_group['params'] - for i, param in enumerate(params): - for idx_model in range(len(models)): - for layer in models[idx_model].sublayers( - include_self=True): - if id(param) in layer._parameters_transform_map: - params[i] = layer._parameters_transform_map[id( - param)][0] - # update _parameter_list - else: - for i, param in enumerate(optimizers[idx_opt]._parameter_list): - for idx_model in range(len(models)): - for layer in models[idx_model].sublayers(include_self=True): - if id(param) in layer._parameters_transform_map: - optimizers[idx_opt]._parameter_list[ - i] = layer._parameters_transform_map[id(param)][ - 0] - if hasattr(optimizers[idx_opt], '_param_groups'): - optimizers[idx_opt]._param_groups[ - i] = layer._parameters_transform_map[id( - param)][0] - return models, optimizers + if (layer._dtype is 'float16') or isinstance( + layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D, + paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, + paddle.nn.LayerNorm)): + continue + layer._to_impl(dtype='float16', include_sublayers=False) + return models def check_models(models): @@ -177,6 +145,10 @@ def check_models(models): raise RuntimeError( "Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.". format(type(model))) + if isinstance(model, paddle.DataParallel): + raise RuntimeError( + "For distributed AMP training, you should first use paddle.amp.decorate() to decotate origin model, and then call paddle.DataParallel get distributed model." + ) def check_optimizers(optimizers): @@ -252,6 +224,13 @@ def amp_guard(enable=True, % tracer._expected_place) enable = False + if tracer._expected_place.is_gpu_place(): + prop = paddle.device.cuda.get_device_capability() + if prop[0] < 7: + warnings.warn( + "AMP only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." + % (paddle.device.cuda.get_device_name(), prop[0], prop[1])) + if level == 'O1': amp_level = AMP_LEVEL.O1 _white_list = WHITE_LIST @@ -366,6 +345,19 @@ def amp_decorate(models, output2 = models[1](data) print(output.dtype) # FP16 print(output2.dtype) # FP16 + + # required: gpu + # Demo3: optimizers is None: + model3 = paddle.nn.Conv2D(3, 2, 3, bias_attr=False) + optimizer3 = paddle.optimizer.Adam(parameters=model2.parameters()) + + model = paddle.fluid.dygraph.amp_decorate(models=model3, level='O2') + + data = paddle.rand([10, 3, 32, 32]) + + with paddle.fluid.dygraph.amp_guard(enable=True, custom_white_list=None, custom_black_list=None, level='O2'): + output = model(data) + print(output.dtype) # FP16 """ if not (level in ['O1', 'O2']): raise ValueError( @@ -373,7 +365,10 @@ def amp_decorate(models, ) if level == 'O1': - return models, optimizers + if optimizers is None: + return models + else: + return models, optimizers models_is_list = False if isinstance(models, paddle.nn.Layer): @@ -387,30 +382,30 @@ def amp_decorate(models, raise TypeError( "models must be either a single model or a list of models.") - optimizers_is_list = False - if isinstance(optimizers, (paddle.optimizer.Optimizer, - paddle.fluid.optimizer.Optimizer)): - optimizers_is_list = False - optimizers = [optimizers] - check_optimizers(optimizers) - elif isinstance(optimizers, list): - check_optimizers(optimizers) - optimizers_is_list = True - else: - raise TypeError( - "optimizers must be either a single optimizer or a list of optimizers." - ) + models = pure_fp16_initialize(models=models) - models, optimizers = pure_fp16_initialize( - enable_pure_fp16=True, models=models, optimizers=optimizers) - - # supprot master_weight - for idx_opt in range(len(optimizers)): - if hasattr(optimizers[idx_opt], '_multi_precision'): - if master_weight is False: - optimizers[idx_opt]._multi_precision = False - else: - optimizers[idx_opt]._multi_precision = True + if optimizers is not None: + # check optimizers + optimizers_is_list = False + if isinstance(optimizers, (paddle.optimizer.Optimizer, + paddle.fluid.optimizer.Optimizer)): + optimizers_is_list = False + optimizers = [optimizers] + check_optimizers(optimizers) + elif isinstance(optimizers, list): + check_optimizers(optimizers) + optimizers_is_list = True + else: + raise TypeError( + "optimizers must be either a single optimizer or a list of optimizers." + ) + # supprot master_weight + for idx_opt in range(len(optimizers)): + if hasattr(optimizers[idx_opt], '_multi_precision'): + if master_weight is False: + optimizers[idx_opt]._multi_precision = False + else: + optimizers[idx_opt]._multi_precision = True if save_dtype is not None: if not (save_dtype in ['float16', 'float32', 'float64']): @@ -422,12 +417,18 @@ def amp_decorate(models, layer.register_state_dict_hook(StateDictHook(save_dtype)) if models_is_list: - if optimizers_is_list: - return models, optimizers + if optimizers is not None: + if optimizers_is_list: + return models, optimizers + else: + return models, optimizers[0] else: - return models, optimizers[0] + return models else: - if optimizers_is_list: - return models[0], optimizers + if optimizers is not None: + if optimizers_is_list: + return models[0], optimizers + else: + return models[0], optimizers[0] else: - return models[0], optimizers[0] + return models[0] diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 4c37a378e0..6a65b3bd9c 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1569,6 +1569,8 @@ class Layer(object): for key, buf in self._buffers.items(): self._buffers[key] = func(buf, device, dtype, blocking) + self._dtype = dtype + def _to_impl(self, device=None, dtype=None, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 5f1f4a4641..62b40f8857 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -524,38 +524,37 @@ class TestAmpDecorator(unittest.TestCase): self.assertRaises(ValueError, func) - def test_input_formate_exception(self): - def test_model_error(): - with fluid.dygraph.guard(): - model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) - opt = paddle.optimizer.SGD(parameters=model.parameters()) - paddle.amp.decorate(models=None, optimizers=opt, level='O2') - - self.assertRaises(TypeError, test_model_error) + def test_input_type_exception(self): + def test_error_model(): + class MyModel(object): + def __init__(self): + print("A fake Model") - def test_optimizer_error(): + model = MyModel() with fluid.dygraph.guard(): - model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) paddle.amp.decorate(models=model, optimizers=None, level='O2') - self.assertRaises(TypeError, test_optimizer_error) + self.assertRaises(TypeError, test_error_model) - def test_input_type_exception(self): - def test_error_model_optimizer(): - class MyModel(object): - def __init__(self): - print("A fake Model") + def test_error_distributed_model(): + model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) + model = paddle.DataParallel(model) + with fluid.dygraph.guard(): + model = paddle.amp.decorate(models=model, level='O2') + + self.assertRaises(RuntimeError, test_error_distributed_model) + def test_error_optimizer(): class MyOptimizer(object): def __init__(self): print("A fake Optimizer") - model = MyModel() + model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) opt = MyOptimizer() with fluid.dygraph.guard(): paddle.amp.decorate(models=model, optimizers=opt, level='O2') - self.assertRaises(TypeError, test_error_model_optimizer) + self.assertRaises(TypeError, test_error_optimizer) def test_set_master_weight(self): model1 = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) @@ -563,32 +562,75 @@ class TestAmpDecorator(unittest.TestCase): learning_rate=0.0001, parameters=model1.parameters(), multi_precision=True) - model1, opt1 = paddle.amp.decorate( - models=model1, optimizers=opt1, level='O2', master_weight=None) - self.assertEqual(opt1._multi_precision, True) model2 = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) opt2 = paddle.optimizer.Adam( learning_rate=0.0001, parameters=model2.parameters(), multi_precision=False) - model2, opt2 = paddle.amp.decorate( - models=model2, optimizers=opt2, level='O2', master_weight=None) + + model1, opt1 = paddle.amp.decorate( + models=model1, optimizers=opt1, level='O2', master_weight=None) + self.assertEqual(opt1._multi_precision, True) + + models, opt2 = paddle.amp.decorate( + models=[model1, model2], + optimizers=opt2, + level='O2', + master_weight=None) self.assertEqual(opt2._multi_precision, True) model3 = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) opt3 = paddle.optimizer.Adam( learning_rate=0.0001, parameters=model3.parameters()) - model3, opt3 = paddle.amp.decorate( - models=model3, optimizers=opt3, level='O2', master_weight=True) - self.assertEqual(opt3._multi_precision, True) model4 = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None) opt4 = paddle.optimizer.Adam( learning_rate=0.0001, parameters=model4.parameters()) - model4, opt4 = paddle.amp.decorate( - models=model4, optimizers=opt4, level='O2', master_weight=False) - self.assertEqual(opt4._multi_precision, False) + + model3, opts = paddle.amp.decorate( + models=model3, + optimizers=[opt3, opt4], + level='O2', + master_weight=True) + self.assertEqual(opts[0]._multi_precision, True) + self.assertEqual(opts[1]._multi_precision, True) + + models = [model3, model4] + optimizers = [opt3, opt4] + models, optimizers = paddle.amp.decorate( + models=models, + optimizers=optimizers, + level='O2', + master_weight=False) + self.assertEqual(optimizers[0]._multi_precision, False) + self.assertEqual(optimizers[1]._multi_precision, False) + + def test_skip_BatchNorm_Layer_norm(self): + model = paddle.nn.LayerNorm(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm1D(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm2D(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm3D(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) class TestPureFp16SaveLoad(unittest.TestCase): @@ -893,8 +935,7 @@ class TestResnet2(unittest.TestCase): train_reader = train_loader if enable_amp and (level == 'O2'): - resnet, optimizer = paddle.amp.decorate( - models=resnet, optimizers=optimizer, level='O2') + resnet = paddle.amp.decorate(models=resnet, level='O2') for batch_id, data in enumerate(train_reader()): if batch_id >= batch_num: -- Gitee