Revert D33850228: [pytorch][PR] Implement Tanh Gelu Approximation

Test Plan: revert-hammer

Differential Revision:
D33850228 (23d03025dc)

Original commit changeset: 3cc33fb298e4

Original Phabricator Diff: D33850228 (23d03025dc)

fbshipit-source-id: 9436e7df73c2b2e2011f321674f24973316d3692
This commit is contained in:
Nikita Shulga 2022-01-31 09:32:17 -08:00 committed by Facebook GitHub Bot
parent 214624e254
commit c9efb58223
50 changed files with 207 additions and 766 deletions

View File

@ -485,7 +485,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
KERNEL_CPU(ADD_NS(avg_pool1d), "avg_pool1d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool), fp32)
KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &, int64_t), fp32)
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(_upsample_nearest_exact1d), "_upsample_nearest_exact1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)

View File

@ -164,12 +164,12 @@ TORCH_META_FUNC(softshrink_backward) (
build_borrowing_binary_op(maybe_get_output(), grad, self);
}
TORCH_META_FUNC(gelu) (const Tensor & self, int64_t approximate) {
TORCH_META_FUNC(gelu) (const Tensor & self) {
build_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(gelu_backward) (
const Tensor& grad, const Tensor& self, int64_t approximate
const Tensor& grad, const Tensor& self
) {
build_borrowing_binary_op(maybe_get_output(), grad, self);
}
@ -324,37 +324,37 @@ bool use_mkldnn(const Tensor& input) {
}
TORCH_IMPL_FUNC(gelu_out_cpu) (
const Tensor& self, int64_t approximate, const Tensor& result
const Tensor& self, const Tensor& result
) {
#if AT_MKLDNN_ENABLED()
if (use_mkldnn(self) && (approximate == at::Gelu::None)) {
if (use_mkldnn(self)) {
const ideep::tensor& x = itensor_from_tensor(self);
ideep::tensor y = itensor_from_tensor(result);
ideep::eltwise_forward::compute(
x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
} else {
GeluKernel(kCPU, *this, approximate);
GeluKernel(kCPU, *this);
}
#else
GeluKernel(kCPU, *this, approximate);
GeluKernel(kCPU, *this);
#endif
}
TORCH_IMPL_FUNC(gelu_backward_out_cpu) (
const Tensor& grad, const Tensor& self, int64_t approximate, const Tensor& grad_input
const Tensor& grad, const Tensor& self, const Tensor& grad_input
) {
#if AT_MKLDNN_ENABLED()
if (use_mkldnn(self) && (approximate == at::Gelu::None)) {
if (use_mkldnn(self)) {
const ideep::tensor& x = itensor_from_tensor(self);
ideep::tensor grady = itensor_from_tensor(grad);
ideep::tensor gradx = itensor_from_tensor(grad_input);
ideep::eltwise_backward::compute(x, grady, gradx,
ideep::algorithm::eltwise_gelu_erf, /*alpha*/ 0.0);
} else {
GeluBackwardKernel(kCPU, *this, approximate);
GeluBackwardKernel(kCPU, *this);
}
#else
GeluBackwardKernel(kCPU, *this, approximate);
GeluBackwardKernel(kCPU, *this);
#endif
}

View File

@ -12,19 +12,6 @@ struct TensorIteratorBase;
class TensorBase;
}
namespace at {
namespace Gelu {
// Keep this in sync with Gelu class in torch/nn/_gelu.py
// These constants control the approximation behavior of gelu functions.
enum Gelu {
None, // Baseline Gelu
Tanh, // Tahn Gelu Approximation
END
};
} // namespace Gelu
} // namespace at
namespace at { namespace native {
using structured_activation_fn = void (*)(TensorIteratorBase&);
@ -48,8 +35,6 @@ using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const
using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
using gelu_fn = void (*)(TensorIteratorBase&, int64_t);
using gelu_backward_fn = void (*)(TensorIteratorBase&, int64_t);
DECLARE_DISPATCH(elu_fn, elu_stub);
DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
@ -58,8 +43,8 @@ DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
DECLARE_DISPATCH(threshold_fn, threshold_stub);
DECLARE_DISPATCH(gelu_fn, GeluKernel);
DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(structured_activation_fn, GeluKernel);
DECLARE_DISPATCH(structured_activation_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);

View File

@ -166,7 +166,7 @@ void elu_backward_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scal
// TODO(yangxm): Add another fast kernel using formula
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
// and the fast tanh impl from Eigen.
void GeluKernelImpl(TensorIteratorBase& it, int64_t approximate) {
void GeluKernelImpl(TensorIteratorBase& it) {
auto grain_size = at::internal::GRAIN_SIZE;
// Numbers based on benchmarking.
// Benchmark: benchmarks/operator_benchmarks/pt/gelu_test.py
@ -187,134 +187,53 @@ void GeluKernelImpl(TensorIteratorBase& it, int64_t approximate) {
if (it.numel() > GELU_MIN_ELEMENTS_FOR_MULTI_THREADING) {
grain_size = it.numel() / at::get_num_threads();
}
if (approximate == at::Gelu::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
const Vec kKappaVec(scalar_t(0.044715));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const scalar_t kKappa = 0.044715;
auto x_cube = x * x * x;
auto inner = kBeta * (x + kKappa * x_cube);
return scalar_t(0.5) * x * (scalar_t(1) + std::tanh(inner));
},
[&](Vec x_vec) {
auto x_cube = x_vec * x_vec * x_vec;
auto inner_vec = kBetaVec * (x_vec + kKappaVec * x_cube);
return kPointFiveVec * x_vec * (kOneVec + inner_vec.tanh());
},
grain_size);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
},
[&](Vec x_vec) {
return x_vec * kPointFiveVec *
(kOneVec + (x_vec * kAlphaVec).erf());
},
grain_size);
});
}
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
},
[&](Vec x_vec) {
return x_vec * kPointFiveVec *
(kOneVec + (x_vec * kAlphaVec).erf());
},
grain_size);
});
}
void GeluBackwardKernelImpl(TensorIteratorBase& it, int64_t approximate) {
if (approximate == at::Gelu::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
const Vec kKappaVec(scalar_t(0.044715));
const Vec kOneVec(scalar_t(1));
const Vec kThreeVec(scalar_t(3));
const Vec kPointFiveVec(scalar_t(0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const scalar_t kKappa = 0.044715;
auto x_sq = x * x;
auto x_cube = x_sq * x;
auto inner = kBeta * (x + kKappa * x_cube);
auto tanh_inner = std::tanh(inner);
auto left = scalar_t(0.5) * x;
auto right = scalar_t(1) + tanh_inner;
auto left_derivative = scalar_t(0.5) * right;
auto tanh_derivative = scalar_t(1) - tanh_inner * tanh_inner;
auto inner_derivative =
kBeta * (scalar_t(1) + scalar_t(3) * kKappa * x_sq);
auto right_derivative = left * tanh_derivative * inner_derivative;
return dy * (left_derivative + right_derivative);
},
[&](Vec dy_vec, Vec x_vec) {
auto x_sq = x_vec * x_vec;
auto x_cube = x_vec * x_vec * x_vec;
auto inner_vec =
kBetaVec * (x_vec + kKappaVec * x_cube);
auto tanh_inner_vec = inner_vec.tanh();
auto left_vec = kPointFiveVec * x_vec;
auto right_vec = kOneVec + tanh_inner_vec;
auto left_derivative_vec = kPointFiveVec * right_vec;
auto tanh_derivative_vec =
kOneVec - tanh_inner_vec * tanh_inner_vec;
auto inner_derivative_vec =
kBetaVec * (kOneVec + kThreeVec * kKappaVec * x_sq);
auto right_derivative_vec =
left_vec * tanh_derivative_vec * inner_derivative_vec;
return dy_vec * (left_derivative_vec + right_derivative_vec);
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
const Vec kMinusPointFiveVec(scalar_t(-0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
const scalar_t cdf =
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
return dy * (cdf + x * pdf);
},
[&](Vec dy_vec, Vec x_vec) {
const Vec cdf_vec =
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
const Vec pdf_vec =
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
return dy_vec * (cdf_vec + x_vec * pdf_vec);
});
});
}
void GeluBackwardKernelImpl(TensorIteratorBase& it) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
using Vec = vec::Vectorized<scalar_t>;
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
const Vec kOneVec(scalar_t(1));
const Vec kPointFiveVec(scalar_t(0.5));
const Vec kMinusPointFiveVec(scalar_t(-0.5));
cpu_kernel_vec(
it,
[](scalar_t dy, scalar_t x) {
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
const scalar_t cdf =
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
return dy * (cdf + x * pdf);
},
[&](Vec dy_vec, Vec x_vec) {
const Vec cdf_vec =
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
const Vec pdf_vec =
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
return dy_vec * (cdf_vec + x_vec * pdf_vec);
});
});
}
void hardsigmoid_kernel(TensorIteratorBase& iter) {

View File

@ -153,15 +153,15 @@ std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Te
}
TORCH_IMPL_FUNC(gelu_out_cuda) (
const Tensor& /*self*/, int64_t approximate, const Tensor& /*result*/
) {
GeluCUDAKernelImpl(*this, approximate);
const Tensor& /*self*/, const Tensor& /*result*/
) {
GeluCUDAKernelImpl(*this);
}
TORCH_IMPL_FUNC(gelu_backward_out_cuda) (
const Tensor& /*grad*/, const Tensor& /*self*/, int64_t approximate, const Tensor& /*grad_input*/
) {
GeluBackwardCUDAKernelImpl(*this, approximate);
const Tensor& /*grad*/, const Tensor& /*self*/, const Tensor& /*grad_input*/
) {
GeluBackwardCUDAKernelImpl(*this);
}
}} // namespace at::native

