mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add resolve in add decomp to enable view (#153945)
Fixes #148950. During the construction of graph and running the node of add under [interpreter](/github.com/pytorch/pytorch/blob/d68d4d31f4824f1d1e0d1d6899e9879ad19b0754/torch/fx/interpreter.py#L301 ), the functional argument of conj complex tensor gets cloned. This result in always having *.is_conj()* evaluted to false in decomposition function. Propose a fix of calling resolve_conj() in the decomposition of complex tensor add. Test as below `python test/dynamo/test_repros.py ReproTests.test_add_complex_conj` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153945 Approved by: https://github.com/jansel
This commit is contained in:
parent
fec571cfd4
commit
dabb55baff
|
|
@ -5742,6 +5742,17 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
torch.view_as_real(out_test).sum().backward()
|
||||
self.assertEqual(x_ref.grad, x_test.grad)
|
||||
|
||||
def test_add_complex_conj(self):
|
||||
def f(x):
|
||||
return x + x.conj()
|
||||
|
||||
x = torch.randn(4, dtype=torch.complex64, requires_grad=True)
|
||||
out = torch.compile(f)(x)
|
||||
expected_complex = (2 * x.real).to(dtype=out.dtype)
|
||||
|
||||
self.assertTrue(out.dtype == torch.complex64)
|
||||
self.assertEqual(out, expected_complex)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/132200
|
||||
def test_partitioner_cse_respects_mutation_boundaries(self):
|
||||
set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
|
||||
|
|
|
|||
|
|
@ -496,6 +496,10 @@ def add(
|
|||
reshaped_tensor = tensor.view(new_shape)
|
||||
return reshaped_tensor
|
||||
|
||||
# Manually resolve complex tensors, as .is_conj() is unreliable after cloning during compilation.
|
||||
x = x + 0
|
||||
z = z + 0
|
||||
|
||||
x_reshaped = reshape_tensor_complex(x.view(x.real.dtype))
|
||||
z_reshaped = reshape_tensor_complex(z.view(y.real.dtype))
|
||||
result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type)
|
||||
|
|
@ -504,7 +508,8 @@ def add(
|
|||
|
||||
@register_decomposition([aten.conj_physical])
|
||||
def conj_physical(self: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.is_complex(), "TODO: implement this"
|
||||
if self.is_complex():
|
||||
return NotImplemented
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user