diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 36e8765759b..061ddcb7c6c 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -582,18 +582,25 @@ class _TargetArgsExpr(_TargetExpr): def pytree_flatten( args: Sequence[Any], kwargs: Mapping[Any, Any] ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: - def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec: - if s.type is None: - return s - mapping = {immutable_list: list, tuple: list, immutable_dict: dict} - return pytree.TreeSpec( - mapping.get(s.type, s.type), - s.context, - list(map(norm_spec, s.children_specs)), - ) + type_mapping = {immutable_list: tuple, list: tuple, immutable_dict: dict} - flat, spec = pytree.tree_flatten([args, kwargs]) - spec = norm_spec(spec) + def convert_type(x: Any) -> Any: + cls = type(x) + convert_fn = type_mapping.get(cls) + if convert_fn is not None: + return pytree.tree_map( + convert_type, + convert_fn(x), + is_leaf=lambda x: type(x) in type_mapping, + ) + return x + + normalized_args_tree = pytree.tree_map( + convert_type, + (args, kwargs), + is_leaf=lambda x: type(x) in type_mapping, + ) + flat, spec = pytree.tree_flatten(normalized_args_tree) return flat, spec def __repr__(self) -> str: diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 16c1313a2d5..7334c79620d 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -136,15 +136,35 @@ class OutputAdapter: # TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276 -def _replace_tuple_with_list(spec: pytree.TreeSpec) -> pytree.TreeSpec: - _type = list if spec.type == tuple else spec.type - return pytree.TreeSpec( - _type, spec.context, list(map(_replace_tuple_with_list, spec.children_specs)) +# TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame. +class _DummyLeaf: # use a class instead. + pass + + +def _replace_list_with_tuple(spec: pytree.TreeSpec) -> pytree.TreeSpec: + def replace_list_with_tuple(x: Any) -> Any: + if type(x) is list: + return pytree.tree_map( + replace_list_with_tuple, + tuple(x), + is_leaf=lambda x: type(x) is list, + ) + return x + + dummy_leaf = _DummyLeaf() + dummy_tree = pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec) + dummy_tree = pytree.tree_map( + replace_list_with_tuple, + dummy_tree, + is_leaf=lambda x: type(x) is list, ) + return pytree.tree_structure(dummy_tree) -def _open_top_level_list_if_single_element(spec: pytree.TreeSpec) -> pytree.TreeSpec: - if spec.type == list and spec.num_children == 1: +def _open_top_level_sequence_if_single_element( + spec: pytree.TreeSpec, +) -> pytree.TreeSpec: + if spec.type in (tuple, list) and spec.num_children == 1: return spec.children_specs[0] return spec @@ -167,10 +187,10 @@ def _assert_identical_pytree_spec( pass_if_any_checks: Sequence[Callable[[], bool]] = [ lambda: spec1 == spec2, # FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'. - lambda: _replace_tuple_with_list(spec1) == _replace_tuple_with_list(spec2), + lambda: _replace_list_with_tuple(spec1) == _replace_list_with_tuple(spec2), # FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list. - lambda: _open_top_level_list_if_single_element(spec1) == spec2, - lambda: spec1 == _open_top_level_list_if_single_element(spec2), + lambda: _open_top_level_sequence_if_single_element(spec1) == spec2, + lambda: spec1 == _open_top_level_sequence_if_single_element(spec2), ] if not any(check() for check in pass_if_any_checks):