View File

@ -392,71 +392,30 @@ void elu_backward_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Sc
});
}
void GeluCUDAKernelImpl(TensorIteratorBase& it, int64_t approximate) {
if (approximate == at::Gelu::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
using T_ACC = acc_type<scalar_t, true>;
gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
constexpr T_ACC kBeta = M_SQRT2 * M_2_SQRTPI * T_ACC(0.5);
constexpr T_ACC kKappa = 0.044715;
auto x_cube = static_cast<T_ACC>(x) * static_cast<T_ACC>(x) * static_cast<T_ACC>(x);
auto inner = kBeta * (static_cast<T_ACC>(x) + kKappa * x_cube);
return T_ACC(0.5) * static_cast<T_ACC>(x) * (T_ACC(1) + c10::cuda::compat::tanh(inner));
});
void GeluCUDAKernelImpl(TensorIteratorBase& it) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
using T_ACC = acc_type<scalar_t, true>;
gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
return static_cast<T_ACC>(x) *
c10::cuda::compat::normcdf(static_cast<T_ACC>(x));
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
using T_ACC = acc_type<scalar_t, true>;
gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
constexpr T_ACC kAlpha = M_SQRT1_2;
return static_cast<T_ACC>(x) * T_ACC(0.5) * (T_ACC(1) + ::erf(static_cast<T_ACC>(x) * kAlpha));
});
});
}
});
}
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, int64_t approximate) {
if (approximate == at::Gelu::Tanh) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
using T_ACC = acc_type<scalar_t, true>;
gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
constexpr T_ACC kBeta = M_SQRT2 * M_2_SQRTPI * T_ACC(0.5);
constexpr T_ACC kKappa = 0.044715;
auto x_sq = static_cast<T_ACC>(x) * static_cast<T_ACC>(x);
auto x_cube = x_sq * static_cast<T_ACC>(x);
auto inner = kBeta * (static_cast<T_ACC>(x) + kKappa * x_cube);
auto tanh_inner = c10::cuda::compat::tanh(inner);
auto left = T_ACC(0.5) * static_cast<T_ACC>(x);
auto right = T_ACC(1) + tanh_inner;
auto left_derivative = 0.5 * right;
auto tanh_derivative = T_ACC(1) - tanh_inner * tanh_inner;
auto inner_derivative = kBeta * (T_ACC(1) + T_ACC(3) * kKappa * x_sq);
auto right_derivative = left * tanh_derivative * inner_derivative;
return static_cast<T_ACC>(dy) * (left_derivative + right_derivative);
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
using T_ACC = acc_type<scalar_t, true>;
gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
constexpr T_ACC kBeta = M_2_SQRTPI * M_SQRT1_2 * T_ACC(0.5);
const T_ACC cdf = c10::cuda::compat::normcdf(static_cast<T_ACC>(x));
const T_ACC pdf =
c10::cuda::compat::exp(
T_ACC(-0.5) * static_cast<T_ACC>(x) * static_cast<T_ACC>(x)) *
kBeta;
return static_cast<T_ACC>(dy) * (cdf + static_cast<T_ACC>(x) * pdf);
});
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
using T_ACC = acc_type<scalar_t, true>;
gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
constexpr T_ACC kBeta = M_2_SQRTPI * M_SQRT1_2 * T_ACC(0.5);
constexpr T_ACC kAlpha = M_SQRT1_2;
const T_ACC cdf =
T_ACC(0.5) * (T_ACC(1) + ::erf(static_cast<T_ACC>(x) * kAlpha));
const T_ACC pdf =
c10::cuda::compat::exp(
T_ACC(-0.5) * static_cast<T_ACC>(x) * static_cast<T_ACC>(x)) *
kBeta;
return static_cast<T_ACC>(dy) * (cdf + static_cast<T_ACC>(x) * pdf);
});
});
}
}
namespace {

View File

@ -1,5 +1,4 @@
#include <ATen/native/Activation.h>
#include <cstdint>
namespace at {
@ -25,7 +24,7 @@ void launch_prelu_cuda_backward_kernel_multi_weights(
const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out,
const TensorBase &input_grad, const TensorBase &weight_grad_collector);
void GeluCUDAKernelImpl(TensorIteratorBase& it, int64_t approximate);
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, int64_t approximate);
void GeluCUDAKernelImpl(TensorIteratorBase& it);
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it);
}} // namespace at::native

