As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here:
1. The vectorized itervar has constant stride in the index_expr. We vectorize the index_expr with `Vectorized<int32>::arange` for this case.
2. Otherwise, we load the index_expr vector in a non-contiguous way with a loop.
Below is the generated code for the first case from the test `test_concat_inner_vec`. Here `x1` is the index_expr and depends on the vectorized itervar `x1`. It has constant stride 1. We vectorized it with arange. We use `all_zero` to implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid.
Before:
```c++
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(155L); x1+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<long>(x1);
auto tmp1 = static_cast<long>(0);
auto tmp2 = tmp0 >= tmp1;
auto tmp3 = static_cast<long>(35);
auto tmp4 = tmp0 < tmp3;
auto tmp5 = [&]
{
auto tmp6 = in_ptr0[static_cast<long>(x1 + (35L*x0))];
return tmp6;
}
;
auto tmp7 = tmp4 ? tmp5() : static_cast<decltype(tmp5())>(0.0);
auto tmp8 = tmp0 >= tmp3;
auto tmp9 = static_cast<long>(155);
auto tmp10 = tmp0 < tmp9;
auto tmp11 = [&]
{
auto tmp12 = in_ptr1[static_cast<long>((-35L) + x1 + (120L*x0))];
return tmp12;
}
;
...
```
After:
```c++
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(144L); x1+=static_cast<long>(16L))
{
auto tmp0 = c10::convert<int>(x1);
auto tmp1 = at::vec::Vectorized<int32_t>::arange(tmp0, 1);
auto tmp2 = static_cast<int>(0);
auto tmp3 = at::vec::Vectorized<int>(tmp2);
auto tmp4 = to_float_mask(tmp1 >= tmp3);
auto tmp5 = static_cast<int>(35);
auto tmp6 = at::vec::Vectorized<int>(tmp5);
auto tmp7 = to_float_mask(tmp1 < tmp6);
auto tmp8 = [&]
{
auto tmp9 = masked_load(in_ptr0 + static_cast<long>(x1 + (35L*x0)), to_float_mask(tmp7));
return tmp9;
}
;
auto tmp10 =
[&]
{
if (all_zero(to_float_mask(tmp7)))
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp8())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp8(), to_float_mask(tmp7));
}
}
()
;
...
```
Below is the generated code for the second case from the test case `test_expr_vec_non_contiguous`. Here, the index_expr is `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` which depends on the vectorized itervar `x2` and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code uses `vector_lane_mask_check` to implement the masked version of non-contiguous load.
Before:
```c++
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)));
auto tmp1 = static_cast<long>(2048);
auto tmp2 = tmp0 < tmp1;
auto tmp3 = [&]
{
auto tmp4 = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (2048L*(static_cast<long>(x1) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
return tmp4;
}
;
auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
tmp_acc0 = max_propagate_nan(tmp_acc0, tmp5);
}
out_ptr0[static_cast<long>(x1 + (1024L*x0))] = tmp_acc0;
}
}
}
```
After:
```c++
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
{
{
#pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
{
auto tmp0 =
[&]
{
__at_align__ std::array<int, 16> tmpbuf;
#pragma GCC unroll 16
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L)));
}
return at::vec::Vectorized<int>::loadu(tmpbuf.data());
}
()
;
auto tmp1 = static_cast<int>(2048);
auto tmp2 = at::vec::Vectorized<int>(tmp1);
auto tmp3 = to_float_mask(tmp0 < tmp2);
auto tmp4 = [&]
{
auto tmp5 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
if (vector_lane_mask_check(tmp3, x1_inner))
{
tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data());
}
()
;
return tmp5;
}
;
auto tmp6 =
[&]
{
if (all_zero(to_float_mask(tmp3)))
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3));
}
}
()
;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6);
}
tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
}
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114545
Approved by: https://github.com/lezcano
Fixes#114310 and supersedes #114748.
There are two reasons why we have quite a few special cases for `round`:
1. `round` is actually two ops. With `ndigits=None` (default), `round` always returns an integer. When `ndigits` is an integer, the returned type is a float.
2. Although `round` takes two arguments, it is a unary function with a parameter rather than a binary one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115259
Approved by: https://github.com/peterbell10, https://github.com/lezcano
This PR fixed#104791
bitcast requires the source and target have the bitwidth.
Because the input tensor's dtype could be promoted, e.g. from float16 to
float, we have to cast the tensor to its original source dtype before
invoking bitcast in such cases. After that, we also need to convert
the bit-casted tensor back to float to make sure we keep using higher
precision values for the rest of the computation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115619
Approved by: https://github.com/jansel, https://github.com/eellison
Fixes#97361
When fused kernel more than 1024 parameters, it should throw error from ctypes.
Limit args number is should be a mechanism to protect stack memory. As we known, CPP is passing args via stack memory, and stack memory has size limitation.
Code change:
1. cpp backend will check the fused nodes' args number, if it is reach the limitation. It will status flush status to ready.
2. scheduler will check `ready_to_flush` API and help backend flush codegen.
3. Add `ready_to_flush` API to `BaseScheduling`, Triton backend will return False due to not support it yet.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113131
Approved by: https://github.com/jgong5, https://github.com/mlazos
For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar.
This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by `CppVecKernel` when any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented with `CppCSEVariable` which bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector. `CppVecOverrides` is improved to propagate these states when the ops are executed.
For the added UT `test_embedding_vec`, the generated code before this PR is:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = in_ptr2[static_cast<long>(x1 + (128L*x0))];
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = in_ptr1[static_cast<long>(x1 + (128L*tmp3))];
auto tmp6 = decltype(tmp4)(tmp4 + tmp5);
out_ptr0[static_cast<long>(x1 + (128L*x0))] = tmp6;
}
}
}
}
}
```
After this PR, we have:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(16L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x1 + (128L*x0)));
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (128L*tmp3)));
auto tmp6 = tmp4 + tmp5;
tmp6.store(out_ptr0 + static_cast<long>(x1 + (128L*x0)));
}
}
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114062
Approved by: https://github.com/jansel
For embedding lookup, there are indirect indexing with indices that are invariant to the vectorized itervar. To vectorize it, we need to keep the related indexing variables as scalars and allow vectorization when the related index_exprs are invariant to the vectorized itervar.
This PR adds the support by lazily broadcasting scalar values (index_expr and constant) to vectors so that vector operations are only generated if needed by `CppVecKernel` when any of the inputs are vectors, otherwise, scalar ops are generated. The cse variable in cpp is now represented with `CppCSEVariable` which bookkeeps the relevant itervars to the variable and has a flag to mark whether it is a scalar or a vector. `CppVecOverrides` is improved to propagate these states when the ops are executed.
For the added UT `test_embedding_vec`, the generated code before this PR is:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = in_ptr2[static_cast<long>(x1 + (128L*x0))];
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = in_ptr1[static_cast<long>(x1 + (128L*tmp3))];
auto tmp6 = decltype(tmp4)(tmp4 + tmp5);
out_ptr0[static_cast<long>(x1 + (128L*x0))] = tmp6;
}
}
}
}
}
```
After this PR, we have:
```c++
extern "C" void kernel(const long* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0)
{
#pragma omp parallel num_threads(64)
{
{
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(128L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(128L); x1+=static_cast<long>(16L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp5 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<long>(x1 + (128L*x0)));
auto tmp1 = decltype(tmp0)(tmp0 + 64);
auto tmp2 = tmp0 < 0;
auto tmp3 = tmp2 ? tmp1 : tmp0;
TORCH_CHECK((0 <= tmp3) & (tmp3 < 64L), "index out of bounds: 0 <= tmp3 < 64L")
auto tmp4 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (128L*tmp3)));
auto tmp6 = tmp4 + tmp5;
tmp6.store(out_ptr0 + static_cast<long>(x1 + (128L*x0)));
}
}
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114062
Approved by: https://github.com/jansel
ghstack dependencies: #113950
Fixes#97361
When fused kernel more than 1024 parameters, it should throw error from ctypes.
Limit args number is should be a mechanism to protect stack memory. As we known, CPP is passing args via stack memory, and stack memory has size limitation.
Code change:
1. cpp backend will check the fused nodes' args number, if it is reach the limitation. It will status flush status to ready.
2. scheduler will check `ready_to_flush` API and help backend flush codegen.
3. Add `ready_to_flush` API to `BaseScheduling`, Triton backend will return False due to not support it yet.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113131
Approved by: https://github.com/jgong5, https://github.com/mlazos
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113602
Approved by: https://github.com/albanD
Previously, floor_div operations were defined in
ATen/native/BinaryOps.h. Since this header was not included under
ABI-compat mode, trying to use those indexing operations would result in
compilation errors.
Technically, it is safe to use aten::native::floor_div_* functions in
ABI-compat mode as they are header-only; we could simply include
BinaryOps.h. However, there are other declarations in BinaryOps.h that
are not binary-compatible, so this is not ideal. Thus, I have moved those
functions into a separate file, and put them under c10/util, since they
don't really have tensor-specific logic.
c10 functions are not all header-only, so this still isn't ideal, but
this still seems like an improvement. Moreover, cpp_prefix.h -- used
when compiling cpp kernels -- already includes c10 header files, so
ABI-compatibility already depends on maintaining some c10 functions as
header-only.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113276
Approved by: https://github.com/chenyang78, https://github.com/desertfire
This was originally ipiszy's PR: https://github.com/pytorch/pytorch/pull/112358
It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I am passing in optional types via
pointer instead.
`AtenTensorHandle`s are already pointers, so nothing needs to change
there. Only value types need to change.
We decided on this approach instead of adding an extra `bool` param to
the callee because this simplifies things. Having the same number of
arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.
There are a number of existing ABI-compatible functions that have
optional-typed value parameters. Previously, they just assumed they
would never be passed a `nullopt` / `None` at runtime. Changing them to
use pointer types now would break ABI stability, so I have created an
exclude list for those functions.
Finally, I think the current implementation is kind of messy, and only
works for FallbackKernels, even though technically ExternKernels could
also have the same issue. It also doesn't support optional types nested
in lists. I've left FIXME comments for both issues.
Differential Revision: [D51084289](https://our.internmc.facebook.com/intern/diff/D51084289)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112527
Approved by: https://github.com/chenyang78, https://github.com/desertfire
`install_config_module` makes a regular module into a ConfigModule with
extra methods defined on it. mypy thinks those extra methods (or module
functions) are undefined since it cannot analyze something so
dynamic. As a workaround, I've created a fake module that defines these
extra functions, which I import into the config modules during type
checking.
As part of this change, I've also added more types to config_utils.py
and enabled typechecking for torch/_dynamo/config.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112130
Approved by: https://github.com/jansel
We spend somewhere on the order 1% in `sympy.Expr.free_symbols` as it is called millions of times.
Most of the time we actually just want to know "is this a constant", however `e.is_constant()` is
horribly slow. It turns out though that there is another propery `is_number` that does what we want.
> property is_number:
>
> Returns True if self has no free symbols and no undefined functions (AppliedUndef, to be precise). It will be faster
> than if not self.free_symbols, however, since is_number will fail as soon as it hits a free symbol or undefined
> function.
Even further, we also avoid the overhead of building the unnecessary set object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112688
Approved by: https://github.com/lezcano
**Summary**
Follow up https://github.com/pytorch/pytorch/pull/109893 which has issue in support of CPU as reported in https://github.com/pytorch/pytorch/issues/109897. This fix mainly includes 2 changes:
- Current implementation of `rename_indexing`
10c646295d/torch/_inductor/codegen/common.py (L1023) only add symbol name start with `s` or `ps` into `kernel.args.sizevars`. However, `Unbacked symint` will start as `i`, so we extend the implementation of `rename_indexing` to support symbol start with `i`.
- Currently, the internal loop index also name start as `i`. Since `i` has has been used as `Unbacked symint`, change the name to start with `x` which should align with trition.
**Test Plan**
```
python -u -m pytest -s -v test_torchinductor_dynamic_shapes.py -k test_bool_mask_nobreak
python -u -m pytest -s -v test_torchinductor_dynamic_shapes.py -k test_nonzero_size_factory_nobreak
python -u -m pytest -s -v test_torchinductor_dynamic_shapes.py -k test_item_zeros_nobreak
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110262
Approved by: https://github.com/ezyang, https://github.com/jgong5
Summary:
This is a prototype for running extern fallback kernels with a host side proxy executor.
Sample of generated cpp wrapper call:
```
at::Tensor buf0; // output buffer
void* tensor_args_var_0[] = {&arg0_1, &arg0_1, &arg1_1, &arg0_1, &arg1_1, &buf0};
int64_t int_args_var_1[] = {81, 81, 7, 7, 7, 81};
proxy_executor->call_function("buf0", int_args_var_1, tensor_args_var_0);
```
- In my current implementation, proxy executor interprets the raw pointers according to the ops schema.
This assumes that custom op MUST have a valid schema registered to Dispatcher. (I would like to validate this assumption)
- I am using callboxed() API of the custom kernels. This is inevitable, as we wish to have a single call_function API for all possible custom kernels.
- These are all the input argument types I have support so far.
union Argument {
# Bool value does not matter
1: bool asNone;
2: TensorArgument asTensor;
3: list<TensorArgument> asTensors;
5: i64 asInt;
7: list<i64> asInts;
8: double asFloat;
9: list<double> asFloats;
10: string asString;
10.5: list<string> asStrings;
11: SymIntArgument asSymInt;
12: list<SymIntArgument> asSymInts;
13: ScalarType asScalarType;
14: MemoryFormat asMemoryFormat;
15: Layout asLayout;
16: Device asDevice;
17: bool asBool;
18: list<bool> asBools;
}
- Need a policy for handling unpopulated argument with default values. Here are the options, and it has BC implications.
1. requires exported fx graph to explicitly populate default values, if users doesn't specify.
2. requires cpp wrapper to explicitly populate default values, if fx graph doesn't specify.
3. Proxy executor look up from opSchema for default values.
For fixing T162112344
Test Plan:
frontend:
buck2 run mode/dev-sand mode/inplace -c fbcode.enable_gpu_sections=True sigmoid/frontend:export_main
test:
buck2 run mode/dev-sand //deeplearning/aot_inductor/test:test_custom_ops
backend:
buck2 run mode/dev-nosan //deeplearning/aot_inductor/fb:main
buck2 test 'fbcode//mode/opt' fbcode//caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark -- --exact 'caffe2/torch/fb/model_transform/experimental/benchmark/test:test_aot_inductor_benchmark - test_aot_inductor_benchmark_cmf30x (caffe2.torch.fb.model_transform.experimental.benchmark.test.test_aot_inductor_benchmark.AOTInductorBenchmark)'
Reviewed By: suo
Differential Revision: D48747417
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108350
Approved by: https://github.com/izaitsevfb
Inductor kernel codegen previously have the following side effect:
- in `Kernel.__exit__ `, we add local used buffers in graph.removed_buffers
- during codegen, we do memory allocation/free.
These cause doing multiple versions of codegen for the same kernel hard. The PR refactor the code to make kernel codegen not changing graph level states. After codegening a kernel, the graph level state is not changed so we can go on to codegen another version of the kernel if we want.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107617
Approved by: https://github.com/jansel
For max_pooling code:
```
#pragma GCC ivdep
for(long i2=static_cast<long>(0L); i2<static_cast<long>(56L); i2+=static_cast<long>(1L))
{
for(long i3=static_cast<long>(0L); i3<static_cast<long>(64L); i3+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<int>(static_cast<int>((-1L) + (2L*i1)));
auto tmp1 = at::vec::Vectorized<int>(static_cast<int>(0));
auto tmp2 = to_float_mask(tmp0 >= tmp1);
auto tmp3 = at::vec::Vectorized<int>(static_cast<int>(112));
auto tmp4 = to_float_mask(tmp0 < tmp3);
auto tmp5 = tmp2 & tmp4;
auto tmp6 = at::vec::Vectorized<int>(static_cast<int>((-1L) + (2L*i2)));
auto tmp7 = to_float_mask(tmp6 >= tmp1);
auto tmp8 = to_float_mask(tmp6 < tmp3);
auto tmp9 = tmp7 & tmp8;
auto tmp10 = tmp5 & tmp9;
auto tmp11 = [&]
{
auto tmp12 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>((-7232L) + i3 + (128L*i2) + (14336L*i1) + (802816L*i0)), 16);
load
auto tmp13 = cvt_lowp_fp_to_fp32<bfloat16>(tmp12);
return tmp13;
}
;
auto tmp14 = decltype(tmp11())::blendv(at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()), tmp11(), to_float_mask(tmp10));
```
the index of ```tmp12 ``` may be a correct index, such as ```i1=0, i2=0, i3=0```, the index is ```-7232L```, it is not a valid index. We may meet segmentation fault error when we call ```tmp11()```, the original behavior is that only the ```tmp10```(index check variable) is true, we can safely get the value, this PR will support masked_load to fixing this issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107670
Approved by: https://github.com/jgong5, https://github.com/jansel