mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[core] Dispatch to at::nansum_out rather than at::native::nansum_out (#156642)
Calling `at::native::nansum_out` causes the fake kernel to dispatch to a `make_reduction` call and then segfaults later due to the `mutable_data_ptr` call in `TensorIteratorBase::build`. It also causes fake tensor propagation issue in Dynamo. The added tests demonstrate the aforementioned 2 issues. This patch fixes it by dispatching to `at::nansum_out` instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156642 Approved by: https://github.com/zou3519
This commit is contained in:
parent
863327ae49
commit
89aa708b39
|
|
@ -1451,7 +1451,7 @@ Tensor& nanmean_out(
|
|||
"nanmean(): expected input to have floating point or complex dtype but got ",
|
||||
self.scalar_type());
|
||||
const auto factor = at::native::isnan(self).logical_not_().sum(dim, keepdim);
|
||||
at::native::nansum_out(self, dim, keepdim, opt_dtype, result).div_(factor);
|
||||
at::nansum_out(result, self, dim, keepdim, opt_dtype).div_(factor);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7015,6 +7015,18 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
with self.assertRaises(torch._dynamo.exc.Unsupported):
|
||||
fn(torch.ones(3))
|
||||
|
||||
def test_nanmean_out(self):
|
||||
def f(x, out):
|
||||
torch.nanmean(x, out=out)
|
||||
|
||||
x = torch.randn(4)
|
||||
out_ref = torch.tensor(0.0)
|
||||
out_res = torch.tensor(0.0)
|
||||
|
||||
f(x, out_ref)
|
||||
torch.compile(f, backend="eager", fullgraph=True)(x, out_res)
|
||||
self.assertEqual(out_ref, out_res)
|
||||
|
||||
|
||||
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
def test_sub_alpha_scalar_repro(self, device):
|
||||
|
|
|
|||
|
|
@ -983,6 +983,15 @@ class FakeTensorTest(TestCase):
|
|||
y = fast_div(mode, x, 2)
|
||||
self.assertEqual(y.dtype, torch.float32)
|
||||
|
||||
def test_nanmean_out(self):
|
||||
# Regression test to ensure we don't error out.
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode() as mode:
|
||||
x = torch.randn(10)
|
||||
out = torch.empty(())
|
||||
torch.nanmean(x, out=out)
|
||||
|
||||
self.assertEqual(out.dtype, x.dtype)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(FakeTensorTest)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user