Commit Graph

126 Commits

Author SHA1 Message Date
Elias Ellison
0a9778a372 Expose cudaStreamCaptureMode in CUDA Graphs, use local setting in inductor (#107407)
>  capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
 may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
 actions in the current thread, and "relaxed" will not error on these actions.

Inductor codegen is single-threaded, so it should be safe to enable "thread_local" for inductor's cuda graph capturing. We have seen errors when inductor cudagraphs has been used concurrently with data preprocessing in other threads.

Differential Revision: [D48656014](https://our.internmc.facebook.com/intern/diff/D48656014)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107407
Approved by: https://github.com/albanD, https://github.com/eqy
2023-08-25 01:44:26 +00:00
FFFrog
4d13422997 fix errors about mypy check in torch/_inductor/compile_fx.py (#107508)
the `compile_fx.py` blocked the merging of [PR1 ](https://github.com/pytorch/pytorch/pull/107127)and [PR2](https://github.com/pytorch/pytorch/pull/107448)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107508
Approved by: https://github.com/ezyang
2023-08-22 22:33:37 +00:00
angelayi
d5b8c71112 [inductor] Revert inductor changes in #105977 (#107468)
Reverts inductor changes in #105977

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107468
Approved by: https://github.com/jansel
2023-08-21 15:50:03 +00:00
Simon Fan
aca3d1433c Estimate Scheduler node runtimes (#106426)
Working as starter task with @Chillee

This PR adds a method under BaseSchedulerNode to estimate the node's runtime in seconds.

We use a heuristic based approach, first by considering whether the operation is memory bandwidth bounded or compute bounded:
- memory bandwidth bounded: we compute the number of bytes that are read/written to
- compute bounded: we compute the FLOPS required by the operation

One use case could be to be used as a cost model for scheduling: https://github.com/pytorch/pytorch/pull/100762

```
(pytorch-3.10) [14:08:02] ~/local/pytorch (xmfan/estimate_snode_runtime) > python3 test/inductor/test_perf.py -k EstimateSnodeRuntimeTests
[(ExternKernelSchedulerNode(name='buf0'), 400)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 3000), (SchedulerNode(name='buf1'), 3000)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-26), (SchedulerNode(name='buf1'), 7.187055238190188e-09)]
.[(ExternKernelSchedulerNode(name='buf0'), 3000)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-26)]
.[(ExternKernelSchedulerNode(name='buf0'), 34600)]
[(ExternKernelSchedulerNode(name='buf0'), 3.22687496698039e-24)]
.[(ExternKernelSchedulerNode(name='buf0'), 396)]
[(ExternKernelSchedulerNode(name='buf0'), 1.88046326747109e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 396)]
[(ExternKernelSchedulerNode(name='buf0'), 1.88046326747109e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 7776176)]
[(ExternKernelSchedulerNode(name='buf0'), 4.63240241413653e-21)]
.[(FusedSchedulerNode(nodes=buf0_buf1), 210)]
[(FusedSchedulerNode(nodes=buf0_buf1), 5.030938666733132e-10)]
.[(ExternKernelSchedulerNode(name='buf0'), 300)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-27)]
.[(SchedulerNode(name='buf0'), 20)]
[(SchedulerNode(name='buf0'), 4.7913701587934585e-11)]
.
----------------------------------------------------------------------
Ran 10 tests in 14.311s
OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106426
Approved by: https://github.com/Chillee
2023-08-17 17:23:30 +00:00
Shunting Zhang
91778ada87 [inductor] graph replayer (#106952)
Recently I feel it's a bit painful to run benchmark scripts on my dev environment. E.g., the command below
```
 python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --only YituTechConvBert --training
```
took about 2 minutes to run. It may take even longer for some other models.

The command is slow since it
- need do dynamo work
- verify the model on CPU
- run perf tests
- compile all the graphs

However, often times I only need to debug inductor specific logic like loop ordering and fusion. A lot of the things the script is done are useless for me. Also I only need test one graph at a time (e.g. check fwd graph first and when I'm done, continue to check bwd graph) rather than compiling all the graphs.

The graph replayer add a `@save_args` decorator to compile_fx_inner function. When `config.save_args` is true, it will pickle all the arguments to `comple_fx_inner` to the file system.  Later on, we can call `load_args_and_run_compile_fx_inner("/tmp/inductor_saved_args/compile_fx_inner_0.pkl")` to replay the graph and compile it with inductor.

Replaying the fwd graph took around 60 seconds (maybe this can be further reduced but this is already 2x speedup for dev efficiency) , and it only took around 20 seconds to reach `Scheduler.__init__` method.

I also checked `TORCH_COMPILE_DEBUG` flag that already exists. The most similar part of `TORCH_COMPILE_DEBUG` is it can save a graph and it's arguments and later on rerun it. But the difference here is, rather than run the model, we want to call inductor API to compile the model (without even going thru dynamo or aot-autograd).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106952
Approved by: https://github.com/jansel
ghstack dependencies: #106990
2023-08-11 22:28:20 +00:00
Yang Chen
40a15b50a8 Enable mypy checking in compile_fx.py (#105830)
This is part of the effort for issue #105230

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105830
Approved by: https://github.com/eellison
2023-08-09 09:05:23 +00:00
angelayi
5b13c779d4 [AOTInductor] Remove call to aot_autograd when receiving ExportedProgram (#105977)
https://github.com/pytorch/pytorch/issues/105555

Existing flow first exports and then calls torch._inductor.aot_compile. However, export calls aot_autograd with the core aten decomposition table, and then torch._inductor.aot_compile calls aot_autograd again with the inductor decomposition table. The 2nd calling of aot_autograd is supposedly causing some problems, and seems excessive, so instead we will create a new function, torch._export.aot_compiler which will export using the inductor decomposition table, pass it to inductor's compile_fx_aot, and because it has already been exported, avoid recalling aot_autograd.

```
def aot_compile(
    f: Callable,
    args: Tuple[Any],
    kwargs: Optional[Dict[str, Any]] = None,
    constraints: Optional[List[Constraint]] = None,
) -> Tuple[str, ExportedProgram]:
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105977
Approved by: https://github.com/desertfire, https://github.com/zhxchen17, https://github.com/eellison
2023-08-04 15:35:23 +00:00
Elias Ellison
57f2a8d3a8 freezing w aot (#105497)
Freezing will take parameters and turn them into constants. A couple changes here:

-  move the setting of `flat_params[dropped_index]` before cpp compilation so that cpp_wrapper knows they have been dropped
- compile_fx_aot is doesn't use aot_autograd for invocation, so we no longer add the wrapper which discards dropped param indices. Continuing to add arguments everywhere didn't seem great, so I added `_in_aot_compilation`, but maybe reviewers would prefer something else.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105497
Approved by: https://github.com/desertfire
2023-08-02 16:30:08 +00:00
XiaobingSuper
eab3b2637a only collect fx node for user_visible_outputs when doing output stride conversion (#106194)
For yolo3, there has a subgraph that output has int value, and AttributeError: 'int' object has no attribute 'name` caused by collecting ser_visible_outputs to do output stride conversion.  This PR will add a check only that the output is a fx node before being added in user_visible_outputs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106194
Approved by: https://github.com/jgong5, https://github.com/eellison, https://github.com/shunting314
2023-07-30 13:48:22 +00:00
Elias Ellison
37cfe944bb add support for mutated params (#106098)
Previously, this didn't work because of the warmup run. Now that we do not run warmup, and then execution on one inductor invocation this works. llama inference 1.6->4.4

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106098
Approved by: https://github.com/ezyang
2023-07-28 17:27:06 +00:00
XiaobingSuper
9c1802f8e3 inductor: using binary folding path to do conv+bn folding (#105650)
This path will use binary folding to do conv+bn folding to avoid using ```make_fx``` which meets tracing errors in some model dynamic shape path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105650
Approved by: https://github.com/eellison
2023-07-26 07:37:47 +00:00
Jason Ansel
c902b84e0b Compiled autograd (#103822)
This branch:
1) converts the autograd tape into an FX graph
2) caches that conversion using a "shadow" graph
3) compiles and runs the generated FX graph instead of the normal autograd

What works currently:
1) Caching, capture, and initial integration
2) Backwards hooks
3) Inlining AotAutograd generated subgraphs
4) torch.compiling the generated FX graph
5) Auto-detecting dynamic shapes based on changes

Future work
1) Larger scale testing
1) Boxed calling convention, so memory can be freed incrementally
1) Support hooks on SavedTensor
1) Additional testing by running eager autograd tests under compiled_autograd.enable()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103822
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-07-24 21:12:05 +00:00
Bin Bao
b0816e4714 [inductor] Fix AOTInductor output issues (#105773)
Summary: This is a follow-up on https://github.com/pytorch/pytorch/pull/105496. There are several issues with the previous fix,
1) It explicitly does copy for every output at the end of the main function;
2) When an output is ReinterpretView, no as_strided was generated for it;
3) There can be duplicated buffer declarations.

