Revert "Fix torch.full with dynamic tensor fill_value in torch.compile (#166554)"

This reverts commit 32066772b3.

Reverted https://github.com/pytorch/pytorch/pull/166554 on behalf of https://github.com/atalman due to Failure: test/nn/test_pooling.py::TestPoolingNNDeviceTypeCPU::test_max_pool_nan_inf_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18959368975/job/54144148546) [HUD commit link](32066772b3) ([comment](https://github.com/pytorch/pytorch/pull/166554#issuecomment-3472976911))
This commit is contained in:
PyTorch MergeBot 2025-10-31 12:55:31 +00:00
parent b0831930ed
commit 657f8c3e21
2 changed files with 5 additions and 75 deletions

View File

@ -5216,63 +5216,6 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
x = torch.randn(1)
self.assertEqual(opt_mod(x), x + 1)
def test_full_with_tensor_fill_value(self):
"""Test that torch.full works correctly with dynamic tensor fill_value"""
# Test with tensor fill_value (the bug case)
def func_tensor(x):
return torch.full((2,), x, dtype=torch.float64)
func_compiled = torch.compile(func_tensor)
# Test with different values
x1 = torch.tensor(5.0, dtype=torch.float64)
x2 = torch.tensor(10.0, dtype=torch.float64)
result1 = func_compiled(x1)
expected1 = torch.full((2,), x1, dtype=torch.float64)
self.assertEqual(result1, expected1)
# This is where the bug occurred - second call reused first value
result2 = func_compiled(x2)
expected2 = torch.full((2,), x2, dtype=torch.float64)
self.assertEqual(result2, expected2)
# Test with different dtypes
for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]:
def func_typed(x):
return torch.full((3,), x, dtype=dtype)
func_typed_compiled = torch.compile(func_typed)
x_typed = torch.tensor(7, dtype=dtype)
result = func_typed_compiled(x_typed)
expected = torch.full((3,), x_typed, dtype=dtype)
self.assertEqual(result, expected)
# Test with non-tensor fill_value (scalar) to ensure we didn't break existing behavior
def func_scalar(size):
return torch.full((size,), 42.0, dtype=torch.float32)
func_scalar_compiled = torch.compile(func_scalar)
result_scalar = func_scalar_compiled(5)
expected_scalar = torch.full((5,), 42.0, dtype=torch.float32)
self.assertEqual(result_scalar, expected_scalar)
# Test with different scalar values
def func_scalar_param():
# Test multiple calls with different hardcoded scalar values
a = torch.full((2,), 3.14, dtype=torch.float32)
b = torch.full((2,), 2.71, dtype=torch.float32)
return a, b
func_scalar_param_compiled = torch.compile(func_scalar_param)
result_a, result_b = func_scalar_param_compiled()
self.assertEqual(result_a, torch.full((2,), 3.14, dtype=torch.float32))
self.assertEqual(result_b, torch.full((2,), 2.71, dtype=torch.float32))
instantiate_parametrized_tests(FunctionTests)
instantiate_parametrized_tests(DefaultsTests)

View File

@ -834,24 +834,11 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
@register(torch.full)
def handle_full(self, tx, size, fill_value, **kwargs):
if isinstance(fill_value, TensorVariable):
# Decompose: create empty tensor and fill it
# This avoids the scalar extraction at compile time
empty_result = TorchInGraphFunctionVariable(torch.empty).call_function(
tx, [size], kwargs
)
# Call fill_ method on the empty tensor
return empty_result.call_method(tx, "fill_", [fill_value], {})
else:
# For Python scalars and other non-tensor types, use default lowering
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
torch.ops.aten.full.default,
*proxy_args_kwargs([size, fill_value], kwargs),
),
result = TorchInGraphFunctionVariable(
torch.ops.aten._local_scalar_dense
).call_function(tx, [fill_value], {})
return TorchInGraphFunctionVariable(torch.full).call_function(
tx, [size, result], kwargs
)
@register(torch._foreach_lerp_)