Commit Graph

513 Commits

Author SHA1 Message Date
PyTorch MergeBot
d49abf039a Revert "update pointwise cat heuristics (#125772)"
This reverts commit d19d932183.

Reverted https://github.com/pytorch/pytorch/pull/125772 on behalf of https://github.com/izaitsevfb due to Fails numerical stability test for aps model, see D57215900 ([comment](https://github.com/pytorch/pytorch/pull/125772#issuecomment-2105932504))
2024-05-11 15:27:44 +00:00
Brian Hirsh
f25c7c9699 functionalize storage resizing, minimal ppFSDP traceable forward (#122434)
More details further down, but first a more high-level description of "how do we functionalize storage resizing"

Today, dynamo converts `param.untyped_storage().resize_(x)` calls that it sees from fsdp into a custom op, `ops.inductor.resize_storage_bytes_(x)`

So given this setup, there are 3 main cases that I think we want to handle:

(1) graph input starts with a real storage size, gets resized down to zero in the graph
(2) graph input starts with 0 storage size, gets resized up in the graph
(3) graph input starts with 0 storage size, gets resized up and used in some compute, then resized back down to 0

For case (1) we need to emit a `resize_storage_bytes_` at the end of the graph, similar to how we emit `copy_()` for data mutations.

For case (2), we need to emit a `resize_storage_bytes_` in the graph, and we **also** need to emit a `copy_()` (the input had its storage resized up, and filled in with data, which is we need to reflect as an input mutation)

For case (3), the net effect is that the input had no data on entry and exit of the function, so we don't need to emit any mutable ops in the end of the graph.

The main thing to call out is that: we need to write a functionalization rule for `resize_storage_byte_`, (`FunctionalTensorWrapper::storage_resize_()`) and this rule actually does very little. We would like to **not** emit any new ops in the graph (like say, a functional resize op). Instead, we should expect / rely on the fact that any resize up will be immediately followed by a `copy_()`/`foreach_copy_`/`out=` op, that will fill in the data of the tensor. So `FunctionalTensor` can temporarily live in a state where its data is invalid, until the `x.copy_(y)` "updates" its data with the new tensor.

So effectively, all that this rule does is:

(1) it stores metadata on the storage, indicating that the tensor was resized, as well as the updated storage size. We need this info in AOTAutograd, so it knows whether to emit a mutable resize_() op in the graph epilogue

(2) There is also a corner case: if we are resizing down to zero, but our tensor had **previously** had a zero size storage, then we update `value_` to point to the original value of the tensor. The reason this seems safe is because if we have a zero storage sized tensor `x`, and we resize it up, use it in some compute, resize it back down to zero, and use it somewhere, we would want the functional version of this code to use the original `x` after the second resize. For FSDP, this is important because we end up saving parameters (graph inputs) for backward, and we want to make sure that the thing we save (and the output to the forward graph) is the original, zero-storage-sized parameter, and not the "version 2" of the parameter after the first resize_()

I think a good order to look at changes in this PR would be:

(1) `test_aotdispatch.py` shows the 3 main cases I focused on as well as the expected functionalized graphs

(2) In `FunctionalStorageImpl.h/cpp`, I had to add a notion of "original base", and "original/curr_size". The first is so I can re-use the zero-size tensor after multiple resizes, and the second is so I can tell in AOTAutograd whether any resizes canceled each other out into a no-op

(3) FunctionalTensorWrapper.h/cpp has the new resize functionalizion rule + some extra utils

(4) `_functorch/_autograd`: the main changes in this folder were around adding the logic at trace-time to detect when we need to put a resize_() in the graph. I also have some assertions to check that any inputs that experience storage resizing will **always be in the graph** and not the opaque epilogue, and I also limited the resize_() mutation case so that you can only ever start with zero storage, or end with zero storage (you can't do e.g. `torch.ones(2).storage().resize_(3)`), and banned it on tensor subclasses

(5) `fake_tensor.py`/`meta_utils.py`: we now need to be able to fakeify tensors with zero storage, so I added a quick version of it in meta_utils.py. This also.. has ramifications for fake tensor caching that I need to fix (include the storage size on the cache key, maybe?)

------------------

This PR subsumes https://github.com/pytorch/pytorch/pull/120971.

This PR is enough to **almost** get a simple ppFSDP forward pass tracing with a functionalized resize_() properly. It also attempts to do the updated version from @jansel, where we don't have any notion of `resize_()` in the graph at all, post functionalization. It would probably be good to test it with @yf225 's FSDP changes, and see how many of the FX passes it allows us to remove. I think that in theory, it should allow us to remove all FX passes that affect the forward graph / partitioner, **except** the one that forces views to be recomputed in the backward (more details below).

There are a few things worth calling out:

(1) failed attempt at functionalizing `aten.copy_()`. I originally wanted to get a version takes these operations:
```
param.storage().resize_(all_gather_size)
param.copy_(all_gather_buffer)
out = aten.matmul(param, param)
```
and functionalizes them into:
```
out = aten.matmul(all_gather_buffer, all_gather_buffer)
```

This would involve getting functionalization to turn `x.copy_(y)` into a giant no-op that just returns `y`. Unfortunately, we can't actually do this in a reasonable way within functionalization (instead, there's a functional `aten.copy` in the graph - see the test case graph expecttest for details). Why? In order for that transformation to be safe, `x` and `y` need to have the same metadata. However, it's possible for `x` and `y` to be subclasses of different types. This is not something we can easily tell from within functionalization, and would be a layering violation. So for now I'm leaving it to downstream code to optimize away the `aten.copy` (this is already the case today, so I think inductor can handle this)

(2) The forward doesn't **actually** run successfully in this PR (see the `assertRaisesRegex` in the test). Why?

The final forward graph looks like this:
```
def forward(self, primals_1, primals_2):
    _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]);  primals_2 = None
    getitem = _foreach_copy[0];  _foreach_copy = None
    mm = torch.ops.aten.mm.default(getitem, getitem);  getitem = None
    t_1 = torch.ops.aten.t.default(primals_1);  primals_1 = None
    return [mm, t_1]
```

Where `primals_1` starts out as a secretly-zero-storage-size parameter, and gets resized up and back down within the forward (these are functionalized away).

Importantly, the matmul happy on the result of the `foreach_copy`, **but** the activation that we save for backward (`t_1`) is the result of transposing the **original parameter** (the zero-storage-size param). This is exactly the optimization in fsdp that allows us to have good peak memory usage.

The problem is that the min-cut partitioner decides to save `t_1` for backward. Running this code in eager breaks, because the kernel for `aten.permute(x)` is not happy when `x` has secretly-zero-sized-storage.

The real problem here is that in eager mode the `permute` kernel runs during the backward, after backward hooks have properly resized the saved activation. Here, we are running the transpose in the forward.

One option would be to turn off the checks in our view kernels and allow them to work on zero-storage-sized tensors, which feels pretty bad. Another option is to tweak the partitioner (or use one of Will's FX passes) to force the partitioner to not save views for backward, and allow the views to be recomputed in the backward. This seems kind of silly, but is also probably harmless.

(3) The backward is still broken. To be fair, this issue is pretty separable from "functionalizing storage resize calls", and can be fixed later (either by a real fix to our tracing infra, or via another hacky FX pass). More description of this problem is described at issue (8) of my PR description in https://github.com/pytorch/pytorch/pull/120971

(4) I only added support for "full graph" resizing: basically, the limited case where a param starts with zero storage size, and gets resized up and back down. I think we can add support for the graph break case, but I think we can keep that add-on separate from this PR unless we need it immediately. I also added asserts so we should fail loudly when we hit this case

(5) I have a change to FakeTensor creation when inputs have zero storage size that.. is probably ok. But I also removed FakeTensor caching on view ops, which I probably need to fix before I can land this PR

(6) I added a notion of "original_base" to `FunctionalStorageImpl`. More details are in the comments, but my rational for this was that we basically need it to ensure that autograd saves the **original**, zero-storage-sized param for backward, after resizing up and back down

(7) I had to update our eager kernels for `aten.copy` and `aten._foreach_copy`, to handle the case where the `self` argument has secretly-zero-storage. Inductor can probably generate correct code for this case, but we need these ops to work properly in this situation for the `aot_eager` backend to do the right thing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122434
Approved by: https://github.com/jansel
2024-05-10 18:09:10 +00:00
eellison
d19d932183 update pointwise cat heuristics (#125772)
Fix for https://github.com/pytorch/pytorch/issues/122871. There are two cases where we emit pointwise cat:

- fusing into a pointwise use
- horizontally fusing copy_ kernels

The regression I looked into previously was due to being overly aggressive in the latter case. I've updated the logic there so that we only emit the horizontal fusion in the case that we would have to emit separate copy_ kernels anyway.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125772
Approved by: https://github.com/Chillee
2024-05-10 01:07:39 +00:00
leslie-fang-intel
d83ab88f81 [Inductor] [Quant] Enable lowering of quant per tensor and refactor quant pattern (#124041)
**Summary**
Per the discussion in https://github.com/pytorch/pytorch/pull/123444, the `decomposed quant/dequant` patterns changed after https://github.com/pytorch/pytorch/pull/123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in https://github.com/pytorch/pytorch/pull/123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124041
Approved by: https://github.com/peterbell10, https://github.com/jgong5
2024-05-09 08:40:44 +00:00
PyTorch MergeBot
ea3f625e32 Revert "[Inductor] [Quant] Enable lowering of quant per tensor and refactor quant pattern (#124041)"
This reverts commit 33e6791645.

Reverted https://github.com/pytorch/pytorch/pull/124041 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think there is a land race with the change 33e6791645 ([comment](https://github.com/pytorch/pytorch/pull/124041#issuecomment-2101766558))
2024-05-09 01:34:19 +00:00
leslie-fang-intel
33e6791645 [Inductor] [Quant] Enable lowering of quant per tensor and refactor quant pattern (#124041)
**Summary**
Per the discussion in https://github.com/pytorch/pytorch/pull/123444, the `decomposed quant/dequant` patterns changed after https://github.com/pytorch/pytorch/pull/123445, we can move the optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase to avoid the changes. In this way, we can:

- Avoid the pattern matcher failure introduced in https://github.com/pytorch/pytorch/pull/123445
- Make the quantization pattern clearer in the pattern matcher phase, since the `quant/dequant` nodes have not been decomposed.

**Changes in this PR**

- Move optimization of `decomposed quant/dequant` from inductor decomposition into lowering phase.
- Corresponding changes in the quantization pattern matcher to ensure no bc-breaking.

**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_q
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124041
Approved by: https://github.com/peterbell10, https://github.com/jgong5
2024-05-09 00:54:22 +00:00
Andrew M. James
445a0c01da Retry: Low mem max_pool2d_with_indices (#122832)
Based on #105687

The low memory path does not need to strictly return the int8 offsets
instead the offset to index computation can be separated from the
inner function of the max pool lowering. The partitioner can then choose
to move the offset to index computation into the backward pass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122832
Approved by: https://github.com/peterbell10, https://github.com/eellison
2024-05-08 19:37:08 +00:00
angelayi
8be4c1bc2f [export] Add metadata for nodes insert_deferred_runtime_asserts (#125414)
Fixes [internal error](https://fb.workplace.com/groups/1075192433118967/permalink/1416709435633930/).

The issue is that the asserting nodes added in the `insert_deferred_runtime_assertion` pass do not contain metadata that the ExportedProgram requires the graph to have. One solution to fix this is to retrace the entire module, or another solution is to manually add back this metadata.

This diff implements the latter solution (manually add back the metadata) through hooking into fx.graph's `create_node` function, and adding export-specific metadata for every node that is created. The reason I did this is so that the `insert_deferred_runtime_assertion` does not have to know about what metadata export wants.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125414
Approved by: https://github.com/zhxchen17, https://github.com/BoyuanFeng
2024-05-07 23:15:21 +00:00
Peter Bell
24b64fc482 [HOP][inductor] Support pytrees as associative_scan input (#122137)
This allows `associative_scan` to take an arbitrary pytree of tensors,
which is flattened to their leaves before calling the `associative_scan`
higher order operator.

I also add support in inductor to generate code for scanning over sequences
of tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122137
Approved by: https://github.com/lezcano, https://github.com/Chillee
ghstack dependencies: #119430
2024-05-06 11:29:28 +00:00
Yifu Wang
58d8388ed3 Remove Inductor IRs for legacy functional collectives (#124992)
This PR completely removes the Inductor IR for legacy functional collectives:
- Removed the `CollectiveKernel` hiearchy and `Wait`, as well as the corresponding lowerings. These IRs are target (i.e. Python) specific and don't model node dependencies propoerly (e.g. they rely on `never_reuse_buffers` for correct behavior). They've been superceded by `ir._CollectiveKernel`.
- Removed `InPlaceHint` and the scheduler logic for handling it. `InPlaceHint` is a codegen-time buffer reuse mechanism controlled by the IR's codegen. It's a bit hacky and overlaps with the default buffer reuse mechanism. Removing it since it is only used by legacy functional collectives.
- Removed `OutputBuffer` and `MultiOutputNoSizeAssert` which are designed for and only used by legacy functional collectives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124992
Approved by: https://github.com/Chillee, https://github.com/wanchaol
2024-05-05 19:49:58 +00:00
Edward Z. Yang
5503c29357 Introduce torch.utils._sympy.symbol (#125395)
This provides utilities for creating and querying properties on
sympy.Symbol.  I want to use this refactor to get a better handle on how
the 's' prefix is being used in Inductor.  To start, I only do
symbolic_shapes code because that's what I'm familiar with.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125395
Approved by: https://github.com/Skylion007
2024-05-03 21:24:23 +00:00
drisspg
25691558d9 Change templated_attention -> flex_attention (#125251)
# Summary

Change all the names

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125251
Approved by: https://github.com/Chillee, https://github.com/yanboliang
2024-05-01 01:08:48 +00:00
Kazuaki Ishizaki
9fec26e231 Fix typo under torch/_inductor directory (#119658)
This PR fixes typo in comments and msgs under `torch/_inductor` directory, and also changes the corresponding test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119658
Approved by: https://github.com/colesbury
2024-04-30 22:28:56 +00:00
Wanchao Liang
00df0d3e94 [dtensor] implement shard dim change with alltoall (#124872)
as titled, we implement a dedicated communication op to allow efficient
sharding dimension change using alltoall, to replace our previous
allgather + local chunk

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124872
Approved by: https://github.com/XilunWu, https://github.com/yifuwang
ghstack dependencies: #124871
2024-04-30 18:30:34 +00:00
PyTorch MergeBot
f1d1e3246f Revert "[dtensor] implement shard dim change with alltoall (#124872)"
This reverts commit 6b79469d24.

Reverted https://github.com/pytorch/pytorch/pull/124872 on behalf of https://github.com/clee2000 due to broke distributed/tensor/parallel/test_tp_examples.py::DistTensorParallelExampleTest::test_transformer_training_is_seq_parallel_True https://github.com/pytorch/pytorch/actions/runs/8882762411/job/24389191482 f7f018a0ed.  Bad TD ([comment](https://github.com/pytorch/pytorch/pull/124872#issuecomment-2083599445))
2024-04-29 20:26:16 +00:00
Wanchao Liang
6b79469d24 [dtensor] implement shard dim change with alltoall (#124872)
as titled, we implement a dedicated communication op to allow efficient
sharding dimension change using alltoall, to replace our previous
allgather + local chunk

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124872
Approved by: https://github.com/XilunWu, https://github.com/yifuwang
ghstack dependencies: #124871
2024-04-29 17:22:30 +00:00
Edward Z. Yang
e5e623af4b Codegen runtime asserts in Inductor (#124874)
This completely subsumes https://github.com/pytorch/pytorch/pull/120816

This makes use of the unbacked binding machinery to teach Inductor how to generate deferred runtime asserts directly. There is some back story about why I did it this way, let me explain.

Previously, our strategy for generating runtime asserts was that Dynamo would insert them into the FX graph after finishing tracing, and we would attempt to code generate them based on the FX graph. This is a good strategy for export, where we immediately export the graph. However, this strategy was afflicted by problems in eager, where we reuse the same ShapeEnv as before. In particular, on subsequent graph passes, we would immediately turn all of these assertions into noops, because when we evaluated their expressions, we would see that because we had a deferred runtime assert in the ShapeEnv, we know "oh, of course this expression is True" already. Oops!

So, with this PR, we take the attitude that as long as the ShapeEnv sticks around, the ShapeEnv's list of deferred runtime asserts is the source of truth, and we don't put anything in the graph. So we just need to decide when to actually generate asserts, and the place I picked was Inductor lowering, since we already have an AssertScalar buffer concept, and so I just need to insert them at this point. AssertScalar also uses raw sympy.Expr rather than SymInt/Bool, so it is easier to prevent unrestricted simplification at this point.

There are a few things jumbled together in this PR. I can split them if you want, but some of the changes are before I changed my strategy, but they're useful changes anyway.

**torch/_dynamo/output_graph.py** and **torch/_inductor/lowering.py** - Here, we stop putting deferred runtime asserts in the graph. I also have to make sure we don't DCE unused symbol arguments; we're going to get some goofy graph arguments this way, will be good to restore that optimization eventually. We also just disable codegen for `_assert_scalar`  entirely; we assume that ShapeEnv will be good enough to capture all of these.

**torch/_inductor/codegen/wrapper.py** and **torch/_inductor/ir.py** - Add a way to codegen sizevars without forcing simplification

**torch/_inductor/graph.py** - The main logic. Our strategy is to interpose in the same place we are testing that unbacked SymInts are properly showing up in lowered code. The logic is directly analogous to the logic in the existing insert deferred runtime asserts FX pass, but it's simpler because sympy expressions can be directly stored on inductor IR nodes.

**torch/fx/experimental/symbolic_shapes.py** - For extra safety, we have a way of freezing runtime asserts, so that if you try to add more we error. This prevents us from adding runtime asserts after we've done lowering. There's a funny interaction with backwards which there's a comment for in graph.py

**torch/fx/passes/runtime_assert.py** - This is not really needed in this PR, but I rewrote the runtime assert logic to use unbacked_bindings rather than inferring it by looking for unbacked SymInts. Now, keypaths are translated into FX node acessors. Unfortunately, I couldn't delete the old inference code, because you still need it to find backed SymInts from arguments (as this pass may be used on graphs which don't explicitly bind all their shape variables as argments). There are some new tests exercising this.

TODO: I think we need to generate asserts for replacements too. This is a preexisting problem that the old FX pass had too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124874
Approved by: https://github.com/jansel
ghstack dependencies: #124864
2024-04-29 10:19:29 +00:00
Edward Z. Yang
b4597fffce Try to reuse old symbol name rather than new symbol name when renaming (#124782)
Previously, unbacked SymInts would gradually get larger and larger as we kept rebinding them.  Now, we do the replacement to preserve the old symbol.

Actually doing this is a bit tricky.  Here’s the order things happen when retracing data dependent:

1. Run fake tensor prop: allocate new unbacked SymInt
2. Run proxy tensor mode, calculate bindings and associate them with FX node
3. Run PropagateUnbackedSymInts, rename unbacked bindings to their old ones so they are consistent

So the problem is when we calculate bindings in step (2), we don't know
what the original names are yet, we only find out later at (3).  But by
the time (3) runs, we've already stuffed some new bindings in
meta["unbacked_bindings"] and we don't know how to update them!  To fix
this, I introduce resolve_unbacked_bindings which post facto applies any
of the renamings we discovered in (3).

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124782
Approved by: https://github.com/lezcano
ghstack dependencies: #124310, #124314, #124316, #124394, #124739
2024-04-25 14:02:42 +00:00
Edward Z. Yang
13ab24f192 Reimplement unbacked symbol bindings in Inductor (#124394)
This PR has a lot of "draw the rest of the fucking owl" energy. Here's how to break it down.

1. **torch/_inductor/graph.py** - We start by tightening unbacked symbol invariants. Specifically, as we lower FX nodes, we check whether or not every unbacked_binding recorded on the FX node meta, actually ends up getting bound (according to get_unbacked_symbol_defs) in all the buffers generated by the lowering. Hopefully this invariant is self evident. This leads to a lot of failures.
2. **torch/_inductor/ir.py** - Problem 1: There is softness in how Inductor computes defs of unbacked symbols in IR node. Previously, we tried to infer it by looking at the output sizes/strides/etc and see if new unbacked symbols popped up that we hadn't seen in the inputs. I don't know exactly what was buggy about the old code, but sometimes we would fail to notice an unbacked symbol had been bound, or rebind an unbacked symbol multiple times. Fortunately, thanks to the earlier PRs in our stack, we now have a nice list of unbacked symbol bindings from FX, so we now just store it directly on ExternKernel and use it directly to report defs. This has to be done twice: once for FallbackKernel (e.g., nonzero) and once for DynamicScalar (e.g., item) (see also **torch/_inductor/lowering.py**, **torch/_inductor/codegen/wrapper.py** and  **torch/_inductor/codegen/cpp_wrapper_cpu.py** for the lowering and codegen changes for item)
   * **process_kernel** - Sidequest! It turns out that Inductor lowering can reallocate unbacked symbols. This happens specifically when we repropagate fake tensors through the operator in `process_kernel`. This repropagation process is necessary because Inductor may have changed the strides of input tensors, and it must now recompute the strides so that it can continue to appropriately plan the rest of the lowering process. This is fine: we just make sure we do the rebind unbacked + compute_unbacked_bindings dance we've been doing previously in the PR stack. But instead of putting unbacked_bindings on a new FX node, they go straight into our unbacked_bindings on the Inductor IR node.
    * **codegen_unbacked_symbol_defs** - Sidequest! FallbackKernel lowering is done in two steps. First, you emit the FallbackKernel buffer. Then, you emit MultiOutput buffers which actually give access to the individual outputs of FallbackKernel, which may have been multi-output. There is a design decision here: does the FallbackKernel bind the unbacked symbols, or the MultiOutput buffer? Historically, we put the binding on MultiOutput buffer, because it's more convenient: the FallbackKernel buffer is fake, in fact, it doesn't even get a name in C++ codegen. But it's kind of inconsistent with the keypath model that we've been tracking unbacked bindings with: if you have a multi-output node, you'd expect a keypath like `[0].size()[0]` representing the first output's first dimension size. That suggests that it's the FallbackKernel that should define the things. So that was my first implementation. Unfortunately, the C++ codegen is too cursed and I could not understand how to make it work in that case. So now we just unsoundly assume you cannot have multi-output data dependent output, and do the codegen in MultiOutput. There are some comments explaining exactly what we are improperly assuming.
3. **_rename_unbacked_to** in **torch/fx/experimental/symbolic_shapes.py** - Previously, when we renamed unbacked symbols, we clobbered any facts we previously knew about them. So for example, if we had a replacement `u0 -> s0` but then we renamed u0 to u1, we would now setup the replacement `u0 -> u1`, clobbering the old replacement. This apparently didn't matter in earlier PRs in the stack, but with Inductor now on the ball, there were some tests that indicated this was a problem. The solution is easy: if u0 had a preexisting replacement, reapply it to u1. However...
    * **torch/_functorch/_aot_autograd/collect_metadata_analysis.py** - When we run forward analysis, this triggers fake tensor repropagation and fresh allocations. Previously, we just cleared out the pending symbols when finished the analysis. But with the change above, this would also migrate replacements to the new symbols... which are now dead. So now we explicitly suppress generation of these symbols with `ignore_fresh_unbacked_symbols` so that no rebinding happens at all.
    * **torch/_dynamo/eval_frame.py** - same deal; I just searched for all sites we called clear() on pending
4. The last step is fixing the long tail of extra problems that show up, now that unbacked_bindings are load bearing into Inductor
    * **torch/_dynamo/eval_frame.py** - Some of the exports are making copies of nodes without repropagating fake tensors, so in this case, it is important to also copy the `unbacked_bindings` (apparently this didn't matter before without the Inductor changes)
    * **torch/_export/pass_base.py** - I discover that this is doing fake tensor repropagation via a test suite failure. Do the same playbook as AOTAutograd: PropagateUnbackedSymInts too!  Actually, they also have implemented their own tracer as well, so do the same playbook as proxy_tensor: record unbacked_bindings on the newly traced nodes. UGH code duplication.
    * **torch/_subclasses/fake_tensor.py**, **torch/_subclasses/fake_impls.py** (with call site updates at  **torch/_functorch/_aot_autograd/traced_function_transforms.py** and **torch/fx/passes/fake_tensor_prop.py**) - What's this new epoch thing? I noticed that sometimes I would be retracing, call nonzero() on a fake tensor, and not allocate a new unbacked symbol. This is actually bad, because if I don't get a new unbacked symbol, I don't know there's a binding site, and `unbacked_bindings` is now missing a binding. The reason for this is memoization: if I reuse the exact same fake tensor on my retrace, it will already have an unbacked symint memoized on it and we will short circuit allocation. Well, that's no good. So I associate the memos with a fake tensor epoch, and every time you start a new fake tensor propagation from scratch, you bump the epoch so that I clear all the memos.
    * **torch/_inductor/scheduler.py** - I notice in unit tests that V.current_node is not always set when we call process_kernel. So I save it into the IR node and restore it when we are running `get_estimated_runtime`.
    * **torch/fx/experimental/symbolic_shapes.py** - A few things
      * **rebind_unbacked** (re **_tensor_version**). Ordinarily, when you have an unbacked SymInt, you persistently hvae it all the way to the end of the program. `_tensor_version` violates this: this generates an unbacked SymInt (for reasons I don't quite understand?) and then gets rid of it later. This triggered an assert violation. I think this op is kind of misusing unbacked SymInt, but I didn't know how to refactor it, so it gets a special case.
      * **rebind_unbacked** (re **Simplify SymBool binding**). Ugh, SymBool, what a pain in the butt. I have an assert that you can only rebind unbacked symbol to another unbacked symbol. This assert fails when a boolean is involved, because the result of running keypath on the result is not `u1`, it's `sympy.Piecewise(... sympy.Eq(u1, 1) ...)`. This is actually just `u1`, but Sympy doesn't know it because it doesn't know that `u1` value range is `[0, 1]`. So we manually implement the simplification needed to get the assert to pass.
      * **compute_unbacked_bindings** (re **This is pretty fragile**). There is a really funny disaster involving memoization and Inductor process kernel. Ordinarily when I retrace, if there was a memo hit in the old trace, there will be a memo hit in the new trace. However, Inductor process kernel breaks this, because it recreates fake tensor inputs to the operator call from scratch (since they might have different strides), and obviously these tensor inputs don't have the memo from the old one. I tried a little bit to try to manually transplant the memo to the new fake tensor but it seemed hopeless, so I just let the fresh symbol ride, allocating a new unbacked symbol. However, in one of our tests, we rely on knowing that the first nonzero call is equal to the second (memoized) nonzero call. The equality test looked pretty easy to discharge, so I just went ahead and added a deferred runtime assert to this effect and it worked.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124394
Approved by: https://github.com/jansel
ghstack dependencies: #124310, #124314, #124316
2024-04-25 02:08:59 +00:00
eellison
68225072e8 Match insignificant strides for sdpa inputs (#124859)
Fix for https://github.com/pytorch/pytorch/issues/124289.

There was a tensor which had a single, expanded element. inductor generated the strides as all 0, while sdpa expects a dense last dimension `t.stride(-1) == 1`. While these are equivalent, we still hit an error in the kernel. We could make fixes in sdpa, but matching the insignificant strides in inductor also works and I am less aware of the downstream sdpa kernel details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124859
Approved by: https://github.com/drisspg
ghstack dependencies: #124751
2024-04-24 23:44:23 +00:00
Peter Bell
7ecbbc40c3 [HOP][inductor] Add higher order associative scan operator (#119430)
Currently only supports single tensor scans, e.g. `cumsum`, `cumprod`, `logcumsumexp`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119430
Approved by: https://github.com/Chillee
2024-04-23 14:40:13 +00:00
Bin Bao
bb37910e30 [AOTI] Fixes ScatterFallback codegen (#124580)
Summary: For https://github.com/pytorch/pytorch/issues/123184. ScatterFallback currently relies on op name matching for codegen, which makes its cpp codegen fragile. Refactor to use op_overload and fix the relevant unit test failures.

Differential Revision: [D56417815](https://our.internmc.facebook.com/intern/diff/D56417815)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124580
Approved by: https://github.com/chenyang78
2024-04-22 20:47:26 +00:00
Aaron Gokaslan
29cc293725 [BE]: FURB142 - Remove set mutations. Use set update (#124551)
Uses set mutation methods instead of manually reimplementing (update, set_difference etc).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124551
Approved by: https://github.com/ezyang
2024-04-21 14:12:33 +00:00
chilli
e620c3e814 Optimized templated attention to use exp2 (#124356)
0.705 (vs. FA2) to 0.860 after this change.

<img width="1270" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/d58f57ba-e50e-44ea-8a8a-4f13b8650adf">

to

<img width="1277" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/f1945b67-0cfc-463c-a2f6-5812b90677fe">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124356
Approved by: https://github.com/drisspg
2024-04-19 01:58:19 +00:00
Pearu Peterson
43b4ac956e Add index_reduce decomposition (#122579)
As in the title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122579
Approved by: https://github.com/peterbell10
ghstack dependencies: #123375
2024-04-18 01:30:47 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Aaron Gokaslan
1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00
chilli
7f6884f620 Added some extra repr to triton template buffers and added autotuned block configs to templated attention (#123813)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123813
Approved by: https://github.com/drisspg, https://github.com/shunting314
ghstack dependencies: #123768
2024-04-11 23:57:47 +00:00
Jiong Gong
0fd072bf90 [inductor] easy: move mkldnn lowerings to its own file (#123556)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123556
Approved by: https://github.com/peterbell10, https://github.com/jansel
2024-04-09 03:44:27 +00:00
angelayi
493478db4a [effects] Add inductor support for tokens (#122347)
Given the following code/dynamo graph:
```
class GraphModule(torch.nn.Module):
    def forward(self, L_x_ : torch.Tensor):
        l_x_ = L_x_
        _print = torch.ops.aten._print('moo')
        res = l_x_ + l_x_;  l_x_ = None
        _print_1 = torch.ops.aten._print('moo')
        return (res,)
```

AOTAutograd will trace the following program, threading tokens from the inputs, through the effectful operator calls (torch.ops.aten._print), and as an output:
```
class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2, 3]"):
        with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo');  arg0_1 = None
        getitem: "f32[0]" = with_effects[0];  with_effects = None
        add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1);  arg1_1 = None
        with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
        getitem_2: "f32[0]" = with_effects_1[0];  with_effects_1 = None
        return (getitem_2, add)
```
However when we get to inductor, since we want the inductor generated code to not have any token inputs/outputs for better readability, we want to modify the aten graph by removing the tokens from inputs, and creating them through `torch.ops.aten._make_dep_token`, and sinking them through the `torch.ops.aten._sink_tokens` operators.
This has to be done *after* the partitioner, otherwise the partitioner will add the make_token/sink_token operators to the backwards graph.
```
class <lambda>(torch.nn.Module):
   def forward(self, arg1_1: "f32[2, 3]"):
       _make_dep_token_default: "f32[0]" = torch.ops.aten._make_dep_token.default()
       with_effects = torch._higher_order_ops.effects.with_effects(_make_dep_token_default, torch.ops.aten._print.default, 'moo');  _make_dep_token_default = None
       getitem: "f32[0]" = with_effects[0];  with_effects = None
       add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1);  arg1_1 = None
       with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
       getitem_2: "f32[0]" = with_effects_1[0];  with_effects_1 = None
       _sink_tokens_default = torch.ops.aten._sink_tokens.default((getitem_2,));  getitem_2 = None
       return (add,)
```
When doing inductor lowering, we convert `with_effects` calls to an `EffectfulKernel`, which just a `FallbackKernel` but with a pointer to previous effectful operator's call. During scheduling, we will create a `StarDep` between the EffectfulKernel and its previous EffectfulKernel so that they don't get reordered. The inductor generated python code looks like:
```
def call(args):
    arg1_1, = args
    args.clear()
    assert_size_stride(arg1_1, (2, 3), (3, 1))
    # Source Nodes: [_print], Original ATen: []
    buf2 = aten._print.default('moo')
    # Source Nodes: [_print_1], Original ATen: []
    buf3 = aten._print.default('moo')
    buf4 = empty_strided_cpu((2, 3), (3, 1), torch.float32)
    cpp_fused_add_0(arg1_1, buf4)
    del arg1_1
    return (buf4, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122347
Approved by: https://github.com/bdhirsh
2024-04-09 03:22:32 +00:00
Gao Tianlin
77681facac [fix] inductor split lowering fails if item() is captured (#123032)
Fixes #122937

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123032
Approved by: https://github.com/jansel
2024-04-07 04:23:57 +00:00
drisspg
f4e2a226aa ScoreMod API (#121845)
# Summary

This PR adds a new higher-order_op: `templated_attention`.  This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention.  PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)).

This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation.

### Details

This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined.

### API

The API for a score_mod function should be as follows:

```Python
def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```

This function receives five parameters:

- `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors.
- `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor.

Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16)

The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as:

```Python
score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9))
```

### Examples
```Python
import torch
from torch.nn.attention.templated_attention import templated_attention

torch.manual_seed(0)

# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim)
query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)

# Lets create a fun new score_modification! I will call this
# Checkerboard. It will reduce the score for neighboring tokens (1 step apart)
# in the sequence. And increase the score for tokens 2 steps apart. For everything
# else, the score will remain the same.

def checkerboard(score, batch, head, token_q, token_kv):
    score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
    score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
    return score

# Lets call templated_attention with this new score modification
output = templated_attention(query, key, value, score_mod=checkerboard)

compiled_templated_attention = torch.compile(templated_attention)
out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard)

torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```

### Future Work
- This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py
- Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated.
- We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable
- Should we error on CPU?
- There are some issues with dynamic shapes
- Capturing of free variables and lifting to inputs to the subgraph is not working correctly today

### Performance
Comparisons generated by this benchmark:

| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     5.412 |              |             |             |             |            |               |                |
| Max     |     8.882 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     3.645 |            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |
| Min     |     0.345 |            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |

For reference

| Configuration                                 | Forward Time (µ seconds) | Backend          | Speedup |
|-----------------------------------------------|--------------------------|------------------|---------|
| Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608                   | Templated Attention                | 1.0  |
| Compiled SDPA (No Mask)                       | 9928                   | Math             | 2.75x   |
| Compiled SDPA (With Mask)                     | 11898                    | Math             | 3.29x   |
| Compiled SDPA (With Mask) | 8704                      | Memory Efficient Attention | 2.42x   |
| Compiled SDPA (No Mask) | 2548                     | FlashAttention2 | 0.706x   |

The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa

<details>

<summary> FULL PERFORMANCE SWEEP NUMBERS </summary>

|   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |   eager_time |   compiled_time |   speedup |
|--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------|
|            1 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      331.444 |          67.221 |     4.931 |
|            1 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      335.300 |          64.187 |     5.224 |
|            1 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      352.039 |          63.806 |     5.517 |
|            1 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      371.699 |         711.349 |     0.523 |
|            1 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |      333.488 |          86.455 |     3.857 |
|            1 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |      322.363 |          82.469 |     3.909 |
|            1 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |      349.967 |          82.233 |     4.256 |
|            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |      486.359 |        1412.453 |     0.344 |
|            1 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |     2794.597 |         551.188 |     5.070 |
|            1 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |     3965.150 |         513.101 |     7.728 |
|            1 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |     2408.013 |         504.759 |     4.771 |
|            1 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |     6850.531 |       16733.675 |     0.409 |
|            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      441.939 |         123.576 |     3.576 |
|            8 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      560.379 |         116.710 |     4.801 |
|            8 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      421.172 |         115.825 |     3.636 |
|            8 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      994.492 |        2132.806 |     0.466 |
|            8 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     1436.430 |         309.495 |     4.641 |
|            8 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     1892.216 |         290.186 |     6.521 |
|            8 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     1360.665 |         282.956 |     4.809 |
|            8 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     3525.532 |        8359.702 |     0.422 |
|            8 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    22026.839 |        3864.604 |     5.700 |
|            8 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    31262.746 |        3609.551 |     8.661 |
|            8 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    20219.079 |        3480.402 |     5.809 |
|            8 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |    54654.647 |      116652.357 |     0.469 |
|           16 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      820.606 |         188.683 |     4.349 |
|           16 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |     1058.362 |         179.295 |     5.903 |
|           16 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      784.372 |         175.714 |     4.464 |
|           16 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |     1890.792 |        4212.877 |     0.449 |
|           16 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     2781.830 |         557.017 |     4.994 |
|           16 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     3694.050 |         525.249 |     7.033 |
|           16 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     2634.164 |         507.613 |     5.189 |
|           16 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     6959.917 |       15331.116 |     0.454 |
|           16 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    43889.096 |        7582.018 |     5.789 |
|           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    62784.293 |        7075.846 |     8.873 |
|           16 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    40308.606 |        6829.587 |     5.902 |
|           16 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |   108892.137 |      233090.953 |     0.467 |
</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845
Approved by: https://github.com/Chillee, https://github.com/zou3519
2024-04-06 01:10:44 +00:00
xinan.lin
9743e3a19c [Inductor Intel GPU backend Upstream] Add Inductor Intel GPU backend. (#121895)
As the design in RFC https://github.com/pytorch/pytorch/issues/114856, this PR implemented Intel GPU Inductor backend by:
- Reuse WrapperCodegen and TritonScheduling for python wrapper and kernel code generation. And implenented device-specific code generation in XPUDeviceOpOverrides
- Reuse fx_pass, lowering, codecache, triton kernel auto-tuning, and compilation.

For the test case, this PR provided test/inductor/test_xpu_basic.py for basic inductor backend functionality testing.
We'll reuse all the existing Inductor test case in the next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121895
Approved by: https://github.com/EikanWang, https://github.com/jansel, https://github.com/desertfire
2024-04-05 09:05:11 +00:00
PyTorch MergeBot
16cb5d48dd Revert "[inductor] Add explicit ops.fma and use it in softmax_backward (#122518)"
This reverts commit 05984e642b.

Reverted https://github.com/pytorch/pytorch/pull/122518 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it starts failing in trunk 05984e642b ([comment](https://github.com/pytorch/pytorch/pull/122518#issuecomment-2038631010))
2024-04-05 02:09:32 +00:00
Peter Bell
05984e642b [inductor] Add explicit ops.fma and use it in softmax_backward (#122518)
This allows us to generate an fma even when fp-fusion is disabled
in the compiler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122518
Approved by: https://github.com/lezcano, https://github.com/Chillee
ghstack dependencies: #121924
2024-04-04 20:53:14 +00:00
eellison
d9cbd57dfe Make u/int8 cat inductor fallback cpu-only (#123278)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123278
Approved by: https://github.com/Chillee
2024-04-04 13:54:37 +00:00
ydwu4
a4035bea5c [while_loop] support closures (#123018)
We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with @zou3519 . This allows us to support closures, parameters and buffers.

The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123018
Approved by: https://github.com/aakhundov
ghstack dependencies: #123217
2024-04-03 19:35:15 +00:00
Andrew M. James
bde1a93bc4 Add lowering for resize, decomp for resize_as. (#122317)
This has been split off from #121354 as the inplace version of these
methods prove to be rather tricky.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122317
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2024-04-03 17:47:29 +00:00
Bin Bao
0ff6155eee [AOTI] Support module buffer mutation (#123164)
Summary: Fixes https://github.com/pytorch/pytorch/issues/120424. Because in a forward pass module buffers may be mutated, we need to allow that in AOTI. In addition, this will be a necessary step if we want to extend AOTI to training.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123164
Approved by: https://github.com/digantdesai, https://github.com/malfet, https://github.com/chenyang78, https://github.com/khabinov
2024-04-02 20:25:26 +00:00
Gao Tianlin
aaef246c74 remove log2 decomposition; add log2 lowering (#123112)
Same reason as `log10`. `log2` is a core aten op, we should not decompose it. As https://github.com/pytorch/pytorch/pull/110882 suggested, it often maps to a hardware intrinsic; Furthermore, decomposing it will negatively impact the numerical precision of the output.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123112
Approved by: https://github.com/peterbell10
2024-04-02 16:16:26 +00:00
Peter Bell
09c72eaa3f [inductor] Remove identity from ops.scan (#119727)
Currently scan has an `init` argument which must be the identity of the
combine function. This isn't strictly necessary if we are more careful about
keeping track of the first element and avoid combining it with anything.

This does additionally require that there are no active load masks, since we can't
do the `where_cond` any more. However, this shouldn't be possible anyway since
scans are always realized and only fused via the scheduler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119727
Approved by: https://github.com/lezcano
2024-04-01 22:47:26 +00:00
eellison
8b49782ba6 [Inductor] require channels last output for channels last input for max_pool2d_backward (#122749)
Previously we fell back on max_pool2d_with_indices_backward for channels last.. Turns out this was slow because we were inferring a contiguous output for channels last inputs. Fixing the layout and lowering gives a 1-2% TIMM win. It will also unblock saving the indices as int8 kernel offsets since we now lower channels last output.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122749
Approved by: https://github.com/Chillee, https://github.com/amjames, https://github.com/jansel, https://github.com/shunting314
2024-04-01 22:02:00 +00:00
Peter Bell
03439d4c1c [inductor] Lower divide by constant as multiplication by reciprocal (#121924)
Fixes #101039

This lowers division by a constant value to be multipication by reciprocal.
The same optimization is applied in eager mode on CUDA:

0636c11811/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu (L36-L38)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121924
Approved by: https://github.com/lezcano
2024-04-01 14:37:37 +00:00
vfdev-5
b524a404e0 Fixed support for uint8 in upsample bicubic2d decomposition (#120411)
Superseeds https://github.com/pytorch/pytorch/pull/104248

Description:
- Fixed support for uint8 for upsample bicubic2d decomposition (on `main` results are wrong, so we can tolerate the slowdown)
- Added missing clamp(0, 1) for xscale and yscale
  - slowdown for f32 on cpu. PR on nodes fusion on CPU: https://github.com/pytorch/pytorch/pull/120077 can help for upsampling cases with align corners = true
  - the slowdown mainly due to the added clamp op and also partially reduced when using torch.stack in weights computation on cpu.
- Removed lowering implementation

Benchmarks:
```
[-------------------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu --------------------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                                   |  Eager (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git069270d) Nightly  |  speed-up PR vs Nightly  |  Eager (2.4.0a0+git069270d) Nightly
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)       |        613.029 (+-1.590)        |         5477.608 (+-9.027)         |           3060.314 (+-12.368)           |     0.559 (+-0.000)      |          608.735 (+-6.336)
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)      |        610.176 (+-1.428)        |        5718.503 (+-11.203)         |           3424.022 (+-12.836)           |     0.599 (+-0.000)      |          604.781 (+-6.229)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)           |        325.001 (+-0.840)        |        6183.029 (+-10.893)         |            3275.032 (+-7.625)           |     0.530 (+-0.000)      |          325.693 (+-1.067)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)          |        325.855 (+-1.108)        |        6391.394 (+-11.552)         |            3533.410 (+-7.666)           |     0.553 (+-0.000)      |          325.838 (+-1.457)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)     |       2521.533 (+-14.857)       |        5025.217 (+-13.415)         |            2814.304 (+-6.742)           |     0.560 (+-0.000)      |         2520.308 (+-10.796)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)    |       2531.204 (+-12.534)       |        5294.925 (+-11.994)         |            3147.590 (+-6.808)           |     0.594 (+-0.000)      |         2521.228 (+-11.732)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)         |        758.352 (+-10.362)       |        5639.912 (+-14.495)         |            3014.123 (+-8.799)           |     0.534 (+-0.000)      |          756.114 (+-4.792)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)        |        758.712 (+-5.781)        |         5927.541 (+-9.982)         |            3249.555 (+-7.226)           |     0.548 (+-0.000)      |          757.719 (+-5.653)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)       |       1524.469 (+-12.860)       |        34321.641 (+-80.310)        |           19373.714 (+-56.351)          |     0.564 (+-0.000)      |         1518.082 (+-49.653)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)      |       1521.746 (+-13.780)       |        35949.711 (+-81.010)        |           21782.366 (+-68.938)          |     0.606 (+-0.000)      |         1467.911 (+-15.901)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)           |        712.311 (+-5.361)        |        38826.510 (+-92.267)        |           20762.314 (+-59.303)          |     0.535 (+-0.000)      |          712.669 (+-4.673)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)          |        715.060 (+-4.757)        |        40269.353 (+-92.543)        |           22402.114 (+-81.574)          |     0.556 (+-0.000)      |          716.001 (+-8.945)

      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)       |       2331.889 (+-29.159)       |        21541.096 (+-72.346)        |           12181.194 (+-45.288)          |     0.565 (+-0.000)      |         2304.864 (+-21.351)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)      |       2333.697 (+-10.066)       |        22514.154 (+-57.798)        |           21709.449 (+-98.307)          |     0.964 (+-0.000)      |         2302.141 (+-13.041)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)           |        1198.768 (+-5.364)       |       37652.371 (+-101.644)        |           42740.413 (+-98.571)          |     1.135 (+-0.000)      |          1197.104 (+-7.225)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)          |        1196.851 (+-5.118)       |       39678.341 (+-173.750)        |           46807.738 (+-92.744)          |     1.180 (+-0.000)      |          1189.322 (+-5.681)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)     |       10020.978 (+-54.855)      |        19955.290 (+-71.891)        |           11420.521 (+-53.179)          |     0.572 (+-0.000)      |         9999.583 (+-61.230)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)    |       10066.441 (+-62.700)      |       21058.334 (+-183.414)        |           19986.577 (+-65.304)          |     0.949 (+-0.000)      |         10018.672 (+-59.188)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)         |       3171.135 (+-14.635)       |        19687.864 (+-54.320)        |           23313.699 (+-57.391)          |     1.184 (+-0.000)      |         3182.191 (+-17.686)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)        |       3181.314 (+-13.784)       |        20224.468 (+-50.827)        |          30541.963 (+-381.385)          |     1.510 (+-0.000)      |         3183.578 (+-16.203)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)       |       5879.450 (+-31.551)       |       136918.555 (+-480.320)       |          77723.568 (+-331.766)          |     0.568 (+-0.000)      |         5726.061 (+-87.517)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)      |       5882.869 (+-30.325)       |       143378.094 (+-513.842)       |         137244.074 (+-4827.730)         |     0.957 (+-0.000)      |         5727.679 (+-22.164)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)           |       2674.937 (+-45.003)       |      244829.360 (+-1930.579)       |         271283.073 (+-2243.245)         |     1.108 (+-0.000)      |         2676.054 (+-24.632)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)          |       2676.217 (+-16.601)       |      248658.668 (+-2904.952)       |         296514.520 (+-2983.281)         |     1.192 (+-0.000)      |         2682.844 (+-19.886)

      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)     |        1768.437 (+-6.294)       |        2934.013 (+-28.870)         |            2520.649 (+-6.797)           |     0.859 (+-0.000)      |          1759.292 (+-5.097)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)    |        1748.660 (+-5.550)       |         3271.104 (+-7.557)         |            2891.306 (+-7.632)           |     0.884 (+-0.000)      |          1746.341 (+-5.845)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)         |        2813.150 (+-6.656)       |         3258.973 (+-7.543)         |            2766.286 (+-6.473)           |     0.849 (+-0.000)      |          2805.077 (+-7.611)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)        |        2812.102 (+-8.211)       |         3568.780 (+-9.018)         |            3125.870 (+-7.324)           |     0.876 (+-0.000)      |          2834.178 (+-9.034)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)   |        1687.975 (+-9.527)       |         2752.085 (+-9.627)         |            2373.274 (+-7.888)           |     0.862 (+-0.000)      |          1698.782 (+-8.098)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)  |        1696.606 (+-8.678)       |        3056.317 (+-13.303)         |           2699.160 (+-10.638)           |     0.883 (+-0.000)      |         1684.942 (+-10.519)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)       |        2613.491 (+-9.769)       |        3176.493 (+-13.366)         |            2730.193 (+-9.573)           |     0.859 (+-0.000)      |          2625.085 (+-9.943)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)      |       2614.946 (+-34.129)       |        3465.398 (+-11.165)         |           3044.396 (+-11.447)           |     0.879 (+-0.000)      |          2627.355 (+-9.608)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)     |       10784.549 (+-58.181)      |        18292.452 (+-59.344)        |           15909.922 (+-49.864)          |     0.870 (+-0.000)      |         10837.656 (+-51.947)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)    |       10786.513 (+-52.308)      |        20449.038 (+-56.204)        |           18295.997 (+-54.522)          |     0.895 (+-0.000)      |         10843.751 (+-44.781)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)         |       17532.699 (+-64.807)      |        20425.699 (+-80.271)        |           17517.040 (+-79.705)          |     0.858 (+-0.000)      |         17595.597 (+-61.870)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)        |       17530.816 (+-55.131)      |        22450.080 (+-92.899)        |           19827.828 (+-77.649)          |     0.883 (+-0.000)      |         17615.934 (+-71.716)

      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)     |       6875.484 (+-40.543)       |        11569.509 (+-62.462)        |          10053.350 (+-208.136)          |     0.869 (+-0.000)      |         6864.501 (+-46.747)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)    |       6843.126 (+-44.498)       |        12915.236 (+-60.654)        |          25335.058 (+-382.640)          |     1.962 (+-0.000)      |         6899.002 (+-46.861)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)         |       11103.418 (+-51.318)      |        28834.389 (+-78.395)        |          37405.463 (+-581.646)          |     1.297 (+-0.000)      |         11223.012 (+-60.709)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)        |       11092.994 (+-70.835)      |       36597.023 (+-118.988)        |           45761.267 (+-85.051)          |     1.250 (+-0.000)      |         11104.014 (+-61.288)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)   |       7106.791 (+-63.666)       |        11191.071 (+-45.402)        |           9786.037 (+-75.781)           |     0.874 (+-0.000)      |         7129.419 (+-77.674)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)  |       7146.519 (+-28.376)       |        12443.571 (+-39.425)        |           20147.067 (+-74.771)          |     1.619 (+-0.000)      |         7179.622 (+-64.847)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)       |       10533.849 (+-44.227)      |       34814.909 (+-138.127)        |          42803.001 (+-114.326)          |     1.229 (+-0.000)      |         10644.039 (+-59.681)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)      |       10548.910 (+-44.221)      |       42876.940 (+-146.959)        |          49711.443 (+-139.276)          |     1.159 (+-0.000)      |         10652.375 (+-44.174)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)     |      42814.521 (+-103.198)      |       73100.489 (+-435.262)        |          63587.659 (+-134.266)          |     0.870 (+-0.000)      |        43208.921 (+-195.287)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)    |      42812.373 (+-103.870)      |       81769.160 (+-373.369)        |         175159.813 (+-2028.558)         |     2.142 (+-0.000)      |         43007.691 (+-96.358)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)         |      69955.505 (+-373.373)      |      215248.616 (+-2040.775)       |         267511.246 (+-2094.161)         |     1.243 (+-0.000)      |        70382.679 (+-594.941)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)        |      69852.157 (+-490.076)      |      242841.484 (+-19645.513)      |         317931.678 (+-2016.498)         |     1.309 (+-0.000)      |        70074.819 (+-352.919)

