Port softshrink to structured (#57623)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57623

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D28224703

Pulled By: ezyang

fbshipit-source-id: 62e40d53eb130205f6c4d2775082e436e6adadce
This commit is contained in:
Freey0 2021-05-14 00:46:19 -07:00 committed by Facebook GitHub Bot
parent d65dff463a
commit f23e10f27b
6 changed files with 25 additions and 25 deletions

View File

@ -66,6 +66,18 @@ TORCH_META_FUNC(hardsigmoid) (const Tensor& self) {
build_unary_op(maybe_get_output(), self);
}
static inline void softshrink_check(const Scalar& lambd) {
double lamb = lambd.to<double>();
TORCH_CHECK(lamb >= 0, "lambda must be greater or equal to 0, but found to be ", lamb, ".");
}
TORCH_META_FUNC(softshrink) (
const Tensor & self, const Scalar& lambd
) {
softshrink_check(lambd);
build_unary_op(maybe_get_output(), self);
}
} // namespace meta
namespace native {
@ -136,6 +148,12 @@ TORCH_IMPL_FUNC(hardsigmoid_out) (
hardsigmoid_stub(device_type(), *this);
}
TORCH_IMPL_FUNC(softshrink_out) (
const Tensor & self, const Scalar& lambd, const Tensor& result
) {
softshrink_stub(device_type(), *this, lambd);
}
Tensor hardtanh(const Tensor& self, const Scalar& min, const Scalar& max) {
return at::clamp(self, min, max);
}
@ -699,25 +717,6 @@ Tensor hardshrink_backward(const Tensor & grad, const Tensor & self, const Scala
return out_tensor;
}
static inline void softshrink_check(const Scalar& lambd) {
double lamb = lambd.to<double>();
TORCH_CHECK(lamb >= 0, "lambda must be greater or equal to 0, but found to be ", lamb, ".");
}
Tensor& softshrink_out(const Tensor & self, const Scalar& lambd, Tensor& result) {
softshrink_check(lambd);
auto iter = TensorIterator::unary_op(result, self);
softshrink_stub(iter.device_type(), iter, lambd);
return result;
}
Tensor softshrink(const Tensor & self, const Scalar& lambd) {
softshrink_check(lambd);
Tensor result;
auto iter = TensorIterator::unary_op(result, self);
softshrink_stub(iter.device_type(), iter, lambd);
return iter.output();
}
Tensor& softshrink_backward_out(const Tensor & grad, const Tensor & self, const Scalar& lambd, Tensor& grad_input) {
auto iter = TensorIterator::binary_op(grad_input, grad, self);

View File

@ -21,6 +21,7 @@ using hardsigmoid_backward_fn = void(*)(TensorIterator&);
using hardswish_fn = void(*)(TensorIterator&);
using hardswish_backward_fn = void(*)(TensorIterator&);
using shrink_fn = void (*)(TensorIterator&, const Scalar&);
using softshrink_fn = void (*)(TensorIteratorBase&, const Scalar&);
using shrink_backward_fn = void (*)(TensorIterator&, const Scalar&);
using elu_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&, const Scalar&);
using elu_backward_fn = void (*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&, bool);
@ -43,7 +44,7 @@ DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
DECLARE_DISPATCH(hardswish_fn, hardswish_stub);
DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub);
DECLARE_DISPATCH(shrink_fn, hardshrink_stub);
DECLARE_DISPATCH(shrink_fn, softshrink_stub);
DECLARE_DISPATCH(softshrink_fn, softshrink_stub);
DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);
DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub);
DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);

View File

@ -384,7 +384,7 @@ void hardshrink_kernel(TensorIterator& iter, const Scalar& lambd) {
});
}
void softshrink_kernel(TensorIterator& iter, const Scalar& lambd) {
void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "softshrink_cpu", [&]() {
auto lambd_val = lambd.to<scalar_t>();
cpu_kernel(iter, [=](scalar_t a) -> scalar_t {

View File

@ -255,7 +255,7 @@ void hardshrink_kernel(TensorIterator& iter, const Scalar& value) {
});
}
void softshrink_kernel(TensorIterator& iter, const Scalar& value) {
void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softshrink_cuda", [&]() {
auto lambd = value.to<scalar_t>();
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {

View File

@ -8223,16 +8223,17 @@
CPU, CUDA: softplus_backward
- func: softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU, CUDA: softshrink_out
- func: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor
structured_delegate: softshrink.out
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU, CUDA: softshrink
- func: softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn

View File

@ -16052,7 +16052,6 @@ class TestNNDeviceType(NNTestCase):
output = model(input)
torch.autograd.gradcheck(model, input)
@expectedFailureMeta # https://github.com/pytorch/pytorch/issues/54897
def test_softshrink_inplace_overlap(self, device):
x = torch.randn((1, 6), device=device).expand((6, 6))
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):