hardshrink_cpu and hardshrink_backward_cpu refactoring with at::native::cpu_kernel

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22459

Differential Revision: D16132625

Pulled By: pbelevich

fbshipit-source-id: d7eb1cd6ed04eba3d0c54feaca1e5ab2836211b5
This commit is contained in:
Pavel Belevich 2019-07-16 18:43:26 -07:00 committed by Facebook Github Bot
parent ef36046ad7
commit bcfa023a00
4 changed files with 62 additions and 24 deletions

View File

@ -14,6 +14,8 @@ static const double SELU_ALPHA = 1.6732632423543772848170429916717;
static const double SELU_SCALE = 1.0507009873554804934193349852946;
DEFINE_DISPATCH(threshold_stub);
DEFINE_DISPATCH(hardshrink_cpu_stub);
DEFINE_DISPATCH(hardshrink_backward_cpu_stub);
Tensor relu(const Tensor & self) {
return at::threshold(self, 0, 0);
@ -339,35 +341,15 @@ std::tuple<Tensor, Tensor> prelu_backward_cpu(const Tensor& grad_out_, const Ten
// -----------------------------------
Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) {
auto out_tensor = at::empty_like(self);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "hardshrink_cpu", [&] {
auto lambd_val = lambd.to<scalar_t>();
at::CPU_tensor_apply2<scalar_t, scalar_t>(
self,
out_tensor,
[&](
scalar_t& self_val,
scalar_t& out_tensor_val) {
out_tensor_val = (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : self_val;
});
});
auto iter = TensorIterator::unary_op(out_tensor, self);
hardshrink_cpu_stub(kCPU, *iter, lambd);
return out_tensor;
}
Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar lambd) {
auto out_tensor = at::empty_like(self);
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "hardshrink_backward_cpu", [&] {
auto lambd_val = lambd.to<scalar_t>();
at::CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
self,
grad,
out_tensor,
[&](
scalar_t& self_val,
scalar_t& grad_val,
scalar_t& out_tensor_val) {
out_tensor_val = (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : grad_val;
});
});
auto iter = TensorIterator::binary_op(out_tensor, grad, self);
hardshrink_backward_cpu_stub(kCPU, *iter, lambd);
return out_tensor;
}

View File

@ -14,10 +14,14 @@ using threshold_fn = void (*)(TensorIterator&, Scalar, Scalar);
using activation_fn = void (*)(const Tensor& /* X */, Tensor* /* Y */);
using activation_backward_fn =
void (*)(const Tensor& /* dY */, const Tensor& /* X */, Tensor* /* dX */);
using hardshrink_cpu_fn = void (*)(TensorIterator&, Scalar);
using hardshrink_backward_cpu_fn = void (*)(TensorIterator&, Scalar);
DECLARE_DISPATCH(threshold_fn, threshold_stub);
DECLARE_DISPATCH(activation_fn, GeluKernel);
DECLARE_DISPATCH(activation_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(hardshrink_cpu_fn, hardshrink_cpu_stub);
DECLARE_DISPATCH(hardshrink_backward_cpu_fn, hardshrink_backward_cpu_stub);
} // namespace native

View File

@ -175,11 +175,41 @@ void GeluBackwardKernelImpl(const Tensor& dY, const Tensor& X, Tensor* dX) {
}
}
void hardshrink_cpu_kernel(TensorIterator& iter, Scalar lambd) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_cpu", [&] {
auto lambd_val = lambd.to<scalar_t>();
cpu_kernel_vec(iter,
[=](scalar_t self_val) {
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : self_val;
},
[=](Vec256<scalar_t> self_val) {
return ((self_val < -lambd_val) | (self_val > lambd_val)) & self_val;
}
);
});
}
void hardshrink_backward_cpu_kernel(TensorIterator& iter, Scalar lambd) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] {
auto lambd_val = lambd.to<scalar_t>();
cpu_kernel_vec(iter,
[=](scalar_t grad_val, scalar_t self_val) {
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : grad_val;
},
[=](Vec256<scalar_t> grad_val, Vec256<scalar_t> self_val) {
return ((self_val < -lambd_val) | (self_val > lambd_val)) & grad_val;
}
);
});
}
} // namespace
REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
REGISTER_DISPATCH(GeluKernel, &GeluKernelImpl);
REGISTER_DISPATCH(GeluBackwardKernel, &GeluBackwardKernelImpl);
REGISTER_DISPATCH(hardshrink_cpu_stub, &hardshrink_cpu_kernel);
REGISTER_DISPATCH(hardshrink_backward_cpu_stub, &hardshrink_backward_cpu_kernel);
} // namespace native
} // namespace at

View File

@ -8242,6 +8242,28 @@ class _TestTorchMixin(object):
# test non-contiguous case
self.assertEqual(torch.tensor([1, 0, 0.5, 0.6]).view(2, 2), data.t().hardshrink(0.3))
def test_hardshrink_edge_cases(self):
def h(t, values, l_expected):
for l, expected in l_expected.items():
values_tensor = torch.tensor([float(v) for v in values]).type(t)
expected_tensor = torch.tensor([float(v) for v in expected]).type(t)
self.assertEqual(expected_tensor == values_tensor.hardshrink(l),
torch.ones_like(values_tensor))
def test_helper(t, min, max):
h(t, [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
{0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf],
1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf],
max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf],
inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]})
test_helper(torch.DoubleTensor,
torch.finfo(torch.double).tiny, torch.finfo(torch.double).max)
test_helper(torch.FloatTensor,
torch.finfo(torch.float).tiny, torch.finfo(torch.float).max)
def test_unbiased(self):
tensor = torch.randn(100)
self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))