Revert "Fix torch.nn.functional.hardswish gradients corner case (#148049)"

This reverts commit 29b28e9d9f.

Reverted https://github.com/pytorch/pytorch/pull/148049 on behalf of https://github.com/soulitzer due to This may be causing an accuracy failure on inductor ([comment](https://github.com/pytorch/pytorch/pull/148049#issuecomment-2706839169))
This commit is contained in:
PyTorch MergeBot 2025-03-07 16:05:56 +00:00
parent 17302b4bc8
commit abcca2fcbb
4 changed files with 26 additions and 28 deletions

View File

@ -832,9 +832,9 @@ void hardswish_backward_kernel(TensorIterator& iter) {
cpu_kernel_vec(
iter,
[&](scalar_t grad_val, scalar_t self_val) -> scalar_t {
if (float(self_val) <= neg_three) {
if (float(self_val) < neg_three) {
return zero;
} else if (float(self_val) < three) {
} else if (float(self_val) <= three) {
return float(grad_val) * ((float(self_val) / three) + one_half);
} else {
return grad_val;
@ -847,19 +847,19 @@ void hardswish_backward_kernel(TensorIterator& iter) {
Vec::blendv(
grad_val0 * ((self_val0 / kThreeVec) + kOneHalfVec),
grad_val0,
self_val0 >= kThreeVec
self_val0 > kThreeVec
),
kZeroVec,
self_val0 <= kNegThreeVec
self_val0 < kNegThreeVec
);
self_val1 = Vec::blendv(
Vec::blendv(
grad_val1 * ((self_val1 / kThreeVec) + kOneHalfVec),
grad_val1,
self_val1 >= kThreeVec
self_val1 > kThreeVec
),
kZeroVec,
self_val1 <= kNegThreeVec
self_val1 < kNegThreeVec
);
return convert_from_float<scalar_t>(self_val0, self_val1);
});
@ -878,9 +878,9 @@ void hardswish_backward_kernel(TensorIterator& iter) {
cpu_kernel_vec(
iter,
[&](scalar_t grad_val, scalar_t self_val) {
if (self_val <= neg_three) {
if (self_val < neg_three) {
return zero;
} else if (self_val < three) {
} else if (self_val <= three) {
return grad_val * ((self_val / three) + one_half);
} else {
return grad_val;
@ -891,10 +891,10 @@ void hardswish_backward_kernel(TensorIterator& iter) {
Vec::blendv(
grad_val * ((self_val / kThreeVec) + kOneHalfVec),
grad_val,
self_val >= kThreeVec
self_val > kThreeVec
),
kZeroVec,
self_val <= kNegThreeVec
self_val < kNegThreeVec
);
}
);

View File

@ -45,9 +45,9 @@ void hardswish_backward_kernel(TensorIterator& iter) {
[zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t {
opmath_t grad_val = static_cast<opmath_t>(grad_val_);
opmath_t self_val = static_cast<opmath_t>(self_val_);
if (self_val <= neg_three) {
if (self_val < neg_three) {
return zero;
} else if (self_val < three) {
} else if (self_val <= three) {
return grad_val * ((self_val / three) + one_half);
} else {
return grad_val;

View File

@ -11151,22 +11151,20 @@ class TestNNDeviceType(NNTestCase):
inputs.requires_grad = True
self.assertTrue(gradcheck(F.hardswish, (inputs,)))
def _test_hardswish_grad_corner(self, device, dtype, scalar, ref_fn):
m = nn.Hardswish()
shape = (1, 9, 9, 1)
inputs = torch.ones(shape, device=device, dtype=dtype)
inputs = inputs * scalar
inputs.requires_grad = True
fwd_result = m(inputs)
fwd_result.backward(torch.ones_like(fwd_result))
ref = ref_fn(shape, device=device, dtype=dtype)
self.assertEqual(inputs.grad, ref)
@onlyNativeDeviceTypes
@onlyCPU
@dtypes(torch.half, torch.bfloat16, torch.float)
def test_hardswish_grad_corner(self, device, dtype):
self._test_hardswish_grad_corner(device, dtype, 3, torch.ones)
self._test_hardswish_grad_corner(device, dtype, -3, torch.zeros)
m = nn.Hardswish()
shape = (1, 9, 9, 1)
cpu_input = torch.ones(shape, device=device, dtype=dtype)
cpu_input = cpu_input * 3
cpu_input.requires_grad = True
fwd_result = m(cpu_input)
grad = torch.ones_like(fwd_result)
fwd_result.backward(grad)
ref = torch.ones(shape, device=device, dtype=dtype)
ref.fill_(1.5)
self.assertEqual(cpu_input.grad, ref)
def _test_batchnorm_eval(self, ndim, device, dtype, module_dtype=None):
module_dtype = module_dtype or dtype

View File

@ -936,8 +936,8 @@ const std::vector<std::string> functions = {
def hardswish(self):
result = torch.hardswish(self)
def backward(grad_output):
m = (self >= 3.).type_as(result)
m = torch.where((self > -3.) & (self < 3.), self / 3. + .5, m)
m = (self > 3.).type_as(result)
m = torch.where((self >= -3.) & (self <= 3.), self / 3. + .5, m)
return grad_output * m
return result, backward