From 657f8c3e21bd8901dd8ce79ca9a54a45b27f604f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 31 Oct 2025 12:55:31 +0000 Subject: [PATCH] Revert "Fix torch.full with dynamic tensor fill_value in torch.compile (#166554)" This reverts commit 32066772b3dee643b1657b8957f32b5ac8b1390a. Reverted https://github.com/pytorch/pytorch/pull/166554 on behalf of https://github.com/atalman due to Failure: test/nn/test_pooling.py::TestPoolingNNDeviceTypeCPU::test_max_pool_nan_inf_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18959368975/job/54144148546) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/32066772b3dee643b1657b8957f32b5ac8b1390a) ([comment](https://github.com/pytorch/pytorch/pull/166554#issuecomment-3472976911)) --- test/dynamo/test_functions.py | 57 -------------------------------- torch/_dynamo/variables/torch.py | 23 +++---------- 2 files changed, 5 insertions(+), 75 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index f3765369e3f..c06331cea7d 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -5216,63 +5216,6 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): x = torch.randn(1) self.assertEqual(opt_mod(x), x + 1) - def test_full_with_tensor_fill_value(self): - """Test that torch.full works correctly with dynamic tensor fill_value""" - - # Test with tensor fill_value (the bug case) - def func_tensor(x): - return torch.full((2,), x, dtype=torch.float64) - - func_compiled = torch.compile(func_tensor) - - # Test with different values - x1 = torch.tensor(5.0, dtype=torch.float64) - x2 = torch.tensor(10.0, dtype=torch.float64) - - result1 = func_compiled(x1) - expected1 = torch.full((2,), x1, dtype=torch.float64) - self.assertEqual(result1, expected1) - - # This is where the bug occurred - second call reused first value - result2 = func_compiled(x2) - expected2 = torch.full((2,), x2, dtype=torch.float64) - self.assertEqual(result2, expected2) - - # Test with different dtypes - for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]: - - def func_typed(x): - return torch.full((3,), x, dtype=dtype) - - func_typed_compiled = torch.compile(func_typed) - x_typed = torch.tensor(7, dtype=dtype) - result = func_typed_compiled(x_typed) - expected = torch.full((3,), x_typed, dtype=dtype) - self.assertEqual(result, expected) - - # Test with non-tensor fill_value (scalar) to ensure we didn't break existing behavior - def func_scalar(size): - return torch.full((size,), 42.0, dtype=torch.float32) - - func_scalar_compiled = torch.compile(func_scalar) - - result_scalar = func_scalar_compiled(5) - expected_scalar = torch.full((5,), 42.0, dtype=torch.float32) - self.assertEqual(result_scalar, expected_scalar) - - # Test with different scalar values - def func_scalar_param(): - # Test multiple calls with different hardcoded scalar values - a = torch.full((2,), 3.14, dtype=torch.float32) - b = torch.full((2,), 2.71, dtype=torch.float32) - return a, b - - func_scalar_param_compiled = torch.compile(func_scalar_param) - result_a, result_b = func_scalar_param_compiled() - - self.assertEqual(result_a, torch.full((2,), 3.14, dtype=torch.float32)) - self.assertEqual(result_b, torch.full((2,), 2.71, dtype=torch.float32)) - instantiate_parametrized_tests(FunctionTests) instantiate_parametrized_tests(DefaultsTests) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 0d73374623b..e48a4881015 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -834,24 +834,11 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): @register(torch.full) def handle_full(self, tx, size, fill_value, **kwargs): if isinstance(fill_value, TensorVariable): - # Decompose: create empty tensor and fill it - # This avoids the scalar extraction at compile time - empty_result = TorchInGraphFunctionVariable(torch.empty).call_function( - tx, [size], kwargs - ) - # Call fill_ method on the empty tensor - return empty_result.call_method(tx, "fill_", [fill_value], {}) - else: - # For Python scalars and other non-tensor types, use default lowering - from .builder import wrap_fx_proxy - - return wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - torch.ops.aten.full.default, - *proxy_args_kwargs([size, fill_value], kwargs), - ), + result = TorchInGraphFunctionVariable( + torch.ops.aten._local_scalar_dense + ).call_function(tx, [fill_value], {}) + return TorchInGraphFunctionVariable(torch.full).call_function( + tx, [size, result], kwargs ) @register(torch._foreach_lerp_)