[NVFuser] Upstream push 0809 (#83067)

Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. removes un-necessary sync from redundant thread compute analysis
  2. symmetric API for BestEffortReplay
  3. support merge on trivial reductions
  4. Ampere async copy improvements
- bug fixes:
  1. vectorization bug fixes
  2. type inference patch : fixes upstream #81725
  3. segmenter bug fix with deterministic iteration ordering
- parser update
  1. added leaky_relu
- scheduler
  1. normalization scheduler clean up.
  2. simplifies matmul scheduling with new transform propagator
  3. merge all dimensions in PW scheduler
  4. various gemm related improvements
- debuggability
  1. nsight compute support
  2. debug dump for InlinePropagator
  3. Add `UnaryOpType::Print`

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD
16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884)
7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803
3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD
01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878)
0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881)
7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875)
501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826)
d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827)
e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824)
9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823)
a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872)
172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868)
a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871)
440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870)
d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867)
1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864)
51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866)
a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering.  (#1865)
1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861)
1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator.  (#1817)
bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852)
e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858)
9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857)
20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855)
405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822)
01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853)
5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846)
92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845)
db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848)
102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847)
b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687)
942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842)
0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840)
63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825)
9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839)
2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067
Approved by: https://github.com/davidberard98
This commit is contained in:
jjsjann123 2022-08-09 05:54:39 -07:00 committed by PyTorch MergeBot
parent ce8716f59a
commit df741c589f
109 changed files with 5870 additions and 3231 deletions

View File

