mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Fix torch.full with dynamic tensor fill_value in torch.compile (#166554)"
This reverts commit32066772b3. Reverted https://github.com/pytorch/pytorch/pull/166554 on behalf of https://github.com/atalman due to Failure: test/nn/test_pooling.py::TestPoolingNNDeviceTypeCPU::test_max_pool_nan_inf_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18959368975/job/54144148546) [HUD commit link](32066772b3) ([comment](https://github.com/pytorch/pytorch/pull/166554#issuecomment-3472976911))
This commit is contained in:
parent
b0831930ed
commit
657f8c3e21
|
|
@ -5216,63 +5216,6 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
|||
x = torch.randn(1)
|
||||
self.assertEqual(opt_mod(x), x + 1)
|
||||
|
||||
def test_full_with_tensor_fill_value(self):
|
||||
"""Test that torch.full works correctly with dynamic tensor fill_value"""
|
||||
|
||||
# Test with tensor fill_value (the bug case)
|
||||
def func_tensor(x):
|
||||
return torch.full((2,), x, dtype=torch.float64)
|
||||
|
||||
func_compiled = torch.compile(func_tensor)
|
||||
|
||||
# Test with different values
|
||||
x1 = torch.tensor(5.0, dtype=torch.float64)
|
||||
x2 = torch.tensor(10.0, dtype=torch.float64)
|
||||
|
||||
result1 = func_compiled(x1)
|
||||
expected1 = torch.full((2,), x1, dtype=torch.float64)
|
||||
self.assertEqual(result1, expected1)
|
||||
|
||||
# This is where the bug occurred - second call reused first value
|
||||
result2 = func_compiled(x2)
|
||||
expected2 = torch.full((2,), x2, dtype=torch.float64)
|
||||
self.assertEqual(result2, expected2)
|
||||
|
||||
# Test with different dtypes
|
||||
for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]:
|
||||
|
||||
def func_typed(x):
|
||||
return torch.full((3,), x, dtype=dtype)
|
||||
|
||||
func_typed_compiled = torch.compile(func_typed)
|
||||
x_typed = torch.tensor(7, dtype=dtype)
|
||||
result = func_typed_compiled(x_typed)
|
||||
expected = torch.full((3,), x_typed, dtype=dtype)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test with non-tensor fill_value (scalar) to ensure we didn't break existing behavior
|
||||
def func_scalar(size):
|
||||
return torch.full((size,), 42.0, dtype=torch.float32)
|
||||
|
||||
func_scalar_compiled = torch.compile(func_scalar)
|
||||
|
||||
result_scalar = func_scalar_compiled(5)
|
||||
expected_scalar = torch.full((5,), 42.0, dtype=torch.float32)
|
||||
self.assertEqual(result_scalar, expected_scalar)
|
||||
|
||||
# Test with different scalar values
|
||||
def func_scalar_param():
|
||||
# Test multiple calls with different hardcoded scalar values
|
||||
a = torch.full((2,), 3.14, dtype=torch.float32)
|
||||
b = torch.full((2,), 2.71, dtype=torch.float32)
|
||||
return a, b
|
||||
|
||||
func_scalar_param_compiled = torch.compile(func_scalar_param)
|
||||
result_a, result_b = func_scalar_param_compiled()
|
||||
|
||||
self.assertEqual(result_a, torch.full((2,), 3.14, dtype=torch.float32))
|
||||
self.assertEqual(result_b, torch.full((2,), 2.71, dtype=torch.float32))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(FunctionTests)
|
||||
instantiate_parametrized_tests(DefaultsTests)
|
||||
|
|
|
|||
|
|
@ -834,24 +834,11 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
@register(torch.full)
|
||||
def handle_full(self, tx, size, fill_value, **kwargs):
|
||||
if isinstance(fill_value, TensorVariable):
|
||||
# Decompose: create empty tensor and fill it
|
||||
# This avoids the scalar extraction at compile time
|
||||
empty_result = TorchInGraphFunctionVariable(torch.empty).call_function(
|
||||
tx, [size], kwargs
|
||||
)
|
||||
# Call fill_ method on the empty tensor
|
||||
return empty_result.call_method(tx, "fill_", [fill_value], {})
|
||||
else:
|
||||
# For Python scalars and other non-tensor types, use default lowering
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
torch.ops.aten.full.default,
|
||||
*proxy_args_kwargs([size, fill_value], kwargs),
|
||||
),
|
||||
result = TorchInGraphFunctionVariable(
|
||||
torch.ops.aten._local_scalar_dense
|
||||
).call_function(tx, [fill_value], {})
|
||||
return TorchInGraphFunctionVariable(torch.full).call_function(
|
||||
tx, [size, result], kwargs
|
||||
)
|
||||
|
||||
@register(torch._foreach_lerp_)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user