View File

@ -1,18 +1,17 @@
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#include <ATen/native/Activation.h>
#if !AT_MKLDNN_ENABLED()
namespace at { namespace native {
Tensor mkldnn_gelu(const Tensor& input, int64_t approximate) {
Tensor mkldnn_gelu(const Tensor& input) {
TORCH_CHECK(false, "mkldnn_gelu: ATen not compiled with MKLDNN support");
}
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, int64_t approximate) {
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input) {
TORCH_CHECK(false, "mkldnn_gelu_backward: ATen not compiled with MKLDNN support");
}
@ -25,13 +24,11 @@ Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, int6
namespace at { namespace native {
Tensor mkldnn_gelu(const Tensor& input, int64_t approximate) {
Tensor mkldnn_gelu(const Tensor& input) {
if (input.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_gelu: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}
TORCH_CHECK(approximate == at::Gelu::None,
"mkldnn_gelu: fast, approximate gelu is not supported");
const ideep::tensor& x = itensor_from_tensor(input);
ideep::tensor y;
ideep::eltwise_forward::compute(
@ -40,9 +37,7 @@ Tensor mkldnn_gelu(const Tensor& input, int64_t approximate) {
input.options().device_opt());
}
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, int64_t approximate) {
TORCH_CHECK(approximate == at::Gelu::None,
"mkldnn_gelu_backward: fast, approximate gelu is not supported");
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input) {
const ideep::tensor& x = itensor_from_tensor(input);
ideep::tensor grady = itensor_from_tensor(grad_output);
ideep::tensor gradx;

View File

@ -3724,7 +3724,7 @@
CPU: prelu_backward_cpu
CUDA: prelu_backward_cuda
- func: gelu.out(Tensor self, int approximate=0, *, Tensor(a!) out) -> Tensor(a!)
- func: gelu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
@ -3733,7 +3733,7 @@
CPU: gelu_out_cpu
CUDA: gelu_out_cuda
- func: gelu(Tensor self, int approximate=0) -> Tensor
- func: gelu(Tensor self) -> Tensor
structured_delegate: gelu.out
device_check: NoCheck # TensorIterator
python_module: nn
@ -3741,7 +3741,7 @@
MkldnnCPU: mkldnn_gelu
QuantizedCPU: gelu_quantized_cpu
- func: gelu_backward.grad_input(Tensor grad_output, Tensor self, int approximate=0, *, Tensor(a!) grad_input) -> Tensor(a!)
- func: gelu_backward.grad_input(Tensor grad, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
python_module: nn
@ -3749,7 +3749,7 @@
CPU: gelu_backward_out_cpu
CUDA: gelu_backward_out_cuda
- func: gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor
- func: gelu_backward(Tensor grad, Tensor self) -> Tensor
structured_delegate: gelu_backward.grad_input
python_module: nn
dispatch:

View File

@ -1,7 +1,6 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/native/Activation.h>
#include <ATen/native/SortingUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/UpSample.h>
@ -616,7 +615,7 @@ static void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx,
});
}
void qgelu_kernel(const Tensor& qx, Tensor& qy, int64_t approximate) {
void qgelu_kernel(const Tensor& qx, Tensor& qy) {
int64_t zero_point = qx.q_zero_point();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float scale = qx.q_scale();
@ -627,83 +626,40 @@ void qgelu_kernel(const Tensor& qx, Tensor& qy, int64_t approximate) {
float output_scale = scale;
float inv_output_scale = 1.0 / output_scale;
const auto kAlphaVec = Vectorized<float>(M_SQRT1_2);
const auto kBetaVec = Vectorized<float>(M_SQRT2 * M_2_SQRTPI * 0.5);
const auto kKappaVec = Vectorized<float>(0.044715);
const auto kOneVec = Vectorized<float>(1);
const auto kPointFiveVec = Vectorized<float>(0.5);
if (approximate == at::Gelu::Tanh) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
output_scale,
output_zero_point,
c10::nullopt);
auto iter = TensorIterator::unary_op(qy, qx);
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
output_scale,
output_zero_point,
c10::nullopt);
auto iter = TensorIterator::unary_op(qy, qx);
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(
iter,
[&](scalar_t value_qx) -> scalar_t {
const auto value_dx =
at::native::dequantize_val(scale, zero_point, value_qx);
const auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const auto kKappa = 0.044715;
const auto x_cube = value_dx * value_dx * value_dx;
const auto inner = kBeta * (value_dx + kKappa * x_cube);
const auto value_dy = 0.5 * value_dx * (1.0 + std::tanh(inner));
return at::native::quantize_val<scalar_t>(
output_scale, output_zero_point, value_dy);
},
[&](Vec value_qx) -> Vec {
auto value_dx = value_qx.dequantize(
scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
for (auto & value : value_dx) {
auto value_cube = value * value * value;
auto inner = kBetaVec * (value + kKappaVec * value_cube);
value = kPointFiveVec * value * (kOneVec + inner.tanh());
}
return Vec::quantize(
value_dx, output_scale, output_zero_point, inv_output_scale);
});
});
} else {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
output_scale,
output_zero_point,
c10::nullopt);
auto iter = TensorIterator::unary_op(qy, qx);
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(
iter,
[&](scalar_t value_qx) -> scalar_t {
const auto value_dx =
at::native::dequantize_val(scale, zero_point, value_qx);
const auto value_dy =
value_dx * 0.5 * (1 + std::erf(value_dx * M_SQRT1_2));
return at::native::quantize_val<scalar_t>(
output_scale, output_zero_point, value_dy);
},
[&](Vec value_qx) -> Vec {
auto value_dx = value_qx.dequantize(
scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
for (auto & value : value_dx) {
value = value * kPointFiveVec * (kOneVec + (value * kAlphaVec).erf());
}
return Vec::quantize(
value_dx, output_scale, output_zero_point, inv_output_scale);
});
});
}
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(
iter,
[&](scalar_t value_qx) -> scalar_t {
const auto value_dx =
at::native::dequantize_val(scale, zero_point, value_qx);
const auto value_dy =
value_dx * 0.5 * (1 + std::erf(value_dx * M_SQRT1_2));
return at::native::quantize_val<scalar_t>(
output_scale, output_zero_point, value_dy);
},
[&](Vec value_qx) -> Vec {
auto value_dx = value_qx.dequantize(
scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
for (auto & value : value_dx) {
value = value * kPointFiveVec * (kOneVec + (value * kAlphaVec).erf());
}
return Vec::quantize(
value_dx, output_scale, output_zero_point, inv_output_scale);
});
});
}

