From 8ff912b0a129562c7946e17a86e26468c417bff9 Mon Sep 17 00:00:00 2001 From: fighting-ye <1138455646@qq.com> Date: Thu, 1 Aug 2024 11:27:47 +0800 Subject: [PATCH] =?UTF-8?q?Bert=E4=BB=A3=E7=A0=81=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- official/nlp/Bert/pretrain_config.yaml | 18 +- .../Bert/pretrain_config_Ascend_Boost.yaml | 18 +- .../nlp/Bert/pretrain_config_Ascend_Thor.yaml | 17 +- official/nlp/Bert/src/__init__.py | 4 +- official/nlp/Bert/src/bert_model.py | 403 +++++------------- .../generate_chinesener_mindrecord.py | 2 +- .../generate_squad_mindrecord.py | 2 +- official/nlp/Bert/task_classifier_config.yaml | 5 +- .../nlp/Bert/task_classifier_cpu_config.yaml | 5 +- official/nlp/Bert/task_ner_config.yaml | 5 +- official/nlp/Bert/task_ner_cpu_config.yaml | 5 +- official/nlp/Bert/task_squad_config.yaml | 5 +- 12 files changed, 190 insertions(+), 299 deletions(-) diff --git a/official/nlp/Bert/pretrain_config.yaml b/official/nlp/Bert/pretrain_config.yaml index d4469561b..2708b1f9e 100644 --- a/official/nlp/Bert/pretrain_config.yaml +++ b/official/nlp/Bert/pretrain_config.yaml @@ -109,6 +109,10 @@ base_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True # nezha nezha_batch_size: 96 nezha_net_cfg: @@ -127,6 +131,10 @@ nezha_net_cfg: use_relative_positions: True dtype: mstype.float32 compute_type: mstype.float16 + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True # large large_batch_size: 24 large_net_cfg: @@ -145,6 +153,10 @@ large_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True # Accelerated large network which is only supported in Ascend yet. large_boost_batch_size: 24 large_boost_net_cfg: @@ -163,8 +175,10 @@ large_boost_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 - - + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True --- # Help description for each configuration enable_modelarts: "Whether training on modelarts, default: False" diff --git a/official/nlp/Bert/pretrain_config_Ascend_Boost.yaml b/official/nlp/Bert/pretrain_config_Ascend_Boost.yaml index 032ac0db4..7af63586c 100644 --- a/official/nlp/Bert/pretrain_config_Ascend_Boost.yaml +++ b/official/nlp/Bert/pretrain_config_Ascend_Boost.yaml @@ -110,6 +110,10 @@ base_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True # nezha nezha_batch_size: 96 nezha_net_cfg: @@ -128,6 +132,10 @@ nezha_net_cfg: use_relative_positions: True dtype: mstype.float32 compute_type: mstype.float16 + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True # large large_batch_size: 24 large_net_cfg: @@ -146,6 +154,10 @@ large_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True # Accelerated large network which is only supported in Ascend yet. large_boost_batch_size: 24 large_boost_net_cfg: @@ -164,8 +176,10 @@ large_boost_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 - - + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True --- # Help description for each configuration enable_modelarts: "Whether training on modelarts, default: False" diff --git a/official/nlp/Bert/pretrain_config_Ascend_Thor.yaml b/official/nlp/Bert/pretrain_config_Ascend_Thor.yaml index 31ac77f8a..e35aab0ef 100644 --- a/official/nlp/Bert/pretrain_config_Ascend_Thor.yaml +++ b/official/nlp/Bert/pretrain_config_Ascend_Thor.yaml @@ -109,6 +109,10 @@ base_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True # nezha nezha_batch_size: 96 nezha_net_cfg: @@ -127,6 +131,10 @@ nezha_net_cfg: use_relative_positions: True dtype: mstype.float32 compute_type: mstype.float16 + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True # large large_batch_size: 20 large_net_cfg: @@ -145,6 +153,10 @@ large_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True # Accelerated large network which is only supported in Ascend yet. large_boost_batch_size: 20 large_boost_net_cfg: @@ -163,7 +175,10 @@ large_boost_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 - + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True --- # Help description for each configuration diff --git a/official/nlp/Bert/src/__init__.py b/official/nlp/Bert/src/__init__.py index 85eb83764..57b62f938 100644 --- a/official/nlp/Bert/src/__init__.py +++ b/official/nlp/Bert/src/__init__.py @@ -21,7 +21,7 @@ from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \ BertTrainOneStepWithLossScaleCellForAdam, \ BertNetworkMatchBucket, BertPretrainEval from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \ - BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \ + BertOutput, BertSelfAttention, BertTransformer, \ EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \ SaturateCast, CreateAttentionMaskFromInputMask from .adam import AdamWeightDecayForBert, AdamWeightDecayOp @@ -32,7 +32,7 @@ __all__ = [ "BertTrainAccumulationAllReducePostWithLossScaleCell", "BertNetworkMatchBucket", "BertPretrainEval", "BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput", - "BertSelfAttention", "BertTransformer", "EmbeddingLookup", + "BertSelfAttention", "BertTransformer", "EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert", "RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask", "BertTrainOneStepWithLossScaleCellForAdam", "AdamWeightDecayOp" diff --git a/official/nlp/Bert/src/bert_model.py b/official/nlp/Bert/src/bert_model.py index ecc07b5b1..2b291a0dd 100644 --- a/official/nlp/Bert/src/bert_model.py +++ b/official/nlp/Bert/src/bert_model.py @@ -54,6 +54,11 @@ class BertConfig: use_relative_positions (bool): Specifies whether to use relative positions. Default: False. dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32. compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. + return_all_encoders (bool): Specifies whether to return all encoders. Default: False. + has_attention_mask (bool): Specifies whether to use attention mask. Default: False. + max_relative_position (int): Max value of relative position. Default: 16. + use_token_type (bool): Specifies whether to use token type embeddings. Default: True. + """ def __init__(self, seq_length=128, @@ -70,7 +75,11 @@ class BertConfig: initializer_range=0.02, use_relative_positions=False, dtype=mstype.float32, - compute_type=mstype.float32): + compute_type=mstype.float32, + return_all_encoders=False, + has_attention_mask=True, + max_relative_position=16, + use_token_type=True): self.seq_length = seq_length self.vocab_size = vocab_size self.hidden_size = hidden_size @@ -86,108 +95,39 @@ class BertConfig: self.use_relative_positions = use_relative_positions self.dtype = dtype self.compute_type = compute_type - - -class EmbeddingLookup(nn.Cell): - """ - A embeddings lookup table with a fixed dictionary and size. - - Args: - vocab_size (int): Size of the dictionary of embeddings. - embedding_size (int): The size of each embedding vector. - embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of - each embedding vector. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - """ - def __init__(self, - vocab_size, - embedding_size, - embedding_shape, - use_one_hot_embeddings=False, - initializer_range=0.02): - super(EmbeddingLookup, self).__init__() - self.vocab_size = vocab_size - self.use_one_hot_embeddings = use_one_hot_embeddings - self.embedding_table = Parameter(initializer - (TruncatedNormal(initializer_range), - [vocab_size, embedding_size])) - self.expand = P.ExpandDims() - self.shape_flat = (-1,) - self.gather = P.Gather() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.array_mul = P.MatMul() - self.reshape = P.Reshape() - self.shape = tuple(embedding_shape) - - def construct(self, input_ids): - """Get output and embeddings lookup table""" - extended_ids = self.expand(input_ids, -1) - flat_ids = self.reshape(extended_ids, self.shape_flat) - if self.use_one_hot_embeddings: - one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) - output_for_reshape = self.array_mul( - one_hot_ids, self.embedding_table) - else: - output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) - output = self.reshape(output_for_reshape, self.shape) - return output, self.embedding_table.value() - + self.return_all_encoders = return_all_encoders + self.has_attention_mask = has_attention_mask + self.max_relative_position = max_relative_position + self.use_token_type = use_token_type class EmbeddingPostprocessor(nn.Cell): """ Postprocessors apply positional and token type embeddings to word embeddings. Args: - embedding_size (int): The size of each embedding vector. + config (Class): Configuration for BertModel. embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of each embedding vector. - use_token_type (bool): Specifies whether to use token type embeddings. Default: False. - token_type_vocab_size (int): Size of token type vocab. Default: 16. - use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - max_position_embeddings (int): Maximum length of sequences used in this - model. Default: 512. - dropout_prob (float): The dropout probability. Default: 0.1. """ - def __init__(self, - embedding_size, - embedding_shape, - use_relative_positions=False, - use_token_type=False, - token_type_vocab_size=16, - use_one_hot_embeddings=False, - initializer_range=0.02, - max_position_embeddings=512, - dropout_prob=0.1): + def __init__(self, config, embedding_shape): super(EmbeddingPostprocessor, self).__init__() - self.use_token_type = use_token_type - self.token_type_vocab_size = token_type_vocab_size - self.use_one_hot_embeddings = use_one_hot_embeddings - self.max_position_embeddings = max_position_embeddings + self.embedding_dim = config.hidden_size + self.use_token_type = config.use_token_type + self.token_type_vocab_size = config.type_vocab_size + self.max_position_embeddings = config.max_position_embeddings self.token_type_embedding = nn.Embedding( - vocab_size=token_type_vocab_size, - embedding_size=embedding_size, - use_one_hot=use_one_hot_embeddings) - self.shape_flat = (-1,) - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.1, mstype.float32) - self.array_mul = P.MatMul() + vocab_size=self.token_type_vocab_size, + embedding_size=self.embedding_dim) self.reshape = P.Reshape() self.shape = tuple(embedding_shape) - self.dropout = nn.Dropout(p=dropout_prob) - self.gather = P.Gather() - self.use_relative_positions = use_relative_positions - self.slice = P.StridedSlice() + self.dropout = nn.Dropout(p=config.hidden_dropout_prob) + self.use_relative_positions = config.use_relative_positions _, seq, _ = self.shape self.full_position_embedding = nn.Embedding( - vocab_size=max_position_embeddings, - embedding_size=embedding_size, + vocab_size=self.max_position_embeddings, + embedding_size=self.embedding_dim, use_one_hot=False) - self.layernorm = nn.LayerNorm((embedding_size,)) + self.layernorm = nn.LayerNorm((self.embedding_dim,)) self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) self.add = P.Add() @@ -212,25 +152,20 @@ class BertOutput(nn.Cell): Apply a linear computation to hidden status and a residual computation to input. Args: + config (Class): Configuration for BertModel. in_channels (int): Input channels. out_channels (int): Output channels. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - dropout_prob (float): The dropout probability. Default: 0.1. - compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. """ - def __init__(self, - in_channels, - out_channels, - initializer_range=0.02, - dropout_prob=0.1, - compute_type=mstype.float32): + def __init__(self, config, in_channels, out_channels): super(BertOutput, self).__init__() - self.dense = nn.Dense(in_channels, out_channels, - weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) - self.dropout = nn.Dropout(p=dropout_prob) - self.dropout_prob = dropout_prob + self.dropout_prob = config.hidden_dropout_prob + self.compute_type = config.compute_type + self.out_channels = out_channels + self.dense = nn.Dense(in_channels, self.out_channels, + weight_init=TruncatedNormal(config.initializer_range)).to_float(self.compute_type) + self.dropout = nn.Dropout(p=self.dropout_prob) self.add = P.Add() - self.layernorm = nn.LayerNorm((out_channels,)).to_float(compute_type) + self.layernorm = nn.LayerNorm((self.out_channels,)).to_float(self.compute_type) self.cast = P.Cast() def construct(self, hidden_status, input_tensor): @@ -334,10 +269,9 @@ class SaturateCast(nn.Cell): the danger that the value will overflow or underflow. Args: - src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32. dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32. """ - def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): + def __init__(self, dst_type=mstype.float32): super(SaturateCast, self).__init__() np_type = mstype.dtype_to_nptype(dst_type) @@ -360,60 +294,39 @@ class BertAttention(nn.Cell): Apply multi-headed attention from "from_tensor" to "to_tensor". Args: + config (Class): Configuration for BertModel. from_tensor_width (int): Size of last dim of from_tensor. to_tensor_width (int): Size of last dim of to_tensor. - num_attention_heads (int): Number of attention heads. Default: 1. size_per_head (int): Size of each attention head. Default: 512. - query_act (str): Activation function for the query transform. Default: None. - key_act (str): Activation function for the key transform. Default: None. - value_act (str): Activation function for the value transform. Default: None. - has_attention_mask (bool): Specifies whether to use attention mask. Default: False. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.0. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32. """ - def __init__(self, + def __init__(self, config, from_tensor_width, to_tensor_width, - num_attention_heads=1, size_per_head=512, - query_act=None, - key_act=None, - value_act=None, - has_attention_mask=False, - attention_probs_dropout_prob=0.0, - use_one_hot_embeddings=False, - initializer_range=0.02, - use_relative_positions=False, - compute_type=mstype.float32): + use_one_hot_embeddings=False): super(BertAttention, self).__init__() - self.num_attention_heads = num_attention_heads + self.num_attention_heads = config.num_attention_heads self.size_per_head = size_per_head - self.has_attention_mask = has_attention_mask - self.use_relative_positions = use_relative_positions - + self.has_attention_mask = config.has_attention_mask + self.use_relative_positions = config.use_relative_positions + self.compute_type = config.compute_type self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head)) self.reshape = P.Reshape() self.shape_from_2d = (-1, from_tensor_width) self.shape_to_2d = (-1, to_tensor_width) - weight = TruncatedNormal(initializer_range) - units = num_attention_heads * size_per_head + weight = TruncatedNormal(config.initializer_range) + units = self.num_attention_heads * size_per_head self.query_layer = nn.Dense(from_tensor_width, units, - activation=query_act, - weight_init=weight).to_float(compute_type) + weight_init=weight).to_float(self.compute_type) self.key_layer = nn.Dense(to_tensor_width, units, - activation=key_act, - weight_init=weight).to_float(compute_type) + weight_init=weight).to_float(self.compute_type) self.value_layer = nn.Dense(to_tensor_width, units, - activation=value_act, - weight_init=weight).to_float(compute_type) + weight_init=weight).to_float(self.compute_type) self.matmul_trans_b = P.BatchMatMul(transpose_b=True) self.multiply = P.Mul() @@ -425,7 +338,7 @@ class BertAttention(nn.Cell): self.matmul = P.BatchMatMul() self.softmax = nn.Softmax() - self.dropout = nn.Dropout(p=attention_probs_dropout_prob) + self.dropout = nn.Dropout(p=config.attention_probs_dropout_prob) if self.has_attention_mask: self.expand_dims = P.ExpandDims() @@ -434,18 +347,24 @@ class BertAttention(nn.Cell): self.cast = P.Cast() self.get_dtype = P.DType() - self.shape_return = (-1, num_attention_heads * size_per_head) + self.shape_return = (-1, self.num_attention_heads * size_per_head) - self.cast_compute_type = SaturateCast(dst_type=compute_type) + self.cast_compute_type = SaturateCast(dst_type=self.compute_type) if self.use_relative_positions: self._generate_relative_positions_embeddings = \ RelaPosEmbeddingsGenerator(depth=size_per_head, - max_relative_position=16, - initializer_range=initializer_range, + max_relative_position=config.max_relative_position, + initializer_range=config.initializer_range, use_one_hot_embeddings=use_one_hot_embeddings) def construct(self, from_tensor, to_tensor, attention_mask): """reshape 2d/3d input tensors to 2d""" + # Scalar dimensions referenced here: + # B = batch size (number of sequences) + # F = `from_tensor` sequence length + # T = `to_tensor` sequence length + # N = `num_attention_heads` + # H = `size_per_head` shape_from = F.shape(attention_mask)[2] from_tensor = F.depend(from_tensor, shape_from) from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) @@ -458,10 +377,11 @@ class BertAttention(nn.Cell): query_layer = self.transpose(query_layer, self.trans_shape) key_layer = self.reshape(key_out, (-1, shape_from, self.num_attention_heads, self.size_per_head)) key_layer = self.transpose(key_layer, self.trans_shape) - + # `attention_scores` = [B, N, F, T] attention_scores = self.matmul_trans_b(query_layer, key_layer) # use_relative_position, supplementary logic + # Self-Attention with Relative Position Representations if self.use_relative_positions: # relations_keys is [F|T, F|T, H] relations_keys = self._generate_relative_positions_embeddings(shape_from) @@ -542,51 +462,32 @@ class BertSelfAttention(nn.Cell): Apply self-attention. Args: - hidden_size (int): Size of the bert encoder layers. - num_attention_heads (int): Number of attention heads. Default: 12. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.1. + config (Class): Configuration for BertModel. use_one_hot_embeddings (bool): Specifies whether to use one_hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32. """ - def __init__(self, - hidden_size, - num_attention_heads=12, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - compute_type=mstype.float32): + def __init__(self, config, use_one_hot_embeddings=False): super(BertSelfAttention, self).__init__() - if hidden_size % num_attention_heads != 0: + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + if self.hidden_size % self.num_attention_heads != 0: raise ValueError("The hidden size (%d) is not a multiple of the number " - "of attention heads (%d)" % (hidden_size, num_attention_heads)) + "of attention heads (%d)" % (self.hidden_size, self.num_attention_heads)) - self.size_per_head = int(hidden_size / num_attention_heads) + self.size_per_head = int(self.hidden_size / self.num_attention_heads) self.attention = BertAttention( - from_tensor_width=hidden_size, - to_tensor_width=hidden_size, - num_attention_heads=num_attention_heads, + config=config, + from_tensor_width=self.hidden_size, + to_tensor_width=self.hidden_size, size_per_head=self.size_per_head, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - use_relative_positions=use_relative_positions, - has_attention_mask=True, - compute_type=compute_type) - - self.output = BertOutput(in_channels=hidden_size, - out_channels=hidden_size, - initializer_range=initializer_range, - dropout_prob=hidden_dropout_prob, - compute_type=compute_type) + use_one_hot_embeddings=use_one_hot_embeddings + ) + + self.output = BertOutput(config=config, + in_channels=self.hidden_size, + out_channels=self.hidden_size) self.reshape = P.Reshape() - self.shape = (-1, hidden_size) + self.shape = (-1, self.hidden_size) def construct(self, input_tensor, attention_mask): attention_output = self.attention(input_tensor, input_tensor, attention_mask) @@ -599,48 +500,24 @@ class BertEncoderCell(nn.Cell): Encoder cells used in BertTransformer. Args: - hidden_size (int): Size of the bert encoder layers. Default: 768. - num_attention_heads (int): Number of attention heads. Default: 12. - intermediate_size (int): Size of intermediate layer. Default: 3072. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.02. + config (Class): Configuration for BertModel. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - hidden_act (str): Activation function. Default: "gelu". - compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32. """ - def __init__(self, - hidden_size=768, - num_attention_heads=12, - intermediate_size=3072, - attention_probs_dropout_prob=0.02, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - hidden_act="gelu", - compute_type=mstype.float32): + def __init__(self, config, use_one_hot_embeddings=False): super(BertEncoderCell, self).__init__() + self.hidden_size = config.hidden_size + self.compute_type = config.compute_type + self.intermediate_size = config.intermediate_size self.attention = BertSelfAttention( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - use_relative_positions=use_relative_positions, - compute_type=compute_type) - self.intermediate = nn.Dense(in_channels=hidden_size, - out_channels=intermediate_size, - activation=hidden_act, - weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) - self.output = BertOutput(in_channels=intermediate_size, - out_channels=hidden_size, - initializer_range=initializer_range, - dropout_prob=hidden_dropout_prob, - compute_type=compute_type) + config=config, + use_one_hot_embeddings=use_one_hot_embeddings) + self.intermediate = nn.Dense(in_channels=self.hidden_size, + out_channels=self.intermediate_size, + activation=config.hidden_act, + weight_init=TruncatedNormal(config.initializer_range)).to_float(self.compute_type) + self.output = BertOutput(config=config, + in_channels=self.intermediate_size, + out_channels=self.hidden_size) def construct(self, hidden_states, attention_mask): # self-attention @@ -657,54 +534,24 @@ class BertTransformer(nn.Cell): Multi-layer bert transformer. Args: - hidden_size (int): Size of the encoder layers. - num_hidden_layers (int): Number of hidden layers in encoder cells. - num_attention_heads (int): Number of attention heads in encoder cells. Default: 12. - intermediate_size (int): Size of intermediate layer in encoder cells. Default: 3072. - attention_probs_dropout_prob (float): The dropout probability for - BertAttention. Default: 0.1. + config (Class): Configuration for BertModel. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. - initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02. - hidden_dropout_prob (float): The dropout probability for BertOutput. Default: 0.1. - use_relative_positions (bool): Specifies whether to use relative positions. Default: False. - hidden_act (str): Activation function used in the encoder cells. Default: "gelu". - compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32. - return_all_encoders (bool): Specifies whether to return all encoders. Default: False. """ - def __init__(self, - hidden_size, - num_hidden_layers, - num_attention_heads=12, - intermediate_size=3072, - attention_probs_dropout_prob=0.1, - use_one_hot_embeddings=False, - initializer_range=0.02, - hidden_dropout_prob=0.1, - use_relative_positions=False, - hidden_act="gelu", - compute_type=mstype.float32, - return_all_encoders=False): + def __init__(self, config, use_one_hot_embeddings=False): super(BertTransformer, self).__init__() - self.return_all_encoders = return_all_encoders + self.return_all_encoders = config.return_all_encoders + self.hidden_size = config.hidden_size layers = [] - for _ in range(num_hidden_layers): - layer = BertEncoderCell(hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - attention_probs_dropout_prob=attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=initializer_range, - hidden_dropout_prob=hidden_dropout_prob, - use_relative_positions=use_relative_positions, - hidden_act=hidden_act, - compute_type=compute_type) + for _ in range(config.num_hidden_layers): + layer = BertEncoderCell(config=config, + use_one_hot_embeddings=use_one_hot_embeddings) layers.append(layer) self.layers = nn.CellList(layers) self.reshape = P.Reshape() - self.shape = (-1, hidden_size) + self.shape = (-1, self.hidden_size) def construct(self, input_tensor, attention_mask): """Multi-layer bert transformer.""" @@ -756,10 +603,7 @@ class BertModel(nn.Cell): is_training (bool): True for training mode. False for eval mode. use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False. """ - def __init__(self, - config, - is_training, - use_one_hot_embeddings=False): + def __init__(self, config, is_training, use_one_hot_embeddings=False): super(BertModel, self).__init__() config = copy.deepcopy(config) if not is_training: @@ -770,6 +614,8 @@ class BertModel(nn.Cell): self.num_hidden_layers = config.num_hidden_layers self.embedding_size = config.hidden_size self.token_type_ids = None + self.compute_type = config.compute_type + self.return_all_encoders = config.return_all_encoders self.last_idx = self.num_hidden_layers - 1 output_embedding_shape = [-1, config.seq_length, self.embedding_size] @@ -781,58 +627,45 @@ class BertModel(nn.Cell): embedding_table=TruncatedNormal(config.initializer_range)) self.bert_embedding_postprocessor = EmbeddingPostprocessor( - embedding_size=self.embedding_size, - embedding_shape=output_embedding_shape, - use_relative_positions=config.use_relative_positions, - use_token_type=True, - token_type_vocab_size=config.type_vocab_size, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=0.02, - max_position_embeddings=config.max_position_embeddings, - dropout_prob=config.hidden_dropout_prob) + config=config, + embedding_shape=output_embedding_shape) self.bert_encoder = BertTransformer( - hidden_size=self.hidden_size, - num_attention_heads=config.num_attention_heads, - num_hidden_layers=self.num_hidden_layers, - intermediate_size=config.intermediate_size, - attention_probs_dropout_prob=config.attention_probs_dropout_prob, - use_one_hot_embeddings=use_one_hot_embeddings, - initializer_range=config.initializer_range, - hidden_dropout_prob=config.hidden_dropout_prob, - use_relative_positions=config.use_relative_positions, - hidden_act=config.hidden_act, - compute_type=config.compute_type, - return_all_encoders=True) + config=config, + use_one_hot_embeddings=use_one_hot_embeddings) self.cast = P.Cast() self.dtype = config.dtype - self.cast_compute_type = SaturateCast(dst_type=config.compute_type) + self.cast_compute_type = SaturateCast(dst_type=self.compute_type) self.slice = P.StridedSlice() self.squeeze_1 = P.Squeeze(axis=1) self.dense = nn.Dense(self.hidden_size, self.hidden_size, activation="tanh", - weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type) + weight_init=TruncatedNormal(config.initializer_range)).to_float(self.compute_type) self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config) def construct(self, input_ids, token_type_ids, input_mask): """Bidirectional Encoder Representations from Transformers.""" # embedding + # Perform embedding lookup on the word ids. embedding_tables = self.bert_embedding_lookup.embedding_table word_embeddings = self.bert_embedding_lookup(input_ids) + # Add positional embeddings and token type embeddings, then layer + # normalize and perform dropout. embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings) - # attention mask [batch_size, seq_length, seq_length] + # attention mask [batch_size, 1, seq_length] attention_mask = self._create_attention_mask_from_input_mask(input_mask) # bert encoder encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output), attention_mask) - - sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) - + if self.return_all_encoders: + sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) + else: + sequence_output = self.cast(encoder_output[0], self.dtype) # pooler batch_size = P.Shape()(input_ids)[0] sequence_slice = self.slice(sequence_output, diff --git a/official/nlp/Bert/src/generate_mindrecord/generate_chinesener_mindrecord.py b/official/nlp/Bert/src/generate_mindrecord/generate_chinesener_mindrecord.py index f74db8f9b..3da29d5ef 100644 --- a/official/nlp/Bert/src/generate_mindrecord/generate_chinesener_mindrecord.py +++ b/official/nlp/Bert/src/generate_mindrecord/generate_chinesener_mindrecord.py @@ -286,7 +286,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description='Make dataset in mindrecord format.') parser.add_argument('--data_dir', default=".", type=str, help='') - parser.add_argument('--max_seq_length', default=202, type=int, help='') + parser.add_argument('--max_seq_length', default=128, type=int, help='') parser.add_argument('--do_train', default=True, type=bool, help='') parser.add_argument('--do_eval', default=True, type=bool, help='') parser.add_argument('--do_lower_case', default=True, type=bool, help='') diff --git a/official/nlp/Bert/src/generate_mindrecord/generate_squad_mindrecord.py b/official/nlp/Bert/src/generate_mindrecord/generate_squad_mindrecord.py index f7274c596..698545111 100644 --- a/official/nlp/Bert/src/generate_mindrecord/generate_squad_mindrecord.py +++ b/official/nlp/Bert/src/generate_mindrecord/generate_squad_mindrecord.py @@ -37,7 +37,7 @@ def parse_args(): parser.add_argument("--do_lower_case", type=bool, default=True, help="Whether to lower case the input text. " "Should be True for uncased models" " and False for cased models.") - parser.add_argument("--max_seq_length", type=int, default=384, help="Maximum sequence length.") + parser.add_argument("--max_seq_length", type=int, default=128, help="Maximum sequence length.") parser.add_argument("--doc_stride", type=int, default=128, help="When splitting up a long document into chunks, " "how much stride to take between chunks.") parser.add_argument("--max_query_length", type=int, default=64, diff --git a/official/nlp/Bert/task_classifier_config.yaml b/official/nlp/Bert/task_classifier_config.yaml index 258005bbb..d1fe14c8b 100644 --- a/official/nlp/Bert/task_classifier_config.yaml +++ b/official/nlp/Bert/task_classifier_config.yaml @@ -71,7 +71,10 @@ bert_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 - + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True --- # Help description for each configuration enable_modelarts: "Whether training on modelarts, default: False" diff --git a/official/nlp/Bert/task_classifier_cpu_config.yaml b/official/nlp/Bert/task_classifier_cpu_config.yaml index 0e0b8a397..30cd4dd66 100644 --- a/official/nlp/Bert/task_classifier_cpu_config.yaml +++ b/official/nlp/Bert/task_classifier_cpu_config.yaml @@ -73,7 +73,10 @@ bert_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 - + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True --- # Help description for each configuration enable_modelarts: "Whether training on modelarts, default: False" diff --git a/official/nlp/Bert/task_ner_config.yaml b/official/nlp/Bert/task_ner_config.yaml index 7243bf2ef..cb7d621df 100644 --- a/official/nlp/Bert/task_ner_config.yaml +++ b/official/nlp/Bert/task_ner_config.yaml @@ -74,7 +74,10 @@ bert_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 - + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True --- # Help description for each configuration enable_modelarts: "Whether training on modelarts, default: False" diff --git a/official/nlp/Bert/task_ner_cpu_config.yaml b/official/nlp/Bert/task_ner_cpu_config.yaml index 765ea1475..fedc14c56 100644 --- a/official/nlp/Bert/task_ner_cpu_config.yaml +++ b/official/nlp/Bert/task_ner_cpu_config.yaml @@ -75,7 +75,10 @@ bert_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 - + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True --- # Help description for each configuration enable_modelarts: "Whether training on modelarts, default: False" diff --git a/official/nlp/Bert/task_squad_config.yaml b/official/nlp/Bert/task_squad_config.yaml index 8ff70db8c..55e6f74be 100644 --- a/official/nlp/Bert/task_squad_config.yaml +++ b/official/nlp/Bert/task_squad_config.yaml @@ -73,7 +73,10 @@ bert_net_cfg: use_relative_positions: False dtype: mstype.float32 compute_type: mstype.float16 - + return_all_encoders: True + has_attention_mask: True + max_relative_position: 16 + use_token_type: True --- # Help description for each configuration enable_modelarts: "Whether training on modelarts, default: False" -- Gitee