Commit Graph

433 Commits

Author SHA1 Message Date
xinan.lin
af0e159740 [Inductor XPU] Add XPU check for is_big_gpu(). (#143491)
Fix #143472

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143491
Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/EikanWang
2024-12-21 02:27:04 +00:00
Michael Lazos
8960cb5809 Add support for bfloat16 atomic adds in fbcode (#143629)
Reland https://github.com/pytorch/pytorch/pull/141857 and fallback on A100 which doesn't have bfloat16 atomic add instrs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143629
Approved by: https://github.com/eellison
2024-12-20 23:05:13 +00:00
Michael Lazos
b4e0e3bfa3 Backout D66648013 (#143433)
Summary:
backing out https://www.internalfb.com/diff/D66648013 (see comments there for justification)

I will reland and disallow the bfloat16 atomics behavior on A100 because it causes a pretty significant performance regression.

Test Plan: This is a revert

Differential Revision: D67357485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143433
Approved by: https://github.com/davidberard98
2024-12-19 00:53:49 +00:00
Rachel Guo
9275091d6e [provenance_tracking] Dump inductor_triton_kernel_to_post_grad_nodes.json info in debug_trace (#143055)
Summary:
This diff mainly adds code changes to dump `inductor_triton_kernel_to_post_grad_nodes.json` artifact which contains mapping info from post_grad -> inductor kernel code:
`{"inductor_triton_kernel_name": [post_grad_node_0, post_grad_node_1, ..., ], "..."}.`

Example paste: P1695235000 verified on the test model.  See "Test Plan":

We use this artifact to demonstrate provenance tracking in the frontend 3-tab highlighter tool:
https://github.com/YUNQIUGUO/compiler_explorer (copy/pasted the input files for demo purpose for now and will integrate with Shangdi's tool to 4-tab)

https://pxl.cl/66BzK

Note: Currently only supports mapping for inductor's`TritonKernel` type. TODO for enhancing more support for `ExternKernel` and other inductor generated kernel type, etc.

Test Plan:
test_model_coverage.sh:
```
#!/bin/sh
MODEL_ENTITY_ID=644688112
SNAPSHOT_ID=32
MODULE=merge

# buck2 build --show-output mode/opt -c=python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true -c fbcode.nvcc_arch=a100,h100 caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark

TORCH_COMPILE_DEBUG=1 CUDA_VISIBLE_DEVICES=0 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCH_LOGS="+inductor, schedule, fusion, output_code" TORCH_TRACE="tmp/guorachel_tt" TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 ../buck-out/v2/gen/fbcode/d29ee94b913014f1/caffe2/torch/fb/model_transform/experimental/benchmark/__mts_gpu_benchmark__/mts_gpu_benchmark.par --model-path manifold://ads_storage_fblearner/tree/user/facebook/fblearner/predictor/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/gpu_lowering/input.predictor.disagg.gpu.merge --lower-backend AOT_INDUCTOR_EP --gpu-trace --aot-inductor-config="{'max_autotune': True}" 2>&1 | tee output.txt
```
 {F1973765026}

```
buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:provenance_tracing -- --exact 'caffe2/test/inductor:provenance_tracing - test_triton_kernel_post_grad_mapping_aot_inductor (caffe2.test.inductor.test_provenance_tracing.TestProvenanceTracingArtifact)'
```

```
TORCH_LOGS="+inductor, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:provenance_tracing -- -r test_triton_kernel_post_grad_mapping_aot_inductor
```

Differential Revision: D66967510

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143055
Approved by: https://github.com/chenyang78
2024-12-18 06:51:50 +00:00
Adnan Akhundov
e885225eda Add persistent+TMA version of Triton mm and addmm (#142101)
This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the `tuned_mm` and `tuned_addmm` lowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by the `use_triton_tma_template` helper):

1. The min. hardware and Triton version requirements are met for the TMA support.

2. The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous).

3. The `config.triton.enable_persistent_tma_matmul` is set to `True`.

Additional notes:

1. As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up.

2. The current Triton TMA API (`experimental_device_tensormap_create2d`) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous.

3. The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time).

4. After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in https://github.com/triton-lang/triton/pull/5290) in the new Triton template, which should allow lifting 2 and 3 above.

5. The configs for the new Triton template in `persistent_mm_kernel_configs` are preliminary. We should do more perf exploration and possibly augment the config in a follow-up.

6. This PR is rebased onto and unifies with two related PRs landed previously: https://github.com/pytorch/pytorch/pull/142045 (some infra unification with the persistent+TMA template for _scaled_mm) and https://github.com/pytorch/pytorch/pull/134532 (add possibility to disable prolog fusion for selected choices).

7. The current Triton TMA API only supports 1D and 2D descriptors (even after https://github.com/triton-lang/triton/pull/5290, see [here](9829ce87cc/python/triton/language/core.py (L1957))). For now, this blocks adding persistent+TMA template for `torch.bmm`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142101
Approved by: https://github.com/drisspg, https://github.com/eellison
2024-12-16 19:12:12 +00:00
Tom Ritchford
da67a6a7bb [inductor] Replace set by OrderedSet (#138466)
Uses the set_linter from https://github.com/pytorch/pytorch/pull/138454
and considerable manual editing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138466
Approved by: https://github.com/eellison
2024-12-13 16:08:45 +00:00
Tom Ritchford
dc23f1944a Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-12 17:39:14 +00:00
Blaine Burton Rister
520ba556cd [Inductor] Refactor "r" reduction prefix to {"r0_", "r1_"}. (#142020)
Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243.

# Feature

This PR changes the `RINDEX` / `"r"` symbol type to `(R0_INDEX, R1_INDEX)` and `("r0_", "r1_")`, respectively. This allows the relevant code to support 2D (often ND) reductions. Unlike the parent PR, this one does not change the tiling algorithm, so `"r1_"` is never used. However, it prepares other parts of the system to handle `"r1_"` once we start using it. This should significantly reduce the chances of hitting merge conflicts, making the parent PR much easier to land.

The only change to the generated triton code is to rename `"rindex"` -> `"r0_index"`, `"RBLOCK"` -> `"R0_BLOCK"`, etc. To maintain compatibilty with existing codegen, this also generates aliases to the old reduction variables like `rindex = r0_index`. If we generated 2D reductions (which this PR will not do), the aliases would be more complicated and would collapse 2D multi-indices to linear indices. See some example kernels in the parent PR.

These aliases can be eliminated by the Triton compiler, and should not impact the final machine code running on the GPU. See the perf testing in the parent PR which confirms the aliases do not impact perf.

# Test plan

The existing CI provides good coverage. This PR modifies the expected code in a few places, renaming reduction variables from `r.*` to `r0_.*`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142020
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@meta.com>
2024-12-12 17:22:20 +00:00
PyTorch MergeBot
5c97ac9721 Revert "Remove unused Python variables in torch/[_-a]* (#133492)"
This reverts commit fda975a7b3.

Reverted https://github.com/pytorch/pytorch/pull/133492 on behalf of https://github.com/clee2000 due to Sorry, I need to revert this in order to revert something else.  The only thing you need to do is rebase and remerge ([comment](https://github.com/pytorch/pytorch/pull/133492#issuecomment-2536635516))
2024-12-11 17:29:12 +00:00
Xuan Zhang
ed388394d1 add torchrec collectives to enforce global ordering (#141970)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141970
Approved by: https://github.com/yf225
2024-12-11 02:45:24 +00:00
Tom Ritchford
fda975a7b3 Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-10 21:48:44 +00:00
Michael Lazos
a3abe1a5ae Add support for bfloat16 atomic adds in fbcode (#141857)
This adds support for bfloat16 atomic add in fbcode (OSS will have to wait until those changes are upstreamed to triton)

Originally I attempted to write inline asm, but the triton API was not flexible enough to support this use case. In the long run the right answer is to implement this properly in OSS triton.

relevant issues:
* https://github.com/pytorch/pytorch/issues/137425 in fbcode only
* https://github.com/pytorch/pytorch/issues/97016

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141857
Approved by: https://github.com/eellison
2024-12-10 11:40:15 +00:00
Shangdi Yu
bcddae14ec Enhance "from_node" node meta to track source recursively (#142066)
Summary:
Change the "from_node" node meta format to be able to track the provenance of nodes recursively.

The new "from_node" format is a a list node NodeSource:

```
class NodeSource:
	self.node_name: str
	self.target: str
	self.graph_id: int
	self.pass_name: str
	self.action: str
	self.from_node: List[NoedSource]
```

This is in preparation for the inductor provenance tracking. For background, the inductor provenance tracking doc: https://docs.google.com/document/d/1dGh9myqNhywmbfP0Quzx_f04bghDFlj8cawj8MopiO8/edit?fbclid=IwZXh0bgNhZW0CMTEAAR0jUQ0Tf4ROLDED8Y_eIzrU0KVZVdRmyIQLp-avt-kGRPI_VgYVNyjH_q0_aem_HCQ_pxHDiwOkO9mQyWB2-g&tab=t.0 (internal only),

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_unflatten_multiple_graphs_state
buck run mode/dev-nosan caffe2/test:fx -- -r node_source
```

Differential Revision: D66737916

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142066
Approved by: https://github.com/avikchaudhuri
2024-12-09 23:39:15 +00:00
PyTorch MergeBot
5c76a2834d Revert "add torchrec collectives to enforce global ordering (#141970)"
This reverts commit ceb94d6a7d.

Reverted https://github.com/pytorch/pytorch/pull/141970 on behalf of https://github.com/malfet due to Apologies for reverting this change, but it broke MacOS testing, but CI was broken at the time ([comment](https://github.com/pytorch/pytorch/pull/141970#issuecomment-2529367680))
2024-12-09 20:25:04 +00:00
Jason Ansel
e343f46464 [inductor] Refactor is_big_gpu (#142220)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142220
Approved by: https://github.com/yanboliang
ghstack dependencies: #142219, #142033, #142222
2024-12-08 18:51:36 +00:00
Aaron Gokaslan
2682e5e0d4 [BE]: Add TypeGuard to is_symbolic (#142304)
Improves type inference for is_symbolic. If it's True, it must be either a SymInt or Torch Tensor currently.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142304
Approved by: https://github.com/jansel
2024-12-08 02:18:17 +00:00
Jason Ansel
0367a31401 [inductor] Minor typing changes (#142219)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142219
Approved by: https://github.com/Skylion007, https://github.com/yanboliang
2024-12-07 17:48:37 +00:00
Xuan Zhang
ceb94d6a7d add torchrec collectives to enforce global ordering (#141970)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141970
Approved by: https://github.com/yf225
2024-12-06 22:38:54 +00:00
Mitchell, Frost
20f24e3fbd [inductor][cpp] Add BMM kernel template for autotuning (#129772)
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See #125683 for more background.

1.  Adds `CppBmmTemplate` class which inherits from `CppPackedGemmTemplate`. Given a number of worker threads `num_threads` and batch size `B`, execute the Gemm kernel. For the first `B - (B % num_threads)` batch inputs, run one sub-gemm problem per thread. Then for the remaining `B % num_threads` sub-gemms, we execute each subproblem using the parallelized Gemm kernel.
To manage this code, the `GEMM_TEMPLATE` from `CppPackedGemmTemplate` is rendered two different times, one with a single thread and one which includes the parallel OMP pragma.
2. Adapts `CppPackedGemmTemplate` to allow for child class. The `GEMM_TEMPLATE` is separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methods `get_options()` and `_get_params_for_choices()` are added to reduce code duplication.

BMM within `dlrm` benchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR.

### Performance
On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly.

#### Test Summary on Granite Rapids
Test   Scenario | Comp Item | Date | Compiler | torchbench | huggingface | timm_models
-- | -- | -- | -- | -- | -- | --
Single Socket Multi-Threads | Pass   Rate | gemm autotune| inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
   |     |   |  bmm + gemm autotune | inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
  |  |  Geomean Speedup | gemm autotune| inductor | 2.15x | 1.91x | 2.52x
   |     |   |  bmm + gemm autotune | inductor | 2.15x | 1.96x | 2.53x
Single Core Single-Thread | Pass   Rate | gemm autotune | inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
   |    |   |  bmm + gemm autotune| inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
 |  | Geomean Speedup | inductor_locally_benchmark_586 | inductor | 2.43x | 1.56x | 2.60x
   |    |   |  inductor_locally_benchmark_585 | inductor | 2.45x | 1.56x | 2.63x

This is not the case on an older Skylake Xeon machine.
For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x.

#### BF16 28-core Skylake Xeon
| Model | Inductor | GemmAutotune | Gemm+BMM Autotune |
|--------|--------|--------|--------|
| BERT_pytorch | 1.233x | 2.597x | 2.608x |
| hf_DistilBert | 1.128x | 2.242x | 2.368x |
| hf_Reformer | 1.124x | 1.419x | 1.590x |
| hf_T5_base | 1.012x | 1.257x | 1.382x |
| hf_T5_large | 1.085x | 2.228x | 2.345x |

## Example BMM Code
```
#include <c10/util/Unroll.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

template <bool accum>
inline void cpp_bmm_micro_gemm_amx_kernel_32_2(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc,
    uint8_t tilecfg_rows
) {
    // TODO(jgong5): add prefetch hint for A, B, C
    auto loadconfig = [](const amx_tilecfg& cfg) {
        _tile_loadconfig(&cfg);
    };
    const auto last_k_offset = K / 32 * 32;
    const auto tail_k_size = K - last_k_offset;
    if C10_LIKELY (last_k_offset > 0) {
        amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig);
    } else {
        amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig);
    }
    auto load_c = [&]() {
        _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float));
        _tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float));
        _tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float));
    };
    auto zero_c = [&]() {
        _tile_zero(0);
        _tile_zero(1);
        _tile_zero(2);
        _tile_zero(3);
    };

    if constexpr (accum) {
        load_c();
    } else {
        zero_c();
    }

    auto compute = [&](int k) {
        _tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16));
        _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(0, 4, 6);
        _tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(1, 4, 7);
        _tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16));
        _tile_dpbf16ps(2, 5, 6);
        _tile_dpbf16ps(3, 5, 7);
    };

    #pragma GCC unroll 4
    for (int k = 0; k < last_k_offset; k += 32) {
        compute(k);
    }

    auto store_c = [&]() {
    // store to C
        _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float));
        _tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float));
        _tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float));
    };

    // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
    if C10_UNLIKELY (tail_k_size > 0) {
        if C10_LIKELY (last_k_offset > 0) {
            store_c();
            amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig);
            load_c();
        }
        compute(last_k_offset);
    }

    store_c();
}

template <bool accum>
inline void cpp_bmm_micro_gemm_amx_kernel_16_2(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc,
    uint8_t tilecfg_rows
) {
    // TODO(jgong5): add prefetch hint for A, B, C
    auto loadconfig = [](const amx_tilecfg& cfg) {
        _tile_loadconfig(&cfg);
    };
    const auto last_k_offset = K / 32 * 32;
    const auto tail_k_size = K - last_k_offset;
    if C10_LIKELY (last_k_offset > 0) {
        amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig);
    } else {
        amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig);
    }
    auto load_c = [&]() {
        _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float));
    };
    auto zero_c = [&]() {
        _tile_zero(0);
        _tile_zero(1);
    };

    if constexpr (accum) {
        load_c();
    } else {
        zero_c();
    }

    auto compute = [&](int k) {
        _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16));
        _tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(0, 2, 3);
        _tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(1, 2, 4);
    };

    #pragma GCC unroll 4
    for (int k = 0; k < last_k_offset; k += 32) {
        compute(k);
    }

    auto store_c = [&]() {
    // store to C
        _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float));
    };

    // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
    if C10_UNLIKELY (tail_k_size > 0) {
        if C10_LIKELY (last_k_offset > 0) {
            store_c();
            amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig);
            load_c();
        }
        compute(last_k_offset);
    }

    store_c();
}

