mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Fix torch.nn.functional.hardswish gradients corner case (#148049)"
This reverts commit 29b28e9d9f.
Reverted https://github.com/pytorch/pytorch/pull/148049 on behalf of https://github.com/soulitzer due to This may be causing an accuracy failure on inductor ([comment](https://github.com/pytorch/pytorch/pull/148049#issuecomment-2706839169))
This commit is contained in:
parent
17302b4bc8
commit
abcca2fcbb
|
|
@ -832,9 +832,9 @@ void hardswish_backward_kernel(TensorIterator& iter) {
|
|||
cpu_kernel_vec(
|
||||
iter,
|
||||
[&](scalar_t grad_val, scalar_t self_val) -> scalar_t {
|
||||
if (float(self_val) <= neg_three) {
|
||||
if (float(self_val) < neg_three) {
|
||||
return zero;
|
||||
} else if (float(self_val) < three) {
|
||||
} else if (float(self_val) <= three) {
|
||||
return float(grad_val) * ((float(self_val) / three) + one_half);
|
||||
} else {
|
||||
return grad_val;
|
||||
|
|
@ -847,19 +847,19 @@ void hardswish_backward_kernel(TensorIterator& iter) {
|
|||
Vec::blendv(
|
||||
grad_val0 * ((self_val0 / kThreeVec) + kOneHalfVec),
|
||||
grad_val0,
|
||||
self_val0 >= kThreeVec
|
||||
self_val0 > kThreeVec
|
||||
),
|
||||
kZeroVec,
|
||||
self_val0 <= kNegThreeVec
|
||||
self_val0 < kNegThreeVec
|
||||
);
|
||||
self_val1 = Vec::blendv(
|
||||
Vec::blendv(
|
||||
grad_val1 * ((self_val1 / kThreeVec) + kOneHalfVec),
|
||||
grad_val1,
|
||||
self_val1 >= kThreeVec
|
||||
self_val1 > kThreeVec
|
||||
),
|
||||
kZeroVec,
|
||||
self_val1 <= kNegThreeVec
|
||||
self_val1 < kNegThreeVec
|
||||
);
|
||||
return convert_from_float<scalar_t>(self_val0, self_val1);
|
||||
});
|
||||
|
|
@ -878,9 +878,9 @@ void hardswish_backward_kernel(TensorIterator& iter) {
|
|||
cpu_kernel_vec(
|
||||
iter,
|
||||
[&](scalar_t grad_val, scalar_t self_val) {
|
||||
if (self_val <= neg_three) {
|
||||
if (self_val < neg_three) {
|
||||
return zero;
|
||||
} else if (self_val < three) {
|
||||
} else if (self_val <= three) {
|
||||
return grad_val * ((self_val / three) + one_half);
|
||||
} else {
|
||||
return grad_val;
|
||||
|
|
@ -891,10 +891,10 @@ void hardswish_backward_kernel(TensorIterator& iter) {
|
|||
Vec::blendv(
|
||||
grad_val * ((self_val / kThreeVec) + kOneHalfVec),
|
||||
grad_val,
|
||||
self_val >= kThreeVec
|
||||
self_val > kThreeVec
|
||||
),
|
||||
kZeroVec,
|
||||
self_val <= kNegThreeVec
|
||||
self_val < kNegThreeVec
|
||||
);
|
||||
}
|
||||
);
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ void hardswish_backward_kernel(TensorIterator& iter) {
|
|||
[zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t {
|
||||
opmath_t grad_val = static_cast<opmath_t>(grad_val_);
|
||||
opmath_t self_val = static_cast<opmath_t>(self_val_);
|
||||
if (self_val <= neg_three) {
|
||||
if (self_val < neg_three) {
|
||||
return zero;
|
||||
} else if (self_val < three) {
|
||||
} else if (self_val <= three) {
|
||||
return grad_val * ((self_val / three) + one_half);
|
||||
} else {
|
||||
return grad_val;
|
||||
|
|
|
|||
|
|
@ -11151,22 +11151,20 @@ class TestNNDeviceType(NNTestCase):
|
|||
inputs.requires_grad = True
|
||||
self.assertTrue(gradcheck(F.hardswish, (inputs,)))
|
||||
|
||||
def _test_hardswish_grad_corner(self, device, dtype, scalar, ref_fn):
|
||||
m = nn.Hardswish()
|
||||
shape = (1, 9, 9, 1)
|
||||
inputs = torch.ones(shape, device=device, dtype=dtype)
|
||||
inputs = inputs * scalar
|
||||
inputs.requires_grad = True
|
||||
fwd_result = m(inputs)
|
||||
fwd_result.backward(torch.ones_like(fwd_result))
|
||||
ref = ref_fn(shape, device=device, dtype=dtype)
|
||||
self.assertEqual(inputs.grad, ref)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@onlyCPU
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float)
|
||||
def test_hardswish_grad_corner(self, device, dtype):
|
||||
self._test_hardswish_grad_corner(device, dtype, 3, torch.ones)
|
||||
self._test_hardswish_grad_corner(device, dtype, -3, torch.zeros)
|
||||
m = nn.Hardswish()
|
||||
shape = (1, 9, 9, 1)
|
||||
cpu_input = torch.ones(shape, device=device, dtype=dtype)
|
||||
cpu_input = cpu_input * 3
|
||||
cpu_input.requires_grad = True
|
||||
fwd_result = m(cpu_input)
|
||||
grad = torch.ones_like(fwd_result)
|
||||
fwd_result.backward(grad)
|
||||
ref = torch.ones(shape, device=device, dtype=dtype)
|
||||
ref.fill_(1.5)
|
||||
self.assertEqual(cpu_input.grad, ref)
|
||||
|
||||
def _test_batchnorm_eval(self, ndim, device, dtype, module_dtype=None):
|
||||
module_dtype = module_dtype or dtype
|
||||
|
|
|
|||
|
|
@ -936,8 +936,8 @@ const std::vector<std::string> functions = {
|
|||
def hardswish(self):
|
||||
result = torch.hardswish(self)
|
||||
def backward(grad_output):
|
||||
m = (self >= 3.).type_as(result)
|
||||
m = torch.where((self > -3.) & (self < 3.), self / 3. + .5, m)
|
||||
m = (self > 3.).type_as(result)
|
||||
m = torch.where((self >= -3.) & (self <= 3.), self / 3. + .5, m)
|
||||
return grad_output * m
|
||||
return result, backward
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user