mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: 1. Added CudaFusionGuard as the custom TypeCheck for nvfuser; enabled dynamic shape support with profiling executor; 2. dropped support for legacy fuser; 3. re-enabled nvfuser tests; 4. added registration for profiling record to allow profiling on user specified nodes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/46452 Reviewed By: zou3519, anjali411 Differential Revision: D24364642 Pulled By: ngimel fbshipit-source-id: daf53a9a6b6636e1ede420a3a6d0397d4a8b450b
685 lines
24 KiB
C++
685 lines
24 KiB
C++
#include <torch/csrc/jit/codegen/cuda/scheduler.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
|
#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
|
|
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
|
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
|
#include <torch/csrc/jit/codegen/cuda/parser.h>
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
constexpr int kUnrollFactor = 1;
|
|
|
|
namespace {
|
|
|
|
std::vector<int> reductionAxes(TensorView* tv) {
|
|
size_t n_dims = tv->nDims();
|
|
std::vector<int> reduction_axes;
|
|
for (size_t i = 0; i < n_dims; i++) {
|
|
if (tv->axis(i)->isReduction()) {
|
|
reduction_axes.emplace_back(i);
|
|
}
|
|
}
|
|
return reduction_axes;
|
|
}
|
|
|
|
// Merge all reduction to the right side and returns total number of
|
|
// reduction axes
|
|
size_t mergeReduction(TensorView* tv) {
|
|
int prev_i = -1;
|
|
size_t num_merged = 0;
|
|
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
|
|
if (!tv->axis(i)->isReduction()) {
|
|
continue;
|
|
}
|
|
if (prev_i == -1) {
|
|
prev_i = i;
|
|
} else {
|
|
tv->merge(i, prev_i);
|
|
prev_i = i;
|
|
num_merged++;
|
|
}
|
|
}
|
|
if (prev_i == 0) {
|
|
tv->reorder({{prev_i, -1}});
|
|
}
|
|
|
|
return prev_i == -1 ? 0 : num_merged + 1;
|
|
}
|
|
|
|
// merge all non-reduction axes to the left side and returns total number of
|
|
// iteration axes
|
|
size_t mergeNonReduction(TensorView* tv) {
|
|
int prev_i = -1;
|
|
size_t num_merged = 0;
|
|
for (int i = static_cast<int>(tv->nDims()) - 1; i >= 0; i--) {
|
|
if (tv->axis(i)->isReduction()) {
|
|
continue;
|
|
}
|
|
if (prev_i == -1) {
|
|
prev_i = i;
|
|
} else {
|
|
tv->merge(i, prev_i);
|
|
prev_i = i;
|
|
num_merged++;
|
|
}
|
|
}
|
|
if (prev_i != 0) {
|
|
tv->reorder({{prev_i, 0}});
|
|
}
|
|
|
|
return prev_i == -1 ? 0 : num_merged + 1;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// This one is a total mess and it should go.
|
|
bool scheduleFusion(Fusion* fusion, const at::ArrayRef<c10::IValue> inputs) {
|
|
FUSER_PERF_SCOPE("scheduleFusion");
|
|
|
|
FusionGuard fg(fusion);
|
|
// maybe has_reduction for scheudling should be done on a per output tensor
|
|
// basis.
|
|
TORCH_INTERNAL_ASSERT(
|
|
!fusion->hasReduction(), "This scheduler only handles pointwise ops.");
|
|
const bool disable_unroll = fusion->isStochastic();
|
|
|
|
for (auto out_val : fusion->outputs()) {
|
|
auto out = out_val->as<TensorView>();
|
|
|
|
// Merge all dimensions because we're only supporting pointwise
|
|
while (out->nDims() > 1) {
|
|
out->merge(-2, -1);
|
|
}
|
|
}
|
|
|
|
// Run through outputs, grab all inputs of outputs
|
|
// squeeze with computeAt to set overall structure.
|
|
for (auto output : fusion->outputs()) {
|
|
if (output->getValType() != ValType::TensorView)
|
|
continue;
|
|
TensorView* out_tv = output->as<TensorView>();
|
|
|
|
// Split into 128 which will be bockDim.x
|
|
out_tv->split(0, kPwThreadX);
|
|
// Split by another 4 which will be our unroll factor
|
|
auto ur_factor = disable_unroll ? 1 : kUnrollFactor;
|
|
out_tv->split(0, ur_factor);
|
|
}
|
|
|
|
for (auto output : fusion->outputs()) {
|
|
if (output->getValType() != ValType::TensorView)
|
|
continue;
|
|
TensorView* out_tv = output->as<TensorView>();
|
|
for (Val* inp : fusion->inputsOf(output)) {
|
|
if (inp->getValType().value() == ValType::TensorView)
|
|
inp->as<TensorView>()->computeAt(out_tv, -1);
|
|
}
|
|
out_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
out_tv->axis(1)->parallelize(ParallelType::Unroll);
|
|
out_tv->axis(2)->parallelize(ParallelType::TIDx);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
namespace {
|
|
// Largest Power of 2 less-than n
|
|
constexpr int lastPow2(int n) {
|
|
n |= (n >> 1);
|
|
n |= (n >> 2);
|
|
n |= (n >> 4);
|
|
n |= (n >> 8); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
|
n |= (n >> 16); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
|
return std::max(1, n - (n >> 1));
|
|
}
|
|
|
|
ReductionParams reductionHeuristic(
|
|
int red_elems,
|
|
int red_outputs,
|
|
bool red_on_fastest_dim) {
|
|
ReductionParams rparams;
|
|
rparams.fastest_dim = red_on_fastest_dim;
|
|
|
|
int gdimx = LaunchParams::UNINITIALIZED_VAL;
|
|
int gdimy = LaunchParams::UNINITIALIZED_VAL;
|
|
int bdimx = LaunchParams::UNINITIALIZED_VAL;
|
|
int bdimy = LaunchParams::UNINITIALIZED_VAL;
|
|
|
|
// 1. Initial Assumptions
|
|
|
|
// Evaluate Dimensions of Reduction TensorView
|
|
TORCH_INTERNAL_ASSERT(red_elems > 0 && red_outputs > 0);
|
|
|
|
// 2. Initial Definition of Block Dimensions
|
|
|
|
// Is fastest dimension a reduction dimension?
|
|
if (rparams.fastest_dim) {
|
|
if (red_elems < rparams.loop_unroll) {
|
|
rparams.loop_unroll = 1;
|
|
}
|
|
bdimx = ceilDiv(red_elems, rparams.loop_unroll);
|
|
bdimy = red_outputs;
|
|
} else {
|
|
bdimx = red_outputs;
|
|
bdimy = red_elems;
|
|
}
|
|
|
|
// 3. Applying Power of 2 Blocking based on the Maximum Number of threads
|
|
|
|
constexpr int kMaxNumThreads = 512;
|
|
int num_threads = kMaxNumThreads;
|
|
int device_warp_size = at::cuda::warp_size();
|
|
|
|
if (bdimx < num_threads) {
|
|
bdimx = lastPow2(bdimx);
|
|
} else {
|
|
bdimx = num_threads;
|
|
}
|
|
|
|
if (bdimy < num_threads) {
|
|
bdimy = lastPow2(bdimy);
|
|
} else {
|
|
bdimy = num_threads;
|
|
}
|
|
|
|
int bdimx_prev = bdimx;
|
|
bdimx = std::min(bdimx, device_warp_size);
|
|
bdimy = std::min(bdimy, num_threads / bdimx);
|
|
bdimx = std::min(bdimx_prev, num_threads / bdimy);
|
|
|
|
// 4. Distributing work across a block
|
|
|
|
// Magic numbers of calculations allowed per thread.
|
|
constexpr int kMinValuesPerThread = 16;
|
|
constexpr int kMaxValuesPerThread = 256;
|
|
|
|
int inputs_consumed_per_block_iter = 1;
|
|
int red_elems_per_thread = red_elems;
|
|
|
|
int outputs_produced_per_block_iter = 1;
|
|
|
|
// Reduction is performed across warp threads (cross-thread reduction)
|
|
if (rparams.fastest_dim) {
|
|
inputs_consumed_per_block_iter *= bdimx;
|
|
red_elems_per_thread =
|
|
ceilDiv(red_elems_per_thread, inputs_consumed_per_block_iter);
|
|
// Warp threads are applied across the output
|
|
} else {
|
|
outputs_produced_per_block_iter *= bdimx;
|
|
}
|
|
|
|
// Decision to do a cross-warp reduction per block
|
|
if (red_elems_per_thread >= (bdimy * kMinValuesPerThread) ||
|
|
red_elems_per_thread >= kMaxValuesPerThread || !rparams.fastest_dim) {
|
|
inputs_consumed_per_block_iter *= bdimy;
|
|
red_elems_per_thread = ceilDiv(red_elems_per_thread, bdimy);
|
|
rparams.cross_block = true;
|
|
rparams.mul_reds_per_blk = false;
|
|
// Do multiple reductions per block
|
|
} else {
|
|
rparams.cross_block = false;
|
|
rparams.mul_reds_per_blk = true;
|
|
outputs_produced_per_block_iter *= bdimy;
|
|
}
|
|
|
|
// 5. Distributing work across blocks
|
|
|
|
// WARNING: Current device for codegen may not be the target device
|
|
int device_max_threads_per_multiprocessor =
|
|
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor;
|
|
int device_multiprocessor_count =
|
|
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
|
|
|
int blocks_per_sm = device_max_threads_per_multiprocessor / (bdimx * bdimy);
|
|
int target_grid_size = device_multiprocessor_count * blocks_per_sm;
|
|
|
|
// Setting the number of blocks based on the number of outputs
|
|
gdimx = ceilDiv(red_outputs, outputs_produced_per_block_iter);
|
|
|
|
// Cross-block reductions (if necessary)
|
|
if (rparams.cross_block && red_elems_per_thread >= kMaxValuesPerThread &&
|
|
gdimx <= target_grid_size) {
|
|
int blks_per_out_1 = ceilDiv(target_grid_size, gdimx);
|
|
int blks_per_out_2 = ceilDiv(red_elems_per_thread, kMinValuesPerThread);
|
|
int blks_per_out_3 = ceilDiv(red_elems_per_thread, kMaxValuesPerThread);
|
|
int blks_per_output =
|
|
std::max(std::min(blks_per_out_1, blks_per_out_2), blks_per_out_3);
|
|
|
|
gdimy = std::max(1, blks_per_output);
|
|
// If a cross-block reduction was generated
|
|
if (blks_per_output > 1) {
|
|
rparams.cross_grid = true;
|
|
}
|
|
}
|
|
|
|
const char* debug_env = getenv("PYTORCH_CUDA_FUSER_RED_SCHED_DEBUG");
|
|
if (debug_env && atoi(debug_env)) {
|
|
std::cout << "\n===== Reduction Parameters ========" << std::endl
|
|
<< "Inputs:" << std::endl
|
|
<< "\tRed Elems: " << red_elems << " Red Outputs: " << red_outputs
|
|
<< " Red On Fastest Dim? " << red_on_fastest_dim << std::endl
|
|
<< "Reduction Characteristics:" << std::endl
|
|
<< "\tMultiple Reds Per Block? " << rparams.mul_reds_per_blk
|
|
<< " Cross Block? " << rparams.cross_block << " Cross Grid? "
|
|
<< rparams.cross_grid << std::endl
|
|
<< "Recommended Blocking:" << std::endl
|
|
<< "\tGridX: " << gdimx << " GridY: " << gdimy
|
|
<< " BlckX: " << bdimx << " BlckY: " << bdimy << std::endl
|
|
<< "====================================" << std::endl;
|
|
}
|
|
|
|
rparams.lparams = LaunchParams(
|
|
LaunchParams::UNINITIALIZED_VAL,
|
|
gdimy,
|
|
LaunchParams::UNINITIALIZED_VAL,
|
|
bdimx,
|
|
bdimy,
|
|
LaunchParams::UNINITIALIZED_VAL);
|
|
return rparams;
|
|
}
|
|
} // anonymous namespace
|
|
|
|
TORCH_CUDA_API c10::optional<ReductionParams> getReductionHeuristics(
|
|
Fusion* fusion,
|
|
const at::ArrayRef<c10::IValue>& fusion_inputs,
|
|
TensorView* red_tv) {
|
|
FUSER_PERF_SCOPE("scheduleReduction");
|
|
|
|
FusionGuard fg(fusion);
|
|
|
|
if (!fusion->hasReduction()) {
|
|
return c10::nullopt;
|
|
}
|
|
|
|
auto red_root_dom = red_tv->getRootDomain();
|
|
const bool red_on_fastest_dim =
|
|
red_root_dom[red_root_dom.size() - 1]->isReduction();
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
red_tv != nullptr, "Reduction TensorView wasn't found.");
|
|
|
|
if (!fusion->hasReduction()) {
|
|
return c10::nullopt;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
red_tv->hasReduction(), "TensorView doesn't have a reduction.");
|
|
const auto red_expr = fusion->origin(red_tv);
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
red_expr->getExprType() != c10::nullopt &&
|
|
red_expr->getExprType().value() == ExprType::ReductionOp,
|
|
"TensorView doesn't have a reduction.");
|
|
|
|
StatefulExpressionEvaluator evaluator(
|
|
executor_utils::statefulBindInputs(fusion_inputs, fusion));
|
|
|
|
int64_t red_outputs = 1;
|
|
int64_t red_elements = 1;
|
|
|
|
for (auto id : red_tv->getRootDomain()) {
|
|
auto inferred_val = evaluator.inferValue(id->rawExtent());
|
|
TORCH_INTERNAL_ASSERT(
|
|
inferred_val.has_value(), "Error inferring reduction size.");
|
|
if (id->isReduction()) {
|
|
red_elements *= inferred_val.value();
|
|
} else {
|
|
red_outputs *= inferred_val.value();
|
|
}
|
|
}
|
|
|
|
return reductionHeuristic(red_elements, red_outputs, red_on_fastest_dim);
|
|
}
|
|
|
|
// fusion is the input IR that will be modified by this function
|
|
void scheduleReduction(
|
|
Fusion* fusion,
|
|
const ReductionParams& rparams,
|
|
TensorView* red_tv,
|
|
std::vector<TensorView*> outs_of_red) {
|
|
FusionGuard fg(fusion);
|
|
|
|
// We coalesc all reduction axes to the right;
|
|
mergeReduction(red_tv);
|
|
|
|
// Merge all iteration dimensions
|
|
mergeNonReduction(red_tv);
|
|
for (auto iter_tv : outs_of_red) {
|
|
mergeNonReduction(iter_tv);
|
|
}
|
|
|
|
// Evaluate Dimensions of Reduction TensorView
|
|
auto red_ids = red_tv->domain()->domain();
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
red_ids.size() == 2, "We coalesced all dimensions into 2 previously.");
|
|
|
|
constexpr int kLoopUnrollSplit = 4;
|
|
|
|
// Scheduling the Reduction
|
|
if (rparams.fastest_dim) {
|
|
// Do multiple reductions per block
|
|
if (rparams.mul_reds_per_blk) {
|
|
// Reduction Splits
|
|
// [outputs, |rF-Leftover, X-Warp, rf-Unroll|]
|
|
// Idx: 0 | 1(-1) 2(-2) 3(-1) |
|
|
// --------------------------------
|
|
// Reduction Dimensions
|
|
red_tv->split(1, rparams.loop_unroll);
|
|
red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
|
|
|
|
// Output Splits
|
|
// [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
|
|
// Idx: | 0 1 | 2(-2) -- 3(-1)
|
|
// ----------------------------
|
|
// Output Dimensions
|
|
red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDy));
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDy));
|
|
}
|
|
|
|
auto red_tv_rf = red_tv->rFactor({-3, -1});
|
|
|
|
// WARNING: computeAt will coalesce the rFactored dimensions
|
|
// rFactored Reduction Tensor after computeAt():
|
|
// [<output dims>, | rF-Leftover, X-Warp, rF-Unroll|]
|
|
// Idx: 0 -- 1 | 2(-3) 3(-2) 4(-1) |
|
|
// ---------------------------------
|
|
// Reduction Dimensions
|
|
red_tv_rf->computeAt(red_tv, -1);
|
|
|
|
// After the Reduction Tensor has rFactoring applied
|
|
// Reduction Output Tensor:
|
|
// [Out-Leftover, Out-PerBlock, X-Warp]
|
|
// Idx: 0 1 2(-1)
|
|
if (!outs_of_red.empty()) {
|
|
red_tv->computeAt(outs_of_red[0], -1);
|
|
}
|
|
|
|
red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
|
|
|
|
red_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
}
|
|
red_tv->axis(1)->parallelize(ParallelType::TIDy);
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->axis(1)->parallelize(ParallelType::TIDy);
|
|
}
|
|
red_tv->axis(-1)->parallelize(ParallelType::TIDx);
|
|
|
|
// Bind Inputs to Reduction
|
|
for (auto input : fusion->inputsOf(red_tv_rf)) {
|
|
if (input->getValType().value() == ValType::TensorView) {
|
|
input->as<TensorView>()->computeAt(red_tv_rf, -1);
|
|
}
|
|
}
|
|
// Do a cross-warp reduction per block
|
|
} else {
|
|
if (rparams.cross_grid) {
|
|
// Reduction Splits
|
|
// [outputs, |rF-Leftover, X-Grid, X-Block, X-Warp, rf-Unroll|]
|
|
// Idx: 0 | 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) |
|
|
// -------------------------------------------------
|
|
// Reduction Dimensions
|
|
red_tv->split(1, rparams.loop_unroll);
|
|
red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
|
|
red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
|
|
red_tv->split(1, NamedScalar::getParallelDim(ParallelType::BIDy));
|
|
|
|
auto red_tv_rf = red_tv->rFactor(
|
|
{-5, -1}); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
|
|
|
// WARNING: computeAt will coalesce the rFactored dimensions
|
|
// rFactored Reduction Tensor after computeAt():
|
|
// [Outputs, |X-Grid, X-Block, X-Warp, rF-Leftover, rF-Unroll|]
|
|
// Idx: 0 | 1(-5) 2(-4) 3(-3) 4(-2) 5(-1) |
|
|
// -------------------------------------------------
|
|
// Reduction Dimensions
|
|
red_tv_rf->computeAt(red_tv, -1);
|
|
|
|
// After the Reduction Tensor has rFactoring applied
|
|
// Reduction Output Tensor:
|
|
// [Outputs, X-Grid, X-Block, X-Warp]
|
|
// Idx: 0 1(-3) 2(-2) 3(-1)
|
|
|
|
if (!outs_of_red.empty()) {
|
|
red_tv->computeAt(outs_of_red[0], -1);
|
|
}
|
|
|
|
red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
|
|
|
|
red_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
}
|
|
red_tv->axis(-1)->parallelize(ParallelType::TIDx);
|
|
red_tv->axis(-2)->parallelize(ParallelType::TIDy);
|
|
red_tv->axis(-3)->parallelize(ParallelType::BIDy);
|
|
|
|
// Bind Inputs to Reduction
|
|
for (auto input : fusion->inputsOf(red_tv_rf)) {
|
|
if (input->getValType().value() == ValType::TensorView) {
|
|
input->as<TensorView>()->computeAt(red_tv_rf, -1);
|
|
}
|
|
}
|
|
} else {
|
|
// Reduction Splits
|
|
// [outputs, |rF-Leftover, X-Block, X-Warp, rf-Unroll|]
|
|
// Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
|
|
// -----------------------------------------
|
|
// Reduction Dimensions
|
|
red_tv->split(1, rparams.loop_unroll);
|
|
red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
|
|
red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
|
|
|
|
auto red_tv_rf = red_tv->rFactor({-4, -1});
|
|
|
|
// WARNING: computeAt will coalesce the rFactored dimensions
|
|
// rFactored Reduction Tensor after computeAt():
|
|
// [Outputs, |X-Block, X-Warp, rF-Leftover, rF-Unroll|]
|
|
// Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
|
|
// -----------------------------------------
|
|
// Reduction Dimensions
|
|
red_tv_rf->computeAt(red_tv, -1);
|
|
|
|
// After the Reduction Tensor has rFactoring applied
|
|
// Reduction Output Tensor:
|
|
// [Outputs, X-Block, X-Warp]
|
|
// Idx: 0 1(-2) 2(-1)
|
|
|
|
if (!outs_of_red.empty()) {
|
|
red_tv->computeAt(outs_of_red[0], -1);
|
|
}
|
|
|
|
red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
|
|
|
|
red_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
}
|
|
red_tv->axis(-1)->parallelize(ParallelType::TIDx);
|
|
red_tv->axis(-2)->parallelize(ParallelType::TIDy);
|
|
|
|
// Bind Inputs to Reduction
|
|
for (auto input : fusion->inputsOf(red_tv_rf)) {
|
|
if (input->getValType().value() == ValType::TensorView) {
|
|
input->as<TensorView>()->computeAt(red_tv_rf, -1);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
if (rparams.cross_block) {
|
|
if (rparams.cross_grid) {
|
|
// Reduction Splits
|
|
// [outputs, |rF-Leftover, rf-Unroll, X-Grid, X-Block|]
|
|
// Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
|
|
// -----------------------------------------
|
|
// Reduction Dimensions
|
|
red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
|
|
red_tv->split(1, NamedScalar::getParallelDim(ParallelType::BIDy));
|
|
red_tv->split(1, kLoopUnrollSplit);
|
|
|
|
// Reordering the Unroll dimension eases applying computeAt()
|
|
// for preceeding operations and the rFactored Tensor.
|
|
// |--- Reordered ----|
|
|
// V V
|
|
// [outputs, |rF-Leftover, X-Block, X-Grid, rF-Unroll|]
|
|
// Idx: 0 | 1(-4) 2(-3) 3(-2) 4(-1) |
|
|
// -----------------------------------------
|
|
// Reduction Dimensions
|
|
red_tv->reorder({{-1, -3}, {-3, -1}});
|
|
|
|
// Output Splits
|
|
// [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
|
|
// Idx: | 0 1 | 2(-4) -- 5(-1)
|
|
// ----------------------------
|
|
// Output Dimensions
|
|
red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
|
|
}
|
|
|
|
auto red_tv_rf = red_tv->rFactor({-4, -1});
|
|
|
|
// WARNING: computeAt will coalesce the rFactored dimensions
|
|
// rFactored Reduction Tensor after computeAt():
|
|
// [<output dims>, |X-Block, X-Grid, rF-Leftover, rF-Unroll|]
|
|
// Idx: 0 -- 1 | 2(-4) 3(-3) 4(-2) 5(-1) |
|
|
// -----------------------------------------
|
|
// Reduction Dimensions
|
|
red_tv_rf->computeAt(red_tv, -1);
|
|
|
|
// After the Reduction Tensor has rFactoring applied
|
|
// Reduction Output Tensor:
|
|
// [Out-Leftover, Out-PerBlock, X-Block, X-Grid]
|
|
// Idx: 0 1 2(-2) 3(-1)
|
|
|
|
if (!outs_of_red.empty()) {
|
|
red_tv->computeAt(outs_of_red[0], -1);
|
|
}
|
|
|
|
red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
|
|
|
|
red_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
iter_tv->axis(1)->parallelize(ParallelType::TIDx);
|
|
}
|
|
|
|
red_tv->axis(-3)->parallelize(ParallelType::TIDx);
|
|
red_tv->axis(-2)->parallelize(ParallelType::TIDy);
|
|
red_tv->axis(-1)->parallelize(ParallelType::BIDy);
|
|
|
|
// Bind Inputs to Reduction
|
|
for (auto input : fusion->inputsOf(red_tv_rf)) {
|
|
if (input->getValType().value() == ValType::TensorView) {
|
|
input->as<TensorView>()->computeAt(red_tv_rf, -1);
|
|
}
|
|
}
|
|
} else {
|
|
// Reduction Splits
|
|
// [outputs, |rF-Leftover, rf-Unroll, X-Block|]
|
|
// Idx: 0 | 1(-3) 2(-2) 3(-1) |
|
|
// ---------------------------------
|
|
// Reduction Dimensions
|
|
red_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
|
|
red_tv->split(1, kLoopUnrollSplit);
|
|
|
|
// Reordering the Unroll dimension eases applying computeAt()
|
|
// for preceeding operations and the rFactored Tensor.
|
|
// |- Reordered -|
|
|
// V V
|
|
// [outputs, |rF-Leftover, X-Block, rF-Unroll|]
|
|
// Idx: 0 | 1(-3) 2(-2) 3(-1) |
|
|
// ---------------------------------
|
|
// Reduction Dimensions
|
|
red_tv->reorder({{-1, -2}, {-2, -1}});
|
|
|
|
// Output Splits
|
|
// [|Out-Leftover, Out-PerBlock|, <Reduction Dims>]
|
|
// Idx: | 0 1 | 2(-3) -- 4(-1)
|
|
// ----------------------------
|
|
// Output Dimensions
|
|
red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
|
|
}
|
|
|
|
auto red_tv_rf = red_tv->rFactor({-3, -1});
|
|
|
|
// WARNING: computeAt will coalesce the rFactored dimensions
|
|
// rFactored Reduction Tensor after computeAt():
|
|
// [<output dims>, |X-Block, rF-Leftover, rF-Unroll|]
|
|
// Idx: 0 -- 1 | 2(-3) 3(-2) 4(-1) |
|
|
// ---------------------------------
|
|
// Reduction Dimensions
|
|
red_tv_rf->computeAt(red_tv, -1);
|
|
|
|
// After the Reduction Tensor has rFactoring applied
|
|
// Reduction Output Tensor:
|
|
// [Out-Leftover, Out-PerBlock, X-Block]
|
|
// Idx: 0 1 2(-1)
|
|
|
|
if (!outs_of_red.empty()) {
|
|
red_tv->computeAt(outs_of_red[0], -1);
|
|
}
|
|
|
|
red_tv_rf->axis(-1)->parallelize(ParallelType::Unroll);
|
|
|
|
red_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
iter_tv->axis(1)->parallelize(ParallelType::TIDx);
|
|
}
|
|
red_tv->axis(-2)->parallelize(ParallelType::TIDx);
|
|
red_tv->axis(-1)->parallelize(ParallelType::TIDy);
|
|
|
|
// Bind Inputs to Reduction
|
|
for (auto input : fusion->inputsOf(red_tv_rf)) {
|
|
if (input->getValType().value() == ValType::TensorView) {
|
|
input->as<TensorView>()->computeAt(red_tv_rf, -1);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
red_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->split(0, NamedScalar::getParallelDim(ParallelType::TIDx));
|
|
}
|
|
|
|
if (!outs_of_red.empty()) {
|
|
red_tv->computeAt(outs_of_red[0], -1);
|
|
}
|
|
|
|
red_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
red_tv->axis(1)->parallelize(ParallelType::TIDx);
|
|
for (auto iter_tv : outs_of_red) {
|
|
iter_tv->axis(0)->parallelize(ParallelType::BIDx);
|
|
iter_tv->axis(1)->parallelize(ParallelType::TIDx);
|
|
}
|
|
|
|
for (auto input : fusion->inputsOf(red_tv)) {
|
|
if (input->getValType().value() == ValType::TensorView) {
|
|
input->as<TensorView>()->computeAt(red_tv, -1);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|