mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add resolve in add decomp to enable view (#153945)
Fixes #148950. During the construction of graph and running the node of add under [interpreter](/github.com/pytorch/pytorch/blob/d68d4d31f4824f1d1e0d1d6899e9879ad19b0754/torch/fx/interpreter.py#L301 ), the functional argument of conj complex tensor gets cloned. This result in always having *.is_conj()* evaluted to false in decomposition function. Propose a fix of calling resolve_conj() in the decomposition of complex tensor add. Test as below `python test/dynamo/test_repros.py ReproTests.test_add_complex_conj` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153945 Approved by: https://github.com/jansel
This commit is contained in:
parent
fec571cfd4
commit
dabb55baff
|
|
@ -5742,6 +5742,17 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||||
torch.view_as_real(out_test).sum().backward()
|
torch.view_as_real(out_test).sum().backward()
|
||||||
self.assertEqual(x_ref.grad, x_test.grad)
|
self.assertEqual(x_ref.grad, x_test.grad)
|
||||||
|
|
||||||
|
def test_add_complex_conj(self):
|
||||||
|
def f(x):
|
||||||
|
return x + x.conj()
|
||||||
|
|
||||||
|
x = torch.randn(4, dtype=torch.complex64, requires_grad=True)
|
||||||
|
out = torch.compile(f)(x)
|
||||||
|
expected_complex = (2 * x.real).to(dtype=out.dtype)
|
||||||
|
|
||||||
|
self.assertTrue(out.dtype == torch.complex64)
|
||||||
|
self.assertEqual(out, expected_complex)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/132200
|
# https://github.com/pytorch/pytorch/issues/132200
|
||||||
def test_partitioner_cse_respects_mutation_boundaries(self):
|
def test_partitioner_cse_respects_mutation_boundaries(self):
|
||||||
set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
|
set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
|
||||||
|
|
|
||||||
|
|
@ -496,6 +496,10 @@ def add(
|
||||||
reshaped_tensor = tensor.view(new_shape)
|
reshaped_tensor = tensor.view(new_shape)
|
||||||
return reshaped_tensor
|
return reshaped_tensor
|
||||||
|
|
||||||
|
# Manually resolve complex tensors, as .is_conj() is unreliable after cloning during compilation.
|
||||||
|
x = x + 0
|
||||||
|
z = z + 0
|
||||||
|
|
||||||
x_reshaped = reshape_tensor_complex(x.view(x.real.dtype))
|
x_reshaped = reshape_tensor_complex(x.view(x.real.dtype))
|
||||||
z_reshaped = reshape_tensor_complex(z.view(y.real.dtype))
|
z_reshaped = reshape_tensor_complex(z.view(y.real.dtype))
|
||||||
result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type)
|
result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type)
|
||||||
|
|
@ -504,7 +508,8 @@ def add(
|
||||||
|
|
||||||
@register_decomposition([aten.conj_physical])
|
@register_decomposition([aten.conj_physical])
|
||||||
def conj_physical(self: torch.Tensor) -> torch.Tensor:
|
def conj_physical(self: torch.Tensor) -> torch.Tensor:
|
||||||
assert not self.is_complex(), "TODO: implement this"
|
if self.is_complex():
|
||||||
|
return NotImplemented
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user