[core] Dispatch to at::nansum_out rather than at::native::nansum_out (#156642)

Calling `at::native::nansum_out` causes the fake kernel to dispatch to a
`make_reduction` call and then segfaults later due to the
`mutable_data_ptr` call in `TensorIteratorBase::build`. It also causes
fake tensor propagation issue in Dynamo. The added tests demonstrate the
aforementioned 2 issues.

This patch fixes it by dispatching to `at::nansum_out` instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156642
Approved by: https://github.com/zou3519
This commit is contained in:
Ryan Guo 2025-06-24 11:08:33 -07:00 committed by PyTorch MergeBot
parent 863327ae49
commit 89aa708b39
3 changed files with 22 additions and 1 deletions

View File

@ -1451,7 +1451,7 @@ Tensor& nanmean_out(
"nanmean(): expected input to have floating point or complex dtype but got ",
self.scalar_type());
const auto factor = at::native::isnan(self).logical_not_().sum(dim, keepdim);
at::native::nansum_out(self, dim, keepdim, opt_dtype, result).div_(factor);
at::nansum_out(result, self, dim, keepdim, opt_dtype).div_(factor);
return result;
}

View File

@ -7015,6 +7015,18 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
with self.assertRaises(torch._dynamo.exc.Unsupported):
fn(torch.ones(3))
def test_nanmean_out(self):
def f(x, out):
torch.nanmean(x, out=out)
x = torch.randn(4)
out_ref = torch.tensor(0.0)
out_res = torch.tensor(0.0)
f(x, out_ref)
torch.compile(f, backend="eager", fullgraph=True)(x, out_res)
self.assertEqual(out_ref, out_res)
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
def test_sub_alpha_scalar_repro(self, device):

View File

@ -983,6 +983,15 @@ class FakeTensorTest(TestCase):
y = fast_div(mode, x, 2)
self.assertEqual(y.dtype, torch.float32)
def test_nanmean_out(self):
# Regression test to ensure we don't error out.
with torch._subclasses.fake_tensor.FakeTensorMode() as mode:
x = torch.randn(10)
out = torch.empty(())
torch.nanmean(x, out=out)
self.assertEqual(out.dtype, x.dtype)
instantiate_parametrized_tests(FakeTensorTest)