From 701c2194bb09bb108b4d2c87804355f7dd62d3b9 Mon Sep 17 00:00:00 2001 From: PaddlePaddle-Gardener Date: Wed, 12 Jan 2022 14:27:32 +0800 Subject: [PATCH] mirgate_38456 --- python/paddle/nn/functional/loss.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 8eb6e05fc0..f13f14cdde 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1696,6 +1696,13 @@ def cross_entropy(input, out = _C_ops.elementwise_mul(out, weight_gather_reshape) else: + if input.shape[axis] != weight.shape[-1]: + raise ValueError( + "input's class_dimension({}) must equal to " + "weight's class_dimension({}) " + "when weight is provided" \ + .format(input.shape[axis], weight.shape[-1])) + valid_label = paddle.where(label == ignore_index, paddle.zeros_like(label), label) # TODO: Temporarily use paddle.nonzero instead of paddle.max @@ -1715,12 +1722,6 @@ def cross_entropy(input, raise ValueError( "Target({}) is out of class_dimension's upper bound({})". format(invalid_label[0], input.shape[axis] - 1)) - if input.shape[axis] != weight.shape[-1]: - raise ValueError( - "input's class_dimension({}) must equal to " - "weight's class_dimension({}) " - "when weight is provided" \ - .format(input.shape[axis], weight.shape[-1])) ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype) -- Gitee