mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[hop] Rework the check of Metadata in the functionalization key (#148789)
This PR is a more cosmetic rework of the metadata check performed by some HOPs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148789 Approved by: https://github.com/ydwu4
This commit is contained in:
parent
f06e366532
commit
cd5c13d8f0
|
|
@ -108,7 +108,7 @@ def check_meta_consistency_vt(
|
|||
lhs_name: str,
|
||||
rhs_name: str,
|
||||
) -> None:
|
||||
from torch._higher_order_ops.while_loop import check_meta_consistency
|
||||
from torch._higher_order_ops.utils import check_meta_consistency
|
||||
|
||||
from . import TensorVariable
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from torch._higher_order_ops.utils import (
|
|||
_maybe_run_with_interpreter,
|
||||
_set_compilation_env,
|
||||
autograd_not_implemented,
|
||||
check_meta_consistency,
|
||||
first_slice_copy,
|
||||
reenter_make_fx,
|
||||
unique_graph_id,
|
||||
|
|
@ -371,12 +372,13 @@ def trace_associative_scan(
|
|||
xs
|
||||
), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}"
|
||||
|
||||
for i, o in zip(xs, outputs):
|
||||
o_meta = o.meta["tensor_meta"]
|
||||
assert o_meta.dtype == i.dtype, (
|
||||
f"combine_fn output type mismatch, expected {i.dtype} "
|
||||
+ f"but got {o_meta.dtype}"
|
||||
)
|
||||
xs_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
|
||||
first_slice_copy(x) for x in xs
|
||||
]
|
||||
output_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
|
||||
c.meta["val"] for c in outputs
|
||||
]
|
||||
check_meta_consistency(xs_fake_tensors, output_fake_tensors, "init", "carry")
|
||||
|
||||
_, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from torch._higher_order_ops.utils import (
|
|||
_has_potential_branch_input_mutation,
|
||||
_maybe_compile_and_run_fn,
|
||||
autograd_not_implemented,
|
||||
check_meta_consistency,
|
||||
first_slice_copy,
|
||||
reenter_make_fx,
|
||||
unique_graph_id,
|
||||
|
|
@ -314,20 +315,13 @@ def trace_scan(
|
|||
assert outputs is not None
|
||||
|
||||
carry, output = _extract_carry_and_out(outputs, len(init))
|
||||
|
||||
for ini, ca in zip(init, carry):
|
||||
ini_meta = ini
|
||||
carry_meta = ca.meta["tensor_meta"]
|
||||
carry_val = ca.meta["val"]
|
||||
if (
|
||||
carry_val.device != ini_meta.device
|
||||
or carry_meta.dtype != ini_meta.dtype
|
||||
or carry_meta.shape != ini_meta.shape
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Expected metadata of the combine_fn result {carry_meta} to be the same as "
|
||||
+ f"the metadata of init with {ini_meta}"
|
||||
)
|
||||
init_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
|
||||
i.clone() for i in init
|
||||
]
|
||||
carry_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [
|
||||
c.meta["val"] for c in carry
|
||||
]
|
||||
check_meta_consistency(init_fake_tensors, carry_fake_tensors, "init", "carry")
|
||||
|
||||
_, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
disable_proxy_modes_tracing,
|
||||
make_fx,
|
||||
)
|
||||
from torch.fx.passes.shape_prop import TensorMetadata
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
|
||||
|
|
@ -142,6 +142,88 @@ def _maybe_reenter_make_fx(fn):
|
|||
return _maybe_make_fx_with_fake_mode(fn)
|
||||
|
||||
|
||||
def check_meta_consistency(
|
||||
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
lhs_name: str,
|
||||
rhs_name: str,
|
||||
) -> None:
|
||||
def diff_meta_pairs(
|
||||
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
) -> list[str]:
|
||||
def diff_meta(
|
||||
lhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
rhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
) -> str:
|
||||
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
|
||||
return ", ".join(
|
||||
diff_tensor_meta(
|
||||
# We set include contiguity=False because we have vmap x cond tests, where if
|
||||
# include_contiguity=True will call t.is_contiguous inside of vmap and get an error
|
||||
# "querying is_contiguous inside of vmap for memory_format other than
|
||||
# torch.contiguous_format is not yet implemented". This is good for because stride
|
||||
# is still checked.
|
||||
_extract_tensor_metadata(lhs, include_contiguity=False),
|
||||
_extract_tensor_metadata(rhs, include_contiguity=False),
|
||||
check_grad=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
def _both_int_types(lhs, rhs):
|
||||
return isinstance(lhs, (int, torch.SymInt)) and isinstance(
|
||||
rhs, (int, torch.SymInt)
|
||||
)
|
||||
|
||||
def _both_tensor(lhs, rhs):
|
||||
return isinstance(lhs, torch.Tensor) and isinstance(
|
||||
rhs, torch.Tensor
|
||||
)
|
||||
|
||||
if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs):
|
||||
return f"type: {lhs} vs {rhs}"
|
||||
|
||||
return ""
|
||||
|
||||
# Manually check the device of lhs and rhs as this field is currently not part of TensorMetadata
|
||||
def diff_device(
|
||||
lhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
rhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
) -> str:
|
||||
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
|
||||
if (
|
||||
rhs.device.type == lhs.device.type
|
||||
and rhs.device.index == lhs.device.index
|
||||
):
|
||||
return ""
|
||||
else:
|
||||
return "device"
|
||||
return ""
|
||||
|
||||
if len(lhs_list) != len(rhs_list):
|
||||
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
|
||||
f"Expected {lhs_name} and {rhs_name} to have same number of outputs but got lhs:{lhs_list} and rhs:{rhs_list}"
|
||||
)
|
||||
all_diffs = []
|
||||
for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)):
|
||||
if diff := diff_meta(lhs, rhs):
|
||||
all_diffs.append(
|
||||
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
|
||||
)
|
||||
if diff := diff_device(lhs, rhs):
|
||||
all_diffs.append(
|
||||
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
|
||||
)
|
||||
return all_diffs
|
||||
|
||||
if all_diffs := diff_meta_pairs(lhs_list, rhs_list):
|
||||
diff_str = "\n".join(all_diffs)
|
||||
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
|
||||
f"Expected {lhs_name} and {rhs_name} to have same metadata but found:\n{diff_str}"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_compilation_env():
|
||||
_old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from torch._higher_order_ops.utils import (
|
|||
_maybe_run_with_interpreter,
|
||||
_set_compilation_env,
|
||||
autograd_not_implemented,
|
||||
diff_tensor_meta,
|
||||
check_meta_consistency,
|
||||
reenter_make_fx,
|
||||
UnsupportedAliasMutationException,
|
||||
validate_subgraph_args_types,
|
||||
|
|
@ -23,7 +23,6 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
|
||||
|
||||
class WhileLoopOp(HigherOrderOperator):
|
||||
|
|
@ -340,88 +339,6 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
|
|||
)
|
||||
|
||||
|
||||
def check_meta_consistency(
|
||||
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
lhs_name: str,
|
||||
rhs_name: str,
|
||||
) -> None:
|
||||
def diff_meta_pairs(
|
||||
lhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
rhs_list: list[Union[torch.Tensor, torch.SymInt, int]],
|
||||
) -> list[str]:
|
||||
def diff_meta(
|
||||
lhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
rhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
) -> str:
|
||||
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
|
||||
return ", ".join(
|
||||
diff_tensor_meta(
|
||||
# We set include contiguity=False because we have vmap x cond tests, where if
|
||||
# include_contiguity=True will call t.is_contiguous inside of vmap and get an error
|
||||
# "querying is_contiguous inside of vmap for memory_format other than
|
||||
# torch.contiguous_format is not yet implemented". This is good for because stride
|
||||
# is still checked.
|
||||
_extract_tensor_metadata(lhs, include_contiguity=False),
|
||||
_extract_tensor_metadata(rhs, include_contiguity=False),
|
||||
check_grad=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
def _both_int_types(lhs, rhs):
|
||||
return isinstance(lhs, (int, torch.SymInt)) and isinstance(
|
||||
rhs, (int, torch.SymInt)
|
||||
)
|
||||
|
||||
def _both_tensor(lhs, rhs):
|
||||
return isinstance(lhs, torch.Tensor) and isinstance(
|
||||
rhs, torch.Tensor
|
||||
)
|
||||
|
||||
if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs):
|
||||
return f"type: {lhs} vs {rhs}"
|
||||
|
||||
return ""
|
||||
|
||||
# Manually check the device of lhs and rhs as this field is currently not part of TensorMetadata
|
||||
def diff_device(
|
||||
lhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
rhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
) -> str:
|
||||
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
|
||||
if (
|
||||
rhs.device.type == lhs.device.type
|
||||
and rhs.device.index == lhs.device.index
|
||||
):
|
||||
return ""
|
||||
else:
|
||||
return "device"
|
||||
return ""
|
||||
|
||||
if len(lhs_list) != len(rhs_list):
|
||||
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
|
||||
f"Expected {lhs_name} and {rhs_name} to have same number of outputs but got lhs:{lhs_list} and rhs:{rhs_list}"
|
||||
)
|
||||
all_diffs = []
|
||||
for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)):
|
||||
if diff := diff_meta(lhs, rhs):
|
||||
all_diffs.append(
|
||||
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
|
||||
)
|
||||
if diff := diff_device(lhs, rhs):
|
||||
all_diffs.append(
|
||||
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
|
||||
)
|
||||
return all_diffs
|
||||
|
||||
if all_diffs := diff_meta_pairs(lhs_list, rhs_list):
|
||||
diff_str = "\n".join(all_diffs)
|
||||
raise torch._dynamo.exc.UncapturedHigherOrderOpError(
|
||||
f"Expected {lhs_name} and {rhs_name} to have same metadata but found:\n{diff_str}"
|
||||
)
|
||||
|
||||
|
||||
@while_loop_op.py_impl(FakeTensorMode)
|
||||
def while_loop_fake_tensor_mode(
|
||||
mode, cond_fn, body_fn, carried_inputs, additional_inputs
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user