Summary:
Codegen used to generate tmp_arg_{index} as temporary args, and index is the position of the caller.
We changed the logic of codegen such that we can reuse previous generated samples, and only delete after arg is no longer used. In this case, we need to make {index} unique, since different functions could reuse the same "tmp_arg_{index}" name string, but corresponds to different args.
Test Plan: `python test/inductor/test_aot_inductor.py -k test_autotuning_args_reuse`
Differential Revision: D72297084
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150522
Approved by: https://github.com/desertfire, https://github.com/22quinn
This PR supports symbol inputs to graph partition functions. Before this PR, we rely on `node.read_writes` to get partition inputs. However, this does not cover symbol inputs.
In this PR, for each graph partition, we collect all symbol inputs which are required to be in scope to successfully perform codegen, including:
- free symbols used in partition nodes.
- free symbols in partition input/node shapes, strides, and offsets. This is needed for recording cudagraphs for tensors with dynamic shapes.
### Note1: MutationLayout
In this example, node.layout is MutationLayoutSHOULDREMOVE. The symint from index `n` does not appear in the size, offset, stridese of node.layout. This symint appear in node.layout.target. So we need extra handle for it.
```python
x = torch.zeros(7, device="cuda")
def fn(n, a):
a[n] = -1
return a
opt_fn = torch.compile(fn, fullgraph=True)
for n in range(2, x.shape[0]):
opt_fn(n, x)
```
### Note2: Composability with Padded Tensor Subclass
W/o graph partition, Padded Tensor subclass lifts outer shapes to input arguments (i.e., arg0_1 for s0, arg1_1 for s1) but does not lift inner shapes (i.e., s2 and s3). Since cudagraph cache relies on integer inputs, it will cache on outer shapes and ignore inner shapes, which is bad.
```
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
args.clear()
s0 = arg0_1
s1 = arg1_1
arg2_1_size = arg2_1.size()
s2 = arg2_1_size[0]
s3 = arg2_1_size[1]
assert_size_stride(arg2_1, (s2, s3), (s3, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((s2, s3), (s3, 1), torch.float32)
# Topologically Sorted Source Nodes: [x1, mul], Original ATen: [aten.add, aten.mul]
triton_poi_fused_add_mul_0_xnumel = s2*s3
stream0 = get_raw_stream(0)
triton_poi_fused_add_mul_0.run(arg2_1, buf0, triton_poi_fused_add_mul_0_xnumel, stream=stream0)
del arg2_1
return (buf0, s0, s1, s1, )
```
w/ graph partition, the partition function only includes tensor and inner shapes as inputs, to make sure the cudagraph caching is correct. Full Comparison: [code](https://www.internalfb.com/intern/diffing/?paste_number=1761674743)
```python
def call(self, args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
args.clear()
s0 = arg0_1
s1 = arg1_1
arg2_1_size = arg2_1.size()
s2 = arg2_1_size[0]
s3 = arg2_1_size[1]
assert_size_stride(arg2_1, (s2, s3), (s3, 1))
partition0_args = [arg2_1, s2, s3]
del arg2_1
(buf0,) = self.partitions[0](partition0_args)
del partition0_args
return (buf0, s0, s1, s1, )
```
The number of cudagraphs is validated below: (also added to test)
```python
import torch
from padded_tensor import PaddedTensor
# Turning off graph_partition leads to
# torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id().id=6
# at the end, which is wrong.
# torch._inductor.config.graph_partition = False
# Turning on graph_partition leads to
# torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id().id=4
# at the end, which is correct.
torch._inductor.config.graph_partition = True
def f(x):
x1 = x + 1
return x1 * 2
compiled_f = torch.compile(f, mode="reduce-overhead")
def run(shape):
x = torch.randn(*shape, device="cuda")
pad_x = PaddedTensor.from_tensor(x, multipliers={0:4, 1:4})
assert hasattr(pad_x, "multipliers"), breakpoint()
eager_out = f(pad_x)
for _ in range(3):
compiled_out = compiled_f(pad_x)
compiled_out = compiled_f(pad_x)
assert eager_out.shape == compiled_out.shape
assert eager_out.tensor.shape == compiled_out.tensor.shape
assert torch.allclose(eager_out.tensor, compiled_out.tensor)
# static shape. record a NEW cudagraph. 1 cudagraph in total now.
run((2,3))
# outer shape is dynamic, leading to a new dynamo graph
# this new dynamo graph forces a NEW cudagraph. 2 cudagraphs in total now
run((3,4))
# outer shape changed but inner shape does not change
# so NO new cudagraph is recorded
run((2,2))
# inner shape is dynamic now, leading to a new dynamo graph
# this new dynamo graph forces a NEW cudagraph. 3 cudagraphs in total now
run((5,6))
# does NOT record a new cudagraph
run((7,8))
# record a NEW cudagraph. 4 cudagraphs in total now
run((10,11))
assert torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id().id == 4
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149458
Approved by: https://github.com/eellison
Summary:
User defined Triton kernel sometimes rely on real inputs to determine
the path of execution. We need real inputs to invoke the correct
behavior of the user defined triton kernels (see example in test case,
where we have an early return for random inputs)
Test Plan:
Included in the commit.
python test/inductor/test_aot_inductor.py -k triton_autotuning
python test/inductor/test_aot_inductor.py -k triton_mutated_autotuning
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149553
Approved by: https://github.com/davidberard98, https://github.com/eellison
Patches over an issue where randomly generated example tensors can cause kernel autotuning to fail, when those tensors would not be possible outputs from previous kernels in the sequence. This fixes a failure in `test_torchinductor_opinfo.py` when run with compile-time autotuning, `test_comprehensive_nanquantile_cuda_float64`.
For clarity, the situation triggering this PR looks like kernels `A -> BCDE -> F` (`BCDE` is fused), where one of the outputs from `A` is a boolean tensor describing some of the input data. Previously, we randomly regenerated that boolean tensor and the input data before passing them to `BCDE`, so that they no longer matched. This caused a `tl.device_assert` call in `BCDE` to fail. With this PR, we reuse the random data input to `A` and the output Boolean tensor, such that they match and pass the device assertion in `BCDE`.
Fixes#147799.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146706
Approved by: https://github.com/desertfire
Summary:
Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583
Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
grid_0 = ((xnumel + 1023) >> 10)
grid_1 = 1
grid_2 = 1
runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```
This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.
It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.
This unification allows this PR to be a net deletion of code.
Differential [disconnected] Revision: D70471332
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305
Approved by: https://github.com/shunting314, https://github.com/eellison
Summary:
Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583
Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
grid_0 = ((xnumel + 1023) >> 10)
grid_1 = 1
grid_2 = 1
runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```
This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.
It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.
This unification allows this PR to be a net deletion of code.
Differential Revision: D70471332
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305
Approved by: https://github.com/shunting314, https://github.com/eellison
Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg. This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
grid_0 = ((xnumel + 1023) >> 10)
grid_1 = 1
grid_2 = 1
runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```
This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.
It also allows us to unify the handling of grids between the Python and C++ wrapper code. Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.
This unification allows this PR to be a net deletion of code.
Note the attached diff contains some minor fbcode-only changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147583
Approved by: https://github.com/eellison, https://github.com/shunting314
In this PR, we extract `codegen_unbacked_symbol_defs` of FallbackKernel out as a `codegen_unbacked_symbol_defs_for_outputs` method in wrapper. With it, HOPs can support the case where the subgraph returns a tensor with unbacked symints. This PR only do it for cond, we'll have follow up PRs for others (e.g. while_loop) as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147567
Approved by: https://github.com/jansel
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
Summary:
Add unique_user_kernel_names which mimics what unique_kernel_names do, but for user defined Triton kernels.
This does rewrite the copied kernel src, and modifies non-Inductor generated code, so we split it out from unique_kernel_names, where we have more control over all namings and generations.
Test Plan: Only used for debug purpose
Differential Revision: D69966608
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147587
Approved by: https://github.com/desertfire
TODO:
- [x] Add handling for when forward is invoked multiple times without invoking backward, so that the fwd/backward states are out of sync
- [x] Update rng state initialization to take from correct device
- [x] Tests
- [x] handling of retain_graph
- [x] respect fallback random
Fix for https://github.com/pytorch/pytorch/issues/130123.
Updates the aot_eager and cudagraph compilation of `run_and_save_rng_state` to use the new mechanism added by https://github.com/pytorch/pytorch/pull/114068 for CUDAGraph safe rng states.
We have a pair of rng states for the fwd and backward respectively. In both forward and backward the rng op will get run with `graphsafe_run_with_rng_state` which takes in RNG state and it hooks onto the current RNG generator before running the operator. The rng states for fwd/backward are initialized with the same value. We ensure that for any given run of the forward, the corresponding backward run will have the same rng states for the op as was observed in the forward.
```
===== Forward graph 1 =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", fwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = fwd_rng_state_0); fwd_rng_state_0 = None
...
===== Backward graph 1 =====
def forward(self, primals_1: "f32[4, 4][4, 1]cuda:0", primals_2: "f32[4, 4][4, 1]cuda:0", tangents_1: "f32[4, 4][4, 1]cuda:0", bwd_rng_state_0):
sin: "f32[4, 4][4, 1]cuda:0" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
graphsafe_run_with_rng_state = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten.rand.default, [4, 4], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False, rng_state = bwd_rng_state_0); bwd_rng_state_0 = None
```
There is some extra complication when a user either calls backward with retain_graph, or calls the backward in a different order as they called the forward. If a user has state fwd_rng_state0, bwd_rng_state0 and calls:
- fwd0: fwd_rng_state0 -> fwd_rng_state1
- fwd1: fwd_rng_state1 -> fwd_rng_state2
- bwd1
- bwd0
Then naively, when bwd1 is invoked the bwd rng states would not be equal to the same states that were observed in fwd1. I added handling of this in the aot runtime wrappers to detect pending backward invocations, and the current position of the bwd rng states, and to update when necesssary.
Other notes:
Because nodes which appear later in the forward appear earlier in the backward, we need a separate rng state for each operator. If we reused the rng across ops, the forward and backward would be run with different rng states. I.e., not applied in the same order.
Questions for reviewers:
This does change numerics, bc the rng of the op is now taken from the input rng state instead of whatever the rng would be midway through running the graph. Technically, we only need this for cuda graph. But, I'd prefer to not have a rng divergence just for cudagraph. I am making it respect `fallback_random`.
Edit: decided to apply to non cudagraphs as well, so long as fallback_random is not set
I'm initializing the rng states by cloning the current state. If you had something like 5 different rands in the model with the same shape, theyd all get the same value. This doesn't seem great. I could use some other initialization scheme like taking seed from graph position, or etc etc. Not sure. Let me know thoughts.
Edit: updated to be taken from randint()
Update: initializing rng states from torch.randint..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146878
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
As title.
Many changes adapted from https://github.com/pytorch/pytorch/pull/129537.
Also this diff is only for *static* method of torchbind *attributes*. Some case that's not supported/tested:
- dynamic torchbind objects
- torchbind objects as an input to the module.
Note that in JIT Inductor, the attributes are lifted as inputs. So even if we just have torchbind objects as attributes, they will show up as inputs in the graph.
Example generated python code in torch.compile with inductor backend for the test case in `inductor/test_torchbind.py` (P1730554370):
```python
async_compile.wait(globals())
del async_compile
def call(args):
arg1_1, arg2_1, arg3_1 = args
args.clear()
assert_size_stride(arg1_1, (2, 3), (3, 1))
assert_size_stride(arg2_1, (2, 3), (3, 1))
buf2 = empty_strided_cpu((2, 3), (3, 1), torch.float32)
cpp_fused_add_0(arg1_1, arg2_1, buf2)
del arg1_1
del arg2_1
# Topologically Sorted Source Nodes: [x, takes_foo_tuple_return], Original ATen: [aten.add]
buf3 = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(arg3_1, buf2)
buf4 = buf3[0]
assert_size_stride(buf4, (2, 3), (3, 1))
buf5 = buf3[1]
assert_size_stride(buf5, (2, 3), (3, 1))
buf6 = buf4; del buf4 # reuse
cpp_fused_add_1(buf6, buf5)
del buf5
# Topologically Sorted Source Nodes: [y, b], Original ATen: [aten.add]
buf7 = torch.ops._TorchScriptTesting.takes_foo.default(arg3_1, buf6)
del buf3
del buf6
buf8 = buf7
assert_size_stride(buf8, (2, 3), (3, 1))
# Topologically Sorted Source Nodes: [c], Original ATen: []
buf9 = torch.ops.higher_order.call_torchbind(arg3_1, 'add_tensor', buf2)
del arg3_1
del buf7
buf10 = buf9
assert_size_stride(buf10, (2, 3), (3, 1))
del buf9
buf11 = buf2; del buf2 # reuse
cpp_fused_add_2(buf11, buf8, buf10)
return (buf11, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg1_1 = rand_strided((2, 3), (3, 1), device='cpu', dtype=torch.float32)
arg2_1 = rand_strided((2, 3), (3, 1), device='cpu', dtype=torch.float32)
import pickle
global arg3_1
arg3_1 = pickle.loads(b'\x80\x04\x95[\x00\x00\x00\x00\x00\x00\x00\x8c\x05torch\x94\x8c\x0cScriptObject\x94\x93\x94)\x81\x94]\x94(K\nK\x14e\x8c0__torch__.torch.classes._TorchScriptTesting._Foo\x94\x86\x94b.')
fn = lambda: call([arg1_1, arg2_1, arg3_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146927
Approved by: https://github.com/angelayi
This fixes handling for "1" and "None" args with new Triton versions. TL;DR: triton_meta["constants"] (which is passed to ASTSource) should be a map of {"kwarg_name": constant_value} for values which are tl.constexpr, or have a value of 1 or None (i.e. "specialized" constants). For constant args, triton_meta["signature"][arg_name] should be "constexpr" (even for specialized constants).
Note: This adds support for Triton versions after 5512; but not for versions in between 5220 and 5512 (i.e. `TritonAttrsDescriptorVersion.V3_BACKENDS_TUPLE`). There's a completely different format for constants/signature in the commit range in between.
To test: I ran `test_torchinductor.py` and `test_triton_kernels.py` with the main branch of triton (~jan 27). The only failing tests are aoti-related tests (which need to be fixed as a follow-up), and test_mutable_custom_op_fixed_layout2_cuda (which is failing with or without the new triton version on my machine); additionally, the split-scan/split-reduction kernels rely on https://github.com/triton-lang/triton/pull/5723.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145515
Approved by: https://github.com/SamGinzburg
Longer term would be good to add as a feature to cpp_wrapper, but this makes sure it doesn't fail on main.
Not sure if this needs a test because it's not meant to compose, but will add one if necessary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145538
Approved by: https://github.com/desertfire
Prior to this PR, constexprs were appearing in signatures as `{.. "XBLOCK : tl.constexpr": "constexpr"}` when they really should appear as `{.. "XBLOCK": "constexpr"}`.
This PR represents the argument names as ArgName objects, which can optionally be marked as constexpr.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145583
Approved by: https://github.com/jansel
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This PR fixes user-defined triton kernel handling (in most cases) for these new triton commits.
What this PR fixes:
* in triton_kernel_wrap.py, AST->TTIR parsing was to be updated for the new triton API
* ir.py - don't remove None args when using newer triton versions
* wrapper.py - update signature & constant handling
What this doesn't fix:
* correct None handling - I want to do a closer look at constant handling (including None, equal_to_1, and other constants).
* cpp wrapper (which needs to be fixed for both user-defined triton kernels and inductor-generated kernels)
test/inductor/test_triton_kernels.py passed on triton commit 74de6b46, with the exception of three tests (those shown here: 1374074098)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145348
Approved by: https://github.com/jansel
ghstack dependencies: #145051
Some context: Inplace padding is an optimization to do padding in place. E.g., if a tensor has size [2048, 2047] and stride [2048, 1]. When we need pad one extra element to the end of each row (e.g. during mm padding), we can just reuse the original tensor and do the padding inplace. This saves memory and bandwidth. One caveat for this optimization is, PyTorch does not allocate 2048 elements for the last row of the original tensor. It only allocate 2047 elements. So assuming the last row having enough space for 2048 elements may be wrong and cause OOB memory access (although I never see this happen maybe due to overallocation in the CUDACachingAllocation, this should better be fixed).
The fix is when we allocate the tensor, instead of doing something like:
```
buf0 = randn_strided([2048, 2047], [2048, 1])
```
we do some small overallocation
```
buf0 = randn_strided([2048, 2048], [2048, 1]).as_strided([2048, 2047], [2048, 1])
```
cpp_wrapper needs special handling since memory allocation goes thru different code path to python wrapper.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145325
Approved by: https://github.com/desertfire, https://github.com/jansel
ghstack dependencies: #140249
**Summary**
Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012
- Support flexible number of GEMMs
- Share activation across GEMMs
- The Grouped GEMM Template supports independent activations
- However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
- Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR
- Each GEMM have its own epilogues
- Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid
python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear
```
**Example**
Here is the example and generated code
```
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
def forward(self, x):
return self.linear0(x), self.linear1(x)
if __name__ == "__main__":
with torch.no_grad():
input = torch.randn(batch_size, in_features, dtype=dtype)
m = M(bias=bias).to(dtype=dtype).eval()
cm = torch.compile(m)
act_res = cm(input)
```
Generated Code: https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py
**Next Step**
- Support Epilogue fusion
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796
Approved by: https://github.com/jgong5, https://github.com/jansel