mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert D33850228: [pytorch][PR] Implement Tanh Gelu Approximation
Test Plan: revert-hammer Differential Revision: D33850228 (23d03025dc) Original commit changeset: 3cc33fb298e4 Original Phabricator Diff: D33850228 (23d03025dc) fbshipit-source-id: 9436e7df73c2b2e2011f321674f24973316d3692
This commit is contained in:
parent
214624e254
commit
c9efb58223
|
|
@ -485,7 +485,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
|
||||||
KERNEL_CPU(ADD_NS(avg_pool1d), "avg_pool1d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool), fp32)
|
KERNEL_CPU(ADD_NS(avg_pool1d), "avg_pool1d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool), fp32)
|
||||||
KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
|
KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
|
||||||
KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
|
KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
|
||||||
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &, int64_t), fp32)
|
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &), fp32)
|
||||||
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
|
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
|
||||||
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
|
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
|
||||||
KERNEL_CPU(ADD_NS(_upsample_nearest_exact1d), "_upsample_nearest_exact1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
|
KERNEL_CPU(ADD_NS(_upsample_nearest_exact1d), "_upsample_nearest_exact1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
|
||||||
|
|
|
||||||
|
|
@ -164,12 +164,12 @@ TORCH_META_FUNC(softshrink_backward) (
|
||||||
build_borrowing_binary_op(maybe_get_output(), grad, self);
|
build_borrowing_binary_op(maybe_get_output(), grad, self);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_META_FUNC(gelu) (const Tensor & self, int64_t approximate) {
|
TORCH_META_FUNC(gelu) (const Tensor & self) {
|
||||||
build_unary_op(maybe_get_output(), self);
|
build_unary_op(maybe_get_output(), self);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_META_FUNC(gelu_backward) (
|
TORCH_META_FUNC(gelu_backward) (
|
||||||
const Tensor& grad, const Tensor& self, int64_t approximate
|
const Tensor& grad, const Tensor& self
|
||||||
) {
|
) {
|
||||||
build_borrowing_binary_op(maybe_get_output(), grad, self);
|
build_borrowing_binary_op(maybe_get_output(), grad, self);
|
||||||
}
|
}
|
||||||
|
|
@ -324,37 +324,37 @@ bool use_mkldnn(const Tensor& input) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(gelu_out_cpu) (
|
TORCH_IMPL_FUNC(gelu_out_cpu) (
|
||||||
const Tensor& self, int64_t approximate, const Tensor& result
|
const Tensor& self, const Tensor& result
|
||||||
) {
|
) {
|
||||||
#if AT_MKLDNN_ENABLED()
|
#if AT_MKLDNN_ENABLED()
|
||||||
if (use_mkldnn(self) && (approximate == at::Gelu::None)) {
|
if (use_mkldnn(self)) {
|
||||||
const ideep::tensor& x = itensor_from_tensor(self);
|
const ideep::tensor& x = itensor_from_tensor(self);
|
||||||
ideep::tensor y = itensor_from_tensor(result);
|
ideep::tensor y = itensor_from_tensor(result);
|
||||||
ideep::eltwise_forward::compute(
|
ideep::eltwise_forward::compute(
|
||||||
x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
|
x, y, ideep::algorithm::eltwise_gelu_erf, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
|
||||||
} else {
|
} else {
|
||||||
GeluKernel(kCPU, *this, approximate);
|
GeluKernel(kCPU, *this);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
GeluKernel(kCPU, *this, approximate);
|
GeluKernel(kCPU, *this);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(gelu_backward_out_cpu) (
|
TORCH_IMPL_FUNC(gelu_backward_out_cpu) (
|
||||||
const Tensor& grad, const Tensor& self, int64_t approximate, const Tensor& grad_input
|
const Tensor& grad, const Tensor& self, const Tensor& grad_input
|
||||||
) {
|
) {
|
||||||
#if AT_MKLDNN_ENABLED()
|
#if AT_MKLDNN_ENABLED()
|
||||||
if (use_mkldnn(self) && (approximate == at::Gelu::None)) {
|
if (use_mkldnn(self)) {
|
||||||
const ideep::tensor& x = itensor_from_tensor(self);
|
const ideep::tensor& x = itensor_from_tensor(self);
|
||||||
ideep::tensor grady = itensor_from_tensor(grad);
|
ideep::tensor grady = itensor_from_tensor(grad);
|
||||||
ideep::tensor gradx = itensor_from_tensor(grad_input);
|
ideep::tensor gradx = itensor_from_tensor(grad_input);
|
||||||
ideep::eltwise_backward::compute(x, grady, gradx,
|
ideep::eltwise_backward::compute(x, grady, gradx,
|
||||||
ideep::algorithm::eltwise_gelu_erf, /*alpha*/ 0.0);
|
ideep::algorithm::eltwise_gelu_erf, /*alpha*/ 0.0);
|
||||||
} else {
|
} else {
|
||||||
GeluBackwardKernel(kCPU, *this, approximate);
|
GeluBackwardKernel(kCPU, *this);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
GeluBackwardKernel(kCPU, *this, approximate);
|
GeluBackwardKernel(kCPU, *this);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,19 +12,6 @@ struct TensorIteratorBase;
|
||||||
class TensorBase;
|
class TensorBase;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace at {
|
|
||||||
namespace Gelu {
|
|
||||||
|
|
||||||
// Keep this in sync with Gelu class in torch/nn/_gelu.py
|
|
||||||
// These constants control the approximation behavior of gelu functions.
|
|
||||||
enum Gelu {
|
|
||||||
None, // Baseline Gelu
|
|
||||||
Tanh, // Tahn Gelu Approximation
|
|
||||||
END
|
|
||||||
};
|
|
||||||
} // namespace Gelu
|
|
||||||
} // namespace at
|
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
|
||||||
using structured_activation_fn = void (*)(TensorIteratorBase&);
|
using structured_activation_fn = void (*)(TensorIteratorBase&);
|
||||||
|
|
@ -48,8 +35,6 @@ using elu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const
|
||||||
using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
using leaky_relu_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
||||||
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
using leaky_relu_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
|
||||||
using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
|
using log_sigmoid_cpu_fn = void (*)(TensorBase&, TensorBase&, const TensorBase&);
|
||||||
using gelu_fn = void (*)(TensorIteratorBase&, int64_t);
|
|
||||||
using gelu_backward_fn = void (*)(TensorIteratorBase&, int64_t);
|
|
||||||
|
|
||||||
DECLARE_DISPATCH(elu_fn, elu_stub);
|
DECLARE_DISPATCH(elu_fn, elu_stub);
|
||||||
DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
|
DECLARE_DISPATCH(elu_backward_fn, elu_backward_stub);
|
||||||
|
|
@ -58,8 +43,8 @@ DECLARE_DISPATCH(softplus_backward_fn, softplus_backward_stub);
|
||||||
DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
|
DECLARE_DISPATCH(log_sigmoid_cpu_fn, log_sigmoid_cpu_stub);
|
||||||
DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
|
DECLARE_DISPATCH(activation_backward_fn, log_sigmoid_backward_stub);
|
||||||
DECLARE_DISPATCH(threshold_fn, threshold_stub);
|
DECLARE_DISPATCH(threshold_fn, threshold_stub);
|
||||||
DECLARE_DISPATCH(gelu_fn, GeluKernel);
|
DECLARE_DISPATCH(structured_activation_fn, GeluKernel);
|
||||||
DECLARE_DISPATCH(gelu_backward_fn, GeluBackwardKernel);
|
DECLARE_DISPATCH(structured_activation_backward_fn, GeluBackwardKernel);
|
||||||
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
|
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
|
||||||
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
|
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
|
||||||
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
|
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
|
||||||
|
|
|
||||||
|
|
@ -166,7 +166,7 @@ void elu_backward_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scal
|
||||||
// TODO(yangxm): Add another fast kernel using formula
|
// TODO(yangxm): Add another fast kernel using formula
|
||||||
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
|
// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715x^3)))
|
||||||
// and the fast tanh impl from Eigen.
|
// and the fast tanh impl from Eigen.
|
||||||
void GeluKernelImpl(TensorIteratorBase& it, int64_t approximate) {
|
void GeluKernelImpl(TensorIteratorBase& it) {
|
||||||
auto grain_size = at::internal::GRAIN_SIZE;
|
auto grain_size = at::internal::GRAIN_SIZE;
|
||||||
// Numbers based on benchmarking.
|
// Numbers based on benchmarking.
|
||||||
// Benchmark: benchmarks/operator_benchmarks/pt/gelu_test.py
|
// Benchmark: benchmarks/operator_benchmarks/pt/gelu_test.py
|
||||||
|
|
@ -187,134 +187,53 @@ void GeluKernelImpl(TensorIteratorBase& it, int64_t approximate) {
|
||||||
if (it.numel() > GELU_MIN_ELEMENTS_FOR_MULTI_THREADING) {
|
if (it.numel() > GELU_MIN_ELEMENTS_FOR_MULTI_THREADING) {
|
||||||
grain_size = it.numel() / at::get_num_threads();
|
grain_size = it.numel() / at::get_num_threads();
|
||||||
}
|
}
|
||||||
if (approximate == at::Gelu::Tanh) {
|
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
|
||||||
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
|
using Vec = vec::Vectorized<scalar_t>;
|
||||||
using Vec = vec::Vectorized<scalar_t>;
|
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
|
||||||
const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
|
const Vec kOneVec(scalar_t(1));
|
||||||
const Vec kKappaVec(scalar_t(0.044715));
|
const Vec kPointFiveVec(scalar_t(0.5));
|
||||||
const Vec kOneVec(scalar_t(1));
|
cpu_kernel_vec(
|
||||||
const Vec kPointFiveVec(scalar_t(0.5));
|
it,
|
||||||
cpu_kernel_vec(
|
[](scalar_t x) {
|
||||||
it,
|
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
|
||||||
[](scalar_t x) {
|
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
|
||||||
const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
|
},
|
||||||
const scalar_t kKappa = 0.044715;
|
[&](Vec x_vec) {
|
||||||
auto x_cube = x * x * x;
|
return x_vec * kPointFiveVec *
|
||||||
auto inner = kBeta * (x + kKappa * x_cube);
|
(kOneVec + (x_vec * kAlphaVec).erf());
|
||||||
return scalar_t(0.5) * x * (scalar_t(1) + std::tanh(inner));
|
},
|
||||||
},
|
grain_size);
|
||||||
[&](Vec x_vec) {
|
});
|
||||||
auto x_cube = x_vec * x_vec * x_vec;
|
|
||||||
auto inner_vec = kBetaVec * (x_vec + kKappaVec * x_cube);
|
|
||||||
return kPointFiveVec * x_vec * (kOneVec + inner_vec.tanh());
|
|
||||||
},
|
|
||||||
grain_size);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
|
||||||
ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
|
|
||||||
using Vec = vec::Vectorized<scalar_t>;
|
|
||||||
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
|
|
||||||
const Vec kOneVec(scalar_t(1));
|
|
||||||
const Vec kPointFiveVec(scalar_t(0.5));
|
|
||||||
cpu_kernel_vec(
|
|
||||||
it,
|
|
||||||
[](scalar_t x) {
|
|
||||||
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
|
|
||||||
return x * scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
|
|
||||||
},
|
|
||||||
[&](Vec x_vec) {
|
|
||||||
return x_vec * kPointFiveVec *
|
|
||||||
(kOneVec + (x_vec * kAlphaVec).erf());
|
|
||||||
},
|
|
||||||
grain_size);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GeluBackwardKernelImpl(TensorIteratorBase& it, int64_t approximate) {
|
void GeluBackwardKernelImpl(TensorIteratorBase& it) {
|
||||||
if (approximate == at::Gelu::Tanh) {
|
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
|
||||||
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
|
using Vec = vec::Vectorized<scalar_t>;
|
||||||
using Vec = vec::Vectorized<scalar_t>;
|
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
|
||||||
const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
|
const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
|
||||||
const Vec kKappaVec(scalar_t(0.044715));
|
const Vec kOneVec(scalar_t(1));
|
||||||
const Vec kOneVec(scalar_t(1));
|
const Vec kPointFiveVec(scalar_t(0.5));
|
||||||
const Vec kThreeVec(scalar_t(3));
|
const Vec kMinusPointFiveVec(scalar_t(-0.5));
|
||||||
const Vec kPointFiveVec(scalar_t(0.5));
|
cpu_kernel_vec(
|
||||||
cpu_kernel_vec(
|
it,
|
||||||
it,
|
[](scalar_t dy, scalar_t x) {
|
||||||
[](scalar_t dy, scalar_t x) {
|
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
|
||||||
const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
|
const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
|
||||||
const scalar_t kKappa = 0.044715;
|
const scalar_t cdf =
|
||||||
auto x_sq = x * x;
|
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
|
||||||
auto x_cube = x_sq * x;
|
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
|
||||||
auto inner = kBeta * (x + kKappa * x_cube);
|
return dy * (cdf + x * pdf);
|
||||||
auto tanh_inner = std::tanh(inner);
|
},
|
||||||
|
[&](Vec dy_vec, Vec x_vec) {
|
||||||
auto left = scalar_t(0.5) * x;
|
const Vec cdf_vec =
|
||||||
auto right = scalar_t(1) + tanh_inner;
|
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
|
||||||
|
const Vec pdf_vec =
|
||||||
auto left_derivative = scalar_t(0.5) * right;
|
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
|
||||||
|
return dy_vec * (cdf_vec + x_vec * pdf_vec);
|
||||||
auto tanh_derivative = scalar_t(1) - tanh_inner * tanh_inner;
|
});
|
||||||
auto inner_derivative =
|
});
|
||||||
kBeta * (scalar_t(1) + scalar_t(3) * kKappa * x_sq);
|
|
||||||
auto right_derivative = left * tanh_derivative * inner_derivative;
|
|
||||||
|
|
||||||
return dy * (left_derivative + right_derivative);
|
|
||||||
},
|
|
||||||
[&](Vec dy_vec, Vec x_vec) {
|
|
||||||
auto x_sq = x_vec * x_vec;
|
|
||||||
auto x_cube = x_vec * x_vec * x_vec;
|
|
||||||
auto inner_vec =
|
|
||||||
kBetaVec * (x_vec + kKappaVec * x_cube);
|
|
||||||
auto tanh_inner_vec = inner_vec.tanh();
|
|
||||||
|
|
||||||
auto left_vec = kPointFiveVec * x_vec;
|
|
||||||
auto right_vec = kOneVec + tanh_inner_vec;
|
|
||||||
|
|
||||||
auto left_derivative_vec = kPointFiveVec * right_vec;
|
|
||||||
|
|
||||||
auto tanh_derivative_vec =
|
|
||||||
kOneVec - tanh_inner_vec * tanh_inner_vec;
|
|
||||||
auto inner_derivative_vec =
|
|
||||||
kBetaVec * (kOneVec + kThreeVec * kKappaVec * x_sq);
|
|
||||||
auto right_derivative_vec =
|
|
||||||
left_vec * tanh_derivative_vec * inner_derivative_vec;
|
|
||||||
|
|
||||||
return dy_vec * (left_derivative_vec + right_derivative_vec);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
|
||||||
ScalarType::BFloat16, it.dtype(), "GeluBackwardKernelImpl", [&]() {
|
|
||||||
using Vec = vec::Vectorized<scalar_t>;
|
|
||||||
const Vec kAlphaVec(scalar_t(M_SQRT1_2));
|
|
||||||
const Vec kBetaVec(scalar_t(M_2_SQRTPI * M_SQRT1_2 * 0.5));
|
|
||||||
const Vec kOneVec(scalar_t(1));
|
|
||||||
const Vec kPointFiveVec(scalar_t(0.5));
|
|
||||||
const Vec kMinusPointFiveVec(scalar_t(-0.5));
|
|
||||||
cpu_kernel_vec(
|
|
||||||
it,
|
|
||||||
[](scalar_t dy, scalar_t x) {
|
|
||||||
const scalar_t kAlpha = scalar_t(M_SQRT1_2);
|
|
||||||
const scalar_t kBeta = M_2_SQRTPI * M_SQRT1_2 * scalar_t(0.5);
|
|
||||||
const scalar_t cdf =
|
|
||||||
scalar_t(0.5) * (scalar_t(1) + std::erf(x * kAlpha));
|
|
||||||
const scalar_t pdf = kBeta * std::exp(x * x * scalar_t(-0.5));
|
|
||||||
return dy * (cdf + x * pdf);
|
|
||||||
},
|
|
||||||
[&](Vec dy_vec, Vec x_vec) {
|
|
||||||
const Vec cdf_vec =
|
|
||||||
kPointFiveVec * (kOneVec + (x_vec * kAlphaVec).erf());
|
|
||||||
const Vec pdf_vec =
|
|
||||||
kBetaVec * (x_vec * x_vec * kMinusPointFiveVec).exp();
|
|
||||||
return dy_vec * (cdf_vec + x_vec * pdf_vec);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void hardsigmoid_kernel(TensorIteratorBase& iter) {
|
void hardsigmoid_kernel(TensorIteratorBase& iter) {
|
||||||
|
|
|
||||||
|
|
@ -153,15 +153,15 @@ std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Te
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(gelu_out_cuda) (
|
TORCH_IMPL_FUNC(gelu_out_cuda) (
|
||||||
const Tensor& /*self*/, int64_t approximate, const Tensor& /*result*/
|
const Tensor& /*self*/, const Tensor& /*result*/
|
||||||
) {
|
) {
|
||||||
GeluCUDAKernelImpl(*this, approximate);
|
GeluCUDAKernelImpl(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(gelu_backward_out_cuda) (
|
TORCH_IMPL_FUNC(gelu_backward_out_cuda) (
|
||||||
const Tensor& /*grad*/, const Tensor& /*self*/, int64_t approximate, const Tensor& /*grad_input*/
|
const Tensor& /*grad*/, const Tensor& /*self*/, const Tensor& /*grad_input*/
|
||||||
) {
|
) {
|
||||||
GeluBackwardCUDAKernelImpl(*this, approximate);
|
GeluBackwardCUDAKernelImpl(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
}} // namespace at::native
|
}} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -392,71 +392,30 @@ void elu_backward_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Sc
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void GeluCUDAKernelImpl(TensorIteratorBase& it, int64_t approximate) {
|
void GeluCUDAKernelImpl(TensorIteratorBase& it) {
|
||||||
if (approximate == at::Gelu::Tanh) {
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
|
using T_ACC = acc_type<scalar_t, true>;
|
||||||
using T_ACC = acc_type<scalar_t, true>;
|
gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
||||||
gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
return static_cast<T_ACC>(x) *
|
||||||
constexpr T_ACC kBeta = M_SQRT2 * M_2_SQRTPI * T_ACC(0.5);
|
c10::cuda::compat::normcdf(static_cast<T_ACC>(x));
|
||||||
constexpr T_ACC kKappa = 0.044715;
|
|
||||||
auto x_cube = static_cast<T_ACC>(x) * static_cast<T_ACC>(x) * static_cast<T_ACC>(x);
|
|
||||||
auto inner = kBeta * (static_cast<T_ACC>(x) + kKappa * x_cube);
|
|
||||||
return T_ACC(0.5) * static_cast<T_ACC>(x) * (T_ACC(1) + c10::cuda::compat::tanh(inner));
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
} else {
|
});
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() {
|
|
||||||
using T_ACC = acc_type<scalar_t, true>;
|
|
||||||
gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
|
|
||||||
constexpr T_ACC kAlpha = M_SQRT1_2;
|
|
||||||
return static_cast<T_ACC>(x) * T_ACC(0.5) * (T_ACC(1) + ::erf(static_cast<T_ACC>(x) * kAlpha));
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, int64_t approximate) {
|
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it) {
|
||||||
if (approximate == at::Gelu::Tanh) {
|
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
|
it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
|
||||||
it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
|
using T_ACC = acc_type<scalar_t, true>;
|
||||||
using T_ACC = acc_type<scalar_t, true>;
|
gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
|
||||||
gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
|
constexpr T_ACC kBeta = M_2_SQRTPI * M_SQRT1_2 * T_ACC(0.5);
|
||||||
constexpr T_ACC kBeta = M_SQRT2 * M_2_SQRTPI * T_ACC(0.5);
|
const T_ACC cdf = c10::cuda::compat::normcdf(static_cast<T_ACC>(x));
|
||||||
constexpr T_ACC kKappa = 0.044715;
|
const T_ACC pdf =
|
||||||
auto x_sq = static_cast<T_ACC>(x) * static_cast<T_ACC>(x);
|
c10::cuda::compat::exp(
|
||||||
auto x_cube = x_sq * static_cast<T_ACC>(x);
|
T_ACC(-0.5) * static_cast<T_ACC>(x) * static_cast<T_ACC>(x)) *
|
||||||
auto inner = kBeta * (static_cast<T_ACC>(x) + kKappa * x_cube);
|
kBeta;
|
||||||
auto tanh_inner = c10::cuda::compat::tanh(inner);
|
return static_cast<T_ACC>(dy) * (cdf + static_cast<T_ACC>(x) * pdf);
|
||||||
|
|
||||||
auto left = T_ACC(0.5) * static_cast<T_ACC>(x);
|
|
||||||
auto right = T_ACC(1) + tanh_inner;
|
|
||||||
|
|
||||||
auto left_derivative = 0.5 * right;
|
|
||||||
|
|
||||||
auto tanh_derivative = T_ACC(1) - tanh_inner * tanh_inner;
|
|
||||||
auto inner_derivative = kBeta * (T_ACC(1) + T_ACC(3) * kKappa * x_sq);
|
|
||||||
auto right_derivative = left * tanh_derivative * inner_derivative;
|
|
||||||
|
|
||||||
return static_cast<T_ACC>(dy) * (left_derivative + right_derivative);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
} else {
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
|
|
||||||
it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() {
|
|
||||||
using T_ACC = acc_type<scalar_t, true>;
|
|
||||||
gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
|
|
||||||
constexpr T_ACC kBeta = M_2_SQRTPI * M_SQRT1_2 * T_ACC(0.5);
|
|
||||||
constexpr T_ACC kAlpha = M_SQRT1_2;
|
|
||||||
const T_ACC cdf =
|
|
||||||
T_ACC(0.5) * (T_ACC(1) + ::erf(static_cast<T_ACC>(x) * kAlpha));
|
|
||||||
const T_ACC pdf =
|
|
||||||
c10::cuda::compat::exp(
|
|
||||||
T_ACC(-0.5) * static_cast<T_ACC>(x) * static_cast<T_ACC>(x)) *
|
|
||||||
kBeta;
|
|
||||||
return static_cast<T_ACC>(dy) * (cdf + static_cast<T_ACC>(x) * pdf);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
|
|
||||||
#include <ATen/native/Activation.h>
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
@ -25,7 +24,7 @@ void launch_prelu_cuda_backward_kernel_multi_weights(
|
||||||
const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out,
|
const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out,
|
||||||
const TensorBase &input_grad, const TensorBase &weight_grad_collector);
|
const TensorBase &input_grad, const TensorBase &weight_grad_collector);
|
||||||
|
|
||||||
void GeluCUDAKernelImpl(TensorIteratorBase& it, int64_t approximate);
|
void GeluCUDAKernelImpl(TensorIteratorBase& it);
|
||||||
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, int64_t approximate);
|
void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it);
|
||||||
|
|
||||||
}} // namespace at::native
|
}} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,17 @@
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
#include <ATen/native/Activation.h>
|
|
||||||
|
|
||||||
|
|
||||||
#if !AT_MKLDNN_ENABLED()
|
#if !AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
|
||||||
Tensor mkldnn_gelu(const Tensor& input, int64_t approximate) {
|
Tensor mkldnn_gelu(const Tensor& input) {
|
||||||
TORCH_CHECK(false, "mkldnn_gelu: ATen not compiled with MKLDNN support");
|
TORCH_CHECK(false, "mkldnn_gelu: ATen not compiled with MKLDNN support");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, int64_t approximate) {
|
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input) {
|
||||||
TORCH_CHECK(false, "mkldnn_gelu_backward: ATen not compiled with MKLDNN support");
|
TORCH_CHECK(false, "mkldnn_gelu_backward: ATen not compiled with MKLDNN support");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -25,13 +24,11 @@ Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, int6
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
|
||||||
Tensor mkldnn_gelu(const Tensor& input, int64_t approximate) {
|
Tensor mkldnn_gelu(const Tensor& input) {
|
||||||
if (input.scalar_type() == ScalarType::BFloat16) {
|
if (input.scalar_type() == ScalarType::BFloat16) {
|
||||||
TORCH_CHECK(mkldnn_bf16_device_check(),
|
TORCH_CHECK(mkldnn_bf16_device_check(),
|
||||||
"mkldnn_gelu: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
|
"mkldnn_gelu: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
|
||||||
}
|
}
|
||||||
TORCH_CHECK(approximate == at::Gelu::None,
|
|
||||||
"mkldnn_gelu: fast, approximate gelu is not supported");
|
|
||||||
const ideep::tensor& x = itensor_from_tensor(input);
|
const ideep::tensor& x = itensor_from_tensor(input);
|
||||||
ideep::tensor y;
|
ideep::tensor y;
|
||||||
ideep::eltwise_forward::compute(
|
ideep::eltwise_forward::compute(
|
||||||
|
|
@ -40,9 +37,7 @@ Tensor mkldnn_gelu(const Tensor& input, int64_t approximate) {
|
||||||
input.options().device_opt());
|
input.options().device_opt());
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input, int64_t approximate) {
|
Tensor mkldnn_gelu_backward(const Tensor& grad_output, const Tensor& input) {
|
||||||
TORCH_CHECK(approximate == at::Gelu::None,
|
|
||||||
"mkldnn_gelu_backward: fast, approximate gelu is not supported");
|
|
||||||
const ideep::tensor& x = itensor_from_tensor(input);
|
const ideep::tensor& x = itensor_from_tensor(input);
|
||||||
ideep::tensor grady = itensor_from_tensor(grad_output);
|
ideep::tensor grady = itensor_from_tensor(grad_output);
|
||||||
ideep::tensor gradx;
|
ideep::tensor gradx;
|
||||||
|
|
|
||||||
|
|
@ -3724,7 +3724,7 @@
|
||||||
CPU: prelu_backward_cpu
|
CPU: prelu_backward_cpu
|
||||||
CUDA: prelu_backward_cuda
|
CUDA: prelu_backward_cuda
|
||||||
|
|
||||||
- func: gelu.out(Tensor self, int approximate=0, *, Tensor(a!) out) -> Tensor(a!)
|
- func: gelu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
structured: True
|
structured: True
|
||||||
structured_inherits: TensorIteratorBase
|
structured_inherits: TensorIteratorBase
|
||||||
device_check: NoCheck # TensorIterator
|
device_check: NoCheck # TensorIterator
|
||||||
|
|
@ -3733,7 +3733,7 @@
|
||||||
CPU: gelu_out_cpu
|
CPU: gelu_out_cpu
|
||||||
CUDA: gelu_out_cuda
|
CUDA: gelu_out_cuda
|
||||||
|
|
||||||
- func: gelu(Tensor self, int approximate=0) -> Tensor
|
- func: gelu(Tensor self) -> Tensor
|
||||||
structured_delegate: gelu.out
|
structured_delegate: gelu.out
|
||||||
device_check: NoCheck # TensorIterator
|
device_check: NoCheck # TensorIterator
|
||||||
python_module: nn
|
python_module: nn
|
||||||
|
|
@ -3741,7 +3741,7 @@
|
||||||
MkldnnCPU: mkldnn_gelu
|
MkldnnCPU: mkldnn_gelu
|
||||||
QuantizedCPU: gelu_quantized_cpu
|
QuantizedCPU: gelu_quantized_cpu
|
||||||
|
|
||||||
- func: gelu_backward.grad_input(Tensor grad_output, Tensor self, int approximate=0, *, Tensor(a!) grad_input) -> Tensor(a!)
|
- func: gelu_backward.grad_input(Tensor grad, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||||
structured: True
|
structured: True
|
||||||
structured_inherits: TensorIteratorBase
|
structured_inherits: TensorIteratorBase
|
||||||
python_module: nn
|
python_module: nn
|
||||||
|
|
@ -3749,7 +3749,7 @@
|
||||||
CPU: gelu_backward_out_cpu
|
CPU: gelu_backward_out_cpu
|
||||||
CUDA: gelu_backward_out_cuda
|
CUDA: gelu_backward_out_cuda
|
||||||
|
|
||||||
- func: gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor
|
- func: gelu_backward(Tensor grad, Tensor self) -> Tensor
|
||||||
structured_delegate: gelu_backward.grad_input
|
structured_delegate: gelu_backward.grad_input
|
||||||
python_module: nn
|
python_module: nn
|
||||||
dispatch:
|
dispatch:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
#include <ATen/native/Activation.h>
|
|
||||||
#include <ATen/native/SortingUtils.h>
|
#include <ATen/native/SortingUtils.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/UpSample.h>
|
#include <ATen/native/UpSample.h>
|
||||||
|
|
@ -616,7 +615,7 @@ static void leaky_qrelu_out_kernel(Tensor& out, const Tensor& qx,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void qgelu_kernel(const Tensor& qx, Tensor& qy, int64_t approximate) {
|
void qgelu_kernel(const Tensor& qx, Tensor& qy) {
|
||||||
int64_t zero_point = qx.q_zero_point();
|
int64_t zero_point = qx.q_zero_point();
|
||||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||||
float scale = qx.q_scale();
|
float scale = qx.q_scale();
|
||||||
|
|
@ -627,83 +626,40 @@ void qgelu_kernel(const Tensor& qx, Tensor& qy, int64_t approximate) {
|
||||||
float output_scale = scale;
|
float output_scale = scale;
|
||||||
float inv_output_scale = 1.0 / output_scale;
|
float inv_output_scale = 1.0 / output_scale;
|
||||||
const auto kAlphaVec = Vectorized<float>(M_SQRT1_2);
|
const auto kAlphaVec = Vectorized<float>(M_SQRT1_2);
|
||||||
const auto kBetaVec = Vectorized<float>(M_SQRT2 * M_2_SQRTPI * 0.5);
|
|
||||||
const auto kKappaVec = Vectorized<float>(0.044715);
|
|
||||||
const auto kOneVec = Vectorized<float>(1);
|
const auto kOneVec = Vectorized<float>(1);
|
||||||
const auto kPointFiveVec = Vectorized<float>(0.5);
|
const auto kPointFiveVec = Vectorized<float>(0.5);
|
||||||
|
|
||||||
if (approximate == at::Gelu::Tanh) {
|
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
|
||||||
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
|
qy = at::_empty_affine_quantized(
|
||||||
qy = at::_empty_affine_quantized(
|
qx.sizes(),
|
||||||
qx.sizes(),
|
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
|
||||||
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
|
output_scale,
|
||||||
output_scale,
|
output_zero_point,
|
||||||
output_zero_point,
|
c10::nullopt);
|
||||||
c10::nullopt);
|
auto iter = TensorIterator::unary_op(qy, qx);
|
||||||
auto iter = TensorIterator::unary_op(qy, qx);
|
|
||||||
|
|
||||||
using Vec = Vectorized<scalar_t>;
|
using Vec = Vectorized<scalar_t>;
|
||||||
cpu_kernel_vec(
|
cpu_kernel_vec(
|
||||||
iter,
|
iter,
|
||||||
[&](scalar_t value_qx) -> scalar_t {
|
[&](scalar_t value_qx) -> scalar_t {
|
||||||
const auto value_dx =
|
const auto value_dx =
|
||||||
at::native::dequantize_val(scale, zero_point, value_qx);
|
at::native::dequantize_val(scale, zero_point, value_qx);
|
||||||
|
const auto value_dy =
|
||||||
const auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
|
value_dx * 0.5 * (1 + std::erf(value_dx * M_SQRT1_2));
|
||||||
const auto kKappa = 0.044715;
|
return at::native::quantize_val<scalar_t>(
|
||||||
const auto x_cube = value_dx * value_dx * value_dx;
|
output_scale, output_zero_point, value_dy);
|
||||||
const auto inner = kBeta * (value_dx + kKappa * x_cube);
|
},
|
||||||
const auto value_dy = 0.5 * value_dx * (1.0 + std::tanh(inner));
|
[&](Vec value_qx) -> Vec {
|
||||||
|
auto value_dx = value_qx.dequantize(
|
||||||
return at::native::quantize_val<scalar_t>(
|
scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
|
||||||
output_scale, output_zero_point, value_dy);
|
for (auto & value : value_dx) {
|
||||||
},
|
value = value * kPointFiveVec * (kOneVec + (value * kAlphaVec).erf());
|
||||||
[&](Vec value_qx) -> Vec {
|
}
|
||||||
auto value_dx = value_qx.dequantize(
|
return Vec::quantize(
|
||||||
scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
|
value_dx, output_scale, output_zero_point, inv_output_scale);
|
||||||
for (auto & value : value_dx) {
|
});
|
||||||
auto value_cube = value * value * value;
|
});
|
||||||
auto inner = kBetaVec * (value + kKappaVec * value_cube);
|
|
||||||
value = kPointFiveVec * value * (kOneVec + inner.tanh());
|
|
||||||
}
|
|
||||||
return Vec::quantize(
|
|
||||||
value_dx, output_scale, output_zero_point, inv_output_scale);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
|
|
||||||
qy = at::_empty_affine_quantized(
|
|
||||||
qx.sizes(),
|
|
||||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
|
||||||
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
|
|
||||||
output_scale,
|
|
||||||
output_zero_point,
|
|
||||||
c10::nullopt);
|
|
||||||
auto iter = TensorIterator::unary_op(qy, qx);
|
|
||||||
|
|
||||||
using Vec = Vectorized<scalar_t>;
|
|
||||||
cpu_kernel_vec(
|
|
||||||
iter,
|
|
||||||
[&](scalar_t value_qx) -> scalar_t {
|
|
||||||
const auto value_dx =
|
|
||||||
at::native::dequantize_val(scale, zero_point, value_qx);
|
|
||||||
const auto value_dy =
|
|
||||||
value_dx * 0.5 * (1 + std::erf(value_dx * M_SQRT1_2));
|
|
||||||
return at::native::quantize_val<scalar_t>(
|
|
||||||
output_scale, output_zero_point, value_dy);
|
|
||||||
},
|
|
||||||
[&](Vec value_qx) -> Vec {
|
|
||||||
auto value_dx = value_qx.dequantize(
|
|
||||||
scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
|
|
||||||
for (auto & value : value_dx) {
|
|
||||||
value = value * kPointFiveVec * (kOneVec + (value * kAlphaVec).erf());
|
|
||||||
}
|
|
||||||
return Vec::quantize(
|
|
||||||
value_dx, output_scale, output_zero_point, inv_output_scale);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,9 @@ namespace native {
|
||||||
|
|
||||||
DEFINE_DISPATCH(qgelu_stub);
|
DEFINE_DISPATCH(qgelu_stub);
|
||||||
|
|
||||||
Tensor gelu_quantized_cpu(const Tensor& qx, int64_t approximate) {
|
Tensor gelu_quantized_cpu(const Tensor& qx) {
|
||||||
Tensor qy;
|
Tensor qy;
|
||||||
qgelu_stub(qx.device().type(), qx, qy, approximate);
|
qgelu_stub(qx.device().type(), qx, qy);
|
||||||
return qy;
|
return qy;
|
||||||
}
|
}
|
||||||
}} // namespace at::native
|
}} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ namespace native {
|
||||||
using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
|
using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
|
||||||
using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
|
using qrelu_leaky_fn = void (*)(Tensor& /*out*/, const Tensor& /*qx*/,
|
||||||
const Scalar& /*negval_*/);
|
const Scalar& /*negval_*/);
|
||||||
using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, int64_t /* approximate */);
|
using qgelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
|
||||||
using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point);
|
using qsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/, double output_scale, int64_t output_zero_point);
|
||||||
using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
|
using qhardsigmoid_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/);
|
||||||
using qclamp_fn = void (*)(
|
using qclamp_fn = void (*)(
|
||||||
|
|
|
||||||
|
|
@ -973,17 +973,10 @@ TEST_F(FunctionalTest, GLU) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FunctionalTest, GELU) {
|
TEST_F(FunctionalTest, GELU) {
|
||||||
|
GELU model;
|
||||||
const auto x = torch::linspace(-3.0, 3.0, 100);
|
const auto x = torch::linspace(-3.0, 3.0, 100);
|
||||||
const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
|
const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
|
||||||
const auto y = F::gelu(x, F::GELUFuncOptions().approximate(torch::kNone));
|
const auto y = F::gelu(x);
|
||||||
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(FunctionalTest, TanhGELU) {
|
|
||||||
const auto x = torch::linspace(-3.0, 3.0, 100);
|
|
||||||
const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
|
|
||||||
const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
|
|
||||||
const auto y = F::gelu(x, F::GELUFuncOptions().approximate(torch::kTanh));
|
|
||||||
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
|
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2854,23 +2854,13 @@ TEST_F(ModulesTest, GLU) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModulesTest, GELU) {
|
TEST_F(ModulesTest, GELU) {
|
||||||
GELU model(GELUOptions().approximate(torch::kNone));
|
GELU model;
|
||||||
const auto x = torch::linspace(-3.0, 3.0, 100);
|
const auto x = torch::linspace(-3.0, 3.0, 100);
|
||||||
const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
|
const auto y_exp = x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));
|
||||||
const auto y = model(x);
|
const auto y = model(x);
|
||||||
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
|
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModulesTest, TanhGELU) {
|
|
||||||
GELU model(GELUOptions().approximate(torch::kTanh));
|
|
||||||
const auto x = torch::linspace(-3.0, 3.0, 100);
|
|
||||||
const auto inner = std::sqrt(2 / M_PI) * (x + 0.044715 * x.pow(3.0));
|
|
||||||
const auto y_exp = 0.5 * x * (1.0 + inner.tanh());
|
|
||||||
const auto y = model(x);
|
|
||||||
ASSERT_TRUE(torch::allclose(y, y_exp, 1.4e-06, 1e-05));
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
||||||
TEST_F(ModulesTest, Mish) {
|
TEST_F(ModulesTest, Mish) {
|
||||||
Mish model;
|
Mish model;
|
||||||
auto x = torch::randn(100) * 10;
|
auto x = torch::randn(100) * 10;
|
||||||
|
|
|
||||||
|
|
@ -50,8 +50,12 @@ ALLOW_LIST = [
|
||||||
("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
|
("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
|
||||||
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
|
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
|
||||||
("aten::randperm", datetime.date(9999, 1, 1)),
|
("aten::randperm", datetime.date(9999, 1, 1)),
|
||||||
("aten::gelu", datetime.date(2022, 3, 1)),
|
("aten::_conv_depthwise2d_backward", datetime.date(2022, 1, 31)),
|
||||||
("aten::gelu_backward", datetime.date(2022, 3, 1)),
|
("aten::conv_depthwise3d_backward", datetime.date(2022, 1, 31)),
|
||||||
|
("aten::cudnn_convolution.deprecated", datetime.date(2022, 1, 31)),
|
||||||
|
("aten::cudnn_convolution.deprecated2", datetime.date(2022, 1, 31)),
|
||||||
|
("aten::cudnn_convolution_transpose.deprecated", datetime.date(2022, 1, 31)),
|
||||||
|
("aten::cudnn_convolution_transpose.deprecated2", datetime.date(2022, 1, 31)),
|
||||||
("aten::cudnn_convolution_backward", datetime.date(2022, 1, 31)),
|
("aten::cudnn_convolution_backward", datetime.date(2022, 1, 31)),
|
||||||
("aten::cudnn_convolution_backward_input", datetime.date(2022, 1, 31)),
|
("aten::cudnn_convolution_backward_input", datetime.date(2022, 1, 31)),
|
||||||
("aten::cudnn_convolution_backward_weight", datetime.date(2022, 1, 31)),
|
("aten::cudnn_convolution_backward_weight", datetime.date(2022, 1, 31)),
|
||||||
|
|
|
||||||
|
|
@ -447,7 +447,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||||
%0 : int[] = prim::Constant[value=[2, 2, 1]]()
|
%0 : int[] = prim::Constant[value=[2, 2, 1]]()
|
||||||
%1 : int = prim::Constant[value=0]()
|
%1 : int = prim::Constant[value=0]()
|
||||||
%2 : Tensor = aten::t(%b)
|
%2 : Tensor = aten::t(%b)
|
||||||
%3 : Tensor = aten::relu(%2)
|
%3 : Tensor = aten::gelu(%2)
|
||||||
%4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%b, %3, %2)
|
%4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%b, %3, %2)
|
||||||
return (%4)
|
return (%4)
|
||||||
"""
|
"""
|
||||||
|
|
@ -471,7 +471,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||||
%1 : int = prim::Constant[value=0]()
|
%1 : int = prim::Constant[value=0]()
|
||||||
%d : Tensor = aten::t(%c)
|
%d : Tensor = aten::t(%c)
|
||||||
%2 : Tensor = aten::t(%b)
|
%2 : Tensor = aten::t(%b)
|
||||||
%3 : Tensor = aten::relu(%2)
|
%3 : Tensor = aten::gelu(%2)
|
||||||
%4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%3, %2, %d, %b, %c, %b)
|
%4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%3, %2, %d, %b, %c, %b)
|
||||||
return (%4)
|
return (%4)
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -136,7 +136,7 @@ class TestExportAsContribOps(unittest.TestCase):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gelu = torch.nn.GELU(approximate='none')
|
self.gelu = torch.nn.GELU()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
res = []
|
res = []
|
||||||
|
|
@ -149,7 +149,7 @@ class TestExportAsContribOps(unittest.TestCase):
|
||||||
res.append(x[0])
|
res.append(x[0])
|
||||||
return torch.stack(res), torch.stack(res2)
|
return torch.stack(res), torch.stack(res2)
|
||||||
|
|
||||||
def symbolic_custom_gelu(g, input, approximate):
|
def symbolic_custom_gelu(g, input):
|
||||||
return g.op("com.microsoft::Gelu", input).setType(input.type())
|
return g.op("com.microsoft::Gelu", input).setType(input.type())
|
||||||
|
|
||||||
from torch.onnx import register_custom_op_symbolic
|
from torch.onnx import register_custom_op_symbolic
|
||||||
|
|
@ -157,7 +157,7 @@ class TestExportAsContribOps(unittest.TestCase):
|
||||||
|
|
||||||
x = torch.randn(3, 3, 4, requires_grad=True)
|
x = torch.randn(3, 3, 4, requires_grad=True)
|
||||||
model = torch.jit.script(M())
|
model = torch.jit.script(M())
|
||||||
run_model_test(self, model, input=(x,))
|
run_model_test(self, model, input=(x, ))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -2383,17 +2383,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||||
def test_gelu(self):
|
def test_gelu(self):
|
||||||
class GeluModel(torch.nn.Module):
|
class GeluModel(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.nn.functional.gelu(x, 'none')
|
return torch.nn.functional.gelu(x)
|
||||||
|
|
||||||
model = GeluModel()
|
|
||||||
inputs = torch.randn(2, 4, 5, 6, requires_grad=True)
|
|
||||||
self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE)
|
|
||||||
|
|
||||||
@skipIfUnsupportedMinOpsetVersion(9)
|
|
||||||
def test_tanh_gelu(self):
|
|
||||||
class GeluModel(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return torch.nn.functional.gelu(x, 'tanh')
|
|
||||||
|
|
||||||
model = GeluModel()
|
model = GeluModel()
|
||||||
inputs = torch.randn(2, 4, 5, 6, requires_grad=True)
|
inputs = torch.randn(2, 4, 5, 6, requires_grad=True)
|
||||||
|
|
|
||||||
|
|
@ -6225,16 +6225,7 @@ class TestONNXRuntime(unittest.TestCase):
|
||||||
def test_gelu(self):
|
def test_gelu(self):
|
||||||
class GeluModel(torch.nn.Module):
|
class GeluModel(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.nn.functional.gelu(x, 'none')
|
return torch.nn.functional.gelu(x)
|
||||||
|
|
||||||
x = torch.randn(2, 4, 5, 6, requires_grad=True)
|
|
||||||
self.run_test(GeluModel(), x)
|
|
||||||
|
|
||||||
@skipIfUnsupportedMinOpsetVersion(9)
|
|
||||||
def test_tanh_gelu(self):
|
|
||||||
class GeluModel(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return torch.nn.functional.gelu(x, 'tanh')
|
|
||||||
|
|
||||||
x = torch.randn(2, 4, 5, 6, requires_grad=True)
|
x = torch.randn(2, 4, 5, 6, requires_grad=True)
|
||||||
self.run_test(GeluModel(), x)
|
self.run_test(GeluModel(), x)
|
||||||
|
|
|
||||||
|
|
@ -804,11 +804,11 @@ class TestUtilityFuns_opset9(_BaseTestCase):
|
||||||
def test_custom_opsets_gelu(self):
|
def test_custom_opsets_gelu(self):
|
||||||
self.addCleanup(unregister_custom_op_symbolic, "::gelu", 1)
|
self.addCleanup(unregister_custom_op_symbolic, "::gelu", 1)
|
||||||
|
|
||||||
def gelu(g, self, approximate):
|
def gelu(g, self):
|
||||||
return g.op("com.microsoft::Gelu", self).setType(self.type())
|
return g.op("com.microsoft::Gelu", self).setType(self.type())
|
||||||
|
|
||||||
register_custom_op_symbolic("::gelu", gelu, 1)
|
register_custom_op_symbolic("::gelu", gelu, 1)
|
||||||
model = torch.nn.GELU(approximate='none')
|
model = torch.nn.GELU()
|
||||||
x = torch.randn(3, 3)
|
x = torch.randn(3, 3)
|
||||||
f = io.BytesIO()
|
f = io.BytesIO()
|
||||||
torch.onnx.export(model, (x, ), f,
|
torch.onnx.export(model, (x, ), f,
|
||||||
|
|
@ -824,11 +824,11 @@ class TestUtilityFuns_opset9(_BaseTestCase):
|
||||||
def test_register_aten_custom_op_symbolic(self):
|
def test_register_aten_custom_op_symbolic(self):
|
||||||
self.addCleanup(unregister_custom_op_symbolic, "aten::gelu", 1)
|
self.addCleanup(unregister_custom_op_symbolic, "aten::gelu", 1)
|
||||||
|
|
||||||
def gelu(g, self, approximate):
|
def gelu(g, self):
|
||||||
return g.op("com.microsoft::Gelu", self).setType(self.type())
|
return g.op("com.microsoft::Gelu", self).setType(self.type())
|
||||||
|
|
||||||
register_custom_op_symbolic("aten::gelu", gelu, 1)
|
register_custom_op_symbolic("aten::gelu", gelu, 1)
|
||||||
model = torch.nn.GELU(approximate='none')
|
model = torch.nn.GELU()
|
||||||
x = torch.randn(3, 3)
|
x = torch.randn(3, 3)
|
||||||
f = io.BytesIO()
|
f = io.BytesIO()
|
||||||
torch.onnx.export(model, (x, ), f, opset_version=self.opset_version)
|
torch.onnx.export(model, (x, ), f, opset_version=self.opset_version)
|
||||||
|
|
|
||||||
|
|
@ -440,9 +440,8 @@ class TestQuantizedOps(TestCase):
|
||||||
shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
|
shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
|
||||||
dtypes = (torch.quint8, torch.qint8)
|
dtypes = (torch.quint8, torch.qint8)
|
||||||
memory_formats = (torch.channels_last, torch.contiguous_format)
|
memory_formats = (torch.channels_last, torch.contiguous_format)
|
||||||
approximation = ['none', 'tanh']
|
test_cases = itertools.product(shapes, dtypes, memory_formats)
|
||||||
test_cases = itertools.product(shapes, dtypes, memory_formats, approximation)
|
for shape, dtype, memory_format in test_cases:
|
||||||
for shape, dtype, memory_format, approximate in test_cases:
|
|
||||||
if memory_format == torch.channels_last and len(shape) != 4:
|
if memory_format == torch.channels_last and len(shape) != 4:
|
||||||
continue
|
continue
|
||||||
X, scale, zero_point, torch_type = \
|
X, scale, zero_point, torch_type = \
|
||||||
|
|
@ -454,7 +453,7 @@ class TestQuantizedOps(TestCase):
|
||||||
dqX = qX.dequantize()
|
dqX = qX.dequantize()
|
||||||
|
|
||||||
op = torch.nn.functional.gelu
|
op = torch.nn.functional.gelu
|
||||||
dqY = op(dqX, approximate)
|
dqY = op(dqX)
|
||||||
qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point,
|
qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point,
|
||||||
dtype=torch_type)
|
dtype=torch_type)
|
||||||
qY_hat = op(qX)
|
qY_hat = op(qX)
|
||||||
|
|
|
||||||
|
|
@ -3516,7 +3516,6 @@ class TestFunctionalTracing(JitTestCase):
|
||||||
"adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
|
"adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
|
||||||
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
|
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
|
||||||
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
|
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
|
||||||
"gelu": CONTROL_FLOW,
|
|
||||||
"hardshrink": ARG_TYPE_MISMATCH,
|
"hardshrink": ARG_TYPE_MISMATCH,
|
||||||
"layer_norm": ARG_TYPE_MISMATCH,
|
"layer_norm": ARG_TYPE_MISMATCH,
|
||||||
"lp_pool1d": ARG_TYPE_MISMATCH,
|
"lp_pool1d": ARG_TYPE_MISMATCH,
|
||||||
|
|
|
||||||
|
|
@ -1260,37 +1260,6 @@ class TestTEFuser(JitTestCase):
|
||||||
" ".join(["Failed:", str(dtype), 'isnan', device])
|
" ".join(["Failed:", str(dtype), 'isnan', device])
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_gelu(self):
|
|
||||||
def apply(fn):
|
|
||||||
return lambda x, approximate: fn(x, approximate)
|
|
||||||
|
|
||||||
unary_ops = [
|
|
||||||
F.gelu,
|
|
||||||
]
|
|
||||||
sizes = [(1,), (2,), (4, 4)]
|
|
||||||
for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes):
|
|
||||||
# TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
|
|
||||||
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
x = self.data_for(dtype, device, size=size)
|
|
||||||
cond = self.data_for(torch.bool, device)
|
|
||||||
fn = apply(op)
|
|
||||||
ref = fn(x, cond)
|
|
||||||
except Exception:
|
|
||||||
# If eager mode doesn't support a dtype/op/device combo,
|
|
||||||
# neither does the fuser. Catch everything to avoid needing to
|
|
||||||
# guess what errors might be thrown by eager.
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
t = torch.jit.trace(fn, (x, cond))
|
|
||||||
torch.testing.assert_close(ref, t(x, cond))
|
|
||||||
self.assertAllFused(t.graph_for(x, cond))
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
" ".join(["Failed:", str(dtype), op.__name__, device, str(size)])
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_unary_ops(self):
|
def test_unary_ops(self):
|
||||||
def apply(fn):
|
def apply(fn):
|
||||||
return lambda x: fn(x)
|
return lambda x: fn(x)
|
||||||
|
|
@ -1325,6 +1294,7 @@ class TestTEFuser(JitTestCase):
|
||||||
F.softplus,
|
F.softplus,
|
||||||
torch.sqrt,
|
torch.sqrt,
|
||||||
torch.rsqrt,
|
torch.rsqrt,
|
||||||
|
F.gelu,
|
||||||
torch.abs,
|
torch.abs,
|
||||||
torch.ceil,
|
torch.ceil,
|
||||||
torch.floor,
|
torch.floor,
|
||||||
|
|
@ -2237,6 +2207,7 @@ works_list = [
|
||||||
'mul',
|
'mul',
|
||||||
'ne',
|
'ne',
|
||||||
'neg',
|
'neg',
|
||||||
|
'nn.functional.gelu',
|
||||||
'nn.functional.hardshrink',
|
'nn.functional.hardshrink',
|
||||||
'nn.functional.hardsigmoid',
|
'nn.functional.hardsigmoid',
|
||||||
'nn.functional.hardswish',
|
'nn.functional.hardswish',
|
||||||
|
|
|
||||||
|
|
@ -9153,25 +9153,16 @@ class TestNN(NNTestCase):
|
||||||
def _gelu_ref(X):
|
def _gelu_ref(X):
|
||||||
return X * stats.norm.cdf(X)
|
return X * stats.norm.cdf(X)
|
||||||
|
|
||||||
def _tanh_gelu_ref(X):
|
for d in devices:
|
||||||
M_SQRT_2_PI = math.sqrt(2 / math.pi)
|
if contiguous:
|
||||||
Z = M_SQRT_2_PI * (X + 0.044715 * np.power(X, 3.0))
|
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)
|
||||||
return 0.5 * X * (1.0 + np.tanh(Z))
|
else:
|
||||||
|
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
|
||||||
for approximate in ['none', 'tanh']:
|
res = F.gelu(X)
|
||||||
for d in devices:
|
ref = _gelu_ref(X.to(numpy_dtype).cpu().detach().numpy())
|
||||||
if contiguous:
|
self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
|
||||||
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)
|
if dtype == torch.float64:
|
||||||
else:
|
gradcheck(F.gelu, [X], eps=1e-4)
|
||||||
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
|
|
||||||
res = F.gelu(X, approximate)
|
|
||||||
if approximate == 'tanh':
|
|
||||||
ref = _tanh_gelu_ref(X.to(numpy_dtype).cpu().detach().numpy())
|
|
||||||
else:
|
|
||||||
ref = _gelu_ref(X.to(numpy_dtype).cpu().detach().numpy())
|
|
||||||
self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
|
|
||||||
if dtype == torch.float64:
|
|
||||||
gradcheck(F.gelu, [X, approximate], eps=1e-4)
|
|
||||||
|
|
||||||
for n in range(1, 10):
|
for n in range(1, 10):
|
||||||
for m in range(1, 10):
|
for m in range(1, 10):
|
||||||
|
|
|
||||||
|
|
@ -1806,14 +1806,10 @@
|
||||||
- name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
|
- name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
|
||||||
self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result)
|
self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result)
|
||||||
|
|
||||||
- name: gelu(Tensor self, int approximate=0) -> Tensor
|
- name: gelu(Tensor self) -> Tensor
|
||||||
self: gelu_backward(grad, self, approximate)
|
self: "GradMode::is_enabled() ? infinitely_differentiable_gelu_backward(grad, self) : gelu_backward(grad, self)"
|
||||||
result: auto_element_wise
|
result: auto_element_wise
|
||||||
|
|
||||||
- name: gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor
|
|
||||||
grad_output: gelu_backward(grad, self, approximate)
|
|
||||||
self: gelu_double_backward(grad, grad_output, self, approximate)
|
|
||||||
|
|
||||||
- name: glu(Tensor self, int dim=-1) -> Tensor
|
- name: glu(Tensor self, int dim=-1) -> Tensor
|
||||||
self: glu_backward(grad, self, dim)
|
self: glu_backward(grad, self, dim)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include <ATen/core/Reduction.h>
|
#include <ATen/core/Reduction.h>
|
||||||
#include <ATen/native/Activation.h>
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/variant.h>
|
#include <c10/util/variant.h>
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
|
|
@ -80,11 +79,6 @@ std::string operator()(const enumtype::k##name& v) const { \
|
||||||
//
|
//
|
||||||
// Note that we also provide the default constructor `SomeOptions() {}`, so that
|
// Note that we also provide the default constructor `SomeOptions() {}`, so that
|
||||||
// `SomeOptions options = {}` can work.
|
// `SomeOptions options = {}` can work.
|
||||||
#define TORCH_OPTIONS_CTOR_VARIANT_ARG2(OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2) \
|
|
||||||
OPTIONS_NAME() {} \
|
|
||||||
OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
|
|
||||||
OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {}
|
|
||||||
|
|
||||||
#define TORCH_OPTIONS_CTOR_VARIANT_ARG3(OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3) \
|
#define TORCH_OPTIONS_CTOR_VARIANT_ARG3(OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3) \
|
||||||
OPTIONS_NAME() {} \
|
OPTIONS_NAME() {} \
|
||||||
OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
|
OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
|
||||||
|
|
@ -206,19 +200,5 @@ at::Reduction::Reduction reduction_get_enum(V variant_enum) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename V>
|
|
||||||
at::Gelu::Gelu gelu_get_enum(V variant_enum) {
|
|
||||||
if (c10::get_if<enumtype::kNone>(&variant_enum)) {
|
|
||||||
return at::Gelu::None;
|
|
||||||
} else if (c10::get_if<enumtype::kTanh>(&variant_enum)) {
|
|
||||||
return at::Gelu::Tanh;
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(
|
|
||||||
false,
|
|
||||||
get_enum_name(variant_enum), " is not a valid value for gelu approximate");
|
|
||||||
return at::Gelu::END;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace enumtype
|
} // namespace enumtype
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -336,16 +336,8 @@ inline Tensor glu(const Tensor& input, const GLUFuncOptions& options = {}) {
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
#ifndef DOXYGEN_SHOULD_SKIP_THIS
|
inline Tensor gelu(const Tensor& input) {
|
||||||
namespace detail {
|
return torch::gelu(input);
|
||||||
inline Tensor gelu(const Tensor& input, GELUFuncOptions::gelu_t approximate) {
|
|
||||||
return torch::gelu(input, enumtype::gelu_get_enum(approximate));
|
|
||||||
}
|
|
||||||
} // namespace detail
|
|
||||||
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
|
|
||||||
|
|
||||||
inline Tensor gelu(const Tensor& input, const GELUFuncOptions& options = {}) {
|
|
||||||
return detail::gelu(input, options.approximate());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
|
||||||
|
|
@ -570,17 +570,12 @@ TORCH_MODULE(GLU);
|
||||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||||
class TORCH_API GELUImpl : public torch::nn::Cloneable<GELUImpl> {
|
class TORCH_API GELUImpl : public torch::nn::Cloneable<GELUImpl> {
|
||||||
public:
|
public:
|
||||||
explicit GELUImpl(const GELUOptions& options_ = {});
|
|
||||||
|
|
||||||
Tensor forward(const Tensor& input);
|
Tensor forward(const Tensor& input);
|
||||||
|
|
||||||
void reset() override;
|
void reset() override;
|
||||||
|
|
||||||
/// Pretty prints the `GELU` module into the given `stream`.
|
/// Pretty prints the `GELU` module into the given `stream`.
|
||||||
void pretty_print(std::ostream& stream) const override;
|
void pretty_print(std::ostream& stream) const override;
|
||||||
|
|
||||||
/// The options with which this `Module` was constructed.
|
|
||||||
GELUOptions options;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A `ModuleHolder` subclass for `GELUImpl`.
|
/// A `ModuleHolder` subclass for `GELUImpl`.
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/arg.h>
|
#include <torch/arg.h>
|
||||||
#include <torch/enum.h>
|
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
#include <torch/types.h>
|
#include <torch/types.h>
|
||||||
|
|
||||||
|
|
@ -96,37 +95,6 @@ using GLUFuncOptions = GLUOptions;
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
/// Options for the `GELU` module.
|
|
||||||
///
|
|
||||||
/// Example:
|
|
||||||
/// ```
|
|
||||||
/// GELU model(GELUOptions(torch::kNone));
|
|
||||||
/// ```
|
|
||||||
struct TORCH_API GELUOptions {
|
|
||||||
typedef c10::variant<enumtype::kNone, enumtype::kTanh> gelu_t;
|
|
||||||
|
|
||||||
TORCH_OPTIONS_CTOR_VARIANT_ARG2(GELUOptions, approximate, kNone, kTanh)
|
|
||||||
|
|
||||||
/// Specifies the approximation to apply to the output.
|
|
||||||
TORCH_ARG(gelu_t, approximate) = torch::kNone;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace functional {
|
|
||||||
/// Options for `torch::nn::functional::gelu`.
|
|
||||||
///
|
|
||||||
/// See the documentation for `torch::nn::GELUOptions` class to learn what
|
|
||||||
/// arguments are supported.
|
|
||||||
///
|
|
||||||
/// Example:
|
|
||||||
/// ```
|
|
||||||
/// namespace F = torch::nn::functional;
|
|
||||||
/// F::gelu(input, F::GELUFuncOptions(torch::kNone));
|
|
||||||
/// ```
|
|
||||||
using GELUFuncOptions = GELUOptions;
|
|
||||||
} // namespace functional
|
|
||||||
|
|
||||||
// ============================================================================
|
|
||||||
|
|
||||||
/// Options for the `Hardshrink` module.
|
/// Options for the `Hardshrink` module.
|
||||||
///
|
///
|
||||||
/// Example:
|
/// Example:
|
||||||
|
|
|
||||||
|
|
@ -284,10 +284,8 @@ void GLUImpl::pretty_print(std::ostream& stream) const {
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
||||||
GELUImpl::GELUImpl(const GELUOptions& options_) : options(options_) {}
|
|
||||||
|
|
||||||
Tensor GELUImpl::forward(const Tensor& input) {
|
Tensor GELUImpl::forward(const Tensor& input) {
|
||||||
return F::detail::gelu(input, options.approximate());
|
return F::gelu(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GELUImpl::reset() {}
|
void GELUImpl::reset() {}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@
|
||||||
#include <ATen/ExpandUtils.h>
|
#include <ATen/ExpandUtils.h>
|
||||||
#include <ATen/native/IndexingUtils.h>
|
#include <ATen/native/IndexingUtils.h>
|
||||||
#include <ATen/native/LinearAlgebraUtils.h>
|
#include <ATen/native/LinearAlgebraUtils.h>
|
||||||
#include <ATen/native/Activation.h>
|
|
||||||
#include <ATen/ScalarOps.h>
|
#include <ATen/ScalarOps.h>
|
||||||
#include <ATen/SparseTensorUtils.h>
|
#include <ATen/SparseTensorUtils.h>
|
||||||
#include <ATen/Utils.h>
|
#include <ATen/Utils.h>
|
||||||
|
|
@ -2339,46 +2338,6 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor gelu_double_backward(
|
|
||||||
const Tensor & ggI,
|
|
||||||
const Tensor & gO,
|
|
||||||
const Tensor & input,
|
|
||||||
int64_t approximate) {
|
|
||||||
if (approximate == at::Gelu::Tanh) {
|
|
||||||
constexpr auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
|
|
||||||
constexpr auto kKappa = 0.044715;
|
|
||||||
|
|
||||||
auto inner = kBeta * (input + kKappa * pow(input, 3));
|
|
||||||
auto tanh_inner = tanh(inner);
|
|
||||||
auto sech_inner = 1 / cosh(inner);
|
|
||||||
|
|
||||||
auto f = 0.5 * input;
|
|
||||||
auto g = 1 - tanh_inner * tanh_inner;
|
|
||||||
auto h = kBeta * (1 + 3 * kKappa * input * input);
|
|
||||||
|
|
||||||
auto f_prime_gh = 0.5 * g * h;
|
|
||||||
|
|
||||||
auto g_prime = (2 * sech_inner) * (-sech_inner * tanh_inner) * h;
|
|
||||||
auto g_prime_fh = f * h * g_prime;
|
|
||||||
|
|
||||||
auto h_prime = 6 * kKappa * input * kBeta;
|
|
||||||
auto h_prime_fg = f * g * h_prime;
|
|
||||||
|
|
||||||
// left_derivative = f_prime_gh
|
|
||||||
// right_derivative = f_prime_gh + g_prime_fh + h_prime_fg
|
|
||||||
// dgrad_dX = left_derivative + right_derivative
|
|
||||||
auto gI = ggI * gO * (2 * f_prime_gh + g_prime_fh + h_prime_fg);
|
|
||||||
return gI;
|
|
||||||
} else {
|
|
||||||
constexpr auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
|
|
||||||
auto input_sq = input * input;
|
|
||||||
auto pdf = kBeta * at::exp(-0.5 * input_sq);
|
|
||||||
auto dgrad_dInput = 2 * pdf - input_sq * pdf;
|
|
||||||
auto gI = ggI * gO * dgrad_dInput;
|
|
||||||
return gI;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor elu_double_backward(
|
Tensor elu_double_backward(
|
||||||
const Tensor& grad,
|
const Tensor& grad,
|
||||||
const Tensor& grad_output,
|
const Tensor& grad_output,
|
||||||
|
|
|
||||||
|
|
@ -303,11 +303,6 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
|
||||||
const Tensor & grad_out,
|
const Tensor & grad_out,
|
||||||
const Tensor & input_,
|
const Tensor & input_,
|
||||||
const Tensor & weight_);
|
const Tensor & weight_);
|
||||||
Tensor gelu_double_backward(
|
|
||||||
const Tensor & ggI,
|
|
||||||
const Tensor & gO,
|
|
||||||
const Tensor & input,
|
|
||||||
int64_t approximate);
|
|
||||||
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_);
|
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_);
|
||||||
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask);
|
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask);
|
||||||
std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
|
std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,6 @@
|
||||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||||
#include <torch/csrc/jit/ir/constants.h>
|
#include <torch/csrc/jit/ir/constants.h>
|
||||||
|
|
||||||
#include <ATen/native/Activation.h>
|
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
|
@ -2275,8 +2273,7 @@ class IrParser {
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
auto ptr_op = getOperatorForLiteral(
|
auto ptr_op = getOperatorForLiteral("aten::gelu(Tensor self) -> Tensor");
|
||||||
"aten::gelu(Tensor self, int approximate=0) -> Tensor");
|
|
||||||
REGISTER_PARSE_RULE(
|
REGISTER_PARSE_RULE(
|
||||||
ptr_op,
|
ptr_op,
|
||||||
{
|
{
|
||||||
|
|
@ -2286,20 +2283,7 @@ class IrParser {
|
||||||
c10::nullopt, value_map[node->inputs()[0]->unique()]);
|
c10::nullopt, value_map[node->inputs()[0]->unique()]);
|
||||||
auto self = list_val.front();
|
auto self = list_val.front();
|
||||||
list_val.pop_front();
|
list_val.pop_front();
|
||||||
|
auto out = gelu(self);
|
||||||
auto approximate = constant_as<int64_t>(node->input(1));
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
approximate.has_value(),
|
|
||||||
"The approximate parameter is required.");
|
|
||||||
const bool kApproximate = approximate.value();
|
|
||||||
|
|
||||||
Val* out = nullptr;
|
|
||||||
if (kApproximate == at::Gelu::Tanh) {
|
|
||||||
out = fast_gelu(self);
|
|
||||||
} else {
|
|
||||||
out = unaryOp(UnaryOpType::Gelu, self);
|
|
||||||
}
|
|
||||||
|
|
||||||
value_map.emplace(
|
value_map.emplace(
|
||||||
node->output()->unique(), ValueHolder(out, format));
|
node->output()->unique(), ValueHolder(out, format));
|
||||||
},
|
},
|
||||||
|
|
@ -2309,7 +2293,7 @@ class IrParser {
|
||||||
|
|
||||||
{
|
{
|
||||||
auto ptr_op = getOperatorForLiteral(
|
auto ptr_op = getOperatorForLiteral(
|
||||||
"aten::gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor");
|
"aten::gelu_backward(Tensor grad, Tensor self) -> Tensor");
|
||||||
REGISTER_PARSE_RULE(
|
REGISTER_PARSE_RULE(
|
||||||
ptr_op,
|
ptr_op,
|
||||||
{
|
{
|
||||||
|
|
@ -2324,19 +2308,7 @@ class IrParser {
|
||||||
auto self = list_val.front();
|
auto self = list_val.front();
|
||||||
list_val.pop_front();
|
list_val.pop_front();
|
||||||
|
|
||||||
auto approximate = constant_as<int64_t>(node->input(2));
|
auto grad_in = gelu_backward(grad_out, self);
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
approximate.has_value(),
|
|
||||||
"The approximate parameter is required.");
|
|
||||||
const bool kApproximate = approximate.value();
|
|
||||||
|
|
||||||
Val* grad_in = nullptr;
|
|
||||||
if (kApproximate == at::Gelu::Tanh) {
|
|
||||||
grad_in = fast_gelu_backward(grad_out, self);
|
|
||||||
} else {
|
|
||||||
grad_in = gelu_backward(grad_out, self);
|
|
||||||
}
|
|
||||||
|
|
||||||
value_map.emplace(
|
value_map.emplace(
|
||||||
node->output()->unique(), ValueHolder(grad_in, format));
|
node->output()->unique(), ValueHolder(grad_in, format));
|
||||||
},
|
},
|
||||||
|
|
@ -3043,38 +3015,6 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static auto gelu_schema =
|
|
||||||
getOperatorForLiteral(
|
|
||||||
"aten::gelu(Tensor self, int approximate=0) -> Tensor")
|
|
||||||
->schema();
|
|
||||||
if (node->matches(gelu_schema)) {
|
|
||||||
switch (offset) {
|
|
||||||
// argument 1: approximate;
|
|
||||||
case 1:
|
|
||||||
profileInt(pr, node, offset);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static auto gelu_backward_schema =
|
|
||||||
getOperatorForLiteral(
|
|
||||||
"aten::gelu_backward(Tensor grad_output, Tensor self, int approximate=0) -> Tensor")
|
|
||||||
->schema();
|
|
||||||
if (node->matches(gelu_backward_schema)) {
|
|
||||||
switch (offset) {
|
|
||||||
// argument 2: approximate;
|
|
||||||
case 2:
|
|
||||||
profileInt(pr, node, offset);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static auto softmax_backward_data_schema =
|
static auto softmax_backward_data_schema =
|
||||||
getOperatorForLiteral(
|
getOperatorForLiteral(
|
||||||
"aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor")
|
"aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor")
|
||||||
|
|
|
||||||
|
|
@ -56,14 +56,6 @@ def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
|
||||||
{"full_out_0_4", R"SCRIPT(
|
{"full_out_0_4", R"SCRIPT(
|
||||||
def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
|
def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
|
||||||
return torch.full(size, fill_value, out=out)
|
return torch.full(size, fill_value, out=out)
|
||||||
)SCRIPT"},
|
|
||||||
{"gelu_0_8", R"SCRIPT(
|
|
||||||
def gelu_0_8(self: Tensor) -> Tensor:
|
|
||||||
return torch._C._nn.gelu(self, 0)
|
|
||||||
)SCRIPT"},
|
|
||||||
{"gelu_out_0_8", R"SCRIPT(
|
|
||||||
def gelu_out_0_8(self: Tensor, *, out: Tensor) -> Tensor:
|
|
||||||
return torch._C._nn.gelu(self, 0, out=out)
|
|
||||||
)SCRIPT"}});
|
)SCRIPT"}});
|
||||||
|
|
||||||
std::shared_ptr<Graph> create_upgrader_graph(
|
std::shared_ptr<Graph> create_upgrader_graph(
|
||||||
|
|
|
||||||
|
|
@ -43,16 +43,7 @@ static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersi
|
||||||
{"aten::full.out",
|
{"aten::full.out",
|
||||||
{{5,
|
{{5,
|
||||||
"full_out_0_4",
|
"full_out_0_4",
|
||||||
"aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}},
|
"aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}}});
|
||||||
{"aten::gelu",
|
|
||||||
{{9,
|
|
||||||
"gelu_0_8",
|
|
||||||
"aten::gelu(Tensor self) -> Tensor"}}},
|
|
||||||
{"aten::gelu.out",
|
|
||||||
{{9,
|
|
||||||
"gelu_out_0_8",
|
|
||||||
"aten::gelu(Tensor self, *, Tensor(a!) out) -> Tensor"}}}
|
|
||||||
});
|
|
||||||
|
|
||||||
const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
|
const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
|
||||||
get_operator_version_map() {
|
get_operator_version_map() {
|
||||||
|
|
|
||||||
|
|
@ -872,7 +872,7 @@ class ShapePropagator : public PropertyPropBase {
|
||||||
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
|
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
|
||||||
"aten::rsqrt(Tensor self) -> Tensor",
|
"aten::rsqrt(Tensor self) -> Tensor",
|
||||||
"aten::selu(Tensor self) -> Tensor",
|
"aten::selu(Tensor self) -> Tensor",
|
||||||
"aten::gelu(Tensor self, int approximate=0) -> Tensor",
|
"aten::gelu(Tensor self) -> Tensor",
|
||||||
"aten::sigmoid(Tensor self) -> Tensor",
|
"aten::sigmoid(Tensor self) -> Tensor",
|
||||||
"aten::sign(Tensor self) -> Tensor",
|
"aten::sign(Tensor self) -> Tensor",
|
||||||
"aten::sin(Tensor self) -> Tensor",
|
"aten::sin(Tensor self) -> Tensor",
|
||||||
|
|
|
||||||
|
|
@ -913,10 +913,16 @@ const std::vector<std::string> functions = {
|
||||||
return grad_output * torch.where(self > 0, 1.0, negative_slope).type_as(result), None
|
return grad_output * torch.where(self > 0, 1.0, negative_slope).type_as(result), None
|
||||||
return result, backward
|
return result, backward
|
||||||
|
|
||||||
def gelu(self : Tensor, approximate : int):
|
def gelu(self):
|
||||||
result = torch.gelu(self, approximate)
|
result = torch.gelu(self)
|
||||||
def backward(grad_output):
|
def backward(grad_output):
|
||||||
return torch.gelu_backward(grad_output, self, approximate), None
|
m_2_sqrtpi = 1.12837916709551257390
|
||||||
|
m_sqrt1_2 = 0.707106781186547524401
|
||||||
|
alpha = m_sqrt1_2
|
||||||
|
beta = m_2_sqrtpi * m_sqrt1_2 * 0.5
|
||||||
|
cdf = (torch.erf(self * m_sqrt1_2) + 1.0) * 0.5
|
||||||
|
pdf = beta * torch.exp(self * self * -0.5)
|
||||||
|
return grad_output * (cdf + self * pdf)
|
||||||
return result, backward
|
return result, backward
|
||||||
|
|
||||||
def hardswish(self):
|
def hardswish(self):
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ const OperatorMap<std::string>& get_tensorexpr_elementwise_set() {
|
||||||
{"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "unary"},
|
{"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "unary"},
|
||||||
{"aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor", "unary"},
|
{"aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor", "unary"},
|
||||||
{"aten::relu6(Tensor self) -> Tensor", "unary"},
|
{"aten::relu6(Tensor self) -> Tensor", "unary"},
|
||||||
{"aten::gelu(Tensor self, int approximate=0) -> Tensor", "unary"},
|
{"aten::gelu(Tensor self) -> Tensor", "unary"},
|
||||||
{"aten::neg(Tensor self) -> Tensor", "unary"},
|
{"aten::neg(Tensor self) -> Tensor", "unary"},
|
||||||
{"aten::reciprocal(Tensor self) -> Tensor", "unary"},
|
{"aten::reciprocal(Tensor self) -> Tensor", "unary"},
|
||||||
{"aten::expm1(Tensor self) -> Tensor", "unary"},
|
{"aten::expm1(Tensor self) -> Tensor", "unary"},
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,6 @@
|
||||||
#include <torch/csrc/jit/tensorexpr/lowerings.h>
|
#include <torch/csrc/jit/tensorexpr/lowerings.h>
|
||||||
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
|
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
|
||||||
|
|
||||||
#include <ATen/native/Activation.h>
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace tensorexpr {
|
namespace tensorexpr {
|
||||||
|
|
@ -643,34 +641,21 @@ int nnc_lowerings_lazy_registration() {
|
||||||
});
|
});
|
||||||
|
|
||||||
RegisterNNCLoweringsFunction aten_gelu(
|
RegisterNNCLoweringsFunction aten_gelu(
|
||||||
{"aten::gelu(Tensor self, int approximate=0) -> (Tensor)"},
|
{"aten::gelu(Tensor self) -> (Tensor)"},
|
||||||
[](const std::vector<ArgValue>& inputs,
|
[](const std::vector<ArgValue>& inputs,
|
||||||
const std::vector<ExprHandle>& outputShape,
|
const std::vector<ExprHandle>& outputShape,
|
||||||
const c10::optional<ScalarType>& outputType,
|
const c10::optional<ScalarType>& outputType,
|
||||||
at::Device device) {
|
at::Device device) {
|
||||||
return computeOneOperandWithCondition(
|
return computeOneOperand(
|
||||||
"aten_gelu",
|
"aten_gelu",
|
||||||
inputs,
|
inputs,
|
||||||
outputShape,
|
outputShape,
|
||||||
outputType,
|
outputType,
|
||||||
[](const ExprHandle& a, const ExprHandle& approximate) {
|
[](const ExprHandle& a) {
|
||||||
|
auto m_sqrt1_2 = Cast::make(a.dtype(), M_SQRT1_2);
|
||||||
auto one = Cast::make(a.dtype(), 1.);
|
auto one = Cast::make(a.dtype(), 1.);
|
||||||
auto point_five = Cast::make(a.dtype(), .5);
|
auto point_five = Cast::make(a.dtype(), .5);
|
||||||
auto tanh_gelu_flag = Cast::make(approximate.dtype(), at::Gelu::Tanh);
|
return a * point_five * (one + erf(a * m_sqrt1_2));
|
||||||
|
|
||||||
// approximate == 'none'
|
|
||||||
auto m_sqrt1_2 = Cast::make(a.dtype(), M_SQRT1_2);
|
|
||||||
auto gelu_result = a * point_five * (one + erf(a * m_sqrt1_2));
|
|
||||||
|
|
||||||
// approximate == 'tanh'
|
|
||||||
auto beta = Cast::make(a.dtype(), M_SQRT2 * M_2_SQRTPI * 0.5);
|
|
||||||
auto kappa = Cast::make(a.dtype(), 0.044715);
|
|
||||||
auto a_cube = a * a * a;
|
|
||||||
auto inner = beta * (a + kappa * a_cube);
|
|
||||||
auto tanh_gelu_result = point_five * a * (one + tanh(inner));
|
|
||||||
|
|
||||||
auto cs = CompareSelect::make(approximate, tanh_gelu_flag, kEQ);
|
|
||||||
return ifThenElse(cs, tanh_gelu_result, gelu_result);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,31 +43,6 @@ Tensor computeOneOperand(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor computeOneOperandWithCondition(
|
|
||||||
const std::string& name,
|
|
||||||
const std::vector<ArgValue>& inputValues,
|
|
||||||
const std::vector<ExprHandle>& outputShape,
|
|
||||||
const c10::optional<ScalarType>& outputType,
|
|
||||||
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
|
||||||
innerExpr,
|
|
||||||
const int checkParamTypes) {
|
|
||||||
return Compute(
|
|
||||||
name,
|
|
||||||
c10::fmap<DimArg>(outputShape),
|
|
||||||
[inputValues, outputType, innerExpr, checkParamTypes](
|
|
||||||
const std::vector<VarHandle>& axes) {
|
|
||||||
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
|
||||||
std::vector<ExprHandle> inputs = {
|
|
||||||
tensorOrConstant(inputValues[0], indices)};
|
|
||||||
|
|
||||||
promoteInputs(inputs, checkParamTypes);
|
|
||||||
// Last expr is the condition, which we don't promote
|
|
||||||
inputs.emplace_back(tensorOrConstant(inputValues[1], indices));
|
|
||||||
ExprHandle compute = innerExpr(inputs[0], inputs[1]);
|
|
||||||
return demoteOutput(compute, outputType);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor computeTwoOperand(
|
Tensor computeTwoOperand(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::vector<ArgValue>& inputValues,
|
const std::vector<ArgValue>& inputValues,
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,6 @@ Tensor computeOneOperand(
|
||||||
const c10::optional<ScalarType>& outputType,
|
const c10::optional<ScalarType>& outputType,
|
||||||
const std::function<ExprHandle(const ExprHandle&)>& innerExpr,
|
const std::function<ExprHandle(const ExprHandle&)>& innerExpr,
|
||||||
const int checkParamTypes = kAllTypes);
|
const int checkParamTypes = kAllTypes);
|
||||||
Tensor computeOneOperandWithCondition(
|
|
||||||
const std::string& name,
|
|
||||||
const std::vector<ArgValue>& inputValues,
|
|
||||||
const std::vector<ExprHandle>& outputShape,
|
|
||||||
const c10::optional<ScalarType>& outputType,
|
|
||||||
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
|
||||||
innerExpr,
|
|
||||||
const int checkParamTypes = kAllTypes);
|
|
||||||
Tensor computeTwoOperand(
|
Tensor computeTwoOperand(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::vector<ArgValue>& inputValues,
|
const std::vector<ArgValue>& inputValues,
|
||||||
|
|
|
||||||
|
|
@ -1,11 +0,0 @@
|
||||||
# Keep this file in sync with enums in aten/src/ATen/core/Gelu.h
|
|
||||||
|
|
||||||
def get_enum(gelu_approximation: str) -> int:
|
|
||||||
if gelu_approximation == 'none':
|
|
||||||
ret = 0
|
|
||||||
elif gelu_approximation == 'tanh':
|
|
||||||
ret = 1
|
|
||||||
else:
|
|
||||||
ret = -1 # TODO: remove once JIT exceptions support control flow
|
|
||||||
raise ValueError("{} is not a valid value for gelu approximation".format(gelu_approximation))
|
|
||||||
return ret
|
|
||||||
|
|
@ -20,7 +20,6 @@ from ..overrides import (
|
||||||
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
|
has_torch_function, has_torch_function_unary, has_torch_function_variadic,
|
||||||
handle_torch_function)
|
handle_torch_function)
|
||||||
from . import _reduction as _Reduction
|
from . import _reduction as _Reduction
|
||||||
from . import _gelu as _Gelu
|
|
||||||
from . import grad # noqa: F401
|
from . import grad # noqa: F401
|
||||||
from .modules import utils
|
from .modules import utils
|
||||||
from .modules.utils import _single, _pair, _triple, _list_with_default
|
from .modules.utils import _single, _pair, _triple, _list_with_default
|
||||||
|
|
@ -1652,31 +1651,19 @@ See :class:`~torch.nn.LogSigmoid` for more details.
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def gelu(input: Tensor, approximate: str = 'none') -> Tensor:
|
def gelu(input):
|
||||||
r"""gelu(input, approximate = 'none') -> Tensor
|
r"""gelu(input) -> Tensor
|
||||||
|
|
||||||
Applies element-wise the function
|
Applies element-wise the function
|
||||||
:math:`\text{GELU}(x) = x * \Phi(x)`
|
:math:`\text{GELU}(x) = x * \Phi(x)`
|
||||||
|
|
||||||
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
|
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
|
||||||
|
|
||||||
When the approximate argument is 'tanh', Gelu is estimated with:
|
|
||||||
:math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
|
|
||||||
|
|
||||||
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
|
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
|
||||||
"""
|
"""
|
||||||
if has_torch_function_unary(input):
|
if has_torch_function_unary(input):
|
||||||
return handle_torch_function(gelu, (input,), input, approximate=approximate)
|
return handle_torch_function(gelu, (input,), input)
|
||||||
|
return torch._C._nn.gelu(input)
|
||||||
# Enforce that the full call with the new kwarg is not invoked when scripting.
|
|
||||||
# TODO: Remove this scripting logic once the 2-week FC window has passed.
|
|
||||||
if not torch.jit.is_scripting():
|
|
||||||
return torch._C._nn.gelu(input, _Gelu.get_enum(approximate))
|
|
||||||
# When scripting, make a simpler call as long as the kwarg is set to the default value.
|
|
||||||
elif approximate == 'none':
|
|
||||||
return torch._C._nn.gelu(input)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("TorchScript currently does not support approximate in nn.Gelu")
|
|
||||||
|
|
||||||
|
|
||||||
def hardshrink(input: Tensor, lambd: float = 0.5) -> Tensor:
|
def hardshrink(input: Tensor, lambd: float = 0.5) -> Tensor:
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,7 @@ def rrelu(input: Tensor, lower: float = ..., upper: float = ..., training: bool
|
||||||
inplace: bool = ...) -> Tensor: ...
|
inplace: bool = ...) -> Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
def gelu(input: Any, approximate: str = ...): ...
|
def gelu(input: Any): ...
|
||||||
|
|
||||||
|
|
||||||
def hardshrink(input: Tensor, lambd: float = ...) -> Tensor: ...
|
def hardshrink(input: Tensor, lambd: float = ...) -> Tensor: ...
|
||||||
|
|
|
||||||
|
|
@ -654,13 +654,6 @@ class GELU(Module):
|
||||||
|
|
||||||
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
|
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
|
||||||
|
|
||||||
When the approximate argument is 'tanh', Gelu is estimated with:
|
|
||||||
:math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
|
|
||||||
|
|
||||||
Args:
|
|
||||||
approximate (string, optional): the gelu approximation algorithm to use:
|
|
||||||
``'none'`` | ``'tanh'``. Default: ``'none'``
|
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
||||||
- Output: :math:`(*)`, same shape as the input.
|
- Output: :math:`(*)`, same shape as the input.
|
||||||
|
|
@ -673,18 +666,8 @@ class GELU(Module):
|
||||||
>>> input = torch.randn(2)
|
>>> input = torch.randn(2)
|
||||||
>>> output = m(input)
|
>>> output = m(input)
|
||||||
"""
|
"""
|
||||||
__constants__ = ['approximate']
|
|
||||||
approximate: str
|
|
||||||
|
|
||||||
def __init__(self, approximate: str = 'none') -> None:
|
|
||||||
super(GELU, self).__init__()
|
|
||||||
self.approximate = approximate
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
return F.gelu(input, self.approximate)
|
return F.gelu(input)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return 'approximate={}'.format(self.approximate)
|
|
||||||
|
|
||||||
|
|
||||||
class Hardshrink(Module):
|
class Hardshrink(Module):
|
||||||
|
|
|
||||||
|
|
@ -3014,27 +3014,12 @@ def remainder(g, input, other):
|
||||||
quo = g.op("Mul", div, other)
|
quo = g.op("Mul", div, other)
|
||||||
return g.op("Sub", input, quo)
|
return g.op("Sub", input, quo)
|
||||||
|
|
||||||
@parse_args("v", "i")
|
|
||||||
def gelu(g, self, approximate):
|
|
||||||
# none approximate : onnx::Constant[value={0}]
|
|
||||||
# tanh approximate : onnx::Constant[value={1}]
|
|
||||||
if approximate == 1:
|
|
||||||
kBeta = math.sqrt(2 / math.pi)
|
|
||||||
kKappa = 0.044715
|
|
||||||
|
|
||||||
beta = torch.tensor(kBeta, dtype=torch.double)
|
def gelu(g, self):
|
||||||
kappa = torch.tensor(kKappa, dtype=torch.double)
|
_sqrt2 = 1.4142135623730951
|
||||||
one = torch.tensor(1., dtype=torch.double)
|
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
|
||||||
half = torch.tensor(0.5, dtype=torch.double)
|
erf_plusone = add(g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)))
|
||||||
|
return mul(g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)))
|
||||||
self_cube = mul(g, self, mul(g, self, self))
|
|
||||||
inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube)))
|
|
||||||
return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner))))
|
|
||||||
else:
|
|
||||||
_sqrt2 = 1.4142135623730951
|
|
||||||
erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double)))
|
|
||||||
erf_plusone = add(g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)))
|
|
||||||
return mul(g, mul(g, self, erf_plusone), g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)))
|
|
||||||
|
|
||||||
@parse_args("v", "i", "v", "v", "f", "i")
|
@parse_args("v", "i", "v", "v", "f", "i")
|
||||||
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
|
def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
|
||||||
|
|
|
||||||
|
|
@ -730,7 +730,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
|
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
|
||||||
_random_samples=None: -1),
|
_random_samples=None: -1),
|
||||||
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
|
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
|
||||||
torch.nn.functional.gelu: lambda input, approximate='none': -1,
|
torch.nn.functional.gelu: lambda input: -1,
|
||||||
torch.nn.functional.glu: lambda input, dim=-1: -1,
|
torch.nn.functional.glu: lambda input, dim=-1: -1,
|
||||||
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
|
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
|
||||||
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
|
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
|
||||||
|
|
|
||||||
|
|
@ -327,8 +327,7 @@ class AutocastCPUTestLists(object):
|
||||||
self.nn_fp32 = [
|
self.nn_fp32 = [
|
||||||
("avg_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}),
|
("avg_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}),
|
||||||
("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
|
("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
|
||||||
("gelu", dummy_bf16[3], {"approximate": torch.nn._gelu.get_enum('none')}),
|
("gelu", dummy_bf16[3]),
|
||||||
("gelu", dummy_bf16[3], {"approximate": torch.nn._gelu.get_enum('tanh')}),
|
|
||||||
("upsample_nearest1d", dummy_bf16[2], {"output_size": (n)}),
|
("upsample_nearest1d", dummy_bf16[2], {"output_size": (n)}),
|
||||||
("upsample_nearest2d", dummy_bf16[3], {"output_size": (n, n)}),
|
("upsample_nearest2d", dummy_bf16[3], {"output_size": (n, n)}),
|
||||||
("upsample_nearest3d", dummy_bf16[4], {"output_size": (n, n, n)}),
|
("upsample_nearest3d", dummy_bf16[4], {"output_size": (n, n, n)}),
|
||||||
|
|
|
||||||
|
|
@ -3903,6 +3903,7 @@ def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
|
||||||
# With `None` weight and bias (tests failing for this, see the link above)
|
# With `None` weight and bias (tests failing for this, see the link above)
|
||||||
# yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,))))
|
# yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,))))
|
||||||
|
|
||||||
|
|
||||||
def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs):
|
def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs):
|
||||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
|
@ -3924,6 +3925,7 @@ def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kw
|
||||||
for input_shape, size, kwargs in cases:
|
for input_shape, size, kwargs in cases:
|
||||||
yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs)
|
yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs)
|
||||||
|
|
||||||
|
|
||||||
def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs):
|
def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs):
|
||||||
N = 5
|
N = 5
|
||||||
# make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ?
|
# make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ?
|
||||||
|
|
@ -4080,13 +4082,8 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
|
||||||
|
|
||||||
def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs):
|
def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs):
|
||||||
N = 5
|
N = 5
|
||||||
tensors = []
|
tensors = [SampleInput(make_tensor((N * 2, N * 2), device=device, dtype=dtype,
|
||||||
for _ in range(1, N):
|
requires_grad=requires_grad, low=-3, high=3)) for _ in range(1, N)]
|
||||||
for approximate in ['none', 'tanh']:
|
|
||||||
tensors.append(SampleInput(
|
|
||||||
make_tensor((N * 2, N * 2), device=device, dtype=dtype,
|
|
||||||
requires_grad=requires_grad, low=-3, high=3),
|
|
||||||
kwargs=dict(approximate=approximate)))
|
|
||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs):
|
def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
|
|
@ -11776,7 +11773,7 @@ op_db: List[OpInfo] = [
|
||||||
supports_gradgrad=True,
|
supports_gradgrad=True,
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
supports_fwgrad_bwgrad=False,
|
supports_fwgrad_bwgrad=True,
|
||||||
autodiff_nonfusible_nodes=["aten::gelu"]),
|
autodiff_nonfusible_nodes=["aten::gelu"]),
|
||||||
OpInfo('nn.functional.relu6',
|
OpInfo('nn.functional.relu6',
|
||||||
aten_name="relu6",
|
aten_name="relu6",
|
||||||
|
|
|
||||||
|
|
@ -3716,16 +3716,12 @@ new_module_tests = [
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
module_name='GELU',
|
module_name='GELU',
|
||||||
constructor_args=('none',),
|
|
||||||
cpp_constructor_args='torch::nn::GELUOptions().approximate(torch::kNone)',
|
|
||||||
input_size=(),
|
input_size=(),
|
||||||
desc='scalar',
|
desc='scalar',
|
||||||
reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
|
reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
module_name='GELU',
|
module_name='GELU',
|
||||||
constructor_args=('none',),
|
|
||||||
cpp_constructor_args='torch::nn::GELUOptions().approximate(torch::kNone)',
|
|
||||||
input_size=(3, 2, 5),
|
input_size=(3, 2, 5),
|
||||||
reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
|
reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
|
||||||
),
|
),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user