[hop][BE] add util diff_meta with prettier error message. (#142162)

The error message changes from:
```python
-torch._dynamo.exc.Unsupported: Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:[('pair0:', TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}), TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}))]
```
to
```python
+torch._dynamo.exc.Unsupported: Expect branches to return tensors with same metadata but find pair[0] differ in 'shape', where lhs is TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}) and rhs is TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142162
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu 2024-12-06 10:38:50 -08:00 committed by PyTorch MergeBot
parent 9ced54a51a
commit 7111cd6ee0
3 changed files with 74 additions and 18 deletions

View File

@ -6957,6 +6957,23 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
with self.assertRaises(AssertionError):
opt_test(True, False, inp)
def test_cond_with_mismatched_output(self):
def output_mismatch_test(x):
def true_fn():
return torch.concat([x, x])
def false_fn():
return x.sin()
return torch.cond(x.sum() > 0, true_fn, false_fn)
x = torch.randn(2, 3)
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
output_mismatch_test(x)
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
torch.compile(output_mismatch_test)(x)
def test_non_aliasing_util(self):
from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing

View File

@ -82,6 +82,34 @@ def discard_graph_changes(tx):
ctx.__exit__(None, None, None)
def diff_meta(tensor_vars1, tensor_vars2) -> str:
from torch._higher_order_ops.utils import diff_tensor_meta
from . import TensorVariable
assert all(isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2)
all_diffs = []
for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
# We have vmap x cond tests and querying is_contiguous inside of vmap for
# memory_format other than torch.contiguous_format is not yet implemented.
# And it seems the remaining metas are good enough for now.
meta1 = _extract_tensor_metadata(
var1.proxy.node.meta["example_value"], include_contiguity=False
)
meta2 = _extract_tensor_metadata(
var2.proxy.node.meta["example_value"], include_contiguity=False
)
# We cannot get accurate require_grad. See Note [invariants for node meta 'val']
pair_diffs = diff_tensor_meta(meta1, meta2, check_grad=False)
if len(pair_diffs) > 0:
fmt_str = ", ".join(pair_diffs)
all_diffs.append(
f"pair[{i}] differ in {fmt_str}, where lhs is {meta1} and rhs is {meta2}"
)
return "\n".join(all_diffs)
@contextlib.contextmanager
def dynamo_enable_grad(tx: "InstructionTranslator", enable=True):
from . import GradModeVariable
@ -888,28 +916,11 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
if not same_treespec.as_python_constant():
unimplemented("Expected branches to return the same pytree structure.")
def diff_meta(tensor_vars1, tensor_vars2):
assert all(
isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2
)
all_diffs = []
for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
# We check the meta data associated with meta["example_value"]
meta1 = _extract_tensor_metadata(
var1.proxy.node.meta["example_value"], include_contiguity=False
)
meta2 = _extract_tensor_metadata(
var2.proxy.node.meta["example_value"], include_contiguity=False
)
if meta1 != meta2:
all_diffs.append((f"pair{i}:", meta1, meta2))
return all_diffs
if diffs := diff_meta(
true_r.unpack_var_sequence(tx), false_r.unpack_var_sequence(tx)
):
unimplemented(
f"Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:{diffs}"
f"Expect branches to return tensors with same metadata but find {diffs}"
)
(
@ -1119,6 +1130,12 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
set_subgraph_inputs="flatten_manual",
should_flatten_outputs=True,
)
if diffs := diff_meta(operands_seq, body_r.unpack_var_sequence(tx)):
unimplemented(
f"Expected carried_inputs and body outputs return tensors with same metadata but find:\n{diffs}"
)
(
cond_graph,
body_graph,

View File

@ -9,6 +9,7 @@ import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._ops import OperatorBase
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.shape_prop import TensorMetadata
from torch.multiprocessing.reductions import StorageWeakRef
@ -481,6 +482,27 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
return torch.select_copy(t, dim, 0)
# Reports the difference between meta of two tensors in a string
def diff_tensor_meta(
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
) -> List[str]:
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
pair_diffs = []
for meta_name in TensorMetadata._fields:
if not check_grad and meta_name == "requires_grad":
continue
val1 = getattr(meta1, meta_name)
val2 = getattr(meta2, meta_name)
try:
if val1 != val2:
pair_diffs.append(f"'{meta_name}'")
except GuardOnDataDependentSymNode as _:
pair_diffs.append(f"'{meta_name}'")
continue
return pair_diffs
# Note [lifted arg types in hop]
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
# This has implications for the types of lifted args for different dispatch keys: