mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Cache even more work for return_and_correct_aliasing (#166365)
Yet another pass found even more work we can move to be done only once. This seems to knock a few microseconds off the DTensor dispatch fast path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166365 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
239e7b541a
commit
0947765eb9
|
|
@ -5,7 +5,7 @@ import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, overload, Protocol, Union
|
from typing import cast, Optional, overload, Protocol, Union
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -601,13 +601,20 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
|
||||||
raise AssertionError(f"expected torch.Tensor, got {type(ret)}")
|
raise AssertionError(f"expected torch.Tensor, got {type(ret)}")
|
||||||
torch._functionalize_unsafe_set(ret, arg)
|
torch._functionalize_unsafe_set(ret, arg)
|
||||||
|
|
||||||
for arg_idx, schema_arg in enumerate(schema_info.args):
|
for arg_idx, return_idx in schema_info.read_only_alias_match_indexes:
|
||||||
for return_idx, schema_out in enumerate(schema_info.outs):
|
alias_non_inplace_storage(args[arg_idx], outs[return_idx])
|
||||||
is_read_only_alias_match = (
|
|
||||||
schema_arg.alias_set & schema_out.alias_set
|
|
||||||
) and not schema_arg.is_write
|
def _get_write_alias(x) -> Optional[str]:
|
||||||
if is_read_only_alias_match:
|
alias_set = x.alias_set
|
||||||
alias_non_inplace_storage(args[arg_idx], outs[return_idx])
|
if not alias_set or not x.is_write:
|
||||||
|
return None
|
||||||
|
# torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
|
||||||
|
if len(alias_set) != 1:
|
||||||
|
raise AssertionError("Expected alias_set to contain exactly one element")
|
||||||
|
# timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for
|
||||||
|
# set of size 1 on Python 3.13.
|
||||||
|
return next(iter(alias_set))
|
||||||
|
|
||||||
|
|
||||||
# This abstracts over the fact that in return_and_correct_aliasing,
|
# This abstracts over the fact that in return_and_correct_aliasing,
|
||||||
|
|
@ -625,13 +632,16 @@ class SchemaInfo:
|
||||||
args: list[AliasInfo]
|
args: list[AliasInfo]
|
||||||
outs: list[AliasInfo]
|
outs: list[AliasInfo]
|
||||||
|
|
||||||
# NOTE[SchemaInfo int_tags]: This has nothing to do with aliasing, but we take
|
is_inplace_view_op: bool
|
||||||
# advantage of our existing caching of data for each OpOverload to paper over an
|
|
||||||
# efficiency problem with pybind11::enum_ (which currently is used to implement
|
# [_get_write_alias(x) for x in outs]. Guaranteed to contain no Nones; we coerce
|
||||||
# torch.Tag): a scan over a list of pybind enums using `in` is inefficient because
|
# all-Nones result to empty list instead, and we don't support
|
||||||
# each element must be converted to int with the __int__ method, which incurs a lot
|
# some-but-not-all-Nones.
|
||||||
# of overhead. Converting to int once and caching removes this per-op overhead.
|
outs_write_aliases: Optional[list[str]]
|
||||||
int_tags: list[int]
|
|
||||||
|
# List of (arg_idx, return_idx) where args[arg_idx].alias_set &
|
||||||
|
# outs[out_idx].alias_set is not empty, and not args[arg_idx].is_write.
|
||||||
|
read_only_alias_match_indexes: list[tuple[int, int]]
|
||||||
|
|
||||||
|
|
||||||
# Given an OpOverload, returns schema information on it.
|
# Given an OpOverload, returns schema information on it.
|
||||||
|
|
@ -702,16 +712,39 @@ def get_alias_info(func) -> SchemaInfo:
|
||||||
)
|
)
|
||||||
for a in func._schema.returns
|
for a in func._schema.returns
|
||||||
]
|
]
|
||||||
|
read_only_alias_match_indexes = []
|
||||||
|
for arg_idx, schema_arg in enumerate(arg_schemas):
|
||||||
|
for return_idx, schema_out in enumerate(out_schemas):
|
||||||
|
is_read_only_alias_match = (
|
||||||
|
schema_arg.alias_set & schema_out.alias_set
|
||||||
|
) and not schema_arg.is_write
|
||||||
|
if is_read_only_alias_match:
|
||||||
|
read_only_alias_match_indexes.append((arg_idx, return_idx))
|
||||||
|
|
||||||
|
outs_write_aliases_list: list[Optional[str]] = [
|
||||||
|
_get_write_alias(r) for r in out_schemas
|
||||||
|
]
|
||||||
|
non_nones = sum(x is not None for x in outs_write_aliases_list)
|
||||||
|
if non_nones == 0:
|
||||||
|
outs_write_aliases: Optional[list[str]] = None
|
||||||
|
elif non_nones != len(outs_write_aliases_list):
|
||||||
|
# simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
|
||||||
|
raise RuntimeError("Unsupported schema: " + str(func._schema))
|
||||||
|
else:
|
||||||
|
outs_write_aliases = cast(list[str], outs_write_aliases_list)
|
||||||
|
|
||||||
schema_info = SchemaInfo(
|
schema_info = SchemaInfo(
|
||||||
args=arg_schemas, outs=out_schemas, int_tags=[int(x) for x in func.tags]
|
args=arg_schemas,
|
||||||
|
outs=out_schemas,
|
||||||
|
# This check is surprisingly expensive because pybind11 enum_s are
|
||||||
|
# inefficient. Just cache it.
|
||||||
|
is_inplace_view_op=torch.Tag.inplace_view in func.tags,
|
||||||
|
outs_write_aliases=outs_write_aliases,
|
||||||
|
read_only_alias_match_indexes=read_only_alias_match_indexes,
|
||||||
)
|
)
|
||||||
return schema_info
|
return schema_info
|
||||||
|
|
||||||
|
|
||||||
# See NOTE[SchemaInfo int_tags] above.
|
|
||||||
_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload]
|
|
||||||
|
|
||||||
|
|
||||||
def return_and_correct_aliasing(func, args, kwargs, out):
|
def return_and_correct_aliasing(func, args, kwargs, out):
|
||||||
"""
|
"""
|
||||||
This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses
|
This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses
|
||||||
|
|
@ -732,17 +765,6 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
||||||
# once for every op in the graph during functionalization.
|
# once for every op in the graph during functionalization.
|
||||||
schema_info = get_alias_info(func)
|
schema_info = get_alias_info(func)
|
||||||
|
|
||||||
def get_write_alias(x):
|
|
||||||
alias_set = x.alias_set
|
|
||||||
if not alias_set or not x.is_write:
|
|
||||||
return None
|
|
||||||
# torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
|
|
||||||
if len(alias_set) != 1:
|
|
||||||
raise AssertionError("Expected alias_set to contain exactly one element")
|
|
||||||
# timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for
|
|
||||||
# set of size 1 on Python 3.13.
|
|
||||||
return next(iter(alias_set))
|
|
||||||
|
|
||||||
def get_arg_from_alias(output_alias, schema_info, args, kwargs):
|
def get_arg_from_alias(output_alias, schema_info, args, kwargs):
|
||||||
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc]
|
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc]
|
||||||
func, args=args, kwargs=kwargs
|
func, args=args, kwargs=kwargs
|
||||||
|
|
@ -770,14 +792,13 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
||||||
|
|
||||||
# For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
|
# For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
|
||||||
# metadata is set correctly.
|
# metadata is set correctly.
|
||||||
# See NOTE[SchemaInfo int_tags] above.
|
if schema_info.is_inplace_view_op:
|
||||||
if _TORCH_TAG_INPLACE_VIEW_INT in schema_info.int_tags:
|
|
||||||
# no_dispatch() to make sure that we secretly change the metadata on the wrapper,
|
# no_dispatch() to make sure that we secretly change the metadata on the wrapper,
|
||||||
# but don't end up dispatching the op anywhere else.
|
# but don't end up dispatching the op anywhere else.
|
||||||
mutated_args = [
|
mutated_args = [
|
||||||
x
|
x
|
||||||
for i, x in enumerate(args)
|
for i, x in enumerate(args)
|
||||||
if get_write_alias(schema_info.args[i]) is not None
|
if _get_write_alias(schema_info.args[i]) is not None
|
||||||
]
|
]
|
||||||
# Assumption: we have a very small number of inplace_view ops that follow a strict schema:
|
# Assumption: we have a very small number of inplace_view ops that follow a strict schema:
|
||||||
# there is only a single argument that gets its metadata mutated.
|
# there is only a single argument that gets its metadata mutated.
|
||||||
|
|
@ -803,16 +824,11 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
||||||
|
|
||||||
# Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()).
|
# Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()).
|
||||||
|
|
||||||
# Compute write aliases once instead of repeatedly.
|
schema_info_outs_write_aliases = schema_info.outs_write_aliases
|
||||||
schema_info_outs_write_aliases = [get_write_alias(r) for r in schema_info.outs]
|
|
||||||
# simple case: none of our outputs have mutable aliases, so we can return the output as-is
|
# simple case: none of our outputs have mutable aliases, so we can return the output as-is
|
||||||
if not any(x is not None for x in schema_info_outs_write_aliases):
|
if schema_info_outs_write_aliases is None:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
|
|
||||||
if not all(x is not None for x in schema_info_outs_write_aliases):
|
|
||||||
raise RuntimeError("Unsupported schema: " + str(func._schema))
|
|
||||||
|
|
||||||
if len(schema_info_outs_write_aliases) == 1:
|
if len(schema_info_outs_write_aliases) == 1:
|
||||||
return get_arg_from_alias(
|
return get_arg_from_alias(
|
||||||
schema_info_outs_write_aliases[0], schema_info, args, kwargs
|
schema_info_outs_write_aliases[0], schema_info, args, kwargs
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user