View File

@ -15,9 +15,9 @@ namespace native {
DEFINE_DISPATCH(qgelu_stub);
Tensor gelu_quantized_cpu(const Tensor& qx, int64_t approximate) {
Tensor gelu_quantized_cpu(const Tensor& qx) {
Tensor qy;
qgelu_stub(qx.device().type(), qx, qy, approximate);
qgelu_stub(qx.device().type(), qx, qy);
return qy;
}
}} // namespace at::native

View File

@ -8,7 +8,7 @@ namespace native {
using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
const Scalar& /*negval_*/);
using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, int64_t /* approximate */);
using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point);
using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
using qclamp_fn = void (*)(

View File

@ -973,17 +973,10 @@ TEST_F(FunctionalTest, GLU) {
}
TEST_F(FunctionalTest, GELU) {
GELU model;
const auto x = torch::linspace(-3.0, 3.0, 100);
const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
const auto y = F::gelu(x, F::GELUFuncOptions().approximate(torch::kNone));
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
}
TEST_F(FunctionalTest, TanhGELU) {
const auto x = torch::linspace(-3.0, 3.0, 100);
const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
const auto y = F::gelu(x, F::GELUFuncOptions().approximate(torch::kTanh));
const auto y = F::gelu(x);
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
}

View File

@ -2854,23 +2854,13 @@ TEST_F(ModulesTest, GLU) {
}
TEST_F(ModulesTest, GELU) {
GELU model(GELUOptions().approximate(torch::kNone));
GELU model;
const auto x = torch::linspace(-3.0, 3.0, 100);
const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
const auto y = model(x);
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
}
TEST_F(ModulesTest, TanhGELU) {
GELU model(GELUOptions().approximate(torch::kTanh));
const auto x = torch::linspace(-3.0, 3.0, 100);
const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
const auto y = model(x);
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST_F(ModulesTest, Mish) {
Mish model;
auto x = torch::randn(100) * 10;

View File

@ -50,8 +50,12 @@ ALLOW_LIST = [
("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
("aten::randperm", datetime.date(9999, 1, 1)),
("aten::gelu", datetime.date(2022, 3, 1)),
("aten::gelu_backward", datetime.date(2022, 3, 1)),
("aten::_conv_depthwise2d_backward", datetime.date(2022, 1, 31)),
("aten::conv_depthwise3d_backward", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution.deprecated", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution.deprecated2", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution_transpose.deprecated", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution_transpose.deprecated2", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution_backward", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution_backward_input", datetime.date(2022, 1, 31)),
("aten::cudnn_convolution_backward_weight", datetime.date(2022, 1, 31)),

View File

@ -447,7 +447,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
%0 : int[] = prim::Constant[value=[2, 2, 1]]()
%1 : int = prim::Constant[value=0]()
%2 : Tensor = aten::t(%b)
%3 : Tensor = aten::relu(%2)
%3 : Tensor = aten::gelu(%2)
%4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%b, %3, %2)
return (%4)
"""
@ -471,7 +471,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
%1 : int = prim::Constant[value=0]()
%d : Tensor = aten::t(%c)
%2 : Tensor = aten::t(%b)
%3 : Tensor = aten::relu(%2)
%3 : Tensor = aten::gelu(%2)
%4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%3, %2, %d, %b, %c, %b)
return (%4)
"""

View File

@ -136,7 +136,7 @@ class TestExportAsContribOps(unittest.TestCase):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.gelu = torch.nn.GELU(approximate='none')
self.gelu = torch.nn.GELU()
def forward(self, x):
res = []
@ -149,7 +149,7 @@ class TestExportAsContribOps(unittest.TestCase):
res.append(x[0])
return torch.stack(res), torch.stack(res2)
def symbolic_custom_gelu(g, input, approximate):
def symbolic_custom_gelu(g, input):
return g.op("com.microsoft::Gelu", input).setType(input.type())
from torch.onnx import register_custom_op_symbolic
@ -157,7 +157,7 @@ class TestExportAsContribOps(unittest.TestCase):
x = torch.randn(3, 3, 4, requires_grad=True)
model = torch.jit.script(M())
run_model_test(self, model, input=(x,))
run_model_test(self, model, input=(x, ))
if __name__ == "__main__":
unittest.main()

View File

@ -2383,17 +2383,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
def test_gelu(self):
class GeluModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.gelu(x, 'none')
model = GeluModel()
inputs = torch.randn(2, 4, 5, 6, requires_grad=True)
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tanh_gelu(self):
class GeluModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.gelu(x, 'tanh')
return torch.nn.functional.gelu(x)
model = GeluModel()
inputs = torch.randn(2, 4, 5, 6, requires_grad=True)

View File

@ -6225,16 +6225,7 @@ class TestONNXRuntime(unittest.TestCase):
def test_gelu(self):
class GeluModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.gelu(x, 'none')
x = torch.randn(2, 4, 5, 6, requires_grad=True)
self.run_test(GeluModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_tanh_gelu(self):
class GeluModel(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.gelu(x, 'tanh')
return torch.nn.functional.gelu(x)
x = torch.randn(2, 4, 5, 6, requires_grad=True)
self.run_test(GeluModel(), x)

View File

@ -804,11 +804,11 @@ class TestUtilityFuns_opset9(_BaseTestCase):
def test_custom_opsets_gelu(self):
self.addCleanup(unregister_custom_op_symbolic, "::gelu", 1)
def gelu(g, self, approximate):
def gelu(g, self):
return g.op("com.microsoft::Gelu", self).setType(self.type())
register_custom_op_symbolic("::gelu", gelu, 1)
model = torch.nn.GELU(approximate='none')
model = torch.nn.GELU()
x = torch.randn(3, 3)
f = io.BytesIO()
torch.onnx.export(model, (x, ), f,
@ -824,11 +824,11 @@ class TestUtilityFuns_opset9(_BaseTestCase):
def test_register_aten_custom_op_symbolic(self):
self.addCleanup(unregister_custom_op_symbolic, "aten::gelu", 1)
def gelu(g, self, approximate):
def gelu(g, self):
return g.op("com.microsoft::Gelu", self).setType(self.type())
register_custom_op_symbolic("aten::gelu", gelu, 1)
model = torch.nn.GELU(approximate='none')
model = torch.nn.GELU()
x = torch.randn(3, 3)
f = io.BytesIO()
torch.onnx.export(model, (x, ), f, opset_version=self.opset_version)

View File

@ -440,9 +440,8 @@ class TestQuantizedOps(TestCase):
shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
dtypes = (torch.quint8, torch.qint8)
memory_formats = (torch.channels_last, torch.contiguous_format)
approximation = ['none', 'tanh']
test_cases = itertools.product(shapes, dtypes, memory_formats, approximation)
for shape, dtype, memory_format, approximate in test_cases:
test_cases = itertools.product(shapes, dtypes, memory_formats)
for shape, dtype, memory_format in test_cases:
if memory_format == torch.channels_last and len(shape) != 4:
continue
X, scale, zero_point, torch_type = \
@ -454,7 +453,7 @@ class TestQuantizedOps(TestCase):
dqX = qX.dequantize()
op = torch.nn.functional.gelu
dqY = op(dqX, approximate)
dqY = op(dqX)
qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point,
dtype=torch_type)
qY_hat = op(qX)

View File

@ -3516,7 +3516,6 @@ class TestFunctionalTracing(JitTestCase):
"adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
"gelu": CONTROL_FLOW,
"hardshrink": ARG_TYPE_MISMATCH,
"layer_norm": ARG_TYPE_MISMATCH,
"lp_pool1d": ARG_TYPE_MISMATCH,

View File

@ -1260,37 +1260,6 @@ class TestTEFuser(JitTestCase):
" ".join(["Failed:", str(dtype), 'isnan', device])
)
def test_gelu(self):
def apply(fn):
return lambda x, approximate: fn(x, approximate)
unary_ops = [
F.gelu,
]
sizes = [(1,), (2,), (4, 4)]
for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes):
# TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
continue
try:
x = self.data_for(dtype, device, size=size)
cond = self.data_for(torch.bool, device)
fn = apply(op)
ref = fn(x, cond)
except Exception:
# If eager mode doesn't support a dtype/op/device combo,
# neither does the fuser. Catch everything to avoid needing to
# guess what errors might be thrown by eager.
continue
try:
t = torch.jit.trace(fn, (x, cond))
torch.testing.assert_close(ref, t(x, cond))
self.assertAllFused(t.graph_for(x, cond))
except Exception as e:
raise RuntimeError(
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
)
def test_unary_ops(self):
def apply(fn):
return lambda x: fn(x)
@ -1325,6 +1294,7 @@ class TestTEFuser(JitTestCase):
F.softplus,
torch.sqrt,
torch.rsqrt,
F.gelu,
torch.abs,
torch.ceil,
torch.floor,
@ -2237,6 +2207,7 @@ works_list = [
'mul',
'ne',
'neg',
'nn.functional.gelu',
'nn.functional.hardshrink',
'nn.functional.hardsigmoid',
'nn.functional.hardswish',

View File

@ -9153,25 +9153,16 @@ class TestNN(NNTestCase):
def _gelu_ref(X):
return X * stats.norm.cdf(X)
def _tanh_gelu_ref(X):
M_SQRT_2_PI = math.sqrt(2 / math.pi)
Z = M_SQRT_2_PI * (X + 0.044715 * np.power(X, 3.0))
return 0.5 * X * (1.0 + np.tanh(Z))
for approximate in ['none', 'tanh']:
for d in devices:
if contiguous:
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)
else:
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
res = F.gelu(X, approximate)
if approximate == 'tanh':
ref = _tanh_gelu_ref(X.to(numpy_dtype).cpu().detach().numpy())
else:
ref = _gelu_ref(X.to(numpy_dtype).cpu().detach().numpy())
self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
if dtype == torch.float64:
gradcheck(F.gelu, [X, approximate], eps=1e-4)
for d in devices:
if contiguous:
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)
else:
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
res = F.gelu(X)
ref = _gelu_ref(X.to(numpy_dtype).cpu().detach().numpy())
self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
if dtype == torch.float64:
gradcheck(F.gelu, [X], eps=1e-4)
for n in range(1, 10):
for m in range(1, 10):

View File

@ -1806,14 +1806,10 @@
- name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result)
- name: gelu(Tensor self, int approximate=0) -> Tensor
self: gelu_backward(grad, self, approximate)
- name: gelu(Tensor self) -> Tensor
self: "GradMode::is_enabled() ? infinitely_differentiable_gelu_backward(grad, self) : gelu_backward(grad, self)"
result: auto_element_wise
- name: gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor
grad_output: gelu_backward(grad, self, approximate)
self: gelu_double_backward(grad, grad_output, self, approximate)
- name: glu(Tensor self, int dim=-1) -> Tensor
self: glu_backward(grad, self, dim)

View File

@ -3,7 +3,6 @@
#include <string>
#include <ATen/core/Reduction.h>
#include <ATen/native/Activation.h>
#include <c10/util/Exception.h>
#include <c10/util/variant.h>
#include <torch/csrc/Export.h>
@ -80,11 +79,6 @@ std::string operator()(const enumtype::k##name& v) const { \
//
// Note that we also provide the default constructor `SomeOptions() {}`, so that
// `SomeOptions options = {}` can work.
#define TORCH_OPTIONS_CTOR_VARIANT_ARG2(OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2) \
OPTIONS_NAME() {} \
OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {}
#define TORCH_OPTIONS_CTOR_VARIANT_ARG3(OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3) \
OPTIONS_NAME() {} \
OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
@ -206,19 +200,5 @@ at::Reduction::Reduction reduction_get_enum(V variant_enum) {
}
}
template <typename V>
at::Gelu::Gelu gelu_get_enum(V variant_enum) {
if (c10::get_if<enumtype::kNone>(&variant_enum)) {
return at::Gelu::None;
} else if (c10::get_if<enumtype::kTanh>(&variant_enum)) {
return at::Gelu::Tanh;
} else {
TORCH_CHECK(
false,
get_enum_name(variant_enum), " is not a valid value for gelu approximate");
return at::Gelu::END;
}
}
} // namespace enumtype
} // namespace torch

