Cache even more work for return_and_correct_aliasing (#166365)

Yet another pass found even more work we can move to be done only once. This seems to knock a few microseconds off the DTensor dispatch fast path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166365
Approved by: https://github.com/bdhirsh
This commit is contained in:
Scott Wolchok 2025-10-30 09:34:35 -07:00 committed by PyTorch MergeBot
parent 239e7b541a
commit 0947765eb9

View File

@ -5,7 +5,7 @@ import warnings
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, overload, Protocol, Union
from typing import cast, Optional, overload, Protocol, Union
from typing_extensions import TypeIs
import torch
@ -601,15 +601,22 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
raise AssertionError(f"expected torch.Tensor, got {type(ret)}")
torch._functionalize_unsafe_set(ret, arg)
for arg_idx, schema_arg in enumerate(schema_info.args):
for return_idx, schema_out in enumerate(schema_info.outs):
is_read_only_alias_match = (
schema_arg.alias_set & schema_out.alias_set
) and not schema_arg.is_write
if is_read_only_alias_match:
for arg_idx, return_idx in schema_info.read_only_alias_match_indexes:
alias_non_inplace_storage(args[arg_idx], outs[return_idx])
def _get_write_alias(x) -> Optional[str]:
alias_set = x.alias_set
if not alias_set or not x.is_write:
return None
# torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
if len(alias_set) != 1:
raise AssertionError("Expected alias_set to contain exactly one element")
# timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for
# set of size 1 on Python 3.13.
return next(iter(alias_set))
# This abstracts over the fact that in return_and_correct_aliasing,
# we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy),
# and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested).
@ -625,13 +632,16 @@ class SchemaInfo:
args: list[AliasInfo]
outs: list[AliasInfo]
# NOTE[SchemaInfo int_tags]: This has nothing to do with aliasing, but we take
# advantage of our existing caching of data for each OpOverload to paper over an
# efficiency problem with pybind11::enum_ (which currently is used to implement
# torch.Tag): a scan over a list of pybind enums using `in` is inefficient because
# each element must be converted to int with the __int__ method, which incurs a lot
# of overhead. Converting to int once and caching removes this per-op overhead.
int_tags: list[int]
is_inplace_view_op: bool
# [_get_write_alias(x) for x in outs]. Guaranteed to contain no Nones; we coerce
# all-Nones result to empty list instead, and we don't support
# some-but-not-all-Nones.
outs_write_aliases: Optional[list[str]]
# List of (arg_idx, return_idx) where args[arg_idx].alias_set &
# outs[out_idx].alias_set is not empty, and not args[arg_idx].is_write.
read_only_alias_match_indexes: list[tuple[int, int]]
# Given an OpOverload, returns schema information on it.
@ -702,16 +712,39 @@ def get_alias_info(func) -> SchemaInfo:
)
for a in func._schema.returns
]
read_only_alias_match_indexes = []
for arg_idx, schema_arg in enumerate(arg_schemas):
for return_idx, schema_out in enumerate(out_schemas):
is_read_only_alias_match = (
schema_arg.alias_set & schema_out.alias_set
) and not schema_arg.is_write
if is_read_only_alias_match:
read_only_alias_match_indexes.append((arg_idx, return_idx))
outs_write_aliases_list: list[Optional[str]] = [
_get_write_alias(r) for r in out_schemas
]
non_nones = sum(x is not None for x in outs_write_aliases_list)
if non_nones == 0:
outs_write_aliases: Optional[list[str]] = None
elif non_nones != len(outs_write_aliases_list):
# simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
raise RuntimeError("Unsupported schema: " + str(func._schema))
else:
outs_write_aliases = cast(list[str], outs_write_aliases_list)
schema_info = SchemaInfo(
args=arg_schemas, outs=out_schemas, int_tags=[int(x) for x in func.tags]
args=arg_schemas,
outs=out_schemas,
# This check is surprisingly expensive because pybind11 enum_s are
# inefficient. Just cache it.
is_inplace_view_op=torch.Tag.inplace_view in func.tags,
outs_write_aliases=outs_write_aliases,
read_only_alias_match_indexes=read_only_alias_match_indexes,
)
return schema_info
# See NOTE[SchemaInfo int_tags] above.
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]
def return_and_correct_aliasing(func, args, kwargs, out):
"""
This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses
@ -732,17 +765,6 @@ def return_and_correct_aliasing(func, args, kwargs, out):
# once for every op in the graph during functionalization.
schema_info = get_alias_info(func)
def get_write_alias(x):
alias_set = x.alias_set
if not alias_set or not x.is_write:
return None
# torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
if len(alias_set) != 1:
raise AssertionError("Expected alias_set to contain exactly one element")
# timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for
# set of size 1 on Python 3.13.
return next(iter(alias_set))
def get_arg_from_alias(output_alias, schema_info, args, kwargs):
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs
@ -770,14 +792,13 @@ def return_and_correct_aliasing(func, args, kwargs, out):
# For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
# metadata is set correctly.
# See NOTE[SchemaInfo int_tags] above.
if _TORCH_TAG_INPLACE_VIEW_INT in schema_info.int_tags:
if schema_info.is_inplace_view_op:
# no_dispatch() to make sure that we secretly change the metadata on the wrapper,
# but don't end up dispatching the op anywhere else.
mutated_args = [
x
for i, x in enumerate(args)
if get_write_alias(schema_info.args[i]) is not None
if _get_write_alias(schema_info.args[i]) is not None
]
# Assumption: we have a very small number of inplace_view ops that follow a strict schema:
# there is only a single argument that gets its metadata mutated.
@ -803,16 +824,11 @@ def return_and_correct_aliasing(func, args, kwargs, out):
# Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()).
# Compute write aliases once instead of repeatedly.
schema_info_outs_write_aliases = [get_write_alias(r) for r in schema_info.outs]
schema_info_outs_write_aliases = schema_info.outs_write_aliases
# simple case: none of our outputs have mutable aliases, so we can return the output as-is
if not any(x is not None for x in schema_info_outs_write_aliases):
if schema_info_outs_write_aliases is None:
return out
# simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
if not all(x is not None for x in schema_info_outs_write_aliases):
raise RuntimeError("Unsupported schema: " + str(func._schema))
if len(schema_info_outs_write_aliases) == 1:
return get_arg_from_alias(
schema_info_outs_write_aliases[0], schema_info, args, kwargs