Commit Graph

102 Commits

Author SHA1 Message Date
Adnan Akhundov
4911b80b8e [inductor] addmm + ReLU / GELU fusion pass (#104132)
Summary:

Add a new path in `post_grad.py` for replacing addmm + ReLU / GELU activation with the corresponding `_addmm_activation` call (with `use_gelu=False` or `True`, respectively). The replacement is done only on `max_autotune_gemm=False` and when the activation is fusible.

Test Plan:

$ python test/inductor/test_pattern_matcher.py -k test_addmm_activation -v

(__main__.TestPaternMatcher.test_addmm_activation) ... /data/users/aakhundov/pytorch/torch/_inductor/compile_fx.py:128: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Using FallbackKernel: aten._addmm_activation.default
Using FallbackKernel: aten._addmm_activation.default
/data/users/aakhundov/pytorch/torch/_dynamo/eval_frame.py:373: UserWarning: changing options to `torch.compile()` may require calling `torch._dynamo.reset()` to take effect
  warnings.warn(
frames [('total', 1), ('ok', 1)]
stats [('calls_captured', 2), ('unique_graphs', 1)]
aot_autograd [('total', 1), ('ok', 1)]
inductor []
ok

----------------------------------------------------------------------
Ran 1 test in 13.415s

OK

Reviewers: @eellison

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104132
Approved by: https://github.com/eellison, https://github.com/jansel
2023-07-10 16:44:14 +00:00
Edward Z. Yang
2385dad4b3 Enable automatic_dynamic_shapes by default (#103623)
Some notes:

* I now manually turn off `_generate` jobs from running with cudagraphs, as it is unrealistic to expect to cudagraph autoregressive generation up to max sequence length, this would imply compiling the entire unrolled sequence generation. Concretely, cm3leon_generate was timing out post this change, likely due to the compile time slowdown of dynamic shapes ON TOP OF accidentally unrolling all the loops
* A few torch._dynamo.reset tactically inserted to force recompiles on tests that expected it
* expectedFailureAutomaticDynamic flip into patching automatic_dynamic_shapes=False

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103623
Approved by: https://github.com/voznesenskym
2023-07-05 00:25:02 +00:00
Shunting Zhang
98f00f881f [inductor] convert layout of conv weight ahead of time for inference (#103642)
This PR handles inference. Will do similar thing for training later.

Some manual testing results shows this can improve inference perf by 2-3% (absolute improvement not relative one).
- convmixer: 4.285x -> 4.309x
- resnet50: 2.170x -> 2.203x

The PR is built upon freezing. Since without freezing, the weight input for a conv node may not be a parameter directly but be the output of precision converting ops. It's so much easier to implement this PR after freezing.

Commands
```
TORCHINDUCTOR_FREEZING=1 python benchmarks/dynamo/timm_models.py --backend inductor --amp --performance --only convmixer_768_32 --inference
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103642
Approved by: https://github.com/eellison
2023-06-28 17:42:32 +00:00
Elias Ellison
de7b6e55eb Fix bad cudagraph interaction (#104196)
Fix for https://github.com/pytorch/pytorch/issues/103126

As mentioned there,

> We need to make sure we are not removing the misaligned inputs before we are checking for misalignment in cudagraphs, so we know not to expect a static input for the misaligned tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104196
Approved by: https://github.com/desertfire
2023-06-27 21:36:09 +00:00
Elias Ellison
edc9c0df7e Fold Conv-Bn (#100653)
Adds Conv-BN folding to inductor freezing. One thing that's a little awkward now is we'll want different decompositions to run depending on if we are in the inference compiler. For now, I require that you run with torch.no_grad() so we can detect if no gradients are required before calling aot_autograd.

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100653
Approved by: https://github.com/jansel
2023-06-26 16:04:34 +00:00
Bin Bao
c1fffdcd5b Change how AOTInductor's fx input is produced (#104123)
Test Plan: CI

Reviewed By: wushirong

Differential Revision: D46983754

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104123
Approved by: https://github.com/chenyang78
2023-06-26 15:59:33 +00:00
Antoni Viros i Martin
0d653730ce Refactory bits for the codegen cache (#103452)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103452
Approved by: https://github.com/ezyang
2023-06-22 13:04:22 +00:00
Bin Bao
da7ca82121 [inductor] Store real inputs to be used for cpp wrapper codegen (#103289)
Summary: defaked args (zeros) may cause device-side access assertion, so
record the orginal real tensor inputs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103289
Approved by: https://github.com/jansel, https://github.com/eellison
2023-06-15 20:05:50 +00:00
Edward Z. Yang
bc6ec97e02 Switch dynamic_shapes to True by default (#103597)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103597
Approved by: https://github.com/voznesenskym
2023-06-15 15:16:20 +00:00
Elias Ellison
d083d444ff Inductor Freezing (#100652)
Adds a freezing pass that will constant fold parameters in inductor `config.freezing`. This occurs post functionalization in aot autograd to capture both dispatching and allow passes to occur post functionalization. A few notes:

- There is an option to discard parameters `config.freezing_discard_parameters` which will take the current eager modules and wrap parameters to a Tensor subclass which will error if used.
- I needed to expose flat_params in aot_autograd in order to discard old references when we constant fold away parameters, like with amp. I also exposed `fw_metadata` to avoid constant folding mutated paraemters.
- Caching parameter transformations/constant folding across different inferences nyi
- Checking version_counter of constant folded params nyi

I'm not really sure what the actual naming should be. In jit there was both "freezing", which was platform agnostic, and "optimize for inference", which made device specific optimizations. We're doing the latter here but maybe freezing is a better name.

Differential Revision: [D46244033](https://our.internmc.facebook.com/intern/diff/D46244033)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100652
Approved by: https://github.com/jansel
2023-06-12 20:56:03 +00:00
Edward Z. Yang
54daf870bc CUDA graphs overrides dynamic shapes and forces specialization (#103290)
Previously, cudagraphs and dynamic_shapes were incompatible and enabling
dynamic shapes would forcibly disable cudagraphs.  This new strategy
I think is better.  The idea is essentially that cudagraphs is an
"optimization" that happens to guard on every input.  When cudagraphs
is on, we force everything static, and this automatically does the right
thing because we will force a recompile if sizes change.

This obsoletes https://github.com/pytorch/pytorch/pull/101813

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103290
Approved by: https://github.com/voznesenskym, https://github.com/eellison
2023-06-12 20:26:55 +00:00
PyTorch MergeBot
d89dd05e4d Revert "CUDA graphs overrides dynamic shapes and forces specialization (#103290)"
This reverts commit c760f0e4dd.

Reverted https://github.com/pytorch/pytorch/pull/103290 on behalf of https://github.com/ezyang due to to handle the other cuda graphs case ([comment](https://github.com/pytorch/pytorch/pull/103290#issuecomment-1584977767))
2023-06-09 18:25:28 +00:00
Edward Z. Yang
c760f0e4dd CUDA graphs overrides dynamic shapes and forces specialization (#103290)
Previously, cudagraphs and dynamic_shapes were incompatible and enabling
dynamic shapes would forcibly disable cudagraphs.  This new strategy
I think is better.  The idea is essentially that cudagraphs is an
"optimization" that happens to guard on every input.  When cudagraphs
is on, we force everything static, and this automatically does the right
thing because we will force a recompile if sizes change.

This obsoletes https://github.com/pytorch/pytorch/pull/101813

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103290
Approved by: https://github.com/voznesenskym
2023-06-09 17:43:47 +00:00
Bin Bao
49577c7e47 [inductor] Turn off autotune_cublasLt for cpp_wrapper (#103004)
Summary: bias_addmm is not backed up by a cpp funciton, so turn
autotune_cublasLt for cpp_wrapper + max_autotune. We can add a cpp
function implementation if there is a performance need.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103004
Approved by: https://github.com/jansel
2023-06-06 14:08:05 +00:00
Bin Bao
44fdfd3222 [inductor] Support select_algorithm with cpp_wrapper (#103003)
Summary: This is one step towards getting cpp_wrapper work with max_autotune.
Switch to use unique kernel name to cache generated cubin file.

This is a copy of https://github.com/pytorch/pytorch/pull/102738 to solve a ghstack issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103003
Approved by: https://github.com/jansel
2023-06-06 14:08:05 +00:00
Bin Bao
881307abcf [inductor] Fix a cpp_wrapper issue when fx_passes modified fx graph (#102851)
Summary: Currently cpp_wrapper for CUDA does it in two passe, which
means we need to deepcopy the input module to isolate any fx
transformations between the two passes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102851
Approved by: https://github.com/jansel
2023-06-05 00:20:38 +00:00
Shunting Zhang
86c7652503 [inductor] layout optimization for conv (#99773)
convolution kernel with channels last runs much faster then kernel with contiguous inputs. The PR leverage that to optimize tensor layouts so we provide 'channels last' inputs to convolution. Some care need to be taken to not convert tensor layout between contiguous and channels last back and forth. Those extra copies hurt performance quite much.

Latest perf number [here](https://hud.pytorch.org/benchmark/compilers?startTime=Wed%2C%2024%20May%202023%2023%3A40%3A37%20GMT&stopTime=Wed%2C%2031%20May%202023%2023%3A40%3A37%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=shunting-layout-opt-19&lCommit=baa797fc100688dfb044fbcbdebcfd2591710f78&rBranch=main&rCommit=999bae0f54108ffc5b7cf2524a02a83901554b16)
- TB: 1.64x -> 1.69x
- HF: 1.79x -> 1.78x (random noise)
- TIMM: 1.51x -> 1.65x

Right now we disable layout optimization for dynamic shape since there is perf loss in that combination. Here is a GH issue to followup: https://github.com/pytorch/pytorch/issues/102670

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99773
Approved by: https://github.com/jansel
2023-06-02 21:08:18 +00:00
spectrometerHBH
5ee46afc05 perf hint logging in inductor (#102250)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102250
Approved by: https://github.com/Skylion007, https://github.com/shunting314, https://github.com/jansel
2023-05-27 03:43:30 +00:00
Animesh Jain
9c4fd72b53 [aot_autograd][functional_rng] Change calling convention (#102344)
Key change - seed, offset are the last 2 args in both the fwd and bwd graphs
Reason - The cudagraphs implementation in inductor currently relies on very simple ordering guarantees i.e. first n inputs are static for both fwd and bwd graphs. In the current implementation of functionalization of rng ops, this assumption is broken because the first 2 inputs are seed, offset.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102344
Approved by: https://github.com/eellison
2023-05-26 21:27:20 +00:00
Bin Bao
fd1d442185 [inductor] Add more dynamic shapes support for CudaWrapperCodeGen (#102019)
Summary: Use size hint for autotuning; Fix some symbol arg codegen
problem. More PRs coming for fixing unit test failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102019
Approved by: https://github.com/jansel
2023-05-24 13:29:47 +00:00
Jason Ansel
0c6f409cda [inductor] Refactor RNG operators (#100064)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100064
Approved by: https://github.com/ngimel
2023-05-20 03:43:33 +00:00
PyTorch MergeBot
5f07c589b0 Revert "[inductor] Refactor RNG operators (#100064)"
This reverts commit 3bbf0683a1.

Reverted https://github.com/pytorch/pytorch/pull/100064 on behalf of https://github.com/izaitsevfb due to breaks inductor tests, see D45936056 ([comment](https://github.com/pytorch/pytorch/pull/100064#issuecomment-1552093728))
2023-05-17 21:16:41 +00:00
Jason Ansel
3bbf0683a1 [inductor] Refactor RNG operators (#100064)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100064
Approved by: https://github.com/ngimel
2023-05-17 01:29:31 +00:00
Daohang Shi
2af7df62a5 log inductor compilation time to scuba (#101317)
Summary: Set up timer around `compile_fx_inner` and log to scuba

Differential Revision: D45822137

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101317
Approved by: https://github.com/nmacchioni
2023-05-16 16:32:17 +00:00
Elias Ellison
3b7c6b21d7 Disable locality reodering in training (#101423)
Differential Revision: [D45874682](https://our.internmc.facebook.com/intern/diff/D45874682)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101423
Approved by: https://github.com/ngimel
2023-05-15 21:34:49 +00:00
Bin Bao
86ddfc7f68 [inductor] Move cpp wrapper trigger logic to inner_compile (#100611)
Summary: This enables cpp wrapper for backward as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100611
Approved by: https://github.com/jansel
2023-05-08 15:24:02 +00:00
Bin Bao
ec3c8abb54 [inductor] Remove redundant model copy when running with cpp_wrapper (#100275)
Summary: to reduce the peak memory consumption

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100275
Approved by: https://github.com/jansel
2023-05-02 16:43:18 +00:00
Bin Bao
afa9d10ed6 [inductor] Support mixed device in cpp wrapper (#99950)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99950
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-04-26 16:26:56 +00:00
Bin Bao
e43918b93a [inductor] Fix AOTInductor (#99203)
Summary: Fix the broken AOTInductor flow and add a smoketest on CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99203
Approved by: https://github.com/jansel
2023-04-25 14:42:12 +00:00
Edward Z. Yang
a109453df4 Delete use_functionalize feature flag (#99317)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99317
Approved by: https://github.com/voznesenskym
2023-04-18 02:09:57 +00:00
Edward Z. Yang
17d7be68ee Delete functorch use_fake_tensor and debug_fake_cross_ref (#99314)
Using fake tensor with AOTAutograd is now mandatory, simplifying our
logic.  Unfortunately, this means debug_fake_cross_ref must go,
but I don't think anyone has used it recently.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99314
Approved by: https://github.com/eellison, https://github.com/zou3519
2023-04-18 02:09:54 +00:00
Jason Ansel
6e1e27fc4e [inductor] Refactor pre-grad passes into inductor.fx_passes (#99130)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99130
Approved by: https://github.com/ngimel
2023-04-16 04:05:56 +00:00
PyTorch MergeBot
629377ea8b Revert "Replace _dynamo.config with an object instead of module (#96455)"
This reverts commit 420104a886.

Reverted https://github.com/pytorch/pytorch/pull/96455 on behalf of https://github.com/jansel due to BC breaking, was landed prematurely
2023-04-12 15:06:14 +00:00
XiaobingSuper
9c98f2ceb7 inductor: rewrite mkldnn fx fusion using pattern_matcher(binary) (#97141)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97141
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
2023-04-12 06:23:03 +00:00
XiaobingSuper
c214c50355 inductor: rewrite mkldnn fx fusion using pattern_matcher(conv_unary) (#97007)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97007
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
2023-04-12 05:52:54 +00:00
Han Qi
420104a886 Replace _dynamo.config with an object instead of module (#96455)
Summary:
    Replace _dynamo.config with an object instead of module

    Current usage patterns of setting and reading fields on config will work
    unchanged.

    Only changes needed going forward:
    1. import torch._dynamo.config will not work. However, just doing
       import torch._dynamo is sufficient to access dynamo config
       as torch._dynamo.config.

    2. Files inside of _dynamo folder need to access config via
       from torch._dynamo.config_util import config instead of
       from torch._dynamo import config. Because _dynamo/__init__.py
       imports some of the files so it would be circular import.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96455
Approved by: https://github.com/williamwen42
2023-04-11 21:23:32 +00:00
Elias Ellison
76ac454146 Index expanded dims before checking memory overlap (#98656)
As the comment for `get_expanded_dims` says:

```
# copy_ fails when trying to write to tensors with memory overlap,
# for expanded dimensions (a dimension which used to have size 1 -> ?)
# we can select one element from that dimension and write to it
# to achieve writing to all values of that dimension of the input tensor
```

We were doing this for the copy, for not for checking if we could copy. Update it so we index then check for memory overlap. This covers all of the `complex_striding` warnings I observed in TB.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98656
Approved by: https://github.com/ngimel, https://github.com/yf225
2023-04-10 22:58:32 +00:00
Kazuaki Ishizaki
f011db345f Fix typos under torch/_inductor directory (#97592)
This PR fixes typos in comments and messages of `.py` files under `torch/_inductor` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97592
Approved by: https://github.com/dagitses, https://github.com/kit1980
2023-04-10 22:53:18 +00:00
Jason Ansel
8fee46693c Fused attention patterns (#97741)
Patterns based on https://github.com/pytorch/pytorch/pull/94729 mainly as a forcing function for implementing joint graph replacements.

Up until now, we had two places to do pattern matching
1) Pre-grad has janky infra (graph not normalized or functional), but is
   desirable for many types of passes where you want your change to
   affect grad formulas.
2) Post-grad has good infra, but cant change grad formulas.

This PR adds a third place to do pattern matching: the joint
forward+backwards graph.  The idea is to take the patterns and lower
them to a joint graph and replace both the forwards+backwards before
we partition them.  This allows us to do something similar to pre-grad
transforms, but run after normalization and functionalization.

Note that we don't seem to have kernels for all of these patterns, some get decomposed in the dispatcher.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97741
Approved by: https://github.com/Chillee
2023-04-10 00:35:22 +00:00
Bin Bao
152d65ae1d [reland][inductor] Enable CudaWrapperCodeGen for non-AOT mode (#98534)
Summary: This is a reland of #98264.

When _inductor.config.cpp_wrapper is specified, we run a
two-pass wrapper codegen to generate wrapper code in cpp which calls
cuLaunchKernel to launch pre-compiled cuda kernels, and then call
load_inline to load that generated wrapper back into the python world.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98534
Approved by: https://github.com/huydhn
2023-04-07 02:04:03 +00:00
PyTorch MergeBot
f228b3977b Revert "[inductor] Enable CudaWrapperCodeGen for non-AOT mode (#98264)"
This reverts commit 77f32eb6cc.

Reverted https://github.com/pytorch/pytorch/pull/98264 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but this is failing in trunk due to a name error fake_mode_from_tensors is not defined 67d1a77086. This is probably a landrace
2023-04-06 19:00:09 +00:00
Bin Bao
77f32eb6cc [inductor] Enable CudaWrapperCodeGen for non-AOT mode (#98264)
Summary: when _inductor.config.cpp_wrapper is specified, we run a
two-pass wrapper codegen to generate wrapper code in cpp which calls
cuLaunchKernel to launch pre-compiled cuda kernels, and then call
load_inline to load that generated wrapper back into the python world.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98264
Approved by: https://github.com/ngimel
2023-04-06 15:59:55 +00:00
Edward Z. Yang
680bf14a40 [EASY] Fix some more places where we incorrectly assume only Tensor (#98310)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98310
Approved by: https://github.com/voznesenskym
2023-04-06 00:57:59 +00:00
Edward Z. Yang
d01ee10b25 Add detect_fake_mode (#98321)
This replaces fake_mode_from_tensors but it preferentially looks for
fake_mode in TracingContext and also if there is an active fake mode
on the dispatch stack, before groveling in tensors to find it.

This advances PegasusForCausalLM, which was previously failing because
we generated a graph that had a parameter (non-fake) and a SymInt,
and thus previously we failed to detect the correct fake mode.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98321
Approved by: https://github.com/voznesenskym
2023-04-05 22:15:16 +00:00
Jason Ansel
3344d79e3f Pattern matcher improvements (#97740)
This adds support for multi-output patterns and example-based
replacements.

Tests/usage are next in this PR stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97740
Approved by: https://github.com/ngimel
2023-04-05 15:25:34 +00:00
Elias Ellison
feb9ec4282 Account for forwards which whose corresponding backwards are not invoked (#98112)
Previously, when we would run a forward graph whose backward we never invoked it would prevent us from switching from warmup to recording. Now, refine the heuristic to allow incrementing the generation as soon as we invoke a backward graph. This still handles the
```
mod1 = torch.compile(...)

mod2 = torch.compile(...)

mod2(mod1(x)).sum().backward()
```
case while accounting for graphs which we may not run backward of.

It also now handles the case where we skip cudagraphify the backward of a forward.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98112
Approved by: https://github.com/jansel
2023-04-05 06:12:16 +00:00
Bin Bao
96f548a1ac [inductor] Add an AOT mode for the Triton backend (#98214)
Summary:
This is a copy of https://github.com/pytorch/pytorch/pull/97152 to make
the landing easier.

This PR implements a two-pass wrapper codegen for the Triton
backend to achieve ahead-of-time compilation. In the first pass, the
regular python wrapper code will be generated, and then the generated
code will be executed to perform Triton compilation and autotuning.
After that, the second pass wrapper codegen will generate C++ wrapper
with proper CUDA API to load and launch Triton-generated CUDA kernels.

Like the AOT mode for the cpp backend, the next step would be to provide
a more complete API for AOT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98214
Approved by: https://github.com/eellison
2023-04-03 22:19:18 +00:00
PyTorch MergeBot
aee96e2cb3 Revert "[inductor] Refactor cpp_wrapper to be an attribute of GraphLowering (#97709)"
This reverts commit 8710dc8d5a.

Reverted https://github.com/pytorch/pytorch/pull/97709 on behalf of https://github.com/malfet due to Broke cpu_wrapper tests on MacOS, see https://github.com/pytorch/pytorch/actions/runs/4545603517/jobs/8014327136#step:13:868
2023-03-28 22:07:33 +00:00
Elias Ellison
6854fd7189 Add Config to Skip Cpp Codegen, Enable in FBCode (#97204)
Differential Revision: [D44353662](https://our.internmc.facebook.com/intern/diff/D44353662)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97204
Approved by: https://github.com/ngimel, https://github.com/bertmaher, https://github.com/mikekgfb, https://github.com/cpuhrsch
2023-03-28 18:21:15 +00:00
Bin Bao
8710dc8d5a [inductor] Refactor cpp_wrapper to be an attribute of GraphLowering (#97709)
Summary: to prepare for further AOT Inductor changes

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 7dff885</samp>

This pull request adds support for AOT compilation and C++ wrapper code generation for inductor models. It modifies the `GraphLowering` class in `torch/_inductor/graph.py` and the `compile_fx` function in `torch/_inductor/compile_fx.py` to enable this feature.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97709
Approved by: https://github.com/jansel
2023-03-28 16:50:36 +00:00