pytorch/torch
Edward Z. Yang d690a596dc Fast path binary ops in fake tensor (#94047)
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}

+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.)

The implementation here is based off of https://github.com/pytorch/pytorch/pull/93118/ but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last).
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94047
Approved by: https://github.com/eellison
2023-02-07 18:34:24 +00:00
..
_awaits [jit] Support Awaitable type (#90863) 2023-01-30 17:38:59 +00:00
_C [jit] Support Awaitable type (#90863) 2023-01-30 17:38:59 +00:00
_C_flatbuffer
_decomp add rsub decomposition with alpha (#94144) 2023-02-07 17:21:13 +00:00
_dispatch
_dynamo General in-place binary op support in dynamo (#94203) 2023-02-07 15:12:32 +00:00
_export [Export] Introduce as_none in ex.Argument union type (#93210) 2023-01-30 21:32:49 +00:00
_functorch Enable Python dispatcher on inference-only aot_dispatch_base (#94118) 2023-02-04 06:10:21 +00:00
_inductor add rsub decomposition with alpha (#94144) 2023-02-07 17:21:13 +00:00
_lazy
_prims [pt2] Fix arange to match ATen behavior (#93353) 2023-02-03 00:44:32 +00:00
_prims_common [pt2] Fix arange to match ATen behavior (#93353) 2023-02-03 00:44:32 +00:00
_refs [decomp] Decompose std/std_mean into aten.var/var_mean (#94072) 2023-02-06 10:22:07 +00:00
_subclasses Fast path binary ops in fake tensor (#94047) 2023-02-07 18:34:24 +00:00
amp
ao [quant][fx][pt2e] Refactor prepare so it's aligned better with the new API plan in pt2e (#94011) 2023-02-07 08:23:56 +00:00
autograd [Py3.11] Remove skip logic from vmap and forward_ad (#91825) 2023-01-25 22:40:56 +00:00
backends Dynamo benchmark: add CPU specific changes (#88477) 2023-01-07 09:26:06 +00:00
contrib
cpu
csrc [mobile] List all missing ops at once (#94205) 2023-02-07 05:45:57 +00:00
cuda FusedAdam(W) should take OptState into account before unscaling grads (#94060) 2023-02-04 05:20:13 +00:00
distributed [FSDP][optim_state_dict] Let optim_state_dict ignore the non-FSDP managed parameters that do not reside on the rank (#94129) 2023-02-07 06:29:28 +00:00
distributions [Dynamo] Fix calling UserDefinedObject.func should pass self object (#92050) 2023-01-21 05:47:01 +00:00
fft
func [functorch] move batch_norm_replacement to torch.func (#91412) 2023-01-12 19:15:41 +00:00
futures Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
fx [Dynamo] Support torch.Tensor.fn as TorchVariable, not UserDefinedObjectVariable, preventing graph break (#93243) 2023-02-07 09:26:50 +00:00
jit Make segment_reduce properly private. (#93166) 2023-02-06 18:32:23 +00:00
legacy
lib More fixes and improved clang-tidy checkers (#93213) 2023-02-01 14:44:17 +00:00
linalg
masked Make segment_reduce properly private. (#93166) 2023-02-06 18:32:23 +00:00
monitor Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
multiprocessing Set min supported Python version to 3.8 (#93155) 2023-01-29 18:28:46 +00:00
nested
nn [Docs] Add pointer to FlashAttention paper (#94253) 2023-02-07 08:05:10 +00:00
onnx [ONNX] Export 'aten::index_put(self, mask, v)' when rank(mask) < rank(self) (#92862) 2023-01-27 02:00:56 +00:00
optim Look up group["capturable"], not defaults["capturable"] in Adam(W) (#94149) 2023-02-07 00:24:35 +00:00
package Set min supported Python version to 3.8 (#93155) 2023-01-29 18:28:46 +00:00
profiler Silence profiler error (#94013) 2023-02-03 17:33:47 +00:00
quantization AO migration: replace torch internal callsites (#94170) 2023-02-07 02:32:23 +00:00
signal Reland "Add torch.utils.device_mode" (#91796) 2023-01-09 20:57:12 +00:00
sparse Revert "Improve bsr @ strided performance in baddmm for bfloat16/half with Triton kernels. (#88078)" 2023-01-26 16:22:29 +00:00
special
testing add rsub decomposition with alpha (#94144) 2023-02-07 17:21:13 +00:00
utils Fast path binary ops in fake tensor (#94047) 2023-02-07 18:34:24 +00:00
__config__.py
__future__.py
__init__.py temp fix for segment reduce undocumented FC window (#94242) 2023-02-07 18:27:01 +00:00
_appdirs.py
_classes.py
_deploy.py
_guards.py Properly resolve source_ref when constructing shape guards (#91058) 2022-12-30 05:56:56 +00:00
_jit_internal.py [jit] jit._drop fun modifier to allow in jit class non-jit decl funs (#93012) 2023-02-01 09:02:05 +00:00
_linalg_utils.py Remove deprecated torch.symeig (#70988) 2023-01-31 11:59:11 +00:00
_lobpcg.py Fix typo in _lobpcg.py (#91641) 2023-01-04 15:19:05 +00:00
_lowrank.py
_meta_registrations.py [pt2] Fix arange to match ATen behavior (#93353) 2023-02-03 00:44:32 +00:00
_namedtensor_internals.py Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
_ops.py Fix checking of current mode in PyOperator dispatch (#92357) 2023-01-18 23:08:36 +00:00
_python_dispatcher.py
_six.py
_sources.py
_storage_docs.py
_tensor_docs.py Point to scatter_reduce for reduce argument in scatter_ docs (#94081) 2023-02-06 19:26:21 +00:00
_tensor_str.py Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
_tensor.py Remove deprecated torch.symeig (#70988) 2023-01-31 11:59:11 +00:00
_torch_docs.py Remove deprecated torch.symeig (#70988) 2023-01-31 11:59:11 +00:00
_utils_internal.py
_utils.py Fix serialization (#94096) 2023-02-06 16:30:20 +00:00
_VF.py
_vmap_internals.py Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
_weights_only_unpickler.py
abi-check.cpp
CMakeLists.txt Revert "[cuDNN][cuDNN V8 API] Always build assuming cuDNN >= 8.0 (#91527)" 2023-01-16 13:28:09 +00:00
custom_class_detail.h More tidy fixes (#93069) 2023-01-27 06:40:50 +00:00
custom_class.h More fixes and improved clang-tidy checkers (#93213) 2023-02-01 14:44:17 +00:00
extension.h
functional.py Update version numbers in torch.{stft,istft} deprecations (#91761) 2023-01-05 22:17:37 +00:00
hub.py Preventing crashing incase of no network by loading from cache (#91569) 2023-01-11 11:56:46 +00:00
library.h
library.py Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
overrides.py Make segment_reduce properly private. (#93166) 2023-02-06 18:32:23 +00:00
py.typed
quasirandom.py
random.py
README.txt
return_types.py
script.h
serialization.py Enable xdoctest runner in CI for real this time (#83816) 2022-12-29 05:32:42 +00:00
storage.py
torch_version.py
types.py

Note [TH abstraction violation]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

TH/THC provide some hpp headers, which are proper C++ headers rather than
C headers.  These headers serve double duty as *internal implementation
detail* headers, whose contents should largely not be used by external
clients.

Ideally, we would not install these headers at all; instead, you should
use public functions (in headers like `THTensor.h`, NOT `THTensor.hpp`)
to manipulate these structs.  However, there are a few places
in torch/csrc where we violate this abstraction.  They are marked with
a pointer to this note.  Each of those sites will have to be refactored
when we refactor the guts of THTensor and related structures.