Summary: This PR adds dynamic-shape support for AOTInductor
* On the runtime/interface side, we added two structs, StaticDimInfo
and DynamicDimInfo, to hold values for static and dynamic dimensions,
respectively. Dynamic dimensions are tracked by an unordered map field
defined in AOTInductorModelBase. At inference time, the inference run
method will assign the current real dimensional value to each dynamic
dimension before executing any kernel.
* On the CUDA wrapper codegen side, we generate dynamic symbols
appropriately for shape computations. We simulate kernel launch grids
in the C++ land by re-using the grid functions from the Python world.
The returned grid configs, which may contain symbolic expressions,
are printed out in their C++ forms via the CppPrinter. Note that
when dynamic shapes are involved, we have to compute grid configs
for each kernel at runtime in the same way as we do for launching
the corresponding Triton kernel. Otherwise, we may end up with
memory-access failures or mis-computations caused by invalid indices
for fetching or storing data in device memory.
Differential Revision: D49100472
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109012
Approved by: https://github.com/khabinov, https://github.com/desertfire, https://github.com/hl475
Summary: Switch AOTInductor unit tests and integration tests to invoke the same runtime interface. This is only an effort to unify the usage of the runtime. The interface scrutiny will come in later PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108663
Approved by: https://github.com/ezyang
ghstack dependencies: #108653
Fixes https://github.com/pytorch/pytorch/issues/108323.
Cpp wrapper has functionality regression on `llama` and `tnt_s_patch16_224` due to recent support of scaled dot product flash attention in inductor.
The schema of this OP is as follows:
```
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
```
For `llama` and `tnt_s_patch16_224`, the OP is called in the below way, where the three positional args with default values are not passed (`float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False`).
```python
y = torch.ops.aten._scaled_dot_product_flash_attention.default(x0, x1, x2, scale = 0.125)
```
This PR fixes the cpp wrapper support for this case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108552
Approved by: https://github.com/jgong5, https://github.com/desertfire, https://github.com/jansel
This replaces `var_unnormalized` reduction type with `welford_reduce` which takes the input data and outputs not just the variance, but also the mean and weights which account for the full welford accumulator state. Thus we can avoid re-computing the mean, and we now have enough information to create a multilayer reduction which I implement here by adding a second reduction type called `welford_combine` which reduces over all three inputs simultaneously.
Multi-layer support is particularly important as normalization operators like BatchNorm are being split in many timm models, which meant `var_unnormalized` had to fall back to two-pass variance calculation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104725
Approved by: https://github.com/lezcano
Working as starter task with @Chillee
This PR adds a method under BaseSchedulerNode to estimate the node's runtime in seconds.
We use a heuristic based approach, first by considering whether the operation is memory bandwidth bounded or compute bounded:
- memory bandwidth bounded: we compute the number of bytes that are read/written to
- compute bounded: we compute the FLOPS required by the operation
One use case could be to be used as a cost model for scheduling: https://github.com/pytorch/pytorch/pull/100762
```
(pytorch-3.10) [14:08:02] ~/local/pytorch (xmfan/estimate_snode_runtime) > python3 test/inductor/test_perf.py -k EstimateSnodeRuntimeTests
[(ExternKernelSchedulerNode(name='buf0'), 400)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 3000), (SchedulerNode(name='buf1'), 3000)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-26), (SchedulerNode(name='buf1'), 7.187055238190188e-09)]
.[(ExternKernelSchedulerNode(name='buf0'), 3000)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-26)]
.[(ExternKernelSchedulerNode(name='buf0'), 34600)]
[(ExternKernelSchedulerNode(name='buf0'), 3.22687496698039e-24)]
.[(ExternKernelSchedulerNode(name='buf0'), 396)]
[(ExternKernelSchedulerNode(name='buf0'), 1.88046326747109e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 396)]
[(ExternKernelSchedulerNode(name='buf0'), 1.88046326747109e-27)]
.[(ExternKernelSchedulerNode(name='buf0'), 7776176)]
[(ExternKernelSchedulerNode(name='buf0'), 4.63240241413653e-21)]
.[(FusedSchedulerNode(nodes=buf0_buf1), 210)]
[(FusedSchedulerNode(nodes=buf0_buf1), 5.030938666733132e-10)]
.[(ExternKernelSchedulerNode(name='buf0'), 300)]
[(ExternKernelSchedulerNode(name='buf0'), 2.35057908433887e-27)]
.[(SchedulerNode(name='buf0'), 20)]
[(SchedulerNode(name='buf0'), 4.7913701587934585e-11)]
.
----------------------------------------------------------------------
Ran 10 tests in 14.311s
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106426
Approved by: https://github.com/Chillee
I happened to find that inductor may cache stale inner_fn_str and ReadWrites object in a ComputedBuffer when I work on looping ordering.
Let's say we have producer buffer buf0 and consumer buffer buf1. Before we call GraphLowering.finalize, the layout for buf0 may be a FlexibleLayout. At that moment, the inner_fn_str or ReadWrites object computed for buf1 will be based on the layout of buf0 which most likely is a contiguous FlexibleLayout. And they will be cached on buf1 object (or buf1.data).
However after we call GraphLowering.finalize, we may realize it's better to give a non-contiguous layout for buf0 (e.g., if its input has non-contiguous layout or whatever reason). The layout change of buf0 should affect the inner_fn_str and ReadWrites object for buf1. But we may have cached those on buf1. The stale ReadWrites objects for buf1 may result in sub-optimal strides for buf1.
This may affect perf and I'll check the nightly runs.
Here is a dump of `nodes` in `Scheduler.__init__` before the fix as a reference: https://gist.github.com/shunting314/ed2152a08e268f5563fd55398b1392c7
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106502
Approved by: https://github.com/jansel
Currently when dynamic=True, TritonTemplates won't be used, as the condition `if list(call_args) != expected_args` defined in `TritonTemplate` cannot be satisfied. This PR tries to fix this issue by allowing passing symbolic variable names via `extra_args` and replacing all symbolic values in the generated TritonTemplate code as call_arg names.
With this change, a locally compiled mm + epilogue node calls into the Triton kernel successfully.
This PR also introduces a new config "max_autotune_gemm_backends" to allow specifying candidate gemm backends for max autotune. Current choices: combinations of ATEN, TRITON. This makes tests easier, so that we can explicitly test Triton gemm kernels + epilogue fusions + dynamic shapes, without falling back to ATen ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105295
Approved by: https://github.com/jansel
Previously, we made backwards graph compilation lazy to avoid paying
for compilation if the user didn't actually end up using the backwards
graph. This was useful in the old days when a lot of things in Inductor
didn't work and we could bypass errors this way.
However, this has a bad implication for dynamic shapes: the backwards
graph compilation can trigger extra guards, which are too late to
install in the Dynamo context if we wait until backwards is being run.
So in this PR I move us back to compiling backwards graph immediately
if we capture any SymInts for backwards.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104971
Approved by: https://github.com/Chillee
Fix cpp wrapper failure on TorchBench model `hf_Reformer` with `randn`:
```
random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype)
```
For cpp wrapper, when `kwargs` is not empty, for `OpOverloadPacket` kernel, we need to know the exact overload schema to handle the `kwargs` properly when calling the cpp kernel: including finding the correct order of the kwargs and getting the default value for optional args without provided value when calling the function (`layout` in the above case).
The current support in this PR is conservative and we'll extend the functionality in subsequent PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104575
Approved by: https://github.com/jgong5, https://github.com/desertfire
This allows `ops.minimum` and `ops.maximum` to be hoisted for indirect indexing
into direct indexing expressions. I also add support to the cpp printer for
Min/Max and fix the triton printer to support multi-argument Min/Max.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105020
Approved by: https://github.com/lezcano
This PR handles inference. Will do similar thing for training later.
Some manual testing results shows this can improve inference perf by 2-3% (absolute improvement not relative one).
- convmixer: 4.285x -> 4.309x
- resnet50: 2.170x -> 2.203x
The PR is built upon freezing. Since without freezing, the weight input for a conv node may not be a parameter directly but be the output of precision converting ops. It's so much easier to implement this PR after freezing.
Commands
```
TORCHINDUCTOR_FREEZING=1 python benchmarks/dynamo/timm_models.py --backend inductor --amp --performance --only convmixer_768_32 --inference
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103642
Approved by: https://github.com/eellison
This PR decouples the logic necessary to compute bounds on variables
from the logic that uses this info to perform the strenght analysis on
int64 variables. While doing so, it tries to minimize the number of
attributes of the class in favour of local variables.
This class is now accessible from any `LoopBody` object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100549
Approved by: https://github.com/eellison
Introduces two higher order operators
* run_and_save_rng_state - Saves the current rng state and then runs the op.
* run_with_rng_state - Runs the op with the rng state supplied as an input
Ideally, we would like to use torch.compile for these operators. But currently the plan is to introduce these operators at the partitioner level, obviating the need to support them fully through the torch.compile stack. To ensure that we have good enough debugging with minifiers, we have ensure that they work with make_fx. In future, we can move on torch.compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102934
Approved by: https://github.com/jansel, https://github.com/zou3519
This PR just contains some mild gyrations necessary to appease mypy.
However, it is not complete; there are a number of legitimate bugs
and mistyping that I need to work out before I can actually turn this
on.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100712
Approved by: https://github.com/ngimel
Added helper functions to match nodes in the graph that are decomposed from their source (leaf modules, or functional ops), as a result of dynamo tracing.
`get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]`
Args:
* graph: The graph we want to partition
* wanted_sources: List of sources of nodes that were decomposed from this source. This can be a function (ex. torch.nn.functional.linear) or a leaf module type (ex. torch.nn.Linear)
Returns:
* Dictionary mapping sources (ex. torch.nn.modules.linear.Linear) to a list of SourcePartitions that correspond to the list of nodes that were flattened from a module of that type.
```
@dataclass
class SourcePartition():
# Nodes in a particular partition
nodes: List[Node]
# Module type
module_type: Type
# Nodes in the graph that are needed as inputs to the partition
input_nodes: List[Node] = field(default_factory=list)
# Nodes in the partition that are being used by nodes outside of the partition
output_nodes: List[Node] = field(default_factory=list)
# Parameters that are being used
params: List[str] = field(default_factory=list)
```
Example:
Original:
```
x -> linear -> linear -> relu -> linear
```
Traced graph:
```
.graph():
%arg0 : [#users=1] = placeholder[target=arg0]
%_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
%t_default : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0,), kwargs = {})
%_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
%_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
%t_default_1 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant0_1,), kwargs = {})
%_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
%addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
%relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
%_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
%t_default_2 : [#users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant2,), kwargs = {})
%_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
%addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
return [addmm_default_2]
```
Result of `get_module_partitions`:
```
{<class 'torch.nn.modules.linear.Linear'>: [
ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],
<class 'torch.nn.modules.activation.ReLU'>: [
ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
```
Also added helper function to check if two module partitions are connected:
`check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98628
Approved by: https://github.com/cccclai
Command to run max autotune baseline:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --only ${MODEL_NAME} --training --batch-size-file $(realpath benchmarks/dynamo/torchbench_models_list.txt)
```
Command to do coordinate descent autotuning:
```
TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting_coordesc TORCHINDUCTOR_PERSISTENT_REDUCTIONS=0 TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --only ${MODEL_NAME} --training --batch-size-file $(realpath benchmarks/dynamo/torchbench_models_list.txt)
```
Explanation of the envvars show up on the command:
```
- TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 : enable coordinate descent tuning
- TORCHINDUCTOR_PERSISTENT_REDUCTIONS=0 : disable persistent reduction. Need do this so we can tune RBLOCK for reductions
- TORCHINDUCTOR_MAX_AUTOTUNE=1: enable max autotune
- TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_shunting_coordesc : use a separate cache dir for coordinate descent tuning. Optional.
```
Here are my experiments results for around 40 torchbench models: https://docs.google.com/spreadsheets/d/1G7i2whIf8Yu-HhN_WovNxwcE-iFDSAw4x3NK4uL4XhI/edit#gid=0
Some highlights
- We improve 2.2% further upon max-autotune on average (geomean)
- timm_resnest benefits most from coordinate descent tuning. There is 1.07x speedup
- We have descent speedup on transformer models
- BERT_pytorch: 1.056x
- timm_vision_transformer: 1.04x
- hf_Bert: 1.030x
- For resnet models, it looks like we have less gain as model get larger. My guess is larger model spend more time on mm/conv, so our tuning for pointwise/reduction helps less
- resnet18: 1.021x
- resnet50: 1.014x
- resnet152: 1.005x
This kind of coordinate descent autotuning can give us 'upper bound' of the gain we can get for tuning configs for pointwise/reduction. On the other hand, by spot checking, we roughly double the compilation time compared to max-autotune. Next steps can be
- we disable persistent reduction in coordinate descent autotune (it's still enabled in baseline) so we can tune RBLOCK for reduction. We can also try to use autotune to pick persistent reduction or not.
- pick good config without benchmarking (e.g. Natalia mentioned checking register spill)
- try the idea on matmul so we know what's the potential there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97203
Approved by: https://github.com/ngimel