mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
hardshrink_cpu and hardshrink_backward_cpu refactoring with at::native::cpu_kernel
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22459 Differential Revision: D16132625 Pulled By: pbelevich fbshipit-source-id: d7eb1cd6ed04eba3d0c54feaca1e5ab2836211b5
This commit is contained in:
parent
ef36046ad7
commit
bcfa023a00
|
|
@ -14,6 +14,8 @@ static const double SELU_ALPHA = 1.6732632423543772848170429916717;
|
|||
static const double SELU_SCALE = 1.0507009873554804934193349852946;
|
||||
|
||||
DEFINE_DISPATCH(threshold_stub);
|
||||
DEFINE_DISPATCH(hardshrink_cpu_stub);
|
||||
DEFINE_DISPATCH(hardshrink_backward_cpu_stub);
|
||||
|
||||
Tensor relu(const Tensor & self) {
|
||||
return at::threshold(self, 0, 0);
|
||||
|
|
@ -339,35 +341,15 @@ std::tuple<Tensor, Tensor> prelu_backward_cpu(const Tensor& grad_out_, const Ten
|
|||
// -----------------------------------
|
||||
Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) {
|
||||
auto out_tensor = at::empty_like(self);
|
||||
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "hardshrink_cpu", [&] {
|
||||
auto lambd_val = lambd.to<scalar_t>();
|
||||
at::CPU_tensor_apply2<scalar_t, scalar_t>(
|
||||
self,
|
||||
out_tensor,
|
||||
[&](
|
||||
scalar_t& self_val,
|
||||
scalar_t& out_tensor_val) {
|
||||
out_tensor_val = (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : self_val;
|
||||
});
|
||||
});
|
||||
auto iter = TensorIterator::unary_op(out_tensor, self);
|
||||
hardshrink_cpu_stub(kCPU, *iter, lambd);
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar lambd) {
|
||||
auto out_tensor = at::empty_like(self);
|
||||
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "hardshrink_backward_cpu", [&] {
|
||||
auto lambd_val = lambd.to<scalar_t>();
|
||||
at::CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(
|
||||
self,
|
||||
grad,
|
||||
out_tensor,
|
||||
[&](
|
||||
scalar_t& self_val,
|
||||
scalar_t& grad_val,
|
||||
scalar_t& out_tensor_val) {
|
||||
out_tensor_val = (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : grad_val;
|
||||
});
|
||||
});
|
||||
auto iter = TensorIterator::binary_op(out_tensor, grad, self);
|
||||
hardshrink_backward_cpu_stub(kCPU, *iter, lambd);
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,10 +14,14 @@ using threshold_fn = void (*)(TensorIterator&, Scalar, Scalar);
|
|||
using activation_fn = void (*)(const Tensor& /* X */, Tensor* /* Y */);
|
||||
using activation_backward_fn =
|
||||
void (*)(const Tensor& /* dY */, const Tensor& /* X */, Tensor* /* dX */);
|
||||
using hardshrink_cpu_fn = void (*)(TensorIterator&, Scalar);
|
||||
using hardshrink_backward_cpu_fn = void (*)(TensorIterator&, Scalar);
|
||||
|
||||
DECLARE_DISPATCH(threshold_fn, threshold_stub);
|
||||
DECLARE_DISPATCH(activation_fn, GeluKernel);
|
||||
DECLARE_DISPATCH(activation_backward_fn, GeluBackwardKernel);
|
||||
DECLARE_DISPATCH(hardshrink_cpu_fn, hardshrink_cpu_stub);
|
||||
DECLARE_DISPATCH(hardshrink_backward_cpu_fn, hardshrink_backward_cpu_stub);
|
||||
|
||||
} // namespace native
|
||||
|
||||
|
|
|
|||
|
|
@ -175,11 +175,41 @@ void GeluBackwardKernelImpl(const Tensor& dY, const Tensor& X, Tensor* dX) {
|
|||
}
|
||||
}
|
||||
|
||||
void hardshrink_cpu_kernel(TensorIterator& iter, Scalar lambd) {
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_cpu", [&] {
|
||||
auto lambd_val = lambd.to<scalar_t>();
|
||||
cpu_kernel_vec(iter,
|
||||
[=](scalar_t self_val) {
|
||||
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : self_val;
|
||||
},
|
||||
[=](Vec256<scalar_t> self_val) {
|
||||
return ((self_val < -lambd_val) | (self_val > lambd_val)) & self_val;
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
void hardshrink_backward_cpu_kernel(TensorIterator& iter, Scalar lambd) {
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] {
|
||||
auto lambd_val = lambd.to<scalar_t>();
|
||||
cpu_kernel_vec(iter,
|
||||
[=](scalar_t grad_val, scalar_t self_val) {
|
||||
return (self_val >= -lambd_val && self_val <= lambd_val) ? scalar_t(0) : grad_val;
|
||||
},
|
||||
[=](Vec256<scalar_t> grad_val, Vec256<scalar_t> self_val) {
|
||||
return ((self_val < -lambd_val) | (self_val > lambd_val)) & grad_val;
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
|
||||
REGISTER_DISPATCH(GeluKernel, &GeluKernelImpl);
|
||||
REGISTER_DISPATCH(GeluBackwardKernel, &GeluBackwardKernelImpl);
|
||||
REGISTER_DISPATCH(hardshrink_cpu_stub, &hardshrink_cpu_kernel);
|
||||
REGISTER_DISPATCH(hardshrink_backward_cpu_stub, &hardshrink_backward_cpu_kernel);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -8242,6 +8242,28 @@ class _TestTorchMixin(object):
|
|||
# test non-contiguous case
|
||||
self.assertEqual(torch.tensor([1, 0, 0.5, 0.6]).view(2, 2), data.t().hardshrink(0.3))
|
||||
|
||||
def test_hardshrink_edge_cases(self):
|
||||
def h(t, values, l_expected):
|
||||
for l, expected in l_expected.items():
|
||||
values_tensor = torch.tensor([float(v) for v in values]).type(t)
|
||||
expected_tensor = torch.tensor([float(v) for v in expected]).type(t)
|
||||
self.assertEqual(expected_tensor == values_tensor.hardshrink(l),
|
||||
torch.ones_like(values_tensor))
|
||||
|
||||
def test_helper(t, min, max):
|
||||
h(t, [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
|
||||
{0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
|
||||
min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf],
|
||||
0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf],
|
||||
1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf],
|
||||
max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf],
|
||||
inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]})
|
||||
|
||||
test_helper(torch.DoubleTensor,
|
||||
torch.finfo(torch.double).tiny, torch.finfo(torch.double).max)
|
||||
test_helper(torch.FloatTensor,
|
||||
torch.finfo(torch.float).tiny, torch.finfo(torch.float).max)
|
||||
|
||||
def test_unbiased(self):
|
||||
tensor = torch.randn(100)
|
||||
self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user