[hop] Rework the check of Metadata in the functionalization key (#148789)

This PR is a more cosmetic rework of the metadata check performed by some HOPs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148789
Approved by: https://github.com/ydwu4
This commit is contained in:
Thomas Bohnstingl 2025-03-18 20:30:59 +00:00 committed by PyTorch MergeBot
parent f06e366532
commit cd5c13d8f0
5 changed files with 101 additions and 106 deletions

View File

@ -108,7 +108,7 @@ def check_meta_consistency_vt(
lhs_name: str,
rhs_name: str,
) -> None:
from torch._higher_order_ops.while_loop import check_meta_consistency
from torch._higher_order_ops.utils import check_meta_consistency
from . import TensorVariable

View File

@ -12,6 +12,7 @@ from torch._higher_order_ops.utils import (
_maybe_run_with_interpreter,
_set_compilation_env,
autograd_not_implemented,
check_meta_consistency,
first_slice_copy,
reenter_make_fx,
unique_graph_id,
@ -371,12 +372,13 @@ def trace_associative_scan(
xs
), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}"
for i, o in zip(xs, outputs):
o_meta = o.meta["tensor_meta"]
assert o_meta.dtype == i.dtype, (
f"combine_fn output type mismatch, expected {i.dtype} "
+ f"but got {o_meta.dtype}"
)
xs_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
first_slice_copy(x) for x in xs
]
output_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
c.meta["val"] for c in outputs
]
check_meta_consistency(xs_fake_tensors, output_fake_tensors, "init", "carry")
_, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")

View File

@ -13,6 +13,7 @@ from torch._higher_order_ops.utils import (
_has_potential_branch_input_mutation,
_maybe_compile_and_run_fn,
autograd_not_implemented,
check_meta_consistency,
first_slice_copy,
reenter_make_fx,
unique_graph_id,
@ -314,20 +315,13 @@ def trace_scan(
assert outputs is not None
carry, output = _extract_carry_and_out(outputs, len(init))
for ini, ca in zip(init, carry):
ini_meta = ini
carry_meta = ca.meta["tensor_meta"]
carry_val = ca.meta["val"]
if (
carry_val.device != ini_meta.device
or carry_meta.dtype != ini_meta.dtype
or carry_meta.shape != ini_meta.shape
):
raise RuntimeError(
f"Expected metadata of the combine_fn result {carry_meta} to be the same as "
+ f"the metadata of init with {ini_meta}"
)
init_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
i.clone() for i in init
]
carry_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
c.meta["val"] for c in carry
]
check_meta_consistency(init_fake_tensors, carry_fake_tensors, "init", "carry")
_, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")

View File

@ -15,7 +15,7 @@ from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
make_fx,
)
from torch.fx.passes.shape_prop import TensorMetadata
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.multiprocessing.reductions import StorageWeakRef
@ -142,6 +142,88 @@ def _maybe_reenter_make_fx(fn):
return _maybe_make_fx_with_fake_mode(fn)
def check_meta_consistency(
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
lhs_name: str,
rhs_name: str,
) -> None:
def diff_meta_pairs(
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
) -> list[str]:
def diff_meta(
lhs: Union[torch.Tensor, torch.SymInt, int],
rhs: Union[torch.Tensor, torch.SymInt, int],
) -> str:
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
return ", ".join(
diff_tensor_meta(
# We set include contiguity=False because we have vmap x cond tests, where if
# include_contiguity=True will call t.is_contiguous inside of vmap and get an error
# "querying is_contiguous inside of vmap for memory_format other than
# torch.contiguous_format is not yet implemented". This is good for because stride
# is still checked.
_extract_tensor_metadata(lhs, include_contiguity=False),
_extract_tensor_metadata(rhs, include_contiguity=False),
check_grad=False,
)
)
else:
def _both_int_types(lhs, rhs):
return isinstance(lhs, (int, torch.SymInt)) and isinstance(
rhs, (int, torch.SymInt)
)
def _both_tensor(lhs, rhs):
return isinstance(lhs, torch.Tensor) and isinstance(
rhs, torch.Tensor
)
if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs):
return f"type: {lhs} vs {rhs}"
return ""
# Manually check the device of lhs and rhs as this field is currently not part of TensorMetadata
def diff_device(
lhs: Union[torch.Tensor, torch.SymInt, int],
rhs: Union[torch.Tensor, torch.SymInt, int],
) -> str:
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
if (
rhs.device.type == lhs.device.type
and rhs.device.index == lhs.device.index
):
return ""
else:
return "device"
return ""
if len(lhs_list) != len(rhs_list):
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
f"Expected {lhs_name} and {rhs_name} to have same number of outputs but got lhs:{lhs_list} and rhs:{rhs_list}"
)
all_diffs = []
for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)):
if diff := diff_meta(lhs, rhs):
all_diffs.append(
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
)
if diff := diff_device(lhs, rhs):
all_diffs.append(
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
)
return all_diffs
if all_diffs := diff_meta_pairs(lhs_list, rhs_list):
diff_str = "\n".join(all_diffs)
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
f"Expected {lhs_name} and {rhs_name} to have same metadata but found:\n{diff_str}"
)
@contextmanager
def _set_compilation_env():
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag

