Commit Graph

267 Commits

Author SHA1 Message Date
Shunting Zhang
bbded928b3 [innductor] make inductor work with new triton compile interface (#115878)
Recent 2 triton PRs (https://github.com/openai/triton/pull/2701, https://github.com/openai/triton/pull/2756) change the interface for triton.compile, this PR added the necessary change on inductor side to work with both old and new compile API.

Also there is some simplification between compilation call in subprocess and the one in main process
- previously we pass warm_cache_only=True if the compilation happens in subprocess. But triton never use that argument in the currently used pin. So I removed that
- previously we only pass compute_capability if compilation happens in subprocess. The PR change that to always passing compute_capability to triton.compile no matter if the compilation happens in main or sub process.

Updated:
There are more interface change from triton side. E.g.
- tl.math.{min, max} now requires a propagate_nan argument
- JITFunction.run now requires a warmup argument. This affect the benchmarking phase of matmul max-autotune; on the other hand, JITFunction.run forbids stream argument now. Simply removing passing this in when benchmarking matmul triton kernel will work for both old and new version of triton.
- triton Autotuner change attribute name from 'warmup' to 'num_warmup' and from 'rep' to 'num_rep'. This cause dynamo failed to handle triton Autotuner object since dynamo TritonKernelVariable makes assumption about attribute names. It's used in some test cases that a model call triton Autotuner directly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115878
Approved by: https://github.com/jansel
2023-12-21 00:03:38 +00:00
PyTorch MergeBot
c215e59bf2 Revert "[inductor] Avoid bool being upcast to int (#109913)"
This reverts commit 92998693a9.

Reverted https://github.com/pytorch/pytorch/pull/109913 on behalf of https://github.com/jeanschmidt due to causing performance regression in relevant metrics, @malfet I believe you are the correct person to help identify and fix the issues. More details check internal OPS count for ads metricsnin the internal related diff ([comment](https://github.com/pytorch/pytorch/pull/109913#issuecomment-1864397407))
2023-12-20 12:33:50 +00:00
Philip Meier
505a9e4854 add support for dynamic shapes in round (#115259)
Fixes #114310 and supersedes #114748.

There are two reasons why we have quite a few special cases for `round`:

1. `round` is actually two ops. With `ndigits=None` (default), `round` always returns an integer. When `ndigits` is an integer, the returned type is a float.
2. Although `round` takes two arguments, it is a unary function with a parameter rather than a binary one.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115259
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2023-12-19 15:45:50 +00:00
Peter Bell
92998693a9 [inductor] Avoid bool being upcast to int (#109913)
Currently the inductor code for `x.any(-1)` does a this strange dance:
```python
tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask)
tmp1 = tmp0.to(tl.int64)
tmp2 = (tmp1 != 0)
```

This happens because `register_lowering` is doing type promotion with the
dimension argument, and so promotes to `int64` which we then cast back to bool.
A better fix would be to fix `register_lowering` but for now I just remove
the unnecessary type promotion from `aten.any`.

In the current code we also see:
```python
     tmp5 = tl.where(rmask & xmask, tmp3, 0)
```
which promotes the boolean value to int since `0` is an int32 in triton.
This fixes it to generate a boolean constant instead.

Finally there is also a triton bug where the `tl.load` itself upcasts to
`tl.int8`. I fix this by adding an explicit cast to `tl.int1`. The final
kernel code looks like:

```python
tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask).to(tl.int1)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.full([1, 1], 0, tl.int1)
tmp4 = tl.where(rmask & xmask, tmp1, tmp3)
tmp5 = triton_helpers.any(tmp4, 1)[:, None]

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109913
Approved by: https://github.com/lezcano
2023-12-19 14:16:10 +00:00
vfdev-5
c7ae2c170f [inductor] Added non-integer expr support for floordiv in triton codegen (#115751)
Description:
- Added non-integer expr support for floordiv in triton codegen
- Added a test
  - cpp test is skipped as failing and https://github.com/pytorch/pytorch/pull/115647 may fix it

This PR is fixing compilation error with the following code:
```python
import torch

def func(x, a):
    n = (a * 1.234) // 8.234
    y = x + n
    return y

cfunc = torch.compile(func, dynamic=True, fullgraph=True)

device = "cuda"
x = torch.tensor(0, dtype=torch.float32, device=device)
a = 33

out = cfunc(x, a)
expected = func(x, a)
torch.testing.assert_close(out, expected)
```
Error message on Nightly:
```
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fx_wrapper' raised:
CompilationError: at 7:38:def triton_(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = ((1.23400000000000*ks0) // 8.23400000000000)
                                      ^
AssertionError()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115751
Approved by: https://github.com/peterbell10
2023-12-13 23:17:42 +00:00
Yang Chen
1392843e7b [inductor] make sure bitcast input and target type have the same bitwidth (#115619)
This PR fixed #104791

bitcast requires the source and target have the bitwidth.
Because the input tensor's dtype could be promoted, e.g. from float16 to
float, we have to cast the tensor to its original source dtype before
invoking bitcast in such cases. After that, we also need to convert
the bit-casted tensor back to float to make sure we keep using higher
precision values for the rest of the computation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115619
Approved by: https://github.com/jansel, https://github.com/eellison
2023-12-13 00:53:04 +00:00
Peter Bell
40dc0580a6 [inductor] De-duplicate triton helper functions (#115546)
Previously if two calls to cumsum were generated in the same triton kernel
we would generate identical helper functions with different names. Now this
recognizes identical functions and only defines it once. To do this I defer
choosing the name until after codegen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115546
Approved by: https://github.com/lezcano
ghstack dependencies: #109132
2023-12-12 16:30:50 +00:00
Peter Bell
02196c21ac [inductor] Parameterize ir.Scan on combine_fn (#109132)
This replaces `tl.cumsum` and `tl.cumprod` with calls to `tl.associative_scan`
where the combine function is generated from inductor IR.

So before we had:
```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 20
    rnumel = 30
    RBLOCK: tl.constexpr = 32
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (30*x0)), rmask & xmask, other=0).to(tl.float32)
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
    tmp2 = tl.where(rmask & xmask, tmp1, 0)
    tmp3 = tl.cumsum(tmp2, 1)
    tl.store(out_ptr0 + (r1 + (30*x0)), tmp3, rmask & xmask)
```

Now we have:
```python
@triton.jit
def _triton_helper_fn0(arg0, arg1):
    tmp0 = tmp0 + tmp1
    return tmp0

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 20
    rnumel = 30
    RBLOCK: tl.constexpr = 32
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (30*x0)), rmask & xmask, other=0).to(tl.float32)
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
    tmp2 = tl.where(rmask & xmask, tmp1, 0)
    tmp3 = tl.associative_scan(tmp2, 1, _triton_helper_fn0)
    tl.store(out_ptr0 + (r1 + (30*x0)), tmp3, rmask & xmask)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109132
Approved by: https://github.com/lezcano
2023-12-12 16:30:50 +00:00
Peter Bell
7aac689b19 [inductor] Add ir.Scan and lower aten.cumsum on CUDA (#106581)
This adds the `ir.Scan` node (currently only supported on CUDA) which re-uses the existing reduction kernel machinery to support different kinds of non-pointwise ops. Just like reductions it supports prologue and epilogue fusions and has both persistent and non-persistent kernel generation.

Currently this doesn't support the equivalent of `Reduction.create_multilayer` and will instead fall back to eager in those cases. This is because splitting into multiple kernel invocations ends up being far slower than cub's single kernel strategy which matches the performance of a copy kernel.

Fixes https://github.com/pytorch/pytorch/issues/93631

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106581
Approved by: https://github.com/lezcano, https://github.com/atalman
2023-12-05 23:31:49 +00:00
Jez Ng
c808a84680 Better logging for "cannot fuse" reasons (#115003)
This was invaluable when I was debugging #114917. Without the node names
in the log message, it was difficult to make sense of them.

However, I did not want to bloat the number of LOC with this change.
Thus, instead of calling `debug()` directly with the node arguments, I
made a new callable class WhyNoFuse to partially apply the node
arguments at the top of each fusion-checking method. WhyNoFuse generates
the logging string only when its `__str__` method gets called, so there
is minimal overhead when logging is disabled.

I also removed the various logging 'tags' like "vert:1" / "triton:1" --
the log messages themselves are unique enough that the user can identify
them without the tag.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115003
Approved by: https://github.com/Skylion007
2023-12-03 04:48:43 +00:00
Shunting Zhang
68a8d74f3f [inductur] benchmark epilogue fused matmul template (#114809)
Want to be a able to benchmark epilogue fused triton matmul kernel for a couple of reasons
1. @eellison  found that certain TB models (resnet50, resnet152, moco) fails sometimes in maxautotune mode on the dashboard. The issue is quite hard to repro due to flakiness. The issue only get triggered when certain triton config for certain epilogue fused kernel get picked. (disable epilogue fusion bypass the issue) It would be nice if we can have a runnable script that directly run that kernel to ease further debugging
2. this is a necessary piece to do benchmark fusion for triton matmul kernels. cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler  for this

Example runnable kernel script: https://gist.github.com/shunting314/00bdbc1b6b46bfa73d1389d8f40cd669

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114809
Approved by: https://github.com/eellison
2023-12-01 21:05:01 +00:00
Jez Ng
47e6cc4d22 Remove yet more type-ignores in dynamo/inductor (#114684)
Probably the last big batch for a while

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114684
Approved by: https://github.com/Skylion007
2023-11-28 22:09:38 +00:00
chundian
74e10f0f60 [inductor] Fix torch.split bug on unbacked symint (#113406)
torch.split(x, l) fails when l's shape is the unbacked symint.

E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.

Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
2023-11-28 20:45:13 +00:00
Jez Ng
71b742b42c [inductor] Remove more type: ignore comments (#114162)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114162
Approved by: https://github.com/Skylion007, https://github.com/eellison
2023-11-28 06:45:55 +00:00
PyTorch MergeBot
ccb1de3595 Revert "[inductor] Fix torch.split bug on unbacked symint (#113406)"
This reverts commit cd7d6938c1.

Reverted https://github.com/pytorch/pytorch/pull/113406 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/113406#issuecomment-1827727411))
2023-11-27 12:20:52 +00:00
chundian
cd7d6938c1 [inductor] Fix torch.split bug on unbacked symint (#113406)
torch.split(x, l) fails when l's shape is the unbacked symint.

E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.

Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
2023-11-24 07:21:00 +00:00
Xu Han
0f887a6d1a limit fused kernel num args. (#113131)
Fixes #97361

When fused kernel more than 1024 parameters, it should throw error from ctypes.
Limit args number is should be a mechanism to protect stack memory. As we known, CPP is passing args via stack memory, and stack memory has size limitation.

Code change:

1. cpp backend will check the fused nodes' args number, if it is reach the limitation. It will status flush status to ready.
2. scheduler will check `ready_to_flush` API and help backend flush codegen.
3. Add `ready_to_flush` API to `BaseScheduling`, Triton backend will return False due to not support it yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113131
Approved by: https://github.com/jgong5, https://github.com/mlazos
2023-11-22 18:05:33 +00:00
Jez Ng
87925789ae Make V.graph properly typed (#114025)
Previously it lacked a type hint and so was treated as an Any type. This
resulted in a lot of untyped code downstream as V.graph is referenced in
many places in inductor code. I've typed it properly now as
GraphLowering, and fixed the numerous type errors this surfaced.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114025
Approved by: https://github.com/eellison
ghstack dependencies: #114013
2023-11-21 02:14:29 +00:00
PyTorch MergeBot
ff7c06a01b Revert "limit fused kernel num args. (#113131)"
This reverts commit 7b442c2b0a.

Reverted https://github.com/pytorch/pytorch/pull/113131 on behalf of https://github.com/albanD due to Breaks lint on trunk ([comment](https://github.com/pytorch/pytorch/pull/113131#issuecomment-1817548349))
2023-11-18 16:14:08 +00:00
Han, Xu
7b442c2b0a limit fused kernel num args. (#113131)
Fixes #97361

When fused kernel more than 1024 parameters, it should throw error from ctypes.
Limit args number is should be a mechanism to protect stack memory. As we known, CPP is passing args via stack memory, and stack memory has size limitation.

Code change:

1. cpp backend will check the fused nodes' args number, if it is reach the limitation. It will status flush status to ready.
2. scheduler will check `ready_to_flush` API and help backend flush codegen.
3. Add `ready_to_flush` API to `BaseScheduling`, Triton backend will return False due to not support it yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113131
Approved by: https://github.com/jgong5, https://github.com/mlazos
2023-11-18 03:55:52 +00:00
Jez Ng
4667e20b3f Delete a bunch of type-ignores (#113990)
* Replaced `ignore[import]` by mypy config file entries
* Removed a bunch of ignores around previously-fixed attr-defined /
  call-arg issues
* Fixed some invalid / undefined types; added a few more type-ignores to
  squelch the downstream errors this exposed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113990
Approved by: https://github.com/eellison, https://github.com/Skylion007
ghstack dependencies: #113979
2023-11-18 02:48:38 +00:00
Jez Ng
204ec11e6d [inductor][easy] Fix fusion logging (#113308)
We should use %s instead of %d as the numel may be sympy Exprs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113308
Approved by: https://github.com/lezcano
2023-11-09 03:19:39 +00:00
Jez Ng
dc63248b76 Make dynamo configs more amenable to static type checking (#112130)
`install_config_module` makes a regular module into a ConfigModule with
extra methods defined on it. mypy thinks those extra methods (or module
functions) are undefined since it cannot analyze something so
dynamic. As a workaround, I've created a fake module that defines these
extra functions, which I import into the config modules during type
checking.

As part of this change, I've also added more types to config_utils.py
and enabled typechecking for torch/_dynamo/config.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112130
Approved by: https://github.com/jansel
2023-11-08 21:17:45 +00:00
drisspg
74c24d2367 Fixes a bug in inductor.triton.load (#113047)
Lettin CI/CD tell me if there is anything wrong with this

Original bug:
``` Shell
        r1 = rindex
        tmp37 = tl.load(out_ptr2 + (r1 + (8192*x0)), rmask, eviction_policy='evict_first', other=0)
                                                     ^
AssertionError('cannot cast int32[constexpr[1],constexpr[2048]] to <[1, 2048], fp8e4nv>')
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113047
Approved by: https://github.com/Skylion007, https://github.com/ipiszy
2023-11-07 04:06:54 +00:00
Aaron Gokaslan
8219bf051b [BE]: Apply RUF015 to torch folder (#113025)
Removes unnecessary allocations of iterators. There is a small chance this may have side effects as the entire iterator is no longer consumed, but this is a way more efficient method for retrieving the first element.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113025
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-11-07 00:48:15 +00:00
Shunting Zhang
493ae78201 [inductor] nan-checker (#112091)
This PR is spilt out of https://github.com/pytorch/pytorch/pull/108193 . It adds the ability to add assertion after each triton kernel calls to make sure all tensor arguments are not nan/inf. It helps me find a few bugs when working on benchmark fusion (due to messing up some kernel/graph level states when generating kernel code).

Right now we have to disable cudagraphs to enable the nan/inf checks. Otherwise we will see errors like: https://gist.github.com/shunting314/053db66c4f121e5f4c5de159bf0032ed . My best guess is it's due to GPU->CPU copy during capturing for cudagraphs. cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @eellison  if there is easy way to make it work with cudagraphs.  But even if the nan-checker is not compatible with cudagraphs, it's probably still fine since it's just for debugging purpose.

Test command:
```
TORCHINDUCTOR_BENCHMARK_KERNEL=1 TORCHINDUCTOR_NAN_ASSERTS=1 python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --only BertForMaskedLM --training --disable-cudagraphs
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112091
Approved by: https://github.com/eellison, https://github.com/jansel
2023-11-02 02:32:04 +00:00
David Berard
8191fb3e06 [Reland2] [inductor][BE] split triton_meta and inductor_meta (#112351)
triton_meta is intended to be passed directly to triton. Previous we were also putting other metadata into triton_meta; but we should split out the other metadata into a separate dict to avoid possible conficts in the future.

This PR splits out triton_meta and inductor_meta so we have a place to put additional metadata that isn't intended to be passed to triton.

Tests - wait for CI

Differential Revision: [D50864493](https://our.internmc.facebook.com/intern/diff/D50864493)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112351
Approved by: https://github.com/eellison
2023-11-02 00:40:12 +00:00
Jiong Gong
e061144aaf [inductor] replace ops.div with ops.truediv (#112243)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112243
Approved by: https://github.com/lezcano
ghstack dependencies: #112234
2023-11-01 05:50:51 +00:00
Shunting Zhang
a1e222ef02 metric table (#109245)
In dynamo/inductor, sometimes it helps to gather metrics/statistics for each model in different levels like model level, graph level, kernel level or pair of fusion nodes level. This kind of thing will be very easy to do with Scuba, but we only have scuba in fbcode. This PR build metric tables to solve part of the problem.

Q: why not log to stdout/err direclty
A: sometimes we need more structured data. E.g., it would be helpful to gather all the stats in a CSV and then do post-processing (like calculating a geomean etc.). Also metric table will tag each row with the model name which is helpful.

Q: what's the difference with speedup_indcutor.csv
A: speedup_indcutor.csv is a special case that gather statistics on model level: i.e., we have one row for each model. But recording statistics on finer grain level like graph etc. is also helpful.

Example use cases:
- As a followup on the bechmark fusion PR, I want to gather all the 'slow' fusion and analyze them. With the metric table, I can easily log slow fusion for each model into a csv file. Here is the log gathered for huggingface:
 https://gist.github.com/shunting314/964e73cc98368b301414ec7b7ad4c702 .
- To help understand the effect of 'loop ordering after fusion' PR, it would be helpful to gather stats like how many fusions happens for each graph. Previously we log the metric to stderr directly. But logging these metrics in a structural way is useful.
- gather number of registers, register spills, shared memory usage for each kernel in each model with runnable kernel code logged.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109245
Approved by: https://github.com/jansel, https://github.com/mlazos
2023-11-01 02:33:42 +00:00
Shunting Zhang
fbafff3668 [reland][inductor] benchmark fusion (#112450)
reland https://github.com/pytorch/pytorch/pull/108193

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112450
Approved by: https://github.com/jansel
2023-10-31 18:17:06 +00:00
PyTorch MergeBot
64fd027f2e Revert "[inductor] benchmark fusion (#108193)"
This reverts commit 73cc5d1cdd.

Reverted https://github.com/pytorch/pytorch/pull/108193 on behalf of https://github.com/izaitsevfb due to Trying to unblock the revert of #108690, please rebase and reland. ([comment](https://github.com/pytorch/pytorch/pull/108193#issuecomment-1782157638))
2023-10-27 01:40:06 +00:00
Shunting Zhang
73cc5d1cdd [inductor] benchmark fusion (#108193)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108193
Approved by: https://github.com/jansel
2023-10-26 22:18:37 +00:00
PyTorch MergeBot
485cc0faae Revert "[inductor] benchmark fusion (#108193)"
This reverts commit ec0cdcdf6a.

Reverted https://github.com/pytorch/pytorch/pull/108193 on behalf of https://github.com/ZainRizvi due to This test is breaking trunk. In the future please make sure to add the ciflow/trunk label before force merging any PR to ensure your code doesn't break those tests ([comment](https://github.com/pytorch/pytorch/pull/108193#issuecomment-1781473282))
2023-10-26 16:41:20 +00:00
Shunting Zhang
ec0cdcdf6a [inductor] benchmark fusion (#108193)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108193
Approved by: https://github.com/jansel
2023-10-26 04:14:22 +00:00
Guilherme Leobas
f97c2dabd9 Move negative index checking to common.py - Fix issue 97365 (#108690)
Fixes https://github.com/pytorch/pytorch/issues/97365

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108690
Approved by: https://github.com/lezcano
2023-10-24 17:27:54 +00:00
PyTorch MergeBot
e62c887bab Revert "[inductor][BE] split triton_meta and inductor_meta (#111397)"
This reverts commit 070b94dc08.

Reverted https://github.com/pytorch/pytorch/pull/111397 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/111397#issuecomment-1776282039))
2023-10-24 00:52:24 +00:00
David Berard
070b94dc08 [inductor][BE] split triton_meta and inductor_meta (#111397)
triton_meta is intended to be passed directly to triton. Previous we were also putting other metadata into triton_meta; but we should split out the other metadata into a separate dict to avoid possible conficts in the future.

This PR splits out triton_meta and inductor_meta so we have a place to put additional metadata that isn't intended to be passed to triton.

Tests - wait for CI

Differential Revision: [D50442547](https://our.internmc.facebook.com/intern/diff/D50442547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111397
Approved by: https://github.com/shunting314, https://github.com/eellison
2023-10-23 21:38:21 +00:00
Jon Chuang
9c7f464eef [inductor]: Better debugging of can_fuse decisions with TORCH_LOGS=fusion (#110415)
Fixes https://github.com/pytorch/pytorch/issues/110393

Example logs (for adagrad on main). In this case, it clearly identifies device mismatch as a potential red flag, which is indeed the obstacle to adagrad's successful fusion. (see: https://github.com/pytorch/pytorch/pull/110339)
```
[2023-10-03 21:50:24,084] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] ===== attempting fusion (1/10): 18 nodes =====
[2023-10-03 21:50:24,084] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer
[2023-10-03 21:50:24,084] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] 13 possible fusions:
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf0_buf1_buf2_buf3), ForeachKernelSchedulerNode(nodes=buf4_buf5_buf6_buf7))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf4_buf5_buf6_buf7), SchedulerNode(name='buf8'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf4_buf5_buf6_buf7), SchedulerNode(name='buf10'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf0_buf1_buf2_buf3), SchedulerNode(name='buf12'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf0_buf1_buf2_buf3), SchedulerNode(name='buf14'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf4_buf5_buf6_buf7), SchedulerNode(name='buf9'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf4_buf5_buf6_buf7), SchedulerNode(name='buf11'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf0_buf1_buf2_buf3), SchedulerNode(name='buf13'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (ForeachKernelSchedulerNode(nodes=buf0_buf1_buf2_buf3), SchedulerNode(name='buf15'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (SchedulerNode(name='buf25'), SchedulerNode(name='buf33'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (SchedulerNode(name='buf43'), SchedulerNode(name='buf51'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (SchedulerNode(name='buf34'), SchedulerNode(name='buf42'))
[2023-10-03 21:50:24,085] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] (SchedulerNode(name='buf16'), SchedulerNode(name='buf24'))
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] completed fusion round (1/10): fused 18 nodes into 5 nodes
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG]
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] ===== attempting fusion (2/10): 5 nodes =====
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] cannot fuse (7): device mismatch (node1: cuda:0, node2: cpu)
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] 0 possible fusions:
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] completed fusion round (2/10): fused 5 nodes into 5 nodes
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG]
[2023-10-03 21:50:24,087] [0/0] torch._inductor.scheduler.__schedule: [DEBUG] ===== fusion complete (2 iterations) =====

```

CC @jansel @ngimel @mlazos @shunting314 @peterbell10  as code owners

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110415
Approved by: https://github.com/mlazos
2023-10-13 00:36:45 +00:00
Jack Taylor
96f616a054 Revert tl.int1 casting change for ROCm to avoid hangs (#110531)
Seeing hangs on ROCm seemingly after this PR https://github.com/pytorch/pytorch/pull/110388
https://ossci-raw-job-status.s3.amazonaws.com/log/17381916785
`inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCUDA::test_comprehensive_exp2_cuda_bool Command took >30min, returning 124`

Conditionalising out of this while we investigate.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110531
Approved by: https://github.com/peterbell10
2023-10-06 08:53:45 +00:00
Kazuaki Ishizaki
434a996c42 Fix typo under torch/_inductor directory (#110530)
This PR fixes typo of comments and messages in files under `torch/_dynamo` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110530
Approved by: https://github.com/kit1980
2023-10-05 02:17:20 +00:00
Peter Bell
dc794ec32c [dynamo] Trace through builtin abs (#110398)
In python `abs(x)` does nothing but delegate to `x.__abs__()` so we should do
the same in dynamo. This also adds `SymNode.__abs__` so we can trace through
indexing expressions involving `abs`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110398
Approved by: https://github.com/jansel, https://github.com/lezcano
2023-10-03 19:25:37 +00:00
Levy Zhao
7f0a659ccc Script to compare measured (trace) runtimes with estimated runtimes (#108037) (#109076)
Summary:

X-link: https://github.com/pytorch/benchmark/pull/1856

Reviewed By: xmfan, xuzhao9

Differential Revision: D48523883

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109076
Approved by: https://github.com/xw285cornell
2023-10-03 17:05:35 +00:00
Peter Bell
01b2f25ebd [inductor] Cast loads from boolean tensors to tl.int1 (#110388)
Triton currently loads pointer to `tl.int1` as `tl.int8`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110388
Approved by: https://github.com/lezcano, https://github.com/Skylion007
2023-10-02 22:52:08 +00:00
chilli
13681382d5 Add heuristic for when evict_first should be set (and some other minor things) (#108841)
Example of when the `evict_first` heuristic helps.
```
@torch.compile
def f(a, b):
    return (a * b).sum(dim=-1)

N = 512
inps = (torch.randn(N, N, N).permute(2, 1, 0), torch.randn(N, N, N).permute(1, 2, 0))
from torch._inductor.utils import do_bench
print(do_bench(lambda: f(*inps)))
```

This generates code like this: http://ix.io/4HFs

```
Original: 3.8 ms
This PR: 3.54 ms
Always `evict_first: 5.4ms
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108841
Approved by: https://github.com/lezcano, https://github.com/jansel
2023-10-01 17:06:12 +00:00
Jon Chuang
6aae636f69 chore(inductor): Simplify will_fusion_create_cycle and cleanup to node.ancestors (#109976)
recursive_predecessors == ancestors so rename.

Improve comments

Simplify `will_fusion_create_cycle` - make it easier to read and add detailed comments.

Diagram to illustrate clarification of shortcut.
![Inductor Deep Dive](https://github.com/pytorch/pytorch/assets/9093549/7a30e088-8a33-4a9c-a8a7-81199cd086e2)

CC: @ngimel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109976
Approved by: https://github.com/jansel
2023-09-27 20:48:53 +00:00
Peter Bell
92d86cd1ad [inductor] Fix triton compiler error in multilayer any (#109325)
Fixes #109196

When we have a split reduction and the tensor is not an even multiple of the split size,
we use `ops.masked` to pad to an even multiple. In the case here we generated:
```python
tmp5 = tl.where(mask, tmp4, 0)
```

which implicitly promotes our boolean value to `int32`. The fix is to give the default
value the same dtype as `result`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109325
Approved by: https://github.com/lezcano
2023-09-26 12:29:29 +00:00
Ying Zhang
bbdce93571 Basic fp8 support in Inductor (#109168)
Add basic fp8 support in Inductor, including:
* Fix fp8 Triton codegen issues;
* Add min_elements_per_thread requirement for fp8 related dtype conversions. More details on Triton implementation can be found from 10f59d8ce0/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp (L10).

Note that the current implementation only works for Pointwise. Will create follow-up PRs for Reduction.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109168
Approved by: https://github.com/drisspg
2023-09-23 04:41:41 +00:00
Edward Z. Yang
3268b039ec Handle unbacked symints in Triton size hints (#109609)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109609
Approved by: https://github.com/yf225
2023-09-22 03:16:53 +00:00
PyTorch MergeBot
169ae7540d Revert "Handle unbacked symints in Triton size hints (#109609)"
This reverts commit 654731a52b.

Reverted https://github.com/pytorch/pytorch/pull/109609 on behalf of https://github.com/ezyang due to this seems to regress HF perf ([comment](https://github.com/pytorch/pytorch/pull/109609#issuecomment-1729688883))
2023-09-21 14:25:42 +00:00
Edward Z. Yang
654731a52b Handle unbacked symints in Triton size hints (#109609)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109609
Approved by: https://github.com/yf225
ghstack dependencies: #109603
2023-09-20 18:03:54 +00:00