Times are in microseconds (us).

[-------------------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cuda ---------------------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                                     |  Eager (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git069270d) Nightly  |  speed-up PR vs Nightly  |  Eager (2.4.0a0+git069270d) Nightly
1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)   |         97.727 (+-0.018)        |          97.765 (+-0.025)          |             97.773 (+-0.027)            |     1.000 (+-0.000)      |           97.905 (+-0.040)
      Input (1, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)  |         97.615 (+-0.066)        |          97.332 (+-0.032)          |             97.950 (+-0.026)            |     1.006 (+-0.000)      |           97.690 (+-0.062)
      Input (1, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)       |        100.635 (+-0.033)        |         125.883 (+-0.020)          |            102.499 (+-0.116)            |     0.814 (+-0.000)      |          101.103 (+-0.027)
      Input (1, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)      |        100.898 (+-0.036)        |         109.717 (+-0.336)          |            102.558 (+-0.120)            |     0.935 (+-0.000)      |          101.642 (+-0.105)
      Input (4, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)   |        462.853 (+-0.028)        |         382.475 (+-0.047)          |            382.472 (+-0.033)            |     1.000 (+-0.000)      |          462.188 (+-0.014)
      Input (4, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)  |        462.783 (+-0.021)        |         382.806 (+-0.037)          |            382.563 (+-0.043)            |     0.999 (+-0.000)      |          462.089 (+-0.028)
      Input (4, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)       |        466.721 (+-0.022)        |         384.438 (+-0.027)          |            384.886 (+-0.037)            |     1.001 (+-0.000)      |          467.014 (+-0.025)
      Input (4, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)      |        466.993 (+-0.032)        |         384.212 (+-0.009)          |            383.946 (+-0.029)            |     0.999 (+-0.000)      |          466.575 (+-0.020)
      Input (1, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)   |        190.070 (+-0.082)        |         209.353 (+-1.096)          |            202.870 (+-0.888)            |     0.969 (+-0.000)      |          189.371 (+-0.164)
      Input (1, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)  |        190.021 (+-0.018)        |         210.504 (+-0.456)          |            201.814 (+-0.770)            |     0.959 (+-0.000)      |          189.314 (+-0.036)
      Input (1, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)       |        188.860 (+-0.207)        |         336.635 (+-0.023)          |            252.026 (+-0.510)            |     0.749 (+-0.000)      |          188.860 (+-0.170)
      Input (1, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)      |        188.725 (+-0.214)        |         276.329 (+-0.563)          |            251.439 (+-0.524)            |     0.910 (+-0.000)      |          188.776 (+-0.189)
      Input (4, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)   |        781.879 (+-0.086)        |         836.389 (+-7.177)          |            816.483 (+-6.626)            |     0.976 (+-0.000)      |          781.362 (+-0.106)
      Input (4, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)  |        781.824 (+-0.099)        |         840.406 (+-7.111)          |            807.530 (+-6.514)            |     0.961 (+-0.000)      |          781.307 (+-0.129)
      Input (4, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)       |        769.290 (+-0.309)        |         675.498 (+-1.537)          |            688.171 (+-4.326)            |     1.019 (+-0.000)      |          769.830 (+-0.222)
      Input (4, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)      |        769.240 (+-0.179)        |         675.800 (+-1.113)          |            673.176 (+-1.740)            |     0.996 (+-0.000)      |          769.935 (+-0.171)

