mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[hop][BE] add util diff_meta with prettier error message. (#142162)
The error message changes from:
```python
-torch._dynamo.exc.Unsupported: Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:[('pair0:', TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}), TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}))]
```
to
```python
+torch._dynamo.exc.Unsupported: Expect branches to return tensors with same metadata but find pair[0] differ in 'shape', where lhs is TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}) and rhs is TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={})
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142162
Approved by: https://github.com/zou3519
This commit is contained in:
parent
9ced54a51a
commit
7111cd6ee0
|
|
@ -6957,6 +6957,23 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
|||
with self.assertRaises(AssertionError):
|
||||
opt_test(True, False, inp)
|
||||
|
||||
def test_cond_with_mismatched_output(self):
|
||||
def output_mismatch_test(x):
|
||||
def true_fn():
|
||||
return torch.concat([x, x])
|
||||
|
||||
def false_fn():
|
||||
return x.sin()
|
||||
|
||||
return torch.cond(x.sum() > 0, true_fn, false_fn)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
||||
output_mismatch_test(x)
|
||||
|
||||
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
||||
torch.compile(output_mismatch_test)(x)
|
||||
|
||||
def test_non_aliasing_util(self):
|
||||
from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,34 @@ def discard_graph_changes(tx):
|
|||
ctx.__exit__(None, None, None)
|
||||
|
||||
|
||||
def diff_meta(tensor_vars1, tensor_vars2) -> str:
|
||||
from torch._higher_order_ops.utils import diff_tensor_meta
|
||||
|
||||
from . import TensorVariable
|
||||
|
||||
assert all(isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2)
|
||||
all_diffs = []
|
||||
for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
|
||||
# We have vmap x cond tests and querying is_contiguous inside of vmap for
|
||||
# memory_format other than torch.contiguous_format is not yet implemented.
|
||||
# And it seems the remaining metas are good enough for now.
|
||||
meta1 = _extract_tensor_metadata(
|
||||
var1.proxy.node.meta["example_value"], include_contiguity=False
|
||||
)
|
||||
meta2 = _extract_tensor_metadata(
|
||||
var2.proxy.node.meta["example_value"], include_contiguity=False
|
||||
)
|
||||
# We cannot get accurate require_grad. See Note [invariants for node meta 'val']
|
||||
pair_diffs = diff_tensor_meta(meta1, meta2, check_grad=False)
|
||||
|
||||
if len(pair_diffs) > 0:
|
||||
fmt_str = ", ".join(pair_diffs)
|
||||
all_diffs.append(
|
||||
f"pair[{i}] differ in {fmt_str}, where lhs is {meta1} and rhs is {meta2}"
|
||||
)
|
||||
return "\n".join(all_diffs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def dynamo_enable_grad(tx: "InstructionTranslator", enable=True):
|
||||
from . import GradModeVariable
|
||||
|
|
@ -888,28 +916,11 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
if not same_treespec.as_python_constant():
|
||||
unimplemented("Expected branches to return the same pytree structure.")
|
||||
|
||||
def diff_meta(tensor_vars1, tensor_vars2):
|
||||
assert all(
|
||||
isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2
|
||||
)
|
||||
all_diffs = []
|
||||
for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
|
||||
# We check the meta data associated with meta["example_value"]
|
||||
meta1 = _extract_tensor_metadata(
|
||||
var1.proxy.node.meta["example_value"], include_contiguity=False
|
||||
)
|
||||
meta2 = _extract_tensor_metadata(
|
||||
var2.proxy.node.meta["example_value"], include_contiguity=False
|
||||
)
|
||||
if meta1 != meta2:
|
||||
all_diffs.append((f"pair{i}:", meta1, meta2))
|
||||
return all_diffs
|
||||
|
||||
if diffs := diff_meta(
|
||||
true_r.unpack_var_sequence(tx), false_r.unpack_var_sequence(tx)
|
||||
):
|
||||
unimplemented(
|
||||
f"Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:{diffs}"
|
||||
f"Expect branches to return tensors with same metadata but find {diffs}"
|
||||
)
|
||||
|
||||
(
|
||||
|
|
@ -1119,6 +1130,12 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
set_subgraph_inputs="flatten_manual",
|
||||
should_flatten_outputs=True,
|
||||
)
|
||||
|
||||
if diffs := diff_meta(operands_seq, body_r.unpack_var_sequence(tx)):
|
||||
unimplemented(
|
||||
f"Expected carried_inputs and body outputs return tensors with same metadata but find:\n{diffs}"
|
||||
)
|
||||
|
||||
(
|
||||
cond_graph,
|
||||
body_graph,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import torch.fx.traceback as fx_traceback
|
|||
import torch.utils._pytree as pytree
|
||||
from torch._ops import OperatorBase
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.passes.shape_prop import TensorMetadata
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
|
||||
|
|
@ -481,6 +482,27 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
|||
return torch.select_copy(t, dim, 0)
|
||||
|
||||
|
||||
# Reports the difference between meta of two tensors in a string
|
||||
def diff_tensor_meta(
|
||||
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
|
||||
) -> List[str]:
|
||||
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
||||
|
||||
pair_diffs = []
|
||||
for meta_name in TensorMetadata._fields:
|
||||
if not check_grad and meta_name == "requires_grad":
|
||||
continue
|
||||
val1 = getattr(meta1, meta_name)
|
||||
val2 = getattr(meta2, meta_name)
|
||||
try:
|
||||
if val1 != val2:
|
||||
pair_diffs.append(f"'{meta_name}'")
|
||||
except GuardOnDataDependentSymNode as _:
|
||||
pair_diffs.append(f"'{meta_name}'")
|
||||
continue
|
||||
return pair_diffs
|
||||
|
||||
|
||||
# Note [lifted arg types in hop]
|
||||
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
|
||||
# This has implications for the types of lifted args for different dispatch keys:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user