Add resolve in add decomp to enable view (#153945)

Fixes #148950.

During the construction of graph and running the node of add under [interpreter](/github.com/pytorch/pytorch/blob/d68d4d31f4824f1d1e0d1d6899e9879ad19b0754/torch/fx/interpreter.py#L301
), the functional argument of conj complex tensor gets cloned. This result in always having *.is_conj()* evaluted to false in decomposition function.

Propose a fix of calling resolve_conj() in the decomposition of complex tensor add.

Test as below
`python test/dynamo/test_repros.py ReproTests.test_add_complex_conj`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153945
Approved by: https://github.com/jansel
This commit is contained in:
cz2h 2025-06-14 00:41:50 +00:00 committed by PyTorch MergeBot
parent fec571cfd4
commit dabb55baff
2 changed files with 17 additions and 1 deletions

View File

@ -5742,6 +5742,17 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
torch.view_as_real(out_test).sum().backward() torch.view_as_real(out_test).sum().backward()
self.assertEqual(x_ref.grad, x_test.grad) self.assertEqual(x_ref.grad, x_test.grad)
def test_add_complex_conj(self):
def f(x):
return x + x.conj()
x = torch.randn(4, dtype=torch.complex64, requires_grad=True)
out = torch.compile(f)(x)
expected_complex = (2 * x.real).to(dtype=out.dtype)
self.assertTrue(out.dtype == torch.complex64)
self.assertEqual(out, expected_complex)
# https://github.com/pytorch/pytorch/issues/132200 # https://github.com/pytorch/pytorch/issues/132200
def test_partitioner_cse_respects_mutation_boundaries(self): def test_partitioner_cse_respects_mutation_boundaries(self):
set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_") set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")

View File

@ -496,6 +496,10 @@ def add(
reshaped_tensor = tensor.view(new_shape) reshaped_tensor = tensor.view(new_shape)
return reshaped_tensor return reshaped_tensor
# Manually resolve complex tensors, as .is_conj() is unreliable after cloning during compilation.
x = x + 0
z = z + 0
x_reshaped = reshape_tensor_complex(x.view(x.real.dtype)) x_reshaped = reshape_tensor_complex(x.view(x.real.dtype))
z_reshaped = reshape_tensor_complex(z.view(y.real.dtype)) z_reshaped = reshape_tensor_complex(z.view(y.real.dtype))
result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type) result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type)
@ -504,7 +508,8 @@ def add(
@register_decomposition([aten.conj_physical]) @register_decomposition([aten.conj_physical])
def conj_physical(self: torch.Tensor) -> torch.Tensor: def conj_physical(self: torch.Tensor) -> torch.Tensor:
assert not self.is_complex(), "TODO: implement this" if self.is_complex():
return NotImplemented
return self return self