Times are in microseconds (us).

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120411
Approved by: https://github.com/lezcano
2024-03-29 13:15:25 +00:00
Hector Yuen
8a33a77fd1 Back out "Added a check in register_lowering to avoid decomposed ops (#117632)" (#122709)
Summary:
Original commit changeset: ebda663a196b

Original Phabricator Diff: D55271788

Test Plan: Some models are failing torch compile with this, retrying the tests

Reviewed By: colinchan15

Differential Revision: D55374457

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122709
Approved by: https://github.com/huydhn
2024-03-28 17:46:57 +00:00
Jason Ansel
07d037674f [inductor] Fix issue with randint + symbolic shapes (#122428)
Fixes #122405

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122428
Approved by: https://github.com/ezyang
2024-03-24 03:41:13 +00:00
angelayi
ed15370aab [aoti] Add handling of ir.Constants in promote_constants (#122419)
This issue popped up when enabling predispatch IR on the benchmarks (https://github.com/pytorch/pytorch/pull/122225)

On the following model:
```
class M(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device

    def forward(self, x):
        t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float)
        t = torch.sqrt(t * 3)
        return x * t
```

We get the following error:
```
======================================================================
ERROR: test_constant_abi_compatible_cuda (__main__.AOTInductorTestABICompatibleCuda)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 2741, in wrapper
    method(*args, **kwargs)
  File "/data/users/angelayi/pytorch/test/inductor/test_torchinductor.py", line 9232, in new_test
    return value(self)
  File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/angelayi/pytorch/test/inductor/test_aot_inductor.py", line 922, in test_constant
    self.check_model(M(self.device), (torch.randn(5, 5, device=self.device),))
  File "/data/users/angelayi/pytorch/test/inductor/test_aot_inductor.py", line 91, in check_model
    actual = AOTIRunnerUtil.run(
  File "/data/users/angelayi/pytorch/test/inductor/test_aot_inductor_utils.py", line 102, in run
    so_path = AOTIRunnerUtil.compile(
  File "/data/users/angelayi/pytorch/test/inductor/test_aot_inductor_utils.py", line 40, in compile
    so_path = torch._inductor.aot_compile_ep(
  File "/data/users/angelayi/pytorch/torch/_inductor/__init__.py", line 150, in aot_compile_ep
    return compile_fx_aot(
  File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 1005, in compile_fx_aot
    compiled_lib_path = compile_fx(
  File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 1111, in compile_fx
    return compile_fx(
  File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 1145, in compile_fx
    return compile_fx(
  File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 1336, in compile_fx
    return inference_compiler(unlifted_gm, example_inputs_)
  File "/data/users/angelayi/pytorch/torch/_dynamo/utils.py", line 265, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 1266, in fw_compiler_base
    return inner_compile(
  File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/angelayi/pytorch/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/data/users/angelayi/pytorch/torch/_inductor/debug.py", line 304, in inner
    return fn(*args, **kwargs)
  File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/angelayi/.conda/envs/pytorch10/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/angelayi/pytorch/torch/_dynamo/utils.py", line 265, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 447, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
  File "/data/users/angelayi/pytorch/torch/_inductor/compile_fx.py", line 707, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/data/users/angelayi/pytorch/torch/_dynamo/utils.py", line 265, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/angelayi/pytorch/torch/_inductor/graph.py", line 612, in run
    return super().run(*args)
  File "/data/users/angelayi/pytorch/torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "/data/users/angelayi/pytorch/torch/_inductor/graph.py", line 957, in run_node
    result = super().run_node(n)
  File "/data/users/angelayi/pytorch/torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/data/users/angelayi/pytorch/torch/_inductor/graph.py", line 819, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/data/users/angelayi/pytorch/torch/_inductor/graph.py", line 816, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/data/users/angelayi/pytorch/torch/_inductor/lowering.py", line 298, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/data/users/angelayi/pytorch/torch/_inductor/lowering.py", line 5340, in mul
    return make_pointwise(fn)(a, b)
  File "/data/users/angelayi/pytorch/torch/_inductor/lowering.py", line 409, in inner
    inputs = promote_constants(inputs, override_return_dtype)
  File "/data/users/angelayi/pytorch/torch/_inductor/lowering.py", line 373, in promote_constants
    ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView)))
torch._inductor.exc.LoweringException: StopIteration:
  target: aten.mul.Tensor
  args[0]: Constant(value=5.0, dtype=torch.float32, device=device(type='cuda', index=0))
  args[1]: 3
```

So I added an additional casing in `promote_constants` to handle the ir.Constants and now it works! Although please let me know if this is the wrong approach. Here's a paste of the full run with the inductor logs: P1198927007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122419
Approved by: https://github.com/eellison, https://github.com/desertfire, https://github.com/chenyang78
2024-03-22 18:39:36 +00:00
vfdev-5
90a13c3c5b Added a check in register_lowering to avoid decomposed ops (#117632)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117632
Approved by: https://github.com/lezcano
2024-03-22 16:38:31 +00:00
chilli
d34514f8db Renamed mutationlayout/aliasedlayout (#122474)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122474
Approved by: https://github.com/jansel
ghstack dependencies: #121624
2024-03-22 08:32:14 +00:00
Adnan Akhundov
e419011471 [inductor] Add torch.while_loop support to JIT Inductor (#122069)
Summary: `torch.while_loop` HOP support is added to JIT Inductor. The test coverage is limited due to the functionality constraints of the upstream `torch.while_loop` op in Dynamo / Export. When those are lifted, we'll add more tests (see TODO-s in the test file).

AOT Inductor support will be added in a follow-up PR.

Test Plan:

```
$ python test/inductor/test_control_flow.py
...
----------------------------------------------------------------------
Ran 38 tests in 159.387s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122069
Approved by: https://github.com/jansel, https://github.com/eellison
2024-03-22 02:45:27 +00:00