mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[hop][BE] add util diff_meta with prettier error message. (#142162)
The error message changes from:
```python
-torch._dynamo.exc.Unsupported: Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:[('pair0:', TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}), TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}))]
```
to
```python
+torch._dynamo.exc.Unsupported: Expect branches to return tensors with same metadata but find pair[0] differ in 'shape', where lhs is TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}) and rhs is TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={})
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142162
Approved by: https://github.com/zou3519
This commit is contained in:
parent
9ced54a51a
commit
7111cd6ee0
|
|
@ -6957,6 +6957,23 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
opt_test(True, False, inp)
|
opt_test(True, False, inp)
|
||||||
|
|
||||||
|
def test_cond_with_mismatched_output(self):
|
||||||
|
def output_mismatch_test(x):
|
||||||
|
def true_fn():
|
||||||
|
return torch.concat([x, x])
|
||||||
|
|
||||||
|
def false_fn():
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
return torch.cond(x.sum() > 0, true_fn, false_fn)
|
||||||
|
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
||||||
|
output_mismatch_test(x)
|
||||||
|
|
||||||
|
with self.assertRaises(torch._dynamo.exc.UncapturedHigherOrderOpError):
|
||||||
|
torch.compile(output_mismatch_test)(x)
|
||||||
|
|
||||||
def test_non_aliasing_util(self):
|
def test_non_aliasing_util(self):
|
||||||
from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing
|
from torch._dynamo.variables.higher_order_ops import _assert_tensors_nonaliasing
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,34 @@ def discard_graph_changes(tx):
|
||||||
ctx.__exit__(None, None, None)
|
ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def diff_meta(tensor_vars1, tensor_vars2) -> str:
|
||||||
|
from torch._higher_order_ops.utils import diff_tensor_meta
|
||||||
|
|
||||||
|
from . import TensorVariable
|
||||||
|
|
||||||
|
assert all(isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2)
|
||||||
|
all_diffs = []
|
||||||
|
for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
|
||||||
|
# We have vmap x cond tests and querying is_contiguous inside of vmap for
|
||||||
|
# memory_format other than torch.contiguous_format is not yet implemented.
|
||||||
|
# And it seems the remaining metas are good enough for now.
|
||||||
|
meta1 = _extract_tensor_metadata(
|
||||||
|
var1.proxy.node.meta["example_value"], include_contiguity=False
|
||||||
|
)
|
||||||
|
meta2 = _extract_tensor_metadata(
|
||||||
|
var2.proxy.node.meta["example_value"], include_contiguity=False
|
||||||
|
)
|
||||||
|
# We cannot get accurate require_grad. See Note [invariants for node meta 'val']
|
||||||
|
pair_diffs = diff_tensor_meta(meta1, meta2, check_grad=False)
|
||||||
|
|
||||||
|
if len(pair_diffs) > 0:
|
||||||
|
fmt_str = ", ".join(pair_diffs)
|
||||||
|
all_diffs.append(
|
||||||
|
f"pair[{i}] differ in {fmt_str}, where lhs is {meta1} and rhs is {meta2}"
|
||||||
|
)
|
||||||
|
return "\n".join(all_diffs)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def dynamo_enable_grad(tx: "InstructionTranslator", enable=True):
|
def dynamo_enable_grad(tx: "InstructionTranslator", enable=True):
|
||||||
from . import GradModeVariable
|
from . import GradModeVariable
|
||||||
|
|
@ -888,28 +916,11 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
if not same_treespec.as_python_constant():
|
if not same_treespec.as_python_constant():
|
||||||
unimplemented("Expected branches to return the same pytree structure.")
|
unimplemented("Expected branches to return the same pytree structure.")
|
||||||
|
|
||||||
def diff_meta(tensor_vars1, tensor_vars2):
|
|
||||||
assert all(
|
|
||||||
isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2
|
|
||||||
)
|
|
||||||
all_diffs = []
|
|
||||||
for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
|
|
||||||
# We check the meta data associated with meta["example_value"]
|
|
||||||
meta1 = _extract_tensor_metadata(
|
|
||||||
var1.proxy.node.meta["example_value"], include_contiguity=False
|
|
||||||
)
|
|
||||||
meta2 = _extract_tensor_metadata(
|
|
||||||
var2.proxy.node.meta["example_value"], include_contiguity=False
|
|
||||||
)
|
|
||||||
if meta1 != meta2:
|
|
||||||
all_diffs.append((f"pair{i}:", meta1, meta2))
|
|
||||||
return all_diffs
|
|
||||||
|
|
||||||
if diffs := diff_meta(
|
if diffs := diff_meta(
|
||||||
true_r.unpack_var_sequence(tx), false_r.unpack_var_sequence(tx)
|
true_r.unpack_var_sequence(tx), false_r.unpack_var_sequence(tx)
|
||||||
):
|
):
|
||||||
unimplemented(
|
unimplemented(
|
||||||
f"Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:{diffs}"
|
f"Expect branches to return tensors with same metadata but find {diffs}"
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
|
|
@ -1119,6 +1130,12 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
set_subgraph_inputs="flatten_manual",
|
set_subgraph_inputs="flatten_manual",
|
||||||
should_flatten_outputs=True,
|
should_flatten_outputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if diffs := diff_meta(operands_seq, body_r.unpack_var_sequence(tx)):
|
||||||
|
unimplemented(
|
||||||
|
f"Expected carried_inputs and body outputs return tensors with same metadata but find:\n{diffs}"
|
||||||
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
cond_graph,
|
cond_graph,
|
||||||
body_graph,
|
body_graph,
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import torch.fx.traceback as fx_traceback
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch._ops import OperatorBase
|
from torch._ops import OperatorBase
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
from torch.fx.passes.shape_prop import TensorMetadata
|
||||||
from torch.multiprocessing.reductions import StorageWeakRef
|
from torch.multiprocessing.reductions import StorageWeakRef
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -481,6 +482,27 @@ def first_slice_copy(t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
||||||
return torch.select_copy(t, dim, 0)
|
return torch.select_copy(t, dim, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# Reports the difference between meta of two tensors in a string
|
||||||
|
def diff_tensor_meta(
|
||||||
|
meta1: TensorMetadata, meta2: TensorMetadata, check_grad=True
|
||||||
|
) -> List[str]:
|
||||||
|
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
||||||
|
|
||||||
|
pair_diffs = []
|
||||||
|
for meta_name in TensorMetadata._fields:
|
||||||
|
if not check_grad and meta_name == "requires_grad":
|
||||||
|
continue
|
||||||
|
val1 = getattr(meta1, meta_name)
|
||||||
|
val2 = getattr(meta2, meta_name)
|
||||||
|
try:
|
||||||
|
if val1 != val2:
|
||||||
|
pair_diffs.append(f"'{meta_name}'")
|
||||||
|
except GuardOnDataDependentSymNode as _:
|
||||||
|
pair_diffs.append(f"'{meta_name}'")
|
||||||
|
continue
|
||||||
|
return pair_diffs
|
||||||
|
|
||||||
|
|
||||||
# Note [lifted arg types in hop]
|
# Note [lifted arg types in hop]
|
||||||
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
|
# For dynamoed hops, we automatically lift the free symbols in tensors as arguments.
|
||||||
# This has implications for the types of lifted args for different dispatch keys:
|
# This has implications for the types of lifted args for different dispatch keys:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user