View File

@ -336,16 +336,8 @@ inline Tensor glu(const Tensor& input, const GLUFuncOptions& options = {}) {
// ============================================================================
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail {
inline Tensor gelu(const Tensor& input, GELUFuncOptions::gelu_t approximate) {
return torch::gelu(input, enumtype::gelu_get_enum(approximate));
}
} // namespace detail
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
inline Tensor gelu(const Tensor& input, const GELUFuncOptions& options = {}) {
return detail::gelu(input, options.approximate());
inline Tensor gelu(const Tensor& input) {
return torch::gelu(input);
}
// ============================================================================

View File

@ -570,17 +570,12 @@ TORCH_MODULE(GLU);
// NOLINTNEXTLINE(bugprone-exception-escape)
class TORCH_API GELUImpl : public torch::nn::Cloneable<GELUImpl> {
public:
explicit GELUImpl(const GELUOptions& options_ = {});
Tensor forward(const Tensor& input);
void reset() override;
/// Pretty prints the `GELU` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;
/// The options with which this `Module` was constructed.
GELUOptions options;
};
/// A `ModuleHolder` subclass for `GELUImpl`.

View File

@ -1,7 +1,6 @@
#pragma once
#include <torch/arg.h>
#include <torch/enum.h>
#include <torch/csrc/Export.h>
#include <torch/types.h>
@ -96,37 +95,6 @@ using GLUFuncOptions = GLUOptions;
// ============================================================================
/// Options for the `GELU` module.
///
/// Example:
/// ```
/// GELU model(GELUOptions(torch::kNone));
/// ```
struct TORCH_API GELUOptions {
typedef c10::variant<enumtype::kNone, enumtype::kTanh> gelu_t;
TORCH_OPTIONS_CTOR_VARIANT_ARG2(GELUOptions, approximate, kNone, kTanh)
/// Specifies the approximation to apply to the output.
TORCH_ARG(gelu_t, approximate) = torch::kNone;
};
namespace functional {
/// Options for `torch::nn::functional::gelu`.
///
/// See the documentation for `torch::nn::GELUOptions` class to learn what
/// arguments are supported.
///
/// Example:
/// ```
/// namespace F = torch::nn::functional;
/// F::gelu(input, F::GELUFuncOptions(torch::kNone));
/// ```
using GELUFuncOptions = GELUOptions;
} // namespace functional
// ============================================================================
/// Options for the `Hardshrink` module.
///
/// Example:

View File

@ -284,10 +284,8 @@ void GLUImpl::pretty_print(std::ostream& stream) const {
// ============================================================================
GELUImpl::GELUImpl(const GELUOptions& options_) : options(options_) {}
Tensor GELUImpl::forward(const Tensor& input) {
return F::detail::gelu(input, options.approximate());
return F::gelu(input);
}
void GELUImpl::reset() {}

View File

@ -12,7 +12,6 @@
#include <ATen/ExpandUtils.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Activation.h>
#include <ATen/ScalarOps.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/Utils.h>
@ -2339,46 +2338,6 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
}
}
Tensor gelu_double_backward(
const Tensor & ggI,
const Tensor & gO,
const Tensor & input,
int64_t approximate) {
if (approximate == at::Gelu::Tanh) {
constexpr auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
constexpr auto kKappa = 0.044715;
auto inner = kBeta * (input + kKappa * pow(input, 3));
auto tanh_inner = tanh(inner);
auto sech_inner = 1 / cosh(inner);
auto f = 0.5 * input;
auto g = 1 - tanh_inner * tanh_inner;
auto h = kBeta * (1 + 3 * kKappa * input * input);
auto f_prime_gh = 0.5 * g * h;
auto g_prime = (2 * sech_inner) * (-sech_inner * tanh_inner) * h;
auto g_prime_fh = f * h * g_prime;
auto h_prime = 6 * kKappa * input * kBeta;
auto h_prime_fg = f * g * h_prime;
// left_derivative = f_prime_gh
// right_derivative = f_prime_gh + g_prime_fh + h_prime_fg
// dgrad_dX = left_derivative + right_derivative
auto gI = ggI * gO * (2 * f_prime_gh + g_prime_fh + h_prime_fg);
return gI;
} else {
constexpr auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
auto input_sq = input * input;
auto pdf = kBeta * at::exp(-0.5 * input_sq);
auto dgrad_dInput = 2 * pdf - input_sq * pdf;
auto gI = ggI * gO * dgrad_dInput;
return gI;
}
}
Tensor elu_double_backward(
const Tensor& grad,
const Tensor& grad_output,

View File

@ -303,11 +303,6 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
const Tensor & grad_out,
const Tensor & input_,
const Tensor & weight_);
Tensor gelu_double_backward(
const Tensor & ggI,
const Tensor & gO,
const Tensor & input,
int64_t approximate);
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_);
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask);
std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(

View File

@ -12,8 +12,6 @@
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/jit/ir/constants.h>
#include <ATen/native/Activation.h>
#include <unordered_map>
#include <utility>
@ -2275,8 +2273,7 @@ class IrParser {
}
{
auto ptr_op = getOperatorForLiteral(
"aten::gelu(Tensor self, int approximate=0) -> Tensor");
auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
@ -2286,20 +2283,7 @@ class IrParser {
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();
auto approximate = constant_as<int64_t>(node->input(1));
TORCH_INTERNAL_ASSERT(
approximate.has_value(),
"The approximate parameter is required.");
const bool kApproximate = approximate.value();
Val* out = nullptr;
if (kApproximate == at::Gelu::Tanh) {
out = fast_gelu(self);
} else {
out = unaryOp(UnaryOpType::Gelu, self);
}
auto out = gelu(self);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
@ -2309,7 +2293,7 @@ class IrParser {
{
auto ptr_op = getOperatorForLiteral(
"aten::gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor");
"aten::gelu_backward(Tensor grad, Tensor self) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
@ -2324,19 +2308,7 @@ class IrParser {
auto self = list_val.front();
list_val.pop_front();
auto approximate = constant_as<int64_t>(node->input(2));
TORCH_INTERNAL_ASSERT(
approximate.has_value(),
"The approximate parameter is required.");
const bool kApproximate = approximate.value();
Val* grad_in = nullptr;
if (kApproximate == at::Gelu::Tanh) {
grad_in = fast_gelu_backward(grad_out, self);
} else {
grad_in = gelu_backward(grad_out, self);
}
auto grad_in = gelu_backward(grad_out, self);
value_map.emplace(
node->output()->unique(), ValueHolder(grad_in, format));
},
@ -3043,38 +3015,6 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
}
}
static auto gelu_schema =
getOperatorForLiteral(
"aten::gelu(Tensor self, int approximate=0) -> Tensor")
->schema();
if (node->matches(gelu_schema)) {
switch (offset) {
// argument 1: approximate;
case 1:
profileInt(pr, node, offset);
break;
default:
return false;
}
return true;
}
static auto gelu_backward_schema =
getOperatorForLiteral(
"aten::gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor")
->schema();
if (node->matches(gelu_backward_schema)) {
switch (offset) {
// argument 2: approximate;
case 2:
profileInt(pr, node, offset);
break;
default:
return false;
}
return true;
}
static auto softmax_backward_data_schema =
getOperatorForLiteral(
"aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor")

View File

@ -56,14 +56,6 @@ def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
{"full_out_0_4", R"SCRIPT(
def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
return torch.full(size, fill_value, out=out)
)SCRIPT"},
{"gelu_0_8", R"SCRIPT(
def gelu_0_8(self: Tensor) -> Tensor:
return torch._C._nn.gelu(self, 0)
)SCRIPT"},
{"gelu_out_0_8", R"SCRIPT(
def gelu_out_0_8(self: Tensor, *, out: Tensor) -> Tensor:
return torch._C._nn.gelu(self, 0, out=out)
)SCRIPT"}});
std::shared_ptr<Graph> create_upgrader_graph(

View File

@ -43,16 +43,7 @@ static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersi
{"aten::full.out",
{{5,
"full_out_0_4",
"aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}},
{"aten::gelu",
{{9,
"gelu_0_8",
"aten::gelu(Tensor self) -> Tensor"}}},
{"aten::gelu.out",
{{9,
"gelu_out_0_8",
"aten::gelu(Tensor self, *, Tensor(a!) out) -> Tensor"}}}
});
"aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}}});
const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
get_operator_version_map() {

View File

@ -872,7 +872,7 @@ class ShapePropagator : public PropertyPropBase {
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
"aten::rsqrt(Tensor self) -> Tensor",
"aten::selu(Tensor self) -> Tensor",
"aten::gelu(Tensor self, int approximate=0) -> Tensor",
"aten::gelu(Tensor self) -> Tensor",
"aten::sigmoid(Tensor self) -> Tensor",
"aten::sign(Tensor self) -> Tensor",
"aten::sin(Tensor self) -> Tensor",

View File

@ -913,10 +913,16 @@ const std::vector<std::string> functions = {
return grad_output * torch.where(self > 0, 1.0, negative_slope).type_as(result), None
return result, backward
def gelu(self : Tensor, approximate : int):
result = torch.gelu(self, approximate)
def gelu(self):
result = torch.gelu(self)
def backward(grad_output):
return torch.gelu_backward(grad_output, self, approximate), None
m_2_sqrtpi = 1.12837916709551257390
m_sqrt1_2 = 0.707106781186547524401
alpha = m_sqrt1_2
beta = m_2_sqrtpi * m_sqrt1_2 * 0.5
cdf = (torch.erf(self * m_sqrt1_2) + 1.0) * 0.5
pdf = beta * torch.exp(self * self * -0.5)
return grad_output * (cdf + self * pdf)
return result, backward
def hardswish(self):

View File

@ -74,7 +74,7 @@ const OperatorMap<std::string>& get_tensorexpr_elementwise_set() {
{"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "unary"},
{"aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor", "unary"},
{"aten::relu6(Tensor self) -> Tensor", "unary"},
{"aten::gelu(Tensor self, int approximate=0) -> Tensor", "unary"},
{"aten::gelu(Tensor self) -> Tensor", "unary"},
{"aten::neg(Tensor self) -> Tensor", "unary"},
{"aten::reciprocal(Tensor self) -> Tensor", "unary"},
{"aten::expm1(Tensor self) -> Tensor", "unary"},

View File

@ -3,8 +3,6 @@
#include <torch/csrc/jit/tensorexpr/lowerings.h>
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
#include <ATen/native/Activation.h>
namespace torch {
namespace jit {
namespace tensorexpr {
@ -643,34 +641,21 @@ int nnc_lowerings_lazy_registration() {
});
RegisterNNCLoweringsFunction aten_gelu(
{"aten::gelu(Tensor self, int approximate=0) -> (Tensor)"},
{"aten::gelu(Tensor self) -> (Tensor)"},
[](const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
at::Device device) {
return computeOneOperandWithCondition(
return computeOneOperand(
"aten_gelu",
inputs,
outputShape,
outputType,
[](const ExprHandle& a, const ExprHandle& approximate) {
[](const ExprHandle& a) {
auto m_sqrt1_2 = Cast::make(a.dtype(), M_SQRT1_2);
auto one = Cast::make(a.dtype(), 1.);
auto point_five = Cast::make(a.dtype(), .5);
auto tanh_gelu_flag = Cast::make(approximate.dtype(), at::Gelu::Tanh);
// approximate == 'none'
auto m_sqrt1_2 = Cast::make(a.dtype(), M_SQRT1_2);
auto gelu_result = a * point_five * (one + erf(a * m_sqrt1_2));
// approximate == 'tanh'
auto beta = Cast::make(a.dtype(), M_SQRT2 * M_2_SQRTPI * 0.5);
auto kappa = Cast::make(a.dtype(), 0.044715);
auto a_cube = a * a * a;
auto inner = beta * (a + kappa * a_cube);
auto tanh_gelu_result = point_five * a * (one + tanh(inner));
auto cs = CompareSelect::make(approximate, tanh_gelu_flag, kEQ);
return ifThenElse(cs, tanh_gelu_result, gelu_result);
return a * point_five * (one + erf(a * m_sqrt1_2));
});
});

View File

@ -43,31 +43,6 @@ Tensor computeOneOperand(
});
}
Tensor computeOneOperandWithCondition(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
innerExpr,
const int checkParamTypes) {
return Compute(
name,
c10::fmap<DimArg>(outputShape),
[inputValues, outputType, innerExpr, checkParamTypes](
const std::vector<VarHandle>& axes) {
std::vector<ExprHandle> indices(axes.begin(), axes.end());
std::vector<ExprHandle> inputs = {
tensorOrConstant(inputValues[0], indices)};
promoteInputs(inputs, checkParamTypes);
// Last expr is the condition, which we don't promote
inputs.emplace_back(tensorOrConstant(inputValues[1], indices));
ExprHandle compute = innerExpr(inputs[0], inputs[1]);
return demoteOutput(compute, outputType);
});
}
Tensor computeTwoOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,

View File

@ -17,14 +17,6 @@ Tensor computeOneOperand(
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&)>& innerExpr,
const int checkParamTypes = kAllTypes);
Tensor computeOneOperandWithCondition(
const std::string& name,
const std::vector<ArgValue>& inputValues,
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
innerExpr,
const int checkParamTypes = kAllTypes);
Tensor computeTwoOperand(
const std::string& name,
const std::vector<ArgValue>& inputValues,

View File

@ -1,11 +0,0 @@
# Keep this file in sync with enums in aten/src/ATen/core/Gelu.h
def get_enum(gelu_approximation: str) -> int:
if gelu_approximation == 'none':
ret = 0
elif gelu_approximation == 'tanh':
ret = 1
else:
ret = -1 # TODO: remove once JIT exceptions support control flow
raise ValueError("{} is not a valid value for gelu approximation".format(gelu_approximation))
return ret

View File

@ -20,7 +20,6 @@ from ..overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
handle_torch_function)
from . import _reduction as _Reduction
from . import _gelu as _Gelu
from . import grad # noqa: F401
from .modules import utils
from .modules.utils import _single, _pair, _triple, _list_with_default
@ -1652,31 +1651,19 @@ See :class:`~torch.nn.LogSigmoid` for more details.
)
def gelu(input: Tensor, approximate: str = 'none') -> Tensor:
r"""gelu(input, approximate = 'none') -> Tensor
def gelu(input):
r"""gelu(input) -> Tensor
Applies element-wise the function
:math:`\text{GELU}(x) = x * \Phi(x)`
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
When the approximate argument is 'tanh', Gelu is estimated with:
:math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
"""
if has_torch_function_unary(input):
return handle_torch_function(gelu, (input,), input, approximate=approximate)
# Enforce that the full call with the new kwarg is not invoked when scripting.
# TODO: Remove this scripting logic once the 2-week FC window has passed.
if not torch.jit.is_scripting():
return torch._C._nn.gelu(input, _Gelu.get_enum(approximate))
# When scripting, make a simpler call as long as the kwarg is set to the default value.
elif approximate == 'none':
return torch._C._nn.gelu(input)
else:
raise RuntimeError("TorchScript currently does not support approximate in nn.Gelu")
return handle_torch_function(gelu, (input,), input)
return torch._C._nn.gelu(input)
def hardshrink(input: Tensor, lambd: float = 0.5) -> Tensor:

View File

@ -141,7 +141,7 @@ def rrelu(input: Tensor, lower: float = ..., upper: float = ..., training: bool
inplace: bool = ...) -> Tensor: ...
def gelu(input: Any, approximate: str = ...): ...
def gelu(input: Any): ...
def hardshrink(input: Tensor, lambd: float = ...) -> Tensor: ...

View File

@ -654,13 +654,6 @@ class GELU(Module):
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
When the approximate argument is 'tanh', Gelu is estimated with:
:math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
Args:
approximate (string, optional): the gelu approximation algorithm to use:
``'none'`` | ``'tanh'``. Default: ``'none'``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
@ -673,18 +666,8 @@ class GELU(Module):
>>> input = torch.randn(2)
>>> output = m(input)
"""
__constants__ = ['approximate']
approximate: str
def __init__(self, approximate: str = 'none') -> None:
super(GELU, self).__init__()
self.approximate = approximate
def forward(self, input: Tensor) -> Tensor:
return F.gelu(input, self.approximate)
def extra_repr(self) -> str:
return 'approximate={}'.format(self.approximate)
return F.gelu(input)
class Hardshrink(Module):

View File

@ -3014,27 +3014,12 @@ def remainder(g, input, other):
quo = g.op("Mul", div, other)
return g.op("Sub", input, quo)
@parse_args("v", "i")
def gelu(g, self, approximate):
# none approximate : onnx::Constant[value={0}]
# tanh approximate : onnx::Constant[value={1}]
if approximate == 1:
kBeta = math.sqrt(2 / math.pi)
kKappa = 0.044715
beta = torch.tensor(kBeta, dtype=torch.double)
kappa = torch.tensor(kKappa, dtype=torch.double)
one = torch.tensor(1., dtype=torch.double)
half = torch.tensor(0.5, dtype=torch.double)
self_cube = mul(g, self, mul(g, self, self))
inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube)))
return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner))))
else:
_sqrt2 = 1.4142135623730951
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
erf_plusone = add(g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)))
return mul(g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)))
def gelu(g, self):
_sqrt2 = 1.4142135623730951
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
erf_plusone = add(g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)))
return mul(g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)))
@parse_args("v", "i", "v", "v", "f", "i")
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):