View File

@ -11,7 +11,7 @@ from torch._higher_order_ops.utils import (
_maybe_run_with_interpreter,
_set_compilation_env,
autograd_not_implemented,
diff_tensor_meta,
check_meta_consistency,
reenter_make_fx,
UnsupportedAliasMutationException,
validate_subgraph_args_types,
@ -23,7 +23,6 @@ from torch.fx.experimental.proxy_tensor import (
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
class WhileLoopOp(HigherOrderOperator):
@ -340,88 +339,6 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
)
def check_meta_consistency(
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
lhs_name: str,
rhs_name: str,
) -> None:
def diff_meta_pairs(
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
) -> list[str]:
def diff_meta(
lhs: Union[torch.Tensor, torch.SymInt, int],
rhs: Union[torch.Tensor, torch.SymInt, int],
) -> str:
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
return ", ".join(
diff_tensor_meta(
# We set include contiguity=False because we have vmap x cond tests, where if
# include_contiguity=True will call t.is_contiguous inside of vmap and get an error
# "querying is_contiguous inside of vmap for memory_format other than
# torch.contiguous_format is not yet implemented". This is good for because stride
# is still checked.
_extract_tensor_metadata(lhs, include_contiguity=False),
_extract_tensor_metadata(rhs, include_contiguity=False),
check_grad=False,
)
)
else:
def _both_int_types(lhs, rhs):
return isinstance(lhs, (int, torch.SymInt)) and isinstance(
rhs, (int, torch.SymInt)
)
def _both_tensor(lhs, rhs):
return isinstance(lhs, torch.Tensor) and isinstance(
rhs, torch.Tensor
)
if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs):
return f"type: {lhs} vs {rhs}"
return ""
# Manually check the device of lhs and rhs as this field is currently not part of TensorMetadata
def diff_device(
lhs: Union[torch.Tensor, torch.SymInt, int],
rhs: Union[torch.Tensor, torch.SymInt, int],
) -> str:
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
if (
rhs.device.type == lhs.device.type
and rhs.device.index == lhs.device.index
):
return ""
else:
return "device"
return ""
if len(lhs_list) != len(rhs_list):
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
f"Expected {lhs_name} and {rhs_name} to have same number of outputs but got lhs:{lhs_list} and rhs:{rhs_list}"
)
all_diffs = []
for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)):
if diff := diff_meta(lhs, rhs):
all_diffs.append(
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
)
if diff := diff_device(lhs, rhs):
all_diffs.append(
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
)
return all_diffs
if all_diffs := diff_meta_pairs(lhs_list, rhs_list):
diff_str = "\n".join(all_diffs)
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
f"Expected {lhs_name} and {rhs_name} to have same metadata but found:\n{diff_str}"
)
@while_loop_op.py_impl(FakeTensorMode)
def while_loop_fake_tensor_mode(
mode, cond_fn, body_fn, carried_inputs, additional_inputs