template <bool accum>
inline void cpp_bmm_micro_gemm(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t M,
    int64_t N,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc
) {
    AOTI_TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32");
    AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
    // TODO(jgong5): loop unroll for M and N
    for (int64_t n = 0; n < N; n += 32) {
        for (int64_t m = 0; m < M; m += 32) {
            int64_t block_m = std::min<int64_t>(M - m, 32);
            int64_t m_tail = m;
            if (block_m >= 32) {
                cpp_bmm_micro_gemm_amx_kernel_32_2<accum>(
                    amx_state,
                    A + m * lda,
                    B + n,
                    C + m * ldc + n,
                    K,
                    lda,
                    ldb,
                    ldc,
                    16
                );
                block_m -= 32;
                m_tail += 32;
            }
            else
            if (block_m >= 16) {
                cpp_bmm_micro_gemm_amx_kernel_16_2<accum>(
                    amx_state,
                    A + m * lda,
                    B + n,
                    C + m * ldc + n,
                    K,
                    lda,
                    ldb,
                    ldc,
                    16
                );
                block_m -= 16;
                m_tail += 16;
            }
            if (block_m > 0) {
                cpp_bmm_micro_gemm_amx_kernel_16_2<accum>(
                    amx_state,
                    A + m_tail * lda,
                    B + n,
                    C + m_tail * ldc + n,
                    K,
                    lda,
                    ldb,
                    ldc,
                    block_m
                );
            }
        }
    }
}
void threaded_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index)
{

    constexpr int64_t num_threads = 48;
    constexpr int64_t N = 64;
    constexpr int64_t K = 96;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
    constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
    constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
    constexpr int64_t M = static_cast<int64_t>(384L);
    constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
    constexpr int64_t Mt_blocks = 1;
    constexpr int64_t Nt_blocks = 1;
    constexpr int64_t Kt_blocks = 3;
    constexpr int64_t Mc_blocks = 1;
    constexpr int64_t Nc_blocks = 1;
    constexpr int64_t Kc_blocks = 3;
    constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
    constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
    constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
    constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
    constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;

    // make sure all partitions are assigned
    AOTI_TORCH_CHECK(
        Mt_blocks * Nt_blocks * Kt_blocks * 48 >= Mr_blocks * Nr_blocks * Kr_blocks,
        "Not all partitions are assigned."
    );
    #pragma omp parallel num_threads(48)
    {
        const int tid = omp_get_thread_num();
        const int64_t k_group_id = tid / num_Kt_blocks;
        const int64_t k_slice_id = tid % num_Kt_blocks;
        const int64_t n_group_id = k_group_id / num_Nt_blocks;
        const int64_t n_slice_id = k_group_id % num_Nt_blocks;
        const int64_t k_block_start = k_slice_id * Kt_blocks;
        const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
        const int64_t n_block_start = n_slice_id * Nt_blocks;
        const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
        const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
        const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
        const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
        AMXState amx_state;
        auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get();
        for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
            const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
            const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
            const int64_t m_start = mc * Mr;
            const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
            const int64_t m_size = m_end - m_start;
            for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
                const int64_t n_start = nc * Nr;
                const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
                const int64_t n_size = n_end - n_start;
                // NB: assume we pad N, nc_block_end won't exceed padded N here.
                const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
                if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); }
                for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
                    int64_t k_start = kc * Kr;
                    int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
                    for (int64_t nci = nc; nci < nc_block_end; nci++) {
                        if (kc == k_block_start) {
                            cpp_bmm_micro_gemm<static_cast<bool>(false)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        } else {
                            cpp_bmm_micro_gemm<static_cast<bool>(true)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        }
                    }
                }
                {
                    {
                        #pragma GCC ivdep
                        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16));
                            }
                            for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                            }
                        }
                    }

                }
            }
        }
        amx_state.release([]() { _tile_release(); });
    }
}
void single_thread_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index)
{

    constexpr int64_t num_threads = 1;
    constexpr int64_t N = 64;
    constexpr int64_t K = 96;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
    constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
    constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
    constexpr int64_t M = static_cast<int64_t>(384L);
    constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
    constexpr int64_t Mt_blocks = 12;
    constexpr int64_t Nt_blocks = 2;
    constexpr int64_t Kt_blocks = 3;
    constexpr int64_t Mc_blocks = 12;
    constexpr int64_t Nc_blocks = 1;
    constexpr int64_t Kc_blocks = 3;
    constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
    constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
    constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
    constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
    constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;

    // make sure all partitions are assigned
    AOTI_TORCH_CHECK(
        Mt_blocks * Nt_blocks * Kt_blocks * 1 >= Mr_blocks * Nr_blocks * Kr_blocks,
        "Not all partitions are assigned."
    );
    {
        constexpr int tid = 0;
        constexpr int64_t k_group_id = 0;
        constexpr int64_t k_slice_id = 0;
        constexpr int64_t n_group_id = 0;
        constexpr int64_t n_slice_id = 0;
        constexpr int64_t m_block_start = 0;
        constexpr int64_t n_block_start = 0;
        constexpr int64_t n_block_end = Nr_blocks;
        constexpr int64_t k_block_start = 0;
        constexpr int64_t k_block_end = Kr_blocks;
        constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
        constexpr int64_t m_block_end = Mr_blocks;
        AMXState amx_state;
        auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get();
        for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
            const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
            const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
            const int64_t m_start = mc * Mr;
            const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
            const int64_t m_size = m_end - m_start;
            for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
                const int64_t n_start = nc * Nr;
                const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
                const int64_t n_size = n_end - n_start;
                // NB: assume we pad N, nc_block_end won't exceed padded N here.
                const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
                if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); }
                for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
                    int64_t k_start = kc * Kr;
                    int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
                    for (int64_t nci = nc; nci < nc_block_end; nci++) {
                        if (kc == k_block_start) {
                            cpp_bmm_micro_gemm<static_cast<bool>(false)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        } else {
                            cpp_bmm_micro_gemm<static_cast<bool>(true)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        }
                    }
                }
                {
                    {
                        #pragma GCC ivdep
                        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16));
                            }
                            for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                            }
                        }
                    }

                }
            }
        }
        amx_state.release([]() { _tile_release(); });
    }
}
extern "C"
void cpp_bmm(const bfloat16* X, const bfloat16* W, bfloat16* Y)
{
    const int64_t B = static_cast<int64_t>(5L);
    constexpr int64_t num_threads = 48;
    int64_t B_single_thread_block = (B / num_threads) * num_threads;

    #pragma omp parallel for num_threads(48)
    for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) {
        single_thread_mm(X, W, Y, b_start);
    }
    for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) {
        threaded_mm(X, W, Y, b_start);
    }
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129772
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
2024-12-06 04:54:00 +00:00
Brian Hirsh
471017cbc9 avoid specializing strides with DDPOptimizer + inductor (#140751)
Fixes https://github.com/pytorch/pytorch/issues/140229

Fixes https://github.com/pytorch/pytorch/issues/139474

The issue was that:

(1) DDPOptimizer has some logic to partition the dynamo graph into buckets, and run AOTAutograd/inductor on each bucket

(2) doing so requires knowing the **exact** strides of the outputs of each subgraph, so we can have example inputs (with correct strides) to each of the later subgraphs to compile with

(3) there is some existing logic to do this today: we have a `fakify_first_call` flag in AOTAutograd that lets you run it with fake tensor inputs (to handle the calling convention changes that AOTAutograd performs at runtime). During this process, we query inductor for the output strides that it compiled with

(4) these outputs strides are stored in the FX graph cache as raw strings of sympy expressions. We have a function, `evaluate_symexpr`, which given the sympy string, and the ShapeEnv's `var_to_val` mapping, will evaluate the sympy string to generate concrete strides

(5) evaluating this expression will specialize on the exact values of any variables in our shape env, however. In DDPOptimizer, we want to know what inductor's stride outputs are symbolically. This requires converting the (string) sympy expression into actual `SymInts` that we can return.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140751
Approved by: https://github.com/eellison
2024-12-05 03:41:12 +00:00
eellison
fd35be2fd3 TritonTemplate dtype fixes (#141991)
- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype
- Update dtype propagation to use `type_to_dtype` instead of instantiating tensor
- Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141991
Approved by: https://github.com/drisspg
ghstack dependencies: #139945, #140057, #141495, #141882
2024-12-04 17:24:23 +00:00
iupaikov-amd
5c2584a14c [ROCm] Enable inductor GEMM lowering for gfx11 (#141687)
This check doesn't make sense for some of the AMD gpus since they have the right amount of CUs but multi_processor_count returns WGPs on RDNA while still performing adequately. A lot of tests fail on modern archs due to this check defaulting them to not using the GEMMs backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141687
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2024-12-02 22:13:34 +00:00
eellison
f83361b274 inductor dtype propagation fixes (#141495)
- Add in upcast_compute_type on creation of new tensors (loads, constants)
- Fixes index_expr - right now we are sort of inconsistent in dtype and dont always respect the dtype specified. would be nice to fix but not doing in this pr.
- bug fix in view dtype where we were always upcasting back to fp32 when input was in bf16/fp16. we should only be doing that if the output is also in bf16/fp16.
- for masked, avoid calling dtype propagation and just use output dtype.

Turns on the runtime dtype verification for opinfo tests. The separate test file is still useful because we can use it for testing turning off codegen_upcast_to_fp32.

Follow ups:

- We could consider requiring less explicit upcast_compute_types calls and do it automatically. That would potentially make things easier but be less flexible in the future. Maybe I should have done it this pr.
- Be more consistent on our index expr dtype printing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141495
Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang
ghstack dependencies: #139945, #140057
2024-11-28 11:39:38 +00:00
Edward Z. Yang
dbbebee9d7 Code motion CompiledFxGraph to a dedicated file (#141654)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141654
Approved by: https://github.com/aorenste, https://github.com/jansel
ghstack dependencies: #141491, #141492, #141574
2024-11-27 20:42:21 +00:00
James Wu
a7ca6a9113 Enable autograd cache on inductor tests (#140890)
This turns on AOTAutogradCache for all inductor tests. It clears AOTAutogradCache on each test as well, by virtue of the local cache using the same directory to store cache entries.

I've also tested with INDUCTOR_TEST_DISABLE_FRESH_CACHE=1, running all the tests. AOTAutogradCache successfully caches 99% of these. There are a few tests that use view_replay and therefore save functional tensors, which cause AOTAutogradCache to fail to pickle its result. Will look into next steps there, but for now, it seems okay if the cache just misses on those cases where it can't serialize the result. It would be better to check before pickling, though.

I've made the following small bugfixes to get this working:
- Inductor is sometimes used in a standalone mode without dynamo, which leads to attribute errors in check_can_cache. In general, we should *never* crash in cache checking, only bypass. So I change a try catch to check Exception instead of just a specific exception.
- Add extra structured logging for metadata on cache hits

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140890
Approved by: https://github.com/bdhirsh
2024-11-27 20:41:43 +00:00
Boyuan Feng
17fd53d8e5 [Inductor] Inplacing with Donated Buffer (#140113)
Currently, inductor does not inplace update a buffer if it is an input buffer. Because we don't know if an input will be used by other functions.

Donated buffer provides additional information that an input buffer will not be used by other functions. So we can inplace update donated buffer when possible.

[Dashboard](https://hud.pytorch.org/benchmark/torchbench/inductor_dynamic?dashboard=torchinductor&startTime=Mon,%2011%20Nov%202024%2018:14:36%20GMT&stopTime=Mon,%2018%20Nov%202024%2018:14:36%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=bf/donated-buffer-inplace&lCommit=5df0769c00e6f9000caeb10fd5cbf0b165f69c2a&rBranch=main&rCommit=2b39a8db7741b816b03677a9c6fec1af05640dee)

![image](https://github.com/user-attachments/assets/f19d961f-7973-418e-9de8-5c2a97950478)
![image](https://github.com/user-attachments/assets/df3bd6a9-58b8-4e8a-8397-9e3b1de9adfe)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140113
Approved by: https://github.com/eellison
2024-11-27 18:51:52 +00:00
eellison
fd553b9817 Add remaining method and tests for dtype propagation (#140057)
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule.

We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140057
Approved by: https://github.com/arui-meta, https://github.com/blaine-rister, https://github.com/ezyang
ghstack dependencies: #139945
2024-11-27 17:06:44 +00:00
eellison
566ceb3e7e Refactor dtype propagation (#139945)
A couple changes.

- Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later.

- Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache.

Tests get added later in the stack when everything is implemented.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139945
Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang
2024-11-27 16:57:02 +00:00
PyTorch MergeBot
65dbd5cc2d Revert "[Inductor] Inplacing with Donated Buffer (#140113)"
This reverts commit eecc8e362c.

Reverted https://github.com/pytorch/pytorch/pull/140113 on behalf of https://github.com/BoyuanFeng due to break test_donated_buffer_inplace internally since donated_buffer = False if is_fbcode() else True ([comment](https://github.com/pytorch/pytorch/pull/140113#issuecomment-2501954300))
2024-11-26 21:20:59 +00:00
Boyuan Feng
eecc8e362c [Inductor] Inplacing with Donated Buffer (#140113)
Currently, inductor does not inplace update a buffer if it is an input buffer. Because we don't know if an input will be used by other functions.

Donated buffer provides additional information that an input buffer will not be used by other functions. So we can inplace update donated buffer when possible.

[Dashboard](https://hud.pytorch.org/benchmark/torchbench/inductor_dynamic?dashboard=torchinductor&startTime=Mon,%2011%20Nov%202024%2018:14:36%20GMT&stopTime=Mon,%2018%20Nov%202024%2018:14:36%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=bf/donated-buffer-inplace&lCommit=5df0769c00e6f9000caeb10fd5cbf0b165f69c2a&rBranch=main&rCommit=2b39a8db7741b816b03677a9c6fec1af05640dee)

![image](https://github.com/user-attachments/assets/f19d961f-7973-418e-9de8-5c2a97950478)
![image](https://github.com/user-attachments/assets/df3bd6a9-58b8-4e8a-8397-9e3b1de9adfe)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140113
Approved by: https://github.com/eellison
2024-11-26 17:19:50 +00:00
Colin Peppler
8f5edcb75c [CUTLASS] Lift shape & stride information as kernel args (#138611)
Differential Revision: [D64773324](https://our.internmc.facebook.com/intern/diff/D64773324)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138611
Approved by: https://github.com/chenyang78
2024-11-25 17:52:33 +00:00
Jason Ansel
808f0f656d [inductor] Refactor MutableBox to make IRNode typing easier (#140895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140895
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2024-11-20 19:50:46 +00:00
PyTorch MergeBot
d472a5f680 Revert "[inductor] Refactor MutableBox to make IRNode typing easier (#140895)"
This reverts commit c79e78b503.

Reverted https://github.com/pytorch/pytorch/pull/140895 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think test_torchbind_inductor is failing in trunk after this lands ([comment](https://github.com/pytorch/pytorch/pull/140895#issuecomment-2484679319))
2024-11-19 04:25:41 +00:00
Jason Ansel
c79e78b503 [inductor] Refactor MutableBox to make IRNode typing easier (#140895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140895
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2024-11-19 00:24:35 +00:00
Jack Taylor
a33fa37b4e [ROCm] Support new AMD triton stream pipeliner (#139881)
Fixes #139182

In Triton 3.2 num_stages=0 will be deprecated with Triton's AMD backend. Let's query default num_stages from the relevant triton backend

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139881
Approved by: https://github.com/bertmaher
2024-11-08 14:51:05 +00:00
xinan.lin
320374b011 [Inductor] Refine triton_bundler.py to support correctly on Intel GPU and fix CI failures. (#139705)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139705
Approved by: https://github.com/EikanWang, https://github.com/jansel, https://github.com/guangyey
2024-11-07 07:10:28 +00:00
Max Podkorytov
8f077b811b [ROCm][Inductor]Fixing missing ck package warning when the backend is disabled (#139790)
```

test_addmm_multiple_dynamic_cuda (__main__.AOTInductorTestABICompatibleCuda) ... W1101 10:26:20.492000 1361741 torch/_inductor/utils.py:1207] Please pip install Composable Kernel package
AUTOTUNE addmm(16x6, 16x16, 16x6)
  triton_mm_0 0.0104 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=0, num_stages=2, num_warps=1
  triton_mm_1 0.0104 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=2, num_warps=1
SingleProcess AUTOTUNE benchmarking takes 0.2182 seconds and 0.2979 seconds precompiling for 2 choices
```
This PR disables the warning message when the CK backend is disabled

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139790
Approved by: https://github.com/ColinPeppler, https://github.com/chenyang78
2024-11-06 22:04:32 +00:00
eellison
aafb3deaf1 Remove multinomial from cudagraph skip list' (#139897)
Since https://github.com/pytorch/pytorch/pull/134818/files we can run multinomial in cudagraph without error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139897
Approved by: https://github.com/BoyuanFeng
2024-11-06 21:28:42 +00:00
Jason Ansel
ed30fa74ab [inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (#139523)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139523
Approved by: https://github.com/ezyang
ghstack dependencies: #139364, #139365, #139370, #139452
2024-11-04 04:28:40 +00:00
Jason Ansel
d189f92eb1 [inductor] Remove SIMDKernel.last_usage (#139364)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139364
Approved by: https://github.com/eellison, https://github.com/shunting314
2024-11-04 04:28:18 +00:00
PyTorch MergeBot
0863d6a08e Revert "[inductor] Remove SIMDKernel.last_usage (#139364)"
This reverts commit 286d3ce266.

Reverted https://github.com/pytorch/pytorch/pull/139364 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing lots of internal tests in D65345157 ([comment](https://github.com/pytorch/pytorch/pull/139364#issuecomment-2452897337))
2024-11-02 06:49:11 +00:00
PyTorch MergeBot
98e11b0021 Revert "[inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (#139523)"
This reverts commit c53beab377.

Reverted https://github.com/pytorch/pytorch/pull/139523 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing lots of internal tests in D65345157 ([comment](https://github.com/pytorch/pytorch/pull/139364#issuecomment-2452897337))
2024-11-02 06:49:10 +00:00
Jason Ansel
c53beab377 [inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (#139523)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139523
Approved by: https://github.com/ezyang
ghstack dependencies: #139364, #139365, #139370, #139452
2024-11-02 03:04:22 +00:00
Jason Ansel
286d3ce266 [inductor] Remove SIMDKernel.last_usage (#139364)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139364
Approved by: https://github.com/eellison, https://github.com/shunting314
2024-11-01 16:28:15 +00:00
Bin Bao
33dce10ece [AOTI][reland] Update zero size computation in clone_preserve_strides (#139458)
Summary: Reland https://github.com/pytorch/pytorch/pull/139224. clone_preserve_strides implemented in _inductor/utils.py does not handle multi-dimensional 0-size tensor correctly.

Differential Revision: [D65317451](https://our.internmc.facebook.com/intern/diff/D65317451)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139458
Approved by: https://github.com/hl475
2024-11-01 13:51:02 +00:00
PyTorch MergeBot
c4d9428b17 Revert "[AOTI] Update zero size computation in clone_preserve_strides (#139224)"
This reverts commit 206a8dde68.

Reverted https://github.com/pytorch/pytorch/pull/139224 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/139224#issuecomment-2450811914))
2024-10-31 21:05:07 +00:00
Bin Bao
206a8dde68 [AOTI] Update zero size computation in clone_preserve_strides (#139224)
Summary: clone_preserve_strides implemented in _inductor/utils.py does not handle multi-dimensional 0-size tensor correctly. Fix that.

Differential Revision: [D65250405](https://our.internmc.facebook.com/intern/diff/D65250405)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139224
Approved by: https://github.com/angelayi
2024-10-31 17:07:18 +00:00
rzou
ccaa2a206a [inductor] make requires_stride_order more unbacked-symint-aware (#137063)
Previously, we tried to sort SymInt strides to determine the stride
order. This PR makes the sorting more unbacked symint aware: given a Tensor
with sizes (u0, u1, u2), it has strides (u1 * u2, u1, 1), which is
sortable under the guard_size_oblivious assumptions.

Test Plan:
- test case

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137063
Approved by: https://github.com/eellison
2024-10-31 13:11:02 +00:00
eellison
4db6b740bc [Easy] GraphTransformObserver Refactoring (#139292)
Uses `torch._inductor.config.trace.log_url_for_graph_xform` by default as the log url. It was only ever instantiated with this as the log_url argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139292
Approved by: https://github.com/shengfukevin, https://github.com/shunting314
2024-10-31 00:33:28 +00:00
Colin L. Rice
0b9320b7c5 fx_graph_cache: Remove custom amd JK (#137501)
This split in JKs was never actually used (We just set both JKs to the same values except when we accidentally didn't due to being humans who make mistakes). This simplifies the overall JK structure and eventually, will let us delete the duplicate JK

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137501
Approved by: https://github.com/oulgen
2024-10-24 01:30:39 +00:00