From d724e285d0f68d624336afcf81a9b29d6ec7d0f1 Mon Sep 17 00:00:00 2001 From: PaddlePaddle-Gardener Date: Thu, 13 Jan 2022 14:21:57 +0800 Subject: [PATCH] mirgate_38885 --- paddle/fluid/operators/abs_op.cu | 4 +- paddle/fluid/operators/bce_loss_op.cu | 4 +- paddle/fluid/operators/clip_op.h | 2 +- paddle/fluid/operators/p_norm_op.cu | 291 ++++++++++-------------- paddle/fluid/operators/renorm_op.cu | 238 +++++++++++++++++++ paddle/pten/kernels/gpu/cast_kernel.cu | 2 +- paddle/pten/kernels/gpu/scale_kernel.cu | 2 +- 7 files changed, 361 insertions(+), 182 deletions(-) diff --git a/paddle/fluid/operators/abs_op.cu b/paddle/fluid/operators/abs_op.cu index 94b0a3ae72..86748d4505 100644 --- a/paddle/fluid/operators/abs_op.cu +++ b/paddle/fluid/operators/abs_op.cu @@ -24,14 +24,14 @@ struct CudaAbsFunctor; template struct CudaAbsFunctor>> { - __device__ __forceinline__ math::Real operator()(const T& x) const { + __device__ __forceinline__ math::Real operator()(const T x) const { return abs(x); } }; template struct CudaAbsFunctor>> { - __device__ __forceinline__ T operator()(const T& x) const { + __device__ __forceinline__ T operator()(const T x) const { return std::abs(x); } }; diff --git a/paddle/fluid/operators/bce_loss_op.cu b/paddle/fluid/operators/bce_loss_op.cu index 18562b2432..da96aa92cd 100644 --- a/paddle/fluid/operators/bce_loss_op.cu +++ b/paddle/fluid/operators/bce_loss_op.cu @@ -28,8 +28,8 @@ template struct BCELossGradFunctor { T one = static_cast(1.0f); T eps = static_cast(1e-12); - __device__ __forceinline__ T operator()(const T& x, const T& label, - const T& dout) const { + __device__ __forceinline__ T operator()(const T x, const T label, + const T dout) const { T term1 = max((one - x) * x, eps); return (dout * (x - label) / term1); } diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h index f08a7b2d57..3672fa983e 100644 --- a/paddle/fluid/operators/clip_op.h +++ b/paddle/fluid/operators/clip_op.h @@ -32,7 +32,7 @@ template class ClipFunctor { public: explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} - HOSTDEVICE T operator()(const T& x) const { + HOSTDEVICE T operator()(const T x) const { return x < min_ ? min_ : x > max_ ? max_ : x; } diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index cfe778c491..b2a9ca6f93 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -21,7 +21,11 @@ limitations under the License. */ namespace cub = hipcub; #endif #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/operators/p_norm_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -56,87 +60,41 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) { return pow(base, exponent); } -template -__global__ void Pnorm(const T* x, const int pre, - const int axis_n, // dim in axis - const int post, float porder, T* out_norm) { - using MT = typename details::MPTypeTrait::Type; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - int num = pre * post; - auto porder_t = static_cast(porder); - auto porder_inv = static_cast(1.0 / porder); - - for (int i = blockIdx.x; i < num; i += gridDim.x) { - int base = (i / post) * post * axis_n + (i % post); - MT sum = static_cast(0.0); - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - const MT x_ij = static_cast(x[base + j * post]); - sum += inline_pow(inline_abs(x_ij), porder_t); - } - MT reduce_result = BlockReduce(temp_storage).Sum(sum); - if (threadIdx.x == 0) - out_norm[i] = static_cast(inline_pow(reduce_result, porder_inv)); +template +struct NonzeroFunctor { + HOSTDEVICE explicit inline NonzeroFunctor() {} + HOSTDEVICE inline T operator()(const T x) const { + return static_cast(static_cast(x) != 0); } -} +}; -template -__global__ void ZeorNorm(const T* x, const int pre, - const int axis_n, // dim in axis - const int post, T* out_norm) { - using MT = typename details::MPTypeTrait::Type; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - int num = pre * post; - for (int i = blockIdx.x; i < num; i += gridDim.x) { - int base = (i / post) * post * axis_n + (i % post); - MT sum = static_cast(0.0); - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - const MT x_ij = static_cast(x[base + j * post]); - sum += static_cast(static_cast(x_ij) != 0); - } - MT reduce_result = BlockReduce(temp_storage).Sum(sum); - if (threadIdx.x == 0) out_norm[i] = static_cast(reduce_result); +template +struct AbsFunctor { + HOSTDEVICE explicit inline AbsFunctor() {} + HOSTDEVICE inline T operator()(const T x) const { + return static_cast(inline_abs(x)); } -} +}; -template -__global__ void InfNorm(const T* x, const int pre, - const int axis_n, // dim in axis - const int post, T* out_norm) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - int num = pre * post; - for (int i = blockIdx.x; i < num; i += gridDim.x) { - int base = (i / post) * post * axis_n + (i % post); - T cur_max = inline_abs(x[base]); - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - T x_ij_abs = inline_abs(x[base + j * post]); - if (cur_max < x_ij_abs) cur_max = x_ij_abs; - } - T reduce_result = BlockReduce(temp_storage).Reduce(cur_max, cub::Max()); - if (threadIdx.x == 0) out_norm[i] = reduce_result; +template +struct UnsignedPowFunctor { + HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { + this->porder = porder; } -} + HOSTDEVICE inline Ty operator()(const Tx x) const { + return static_cast(inline_pow(inline_abs(x), static_cast(porder))); + } + float porder; +}; -template -__global__ void NegInfNorm(const T* x, const int pre, - const int axis_n, // dim in axis - const int post, T* out_norm) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - int num = pre * post; - for (int i = blockIdx.x; i < num; i += gridDim.x) { - int base = (i / post) * post * axis_n + (i % post); - T cur_min = inline_abs(x[base]); - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - T x_ij_abs = inline_abs(x[base + j * post]); - if (cur_min > x_ij_abs) cur_min = x_ij_abs; - } - T reduce_result = BlockReduce(temp_storage).Reduce(cur_min, cub::Min()); - if (threadIdx.x == 0) out_norm[i] = reduce_result; +template +struct PowFunctor { + HOSTDEVICE explicit inline PowFunctor(float porder) { this->porder = porder; } + HOSTDEVICE inline Ty operator()(const Tx x) const { + return static_cast(inline_pow(x, static_cast(porder))); } -} + float porder; +}; template class PnormCUDAKernel : public framework::OpKernel { @@ -146,101 +104,84 @@ class PnormCUDAKernel : public framework::OpKernel { auto* out_norm = ctx.Output("Out"); const T* x = in_x->data(); T* norm = out_norm->mutable_data(ctx.GetPlace()); - auto xdim = in_x->dims(); auto ndim = out_norm->dims(); float porder = ctx.Attr("porder"); - int axis = ctx.Attr("axis"); bool asvector = ctx.Attr("asvector"); - if (axis < 0) axis = xdim.size() + axis; - int pre, n, post; - GetDims(xdim, axis, &pre, &n, &post, asvector); - - auto& dev_ctx = ctx.cuda_device_context(); + int axis = ctx.Attr("axis"); + std::vector reduce_axis = {axis}; + reduce_axis = GetReduceDim(reduce_axis, xdim.size(), asvector); -#ifdef __HIPCC__ - const int block = 256; -#else - const int block = 512; -#endif + auto stream = ctx.cuda_device_context().stream(); - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_threads / block, 1); - int grid = std::min(max_blocks, pre * post); + using MT = typename details::MPTypeTrait::Type; if (porder == 0) { - ZeorNorm<<>>(x, pre, n, post, - norm); + TensorReduceFunctorImpl>( + *in_x, out_norm, NonzeroFunctor(), reduce_axis, stream); } else if (porder == INFINITY) { - InfNorm<<>>(x, pre, n, post, - norm); + TensorReduceFunctorImpl>( + *in_x, out_norm, AbsFunctor(), reduce_axis, stream); } else if (porder == -INFINITY) { - NegInfNorm<<>>(x, pre, n, - post, norm); + TensorReduceFunctorImpl>( + *in_x, out_norm, AbsFunctor(), reduce_axis, stream); } else { - Pnorm<<>>(x, pre, n, post, - porder, norm); + framework::Tensor tmp_x; + tmp_x.mutable_data(xdim, ctx.GetPlace()); + std::vector ins = {in_x}; + std::vector outs = {&tmp_x}; + auto func = UnsignedPowFunctor(porder); + const auto& cuda_ctx = + ctx.template device_context(); + + LaunchSameDimsElementwiseCudaKernel>( + cuda_ctx, ins, &outs, func); + framework::Tensor tmp_y; + tmp_y.mutable_data(ndim, ctx.GetPlace()); + TensorReduceFunctorImpl>( + tmp_x, &tmp_y, kps::IdentityFunctor(), reduce_axis, stream); + const framework::Tensor* tmp_norm = &tmp_y; + ins = {tmp_norm}; + outs = {out_norm}; + auto func_inverse = UnsignedPowFunctor(1. / porder); + + LaunchSameDimsElementwiseCudaKernel>( + cuda_ctx, ins, &outs, func_inverse); } } }; -template -__global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad, - const float porder, const int pre, - const int axis_n, const int post, const T eps, - T* x_grad) { - using MT = typename details::MPTypeTrait::Type; - // dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x) - int num = pre * post; - auto porder_grad = static_cast(porder - 1.0f); - for (int i = blockIdx.x; i < num; i += gridDim.x) { - __shared__ MT pnorm_i; - __shared__ MT yout_i; - - auto base = (i / post) * post * axis_n + (i % post); - - if (threadIdx.x == 0) { - pnorm_i = static_cast(x_norm[i]); - yout_i = static_cast(y_grad[i]); - } - __syncthreads(); - - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - int index = base + j * post; - const MT x_ij = static_cast(inline_abs(x[index])); - x_grad[index] = static_cast( - inline_pow(x_ij, porder_grad) / - (inline_pow(pnorm_i, porder_grad) + static_cast(eps)) * yout_i * - static_cast(inline_sign(x[index]))); - } +template +struct AbsMaxAndMinGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + auto equals = ((*x).abs() == y->broadcast(dim)); + auto ones = dx->constant(static_cast(1.)); + auto negs = dx->constant(static_cast(-1.)); + auto zeros = dx->constant(static_cast(0.)); + auto positives = (*x) > zeros; + dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros) * + positives.select(ones, negs); } -} - -template -__global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad, - const int pre, const int axis_n, const int post, - T* x_grad) { - int num = pre * post; - for (int i = blockIdx.x; i < num; i += gridDim.x) { - __shared__ T pnorm_i; - __shared__ T yout_i; - auto base = (i / post) * post * axis_n + (i % post); - if (threadIdx.x == 0) { - pnorm_i = x_norm[i]; - yout_i = y_grad[i]; - } - __syncthreads(); +}; - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - int index = base + j * post; - const T x_ij = inline_abs(x[index]); - if (x_ij == pnorm_i) { - x_grad[index] = static_cast(inline_sign(x[index])) * yout_i; - } else { - x_grad[index] = static_cast(0); - } - } +template +struct PNormPostGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + auto ones = dx->constant(static_cast(1.)); + auto negs = dx->constant(static_cast(-1.)); + auto zeros = dx->constant(static_cast(0.)); + auto positives = (*x) > zeros; + dx->device(place) = (*dx) * dy->broadcast(dim) * y->broadcast(dim) * + positives.select(ones, negs); } -} +}; template class PnormGradCUDAKernel : public framework::OpKernel { @@ -252,40 +193,40 @@ class PnormGradCUDAKernel : public framework::OpKernel { ctx.Input(framework::GradVarName("Out")); auto* out_dx = ctx.Output(framework::GradVarName("X")); T* dx = out_dx->mutable_data(ctx.GetPlace()); - const T* x = in_x->data(); - const T* x_norm = in_norm->data(); - const T* norm_dy = in_norm_dy->data(); auto xdim = in_x->dims(); float porder = ctx.Attr("porder"); - T eps = static_cast(ctx.Attr("epsilon")); int axis = ctx.Attr("axis"); - bool asvector = ctx.Attr("asvector"); + bool reduce_all = (in_norm->numel() == 1); if (axis < 0) axis = xdim.size() + axis; - int pre, n, post; - GetDims(xdim, axis, &pre, &n, &post, asvector); - - auto& dev_ctx = ctx.cuda_device_context(); + const std::vector dims = {axis}; -#ifdef __HIPCC__ - const int block = 256; -#else - const int block = 512; -#endif + auto& cuda_ctx = ctx.template device_context(); - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_threads / block, 1); - int grid = std::min(max_blocks, pre * post); if (porder == 0) { math::SetConstant set_zero; - auto& dev_ctx = ctx.template device_context(); - set_zero(dev_ctx, out_dx, static_cast(0)); + set_zero(cuda_ctx, out_dx, static_cast(0)); } else if (porder == INFINITY || porder == -INFINITY) { - InfNormGradient<<>>( - x, x_norm, norm_dy, pre, n, post, dx); + LaunchReduceGradKernel>( + ctx, in_x, in_norm, in_norm_dy, out_dx, dims, reduce_all); } else { - PnormGradient<<>>( - x, x_norm, norm_dy, porder, pre, n, post, eps, dx); + framework::Tensor tmp_norm; + tmp_norm.mutable_data(in_norm->dims(), ctx.GetPlace()); + std::vector ins = {in_norm}; + std::vector outs = {&tmp_norm}; + auto pow_functor = PowFunctor(1. - porder); + LaunchSameDimsElementwiseCudaKernel>(cuda_ctx, ins, &outs, + pow_functor); + ins = {in_x}; + outs = {out_dx}; + auto unsigned_pow = UnsignedPowFunctor(porder - 1.); + LaunchSameDimsElementwiseCudaKernel>( + cuda_ctx, ins, &outs, unsigned_pow); + const framework::Tensor* tmp_norm_const = &tmp_norm; + LaunchReduceGradKernel>( + ctx, in_x, tmp_norm_const, in_norm_dy, out_dx, dims, reduce_all); } } }; diff --git a/paddle/fluid/operators/renorm_op.cu b/paddle/fluid/operators/renorm_op.cu index e69de29bb2..b21b9fde56 100644 --- a/paddle/fluid/operators/renorm_op.cu +++ b/paddle/fluid/operators/renorm_op.cu @@ -0,0 +1,238 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/renorm_op.h" + +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace operators { + +__device__ __forceinline__ float inline_pow(float base, float exponent) { + return pow(base, exponent); +} + +__device__ __forceinline__ double inline_pow(double base, double exponent) { + return pow(base, exponent); +} + +__device__ __forceinline__ float inline_abs(float x) { return abs(x); } +__device__ __forceinline__ double inline_abs(double x) { return abs(x); } + +template +struct UnsignedPowFunctor { + HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { + this->porder = porder; + } + HOSTDEVICE inline Ty operator()(const Tx x) const { + return static_cast(inline_pow(inline_abs(x), static_cast(porder))); + } + float porder; +}; + +template +__global__ void RenormKernelFunc3(int64_t size, T* dim_value, float p, + float max_norm) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + if (i < size) { + T temp = pow(dim_value[i], (T)(1.0 / p)); + dim_value[i] = 1.0; + if (temp > max_norm) dim_value[i] = max_norm / temp; + } +} + +template +__global__ void RenormKernelFunc4(const T* x_data, T* out_data, int64_t size, + T* dim_value, int64_t dimension_each, + int64_t dim_divisor) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + auto dim_index = i / dim_divisor % dimension_each; + if (i < size) { + if (dim_value[dim_index] < 1.0) + out_data[i] = dim_value[dim_index] * x_data[i]; + else + out_data[i] = x_data[i]; + } +} + +template +__global__ void RenormGradKernelFunc1(const T* x_data, const T* dout_data, + T* pow_value, T* mul_value, int64_t size, + int64_t dimension_each, float p, + int64_t dim_divisor) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + auto dim_index = i / dim_divisor % dimension_each; + if (i < size) { + pow_value[i] = pow(abs(x_data[i]), (T)p); + mul_value[i] = x_data[i] * dout_data[i]; + } +} + +template +__global__ void RenormGradKernelFunc2(const T* x_data, const T* dout_data, + T* dx_data, int64_t size, T* dim_value, + T* dim_power_sum, T* weight_derivative, + int64_t dimension_each, float p, + float max_norm, int64_t dim_divisor) { + int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x; + auto dim_index = i / dim_divisor % dimension_each; + if (i < dimension_each) { + dim_power_sum[i] = 0; + auto temp = pow(dim_value[i], (T)(1.0 / p)); + if (temp > max_norm) { + dim_power_sum[i] = pow(dim_value[i], (T)(-1.0 - 1.0 / p)) * -1 * max_norm; + dim_value[i] = max_norm / temp; + } else { + dim_value[i] = 1.0; + } + } + __syncthreads(); + if (i < size) { + dx_data[i] = dim_value[dim_index] * dout_data[i]; + dx_data[i] = dx_data[i] + + weight_derivative[dim_index] * dim_power_sum[dim_index] * + pow(abs(x_data[i]), T(p - 1.0)) * + (x_data[i] >= 0 ? 1 : -1); + } +} + +template +class CUDARenormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + auto numel = x->numel(); + const T* x_data = x->data(); + auto input_dims = x->dims(); + float max_norm = context.Attr("max_norm"); + float p = context.Attr("p"); + int dim = context.Attr("axis"); + auto dimension_each = input_dims[dim]; + auto dim_size = input_dims.size(); + framework::Tensor pow_value, dim_value; + int64_t dim_divisor = 1, pre_mul = 1; + for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; + for (int i = 0; i < dim; i++) pre_mul *= input_dims[i]; + pow_value.Resize( + framework::make_ddim({pre_mul, dimension_each, dim_divisor})); + dim_value.Resize(framework::make_ddim({dimension_each})); + pow_value.mutable_data(context.GetPlace()); + out->Resize(framework::make_ddim(framework::vectorize(input_dims))); + T* out_data = out->mutable_data(context.GetPlace()); + auto stream = context.cuda_device_context().stream(); + int block = std::min(numel, static_cast(256)); + using MT = typename details::MPTypeTrait::Type; + int grid = (numel + block - 1) / block; + + int block2 = std::min(dimension_each, static_cast(256)); + int grid2 = (dimension_each + block2 - 1) / block2; + std::vector ins = {x}; + std::vector outs = {&pow_value}; + auto func = UnsignedPowFunctor(p); + const auto& cuda_ctx = + context.template device_context(); + + LaunchSameDimsElementwiseCudaKernel>( + cuda_ctx, ins, &outs, func); + std::vector reduce_axis = {0, 2}; + TensorReduceFunctorImpl>( + pow_value, &dim_value, kps::IdentityFunctor(), reduce_axis, stream); + RenormKernelFunc3<<>>( + numel, dim_value.mutable_data(context.GetPlace()), p, max_norm); + RenormKernelFunc4<<>>( + x_data, out_data, numel, dim_value.mutable_data(context.GetPlace()), + dimension_each, dim_divisor); + // platform::GpuStreamSync(stream); + } +}; + +template +class CUDAGradRenormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor* d_out = + ctx.Input(framework::GradVarName("Out")); + const framework::Tensor* x = ctx.Input("X"); + framework::Tensor* d_x = + ctx.Output(framework::GradVarName("X")); + + auto numel = d_out->numel(); + const T* dout_data = d_out->data(); + const T* x_data = x->data(); + auto input_dims = x->dims(); + float max_norm = ctx.Attr("max_norm"); + float p = ctx.Attr("p"); + int dim = ctx.Attr("axis"); + auto dimension_each = input_dims[dim]; + auto dim_size = input_dims.size(); + int64_t dim_divisor = 1, pre_mul = 1; + for (int i = dim + 1; i < dim_size; i++) dim_divisor *= input_dims[i]; + for (int i = 0; i < dim; i++) pre_mul *= input_dims[i]; + d_x->Resize(framework::make_ddim(framework::vectorize(input_dims))); + T* dx_data = d_x->mutable_data(ctx.GetPlace()); + framework::Tensor pow_value, mul_value, dim_value, dim_power_sum, + weight_derivative; + pow_value.Resize( + framework::make_ddim({pre_mul, dimension_each, dim_divisor})); + mul_value.Resize( + framework::make_ddim({pre_mul, dimension_each, dim_divisor})); + dim_value.Resize(framework::make_ddim({dimension_each})); + dim_power_sum.Resize(framework::make_ddim({dimension_each})); + weight_derivative.Resize(framework::make_ddim({dimension_each})); + auto stream = ctx.cuda_device_context().stream(); + int block = std::min(numel, static_cast(256)); + int grid = (numel + block - 1) / block; + pow_value.mutable_data(ctx.GetPlace()); + mul_value.mutable_data(ctx.GetPlace()); + dim_value.mutable_data(ctx.GetPlace()); + dim_power_sum.mutable_data(ctx.GetPlace()); + weight_derivative.mutable_data(ctx.GetPlace()); + RenormGradKernelFunc1<<>>( + x_data, dout_data, pow_value.mutable_data(ctx.GetPlace()), + mul_value.mutable_data(ctx.GetPlace()), numel, dimension_each, p, + dim_divisor); + std::vector reduce_axis = {0, 2}; + TensorReduceFunctorImpl>( + pow_value, &dim_value, kps::IdentityFunctor(), reduce_axis, stream); + TensorReduceFunctorImpl>( + mul_value, &weight_derivative, kps::IdentityFunctor(), reduce_axis, + stream); + RenormGradKernelFunc2<<>>( + x_data, dout_data, dx_data, numel, + dim_value.mutable_data(ctx.GetPlace()), + dim_power_sum.mutable_data(ctx.GetPlace()), + weight_derivative.mutable_data(ctx.GetPlace()), dimension_each, p, + max_norm, dim_divisor); + // platform::GpuStreamSync(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(renorm, ops::CUDARenormKernel, + ops::CUDARenormKernel); + +REGISTER_OP_CUDA_KERNEL(renorm_grad, ops::CUDAGradRenormKernel, + ops::CUDAGradRenormKernel); diff --git a/paddle/pten/kernels/gpu/cast_kernel.cu b/paddle/pten/kernels/gpu/cast_kernel.cu index 9f65400f93..0bbe7a3a13 100644 --- a/paddle/pten/kernels/gpu/cast_kernel.cu +++ b/paddle/pten/kernels/gpu/cast_kernel.cu @@ -30,7 +30,7 @@ namespace pten { template struct CastFuctor { - __device__ __forceinline__ OutT operator()(const InT& x) const { + __device__ __forceinline__ OutT operator()(const InT x) const { return static_cast(x); } }; diff --git a/paddle/pten/kernels/gpu/scale_kernel.cu b/paddle/pten/kernels/gpu/scale_kernel.cu index f4bb5c5dbf..68574c063e 100644 --- a/paddle/pten/kernels/gpu/scale_kernel.cu +++ b/paddle/pten/kernels/gpu/scale_kernel.cu @@ -34,7 +34,7 @@ struct ScaleFunctor { bias_after_scale = is_bias_after_sacle; } - __device__ __forceinline__ InT operator()(const InT& x) const { + __device__ __forceinline__ InT operator()(const InT x) const { if (bias_after_scale) { return scale * x + bias; } else { -- Gitee