View File

@ -730,7 +730,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None: -1),
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
torch.nn.functional.gelu: lambda input, approximate='none': -1,
torch.nn.functional.gelu: lambda input: -1,
torch.nn.functional.glu: lambda input, dim=-1: -1,
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,

View File

@ -327,8 +327,7 @@ class AutocastCPUTestLists(object):
self.nn_fp32 = [
("avg_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}),
("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
("gelu", dummy_bf16[3], {"approximate": torch.nn._gelu.get_enum('none')}),
("gelu", dummy_bf16[3], {"approximate": torch.nn._gelu.get_enum('tanh')}),
("gelu", dummy_bf16[3]),
("upsample_nearest1d", dummy_bf16[2], {"output_size": (n)}),
("upsample_nearest2d", dummy_bf16[3], {"output_size": (n, n)}),
("upsample_nearest3d", dummy_bf16[4], {"output_size": (n, n, n)}),

View File

@ -3903,6 +3903,7 @@ def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
# With `None` weight and bias (tests failing for this, see the link above)
# yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,))))
def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -3924,6 +3925,7 @@ def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kw
for input_shape, size, kwargs in cases:
yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs)
def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs):
N = 5
# make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ?
@ -4080,13 +4082,8 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs):
N = 5
tensors = []
for _ in range(1, N):
for approximate in ['none', 'tanh']:
tensors.append(SampleInput(
make_tensor((N * 2, N * 2), device=device, dtype=dtype,
requires_grad=requires_grad, low=-3, high=3),
kwargs=dict(approximate=approximate)))
tensors = [SampleInput(make_tensor((N * 2, N * 2), device=device, dtype=dtype,
requires_grad=requires_grad, low=-3, high=3)) for _ in range(1, N)]
return tensors
def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs):
@ -11776,7 +11773,7 @@ op_db: List[OpInfo] = [
supports_gradgrad=True,
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=False,
supports_fwgrad_bwgrad=True,
autodiff_nonfusible_nodes=["aten::gelu"]),
OpInfo('nn.functional.relu6',
aten_name="relu6",

View File

@ -3716,16 +3716,12 @@ new_module_tests = [
),
dict(
module_name='GELU',
constructor_args=('none',),
cpp_constructor_args='torch::nn::GELUOptions().approximate(torch::kNone)',
input_size=(),
desc='scalar',
reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
),
dict(
module_name='GELU',
constructor_args=('none',),
cpp_constructor_args='torch::nn::GELUOptions().approximate(torch::kNone)',
input_size=(3, 2, 5),
reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
),