mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Port softshrink to structured (#57623)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57623 Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D28224703 Pulled By: ezyang fbshipit-source-id: 62e40d53eb130205f6c4d2775082e436e6adadce
This commit is contained in:
parent
d65dff463a
commit
f23e10f27b
|
|
@ -66,6 +66,18 @@ TORCH_META_FUNC(hardsigmoid) (const Tensor& self) {
|
|||
build_unary_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
static inline void softshrink_check(const Scalar& lambd) {
|
||||
double lamb = lambd.to<double>();
|
||||
TORCH_CHECK(lamb >= 0, "lambda must be greater or equal to 0, but found to be ", lamb, ".");
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(softshrink) (
|
||||
const Tensor & self, const Scalar& lambd
|
||||
) {
|
||||
softshrink_check(lambd);
|
||||
build_unary_op(maybe_get_output(), self);
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
|
||||
namespace native {
|
||||
|
|
@ -136,6 +148,12 @@ TORCH_IMPL_FUNC(hardsigmoid_out) (
|
|||
hardsigmoid_stub(device_type(), *this);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(softshrink_out) (
|
||||
const Tensor & self, const Scalar& lambd, const Tensor& result
|
||||
) {
|
||||
softshrink_stub(device_type(), *this, lambd);
|
||||
}
|
||||
|
||||
Tensor hardtanh(const Tensor& self, const Scalar& min, const Scalar& max) {
|
||||
return at::clamp(self, min, max);
|
||||
}
|
||||
|
|
@ -699,25 +717,6 @@ Tensor hardshrink_backward(const Tensor & grad, const Tensor & self, const Scala
|
|||
return out_tensor;
|
||||
}
|
||||
|
||||
static inline void softshrink_check(const Scalar& lambd) {
|
||||
double lamb = lambd.to<double>();
|
||||
TORCH_CHECK(lamb >= 0, "lambda must be greater or equal to 0, but found to be ", lamb, ".");
|
||||
}
|
||||
|
||||
Tensor& softshrink_out(const Tensor & self, const Scalar& lambd, Tensor& result) {
|
||||
softshrink_check(lambd);
|
||||
auto iter = TensorIterator::unary_op(result, self);
|
||||
softshrink_stub(iter.device_type(), iter, lambd);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor softshrink(const Tensor & self, const Scalar& lambd) {
|
||||
softshrink_check(lambd);
|
||||
Tensor result;
|
||||
auto iter = TensorIterator::unary_op(result, self);
|
||||
softshrink_stub(iter.device_type(), iter, lambd);
|
||||
return iter.output();
|
||||
}
|
||||
|
||||
Tensor& softshrink_backward_out(const Tensor & grad, const Tensor & self, const Scalar& lambd, Tensor& grad_input) {
|
||||
auto iter = TensorIterator::binary_op(grad_input, grad, self);
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ using hardsigmoid_backward_fn = void(*)(TensorIterator&);
|
|||
using hardswish_fn = void(*)(TensorIterator&);
|
||||
using hardswish_backward_fn = void(*)(TensorIterator&);
|
||||
using shrink_fn = void (*)(TensorIterator&, const Scalar&);
|
||||
using softshrink_fn = void (*)(TensorIteratorBase&, const Scalar&);
|
||||
using shrink_backward_fn = void (*)(TensorIterator&, const Scalar&);
|
||||
using elu_fn = void (*)(TensorIteratorBase&, const Scalar&, const Scalar&, const Scalar&);
|
||||
using elu_backward_fn = void (*)(TensorIterator&, const Scalar&, const Scalar&, const Scalar&, bool);
|
||||
|
|
@ -43,7 +44,7 @@ DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
|
|||
DECLARE_DISPATCH(hardswish_fn, hardswish_stub);
|
||||
DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub);
|
||||
DECLARE_DISPATCH(shrink_fn, hardshrink_stub);
|
||||
DECLARE_DISPATCH(shrink_fn, softshrink_stub);
|
||||
DECLARE_DISPATCH(softshrink_fn, softshrink_stub);
|
||||
DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);
|
||||
DECLARE_DISPATCH(leaky_relu_fn, leaky_relu_stub);
|
||||
DECLARE_DISPATCH(leaky_relu_backward_fn, leaky_relu_backward_stub);
|
||||
|
|
|
|||
|
|
@ -384,7 +384,7 @@ void hardshrink_kernel(TensorIterator& iter, const Scalar& lambd) {
|
|||
});
|
||||
}
|
||||
|
||||
void softshrink_kernel(TensorIterator& iter, const Scalar& lambd) {
|
||||
void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambd) {
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "softshrink_cpu", [&]() {
|
||||
auto lambd_val = lambd.to<scalar_t>();
|
||||
cpu_kernel(iter, [=](scalar_t a) -> scalar_t {
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ void hardshrink_kernel(TensorIterator& iter, const Scalar& value) {
|
|||
});
|
||||
}
|
||||
|
||||
void softshrink_kernel(TensorIterator& iter, const Scalar& value) {
|
||||
void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softshrink_cuda", [&]() {
|
||||
auto lambd = value.to<scalar_t>();
|
||||
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
|
|
|
|||
|
|
@ -8223,16 +8223,17 @@
|
|||
CPU, CUDA: softplus_backward
|
||||
|
||||
- func: softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
device_check: NoCheck # TensorIterator
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU, CUDA: softshrink_out
|
||||
|
||||
- func: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor
|
||||
structured_delegate: softshrink.out
|
||||
device_check: NoCheck # TensorIterator
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU, CUDA: softshrink
|
||||
|
||||
- func: softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
python_module: nn
|
||||
|
|
|
|||
|
|
@ -16052,7 +16052,6 @@ class TestNNDeviceType(NNTestCase):
|
|||
output = model(input)
|
||||
torch.autograd.gradcheck(model, input)
|
||||
|
||||
@expectedFailureMeta # https://github.com/pytorch/pytorch/issues/54897
|
||||
def test_softshrink_inplace_overlap(self, device):
|
||||
x = torch.randn((1, 6), device=device).expand((6, 6))
|
||||
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user