@ -133,8 +133,8 @@ static void MagicScheduler_DivMaxSoftDropFwd(
std::vector<at::Tensor> cg_outputs;
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, *norm_params);
FusionExecutor fe;
fe.compileFusion(&fusion);
@ -143,7 +143,7 @@ static void MagicScheduler_DivMaxSoftDropFwd(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
cg_outputs = fe.runFusion({t0, t1}, norm_params.value().lparams);
cg_outputs = fe.runFusion({t0, t1}, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
@ -193,8 +193,8 @@ static void MagicScheduler_DivMaxSoftDropBwd(
std::vector<at::Tensor> cg_outputs;
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, *norm_params);
FusionExecutor fe;
fe.compileFusion(&fusion);
@ -203,7 +203,7 @@ static void MagicScheduler_DivMaxSoftDropBwd(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params.value().lparams);
cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
@ -308,8 +308,8 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd(
std::vector<at::Tensor> cg_outputs;
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, *norm_params);
FusionExecutor fe;
fe.compileFusion(&fusion);
@ -319,7 +319,7 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
@ -423,8 +423,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd1(
std::vector<at::Tensor> cg_outputs;
auto norm_params = getReductionHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
scheduleReduction(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
scheduleReduction(&fusion, *norm_params);
FusionExecutor fe;
fe.compileFusion(&fusion);
@ -434,7 +434,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd1(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
clearL2Cache();
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
@ -534,8 +534,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd2(
std::vector<at::Tensor> cg_outputs;
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
schedulePersistentKernel(&fusion, *norm_params);
FusionExecutor fe;
fe.compileFusion(&fusion);
@ -545,7 +545,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd2(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
@ -625,8 +625,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd3(
std::vector<at::Tensor> cg_outputs;
auto norm_params = getReductionHeuristics(&fusion, at_inputs);
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
scheduleReduction(&fusion, norm_params.value());
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
scheduleReduction(&fusion, *norm_params);
FusionExecutor fe;
fe.compileFusion(&fusion);
@ -636,7 +636,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd3(
cudaDeviceSynchronize();
for (auto _ : benchmark_state) {
CudaKernelTimer timer;
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the

View File

@ -69,8 +69,7 @@ static void NvFuserScheduler_Broadcast(
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
auto executor_instance = compile_log.fusion_executor;
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
auto params = toString(compile_log.pointwise_params.value());
auto params = toString(compile_log.params);
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
benchmark_state.SetLabel(params + lparams);

View File

@ -65,8 +65,7 @@ static void NvFuserScheduler_Reduction(
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
auto executor_instance = compile_log.fusion_executor;
TORCH_INTERNAL_ASSERT(compile_log.reduction_params.has_value());
auto rparams = toString(compile_log.reduction_params.value());
auto rparams = toString(compile_log.params);
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
benchmark_state.SetLabel(rparams + lparams);

View File

@ -135,8 +135,7 @@ static void NvFuserScheduler_SBR(
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
auto executor_instance = compile_log.fusion_executor;
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
auto params = toString(compile_log.pointwise_params.value());
auto params = toString(compile_log.params);
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
benchmark_state.SetLabel(params + lparams);
@ -238,8 +237,7 @@ static void NvFuserScheduler_SBR_Norm(
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
auto executor_instance = compile_log.fusion_executor;
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
auto params = toString(compile_log.pointwise_params.value());
auto params = toString(compile_log.params);
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
benchmark_state.SetLabel(params + lparams);

View File

@ -84,8 +84,8 @@ static void setupTranspose(
return (is_transpose) ? transpose(tv, axes.first, axes.second) : tv;
};
auto input1 = makeContigTensor(num_dims);
auto input2 = makeContigTensor(num_dims);
auto input1 = makeContigTensor(num_dims, dtype);
auto input2 = makeContigTensor(num_dims, dtype);
fusion->addInput(input1);
fusion->addInput(input2);

View File

@ -89,6 +89,20 @@ std::string toString(PointwiseParams params) {
return ss.str();
}
std::string toString(const std::shared_ptr<HeuristicParams>& params) {
auto rparams = std::dynamic_pointer_cast<ReductionParams>(params);
if (rparams) {
return toString(*rparams);
}
auto pparams = std::dynamic_pointer_cast<PointwiseParams>(params);
if (pparams) {
return toString(*pparams);
}
TORCH_INTERNAL_ASSERT(
false,
"Unknown heuristic parameter type. Did you just added a new heuristic parameter type but forget to update here?");
}
std::string toString(LaunchParams lparams) {
std::stringstream ss;
lparams.toString();
@ -123,9 +137,7 @@ TensorView* makeContigTensor(size_t ndims, DataType dtype) {
.build();
}
TensorView* makeConcreteTensor(
std::vector<int64_t> shape,
DataType dtype) {
TensorView* makeConcreteTensor(std::vector<int64_t> shape, DataType dtype) {
return TensorViewBuilder().shape(shape).dtype(dtype).build();
}
@ -157,15 +169,9 @@ void runBenchmarkIterations(
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
auto executor_instance = compile_log.fusion_executor;
if (compile_log.reduction_params.has_value()) {
auto rparams = toString(compile_log.reduction_params.value());
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
benchmark_state.SetLabel(rparams + lparams);
} else if (compile_log.pointwise_params.has_value()){
auto pparams = toString(compile_log.pointwise_params.value());
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
benchmark_state.SetLabel(pparams + lparams);
}
auto params = toString(compile_log.params);
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
benchmark_state.SetLabel(params + lparams);
executor_instance->setMeasureKernelTimeFlag(true);

View File

@ -38,6 +38,7 @@ TensorView* makeContigConcreteTensor(
std::string toString(ReductionParams rparams);
std::string toString(PointwiseParams params);
std::string toString(const std::shared_ptr<HeuristicParams>& params);
std::string toString(LaunchParams lparams);
// Run benchmark iterations with provided inputs. If not segmented, report

View File

@ -717,8 +717,10 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/register_interface.cpp",
"torch/csrc/jit/codegen/cuda/root_domain_map.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",

View File

@ -669,6 +669,7 @@ class TestCudaFuser(JitTestCase):
torch.isreal,
torch.nn.functional.softplus,
torch.nn.functional.gelu,
torch.nn.functional.leaky_relu,
torch.nn.functional.silu,
torch.relu,
torch.sigmoid,
@ -4938,6 +4939,28 @@ class TestCudaFuser(JitTestCase):
t2_jit = torch.jit.script(t2)
self._run_helper(t2_jit, t2, x0, x1, x2, check_stride=True)
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_type_inference(self):
device = "cuda"
x0 = torch.randn(10, 128, device=device)
x1 = torch.rand_like(x0)
x2 = torch.rand_like(x0)
def t(x0, x1, x2, flag : bool = True):
x3 = 2.0 * x0
x4 = 2.0 * x1
x5 = 2.0 * x2
if flag:
return torch.stack([x3, x4, x5], dim=-1)
# second code path doesn't run through profiling
# hence would utilize type inference with profiling information
return x0 + x1 + x2
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x0, x1, x2, check_stride=True)
class TestEnableDisableCudaFuser(JitTestCase):
def setUp(self):

View File

@ -466,6 +466,7 @@ NVFUSER_DEFINE_UNARY_OP(relu, Relu)
NVFUSER_DEFINE_UNARY_OP(round, Round)
NVFUSER_DEFINE_UNARY_OP(silu, Silu)
NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
NVFUSER_DEFINE_UNARY_OP(print, Print)
#undef NVFUSER_DEFINE_UNARY_OP
Val* randlike(Val* v) {
@ -1430,12 +1431,6 @@ WelfordResult::WelfordResult(
TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(n->definition()));
}
WelfordResult WelfordResult::rFactor(const std::vector<int>& axes) {
auto o_tv = avg->definition()->as<WelfordOp>()->out()->as<TensorView>();
auto rf_tvs = o_tv->rFactor(axes, std::vector<TensorView*>{avg, var_sum, n});
return WelfordResult{rf_tvs.at(0), rf_tvs.at(1), rf_tvs.at(2)};
}
// COMPOUND OPERATIONS
// add_alpha

View File

@ -107,8 +107,6 @@ class TORCH_CUDA_CU_API WelfordResult {
TensorView* in_avg,
TensorView* in_var_sum,
TensorView* in_n);
WelfordResult rFactor(const std::vector<int>& axes);
};
//! Welford operator on specified axes. This is currently the only scan op with
@ -253,6 +251,9 @@ TORCH_CUDA_CU_API TensorView* isposinf(TensorView*);
// isreal
TORCH_CUDA_CU_API Val* isreal(Val*);
TORCH_CUDA_CU_API TensorView* isreal(TensorView*);
// print
TORCH_CUDA_CU_API Val* print(Val*);
TORCH_CUDA_CU_API TensorView* print(TensorView*);
// Broadcasts inp based on bool vector. Size of broadcast bool vector should be
// the number of dims desired in the broadcasted tensor. This vector should be

View File

@ -474,7 +474,11 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}
void handle(const kir::TensorIndex* ti) final {
//! Returns the sum of all indices in a TensorIndex,
//! or 0 if the indices vector is empty.
//! Used lowering generic tensor index and lowering
//! mma fragment indices.
std::string genTensorIndex(const kir::TensorIndex* ti) {
bool first = true;
std::stringstream index;
for (auto* ind : ti->indices()) {
@ -490,12 +494,17 @@ class CudaKernelGenerator : private OptOutConstDispatch {
if (first) {
index << "0";
}
return index.str();
}
void handle(const kir::TensorIndex* ti) final {
bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global &&
kernel_->summary().sync_map.needsRawSync(ti->view()).hasBID();
if (is_volatile) {
code_ << "*(volatile " << ti->getDataType().value() << "*)&";
}
code_ << varName(ti->view()) << "[" << index.str() << "]";
code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]";
}
void handle(const ViewAsScalar* sv) final {
@ -621,7 +630,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
// Double buffered local tensors need indexed initialization,
// so will need to use `arraySet` option.
if (out_tv->getMemoryType() == MemoryType::Local &&
!out_tv->isDoubleBuffered()) {
!(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered())) {
// Vectorized initialization
indent() << varName(out_tv) << ".set(" << gen(uop->in()) << ");\n";
} else {
@ -971,13 +980,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}
std::string genArchString(MmaOptions options) {
std::string genArchString(MmaOptions::MacroType macro) {
std::stringstream ss;
if (isVolta(options.macro)) {
if (isVolta(macro)) {
ss << "Volta";
} else if (isTuring(options.macro)) {
} else if (isTuring(macro)) {
ss << "Turing";
} else if (isAmpere(options.macro)) {
} else if (isAmpere(macro)) {
ss << "Ampere";
} else {
TORCH_INTERNAL_ASSERT(false, "mma macro unknown arch");
@ -988,7 +997,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
std::string genMmaOp(const MmaOp* mma, bool init = false) {
std::stringstream ss;
auto options = mma->options();
ss << genArchString(options) << "::";
ss << genArchString(options.macro) << "::";
if (init) {
ss << "init";
}
@ -1013,14 +1022,17 @@ class CudaKernelGenerator : private OptOutConstDispatch {
auto options = mma->options();
auto in_a = mma->inA()->as<kir::TensorIndex>()->view();
auto dtype = in_a->getDataType().value();
indent() << kTab << "reinterpret_cast<Array<" << dtype << ","
indent() << kTab << "&(reinterpret_cast<Array<" << dtype << ","
<< getInputARegisterSize(options.macro) << ","
<< getInputARegisterSize(options.macro) << ">*>(&"
<< gen(mma->inA()) << "),\n";
indent() << kTab << "reinterpret_cast<Array<" << dtype << ","
<< varName(mma->inA()->as<kir::TensorIndex>()->view()) << ")["
<< genTensorIndex(mma->inA()->as<kir::TensorIndex>()) << "])"
<< ",\n";
indent() << kTab << "&(reinterpret_cast<Array<" << dtype << ","
<< getInputBRegisterSize(options.macro) << ","
<< getInputBRegisterSize(options.macro) << ">*>(&"
<< gen(mma->inB()) << ")";
<< varName(mma->inB()->as<kir::TensorIndex>()->view()) << ")["
<< genTensorIndex(mma->inB()->as<kir::TensorIndex>()) << "])";
}
void genMmaInitialization(const MmaOp* mma, const UnaryOp* uop) {
@ -2332,7 +2344,19 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
void handle(const kir::CpAsyncWait* cpasync_wait) final {
indent() << "Ampere::cpAsyncBarrier();\n";
if (cpasync_wait->keepStages() > 0) {
// Perform partial sync, see comment on kir::CpAsyncWait.
indent() << "Ampere::cpAsyncPartialBarrier<" << cpasync_wait->keepStages()
<< ">();\n";
} else {
// Perform sync all, see comment on kir::CpAsyncWait.
indent() << "Ampere::cpAsyncBarrier();\n";
}
}
void handle(const kir::CpAsyncCommit* cpasync_wait) final {
// Commit inflight cp.async transfers. See comment on kir::CpAsyncCommit.
indent() << "Ampere::cpAsyncCommit();\n";
}
void handle(const kir::GridSync* sync) final {
@ -2390,11 +2414,9 @@ class CudaKernelGenerator : private OptOutConstDispatch {
void handle(const kir::IntPair* int_pair) {
const auto def = int_pair->definition();
if (print_inline_) {
code_ << gen(def);
} else {
code_ << varName(int_pair);
}
TORCH_INTERNAL_ASSERT(
def != nullptr, "no support for un-inlined int pair yet.");
code_ << gen(def);
}
void handle(const kir::PairSelect* pair_select) {

View File

@ -14,6 +14,35 @@ namespace jit {
namespace fuser {
namespace cuda {
// Simple selector that only propagates across tensor views in the provided
// unordered_set. Will also propagate to all consumers of those tensors, and the
// siblings of those tensors.
class ComputeAtSelector : public MaxInfoSpanningTree::Selector {
std::unordered_set<TensorView*> selected_;
public:
virtual bool allowC2P(TensorView* from, TensorView* to) override {
return selected_.count(to) > 0;
}
virtual bool allowP2C(TensorView* from, TensorView* to) override {
// If the producer is in the selected set, then the consumer must also be
// replayed to obtain a compatible loop structure so that this producer
// can be consumed in this loop.
return selected_.count(from) > 0 || selected_.count(to) > 0;
}
virtual bool allowSibling(TensorView* from, TensorView* to) override {
return true;
}
ComputeAtSelector(std::unordered_set<TensorView*> selected)
: selected_(std::move(selected)) {}
const std::unordered_set<TensorView*>& selected() const {
return selected_;
}
};
namespace {
// Wrapper around set_intersection
@ -182,11 +211,10 @@ void ComputeAt::runAt(
FusionGuard fg(producer->fusion());
auto selected = getPropagationSubgraph(producer, consumer);
InlinePropagatorSelector selector(selected);
ComputeAtSelector selector(selected);
InlinePropagator inline_propagator(
consumer, consumer_position, mode, selector.selected());
MaxProducerPosUpdater updater;
MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);
@ -199,7 +227,6 @@ void ComputeAt::runAt(
}
path.traverse(&inline_propagator);
path.traverse(&updater);
}
void ComputeAt::runWith(
@ -224,11 +251,10 @@ void ComputeAt::runWith(
FusionGuard fg(producer->fusion());
auto selected = getPropagationSubgraph(producer, consumer);
InlinePropagatorSelector selector(selected);
ComputeAtSelector selector(selected);
InlinePropagator inline_propagator(
producer, producer_position, mode, selector.selected());
MaxProducerPosUpdater updater;
MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector);
@ -240,7 +266,6 @@ void ComputeAt::runWith(
path.traverse(&propagator);
}
path.traverse(&inline_propagator);
path.traverse(&updater);
}
} // namespace cuda

View File

@ -178,6 +178,8 @@ void IterDomainGraph::build(Fusion* fusion) {
BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map);
const auto& permissive_c2p_map = permissive_replay_PasC.getReplay();
const auto permissive_disjoint_sets =
permissive_replay_PasC.getDisjointSets();
// For exact mapings do not map any broadcast dimensions to
// non-broadcast dimensions. Prevent any broadcasted axes being mapped
@ -213,6 +215,17 @@ void IterDomainGraph::build(Fusion* fusion) {
auto p_id = entry.second;
if (idIsAComputeAtLeafDomain(p_id, p_tv)) {
loop_nodes_.mapEntries(c_id, p_id);
} else {
// When there are trivial reductions merged with other dims, `p_id`
// might not be a compute at leaf domain of `p_tv`, but it actually
// has an equivalent compute at leaf domain. For that case, we map
// the equivalent compute at leaf domain.
for (int i = 0; i < p_tv->getComputeAtPosition(); i++) {
auto id = p_tv->axis(i);
if (permissive_disjoint_sets.permissiveAreMapped(p_id, id)) {
loop_nodes_.mapEntries(c_id, id);
}
}
}
permissive_nodes_.mapEntries(c_id, p_id);
consumers_.at(p_id).pushBack(c_id);
@ -225,8 +238,8 @@ void IterDomainGraph::build(Fusion* fusion) {
mapMaybeSwizzleOp(permissive_nodes_, c_id);
}
// Make sure we always get root mapping for the permissive map. Because
// of forwarding we could otherwise miss some root mappings.
// Make sure we always get root mapping for the permissive map.
// Because of forwarding we could otherwise miss some root mappings.
for (auto entry : permissive_c2p_root_map) {
auto c_id = entry.first;
auto p_id = entry.second;

View File

@ -96,6 +96,47 @@ class VectorOfUniqueEntries {
return set_.find(entry) != set_.end();
}
// Erase given entry from the containers if
// there is a match.
void erase(T entry) {
vector_.erase(
std::remove_if(
vector_.begin(),
vector_.end(),
[entry](T val) { return val == entry; }),
vector_.end());
set_.erase(entry);
}
// Insert elements at the end of the container.
template <typename InputIt>
void insert(InputIt begin, InputIt end) {
for (auto it = begin; it != end; it++) {
pushBack(*it);
}
}
// Returns iterator pointing to the beginning of vector container
auto begin() const {
return vector().begin();
}
// Returns iterator pointing to the end of vector container
auto end() const {
return vector().end();
}
// Returns iterator pointing to the beginning of vector container
auto begin() {
return vector().begin();
}
// Returns iterator pointing to the end of vector container
auto end() {
return vector().end();
}
std::string toString() {
std::stringstream ss;
ss << "{ ";

View File

@ -163,6 +163,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::CpAsyncWait:
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
return;
case ExprType::CpAsyncCommit:
ptr(handler)->handle(expr->as<kir::CpAsyncCommit>());
return;
case ExprType::InitMagicZero:
ptr(handler)->handle(expr->as<kir::InitMagicZero>());
return;
@ -334,6 +337,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::CpAsyncWait:
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
return;
case ExprType::CpAsyncCommit:
ptr(handler)->handle(expr->as<kir::CpAsyncCommit>());
return;
case ExprType::InitMagicZero:
ptr(handler)->handle(expr->as<kir::InitMagicZero>());
return;
@ -513,6 +519,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::CpAsyncWait:
ptr(mutator)->mutate(expr->as<kir::CpAsyncWait>());
return;
case ExprType::CpAsyncCommit:
ptr(mutator)->mutate(expr->as<kir::CpAsyncCommit>());
return;
case ExprType::InitMagicZero:
ptr(mutator)->mutate(expr->as<kir::InitMagicZero>());
return;
@ -757,6 +766,9 @@ void OptOutConstDispatch::handle(const kir::GridSync* stmt) {
void OptOutConstDispatch::handle(const kir::CpAsyncWait* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::CpAsyncCommit* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::InitMagicZero* stmt) {
unhandled(stmt);
}
@ -898,6 +910,9 @@ void OptOutDispatch::handle(kir::GridSync* stmt) {
void OptOutDispatch::handle(kir::CpAsyncWait* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::CpAsyncCommit* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::InitMagicZero* stmt) {
unhandled(stmt);
}

View File

@ -98,6 +98,7 @@ class Allocate;
class BlockSync;
class GridSync;
class CpAsyncWait;
class CpAsyncCommit;
class ForLoop;
class IfThenElse;
class GridReduction;
@ -163,6 +164,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const kir::BlockSync*);
virtual void handle(const kir::GridSync*);
virtual void handle(const kir::CpAsyncWait*);
virtual void handle(const kir::CpAsyncCommit*);
virtual void handle(const kir::InitMagicZero*);
virtual void handle(const kir::UpdateMagicZero*);
virtual void handle(const kir::ForLoop*);
@ -225,6 +227,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(kir::BlockSync* stmt);
virtual void handle(kir::GridSync* stmt);
virtual void handle(kir::CpAsyncWait* stmt);
virtual void handle(kir::CpAsyncCommit* stmt);
virtual void handle(kir::InitMagicZero* stmt);
virtual void handle(kir::UpdateMagicZero* stmt);
virtual void handle(kir::ForLoop* stmt);
@ -328,6 +331,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(kir::BlockSync*);
virtual void mutate(kir::GridSync*);
virtual void mutate(kir::CpAsyncWait*);
virtual void mutate(kir::CpAsyncCommit*);
virtual void mutate(kir::InitMagicZero*);
virtual void mutate(kir::UpdateMagicZero*);
virtual void mutate(kir::ForLoop*);

View File

@ -93,7 +93,9 @@ std::string FusionExecutor::getStructuredCode(const std::string& kernel) {
std::cout << "\n======= Codegen output for kernel: " << kernelName()
<< " =======\n\n"
<< code << "\n======================================\n\n";
} else if (isDebugDumpEnabled(DebugDumpOption::CudaToFile)) {
}
if (isDebugDumpEnabled(DebugDumpOption::CudaToFile) ||
isDebugDumpEnabled(DebugDumpOption::DebugInfo)) {
std::stringstream file_name;
file_name << "__tmp_kernel" << fusion_id_ << ".cu";
std::cout << "PRINTING: " << file_name.str() << std::endl;

View File

@ -799,7 +799,7 @@ kir::ExpressionEvaluator bindKernelInputs(
extent->toString(),
" to ",
value,
"but it's already set to ",
" but it's already set to ",
*prev_value);
should_bind = false;
}
@ -925,9 +925,12 @@ std::pair<NvrtcFunction, std::string> nvrtcCompile(
nvrtcProgram program; // NOLINT(cppcoreguidelines-init-variables)
{
std::stringstream ss;
ss << "__tmp_kernel" << id << ".cu";
std::string name = ss.str();
FUSER_PERF_SCOPE("executor_utils::NvrtcCreateProgram");
AT_CUDA_NVRTC_CHECK(at::globalContext().getNVRTC().nvrtcCreateProgram(
&program, code.c_str(), nullptr, 0, nullptr, nullptr));
&program, code.c_str(), name.c_str(), 0, nullptr, nullptr));
}
ResourceGuard holdProgram([&] {
@ -974,11 +977,13 @@ std::pair<NvrtcFunction, std::string> nvrtcCompile(
args.push_back("--fmad=true");
}
#endif
#ifndef NDEBUG
// Add line info to generated kernels
args.push_back("-lineinfo");
#else
if (isDebugDumpEnabled(DebugDumpOption::DebugInfo)) {
args.push_back("-lineinfo");
args.push_back("-G");
args.push_back("--dopt=on");
}
#ifdef NDEBUG
// Avoid excessive register usage from assertion
args.push_back("-DNDEBUG");
#endif

View File

@ -1,5 +1,6 @@
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/codegen.h>
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
@ -229,15 +230,6 @@ void Fusion::addOutput(Val* output) {
all_tv_uses_valid_ = false;
}
void Fusion::addOutput(WelfordResult& wr) {
// Want to always make sure the avg gets added last
// since avg will be the out() value of welfordOp,
// and want to make it the top of the computeAt chain
addOutput(wr.var_sum);
addOutput(wr.n);
addOutput(wr.avg);
}
void Fusion::removeInput(Val* input) {
auto find_input = std::find(inputs_.begin(), inputs_.end(), input);
if (find_input != inputs_.end()) {
@ -516,7 +508,19 @@ std::vector<Val*> Fusion::usedMathVals() {
return used_math_vals;
}
std::unordered_set<Expr*> Fusion::unordered_uses(Val* val) const {
std::vector<Val*> Fusion::terminatingMathVals() {
VectorOfUniqueEntries<Val*> result;
auto used_vals = usedMathVals();
for (auto v : used_vals) {
// Locate the vals that are not expr outputs but have valid definitions.
if (unordered_uses(v).empty() && v->definition() != nullptr) {
result.pushBack(v);
}
}
return result.vector();
}
std::unordered_set<Expr*> Fusion::unordered_uses(const Val* val) const {
return std::unordered_set<Expr*>(val->uses().begin(), val->uses().end());
}

View File

@ -110,9 +110,6 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
//! Register output as an output of the fusion
void addOutput(Val* output);
//! Register output as an output of the fusion
void addOutput(WelfordResult& output);
//! Deregister input as an input of the fusion
void removeInput(Val* input);
@ -153,8 +150,16 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
//! also included as they must show up in the final code.
std::vector<Val*> usedMathVals();
//! Returns all vals that are produced by used math expressions and
//! also do not have further consumers.
//!
//! In the case of an active multi-output expressions, the returned vector
//! will include the expression outputs that did not lead to an fusion
//! output.
std::vector<Val*> terminatingMathVals();
//! Return all Exprs that use val
std::unordered_set<Expr*> unordered_uses(Val* val) const;
std::unordered_set<Expr*> unordered_uses(const Val* val) const;
//! Return the Expr that produces val
Expr* definition(const Val* val) const;

View File

@ -16,6 +16,12 @@ namespace jit {
namespace fuser {
namespace cuda {
namespace {
using GroupSet = VectorOfUniqueEntries<SegmentedGroup*>;
} // namespace
std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::getNeighborGroups() {
std::vector<NeighborGroup> neighbors;
for (auto inp : producer_edges) {
@ -75,7 +81,7 @@ std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::
return {};
}
std::vector<bool> can_merge(true, neighbors.size());
std::vector<bool> can_merge(neighbors.size(), true);
// Find neighbors with a level that is only 1 differant than this groups level
for (const auto i : c10::irange(neighbors.size())) {
@ -155,16 +161,16 @@ void insertUniquePredicated(
std::vector<Val*>& v,
const std::vector<SegmentedEdge*>& e,
PREDICATE pred) {
std::unordered_set<Val*> to_add;
std::transform(
e.cbegin(),
e.cend(),
std::inserter(to_add, to_add.end()),
[](SegmentedEdge* se) { return se->val; });
VectorOfUniqueEntries<Val*> to_add;
for (auto edge : e) {
to_add.pushBack(edge->val);
}
std::copy_if(
to_add.begin(), to_add.end(), std::back_inserter(v), [pred](Val* val) {
return pred(val);
});
to_add.vector().begin(),
to_add.vector().end(),
std::back_inserter(v),
[pred](Val* val) { return pred(val); });
}
void SegmentedGroup::finalize() {
@ -811,7 +817,6 @@ void SegmentedFusion::finalize() {
//! currently O(n^2). O(nlogn) would be a reasonable
//! goal to achieve.
class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
using GroupSet = std::unordered_set<SegmentedGroup*>;
using GroupSetOwningPtr = std::unique_ptr<GroupSet>;
using DependencyMap = std::unordered_map<SegmentedGroup*, GroupSetOwningPtr>;
@ -829,7 +834,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
const std::vector<SegmentedGroup*>& groups_to_check) {
auto& producers_of_group = getAllKnownProducersSet(group);
for (const auto& potential_producer : groups_to_check) {
if (producers_of_group->count(potential_producer)) {
if (producers_of_group->has(potential_producer)) {
return true;
}
}
@ -841,7 +846,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
if (it == known_producers_of_.end()) {
return false;
}
return it->second->count(b);
return it->second->has(b);
}
bool isProducerOf(SegmentedGroup* a, SegmentedGroup* b) {
@ -872,18 +877,14 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
GroupSet values_between;
auto& all_producers_of_consumer = known_producers_of_.at(consumer);
TORCH_INTERNAL_ASSERT(
all_producers_of_consumer->count(producer),
all_producers_of_consumer->has(producer),
"Fusion segment: Trying to compute path between two nodes that are not producer-consumer pairs");
std::copy_if(
all_producers_of_consumer->begin(),
all_producers_of_consumer->end(),
std::inserter(values_between, values_between.end()),
[this, producer](SegmentedGroup* producer_of_consumer) {
// Checks if producer is on the producer path of this intermediate
// node
return known_producers_of_.at(producer_of_consumer)->count(producer);
});
for (auto producer_of_consumer : *all_producers_of_consumer) {
if (known_producers_of_.at(producer_of_consumer)->has(producer)) {
values_between.pushBack(producer_of_consumer);
}
}
return values_between;
}
@ -892,7 +893,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
//! used for generating assertions after transforms
bool isproducerMapDAG() const {
for (auto& it : known_producers_of_) {
if (it.second->count(it.first)) {
if (it.second->has(it.first)) {
return false;
}
}
@ -909,7 +910,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
void addConsumersToWorkList(SegmentedGroup* producer, GroupSet& to_visit) {
for (auto e : producer->consumer_edges) {
// A consumer wouldn't have been worked before any of its producer
to_visit.insert(e->to);
to_visit.pushBack(e->to);
}
}
@ -922,7 +923,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
SegmentedGroup* from) {
auto& producer_set_to_merge = *getAllKnownProducersSet(from);
for (auto group : producer_set_to_merge) {
getAllKnownProducersSet(into)->insert(group);
getAllKnownProducersSet(into)->pushBack(group);
}
}
@ -943,8 +944,8 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
GroupSet intersection;
for (auto group : smaller_group_set) {
if (bigger_group_set.count(group)) {
intersection.insert(group);
if (bigger_group_set.has(group)) {
intersection.pushBack(group);
}
}
return intersection;
@ -956,7 +957,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
};
//! Finds the common producers of given set of groups
GroupDependencyAnalysis::GroupSet GroupDependencyAnalysis::getCommonProducersOf(
GroupSet GroupDependencyAnalysis::getCommonProducersOf(
std::vector<SegmentedGroup*> groups) {
if (groups.empty()) {
return {};
@ -1006,9 +1007,9 @@ void GroupDependencyAnalysis::mergeGroups(
// update producer maps of other groups
for (auto& it : known_producers_of_) {
// for all groups that are produced by either a or b
if (it.second->count(a) || it.second->count(b)) {
if (it.second->has(a) || it.second->has(b)) {
// insert ab as the new producer
it.second->insert(ab);
it.second->pushBack(ab);
// all producers of both a and b are now producers of `it`
mergeAllKnownProducersIntoFrom(it.first, ab);
}
@ -1054,7 +1055,7 @@ void GroupDependencyAnalysis::mergeGroups(
it.second->erase(merged_producer);
}
// insert the new group as producer
it.second->insert(merged);
it.second->pushBack(merged);
}
}
}
@ -1068,11 +1069,11 @@ void GroupDependencyAnalysis::computeAllProducers() {
// Collect source nodes, with no producers we are guaranteed
// a source node on a DAG
std::copy_if(
segmented_fusion_->cgroups().begin(),
segmented_fusion_->cgroups().end(),
std::inserter(visited, visited.end()),
[](SegmentedGroup* group) { return group->producer_edges.empty(); });
for (auto group : segmented_fusion_->cgroups()) {
if (group->producer_edges.empty()) {
visited.pushBack(group);
}
}
// visited now only contain source nodes
// they can go backward to nowhere
@ -1086,20 +1087,18 @@ void GroupDependencyAnalysis::computeAllProducers() {
if (std::all_of(
visiting_group->producer_edges.begin(),
visiting_group->producer_edges.end(),
[&visited](SegmentedEdge* e) {
return visited.count(e->from);
})) {
[&visited](SegmentedEdge* e) { return visited.has(e->from); })) {
// filter multi-edges
GroupSet producers_of_visiting_group;
for (auto edge : visiting_group->producer_edges) {
producers_of_visiting_group.insert(edge->from);
producers_of_visiting_group.pushBack(edge->from);
}
// populate all possible paths
// from producer backward, including
// the producer
for (auto producer : producers_of_visiting_group) {
getAllKnownProducersSet(visiting_group)->insert(producer);
getAllKnownProducersSet(visiting_group)->pushBack(producer);
mergeAllKnownProducersIntoFrom(visiting_group, producer);
}
to_update = visiting_group;
@ -1109,7 +1108,7 @@ void GroupDependencyAnalysis::computeAllProducers() {
if (to_update) {
addConsumersToWorkList(to_update, to_visit);
to_visit.erase(to_update);
visited.insert(to_update);
visited.pushBack(to_update);
} else {
TORCH_INTERNAL_ASSERT(false, "unreachable, original graph not a DAG");
}
@ -2060,7 +2059,6 @@ bool SegmentCandidateFinder::TranslateWelfordInFusion(
//! This pass tries to merge nodes with the same reduction type based
//! on the graph structure.
class CombineReductions {
using GroupSet = std::unordered_set<SegmentedGroup*>;
using GroupVec = std::vector<SegmentedGroup*>;
class ReductionSignature;
@ -2240,7 +2238,7 @@ class CombineReductions {
groups_with_reductions_.begin(),
groups_with_reductions_.end(),
[&all_groups_to_merge](SegmentedGroup* group) {
return all_groups_to_merge.count(group);
return all_groups_to_merge.has(group);
}),
groups_with_reductions_.end());
@ -2374,7 +2372,7 @@ class CombineReductions {
groups_with_reductions_.begin(),
groups_with_reductions_.end(),
[&groups_to_merge_set](SegmentedGroup* group) {
return groups_to_merge_set.count(group);
return groups_to_merge_set.has(group);
}),
groups_with_reductions_.end());
@ -2414,8 +2412,8 @@ class CombineReductions {
maybe_consumer, maybe_producer)) {
auto groups_to_check =
dependency_analysis->valuesBetween(maybe_producer, maybe_consumer);
groups_to_check.insert(maybe_producer);
groups_to_check.insert(maybe_consumer);
groups_to_check.pushBack(maybe_producer);
groups_to_check.pushBack(maybe_consumer);
// Check that either no group has a reduction or all groups have the same
// reduction signature
@ -2428,13 +2426,13 @@ class CombineReductions {
// output edge does not generate much saving of global memory access
// we want to postpone merging these edges till the very final pass
for (auto producer_edge_of_group : group->producer_edges) {
if (groups_to_check.count(producer_edge_of_group->from) &&
if (groups_to_check.has(producer_edge_of_group->from) &&
producer_edge_of_group->val->isFusionOutput()) {
return {};
}
}
for (auto consumer_edge_of_group : group->consumer_edges) {
if (groups_to_check.count(consumer_edge_of_group->to) &&
if (groups_to_check.has(consumer_edge_of_group->to) &&
consumer_edge_of_group->val->isFusionOutput()) {
return {};
}
@ -2653,11 +2651,11 @@ void SegmentCandidateFinder::findSegments() {
// Expressions to exclude from segmentation because they're just derived from
// unary ops on inputs to the complete fusion
std::unordered_set<Expr*> excluded_inp_unary_exprs;
VectorOfUniqueEntries<Expr*> excluded_inp_unary_exprs;
// "Terminating" outputs from the excluded input unary exprs, these will be
// treated as complete fusion inputs.
std::unordered_set<Val*> forwarded_inputs;
VectorOfUniqueEntries<Val*> forwarded_inputs;
{
std::deque<Expr*> to_visit;
for (auto inp : completeFusion()->inputs()) {
@ -2677,8 +2675,8 @@ void SegmentCandidateFinder::findSegments() {
}
if (expr->output(0)->uses().size() > 1) {
excluded_inp_unary_exprs.emplace(expr);
forwarded_inputs.emplace(expr->output(0));
excluded_inp_unary_exprs.pushBack(expr);
forwarded_inputs.pushBack(expr->output(0));
continue;
}
@ -2735,7 +2733,7 @@ void SegmentCandidateFinder::findSegments() {
continue;
}
if (excluded_inp_unary_exprs.count(expr)) {
if (excluded_inp_unary_exprs.has(expr)) {
continue;
}

View File

@ -527,13 +527,43 @@ void IndexCompute::handle(Swizzle2D* swizzle_2d) {
const auto out_x_ind = out_x_it->second;
const auto out_y_ind = out_y_it->second;
// Actual swizzle operation is handled via IndexSwizzle pass
// all behavior in this pass is directly forward through the
// index and extent.
index_map_[in_x_id] = out_x_ind;
index_map_[in_y_id] = out_y_ind;
extent_map_[in_y_id] = getExtent(out_y_id);
extent_map_[in_x_id] = getExtent(out_x_id);
if (swizzle_mode_ == SwizzleMode::NoSwizzle ||
swizzle_mode_ != swizzle_2d->swizzleMode()) {
// Handle inactive swizzles by just passing through index
// and extend information.
TORCH_INTERNAL_ASSERT(
index_map_.count(in_x_id) == index_map_.count(in_y_id),
"input index should be either both defined or both undefined");
if (index_map_.count(in_x_id)) {
// Only propagate original index through if
// the input index hasn't been computed.
// TODO:
// This part should be cleaner once we remove the
// second index traversal pass.
return;
}
index_map_[in_x_id] = out_x_ind;
index_map_[in_y_id] = out_y_ind;
extent_map_[in_y_id] = getExtent(out_y_id);
extent_map_[in_x_id] = getExtent(out_x_id);
} else {
// Generate integer swizzle math if the
// swizzle is activated. See also
// [Note on swizzle mode].
auto out_pair = IrBuilder::swizzle2DIntExpr(
out_x_ind,
out_y_ind,
getExtent(out_x_id),
getExtent(out_y_id),
swizzle_2d->swizzleType());
index_map_[in_x_id] =
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X);
index_map_[in_y_id] =
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y);
}
}
void IndexCompute::handle(Expr* e) {
@ -616,9 +646,31 @@ IndexCompute::IndexCompute(
reference_halo_extent_map_(std::move(reference_halo_extent_map)) {
FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute");
concrete_id_pass_ = true;
swizzle_mode_ = SwizzleMode::Loop;
}
void IndexCompute::run(const LoopIndexing& loop_indexing) {
// Apply loop swizzles if there are any that outputs to
// the loop domains.
// Currently only support loop swizzles that directly output
// to concrete loop domains and these are validated in
// validate swizzle pass.
// TODO:
// will gradually enable replaying and mapping of loop
// swizzles in the IR infrastructure and once that's piped
// through this part of logic will be removed.
std::unordered_set<Expr*> visited;
for (auto loop_id : loop_indexing.loopDomains()) {
auto loop_id_def = loop_id->definition();
if (loop_id_def != nullptr && loop_id_def->isA<Swizzle2D>()) {
if (visited.insert(loop_id_def).second) {
handle(loop_id_def);
}
}
}
// Run through the loop indexing expressions and generate
// the indexing integer math for the concrete ids.
for (auto expr : loop_indexing.getBackwardExprList()) {
handle(expr);
}
@ -955,6 +1007,7 @@ void IndexSwizzle::run() {
UpdateLeafIndices update_leaves(td_, indexMap(), extentMap());
index_map_ = update_leaves.indexMap();
extent_map_ = update_leaves.extentMap();
IndexCompute::swizzle_mode_ = SwizzleMode::Data;
IndexCompute::run();
}
}
@ -969,7 +1022,8 @@ void IndexSwizzle::handle(Expr* e) {
return swizzled_ids_.find(id) != swizzled_ids_.end();
}) ||
(e->isA<Swizzle2D>() &&
e->as<Swizzle2D>()->swizzleType() != Swizzle2DType::NoSwizzle);
e->as<Swizzle2D>()->swizzleType() != Swizzle2DType::NoSwizzle &&
e->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Data);
if (!needs_update) {
return;
}
@ -983,8 +1037,6 @@ void IndexSwizzle::handle(Expr* e) {
void IndexSwizzle::handle(Swizzle2D* swizzle_2d) {
auto out_x_id = swizzle_2d->outX();
auto out_y_id = swizzle_2d->outY();
auto in_x_id = swizzle_2d->inX();
auto in_y_id = swizzle_2d->inY();
auto out_x_it = index_map_.find(out_x_id);
auto out_y_it = index_map_.find(out_y_id);
@ -998,28 +1050,7 @@ void IndexSwizzle::handle(Swizzle2D* swizzle_2d) {
out_x_it != index_map_.end() && out_y_it != index_map_.end(),
"Swizzle output indices were not propagated through");
const auto out_x_ind = out_x_it->second;
const auto out_y_ind = out_y_it->second;
// Can propagate zero only for a few
// swizzle types (TODO)
if (swizzle_2d->swizzleType() != Swizzle2DType::NoSwizzle) {
auto out_pair = IrBuilder::swizzle2DIntExpr(
out_x_ind,
out_y_ind,
getExtent(out_x_id),
getExtent(out_y_id),
swizzle_2d->swizzleType());
index_map_[in_x_id] =
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X);
index_map_[in_y_id] =
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y);
swizzled_ids_.insert(in_x_id);
swizzled_ids_.insert(in_y_id);
}
IndexCompute::handle(swizzle_2d);
}
// Used for local and shared index mapping. Returns a map from loops
@ -1125,7 +1156,16 @@ indexMapFromTV(
// Similarly for local memory tensors, zero replacement can be
// only done when there's a matching domain with the same
// parallel type
(loop->iter_domain()->isThread() && is_local && same_parallel_type)) {
(loop->iter_domain()->isThread() && is_local && same_parallel_type) ||
// MMA operands are currently indexed in units of "fragments",
// so each mma tensor domain would be zero-ed and the tensor index
// calculated here would be the fragment index.
// TODO: This is a quick WAR to enable iterating over a register array
// of MMA fragments, so we could generate unrolled mma loops.
// Eventually we still want IdGraph to be able to analyze the
// in-register layout of mma fragments for more unified indexing math
// as well as more flexibility in swizzling loops.
(loop->iter_domain()->isMma() && !as_consumer)) {
idx = GpuLower::current()->kernel()->zeroVal();
zero_loops.insert(loop);
} else {
@ -1139,8 +1179,11 @@ indexMapFromTV(
}
if (loop == double_buffer_loop) {
auto stage_depth =
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
loop->iter_domain());
idx = SimplifyingIrBuilder::addExpr(
idx, GpuLower::current()->kernel()->oneVal());
idx, SimplifyingIrBuilder::create<Int>(stage_depth - 1));
}
loop_to_ind_map[loop] = idx;
@ -1802,14 +1845,16 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
}
}
if (producer_tv->isDoubleBuffered()) {
if (producer_tv->isDoubleBuffered() || producer_tv->isCircularBuffered()) {
auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
producer_tv, loops, true);
if (db_loop != nullptr) {
auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor(
db_loop->iter_domain());
auto loop_index =
db_loop->isTrivial() ? db_loop->start() : db_loop->index();
auto db_switch_index = SimplifyingIrBuilder::modExpr(
loop_index, SimplifyingIrBuilder::create<Int>(2));
loop_index, SimplifyingIrBuilder::create<Int>(stage_depth));
auto original_alloc_size =
gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv);
auto db_strided_index =
@ -2068,14 +2113,36 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
TORCH_INTERNAL_ASSERT(
strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size());
if (consumer_tv->isDoubleBuffered()) {
auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
consumer_tv, loops, true);
if (db_loop != nullptr) {
auto db_switch_index = SimplifyingIrBuilder::subExpr(
gpu_lower->kernel()->oneVal(),
SimplifyingIrBuilder::modExpr(
db_loop->index(), SimplifyingIrBuilder::create<Int>(2)));
if (consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) {
auto db_loop =
gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops);
auto stage_depth =
gpu_lower->doubleBufferInfo().getStageDepthFor(db_loop->iter_domain());
bool is_circular_buffer_loop = stage_depth > 2;
bool is_prolog =
db_loop->doubleBufferLoopStage() == DoubleBufferLoopStage::Prolog;
Val* db_switch_index = nullptr;
// In double buffered we don't materialize the prolog loop as there will
// be only one iteration. In circular buffer case we materialize the
// prolog loop as well covering the first N-1 iterations, N being the
// stage depth.
if (!is_prolog || is_circular_buffer_loop) {
if (is_prolog && is_circular_buffer_loop) {
// The buffer switching logic is the same as original index
// in the case of circular buffer prolog.
db_switch_index = db_loop->index();
} else {
// Switching index generated for main loop or epilog component.
db_switch_index = SimplifyingIrBuilder::modExpr(
SimplifyingIrBuilder::addExpr(
db_loop->index(),
SimplifyingIrBuilder::create<Int>(stage_depth - 1)),
SimplifyingIrBuilder::create<Int>(stage_depth));
}
// Use the generated switching buffer index to access the buffer space.
auto original_alloc_size =
gpu_lower->doubleBufferInfo().getOriginalAllocSize(consumer_tv);
auto db_strided_index =
@ -2110,7 +2177,8 @@ std::vector<Val*> Index::getProducerStridedIndices(
TORCH_INTERNAL_ASSERT(
strided_indices.size() ==
producer->getMaybeRFactorDomain().size() +
(producer->isDoubleBuffered() ? 1 : 0));
(producer->isDoubleBuffered() || producer->isCircularBuffered() ? 1
: 0));
return strided_indices;
}
@ -2721,8 +2789,8 @@ std::pair<Val*, Val*> hoistPredicates(
Val* stop_index,
const std::vector<kir::ForLoop*>& loops,
std::vector<IterDomain*> loop_domains,
const std::unordered_map<IterDomain*, Val*> start_initial_loop_index_map,
const std::unordered_map<IterDomain*, Val*> stop_initial_loop_index_map,
const std::unordered_map<IterDomain*, Val*>& start_initial_loop_index_map,
const std::unordered_map<IterDomain*, Val*>& stop_initial_loop_index_map,
kir::ForLoop* unswitch_or_vec_loop,
IterDomain* predicated_consumer_id,
TensorView* predicated_consumer_tv) {
@ -2771,6 +2839,22 @@ std::pair<Val*, Val*> hoistPredicates(
return {hoisted_start_index, hoisted_stop_index};
}
// Updates a loop index map with a loop index protected by magic zero
std::unordered_map<IterDomain*, Val*> updateInitialLoopIndexMap(
const std::unordered_map<IterDomain*, Val*>& initial_loop_index_map,
const IndexMagicZeroInfo& magic_zero_info) {
if (magic_zero_info.original_loop_index != nullptr) {
TORCH_INTERNAL_ASSERT(magic_zero_info.protected_loop_index != nullptr);
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
magic_zero_info.loop_id, IdMappingMode::EXACT);
auto updated_map = initial_loop_index_map;
updated_map[concrete_loop_id] = magic_zero_info.protected_loop_index;
return updated_map;
} else {
return initial_loop_index_map;
}
}
} // namespace
// Returns predicates and the concrete (by loop map) root domains they cover
@ -2886,13 +2970,38 @@ std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
auto stop_index = consumer_stop_indexing_it->second;
auto start_index = consumer_start_index_map.at(contig_id);
IndexMagicZeroInfo start_magic_zero_info;
IndexMagicZeroInfo stop_magic_zero_info;
// When the start and stop indices are not the same, apply the
// magic-zero protection separately for both of them.
if (stop_index != start_index) {
start_magic_zero_info = protectPredicateIndexWithMagicZero(
start_index, start_indexing_from_idgraph, loops);
stop_magic_zero_info = protectPredicateIndexWithMagicZero(
stop_index, stop_indexing_from_idgraph, loops);
} else {
stop_magic_zero_info = protectPredicateIndexWithMagicZero(
stop_index, stop_indexing_from_idgraph, loops);
start_magic_zero_info = stop_magic_zero_info;
}
start_index = start_magic_zero_info.index;
stop_index = stop_magic_zero_info.index;
// Update the loop-index map with the magic-zero protection info
// before passing it to the hoisting function
std::tie(start_index, stop_index) = hoistPredicates(
start_index,
stop_index,
loops,
stop_indexing_from_idgraph.resolved_loop_domains,
start_indexing_from_idgraph.initial_concrete_index_map,
stop_indexing_from_idgraph.initial_concrete_index_map,
updateInitialLoopIndexMap(
start_indexing_from_idgraph.initial_concrete_index_map,
start_magic_zero_info),
updateInitialLoopIndexMap(
stop_indexing_from_idgraph.initial_concrete_index_map,
stop_magic_zero_info),
unswitch_or_vec_loop,
contig_id,
consumer_tv);
@ -2935,19 +3044,6 @@ std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
return pred_info_vec;
}
bool Index::protectWithMagicZero(
kir::ForLoop* loop,
IterDomain* reference_domain,
Val* ind) {
bool ref_dom_simple =
(reference_domain == nullptr ? true
: reference_domain->definition() != nullptr);
bool ind_simple =
(ind == nullptr ? true
: ind->definition() != nullptr && !ind->isZeroInt());
return loop->isUnrolled() && (!ref_dom_simple || !ind_simple);
}
RootPredicateInfo RootPredicateInfo::getFalseInfo() {
RootPredicateInfo info;
info.start_predicate_ = GpuLower::current()->kernel()->falseVal();

View File

@ -130,6 +130,13 @@ class IndexCompute : public BackwardVisitor {
// map rather than the actual IDs used in the ID expressions.
bool concrete_id_pass_ = false;
// Mode of swizzle that are activated in this index compute
// instance. Will treat swizzles of different mode as no-op.
// Currently data mode swizzles are handled same as before in IndexSwizzle
// pass, while loop mode swizzles are handled early on in concrete indexing
// pass. See also [Note on swizzle mode]
SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle;
public:
const std::unordered_map<IterDomain*, Val*>& indexMap() const {
return index_map_;
@ -361,18 +368,6 @@ class Index {
const std::vector<kir::ForLoop*>& loops,
kir::ForLoop* unswitch_or_vec_loop,
bool padding_predicate);
// Determine if we may run into over reuse of predicates or registers in the
// compiler. If the loop can be unrolled and the index and domain are not
// "simple" we likely want the loop protected.
//
// Magic zero protection should only be done for global memory and predicates.
// We should avoid use on registers. Shared memory does not require it, but
// likely wouldn't hurt.
static bool protectWithMagicZero(
kir::ForLoop* loop,
IterDomain* reference_domain = nullptr,
Val* ind = nullptr);
};
// Used for local and shared index mapping. Returns a map from loops

View File

@ -10,22 +10,10 @@ namespace jit {
namespace fuser {
namespace cuda {
bool InlinePropagatorSelector::allowC2P(TensorView* from, TensorView* to) {
return selected_.count(to) > 0;
}
bool InlinePropagatorSelector::allowP2C(TensorView* from, TensorView* to) {
// If the producer is in the selected set, then the consumer must also be
// replayed to obtain a compatible loop structure so that this producer
// can be consumed in this loop.
return selected_.count(from) > 0 || selected_.count(to) > 0;
}
bool InlinePropagatorSelector::allowSibling(TensorView* from, TensorView* to) {
return true;
}
MaxPosCalculator::MaxPosCalculator(ComputeAtMode mode) : mode_(mode) {
MaxPosCalculator::MaxPosCalculator(
ComputeAtMode mode,
std::unordered_set<IterDomain*> uninlinable_ids)
: mode_(mode), uninlinable_ids_(std::move(uninlinable_ids)) {
buildUnmappableDims();
}
@ -65,6 +53,10 @@ bool MaxPosCalculator::isAllowedID(
allowed = allowed && !id->isReduction();
}
if (uninlinable_ids_.count(id)) {
return false;
}
if (!allow_vectorize) {
// Avoid inlining if marked as Vectorize or Group. In the case of
// BestEffort and MostInlined modes, avoid Unroll as well.
@ -121,6 +113,13 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
for (size_t producer_pos = 0; producer_pos < producer->nDims();
producer_pos++) {
// If the producer position is mismatching with the consumer, then we can
// not inline into this position, otherwise the max producer position of
// the consumer will become invalid and expression sort will fail.
if (TransformReplay::getMatchedLeafPosWithoutReplayCasP(
consumer, producer, producer_pos + 1) < 0) {
return producer_pos;
}
auto map_it = p2c_replay_map.find(producer->axis(producer_pos));
if (map_it != p2c_replay_map.end()) {
auto c_id = map_it->second;
@ -147,9 +146,17 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) {
}
void InlinePropagator::setCAPos(TensorView* tv) {
bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator);
size_t pos = mapped_reference_pos_.at(tv);
if (debug) {
std::cout << " Setting CA pos of " << tv << ":" << std::endl;
std::cout << " mapped position: " << pos << std::endl;
}
if ((selected_.empty() || selected_.count(tv)) && !tv->isFusionInput()) {
auto max_pos = getMaxPosAll(tv);
if (debug) {
std::cout << " max inlinable position: " << max_pos << std::endl;
}
if (mode_ == ComputeAtMode::Standard) {
TORCH_INTERNAL_ASSERT(
pos <= max_pos,
@ -159,14 +166,32 @@ void InlinePropagator::setCAPos(TensorView* tv) {
pos,
", max position that's allowed is ",
max_pos);
} else {
} else if (mode_ == ComputeAtMode::BestEffort) {
pos = std::min<size_t>(pos, max_pos);
} else {
pos = max_pos;
}
// hoist inner most broadcast
while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) {
pos--;
}
tv->setComputeAt(pos);
auto current_ca_pos = tv->getComputeAtPosition();
if (debug) {
std::cout << " current CA position: " << current_ca_pos << std::endl;
}
if (pos > current_ca_pos) {
if (debug) {
std::cout << " new CA position: " << pos << std::endl;
}
tv->setComputeAt(pos);
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
needs_update_max_producer_.insert(consumer_tv);
}
} else if (debug) {
std::cout << " CA position not changed" << std::endl;
}
} else if (debug) {
std::cout << " tensor not selected, skip" << std::endl;
}
}
@ -174,8 +199,9 @@ InlinePropagator::InlinePropagator(
TensorView* reference,
int64_t reference_pos,
ComputeAtMode mode,
std::unordered_set<TensorView*> selected)
: max_pos_calc(mode),
std::unordered_set<TensorView*> selected,
std::unordered_set<IterDomain*> uninlinable_ids)
: max_pos_calc(mode, std::move(uninlinable_ids)),
selected_(std::move(selected)),
reference_(reference),
mode_(mode) {
@ -194,69 +220,150 @@ InlinePropagator::InlinePropagator(
reference_pos_ = reference_pos;
}
void InlinePropagator::setUp() {
bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator);
mapped_reference_pos_[reference_] = reference_pos_;
if (debug) {
std::cout << "InlinePropagator::setUp" << std::endl;
std::cout << " reference: " << reference_ << " @ " << reference_pos_
<< std::endl;
}
setCAPos(reference_);
}
namespace {
// Try to find the aligned position on consumer's domain corresponding to the
// compute at position of producer domain. Used in InlinePropagator pass only.
// No checking on actual producer-consumer relationship.
unsigned int getConsumerPosAlignedToProducerCA(
TensorView* consumer,
TensorView* producer) {
// Locate consumer's position that aligns with
// the producer's new compute at axis. We need broadcast axes forwarded so we
// need to replay PasC as CasP will not forward braodcast dims. For example
// if we have:
// T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
// produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
// NVFuserTest.FusionComplexBCast1_CUDA
auto disjoint_sets =
BestEffortReplay::replayPasC(
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer))
.getDisjointSets();
// Find the innermost position of consumer that has
// been mapped within the producer ca axis.
unsigned int consumer_pos = consumer->nDims();
while (consumer_pos > 0) {
auto consumer_id = consumer->axis((int)consumer_pos - 1);
auto p_dom = producer->domain()->domain();
if (std::any_of(
p_dom.begin(),
p_dom.begin() + producer->getComputeAtPosition(),
[&consumer_id, &disjoint_sets](IterDomain* p_id) {
return disjoint_sets.permissiveAreMapped(consumer_id, p_id);
})) {
break;
}
consumer_pos--;
}
return consumer_pos;
}
} // namespace
void InlinePropagator::tearDown() {
for (auto consumer : needs_update_max_producer_) {
unsigned int consumer_pos = 0;
for (auto producer : ir_utils::producerTvsOf(consumer)) {
consumer_pos = std::max(
consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer));
}
consumer->setMaxProducer(consumer_pos);
}
}
void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
mapped_reference_pos_[reference_] = reference_pos_;
setCAPos(reference_);
bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator);
if (debug) {
std::cout << "InlinePropagator::propagateC2P" << std::endl;
std::cout << " from: " << from << std::endl;
std::cout << " to: " << to << std::endl;
}
// Step 1: find mapped_reference_pos_[to]
int from_pos;
if (mode_ != ComputeAtMode::MostInlined) {
from_pos = mapped_reference_pos_.at(from);
} else {
from_pos = from->nDims();
}
int from_pos = mapped_reference_pos_.at(from);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" at ",
from_pos,
" to producer ",
to,
" because this would require replay.");
if (mode_ == ComputeAtMode::Standard) {
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from consumer ",
from,
" at ",
from_pos,
" to producer ",
to,
" because this would require replay.");
} else {
// For MostInlined and BestEffort inline propagation, we allow the DAG to
// be not replayed fully consistently. For such case, we just don't inline
// into the mismatched dimension.
while (to_pos < 0) {
from_pos--;
to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(
to, from, from_pos);
}
}
mapped_reference_pos_[to] = to_pos;
// Step 2: set CA position of `to`
setCAPos(to);
}
void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
mapped_reference_pos_[reference_] = reference_pos_;
setCAPos(reference_);
bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator);
if (debug) {
std::cout << "InlinePropagator::propagateP2C" << std::endl;
std::cout << " from: " << from << std::endl;
std::cout << " to: " << to << std::endl;
}
// Step 1: find mapped_reference_pos_[to]
int from_pos;
if (mode_ != ComputeAtMode::MostInlined) {
from_pos = mapped_reference_pos_.at(from);
} else {
from_pos = from->nDims();
}
int from_pos = mapped_reference_pos_.at(from);
auto to_pos =
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" at ",
from_pos,
" to consumer ",
to,
" because this would require replay.");
if (mode_ == ComputeAtMode::Standard) {
TORCH_CHECK(
to_pos >= 0,
"Unable to propagate CA position from producer ",
from,
" at ",
from_pos,
" to consumer ",
to,
" because this would require replay.");
} else {
// For MostInlined and BestEffort inline propagation, we allow the DAG to
// be not replayed fully consistently. For such case, we just don't inline
// into the mismatched dimension.
while (to_pos < 0) {
from_pos--;
to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(
to, from, from_pos);
}
}
mapped_reference_pos_[to] = to_pos;
// Step 2: set CA position of `to`
setCAPos(to);
}
void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
if (is_first_) {
is_first_ = false;
mapped_reference_pos_[reference_] = reference_pos_;
setCAPos(reference_);
bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator);
if (debug) {
std::cout << "InlinePropagator::propagateSibling" << std::endl;
std::cout << " from: " << from << std::endl;
std::cout << " to: " << to << std::endl;
}
// Step 1: find mapped_reference_pos_[to]
auto from_pos = mapped_reference_pos_.at(from);
@ -272,96 +379,6 @@ void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
setCAPos(to);
}
namespace {
// Try to find the aligned position on consumer's domain corresponding to the
// compute at position of producer domain. Used in computeAt pass only. No
// checking on actual producer-consumer relationship.
unsigned int getConsumerPosAlignedToProducerCA(
TensorView* consumer,
TensorView* producer) {
// Locate consumer's position that aligns with
// the producer's new compute at axis. We need broadcast axes forwarded so we
// need to replay PasC as CasP will not forward braodcast dims. For example
// if we have:
// T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
// produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
// NVFuserTest.FusionComplexBCast1_CUDA
auto c2p_map =
BestEffortReplay::replayPasC(
producer,
consumer,
-1,
// Compute at root domain may not be valid here, as all
// producers don't have to be able to map into consumer at
// max producer position. Since computeAt should be valid
// and this mechanism is only intended to lower produce
// position of consumer, we can simply use the pairwise map.
PairwiseRootDomainMap(producer, consumer))
.getReplay();
// Find the innermost position of consumer that has
// been mapped within the producer ca axis.
unsigned int consumer_pos = consumer->nDims();
while (consumer_pos > 0) {
auto consumer_id = consumer->axis((int)consumer_pos - 1);
auto p_dom = producer->domain()->domain();
if (std::any_of(
p_dom.begin(),
p_dom.begin() + producer->getComputeAtPosition(),
[&consumer_id, &c2p_map](IterDomain* p_id) {
auto c_id_it = c2p_map.find(consumer_id);
if (c_id_it != c2p_map.end()) {
return c_id_it->second == p_id;
}
return false;
})) {
break;
}
consumer_pos--;
}
return consumer_pos;
}
} // namespace
// Try to find the aligned position on consumer's domain corresponding to the
// compute at position of producer domain.
void MaxProducerPosUpdater::handle(TensorView* consumer) {
unsigned int consumer_pos = 0;
for (auto producer : ir_utils::producerTvsOf(consumer)) {
consumer_pos = std::max(
consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer));
}
consumer->setMaxProducer(consumer_pos);
}
void MaxProducerPosUpdater::propagateC2P(TensorView* from, TensorView* to) {
if (updated_.empty()) {
// handle the reference tensor
updated_.insert(nullptr);
propagateC2P(nullptr, from);
}
for (auto consumer_tv : ir_utils::consumerTvsOf(to)) {
if (updated_.count(consumer_tv) > 0) {
continue;
}
handle(consumer_tv);
updated_.insert(consumer_tv);
}
}
void MaxProducerPosUpdater::propagateP2C(TensorView* from, TensorView* to) {
propagateC2P(from, to);
}
void MaxProducerPosUpdater::propagateSibling(TensorView* from, TensorView* to) {
propagateC2P(from, to);
}
} // namespace cuda
} // namespace fuser
} // namespace jit

View File

@ -11,31 +11,16 @@ namespace jit {
namespace fuser {
namespace cuda {
// Simple selector that only propagates across tensor views in the provided
// unordered_set. Will also propagate to all consumers of those tensors, and the
// siblings of those tensors.
class TORCH_CUDA_CU_API InlinePropagatorSelector
: public MaxInfoSpanningTree::Selector {
std::unordered_set<TensorView*> selected_;
public:
virtual bool allowC2P(TensorView* from, TensorView* to) override;
virtual bool allowP2C(TensorView* from, TensorView* to) override;
virtual bool allowSibling(TensorView* from, TensorView* to) override;
InlinePropagatorSelector(std::unordered_set<TensorView*> selected)
: selected_(std::move(selected)){};
const std::unordered_set<TensorView*>& selected() const {
return selected_;
}
};
class TORCH_CUDA_CU_API MaxPosCalculator {
ComputeAtMode mode_ = ComputeAtMode::Standard;
// Root domains in producer that's unmappable to any of its consumers
std::unordered_set<IterDomain*> unmappable_dims_;
// User set IterDomains to not inline, used in schedulers to avoid inlining
// trivial reductions
std::unordered_set<IterDomain*> uninlinable_ids_;
// Iterate through all TVs and collect the dimensions of each TV that don't
// map to all its consumer TVs.
void buildUnmappableDims();
@ -65,7 +50,9 @@ class TORCH_CUDA_CU_API MaxPosCalculator {
TensorView* producer,
TensorView* consumer) const;
MaxPosCalculator(ComputeAtMode mode);
MaxPosCalculator(
ComputeAtMode mode,
std::unordered_set<IterDomain*> uninlinable_ids = {});
};
// Propagate inline position to the `selected` tensors in the DAG. If `selected`
@ -91,17 +78,18 @@ class TORCH_CUDA_CU_API InlinePropagator
const MaxPosCalculator max_pos_calc;
std::unordered_set<TensorView*> selected_;
std::unordered_set<TensorView*> needs_update_max_producer_;
TensorView* reference_;
size_t reference_pos_;
ComputeAtMode mode_ = ComputeAtMode::Standard;
bool is_first_ = true;
public:
InlinePropagator(
TensorView* reference,
int64_t reference_pos,
ComputeAtMode mode = ComputeAtMode::Standard,
std::unordered_set<TensorView*> selected = {});
std::unordered_set<TensorView*> selected = {},
std::unordered_set<IterDomain*> uninlinable_ids = {});
InlinePropagator(
TensorView* reference,
@ -117,24 +105,11 @@ class TORCH_CUDA_CU_API InlinePropagator
// Actually propagate the transformations for the inlining pass. Uses the
// functions above to figure out what position to do the propagation at.
virtual void setUp() override;
virtual void propagateC2P(TensorView* from, TensorView* to) override;
virtual void propagateP2C(TensorView* from, TensorView* to) override;
virtual void propagateSibling(TensorView* from, TensorView* to) override;
};
// This is actually not a propagation, it only sets the max producer position of
// the tensors, and it is not needed to compute the max producer position in a
// specific order. But MaxInfoSpanningTree provides a very convenient API to
// visit the tensors, so I just use it for cleaner code.
class TORCH_CUDA_CU_API MaxProducerPosUpdater
: public MaxInfoSpanningTree::Propagator {
std::unordered_set<TensorView*> updated_;
void handle(TensorView* tv);
public:
virtual void propagateC2P(TensorView* from, TensorView* to) override;
virtual void propagateP2C(TensorView* from, TensorView* to) override;
virtual void propagateSibling(TensorView* from, TensorView* to) override;
virtual void tearDown() override;
};
} // namespace cuda

View File

@ -84,6 +84,9 @@ Val* IrBuilder::newResult(DataType dtype) {
}
Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
TORCH_CHECK(
lhs != nullptr && rhs != nullptr,
"Either lhs or rhs is a nullptr in newArithmeticExpr.");
TORCH_CHECK(
lhs->dtype() == rhs->dtype(),
"Incompatible operand types: ",
@ -97,6 +100,9 @@ Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
}
Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
TORCH_CHECK(
lhs != nullptr && rhs != nullptr,
"Either lhs or rhs is a nullptr in newLogicExpr.");
auto result = IrBuilder::create<Bool>(c10::nullopt);
IrBuilder::create<BinaryOp>(op_type, result, lhs, rhs);
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
@ -104,6 +110,9 @@ Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
}
Val* IrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) {
TORCH_CHECK(
pred != nullptr && lhs != nullptr && rhs != nullptr,
"Either pred, lhs, or rhs is a nullptr in whereExpr.");
TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types");
auto result = newResult(lhs->dtype());
IrBuilder::create<TernaryOp>(TernaryOpType::Where, result, pred, lhs, rhs);
@ -111,30 +120,35 @@ Val* IrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) {
}
Val* IrBuilder::negExpr(Val* val) {
TORCH_CHECK(val != nullptr, "val is a nullptr in negExpr.");
auto result = newResult(val->dtype());
IrBuilder::create<UnaryOp>(UnaryOpType::Neg, result, val);
return result;
}
Val* IrBuilder::notExpr(Val* val) {
TORCH_CHECK(val != nullptr, "val is a nullptr in notExpr.");
auto result = newResult(val->dtype());
IrBuilder::create<UnaryOp>(UnaryOpType::Not, result, val);
return result;
}
Val* IrBuilder::setExpr(Val* val) {
TORCH_CHECK(val != nullptr, "val is a nullptr in setExpr.");
auto result = newResult(val->dtype());
IrBuilder::create<UnaryOp>(UnaryOpType::Set, result, val);
return result;
}
Val* IrBuilder::setExprNamedScalar(const std::string& name, Val* val) {
TORCH_CHECK(val != nullptr, "val is a nullptr in setExprNamedScalar.");
auto result = IrBuilder::create<NamedScalar>(name, val->dtype());
IrBuilder::create<UnaryOp>(UnaryOpType::Set, result, val);
return result;
}
Val* IrBuilder::addressExprNamedScalar(const std::string& name, Val* val) {
TORCH_CHECK(val != nullptr, "val is a nullptr in addressExprNamedScalar.");
auto result = IrBuilder::create<NamedScalar>(name, DataType::Int);
IrBuilder::create<UnaryOp>(UnaryOpType::Address, result, val);
return result;

View File

@ -381,7 +381,11 @@ class TORCH_CUDA_CU_API TensorView : public Val {
//! Swizzle the rectangular tile defined by the iterdomains corresponding
//! to the 2 given indices.
TensorView* swizzle(Swizzle2DType swizzle_type, int x, int y);
TensorView* swizzle(
Swizzle2DType swizzle_type,
int x,
int y,
SwizzleMode swizzle_mode = SwizzleMode::Data);
// WARNING: rFactor does not return this TensorView, ir returns a new
// tensorview consumed by this!
@ -450,10 +454,26 @@ class TORCH_CUDA_CU_API TensorView : public Val {
// Apply double buffering transformation
void doubleBuffer();
// Apply circular buffering transformation
void circularBuffer(unsigned int number_of_stage);
// Returns true if this tensor is double buffered.
bool isDoubleBuffered() const {
return is_double_buffered_;
}
// Returns true if this tensor is circular buffered.
bool isCircularBuffered() const {
return is_circular_buffered_;
}
// Returns the depth of circular buffering if applicable.
unsigned int circularBufferDepth() const {
TORCH_INTERNAL_ASSERT(
is_circular_buffered_, toString(), "not circular buffered");
return circular_buffer_stage_;
}
//! Transforms the innermost iterdomains according to the given mma swizzle,
//! this should be used on the tvs that are either inputs/outputs of an
//! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to
@ -509,6 +529,13 @@ class TORCH_CUDA_CU_API TensorView : public Val {
SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
std::vector<IterDomain*> axes_to_swizzle_;
bool is_double_buffered_ = false;
//! Indicates if the tensor is circular buffered.
bool is_circular_buffered_ = false;
//! Indicates the circular buffering stage depth if applicable.
unsigned int circular_buffer_stage_ = 0;
// special handling for CPU based zero-dim tensors (i.e. CPU Tensors that only
// have one value). This is only used if on an input value, otherwise ignored.
// This is important as special handling because these "scalars" should be
@ -545,7 +572,8 @@ class TORCH_CUDA_CU_API TensorViewBuilder {
TensorViewBuilder& contiguity(std::vector<bool> contiguity);
//! Set the shape (default 0 dimensional, ie. scalar)
TensorViewBuilder& shape(std::vector<int64_t> shape);
TensorViewBuilder& shape(std::vector<Val*> shape);
TensorViewBuilder& shape(const std::vector<int64_t>& shape);
//! Creates a new TensorView with the specified options
TensorView* build() const;
@ -554,7 +582,7 @@ class TORCH_CUDA_CU_API TensorViewBuilder {
size_t ndims_ = 0;
DataType dtype_ = DataType::Float;
std::vector<bool> contiguity_;
std::vector<int64_t> shape_;
std::vector<Val*> shape_;
};
} // namespace cuda

View File

@ -338,6 +338,22 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr {
//! Fused Matmul operation
class TORCH_CUDA_CU_API MmaOp : public Expr {
public:
// This is a temporary data structure to for the
// scheduling specific parameters that we still need
// to store on an mma node. Eventually will only be
// the mma macro type that will stay on the IR node
// after additional cleaning ups.
struct OptionsInMma {
MmaOptions::MacroType macro = MmaOptions::MacroType::NoMMA;
MmaOptions::MmaInputLayout operand_layout = MmaOptions::MmaInputLayout::TT;
int accumulator_stride = 0;
bool operator==(const OptionsInMma& other) const {
return macro == other.macro && operand_layout == other.operand_layout &&
accumulator_stride == other.accumulator_stride;
}
};
MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init);
MmaOp(
@ -346,7 +362,7 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
Val* in_a,
Val* in_b,
Val* init,
MmaOptions options);
OptionsInMma options);
MmaOp(const MmaOp* src, IrCloner* ir_cloner);
@ -379,7 +395,15 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
}
void configureOptions(MmaOptions options) {
options_ = options;
options_ = OptionsInMma();
TORCH_INTERNAL_ASSERT(
options.macro != MmaOptions::MacroType::NoMMA,
"Un-configured mma type from options.");
TORCH_INTERNAL_ASSERT(
options.accumulator_stride > 0, "Un-configured accumulator stride.");
options_->accumulator_stride = options.accumulator_stride;
options_->macro = options.macro;
options_->operand_layout = options.operand_layout;
}
private:
@ -387,7 +411,7 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
Val* const in_a_ = nullptr;
Val* const in_b_ = nullptr;
Val* const init_ = nullptr;
c10::optional<MmaOptions> options_ = c10::nullopt;
c10::optional<OptionsInMma> options_ = c10::nullopt;
};
class TORCH_CUDA_CU_API TransposeOp : public Expr {
@ -967,7 +991,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
static std::pair<IterDomain*, IterDomain*> swizzle(
Swizzle2DType swizzle_type,
IterDomain* in_x,
IterDomain* in_y);
IterDomain* in_y,
SwizzleMode swizzle_mode = SwizzleMode::Data);
bool isMmaSwizzled() const {
return is_mma_swizzled_;
@ -1174,7 +1199,11 @@ class TORCH_CUDA_CU_API TensorDomain : public Val {
//! Applies 2D swizzle on a rectangular tile defined by
//! a pair of iterdomains contained in this domain.
void swizzle(Swizzle2DType swizzle_type, int x, int y);
void swizzle(
Swizzle2DType swizzle_type,
int x,
int y,
SwizzleMode swizzle_mode = SwizzleMode::Data);
// Transform TensorView according to merge and split transformations
TensorDomain* view(
@ -1315,7 +1344,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {
IterDomain* out_y,
IterDomain* in_x,
IterDomain* in_y,
Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle);
Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle,
SwizzleMode swizzle_mode = SwizzleMode::Data);
Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner);
@ -1335,10 +1365,14 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {
return in_y_;
}
const auto& swizzleType() const {
auto swizzleType() const {
return swizzle_type_;
}
auto swizzleMode() const {
return swizzle_mode_;
}
bool sameAs(const Statement* other) const override;
private:
@ -1353,7 +1387,50 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {
// The type of predefined 1-to-1 functions
// used for swizzling math.
Swizzle2DType swizzle_type_;
Swizzle2DType swizzle_type_ = Swizzle2DType::NoSwizzle;
// Swizzle mode of this swizzle instance.
// [Note on swizzle mode]
// On the current implementations we support two modes of
// swizzle math, namely, data mode and loop mode.
// `Data` mode swizzling is a swizzle that will change the
// data layout in shared memory, likely in global memory buffers
// as well in the future. see also IndexSwizzle in index_compute.cpp.
//
// Most important use cases are transpose bank conflict removal, and mma
// swizzled shared memory layout. Example illustrated in 1D case:
//
// for (int i = 0; i<I; i++){
// # This is a `Data` mode swizzle.
// Tshared [swizzled(i)] = Tin[i];
// }
// # Now Tshared holds swizzled data, i.e. the data layout of
// Tshared does not map to Tin with affine relationships.
//
// for(int i=0;i<I;i++){
// Tout = Tshared[swizzled(i)];
// }
//
// `Loop` mode swizzling does not affect the data layout of any buffer
// but only permutes the iteration order of serial or parallel loop.
// This is useful when we want to designate non-affine mapping of thread
// to data or we want to generate non-affine loops.
// Exampe illustrated in 1D case:
// for (int i = 0; i<I; i++){
// # This is a `Loop` mode swizzle
// Tshared [swizzled(i)] = Tin[swizzled(i)];
// }
// # Now Tshared holds normal data, i.e. it still has
// the same data layout as if the swizzle wasn't there.
//
// # Consumers of Tshared does not need to know about the
// loop swizzle at previous op if not inlined.
// for(int i=0;i<I;i++){
// Tout = Tshared[i];
// }
// TODO: Loop swizzles eventually will be piped through in all mappings
// and replay of the fusion IR infrastructure.
SwizzleMode swizzle_mode_ = SwizzleMode::Data;
};
//! Integer value which has a special name

View File

@ -599,6 +599,10 @@ void IrPrinter::handle(const kir::BlockSync* node) {
}
void IrPrinter::handle(const kir::CpAsyncWait* node) {
indent() << "CPASYNC_WAIT(" << node->keepStages() << ")\n";
}
void IrPrinter::handle(const kir::CpAsyncCommit* node) {
indent() << "CPASYNC_WAIT()\n";
}

View File

@ -112,6 +112,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
void handle(const kir::BlockSync*) final;
void handle(const kir::GridSync*) final;
void handle(const kir::CpAsyncWait*) final;
void handle(const kir::CpAsyncCommit*) final;
void handle(const kir::InitMagicZero*) final;
void handle(const kir::UpdateMagicZero*) final;
void handle(const kir::AllocateFusedReduction*) final;

View File

@ -630,7 +630,7 @@ MmaOp::MmaOp(
Val* in_a,
Val* in_b,
Val* init,
MmaOptions options)
OptionsInMma options)
: MmaOp(passkey, out, in_a, in_b, init) {
options_ = options;
}
@ -1293,7 +1293,8 @@ std::pair<IterDomain*, IterDomain*> IterDomain::stridedSplit(int factor) {
std::pair<IterDomain*, IterDomain*> IterDomain::swizzle(
Swizzle2DType swizzle_type,
IterDomain* in_x,
IterDomain* in_y) {
IterDomain* in_y,
SwizzleMode swizzle_mode) {
TORCH_CHECK(
!in_x->extent()->isZeroInt() && !in_y->extent()->isZeroInt(),
"Invalid swizzling of a empty dimension.");
@ -1319,7 +1320,7 @@ std::pair<IterDomain*, IterDomain*> IterDomain::swizzle(
IterDomain* out_y = IterDomainBuilder(in_y).build();
IrBuilder::create<Swizzle2D>(
in_x->container(), out_x, out_y, in_x, in_y, swizzle_type);
in_x->container(), out_x, out_y, in_x, in_y, swizzle_type, swizzle_mode);
return std::make_pair(out_x, out_y);
}
@ -1790,7 +1791,11 @@ std::vector<IterDomain*> TensorDomain::orderedAs(
return reordered_domain;
}
void TensorDomain::swizzle(Swizzle2DType swizzle_type, int x, int y) {
void TensorDomain::swizzle(
Swizzle2DType swizzle_type,
int x,
int y,
SwizzleMode swizzle_mode) {
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain");
TORCH_CHECK(
@ -1808,7 +1813,7 @@ void TensorDomain::swizzle(Swizzle2DType swizzle_type, int x, int y) {
IterDomain* axis_out_y = nullptr;
std::tie(axis_out_x, axis_out_y) =
IterDomain::swizzle(swizzle_type, axis_x, axis_y);
IterDomain::swizzle(swizzle_type, axis_x, axis_y, swizzle_mode);
domain_.erase(domain_.begin() + x);
domain_.insert(domain_.begin() + x, axis_out_x);
@ -2039,13 +2044,15 @@ Swizzle2D::Swizzle2D(
IterDomain* out_y,
IterDomain* in_x,
IterDomain* in_y,
Swizzle2DType swizzle_type)
Swizzle2DType swizzle_type,
SwizzleMode swizzle_mode)
: Expr(passkey, ExprType::Swizzle2D),
out_x_{out_x},
out_y_{out_y},
in_x_{in_x},
in_y_{in_y},
swizzle_type_(swizzle_type) {
swizzle_type_(swizzle_type),
swizzle_mode_(swizzle_mode) {
addOutput(out_x);
addOutput(out_y);
addInput(in_x);
@ -2071,7 +2078,8 @@ Swizzle2D::Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner)
out_y_(ir_cloner->clone(src->out_y_)),
in_x_(ir_cloner->clone(src->in_x_)),
in_y_(ir_cloner->clone(src->in_y_)),
swizzle_type_(src->swizzle_type_) {}
swizzle_type_(src->swizzle_type_),
swizzle_mode_(src->swizzle_mode_) {}
NamedScalar::NamedScalar(
IrBuilderPasskey passkey,

View File

@ -3,6 +3,7 @@
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <set>
@ -473,25 +474,23 @@ TensorView* rfactorHelper(
TensorView* reduction_tv,
const std::vector<int>& axes) {
TORCH_INTERNAL_ASSERT(reduction_tv->definition() != nullptr);
const bool is_welford = reduction_tv->definition()->isA<WelfordOp>();
if (!is_welford) {
const bool has_multiple_tvs = reduction_tv->definition()->inputs().size() > 1;
if (!has_multiple_tvs) {
return reduction_tv->rFactor(axes);
}
auto welford = reduction_tv->definition()->as<WelfordOp>();
auto w_avg = welford->outAvg()->as<TensorView>();
auto w_var = welford->outVar()->as<TensorView>();
auto w_n = welford->outN()->as<TensorView>();
auto rtvs =
reduction_tv->rFactor(axes, std::vector<TensorView*>{w_avg, w_var, w_n});
std::vector<TensorView*> out_tvs;
std::transform(
reduction_tv->definition()->outputs().begin(),
reduction_tv->definition()->outputs().end(),
std::back_inserter(out_tvs),
[](Val* val) { return val->as<TensorView>(); });
if (reduction_tv == w_n) {
return rtvs.at(2);
} else if (reduction_tv == w_var) {
return rtvs.at(1);
} else {
return rtvs.at(0);
}
auto rf_tvs = reduction_tv->rFactor(axes, out_tvs);
return rf_tvs.at(std::distance(
out_tvs.begin(),
std::find(out_tvs.begin(), out_tvs.end(), reduction_tv)));
}
namespace {
@ -652,6 +651,19 @@ std::vector<TensorView*> allTvs(Fusion* fusion) {
return uniqueEntries<TensorView>(all_tvs);
}
std::vector<TensorView*> allTvsExcept(
Fusion* fusion,
const std::unordered_set<TensorView*>& except) {
auto all_tvs = allTvs(fusion);
std::vector<TensorView*> result;
for (auto tv : all_tvs) {
if (except.count(tv) == 0) {
result.emplace_back(tv);
}
}
return result;
}
std::vector<Expr*> getReductionOps(Fusion* fusion, bool ignore_trivial) {
std::vector<Expr*> red_ops;
@ -796,6 +808,145 @@ Val* getReductionInitValOf(TensorView* tv) {
return init;
}
// TODO: Should mma be in here? Should we return true if it's a trivial
// reduction?
bool isReductionOp(const Expr* expr) {
// Note that GridReduction inherits ReductionOp
return expr->isA<ReductionOp>() || expr->isA<GroupedReductionOp>() ||
expr->isA<WelfordOp>() || expr->isA<kir::GridWelford>();
}
bool isReductionTvOp(const Expr* expr) {
return ir_utils::isTvOp(expr) && isReductionOp(expr);
}
namespace {
struct ReplaceValInIndexVal : public OptInDispatch {
public:
//! Apply replacements to index as specified in
//! replacement_map. index is assumed to consist only from Int and
//! NamedScalar
static Val* replace(
Val* index,
const std::unordered_map<Val*, Val*>& replacement_map) {
ReplaceValInIndexVal replace_index_val(replacement_map);
replace_index_val.handle(index);
// Return the original index if not replaced
if (replace_index_val.is_replaced_) {
return replace_index_val.last_visited_val_;
} else {
return index;
}
}
private:
ReplaceValInIndexVal(const std::unordered_map<Val*, Val*>& replacement_map)
: replacement_map_(replacement_map) {}
using OptOutDispatch::handle;
void handle(Val* val) override {
TORCH_INTERNAL_ASSERT(
val->isA<Int>() || val->isA<NamedScalar>() || val->isA<kir::IntPair>(),
"Invalid Val type: ",
val->toString());
// if val appears in the replacement map, stop traversing and set
// the current val with the replacement
auto it = replacement_map_.find(val);
if (it != replacement_map_.end()) {
last_visited_val_ = it->second;
is_replaced_ = true;
return;
}
// Recursively traverse its defining expr
auto def = val->definition();
if (def != nullptr) {
switch (def->etype()) {
case ExprType::UnaryOp:
case ExprType::BinaryOp:
case ExprType::Swizzle2DInt:
case ExprType::PairSelect:
handle(val->definition());
break;
default:
TORCH_INTERNAL_ASSERT(
false, "Unexpected definition: ", def->toString())
}
// last_visited_val_ is set in the expr handlers
} else {
last_visited_val_ = val;
}
}
// Clone expression after recurisvely replacing inputs
void handle(UnaryOp* uop) override {
handle(uop->in());
auto inp = last_visited_val_;
TORCH_INTERNAL_ASSERT(uop->out()->isA<Int>());
auto out = IrBuilder::create<Int>(c10::nullopt);
IrBuilder::create<UnaryOp>(uop->getUnaryOpType(), out, inp);
last_visited_val_ = out;
}
// Clone expression after recurisvely replacing inputs
void handle(BinaryOp* bop) override {
handle(bop->lhs());
auto lhs = last_visited_val_;
handle(bop->rhs());
auto rhs = last_visited_val_;
TORCH_INTERNAL_ASSERT(bop->out()->isA<Int>());
auto out = IrBuilder::create<Int>(c10::nullopt);
IrBuilder::create<BinaryOp>(bop->getBinaryOpType(), out, lhs, rhs);
last_visited_val_ = out;
}
// Clone expression after recurisvely replacing inputs
void handle(kir::Swizzle2DInt* swizzle_2d) override {
handle(swizzle_2d->inX());
auto in_x = last_visited_val_;
handle(swizzle_2d->inY());
auto in_y = last_visited_val_;
auto out = IrBuilder::create<kir::IntPair>();
// Extents are assumed constant in swizzle so no need to
// duplicate their graphs.
IrBuilder::create<kir::Swizzle2DInt>(
out,
in_x,
in_y,
swizzle_2d->extentX(),
swizzle_2d->extentY(),
swizzle_2d->swizzleType());
last_visited_val_ = out;
}
void handle(kir::PairSelect* pair_select) override {
handle(pair_select->in()->asVal());
auto in = last_visited_val_;
TORCH_INTERNAL_ASSERT(pair_select->out()->isA<Int>());
auto out = IrBuilder::create<Int>(c10::nullopt);
IrBuilder::create<kir::PairSelect>(
out, in->as<kir::IntPair>(), pair_select->selection());
last_visited_val_ = out;
}
private:
const std::unordered_map<Val*, Val*>& replacement_map_;
Val* last_visited_val_ = nullptr;
bool is_replaced_ = false;
};
} // namespace
Val* replaceValInIndexVal(
Val* index,
const std::unordered_map<Val*, Val*>& replacement_map) {
return ReplaceValInIndexVal::replace(index, replacement_map);
}
} // namespace ir_utils
} // namespace cuda
} // namespace fuser

View File

@ -156,6 +156,18 @@ std::vector<int> normalizeOld2New(
// Reference is found through direct pointer comparison.
Expr* replaceValInExpr(Expr* expr, Val* reference, Val* substitute);
//! Replace Vals in an index Val as specified by replacement_map while
//! cloning the given index Val. The index val is assumed to represent
//! a tensor index consisting of Ints and arithmetic expressions.
//!
//! This is similar to replaceValInExpr but is different as Vals are
//! cloned such that no other exprs using the same leaf Vals are not
//! modified. TODO: Consider cleaning up the multiple replacement
//! routines.
Val* replaceValInIndexVal(
Val* index,
const std::unordered_map<Val*, Val*>& replacement_map);
// Makes rfactor generic with reduction ops and Welford
TORCH_CUDA_CU_API TensorView* rfactorHelper(
TensorView* red_tv,
@ -282,6 +294,12 @@ TORCH_CUDA_CU_API std::vector<TensorView*> outputTvsOf(
// returns all tensor views in fusion that are used between outputs and inputs.
TORCH_CUDA_CU_API std::vector<TensorView*> allTvs(Fusion* fusion);
// returns all tensor views in fusion that are used between outputs and inputs
// except the specified set.
TORCH_CUDA_CU_API std::vector<TensorView*> allTvsExcept(
Fusion* fusion,
const std::unordered_set<TensorView*>& except);
TORCH_CUDA_CU_API std::vector<Expr*> getReductionOps(
Fusion* fusion,
bool ignore_trivial = true);
@ -289,6 +307,12 @@ TORCH_CUDA_CU_API std::vector<Expr*> getReductionOps(
// Returns the initialization value of tv or nullptr if not initialized.
TORCH_CUDA_CU_API Val* getReductionInitValOf(TensorView* tv);
// Returns if Expr is a reduction op
TORCH_CUDA_CU_API bool isReductionOp(const Expr*);
// Returns if Expr is a reduction op with TensorView or TensorIndex
TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);
template <typename T>
std::string toString(const T& nodes) {
std::stringstream ss;

View File

@ -327,7 +327,7 @@ std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
auto scheduler_entry = schedulers()[group_id].get();
// Check that the heuristics are matched, in the case of segmented fusion
TORCH_INTERNAL_ASSERT(!sg || scheduler_entry->heuristc() == sg->heuristic());
TORCH_INTERNAL_ASSERT(!sg || scheduler_entry->heuristic() == sg->heuristic());
if (!executors_[group_id].compiled()) {
FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::Compile");
@ -341,32 +341,16 @@ std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
options.index_mode = scheduler_entry->indexMode();
FusionGuard fg(fusion_to_run.get());
scheduler_entry->schedule(fusion_to_run.get());
// Load launch params for reduction and normalization kernels
if (scheduler_entry->hasReductionParam()) {
launch_params = scheduler_entry->reductionParams().lparams;
} else {
launch_params = scheduler_entry->pointwiseParams().lparams;
}
launch_params = scheduler_entry->params()->lparams;
executors_[group_id].compileFusion(
fusion_to_run.get(), inputs, launch_params, options);
} else {
// Load launch params for reduction and normalization kernels
if (scheduler_entry->hasReductionParam()) {
launch_params = scheduler_entry->reductionParams().lparams;
} else {
launch_params = scheduler_entry->pointwiseParams().lparams;
}
launch_params = scheduler_entry->params()->lparams;
}
if (profiling_) {
most_recent_executor_log_.fusion_executor = &executors_[group_id];
if (scheduler_entry->hasReductionParam()) {
most_recent_executor_log_.reduction_params =
scheduler_entry->reductionParams();
} else {
most_recent_executor_log_.pointwise_params =
scheduler_entry->pointwiseParams();
}
most_recent_executor_log_.params = scheduler_entry->params()->clone();
}
auto& executor = executors_[group_id];
@ -395,11 +379,7 @@ std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
}
}
std::cout << "Compiler log: " << executor.compilerLog() << "\n";
if (scheduler_entry->hasReductionParam()) {
std::cout << scheduler_entry->reductionParams().toString() << "\n";
} else {
std::cout << scheduler_entry->pointwiseParams().toString() << "\n";
}
std::cout << scheduler_entry->params()->toString() << "\n";
std::cout << "With arguments: " << executor.lastLaunchParams().toString();
std::cout << executor.kernelName() << " " << executor.bytesProcessed()
<< " bytes/ " << std::setprecision(3) << executor.kernelTimeMs()
@ -604,13 +584,8 @@ void FusionKernelRuntime::updateHeuristicsLaunchParams(
update_heuristics->heuristicsList().size() == scheduler_list_length);
for (const auto i : c10::irange(scheduler_list_length)) {
auto& schedulerPtr = heuristics_->heuristicsList()[i];
if (schedulerPtr->hasReductionParam()) {
schedulerPtr->updateLaunchConstraint(
update_heuristics->heuristicsList()[i]->reductionParams().lparams);
} else {
schedulerPtr->updateLaunchConstraint(
update_heuristics->heuristicsList()[i]->pointwiseParams().lparams);
}
schedulerPtr->updateLaunchConstraint(
update_heuristics->heuristicsList()[i]->params()->lparams);
}
}

View File

@ -25,8 +25,7 @@ class SchedulerRuntimeInfo;
// Utilities for benchmarking and profiling
struct ExecutorLog {
c10::optional<ReductionParams> reduction_params = c10::nullopt;
c10::optional<PointwiseParams> pointwise_params = c10::nullopt;
std::shared_ptr<HeuristicParams> params = nullptr;
FusionExecutor* fusion_executor = nullptr;
};

View File

@ -19,7 +19,7 @@ void ExpressionEvaluator::bind(
TORCH_CHECK(
value->definition() == nullptr,
"Tried to bind to a value that is computed in the kernel IR: ",
value->toString(),
value->toInlineString(),
" with ",
concrete_value);
known_values_[value] = concrete_value;

View File

@ -93,8 +93,15 @@ GridSync::GridSync(
sync_dims_(sync_dims),
sync_buffer_(sync_buffer) {}
CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey)
: Expr(passkey, ExprType::CpAsyncWait) {
CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages)
: Expr(passkey, ExprType::CpAsyncWait), keep_stages_(keep_stages) {
TORCH_INTERNAL_ASSERT(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
}
CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey)
: Expr(passkey, ExprType::CpAsyncCommit) {
TORCH_INTERNAL_ASSERT(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");

View File

@ -55,6 +55,7 @@ class Allocate;
class BlockSync;
class GridSync;
class CpAsyncWait;
class CpAsyncCommit;
class InitMagicZero;
class UpdateMagicZero;
class ForLoop;
@ -258,11 +259,27 @@ class TORCH_CUDA_CU_API BlockSync final : public Expr {
};
// CpAsyncWait represents wait intrinsics for cp.async
// TODO: expand to support different wait modes of the intrinsic
// as the analysis passes build out.
class TORCH_CUDA_CU_API CpAsyncWait final : public Expr {
public:
explicit CpAsyncWait(IrBuilderPasskey passkey);
explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0);
//! Returns the remaining number of stages that are not synchronized
//! after this op.
unsigned int keepStages() const {
return keep_stages_;
}
private:
//! Number of stage to leave un-sync'ed by this op.
unsigned int keep_stages_ = 0;
};
// CpAsyncCommit represents commit intrinsics for cp.async
// A commit intrinsic communicates delimiter of transaction groups
// to the async load hardware. Example usage see [Cicular buffer].
class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr {
public:
explicit CpAsyncCommit(IrBuilderPasskey passkey);
};
// Synchronize all blocks in device, implies cooperative group launch is

View File

@ -257,6 +257,9 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
// Validate mma data format and compatibility if any on the fusion.
validateMma(fusion_);
// Validate swizzle usage on the fusion schedule.
validateSwizzle(fusion_);
// Compute thread predicates. Depends on parallel_dimension_map_
thread_pred_map_.build(fusion_);

View File

@ -408,7 +408,7 @@ class AllocationInserter : public kir::ExprMutator {
// Double the allocation size if double-buffered. Record the
// original size for indexing.
if (info.buffer->isDoubleBuffered()) {
if (info.buffer->isDoubleBuffered() || info.buffer->isCircularBuffered()) {
Val* original_alloc_size = nullptr;
for (auto alloc_dim : alloc_dims) {
if (original_alloc_size == nullptr) {
@ -420,7 +420,11 @@ class AllocationInserter : public kir::ExprMutator {
}
GpuLower::current()->doubleBufferInfo().setOriginalAllocSize(
info.buffer, original_alloc_size);
alloc_dims.push_back(IrBuilder::create<Int>(2));
int double_buffer_stage = 2;
if (info.buffer->isCircularBuffered()) {
double_buffer_stage = info.buffer->circularBufferDepth();
}
alloc_dims.push_back(IrBuilder::create<Int>(double_buffer_stage));
}
// Create the allocation node

View File

@ -120,7 +120,7 @@ class DoubleBufferFusionInspector : private IterVisitor {
using IterVisitor::handle;
void handle(TensorView* tv) final {
if (!tv->isDoubleBuffered()) {
if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
return;
}
@ -190,10 +190,12 @@ class DoubleBufferLoopCloner : public kir::IrVisitor {
double_buffer_loop_->iter_domain(), loop_type_);
auto start = double_buffer_loop_->start();
auto stop = double_buffer_loop_->stop();
auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor(
double_buffer_loop_->iter_domain());
if (loop_type_ == DoubleBufferLoopStage::Prolog) {
TORCH_INTERNAL_ASSERT(start->isZeroInt());
stop = gpu_lower->kernel()->oneVal();
stop = SimplifyingIrBuilder::create<Int>(stage_depth - 1);
} else if (
loop_type_ == DoubleBufferLoopStage::Main &&
requireEpilogue(double_buffer_load_exprs_)) {
@ -202,7 +204,8 @@ class DoubleBufferLoopCloner : public kir::IrVisitor {
} else if (loop_type_ == DoubleBufferLoopStage::Epilog) {
TORCH_INTERNAL_ASSERT(requireEpilogue(double_buffer_load_exprs_));
start = IrBuilder::subExpr(
double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal());
double_buffer_loop_->stop(),
SimplifyingIrBuilder::create<Int>(stage_depth - 1));
}
cloned_top_level_loop_ = IrBuilder::create<kir::ForLoop>(
@ -217,6 +220,11 @@ class DoubleBufferLoopCloner : public kir::IrVisitor {
loop_type_);
handle(double_buffer_loop_);
if (stage_depth > 2) {
cloned_top_level_loop_->body().push_back(
IrBuilder::create<kir::CpAsyncCommit>());
}
}
void handle(kir::ForLoop* fl) final {
@ -314,7 +322,8 @@ class DoubleBufferLoopNestInspector : private kir::IrVisitor {
}
// Ignore init loop
if (!out_tv->isDoubleBuffered() || !expr->input(0)->isA<TensorView>()) {
if (!(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered()) ||
!expr->input(0)->isA<TensorView>()) {
return;
}
@ -430,7 +439,11 @@ class DoubleBufferInserter : private kir::ExprMutator {
// loop is async copy. We want to wait for the gmem loads to
// finish before synchronizing the block.
if (std::any_of(loads.begin(), loads.end(), ir_utils::isCpAsyncOp)) {
auto cp_async_wait = IrBuilder::create<kir::CpAsyncWait>();
auto stage_depth =
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
double_buffer_loop->iter_domain());
auto cp_async_wait =
IrBuilder::create<kir::CpAsyncWait>(stage_depth - 2);
registerInsertBefore(double_buffer_loop, cp_async_wait);
insert_cpasync_wait = true;
}
@ -506,7 +519,9 @@ class DoubleBufferInserter : private kir::ExprMutator {
// passes. Cleanups suggested in [Double Buffer Sync]
// would resolve this dependency on pass ordering.
auto end_of_loop_expr = main_loop->body().exprs().back();
auto cp_async_wait = IrBuilder::create<kir::CpAsyncWait>();
auto stage_depth = GpuLower::current()->doubleBufferInfo().getStageDepthFor(
main_loop->iter_domain());
auto cp_async_wait = IrBuilder::create<kir::CpAsyncWait>(stage_depth - 2);
// Check if a sync has been inserted by WAR sync pass.
auto block_sync_it = std::find_if(
@ -557,7 +572,9 @@ bool DoubleBufferInfo::isDoubleBufferedIterDomain(IterDomain* id) {
DoubleBufferInfo::TvInfo& DoubleBufferInfo::getTvInfo(const TensorView* tv) {
TORCH_INTERNAL_ASSERT(
tv->isDoubleBuffered(), "Not a double-buffered tensor: ", tv->toString());
tv->isDoubleBuffered() || tv->isCircularBuffered(),
"Not a double-buffered tensor: ",
tv->toString());
return map_[tv];
}
@ -565,16 +582,63 @@ void DoubleBufferInfo::setDoubleBufferAxis(
const TensorView* tv,
IterDomain* axis) {
getTvInfo(tv).double_buffer_axis = axis;
// Also validate the stage consistency with CA map.
unsigned int stage_depth = 0;
if (tv->isCircularBuffered()) {
stage_depth = tv->circularBufferDepth();
} else {
// Double buffer is essentially
// circular buffer with depth 2.
stage_depth = 2;
}
// Set and validate the new stage depth.
setStageDepth(axis, stage_depth);
}
void DoubleBufferInfo::setStageDepth(IterDomain* id, unsigned int stage_depth) {
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
id, IdMappingMode::LOOP);
auto maybe_exisiting_depth_it = stage_depth_.find(concrete_loop_id);
if (maybe_exisiting_depth_it == stage_depth_.end()) {
stage_depth_[concrete_loop_id] = stage_depth;
} else {
TORCH_INTERNAL_ASSERT(
stage_depth == maybe_exisiting_depth_it->second,
"Unsupported multiple depth pipelining, was set to ",
maybe_exisiting_depth_it->second,
" by ",
maybe_exisiting_depth_it->first->toString(),
" and then set to ",
stage_depth,
" by ",
concrete_loop_id->toString());
}
}
IterDomain* DoubleBufferInfo::getDoubleBufferAxis(const TensorView* tv) {
if (!tv->isDoubleBuffered()) {
if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
return nullptr;
}
return getTvInfo(tv).double_buffer_axis;
}
unsigned int DoubleBufferInfo::getStageDepthFor(
IterDomain* double_buffer_axis) {
auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
double_buffer_axis, IdMappingMode::LOOP);
auto maybe_depth_it = stage_depth_.find(concrete_id);
TORCH_INTERNAL_ASSERT(
maybe_depth_it != stage_depth_.end(), "Stage depth not found");
return maybe_depth_it->second;
}
kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop(
IterDomain* axis,
const std::vector<kir::ForLoop*>& loops,
@ -582,7 +646,8 @@ kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop(
auto loop_it = std::find_if(loops.begin(), loops.end(), [&](const auto loop) {
return GpuLower::current()->caMap()->areMapped(
loop->iter_domain(), axis, IdMappingMode::EXACT) &&
(!ignore_prologue || !loop->stop()->isOneInt());
(!ignore_prologue ||
loop->doubleBufferLoopStage() != DoubleBufferLoopStage::Prolog);
});
if (loop_it != loops.end()) {
@ -612,7 +677,7 @@ void DoubleBufferInfo::setOriginalAllocSize(
}
Val* DoubleBufferInfo::getOriginalAllocSize(const TensorView* tv) {
if (!tv->isDoubleBuffered()) {
if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
return nullptr;
}

View File

@ -77,6 +77,85 @@
// - Double the allocation size
// - Omit the RAW sync in the Main and Epilogue loops
// [Cicular buffer] An generalization of double buffering.
// On sm80+ hardware there is asynchronous copy infrastructure that
// motivates a circular buffering generalization of double buffering.
// Almost all analyses previously done for double buffering are exactly
// the same with circular buffering, except for the introduction of
// new concept: `stage depth`.
//
// The `stage depth` is defined as the multiplier of extra buffering
// space used. In the case of double buffering, the stage depth would
// be 2.
//
// A circular buffered loop structure would look like follows, which
// exactly parallels the case of double buffered loop structure, since
// it is a exact generalization to the same purpose.
//
// Here S is the original allocation size as above,
// D is the stage depth. With D=2, the below loop structure becomes
// exactly the same as the case in double buffering.
//
// allocate X[S*D] // allocation
// for i in 0..D-1: // prolog
// for j in ...
// if pred:
// x[i*S+j] = y[i, j];
//
// for i in 0..N: // main loop
// for j in ...
// if pred:
// x[((i+D-1)%D)*S+j] = y[i+D-1, j];
// for j in ...
// .. = x[(i%D)*S+j]
//
// (Epilog omitted since this only makes sense in using
// cp.async, where producer will be in global mem and consumer will
// be in shared mem).
//
// The profitability of this optimization comes from extra tolerance
// of global memory pipeline latency, as on the expression `.. = x[(i%D)*S+j]`
// we only need to make sure the data for the current iteration is
// completed while the remaining D-2 load iterations could still be in progress
// and overlap with the computes of the current loop.
//
// To express this pattern on sm80+ hardware we can group the loads
// in each iteration of the circular buffered loop as one "transaction",
// and specify how many transactions we want to ensure completion when
// we insert the async barriers.
//
// allocate X[S*D] // allocation
// for i in 0..D-1: // prolog
// for j in ...
// if pred:
// x[i*S+j] = y[i, j];
// cp.async.commit; // mark the transaction boundary
//
// # At this point we have D-1 transactions on the fly.
// and for the first iteration of the main loop we need
// one transaction completed, so we leave D-2 transactions
// on the fly, which would be the input to the barrier instruction.
//
// cp.async.wait D-2 // ensure all but the last D-2 transactions complete.
//
// for i in 0..N: // main loop
// # At this point we always have D-2 transactions on the fly.
// and one completed.
// for j in ...
// if pred:
// x[((i+D-1)%D)*S+j] = y[i+D-1, j];
// for j in ...
// .. = x[(i%D)*S+j]
// cp.async.commit; // mark the transaction boundary for the
// load issued in this iteration.
// # At this point we have D-1 transactions on the fly,
// and none completed.
// cp.async.wait D-2; // Ensure all but the last D-2 transactions complete.
// __syncthreads(); // Need to syncthreads because each thread will only
// ensure completion of its own async copies so
// would need to sync to this point to ensure
// completion of the whole tile.
namespace torch {
namespace jit {
namespace fuser {
@ -132,9 +211,23 @@ class TORCH_CUDA_CU_API DoubleBufferInfo {
//! as a double buffer loop.
bool isDoubleBufferedIterDomain(IterDomain* id);
//! Get the number of circular buffer stages for the given axis,
//! the number of stages will be 2 in the case of double buffer loop.
unsigned int getStageDepthFor(IterDomain* circular_buffered_id);
private:
TvInfo& getTvInfo(const TensorView* tv);
//! Set the number of circular buffer stages for the given
//! circular_buffered_id.
//! Current code generation only supports one stage depth per loop disjoint
//! set,
//! so this function will throw an error if trying to set different stage
//! numbers to iterdomains that are loop mapped.
void setStageDepth(
IterDomain* circular_buffered_id,
unsigned int stage_depth);
private:
//! Keeps track of information for lowering double buffered tensors
std::unordered_map<const TensorView*, TvInfo> map_;
@ -142,6 +235,12 @@ class TORCH_CUDA_CU_API DoubleBufferInfo {
//! Keeps track of which concrete loop map is realizing double buffer
//! iterdomains.
std::unordered_set<const IterDomain*> concrete_double_buffered_loop_id_;
//! Keeps track of double buffer loop stage depth.
//! Currently for each disjoint set of loop mapped iterdomains,
//! Only one stage depth is supported, so that the loops can indeed
//! shared with the same prolog extent and main loop offset.
std::unordered_map<IterDomain*, unsigned int> stage_depth_;
};
} // namespace cuda

View File

@ -252,8 +252,10 @@ class ExprSegmentationSorter {
// Allocate an empty expr group and return it
ExprGroup* makeEmptyGroup();
// Allocate an expr group with the provided expr and return it
ExprGroup* makeEmptyGroup(Expr*);
// Allocate an expr group with the provided expr and return it. Also requires
// information on if this expression is a terminating expression (none of its
// outputs are used in other expressions being sorted).
ExprGroup* makeEmptyGroup(Expr*, bool terminating_expr);
// Returns if sg1 and sg2 should be merged together, is called if they can
// based on the current status of the DAG.
@ -538,14 +540,19 @@ ExprGroup* ExprSegmentationSorter::makeEmptyGroup() {
return groups_.back().get();
}
ExprGroup* ExprSegmentationSorter::makeEmptyGroup(Expr* expr) {
ExprGroup* ExprSegmentationSorter::makeEmptyGroup(
Expr* expr,
bool terminating_expr) {
auto group = makeEmptyGroup();
group->exprs().push_back(expr);
if (ir_utils::isTvOp(expr)) {
auto out_tv = expr->outputs()[0]->as<TensorView>();
// Grab all id's that are shared with other tensors.
for (const auto tv_i : c10::irange(out_tv->getComputeAtPosition())) {
group->payload()->ca_domains_.push_back(out_tv->axis(tv_i));
// If not connected to consumers, doesn't mater what compute at is set to
if (!terminating_expr) {
for (const auto tv_i : c10::irange(out_tv->getComputeAtPosition())) {
group->payload()->ca_domains_.push_back(out_tv->axis(tv_i));
}
}
for (const auto tv_i : c10::irange(out_tv->getMaxProducerPosition())) {
group->payload()->pa_domains_.push_back(out_tv->axis(tv_i));
@ -1219,9 +1226,24 @@ void ExprSegmentationSorter::sort() {
// Need this for initialization of the DAG that is processed
std::unordered_map<Expr*, ExprGroup*> expr2group;
auto all_exprs = fusion_->exprs();
// Figure out all the values used as inputs to the expressions we're sorting
// (to find terminating expressions). There could be branches of expressions
// not used to produce outputs, so can't simply check val->uses() to figure
// out if it's actually used in the expressions we're sorting.
std::unordered_set<Val*> used_vals;
for (auto expr : all_exprs) {
used_vals.insert(expr->inputs().begin(), expr->inputs().end());
}
// Initialize DAG, convert each expr to a segment group
for (auto expr : fusion_->exprs()) {
auto group = makeEmptyGroup(expr);
for (auto expr : all_exprs) {
bool is_terminating_expr = std::none_of(
expr->outputs().begin(),
expr->outputs().end(),
[&used_vals](Val* output) { return used_vals.count(output) > 0; });
auto group = makeEmptyGroup(expr, is_terminating_expr);
expr2group.insert(std::make_pair(expr, group));
}

View File

@ -834,6 +834,11 @@ void IndexLowering::handle(const kir::CpAsyncWait* wait) {
pushBack(const_cast<kir::CpAsyncWait*>(wait)); // NOLINT
}
void IndexLowering::handle(const kir::CpAsyncCommit* commit) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::CpAsyncCommit*>(commit)); // NOLINT
}
void IndexLowering::generate(const std::vector<Expr*>& exprs) {
for (auto expr : exprs) {
OptOutConstDispatch::handle(expr);

View File

@ -55,6 +55,7 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch {
void handle(const kir::BlockSync*) final;
void handle(const kir::GridSync*) final;
void handle(const kir::CpAsyncWait*) final;
void handle(const kir::CpAsyncCommit*) final;
void generate(const std::vector<Expr*>& exprs);

View File

@ -3,6 +3,7 @@
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_index_compute.h>
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
@ -24,39 +25,6 @@ IndexFromIdGraph::IndexFromIdGraph(
namespace {
void insertMagicZero(
const std::vector<kir::ForLoop*>& loops,
const std::vector<IterDomain*>& loop_domains,
std::unordered_map<IterDomain*, Val*>& concrete_loop_idx_map) {
// Find magic zero insertion point
IterDomain* magic_zero_loop = nullptr;
// Search for proper magic zero insertion point,
// prefer innermost.
for (auto idx : c10::irange(loops.size())) {
auto loop = loops[idx];
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
loop_domains[idx], IdMappingMode::EXACT);
auto loop_ind = concrete_loop_idx_map.at(concrete_loop_id);
// Save the concrete id if this loop id is decided to
// be the insertion point by the magic zero util.
if (Index::protectWithMagicZero(loop, concrete_loop_id, loop_ind)) {
magic_zero_loop = concrete_loop_id;
}
}
// Insert magic zero if insertion point found
if (magic_zero_loop != nullptr &&
concrete_loop_idx_map.count(magic_zero_loop)) {
auto& ind = concrete_loop_idx_map.at(magic_zero_loop);
if (!ind->isConstScalar()) {
ind = SimplifyingIrBuilder::addExpr(
ind, GpuLower::current()->kernel()->magicZeroVal());
}
}
}
// Maps all producer domains to consumer with broadcast
// forwarding. Used to find the allocation position.
// TODO: should this be an ir_util ? Didn't seem to be
@ -159,7 +127,7 @@ IndexingParameters getGlobalIndexParameters(
index_parameters.concrete_id_to_halo_extent =
GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing);
insertMagicZero(
protectNonPredicateIndexWithMagicZero(
loops,
loop_indexing.loopDomains(),
index_parameters.initial_concrete_id_index);
@ -182,10 +150,13 @@ IndexingParameters getGlobalIndexParameters(
auto concrete_loop_id = ir_utils::caMapExactConcreteId(loop_id);
auto stage_depth =
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
loop->iter_domain());
index_parameters.initial_concrete_id_index[concrete_loop_id] =
SimplifyingIrBuilder::addExpr(
index_parameters.initial_concrete_id_index[concrete_loop_id],
GpuLower::current()->kernel()->oneVal());
SimplifyingIrBuilder::create<Int>(stage_depth - 1));
}
}
}
@ -412,9 +383,12 @@ IndexingParameters getPredicateInitialIndexParameters(
// be true that that index has been modified to support
// unswitch. In that case, it is not necessary to move ahead the
// index for double buffering.
auto stage_depth =
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
db_loop->iter_domain());
if (cur_index == db_loop->index()) {
loop_to_ind_map[db_loop] = SimplifyingIrBuilder::addExpr(
cur_index, GpuLower::current()->kernel()->oneVal());
cur_index, SimplifyingIrBuilder::create<Int>(stage_depth - 1));
}
}
}
@ -428,10 +402,9 @@ IndexingParameters getPredicateInitialIndexParameters(
loop_to_ind_map.at(loop);
}
insertMagicZero(
loops,
loop_indexing.loopDomains(),
index_parameters.initial_concrete_id_index);
// Note that, unlike non-predicate indexing, magic-zero insertion is
// not done at this point but is done individually for each indexed
// domain. See Index::getReferenceRootPredicates.
// Derive the halo extents from the loop indexing result.
index_parameters.concrete_id_to_halo_extent =

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/index_compute.h>
namespace torch {
namespace jit {

View File

@ -124,13 +124,15 @@ CommonIndexKey::CommonIndexKey(
if (it != concrete_leaf_ids.end()) {
// This leaf reference id is used for indexing the consumer id
used_loops_.push_back(loop);
auto index_it =
loop_index_map.find(gpu_lower->caMap()->getConcreteMappedID(
loop_domains.at(i), IdMappingMode::EXACT));
auto loop_concrete_id = gpu_lower->caMap()->getConcreteMappedID(
loop_domains.at(i), IdMappingMode::EXACT);
auto index_it = loop_index_map.find(loop_concrete_id);
TORCH_INTERNAL_ASSERT(
index_it != loop_index_map.end(),
"Index not found for leaf ID, ",
loop_domains.at(i)->toString());
loop_domains.at(i)->toString(),
", concrete ID: ",
loop_concrete_id->toString());
loop_index_vals_.push_back(index_it->second);
}
}
@ -235,6 +237,7 @@ std::pair<Val*, bool> CommonIndexMap::insert(
const CommonIndexKey key(
indexed_consumer_id, consumer_td, ref_td, ref_index_map, loops);
return tryInsertNewIndex(key, index);
}

View File

@ -724,7 +724,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
// here, except for the initial load part, which is taken care
// separately by DoubleBufferInserter.
if (tv->getMemoryType() == MemoryType::Shared &&
!tv->isDoubleBuffered()) {
!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
smem[tv] = expr;
// only keep track of async writes in smem_async

View File

@ -2,8 +2,10 @@
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_index_compute.h>
namespace torch {
namespace jit {
@ -88,6 +90,133 @@ bool isProtectedWithMagicZero(const Val* val) {
return bop->getBinaryOpType() == BinaryOpType::Add && isMagicZero(bop->rhs());
}
bool needsMagicZero(
kir::ForLoop* loop,
IterDomain* reference_domain,
Val* ind) {
if (ind->isConstScalar()) {
return false;
}
bool ref_dom_simple =
reference_domain == nullptr || reference_domain->definition() != nullptr;
bool ind_simple =
ind == nullptr || (ind->definition() != nullptr && !ind->isZeroInt());
return loop->isUnrolled() && (!ref_dom_simple || !ind_simple);
}
void protectNonPredicateIndexWithMagicZero(
const std::vector<kir::ForLoop*>& loops,
const std::vector<IterDomain*>& loop_domains,
std::unordered_map<IterDomain*, Val*>& concrete_loop_idx_map) {
// Find magic zero insertion point
IterDomain* magic_zero_loop = nullptr;
// Search for proper magic zero insertion point,
// prefer innermost.
for (auto idx : c10::irange(loops.size())) {
auto loop = loops[idx];
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
loop_domains[idx], IdMappingMode::EXACT);
auto loop_ind = concrete_loop_idx_map.at(concrete_loop_id);
// Save the concrete id if this loop id is decided to
// be the insertion point by the magic zero util.
if (needsMagicZero(loop, concrete_loop_id, loop_ind)) {
magic_zero_loop = concrete_loop_id;
}
}
// Insert magic zero if insertion point found
if (magic_zero_loop != nullptr &&
concrete_loop_idx_map.count(magic_zero_loop)) {
auto& ind = concrete_loop_idx_map.at(magic_zero_loop);
ind = SimplifyingIrBuilder::addExpr(
ind, GpuLower::current()->kernel()->magicZeroVal());
}
}
namespace {
//! Protect loop_index_to_protect appearing in overall_index_val
IndexMagicZeroInfo protectIndexByReplacingLoopIndex(
IterDomain* loop_id,
Val* overall_index_val,
Val* loop_index_to_protect) {
auto protected_loop_index = SimplifyingIrBuilder::addExpr(
loop_index_to_protect, GpuLower::current()->kernel()->magicZeroVal());
std::unordered_map<Val*, Val*> replacement_map;
replacement_map[loop_index_to_protect] = protected_loop_index;
auto protected_index =
ir_utils::replaceValInIndexVal(overall_index_val, replacement_map);
IndexMagicZeroInfo info;
info.index = protected_index;
info.original_loop_index = loop_index_to_protect;
info.protected_loop_index = protected_loop_index;
info.loop_id = loop_id;
return info;
}
} // namespace
IndexMagicZeroInfo protectPredicateIndexWithMagicZero(
Val* index,
const IndexFromIdGraph& id_graph,
const std::vector<kir::ForLoop*>& loops) {
// Gather the loop indices
std::unordered_set<Val*> loop_indices;
for (auto loop_id : id_graph.resolved_loop_domains) {
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
loop_id, IdMappingMode::EXACT);
auto index_it = id_graph.initial_concrete_index_map.find(concrete_loop_id);
TORCH_INTERNAL_ASSERT(
index_it != id_graph.initial_concrete_index_map.end(),
"Index not found for loop: ",
concrete_loop_id->toString());
auto loop_index = index_it->second;
loop_indices.insert(loop_index);
}
// Figure out which loop indices are used in index
const auto vals = DependencyCheck::getAllValsBetween(loop_indices, {index});
// Traverser from the inner-most loop and apply the magic-zero
// prorection if needed
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
auto loop = loops.at(i);
auto loop_id = id_graph.resolved_loop_domains.at(i);
TORCH_INTERNAL_ASSERT(GpuLower::current()->caMap()->areMapped(
loop_id, loop->iter_domain(), IdMappingMode::PERMISSIVE));
IterDomain* concrete_loop_id =
GpuLower::current()->caMap()->getConcreteMappedID(
loop_id, IdMappingMode::EXACT);
auto index_it = id_graph.initial_concrete_index_map.find(concrete_loop_id);
TORCH_INTERNAL_ASSERT(
index_it != id_graph.initial_concrete_index_map.end());
auto loop_index = index_it->second;
const auto is_loop_index_used =
std::find(vals.begin(), vals.end(), loop_index) != vals.end();
if (!is_loop_index_used) {
continue;
}
if (needsMagicZero(loop, concrete_loop_id, loop_index)) {
return protectIndexByReplacingLoopIndex(loop_id, index, loop_index);
}
}
// No loop is identified to require protection with magic zero. Just
// return the index argument as is
IndexMagicZeroInfo not_proteced;
not_proteced.index = index;
return not_proteced;
}
} // namespace cuda
} // namespace fuser
} // namespace jit

View File

@ -10,6 +10,8 @@ namespace jit {
namespace fuser {
namespace cuda {
struct IndexFromIdGraph;
//! Insert magic zero definition at the begining of the kernel. Insert magic
//! zero update after every (outer most) loop nest with a compile time extent.
//!
@ -17,13 +19,59 @@ namespace cuda {
std::vector<Expr*> insertMagicZero(const std::vector<Expr*>& exprs);
//! Check if val is a reference to the magic zero variable
bool isMagicZero(const Val* val);
TORCH_CUDA_CU_API bool isMagicZero(const Val* val);
//! Check if val is protected with magic zero.
//!
//! Specifically, this returns true if val is defined as "x + magic_zero".
bool isProtectedWithMagicZero(const Val* val);
// Determine if we may run into over reuse of predicates or registers in the
// compiler. If the loop can be unrolled and the index and domain are not
// "simple" we likely want the loop protected.
//
// Magic zero protection should only be done for global memory and predicates.
// We should avoid use on registers. Shared memory does not require it, but
// likely wouldn't hurt.
bool needsMagicZero(
kir::ForLoop* loop,
IterDomain* reference_domain = nullptr,
Val* ind = nullptr);
struct IndexMagicZeroInfo {
//! Index that may be updated with magic zero
Val* index = nullptr;
//! Loop index that is protected by magic zero. nullptr if no loop
//! is protected
Val* original_loop_index = nullptr;
//! Protected loop index. nullptr if no loop is protected
Val* protected_loop_index = nullptr;
//! Protected loop. nullptr if no loop is protected
IterDomain* loop_id = nullptr;
};
//! Protect an index val of an IterDomain with magic zero
//!
//! This should be only used for predicate indexing.
//!
//! No protection is done if none of the loops is determined to require
//! protection by needsMagicZero.
IndexMagicZeroInfo protectPredicateIndexWithMagicZero(
Val* index,
const IndexFromIdGraph& id_graph,
const std::vector<kir::ForLoop*>& loops);
//! Protect an index val of a tensor with magic zero
//!
//! This should be only used for non-predicate indexing.
//!
//! No protection is done if none of the loops is determined to require
//! protection by needsMagicZero.
void protectNonPredicateIndexWithMagicZero(
const std::vector<kir::ForLoop*>& loops,
const std::vector<IterDomain*>& loop_domains,
std::unordered_map<IterDomain*, Val*>& concrete_loop_idx_map);
} // namespace cuda
} // namespace fuser
} // namespace jit

View File

@ -137,6 +137,16 @@ void SyncMap::build(Fusion* fusion) {
->threadPredMap()
.getPredicateInfo(producer)
.redundant_types;
// Get the parallel types that are inactive in consumer's use chains.
auto producer_redundant_use_types = GpuLower::current()
->threadPredMap()
.getPredicateInfo(producer)
.redundant_use_types;
// In sync info pass we only consider the parallel types in
// producer that are redundantly produced but not redundantly consumed.
producer_redundant_types =
producer_redundant_types & (~producer_redundant_use_types);
for (const auto producer_i : c10::irange(producer->nDims())) {
auto producer_axis = producer->axis(producer_i);
@ -205,25 +215,24 @@ void SyncMap::build(Fusion* fusion) {
continue;
}
auto parallel_type_i = getParallelTypeBitMapOffset(parallel_type);
auto p_id = producer_parallel_ids[parallel_type_i];
auto c_id = consumer_parallel_ids[parallel_type_i];
// If consumer is parallelized with this type but producer is
// predicated redundant on this type. This parallel dimension
// is a RAW dimension. See test: FusionSeriaSmemWriteParallelRead1/2
//
// Even if consumer is not parallelized with this type, would still
// need a raw sync unless all use chain of the producer end with an
// output with the same redundant type.
// TODO: need a separate pass to detect the case where no raw sync
// is needed in this case, i.e. all use-def chains are redundant.
// In the case when the parallel id's are mapped by ca map,
// will additionally need to consider if the producer is
// a redundant write. The raw dim can be skipped only if
// consumer use chains only contain redundant uses.
// TODO:
// still losing a bit precision here for expr ordering
// sensitive cases, but we could wait until that becomes
// a perf limiter to fix.
if (producer_redundant_types.get(parallel_type)) {
raw_dims.set(parallel_type);
continue;
}
auto parallel_type_i = getParallelTypeBitMapOffset(parallel_type);
auto p_id = producer_parallel_ids[parallel_type_i];
auto c_id = consumer_parallel_ids[parallel_type_i];
if (p_id == nullptr && c_id == nullptr) {
continue;
} else if (p_id != nullptr && c_id != nullptr) {

View File

@ -210,11 +210,11 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
auto tv_inp = inp->as<TensorView>();
// Change for welford Op, we want the users of all outputs of welfordOp
// to use a single predicate name.
// If tv_inp was an output of a multi-output expression, just change it to a
// consistent sibling to use a single predicate name.
if (auto tv_def = tv_inp->definition()) {
if (auto wop = dynamic_cast<WelfordOp*>(tv_def)) {
tv_inp = wop->out()->as<TensorView>();
if (tv_def->outputs().size() > 1) {
tv_inp = ir_utils::getTvOutput(tv_def);
}
}
@ -285,6 +285,184 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
}
}
namespace {
//! A simple backward data flow pass:
//! This pass propagates information backward to annotate "redundant use
//! chain"'s.
//! The reason this is needed is that, say for example, if we have a chain
//! of register-to-register ops that begins with a redundant shared mem write
//! and ends with an op that non-redundantly uses the result, we'd need to
//! insert a sync at the begining of the register-to-register chain.
//!
//! The same mechanism also applies in the case of a register/sharedmem chain
//! that starts and ends with global memory read/write.
//!
//! The propagation rule is summarized as follows:
//!
//! Shared TV val:
//! Reset all block redundant info to its own redundant write info
//! Backpropagate grid redundant info
//! Global TV val:
//! Reset all redundant info to its own redundant write info
//! Local Tv val:
//! Backpropagate all redundant info
//! Exprs:
//! Propagate redundant info backwards from outputs to inputs:
//! For each parallel type,
//! The parallel type is redundantly used in the expr input
//! only if all of the outputs redundantly use the same type.
class RedundantUseAnalysis : BackwardVisitor {
public:
RedundantUseAnalysis(Fusion* fusion, const ThreadPredicateMap& pred_map)
: fusion_(fusion), pred_map_(pred_map) {
traverseFrom(fusion, fusion->terminatingMathVals());
}
//! Returns a bit map signifying the parallel dimensions
//! on which the given tv is redundantly used. On these
//! dimensions not all threads/blocks are required to
//! hold valid value for their dependent computations.
ParallelTypeBitmap getRedundantUseBitMap(const TensorView* tv) {
// Since all tv's consumers are visited at this point, we
// can aggregate the final redundant use info for this tv.
if (fusion_->unordered_uses(tv).empty()) {
// Base case, un-used is also not redundantly used
return ParallelTypeBitmap();
} else {
// Aggregate redundant use as a conjunction of all
// consumer's redundant consumer info propagated
// backward from their consumer chains.
ParallelTypeBitmap redundant_use;
redundant_use.setAllBID();
redundant_use.setAllTID();
for (auto expr : fusion_->unordered_uses(tv)) {
redundant_use &= redundant_expr_use_map_.at(expr);
}
return redundant_use;
}
}
private:
using BackwardVisitor::handle;
void handle(TensorView* tv) final {
auto redundant_tv_map = pred_map_.getPredicateInfo(tv).redundant_types;
// Setup the info to propagate backward for the producer tv's and
// expressions.
ParallelTypeBitmap& redundant_consumer_map =
redundant_consumer_parallel_type_map_[tv];
// Initialize the use map to the redundant pred result
redundant_consumer_map = redundant_tv_map;
if (tv->getMemoryType() == MemoryType::Shared) {
backPropagateRedundantUse(
redundant_consumer_map,
tv,
false, // no propagate TID redundant use for shared tv
true // propagate BID redundant use
);
} else if (tv->getMemoryType() == MemoryType::Local) {
backPropagateRedundantUse(
redundant_consumer_map,
tv,
true, // propagate TID redundant use
true // propagate BID redundant use
);
}
}
void backPropagateRedundantUse(
ParallelTypeBitmap& use_map,
TensorView* tv,
bool propagate_tid,
bool propagate_bid) {
// Clear the propagated part of the original result
if (propagate_bid) {
use_map.setAllBID();
}
if (propagate_tid) {
use_map.setAllTID();
}
for (auto expr : fusion_->unordered_uses(tv)) {
// Assuming all consumer expressions have been
// visited at this point since we are traversing
// backward.
auto expr_use_map = redundant_expr_use_map_.at(expr);
// Clear the part of expression use map that does not
// need to be propagated.
if (!propagate_bid) {
expr_use_map.setAllBID();
}
if (!propagate_tid) {
expr_use_map.setAllTID();
}
// Accumulate expression redundant usage
// This implements the `only if all` part in
// the discussion above.
use_map &= expr_use_map;
}
}
void handle(Expr* expr) final {
if (ir_utils::isTvOp(expr)) {
// Initialize redundant info for current expr
c10::optional<ParallelTypeBitmap> maybe_expr_pred_map;
for (auto consumer_tv :
ir_utils::filterByType<TensorView>(expr->outputs())) {
auto tv_redundant_bitmap =
redundant_consumer_parallel_type_map_.at(consumer_tv);
if (maybe_expr_pred_map.has_value()) {
// Accumulate redundant info of this tv output.
maybe_expr_pred_map.value() &= tv_redundant_bitmap;
} else {
// Copy the tv's redundant info as the first valid case.
maybe_expr_pred_map = tv_redundant_bitmap;
}
}
TORCH_INTERNAL_ASSERT(
maybe_expr_pred_map.has_value(), "TV op not having a tv output");
redundant_expr_use_map_[expr] = maybe_expr_pred_map.value();
}
}
private:
// Populated redundant use information on the used tv's
// This map provides information on if the given tv does not require
// valid data from its producer on any parallel dimensions.
// For example:
// T1_local = T0_shared[...]
// if(tid.x == 0)
// T2_shared[...] = T1_local[...]
// Then tidx would be redundant consumer parallel type
// for T1, as T1 is local tensor, and only threads satisfying
// tidx == 0 would need to provide a valid data.
// In this case, not all threads would need to read correct data
// from T0_shared, which would help remove some sync's.
std::unordered_map<const TensorView*, ParallelTypeBitmap>
redundant_consumer_parallel_type_map_;
// Populated redundant use information on the used tv expressions.
std::unordered_map<const Expr*, ParallelTypeBitmap> redundant_expr_use_map_;
// Short cut to the owning fusion of this analysis.
Fusion* fusion_ = nullptr;
// Short cut to the active pred map analysis this pass is running as part of.
const ThreadPredicateMap& pred_map_;
};
} // namespace
void ThreadPredicateMap::build(Fusion* fusion) {
FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap");
@ -298,6 +476,15 @@ void ThreadPredicateMap::build(Fusion* fusion) {
updateBitSet(expr);
}
updated_tvs_.clear();
populateRedundantUseMap(fusion);
}
void ThreadPredicateMap::populateRedundantUseMap(Fusion* fusion) {
RedundantUseAnalysis redundant_use(fusion, *this);
for (auto& it : thread_predicates_) {
it.second.redundant_use_types =
redundant_use.getRedundantUseBitMap(it.first);
}
}
ThreadPredicateMap::const_iterator ThreadPredicateMap::find(
@ -399,6 +586,23 @@ ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains(
return parallel_broadcast & at(tv).limited_types;
}
ParallelTypeBitmap ThreadPredicateMap::getRedundantConsumerType(
Expr* expr) const {
c10::optional<ParallelTypeBitmap> result;
for (auto out_tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
auto out_tv_redundant_map = getPredicateInfo(out_tv).redundant_use_types;
if (!result.has_value()) {
result = out_tv_redundant_map;
} else {
result.value() &= out_tv_redundant_map;
}
}
TORCH_INTERNAL_ASSERT(
result.has_value(), "ThreadPredicateMap : TV op assumed");
return result.value();
}
void ThreadPredicateMap::markAsUpdated(const TensorView* tv) {
updated_tvs_.insert(tv);
}
@ -410,6 +614,7 @@ void ThreadPredicateMap::print() const {
std::cout << "T" << kv.first->name();
std::cout << " {" << kv.second.limited_types.toString() << "}\n";
std::cout << "{" << kv.second.redundant_types.toString() << "}\n";
std::cout << "{" << kv.second.redundant_use_types.toString() << "}\n";
}
std::cout << "--------------------------------\n\n";
}

View File

@ -48,9 +48,24 @@ class TORCH_CUDA_CU_API ThreadPredicateMap {
ParallelTypeBitmap limited_types;
// Parallel types where only one thread/block is enough.
ParallelTypeBitmap redundant_types;
// Tracking use chain of redundant writes:
// [Redundant use chain]
// a parallel type is a `redundant_consumer_type` only
// if all of its propagation use chains terminate with
// a redundant write of this type.
// A propagation use chain is currently either a reg-to-reg
// chain for a shared mem tv, or a reg/smem-to-reg/smem chain
// for a global tv.
// This is complementary information to `redundant_types`.
// If a tensor view is redundantly written and not redundantly
// used by all consumers, see FusionRedundantPredSync3,
// a RAW sync will need to be inserted before reading
// this redundantly written tensor.
ParallelTypeBitmap redundant_use_types;
bool operator==(const PredicateInfo& other) const {
return limited_types == other.limited_types &&
redundant_types == other.redundant_types;
redundant_types == other.redundant_types &&
redundant_use_types == other.redundant_use_types;
}
};
@ -92,6 +107,9 @@ class TORCH_CUDA_CU_API ThreadPredicateMap {
static Bool* getPredicateFromPredicateInfo(
const ThreadPredicateMap::PredicateInfo& pred_info);
//! Get the redundant use types of the given expr, see [Redundant use chain]
ParallelTypeBitmap getRedundantConsumerType(Expr* expr) const;
private:
// Update the thread_predicates bitset based on provided Expr
void updateBitSet(const Expr*);
@ -111,6 +129,10 @@ class TORCH_CUDA_CU_API ThreadPredicateMap {
//! Update a mapping
bool update(const TensorView* tv, const PredicateInfo& pred_and_src);
//! Backward populate redundant use chain info once the redundant
//! parallel writes have been identified.
void populateRedundantUseMap(Fusion* fusion);
private:
MapType thread_predicates_;
//! Keep track of updated tensors that need predicates to be computed

View File

@ -5,6 +5,7 @@
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <unordered_set>
@ -23,29 +24,39 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
TORCH_INTERNAL_ASSERT(
root_id->definition() == nullptr, "Not root IterDomain: ", root_id);
if (tv->definition() == nullptr) {
auto def = tv->definition();
if (def == nullptr) {
// This is an input tensor, so no rfactor tensor to traverse.
return false;
}
const auto& inputs = tv->definition()->inputs();
// Check the reduction expression that produces tv
if (inputs.size() != 1 || !inputs[0]->isA<TensorView>() ||
(tv->definition()->getExprType() != ExprType::ReductionOp &&
tv->definition()->getExprType() != ExprType::WelfordOp)) {
// No rfactor producer found
if (!ir_utils::isReductionOp(def) || def->isA<MmaOp>()) {
return false;
}
auto producer = inputs[0]->as<TensorView>();
TORCH_INTERNAL_ASSERT(
def->inputs().size() == def->outputs().size(),
"This logic block assumes number of inputs is the same as number of outputs of reduction ops.");
if (!producer->hasRFactor()) {
// Reduction expr may have multiple inputs, just grab any TV
// input. Note that in theory it is possible that a
// GroupedReductionOp has rfactor inputs as well as non-rfactor
// inputs, so grabbing the one that actually corresponds to tv can
// be important. In reality, though, such a GroupedReductionOp
// should not happen as we do not group reductions of rfactor and
// non-rfactor tensor.
auto producer_tv = ir_utils::getTvInput(def);
TORCH_INTERNAL_ASSERT(producer_tv != nullptr);
if (!producer_tv->hasRFactor()) {
return false;
}
auto c2p = PairwiseRootDomainMap(producer, tv)
.mapConsumerToProducer(tv->domain(), producer->domain());
auto c2p = PairwiseRootDomainMap(producer_tv, tv)
.mapConsumerToProducer(tv->domain(), producer_tv->domain());
auto producer_id_it = c2p.find(root_id);
if (producer_id_it == c2p.end()) {
@ -55,7 +66,7 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
auto producer_root_id = producer_id_it->second;
return analyzeIfDerivedFromTrivialReduction(producer, producer_root_id);
return analyzeIfDerivedFromTrivialReduction(producer_tv, producer_root_id);
}
bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) {
@ -109,11 +120,6 @@ bool TrivialReductionInfo::isDerived(IterDomain* id) const {
return domains_.find(id) != domains_.end();
}
bool TrivialReductionInfo::isDerivedFromRoot(IterDomain* id) const {
return domains_derived_from_root_.find(id) !=
domains_derived_from_root_.end();
}
} // namespace cuda
} // namespace fuser
} // namespace jit

View File

@ -21,9 +21,6 @@ class TORCH_CUDA_CU_API TrivialReductionInfo {
bool isDerived(IterDomain* id) const;
// TODO: Not used, cleanup
bool isDerivedFromRoot(IterDomain* id) const;
private:
//! IterDomains that are derived only from trivial
//! reductons. Included domains are not limited to reduction axes as

View File

@ -183,14 +183,13 @@ TensorView* getTvOutput(const Expr* expr) {
return nullptr;
}
bool isReductionOp(const Expr* expr) {
// Note that GridReduction inherits ReductionOp
return expr->isA<ReductionOp>() || expr->isA<GroupedReductionOp>() ||
expr->isA<WelfordOp>() || expr->isA<kir::GridWelford>();
}
bool isReductionTvOp(const Expr* expr) {
return isTvOp(expr) && isReductionOp(expr);
TensorView* getTvInput(const Expr* expr) {
for (auto inp : expr->inputs()) {
if (auto tv = getTv(inp)) {
return tv;
}
}
return nullptr;
}
bool isScalarOp(const Expr* expr) {
@ -483,7 +482,7 @@ BasicAllocInfo getAllocInformation(
// Allocation of a double buffered tensor is placed outside its
// double buffer axis.
if (tv->isDoubleBuffered() &&
if ((tv->isDoubleBuffered() || tv->isCircularBuffered()) &&
tv->axis(info.alloc_pos) ==
gpu_lower->doubleBufferInfo().getDoubleBufferAxis(tv)) {
outer_alloc_found = true;

View File

@ -79,11 +79,8 @@ TORCH_CUDA_CU_API bool isTvOp(const Expr*);
// Returns the first output of Expr that is a TensorView
TORCH_CUDA_CU_API TensorView* getTvOutput(const Expr*);
// Returns if Expr is a reduction op
TORCH_CUDA_CU_API bool isReductionOp(const Expr*);
// Returns if Expr is a reduction op with TensorView or TensorIndex
TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);
// Returns the first input of Expr that is a TensorView
TORCH_CUDA_CU_API TensorView* getTvInput(const Expr*);
bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map);

View File

@ -899,22 +899,42 @@ void validateMmaTensors(MmaOp* mma) {
}
// Note: this check will be relaxed in a follow up.
auto validate_operand_ids = [](const TensorView* tv) {
auto validate_operand = [](const TensorView* tv) {
TORCH_INTERNAL_ASSERT(
tv->getMemoryType() == MemoryType::Local,
"Only supporting register input for mma ops, up to sm80 all mma ops have to take register inputs.");
TORCH_INTERNAL_ASSERT(
std::all_of(
tv->domain()->domain().begin() + tv->getComputeAtPosition(),
tv->domain()->domain().end(),
[](IterDomain* id) {
return id->isMmaSwizzled() ||
(id->isBroadcast() &&
// MMA instructions can only take inputs from registers,
// so we always assume mma op inputs are located on
// registers.
// Currently requiring that serial ids on the right of the
// CA axis are constant sized to ensure early detection of
// invalid mma schedules.
((id->isBroadcast() || id->extent()->isConstInt()) &&
id->getParallelType() == ParallelType::Serial);
}),
"All id's on the right of CA pos needs to be mma-swizzled by WarpMmaSwizzler\n",
tv);
};
validate_operand_ids(mma->inA()->as<TensorView>());
validate_operand_ids(mma->inB()->as<TensorView>());
validate_operand(mma->inA()->as<TensorView>());
validate_operand(mma->inB()->as<TensorView>());
// Additionally validate that mma is not directly taking a double buffered
// register input as the double buffer indexing is currently not compatible
// with fragment iteration. Would need to require a cache stage in this case.
TORCH_INTERNAL_ASSERT(
!mma->inA()->as<TensorView>()->isDoubleBuffered(),
"MMA op cannot directly take double buffered register input, put a set stage before.");
TORCH_INTERNAL_ASSERT(
!mma->inB()->as<TensorView>()->isDoubleBuffered(),
"MMA op cannot directly take double buffered register input, put a set stage before.");
}
//! Note and TODO:
@ -1011,6 +1031,7 @@ void validateMma(Fusion* fusion) {
validateMinimumArch(7, 0);
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Turing_16_16_16:
validateMinimumArch(7, 5);
// Check that operands come from ldmatrix, can be
@ -1019,6 +1040,7 @@ void validateMma(Fusion* fusion) {
validateTuringMmaInput(mma->inB()->as<TensorView>());
break;
case MmaOptions::MacroType::Ampere_16_8_16:
case MmaOptions::MacroType::Ampere_16_16_16:
validateMinimumArch(8, 0);
// Check that operands come from ldmatrix, can be
@ -1037,17 +1059,76 @@ void validateMma(Fusion* fusion) {
}
}
namespace {
// Utility function to validate a loop swizzle:
// 1. Throws an error if any output of the swizzle is not in leaf_domain set.
// 2. Warns if any output of the swizzle is not the concrete id of the loop
// map.
// The second case would make the codegen ignore this swizzle, as if it was not
// there at all.
void validateLoopSwizzle(
Expr* swizzle_expr,
std::unordered_set<IterDomain*>& leaf_domains) {
for (auto out_id :
ir_utils::filterByType<IterDomain>(swizzle_expr->outputs())) {
TORCH_INTERNAL_ASSERT(
leaf_domains.count(out_id),
"Loop swizzle can only be direct producer of leaf domains.");
if (GpuLower::current()->caMap()->getConcreteMappedID(
out_id, IdMappingMode::LOOP) != out_id) {
TORCH_WARN_ONCE("Ignored loop swizzle :", swizzle_expr->toString());
}
}
}
} // namespace
void validateSwizzle(Fusion* fusion) {
auto used_vals = fusion->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
if (tv->hasSwizzleOp()) {
std::unordered_set<IterDomain*> tv_leaf_domain_set(
tv->domain()->domain().begin(), tv->domain()->domain().end());
// Make sure no swizzle op is inlined:
auto inlined_swizzles = ir_utils::getAllSwizzlesBetween(
tv->getMaybeRFactorDomain(),
{tv->domain()->domain().begin(),
tv->domain()->domain().begin() + tv->getComputeAtPosition()});
TORCH_INTERNAL_ASSERT(
inlined_swizzles.empty(), "No support for inlined swizzles");
auto not_inlined_swizzles = ir_utils::getAllSwizzlesBetween(
tv->getMaybeRFactorDomain(),
{tv->domain()->domain().begin() + tv->getComputeAtPosition(),
tv->domain()->domain().end()});
// Check inlined swizzles: only loop swizzles can be inlined currently
// as inlining data swizzles would require addtional support of unswizzle
// operator, which currently doesn't have important use cases.
for (auto swizzle_expr : inlined_swizzles) {
TORCH_INTERNAL_ASSERT(
swizzle_expr->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Loop,
"Only support inlining loop swizzles");
validateLoopSwizzle(swizzle_expr, tv_leaf_domain_set);
}
std::unordered_set<Expr*> inlined_swizzle_set(
inlined_swizzles.begin(), inlined_swizzles.end());
// Check not inlined swizzles:
// Apply the loop swizzle check when it applies, and
// also make sure that the no swizzle is also in inlined_swizzle set.
// The latter would mean that one output of the swizzle is inlined while
// the other is not. Such case will not be supported.
for (auto swizzle_expr : not_inlined_swizzles) {
TORCH_INTERNAL_ASSERT(
!inlined_swizzle_set.count(swizzle_expr),
"Cannot partially inline across swizzle domains.",
swizzle_expr->toString());
if (swizzle_expr->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Loop) {
validateLoopSwizzle(swizzle_expr, tv_leaf_domain_set);
}
}
}
}
}

View File

@ -135,6 +135,7 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) {
if (path_.empty()) {
compute_spanning_tree();
}
propagator->setUp();
for (const auto& next_hop : path_) {
switch (next_hop.type) {
case NextHopType::SIBLING:
@ -148,6 +149,7 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) {
break;
}
}
propagator->tearDown();
}
MaxRootDomainInfoSpanningTree::RootDomainInfo::operator bool() const {
@ -422,6 +424,18 @@ void SpanningTreePrinter::propagateSibling(TensorView* from, TensorView* to) {
stream_ << " to: " << to->toString() << std::endl;
}
bool SetSelector::allowC2P(TensorView* from, TensorView* to) {
return selected_.count(to) > 0;
}
bool SetSelector::allowP2C(TensorView* from, TensorView* to) {
return selected_.count(to) > 0;
}
bool SetSelector::allowSibling(TensorView* from, TensorView* to) {
return true;
}
} // namespace cuda
} // namespace fuser
} // namespace jit

View File

@ -47,6 +47,8 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {
// This is the interface to implement the actual propagation
struct Propagator {
virtual void setUp() {}
virtual void tearDown() {}
virtual void propagateC2P(TensorView* from, TensorView* to) = 0;
virtual void propagateP2C(TensorView* from, TensorView* to) = 0;
virtual void propagateSibling(TensorView* from, TensorView* to) = 0;
@ -254,6 +256,25 @@ class TORCH_CUDA_CU_API SpanningTreePrinter
SpanningTreePrinter(std::ostream& stream = std::cout) : stream_(stream) {}
};
// Simple selector for selecting subgraphs to build spanning trees. The selector
// allows propagation only to the given set of selected tensorviews, except for
// sibiling propagation, which we should never block.
class TORCH_CUDA_CU_API SetSelector : public MaxInfoSpanningTree::Selector {
std::unordered_set<TensorView*> selected_;
public:
virtual bool allowC2P(TensorView* from, TensorView* to) override;
virtual bool allowP2C(TensorView* from, TensorView* to) override;
virtual bool allowSibling(TensorView* from, TensorView* to) override;
SetSelector(std::unordered_set<TensorView*> selected)
: selected_(std::move(selected)) {}
const std::unordered_set<TensorView*>& selected() const {
return selected_;
}
};
} // namespace cuda
} // namespace fuser
} // namespace jit

View File

@ -7,6 +7,16 @@ namespace jit {
namespace fuser {
namespace cuda {
MmaOp* MmaOptions::mmaOp() const {
TORCH_INTERNAL_ASSERT(
accumulator_tv != nullptr && accumulator_tv->definition() != nullptr,
"Invalid accumulator_tv.");
auto mma_op = dynamic_cast<MmaOp*>(accumulator_tv->definition());
TORCH_INTERNAL_ASSERT(
mma_op != nullptr, "accumulator tv not an output of mma op");
return mma_op;
}
MmaBuilder::MmaBuilder(
MmaOptions::MacroType macro,
MatMulTileOptions gemm_tile) {
@ -22,6 +32,10 @@ MmaBuilder::MmaBuilder(
case MmaOptions::MacroType::Ampere_16_8_16:
option_.accumulator_stride = outer_stride * 2;
break;
case MmaOptions::MacroType::Ampere_16_16_16:
case MmaOptions::MacroType::Turing_16_16_16:
option_.accumulator_stride = outer_stride * 4;
break;
default:
TORCH_CHECK(false, "unsupported macro");
break;
@ -41,7 +55,7 @@ MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) {
// TODO: validate op config
MmaOptions MmaBuilder::build() const {
TORCH_CHECK(
option_.mma_op != nullptr,
option_.accumulator_tv != nullptr,
"Please configure accumulator tv before using swizzle options.")
return option_;
}
@ -60,9 +74,10 @@ void MmaBuilder::accumulatorTv(TensorView* tv) {
TORCH_CHECK(
tv->getMemoryType() == MemoryType::Local, "Mma only outputs to register");
TORCH_CHECK(tv->definition(), "Input cannot be accumulator tv");
auto mma = dynamic_cast<MmaOp*>(tv->definition());
TORCH_CHECK(mma, "Requires mma op output for reduction tv");
option_.mma_op = mma;
TORCH_CHECK(
tv->definition()->isA<MmaOp>(),
"Requires mma op output for reduction tv");
option_.accumulator_tv = tv;
}
namespace {
@ -73,6 +88,8 @@ LoadStoreOpType getLdMatrixType(MmaOptions options) {
switch (options.macro) {
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
case MmaOptions::MacroType::Ampere_16_16_16:
case MmaOptions::MacroType::Turing_16_16_16:
// Turing mma assumes TN as default
transpose = (options.operand == MmaOptions::Operand::A &&
!isOperandTransposed(options)) ||
@ -98,16 +115,20 @@ bool isVolta(MmaOptions::MacroType macro) {
}
bool isTuring(MmaOptions::MacroType macro) {
return macro == MmaOptions::MacroType::Turing_16_8_16;
return macro == MmaOptions::MacroType::Turing_16_8_16 ||
macro == MmaOptions::MacroType::Turing_16_16_16;
}
bool isAmpere(MmaOptions::MacroType macro) {
return macro == MmaOptions::MacroType::Ampere_16_8_16;
return macro == MmaOptions::MacroType::Ampere_16_8_16 ||
macro == MmaOptions::MacroType::Ampere_16_16_16;
}
int getOutputRegisterSize(MmaOptions::MacroType macro) {
switch (macro) {
case MmaOptions::MacroType::Volta_16_16_4:
case MmaOptions::MacroType::Ampere_16_16_16:
case MmaOptions::MacroType::Turing_16_16_16:
return 8;
break;
case MmaOptions::MacroType::Turing_16_8_16:
@ -127,7 +148,9 @@ int getInputARegisterSize(MmaOptions::MacroType macro) {
return 4;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_8_16:
case MmaOptions::MacroType::Ampere_16_16_16:
return 8;
break;
default:
@ -145,6 +168,9 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) {
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return 4;
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_16_16:
return 8;
default:
TORCH_INTERNAL_ASSERT(false, "unknown macro");
break;
@ -197,6 +223,10 @@ std::string toString(MmaOptions::MacroType mt) {
case MmaOptions::MacroType::Ampere_16_8_16:
ss << "M16N8K16";
break;
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_16_16:
ss << "M16N16K16";
break;
default:
TORCH_INTERNAL_ASSERT(false, "undefined mma type");
break;

View File

@ -19,6 +19,10 @@ struct GemmTile {
GemmTile operator/(const GemmTile& other) {
return GemmTile(m / other.m, n / other.n, k / other.k);
}
std::vector<int> toVector() {
return {m, n, k};
}
};
//! Utility data structure for recording gemm tiles
@ -58,7 +62,9 @@ struct MmaOptions {
NoMMA = 0,
Volta_16_16_4,
Ampere_16_8_16,
Ampere_16_16_16,
Turing_16_8_16,
Turing_16_16_16,
Ampere_16_8_8 // place holder for tf32
};
@ -95,8 +101,18 @@ struct MmaOptions {
accumulator_stride == other.accumulator_stride;
}
// To be inferred by mma builder interface.
MmaOp* mma_op = nullptr;
// The accumulator tensorview register supplied by the
// scheduler interface. Each mma builder is responsible
// for the parameters of one mma op, so the options struct
// would need a pointer to keep track of which mma op it
// is describing.
// Tracking mma expressions would not be stable as the expression
// can get deleted by mutate passes.
TensorView* accumulator_tv = nullptr;
//! Returns the mma op that this options parameter list
//! is describing. See comment on accumulator_tv.
MmaOp* mmaOp() const;
};
//! User interface for configuring the mma and mma related

View File

@ -477,6 +477,9 @@ void OptOutMutator::mutate(kir::GridSync*) {
void OptOutMutator::mutate(kir::CpAsyncWait*) {
TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
}
void OptOutMutator::mutate(kir::CpAsyncCommit*) {
TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
}
void OptOutMutator::mutate(kir::InitMagicZero*) {
TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
}

View File

@ -184,6 +184,13 @@ TensorView* tanh_backward(TensorView* dy, TensorView* tanh_x) {
return dx;
}
TensorView* leaky_relu(TensorView* x, Val* negative_slope) {
TORCH_INTERNAL_ASSERT(x != nullptr, "input is invalid.");
TORCH_INTERNAL_ASSERT(negative_slope != nullptr, "negative_slope is invalid");
auto zero = IrBuilder::create<Double>(x->container(), 0.);
return where(ge(x, zero), x, mul(negative_slope, x));
}
TensorView* view_as_real(TensorView* x) {
auto input_type = x->getDataType().value();
TORCH_CHECK(

View File

@ -52,6 +52,7 @@ TORCH_CUDA_CU_API TensorView* gelu_backward(TensorView* dy, TensorView* x);
TORCH_CUDA_CU_API TensorView* tanh_gelu(TensorView* x);
TORCH_CUDA_CU_API TensorView* tanh_gelu_backward(TensorView* dy, TensorView* x);
TORCH_CUDA_CU_API TensorView* tanh_backward(TensorView* dy, TensorView* tanh_x);
TORCH_CUDA_CU_API TensorView* leaky_relu(TensorView* x, Val* negative_slope);
TORCH_CUDA_CU_API TensorView* view_as_real(TensorView* x);

View File

@ -529,8 +529,8 @@ ForwardNormResult batch_norm(
auto invstd_bcast = broadcast(unbiased_invstd, broadcast_mask);
// During inference, mean/invstd output are empty tensors
mean = TensorViewBuilder().shape({0}).build();
invstd = TensorViewBuilder().shape({0}).build();
mean = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
invstd = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
y = mul(x_sub_mean, invstd_bcast);
}
@ -782,8 +782,8 @@ ForwardNormResult instance_norm(
broadcast(unbiased_invstd, channels_only_broadcast_mask);
// During inference, mean/invstd output are empty tensors
mean = TensorViewBuilder().shape({0}).build();
invstd = TensorViewBuilder().shape({0}).build();
mean = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
invstd = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
y = mul(x_sub_mean, invstd_bcast);
}

View File

@ -2909,6 +2909,34 @@ class IrParser {
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto self = list_val.front()->as<TensorView>();
list_val.pop_front();
Val* negative_slope = value_map[node->inputs()[1]->unique()];
auto out = leaky_relu(self, negative_slope);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
[](const Node* node) -> bool {
if (!isInputNonSizeZeroTensor(node)) {
return false;
}
return true;
},
nullptr);
}
{
auto ptr_op = getOperatorForLiteral(
"aten::gelu(Tensor self, *, str approximate='none') -> Tensor");

View File

@ -528,3 +528,105 @@ __device__ inline int64_t readCycleCounter() {
__threadfence();
return clock64();
}
__device__ float print_impl(const char* name, float value) {
printf(
"%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
value,
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}
__device__ double print_impl(const char* name, double value) {
printf(
"%s = %lf @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
value,
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}
__device__ int print_impl(const char* name, int value) {
printf(
"%s = %d @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
value,
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}
__device__ int64_t print_impl(const char* name, int64_t value) {
printf(
"%s = %ld @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
value,
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}
__device__ bool print_impl(const char* name, bool value) {
printf(
"%s = %s @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
value ? "true" : "false",
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}
__device__ __half print_impl(const char* name, __half value) {
printf(
"%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
__half2float(value),
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}
#if __CUDACC_VER_MAJOR__ >= 11
__device__ __bfloat print_impl(const char* name, __bfloat value) {
printf(
"%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
__bfloat2float(value),
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}
#endif
#define print(...) print_impl(#__VA_ARGS__, (__VA_ARGS__))

View File

@ -157,6 +157,15 @@ DEVICE_INLINE void cpAsyncBarrier() {
asm volatile("cp.async.wait_all;");
}
DEVICE_INLINE void cpAsyncCommit() {
asm volatile("cp.async.commit_group;");
}
template <int keep_stages>
DEVICE_INLINE void cpAsyncPartialBarrier() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(keep_stages));
}
} // namespace Ampere
#endif // Arch 80

View File

@ -276,6 +276,30 @@ DEVICE_INLINE void M16N8K16TN(
_C[acc_stride + 1] = C_data[3];
}
template <int acc_stride>
DEVICE_INLINE void initM16N16K16TN(Array<float, 8, 8>* accumulator) {
float* _C = reinterpret_cast<float*>(accumulator);
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[0]));
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[2]));
}
template <int acc_stride = 2>
DEVICE_INLINE void M16N16K16TN(
Array<float, 8, 8>* C,
Array<__half, 8, 8>* A,
Array<__half, 8, 8>* B) {
float* _C = reinterpret_cast<float*>(C);
__half* _B = reinterpret_cast<__half*>(B);
M16N8K16TN<acc_stride>(
reinterpret_cast<Array<float, 4, 4>*>(&_C[0]),
A,
reinterpret_cast<Array<__half, 4, 4>*>(&_B[0]));
M16N8K16TN<acc_stride>(
reinterpret_cast<Array<float, 4, 4>*>(&_C[2]),
A,
reinterpret_cast<Array<__half, 4, 4>*>(&_B[4]));
}
} // namespace Turing
#endif // Arch 75
@ -338,6 +362,30 @@ DEVICE_INLINE void M16N8K16TN(
_C[acc_stride + 1] = C_data[3];
}
template <int acc_stride>
DEVICE_INLINE void initM16N16K16TN(Array<float, 8, 8>* accumulator) {
float* _C = reinterpret_cast<float*>(accumulator);
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[0]));
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[2]));
}
template <int acc_stride = 2>
DEVICE_INLINE void M16N16K16TN(
Array<float, 8, 8>* C,
Array<__half, 8, 8>* A,
Array<__half, 8, 8>* B) {
float* _C = reinterpret_cast<float*>(C);
__half* _B = reinterpret_cast<__half*>(B);
M16N8K16TN<acc_stride>(
reinterpret_cast<Array<float, 4, 4>*>(&_C[0]),
A,
reinterpret_cast<Array<__half, 4, 4>*>(&_B[0]));
M16N8K16TN<acc_stride>(
reinterpret_cast<Array<float, 4, 4>*>(&_C[2]),
A,
reinterpret_cast<Array<__half, 4, 4>*>(&_B[4]));
}
} // namespace Ampere
#endif // Arch 80

View File

@ -2,6 +2,7 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
namespace torch {
@ -24,6 +25,8 @@ namespace HeuristicCompileTime {
//! Enum for all possible types of cached entries of compile-time info.
enum class CompileTimeEntryType {
DOMAIN_MAP,
REFERENCE_TENSORS,
VECTORIZABLE_INPUTS_AND_OUTPUTS,
UNROLLABLE_INPUTS_AND_OUTPUTS,
REDUCTION_TVS,
@ -32,6 +35,24 @@ enum class CompileTimeEntryType {
BROADCAST_BYTE_MULTIPLES
};
//! Entry type definition class for `DOMAIN_MAP`,
//! stores the domain map of a fusion.
class DomainMap {
public:
using DataType = pointwise_utils::DomainMap;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::DOMAIN_MAP;
};
//! Entry type definition class for `REFERENCE_TENSORS`,
//! stores the the reference TensorViews used to schedule a fusion.
class ReferenceTensors {
public:
using DataType = std::vector<TensorView*>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::REFERENCE_TENSORS;
};
//! Entry type definition class for `VECTORIZABLE_INPUTS_AND_OUTPUTS`,
//! stores the vectorizable TensorViews on a fusion's inputs and outputs.
class VectorizableInputsAndOutputs {

View File

@ -0,0 +1,37 @@
#pragma once
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <string>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
class HeuristicParams {
public:
std::string tag = "";
LaunchParams lparams;
virtual std::string toString() const {
return "Undefined Heuristic Params";
}
virtual size_t hash() const = 0;
virtual ~HeuristicParams() = default;
virtual bool sameAs(const std::shared_ptr<HeuristicParams>& other) const = 0;
virtual std::shared_ptr<HeuristicParams> clone() const = 0;
HeuristicParams() = default;
HeuristicParams(const std::string& tag) : tag(tag) {}
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,329 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/matmul.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace {
// Move the broadcast axes to the left on the specified number of inner
// dimensions e.g. (when number_of_inner_pos == 3):
// [... I0, B, I1] -> [... B, I0, I1]
// should probably be only used to order innermost mnk axes.
void moveInnerBroadcastLeft(TensorView* tv, int number_of_inner_pos = 3) {
TORCH_INTERNAL_ASSERT(tv->nDims() >= number_of_inner_pos);
std::vector<int> broadcast_pos;
std::vector<int> nonbroadcast_pos;
for (auto i : c10::irange(number_of_inner_pos)) {
auto axis_idx = i - number_of_inner_pos;
auto id = tv->axis(axis_idx);
if (id->isBroadcast()) {
broadcast_pos.push_back(axis_idx);
} else {
nonbroadcast_pos.push_back(axis_idx);
}
}
auto combined_pos_vec = broadcast_pos;
combined_pos_vec.insert(
combined_pos_vec.end(), nonbroadcast_pos.begin(), nonbroadcast_pos.end());
std::unordered_map<int, int> order_map;
for (auto i : c10::irange(number_of_inner_pos)) {
order_map[combined_pos_vec.at(i)] = i - number_of_inner_pos;
}
// Apply ordering.
tv->reorder(order_map);
}
} // namespace
void scheduleMatmul(
TensorView* c,
TensorView* a,
TensorView* b,
MatmulParam& params) {
// Unpack from params.
auto& mma_builder = params.mma_builder;
auto& gemm_tile = params.tile_sizes;
// Including current tensor naming convention for reference,
// this is very temporary and will change over time and
// in fact the whole body of this function will
// eventually be a set of utility functions for different
// sections of matmul(fusion) kernels, with
// each having its own build out to do.
//
// Current naming convention:
//
// operands assumed in global memory : a, b
//
// registers staging global load : ar, br (short for a/b read)
//
// shared mem cache of operands : acw_smem, bcw_smem (short for a/b
// cache_write smem)
//
// registers at shared memory load output : acr, bcr (short for a/b cache
// read)
//
// register tensor input to the actual mma op: ab, bb (short for a/b
// broadcasted)
//
// accumulator register: cc (short for c cache)
//
// result in global memory: c
// Currently only support a, b, c as fusion inputs/outputs
// aka. no prolog and epilog fusion yet.
TORCH_CHECK(
c->isFusionOutput() && a->isFusionInput() && b->isFusionInput(),
"not supporting matmul fusion yet");
TORCH_CHECK(c->definition() && c->definition()->isA<MmaOp>());
mma_builder.configureMma(c);
// TODO:
// Beyond this point, mma_builder really just becomes a populated
// list of parameters to describes the mma swizzles that should
// be annotated on the tensor domain. Conceptually the mma builder
// object should be separated to 2 parts, one as scheduler utility
// and the other as matmul heuristic parameters, which we are
// starting to build out.
// Setup register and shared memory stages:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// Setup accumulator register.
auto cc = c->cacheBefore();
// Get the input to the mma op.
auto mma = dynamic_cast<MmaOp*>(cc->definition());
TORCH_INTERNAL_ASSERT(mma != nullptr);
auto ab = mma->inA()->as<TensorView>();
auto bb = mma->inB()->as<TensorView>();
// Get exact configurations from mma builder.
mma_builder.accumulatorTv(cc);
auto mma_options = mma_builder.build();
// Staging register for global memory load
TensorView *ar = a, *br = b;
if (!params.async_gmem_load_operands) {
ar = a->cacheAfter();
br = b->cacheAfter();
}
// TODO:
// Significant build out needed here
// for more flexibility and data type support.
// Shared memory
TensorView* acw_smem = nullptr;
TensorView* bcw_smem = nullptr;
// Shared memory read
TensorView* acr = nullptr;
TensorView* bcr = nullptr;
// Different paths because Volta swizzle needs to
// involve the broadcast dimensions that are concretized
// at mma, while Ampere ones should be done before
// the broadcast op to be able to use cp.async.
// TODO:
// Also a few additional parameters should be introduced
// to control this stage of scheduling.
if (isVolta(mma_options.macro)) {
acw_smem = ab->cacheAfter();
bcw_smem = bb->cacheAfter();
// Cache again to be able to vectorize.
acw_smem = acw_smem->cacheAfter();
bcw_smem = bcw_smem->cacheAfter();
acr = acw_smem->cacheAfter();
bcr = bcw_smem->cacheAfter();
if (params.double_buffer_options.double_buffer_smem_read) {
// Provide another copy op between the double buffered
// smem load register and the actual mma ops to avoid
// complication in double buffered fragment iteration.
ab = acr->cacheAfter();
bb = bcr->cacheAfter();
} else {
ab = acr;
bb = bcr;
}
} else {
// Use cp.async as requested in scheduler params.
c10::optional<LoadStoreOpType> load_op = c10::nullopt;
if (params.async_gmem_load_operands) {
load_op = LoadStoreOpType::CpAsync;
}
acw_smem = ar->cacheAfter(load_op);
bcw_smem = br->cacheAfter(load_op);
acr = acw_smem->cacheAfter(
mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
bcr = bcw_smem->cacheAfter(
mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
}
// Make a CTA tile
// ------------------------------------------------------------------
scheduler_utils::matmul_utils::canonicalizeMmaTvOrdering(cc);
// [... M,N,K]
scheduler_utils::matmul_utils::makeTile(cc, gemm_tile.cta_tile.toVector());
// [Mo, No, Ko, Mi, Ni, Ki]
// Propagate tiling globally
scheduler_utils::transformPropagateToAllFrom(cc, -1);
// Schedule warp tile
scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(cc, gemm_tile);
// Propagate warp tile to main loop and epilog/output tvs
scheduler_utils::BoundedDirectionalTransformPropagator::bothWays(
cc, -1, {acw_smem, bcw_smem}, {c});
// Schedule prolog:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// ------------------------------------------------------------------
scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(acw_smem);
// [... M, K]
acw_smem->merge(-2);
scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
acw_smem, gemm_tile, 8, false);
// [... N, K]
scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(bcw_smem);
bcw_smem->merge(-2);
scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
bcw_smem, gemm_tile, 8, false);
// Propagate prolog tensors
// propagate up the DAG, and propagate parallel type.
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
acw_smem,
-1,
{a},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
bcw_smem,
-1,
{b},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
// Set computeAt, setup the loop nesting structure on the kernel.
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// ------------------------------------------------------------------
// CTA tile:
// Swizzle block tiles:
c->swizzle(Swizzle2DType::ZShape, 0, 1, SwizzleMode::Loop);
a->computeAt(c, 2);
b->computeAt(c, 2);
// Prolog:
a->computeAt(cc, 3);
b->computeAt(cc, 3);
// Main Loop:
acr->computeAt(cc, -6);
bcr->computeAt(cc, -6);
// Add mma swizzle:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// ------------------------------------------------------------------
if (isTuring(mma_options.macro) || isAmpere(mma_options.macro)) {
moveInnerBroadcastLeft(ab);
moveInnerBroadcastLeft(bb);
}
ab->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
bb->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
// Propagate mma input swizzle up the DAG
// to all the tensors before mma op and after shared mem read.
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
ab,
-1,
{acw_smem},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
bb,
-1,
{bcw_smem},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
cc->applyMmaSwizzle(
mma_builder.operand(MmaOptions::Operand::Accumulator).build());
// Set memory type:
acw_smem->setMemoryType(MemoryType::Shared);
bcw_smem->setMemoryType(MemoryType::Shared);
// Set parallelization:
// TODO: this section goes to a separate matmul util,
// and needs more configurability.
// ------------------------------------------------------------------
// Vectorize smem stores/loads:
acw_smem->axis(-1)->parallelize(ParallelType::Vectorize);
bcw_smem->axis(-1)->parallelize(ParallelType::Vectorize);
acr->axis(-1)->parallelize(ParallelType::Vectorize);
bcr->axis(-1)->parallelize(ParallelType::Vectorize);
// 0 1 2 3 4 5 6 7 8 9 10
// [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
cc->axis(0)->parallelize(ParallelType::BIDx);
cc->axis(1)->parallelize(ParallelType::BIDy);
cc->axis(3)->parallelize(ParallelType::TIDz);
cc->axis(4)->parallelize(ParallelType::TIDy);
// Propagate mma output swizzle and parallelization down the DAG
if (params.double_buffer_options.double_buffer_smem_write) {
TORCH_CHECK(
params.double_buffer_options.smem_double_buffer_stage > 1,
"Invalid buffer stage config")
if (params.double_buffer_options.smem_double_buffer_stage > 2) {
TORCH_CHECK(
params.async_gmem_load_operands,
"Circular buffer only supports async load");
}
acw_smem->circularBuffer(
params.double_buffer_options.smem_double_buffer_stage);
bcw_smem->circularBuffer(
params.double_buffer_options.smem_double_buffer_stage);
}
if (params.double_buffer_options.double_buffer_smem_read) {
acr->doubleBuffer();
bcr->doubleBuffer();
}
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
cc,
-1,
{c},
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType()
.propagateToBoundary());
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,55 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/mma_type.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
//! Starting point for a matmul scheduler parameters:
class MatmulParam {
public:
MatmulParam(MmaBuilder builder) : mma_builder(builder) {}
struct DoubleBufferOptions {
bool double_buffer_smem_write = false;
bool double_buffer_smem_read = false;
int smem_double_buffer_stage = 2;
};
//! (Ampere+) Use cp.async to load operands.
bool async_gmem_load_operands = false;
//! Specifies the tiling hierarchy on block,
//! warp, and instruction levels.
MatMulTileOptions tile_sizes;
//! Parameters for configuring mma ops.
MmaBuilder mma_builder;
//! Specify which tensor we double buffer.
DoubleBufferOptions double_buffer_options;
};
//! Prototype auto scheduling function.
//! Currently only support a pure matmul with no
//! fused prolog or epilog.
//!
//! TODO:
//! - will support a range of fusions in a follow up
//! - will formalize scheduling decisions into
//! matmul params data structure.
TORCH_CUDA_CU_API void scheduleMatmul(
TensorView* c_tv,
TensorView* a_tv,
TensorView* b_tv,
MatmulParam& params);
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -4,6 +4,7 @@
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
namespace torch {
namespace jit {
@ -133,6 +134,13 @@ void WarpMmaSwizzler::scheduleMmaWarpOutput(
setWarpMapped(tv, 4);
}
break;
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_16_16:
scheduleTuringM16N16K16MmaWarpOutput(tv, options);
if (tv->definition()->isA<MmaOp>()) {
setWarpMapped(tv, 4);
}
break;
default:
TORCH_CHECK(
false, "scheduleMmaWarp: unsupported mma option ", toString(macro));
@ -150,6 +158,8 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) {
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
case MmaOptions::MacroType::Turing_16_16_16:
case MmaOptions::MacroType::Ampere_16_16_16:
scheduleTuringOperandRead(tv, options);
break;
default:
@ -245,6 +255,14 @@ std::vector<IterDomain*> getMmaDomains(MmaOp* mma, MmaDimension dimension) {
return result;
}
//! Variant of getMmaDomains that returns a set
std::unordered_set<IterDomain*> getMmaDomainSet(
MmaOp* mma,
MmaDimension dimension) {
auto mma_domains = getMmaDomains(mma, dimension);
return {mma_domains.begin(), mma_domains.end()};
}
// [MMA dimension matching]
// Returns all the axes that correspond to the given mma dimension. This is the
// first relaxation step on the mma check.
@ -325,9 +343,10 @@ void validateMmaRootInnerMNK(
int m,
int n,
int k) {
auto m_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::M);
auto n_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::N);
auto k_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::K);
auto mma = options.mmaOp();
auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M);
auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N);
auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K);
TORCH_CHECK(
!m_dims.empty() && !n_dims.empty() && !k_dims.empty(),
@ -354,8 +373,9 @@ void validateMmaRootInnerMNK(
//! swizzles to the right axes.
//! This check will be relaxed as we build out the mma usage patterns.
void validateMmaRootInnerMN(TensorView* tv, MmaOptions options, int m, int n) {
auto m_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::M);
auto n_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::N);
auto mma = options.mmaOp();
auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M);
auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N);
TORCH_CHECK(
!m_dims.empty() && !n_dims.empty(),
@ -494,14 +514,17 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) {
// Check mma option is supported
TORCH_CHECK(
options.macro == MmaOptions::MacroType::Ampere_16_8_16 ||
options.macro == MmaOptions::MacroType::Turing_16_8_16,
options.macro == MmaOptions::MacroType::Ampere_16_16_16 ||
options.macro == MmaOptions::MacroType::Turing_16_8_16 ||
options.macro == MmaOptions::MacroType::Turing_16_16_16,
"scheduleLdMatrix: unknown macro for ldmatrix");
if (options.operand == MmaOptions::Operand::A) {
TORCH_INTERNAL_ASSERT(tv->nDims() >= 2);
// validation:
auto m_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::M);
auto k_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::K);
auto mma = options.mmaOp();
auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M);
auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K);
TORCH_INTERNAL_ASSERT(
canValidateIsInnerDim(m_dims.back(), tv->axis(-2), 16),
@ -532,46 +555,84 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) {
tv->axis(-2)->parallelize(ParallelType::TIDx);
} else if (options.operand == MmaOptions::Operand::B) {
auto n_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::N);
auto k_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::K);
auto mma = options.mmaOp();
auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N);
auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K);
// validation:
TORCH_INTERNAL_ASSERT(
canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 8),
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");
TORCH_INTERNAL_ASSERT(
canValidateIsInnerDim(k_dims.back(), tv->axis(-1), 16),
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");
if (transposed) {
// [8, 16]
tv->split(-2, 4);
// Each ldmatrix 4 would be loading an effective 16x16x16 tile, which is 2x
// the
// size of regular 16x8x16 tile supported by largest mma operation. The
// swizzle also needs to be different to take this into account.
// TODO:
// Using an emulated 16x16x16 mma tile is a temporary step to enable the
// widest load possible for scheduler bring up phase.
// A unifying step would be needed in a follow up to support all these
// swizzles
// with a single affine utility.
bool use_ldmatrix4 = canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 16);
// [2i, 4i, 16]
tv->reorder({{-1, -2}, {-2, -1}});
// [2i, 16, 4i]
if (use_ldmatrix4) {
// [... N16, K16]
tv->split(-2, 8);
tv->split(-1, 8);
tv->merge(-3);
// [warp, 4i]
} else {
//[8, 16]
tv->split(-1, 4);
tv->split(-2, 2);
// -4 -3 -2 -1
// [... N2o, N8, K2o, K8]
tv->reorder({{-3, -2}, {-2, -3}});
// [... N2o, K2o, N8, K8]
// 0 1 2 3
//[8, oo2,oi2,i4]
tv->reorder({{-4, -2}, {-2, -4}});
// 0 1 2 3
//[oi2, oo2, 8,i4]
if (transposed) {
tv->reorder({{-1, -2}, {-2, -1}});
}
tv->merge(-4);
tv->merge(-3);
// 0 1
//[warp, i4]
}
tv->axis(-2)->parallelize(ParallelType::TIDx);
// [Warp, K8]
tv->axis(-2)->parallelize(ParallelType::TIDx);
if (is_immediate_output) {
tv->axis(-1)->parallelize(ParallelType::Vectorize);
}
} else {
// validation:
TORCH_INTERNAL_ASSERT(
canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 8),
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");
if (transposed) {
// [8, 16]
tv->split(-2, 4);
// [2i, 4i, 16]
tv->reorder({{-1, -2}, {-2, -1}});
// [2i, 16, 4i]
tv->merge(-3);
// [warp, 4i]
} else {
//[8, 16]
tv->split(-1, 4);
tv->split(-2, 2);
// 0 1 2 3
//[8, oo2,oi2,i4]
tv->reorder({{-4, -2}, {-2, -4}});
// 0 1 2 3
//[oi2, oo2, 8,i4]
tv->merge(-4);
tv->merge(-3);
// 0 1
//[warp, i4]
}
tv->axis(-2)->parallelize(ParallelType::TIDx);
}
} else {
TORCH_INTERNAL_ASSERT(false, "unreachable");
}
@ -704,6 +765,52 @@ void WarpMmaSwizzler::scheduleTuringM16N8K16MmaWarpOutput(
tv->axis(m_pos)->parallelize(ParallelType::TIDx);
}
void WarpMmaSwizzler::scheduleTuringM16N16K16MmaWarpOutput(
TensorView* tv,
const MmaOptions& options) {
// Assume last 2 dims [M16, N8] or [M16, N8, R]
// Locate instruction m
bool is_reduction = tv->axis(-1)->isReduction();
// Make sure instruction tile size is correct.
if (is_reduction) {
validateMmaRootInnerMNK(tv, options, 16, 16, 16);
} else {
validateMmaRootInnerMN(tv, options, 16, 16);
}
int m_pos = is_reduction ? -3 : -2;
// m
// [16, 16 (,R)]
tv->split(m_pos + 1, 8);
// m
// [16, n2, 8 (,R)]
tv->reorder({{m_pos, m_pos - 1}, {m_pos - 1, m_pos}});
// m
// [n2, 16, 8 (,R)]
tv->split(m_pos, 8);
tv->split(m_pos + 1, 2);
// m
// [2o, 8o, 4i, 2i (,R)]
tv->merge(m_pos - 1);
// m
// [2o, Warp, 2i (,R)]
TORCH_CHECK(tv->definition() != nullptr);
if (is_reduction && tv->definition()->isA<MmaOp>()) {
// Set instruction loops for mma reduce
for (int pos : c10::irange(5)) {
tv->axis(-pos - 1)->parallelize(ParallelType::Mma);
}
}
tv->axis(m_pos)->parallelize(ParallelType::TIDx);
}
namespace {
bool isMmaInitLoop(const kir::Scope& loop_body) {
@ -750,6 +857,76 @@ bool isMmaInitLoop(const kir::ForLoop* loop) {
} // namespace mma_util
void scheduler_utils::matmul_utils::canonicalizeMmaTvOrdering(TensorView* tv) {
std::unordered_set<IterDomain*> root_id_set{
tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()};
auto mma = dynamic_cast<MmaOp*>(tv->definition());
TORCH_CHECK(
mma != nullptr, "canonicalizeMmaTvOrdering : only support mma op output");
auto m_id_set = mma_util::getMmaDomainSet(mma, mma_util::MmaDimension::M);
auto n_id_set = mma_util::getMmaDomainSet(mma, mma_util::MmaDimension::N);
auto k_id_set = mma_util::getMmaDomainSet(mma, mma_util::MmaDimension::K);
std::vector<int> batch_pos, prev_reduction_pos, m_pos, n_pos, k_pos;
auto ndims = tv->nDims();
for (auto idx : c10::irange(ndims)) {
auto id = tv->axis(idx);
TORCH_CHECK(root_id_set.count(id), id->toString(), " not a root id.");
// Categorize each original iterdomain position
if (m_id_set.count(id)) {
m_pos.push_back(idx);
} else if (n_id_set.count(id)) {
n_pos.push_back(idx);
} else if (k_id_set.count(id)) {
k_pos.push_back(idx);
} else if (id->isReduction()) {
prev_reduction_pos.push_back(idx);
} else {
batch_pos.push_back(idx);
}
}
// Collect all mma id's, other id's would be either
// batch or incoming reduction.
// Ordering map from old position to new position
// that we wil build using the position vectors.
std::unordered_map<int, int> order_map;
// Running position counter keeping track of the
// current insert position in order_map.
int current_pos = 0;
// Utility to insert the ordered pos sequences to
// the ordering map.
auto insert_to_order_map =
[&order_map, &current_pos](const std::vector<int>& original_pos) {
for (auto pos : original_pos) {
order_map[pos] = current_pos++;
}
};
// Order the categories, while keeping the original
// intra-category ordering.
insert_to_order_map(batch_pos);
insert_to_order_map(prev_reduction_pos);
insert_to_order_map(m_pos);
insert_to_order_map(n_pos);
insert_to_order_map(k_pos);
// Validate that all of the root ids are covered by
// the inserted categories.
TORCH_INTERNAL_ASSERT(current_pos == ndims, "Id not completely categorized");
// Apply the new ordering
tv->reorder(order_map);
}
} // namespace cuda
} // namespace fuser
} // namespace jit

View File

@ -115,18 +115,32 @@ class TORCH_CUDA_CU_API WarpMmaSwizzler {
MmaOptions options = MmaOptions());
private:
//! Swizzle implementations for Volta mma.
//! Operand swizzle implementations for Volta mma.
static void scheduleVoltaOperandRead(TensorView* tv, MmaOptions options);
//! Accumulator swizzle implementations for Volta mma.
static void scheduleVoltaM16N16K4Fp32Output(
TensorView* tv,
const MmaOptions& options);
//! Swizzle implementations for Turing mma.
//! Operand swizzle implementations for Turing and Ampere mma.
static void scheduleTuringOperandRead(TensorView* tv, MmaOptions options);
//! Accumulator swizzle implementation for Turing and Ampere mma.
static void scheduleTuringM16N8K16MmaWarpOutput(
TensorView* tv,
const MmaOptions& options);
//! Accumulator swizzle implementation for emulated 16x16x16 mma tile
//! that enables using ldmatrix.x4.
//! Note:
//! Keeping both this option and the ldmatrix.x2 variant above for
//! now for wider scheduler exploration space. Eventually both of
//! these can be unified with a single affine utility.
static void scheduleTuringM16N16K16MmaWarpOutput(
TensorView* tv,
const MmaOptions& options);
//! Utility to lock the transformed dimensions from further transforms.
static void setWarpMapped(TensorView* tv, int number_of_dims);
};

View File

@ -33,7 +33,7 @@ int64_t roundUpPow2Or8(const int64_t x) {
// Copied from reduction scheduler, should generalize. Simply needed to take out
// grid reductions.
ReductionParams innerPersistentHeuristic(
std::shared_ptr<ReductionParams> innerPersistentHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t inner_most_dimension_numel,
@ -86,10 +86,9 @@ ReductionParams innerPersistentHeuristic(
const int64_t warp_size_based_on_l1 = std::min(
ceilDiv(
total_reduction_numel,
std::max(
l1_cache /
(n_tensor_inputs * max_input_dtype_size * active_threads),
(int64_t)1)),
scheduler_utils::safeDiv(
l1_cache,
n_tensor_inputs * max_input_dtype_size * active_threads)),
(int64_t)16);
// Take the smaller
@ -105,7 +104,7 @@ ReductionParams innerPersistentHeuristic(
// communication is slow so it shouldn't be done for every element in the
// reduction.
int64_t min_target_iterations =
std::max((int64_t)32 / (int64_t)max_input_dtype_size, (int64_t)1);
scheduler_utils::safeDiv(32, max_input_dtype_size);
// Start trying to break parallelization up across threads,
// unrolling/iterations, and blocks.
@ -121,8 +120,8 @@ ReductionParams innerPersistentHeuristic(
// If we have more than a wave of blocks, put parallelism into unrolling and
// target iterations
if (target_blocks > device_multiprocessor_count) {
auto available_unroll = std::max(
n_elems / (warp_size * device_multiprocessor_count), (int64_t)1);
auto available_unroll = scheduler_utils::safeDiv(
n_elems, warp_size * device_multiprocessor_count);
// Spread across unrolling and iterations, want a balance of the two so flip
// back and forth to alternate adding to them.
@ -140,12 +139,10 @@ ReductionParams innerPersistentHeuristic(
target_iterations *= 2;
}
available_unroll = std::max(
n_elems /
(warp_size * device_multiprocessor_count * target_unroll *
target_iterations),
(int64_t)1);
available_unroll = scheduler_utils::safeDiv(
n_elems,
warp_size * device_multiprocessor_count * target_unroll *
target_iterations);
flip = !flip;
}
@ -171,9 +168,8 @@ ReductionParams innerPersistentHeuristic(
// Compute maximum number of reductions we could do in the same kernel based
// on persistent buffer size
const int64_t max_multi_reduction_factor = std::max(
scheduler_utils::register_file_size / max_persistent_buffer_size,
(int64_t)1);
const int64_t max_multi_reduction_factor = scheduler_utils::safeDiv(
scheduler_utils::register_file_size, max_persistent_buffer_size);
// To get to target threads:
// Prioritize
@ -226,18 +222,18 @@ ReductionParams innerPersistentHeuristic(
// Put everything else in bdimy for now
bdimy = std::min(
std::max(warp_size / bdimx, (int64_t)1), max_multi_reduction_factor);
scheduler_utils::safeDiv(warp_size, bdimx), max_multi_reduction_factor);
// If 3D fill the rest of the threads into bdimz
bdimz = std::min(
std::min(
std::max(max_threads_in_block / (bdimx * bdimy), (int64_t)1),
scheduler_utils::safeDiv(max_threads_in_block, bdimx * bdimy),
outer_reduction_numel),
scheduler_utils::z_block_limit);
// If 3D doesn't fill out the threads, adjust to add to bdimy
bdimy = std::min(
std::max(max_threads_in_block / (bdimx * bdimz), (int64_t)1),
scheduler_utils::safeDiv(max_threads_in_block, bdimx * bdimz),
max_multi_reduction_factor);
// If we don't have a full warp and have an unroll factor, move unroll into
@ -251,14 +247,14 @@ ReductionParams innerPersistentHeuristic(
// Readjust bdimy and bdimz
bdimy = std::min(
std::max(warp_size / bdimx, (int64_t)1), max_multi_reduction_factor);
scheduler_utils::safeDiv(warp_size, bdimx), max_multi_reduction_factor);
bdimz = std::min(
std::max(max_threads_in_block / (bdimx * bdimy), (int64_t)1),
scheduler_utils::safeDiv(max_threads_in_block, bdimx * bdimy),
outer_reduction_numel);
bdimy = std::min(
std::max(max_threads_in_block / (bdimx * bdimz), (int64_t)1),
scheduler_utils::safeDiv(max_threads_in_block, bdimx * bdimz),
max_multi_reduction_factor);
}
@ -313,14 +309,13 @@ ReductionParams innerPersistentHeuristic(
// iteration domain
if (inner_reduction_unroll_factor * outer_reduction_unroll_factor <
max_unroll &&
std::max(max_multi_reduction_factor / bdimy, (int64_t)1) > 2) {
scheduler_utils::safeDiv(max_multi_reduction_factor, bdimy) > 2) {
// Don't go over a combined inner/outer unroll of max_unroll
auto unroll_available = std::min(
std::max(
max_unroll /
(inner_reduction_unroll_factor * outer_reduction_unroll_factor),
(int64_t)1),
std::max(max_multi_reduction_factor / bdimy, (int64_t)1));
scheduler_utils::safeDiv(
max_unroll,
inner_reduction_unroll_factor * outer_reduction_unroll_factor),
scheduler_utils::safeDiv(max_multi_reduction_factor, bdimy));
if (unroll_available > 1 && godim > 2 * device_multiprocessor_count) {
unroll_available = std::min(
unroll_available, ceilDiv(godim, 2 * device_multiprocessor_count));
@ -414,51 +409,52 @@ ReductionParams innerPersistentHeuristic(
int64_t gdimy = LaunchParams::UNINITIALIZED_VAL;
int64_t gdimz = LaunchParams::UNINITIALIZED_VAL;
ReductionParams rparams;
auto rparams = std::make_shared<ReductionParams>();
rparams.persistent_kernel = true;
rparams.fastest_dim = true;
rparams->persistent_kernel = true;
rparams->fastest_dim = true;
// Inner reduction domain
rparams.cross_block_inner_reduction = true;
rparams.block_dim_inner_reduction = ParallelType::TIDx;
rparams.pad_inner_reduction_to_warp = pad_bdimx;
rparams.batches_per_block_inner_reduction = batches_per_block_inner_reduction;
rparams->cross_block_inner_reduction = true;
rparams->block_dim_inner_reduction = ParallelType::TIDx;
rparams->pad_inner_reduction_to_warp = pad_bdimx;
rparams->batches_per_block_inner_reduction =
batches_per_block_inner_reduction;
// For persistent schedules always have to mark the reduction unrolled
// otherwise rfactor can fail
rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor;
rparams.vectorize_inner_reduction = vectorize;
rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor;
rparams->vectorize_inner_reduction = vectorize;
// Iter domain
rparams.multiple_reds_per_blk = bdimy > 1;
if (rparams.multiple_reds_per_blk) {
rparams.block_dim_iter_dom = ParallelType::TIDy;
rparams->multiple_reds_per_blk = bdimy > 1;
if (rparams->multiple_reds_per_blk) {
rparams->block_dim_iter_dom = ParallelType::TIDy;
}
if (godim > 1) {
rparams.grid_dim_iter_dom = ParallelType::BIDx;
rparams->grid_dim_iter_dom = ParallelType::BIDx;
if (godim > scheduler_utils::x_grid_limit) {
rparams.split_grid_dim_iter_dom = true;
rparams->split_grid_dim_iter_dom = true;
gdimx = scheduler_utils::x_grid_limit;
}
}
if (iter_unroll_factor > 1) {
rparams.unroll_factor_iter_dom = iter_unroll_factor;
rparams->unroll_factor_iter_dom = iter_unroll_factor;
}
// Outer reduction domain
rparams.schedule_3D = total_reduction_numel != inner_most_dimension_numel;
if (rparams.schedule_3D) {
rparams.batches_per_block_outer_reduction =
rparams->schedule_3D = total_reduction_numel != inner_most_dimension_numel;
if (rparams->schedule_3D) {
rparams->batches_per_block_outer_reduction =
batches_per_block_outer_reduction;
rparams.block_dim_outer_reduction = ParallelType::TIDz;
rparams.cross_block_outer_reduction = true;
rparams.unroll_factor_outer_reduction = outer_reduction_unroll_factor;
rparams->block_dim_outer_reduction = ParallelType::TIDz;
rparams->cross_block_outer_reduction = true;
rparams->unroll_factor_outer_reduction = outer_reduction_unroll_factor;
}
rparams.lparams = LaunchParams(
rparams->lparams = LaunchParams(
gdimx,
gdimy,
gdimz,
@ -466,7 +462,7 @@ ReductionParams innerPersistentHeuristic(
bdimy,
LaunchParams::UNINITIALIZED_VAL);
rparams.tag = "Inner Persistent Heuristic.\n";
rparams->tag = "Inner Persistent Heuristic.\n";
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
std::cerr << "\n===== Reduction Stats ========\n"
@ -483,7 +479,7 @@ ReductionParams innerPersistentHeuristic(
<< "\n"
<< "block(" << (pad_bdimx ? padded_bdimx : bdimx) << ", " << bdimy
<< ", " << bdimz << ")";
std::cerr << rparams.toString() << std::endl;
std::cerr << rparams->toString() << std::endl;
}
return rparams;
@ -492,7 +488,7 @@ ReductionParams innerPersistentHeuristic(
// Copied from reduction scheduler, should generalize. Simply needed to take out
// grid reductions.
// TODO: Check adding iteration domain unrolling
ReductionParams OuterPersistentHeuristic(
std::shared_ptr<ReductionParams> outerPersistentHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t n_tensor_inputs,
@ -700,47 +696,47 @@ ReductionParams OuterPersistentHeuristic(
gdimx = ceilDiv(total_iteration_numel, bdimx);
ReductionParams rparams;
rparams.batches_per_block_inner_reduction = batches_per_block;
rparams.persistent_kernel = true;
auto rparams = std::make_shared<ReductionParams>();
rparams->batches_per_block_inner_reduction = batches_per_block;
rparams->persistent_kernel = true;
rparams.fastest_dim = false;
rparams.cross_block_inner_reduction = true;
rparams.cross_grid_inner_reduction = false;
rparams.multiple_reds_per_blk = bdimx > 1;
rparams->fastest_dim = false;
rparams->cross_block_inner_reduction = true;
rparams->cross_grid_inner_reduction = false;
rparams->multiple_reds_per_blk = bdimx > 1;
if (rparams.multiple_reds_per_blk) {
rparams.block_dim_iter_dom = ParallelType::TIDx;
if (rparams->multiple_reds_per_blk) {
rparams->block_dim_iter_dom = ParallelType::TIDx;
}
rparams.grid_dim_iter_dom = ParallelType::BIDx;
rparams.split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit;
rparams->grid_dim_iter_dom = ParallelType::BIDx;
rparams->split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit;
if (rparams.block_dim_iter_dom == ParallelType::TIDx) {
rparams.block_dim_inner_reduction = ParallelType::TIDy;
if (rparams->block_dim_iter_dom == ParallelType::TIDx) {
rparams->block_dim_inner_reduction = ParallelType::TIDy;
} else {
rparams.block_dim_inner_reduction = ParallelType::TIDx;
rparams->block_dim_inner_reduction = ParallelType::TIDx;
}
// Always need to mark inner reduction unroll for rfactor in outer persitent
// kernels
rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor;
rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor;
rparams.unroll_factor_iter_dom = iter_unroll_factor;
rparams->unroll_factor_iter_dom = iter_unroll_factor;
if (iter_unroll_factor > 1) {
rparams.vectorize_iter_dom = vectorize;
rparams->vectorize_iter_dom = vectorize;
}
rparams.lparams = LaunchParams(
rparams->lparams = LaunchParams(
LaunchParams::UNINITIALIZED_VAL,
LaunchParams::UNINITIALIZED_VAL,
LaunchParams::UNINITIALIZED_VAL,
rparams.multiple_reds_per_blk ? bdimx : bdimy,
rparams->multiple_reds_per_blk ? bdimx : bdimy,
LaunchParams::UNINITIALIZED_VAL,
LaunchParams::UNINITIALIZED_VAL);
rparams.tag = "Outer persistent kernel heuristic.\n";
rparams->tag = "Outer persistent kernel heuristic.\n";
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
std::cerr << "\n===== Reduction Stats ========\n"
@ -754,7 +750,7 @@ ReductionParams OuterPersistentHeuristic(
<< "max_multi_reduction_factor: " << max_multi_reduction_factor
<< "\n"
<< "block(" << bdimx << ", " << bdimy << ", 1)" << std::endl;
std::cerr << rparams.toString() << std::endl;
std::cerr << rparams->toString() << std::endl;
}
return rparams;
@ -762,7 +758,7 @@ ReductionParams OuterPersistentHeuristic(
} // namespace
ReductionParams PersistentHeuristic(
std::shared_ptr<ReductionParams> persistentHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t inner_most_dimension_numel,
@ -772,7 +768,7 @@ ReductionParams PersistentHeuristic(
const int64_t max_persistent_buffer_size,
size_t vectorize_factor,
bool project_persistent_buffers) {
ReductionParams rparams;
std::shared_ptr<ReductionParams> rparams;
if (fastest_dim_reduction) {
rparams = innerPersistentHeuristic(
total_reduction_numel,
@ -783,7 +779,7 @@ ReductionParams PersistentHeuristic(
max_persistent_buffer_size,
vectorize_factor);
} else {
rparams = OuterPersistentHeuristic(
rparams = outerPersistentHeuristic(
total_reduction_numel,
total_iteration_numel,
n_tensor_inputs,
@ -791,11 +787,11 @@ ReductionParams PersistentHeuristic(
max_persistent_buffer_size,
vectorize_factor);
}
rparams.project_persistent_buffers = project_persistent_buffers;
rparams->project_persistent_buffers = project_persistent_buffers;
return rparams;
}
TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache) {
@ -827,8 +823,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
TORCH_INTERNAL_ASSERT(
red_expr->getExprType() != c10::nullopt &&
(red_expr->getExprType().value() == ExprType::ReductionOp ||
red_expr->getExprType().value() == ExprType::WelfordOp),
ir_utils::isReductionOp(red_expr),
"TensorView doesn't have a reduction.");
auto tv_inps = ir_utils::filterByType<TensorView>(fusion->inputs());
@ -939,7 +934,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
n_tensor_inputs++;
}
return PersistentHeuristic(
return persistentHeuristic(
properties.total_reduction_numel,
properties.total_iteration_numel,
properties.inner_most_dimension_numel,
@ -951,7 +946,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
project_persistent_buffers);
}
TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
Fusion* fusion,
const at::ArrayRef<c10::IValue>& runtime_inputs,
HeuristicSummary* data_cache) {
@ -980,8 +975,9 @@ TORCH_CUDA_CU_API void schedulePersistentKernel(
bool unroll = rparams.isUnrolled();
// Cache inputs if unrolled
auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll);
// Cache inputs even if not unrolled, as otherwise we may not create a
// persistent buffer if that persistent buffer would be the input.
auto cached_inputs = scheduler_utils::cacheInputs(fusion, true);
// Cache and fork outputs
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll);

View File

@ -18,12 +18,12 @@ namespace cuda {
class SchedulerRuntimeInfo;
class HeuristicSummary;
TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
Fusion* fusion,
const at::ArrayRef<c10::IValue>& runtime_inputs,
HeuristicSummary* data_cache = nullptr);
TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr);

View File

@ -27,26 +27,28 @@ namespace {
// Unused at the moment, commenting for clang tidy
constexpr int64_t kThreadX = 128;
// DomainMap uses the ComputeAtMap to find a reference TensorView
// that maps to all iterDomains in the fusion.
class DomainMap {
class DomainMap : public pointwise_utils::DomainMap {
public:
DomainMap(Fusion* fusion) : fusion_(fusion), ca_map_(ComputeAtMap(fusion)) {
view_tvs_ = scheduler_utils::getViewTVs(fusion);
}
using pointwise_utils::DomainMap::DomainMap;
// The pointwise scheduler heuristics requires a minimum number of axes.
// The output reference tensor should respect this requirement.
TensorView* findReferenceTensorView(int minimum_num_axes = 0) const {
TensorView* result = nullptr;
int max_dims = -1;
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
return output_tv;
int n_dims = pointwise_utils::nRootDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
}
}
}
return nullptr;
return result;
}
static bool hasReferenceTensorView(Fusion* fusion) {
@ -71,7 +73,7 @@ class DomainMap {
continue;
}
if (!areAllMapped(input_tv, output_tv)) {
if (!areAllInputIdsMappedToOutput(input_tv, output_tv)) {
return false;
}
}
@ -83,92 +85,11 @@ class DomainMap {
TORCH_INTERNAL_ASSERT(tv != nullptr);
return (num_axes == 0 || tv->getMaybeRFactorDomain().size() > num_axes);
}
// Determine if all iterDomains are mapped between input and output tvs
bool areAllMapped(TensorView* input_tv, TensorView* output_tv) const {
// Get concrete IDs for input root or rfactor domain
std::unordered_set<IterDomain*> in_concrete_ids;
for (auto in_id : input_tv->getMaybeRFactorDomain()) {
if (!ca_map_.getConcreteMappedID(in_id, IdMappingMode::EXACT)
->isBroadcast() &&
!in_id->isReduction()) {
in_concrete_ids.insert(
ca_map_.getConcreteMappedID(in_id, IdMappingMode::EXACT));
}
}
// Erase all input concrete IDs mapped to the output domain
// Ignore unresolved broadcast dimensions
for (auto out_id : output_tv->getMaybeRFactorDomain()) {
if (!out_id->isBroadcast()) {
if (!eraseIfMapped(in_concrete_ids, out_id)) {
eraseIfMappedThroughView(in_concrete_ids, out_id);
}
}
}
return in_concrete_ids.empty();
}
// Erase input concrete ID if it is mapped to output ID
bool eraseIfMapped(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const {
auto out_concrete_id =
ca_map_.getConcreteMappedID(out_id, IdMappingMode::EXACT);
auto in_concrete_id_iter = in_concrete_ids.find(out_concrete_id);
bool found_match = in_concrete_id_iter != in_concrete_ids.end();
if (found_match) {
in_concrete_ids.erase(in_concrete_id_iter);
}
return found_match;
}
// Check if in_id is mapped to out_id through any view rfactor domain
void eraseIfMappedThroughView(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const {
for (auto view : view_tvs_) {
// Find any ID in view rfactor domain that is mapped to output ID
auto view_rfactor_id = anyMapped(view->getRFactorDomain(), out_id);
if (view_rfactor_id == nullptr) {
continue;
}
if (view_rfactor_id->isRFactorProduct()) {
// Check if input ID is mapped to any input IDs of the view rfactor ID
auto root_inputs = InputsOf::outputs(fusion_, {view_rfactor_id});
auto filtered_root_ids =
ir_utils::filterByType<IterDomain>(root_inputs);
for (auto view_root_id : filtered_root_ids) {
eraseIfMapped(in_concrete_ids, view_root_id);
}
} else {
// Otherwise, the input ID must map to the view rfactor ID
eraseIfMapped(in_concrete_ids, view_rfactor_id);
}
}
}
// Find any id in domain that maps with target id
IterDomain* anyMapped(
const std::vector<IterDomain*> domain,
IterDomain* target) const {
for (auto id : domain) {
if (ca_map_.areMapped(id, target, IdMappingMode::EXACT)) {
return id;
}
}
return nullptr;
}
Fusion* fusion_ = nullptr;
ComputeAtMap ca_map_;
std::vector<TensorView*> view_tvs_;
};
} // namespace
c10::optional<PointwiseParams> getPointwiseHeuristics(
std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
Fusion* fusion,
const at::ArrayRef<c10::IValue>& runtime_inputs,
HeuristicSummary* data_cache) {
@ -176,7 +97,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
return getPointwiseHeuristics(fusion, runtime_info, data_cache);
}
c10::optional<PointwiseParams> getPointwiseHeuristics(
std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache) {
@ -187,35 +108,21 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
// Incase any buffer is of type DataType::Index
DataType index_type = indexModeToDtype(runtime_info.getIndexMode());
TensorView* largest_out = nullptr;
int max_dims = -1;
auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
// Will want to access this with direct indexing later, convert now.
std::vector<TensorView*> out_tvs;
// Only use valid reference tensors during heuristics analysis
DomainMap domain_map(fusion);
for (auto out_tv : ir_utils::filterByType<TensorView>(fusion->outputs())) {
if (domain_map.isValidReference(out_tv)) {
out_tvs.push_back(out_tv);
}
}
TORCH_INTERNAL_ASSERT(
!out_tvs.empty(), "No valid reference outputs were found!");
for (auto out_tv : out_tvs) {
int n_dims = 0;
for (auto id : out_tv->getMaybeRFactorDomain()) {
if (id->isReduction() || id->isBroadcast()) {
continue;
}
n_dims++;
}
if (n_dims > max_dims) {
largest_out = out_tv;
max_dims = n_dims;
}
}
auto domain_map_entry =
HeuristicSummaryEntry<HeuristicCompileTime::DomainMap>(
data_cache,
[fusion]() { return std::make_unique<DomainMap>(fusion); });
const auto& domain_map = dynamic_cast<DomainMap&>(domain_map_entry.get());
auto largest_out_entry =
HeuristicSummaryEntry<HeuristicCompileTime::ReferenceTensors>(
data_cache, [&domain_map]() {
std::vector<TensorView*> data{domain_map.findReferenceTensorView()};
return std::make_unique<std::vector<TensorView*>>(std::move(data));
});
TensorView* largest_out = largest_out_entry.get()[0];
TORCH_INTERNAL_ASSERT(largest_out != nullptr);
@ -224,15 +131,12 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
// TODO: Set to 1?
int64_t max_input_dtype_size = 2;
size_t n_tensors = 0;
for (auto inp : in_tvs) {
max_input_dtype_size = std::max(
max_input_dtype_size,
(int64_t)dataTypeSize(inp->getDataType().value(), index_type));
n_tensors++;
}
n_tensors += std::distance(out_tvs.begin(), out_tvs.end());
auto ref_root = largest_out->getMaybeRFactorDomain();
std::vector<int64_t> elem_counts(ref_root.size(), 1);
@ -266,10 +170,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
std::vector<scheduler_utils::BroadcastMultiple>>();
});
broadcast_byte_multiples_entry.get();
PointwiseParams params;
params.tag = "Pointwise heuristics";
return params;
return std::make_shared<PointwiseParams>("Pointwise heuristics");
}
// Find all vectorizable inputs/outputs
@ -307,8 +208,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
max_unroll_factor = 1;
}
PointwiseParams params;
params.tag = "Pointwise heuristics";
auto params = std::make_shared<PointwiseParams>("Pointwise heuristics");
/*
* 2D pointwise scheduling logic. What is expected is there's some
@ -389,7 +289,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
}
// If there isn't very much parallelism available, just use 1D scheduler
if (true || n_elems * 2 > device_multiprocessor_count * kThreadX) {
if (n_elems * 2 > device_multiprocessor_count * kThreadX) {
int64_t min_total_transfer = std::numeric_limits<int64_t>::max();
for (const auto break_point_i : c10::irange(ref_root.size())) {
@ -476,7 +376,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
// Vectorizing innermost domains
// Don't try to vectorize if it's not recommended
params.unroll_factor = 1;
params->unroll_factor = 1;
// Compute maximum vectorize factor that can be used
size_t vectorize_factor = max_unroll_factor;
@ -506,34 +406,33 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
}
if (vectorize_factor == 1) {
params.vectorize = false;
params.unroll_factor = max_unroll_factor;
params->vectorize = false;
params->unroll_factor = max_unroll_factor;
} else {
params.vectorize = true;
params.unroll_factor = vectorize_factor;
params->vectorize = true;
params->unroll_factor = vectorize_factor;
}
TORCH_INTERNAL_ASSERT(right_elem_count > 0 || break_point == 0);
TORCH_INTERNAL_ASSERT(!(bdimy > 1 && gdim_right > 1));
params.break_point = break_point;
params.flip_grid_binding = flip_grid_binding;
params.split_block = bdimy > 1;
params->break_point = break_point;
params->flip_grid_binding = flip_grid_binding;
params->split_block = bdimy > 1;
params.lparams.bind(bdimx, ParallelType::TIDx);
if (params.split_block) {
params.lparams.bind(bdimy, ParallelType::TIDy);
params->lparams.bind(bdimx, ParallelType::TIDx);
if (params->split_block) {
params->lparams.bind(bdimy, ParallelType::TIDy);
}
if ((flip_grid_binding && gdim_right > 65535) ||
(!flip_grid_binding && gdim_left > 65535)) {
params.split_grid_y_dim = true;
params->split_grid_y_dim = true;
}
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
std::cerr << "\n===== Pointwise Stats ========\n"
<< "num_elems: " << n_elems << "\n"
<< "elem_counts: " << elem_counts << "\n"
<< "n_tensor_inputs: " << n_tensors << "\n"
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
<< "vectorize_factor: " << vectorize_factor << std::endl;
std::cerr << "broadcast_byte_multiples: ";
@ -545,7 +444,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
<< (right_elem_count > 0 ? n_elems / right_elem_count : 0)
<< " RHS elems: " << right_elem_count << std::endl;
std::cerr << std::endl;
std::cerr << params.toString() << std::endl;
std::cerr << params->toString() << std::endl;
}
return params;
@ -558,25 +457,11 @@ LaunchParams schedulePointwise(
FUSER_PERF_SCOPE("scheduleFusion");
auto params = getPointwiseHeuristics(fusion, runtime_inputs);
TORCH_INTERNAL_ASSERT(
params.has_value(), "Could not schedule pointwise operation.");
schedulePointwise(fusion, params.value());
return params.value().lparams;
params != nullptr, "Could not schedule pointwise operation.");
schedulePointwise(fusion, *params);
return params->lparams;
}
namespace {
// Returns number of non-reduction/non-broadcast dims in rfactor domain
size_t nRootDims(const TensorView* tv) {
auto root_dom = tv->getMaybeRFactorDomain();
size_t tv_n_dims = 0;
for (auto dim : root_dom) {
if (!dim->isReduction() && !dim->isBroadcast()) {
tv_n_dims++;
}
}
return tv_n_dims;
}
} // namespace
bool hasReferenceTensorView(Fusion* fusion) {
return DomainMap::hasReferenceTensorView(fusion);
}
@ -617,11 +502,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
size_t max_dims = 0;
for (auto inp : input_tvs) {
max_dims = std::max(nRootDims(inp), max_dims);
max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims);
}
for (auto out : output_tvs) {
max_dims = std::max(nRootDims(out), max_dims);
max_dims = std::max(pointwise_utils::nRootDims(out), max_dims);
}
// If everything is zero dim tensors, just return.
@ -643,10 +528,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
int rhs_i = -1;
for (int i = (int)reference_tv->nDims(); i > (int)params.break_point; i--) {
auto axis_i = i - 1;
if (reference_tv->axis(axis_i)->isBroadcast() ||
reference_tv->axis(axis_i)->isReduction()) {
continue;
}
if (rhs_i == -1) {
rhs_i = axis_i;
} else {
@ -663,10 +544,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
int lhs_i = -1;
for (int i = (int)params.break_point; i > 0; i--) {
auto axis_i = i - 1;
if (reference_tv->axis(axis_i)->isBroadcast() ||
reference_tv->axis(axis_i)->isReduction()) {
continue;
}
if (lhs_i == -1) {
lhs_i = axis_i;
} else {
@ -676,6 +553,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
}
int64_t unswitch_pos;
IterDomain* vectorize_id = nullptr;
if (params.break_point) {
// 2D parallelization scheme
TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0);
@ -691,10 +569,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// [outer, Unswitch | i-remainder, TIDx, Vectorization]
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
reference_tv->axis(4)->parallelize(ParallelType::Vectorize);
// Vectorization are propagated separately
vectorize_id = reference_tv->axis(4);
// [outer, Unswitch | i-remainder, TIDx, Vectorization]
// To make consistent with unrolling:
@ -709,6 +585,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
reference_tv->reorder({{1, 2}});
// [outer, i-remainder, unswitch, unroll, TIDx ]
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Here we do not set axis(3)->parallelize(Unroll) because we do not want
// it to be propagated. We manually unroll by splitting the inline
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(4)->parallelize(ParallelType::TIDx);
//[outer | i-remainder, Unswitch, Unroll, TIDx]
@ -795,9 +676,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(1)->parallelize(ParallelType::TIDx);
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
reference_tv->axis(3)->parallelize(ParallelType::Vectorize);
// Vectorization are propagated separately
vectorize_id = reference_tv->axis(3);
//[BIDx, TIDx, Unswitch, Vectorization]
// To make consistent with unrolling:
@ -814,6 +694,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
// [BIDx, Unswitch, Unroll, TIDx]
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
// Here we do not set axis(2)->parallelize(Unroll) because we do not want
// it to be propagated. We manually unroll by splitting the inline
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
}
unswitch_pos = 2;
@ -822,43 +707,42 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
TransformPropagator propagator(reference_tv);
MaxRootDomainInfoSpanningTree spanning_tree(reference_tv);
spanning_tree.traverse(&propagator);
scheduler_utils::parallelizeAllLike(reference_tv, all_tvs);
scheduler_utils::parallelizeAllLike(reference_tv);
if (params.vectorize) {
// Grab all tensor views that should be vectorized
auto vectorized_tvs =
auto inputs_outputs =
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true);
// Going to move inputs to consumers of inputs, need a copy as we'll modify
// the original.
{
auto vectorized_tvs_copy = vectorized_tvs;
for (auto inp : vectorized_tvs_copy) {
if (!inp->isFusionInput()) {
continue;
}
vectorized_tvs.erase(
std::find(vectorized_tvs.begin(), vectorized_tvs.end(), inp));
auto consumer_tvs = ir_utils::consumerTvsOf(inp);
vectorized_tvs.insert(
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
std::vector<TensorView*> vectorized_tvs;
bool should_vectorize_reference_tv = false;
for (auto tv : inputs_outputs) {
if (tv == reference_tv) {
should_vectorize_reference_tv = true;
}
if (!tv->isFusionInput()) {
vectorized_tvs.emplace_back(tv);
continue;
}
// move inputs to consumers of inputs
auto consumer_tvs = ir_utils::consumerTvsOf(tv);
vectorized_tvs.insert(
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
}
// Clear vectorize on tensors that shouldn't have it
for (auto tv : all_tvs) {
if (std::find(vectorized_tvs.begin(), vectorized_tvs.end(), tv) ==
vectorized_tvs.end()) {
for (auto id : tv->domain()->domain()) {
if (id->getParallelType() == ParallelType::Vectorize) {
id->parallelize(ParallelType::Serial);
}
}
}
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
vectorize_id->parallelize(ParallelType::Vectorize);
scheduler_utils::parallelizeAllLike(
reference_tv, vectorized_tvs, {ParallelType::Vectorize});
if (!should_vectorize_reference_tv) {
vectorize_id->parallelize(ParallelType::Serial);
}
}
// Begin by inlining at the unswitch position for the entire DAG. The cached
// inputs, and outputs will keep this inline position, but other tensors will
// get a higher position in later inline propagation.
// get a higher position in later inline propagation. We need this separate
// step because we were not using ParallelType::Unroll, so we have to do
// unrolling manually.
InlinePropagator inline_unswitch(
reference_tv, unswitch_pos, ComputeAtMode::BestEffort);
spanning_tree.traverse(&inline_unswitch);
@ -877,10 +761,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
InlinePropagator inline_inner_most(
reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors);
spanning_tree.traverse(&inline_inner_most);
// Fix max producer position
MaxProducerPosUpdater updater;
spanning_tree.traverse(&updater);
}
} // namespace cuda

View File

@ -13,12 +13,12 @@ namespace cuda {
class SchedulerRuntimeInfo;
class HeuristicSummary;
TORCH_CUDA_CU_API c10::optional<PointwiseParams> getPointwiseHeuristics(
TORCH_CUDA_CU_API std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
Fusion* fusion,
const at::ArrayRef<c10::IValue>& runtime_inputs,
HeuristicSummary* data_cache = nullptr);
TORCH_CUDA_CU_API c10::optional<PointwiseParams> getPointwiseHeuristics(
TORCH_CUDA_CU_API std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr);

View File

@ -1,6 +1,6 @@
#pragma once
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/heuristic.h>
#include <sstream>
@ -9,11 +9,11 @@ namespace jit {
namespace fuser {
namespace cuda {
// Parameters the Reduction Heuristic Generates to describe the optimial
// schedule. Warning: equal operator is intended for use in caching the kernel
// associated with these reduction parameters. It does not check if the launch
// parameters are equivelent!
class PointwiseParams {
// Parameters of the pointwise heuristic to describe the optimial schedule.
// Warning: equal operator is intended for use in caching the kernel associated
// with these pointwise parameters. It does not check if the launch parameters
// are equivelent!
class PointwiseParams : public HeuristicParams {
public:
// vectorize if true, otherwise unroll
bool vectorize = false;
@ -39,12 +39,16 @@ class PointwiseParams {
// Unroll or vectorization factor
size_t unroll_factor = 1;
std::string tag = "";
LaunchParams lparams;
using HeuristicParams::HeuristicParams;
// Warning: Does not check launch parameters!
bool operator==(const PointwiseParams& other) const {
bool sameAs(
const std::shared_ptr<HeuristicParams>& other_base) const override {
auto other_casted = std::dynamic_pointer_cast<PointwiseParams>(other_base);
if (other_casted == nullptr) {
return false;
}
const PointwiseParams& other = *other_casted;
bool attr_equal = other.vectorize == vectorize &&
other.break_point == break_point && other.split_block == split_block &&
other.split_grid_y_dim == split_grid_y_dim &&
@ -53,7 +57,7 @@ class PointwiseParams {
return attr_equal;
}
std::string toString() const {
std::string toString() const override {
std::stringstream ss;
ss << "\n===== Pointwise Parameters ========\n"
<< (tag == "" ? "" : "Tag: ") << tag << " Pointwise Characteristics:\n"
@ -82,20 +86,21 @@ class PointwiseParams {
ss << "====================================\n";
return ss.str();
}
};
// Warning: Hash is not based on launch parameters!
class PointwiseParamsHash {
public:
size_t operator()(const PointwiseParams& pp) const {
size_t attr_hash = static_cast<size_t>(pp.vectorize) ^
static_cast<size_t>(pp.break_point) << 4 ^
static_cast<size_t>(pp.split_block) << 5 ^
static_cast<size_t>(pp.split_grid_y_dim) << 6 ^
static_cast<size_t>(pp.unroll_factor) << 9 ^
static_cast<size_t>(pp.flip_grid_binding) << 10;
// Warning: Hash is not based on launch parameters!
size_t hash() const override {
size_t attr_hash = static_cast<size_t>(vectorize) ^
static_cast<size_t>(break_point) << 4 ^
static_cast<size_t>(split_block) << 5 ^
static_cast<size_t>(split_grid_y_dim) << 6 ^
static_cast<size_t>(unroll_factor) << 9 ^
static_cast<size_t>(flip_grid_binding) << 10;
return attr_hash;
}
std::shared_ptr<HeuristicParams> clone() const override {
return std::make_shared<PointwiseParams>(*this);
}
};
} // namespace cuda

View File

@ -0,0 +1,101 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace pointwise_utils {
DomainMap::DomainMap(Fusion* fusion)
: fusion_(fusion), ca_map_(ComputeAtMap(fusion)) {
view_tvs_ = scheduler_utils::getViewTVs(fusion);
}
bool DomainMap::areExactMapped(IterDomain* id1, IterDomain* id2) {
return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT);
}
// Determine if all IterDomains in input are mapped to output
bool DomainMap::areAllInputIdsMappedToOutput(
TensorView* input_tv,
TensorView* output_tv) const {
// Get concrete IDs for input root or rfactor domain
std::unordered_set<IterDomain*> in_concrete_ids;
for (auto in_id : input_tv->getMaybeRFactorDomain()) {
auto concrete = ca_map_.getConcreteMappedID(in_id, IdMappingMode::EXACT);
if (!concrete->isBroadcast() && !in_id->isReduction()) {
in_concrete_ids.insert(concrete);
}
}
// Erase all input concrete IDs mapped to the output domain
// Ignore unresolved broadcast dimensions
for (auto out_id : output_tv->getMaybeRFactorDomain()) {
if (!out_id->isBroadcast()) {
if (!eraseIfMapped(in_concrete_ids, out_id)) {
eraseIfInputMappedThroughViewToOutput(in_concrete_ids, out_id);
}
}
}
return in_concrete_ids.empty();
}
// Erase input concrete ID if it is mapped to output ID
bool DomainMap::eraseIfMapped(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const {
auto out_concrete_id =
ca_map_.getConcreteMappedID(out_id, IdMappingMode::EXACT);
auto in_concrete_id_iter = in_concrete_ids.find(out_concrete_id);
bool found_match = in_concrete_id_iter != in_concrete_ids.end();
if (found_match) {
in_concrete_ids.erase(in_concrete_id_iter);
}
return found_match;
}
// Check if in_id is mapped to out_id through any view rfactor domain.
// Currently this function only allow having one view on the path from input to
// output. If there are multiple views, then likely the pointwise scheduler will
// reject the fusion because we can not correctly find a reference tensor.
void DomainMap::eraseIfInputMappedThroughViewToOutput(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const {
for (auto view : view_tvs_) {
// Find any ID in view rfactor domain that is mapped to output ID
auto view_rfactor_id = anyMapped(view->getRFactorDomain(), out_id);
if (view_rfactor_id == nullptr) {
continue;
}
if (view_rfactor_id->isRFactorProduct()) {
// Check if input ID is mapped to any input IDs of the view rfactor ID
auto root_inputs = InputsOf::outputs(fusion_, {view_rfactor_id});
auto filtered_root_ids = ir_utils::filterByType<IterDomain>(root_inputs);
for (auto view_root_id : filtered_root_ids) {
eraseIfMapped(in_concrete_ids, view_root_id);
}
} else {
// Otherwise, the input ID must map to the view rfactor ID
eraseIfMapped(in_concrete_ids, view_rfactor_id);
}
}
}
// Find any id in domain that maps with target id
IterDomain* DomainMap::anyMapped(
const std::vector<IterDomain*>& domain,
IterDomain* target) const {
for (auto id : domain) {
if (ca_map_.areMapped(id, target, IdMappingMode::EXACT)) {
return id;
}
}
return nullptr;
}
} // namespace pointwise_utils
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,68 @@
#pragma once
#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace pointwise_utils {
// DomainMap uses the ComputeAtMap to find a reference TensorView
// that maps to all IterDomains in the fusion.
class DomainMap {
public:
DomainMap(Fusion* fusion);
virtual ~DomainMap() = default;
bool areExactMapped(IterDomain* id1, IterDomain* id2);
const ComputeAtMap& getComputeAtMap() const {
return ca_map_;
}
protected:
// Determine if all iterDomains are mapped between input and output tvs
bool areAllInputIdsMappedToOutput(TensorView* input_tv, TensorView* output_tv)
const;
// Erase input concrete ID if it is mapped to output ID
bool eraseIfMapped(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;
// Check if in_id is mapped to out_id through any view rfactor domain
void eraseIfInputMappedThroughViewToOutput(
std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;
// Find any id in domain that maps with target id
IterDomain* anyMapped(
const std::vector<IterDomain*>& domain,
IterDomain* target) const;
Fusion* fusion_ = nullptr;
ComputeAtMap ca_map_;
std::vector<TensorView*> view_tvs_;
};
// Returns number of non-reduction/non-broadcast dims in rfactor domain
inline size_t nRootDims(const TensorView* tv) {
auto root_dom = tv->getMaybeRFactorDomain();
size_t tv_n_dims = 0;
for (auto dim : root_dom) {
if (!dim->isReduction() && !dim->isBroadcast()) {
tv_n_dims++;
}
}
return tv_n_dims;
}
} // namespace pointwise_utils
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch

View File

@ -40,11 +40,6 @@ int64_t roundDownPow2OrMultipleOf(const int64_t x, const int64_t multiple) {
return std::max(std::max(round_down_multiple, round_down_pow2), (int64_t)1);
}
// Div x by y, but min at 1
int64_t safeDiv(const int64_t x, const int64_t y) {
return std::max(x / y, (int64_t)1);
}
int64_t clamp(const int64_t val, const int64_t min_val, const int64_t max_val) {
return std::min(std::max(val, min_val), max_val);
}
@ -54,20 +49,20 @@ int64_t clamp(const int64_t val, const int64_t min_val, const int64_t max_val) {
void reduceProductTo(int64_t& z, int64_t& y, int64_t& x, const int64_t max) {
TORCH_INTERNAL_ASSERT(max > 1);
if (z * y * x > max) {
z = safeDiv(z, 2);
z = scheduler_utils::safeDiv(z, 2);
}
if (z * y * x > max) {
y = safeDiv(y, 2);
y = scheduler_utils::safeDiv(y, 2);
}
if (z * y * x > max) {
x = safeDiv(x, 2);
x = scheduler_utils::safeDiv(x, 2);
}
if (z * y * x > max) {
reduceProductTo(x, y, z, max);
}
}
ReductionParams innerReductionHeuristic(
std::shared_ptr<ReductionParams> innerReductionHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t inner_most_dimension_numel,
@ -382,21 +377,21 @@ ReductionParams innerReductionHeuristic(
// require iterating over this entire function.
}
ReductionParams rparams;
rparams.fastest_dim = true;
rparams.cross_block_inner_reduction = true;
rparams.block_dim_inner_reduction = ParallelType::TIDx;
rparams.cross_grid_inner_reduction = gridim > 1;
rparams.multiple_reds_per_blk = bdimy > 1;
auto rparams = std::make_shared<ReductionParams>();
rparams->fastest_dim = true;
rparams->cross_block_inner_reduction = true;
rparams->block_dim_inner_reduction = ParallelType::TIDx;
rparams->cross_grid_inner_reduction = gridim > 1;
rparams->multiple_reds_per_blk = bdimy > 1;
bool pad_bdimx = bdimx > 16 &&
bdimx * bdimy <
(int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
// If barely just covering reduction dim, don't pad to the next warp
pad_bdimx = pad_bdimx &&
bdimx * inner_reduction_unroll_factor != inner_most_dimension_numel;
rparams.pad_inner_reduction_to_warp = pad_bdimx;
rparams->pad_inner_reduction_to_warp = pad_bdimx;
if (rparams.pad_inner_reduction_to_warp) {
if (rparams->pad_inner_reduction_to_warp) {
// Adjust bdimx based on padding
auto min_warp_size =
(int64_t)at::cuda::getCurrentDeviceProperties()->warpSize;
@ -405,24 +400,24 @@ ReductionParams innerReductionHeuristic(
: bdimx + min_warp_size - bdimx % min_warp_size;
}
rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor;
rparams.vectorize_inner_reduction = vectorize;
rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor;
rparams->vectorize_inner_reduction = vectorize;
if (rparams.multiple_reds_per_blk) {
rparams.block_dim_iter_dom = ParallelType::TIDy;
if (rparams->multiple_reds_per_blk) {
rparams->block_dim_iter_dom = ParallelType::TIDy;
}
rparams.unroll_factor_iter_dom = iter_unroll_factor;
rparams->unroll_factor_iter_dom = iter_unroll_factor;
rparams.schedule_3D = total_reduction_numel != inner_most_dimension_numel;
rparams->schedule_3D = total_reduction_numel != inner_most_dimension_numel;
// Outer reduction domain
if (rparams.schedule_3D) {
rparams.cross_grid_outer_reduction = grodim > 1;
if (rparams->schedule_3D) {
rparams->cross_grid_outer_reduction = grodim > 1;
if (bdimz > 1) {
rparams.block_dim_outer_reduction = ParallelType::TIDz;
rparams.cross_block_outer_reduction = true;
rparams->block_dim_outer_reduction = ParallelType::TIDz;
rparams->cross_block_outer_reduction = true;
}
rparams.unroll_factor_outer_reduction = outer_reduction_unroll_factor;
rparams->unroll_factor_outer_reduction = outer_reduction_unroll_factor;
}
int64_t gdimx = LaunchParams::UNINITIALIZED_VAL;
@ -433,38 +428,38 @@ ReductionParams innerReductionHeuristic(
// gdimx assigned to grdim. Otherwise it's helpful to pull godim into gdimx in
// case it's larger than gdimy can hold, as not doing so can thrash the cache.
if (rparams.cross_grid_inner_reduction) {
rparams.grid_dim_inner_reduction = ParallelType::BIDx;
rparams.split_grid_dim_inner_reduction = true;
if (rparams->cross_grid_inner_reduction) {
rparams->grid_dim_inner_reduction = ParallelType::BIDx;
rparams->split_grid_dim_inner_reduction = true;
gdimx = std::min(gridim, scheduler_utils::x_grid_limit);
rparams.grid_dim_iter_dom = ParallelType::BIDy;
rparams->grid_dim_iter_dom = ParallelType::BIDy;
if (godim > scheduler_utils::y_grid_limit) {
rparams.split_grid_dim_iter_dom = true;
rparams->split_grid_dim_iter_dom = true;
gdimy = std::min(godim, scheduler_utils::y_grid_limit);
}
} else {
rparams.grid_dim_iter_dom = ParallelType::BIDx;
rparams->grid_dim_iter_dom = ParallelType::BIDx;
if (gdimx > scheduler_utils::x_grid_limit) {
rparams.split_grid_dim_iter_dom = true;
rparams->split_grid_dim_iter_dom = true;
gdimx = godim;
}
}
if (rparams.cross_grid_outer_reduction) {
if (rparams.cross_block_inner_reduction) {
rparams.grid_dim_outer_reduction = ParallelType::BIDz;
if (rparams->cross_grid_outer_reduction) {
if (rparams->cross_block_inner_reduction) {
rparams->grid_dim_outer_reduction = ParallelType::BIDz;
gdimz = std::min(grodim, scheduler_utils::z_grid_limit);
rparams.split_grid_dim_outer_reduction = true;
rparams->split_grid_dim_outer_reduction = true;
} else {
rparams.grid_dim_outer_reduction = ParallelType::BIDy;
rparams->grid_dim_outer_reduction = ParallelType::BIDy;
gdimy = std::min(grodim, scheduler_utils::y_grid_limit);
rparams.split_grid_dim_outer_reduction = true;
rparams->split_grid_dim_outer_reduction = true;
}
}
rparams.lparams = LaunchParams(
rparams->lparams = LaunchParams(
gdimx,
gdimy,
gdimz,
@ -483,20 +478,20 @@ ReductionParams innerReductionHeuristic(
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
<< "block(" << bdimx << ", " << bdimy << ", " << bdimz << ")"
<< std::endl;
std::cerr << rparams.toString() << std::endl;
std::cerr << rparams->toString() << std::endl;
}
// If 3d, check if it's supported by the scheduler, otherwise force 1D
// schedule
if (rparams.schedule_3D) {
if (rparams.multiple_reds_per_blk &&
(rparams.cross_grid_inner_reduction ||
rparams.cross_grid_outer_reduction)) {
if (rparams->schedule_3D) {
if (rparams->multiple_reds_per_blk &&
(rparams->cross_grid_inner_reduction ||
rparams->cross_grid_outer_reduction)) {
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
std::cerr << "\n===== UNSUPPORTED REDUCTION HEURISTIC ========\n";
std::cerr << rparams.multiple_reds_per_blk << ", "
<< (rparams.unroll_factor_inner_reduction > 1) << ", "
<< rparams.cross_grid_inner_reduction << std::endl;
std::cerr << rparams->multiple_reds_per_blk << ", "
<< (rparams->unroll_factor_inner_reduction > 1) << ", "
<< rparams->cross_grid_inner_reduction << std::endl;
}
return innerReductionHeuristic(
total_reduction_numel,
@ -511,7 +506,7 @@ ReductionParams innerReductionHeuristic(
return rparams;
}
ReductionParams OuterReductionHeuristic(
std::shared_ptr<ReductionParams> outerReductionHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t n_tensor_inputs,
@ -684,16 +679,17 @@ ReductionParams OuterReductionHeuristic(
bdimx = roundUpPow2OrMultipleOf(bdimx, 8);
// Fill bdimy with left over threads
bdimy =
std::min(safeDiv(target_threads_in_block, bdimx), total_reduction_numel);
bdimy = std::min(
scheduler_utils::safeDiv(target_threads_in_block, bdimx),
total_reduction_numel);
bdimy = roundDownPow2OrMultipleOf(bdimy, 8);
// Move parallelization into unrolling the reduction dimension if
// parallelizing iteration dimension didn't take the available unroll factor.
if (iter_unroll_factor < max_unroll && rDimAvail() > 2) {
inner_reduction_unroll_factor =
std::min(rDimAvail(), safeDiv(max_unroll, iter_unroll_factor));
inner_reduction_unroll_factor = std::min(
rDimAvail(), scheduler_utils::safeDiv(max_unroll, iter_unroll_factor));
inner_reduction_unroll_factor =
scheduler_utils::lastPow2(inner_reduction_unroll_factor);
@ -731,7 +727,8 @@ ReductionParams OuterReductionHeuristic(
// Empiercally found stride shouldn't exceed 256kiB boundaries in a block
int64_t kMaxStride = 128 * 1024;
int64_t max_remainder_size = safeDiv(kMaxStride, bytes_stride_remainder);
int64_t max_remainder_size =
scheduler_utils::safeDiv(kMaxStride, bytes_stride_remainder);
int64_t grdim_for_stride = ceilDiv(
total_reduction_numel,
@ -771,13 +768,13 @@ ReductionParams OuterReductionHeuristic(
// Always disabled for now.
// bool flip_grid = gidim > 1 && gidim < 8;
const bool flip_grid = false;
ReductionParams rparams;
auto rparams = std::make_shared<ReductionParams>();
// cross grid implies cross block
rparams.cross_block_inner_reduction = bdimy > 1 || grdim > 1;
rparams.cross_grid_inner_reduction = grdim > 1;
if (rparams.cross_grid_inner_reduction) {
rparams.split_grid_dim_inner_reduction = true;
rparams.grid_dim_inner_reduction =
rparams->cross_block_inner_reduction = bdimy > 1 || grdim > 1;
rparams->cross_grid_inner_reduction = grdim > 1;
if (rparams->cross_grid_inner_reduction) {
rparams->split_grid_dim_inner_reduction = true;
rparams->grid_dim_inner_reduction =
flip_grid ? ParallelType::BIDx : ParallelType::BIDy;
if (flip_grid) {
gdimx = std::min(grdim, scheduler_utils::x_grid_limit);
@ -785,17 +782,17 @@ ReductionParams OuterReductionHeuristic(
gdimy = std::min(grdim, scheduler_utils::y_grid_limit);
}
}
rparams.multiple_reds_per_blk = bdimx > 1 || iter_unroll_factor > 1;
rparams->multiple_reds_per_blk = bdimx > 1 || iter_unroll_factor > 1;
if (rparams.multiple_reds_per_blk) {
rparams.block_dim_iter_dom = ParallelType::TIDx;
if (rparams->multiple_reds_per_blk) {
rparams->block_dim_iter_dom = ParallelType::TIDx;
}
rparams.grid_dim_iter_dom =
rparams->grid_dim_iter_dom =
flip_grid ? ParallelType::BIDy : ParallelType::BIDx;
if (gidim > (flip_grid ? scheduler_utils::y_grid_limit
: scheduler_utils::x_grid_limit)) {
rparams.split_grid_dim_iter_dom = true;
rparams->split_grid_dim_iter_dom = true;
if (flip_grid) {
gdimy = scheduler_utils::y_grid_limit;
} else {
@ -803,29 +800,29 @@ ReductionParams OuterReductionHeuristic(
}
}
rparams.flip_grid = flip_grid;
rparams->flip_grid = flip_grid;
if (rparams.cross_block_inner_reduction) {
if (rparams.block_dim_iter_dom == ParallelType::TIDx) {
rparams.block_dim_inner_reduction = ParallelType::TIDy;
if (rparams->cross_block_inner_reduction) {
if (rparams->block_dim_iter_dom == ParallelType::TIDx) {
rparams->block_dim_inner_reduction = ParallelType::TIDy;
} else {
rparams.block_dim_inner_reduction = ParallelType::TIDx;
rparams->block_dim_inner_reduction = ParallelType::TIDx;
}
}
rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor;
rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor;
rparams.unroll_factor_iter_dom = iter_unroll_factor;
rparams->unroll_factor_iter_dom = iter_unroll_factor;
if (iter_unroll_factor > 1) {
rparams.vectorize_iter_dom = vectorize;
rparams->vectorize_iter_dom = vectorize;
}
rparams.lparams = LaunchParams(
rparams->lparams = LaunchParams(
gdimx,
gdimy,
LaunchParams::UNINITIALIZED_VAL,
rparams.multiple_reds_per_blk ? bdimx : bdimy,
rparams.multiple_reds_per_blk ? bdimy : LaunchParams::UNINITIALIZED_VAL,
rparams->multiple_reds_per_blk ? bdimx : bdimy,
rparams->multiple_reds_per_blk ? bdimy : LaunchParams::UNINITIALIZED_VAL,
LaunchParams::UNINITIALIZED_VAL);
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
@ -836,14 +833,14 @@ ReductionParams OuterReductionHeuristic(
<< "n_tensor_inputs: " << n_tensor_inputs << "\n"
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
<< "block(" << bdimx << ", " << bdimy << ", 1)" << std::endl;
std::cerr << rparams.toString() << std::endl;
std::cerr << rparams->toString() << std::endl;
}
return rparams;
}
} // namespace
ReductionParams reductionHeuristic(
std::shared_ptr<ReductionParams> reductionHeuristic(
const int64_t total_reduction_numel,
const int64_t total_iteration_numel,
const int64_t inner_most_dimension_numel,
@ -861,7 +858,7 @@ ReductionParams reductionHeuristic(
vectorize_factor);
} else {
// 3D schedules not enabled for outer reductions
return OuterReductionHeuristic(
return outerReductionHeuristic(
total_reduction_numel,
total_iteration_numel,
n_tensor_inputs,
@ -870,7 +867,7 @@ ReductionParams reductionHeuristic(
}
}
TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
Fusion* fusion,
const at::ArrayRef<c10::IValue>& runtime_inputs,
HeuristicSummary* data_cache) {
@ -881,7 +878,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
return getReductionHeuristics(fusion, runtime_info, data_cache);
}
TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache) {
@ -911,8 +908,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
TORCH_INTERNAL_ASSERT(
red_expr->getExprType() != c10::nullopt &&
(red_expr->getExprType().value() == ExprType::ReductionOp ||
red_expr->getExprType().value() == ExprType::WelfordOp),
ir_utils::isReductionOp(red_expr),
"TensorView doesn't have a reduction.");
auto properties =

View File

@ -13,12 +13,12 @@ namespace cuda {
class SchedulerRuntimeInfo;
class HeuristicSummary;
TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
Fusion* fusion,
const at::ArrayRef<c10::IValue>& runtime_inputs,
HeuristicSummary* data_cache = nullptr);
TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr);

View File

@ -1,6 +1,6 @@
#pragma once
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/heuristic.h>
#include <sstream>
@ -9,11 +9,11 @@ namespace jit {
namespace fuser {
namespace cuda {
// Parameters the Reduction Heuristic Generates to describe the optimial
// schedule. Warning: equal operator is intended for use in caching the kernel
// associated with these reduction parameters. It does not check if the launch
// parameters are equivelent!
class ReductionParams {
// Parameters of the reduction heuristic to describe the optimial schedule.
// Warning: equal operator is intended for use in caching the kernel associated
// with these reduction parameters. It does not check if the launch parameters
// are equivelent!
class ReductionParams : public HeuristicParams {
public:
// Reducing inner most dimension?
bool fastest_dim = false;
@ -100,18 +100,22 @@ class ReductionParams {
// parameters, not used for equivalence/hashing.
ParallelType grid_dim_outer_reduction = ParallelType::Serial;
std::string tag = "";
LaunchParams lparams;
bool isUnrolled() const {
return unroll_factor_inner_reduction > 1 || unroll_factor_iter_dom > 1 ||
unroll_factor_outer_reduction > 1;
}
public:
using HeuristicParams::HeuristicParams;
// Warning: Does not check launch parameters!
bool operator==(const ReductionParams& other) const {
bool sameAs(
const std::shared_ptr<HeuristicParams>& other_base) const override {
auto other_casted = std::dynamic_pointer_cast<ReductionParams>(other_base);
if (other_casted == nullptr) {
return false;
}
const ReductionParams& other = *other_casted;
bool attr_equal = other.fastest_dim == fastest_dim &&
other.persistent_kernel == persistent_kernel &&
other.project_persistent_buffers == project_persistent_buffers &&
@ -139,7 +143,7 @@ class ReductionParams {
return attr_equal;
}
std::string toString() const {
std::string toString() const override {
std::stringstream ss;
ss << "\n===== Reduction Parameters ========\n"
<< (tag == "" ? "" : "Tag: ") << tag << "\n"
@ -216,38 +220,37 @@ class ReductionParams {
ss << "====================================\n";
return ss.str();
}
};
// Warning: Hash is not based on launch parameters!
class ReductionParamsHash {
public:
size_t operator()(const ReductionParams& rp) const {
// Warning: Hash is not based on launch parameters!
size_t hash() const override {
constexpr size_t bits = sizeof(std::size_t) * 8;
size_t attr_hash = static_cast<size_t>(rp.fastest_dim) << (bits - 1) ^
static_cast<size_t>(rp.persistent_kernel) << (bits - 2) ^
static_cast<size_t>(rp.project_persistent_buffers) << (bits - 3) ^
static_cast<size_t>(rp.schedule_3D) << (bits - 4) ^
static_cast<size_t>(rp.flip_grid) << (bits - 5) ^
static_cast<size_t>(rp.cross_block_inner_reduction) << (bits - 6) ^
static_cast<size_t>(rp.cross_grid_inner_reduction) << (bits - 7) ^
static_cast<size_t>(rp.unroll_factor_inner_reduction) << (bits - 8) ^
static_cast<size_t>(rp.vectorize_inner_reduction) << (bits - 9) ^
static_cast<size_t>(rp.split_grid_dim_inner_reduction) << (bits - 10) ^
static_cast<size_t>(rp.pad_inner_reduction_to_warp) << (bits - 11) ^
static_cast<size_t>(rp.batches_per_block_inner_reduction)
<< (bits - 12) ^
static_cast<size_t>(rp.multiple_reds_per_blk) << (bits - 13) ^
static_cast<size_t>(rp.unroll_factor_iter_dom) << (bits - 14) ^
static_cast<size_t>(rp.vectorize_iter_dom) << (bits - 15) ^
static_cast<size_t>(rp.split_grid_dim_iter_dom) << (bits - 16) ^
static_cast<size_t>(rp.cross_block_outer_reduction) << (bits - 17) ^
static_cast<size_t>(rp.cross_grid_outer_reduction) << (bits - 18) ^
static_cast<size_t>(rp.split_grid_dim_outer_reduction) << (bits - 19) ^
static_cast<size_t>(rp.batches_per_block_outer_reduction)
<< (bits - 20) ^
static_cast<size_t>(rp.unroll_factor_outer_reduction) << (bits - 21);
size_t attr_hash = static_cast<size_t>(fastest_dim) << (bits - 1) ^
static_cast<size_t>(persistent_kernel) << (bits - 2) ^
static_cast<size_t>(project_persistent_buffers) << (bits - 3) ^
static_cast<size_t>(schedule_3D) << (bits - 4) ^
static_cast<size_t>(flip_grid) << (bits - 5) ^
static_cast<size_t>(cross_block_inner_reduction) << (bits - 6) ^
static_cast<size_t>(cross_grid_inner_reduction) << (bits - 7) ^
static_cast<size_t>(unroll_factor_inner_reduction) << (bits - 8) ^
static_cast<size_t>(vectorize_inner_reduction) << (bits - 9) ^
static_cast<size_t>(split_grid_dim_inner_reduction) << (bits - 10) ^
static_cast<size_t>(pad_inner_reduction_to_warp) << (bits - 11) ^
static_cast<size_t>(batches_per_block_inner_reduction) << (bits - 12) ^
static_cast<size_t>(multiple_reds_per_blk) << (bits - 13) ^
static_cast<size_t>(unroll_factor_iter_dom) << (bits - 14) ^
static_cast<size_t>(vectorize_iter_dom) << (bits - 15) ^
static_cast<size_t>(split_grid_dim_iter_dom) << (bits - 16) ^
static_cast<size_t>(cross_block_outer_reduction) << (bits - 17) ^
static_cast<size_t>(cross_grid_outer_reduction) << (bits - 18) ^
static_cast<size_t>(split_grid_dim_outer_reduction) << (bits - 19) ^
static_cast<size_t>(batches_per_block_outer_reduction) << (bits - 20) ^
static_cast<size_t>(unroll_factor_outer_reduction) << (bits - 21);
return attr_hash;
}
std::shared_ptr<HeuristicParams> clone() const override {
return std::make_shared<ReductionParams>(*this);
}
};
} // namespace cuda

View File

@ -1,8 +1,10 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/inline_propagator.h>
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/maxinfo_propagator.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
@ -219,13 +221,13 @@ void multiReductionInliner(
std::vector<TensorView*> reduction_tvs,
std::vector<TensorView*> cached_inputs,
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs) {
// Propagate transformations before we rfactor the other reductions
TransformPropagator propagator(reference_tv);
MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator);
// Apply rfactor to all reductions if applicable
std::vector<TensorView*> rfactor_tvs;
// If reduction_tv is rfactored, rfactor all reductions.
if (reference_tv != reduction_tv) {
// Apply rfactor to all reductions if applicable
std::vector<int> rfactor_axes;
for (const auto i : c10::irange(reference_tv->nDims())) {
if (reference_tv->axis((int)i)->isReduction() &&
@ -236,155 +238,86 @@ void multiReductionInliner(
for (auto reduction_tv_ : reduction_tvs) {
if (reduction_tv_ == reduction_tv) {
// The reduction tv
rfactor_tvs.push_back(reference_tv);
// This should come in already rfactored
continue;
} else {
rfactor_tvs.push_back(
ir_utils::rfactorHelper(reduction_tv_, rfactor_axes));
ir_utils::rfactorHelper(reduction_tv_, rfactor_axes);
}
}
TORCH_INTERNAL_ASSERT(
reduction_tvs.size() == rfactor_tvs.size(),
"Expected all reductions to contain rfactor.");
}
// Propagate parallelization
scheduler_utils::parallelizeAllLike(reference_tv, ir_utils::allTvs(fusion));
// Find iter domains that are mapped to a trivial reduction, these should
// never be inlined.
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
scheduler_utils::getTrivialReductionMap(fusion);
bool unroll = rparams.isUnrolled();
bool vectorize =
rparams.vectorize_inner_reduction || rparams.vectorize_iter_dom;
// Propagate parallelization except vectorization and unrolling
scheduler_utils::parallelizeAllLike(
reference_tv,
{},
allParallelTypesExcept(
{ParallelType::Unroll,
ParallelType::Vectorize,
ParallelType::MisalignedVectorize}));
if (unroll) {
// Inline Input caches to their consumers outside unswitched/vectorization
// position Inline consumers of input caches to rfactor tensors
// Mark which tensor views are actual input caches to leave vectorization on
// them
std::unordered_set<TensorView*> keep_unrolled;
std::vector<TensorView*> compute_from;
// Find all tensor views that should have unroll or vectorization
std::unordered_set<TensorView*> are_unrolled;
// Grab all tensor views that should be vectorized
auto vectorizable_inputs_outputs =
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true);
// Inputs to cache
auto vectorizable_expr = [](Expr* e) {
return e->isA<UnaryOp>() &&
e->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::Set;
};
for (auto cached_input : cached_inputs) {
auto consumers_of_input_cache = ir_utils::consumerTvsOf(cached_input);
for (auto consumer : consumers_of_input_cache) {
auto unswitch_it = std::find_if(
consumer->domain()->domain().begin(),
consumer->domain()->domain().end(),
[&mapped_to_trivial_reduction](IterDomain* id) {
return id->getParallelType() == ParallelType::Unswitch ||
id->getParallelType() == ParallelType::Unroll ||
id->getParallelType() == ParallelType::Vectorize ||
id->getParallelType() == ParallelType::MisalignedVectorize ||
mapped_to_trivial_reduction.count(id);
});
auto unswitch_pos = unswitch_it == consumer->domain()->domain().end()
? -1
: std::distance(consumer->domain()->domain().begin(), unswitch_it) +
1;
cached_input->computeAt(
consumer, unswitch_pos, ComputeAtMode::BestEffort);
compute_from.push_back(consumer);
if (vectorize) {
auto producer_tvs = ir_utils::producerTvsOf(cached_input);
if (producer_tvs.size() == 1 &&
std::find(
vectorizable_inputs_outputs.begin(),
vectorizable_inputs_outputs.end(),
producer_tvs[0]) != vectorizable_inputs_outputs.end()) {
keep_unrolled.emplace(cached_input);
}
} else {
keep_unrolled.emplace(cached_input);
if (vectorize) {
auto producer_tvs = ir_utils::producerTvsOf(cached_input);
if (producer_tvs.size() == 1 &&
vectorizable_expr(cached_input->definition()) &&
std::find(
vectorizable_inputs_outputs.begin(),
vectorizable_inputs_outputs.end(),
producer_tvs[0]) != vectorizable_inputs_outputs.end()) {
are_unrolled.emplace(cached_input);
}
} else {
are_unrolled.emplace(cached_input);
}
}
// Inline output caches into outputs
std::vector<TensorView*> compute_to;
for (auto cached_output_pair : cached_outputs) {
auto cached_output = cached_output_pair.first;
auto output = cached_output_pair.second;
if (vectorize) {
if (std::find(
if (vectorizable_expr(output->definition()) &&
std::find(
vectorizable_inputs_outputs.begin(),
vectorizable_inputs_outputs.end(),
output) != vectorizable_inputs_outputs.end()) {
keep_unrolled.emplace(output);
are_unrolled.emplace(output);
}
} else {
keep_unrolled.emplace(output);
}
// If an output has multiple consumers don't process compute at structure
// here, we want only terminating outputs
if (cached_output->uses().size() > 1) {
continue;
}
auto pos_it = std::find_if(
output->domain()->domain().begin(),
output->domain()->domain().end(),
[&mapped_to_trivial_reduction](IterDomain* id) {
return id->getParallelType() == ParallelType::Unswitch ||
id->getParallelType() == ParallelType::Unroll ||
id->getParallelType() == ParallelType::Vectorize ||
id->getParallelType() == ParallelType::MisalignedVectorize ||
mapped_to_trivial_reduction.count(id);
});
auto pos = pos_it == output->domain()->domain().end()
? -1
: std::distance(output->domain()->domain().begin(), pos_it) + 1;
cached_output->computeAt(output, pos, ComputeAtMode::BestEffort);
compute_to.push_back(cached_output);
}
{
// Add inputs to compute_at that weren't unrolled
auto processed_inputs = ir_utils::inputTvsOf(compute_from);
std::unordered_set<TensorView*> processed_inputs_set{
processed_inputs.begin(), processed_inputs.end()};
for (auto inp_tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
if (!processed_inputs_set.count(inp_tv)) {
compute_from.push_back(inp_tv);
}
}
auto processed_outputs = ir_utils::inputTvsOf(compute_to);
std::unordered_set<TensorView*> processed_outputs_set{
processed_outputs.begin(), processed_outputs.end()};
for (auto out_tv :
ir_utils::filterByType<TensorView>(fusion->outputs())) {
if (!processed_outputs_set.count(out_tv) && out_tv->uses().empty()) {
compute_to.push_back(out_tv);
}
are_unrolled.emplace(output);
}
}
// Before compute at-ing the internal structure, remove vectorization
// anywhere it doesn't belong. Otherwise it will mess up our inlining. Clear
// explicit unroll or vectorization when not for input or output GMEM
// transfers.
for (auto tv : ir_utils::allTvs(fusion)) {
if (!keep_unrolled.count(tv)) {
// Propagate vectorization/unrolling to those tensors that need it
scheduler_utils::parallelizeAllLike(
reference_tv,
-1,
{are_unrolled.begin(), are_unrolled.end()},
{ParallelType::Unroll,
ParallelType::Vectorize,
ParallelType::MisalignedVectorize});
std::vector<TensorView*> rfactor_and_reduction_tvs = {
reference_tv, reduction_tv};
// If reference shouldn't be unrolled, clear that parallel type.
for (auto tv : rfactor_and_reduction_tvs) {
if (are_unrolled.count(tv) == 0) {
for (const auto i : c10::irange(tv->nDims())) {
auto id = tv->axis((int)i);
if (id->getParallelType() == ParallelType::Unroll ||
@ -395,146 +328,22 @@ void multiReductionInliner(
}
}
}
// Make sure not to completely inline if there's trivial reductions in the
// fusion
auto pos_it = std::find_if(
reference_tv->domain()->domain().begin(),
reference_tv->domain()->domain().end(),
[&mapped_to_trivial_reduction](IterDomain* id) {
return mapped_to_trivial_reduction.count(id);
});
auto pos = pos_it == reference_tv->domain()->domain().end()
? -1
: std::distance(reference_tv->domain()->domain().begin(), pos_it) + 1;
// Compute at inputs to rfactor dimensions
scheduler_utils::computeAtBetween(
compute_from, rfactor_tvs, pos, ComputeAtMode::MostInlined);
// Inline rfactor into reduction
if (reference_tv != reduction_tv) {
// Compute at rfactor into following reduction, keep outside first
// reduction iter domain in the rfactor tensor view
for (const auto i : c10::irange(rfactor_tvs.size())) {
if (rparams.unroll_factor_iter_dom > 1) {
auto rfactor_tv = rfactor_tvs[i];
auto rfactor_tv_dom = rfactor_tv->domain()->domain();
auto reduction_it = std::find_if(
rfactor_tv_dom.begin(), rfactor_tv_dom.end(), [](IterDomain* id) {
return id->isReduction();
});
TORCH_INTERNAL_ASSERT(
reduction_it != rfactor_tv_dom.end(),
"Expected reduction axis in ",
rfactor_tv);
auto pos = std::distance(rfactor_tv_dom.begin(), reduction_it);
// I would like computeAtMode here to be Standard. However, the
// processing of welford rfactors in compute at ends up propating
// compute at from reduction_tv->rfactor_tv to all outputs.
rfactor_tv->computeWith(
reduction_tvs[i], pos, ComputeAtMode::BestEffort);
} else {
rfactor_tvs[i]->computeWith(
reduction_tvs[i], -1, ComputeAtMode::BestEffort);
}
}
}
// Remove anything before a reduction from compute_from
{
auto producers_of_reductions = DependencyCheck::getAllValsBetween(
{fusion->inputs().begin(), fusion->inputs().end()},
{reduction_tvs.begin(), reduction_tvs.end()});
auto producer_tvs_of_reductions =
ir_utils::filterByType<TensorView>(producers_of_reductions);
compute_from.erase(
std::remove_if(
compute_from.begin(),
compute_from.end(),
[&producer_tvs_of_reductions](TensorView* compute_from_tv) {
return std::find(
producer_tvs_of_reductions.begin(),
producer_tvs_of_reductions.end(),
compute_from_tv) != producer_tvs_of_reductions.end();
}),
compute_from.end());
}
// Add reduction tensor views to compute from
compute_from.insert(
compute_from.end(), reduction_tvs.begin(), reduction_tvs.end());
// Compute between reductions and output caches
scheduler_utils::computeAtBetween(
compute_from,
compute_to,
-1,
ComputeAtMode::BestEffort,
mapped_to_trivial_reduction);
} else {
// Want to inline, especially backwards based on reduction_tv, otherwise
// rfactor tv may not be inlined correctly
auto ref_tvs = rfactor_tvs.size() ? rfactor_tvs : reduction_tvs;
for (auto red_tv : ref_tvs) {
auto pos_it = std::find_if(
red_tv->domain()->domain().begin(),
red_tv->domain()->domain().end(),
[&mapped_to_trivial_reduction](IterDomain* id) {
return id->getParallelType() == ParallelType::Unswitch ||
id->getParallelType() == ParallelType::Unroll ||
id->getParallelType() == ParallelType::Vectorize ||
id->getParallelType() == ParallelType::MisalignedVectorize ||
mapped_to_trivial_reduction.count(id);
});
auto pos = pos_it == red_tv->domain()->domain().end()
? -1
: std::distance(red_tv->domain()->domain().begin(), pos_it) + 1;
scheduler_utils::computeAtInputs(red_tv, pos, ComputeAtMode::MostInlined);
scheduler_utils::computeWithOutputs(
red_tv, pos, ComputeAtMode::BestEffort);
}
// For topologies where there may not be paths to all inputs/outputs from
// the reductions, we need to take a similar approach to the unrolled
// version and setup of compute at from inputs->outputs avoiding going
// through the reduction expressions. This can be done by grabbing inputs
// not on path to a reduction, and computeAt-ing with all outputs. This
// doesn't guarantee we don't go through a reduction, but with best effort
// it should minimize damage if it does.
std::vector<TensorView*> compute_to;
for (auto out : ir_utils::filterByType<TensorView>(fusion->outputs())) {
// only terminating outputs
if (out->uses().size() || out->isFusionInput()) {
continue;
}
compute_to.push_back(out);
}
std::vector<TensorView*> compute_from;
std::unordered_set<TensorView*> inps_of_reds;
{
auto inps_of_red_vec = ir_utils::inputTvsOf(ref_tvs);
inps_of_reds = std::unordered_set<TensorView*>(
inps_of_red_vec.begin(), inps_of_red_vec.end());
}
for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) {
if (inps_of_reds.find(inp) != inps_of_reds.end()) {
continue;
}
compute_from.push_back(inp);
}
scheduler_utils::computeAtBetween(
compute_from,
compute_to,
-1,
ComputeAtMode::BestEffort,
mapped_to_trivial_reduction);
}
// Find iter domains that are mapped to a trivial reduction, these should
// never be inlined.
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
scheduler_utils::getTrivialReductionMap(fusion);
// Inline the schedule
InlinePropagator inline_propagator(
reference_tv,
-1,
ComputeAtMode::MostInlined,
{},
mapped_to_trivial_reduction);
MaxRootDomainInfoSpanningTree(reference_tv).traverse(&inline_propagator);
}
namespace {

View File

@ -22,7 +22,7 @@ TensorView* scheduleReductionTV(
bool has_iter_axis);
// Inlining function intended for single or multi reduction fusions.
void multiReductionInliner(
TORCH_CUDA_CU_API void multiReductionInliner(
Fusion* fusion,
const ReductionParams& rparams,
TensorView* reduction_tv,
@ -41,7 +41,7 @@ void multiReductionInliner(
// Rfactored axes are reductions bound to grid or blocks. If no axes are bound
// to a grid or block dimension it will rfactor the r-unswitch dimension.
// Reduction inliner expects an rfactored domain.
TensorView* sortAndRFactor(TensorView* reference_tv);
TORCH_CUDA_CU_API TensorView* sortAndRFactor(TensorView* reference_tv);
// Take all projectable persistent buffers, and move them to the inputs.
TORCH_CUDA_CU_API void projectPersistentBuffers(Fusion* fusion);

View File

@ -239,8 +239,8 @@ class SchedulerTopologyChecker {
static bool hasPostReductionBCast(Fusion* fusion) {
auto all_vals = fusion->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
// Welford can have 2 outputs, so do this on all found reduction tensor
// views
// Reductions can have multiple outputs, so do this on all found reduction
// tensor views
if (tv->hasReduction() && !tv->isFusionInput()) {
auto tv_chains = tvChains(DependencyCheck::getAllUseChains(tv));
// Propagate forward from reduction through all uses of the reduction
@ -301,18 +301,17 @@ class SchedulerTopologyChecker {
// When checking post reduction vals, we need to make sure
// we are really checking paths starting from all outputs
// of multi-output reductions, i.e. welford. The reduction_tv
// vector is assumed to only have one of them.
// of multi-output reductions, i.e. welford/grouped reduction. The
// reduction_tv vector is assumed to only have one of them.
std::unordered_set<Val*> reduction_tv_set(
reduction_tvs.begin(), reduction_tvs.end());
for (auto red : reduction_tvs) {
if (red->definition()) {
if (auto wop = dynamic_cast<WelfordOp*>(red->definition())) {
for (auto wop_output : wop->outputs()) {
if (wop_output->isA<TensorView>()) {
reduction_tv_set.insert(wop_output);
}
if (ir_utils::isReductionOp(red->definition())) {
auto outs = red->definition()->outputs();
for (auto out_tv : ir_utils::filterByType<TensorView>(outs)) {
reduction_tv_set.insert(out_tv);
}
}
}
@ -721,18 +720,7 @@ bool SchedulerEntry::sameAs(const SchedulerEntry* other) {
if (index_mode_ != other->index_mode_) {
return false;
}
// Heuristic equal should imply has_reduction_param_ equal,
// need to double check if it is the case before removing
// the below one.
if (has_reduction_param_ != other->has_reduction_param_) {
return false;
}
if (has_reduction_param_) {
return rparams_ == other->rparams_;
} else {
return pparams_ == other->pparams_;
}
return true;
return params_->sameAs(other->params_);
}
namespace {
@ -834,7 +822,7 @@ class ReductionScheduler : public SchedulerEntry {
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr)
: SchedulerEntry(ScheduleHeuristic::Reduction, true) {
: SchedulerEntry(ScheduleHeuristic::Reduction) {
computeHeuristics(fusion, runtime_info, data_cache);
}
@ -964,7 +952,7 @@ class ReductionScheduler : public SchedulerEntry {
void schedule(Fusion* fusion) override {
FUSER_PERF_SCOPE("Schedule Single Reduction");
scheduleReduction(fusion, rparams());
scheduleReduction(fusion, reductionParams());
}
private:
@ -972,9 +960,8 @@ class ReductionScheduler : public SchedulerEntry {
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
auto param = getReductionHeuristics(fusion, runtime_info, data_cache);
TORCH_INTERNAL_ASSERT(param.has_value());
rparams() = param.value();
params_ = getReductionHeuristics(fusion, runtime_info, data_cache);
TORCH_INTERNAL_ASSERT(params_ != nullptr);
}
};
@ -984,7 +971,7 @@ class PointWiseScheduler : public SchedulerEntry {
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr)
: SchedulerEntry(ScheduleHeuristic::PointWise, false) {
: SchedulerEntry(ScheduleHeuristic::PointWise) {
computeHeuristics(fusion, runtime_info, data_cache);
}
@ -1000,9 +987,8 @@ class PointWiseScheduler : public SchedulerEntry {
auto reduction_ops =
ir_utils::getReductionOps(fusion, true /* ignore_trivial */);
auto welford_ops = ir_utils::filterByType<WelfordOp>(reduction_ops);
if (!reduction_ops.empty() || !welford_ops.empty()) {
if (!reduction_ops.empty()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::PointWise, "no support for reduction ops");
return false;
@ -1027,16 +1013,15 @@ class PointWiseScheduler : public SchedulerEntry {
void schedule(Fusion* fusion) override {
FUSER_PERF_SCOPE("Schedule PointWise Fusion");
schedulePointwise(fusion, pparams());
schedulePointwise(fusion, pointwiseParams());
}
void computeHeuristics(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
auto pparam = getPointwiseHeuristics(fusion, runtime_info, data_cache);
TORCH_INTERNAL_ASSERT(pparam.has_value());
pparams() = pparam.value();
params_ = getPointwiseHeuristics(fusion, runtime_info, data_cache);
TORCH_INTERNAL_ASSERT(params_ != nullptr);
}
};
@ -1046,13 +1031,13 @@ class PersistentKernelScheduler : public SchedulerEntry {
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr)
: SchedulerEntry(ScheduleHeuristic::Persistent, true) {
: SchedulerEntry(ScheduleHeuristic::Persistent) {
computeHeuristics(fusion, runtime_info, data_cache);
}
void schedule(Fusion* fusion) override {
FUSER_PERF_SCOPE("Schedule Persistent Fusion");
schedulePersistentKernel(fusion, rparams());
schedulePersistentKernel(fusion, reductionParams());
}
static bool canScheduleCompileTime(Fusion* fusion) {
@ -1065,15 +1050,6 @@ class PersistentKernelScheduler : public SchedulerEntry {
auto reduction_ops =
ir_utils::getReductionOps(fusion, false /* ignore_trivial */);
auto welford_ops = ir_utils::filterByType<WelfordOp>(reduction_ops);
// For persistent schedule we want welford translated to average and
// standard deviation reductions.
if (welford_ops.begin() != welford_ops.end()) {
scheduler_debug_utils::canScheduleRejectReason(
ScheduleHeuristic::Persistent,
"no support for un-translated welford");
return false;
}
auto view_tvs = scheduler_utils::getViewTVs(fusion);
if (view_tvs.size() > 0) {
@ -1263,9 +1239,8 @@ class PersistentKernelScheduler : public SchedulerEntry {
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
HeuristicSummary* data_cache = nullptr) {
auto param = getPersistentHeuristics(fusion, runtime_info, data_cache);
TORCH_INTERNAL_ASSERT(param.has_value());
rparams() = param.value();
params_ = getPersistentHeuristics(fusion, runtime_info, data_cache);
TORCH_INTERNAL_ASSERT(params_ != nullptr);
}
};
@ -1367,11 +1342,7 @@ c10::optional<ScheduleHeuristic> SchedulerEntry::proposeHeuristics(
}
size_t SchedulerEntryHash::operator()(const SchedulerEntry& se) const {
if (se.hasReductionParam()) {
return ReductionParamsHash()(se.reductionParams());
} else {
return PointwiseParamsHash()(se.pointwiseParams());
}
return se.params()->hash();
}
std::string toString(ScheduleHeuristic sh) {
@ -1444,6 +1415,9 @@ HeuristicSummary::HeuristicSummary(
void HeuristicSummary::validate() const {
switch (heuristic_) {
case ScheduleHeuristic::PointWise: {
TORCH_INTERNAL_ASSERT(entry_type_map_.count(EntryType::DOMAIN_MAP));
TORCH_INTERNAL_ASSERT(
entry_type_map_.count(EntryType::REFERENCE_TENSORS));
TORCH_INTERNAL_ASSERT(
entry_type_map_.count(EntryType::VECTORIZABLE_INPUTS_AND_OUTPUTS));
TORCH_INTERNAL_ASSERT(
@ -1512,6 +1486,8 @@ HeuristicSummaryEntry<EntryClass>::HeuristicSummaryEntry(
}
// Template instantiation for pre-defined cache entries
template class HeuristicSummaryEntry<HeuristicCompileTime::DomainMap>;
template class HeuristicSummaryEntry<HeuristicCompileTime::ReferenceTensors>;
template class HeuristicSummaryEntry<
HeuristicCompileTime::VectorizableInputsAndOutputs>;
template class HeuristicSummaryEntry<

View File

@ -2,6 +2,9 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/heuristic.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
@ -158,11 +161,7 @@ class TORCH_CUDA_CU_API SchedulerEntry {
//! Heuristic comparison
bool sameAs(const SchedulerEntry* other);
bool hasReductionParam() const {
return has_reduction_param_;
}
ScheduleHeuristic heuristc() const {
ScheduleHeuristic heuristic() const {
return heuristc_;
}
@ -170,51 +169,38 @@ class TORCH_CUDA_CU_API SchedulerEntry {
return index_mode_;
}
const std::shared_ptr<HeuristicParams>& params() const {
return params_;
}
const ReductionParams& reductionParams() const {
auto rparams = std::dynamic_pointer_cast<ReductionParams>(params_);
TORCH_INTERNAL_ASSERT(
has_reduction_param_, "This schedule heuristic is not reduction.");
return rparams_;
rparams != nullptr, "Heuristic parameter is not a reduction parameter");
return *rparams;
}
const PointwiseParams& pointwiseParams() const {
auto pparams = std::dynamic_pointer_cast<PointwiseParams>(params_);
TORCH_INTERNAL_ASSERT(
!has_reduction_param_, "This schedule heuristic is not pointwise.");
return pparams_;
pparams != nullptr, "Heuristic parameter is not a pointwise parameter");
return *pparams;
}
void updateLaunchConstraint(const LaunchParams& launch_params) {
if (hasReductionParam()) {
rparams_.lparams = launch_params;
} else {
pparams_.lparams = launch_params;
}
params_->lparams = launch_params;
}
protected:
explicit SchedulerEntry(ScheduleHeuristic heuristic, bool has_reduction_param)
: heuristc_(heuristic), has_reduction_param_(has_reduction_param) {}
explicit SchedulerEntry(ScheduleHeuristic heuristic) : heuristc_(heuristic) {}
ReductionParams& rparams() {
return rparams_;
}
PointwiseParams& pparams() {
return pparams_;
}
//! Heuristic parameters if applicable
std::shared_ptr<HeuristicParams> params_ = nullptr;
private:
//! What kind of heuristics does this entry have?
const ScheduleHeuristic heuristc_;
//! Has reduction params if true, else has pointwise params
const bool has_reduction_param_;
//! Reduction parameters if applicable
ReductionParams rparams_;
//! Pointwise parameters if applicable
PointwiseParams pparams_;
//! Kernel Index Mode
KernelIndexMode index_mode_ = KernelIndexMode::INT64;
};
@ -226,10 +212,12 @@ class TORCH_CUDA_CU_API SchedulerEntryHash {
};
//! Debug print function for heuristics
std::string toString(ScheduleHeuristic sh);
TORCH_CUDA_CU_API std::string toString(ScheduleHeuristic sh);
//! Debug print function for heuristics
std::ostream& operator<<(std::ostream& os, ScheduleHeuristic sh);
TORCH_CUDA_CU_API std::ostream& operator<<(
std::ostream& os,
ScheduleHeuristic sh);
} // namespace cuda
} // namespace fuser

View File

@ -188,30 +188,53 @@ size_t mergeNonReduction(
void parallelizeAllLike(
TensorView* reference_tv,
const std::vector<TensorView*>& all_tvs) {
int64_t pos,
std::vector<TensorView*> selected_tvs,
const std::unordered_set<ParallelType>& selected_parallel_types,
bool propagate_padding) {
FusionGuard fg(reference_tv->fusion());
if (pos < 0) {
pos += reference_tv->nDims() + 1;
}
TORCH_CHECK(
pos >= 0 && pos <= reference_tv->nDims(),
"parallelizeAllLike called on an position outside valid range.");
std::unordered_map<IterDomain*, IterDomain*> concrete_to_reference_map;
auto ca_map = ComputeAtMap(FusionGuard::getCurFusion());
for (auto id : reference_tv->domain()->domain()) {
ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE)
->parallelize(id->getParallelType());
if (id->hasPaddingToMultipleOfWarp()) {
ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE)
->padToMultipleOfWarp(id->getMaybeSizeAfterPadding());
}
const auto& reference_dom = reference_tv->domain()->domain();
for (auto it = reference_dom.begin(); it != reference_dom.begin() + pos;
it++) {
auto ca_id = ca_map.getConcreteMappedID(*it, IdMappingMode::PERMISSIVE);
concrete_to_reference_map[ca_id] = *it;
}
for (auto tv : all_tvs) {
if (selected_tvs.empty()) {
selected_tvs = ir_utils::allTvs(reference_tv->fusion());
}
for (auto tv : selected_tvs) {
if (tv->isFusionInput()) {
continue;
}
for (const auto i : c10::irange(tv->domain()->domain().size())) {
auto ca_id =
ca_map.getConcreteMappedID(tv->axis(i), IdMappingMode::PERMISSIVE);
tv->axis(i)->parallelize(ca_id->getParallelType());
if (ca_id->hasPaddingToMultipleOfWarp()) {
tv->axis(i)->padToMultipleOfWarp(ca_id->getMaybeSizeAfterPadding());
if (concrete_to_reference_map.count(ca_id) > 0) {
auto reference_id = concrete_to_reference_map.at(ca_id);
auto reference_parallel_type = reference_id->getParallelType();
if (selected_parallel_types.empty() ||
selected_parallel_types.count(reference_parallel_type)) {
tv->axis(i)->parallelize(reference_parallel_type);
}
if (propagate_padding) {
if (reference_id->hasPaddingToMultipleOfWarp()) {
tv->axis(i)->padToMultipleOfWarp(
reference_id->getMaybeSizeAfterPadding());
}
}
}
}
}
@ -1607,7 +1630,8 @@ void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) {
void scheduleContiguousVectorLoad(
TensorView* tv,
MatMulTileOptions tile,
int vector_word) {
int vector_word,
bool vectorize) {
auto warp_dims = tile.cta_tile / tile.warp_tile;
int num_of_thread = warp_dims.m * warp_dims.n * warp_dims.k * 32;
@ -1630,14 +1654,423 @@ void scheduleContiguousVectorLoad(
tv->split(-3, warp_dims.k);
}
tv->axis(-1)->parallelize(ParallelType::Vectorize);
if (vectorize) {
tv->axis(-1)->parallelize(ParallelType::Vectorize);
}
tv->axis(-2)->parallelize(ParallelType::TIDx);
tv->axis(-3)->parallelize(ParallelType::TIDy);
tv->axis(-4)->parallelize(ParallelType::TIDz);
}
void makeTile(TensorView* tv, std::vector<int> tile_sizes) {
TORCH_CHECK(
tv->domain()->domain().size() >= tile_sizes.size(),
"Tensor dimension less than tile dimension!");
// Number of inner dimensions we are tiling.
const auto tile_dimension_size = tile_sizes.size();
// Split the inner dimensions:
for (auto idx : c10::irange(tile_dimension_size)) {
// Using negative indexing to accomodate potential batching
// dimensions on the further left. Eg.:
// 0, 1, 2 -> -3,-2,-1
// [M, N, K] -> [B0, B1, M, N, K]
tv->split(idx - tile_dimension_size, tile_sizes.at(idx));
}
// The transformation happened should look like:
// Before After
// [..., M, N, K] -> [..., Mo, Mi, No, Ni, Ko, Ki]
// Re-order the tiles so that all the outer tiles are
// on the left of all the inner tiles
std::unordered_map<int, int> reorder_map_old_to_new;
// Number of tiled inner dimensions after we split.
const auto split_tile_dimension_size = 2 * tile_dimension_size;
for (auto idx : c10::irange(split_tile_dimension_size)) {
// We want to reorder as follows:
// Before
//
// [..., Mo, Mi, No, Ni, Ko, Ki] ->
// After
// vvv group0 vvv vvv group1 vvv
// [..., Mo, No, Ko, Mi, Ni, Ki]
// The index offset within group of current
// iterdomain, with grouping specified above.
auto index_within_group = idx / 2;
// The index of the group the current id belongs
// to, as specified above.
auto group_index = idx % 2;
// Calculate the actual index after reordering
auto index_after_reorder =
group_index * tile_dimension_size + index_within_group;
// Add pair {idx_before, idx_after} to re-order map.
reorder_map_old_to_new.insert(std::make_pair(
idx - split_tile_dimension_size,
index_after_reorder - split_tile_dimension_size));
}
// Apply the re-order map to tensor
tv->reorder(reorder_map_old_to_new);
}
namespace {
c10::optional<IterDomain*> getMaybeRootIfInnermostTiled(
IterDomain* id,
const std::unordered_set<IterDomain*>& maybe_rfactor_id_set) {
// Root id defaults to an "innermost id".
while (id->definition() && !maybe_rfactor_id_set.count(id)) {
if (auto split = dynamic_cast<Split*>(id->definition())) {
if (id == split->inner()) {
id = split->in();
continue;
}
}
// Didn't pass the inner most check, return empty.
return c10::nullopt;
}
return id;
}
} // namespace
TORCH_CUDA_CU_API void orderTiledConcreteIdAsRoot(TensorView* tv) {
auto ndims = tv->nDims();
// Keep track of the left most position where we will
// be reordering the axes.
auto leftmost_pos = ndims;
// Pull the root id's of the given tv.
std::unordered_set<IterDomain*> maybe_rfactor_id_set{
tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()};
// Keep track of leaf positions that is either a reduction
// or a broadcast.
// Note: Currently don't really see a case where this function
// should be called on a reduction output tv, but adding them
// here for completeness.
std::deque<int> broadcast_or_reduction_pos;
// Map the root id's to their innermost concrete id's
// on the leaf.
std::unordered_map<IterDomain*, int> root_id_to_inner_leaf_pos;
// Try to re-order inner iterdomains from the innermost
// position backward. This utility only tries to re-order
// inner tiles on the innermost positions, like the resulting
// tensor from makeTile utility.
// The re-ordering would first try to decide the inner iterdomains
// we want to re-order. For this we start from the innermost position
// and move back and collect all the iterdomains that we know
// are inner tiles of some root domain or broadcast/reduction domains
// that won't affect the concrete id layout.
// The collection process would stop whenever a iterdomain that is
// neither an inner tile nor reduction/broadcast is found, and would
// not re-order any iterdomain beyond that point to keep the
// outer loop structure unchanged.
for (int64_t i = static_cast<int64_t>(ndims) - 1; i >= 0; i--) {
auto leaf_id = tv->axis(i);
if (leaf_id->isBroadcast() || leaf_id->isReduction()) {
// Register this reduction or broadcast axis
// to reorder.
broadcast_or_reduction_pos.push_front(i);
leftmost_pos = i;
continue;
}
auto maybe_root =
getMaybeRootIfInnermostTiled(leaf_id, maybe_rfactor_id_set);
if (maybe_root.has_value()) {
// Found an innermost id, add them to the
// axes to reorder.
TORCH_INTERNAL_ASSERT(
root_id_to_inner_leaf_pos
.insert(std::make_pair(maybe_root.value(), i))
.second,
"Multiple \"innermost\" id seen for root id :",
maybe_root.value()->toString(),
" on ",
tv->toString(),
" very likely an invariant is broken.");
leftmost_pos = i;
} else {
break;
}
}
// Calculate the ordering:
// pointer to the current target postion after
// repordering
int current_pos = leftmost_pos;
std::unordered_map<int, int> reorder_map_old_to_new;
// first place all the broadcast and reduction on the left:
for (auto original_broadcast_or_reduction_pos : broadcast_or_reduction_pos) {
reorder_map_old_to_new[original_broadcast_or_reduction_pos] = current_pos++;
}
// Next put all the innermost leaf id's, we make sure that
// the inner tile ordering follows the corresponding root
// domain ordering by iterating on the root domain and
// find their corresponding inner tile iterdomains from
// the populated root_id_to_inner_leaf_pos.
for (auto root_id : tv->getMaybeRFactorDomain()) {
auto leaf_id_pos_it = root_id_to_inner_leaf_pos.find(root_id);
if (leaf_id_pos_it != root_id_to_inner_leaf_pos.end()) {
reorder_map_old_to_new[leaf_id_pos_it->second] = current_pos++;
}
}
// Validate that we have processed all inner ids or broadcast/reduction
// ids we have registered.
TORCH_INTERNAL_ASSERT(current_pos == ndims, "Inconsistent ordering logic");
// Apply the new order:
tv->reorder(reorder_map_old_to_new);
}
} // namespace matmul_utils
//! Propagate current transformations on from_tv to all graphs
TORCH_CUDA_CU_API void transformPropagateToAllFrom(
TensorView* from_tv,
int pos) {
TransformPropagator propagator(from_tv, pos);
MaxRootDomainInfoSpanningTree(from_tv, nullptr).traverse(&propagator);
}
namespace {
//! Utility enum to signify which direction
//! BoundedDirectionalTransformPropagator
//! passes will propagate the transforms.
enum class PropagateDirection { Backward = 0, Forward };
//! Returns true if the given tensorview is a fake boundary
//! TensorView, see Note [Fake Boundary Tensorview].
//! This function assumes and would not check that tv is a boundary
//! of the select_tv set.
bool isFakeBoundaryTensorview(
TensorView* tv,
const std::unordered_set<TensorView*>& selected_tv_set,
PropagateDirection direction) {
if (direction == PropagateDirection::Forward) {
// In the case of forward propagation,
// a boundary tv is a fake boundary if
// it has any consumer tv that's in the selected
// set.
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
if (selected_tv_set.count(consumer_tv)) {
// Found a consumer that's in selected tv set.
return true;
}
}
} else {
// In the case of backward propagation,
// a boundary tv is a fake boundary if it has any producer
// that is within the selected set.
for (auto producer_tv : ir_utils::producerTvsOf(tv)) {
if (selected_tv_set.count(producer_tv)) {
// Found a producer that's in selected tv set.
return true;
}
}
}
// Didn't find any producer/consumer in the selected tv set.
// The given tv is not a fake boundary tv.
return false;
}
//! Utility function to generate the set of tensorviews to propagate
//! transform to by BoundedDirectionalTransformPropagator.
std::unordered_set<TensorView*> getDirectionalPropagatePathSet(
TensorView* from_tv,
std::vector<TensorView*> boundary_tvs,
BoundedDirectionalTransformPropagator::Options options,
PropagateDirection direction) {
// Prepare to collect all candidate tensorviews
// within the specified boundary.
std::vector<Val*> propagate_candidate;
// Collect boundary tvs in a set.
std::unordered_set<TensorView*> boundary_tv_set(
boundary_tvs.begin(), boundary_tvs.end());
if (direction == PropagateDirection::Forward) {
// In the case of forward propagation, collect all tvs
// that are consumers of `from_tv` and producers of
// boundary tvs.
propagate_candidate = DependencyCheck::getAllValsBetween(
{from_tv}, {boundary_tvs.begin(), boundary_tvs.end()});
} else {
// In the case of backward propagation, collect all tvs
// that are producers of `from_tv` and consumers of
// boundary tvs.
propagate_candidate = DependencyCheck::getAllValsBetween(
{boundary_tvs.begin(), boundary_tvs.end()}, {from_tv});
}
// Populate initial selected tensorviews in a set.
auto propagate_candidate_tv_view =
ir_utils::filterByType<TensorView>(propagate_candidate);
// Prepare to filter out un-wanted tensorviews according
// to the option parameters.
std::unordered_set<TensorView*> propagate_path_set{
propagate_candidate_tv_view.begin(), propagate_candidate_tv_view.end()};
// Remove boundary tensorviews if we don't want to transform
// tensorviews on the boundary.
if (!options.transform_boundary) {
// Additional refining step to identify "fake boundary" tensorviews.
// We don't want to erase fake boundary tensorviews from the selected
// set when we are erasing boundary tvs.
//
// Note [Fake Boundary Tensorview]
// A tensorview, tv0, is defined as fake boundary tv if
// 1. Tv0 is on the given boundary set.
// 2. There is a path from another boundary tv, Tv1 to from_tv that
// goes through Tv0.
//
// In this case the propagation behavior is not precisely defined.
// Our current decision is to treat such tensorview as non-boundary
// tv to make sure the propagation paths are not blocked. E.g.:
//
// T1 = T0
// T2 = T1
// T3 = T2 + T1
// if we propagate with from_tv = {T3}, boundary_tv = {T0, T2},
// transform_boundary=false
//
// Here T2 is a fake boundary and we will still transform T2 as it is
// on the path between T3 and T0.
// Initialize set of fake boundary tvs.
std::unordered_set<TensorView*> fake_boundary_set;
// Populate the set of fake boundary tvs.
std::copy_if(
boundary_tvs.begin(),
boundary_tvs.end(),
std::inserter(fake_boundary_set, fake_boundary_set.end()),
[&propagate_path_set, direction](TensorView* tv) {
return isFakeBoundaryTensorview(tv, propagate_path_set, direction);
});
// Remove boundary tvs from the selected set, keeping fake boundary tvs.
for (auto boundary_tv : boundary_tvs) {
if (!fake_boundary_set.count(boundary_tv)) {
propagate_path_set.erase(boundary_tv);
}
}
}
return propagate_path_set;
}
} // namespace
void BoundedDirectionalTransformPropagator::propagate(
TensorView* from_tv,
int pos,
std::unordered_set<TensorView*> included_tvs,
Options options) {
// Run transform propagation using the custom selector.
SetSelector selector(included_tvs);
TransformPropagator propagator(from_tv, pos);
MaxRootDomainInfoSpanningTree(from_tv, &selector).traverse(&propagator);
// Propagate parallel type if requested by option parameters.
if (options.propagate_parallel_type) {
scheduler_utils::parallelizeAllLike(
from_tv,
options.parallel_propagation_pos,
{included_tvs.begin(), included_tvs.end()},
allParallelTypesExcept({ParallelType::Vectorize, ParallelType::Mma}));
}
}
void BoundedDirectionalTransformPropagator::backward(
TensorView* from,
int pos,
std::vector<TensorView*> to,
c10::optional<Options> options) {
if (!options.has_value()) {
options = Options();
}
TORCH_INTERNAL_ASSERT(
!to.empty(),
"Propagation needs to be bounded, so no support for empty boundary.");
// Collect all tvs to included on the backward path as specified
// by boundary and options.
auto included_tvs = getDirectionalPropagatePathSet(
from, to, *options, PropagateDirection::Backward);
// Actually run the propagation.
propagate(from, pos, included_tvs, *options);
}
void BoundedDirectionalTransformPropagator::forward(
TensorView* from,
int pos,
std::vector<TensorView*> to,
c10::optional<Options> options) {
if (!options.has_value()) {
options = Options();
}
TORCH_INTERNAL_ASSERT(
!to.empty(),
"Propagation needs to be bounded, so no support for empty boundary.")
// Collect all tvs to included on the forward path as specified
// by boundary and options.
auto included_tvs = getDirectionalPropagatePathSet(
from, to, *options, PropagateDirection::Forward);
// Actually run the propagation.
propagate(from, pos, included_tvs, *options);
}
void BoundedDirectionalTransformPropagator::bothWays(
TensorView* from,
int pos,
std::vector<TensorView*> backward_to,
std::vector<TensorView*> forward_to,
c10::optional<Options> options) {
if (!options.has_value()) {
options = Options();
}
TORCH_INTERNAL_ASSERT(
!backward_to.empty() && !forward_to.empty(),
"Propagation needs to be bounded, so no support for empty boundary.")
// Collect all tvs to included on the backward and forward path as specified
// by boundary and options.
auto backward_included_tvs = getDirectionalPropagatePathSet(
from, backward_to, *options, PropagateDirection::Backward);
auto forward_included_tvs = getDirectionalPropagatePathSet(
from, forward_to, *options, PropagateDirection::Forward);
// Combined the included tvs on both paths.
auto included_tvs = backward_included_tvs;
included_tvs.insert(forward_included_tvs.begin(), forward_included_tvs.end());
// Run the propagation on the combined set of tvs.
propagate(from, pos, included_tvs, *options);
}
// Grab all values and expressions used to make the merged_domain and remove
// them from the fusion
void cleanUpInnermostMergedDomains(

View File

@ -2,6 +2,7 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/maxinfo_propagator.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
namespace torch {
@ -11,6 +12,7 @@ namespace cuda {
class SchedulerRuntimeInfo;
class ExpressionEvaluator;
class HeuristicSummary;
namespace scheduler_utils {
@ -37,6 +39,11 @@ constexpr int64_t lastPow2(int64_t n) {
return std::max((int64_t)1, n - (n >> 1));
}
// Div x by y, but min at 1
inline int64_t safeDiv(const int64_t x, const int64_t y) {
return std::max(x / y, (int64_t)1);
}
// Merge all reduction to the right side and returns total number of
// reduction axes. Don't merge is typically used for trivial reductions.
size_t mergeReduction(
@ -49,9 +56,32 @@ size_t mergeNonReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge = {});
// Propagate the parallelization from the selected dimensions of the reference
// tensor to their corresponding dimensions in all selected tensors in the DAG.
// Position `pos` means selecting all the dimensions [0, 1, ..., pos - 1]. pos =
// -1 means selecting all dimensions. `selected_tvs` are selected tensors in the
// DAG. Empty `selected_tvs` means selecting all tensors in the fusion of
// `reference_tv`. `selected_parallel_types` are the selected parallel types.
// Empty `selected_parallel_types` means selecting all parallel types.
TORCH_CUDA_CU_API void parallelizeAllLike(
TensorView* reference_tv,
const std::vector<TensorView*>& all_tvs);
int64_t pos = -1,
std::vector<TensorView*> selected_tvs = {},
const std::unordered_set<ParallelType>& selected_parallel_types = {},
bool propagate_padding = true);
TORCH_CUDA_CU_API inline void parallelizeAllLike(
TensorView* reference_tv,
std::vector<TensorView*> selected_tvs,
const std::unordered_set<ParallelType>& selected_parallel_types = {},
bool propagate_padding = true) {
parallelizeAllLike(
reference_tv,
-1,
std::move(selected_tvs),
selected_parallel_types,
propagate_padding);
}
TORCH_CUDA_CU_API void computeAtInputs(
TensorView* consumer,
@ -166,7 +196,6 @@ std::pair<bool, bool> canonicalDimReduction(
// Return a list of tensor views that are outputs of reduction operations. If
// multiple outputs of an expression are found, only include one in the list
// (WelfordOp)
TORCH_CUDA_CU_API std::vector<TensorView*> getReductionTvs(
Fusion* fusion,
bool ignore_trivial = true);
@ -179,13 +208,14 @@ void clearMemorySpace(Fusion* fusion);
// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
// return empty vector.
std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll);
TORCH_CUDA_CU_API std::vector<TensorView*> cacheInputs(
Fusion* fusion,
bool unroll);
// Returns the pairs of <cache of each fusion output, corresponding output> for
// all outputs.
std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
Fusion* fusion,
bool unroll);
TORCH_CUDA_CU_API std::vector<std::pair<TensorView*, TensorView*>>
cacheAndForkOutputs(Fusion* fusion, bool unroll);
// Ignores broadcast and reduction, returns iter domain in root domain that's
// "inner most". If this is an rfactored reduction domain, actually check the
@ -289,10 +319,12 @@ namespace matmul_utils {
TORCH_CUDA_CU_API void scheduleContiguousVectorLoad(
TensorView* tv,
MatMulTileOptions tile,
int vector_word);
int vector_word,
bool vectorize = true);
//! Schedule utility for mma output in matmul main loop:
//! Realize the hierarchical tiling based on the given tiling options.
//! TODO: rewrite this one with makeTile
TORCH_CUDA_CU_API void scheduleWarpTileWithReduction(
TensorView* tv,
MatMulTileOptions tile);
@ -300,12 +332,142 @@ TORCH_CUDA_CU_API void scheduleWarpTileWithReduction(
//! Schedule utility for mma output in matmul main loop:
//! Realize the hierarchical tiling based on the given tiling options
//! on consumers of mma ops in epilog.
//! TODO: remove this one eventually.
TORCH_CUDA_CU_API void scheduleWarpTileWithNoReduction(
TensorView* tv,
MatMulTileOptions tile);
//! Lower level primitive spliting inner iterdomains into tiles:
//! Eg.
//! A[B,I0,I1,I2] -> makeTile({1,2,3})
//! Gives A[B, I0o, I1o, I2o, I0i(1), I1i(2), I2i(3)]
TORCH_CUDA_CU_API void makeTile(TensorView* tv, std::vector<int> tile_sizes);
//! Order the inner tile dimensions as the original order in
//! root domain. Also putting broadcast domains on the left.
//! Eg. A[I0o,I1o,B2o,I0i,I1i,B2i] (root domain: I1,B,I0)
//! -> A[I0o, I1o, B2o, B2i, I1i, I0i]
//! This is used to facilitate data layout swizzling and
//! defining vectorized loads.
TORCH_CUDA_CU_API void orderTiledConcreteIdAsRoot(TensorView* tv);
//! Orders the root id ordering of the given tv as
//! [Batch, Previous Reduction, M, N, K]
//! for easier processing of later scheduling steps.
//!
//! This matching works on root domain only, and
//! will throw if the tv has a leaf iterdomain that is
//! not a root id.
TORCH_CUDA_CU_API void canonicalizeMmaTvOrdering(TensorView* tv);
} // namespace matmul_utils
//! Propagate current transformations on from_tv up to the given
//! position, to all tensorviews on the owning fusion that has
//! a connection with `from_tv` on the fusion graph.
TORCH_CUDA_CU_API void transformPropagateToAllFrom(
TensorView* from_tv,
int pos);
//! A type of custom transform propagator that propagates iterdomain
//! transforms from a source tv to all tvs that are selected
//! using a "direction" and a "boundary".
//!
//! The propagation model always assumes a `from_tv`, a `direction` and a
//! `boundary`.
//!
//! This propagator will only transform producers and consumers
//! of `from_tv`, and all propagation modes **require** a boundary to be
//! specified to signify where the propagation should stop.
//!
//! There are currently three modes of propagation: forward, backward and
//! both-way, see comment on the interface functions for details.
struct TORCH_CUDA_CU_API BoundedDirectionalTransformPropagator {
//! Custom option container for configuring
//! the transform propagation actions.
//! All option values default to false unless
//! the corresponding setter is called.
struct Options {
//! If true, the transform propagator will
//! also propagate parallel types from
//! `from_tv` to all selected tvs.
bool propagate_parallel_type = false;
//! If true, the specified boundary tvs
//! will also be replayed as `from_tv`.
//! If false, they will not be affected
//! by the propagation pass.
bool transform_boundary = false;
//! Sets the position boundary in parallel
//! type propagation, see comment on
//! scheduler_utils::parallelizeAllLike.
//! Only used if propagate_parallel_type==true.
int parallel_propagation_pos = -1;
//! Setter for enabling parallel type
//! propagation. see comment on the variable.
//!
//! \param up_to_pos, sets the parallel type
//! propagation boundary. see comment on
//! scheduler_utils::parallelizeAllLike.
Options propagateParallelType(int up_to_pos = -1) {
propagate_parallel_type = true;
parallel_propagation_pos = up_to_pos;
return *this;
}
//! Setter for enabling propagation to
//! boundary tvs. see comment on the variable
Options propagateToBoundary() {
transform_boundary = true;
return *this;
}
};
//! Replay transforms from tensorview `from`
//! to the tensorviews that are consumers
//! of boundary tensorviews in `to` and producers of `from`.
static void backward(
TensorView* from,
int pos,
std::vector<TensorView*> to,
c10::optional<Options> options = c10::nullopt);
//! Replay transforms from tensorview `from`
//! to the tensorviews that are producers
//! of boundary tensorviews in `to` and consumers of `from`.
static void forward(
TensorView* from,
int pos,
std::vector<TensorView*> to,
c10::optional<Options> options = c10::nullopt);
//! Replay transforms from tensorview `from`
//! to all the tensorviews that are consumers
//! of tensorviews in `backward_to` and producers
//! of tensorviews in `forward_to` while being
//! either a producer or a consumer of tensorview `from`.
static void bothWays(
TensorView* from,
int pos,
std::vector<TensorView*> backward_to,
std::vector<TensorView*> forward_to,
c10::optional<Options> options = c10::nullopt);
private:
//! Utility function:
//! Will realize the transform propagation to the
//! tensorview's in `included_tvs`.
//! Assumes that all tvs in included_tvs are either
//! a producer or a consumer of from_tv.
static void propagate(
TensorView* from_tv,
int pos,
std::unordered_set<TensorView*> included_tvs,
Options options);
};
} // namespace scheduler_utils
} // namespace cuda
} // namespace fuser

View File

@ -211,6 +211,8 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner)
memory_type_(src->memory_type_),
swizzle_type_(src->swizzle_type_),
is_double_buffered_(src->is_double_buffered_),
is_circular_buffered_(src->is_circular_buffered_),
circular_buffer_stage_(src->circular_buffer_stage_),
cpu_scalar_(src->cpu_scalar_),
has_swizzle_op_(src->has_swizzle_op_) {
for (const auto id : src->axesToSwizzle()) {
@ -571,7 +573,11 @@ TensorView* TensorView::swizzle(
return this;
}
TensorView* TensorView::swizzle(Swizzle2DType swizzle_type, int x, int y) {
TensorView* TensorView::swizzle(
Swizzle2DType swizzle_type,
int x,
int y,
SwizzleMode swizzle_mode) {
has_swizzle_op_ = true;
if (x < 0) {
x += domain()->nDims();
@ -645,7 +651,7 @@ TensorView* TensorView::swizzle(Swizzle2DType swizzle_type, int x, int y) {
}
}
domain()->swizzle(swizzle_type, x, y);
domain()->swizzle(swizzle_type, x, y, swizzle_mode);
return this;
}
@ -735,11 +741,10 @@ TensorView* TensorView::multiOutputRfactorHelper(
!container()->isA<kir::Kernel>(),
"Function invalid for kernel container.");
// Hack:
// Semantically we should always keep the outputs of welfordOp scheduled
// the same but the user end cannot guarantee that.
// In order to guarantee that the rFactor is defined meaningfully the
// scheduling of the output TV that got the rfactor call is force replayed
// towards the other two
// Semantically we should always keep the outputs of multi reduction ops
// scheduled the same but the user end cannot guarantee that. In order to
// guarantee that the rFactor is defined meaningfully the scheduling of the
// output TV that got the rfactor call is force replayed towards the other two
if (!sameAs(tv)) {
auto root = tv->getRootDomain();
@ -758,7 +763,7 @@ TensorView* TensorView::multiOutputRfactorHelper(
std::vector<IterDomain*> new_id;
for (auto id : domain()->domain()) {
TORCH_INTERNAL_ASSERT(
replay.getReplay().count(id), "Welford Replay Failed");
replay.getReplay().count(id), "Multi-output reduction replay failed");
new_id.push_back(replay.getReplay().at(id));
}
@ -795,12 +800,11 @@ std::vector<TensorView*> TensorView::rFactor(
TORCH_CHECK(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
FusionGuard fg(fusion());
TORCH_CHECK(
definition() != nullptr &&
(definition()->getExprType() == ExprType::GroupedReductionOp ||
definition()->getExprType() == ExprType::WelfordOp),
"Error rfactoring welford ",
definition() != nullptr && ir_utils::isReductionOp(definition()),
"Error rfactoring multi-output reduction op ",
this,
" its definition is either a nullptr or not a GroupedReductionOp or a WelfordOp.");
" its definition is either a nullptr or not a GroupedReductionOp or a multi-output reduction op.");
TORCH_CHECK(
!domain()->hasRFactor(), "Cannot call rfactor on the same view twice.");
@ -1129,6 +1133,21 @@ void TensorView::doubleBuffer() {
is_double_buffered_ = true;
}
void TensorView::circularBuffer(unsigned int stage) {
// Early correctness checking. May miss eventual errors as the
// checks depend on memory types and parallelization, which may not
// be finalized until lowering.
TORCH_INTERNAL_ASSERT(stage > 1, "Unsupported stage number");
if (stage == 2) {
// Re-direct to double buffer interface if stage is 2;
doubleBuffer();
return;
}
validateDoubleBufferedTensor(this);
is_circular_buffered_ = true;
circular_buffer_stage_ = stage;
}
bool TensorView::isEmptyTensor() const {
auto& root_domain = getMaybeRFactorDomain();
return std::all_of(
@ -1174,7 +1193,29 @@ TensorViewBuilder& TensorViewBuilder::contiguity(std::vector<bool> contiguity) {
return *this;
}
TensorViewBuilder& TensorViewBuilder::shape(std::vector<int64_t> shape) {
TensorViewBuilder& TensorViewBuilder::shape(const std::vector<int64_t>& shape) {
TORCH_CHECK(shape_.empty(), "Attempting to reset shape");
if (!shape.empty()) {
TORCH_CHECK(ndims_ == 0 || ndims_ == shape.size());
ndims_ = shape.size();
}
shape_.clear();
shape_.reserve(shape.size());
for (int64_t i : shape) {
if (i == -1) {
shape_.emplace_back(IrBuilder::create<Int>());
} else {
TORCH_CHECK(
i >= 0,
"Invalid extent value. ",
"For a tensor representing a single scalar use ndims = 0 with no sizes set.");
shape_.emplace_back(IrBuilder::create<Int>(i));
}
}
return *this;
}
TensorViewBuilder& TensorViewBuilder::shape(std::vector<Val*> shape) {
TORCH_CHECK(shape_.empty(), "Attempting to reset shape");
if (!shape.empty()) {
TORCH_CHECK(ndims_ == 0 || ndims_ == shape.size());
@ -1188,17 +1229,13 @@ TensorView* TensorViewBuilder::build() const {
// Build the domain
std::vector<IterDomain*> domain(ndims_, nullptr);
for (const auto i : c10::irange(ndims_)) {
if (shape_.empty() || shape_[i] == -1) {
if (shape_.empty()) {
domain[i] =
IterDomainBuilder(
FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create<Int>())
.build();
} else {
TORCH_CHECK(
shape_[i] >= 0,
"Invalid extent value. ",
"For a tensor representing a single scalar use ndims = 0 with no sizes set.");
if (shape_[i] == 1) {
if (shape_[i]->isOneInt()) {
// If size is known to be 1, assume it needs to be broadcasted.
domain[i] = IterDomainBuilder(
FusionGuard::getCurFusion()->zeroVal(),
@ -1206,10 +1243,9 @@ TensorView* TensorViewBuilder::build() const {
.iter_type(IterType::Broadcast)
.build();
} else {
domain[i] = IterDomainBuilder(
FusionGuard::getCurFusion()->zeroVal(),
IrBuilder::create<Int>(shape_[i]))
.build();
domain[i] =
IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), shape_[i])
.build();
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -26,6 +26,7 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
#include <torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h>
#include <torch/csrc/jit/codegen/cuda/test/test_utils.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
@ -46,29 +47,6 @@ using namespace at::indexing;
namespace {
// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
// but unknown sizes
TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
return TensorViewBuilder()
.ndims(ndims)
.dtype(dtype)
.contiguity(std::vector<bool>(ndims, true))
.build();
}
// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
// but unknown sizes
TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
}
// Make a non-contiguous tensor of compile-time known sizes
TensorView* makeConcreteTensor(
std::vector<int64_t> shape,
DataType dtype = DataType::Float) {
return TensorViewBuilder().shape(shape).dtype(dtype).build();
}
class KernelExprVisitor : private kir::IrVisitor {
public:
static std::vector<Expr*> getAllExprs(const kir::Kernel* kernel) {
@ -144,7 +122,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) {
tv3->axis(0)->parallelize(ParallelType::BIDy);
tv3->axis(2)->parallelize(ParallelType::BIDx);
tv3->axis(3)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv3);
// Just to make sure fused_reduction and work buffers are allocated
// uniquely
@ -247,7 +225,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce3_CUDA) {
tv3->axis(1)->parallelize(ParallelType::BIDx);
tv3->axis(2)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv3);
GpuLower gpulw(&fusion);
validateNoParallelBroadcastExist(gpulw.kernel());
@ -293,7 +271,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce4_CUDA) {
tv4->axis(0)->parallelize(ParallelType::BIDx);
tv4->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv4);
GpuLower gpulw(&fusion);
validateNoParallelBroadcastExist(gpulw.kernel());
@ -352,7 +330,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce5_CUDA) {
tv4->axis(1)->parallelize(ParallelType::BIDx);
tv4->axis(2)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv4);
tv6->axis(0)->parallelize(ParallelType::BIDy);
tv6->axis(1)->parallelize(ParallelType::BIDx);
@ -410,7 +388,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce6_CUDA) {
tv1->axis(2)->parallelize(ParallelType::BIDx);
tv1->axis(3)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv1);
tv1->axis(4)->parallelize(ParallelType::Vectorize);
@ -456,7 +434,7 @@ TEST_F(NVFuserTest, FusionGridAllreduceWelford1_CUDA) {
tv5->axis(0)->parallelize(ParallelType::BIDx);
tv5->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv5);
GpuLower gpulw(&fusion);
validateNoParallelBroadcastExist(gpulw.kernel());
@ -506,7 +484,7 @@ TEST_F(NVFuserTest, FusionGridAllreduceWelford2_CUDA) {
tv3->axis(1)->parallelize(ParallelType::BIDx);
tv3->axis(2)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv3);
// There must be no parallel broadcast
GpuLower gpulw(&fusion);
@ -610,7 +588,7 @@ TEST_F(NVFuserTest, FusionFusedReductionBatchnorm_CUDA) {
TransformPropagator propagator(tv0);
MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator);
auto tvs_rf = tvs.rFactor({-5, -4, -3, -2, -1});
ir_utils::rfactorHelper(tvs.avg, {-5, -4, -3, -2, -1});
tv0->computeAt(tv29, 2);
tv1->computeAt(tv29, 2);
@ -622,7 +600,7 @@ TEST_F(NVFuserTest, FusionFusedReductionBatchnorm_CUDA) {
tv29->axis(2)->parallelize(ParallelType::BIDy);
tv29->axis(3)->parallelize(ParallelType::TIDz);
tv29->axis(4)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv29, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv29);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
@ -696,7 +674,7 @@ TEST_F(NVFuserTest, FusionGroupedReduction1_CUDA) {
tv2->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv2);
std::vector<int64_t> shape({99, 999});
@ -872,7 +850,7 @@ TEST_F(NVFuserTest, FusionGroupedReduction6_CUDA) {
tv2->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv2);
std::vector<int64_t> shape({99, 999});
@ -937,7 +915,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionRfactor1_CUDA) {
tv1_rf->axis(0)->parallelize(ParallelType::BIDx);
tv1_rf->axis(2)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1_rf, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv1_rf);
std::vector<int64_t> shape({12345});
@ -982,7 +960,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionRfactor2_CUDA) {
tv1_rf->axis(0)->parallelize(ParallelType::BIDx);
tv1_rf->axis(2)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1_rf, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv1_rf);
std::vector<int64_t> shape({12345});
@ -1028,7 +1006,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionAfterComputeAt_CUDA) {
groupReductions({tv2, tv3});
tv2->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv2);
std::vector<int64_t> shape({3, 1234});
@ -1068,7 +1046,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce1_CUDA) {
tv2->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv2);
std::vector<int64_t> shape({999});
@ -1117,7 +1095,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce2_CUDA) {
tv1->axis(0)->parallelize(ParallelType::BIDy);
tv1->axis(1)->parallelize(ParallelType::BIDx);
tv1->axis(2)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv1);
std::vector<int64_t> shape({10, 999});
@ -1169,7 +1147,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce3_CUDA) {
tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv1);
std::vector<int64_t> shape({999});
@ -1221,7 +1199,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce4_CUDA) {
reduction_tv->axis(0)->parallelize(ParallelType::BIDx);
reduction_tv->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(reduction_tv, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(reduction_tv);
std::vector<int64_t> shape({999});
@ -1281,7 +1259,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce5_CUDA) {
tv1->axis(0)->parallelize(ParallelType::BIDx);
tv1->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv1);
std::vector<int64_t> shape({999});
@ -1445,7 +1423,7 @@ TEST_F(NVFuserTest, FusionPersistentBNBackwardAllreduce_CUDA) {
}
// Parallelization
scheduler_utils::parallelizeAllLike(grad_input, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(grad_input);
input_cache->axis(-1)->parallelize(ParallelType::Vectorize);
grad_output_cache->axis(-1)->parallelize(ParallelType::Vectorize);
@ -1559,7 +1537,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionReEntrant1_CUDA) {
tv2->axis(2)->parallelize(ParallelType::BIDx);
tv2->axis(3)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv2);
std::vector<int64_t> shape({99, 999});
@ -1670,7 +1648,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionChannelsLastBatchNormLike_CUDA) {
ref->axis(4)->parallelize(ParallelType::Serial);
ref->axis(5)->parallelize(ParallelType::Serial);
scheduler_utils::parallelizeAllLike(ref, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(ref);
tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize);
tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize);
@ -1799,7 +1777,7 @@ TEST_F(
ref->axis(4)->parallelize(ParallelType::Serial);
ref->axis(5)->parallelize(ParallelType::Serial);
scheduler_utils::parallelizeAllLike(ref, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(ref);
tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize);
tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize);
@ -1866,7 +1844,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce1_CUDA) {
tv1->axis(2)->parallelize(ParallelType::BIDx);
tv1->axis(3)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv1);
tv2->axis(4)->parallelize(ParallelType::Group);
@ -1944,7 +1922,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce2_CUDA) {
tv1->axis(2)->parallelize(ParallelType::BIDx);
tv1->axis(3)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv1);
tv2->axis(4)->parallelize(ParallelType::Group);
tv2->axis(5)->parallelize(ParallelType::Group);
@ -2030,7 +2008,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce3_CUDA) {
tv1->axis(2)->parallelize(ParallelType::BIDx);
tv1->axis(3)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv1);
tv2->axis(4)->parallelize(ParallelType::Group);

View File

@ -26,6 +26,7 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
#include <torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h>
#include <torch/csrc/jit/codegen/cuda/test/test_utils.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
@ -46,48 +47,6 @@ using namespace at::indexing;
namespace {
// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
// but unknown sizes
TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
return TensorViewBuilder()
.ndims(ndims)
.dtype(dtype)
.contiguity(std::vector<bool>(ndims, true))
.build();
}
// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
// but unknown sizes
TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
}
// Make a non-contiguous tensor of compile-time known sizes
TensorView* makeConcreteTensor(
std::vector<int64_t> shape,
DataType dtype = DataType::Float) {
return TensorViewBuilder().shape(shape).dtype(dtype).build();
}
void checkIntValue(
ExpressionEvaluator& evaluator,
Val* val,
Int::ScalarType expected_value) {
TORCH_CHECK(val->isAnInt());
const auto actual_value = evaluator.evaluate(val);
TORCH_CHECK(actual_value.has_value());
TORCH_CHECK(actual_value.value() == expected_value);
}
void checkIntValue(
kir::ExpressionEvaluator& evaluator,
const Val* val,
Int::ScalarType expected_value) {
const auto actual_value = evaluator.evaluate(val);
TORCH_CHECK(actual_value.has_value());
TORCH_CHECK(actual_value.value() == expected_value);
}
// Used to signify invalid ranges, i.e., values at offset 0 to
// start_offset, and values at offset stop_offset to the end of the
// domain.
@ -2175,7 +2134,7 @@ TEST_F(NVFuserTest, FusionHdiff_CUDA) {
out->axis(3)->parallelize(ParallelType::TIDy);
out->axis(4)->parallelize(ParallelType::TIDx);
// Apply the same parallelization to all other tensors
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(out);
// Store intermediate stencil results on smem so that they can be
// accessed by threads
@ -2733,7 +2692,7 @@ TEST_F(NVFuserTest, FusionGather6_CUDA) {
out->axis(1)->parallelize(ParallelType::BIDx);
out->axis(2)->parallelize(ParallelType::TIDy);
out->axis(3)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(out);
const int s1 = 101;
const int s2 = 99;
@ -2793,7 +2752,7 @@ TEST_F(NVFuserTest, FusionGather7_CUDA) {
out->axis(1)->parallelize(ParallelType::BIDx);
out->axis(2)->parallelize(ParallelType::TIDy);
out->axis(3)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(out);
const int s1 = 101;
const int s2 = 99;
@ -2894,7 +2853,7 @@ TEST_F(NVFuserTest, FusionGather9_CUDA) {
out->axis(1)->parallelize(ParallelType::BIDx);
out->axis(2)->parallelize(ParallelType::TIDy);
out->axis(3)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(out);
const int s1 = 101;
const int s2 = 99;
@ -3817,7 +3776,7 @@ TEST_F(NVFuserTest, FusionShiftNoPadding1_CUDA) {
tv5->axis(-1)->parallelize(ParallelType::TIDx);
tv5->axis(-2)->parallelize(ParallelType::TIDy);
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv5);
int numel_x = 99;
int numel_y = 101;
@ -3873,7 +3832,7 @@ TEST_F(NVFuserTest, FusionShiftNoPadding2_CUDA) {
tv3->computeAt(tv5, -1);
tv5->axis(-1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv5);
int numel_x = 99;
int numel_y = 101;
@ -3934,7 +3893,7 @@ TEST_F(NVFuserTest, FusionShiftNoPadding3_CUDA) {
tv3->computeAt(tv_avg, -1);
tv_avg->axis(-1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv_avg, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv_avg);
int numel_x = 99;
int numel_y = 101;
@ -4122,7 +4081,7 @@ TEST_F(NVFuserTest, FusionShiftPadding1_CUDA) {
tv5->axis(-1)->parallelize(ParallelType::TIDx);
tv5->axis(-2)->parallelize(ParallelType::TIDy);
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv5);
int numel_x = 99;
int numel_y = 101;
@ -5080,7 +5039,7 @@ TEST_F(NVFuserTest, FusionMaxPoolingStrided_CUDA) {
max_tensor->axis(4)->parallelize(ParallelType::TIDy);
max_tensor->axis(5)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(max_tensor, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(max_tensor);
inp_cache->setMemoryType(MemoryType::Shared);
@ -5332,7 +5291,7 @@ TEST_F(NVFuserTest, FusionGather9ptStencilDoubleBuffering_CUDA) {
out->axis(2)->parallelize(ParallelType::TIDy);
out->axis(0)->parallelize(ParallelType::BIDx);
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(out);
tv0_cache->doubleBuffer();
@ -5380,7 +5339,7 @@ TEST_F(NVFuserTest, FusionValidateParallelizeShift_CUDA) {
tv5->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv5);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({1024 * 32}, options);

File diff suppressed because it is too large Load Diff

View File

@ -28,6 +28,7 @@
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
#include <torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h>
#include <torch/csrc/jit/codegen/cuda/test/test_utils.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
@ -49,33 +50,6 @@ namespace jit {
using namespace torch::jit::fuser::cuda;
using namespace at::indexing;
namespace {
// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
// but unknown sizes
TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
return TensorViewBuilder()
.ndims(ndims)
.dtype(dtype)
.contiguity(std::vector<bool>(ndims, true))
.build();
}
// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
// but unknown sizes
TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
}
// Make a non-contiguous tensor of compile-time known sizes
TensorView* makeConcreteTensor(
std::vector<int64_t> shape,
DataType dtype = DataType::Float) {
return TensorViewBuilder().shape(shape).dtype(dtype).build();
}
} // namespace
TEST_F(NVFuserTest, FusionViewDtypeSameSizeOutput_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

Some files were not shown because too many files have changed in this diff Show More