Cache even more work for return_and_correct_aliasing (#166365)

Yet another pass found even more work we can move to be done only once. This seems to knock a few microseconds off the DTensor dispatch fast path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166365
Approved by: https://github.com/bdhirsh
This commit is contained in:
Scott Wolchok 2025-10-30 09:34:35 -07:00 committed by PyTorch MergeBot
parent 239e7b541a
commit 0947765eb9

View File

@ -5,7 +5,7 @@ import warnings
from collections import deque from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, overload, Protocol, Union from typing import cast, Optional, overload, Protocol, Union
from typing_extensions import TypeIs from typing_extensions import TypeIs
import torch import torch
@ -601,13 +601,20 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
raise AssertionError(f"expected torch.Tensor, got {type(ret)}") raise AssertionError(f"expected torch.Tensor, got {type(ret)}")
torch._functionalize_unsafe_set(ret, arg) torch._functionalize_unsafe_set(ret, arg)
for arg_idx, schema_arg in enumerate(schema_info.args): for arg_idx, return_idx in schema_info.read_only_alias_match_indexes:
for return_idx, schema_out in enumerate(schema_info.outs): alias_non_inplace_storage(args[arg_idx], outs[return_idx])
is_read_only_alias_match = (
schema_arg.alias_set & schema_out.alias_set
) and not schema_arg.is_write def _get_write_alias(x) -> Optional[str]:
if is_read_only_alias_match: alias_set = x.alias_set
alias_non_inplace_storage(args[arg_idx], outs[return_idx]) if not alias_set or not x.is_write:
return None
# torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
if len(alias_set) != 1:
raise AssertionError("Expected alias_set to contain exactly one element")
# timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for
# set of size 1 on Python 3.13.
return next(iter(alias_set))
# This abstracts over the fact that in return_and_correct_aliasing, # This abstracts over the fact that in return_and_correct_aliasing,
@ -625,13 +632,16 @@ class SchemaInfo:
args: list[AliasInfo] args: list[AliasInfo]
outs: list[AliasInfo] outs: list[AliasInfo]
# NOTE[SchemaInfo int_tags]: This has nothing to do with aliasing, but we take is_inplace_view_op: bool
# advantage of our existing caching of data for each OpOverload to paper over an
# efficiency problem with pybind11::enum_ (which currently is used to implement # [_get_write_alias(x) for x in outs]. Guaranteed to contain no Nones; we coerce
# torch.Tag): a scan over a list of pybind enums using `in` is inefficient because # all-Nones result to empty list instead, and we don't support
# each element must be converted to int with the __int__ method, which incurs a lot # some-but-not-all-Nones.
# of overhead. Converting to int once and caching removes this per-op overhead. outs_write_aliases: Optional[list[str]]
int_tags: list[int]
# List of (arg_idx, return_idx) where args[arg_idx].alias_set &
# outs[out_idx].alias_set is not empty, and not args[arg_idx].is_write.
read_only_alias_match_indexes: list[tuple[int, int]]
# Given an OpOverload, returns schema information on it. # Given an OpOverload, returns schema information on it.
@ -702,16 +712,39 @@ def get_alias_info(func) -> SchemaInfo:
) )
for a in func._schema.returns for a in func._schema.returns
] ]
read_only_alias_match_indexes = []
for arg_idx, schema_arg in enumerate(arg_schemas):
for return_idx, schema_out in enumerate(out_schemas):
is_read_only_alias_match = (
schema_arg.alias_set & schema_out.alias_set
) and not schema_arg.is_write
if is_read_only_alias_match:
read_only_alias_match_indexes.append((arg_idx, return_idx))
outs_write_aliases_list: list[Optional[str]] = [
_get_write_alias(r) for r in out_schemas
]
non_nones = sum(x is not None for x in outs_write_aliases_list)
if non_nones == 0:
outs_write_aliases: Optional[list[str]] = None
elif non_nones != len(outs_write_aliases_list):
# simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
raise RuntimeError("Unsupported schema: " + str(func._schema))
else:
outs_write_aliases = cast(list[str], outs_write_aliases_list)
schema_info = SchemaInfo( schema_info = SchemaInfo(
args=arg_schemas, outs=out_schemas, int_tags=[int(x) for x in func.tags] args=arg_schemas,
outs=out_schemas,
# This check is surprisingly expensive because pybind11 enum_s are
# inefficient. Just cache it.
is_inplace_view_op=torch.Tag.inplace_view in func.tags,
outs_write_aliases=outs_write_aliases,
read_only_alias_match_indexes=read_only_alias_match_indexes,
) )
return schema_info return schema_info
# See NOTE[SchemaInfo int_tags] above.
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]
def return_and_correct_aliasing(func, args, kwargs, out): def return_and_correct_aliasing(func, args, kwargs, out):
""" """
This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses
@ -732,17 +765,6 @@ def return_and_correct_aliasing(func, args, kwargs, out):
# once for every op in the graph during functionalization. # once for every op in the graph during functionalization.
schema_info = get_alias_info(func) schema_info = get_alias_info(func)
def get_write_alias(x):
alias_set = x.alias_set
if not alias_set or not x.is_write:
return None
# torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
if len(alias_set) != 1:
raise AssertionError("Expected alias_set to contain exactly one element")
# timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for
# set of size 1 on Python 3.13.
return next(iter(alias_set))
def get_arg_from_alias(output_alias, schema_info, args, kwargs): def get_arg_from_alias(output_alias, schema_info, args, kwargs):
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc] new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs func, args=args, kwargs=kwargs
@ -770,14 +792,13 @@ def return_and_correct_aliasing(func, args, kwargs, out):
# For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's # For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
# metadata is set correctly. # metadata is set correctly.
# See NOTE[SchemaInfo int_tags] above. if schema_info.is_inplace_view_op:
if _TORCH_TAG_INPLACE_VIEW_INT in schema_info.int_tags:
# no_dispatch() to make sure that we secretly change the metadata on the wrapper, # no_dispatch() to make sure that we secretly change the metadata on the wrapper,
# but don't end up dispatching the op anywhere else. # but don't end up dispatching the op anywhere else.
mutated_args = [ mutated_args = [
x x
for i, x in enumerate(args) for i, x in enumerate(args)
if get_write_alias(schema_info.args[i]) is not None if _get_write_alias(schema_info.args[i]) is not None
] ]
# Assumption: we have a very small number of inplace_view ops that follow a strict schema: # Assumption: we have a very small number of inplace_view ops that follow a strict schema:
# there is only a single argument that gets its metadata mutated. # there is only a single argument that gets its metadata mutated.
@ -803,16 +824,11 @@ def return_and_correct_aliasing(func, args, kwargs, out):
# Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()). # Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()).
# Compute write aliases once instead of repeatedly. schema_info_outs_write_aliases = schema_info.outs_write_aliases
schema_info_outs_write_aliases = [get_write_alias(r) for r in schema_info.outs]
# simple case: none of our outputs have mutable aliases, so we can return the output as-is # simple case: none of our outputs have mutable aliases, so we can return the output as-is
if not any(x is not None for x in schema_info_outs_write_aliases): if schema_info_outs_write_aliases is None:
return out return out
# simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
if not all(x is not None for x in schema_info_outs_write_aliases):
raise RuntimeError("Unsupported schema: " + str(func._schema))
if len(schema_info_outs_write_aliases) == 1: if len(schema_info_outs_write_aliases) == 1:
return get_arg_from_alias( return get_arg_from_alias(
schema_info_outs_write_aliases[0], schema_info, args, kwargs schema_info_outs_write_aliases[0], schema_info, args, kwargs