This PR fixes by making sure can_reuse behave consistently between two AOTIndcutor passes, and thus always generate the same set of kernels. It also adds handling of ReinterpretView.

Differential Revision: [D47692214](https://our.internmc.facebook.com/intern/diff/D47692214)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105773
Approved by: https://github.com/jansel
2023-07-24 01:58:49 +00:00
David Berard
5abc5ab55d [inductor] Disable cudagraphs if index_put_ fallback is encountered (#105439)
**TL;DR**: if lowerings.py encounters aten.index_put, it will set V.graph.cudagraphs_okay = False, which will disable cudagraphs. index_put needs to be disabled because it crashes cuda graphs.

index_put_ fallbacks fail with cuda graphs when `accumulate=True` - likely for the same reason that it fails with deterministic_algorithms_enabled:
fcb7d4b358/aten/src/ATen/native/TensorAdvancedIndexing.cpp (L730)

A first attempt was just to expand the scenarios where `index_put_` is one of the disallowed kernels in utils.py: 2fa7d11b64/torch/_inductor/utils.py (L436-L438)

However this disables cuda graphs in too many scenarios, because index_put doesn't cause issues if it gets fused, it only causes issues if the aten kernel gets called. So in the updated version of this PR, we check for fallbacks in lowerings.py and disable cudagraphs only if a fallback is encountered there.

Example of failure outside of PT2:

```python
import torch

def fn(x, y, z):
    x = torch.zeros_like(x)
    return x.index_put_([y], z, True)
    # return x + 1

x = torch.zeros((512, 512), dtype=torch.bool, device='cuda')
y = torch.arange(512, dtype=torch.int64, device='cuda')
z = torch.ones((512, 512), dtype=torch.bool, device='cuda')

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        fn(x, y, z)
torch.cuda.current_stream().wait_stream(s)

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    fn(x, y, z)
```

fails with
```
Traceback (most recent call last):
  File "/data/users/dberard/scripts/graphed_index_put.py", line 24, in <module>
    fn(x, y, z)
  File "/data/users/dberard/scripts/graphed_index_put.py", line 8, in fn
    return x.index_put_([y], z, True)
RuntimeError: CUDA error: operation not permitted when stream is capturing
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/data/users/dberard/scripts/graphed_index_put.py", line 24, in <module>
    fn(x, y, z)
  File "/data/users/dberard/pytorch/torch/cuda/graphs.py", line 173, in __exit__
    self.cuda_graph.capture_end()
  File "/data/users/dberard/pytorch/torch/cuda/graphs.py", line 79, in capture_end
    super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
```

Differential Revision: [D47538548](https://our.internmc.facebook.com/intern/diff/D47538548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105439
Approved by: https://github.com/eellison
2023-07-19 23:38:29 +00:00
Bin Bao
fe04c6c371 [inductor] Allow specify a subdir to store .so and .cubin files (#105466)
Summary: The subdir is used to store .so and .cubin files generated by AOTInductor. It can either be specified, or created based on hash of the input graph.

Differential Revision: [D47556730](https://our.internmc.facebook.com/intern/diff/D47556730)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105466
Approved by: https://github.com/chenyang78
2023-07-19 03:13:50 +00:00
Nicolas Macchioni
6ca3d7e1a2 [pt2][inductor] only use global cache on MAST (#105375)
Summary:
until we can further investigate the autotuning differences between MAST and non-MAST (devserver) environments, turn off the global cache for all non-MAST environments. this ensures we don't see unexpected regressions

also update scuba logging for cache lookup, and add scuba logging for autotuning results.

Test Plan: sandcastle + CI

Differential Revision: D47516633

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105375
Approved by: https://github.com/jansel
2023-07-18 06:16:47 +00:00
willfengg
8010f6bf48 [dynamo][inductor] Provide public API to get compiler options/configs (#105026)
issues resolved: https://github.com/pytorch/pytorch/issues/101832

**context**: get torch.compile config for further usage. E.g, the training platform wants to get if model is compiled with cudagraph enabled and trigger further action

**how it is implemented**
   * the core logic is backend.get_compiler_config() in torch/_dynamo/eval_frame.py
   * for backend='inductor' / _TorchCompileInductorWrapper, we have inductor-specific implementation in get_compiler_config in torch/_inductor/compile_fx.py and torch/__init__.py

**how to use it**: Below is an example.

```
model = DummyModule()
optimized_module = torch.compile(
    model, options={"triton.cudagraphs": True}
)
compiler_config = optimized_module.get_compiler_config()

if compiler_config["triton.cudagraphs"]:
   pass
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105026
Approved by: https://github.com/yanboliang, https://github.com/jansel
2023-07-18 06:12:06 +00:00
Edward Z. Yang
1152e86da1 Transmute refined SymInt into int (#104828)
Previously, x.size(0) could return a SymInt, even when the internal
sympy expression was actually already constant (e.g., due to an
introduced guard.)  We now allow to query the Python object with
maybe_as_int which allows us to transmute these objects back to
int when possible.

It is still possible to end up with a constant SymInt even after this
change, e.g., if you get out a SymInt and while holding onto it
specialize it, but casual users are more likely to get ints when they
want to.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104828
Approved by: https://github.com/Skylion007
2023-07-15 18:46:10 +00:00
Edward Z. Yang
10cbc9a063 Enable cuda graphs for dynamic shapes (#105064)
The general idea is to do a separate CUDA graph for each size. Because of cuda graph trees, these graphs will all share the same memory pool, so your memory usage will only be the worst case memory usage of the biggest dynamic size you want. This requires an extra dispatch in the cudagraphified callable. You must pay for a CUDA graph recording for every dynamic size you encounter, but this is MUCH cheaper than running the entire PT2 compile stack, so I expect you to still see benefits.

This was surprisingly easy to do.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105064
Approved by: https://github.com/voznesenskym
2023-07-14 16:13:50 +00:00
PyTorch MergeBot
1c69f363c4 Revert "Transmute refined SymInt into int (#104828)"
This reverts commit 0f322a300e.

Reverted https://github.com/pytorch/pytorch/pull/104828 on behalf of https://github.com/ezyang due to executorch failure ([comment](https://github.com/pytorch/pytorch/pull/104828#issuecomment-1635997559))
2023-07-14 15:08:11 +00:00
Edward Z. Yang
0f322a300e Transmute refined SymInt into int (#104828)
Previously, x.size(0) could return a SymInt, even when the internal
sympy expression was actually already constant (e.g., due to an
introduced guard.)  We now allow to query the Python object with
maybe_as_int which allows us to transmute these objects back to
int when possible.

It is still possible to end up with a constant SymInt even after this
change, e.g., if you get out a SymInt and while holding onto it
specialize it, but casual users are more likely to get ints when they
want to.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104828
Approved by: https://github.com/Skylion007
2023-07-13 07:02:52 +00:00
Horace He
601db856d1 elevated cudagraphs failure to warning, added lineno to recompiles (#105081)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105081
Approved by: https://github.com/mlazos
2023-07-13 01:17:58 +00:00
Edward Z. Yang
979f826015 Read out real strides from compilation result, rather than real args (#105010)
This prefigures a refactor that will move the backward compilation
to entirely ahead of time, so I need to extract these strides some
other way.  Straight from the compiler's mouth will do it.

I can't easily get the information via the return result of `fw_compiler` without changing the calling convention, so instead I smuggle it via TracingContext. TracingContext may be None when we are compiling patterns for the joint graph pattern matcher.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105010
Approved by: https://github.com/shunting314
2023-07-12 11:33:08 +00:00
Edward Z. Yang
6059fea760 Make perf_hint_log report at info level (#104873)
If you do it at warning, these log messages will get displayed by
default, which is not the intended behavior.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104873
Approved by: https://github.com/mlazos
2023-07-10 23:46:34 +00:00
Adnan Akhundov
4911b80b8e [inductor] addmm + ReLU / GELU fusion pass (#104132)
Summary:

Add a new path in `post_grad.py` for replacing addmm + ReLU / GELU activation with the corresponding `_addmm_activation` call (with `use_gelu=False` or `True`, respectively). The replacement is done only on `max_autotune_gemm=False` and when the activation is fusible.

Test Plan:

$ python test/inductor/test_pattern_matcher.py -k test_addmm_activation -v

(__main__.TestPaternMatcher.test_addmm_activation) ... /data/users/aakhundov/pytorch/torch/_inductor/compile_fx.py:128: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Using FallbackKernel: aten._addmm_activation.default
Using FallbackKernel: aten._addmm_activation.default
/data/users/aakhundov/pytorch/torch/_dynamo/eval_frame.py:373: UserWarning: changing options to `torch.compile()` may require calling `torch._dynamo.reset()` to take effect
  warnings.warn(
frames [('total', 1), ('ok', 1)]
stats [('calls_captured', 2), ('unique_graphs', 1)]
aot_autograd [('total', 1), ('ok', 1)]
inductor []
ok

----------------------------------------------------------------------
Ran 1 test in 13.415s

OK

Reviewers: @eellison

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104132
Approved by: https://github.com/eellison, https://github.com/jansel
2023-07-10 16:44:14 +00:00
Edward Z. Yang
2385dad4b3 Enable automatic_dynamic_shapes by default (#103623)
Some notes:

* I now manually turn off `_generate` jobs from running with cudagraphs, as it is unrealistic to expect to cudagraph autoregressive generation up to max sequence length, this would imply compiling the entire unrolled sequence generation. Concretely, cm3leon_generate was timing out post this change, likely due to the compile time slowdown of dynamic shapes ON TOP OF accidentally unrolling all the loops
* A few torch._dynamo.reset tactically inserted to force recompiles on tests that expected it
* expectedFailureAutomaticDynamic flip into patching automatic_dynamic_shapes=False

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103623
Approved by: https://github.com/voznesenskym
2023-07-05 00:25:02 +00:00
Shunting Zhang
98f00f881f [inductor] convert layout of conv weight ahead of time for inference (#103642)
This PR handles inference. Will do similar thing for training later.

Some manual testing results shows this can improve inference perf by 2-3% (absolute improvement not relative one).
- convmixer: 4.285x -> 4.309x
- resnet50: 2.170x -> 2.203x

The PR is built upon freezing. Since without freezing, the weight input for a conv node may not be a parameter directly but be the output of precision converting ops. It's so much easier to implement this PR after freezing.

Commands
```
TORCHINDUCTOR_FREEZING=1 python benchmarks/dynamo/timm_models.py --backend inductor --amp --performance --only convmixer_768_32 --inference
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103642
Approved by: https://github.com/eellison
2023-06-28 17:42:32 +00:00
Elias Ellison
de7b6e55eb Fix bad cudagraph interaction (#104196)
Fix for https://github.com/pytorch/pytorch/issues/103126

As mentioned there,

> We need to make sure we are not removing the misaligned inputs before we are checking for misalignment in cudagraphs, so we know not to expect a static input for the misaligned tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104196
Approved by: https://github.com/desertfire
2023-06-27 21:36:09 +00:00
Elias Ellison
edc9c0df7e Fold Conv-Bn (#100653)
Adds Conv-BN folding to inductor freezing. One thing that's a little awkward now is we'll want different decompositions to run depending on if we are in the inference compiler. For now, I require that you run with torch.no_grad() so we can detect if no gradients are required before calling aot_autograd.

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100653
Approved by: https://github.com/jansel
2023-06-26 16:04:34 +00:00
Bin Bao
c1fffdcd5b Change how AOTInductor's fx input is produced (#104123)
Test Plan: CI

Reviewed By: wushirong

Differential Revision: D46983754

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104123
Approved by: https://github.com/chenyang78
2023-06-26 15:59:33 +00:00
Antoni Viros i Martin
0d653730ce Refactory bits for the codegen cache (#103452)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103452
Approved by: https://github.com/ezyang
2023-06-22 13:04:22 +00:00
Bin Bao
da7ca82121 [inductor] Store real inputs to be used for cpp wrapper codegen (#103289)
Summary: defaked args (zeros) may cause device-side access assertion, so
record the orginal real tensor inputs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103289
Approved by: https://github.com/jansel, https://github.com/eellison
2023-06-15 20:05:50 +00:00
Edward Z. Yang
bc6ec97e02 Switch dynamic_shapes to True by default (#103597)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103597
Approved by: https://github.com/voznesenskym
2023-06-15 15:16:20 +00:00
Elias Ellison
d083d444ff Inductor Freezing (#100652)
Adds a freezing pass that will constant fold parameters in inductor `config.freezing`. This occurs post functionalization in aot autograd to capture both dispatching and allow passes to occur post functionalization. A few notes:

- There is an option to discard parameters `config.freezing_discard_parameters` which will take the current eager modules and wrap parameters to a Tensor subclass which will error if used.
- I needed to expose flat_params in aot_autograd in order to discard old references when we constant fold away parameters, like with amp. I also exposed `fw_metadata` to avoid constant folding mutated paraemters.
- Caching parameter transformations/constant folding across different inferences nyi
- Checking version_counter of constant folded params nyi

I'm not really sure what the actual naming should be. In jit there was both "freezing", which was platform agnostic, and "optimize for inference", which made device specific optimizations. We're doing the latter here but maybe freezing is a better name.

Differential Revision: [D46244033](https://our.internmc.facebook.com/intern/diff/D46244033)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100652
Approved by: https://github.com/jansel
2023-06-12 20:56:03 +00:00
Edward Z. Yang
54daf870bc CUDA graphs overrides dynamic shapes and forces specialization (#103290)
Previously, cudagraphs and dynamic_shapes were incompatible and enabling
dynamic shapes would forcibly disable cudagraphs.  This new strategy
I think is better.  The idea is essentially that cudagraphs is an
"optimization" that happens to guard on every input.  When cudagraphs
is on, we force everything static, and this automatically does the right
thing because we will force a recompile if sizes change.

This obsoletes https://github.com/pytorch/pytorch/pull/101813

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103290
Approved by: https://github.com/voznesenskym, https://github.com/eellison
2023-06-12 20:26:55 +00:00
PyTorch MergeBot
d89dd05e4d Revert "CUDA graphs overrides dynamic shapes and forces specialization (#103290)"
This reverts commit c760f0e4dd.

Reverted https://github.com/pytorch/pytorch/pull/103290 on behalf of https://github.com/ezyang due to to handle the other cuda graphs case ([comment](https://github.com/pytorch/pytorch/pull/103290#issuecomment-1584977767))
2023-06-09 18:25:28 +00:00
Edward Z. Yang
c760f0e4dd CUDA graphs overrides dynamic shapes and forces specialization (#103290)
Previously, cudagraphs and dynamic_shapes were incompatible and enabling
dynamic shapes would forcibly disable cudagraphs.  This new strategy
I think is better.  The idea is essentially that cudagraphs is an
"optimization" that happens to guard on every input.  When cudagraphs
is on, we force everything static, and this automatically does the right
thing because we will force a recompile if sizes change.

This obsoletes https://github.com/pytorch/pytorch/pull/101813

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103290
Approved by: https://github.com/voznesenskym
2023-06-09 17:43:47 +00:00
Bin Bao
49577c7e47 [inductor] Turn off autotune_cublasLt for cpp_wrapper (#103004)
Summary: bias_addmm is not backed up by a cpp funciton, so turn
autotune_cublasLt for cpp_wrapper + max_autotune. We can add a cpp
function implementation if there is a performance need.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103004
Approved by: https://github.com/jansel
2023-06-06 14:08:05 +00:00
Bin Bao
44fdfd3222 [inductor] Support select_algorithm with cpp_wrapper (#103003)
Summary: This is one step towards getting cpp_wrapper work with max_autotune.
Switch to use unique kernel name to cache generated cubin file.

This is a copy of https://github.com/pytorch/pytorch/pull/102738 to solve a ghstack issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103003
Approved by: https://github.com/jansel
2023-06-06 14:08:05 +00:00
Bin Bao
881307abcf [inductor] Fix a cpp_wrapper issue when fx_passes modified fx graph (#102851)
Summary: Currently cpp_wrapper for CUDA does it in two passe, which
means we need to deepcopy the input module to isolate any fx
transformations between the two passes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102851
Approved by: https://github.com/jansel
2023-06-05 00:20:38 +00:00
Shunting Zhang
86c7652503 [inductor] layout optimization for conv (#99773)
convolution kernel with channels last runs much faster then kernel with contiguous inputs. The PR leverage that to optimize tensor layouts so we provide 'channels last' inputs to convolution. Some care need to be taken to not convert tensor layout between contiguous and channels last back and forth. Those extra copies hurt performance quite much.

Latest perf number [here](https://hud.pytorch.org/benchmark/compilers?startTime=Wed%2C%2024%20May%202023%2023%3A40%3A37%20GMT&stopTime=Wed%2C%2031%20May%202023%2023%3A40%3A37%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=shunting-layout-opt-19&lCommit=baa797fc100688dfb044fbcbdebcfd2591710f78&rBranch=main&rCommit=999bae0f54108ffc5b7cf2524a02a83901554b16)
- TB: 1.64x -> 1.69x
- HF: 1.79x -> 1.78x (random noise)
- TIMM: 1.51x -> 1.65x

Right now we disable layout optimization for dynamic shape since there is perf loss in that combination. Here is a GH issue to followup: https://github.com/pytorch/pytorch/issues/102670

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99773
Approved by: https://github.com/jansel
2023-06-02 21:08:18 +00:00
spectrometerHBH
5ee46afc05 perf hint logging in inductor (#102250)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102250
Approved by: https://github.com/Skylion007, https://github.com/shunting314, https://github.com/jansel
2023-05-27 03:43:30 +00:00
Animesh Jain
9c4fd72b53 [aot_autograd][functional_rng] Change calling convention (#102344)
Key change - seed, offset are the last 2 args in both the fwd and bwd graphs
Reason - The cudagraphs implementation in inductor currently relies on very simple ordering guarantees i.e. first n inputs are static for both fwd and bwd graphs. In the current implementation of functionalization of rng ops, this assumption is broken because the first 2 inputs are seed, offset.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102344
Approved by: https://github.com/eellison
2023-05-26 21:27:20 +00:00
Bin Bao
fd1d442185 [inductor] Add more dynamic shapes support for CudaWrapperCodeGen (#102019)
Summary: Use size hint for autotuning; Fix some symbol arg codegen
problem. More PRs coming for fixing unit test failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102019
Approved by: https://github.com/jansel
2023-05-24 13:29:47 +00:00
Jason Ansel
0c6f409cda [inductor] Refactor RNG operators (#100064)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100064
Approved by: https://github.com/ngimel
2023-05-20 03:43:33 +00:00
PyTorch MergeBot
5f07c589b0 Revert "[inductor] Refactor RNG operators (#100064)"
This reverts commit 3bbf0683a1.

Reverted https://github.com/pytorch/pytorch/pull/100064 on behalf of https://github.com/izaitsevfb due to breaks inductor tests, see D45936056 ([comment](https://github.com/pytorch/pytorch/pull/100064#issuecomment-1552093728))
2023-05-17 21:16:41 +00:00
Jason Ansel
3bbf0683a1 [inductor] Refactor RNG operators (#100064)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100064
Approved by: https://github.com/ngimel
2023-05-17 01:29:31 +00:00
Daohang Shi
2af7df62a5 log inductor compilation time to scuba (#101317)
Summary: Set up timer around `compile_fx_inner` and log to scuba

Differential Revision: D45822137

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101317
Approved by: https://github.com/nmacchioni
2023-05-16 16:32:17 +00:00
Elias Ellison
3b7c6b21d7 Disable locality reodering in training (#101423)
Differential Revision: [D45874682](https://our.internmc.facebook.com/intern/diff/D45874682)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101423
Approved by: https://github.com/ngimel
2023-05-15 21:34:49 +00:00
Bin Bao
86ddfc7f68 [inductor] Move cpp wrapper trigger logic to inner_compile (#100611)
Summary: This enables cpp wrapper for backward as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100611
Approved by: https://github.com/jansel
2023-05-08 15:24:02 +00:00