mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Fix torch.full with dynamic tensor fill_value in torch.compile (#166554)"
This reverts commit32066772b3. Reverted https://github.com/pytorch/pytorch/pull/166554 on behalf of https://github.com/atalman due to Failure: test/nn/test_pooling.py::TestPoolingNNDeviceTypeCPU::test_max_pool_nan_inf_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18959368975/job/54144148546) [HUD commit link](32066772b3) ([comment](https://github.com/pytorch/pytorch/pull/166554#issuecomment-3472976911))
This commit is contained in:
parent
b0831930ed
commit
657f8c3e21
|
|
@ -5216,63 +5216,6 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
||||||
x = torch.randn(1)
|
x = torch.randn(1)
|
||||||
self.assertEqual(opt_mod(x), x + 1)
|
self.assertEqual(opt_mod(x), x + 1)
|
||||||
|
|
||||||
def test_full_with_tensor_fill_value(self):
|
|
||||||
"""Test that torch.full works correctly with dynamic tensor fill_value"""
|
|
||||||
|
|
||||||
# Test with tensor fill_value (the bug case)
|
|
||||||
def func_tensor(x):
|
|
||||||
return torch.full((2,), x, dtype=torch.float64)
|
|
||||||
|
|
||||||
func_compiled = torch.compile(func_tensor)
|
|
||||||
|
|
||||||
# Test with different values
|
|
||||||
x1 = torch.tensor(5.0, dtype=torch.float64)
|
|
||||||
x2 = torch.tensor(10.0, dtype=torch.float64)
|
|
||||||
|
|
||||||
result1 = func_compiled(x1)
|
|
||||||
expected1 = torch.full((2,), x1, dtype=torch.float64)
|
|
||||||
self.assertEqual(result1, expected1)
|
|
||||||
|
|
||||||
# This is where the bug occurred - second call reused first value
|
|
||||||
result2 = func_compiled(x2)
|
|
||||||
expected2 = torch.full((2,), x2, dtype=torch.float64)
|
|
||||||
self.assertEqual(result2, expected2)
|
|
||||||
|
|
||||||
# Test with different dtypes
|
|
||||||
for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]:
|
|
||||||
|
|
||||||
def func_typed(x):
|
|
||||||
return torch.full((3,), x, dtype=dtype)
|
|
||||||
|
|
||||||
func_typed_compiled = torch.compile(func_typed)
|
|
||||||
x_typed = torch.tensor(7, dtype=dtype)
|
|
||||||
result = func_typed_compiled(x_typed)
|
|
||||||
expected = torch.full((3,), x_typed, dtype=dtype)
|
|
||||||
self.assertEqual(result, expected)
|
|
||||||
|
|
||||||
# Test with non-tensor fill_value (scalar) to ensure we didn't break existing behavior
|
|
||||||
def func_scalar(size):
|
|
||||||
return torch.full((size,), 42.0, dtype=torch.float32)
|
|
||||||
|
|
||||||
func_scalar_compiled = torch.compile(func_scalar)
|
|
||||||
|
|
||||||
result_scalar = func_scalar_compiled(5)
|
|
||||||
expected_scalar = torch.full((5,), 42.0, dtype=torch.float32)
|
|
||||||
self.assertEqual(result_scalar, expected_scalar)
|
|
||||||
|
|
||||||
# Test with different scalar values
|
|
||||||
def func_scalar_param():
|
|
||||||
# Test multiple calls with different hardcoded scalar values
|
|
||||||
a = torch.full((2,), 3.14, dtype=torch.float32)
|
|
||||||
b = torch.full((2,), 2.71, dtype=torch.float32)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
func_scalar_param_compiled = torch.compile(func_scalar_param)
|
|
||||||
result_a, result_b = func_scalar_param_compiled()
|
|
||||||
|
|
||||||
self.assertEqual(result_a, torch.full((2,), 3.14, dtype=torch.float32))
|
|
||||||
self.assertEqual(result_b, torch.full((2,), 2.71, dtype=torch.float32))
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(FunctionTests)
|
instantiate_parametrized_tests(FunctionTests)
|
||||||
instantiate_parametrized_tests(DefaultsTests)
|
instantiate_parametrized_tests(DefaultsTests)
|
||||||
|
|
|
||||||
|
|
@ -834,24 +834,11 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||||
@register(torch.full)
|
@register(torch.full)
|
||||||
def handle_full(self, tx, size, fill_value, **kwargs):
|
def handle_full(self, tx, size, fill_value, **kwargs):
|
||||||
if isinstance(fill_value, TensorVariable):
|
if isinstance(fill_value, TensorVariable):
|
||||||
# Decompose: create empty tensor and fill it
|
result = TorchInGraphFunctionVariable(
|
||||||
# This avoids the scalar extraction at compile time
|
torch.ops.aten._local_scalar_dense
|
||||||
empty_result = TorchInGraphFunctionVariable(torch.empty).call_function(
|
).call_function(tx, [fill_value], {})
|
||||||
tx, [size], kwargs
|
return TorchInGraphFunctionVariable(torch.full).call_function(
|
||||||
)
|
tx, [size, result], kwargs
|
||||||
# Call fill_ method on the empty tensor
|
|
||||||
return empty_result.call_method(tx, "fill_", [fill_value], {})
|
|
||||||
else:
|
|
||||||
# For Python scalars and other non-tensor types, use default lowering
|
|
||||||
from .builder import wrap_fx_proxy
|
|
||||||
|
|
||||||
return wrap_fx_proxy(
|
|
||||||
tx=tx,
|
|
||||||
proxy=tx.output.create_proxy(
|
|
||||||
"call_function",
|
|
||||||
torch.ops.aten.full.default,
|
|
||||||
*proxy_args_kwargs([size, fill_value], kwargs),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@register(torch._foreach_lerp_)
|
@register(torch._foreach_lerp_)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user