mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fast path binary ops in fake tensor (#94047)
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup. Before: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:53.97591 backend_compile:33.60832 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` After: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:40.18931 backend_compile:25.28828 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit# This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment: ``` diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e3bf545f3b8..395942c6ffe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} + with no_dispatch(): + if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}: + return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda') + if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: ``` I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.) The implementation here is based off of https://github.com/pytorch/pytorch/pull/93118/ but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences: * Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last). * I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right. Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1)) Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/94047 Approved by: https://github.com/eellison
This commit is contained in:
parent
0603f4ff14
commit
d690a596dc
|
|
@ -1,6 +1,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
@ -11,7 +12,12 @@ from weakref import ReferenceType
|
||||||
import torch
|
import torch
|
||||||
from torch._guards import Source
|
from torch._guards import Source
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
from torch._prims_common import is_float_dtype, is_integer_dtype
|
from torch._prims_common import (
|
||||||
|
elementwise_dtypes,
|
||||||
|
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||||
|
is_float_dtype,
|
||||||
|
is_integer_dtype,
|
||||||
|
)
|
||||||
from torch._subclasses.meta_utils import MetaConverter
|
from torch._subclasses.meta_utils import MetaConverter
|
||||||
from torch.fx.operator_schemas import normalize_function
|
from torch.fx.operator_schemas import normalize_function
|
||||||
from torch.multiprocessing.reductions import StorageWeakRef
|
from torch.multiprocessing.reductions import StorageWeakRef
|
||||||
|
|
@ -20,9 +26,11 @@ from torch.utils._mode_utils import no_dispatch
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only
|
from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only
|
||||||
from torch.utils._stats import count
|
from torch.utils._stats import count, count_label
|
||||||
from torch.utils.weak import WeakIdRef
|
from torch.utils.weak import WeakIdRef
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
pytree = torch.utils._pytree
|
pytree = torch.utils._pytree
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
TensorWeakRef = Any
|
TensorWeakRef = Any
|
||||||
|
|
@ -31,6 +39,22 @@ aten = torch._ops.ops.aten
|
||||||
|
|
||||||
CONSTANT_NUMEL_LIMIT = 1
|
CONSTANT_NUMEL_LIMIT = 1
|
||||||
|
|
||||||
|
RECURSION_COUNT = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Small helper that increments recursion count, and
|
||||||
|
# resets it when the object goes out of scope. Useful
|
||||||
|
# if you don't want to increase indentation which is
|
||||||
|
# what a context manager would do.
|
||||||
|
class IncrementRecursionCount:
|
||||||
|
def __init__(self):
|
||||||
|
global RECURSION_COUNT
|
||||||
|
RECURSION_COUNT += 1
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
global RECURSION_COUNT
|
||||||
|
RECURSION_COUNT -= 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnsupportedFakeTensorException(RuntimeError):
|
class UnsupportedFakeTensorException(RuntimeError):
|
||||||
|
|
@ -509,6 +533,189 @@ def conv(fake_mode, func, *args, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FAST_OP_IMPLEMENTATIONS = {}
|
||||||
|
|
||||||
|
|
||||||
|
# Unlike register_op_impl, these don't do the slow iteration for
|
||||||
|
# run_impl_check, and these run BEFORE decompositions
|
||||||
|
def register_fast_op_impl(func: OpOverload):
|
||||||
|
def impl_decorator(op_impl):
|
||||||
|
FAST_OP_IMPLEMENTATIONS[func] = op_impl
|
||||||
|
return op_impl
|
||||||
|
|
||||||
|
return impl_decorator
|
||||||
|
|
||||||
|
|
||||||
|
# infer_size_impl in ExpandUtils
|
||||||
|
def infer_size(a, b):
|
||||||
|
dimsA = len(a)
|
||||||
|
dimsB = len(b)
|
||||||
|
ndim = max(dimsA, dimsB)
|
||||||
|
expandedSizes = [0] * ndim
|
||||||
|
for i in range(ndim - 1, -1, -1):
|
||||||
|
offset = ndim - 1 - i
|
||||||
|
dimA = dimsA - 1 - offset
|
||||||
|
dimB = dimsB - 1 - offset
|
||||||
|
sizeA = a[dimA] if dimA >= 0 else 1
|
||||||
|
sizeB = b[dimB] if dimB >= 0 else 1
|
||||||
|
if not (sizeA == sizeB or sizeA == 1 or sizeB == 1):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The size of tensor a ({sizeA}) "
|
||||||
|
f"must match the size of tensor b ({sizeB}) "
|
||||||
|
f"at non-singleton dimension {i})"
|
||||||
|
)
|
||||||
|
expandedSizes[i] = sizeB if sizeA == 1 else sizeA
|
||||||
|
return tuple(expandedSizes)
|
||||||
|
|
||||||
|
|
||||||
|
def make_fast_binary_impl(slow_ref):
|
||||||
|
def fast_binary_impl(mode, *args, **kwargs):
|
||||||
|
def slow(msg):
|
||||||
|
count_label(f"slow {msg}")
|
||||||
|
with mode:
|
||||||
|
return slow_ref(*args, **kwargs)
|
||||||
|
|
||||||
|
count_label("attempt fast")
|
||||||
|
|
||||||
|
# Fast path (based off of TensorIterator fast path).
|
||||||
|
# Unfortunately, there is no way to easily deduplicate
|
||||||
|
# this with either the TensorIterator C++ implementation
|
||||||
|
# (which we don't want to SymIntify, and also the algorithm
|
||||||
|
# here is slightly different from TensorIterator to allow
|
||||||
|
# for broadcasting), nor the PrimTorch implementation
|
||||||
|
# (which does not actually implement a fast path.)
|
||||||
|
|
||||||
|
operands = args
|
||||||
|
|
||||||
|
# compute_shape
|
||||||
|
has_scalars = False
|
||||||
|
has_tensors = False
|
||||||
|
final_shape = None
|
||||||
|
for op in operands:
|
||||||
|
shape = op.shape if isinstance(op, torch.Tensor) else ()
|
||||||
|
if len(shape) == 0:
|
||||||
|
has_scalars = True
|
||||||
|
else:
|
||||||
|
has_tensors = True
|
||||||
|
if final_shape is None:
|
||||||
|
final_shape = shape
|
||||||
|
# TODO: Minor optimization: track if the shapes
|
||||||
|
# were equal so you can skip the equality check
|
||||||
|
# below if unnecessary
|
||||||
|
final_shape = infer_size(final_shape, shape)
|
||||||
|
assert final_shape is not None
|
||||||
|
|
||||||
|
# Do some extra safety checks to see if the output
|
||||||
|
# stride is obvious
|
||||||
|
for op in operands:
|
||||||
|
if isinstance(op, torch.Tensor) and op.shape == final_shape:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return slow("both tensors nontrivially broadcast")
|
||||||
|
|
||||||
|
# compute_types
|
||||||
|
cpu = torch.device("cpu")
|
||||||
|
common_device = cpu
|
||||||
|
common_dtype = None
|
||||||
|
output_dtype = None
|
||||||
|
has_different_input_dtypes = False
|
||||||
|
for op in operands:
|
||||||
|
if not isinstance(op, torch.Tensor):
|
||||||
|
# Use elementwise_dtypes for the tricky case
|
||||||
|
has_different_input_dtypes = True
|
||||||
|
continue
|
||||||
|
if common_device == cpu and not op.device.type == "cpu":
|
||||||
|
common_device = op.device
|
||||||
|
# Slightly simplified here as target_dtype cannot vary
|
||||||
|
if common_dtype is None:
|
||||||
|
common_dtype = op.dtype
|
||||||
|
elif common_dtype != op.dtype:
|
||||||
|
has_different_input_dtypes = True
|
||||||
|
|
||||||
|
if has_different_input_dtypes:
|
||||||
|
# compute promotion
|
||||||
|
# TODO: we don't need the compute type
|
||||||
|
_, common_dtype = elementwise_dtypes(
|
||||||
|
*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||||
|
)
|
||||||
|
|
||||||
|
# check all tensors on same device
|
||||||
|
# cpu scalars are assumed allow
|
||||||
|
current_cpu_scalars_on_non_cpu = 0
|
||||||
|
max_cpu_scalars_on_non_cpu = 1 # hard coded atm
|
||||||
|
for op in operands:
|
||||||
|
if not isinstance(op, torch.Tensor):
|
||||||
|
continue
|
||||||
|
if common_device != cpu and op.dim() == 0 and op.device == cpu:
|
||||||
|
if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
|
||||||
|
return slow("error")
|
||||||
|
current_cpu_scalars_on_non_cpu += 1
|
||||||
|
elif op.device != common_device:
|
||||||
|
return slow("error")
|
||||||
|
|
||||||
|
# compute_fast_setup_type
|
||||||
|
is_contiguous = True
|
||||||
|
is_channels_last = True
|
||||||
|
# TODO: is_non-overlapping_and_dense (not bound from Python
|
||||||
|
# no inplace, no out, everything defined
|
||||||
|
for op in operands:
|
||||||
|
if not isinstance(op, torch.Tensor):
|
||||||
|
continue
|
||||||
|
is_contiguous = is_contiguous and op.is_contiguous(
|
||||||
|
memory_format=torch.contiguous_format
|
||||||
|
)
|
||||||
|
is_channels_last = is_channels_last and op.is_contiguous(
|
||||||
|
memory_format=torch.channels_last
|
||||||
|
)
|
||||||
|
if is_contiguous:
|
||||||
|
# do contiguous
|
||||||
|
count_label("fast is_contiguous")
|
||||||
|
return FakeTensor(
|
||||||
|
mode,
|
||||||
|
torch.empty(
|
||||||
|
final_shape,
|
||||||
|
dtype=common_dtype,
|
||||||
|
device="meta",
|
||||||
|
memory_format=torch.contiguous_format,
|
||||||
|
),
|
||||||
|
device=common_device,
|
||||||
|
)
|
||||||
|
if is_channels_last:
|
||||||
|
count_label("fast channels_last")
|
||||||
|
# do channels last
|
||||||
|
return FakeTensor(
|
||||||
|
mode,
|
||||||
|
torch.empty(
|
||||||
|
final_shape,
|
||||||
|
dtype=common_dtype,
|
||||||
|
device="meta",
|
||||||
|
memory_format=torch.channels_last,
|
||||||
|
),
|
||||||
|
device=common_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return slow("no contiguity match")
|
||||||
|
|
||||||
|
return fast_binary_impl
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(None)
|
||||||
|
def get_fast_op_impls():
|
||||||
|
import torch._refs
|
||||||
|
|
||||||
|
register_fast_op_impl(torch.ops.aten.add.Tensor)(
|
||||||
|
make_fast_binary_impl(torch._refs.add)
|
||||||
|
)
|
||||||
|
register_fast_op_impl(torch.ops.aten.sub.Tensor)(
|
||||||
|
make_fast_binary_impl(torch._refs.sub)
|
||||||
|
)
|
||||||
|
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
|
||||||
|
register_fast_op_impl(torch.ops.aten.div.Tensor)(
|
||||||
|
make_fast_binary_impl(torch._refs.div)
|
||||||
|
)
|
||||||
|
return FAST_OP_IMPLEMENTATIONS
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def in_kernel_invocation_manager(fake_mode):
|
def in_kernel_invocation_manager(fake_mode):
|
||||||
# See: note [Fake Tensor Dispatch Keys]
|
# See: note [Fake Tensor Dispatch Keys]
|
||||||
|
|
@ -776,6 +983,13 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
|
|
||||||
@count
|
@count
|
||||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
try:
|
||||||
|
return self.dispatch(func, types, args, kwargs)
|
||||||
|
except TypeError:
|
||||||
|
log.exception("fake tensor raised TypeError")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def dispatch(self, func, types, args=(), kwargs=None):
|
||||||
kwargs = kwargs if kwargs else {}
|
kwargs = kwargs if kwargs else {}
|
||||||
|
|
||||||
if func == torch.ops.prim.device.default:
|
if func == torch.ops.prim.device.default:
|
||||||
|
|
@ -785,6 +999,12 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
else:
|
else:
|
||||||
return args[0].fake_device
|
return args[0].fake_device
|
||||||
|
|
||||||
|
if log.getEffectiveLevel() <= logging.DEBUG:
|
||||||
|
log.debug(
|
||||||
|
f"{' ' * RECURSION_COUNT}FakeTensorMode.__torch_dispatch__: {func}"
|
||||||
|
)
|
||||||
|
incr = IncrementRecursionCount()
|
||||||
|
|
||||||
# Some attribute queries that can be serviced directly
|
# Some attribute queries that can be serviced directly
|
||||||
# See Note [is_coalesced is dispatched]
|
# See Note [is_coalesced is dispatched]
|
||||||
if func in {
|
if func in {
|
||||||
|
|
@ -894,6 +1114,12 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
# is written to must be invalidated
|
# is written to must be invalidated
|
||||||
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
|
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
|
||||||
|
|
||||||
|
# Try for fastpath
|
||||||
|
if has_symbolic_sizes:
|
||||||
|
fast_impl = get_fast_op_impls().get(func)
|
||||||
|
if fast_impl is not None:
|
||||||
|
return fast_impl(self, *args, **kwargs)
|
||||||
|
|
||||||
# If there's a Python meta, prefer that over the decomposition
|
# If there's a Python meta, prefer that over the decomposition
|
||||||
from torch._decomp import meta_table as meta_table
|
from torch._decomp import meta_table as meta_table
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,13 @@
|
||||||
# AND SCRUB AWAY TORCH NOTIONS THERE.
|
# AND SCRUB AWAY TORCH NOTIONS THERE.
|
||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
|
from typing import OrderedDict
|
||||||
|
|
||||||
simple_call_counter = collections.OrderedDict()
|
simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
|
||||||
|
|
||||||
|
def count_label(label):
|
||||||
|
prev = simple_call_counter.setdefault(label, 0)
|
||||||
|
simple_call_counter[label] = prev + 1
|
||||||
|
|
||||||
def count(fn):
|
def count(fn):
|
||||||
@functools.wraps(fn)
|
@functools.wraps(fn)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user