From 0947765eb9208996f221dbcb088df800be3953d7 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 30 Oct 2025 09:34:35 -0700 Subject: [PATCH] Cache even more work for return_and_correct_aliasing (#166365) Yet another pass found even more work we can move to be done only once. This seems to knock a few microseconds off the DTensor dispatch fast path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166365 Approved by: https://github.com/bdhirsh --- torch/utils/_python_dispatch.py | 98 +++++++++++++++++++-------------- 1 file changed, 57 insertions(+), 41 deletions(-) diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 4ab48bc41ba..d853a25daca 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -5,7 +5,7 @@ import warnings from collections import deque from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional, overload, Protocol, Union +from typing import cast, Optional, overload, Protocol, Union from typing_extensions import TypeIs import torch @@ -601,13 +601,20 @@ def _correct_storage_aliasing(func, schema_info, args, outs): raise AssertionError(f"expected torch.Tensor, got {type(ret)}") torch._functionalize_unsafe_set(ret, arg) - for arg_idx, schema_arg in enumerate(schema_info.args): - for return_idx, schema_out in enumerate(schema_info.outs): - is_read_only_alias_match = ( - schema_arg.alias_set & schema_out.alias_set - ) and not schema_arg.is_write - if is_read_only_alias_match: - alias_non_inplace_storage(args[arg_idx], outs[return_idx]) + for arg_idx, return_idx in schema_info.read_only_alias_match_indexes: + alias_non_inplace_storage(args[arg_idx], outs[return_idx]) + + +def _get_write_alias(x) -> Optional[str]: + alias_set = x.alias_set + if not alias_set or not x.is_write: + return None + # torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing + if len(alias_set) != 1: + raise AssertionError("Expected alias_set to contain exactly one element") + # timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for + # set of size 1 on Python 3.13. + return next(iter(alias_set)) # This abstracts over the fact that in return_and_correct_aliasing, @@ -625,13 +632,16 @@ class SchemaInfo: args: list[AliasInfo] outs: list[AliasInfo] - # NOTE[SchemaInfo int_tags]: This has nothing to do with aliasing, but we take - # advantage of our existing caching of data for each OpOverload to paper over an - # efficiency problem with pybind11::enum_ (which currently is used to implement - # torch.Tag): a scan over a list of pybind enums using `in` is inefficient because - # each element must be converted to int with the __int__ method, which incurs a lot - # of overhead. Converting to int once and caching removes this per-op overhead. - int_tags: list[int] + is_inplace_view_op: bool + + # [_get_write_alias(x) for x in outs]. Guaranteed to contain no Nones; we coerce + # all-Nones result to empty list instead, and we don't support + # some-but-not-all-Nones. + outs_write_aliases: Optional[list[str]] + + # List of (arg_idx, return_idx) where args[arg_idx].alias_set & + # outs[out_idx].alias_set is not empty, and not args[arg_idx].is_write. + read_only_alias_match_indexes: list[tuple[int, int]] # Given an OpOverload, returns schema information on it. @@ -702,16 +712,39 @@ def get_alias_info(func) -> SchemaInfo: ) for a in func._schema.returns ] + read_only_alias_match_indexes = [] + for arg_idx, schema_arg in enumerate(arg_schemas): + for return_idx, schema_out in enumerate(out_schemas): + is_read_only_alias_match = ( + schema_arg.alias_set & schema_out.alias_set + ) and not schema_arg.is_write + if is_read_only_alias_match: + read_only_alias_match_indexes.append((arg_idx, return_idx)) + + outs_write_aliases_list: list[Optional[str]] = [ + _get_write_alias(r) for r in out_schemas + ] + non_nones = sum(x is not None for x in outs_write_aliases_list) + if non_nones == 0: + outs_write_aliases: Optional[list[str]] = None + elif non_nones != len(outs_write_aliases_list): + # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)" + raise RuntimeError("Unsupported schema: " + str(func._schema)) + else: + outs_write_aliases = cast(list[str], outs_write_aliases_list) + schema_info = SchemaInfo( - args=arg_schemas, outs=out_schemas, int_tags=[int(x) for x in func.tags] + args=arg_schemas, + outs=out_schemas, + # This check is surprisingly expensive because pybind11 enum_s are + # inefficient. Just cache it. + is_inplace_view_op=torch.Tag.inplace_view in func.tags, + outs_write_aliases=outs_write_aliases, + read_only_alias_match_indexes=read_only_alias_match_indexes, ) return schema_info -# See NOTE[SchemaInfo int_tags] above. -_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload] - - def return_and_correct_aliasing(func, args, kwargs, out): """ This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses @@ -732,17 +765,6 @@ def return_and_correct_aliasing(func, args, kwargs, out): # once for every op in the graph during functionalization. schema_info = get_alias_info(func) - def get_write_alias(x): - alias_set = x.alias_set - if not alias_set or not x.is_write: - return None - # torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing - if len(alias_set) != 1: - raise AssertionError("Expected alias_set to contain exactly one element") - # timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for - # set of size 1 on Python 3.13. - return next(iter(alias_set)) - def get_arg_from_alias(output_alias, schema_info, args, kwargs): new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc] func, args=args, kwargs=kwargs @@ -770,14 +792,13 @@ def return_and_correct_aliasing(func, args, kwargs, out): # For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's # metadata is set correctly. - # See NOTE[SchemaInfo int_tags] above. - if _TORCH_TAG_INPLACE_VIEW_INT in schema_info.int_tags: + if schema_info.is_inplace_view_op: # no_dispatch() to make sure that we secretly change the metadata on the wrapper, # but don't end up dispatching the op anywhere else. mutated_args = [ x for i, x in enumerate(args) - if get_write_alias(schema_info.args[i]) is not None + if _get_write_alias(schema_info.args[i]) is not None ] # Assumption: we have a very small number of inplace_view ops that follow a strict schema: # there is only a single argument that gets its metadata mutated. @@ -803,16 +824,11 @@ def return_and_correct_aliasing(func, args, kwargs, out): # Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()). - # Compute write aliases once instead of repeatedly. - schema_info_outs_write_aliases = [get_write_alias(r) for r in schema_info.outs] + schema_info_outs_write_aliases = schema_info.outs_write_aliases # simple case: none of our outputs have mutable aliases, so we can return the output as-is - if not any(x is not None for x in schema_info_outs_write_aliases): + if schema_info_outs_write_aliases is None: return out - # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)" - if not all(x is not None for x in schema_info_outs_write_aliases): - raise RuntimeError("Unsupported schema: " + str(func._schema)) - if len(schema_info_outs_write_aliases) == 1: return get_arg_from_alias( schema_info_outs_write_aliases[0], schema_info, args, kwargs