Summary:
This is a prototype for running extern fallback kernels with a host side proxy executor.
Sample of generated cpp wrapper call:
```
at::Tensor buf0; // output buffer
void* tensor_args_var_0[] = {&arg0_1, &arg0_1, &arg1_1, &arg0_1, &arg1_1, &buf0};
int64_t int_args_var_1[] = {81, 81, 7, 7, 7, 81};
proxy_executor->call_function("buf0", int_args_var_1, tensor_args_var_0);
```
- In my current implementation, proxy executor interprets the raw pointers according to the ops schema.
This assumes that custom op MUST have a valid schema registered to Dispatcher. (I would like to validate this assumption)
- I am using callboxed() API of the custom kernels. This is inevitable, as we wish to have a single call_function API for all possible custom kernels.
- These are all the input argument types I have support so far.
union Argument {
# Bool value does not matter
1: bool asNone;
2: TensorArgument asTensor;
3: list<TensorArgument> asTensors;
5: i64 asInt;
7: list<i64> asInts;
8: double asFloat;
9: list<double> asFloats;
10: string asString;
10.5: list<string> asStrings;
11: SymIntArgument asSymInt;
12: list<SymIntArgument> asSymInts;
13: ScalarType asScalarType;
14: MemoryFormat asMemoryFormat;
15: Layout asLayout;
16: Device asDevice;
17: bool asBool;
18: list<bool> asBools;
}
- Need a policy for handling unpopulated argument with default values. Here are the options, and it has BC implications.
1. requires exported fx graph to explicitly populate default values, if users doesn't specify.
2. requires cpp wrapper to explicitly populate default values, if fx graph doesn't specify.
3. Proxy executor look up from opSchema for default values.
For fixing T162112344
Test Plan:
frontend:
buck2 run mode/dev-sand mode/inplace -c fbcode.enable_gpu_sections=True sigmoid/frontend:export_main
test:
buck2 run mode/dev-sand //deeplearning/aot_inductor/test:test_custom_ops
backend:
buck2 run mode/dev-nosan //deeplearning/aot_inductor/fb:main
buck2 test 'fbcode//mode/opt' fbcode//caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark -- --exact 'caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark - test_aot_inductor_benchmark_cmf30x (caffe2.torch.fb.model_transform.experimental.benchmark.test.test_aot_inductor_benchmark.AOTInductorBenchmark)'
Reviewed By: suo
Differential Revision: D48747417
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108350
Approved by: https://github.com/izaitsevfb
Inductor kernel codegen previously have the following side effect:
- in `Kernel.__exit__ `, we add local used buffers in graph.removed_buffers
- during codegen, we do memory allocation/free.
These cause doing multiple versions of codegen for the same kernel hard. The PR refactor the code to make kernel codegen not changing graph level states. After codegening a kernel, the graph level state is not changed so we can go on to codegen another version of the kernel if we want.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107617
Approved by: https://github.com/jansel
For max_pooling code:
```
#pragma GCC ivdep
for(long i2=static_cast<long>(0L); i2<static_cast<long>(56L); i2+=static_cast<long>(1L))
{
for(long i3=static_cast<long>(0L); i3<static_cast<long>(64L); i3+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int>(static_cast<int>((-1L) + (2L*i1)));
auto tmp1 = at::vec::Vectorized<int>(static_cast<int>(0));
auto tmp2 = to_float_mask(tmp0 >= tmp1);
auto tmp3 = at::vec::Vectorized<int>(static_cast<int>(112));
auto tmp4 = to_float_mask(tmp0 < tmp3);
auto tmp5 = tmp2 & tmp4;
auto tmp6 = at::vec::Vectorized<int>(static_cast<int>((-1L) + (2L*i2)));
auto tmp7 = to_float_mask(tmp6 >= tmp1);
auto tmp8 = to_float_mask(tmp6 < tmp3);
auto tmp9 = tmp7 & tmp8;
auto tmp10 = tmp5 & tmp9;
auto tmp11 = [&]
{
auto tmp12 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>((-7232L) + i3 + (128L*i2) + (14336L*i1) + (802816L*i0)), 16);
load
auto tmp13 = cvt_lowp_fp_to_fp32<bfloat16>(tmp12);
return tmp13;
}
;
auto tmp14 = decltype(tmp11())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp11(), to_float_mask(tmp10));
```
the index of ```tmp12 ``` may be a correct index, such as ```i1=0, i2=0, i3=0```, the index is ```-7232L```, it is not a valid index. We may meet segmentation fault error when we call ```tmp11()```, the original behavior is that only the ```tmp10```(index check variable) is true, we can safely get the value, this PR will support masked_load to fixing this issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107670
Approved by: https://github.com/jgong5, https://github.com/jansel
This replaces `var_unnormalized` reduction type with `welford_reduce` which takes the input data and outputs not just the variance, but also the mean and weights which account for the full welford accumulator state. Thus we can avoid re-computing the mean, and we now have enough information to create a multilayer reduction which I implement here by adding a second reduction type called `welford_combine` which reduces over all three inputs simultaneously.
Multi-layer support is particularly important as normalization operators like BatchNorm are being split in many timm models, which meant `var_unnormalized` had to fall back to two-pass variance calculation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104725
Approved by: https://github.com/lezcano
## Summary
This is re-land PR for https://github.com/pytorch/pytorch/pull/100706 to address the compilation latency performance regression.
## Root Cause
Regarding the C++/OpenMP backend, `codecache.pick_vec_isa()` to check vectorization ISA is a time-consuming and one-shot operation. It leads to taking a longer time to import `codegen.cpp` package because the `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while the decorator will invoke `codecache.pick_vec_isa()` to initialize the `simd_nelements` of the `LoopLevel`.
c14cf312c9/torch/_inductor/codegen/cpp.py (L2883C53-L2883C53)
In terms of the Triton backend, it does not need to touch it. But we'd prefer to uniform the code. Therefore, the new design simultaneously registers `CpuScheduling` for CPU and `TritonScheduling` for Triton regardless of whether the current backend is Triton. It will bring additional overhead to the Triton backend.
```python
def init_backend_registration(self):
if get_scheduling_for_device("cpu") is None:
from .codegen.cpp import CppScheduling
register_backend_for_device("cpu", CppScheduling, WrapperCodeGen)
if get_scheduling_for_device("cuda") is None:
from .codegen.triton import TritonScheduling
register_backend_for_device("cuda", TritonScheduling, WrapperCodeGen)
```
## Solution
To resolve the compilation latency regression for the Triton backend, we changed the `LoopLevel` a little bit([new code changes](https://github.com/pytorch/pytorch/pull/106874/files#diff-5ab7b0235e2076a5fc6629ba0b109208940f5b94f5c13babc3e0f87cf4fcec82R2893-R2904)) by moving the `simd_nelements` to `__post_init__` and the compilation performance would be back.
## Compilation Latency Performance Result
We ran a single model benchmark and reproduced the compilation regression:
- Run `python benchmarks/dynamo/torchbench.py -dcuda --training --performance --inductor --only hf_Bart`
- W/ PR #100706, the compilation latency is about **57~58**
```
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio,eager_peak_mem,dynamo_peak_mem,calls_captured,unique_graphs,graph_breaks,unique_graph_breaks
cuda,hf_Bart,4,1.556712,109.676554,57.055242,0.936330,5.760698,6.152422,642,1,8,7
cuda,hf_Bart,4,1.646658,109.621747,57.909817,0.936330,5.760698,6.152422,642,1,8,7
```
- W/O PR #100706, the compilation latency is about **46~47**
```
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio,eager_peak_mem,dynamo_peak_mem,calls_captured,unique_graphs,graph_breaks,unique_graph_breaks
cuda,hf_Bart,4,1.599065,108.702480,47.490346,0.936330,5.760698,6.152422,642,1,8,7
cuda,hf_Bart,4,1.588419,108.431411,46.983041,0.936330,5.760698,6.152422,642,1,8,7
```
This PR fixed the compilation performance regression.
- W/ this PR #106874, the compilation latency is about **47~48**
```
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio,eager_peak_mem,dynamo_peak_mem,calls_captured,unique_graphs,graph_breaks,unique_graph_breaks
cuda,hf_Bart,4,1.586261,108.149467,47.481058,0.936330,5.760698,6.152422,642,1,8,7
cuda,hf_Bart,4,1.758915,108.613899,47.925633,0.936330,5.760698,6.152422,642,1,8,7
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106874
Approved by: https://github.com/jansel
This PR aims to sort out the data type for `constant`.
The constant should be promoted to float https://github.com/pytorch/pytorch/pull/105440. So there are serval changes to do:
- Data type propagation should propagate constant node to `float` dtype if original dtype is `bfloat16`
- We do not need to insert `to_dtype` after the `constant` node, directly init an `fp32` constant is faster.
```
vectorized<bfloat16> tmp(value);
vectorized <float> tmp1 = cvt_bf16_fp32(tmp);
->
vectorized<float> tmp(value);
```
- move `constant` out of the list for `all operations can support bf16 without converting to fp32`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105827
Approved by: https://github.com/jgong5, https://github.com/jansel
This PR intends to extend Inductor to support the third-party backend that only focuses on the code generation just like what C++/OpenMP and Triton backend have done.
Currently, the generated code by Inductor contains two major parts. One is the kernel, and the other is the Python wrapper to glue the kernel. Therefore, the third-party backend needs to customize the two parts to generate its specific code.
- Python wrapper code generation
Inductor provides a `WrapperCodeGen` class to generate the Python wrapper code to glue the kernel. Therefore, it is straightforward for the third-party backend to generate the backend-specific Python wrapper code. It just needs to inherit the `WrapperCodeGen` class and purposely override the particular member functions.
- Kernel code generation
It is driven by different `Scheduling`. Hence, the third-party backend needs to provide a custom `Scheduling` for its specific kernel code generation. Currently, `CppScheduling` and `TritonScheduling` are for C++/OpenMP and Triton backend, respectively. But there is no common `Scheduling` class. Based on the scheduling invocation, this PR abstracts a common `Scheduling` class containing the following member functions.
- [group_fn](71c4becda7/torch/_inductor/scheduler.py (LL649C64-L649C64))
- [flush](71c4becda7/torch/_inductor/scheduler.py (L1150))
- [can_fuse_vertical](71c4becda7/torch/_inductor/scheduler.py (L1006))
- [can_fuse_horizontal](71c4becda7/torch/_inductor/scheduler.py (LL1008C45-L1008C64))
- [codegen_template](71c4becda7/torch/_inductor/scheduler.py (L1234)) _This function is only available for triton. If the third-party backend behaves as a sub-class of `TritonScheduling`, it can override it or reuse it._
- [codegen_nodes](71c4becda7/torch/_inductor/scheduler.py (L1234))
- [codegen_sync](71c4becda7/torch/_inductor/scheduler.py (LL1251C1-L1251C1)). _This function is only available for triton debug purpose. But it might also be useful for other computation devices. Therefore, we'd prefer to keep this function._
The third-party backend needs to inherit from the `Scheduling` class and implement these functions.
Regarding some other classes like `CppKernel` and `TritonKernel` for code generation, they are used by or part of the logic of either `Scheduling` or `WrapperCodeGen`. Hence, this PR does not define the interface and leaves the flexibility to the third-party backend. The third-party backend can decide to implement these classes from scratch or reuse them by inheriting and overriding them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100706
Approved by: https://github.com/jansel
This allows `ops.minimum` and `ops.maximum` to be hoisted for indirect indexing
into direct indexing expressions. I also add support to the cpp printer for
Min/Max and fix the triton printer to support multi-argument Min/Max.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105020
Approved by: https://github.com/lezcano
This is intended as a first step towards reductions with multiple outputs. This
also incidentally improves CSE of reductions under C++ codegen. For example,
```python
def fn(x):
return torch.argmin(x, dim=-1), torch.argmin(x, dim=-1)
```
Currently this generates two reductions, where the common load is CSEd
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
if (tmp_acc1.value > tmp0) {
tmp_acc1.index = i1; tmp_acc1.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
auto tmp2 = tmp_acc1.index;
out_ptr1[static_cast<long>(i0)] = tmp2;
```
but with this change it gets CSEd to a single accumulator
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10L); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
out_ptr1[static_cast<long>(i0)] = tmp1;
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102737
Approved by: https://github.com/jgong5, https://github.com/lezcano
This is a bit inefficient because it computes the mean and throws it
away since ir.Reduction nodes only have 1 output. However, the mean
can at least be scheduled into the same loop as the variance now since
there is no data dependency. Thus we can take fewer passes over the
data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102486
Approved by: https://github.com/lezcano, https://github.com/jansel
The analysis for SymPy expressions was incorrect as, even though it said
that the assumption was "smoothness" the assumption was, in fact, that he
formula was monotone in every variable. In other words, it was
assuming that the derivative does not change signs in any variable (!!).
We implement a function that, given bounds on the values of the free
symbols of a sympy expression, it gives a bound on a the expression
itself.
We reshuffle a few things in value_ranges.py to create a
`SymPyValueRangeAnalysis` class, but we do not change any code really.
The only relevant change in that file is the addition of the
`sympy_bound`s function. We do this because we don't want to inadvertently
use any fallbacks in this case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104559
Approved by: https://github.com/eellison
This is intended as a first step towards reductions with multiple outputs. This
also incidentally improves CSE of reductions under C++ codegen. For example,
```python
def fn(x):
return torch.argmin(x, dim=-1), torch.argmin(x, dim=-1)
```
Currently this generates two reductions, where the common load is CSEd
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
if (tmp_acc1.value > tmp0) {
tmp_acc1.index = i1; tmp_acc1.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
auto tmp2 = tmp_acc1.index;
out_ptr1[static_cast<long>(i0)] = tmp2;
```
but with this change it gets CSEd to a single accumulator
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10L); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
out_ptr1[static_cast<long>(i0)] = tmp1;
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102737
Approved by: https://github.com/jgong5, https://github.com/lezcano
This is a bit inefficient because it computes the mean and throws it
away since ir.Reduction nodes only have 1 output. However, the mean
can at least be scheduled into the same loop as the variance now since
there is no data dependency. Thus we can take fewer passes over the
data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102486
Approved by: https://github.com/lezcano, https://github.com/jansel
**Summary**
We have supported the vectorization code gen with pattern of `dequant-relu-quant`, for which `to_uint8` is the last node of quant pattern before store into memory. However, there is another case that `dequant1-relu-quant2-dequant2-relu-quant3`. In this case, `quant2` is at the middle of fusion pattern, we enable vectorization code gen of `quant2-dequant2` in this PR.
**Test Plan**
```
python -u -m pytest -s -v test_cpu_repro.py -k test_dequant_relu_quant_dequant_relu_quant_lowering
```
**Next Step**
* For better performance, we can add another pass to eliminate pair nodes of `float_to_uint8` and `uint8_to_float`.
* For better performance, we should annotate `dequant1` and `quant2` as share observer in quantization recipe. Then we can lower `dequant1-relu-quant2` into a QReLU node to fully eliminate the calculation of `dequant1` and `quant2`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104503
Approved by: https://github.com/jgong5, https://github.com/jansel
**Summary**
Refactor the vectorization code generation of uint8 input data type. Previously, we combine the uint8 data load and uint8 to float data convert into one step as `load_uint8_as_float` and `store_float_as_uint8`. After refactor, we split them into 2 steps of load/store and data type convert to make the behavior same as BFloat16 data type .
The previous generated code is:
```
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(432L); i0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::load_uint8_as_float(in_ptr0 + static_cast<long>(i0));
auto tmp1 = (tmp0);
auto tmp2 = at::vec::Vectorized<float>(static_cast<float>(100.0));
auto tmp3 = tmp1 - tmp2;
auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(0.01));
auto tmp5 = tmp3 * tmp4;
auto tmp6 = at::vec::clamp_min(tmp5, decltype(tmp5)(0));
auto tmp7 = tmp6 * tmp2;
auto tmp8 = tmp7.round();
auto tmp9 = tmp8 + tmp2;
auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(0.0));
auto tmp11 = at::vec::maximum(tmp9, tmp10);
auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(255.0));
auto tmp13 = at::vec::minimum(tmp11, tmp12);
auto tmp14 = (tmp13);
at::vec::store_float_as_uint8(tmp14, out_ptr0 + static_cast<long>(i0));
}
```
After this PR, the generated code is:
```
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(432L); i0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<uint8_t>::loadu(in_ptr0 + static_cast<long>(i0), 16);
auto tmp1 = cvt_uint8_to_fp32_with_same_elem_num(tmp0);
auto tmp2 = at::vec::Vectorized<float>(static_cast<float>(100.0));
auto tmp3 = tmp1 - tmp2;
auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(0.01));
auto tmp5 = tmp3 * tmp4;
auto tmp6 = at::vec::clamp_min(tmp5, decltype(tmp5)(0));
auto tmp7 = tmp6 * tmp2;
auto tmp8 = tmp7.round();
auto tmp9 = tmp8 + tmp2;
auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(0.0));
auto tmp11 = at::vec::maximum(tmp9, tmp10);
auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(255.0));
auto tmp13 = at::vec::minimum(tmp11, tmp12);
auto tmp14 = cvt_fp32_to_uint8(tmp13);
tmp14.store(out_ptr0 + static_cast<long>(i0), 16);
}
```
**Test Plan**
```
python -m pytest test_cpu_repro.py -k test_decomposed_dequant_relu_quant
python -m pytest test_cpu_repro.py -k test_tile2d_load_decomposed_dequant_add_relu_quant
python -m pytest test_cpu_repro.py -k test_tile2d_store_channel_shuffle_cl_quant_output
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104075
Approved by: https://github.com/jgong5, https://github.com/jansel
Summary:
This diff adds a path in inductor to invoke gcc through Remote Execution, when run from within fbcode.
This should (hopefully) let us kill the `inductor.disable_cpp_codegen` flag, since we should now be able to invoke clang at runtime from within fbcode to compile c++ code. This was preventing https://github.com/pytorch/pytorch/pull/100115 from landing, which fixed one of the last remaining models in torchbench that was failing with `torch.compile` (hf_Longformer).
Enumeration of changes:
- updated inductor to invoke `_run_build_command()` when in fbcode, which hooks into Remote Execution
- When inductor invokes g++ normally, it includes a bunch of absolute paths, to stuff like the pytorch header paths, and the input and output path. I changed these all to relative paths when in fbcode, and copied everything we needed into a temp dir that we send to Remote Execution.
- updated `triton/fb/make_build_paths.py` to let us grab paths to openmp, sleef, and ld from within the Remote Execution environment. I'm not sure if there's a better way to do this (but this way appeared to work, thanks to Bert's suggestion from https://www.internalfb.com/diff/D46482550?dst_version_fbid=231706286239076&transaction_fbid=229345569847706)
- factored `triton/fb/build.py` (it had a function to create a triton build command and run it all in one go, I separated the bit that takes in an arbitrary command (our clang command), and runs it with RE)
- a few tweaks to the include paths that inductor uses: it adds those two extra paths (sleef and openmp), and it also does not manually include the `-ltorch`,`-lc10`,`-ltorch_python`,`-ltorch_cpu` libs - the linker was complaining that it couldn't find those libs, and not including those flags ends up working
- I added a few more missing headers. Maybe with D46527002 this won't be necessary?
- I had a basic manual test in `scripts/hirsheybar/tmp2.py`. We probably want to try running an actual job in MAST to make sure this works.
Test Plan: `scripts/hirsheybar/pt2/tmp2.py` has a basic test, but I'm also planning on testing by kicking off a MAST job with cmf_10x (thanks to a bunch of help from Bert)
Reviewed By: bertmaher
Differential Revision: D46364355
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104351
Approved by: https://github.com/bertmaher
Fixes#102752
These 3 fallback kernels appear in GoogleFnet because they take complex arguments - i.e., usually they aren't fallback kernels. To support this model, we added support for these 3 ops.
Details:
1. Add these 3 ops to the allowlist. I assume that we eventually want to support all fallback kernels, but for now we just add these 3 ops to the allowlist.
2. Support complex64 in cpp codegen
3. Support List[] arguments and ScalarType arguments in cpp codegen
4. Allow alias_info in schema arguments. In the original PR supporting fallback kernels for cpp wrapper, ops with schemas with non-null alias_info for any of the arguments were disallowed; but I don't think there's any reason we need to disallow these in cpp wrapper code.
Caveats:
* This has not added support for complex32 or complex128
* It only works with static shapes, not dynamic shapes. It seems like the dynamic shapes issue is unrelated to cpp wrapper, since it fails in the test_torchinductor_dynamic_shapes.py test. I checked these `test_fft_.*` tests, which I added in this PR, and verified that they were broken with dynamic shapes before any of the code changes from this PR.
**Test**:
```
benchmarks/dynamo/huggingface.py --inductor --amp --accuracy --inference --device cuda --cpp-wrapper --only GoogleFnet
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103183
Approved by: https://github.com/desertfire, https://github.com/jgong5, https://github.com/chunyuan-w
Fix https://github.com/pytorch/pytorch/issues/100830.
For the inplace node, there will be a `copy_` generated and the `copy_` will be `realized` as a `scheduler buffer` since it is a mutation. This `scheduler buffer` is a memory copy but after fusing with the previous buffer, it will not be a memory copy only buffers.
This PR solves the issue by removing `load_bf16_as_fp32` and `store_bf16_from_fp32`. Instead, enable fp32/bf16 vec conversion in `to_dtype`. Then we always store bf16.
```python
import torch
import torch.nn as nn
torch.manual_seed(420)
from torch._inductor import config
x = torch.randn(1, 18, dtype=torch.bfloat16)
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.relu = nn.ReLU(inplace=True) # nn.ReLU(inplace=False)
def forward(self, input1):
out = self.relu(input1)
# input1.copy_(out)
return out
func = ExampleModel()
with torch.no_grad():
func.train(False)
res1 = func(x) # without jit
print(res1)
jit_func = torch.compile(func)
res2 = jit_func(x)
print(res2)
```
Generated code without this PR: (`tm3` store is wrong, `tmp3` is `float` while `out_ptr1` is `bf16`)
```
auto tmp0 = load_bf16_as_float(out_ptr1 + static_cast<long>(i0));
auto tmp1 = (tmp0);
auto tmp2 = at::vec::clamp_min(tmp1, decltype(tmp1)(0));
auto tmp3 = (tmp2);
store_float_as_bf16(out_ptr0 + static_cast<long>(i0), tmp3);
tmp3.store(out_ptr1 + static_cast<long>(i0), 16);
```
Generated code with this PR:
```
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(out_ptr1 + static_cast<long>(i0), 16);
auto tmp1 = cvt_bf16_to_fp32(tmp0);
auto tmp2 = at::vec::clamp_min(tmp1, decltype(tmp1)(0));
auto tmp3 = cvt_fp32_to_bf16(tmp2);
tmp3.store(out_ptr0 + static_cast<long>(i0), 16);
tmp3.store(out_ptr1 + static_cast<long>(i0), 16);
```
This PR also fixed the data type propagation for `masked_subblock`.
Before the masked_subblock's dtype is propagated by its input which is wrong.
```
opcode name target args kwargs
----------- --------- --------- -------------------------- --------
call_module masked_subblock1 masked_subblock1 (and__2, -inf)
```
Now we propagated it by subblock with the same name:
```
# graph for body.subblocks['masked_subblock1']
opcode name target args kwargs
----------- --------- --------- -------------------------- --------
placeholder ops ops () {}
call_module get_index get_index ('index2',) {}
call_method load load (ops, 'arg0_1', get_index) {}
call_method to_dtype to_dtype (ops, load, torch.float32) {}
output output output (to_dtype,) {}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101042
Approved by: https://github.com/jgong5, https://github.com/jansel
Currently reduction bodies are duplicated in several different places.
This reduces duplication by `combine_fn` definition used in
`_unroll_reduction_fn` and using it in the triton codegen. For cpp
this also makes better use of `reduction_combine{,_vec}` by using them
to generate the `omp declare reduction` line and the `vec_reduce_all`
call.
For triton the only change is that that the combine step gets spread
over two lines, e.g. instead of:
```python
_tmp1 = tl.where(rmask & xmask, triton_helpers.maximum(_tmp1, tmp0), _tmp1)
```
we get
```python
tmp2 = triton_helpers.maximum(_tmp1, tmp0)
_tmp1 = tl.where(rmask & xmask, tmp2, _tmp1)
```
For cpp the only change is that inplace reduction operations are now written as
an out-of-place operation and an assignment, e.g. instead if
```cpp
omp_out += omp_in
```
we generate
```cpp
omp_out = omp_out + omp_in
```
Which is a purely cosmetic change
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99661
Approved by: https://github.com/lezcano, https://github.com/ngimel