diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 209babb5947..4eeefa2f81f 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -108,7 +108,7 @@ def check_meta_consistency_vt( lhs_name: str, rhs_name: str, ) -> None: - from torch._higher_order_ops.while_loop import check_meta_consistency + from torch._higher_order_ops.utils import check_meta_consistency from . import TensorVariable diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index 55f9d0dcc31..9747282ec82 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -12,6 +12,7 @@ from torch._higher_order_ops.utils import ( _maybe_run_with_interpreter, _set_compilation_env, autograd_not_implemented, + check_meta_consistency, first_slice_copy, reenter_make_fx, unique_graph_id, @@ -371,12 +372,13 @@ def trace_associative_scan( xs ), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}" - for i, o in zip(xs, outputs): - o_meta = o.meta["tensor_meta"] - assert o_meta.dtype == i.dtype, ( - f"combine_fn output type mismatch, expected {i.dtype} " - + f"but got {o_meta.dtype}" - ) + xs_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [ + first_slice_copy(x) for x in xs + ] + output_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [ + c.meta["val"] for c in outputs + ] + check_meta_consistency(xs_fake_tensors, output_fake_tensors, "init", "carry") _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph") diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index a71f1de05c0..0d69141e619 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -13,6 +13,7 @@ from torch._higher_order_ops.utils import ( _has_potential_branch_input_mutation, _maybe_compile_and_run_fn, autograd_not_implemented, + check_meta_consistency, first_slice_copy, reenter_make_fx, unique_graph_id, @@ -314,20 +315,13 @@ def trace_scan( assert outputs is not None carry, output = _extract_carry_and_out(outputs, len(init)) - - for ini, ca in zip(init, carry): - ini_meta = ini - carry_meta = ca.meta["tensor_meta"] - carry_val = ca.meta["val"] - if ( - carry_val.device != ini_meta.device - or carry_meta.dtype != ini_meta.dtype - or carry_meta.shape != ini_meta.shape - ): - raise RuntimeError( - f"Expected metadata of the combine_fn result {carry_meta} to be the same as " - + f"the metadata of init with {ini_meta}" - ) + init_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [ + i.clone() for i in init + ] + carry_fake_tensors: list[torch.Tensor | torch.SymInt | int] = [ + c.meta["val"] for c in carry + ] + check_meta_consistency(init_fake_tensors, carry_fake_tensors, "init", "carry") _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph") diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 0870a077a4f..c0e288eb145 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -15,7 +15,7 @@ from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, make_fx, ) -from torch.fx.passes.shape_prop import TensorMetadata +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.multiprocessing.reductions import StorageWeakRef @@ -142,6 +142,88 @@ def _maybe_reenter_make_fx(fn): return _maybe_make_fx_with_fake_mode(fn) +def check_meta_consistency( + lhs_list: list[Union[torch.Tensor, torch.SymInt, int]], + rhs_list: list[Union[torch.Tensor, torch.SymInt, int]], + lhs_name: str, + rhs_name: str, +) -> None: + def diff_meta_pairs( + lhs_list: list[Union[torch.Tensor, torch.SymInt, int]], + rhs_list: list[Union[torch.Tensor, torch.SymInt, int]], + ) -> list[str]: + def diff_meta( + lhs: Union[torch.Tensor, torch.SymInt, int], + rhs: Union[torch.Tensor, torch.SymInt, int], + ) -> str: + if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor): + return ", ".join( + diff_tensor_meta( + # We set include contiguity=False because we have vmap x cond tests, where if + # include_contiguity=True will call t.is_contiguous inside of vmap and get an error + # "querying is_contiguous inside of vmap for memory_format other than + # torch.contiguous_format is not yet implemented". This is good for because stride + # is still checked. + _extract_tensor_metadata(lhs, include_contiguity=False), + _extract_tensor_metadata(rhs, include_contiguity=False), + check_grad=False, + ) + ) + else: + + def _both_int_types(lhs, rhs): + return isinstance(lhs, (int, torch.SymInt)) and isinstance( + rhs, (int, torch.SymInt) + ) + + def _both_tensor(lhs, rhs): + return isinstance(lhs, torch.Tensor) and isinstance( + rhs, torch.Tensor + ) + + if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs): + return f"type: {lhs} vs {rhs}" + + return "" + + # Manually check the device of lhs and rhs as this field is currently not part of TensorMetadata + def diff_device( + lhs: Union[torch.Tensor, torch.SymInt, int], + rhs: Union[torch.Tensor, torch.SymInt, int], + ) -> str: + if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor): + if ( + rhs.device.type == lhs.device.type + and rhs.device.index == lhs.device.index + ): + return "" + else: + return "device" + return "" + + if len(lhs_list) != len(rhs_list): + raise torch._dynamo.exc.UncapturedHigherOrderOpError( + f"Expected {lhs_name} and {rhs_name} to have same number of outputs but got lhs:{lhs_list} and rhs:{rhs_list}" + ) + all_diffs = [] + for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)): + if diff := diff_meta(lhs, rhs): + all_diffs.append( + f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}" + ) + if diff := diff_device(lhs, rhs): + all_diffs.append( + f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}" + ) + return all_diffs + + if all_diffs := diff_meta_pairs(lhs_list, rhs_list): + diff_str = "\n".join(all_diffs) + raise torch._dynamo.exc.UncapturedHigherOrderOpError( + f"Expected {lhs_name} and {rhs_name} to have same metadata but found:\n{diff_str}" + ) + + @contextmanager def _set_compilation_env(): _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 6aaee3280a0..3aa57ae91d2 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -11,7 +11,7 @@ from torch._higher_order_ops.utils import ( _maybe_run_with_interpreter, _set_compilation_env, autograd_not_implemented, - diff_tensor_meta, + check_meta_consistency, reenter_make_fx, UnsupportedAliasMutationException, validate_subgraph_args_types, @@ -23,7 +23,6 @@ from torch.fx.experimental.proxy_tensor import ( ProxyTorchDispatchMode, track_tensor_tree, ) -from torch.fx.passes.shape_prop import _extract_tensor_metadata class WhileLoopOp(HigherOrderOperator): @@ -340,88 +339,6 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs ) -def check_meta_consistency( - lhs_list: list[Union[torch.Tensor, torch.SymInt, int]], - rhs_list: list[Union[torch.Tensor, torch.SymInt, int]], - lhs_name: str, - rhs_name: str, -) -> None: - def diff_meta_pairs( - lhs_list: list[Union[torch.Tensor, torch.SymInt, int]], - rhs_list: list[Union[torch.Tensor, torch.SymInt, int]], - ) -> list[str]: - def diff_meta( - lhs: Union[torch.Tensor, torch.SymInt, int], - rhs: Union[torch.Tensor, torch.SymInt, int], - ) -> str: - if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor): - return ", ".join( - diff_tensor_meta( - # We set include contiguity=False because we have vmap x cond tests, where if - # include_contiguity=True will call t.is_contiguous inside of vmap and get an error - # "querying is_contiguous inside of vmap for memory_format other than - # torch.contiguous_format is not yet implemented". This is good for because stride - # is still checked. - _extract_tensor_metadata(lhs, include_contiguity=False), - _extract_tensor_metadata(rhs, include_contiguity=False), - check_grad=False, - ) - ) - else: - - def _both_int_types(lhs, rhs): - return isinstance(lhs, (int, torch.SymInt)) and isinstance( - rhs, (int, torch.SymInt) - ) - - def _both_tensor(lhs, rhs): - return isinstance(lhs, torch.Tensor) and isinstance( - rhs, torch.Tensor - ) - - if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs): - return f"type: {lhs} vs {rhs}" - - return "" - - # Manually check the device of lhs and rhs as this field is currently not part of TensorMetadata - def diff_device( - lhs: Union[torch.Tensor, torch.SymInt, int], - rhs: Union[torch.Tensor, torch.SymInt, int], - ) -> str: - if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor): - if ( - rhs.device.type == lhs.device.type - and rhs.device.index == lhs.device.index - ): - return "" - else: - return "device" - return "" - - if len(lhs_list) != len(rhs_list): - raise torch._dynamo.exc.UncapturedHigherOrderOpError( - f"Expected {lhs_name} and {rhs_name} to have same number of outputs but got lhs:{lhs_list} and rhs:{rhs_list}" - ) - all_diffs = [] - for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)): - if diff := diff_meta(lhs, rhs): - all_diffs.append( - f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}" - ) - if diff := diff_device(lhs, rhs): - all_diffs.append( - f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}" - ) - return all_diffs - - if all_diffs := diff_meta_pairs(lhs_list, rhs_list): - diff_str = "\n".join(all_diffs) - raise torch._dynamo.exc.UncapturedHigherOrderOpError( - f"Expected {lhs_name} and {rhs_name} to have same metadata but found:\n{diff_str}" - ) - - @while_loop_op.py_impl(FakeTensorMode) def while_loop_fake_tensor_mode( mode, cond_fn, body_fn, carried_inputs, additional_inputs