Mitchell, Frost
20f24e3fbd
[inductor][cpp] Add BMM kernel template for autotuning ( #129772 )
...
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See #125683 for more background.
1. Adds `CppBmmTemplate` class which inherits from `CppPackedGemmTemplate`. Given a number of worker threads `num_threads` and batch size `B`, execute the Gemm kernel. For the first `B - (B % num_threads)` batch inputs, run one sub-gemm problem per thread. Then for the remaining `B % num_threads` sub-gemms, we execute each subproblem using the parallelized Gemm kernel.
To manage this code, the `GEMM_TEMPLATE` from `CppPackedGemmTemplate` is rendered two different times, one with a single thread and one which includes the parallel OMP pragma.
2. Adapts `CppPackedGemmTemplate` to allow for child class. The `GEMM_TEMPLATE` is separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methods `get_options()` and `_get_params_for_choices()` are added to reduce code duplication.
BMM within `dlrm` benchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR.
### Performance
On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly.
#### Test Summary on Granite Rapids
Test Scenario | Comp Item | Date | Compiler | torchbench | huggingface | timm_models
-- | -- | -- | -- | -- | -- | --
Single Socket Multi-Threads | Pass Rate | gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61
| | | bmm + gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61
| | Geomean Speedup | gemm autotune| inductor | 2.15x | 1.91x | 2.52x
| | | bmm + gemm autotune | inductor | 2.15x | 1.96x | 2.53x
Single Core Single-Thread | Pass Rate | gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61
| | | bmm + gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61
| | Geomean Speedup | inductor_locally_benchmark_586 | inductor | 2.43x | 1.56x | 2.60x
| | | inductor_locally_benchmark_585 | inductor | 2.45x | 1.56x | 2.63x
This is not the case on an older Skylake Xeon machine.
For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x.
#### BF16 28-core Skylake Xeon
| Model | Inductor | GemmAutotune | Gemm+BMM Autotune |
|--------|--------|--------|--------|
| BERT_pytorch | 1.233x | 2.597x | 2.608x |
| hf_DistilBert | 1.128x | 2.242x | 2.368x |
| hf_Reformer | 1.124x | 1.419x | 1.590x |
| hf_T5_base | 1.012x | 1.257x | 1.382x |
| hf_T5_large | 1.085x | 2.228x | 2.345x |
## Example BMM Code
```
#include <c10/util/Unroll.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
template <bool accum>
inline void cpp_bmm_micro_gemm_amx_kernel_32_2(
AMXState& amx_state,
const bfloat16* __restrict__ A,
const bfloat16* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
uint8_t tilecfg_rows
) {
// TODO(jgong5): add prefetch hint for A, B, C
auto loadconfig = [](const amx_tilecfg& cfg) {
_tile_loadconfig(&cfg);
};
const auto last_k_offset = K / 32 * 32;
const auto tail_k_size = K - last_k_offset;
if C10_LIKELY (last_k_offset > 0) {
amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig);
} else {
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig);
}
auto load_c = [&]() {
_tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float));
_tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float));
_tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float));
};
auto zero_c = [&]() {
_tile_zero(0);
_tile_zero(1);
_tile_zero(2);
_tile_zero(3);
};
if constexpr (accum) {
load_c();
} else {
zero_c();
}
auto compute = [&](int k) {
_tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16));
_tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(0, 4, 6);
_tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(1, 4, 7);
_tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16));
_tile_dpbf16ps(2, 5, 6);
_tile_dpbf16ps(3, 5, 7);
};
#pragma GCC unroll 4
for (int k = 0; k < last_k_offset; k += 32) {
compute(k);
}
auto store_c = [&]() {
// store to C
_tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float));
_tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float));
_tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float));
};
// TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
if C10_UNLIKELY (tail_k_size > 0) {
if C10_LIKELY (last_k_offset > 0) {
store_c();
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig);
load_c();
}
compute(last_k_offset);
}
store_c();
}
template <bool accum>
inline void cpp_bmm_micro_gemm_amx_kernel_16_2(
AMXState& amx_state,
const bfloat16* __restrict__ A,
const bfloat16* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
uint8_t tilecfg_rows
) {
// TODO(jgong5): add prefetch hint for A, B, C
auto loadconfig = [](const amx_tilecfg& cfg) {
_tile_loadconfig(&cfg);
};
const auto last_k_offset = K / 32 * 32;
const auto tail_k_size = K - last_k_offset;
if C10_LIKELY (last_k_offset > 0) {
amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig);
} else {
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig);
}
auto load_c = [&]() {
_tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float));
};
auto zero_c = [&]() {
_tile_zero(0);
_tile_zero(1);
};
if constexpr (accum) {
load_c();
} else {
zero_c();
}
auto compute = [&](int k) {
_tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16));
_tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(0, 2, 3);
_tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16));
_tile_dpbf16ps(1, 2, 4);
};
#pragma GCC unroll 4
for (int k = 0; k < last_k_offset; k += 32) {
compute(k);
}
auto store_c = [&]() {
// store to C
_tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
_tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float));
};
// TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
if C10_UNLIKELY (tail_k_size > 0) {
if C10_LIKELY (last_k_offset > 0) {
store_c();
amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig);
load_c();
}
compute(last_k_offset);
}
store_c();
}
template <bool accum>
inline void cpp_bmm_micro_gemm(
AMXState& amx_state,
const bfloat16* __restrict__ A,
const bfloat16* __restrict__ B,
float* __restrict__ C,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc
) {
AOTI_TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32");
AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
// TODO(jgong5): loop unroll for M and N
for (int64_t n = 0; n < N; n += 32) {
for (int64_t m = 0; m < M; m += 32) {
int64_t block_m = std::min<int64_t>(M - m, 32);
int64_t m_tail = m;
if (block_m >= 32) {
cpp_bmm_micro_gemm_amx_kernel_32_2<accum>(
amx_state,
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc,
16
);
block_m -= 32;
m_tail += 32;
}
else
if (block_m >= 16) {
cpp_bmm_micro_gemm_amx_kernel_16_2<accum>(
amx_state,
A + m * lda,
B + n,
C + m * ldc + n,
K,
lda,
ldb,
ldc,
16
);
block_m -= 16;
m_tail += 16;
}
if (block_m > 0) {
cpp_bmm_micro_gemm_amx_kernel_16_2<accum>(
amx_state,
A + m_tail * lda,
B + n,
C + m_tail * ldc + n,
K,
lda,
ldb,
ldc,
block_m
);
}
}
}
}
void threaded_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index)
{
constexpr int64_t num_threads = 48;
constexpr int64_t N = 64;
constexpr int64_t K = 96;
constexpr int64_t Mr = 32;
constexpr int64_t Nr = 32;
constexpr int64_t Kr = 32;
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
constexpr int64_t M = static_cast<int64_t>(384L);
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
constexpr int64_t Mt_blocks = 1;
constexpr int64_t Nt_blocks = 1;
constexpr int64_t Kt_blocks = 3;
constexpr int64_t Mc_blocks = 1;
constexpr int64_t Nc_blocks = 1;
constexpr int64_t Kc_blocks = 3;
constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
// make sure all partitions are assigned
AOTI_TORCH_CHECK(
Mt_blocks * Nt_blocks * Kt_blocks * 48 >= Mr_blocks * Nr_blocks * Kr_blocks,
"Not all partitions are assigned."
);
#pragma omp parallel num_threads(48)
{
const int tid = omp_get_thread_num();
const int64_t k_group_id = tid / num_Kt_blocks;
const int64_t k_slice_id = tid % num_Kt_blocks;
const int64_t n_group_id = k_group_id / num_Nt_blocks;
const int64_t n_slice_id = k_group_id % num_Nt_blocks;
const int64_t k_block_start = k_slice_id * Kt_blocks;
const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
const int64_t n_block_start = n_slice_id * Nt_blocks;
const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
AMXState amx_state;
auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get();
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
const int64_t m_start = mc * Mr;
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
const int64_t m_size = m_end - m_start;
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
const int64_t n_start = nc * Nr;
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
const int64_t n_size = n_end - n_start;
// NB: assume we pad N, nc_block_end won't exceed padded N here.
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); }
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * Kr;
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
for (int64_t nci = nc; nci < nc_block_end; nci++) {
if (kc == k_block_start) {
cpp_bmm_micro_gemm<static_cast<bool>(false)>(
amx_state,
&(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
&(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
&(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(96L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
} else {
cpp_bmm_micro_gemm<static_cast<bool>(true)>(
amx_state,
&(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
&(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
&(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(96L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
}
}
}
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16));
auto tmp1 = at::vec::convert<bfloat16>(tmp0);
tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16));
}
for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
auto tmp1 = at::vec::convert<bfloat16>(tmp0);
tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
}
}
}
}
}
}
amx_state.release([]() { _tile_release(); });
}
}
void single_thread_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index)
{
constexpr int64_t num_threads = 1;
constexpr int64_t N = 64;
constexpr int64_t K = 96;
constexpr int64_t Mr = 32;
constexpr int64_t Nr = 32;
constexpr int64_t Kr = 32;
constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
constexpr int64_t M = static_cast<int64_t>(384L);
constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
constexpr int64_t Mt_blocks = 12;
constexpr int64_t Nt_blocks = 2;
constexpr int64_t Kt_blocks = 3;
constexpr int64_t Mc_blocks = 12;
constexpr int64_t Nc_blocks = 1;
constexpr int64_t Kc_blocks = 3;
constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
// make sure all partitions are assigned
AOTI_TORCH_CHECK(
Mt_blocks * Nt_blocks * Kt_blocks * 1 >= Mr_blocks * Nr_blocks * Kr_blocks,
"Not all partitions are assigned."
);
{
constexpr int tid = 0;
constexpr int64_t k_group_id = 0;
constexpr int64_t k_slice_id = 0;
constexpr int64_t n_group_id = 0;
constexpr int64_t n_slice_id = 0;
constexpr int64_t m_block_start = 0;
constexpr int64_t n_block_start = 0;
constexpr int64_t n_block_end = Nr_blocks;
constexpr int64_t k_block_start = 0;
constexpr int64_t k_block_end = Kr_blocks;
constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
constexpr int64_t m_block_end = Mr_blocks;
AMXState amx_state;
auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get();
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
const int64_t m_start = mc * Mr;
const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
const int64_t m_size = m_end - m_start;
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
const int64_t n_start = nc * Nr;
const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
const int64_t n_size = n_end - n_start;
// NB: assume we pad N, nc_block_end won't exceed padded N here.
const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); }
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * Kr;
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
for (int64_t nci = nc; nci < nc_block_end; nci++) {
if (kc == k_block_start) {
cpp_bmm_micro_gemm<static_cast<bool>(false)>(
amx_state,
&(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
&(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
&(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(96L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
} else {
cpp_bmm_micro_gemm<static_cast<bool>(true)>(
amx_state,
&(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
&(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
&(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
static_cast<int64_t>(m_end + ((-1L)*m_start)),
static_cast<int64_t>(Nr),
static_cast<int64_t>(k_end + ((-1L)*k_start)),
static_cast<int64_t>(96L),
static_cast<int64_t>(32L),
static_cast<int64_t>(Nc_blocks*Nr)
);
}
}
}
{
{
#pragma GCC ivdep
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
{
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16));
auto tmp1 = at::vec::convert<bfloat16>(tmp0);
tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16));
}
for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
auto tmp1 = at::vec::convert<bfloat16>(tmp0);
tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
}
}
}
}
}
}
amx_state.release([]() { _tile_release(); });
}
}
extern "C"
void cpp_bmm(const bfloat16* X, const bfloat16* W, bfloat16* Y)
{
const int64_t B = static_cast<int64_t>(5L);
constexpr int64_t num_threads = 48;
int64_t B_single_thread_block = (B / num_threads) * num_threads;
#pragma omp parallel for num_threads(48)
for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) {
single_thread_mm(X, W, Y, b_start);
}
for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) {
threaded_mm(X, W, Y, b_start);
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129772
Approved by: https://github.com/jgong5 , https://github.com/leslie-fang-intel , https://github.com/jansel
2024-12-06 04:54:00 +00:00
niklasz
3f457ee1f6
Fix AOT Graph capture not propagating non_blocking copy parameter to … ( #136513 )
...
…inductor codegen.
Fixes #136260
**Note**: this is my first code contribution to torch so please let me know if there's anything I need to fix/some other convention I should follow.
Regarding the bug, re-running the issue's reproduction code:
```
import torch
def fn(x):
return x.to(device="cuda", non_blocking=True)
inp = torch.randn(3, 4)
torch.compile(fn)(inp)
```
We now have the non_blocking being passed on to codegen properly:
```
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] TRACED GRAPH
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] ===== pre insert_deferred_runtime_asserts __compiled_fn_1 =====
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] <eval_with_key>.0 class GraphModule(torch.nn.Module):
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] def forward(self, L_x_: "f32[3, 4]"):
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] l_x_ = L_x_
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] to: "f32[3, 4]" = l_x_.to(device = 'cuda', non_blocking = True); l_x_ = None
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] return (to,)
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] TRACED GRAPH
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] ===== __compiled_fn_1 =====
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] def forward(self, L_x_: "f32[3, 4][4, 1]cpu"):
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] l_x_ = L_x_
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] to: "f32[3, 4][4, 1]cuda:0" = l_x_.to(device = 'cuda', non_blocking = True); l_x_ = None
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] return (to,)
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.404000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:114] [0/0] [__aot_graphs] aot_config id: 0, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=False, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[], subclass_inp_meta=[0], subclass_fw_graph_out_meta=[0], subclass_tangent_meta=[], is_train=False, traced_tangent_metas=None, num_symints_saved_for_bw=None, grad_enabled_mutation=None, deterministic=None, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None, num_backward_tokens=0),subclass_metadata=None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] TRACED GRAPH
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] ===== Forward graph 0 =====
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] def forward(self, arg0_1: "f32[3, 4][4, 1]cpu"):
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] device_put: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.device_put.default(arg0_1, device(type='cuda', index=0), True); arg0_1 = None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] convert_element_type: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.convert_element_type.default(device_put, torch.float32); device_put = None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] return (convert_element_type,)
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1134] [0/0] [__output_code] Output code written to: /tmp/torchinductor_niklasz/ha/chaai264g6ribfw3q2qhl6ayjtaqaavku5wivxtzw4nabgd6htsv.py
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] Output code:
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] # AOT ID: ['0_inference']
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import torch
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import math
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import random
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import os
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import tempfile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from math import inf, nan
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch import device, empty_strided
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] aten = torch.ops.aten
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] _quantized = torch.ops._quantized
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile = AsyncCompile()
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile.wait(globals())
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] del async_compile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def call(args):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] arg0_1, = args
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] args.clear()
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] assert_size_stride(arg0_1, (3, 4), (4, 1))
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] with torch.cuda._DeviceGuard(0):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] torch.cuda.set_device(0)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] buf0 = empty_strided_cuda((3, 4), (4, 1), torch.float32)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] buf0.copy_(arg0_1, True)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] del arg0_1
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] return (buf0, )
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._dynamo.testing import rand_strided
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.utils import print_performance
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] arg0_1 = rand_strided((3, 4), (4, 1), device='cpu', dtype=torch.float32)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] fn = lambda: call([arg0_1])
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] return print_performance(fn, times=times, repeat=repeat)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] if __name__ == "__main__":
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.wrapper_benchmark import compiled_module_main
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] compiled_module_main('None', benchmark_compiled_module)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
```
See above line `buf0.copy_(arg0_1, True)`. Specific log setting used: `export TORCH_LOGS="graph_code,aot_graphs,output_code"`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136513
Approved by: https://github.com/eellison
2024-10-01 00:32:47 +00:00