Fast path binary ops in fake tensor (#94047)

Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}

+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.)

The implementation here is based off of https://github.com/pytorch/pytorch/pull/93118/ but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last).
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94047
Approved by: https://github.com/eellison
This commit is contained in:
Edward Z. Yang 2023-02-07 07:15:15 -08:00 committed by PyTorch MergeBot
parent 0603f4ff14
commit d690a596dc
2 changed files with 234 additions and 3 deletions

View File

@ -1,6 +1,7 @@
import contextlib import contextlib
import functools import functools
import itertools import itertools
import logging
import os import os
import weakref import weakref
from dataclasses import dataclass from dataclasses import dataclass
@ -11,7 +12,12 @@ from weakref import ReferenceType
import torch import torch
from torch._guards import Source from torch._guards import Source
from torch._ops import OpOverload from torch._ops import OpOverload
from torch._prims_common import is_float_dtype, is_integer_dtype from torch._prims_common import (
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
is_float_dtype,
is_integer_dtype,
)
from torch._subclasses.meta_utils import MetaConverter from torch._subclasses.meta_utils import MetaConverter
from torch.fx.operator_schemas import normalize_function from torch.fx.operator_schemas import normalize_function
from torch.multiprocessing.reductions import StorageWeakRef from torch.multiprocessing.reductions import StorageWeakRef
@ -20,9 +26,11 @@ from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only
from torch.utils._stats import count from torch.utils._stats import count, count_label
from torch.utils.weak import WeakIdRef from torch.utils.weak import WeakIdRef
log = logging.getLogger(__name__)
pytree = torch.utils._pytree pytree = torch.utils._pytree
T = TypeVar("T") T = TypeVar("T")
TensorWeakRef = Any TensorWeakRef = Any
@ -31,6 +39,22 @@ aten = torch._ops.ops.aten
CONSTANT_NUMEL_LIMIT = 1 CONSTANT_NUMEL_LIMIT = 1
RECURSION_COUNT = 0
# Small helper that increments recursion count, and
# resets it when the object goes out of scope. Useful
# if you don't want to increase indentation which is
# what a context manager would do.
class IncrementRecursionCount:
def __init__(self):
global RECURSION_COUNT
RECURSION_COUNT += 1
def __del__(self):
global RECURSION_COUNT
RECURSION_COUNT -= 1
@dataclass @dataclass
class UnsupportedFakeTensorException(RuntimeError): class UnsupportedFakeTensorException(RuntimeError):
@ -509,6 +533,189 @@ def conv(fake_mode, func, *args, **kwargs):
) )
FAST_OP_IMPLEMENTATIONS = {}
# Unlike register_op_impl, these don't do the slow iteration for
# run_impl_check, and these run BEFORE decompositions
def register_fast_op_impl(func: OpOverload):
def impl_decorator(op_impl):
FAST_OP_IMPLEMENTATIONS[func] = op_impl
return op_impl
return impl_decorator
# infer_size_impl in ExpandUtils
def infer_size(a, b):
dimsA = len(a)
dimsB = len(b)
ndim = max(dimsA, dimsB)
expandedSizes = [0] * ndim
for i in range(ndim - 1, -1, -1):
offset = ndim - 1 - i
dimA = dimsA - 1 - offset
dimB = dimsB - 1 - offset
sizeA = a[dimA] if dimA >= 0 else 1
sizeB = b[dimB] if dimB >= 0 else 1
if not (sizeA == sizeB or sizeA == 1 or sizeB == 1):
raise RuntimeError(
f"The size of tensor a ({sizeA}) "
f"must match the size of tensor b ({sizeB}) "
f"at non-singleton dimension {i})"
)
expandedSizes[i] = sizeB if sizeA == 1 else sizeA
return tuple(expandedSizes)
def make_fast_binary_impl(slow_ref):
def fast_binary_impl(mode, *args, **kwargs):
def slow(msg):
count_label(f"slow {msg}")
with mode:
return slow_ref(*args, **kwargs)
count_label("attempt fast")
# Fast path (based off of TensorIterator fast path).
# Unfortunately, there is no way to easily deduplicate
# this with either the TensorIterator C++ implementation
# (which we don't want to SymIntify, and also the algorithm
# here is slightly different from TensorIterator to allow
# for broadcasting), nor the PrimTorch implementation
# (which does not actually implement a fast path.)
operands = args
# compute_shape
has_scalars = False
has_tensors = False
final_shape = None
for op in operands:
shape = op.shape if isinstance(op, torch.Tensor) else ()
if len(shape) == 0:
has_scalars = True
else:
has_tensors = True
if final_shape is None:
final_shape = shape
# TODO: Minor optimization: track if the shapes
# were equal so you can skip the equality check
# below if unnecessary
final_shape = infer_size(final_shape, shape)
assert final_shape is not None
# Do some extra safety checks to see if the output
# stride is obvious
for op in operands:
if isinstance(op, torch.Tensor) and op.shape == final_shape:
break
else:
return slow("both tensors nontrivially broadcast")
# compute_types
cpu = torch.device("cpu")
common_device = cpu
common_dtype = None
output_dtype = None
has_different_input_dtypes = False
for op in operands:
if not isinstance(op, torch.Tensor):
# Use elementwise_dtypes for the tricky case
has_different_input_dtypes = True
continue
if common_device == cpu and not op.device.type == "cpu":
common_device = op.device
# Slightly simplified here as target_dtype cannot vary
if common_dtype is None:
common_dtype = op.dtype
elif common_dtype != op.dtype:
has_different_input_dtypes = True
if has_different_input_dtypes:
# compute promotion
# TODO: we don't need the compute type
_, common_dtype = elementwise_dtypes(
*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
# check all tensors on same device
# cpu scalars are assumed allow
current_cpu_scalars_on_non_cpu = 0
max_cpu_scalars_on_non_cpu = 1 # hard coded atm
for op in operands:
if not isinstance(op, torch.Tensor):
continue
if common_device != cpu and op.dim() == 0 and op.device == cpu:
if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
return slow("error")
current_cpu_scalars_on_non_cpu += 1
elif op.device != common_device:
return slow("error")
# compute_fast_setup_type
is_contiguous = True
is_channels_last = True
# TODO: is_non-overlapping_and_dense (not bound from Python
# no inplace, no out, everything defined
for op in operands:
if not isinstance(op, torch.Tensor):
continue
is_contiguous = is_contiguous and op.is_contiguous(
memory_format=torch.contiguous_format
)
is_channels_last = is_channels_last and op.is_contiguous(
memory_format=torch.channels_last
)
if is_contiguous:
# do contiguous
count_label("fast is_contiguous")
return FakeTensor(
mode,
torch.empty(
final_shape,
dtype=common_dtype,
device="meta",
memory_format=torch.contiguous_format,
),
device=common_device,
)
if is_channels_last:
count_label("fast channels_last")
# do channels last
return FakeTensor(
mode,
torch.empty(
final_shape,
dtype=common_dtype,
device="meta",
memory_format=torch.channels_last,
),
device=common_device,
)
return slow("no contiguity match")
return fast_binary_impl
@functools.lru_cache(None)
def get_fast_op_impls():
import torch._refs
register_fast_op_impl(torch.ops.aten.add.Tensor)(
make_fast_binary_impl(torch._refs.add)
)
register_fast_op_impl(torch.ops.aten.sub.Tensor)(
make_fast_binary_impl(torch._refs.sub)
)
register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]
register_fast_op_impl(torch.ops.aten.div.Tensor)(
make_fast_binary_impl(torch._refs.div)
)
return FAST_OP_IMPLEMENTATIONS
@contextlib.contextmanager @contextlib.contextmanager
def in_kernel_invocation_manager(fake_mode): def in_kernel_invocation_manager(fake_mode):
# See: note [Fake Tensor Dispatch Keys] # See: note [Fake Tensor Dispatch Keys]
@ -776,6 +983,13 @@ class FakeTensorMode(TorchDispatchMode):
@count @count
def __torch_dispatch__(self, func, types, args=(), kwargs=None): def __torch_dispatch__(self, func, types, args=(), kwargs=None):
try:
return self.dispatch(func, types, args, kwargs)
except TypeError:
log.exception("fake tensor raised TypeError")
raise
def dispatch(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {} kwargs = kwargs if kwargs else {}
if func == torch.ops.prim.device.default: if func == torch.ops.prim.device.default:
@ -785,6 +999,12 @@ class FakeTensorMode(TorchDispatchMode):
else: else:
return args[0].fake_device return args[0].fake_device
if log.getEffectiveLevel() <= logging.DEBUG:
log.debug(
f"{' ' * RECURSION_COUNT}FakeTensorMode.__torch_dispatch__: {func}"
)
incr = IncrementRecursionCount()
# Some attribute queries that can be serviced directly # Some attribute queries that can be serviced directly
# See Note [is_coalesced is dispatched] # See Note [is_coalesced is dispatched]
if func in { if func in {
@ -894,6 +1114,12 @@ class FakeTensorMode(TorchDispatchMode):
# is written to must be invalidated # is written to must be invalidated
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
# Try for fastpath
if has_symbolic_sizes:
fast_impl = get_fast_op_impls().get(func)
if fast_impl is not None:
return fast_impl(self, *args, **kwargs)
# If there's a Python meta, prefer that over the decomposition # If there's a Python meta, prefer that over the decomposition
from torch._decomp import meta_table as meta_table from torch._decomp import meta_table as meta_table

View File

@ -3,8 +3,13 @@
# AND SCRUB AWAY TORCH NOTIONS THERE. # AND SCRUB AWAY TORCH NOTIONS THERE.
import collections import collections
import functools import functools
from typing import OrderedDict
simple_call_counter = collections.OrderedDict() simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
def count_label(label):
prev = simple_call_counter.setdefault(label, 0)
simple_call_counter[label] = prev + 1
def count(fn): def count(fn):
@functools.wraps(fn) @functools.wraps(fn)