mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[NVFuser] Upstream push 0809 (#83067)
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD 16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875) 501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823) a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872) 172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871) 440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870) d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864) 51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866) a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861) 1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852) e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858) 9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857) 20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855) 405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853) 5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845) db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848) 102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847) b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842) 0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840) 63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839) 2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067 Approved by: https://github.com/davidberard98
This commit is contained in:
parent
ce8716f59a
commit
df741c589f
|
|
@ -133,8 +133,8 @@ static void MagicScheduler_DivMaxSoftDropFwd(
|
|||
std::vector<at::Tensor> cg_outputs;
|
||||
|
||||
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
|
||||
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
|
||||
schedulePersistentKernel(&fusion, norm_params.value());
|
||||
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
|
||||
schedulePersistentKernel(&fusion, *norm_params);
|
||||
|
||||
FusionExecutor fe;
|
||||
fe.compileFusion(&fusion);
|
||||
|
|
@ -143,7 +143,7 @@ static void MagicScheduler_DivMaxSoftDropFwd(
|
|||
cudaDeviceSynchronize();
|
||||
for (auto _ : benchmark_state) {
|
||||
CudaKernelTimer timer;
|
||||
cg_outputs = fe.runFusion({t0, t1}, norm_params.value().lparams);
|
||||
cg_outputs = fe.runFusion({t0, t1}, norm_params->lparams);
|
||||
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
|
||||
}
|
||||
// Sync everything up before we're finished, don't want to run ahead on the
|
||||
|
|
@ -193,8 +193,8 @@ static void MagicScheduler_DivMaxSoftDropBwd(
|
|||
std::vector<at::Tensor> cg_outputs;
|
||||
|
||||
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
|
||||
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
|
||||
schedulePersistentKernel(&fusion, norm_params.value());
|
||||
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
|
||||
schedulePersistentKernel(&fusion, *norm_params);
|
||||
|
||||
FusionExecutor fe;
|
||||
fe.compileFusion(&fusion);
|
||||
|
|
@ -203,7 +203,7 @@ static void MagicScheduler_DivMaxSoftDropBwd(
|
|||
cudaDeviceSynchronize();
|
||||
for (auto _ : benchmark_state) {
|
||||
CudaKernelTimer timer;
|
||||
cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params.value().lparams);
|
||||
cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params->lparams);
|
||||
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
|
||||
}
|
||||
// Sync everything up before we're finished, don't want to run ahead on the
|
||||
|
|
@ -308,8 +308,8 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd(
|
|||
std::vector<at::Tensor> cg_outputs;
|
||||
|
||||
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
|
||||
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
|
||||
schedulePersistentKernel(&fusion, norm_params.value());
|
||||
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
|
||||
schedulePersistentKernel(&fusion, *norm_params);
|
||||
|
||||
FusionExecutor fe;
|
||||
fe.compileFusion(&fusion);
|
||||
|
|
@ -319,7 +319,7 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd(
|
|||
cudaDeviceSynchronize();
|
||||
for (auto _ : benchmark_state) {
|
||||
CudaKernelTimer timer;
|
||||
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
|
||||
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
|
||||
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
|
||||
}
|
||||
// Sync everything up before we're finished, don't want to run ahead on the
|
||||
|
|
@ -423,8 +423,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd1(
|
|||
std::vector<at::Tensor> cg_outputs;
|
||||
|
||||
auto norm_params = getReductionHeuristics(&fusion, at_inputs);
|
||||
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
|
||||
scheduleReduction(&fusion, norm_params.value());
|
||||
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
|
||||
scheduleReduction(&fusion, *norm_params);
|
||||
|
||||
FusionExecutor fe;
|
||||
fe.compileFusion(&fusion);
|
||||
|
|
@ -434,7 +434,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd1(
|
|||
cudaDeviceSynchronize();
|
||||
for (auto _ : benchmark_state) {
|
||||
clearL2Cache();
|
||||
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
|
||||
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
|
||||
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
|
||||
}
|
||||
// Sync everything up before we're finished, don't want to run ahead on the
|
||||
|
|
@ -534,8 +534,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd2(
|
|||
std::vector<at::Tensor> cg_outputs;
|
||||
|
||||
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
|
||||
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
|
||||
schedulePersistentKernel(&fusion, norm_params.value());
|
||||
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
|
||||
schedulePersistentKernel(&fusion, *norm_params);
|
||||
|
||||
FusionExecutor fe;
|
||||
fe.compileFusion(&fusion);
|
||||
|
|
@ -545,7 +545,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd2(
|
|||
cudaDeviceSynchronize();
|
||||
for (auto _ : benchmark_state) {
|
||||
CudaKernelTimer timer;
|
||||
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
|
||||
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
|
||||
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
|
||||
}
|
||||
// Sync everything up before we're finished, don't want to run ahead on the
|
||||
|
|
@ -625,8 +625,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd3(
|
|||
std::vector<at::Tensor> cg_outputs;
|
||||
|
||||
auto norm_params = getReductionHeuristics(&fusion, at_inputs);
|
||||
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
|
||||
scheduleReduction(&fusion, norm_params.value());
|
||||
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
|
||||
scheduleReduction(&fusion, *norm_params);
|
||||
|
||||
FusionExecutor fe;
|
||||
fe.compileFusion(&fusion);
|
||||
|
|
@ -636,7 +636,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd3(
|
|||
cudaDeviceSynchronize();
|
||||
for (auto _ : benchmark_state) {
|
||||
CudaKernelTimer timer;
|
||||
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
|
||||
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
|
||||
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
|
||||
}
|
||||
// Sync everything up before we're finished, don't want to run ahead on the
|
||||
|
|
|
|||
|
|
@ -69,8 +69,7 @@ static void NvFuserScheduler_Broadcast(
|
|||
|
||||
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
|
||||
auto executor_instance = compile_log.fusion_executor;
|
||||
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
|
||||
auto params = toString(compile_log.pointwise_params.value());
|
||||
auto params = toString(compile_log.params);
|
||||
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
|
||||
|
||||
benchmark_state.SetLabel(params + lparams);
|
||||
|
|
|
|||
|
|
@ -65,8 +65,7 @@ static void NvFuserScheduler_Reduction(
|
|||
|
||||
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
|
||||
auto executor_instance = compile_log.fusion_executor;
|
||||
TORCH_INTERNAL_ASSERT(compile_log.reduction_params.has_value());
|
||||
auto rparams = toString(compile_log.reduction_params.value());
|
||||
auto rparams = toString(compile_log.params);
|
||||
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
|
||||
|
||||
benchmark_state.SetLabel(rparams + lparams);
|
||||
|
|
|
|||
|
|
@ -135,8 +135,7 @@ static void NvFuserScheduler_SBR(
|
|||
|
||||
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
|
||||
auto executor_instance = compile_log.fusion_executor;
|
||||
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
|
||||
auto params = toString(compile_log.pointwise_params.value());
|
||||
auto params = toString(compile_log.params);
|
||||
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
|
||||
|
||||
benchmark_state.SetLabel(params + lparams);
|
||||
|
|
@ -238,8 +237,7 @@ static void NvFuserScheduler_SBR_Norm(
|
|||
|
||||
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
|
||||
auto executor_instance = compile_log.fusion_executor;
|
||||
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
|
||||
auto params = toString(compile_log.pointwise_params.value());
|
||||
auto params = toString(compile_log.params);
|
||||
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
|
||||
|
||||
benchmark_state.SetLabel(params + lparams);
|
||||
|
|
|
|||
|
|
@ -84,8 +84,8 @@ static void setupTranspose(
|
|||
return (is_transpose) ? transpose(tv, axes.first, axes.second) : tv;
|
||||
};
|
||||
|
||||
auto input1 = makeContigTensor(num_dims);
|
||||
auto input2 = makeContigTensor(num_dims);
|
||||
auto input1 = makeContigTensor(num_dims, dtype);
|
||||
auto input2 = makeContigTensor(num_dims, dtype);
|
||||
fusion->addInput(input1);
|
||||
fusion->addInput(input2);
|
||||
|
||||
|
|
|
|||
|
|
@ -89,6 +89,20 @@ std::string toString(PointwiseParams params) {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
std::string toString(const std::shared_ptr<HeuristicParams>& params) {
|
||||
auto rparams = std::dynamic_pointer_cast<ReductionParams>(params);
|
||||
if (rparams) {
|
||||
return toString(*rparams);
|
||||
}
|
||||
auto pparams = std::dynamic_pointer_cast<PointwiseParams>(params);
|
||||
if (pparams) {
|
||||
return toString(*pparams);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false,
|
||||
"Unknown heuristic parameter type. Did you just added a new heuristic parameter type but forget to update here?");
|
||||
}
|
||||
|
||||
std::string toString(LaunchParams lparams) {
|
||||
std::stringstream ss;
|
||||
lparams.toString();
|
||||
|
|
@ -123,9 +137,7 @@ TensorView* makeContigTensor(size_t ndims, DataType dtype) {
|
|||
.build();
|
||||
}
|
||||
|
||||
TensorView* makeConcreteTensor(
|
||||
std::vector<int64_t> shape,
|
||||
DataType dtype) {
|
||||
TensorView* makeConcreteTensor(std::vector<int64_t> shape, DataType dtype) {
|
||||
return TensorViewBuilder().shape(shape).dtype(dtype).build();
|
||||
}
|
||||
|
||||
|
|
@ -157,15 +169,9 @@ void runBenchmarkIterations(
|
|||
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
|
||||
auto executor_instance = compile_log.fusion_executor;
|
||||
|
||||
if (compile_log.reduction_params.has_value()) {
|
||||
auto rparams = toString(compile_log.reduction_params.value());
|
||||
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
|
||||
benchmark_state.SetLabel(rparams + lparams);
|
||||
} else if (compile_log.pointwise_params.has_value()){
|
||||
auto pparams = toString(compile_log.pointwise_params.value());
|
||||
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
|
||||
benchmark_state.SetLabel(pparams + lparams);
|
||||
}
|
||||
auto params = toString(compile_log.params);
|
||||
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
|
||||
benchmark_state.SetLabel(params + lparams);
|
||||
|
||||
executor_instance->setMeasureKernelTimeFlag(true);
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ TensorView* makeContigConcreteTensor(
|
|||
|
||||
std::string toString(ReductionParams rparams);
|
||||
std::string toString(PointwiseParams params);
|
||||
std::string toString(const std::shared_ptr<HeuristicParams>& params);
|
||||
std::string toString(LaunchParams lparams);
|
||||
|
||||
// Run benchmark iterations with provided inputs. If not segmented, report
|
||||
|
|
|
|||
|
|
@ -717,8 +717,10 @@ libtorch_cuda_core_sources = [
|
|||
"torch/csrc/jit/codegen/cuda/register_interface.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/root_domain_map.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",
|
||||
|
|
|
|||
|
|
@ -669,6 +669,7 @@ class TestCudaFuser(JitTestCase):
|
|||
torch.isreal,
|
||||
torch.nn.functional.softplus,
|
||||
torch.nn.functional.gelu,
|
||||
torch.nn.functional.leaky_relu,
|
||||
torch.nn.functional.silu,
|
||||
torch.relu,
|
||||
torch.sigmoid,
|
||||
|
|
@ -4938,6 +4939,28 @@ class TestCudaFuser(JitTestCase):
|
|||
t2_jit = torch.jit.script(t2)
|
||||
self._run_helper(t2_jit, t2, x0, x1, x2, check_stride=True)
|
||||
|
||||
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
||||
"Requires fusion optimization pass to be effective")
|
||||
def test_type_inference(self):
|
||||
device = "cuda"
|
||||
x0 = torch.randn(10, 128, device=device)
|
||||
x1 = torch.rand_like(x0)
|
||||
x2 = torch.rand_like(x0)
|
||||
|
||||
def t(x0, x1, x2, flag : bool = True):
|
||||
x3 = 2.0 * x0
|
||||
x4 = 2.0 * x1
|
||||
x5 = 2.0 * x2
|
||||
if flag:
|
||||
return torch.stack([x3, x4, x5], dim=-1)
|
||||
# second code path doesn't run through profiling
|
||||
# hence would utilize type inference with profiling information
|
||||
return x0 + x1 + x2
|
||||
|
||||
t_jit = torch.jit.script(t)
|
||||
self._run_helper(t_jit, t, x0, x1, x2, check_stride=True)
|
||||
|
||||
|
||||
class TestEnableDisableCudaFuser(JitTestCase):
|
||||
def setUp(self):
|
||||
|
|
|
|||
|
|
@ -466,6 +466,7 @@ NVFUSER_DEFINE_UNARY_OP(relu, Relu)
|
|||
NVFUSER_DEFINE_UNARY_OP(round, Round)
|
||||
NVFUSER_DEFINE_UNARY_OP(silu, Silu)
|
||||
NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
|
||||
NVFUSER_DEFINE_UNARY_OP(print, Print)
|
||||
#undef NVFUSER_DEFINE_UNARY_OP
|
||||
|
||||
Val* randlike(Val* v) {
|
||||
|
|
@ -1430,12 +1431,6 @@ WelfordResult::WelfordResult(
|
|||
TORCH_INTERNAL_ASSERT(avg->definition()->sameAs(n->definition()));
|
||||
}
|
||||
|
||||
WelfordResult WelfordResult::rFactor(const std::vector<int>& axes) {
|
||||
auto o_tv = avg->definition()->as<WelfordOp>()->out()->as<TensorView>();
|
||||
auto rf_tvs = o_tv->rFactor(axes, std::vector<TensorView*>{avg, var_sum, n});
|
||||
return WelfordResult{rf_tvs.at(0), rf_tvs.at(1), rf_tvs.at(2)};
|
||||
}
|
||||
|
||||
// COMPOUND OPERATIONS
|
||||
|
||||
// add_alpha
|
||||
|
|
|
|||
|
|
@ -107,8 +107,6 @@ class TORCH_CUDA_CU_API WelfordResult {
|
|||
TensorView* in_avg,
|
||||
TensorView* in_var_sum,
|
||||
TensorView* in_n);
|
||||
|
||||
WelfordResult rFactor(const std::vector<int>& axes);
|
||||
};
|
||||
|
||||
//! Welford operator on specified axes. This is currently the only scan op with
|
||||
|
|
@ -253,6 +251,9 @@ TORCH_CUDA_CU_API TensorView* isposinf(TensorView*);
|
|||
// isreal
|
||||
TORCH_CUDA_CU_API Val* isreal(Val*);
|
||||
TORCH_CUDA_CU_API TensorView* isreal(TensorView*);
|
||||
// print
|
||||
TORCH_CUDA_CU_API Val* print(Val*);
|
||||
TORCH_CUDA_CU_API TensorView* print(TensorView*);
|
||||
|
||||
// Broadcasts inp based on bool vector. Size of broadcast bool vector should be
|
||||
// the number of dims desired in the broadcasted tensor. This vector should be
|
||||
|
|
|
|||
|
|
@ -474,7 +474,11 @@ class CudaKernelGenerator : private OptOutConstDispatch {
|
|||
}
|
||||
}
|
||||
|
||||
void handle(const kir::TensorIndex* ti) final {
|
||||
//! Returns the sum of all indices in a TensorIndex,
|
||||
//! or 0 if the indices vector is empty.
|
||||
//! Used lowering generic tensor index and lowering
|
||||
//! mma fragment indices.
|
||||
std::string genTensorIndex(const kir::TensorIndex* ti) {
|
||||
bool first = true;
|
||||
std::stringstream index;
|
||||
for (auto* ind : ti->indices()) {
|
||||
|
|
@ -490,12 +494,17 @@ class CudaKernelGenerator : private OptOutConstDispatch {
|
|||
if (first) {
|
||||
index << "0";
|
||||
}
|
||||
|
||||
return index.str();
|
||||
}
|
||||
|
||||
void handle(const kir::TensorIndex* ti) final {
|
||||
bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global &&
|
||||
kernel_->summary().sync_map.needsRawSync(ti->view()).hasBID();
|
||||
if (is_volatile) {
|
||||
code_ << "*(volatile " << ti->getDataType().value() << "*)&";
|
||||
}
|
||||
code_ << varName(ti->view()) << "[" << index.str() << "]";
|
||||
code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]";
|
||||
}
|
||||
|
||||
void handle(const ViewAsScalar* sv) final {
|
||||
|
|
@ -621,7 +630,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
|
|||
// Double buffered local tensors need indexed initialization,
|
||||
// so will need to use `arraySet` option.
|
||||
if (out_tv->getMemoryType() == MemoryType::Local &&
|
||||
!out_tv->isDoubleBuffered()) {
|
||||
!(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered())) {
|
||||
// Vectorized initialization
|
||||
indent() << varName(out_tv) << ".set(" << gen(uop->in()) << ");\n";
|
||||
} else {
|
||||
|
|
@ -971,13 +980,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
|
|||
}
|
||||
}
|
||||
|
||||
std::string genArchString(MmaOptions options) {
|
||||
std::string genArchString(MmaOptions::MacroType macro) {
|
||||
std::stringstream ss;
|
||||
if (isVolta(options.macro)) {
|
||||
if (isVolta(macro)) {
|
||||
ss << "Volta";
|
||||
} else if (isTuring(options.macro)) {
|
||||
} else if (isTuring(macro)) {
|
||||
ss << "Turing";
|
||||
} else if (isAmpere(options.macro)) {
|
||||
} else if (isAmpere(macro)) {
|
||||
ss << "Ampere";
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false, "mma macro unknown arch");
|
||||
|
|
@ -988,7 +997,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
|
|||
std::string genMmaOp(const MmaOp* mma, bool init = false) {
|
||||
std::stringstream ss;
|
||||
auto options = mma->options();
|
||||
ss << genArchString(options) << "::";
|
||||
ss << genArchString(options.macro) << "::";
|
||||
if (init) {
|
||||
ss << "init";
|
||||
}
|
||||
|
|
@ -1013,14 +1022,17 @@ class CudaKernelGenerator : private OptOutConstDispatch {
|
|||
auto options = mma->options();
|
||||
auto in_a = mma->inA()->as<kir::TensorIndex>()->view();
|
||||
auto dtype = in_a->getDataType().value();
|
||||
indent() << kTab << "reinterpret_cast<Array<" << dtype << ","
|
||||
indent() << kTab << "&(reinterpret_cast<Array<" << dtype << ","
|
||||
<< getInputARegisterSize(options.macro) << ","
|
||||
<< getInputARegisterSize(options.macro) << ">*>(&"
|
||||
<< gen(mma->inA()) << "),\n";
|
||||
indent() << kTab << "reinterpret_cast<Array<" << dtype << ","
|
||||
<< varName(mma->inA()->as<kir::TensorIndex>()->view()) << ")["
|
||||
<< genTensorIndex(mma->inA()->as<kir::TensorIndex>()) << "])"
|
||||
<< ",\n";
|
||||
indent() << kTab << "&(reinterpret_cast<Array<" << dtype << ","
|
||||
<< getInputBRegisterSize(options.macro) << ","
|
||||
<< getInputBRegisterSize(options.macro) << ">*>(&"
|
||||
<< gen(mma->inB()) << ")";
|
||||
<< varName(mma->inB()->as<kir::TensorIndex>()->view()) << ")["
|
||||
<< genTensorIndex(mma->inB()->as<kir::TensorIndex>()) << "])";
|
||||
}
|
||||
|
||||
void genMmaInitialization(const MmaOp* mma, const UnaryOp* uop) {
|
||||
|
|
@ -2332,7 +2344,19 @@ class CudaKernelGenerator : private OptOutConstDispatch {
|
|||
}
|
||||
|
||||
void handle(const kir::CpAsyncWait* cpasync_wait) final {
|
||||
indent() << "Ampere::cpAsyncBarrier();\n";
|
||||
if (cpasync_wait->keepStages() > 0) {
|
||||
// Perform partial sync, see comment on kir::CpAsyncWait.
|
||||
indent() << "Ampere::cpAsyncPartialBarrier<" << cpasync_wait->keepStages()
|
||||
<< ">();\n";
|
||||
} else {
|
||||
// Perform sync all, see comment on kir::CpAsyncWait.
|
||||
indent() << "Ampere::cpAsyncBarrier();\n";
|
||||
}
|
||||
}
|
||||
|
||||
void handle(const kir::CpAsyncCommit* cpasync_wait) final {
|
||||
// Commit inflight cp.async transfers. See comment on kir::CpAsyncCommit.
|
||||
indent() << "Ampere::cpAsyncCommit();\n";
|
||||
}
|
||||
|
||||
void handle(const kir::GridSync* sync) final {
|
||||
|
|
@ -2390,11 +2414,9 @@ class CudaKernelGenerator : private OptOutConstDispatch {
|
|||
|
||||
void handle(const kir::IntPair* int_pair) {
|
||||
const auto def = int_pair->definition();
|
||||
if (print_inline_) {
|
||||
code_ << gen(def);
|
||||
} else {
|
||||
code_ << varName(int_pair);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
def != nullptr, "no support for un-inlined int pair yet.");
|
||||
code_ << gen(def);
|
||||
}
|
||||
|
||||
void handle(const kir::PairSelect* pair_select) {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,35 @@ namespace jit {
|
|||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
// Simple selector that only propagates across tensor views in the provided
|
||||
// unordered_set. Will also propagate to all consumers of those tensors, and the
|
||||
// siblings of those tensors.
|
||||
class ComputeAtSelector : public MaxInfoSpanningTree::Selector {
|
||||
std::unordered_set<TensorView*> selected_;
|
||||
|
||||
public:
|
||||
virtual bool allowC2P(TensorView* from, TensorView* to) override {
|
||||
return selected_.count(to) > 0;
|
||||
}
|
||||
|
||||
virtual bool allowP2C(TensorView* from, TensorView* to) override {
|
||||
// If the producer is in the selected set, then the consumer must also be
|
||||
// replayed to obtain a compatible loop structure so that this producer
|
||||
// can be consumed in this loop.
|
||||
return selected_.count(from) > 0 || selected_.count(to) > 0;
|
||||
}
|
||||
|
||||
virtual bool allowSibling(TensorView* from, TensorView* to) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
ComputeAtSelector(std::unordered_set<TensorView*> selected)
|
||||
: selected_(std::move(selected)) {}
|
||||
const std::unordered_set<TensorView*>& selected() const {
|
||||
return selected_;
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// Wrapper around set_intersection
|
||||
|
|
@ -182,11 +211,10 @@ void ComputeAt::runAt(
|
|||
FusionGuard fg(producer->fusion());
|
||||
|
||||
auto selected = getPropagationSubgraph(producer, consumer);
|
||||
InlinePropagatorSelector selector(selected);
|
||||
ComputeAtSelector selector(selected);
|
||||
|
||||
InlinePropagator inline_propagator(
|
||||
consumer, consumer_position, mode, selector.selected());
|
||||
MaxProducerPosUpdater updater;
|
||||
|
||||
MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector);
|
||||
|
||||
|
|
@ -199,7 +227,6 @@ void ComputeAt::runAt(
|
|||
}
|
||||
|
||||
path.traverse(&inline_propagator);
|
||||
path.traverse(&updater);
|
||||
}
|
||||
|
||||
void ComputeAt::runWith(
|
||||
|
|
@ -224,11 +251,10 @@ void ComputeAt::runWith(
|
|||
FusionGuard fg(producer->fusion());
|
||||
|
||||
auto selected = getPropagationSubgraph(producer, consumer);
|
||||
InlinePropagatorSelector selector(selected);
|
||||
ComputeAtSelector selector(selected);
|
||||
|
||||
InlinePropagator inline_propagator(
|
||||
producer, producer_position, mode, selector.selected());
|
||||
MaxProducerPosUpdater updater;
|
||||
|
||||
MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector);
|
||||
|
||||
|
|
@ -240,7 +266,6 @@ void ComputeAt::runWith(
|
|||
path.traverse(&propagator);
|
||||
}
|
||||
path.traverse(&inline_propagator);
|
||||
path.traverse(&updater);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -178,6 +178,8 @@ void IterDomainGraph::build(Fusion* fusion) {
|
|||
BestEffortReplay::replayPasC(p_tv, c_tv, -1, pairwise_map);
|
||||
|
||||
const auto& permissive_c2p_map = permissive_replay_PasC.getReplay();
|
||||
const auto permissive_disjoint_sets =
|
||||
permissive_replay_PasC.getDisjointSets();
|
||||
|
||||
// For exact mapings do not map any broadcast dimensions to
|
||||
// non-broadcast dimensions. Prevent any broadcasted axes being mapped
|
||||
|
|
@ -213,6 +215,17 @@ void IterDomainGraph::build(Fusion* fusion) {
|
|||
auto p_id = entry.second;
|
||||
if (idIsAComputeAtLeafDomain(p_id, p_tv)) {
|
||||
loop_nodes_.mapEntries(c_id, p_id);
|
||||
} else {
|
||||
// When there are trivial reductions merged with other dims, `p_id`
|
||||
// might not be a compute at leaf domain of `p_tv`, but it actually
|
||||
// has an equivalent compute at leaf domain. For that case, we map
|
||||
// the equivalent compute at leaf domain.
|
||||
for (int i = 0; i < p_tv->getComputeAtPosition(); i++) {
|
||||
auto id = p_tv->axis(i);
|
||||
if (permissive_disjoint_sets.permissiveAreMapped(p_id, id)) {
|
||||
loop_nodes_.mapEntries(c_id, id);
|
||||
}
|
||||
}
|
||||
}
|
||||
permissive_nodes_.mapEntries(c_id, p_id);
|
||||
consumers_.at(p_id).pushBack(c_id);
|
||||
|
|
@ -225,8 +238,8 @@ void IterDomainGraph::build(Fusion* fusion) {
|
|||
mapMaybeSwizzleOp(permissive_nodes_, c_id);
|
||||
}
|
||||
|
||||
// Make sure we always get root mapping for the permissive map. Because
|
||||
// of forwarding we could otherwise miss some root mappings.
|
||||
// Make sure we always get root mapping for the permissive map.
|
||||
// Because of forwarding we could otherwise miss some root mappings.
|
||||
for (auto entry : permissive_c2p_root_map) {
|
||||
auto c_id = entry.first;
|
||||
auto p_id = entry.second;
|
||||
|
|
|
|||
|
|
@ -96,6 +96,47 @@ class VectorOfUniqueEntries {
|
|||
return set_.find(entry) != set_.end();
|
||||
}
|
||||
|
||||
// Erase given entry from the containers if
|
||||
// there is a match.
|
||||
void erase(T entry) {
|
||||
vector_.erase(
|
||||
std::remove_if(
|
||||
vector_.begin(),
|
||||
vector_.end(),
|
||||
[entry](T val) { return val == entry; }),
|
||||
vector_.end());
|
||||
|
||||
set_.erase(entry);
|
||||
}
|
||||
|
||||
// Insert elements at the end of the container.
|
||||
template <typename InputIt>
|
||||
void insert(InputIt begin, InputIt end) {
|
||||
for (auto it = begin; it != end; it++) {
|
||||
pushBack(*it);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns iterator pointing to the beginning of vector container
|
||||
auto begin() const {
|
||||
return vector().begin();
|
||||
}
|
||||
|
||||
// Returns iterator pointing to the end of vector container
|
||||
auto end() const {
|
||||
return vector().end();
|
||||
}
|
||||
|
||||
// Returns iterator pointing to the beginning of vector container
|
||||
auto begin() {
|
||||
return vector().begin();
|
||||
}
|
||||
|
||||
// Returns iterator pointing to the end of vector container
|
||||
auto end() {
|
||||
return vector().end();
|
||||
}
|
||||
|
||||
std::string toString() {
|
||||
std::stringstream ss;
|
||||
ss << "{ ";
|
||||
|
|
|
|||
|
|
@ -163,6 +163,9 @@ void Expr::dispatch(T handler, Expr* expr) {
|
|||
case ExprType::CpAsyncWait:
|
||||
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
|
||||
return;
|
||||
case ExprType::CpAsyncCommit:
|
||||
ptr(handler)->handle(expr->as<kir::CpAsyncCommit>());
|
||||
return;
|
||||
case ExprType::InitMagicZero:
|
||||
ptr(handler)->handle(expr->as<kir::InitMagicZero>());
|
||||
return;
|
||||
|
|
@ -334,6 +337,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
|
|||
case ExprType::CpAsyncWait:
|
||||
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
|
||||
return;
|
||||
case ExprType::CpAsyncCommit:
|
||||
ptr(handler)->handle(expr->as<kir::CpAsyncCommit>());
|
||||
return;
|
||||
case ExprType::InitMagicZero:
|
||||
ptr(handler)->handle(expr->as<kir::InitMagicZero>());
|
||||
return;
|
||||
|
|
@ -513,6 +519,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
|
|||
case ExprType::CpAsyncWait:
|
||||
ptr(mutator)->mutate(expr->as<kir::CpAsyncWait>());
|
||||
return;
|
||||
case ExprType::CpAsyncCommit:
|
||||
ptr(mutator)->mutate(expr->as<kir::CpAsyncCommit>());
|
||||
return;
|
||||
case ExprType::InitMagicZero:
|
||||
ptr(mutator)->mutate(expr->as<kir::InitMagicZero>());
|
||||
return;
|
||||
|
|
@ -757,6 +766,9 @@ void OptOutConstDispatch::handle(const kir::GridSync* stmt) {
|
|||
void OptOutConstDispatch::handle(const kir::CpAsyncWait* stmt) {
|
||||
unhandled(stmt);
|
||||
}
|
||||
void OptOutConstDispatch::handle(const kir::CpAsyncCommit* stmt) {
|
||||
unhandled(stmt);
|
||||
}
|
||||
void OptOutConstDispatch::handle(const kir::InitMagicZero* stmt) {
|
||||
unhandled(stmt);
|
||||
}
|
||||
|
|
@ -898,6 +910,9 @@ void OptOutDispatch::handle(kir::GridSync* stmt) {
|
|||
void OptOutDispatch::handle(kir::CpAsyncWait* stmt) {
|
||||
unhandled(stmt);
|
||||
}
|
||||
void OptOutDispatch::handle(kir::CpAsyncCommit* stmt) {
|
||||
unhandled(stmt);
|
||||
}
|
||||
void OptOutDispatch::handle(kir::InitMagicZero* stmt) {
|
||||
unhandled(stmt);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ class Allocate;
|
|||
class BlockSync;
|
||||
class GridSync;
|
||||
class CpAsyncWait;
|
||||
class CpAsyncCommit;
|
||||
class ForLoop;
|
||||
class IfThenElse;
|
||||
class GridReduction;
|
||||
|
|
@ -163,6 +164,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
|
|||
virtual void handle(const kir::BlockSync*);
|
||||
virtual void handle(const kir::GridSync*);
|
||||
virtual void handle(const kir::CpAsyncWait*);
|
||||
virtual void handle(const kir::CpAsyncCommit*);
|
||||
virtual void handle(const kir::InitMagicZero*);
|
||||
virtual void handle(const kir::UpdateMagicZero*);
|
||||
virtual void handle(const kir::ForLoop*);
|
||||
|
|
@ -225,6 +227,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
|
|||
virtual void handle(kir::BlockSync* stmt);
|
||||
virtual void handle(kir::GridSync* stmt);
|
||||
virtual void handle(kir::CpAsyncWait* stmt);
|
||||
virtual void handle(kir::CpAsyncCommit* stmt);
|
||||
virtual void handle(kir::InitMagicZero* stmt);
|
||||
virtual void handle(kir::UpdateMagicZero* stmt);
|
||||
virtual void handle(kir::ForLoop* stmt);
|
||||
|
|
@ -328,6 +331,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
|
|||
virtual void mutate(kir::BlockSync*);
|
||||
virtual void mutate(kir::GridSync*);
|
||||
virtual void mutate(kir::CpAsyncWait*);
|
||||
virtual void mutate(kir::CpAsyncCommit*);
|
||||
virtual void mutate(kir::InitMagicZero*);
|
||||
virtual void mutate(kir::UpdateMagicZero*);
|
||||
virtual void mutate(kir::ForLoop*);
|
||||
|
|
|
|||
|
|
@ -93,7 +93,9 @@ std::string FusionExecutor::getStructuredCode(const std::string& kernel) {
|
|||
std::cout << "\n======= Codegen output for kernel: " << kernelName()
|
||||
<< " =======\n\n"
|
||||
<< code << "\n======================================\n\n";
|
||||
} else if (isDebugDumpEnabled(DebugDumpOption::CudaToFile)) {
|
||||
}
|
||||
if (isDebugDumpEnabled(DebugDumpOption::CudaToFile) ||
|
||||
isDebugDumpEnabled(DebugDumpOption::DebugInfo)) {
|
||||
std::stringstream file_name;
|
||||
file_name << "__tmp_kernel" << fusion_id_ << ".cu";
|
||||
std::cout << "PRINTING: " << file_name.str() << std::endl;
|
||||
|
|
|
|||
|
|
@ -799,7 +799,7 @@ kir::ExpressionEvaluator bindKernelInputs(
|
|||
extent->toString(),
|
||||
" to ",
|
||||
value,
|
||||
"but it's already set to ",
|
||||
" but it's already set to ",
|
||||
*prev_value);
|
||||
should_bind = false;
|
||||
}
|
||||
|
|
@ -925,9 +925,12 @@ std::pair<NvrtcFunction, std::string> nvrtcCompile(
|
|||
nvrtcProgram program; // NOLINT(cppcoreguidelines-init-variables)
|
||||
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "__tmp_kernel" << id << ".cu";
|
||||
std::string name = ss.str();
|
||||
FUSER_PERF_SCOPE("executor_utils::NvrtcCreateProgram");
|
||||
AT_CUDA_NVRTC_CHECK(at::globalContext().getNVRTC().nvrtcCreateProgram(
|
||||
&program, code.c_str(), nullptr, 0, nullptr, nullptr));
|
||||
&program, code.c_str(), name.c_str(), 0, nullptr, nullptr));
|
||||
}
|
||||
|
||||
ResourceGuard holdProgram([&] {
|
||||
|
|
@ -974,11 +977,13 @@ std::pair<NvrtcFunction, std::string> nvrtcCompile(
|
|||
args.push_back("--fmad=true");
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Add line info to generated kernels
|
||||
args.push_back("-lineinfo");
|
||||
#else
|
||||
if (isDebugDumpEnabled(DebugDumpOption::DebugInfo)) {
|
||||
args.push_back("-lineinfo");
|
||||
args.push_back("-G");
|
||||
args.push_back("--dopt=on");
|
||||
}
|
||||
#ifdef NDEBUG
|
||||
// Avoid excessive register usage from assertion
|
||||
args.push_back("-DNDEBUG");
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/codegen.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
|
|
@ -229,15 +230,6 @@ void Fusion::addOutput(Val* output) {
|
|||
all_tv_uses_valid_ = false;
|
||||
}
|
||||
|
||||
void Fusion::addOutput(WelfordResult& wr) {
|
||||
// Want to always make sure the avg gets added last
|
||||
// since avg will be the out() value of welfordOp,
|
||||
// and want to make it the top of the computeAt chain
|
||||
addOutput(wr.var_sum);
|
||||
addOutput(wr.n);
|
||||
addOutput(wr.avg);
|
||||
}
|
||||
|
||||
void Fusion::removeInput(Val* input) {
|
||||
auto find_input = std::find(inputs_.begin(), inputs_.end(), input);
|
||||
if (find_input != inputs_.end()) {
|
||||
|
|
@ -516,7 +508,19 @@ std::vector<Val*> Fusion::usedMathVals() {
|
|||
return used_math_vals;
|
||||
}
|
||||
|
||||
std::unordered_set<Expr*> Fusion::unordered_uses(Val* val) const {
|
||||
std::vector<Val*> Fusion::terminatingMathVals() {
|
||||
VectorOfUniqueEntries<Val*> result;
|
||||
auto used_vals = usedMathVals();
|
||||
for (auto v : used_vals) {
|
||||
// Locate the vals that are not expr outputs but have valid definitions.
|
||||
if (unordered_uses(v).empty() && v->definition() != nullptr) {
|
||||
result.pushBack(v);
|
||||
}
|
||||
}
|
||||
return result.vector();
|
||||
}
|
||||
|
||||
std::unordered_set<Expr*> Fusion::unordered_uses(const Val* val) const {
|
||||
return std::unordered_set<Expr*>(val->uses().begin(), val->uses().end());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -110,9 +110,6 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
|
|||
//! Register output as an output of the fusion
|
||||
void addOutput(Val* output);
|
||||
|
||||
//! Register output as an output of the fusion
|
||||
void addOutput(WelfordResult& output);
|
||||
|
||||
//! Deregister input as an input of the fusion
|
||||
void removeInput(Val* input);
|
||||
|
||||
|
|
@ -153,8 +150,16 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer {
|
|||
//! also included as they must show up in the final code.
|
||||
std::vector<Val*> usedMathVals();
|
||||
|
||||
//! Returns all vals that are produced by used math expressions and
|
||||
//! also do not have further consumers.
|
||||
//!
|
||||
//! In the case of an active multi-output expressions, the returned vector
|
||||
//! will include the expression outputs that did not lead to an fusion
|
||||
//! output.
|
||||
std::vector<Val*> terminatingMathVals();
|
||||
|
||||
//! Return all Exprs that use val
|
||||
std::unordered_set<Expr*> unordered_uses(Val* val) const;
|
||||
std::unordered_set<Expr*> unordered_uses(const Val* val) const;
|
||||
|
||||
//! Return the Expr that produces val
|
||||
Expr* definition(const Val* val) const;
|
||||
|
|
|
|||
|
|
@ -16,6 +16,12 @@ namespace jit {
|
|||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
namespace {
|
||||
|
||||
using GroupSet = VectorOfUniqueEntries<SegmentedGroup*>;
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::getNeighborGroups() {
|
||||
std::vector<NeighborGroup> neighbors;
|
||||
for (auto inp : producer_edges) {
|
||||
|
|
@ -75,7 +81,7 @@ std::vector<SegmentedGroup::NeighborGroup> SegmentedGroup::
|
|||
return {};
|
||||
}
|
||||
|
||||
std::vector<bool> can_merge(true, neighbors.size());
|
||||
std::vector<bool> can_merge(neighbors.size(), true);
|
||||
|
||||
// Find neighbors with a level that is only 1 differant than this groups level
|
||||
for (const auto i : c10::irange(neighbors.size())) {
|
||||
|
|
@ -155,16 +161,16 @@ void insertUniquePredicated(
|
|||
std::vector<Val*>& v,
|
||||
const std::vector<SegmentedEdge*>& e,
|
||||
PREDICATE pred) {
|
||||
std::unordered_set<Val*> to_add;
|
||||
std::transform(
|
||||
e.cbegin(),
|
||||
e.cend(),
|
||||
std::inserter(to_add, to_add.end()),
|
||||
[](SegmentedEdge* se) { return se->val; });
|
||||
VectorOfUniqueEntries<Val*> to_add;
|
||||
for (auto edge : e) {
|
||||
to_add.pushBack(edge->val);
|
||||
}
|
||||
|
||||
std::copy_if(
|
||||
to_add.begin(), to_add.end(), std::back_inserter(v), [pred](Val* val) {
|
||||
return pred(val);
|
||||
});
|
||||
to_add.vector().begin(),
|
||||
to_add.vector().end(),
|
||||
std::back_inserter(v),
|
||||
[pred](Val* val) { return pred(val); });
|
||||
}
|
||||
|
||||
void SegmentedGroup::finalize() {
|
||||
|
|
@ -811,7 +817,6 @@ void SegmentedFusion::finalize() {
|
|||
//! currently O(n^2). O(nlogn) would be a reasonable
|
||||
//! goal to achieve.
|
||||
class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
|
||||
using GroupSet = std::unordered_set<SegmentedGroup*>;
|
||||
using GroupSetOwningPtr = std::unique_ptr<GroupSet>;
|
||||
using DependencyMap = std::unordered_map<SegmentedGroup*, GroupSetOwningPtr>;
|
||||
|
||||
|
|
@ -829,7 +834,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
|
|||
const std::vector<SegmentedGroup*>& groups_to_check) {
|
||||
auto& producers_of_group = getAllKnownProducersSet(group);
|
||||
for (const auto& potential_producer : groups_to_check) {
|
||||
if (producers_of_group->count(potential_producer)) {
|
||||
if (producers_of_group->has(potential_producer)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
@ -841,7 +846,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
|
|||
if (it == known_producers_of_.end()) {
|
||||
return false;
|
||||
}
|
||||
return it->second->count(b);
|
||||
return it->second->has(b);
|
||||
}
|
||||
|
||||
bool isProducerOf(SegmentedGroup* a, SegmentedGroup* b) {
|
||||
|
|
@ -872,18 +877,14 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
|
|||
GroupSet values_between;
|
||||
auto& all_producers_of_consumer = known_producers_of_.at(consumer);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
all_producers_of_consumer->count(producer),
|
||||
all_producers_of_consumer->has(producer),
|
||||
"Fusion segment: Trying to compute path between two nodes that are not producer-consumer pairs");
|
||||
|
||||
std::copy_if(
|
||||
all_producers_of_consumer->begin(),
|
||||
all_producers_of_consumer->end(),
|
||||
std::inserter(values_between, values_between.end()),
|
||||
[this, producer](SegmentedGroup* producer_of_consumer) {
|
||||
// Checks if producer is on the producer path of this intermediate
|
||||
// node
|
||||
return known_producers_of_.at(producer_of_consumer)->count(producer);
|
||||
});
|
||||
for (auto producer_of_consumer : *all_producers_of_consumer) {
|
||||
if (known_producers_of_.at(producer_of_consumer)->has(producer)) {
|
||||
values_between.pushBack(producer_of_consumer);
|
||||
}
|
||||
}
|
||||
|
||||
return values_between;
|
||||
}
|
||||
|
|
@ -892,7 +893,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
|
|||
//! used for generating assertions after transforms
|
||||
bool isproducerMapDAG() const {
|
||||
for (auto& it : known_producers_of_) {
|
||||
if (it.second->count(it.first)) {
|
||||
if (it.second->has(it.first)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -909,7 +910,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
|
|||
void addConsumersToWorkList(SegmentedGroup* producer, GroupSet& to_visit) {
|
||||
for (auto e : producer->consumer_edges) {
|
||||
// A consumer wouldn't have been worked before any of its producer
|
||||
to_visit.insert(e->to);
|
||||
to_visit.pushBack(e->to);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -922,7 +923,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
|
|||
SegmentedGroup* from) {
|
||||
auto& producer_set_to_merge = *getAllKnownProducersSet(from);
|
||||
for (auto group : producer_set_to_merge) {
|
||||
getAllKnownProducersSet(into)->insert(group);
|
||||
getAllKnownProducersSet(into)->pushBack(group);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -943,8 +944,8 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
|
|||
|
||||
GroupSet intersection;
|
||||
for (auto group : smaller_group_set) {
|
||||
if (bigger_group_set.count(group)) {
|
||||
intersection.insert(group);
|
||||
if (bigger_group_set.has(group)) {
|
||||
intersection.pushBack(group);
|
||||
}
|
||||
}
|
||||
return intersection;
|
||||
|
|
@ -956,7 +957,7 @@ class GroupDependencyAnalysis : public NonCopyable, public SegmenterAnalysis {
|
|||
};
|
||||
|
||||
//! Finds the common producers of given set of groups
|
||||
GroupDependencyAnalysis::GroupSet GroupDependencyAnalysis::getCommonProducersOf(
|
||||
GroupSet GroupDependencyAnalysis::getCommonProducersOf(
|
||||
std::vector<SegmentedGroup*> groups) {
|
||||
if (groups.empty()) {
|
||||
return {};
|
||||
|
|
@ -1006,9 +1007,9 @@ void GroupDependencyAnalysis::mergeGroups(
|
|||
// update producer maps of other groups
|
||||
for (auto& it : known_producers_of_) {
|
||||
// for all groups that are produced by either a or b
|
||||
if (it.second->count(a) || it.second->count(b)) {
|
||||
if (it.second->has(a) || it.second->has(b)) {
|
||||
// insert ab as the new producer
|
||||
it.second->insert(ab);
|
||||
it.second->pushBack(ab);
|
||||
// all producers of both a and b are now producers of `it`
|
||||
mergeAllKnownProducersIntoFrom(it.first, ab);
|
||||
}
|
||||
|
|
@ -1054,7 +1055,7 @@ void GroupDependencyAnalysis::mergeGroups(
|
|||
it.second->erase(merged_producer);
|
||||
}
|
||||
// insert the new group as producer
|
||||
it.second->insert(merged);
|
||||
it.second->pushBack(merged);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1068,11 +1069,11 @@ void GroupDependencyAnalysis::computeAllProducers() {
|
|||
|
||||
// Collect source nodes, with no producers we are guaranteed
|
||||
// a source node on a DAG
|
||||
std::copy_if(
|
||||
segmented_fusion_->cgroups().begin(),
|
||||
segmented_fusion_->cgroups().end(),
|
||||
std::inserter(visited, visited.end()),
|
||||
[](SegmentedGroup* group) { return group->producer_edges.empty(); });
|
||||
for (auto group : segmented_fusion_->cgroups()) {
|
||||
if (group->producer_edges.empty()) {
|
||||
visited.pushBack(group);
|
||||
}
|
||||
}
|
||||
|
||||
// visited now only contain source nodes
|
||||
// they can go backward to nowhere
|
||||
|
|
@ -1086,20 +1087,18 @@ void GroupDependencyAnalysis::computeAllProducers() {
|
|||
if (std::all_of(
|
||||
visiting_group->producer_edges.begin(),
|
||||
visiting_group->producer_edges.end(),
|
||||
[&visited](SegmentedEdge* e) {
|
||||
return visited.count(e->from);
|
||||
})) {
|
||||
[&visited](SegmentedEdge* e) { return visited.has(e->from); })) {
|
||||
// filter multi-edges
|
||||
GroupSet producers_of_visiting_group;
|
||||
for (auto edge : visiting_group->producer_edges) {
|
||||
producers_of_visiting_group.insert(edge->from);
|
||||
producers_of_visiting_group.pushBack(edge->from);
|
||||
}
|
||||
|
||||
// populate all possible paths
|
||||
// from producer backward, including
|
||||
// the producer
|
||||
for (auto producer : producers_of_visiting_group) {
|
||||
getAllKnownProducersSet(visiting_group)->insert(producer);
|
||||
getAllKnownProducersSet(visiting_group)->pushBack(producer);
|
||||
mergeAllKnownProducersIntoFrom(visiting_group, producer);
|
||||
}
|
||||
to_update = visiting_group;
|
||||
|
|
@ -1109,7 +1108,7 @@ void GroupDependencyAnalysis::computeAllProducers() {
|
|||
if (to_update) {
|
||||
addConsumersToWorkList(to_update, to_visit);
|
||||
to_visit.erase(to_update);
|
||||
visited.insert(to_update);
|
||||
visited.pushBack(to_update);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false, "unreachable, original graph not a DAG");
|
||||
}
|
||||
|
|
@ -2060,7 +2059,6 @@ bool SegmentCandidateFinder::TranslateWelfordInFusion(
|
|||
//! This pass tries to merge nodes with the same reduction type based
|
||||
//! on the graph structure.
|
||||
class CombineReductions {
|
||||
using GroupSet = std::unordered_set<SegmentedGroup*>;
|
||||
using GroupVec = std::vector<SegmentedGroup*>;
|
||||
class ReductionSignature;
|
||||
|
||||
|
|
@ -2240,7 +2238,7 @@ class CombineReductions {
|
|||
groups_with_reductions_.begin(),
|
||||
groups_with_reductions_.end(),
|
||||
[&all_groups_to_merge](SegmentedGroup* group) {
|
||||
return all_groups_to_merge.count(group);
|
||||
return all_groups_to_merge.has(group);
|
||||
}),
|
||||
groups_with_reductions_.end());
|
||||
|
||||
|
|
@ -2374,7 +2372,7 @@ class CombineReductions {
|
|||
groups_with_reductions_.begin(),
|
||||
groups_with_reductions_.end(),
|
||||
[&groups_to_merge_set](SegmentedGroup* group) {
|
||||
return groups_to_merge_set.count(group);
|
||||
return groups_to_merge_set.has(group);
|
||||
}),
|
||||
groups_with_reductions_.end());
|
||||
|
||||
|
|
@ -2414,8 +2412,8 @@ class CombineReductions {
|
|||
maybe_consumer, maybe_producer)) {
|
||||
auto groups_to_check =
|
||||
dependency_analysis->valuesBetween(maybe_producer, maybe_consumer);
|
||||
groups_to_check.insert(maybe_producer);
|
||||
groups_to_check.insert(maybe_consumer);
|
||||
groups_to_check.pushBack(maybe_producer);
|
||||
groups_to_check.pushBack(maybe_consumer);
|
||||
|
||||
// Check that either no group has a reduction or all groups have the same
|
||||
// reduction signature
|
||||
|
|
@ -2428,13 +2426,13 @@ class CombineReductions {
|
|||
// output edge does not generate much saving of global memory access
|
||||
// we want to postpone merging these edges till the very final pass
|
||||
for (auto producer_edge_of_group : group->producer_edges) {
|
||||
if (groups_to_check.count(producer_edge_of_group->from) &&
|
||||
if (groups_to_check.has(producer_edge_of_group->from) &&
|
||||
producer_edge_of_group->val->isFusionOutput()) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
for (auto consumer_edge_of_group : group->consumer_edges) {
|
||||
if (groups_to_check.count(consumer_edge_of_group->to) &&
|
||||
if (groups_to_check.has(consumer_edge_of_group->to) &&
|
||||
consumer_edge_of_group->val->isFusionOutput()) {
|
||||
return {};
|
||||
}
|
||||
|
|
@ -2653,11 +2651,11 @@ void SegmentCandidateFinder::findSegments() {
|
|||
|
||||
// Expressions to exclude from segmentation because they're just derived from
|
||||
// unary ops on inputs to the complete fusion
|
||||
std::unordered_set<Expr*> excluded_inp_unary_exprs;
|
||||
VectorOfUniqueEntries<Expr*> excluded_inp_unary_exprs;
|
||||
|
||||
// "Terminating" outputs from the excluded input unary exprs, these will be
|
||||
// treated as complete fusion inputs.
|
||||
std::unordered_set<Val*> forwarded_inputs;
|
||||
VectorOfUniqueEntries<Val*> forwarded_inputs;
|
||||
{
|
||||
std::deque<Expr*> to_visit;
|
||||
for (auto inp : completeFusion()->inputs()) {
|
||||
|
|
@ -2677,8 +2675,8 @@ void SegmentCandidateFinder::findSegments() {
|
|||
}
|
||||
|
||||
if (expr->output(0)->uses().size() > 1) {
|
||||
excluded_inp_unary_exprs.emplace(expr);
|
||||
forwarded_inputs.emplace(expr->output(0));
|
||||
excluded_inp_unary_exprs.pushBack(expr);
|
||||
forwarded_inputs.pushBack(expr->output(0));
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -2735,7 +2733,7 @@ void SegmentCandidateFinder::findSegments() {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (excluded_inp_unary_exprs.count(expr)) {
|
||||
if (excluded_inp_unary_exprs.has(expr)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -527,13 +527,43 @@ void IndexCompute::handle(Swizzle2D* swizzle_2d) {
|
|||
const auto out_x_ind = out_x_it->second;
|
||||
const auto out_y_ind = out_y_it->second;
|
||||
|
||||
// Actual swizzle operation is handled via IndexSwizzle pass
|
||||
// all behavior in this pass is directly forward through the
|
||||
// index and extent.
|
||||
index_map_[in_x_id] = out_x_ind;
|
||||
index_map_[in_y_id] = out_y_ind;
|
||||
extent_map_[in_y_id] = getExtent(out_y_id);
|
||||
extent_map_[in_x_id] = getExtent(out_x_id);
|
||||
if (swizzle_mode_ == SwizzleMode::NoSwizzle ||
|
||||
swizzle_mode_ != swizzle_2d->swizzleMode()) {
|
||||
// Handle inactive swizzles by just passing through index
|
||||
// and extend information.
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
index_map_.count(in_x_id) == index_map_.count(in_y_id),
|
||||
"input index should be either both defined or both undefined");
|
||||
if (index_map_.count(in_x_id)) {
|
||||
// Only propagate original index through if
|
||||
// the input index hasn't been computed.
|
||||
// TODO:
|
||||
// This part should be cleaner once we remove the
|
||||
// second index traversal pass.
|
||||
return;
|
||||
}
|
||||
index_map_[in_x_id] = out_x_ind;
|
||||
index_map_[in_y_id] = out_y_ind;
|
||||
extent_map_[in_y_id] = getExtent(out_y_id);
|
||||
extent_map_[in_x_id] = getExtent(out_x_id);
|
||||
} else {
|
||||
// Generate integer swizzle math if the
|
||||
// swizzle is activated. See also
|
||||
// [Note on swizzle mode].
|
||||
|
||||
auto out_pair = IrBuilder::swizzle2DIntExpr(
|
||||
out_x_ind,
|
||||
out_y_ind,
|
||||
getExtent(out_x_id),
|
||||
getExtent(out_y_id),
|
||||
swizzle_2d->swizzleType());
|
||||
|
||||
index_map_[in_x_id] =
|
||||
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X);
|
||||
index_map_[in_y_id] =
|
||||
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y);
|
||||
}
|
||||
}
|
||||
|
||||
void IndexCompute::handle(Expr* e) {
|
||||
|
|
@ -616,9 +646,31 @@ IndexCompute::IndexCompute(
|
|||
reference_halo_extent_map_(std::move(reference_halo_extent_map)) {
|
||||
FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute");
|
||||
concrete_id_pass_ = true;
|
||||
swizzle_mode_ = SwizzleMode::Loop;
|
||||
}
|
||||
|
||||
void IndexCompute::run(const LoopIndexing& loop_indexing) {
|
||||
// Apply loop swizzles if there are any that outputs to
|
||||
// the loop domains.
|
||||
// Currently only support loop swizzles that directly output
|
||||
// to concrete loop domains and these are validated in
|
||||
// validate swizzle pass.
|
||||
// TODO:
|
||||
// will gradually enable replaying and mapping of loop
|
||||
// swizzles in the IR infrastructure and once that's piped
|
||||
// through this part of logic will be removed.
|
||||
std::unordered_set<Expr*> visited;
|
||||
for (auto loop_id : loop_indexing.loopDomains()) {
|
||||
auto loop_id_def = loop_id->definition();
|
||||
if (loop_id_def != nullptr && loop_id_def->isA<Swizzle2D>()) {
|
||||
if (visited.insert(loop_id_def).second) {
|
||||
handle(loop_id_def);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run through the loop indexing expressions and generate
|
||||
// the indexing integer math for the concrete ids.
|
||||
for (auto expr : loop_indexing.getBackwardExprList()) {
|
||||
handle(expr);
|
||||
}
|
||||
|
|
@ -955,6 +1007,7 @@ void IndexSwizzle::run() {
|
|||
UpdateLeafIndices update_leaves(td_, indexMap(), extentMap());
|
||||
index_map_ = update_leaves.indexMap();
|
||||
extent_map_ = update_leaves.extentMap();
|
||||
IndexCompute::swizzle_mode_ = SwizzleMode::Data;
|
||||
IndexCompute::run();
|
||||
}
|
||||
}
|
||||
|
|
@ -969,7 +1022,8 @@ void IndexSwizzle::handle(Expr* e) {
|
|||
return swizzled_ids_.find(id) != swizzled_ids_.end();
|
||||
}) ||
|
||||
(e->isA<Swizzle2D>() &&
|
||||
e->as<Swizzle2D>()->swizzleType() != Swizzle2DType::NoSwizzle);
|
||||
e->as<Swizzle2D>()->swizzleType() != Swizzle2DType::NoSwizzle &&
|
||||
e->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Data);
|
||||
if (!needs_update) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -983,8 +1037,6 @@ void IndexSwizzle::handle(Expr* e) {
|
|||
void IndexSwizzle::handle(Swizzle2D* swizzle_2d) {
|
||||
auto out_x_id = swizzle_2d->outX();
|
||||
auto out_y_id = swizzle_2d->outY();
|
||||
auto in_x_id = swizzle_2d->inX();
|
||||
auto in_y_id = swizzle_2d->inY();
|
||||
|
||||
auto out_x_it = index_map_.find(out_x_id);
|
||||
auto out_y_it = index_map_.find(out_y_id);
|
||||
|
|
@ -998,28 +1050,7 @@ void IndexSwizzle::handle(Swizzle2D* swizzle_2d) {
|
|||
out_x_it != index_map_.end() && out_y_it != index_map_.end(),
|
||||
"Swizzle output indices were not propagated through");
|
||||
|
||||
const auto out_x_ind = out_x_it->second;
|
||||
const auto out_y_ind = out_y_it->second;
|
||||
|
||||
// Can propagate zero only for a few
|
||||
// swizzle types (TODO)
|
||||
|
||||
if (swizzle_2d->swizzleType() != Swizzle2DType::NoSwizzle) {
|
||||
auto out_pair = IrBuilder::swizzle2DIntExpr(
|
||||
out_x_ind,
|
||||
out_y_ind,
|
||||
getExtent(out_x_id),
|
||||
getExtent(out_y_id),
|
||||
swizzle_2d->swizzleType());
|
||||
|
||||
index_map_[in_x_id] =
|
||||
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X);
|
||||
index_map_[in_y_id] =
|
||||
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y);
|
||||
|
||||
swizzled_ids_.insert(in_x_id);
|
||||
swizzled_ids_.insert(in_y_id);
|
||||
}
|
||||
IndexCompute::handle(swizzle_2d);
|
||||
}
|
||||
|
||||
// Used for local and shared index mapping. Returns a map from loops
|
||||
|
|
@ -1125,7 +1156,16 @@ indexMapFromTV(
|
|||
// Similarly for local memory tensors, zero replacement can be
|
||||
// only done when there's a matching domain with the same
|
||||
// parallel type
|
||||
(loop->iter_domain()->isThread() && is_local && same_parallel_type)) {
|
||||
(loop->iter_domain()->isThread() && is_local && same_parallel_type) ||
|
||||
// MMA operands are currently indexed in units of "fragments",
|
||||
// so each mma tensor domain would be zero-ed and the tensor index
|
||||
// calculated here would be the fragment index.
|
||||
// TODO: This is a quick WAR to enable iterating over a register array
|
||||
// of MMA fragments, so we could generate unrolled mma loops.
|
||||
// Eventually we still want IdGraph to be able to analyze the
|
||||
// in-register layout of mma fragments for more unified indexing math
|
||||
// as well as more flexibility in swizzling loops.
|
||||
(loop->iter_domain()->isMma() && !as_consumer)) {
|
||||
idx = GpuLower::current()->kernel()->zeroVal();
|
||||
zero_loops.insert(loop);
|
||||
} else {
|
||||
|
|
@ -1139,8 +1179,11 @@ indexMapFromTV(
|
|||
}
|
||||
|
||||
if (loop == double_buffer_loop) {
|
||||
auto stage_depth =
|
||||
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
|
||||
loop->iter_domain());
|
||||
idx = SimplifyingIrBuilder::addExpr(
|
||||
idx, GpuLower::current()->kernel()->oneVal());
|
||||
idx, SimplifyingIrBuilder::create<Int>(stage_depth - 1));
|
||||
}
|
||||
|
||||
loop_to_ind_map[loop] = idx;
|
||||
|
|
@ -1802,14 +1845,16 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
|
|||
}
|
||||
}
|
||||
|
||||
if (producer_tv->isDoubleBuffered()) {
|
||||
if (producer_tv->isDoubleBuffered() || producer_tv->isCircularBuffered()) {
|
||||
auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
|
||||
producer_tv, loops, true);
|
||||
if (db_loop != nullptr) {
|
||||
auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor(
|
||||
db_loop->iter_domain());
|
||||
auto loop_index =
|
||||
db_loop->isTrivial() ? db_loop->start() : db_loop->index();
|
||||
auto db_switch_index = SimplifyingIrBuilder::modExpr(
|
||||
loop_index, SimplifyingIrBuilder::create<Int>(2));
|
||||
loop_index, SimplifyingIrBuilder::create<Int>(stage_depth));
|
||||
auto original_alloc_size =
|
||||
gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv);
|
||||
auto db_strided_index =
|
||||
|
|
@ -2068,14 +2113,36 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
|
|||
TORCH_INTERNAL_ASSERT(
|
||||
strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size());
|
||||
|
||||
if (consumer_tv->isDoubleBuffered()) {
|
||||
auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
|
||||
consumer_tv, loops, true);
|
||||
if (db_loop != nullptr) {
|
||||
auto db_switch_index = SimplifyingIrBuilder::subExpr(
|
||||
gpu_lower->kernel()->oneVal(),
|
||||
SimplifyingIrBuilder::modExpr(
|
||||
db_loop->index(), SimplifyingIrBuilder::create<Int>(2)));
|
||||
if (consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) {
|
||||
auto db_loop =
|
||||
gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops);
|
||||
auto stage_depth =
|
||||
gpu_lower->doubleBufferInfo().getStageDepthFor(db_loop->iter_domain());
|
||||
bool is_circular_buffer_loop = stage_depth > 2;
|
||||
bool is_prolog =
|
||||
db_loop->doubleBufferLoopStage() == DoubleBufferLoopStage::Prolog;
|
||||
|
||||
Val* db_switch_index = nullptr;
|
||||
|
||||
// In double buffered we don't materialize the prolog loop as there will
|
||||
// be only one iteration. In circular buffer case we materialize the
|
||||
// prolog loop as well covering the first N-1 iterations, N being the
|
||||
// stage depth.
|
||||
if (!is_prolog || is_circular_buffer_loop) {
|
||||
if (is_prolog && is_circular_buffer_loop) {
|
||||
// The buffer switching logic is the same as original index
|
||||
// in the case of circular buffer prolog.
|
||||
db_switch_index = db_loop->index();
|
||||
} else {
|
||||
// Switching index generated for main loop or epilog component.
|
||||
db_switch_index = SimplifyingIrBuilder::modExpr(
|
||||
SimplifyingIrBuilder::addExpr(
|
||||
db_loop->index(),
|
||||
SimplifyingIrBuilder::create<Int>(stage_depth - 1)),
|
||||
SimplifyingIrBuilder::create<Int>(stage_depth));
|
||||
}
|
||||
|
||||
// Use the generated switching buffer index to access the buffer space.
|
||||
auto original_alloc_size =
|
||||
gpu_lower->doubleBufferInfo().getOriginalAllocSize(consumer_tv);
|
||||
auto db_strided_index =
|
||||
|
|
@ -2110,7 +2177,8 @@ std::vector<Val*> Index::getProducerStridedIndices(
|
|||
TORCH_INTERNAL_ASSERT(
|
||||
strided_indices.size() ==
|
||||
producer->getMaybeRFactorDomain().size() +
|
||||
(producer->isDoubleBuffered() ? 1 : 0));
|
||||
(producer->isDoubleBuffered() || producer->isCircularBuffered() ? 1
|
||||
: 0));
|
||||
|
||||
return strided_indices;
|
||||
}
|
||||
|
|
@ -2721,8 +2789,8 @@ std::pair<Val*, Val*> hoistPredicates(
|
|||
Val* stop_index,
|
||||
const std::vector<kir::ForLoop*>& loops,
|
||||
std::vector<IterDomain*> loop_domains,
|
||||
const std::unordered_map<IterDomain*, Val*> start_initial_loop_index_map,
|
||||
const std::unordered_map<IterDomain*, Val*> stop_initial_loop_index_map,
|
||||
const std::unordered_map<IterDomain*, Val*>& start_initial_loop_index_map,
|
||||
const std::unordered_map<IterDomain*, Val*>& stop_initial_loop_index_map,
|
||||
kir::ForLoop* unswitch_or_vec_loop,
|
||||
IterDomain* predicated_consumer_id,
|
||||
TensorView* predicated_consumer_tv) {
|
||||
|
|
@ -2771,6 +2839,22 @@ std::pair<Val*, Val*> hoistPredicates(
|
|||
return {hoisted_start_index, hoisted_stop_index};
|
||||
}
|
||||
|
||||
// Updates a loop index map with a loop index protected by magic zero
|
||||
std::unordered_map<IterDomain*, Val*> updateInitialLoopIndexMap(
|
||||
const std::unordered_map<IterDomain*, Val*>& initial_loop_index_map,
|
||||
const IndexMagicZeroInfo& magic_zero_info) {
|
||||
if (magic_zero_info.original_loop_index != nullptr) {
|
||||
TORCH_INTERNAL_ASSERT(magic_zero_info.protected_loop_index != nullptr);
|
||||
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
|
||||
magic_zero_info.loop_id, IdMappingMode::EXACT);
|
||||
auto updated_map = initial_loop_index_map;
|
||||
updated_map[concrete_loop_id] = magic_zero_info.protected_loop_index;
|
||||
return updated_map;
|
||||
} else {
|
||||
return initial_loop_index_map;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Returns predicates and the concrete (by loop map) root domains they cover
|
||||
|
|
@ -2886,13 +2970,38 @@ std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
|
|||
auto stop_index = consumer_stop_indexing_it->second;
|
||||
auto start_index = consumer_start_index_map.at(contig_id);
|
||||
|
||||
IndexMagicZeroInfo start_magic_zero_info;
|
||||
IndexMagicZeroInfo stop_magic_zero_info;
|
||||
|
||||
// When the start and stop indices are not the same, apply the
|
||||
// magic-zero protection separately for both of them.
|
||||
if (stop_index != start_index) {
|
||||
start_magic_zero_info = protectPredicateIndexWithMagicZero(
|
||||
start_index, start_indexing_from_idgraph, loops);
|
||||
stop_magic_zero_info = protectPredicateIndexWithMagicZero(
|
||||
stop_index, stop_indexing_from_idgraph, loops);
|
||||
} else {
|
||||
stop_magic_zero_info = protectPredicateIndexWithMagicZero(
|
||||
stop_index, stop_indexing_from_idgraph, loops);
|
||||
start_magic_zero_info = stop_magic_zero_info;
|
||||
}
|
||||
|
||||
start_index = start_magic_zero_info.index;
|
||||
stop_index = stop_magic_zero_info.index;
|
||||
|
||||
// Update the loop-index map with the magic-zero protection info
|
||||
// before passing it to the hoisting function
|
||||
std::tie(start_index, stop_index) = hoistPredicates(
|
||||
start_index,
|
||||
stop_index,
|
||||
loops,
|
||||
stop_indexing_from_idgraph.resolved_loop_domains,
|
||||
start_indexing_from_idgraph.initial_concrete_index_map,
|
||||
stop_indexing_from_idgraph.initial_concrete_index_map,
|
||||
updateInitialLoopIndexMap(
|
||||
start_indexing_from_idgraph.initial_concrete_index_map,
|
||||
start_magic_zero_info),
|
||||
updateInitialLoopIndexMap(
|
||||
stop_indexing_from_idgraph.initial_concrete_index_map,
|
||||
stop_magic_zero_info),
|
||||
unswitch_or_vec_loop,
|
||||
contig_id,
|
||||
consumer_tv);
|
||||
|
|
@ -2935,19 +3044,6 @@ std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
|
|||
return pred_info_vec;
|
||||
}
|
||||
|
||||
bool Index::protectWithMagicZero(
|
||||
kir::ForLoop* loop,
|
||||
IterDomain* reference_domain,
|
||||
Val* ind) {
|
||||
bool ref_dom_simple =
|
||||
(reference_domain == nullptr ? true
|
||||
: reference_domain->definition() != nullptr);
|
||||
bool ind_simple =
|
||||
(ind == nullptr ? true
|
||||
: ind->definition() != nullptr && !ind->isZeroInt());
|
||||
return loop->isUnrolled() && (!ref_dom_simple || !ind_simple);
|
||||
}
|
||||
|
||||
RootPredicateInfo RootPredicateInfo::getFalseInfo() {
|
||||
RootPredicateInfo info;
|
||||
info.start_predicate_ = GpuLower::current()->kernel()->falseVal();
|
||||
|
|
|
|||
|
|
@ -130,6 +130,13 @@ class IndexCompute : public BackwardVisitor {
|
|||
// map rather than the actual IDs used in the ID expressions.
|
||||
bool concrete_id_pass_ = false;
|
||||
|
||||
// Mode of swizzle that are activated in this index compute
|
||||
// instance. Will treat swizzles of different mode as no-op.
|
||||
// Currently data mode swizzles are handled same as before in IndexSwizzle
|
||||
// pass, while loop mode swizzles are handled early on in concrete indexing
|
||||
// pass. See also [Note on swizzle mode]
|
||||
SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle;
|
||||
|
||||
public:
|
||||
const std::unordered_map<IterDomain*, Val*>& indexMap() const {
|
||||
return index_map_;
|
||||
|
|
@ -361,18 +368,6 @@ class Index {
|
|||
const std::vector<kir::ForLoop*>& loops,
|
||||
kir::ForLoop* unswitch_or_vec_loop,
|
||||
bool padding_predicate);
|
||||
|
||||
// Determine if we may run into over reuse of predicates or registers in the
|
||||
// compiler. If the loop can be unrolled and the index and domain are not
|
||||
// "simple" we likely want the loop protected.
|
||||
//
|
||||
// Magic zero protection should only be done for global memory and predicates.
|
||||
// We should avoid use on registers. Shared memory does not require it, but
|
||||
// likely wouldn't hurt.
|
||||
static bool protectWithMagicZero(
|
||||
kir::ForLoop* loop,
|
||||
IterDomain* reference_domain = nullptr,
|
||||
Val* ind = nullptr);
|
||||
};
|
||||
|
||||
// Used for local and shared index mapping. Returns a map from loops
|
||||
|
|
|
|||
|
|
@ -10,22 +10,10 @@ namespace jit {
|
|||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
bool InlinePropagatorSelector::allowC2P(TensorView* from, TensorView* to) {
|
||||
return selected_.count(to) > 0;
|
||||
}
|
||||
|
||||
bool InlinePropagatorSelector::allowP2C(TensorView* from, TensorView* to) {
|
||||
// If the producer is in the selected set, then the consumer must also be
|
||||
// replayed to obtain a compatible loop structure so that this producer
|
||||
// can be consumed in this loop.
|
||||
return selected_.count(from) > 0 || selected_.count(to) > 0;
|
||||
}
|
||||
|
||||
bool InlinePropagatorSelector::allowSibling(TensorView* from, TensorView* to) {
|
||||
return true;
|
||||
}
|
||||
|
||||
MaxPosCalculator::MaxPosCalculator(ComputeAtMode mode) : mode_(mode) {
|
||||
MaxPosCalculator::MaxPosCalculator(
|
||||
ComputeAtMode mode,
|
||||
std::unordered_set<IterDomain*> uninlinable_ids)
|
||||
: mode_(mode), uninlinable_ids_(std::move(uninlinable_ids)) {
|
||||
buildUnmappableDims();
|
||||
}
|
||||
|
||||
|
|
@ -65,6 +53,10 @@ bool MaxPosCalculator::isAllowedID(
|
|||
allowed = allowed && !id->isReduction();
|
||||
}
|
||||
|
||||
if (uninlinable_ids_.count(id)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!allow_vectorize) {
|
||||
// Avoid inlining if marked as Vectorize or Group. In the case of
|
||||
// BestEffort and MostInlined modes, avoid Unroll as well.
|
||||
|
|
@ -121,6 +113,13 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
|
|||
|
||||
for (size_t producer_pos = 0; producer_pos < producer->nDims();
|
||||
producer_pos++) {
|
||||
// If the producer position is mismatching with the consumer, then we can
|
||||
// not inline into this position, otherwise the max producer position of
|
||||
// the consumer will become invalid and expression sort will fail.
|
||||
if (TransformReplay::getMatchedLeafPosWithoutReplayCasP(
|
||||
consumer, producer, producer_pos + 1) < 0) {
|
||||
return producer_pos;
|
||||
}
|
||||
auto map_it = p2c_replay_map.find(producer->axis(producer_pos));
|
||||
if (map_it != p2c_replay_map.end()) {
|
||||
auto c_id = map_it->second;
|
||||
|
|
@ -147,9 +146,17 @@ size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) {
|
|||
}
|
||||
|
||||
void InlinePropagator::setCAPos(TensorView* tv) {
|
||||
bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator);
|
||||
size_t pos = mapped_reference_pos_.at(tv);
|
||||
if (debug) {
|
||||
std::cout << " Setting CA pos of " << tv << ":" << std::endl;
|
||||
std::cout << " mapped position: " << pos << std::endl;
|
||||
}
|
||||
if ((selected_.empty() || selected_.count(tv)) && !tv->isFusionInput()) {
|
||||
auto max_pos = getMaxPosAll(tv);
|
||||
if (debug) {
|
||||
std::cout << " max inlinable position: " << max_pos << std::endl;
|
||||
}
|
||||
if (mode_ == ComputeAtMode::Standard) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
pos <= max_pos,
|
||||
|
|
@ -159,14 +166,32 @@ void InlinePropagator::setCAPos(TensorView* tv) {
|
|||
pos,
|
||||
", max position that's allowed is ",
|
||||
max_pos);
|
||||
} else {
|
||||
} else if (mode_ == ComputeAtMode::BestEffort) {
|
||||
pos = std::min<size_t>(pos, max_pos);
|
||||
} else {
|
||||
pos = max_pos;
|
||||
}
|
||||
// hoist inner most broadcast
|
||||
while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) {
|
||||
pos--;
|
||||
}
|
||||
tv->setComputeAt(pos);
|
||||
auto current_ca_pos = tv->getComputeAtPosition();
|
||||
if (debug) {
|
||||
std::cout << " current CA position: " << current_ca_pos << std::endl;
|
||||
}
|
||||
if (pos > current_ca_pos) {
|
||||
if (debug) {
|
||||
std::cout << " new CA position: " << pos << std::endl;
|
||||
}
|
||||
tv->setComputeAt(pos);
|
||||
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
|
||||
needs_update_max_producer_.insert(consumer_tv);
|
||||
}
|
||||
} else if (debug) {
|
||||
std::cout << " CA position not changed" << std::endl;
|
||||
}
|
||||
} else if (debug) {
|
||||
std::cout << " tensor not selected, skip" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -174,8 +199,9 @@ InlinePropagator::InlinePropagator(
|
|||
TensorView* reference,
|
||||
int64_t reference_pos,
|
||||
ComputeAtMode mode,
|
||||
std::unordered_set<TensorView*> selected)
|
||||
: max_pos_calc(mode),
|
||||
std::unordered_set<TensorView*> selected,
|
||||
std::unordered_set<IterDomain*> uninlinable_ids)
|
||||
: max_pos_calc(mode, std::move(uninlinable_ids)),
|
||||
selected_(std::move(selected)),
|
||||
reference_(reference),
|
||||
mode_(mode) {
|
||||
|
|
@ -194,69 +220,150 @@ InlinePropagator::InlinePropagator(
|
|||
reference_pos_ = reference_pos;
|
||||
}
|
||||
|
||||
void InlinePropagator::setUp() {
|
||||
bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator);
|
||||
mapped_reference_pos_[reference_] = reference_pos_;
|
||||
if (debug) {
|
||||
std::cout << "InlinePropagator::setUp" << std::endl;
|
||||
std::cout << " reference: " << reference_ << " @ " << reference_pos_
|
||||
<< std::endl;
|
||||
}
|
||||
setCAPos(reference_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Try to find the aligned position on consumer's domain corresponding to the
|
||||
// compute at position of producer domain. Used in InlinePropagator pass only.
|
||||
// No checking on actual producer-consumer relationship.
|
||||
unsigned int getConsumerPosAlignedToProducerCA(
|
||||
TensorView* consumer,
|
||||
TensorView* producer) {
|
||||
// Locate consumer's position that aligns with
|
||||
// the producer's new compute at axis. We need broadcast axes forwarded so we
|
||||
// need to replay PasC as CasP will not forward braodcast dims. For example
|
||||
// if we have:
|
||||
// T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
|
||||
// produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
|
||||
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
|
||||
// NVFuserTest.FusionComplexBCast1_CUDA
|
||||
|
||||
auto disjoint_sets =
|
||||
BestEffortReplay::replayPasC(
|
||||
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer))
|
||||
.getDisjointSets();
|
||||
|
||||
// Find the innermost position of consumer that has
|
||||
// been mapped within the producer ca axis.
|
||||
unsigned int consumer_pos = consumer->nDims();
|
||||
while (consumer_pos > 0) {
|
||||
auto consumer_id = consumer->axis((int)consumer_pos - 1);
|
||||
auto p_dom = producer->domain()->domain();
|
||||
if (std::any_of(
|
||||
p_dom.begin(),
|
||||
p_dom.begin() + producer->getComputeAtPosition(),
|
||||
[&consumer_id, &disjoint_sets](IterDomain* p_id) {
|
||||
return disjoint_sets.permissiveAreMapped(consumer_id, p_id);
|
||||
})) {
|
||||
break;
|
||||
}
|
||||
consumer_pos--;
|
||||
}
|
||||
|
||||
return consumer_pos;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void InlinePropagator::tearDown() {
|
||||
for (auto consumer : needs_update_max_producer_) {
|
||||
unsigned int consumer_pos = 0;
|
||||
for (auto producer : ir_utils::producerTvsOf(consumer)) {
|
||||
consumer_pos = std::max(
|
||||
consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer));
|
||||
}
|
||||
consumer->setMaxProducer(consumer_pos);
|
||||
}
|
||||
}
|
||||
|
||||
void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) {
|
||||
if (is_first_) {
|
||||
is_first_ = false;
|
||||
mapped_reference_pos_[reference_] = reference_pos_;
|
||||
setCAPos(reference_);
|
||||
bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator);
|
||||
if (debug) {
|
||||
std::cout << "InlinePropagator::propagateC2P" << std::endl;
|
||||
std::cout << " from: " << from << std::endl;
|
||||
std::cout << " to: " << to << std::endl;
|
||||
}
|
||||
// Step 1: find mapped_reference_pos_[to]
|
||||
int from_pos;
|
||||
if (mode_ != ComputeAtMode::MostInlined) {
|
||||
from_pos = mapped_reference_pos_.at(from);
|
||||
} else {
|
||||
from_pos = from->nDims();
|
||||
}
|
||||
int from_pos = mapped_reference_pos_.at(from);
|
||||
auto to_pos =
|
||||
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos);
|
||||
TORCH_CHECK(
|
||||
to_pos >= 0,
|
||||
"Unable to propagate CA position from consumer ",
|
||||
from,
|
||||
" at ",
|
||||
from_pos,
|
||||
" to producer ",
|
||||
to,
|
||||
" because this would require replay.");
|
||||
if (mode_ == ComputeAtMode::Standard) {
|
||||
TORCH_CHECK(
|
||||
to_pos >= 0,
|
||||
"Unable to propagate CA position from consumer ",
|
||||
from,
|
||||
" at ",
|
||||
from_pos,
|
||||
" to producer ",
|
||||
to,
|
||||
" because this would require replay.");
|
||||
} else {
|
||||
// For MostInlined and BestEffort inline propagation, we allow the DAG to
|
||||
// be not replayed fully consistently. For such case, we just don't inline
|
||||
// into the mismatched dimension.
|
||||
while (to_pos < 0) {
|
||||
from_pos--;
|
||||
to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(
|
||||
to, from, from_pos);
|
||||
}
|
||||
}
|
||||
mapped_reference_pos_[to] = to_pos;
|
||||
// Step 2: set CA position of `to`
|
||||
setCAPos(to);
|
||||
}
|
||||
|
||||
void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) {
|
||||
if (is_first_) {
|
||||
is_first_ = false;
|
||||
mapped_reference_pos_[reference_] = reference_pos_;
|
||||
setCAPos(reference_);
|
||||
bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator);
|
||||
if (debug) {
|
||||
std::cout << "InlinePropagator::propagateP2C" << std::endl;
|
||||
std::cout << " from: " << from << std::endl;
|
||||
std::cout << " to: " << to << std::endl;
|
||||
}
|
||||
// Step 1: find mapped_reference_pos_[to]
|
||||
int from_pos;
|
||||
if (mode_ != ComputeAtMode::MostInlined) {
|
||||
from_pos = mapped_reference_pos_.at(from);
|
||||
} else {
|
||||
from_pos = from->nDims();
|
||||
}
|
||||
int from_pos = mapped_reference_pos_.at(from);
|
||||
auto to_pos =
|
||||
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos);
|
||||
TORCH_CHECK(
|
||||
to_pos >= 0,
|
||||
"Unable to propagate CA position from producer ",
|
||||
from,
|
||||
" at ",
|
||||
from_pos,
|
||||
" to consumer ",
|
||||
to,
|
||||
" because this would require replay.");
|
||||
if (mode_ == ComputeAtMode::Standard) {
|
||||
TORCH_CHECK(
|
||||
to_pos >= 0,
|
||||
"Unable to propagate CA position from producer ",
|
||||
from,
|
||||
" at ",
|
||||
from_pos,
|
||||
" to consumer ",
|
||||
to,
|
||||
" because this would require replay.");
|
||||
} else {
|
||||
// For MostInlined and BestEffort inline propagation, we allow the DAG to
|
||||
// be not replayed fully consistently. For such case, we just don't inline
|
||||
// into the mismatched dimension.
|
||||
while (to_pos < 0) {
|
||||
from_pos--;
|
||||
to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(
|
||||
to, from, from_pos);
|
||||
}
|
||||
}
|
||||
mapped_reference_pos_[to] = to_pos;
|
||||
// Step 2: set CA position of `to`
|
||||
setCAPos(to);
|
||||
}
|
||||
|
||||
void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
|
||||
if (is_first_) {
|
||||
is_first_ = false;
|
||||
mapped_reference_pos_[reference_] = reference_pos_;
|
||||
setCAPos(reference_);
|
||||
bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator);
|
||||
if (debug) {
|
||||
std::cout << "InlinePropagator::propagateSibling" << std::endl;
|
||||
std::cout << " from: " << from << std::endl;
|
||||
std::cout << " to: " << to << std::endl;
|
||||
}
|
||||
// Step 1: find mapped_reference_pos_[to]
|
||||
auto from_pos = mapped_reference_pos_.at(from);
|
||||
|
|
@ -272,96 +379,6 @@ void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) {
|
|||
setCAPos(to);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Try to find the aligned position on consumer's domain corresponding to the
|
||||
// compute at position of producer domain. Used in computeAt pass only. No
|
||||
// checking on actual producer-consumer relationship.
|
||||
unsigned int getConsumerPosAlignedToProducerCA(
|
||||
TensorView* consumer,
|
||||
TensorView* producer) {
|
||||
// Locate consumer's position that aligns with
|
||||
// the producer's new compute at axis. We need broadcast axes forwarded so we
|
||||
// need to replay PasC as CasP will not forward braodcast dims. For example
|
||||
// if we have:
|
||||
// T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 )
|
||||
// produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will
|
||||
// have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to
|
||||
// NVFuserTest.FusionComplexBCast1_CUDA
|
||||
|
||||
auto c2p_map =
|
||||
BestEffortReplay::replayPasC(
|
||||
producer,
|
||||
consumer,
|
||||
-1,
|
||||
// Compute at root domain may not be valid here, as all
|
||||
// producers don't have to be able to map into consumer at
|
||||
// max producer position. Since computeAt should be valid
|
||||
// and this mechanism is only intended to lower produce
|
||||
// position of consumer, we can simply use the pairwise map.
|
||||
PairwiseRootDomainMap(producer, consumer))
|
||||
.getReplay();
|
||||
|
||||
// Find the innermost position of consumer that has
|
||||
// been mapped within the producer ca axis.
|
||||
unsigned int consumer_pos = consumer->nDims();
|
||||
while (consumer_pos > 0) {
|
||||
auto consumer_id = consumer->axis((int)consumer_pos - 1);
|
||||
auto p_dom = producer->domain()->domain();
|
||||
if (std::any_of(
|
||||
p_dom.begin(),
|
||||
p_dom.begin() + producer->getComputeAtPosition(),
|
||||
[&consumer_id, &c2p_map](IterDomain* p_id) {
|
||||
auto c_id_it = c2p_map.find(consumer_id);
|
||||
if (c_id_it != c2p_map.end()) {
|
||||
return c_id_it->second == p_id;
|
||||
}
|
||||
return false;
|
||||
})) {
|
||||
break;
|
||||
}
|
||||
consumer_pos--;
|
||||
}
|
||||
|
||||
return consumer_pos;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Try to find the aligned position on consumer's domain corresponding to the
|
||||
// compute at position of producer domain.
|
||||
void MaxProducerPosUpdater::handle(TensorView* consumer) {
|
||||
unsigned int consumer_pos = 0;
|
||||
for (auto producer : ir_utils::producerTvsOf(consumer)) {
|
||||
consumer_pos = std::max(
|
||||
consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer));
|
||||
}
|
||||
consumer->setMaxProducer(consumer_pos);
|
||||
}
|
||||
|
||||
void MaxProducerPosUpdater::propagateC2P(TensorView* from, TensorView* to) {
|
||||
if (updated_.empty()) {
|
||||
// handle the reference tensor
|
||||
updated_.insert(nullptr);
|
||||
propagateC2P(nullptr, from);
|
||||
}
|
||||
for (auto consumer_tv : ir_utils::consumerTvsOf(to)) {
|
||||
if (updated_.count(consumer_tv) > 0) {
|
||||
continue;
|
||||
}
|
||||
handle(consumer_tv);
|
||||
updated_.insert(consumer_tv);
|
||||
}
|
||||
}
|
||||
|
||||
void MaxProducerPosUpdater::propagateP2C(TensorView* from, TensorView* to) {
|
||||
propagateC2P(from, to);
|
||||
}
|
||||
|
||||
void MaxProducerPosUpdater::propagateSibling(TensorView* from, TensorView* to) {
|
||||
propagateC2P(from, to);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -11,31 +11,16 @@ namespace jit {
|
|||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
// Simple selector that only propagates across tensor views in the provided
|
||||
// unordered_set. Will also propagate to all consumers of those tensors, and the
|
||||
// siblings of those tensors.
|
||||
class TORCH_CUDA_CU_API InlinePropagatorSelector
|
||||
: public MaxInfoSpanningTree::Selector {
|
||||
std::unordered_set<TensorView*> selected_;
|
||||
|
||||
public:
|
||||
virtual bool allowC2P(TensorView* from, TensorView* to) override;
|
||||
virtual bool allowP2C(TensorView* from, TensorView* to) override;
|
||||
virtual bool allowSibling(TensorView* from, TensorView* to) override;
|
||||
|
||||
InlinePropagatorSelector(std::unordered_set<TensorView*> selected)
|
||||
: selected_(std::move(selected)){};
|
||||
const std::unordered_set<TensorView*>& selected() const {
|
||||
return selected_;
|
||||
}
|
||||
};
|
||||
|
||||
class TORCH_CUDA_CU_API MaxPosCalculator {
|
||||
ComputeAtMode mode_ = ComputeAtMode::Standard;
|
||||
|
||||
// Root domains in producer that's unmappable to any of its consumers
|
||||
std::unordered_set<IterDomain*> unmappable_dims_;
|
||||
|
||||
// User set IterDomains to not inline, used in schedulers to avoid inlining
|
||||
// trivial reductions
|
||||
std::unordered_set<IterDomain*> uninlinable_ids_;
|
||||
|
||||
// Iterate through all TVs and collect the dimensions of each TV that don't
|
||||
// map to all its consumer TVs.
|
||||
void buildUnmappableDims();
|
||||
|
|
@ -65,7 +50,9 @@ class TORCH_CUDA_CU_API MaxPosCalculator {
|
|||
TensorView* producer,
|
||||
TensorView* consumer) const;
|
||||
|
||||
MaxPosCalculator(ComputeAtMode mode);
|
||||
MaxPosCalculator(
|
||||
ComputeAtMode mode,
|
||||
std::unordered_set<IterDomain*> uninlinable_ids = {});
|
||||
};
|
||||
|
||||
// Propagate inline position to the `selected` tensors in the DAG. If `selected`
|
||||
|
|
@ -91,17 +78,18 @@ class TORCH_CUDA_CU_API InlinePropagator
|
|||
|
||||
const MaxPosCalculator max_pos_calc;
|
||||
std::unordered_set<TensorView*> selected_;
|
||||
std::unordered_set<TensorView*> needs_update_max_producer_;
|
||||
TensorView* reference_;
|
||||
size_t reference_pos_;
|
||||
ComputeAtMode mode_ = ComputeAtMode::Standard;
|
||||
bool is_first_ = true;
|
||||
|
||||
public:
|
||||
InlinePropagator(
|
||||
TensorView* reference,
|
||||
int64_t reference_pos,
|
||||
ComputeAtMode mode = ComputeAtMode::Standard,
|
||||
std::unordered_set<TensorView*> selected = {});
|
||||
std::unordered_set<TensorView*> selected = {},
|
||||
std::unordered_set<IterDomain*> uninlinable_ids = {});
|
||||
|
||||
InlinePropagator(
|
||||
TensorView* reference,
|
||||
|
|
@ -117,24 +105,11 @@ class TORCH_CUDA_CU_API InlinePropagator
|
|||
|
||||
// Actually propagate the transformations for the inlining pass. Uses the
|
||||
// functions above to figure out what position to do the propagation at.
|
||||
virtual void setUp() override;
|
||||
virtual void propagateC2P(TensorView* from, TensorView* to) override;
|
||||
virtual void propagateP2C(TensorView* from, TensorView* to) override;
|
||||
virtual void propagateSibling(TensorView* from, TensorView* to) override;
|
||||
};
|
||||
|
||||
// This is actually not a propagation, it only sets the max producer position of
|
||||
// the tensors, and it is not needed to compute the max producer position in a
|
||||
// specific order. But MaxInfoSpanningTree provides a very convenient API to
|
||||
// visit the tensors, so I just use it for cleaner code.
|
||||
class TORCH_CUDA_CU_API MaxProducerPosUpdater
|
||||
: public MaxInfoSpanningTree::Propagator {
|
||||
std::unordered_set<TensorView*> updated_;
|
||||
void handle(TensorView* tv);
|
||||
|
||||
public:
|
||||
virtual void propagateC2P(TensorView* from, TensorView* to) override;
|
||||
virtual void propagateP2C(TensorView* from, TensorView* to) override;
|
||||
virtual void propagateSibling(TensorView* from, TensorView* to) override;
|
||||
virtual void tearDown() override;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -84,6 +84,9 @@ Val* IrBuilder::newResult(DataType dtype) {
|
|||
}
|
||||
|
||||
Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
|
||||
TORCH_CHECK(
|
||||
lhs != nullptr && rhs != nullptr,
|
||||
"Either lhs or rhs is a nullptr in newArithmeticExpr.");
|
||||
TORCH_CHECK(
|
||||
lhs->dtype() == rhs->dtype(),
|
||||
"Incompatible operand types: ",
|
||||
|
|
@ -97,6 +100,9 @@ Val* IrBuilder::newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
|
|||
}
|
||||
|
||||
Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
|
||||
TORCH_CHECK(
|
||||
lhs != nullptr && rhs != nullptr,
|
||||
"Either lhs or rhs is a nullptr in newLogicExpr.");
|
||||
auto result = IrBuilder::create<Bool>(c10::nullopt);
|
||||
IrBuilder::create<BinaryOp>(op_type, result, lhs, rhs);
|
||||
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
|
|
@ -104,6 +110,9 @@ Val* IrBuilder::newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs) {
|
|||
}
|
||||
|
||||
Val* IrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) {
|
||||
TORCH_CHECK(
|
||||
pred != nullptr && lhs != nullptr && rhs != nullptr,
|
||||
"Either pred, lhs, or rhs is a nullptr in whereExpr.");
|
||||
TORCH_CHECK(lhs->dtype() == rhs->dtype(), "Incompatible operand types");
|
||||
auto result = newResult(lhs->dtype());
|
||||
IrBuilder::create<TernaryOp>(TernaryOpType::Where, result, pred, lhs, rhs);
|
||||
|
|
@ -111,30 +120,35 @@ Val* IrBuilder::whereExpr(Val* pred, Val* lhs, Val* rhs) {
|
|||
}
|
||||
|
||||
Val* IrBuilder::negExpr(Val* val) {
|
||||
TORCH_CHECK(val != nullptr, "val is a nullptr in negExpr.");
|
||||
auto result = newResult(val->dtype());
|
||||
IrBuilder::create<UnaryOp>(UnaryOpType::Neg, result, val);
|
||||
return result;
|
||||
}
|
||||
|
||||
Val* IrBuilder::notExpr(Val* val) {
|
||||
TORCH_CHECK(val != nullptr, "val is a nullptr in notExpr.");
|
||||
auto result = newResult(val->dtype());
|
||||
IrBuilder::create<UnaryOp>(UnaryOpType::Not, result, val);
|
||||
return result;
|
||||
}
|
||||
|
||||
Val* IrBuilder::setExpr(Val* val) {
|
||||
TORCH_CHECK(val != nullptr, "val is a nullptr in setExpr.");
|
||||
auto result = newResult(val->dtype());
|
||||
IrBuilder::create<UnaryOp>(UnaryOpType::Set, result, val);
|
||||
return result;
|
||||
}
|
||||
|
||||
Val* IrBuilder::setExprNamedScalar(const std::string& name, Val* val) {
|
||||
TORCH_CHECK(val != nullptr, "val is a nullptr in setExprNamedScalar.");
|
||||
auto result = IrBuilder::create<NamedScalar>(name, val->dtype());
|
||||
IrBuilder::create<UnaryOp>(UnaryOpType::Set, result, val);
|
||||
return result;
|
||||
}
|
||||
|
||||
Val* IrBuilder::addressExprNamedScalar(const std::string& name, Val* val) {
|
||||
TORCH_CHECK(val != nullptr, "val is a nullptr in addressExprNamedScalar.");
|
||||
auto result = IrBuilder::create<NamedScalar>(name, DataType::Int);
|
||||
IrBuilder::create<UnaryOp>(UnaryOpType::Address, result, val);
|
||||
return result;
|
||||
|
|
|
|||
|
|
@ -381,7 +381,11 @@ class TORCH_CUDA_CU_API TensorView : public Val {
|
|||
|
||||
//! Swizzle the rectangular tile defined by the iterdomains corresponding
|
||||
//! to the 2 given indices.
|
||||
TensorView* swizzle(Swizzle2DType swizzle_type, int x, int y);
|
||||
TensorView* swizzle(
|
||||
Swizzle2DType swizzle_type,
|
||||
int x,
|
||||
int y,
|
||||
SwizzleMode swizzle_mode = SwizzleMode::Data);
|
||||
|
||||
// WARNING: rFactor does not return this TensorView, ir returns a new
|
||||
// tensorview consumed by this!
|
||||
|
|
@ -450,10 +454,26 @@ class TORCH_CUDA_CU_API TensorView : public Val {
|
|||
// Apply double buffering transformation
|
||||
void doubleBuffer();
|
||||
|
||||
// Apply circular buffering transformation
|
||||
void circularBuffer(unsigned int number_of_stage);
|
||||
|
||||
// Returns true if this tensor is double buffered.
|
||||
bool isDoubleBuffered() const {
|
||||
return is_double_buffered_;
|
||||
}
|
||||
|
||||
// Returns true if this tensor is circular buffered.
|
||||
bool isCircularBuffered() const {
|
||||
return is_circular_buffered_;
|
||||
}
|
||||
|
||||
// Returns the depth of circular buffering if applicable.
|
||||
unsigned int circularBufferDepth() const {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
is_circular_buffered_, toString(), "not circular buffered");
|
||||
return circular_buffer_stage_;
|
||||
}
|
||||
|
||||
//! Transforms the innermost iterdomains according to the given mma swizzle,
|
||||
//! this should be used on the tvs that are either inputs/outputs of an
|
||||
//! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to
|
||||
|
|
@ -509,6 +529,13 @@ class TORCH_CUDA_CU_API TensorView : public Val {
|
|||
SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
|
||||
std::vector<IterDomain*> axes_to_swizzle_;
|
||||
bool is_double_buffered_ = false;
|
||||
|
||||
//! Indicates if the tensor is circular buffered.
|
||||
bool is_circular_buffered_ = false;
|
||||
|
||||
//! Indicates the circular buffering stage depth if applicable.
|
||||
unsigned int circular_buffer_stage_ = 0;
|
||||
|
||||
// special handling for CPU based zero-dim tensors (i.e. CPU Tensors that only
|
||||
// have one value). This is only used if on an input value, otherwise ignored.
|
||||
// This is important as special handling because these "scalars" should be
|
||||
|
|
@ -545,7 +572,8 @@ class TORCH_CUDA_CU_API TensorViewBuilder {
|
|||
TensorViewBuilder& contiguity(std::vector<bool> contiguity);
|
||||
|
||||
//! Set the shape (default 0 dimensional, ie. scalar)
|
||||
TensorViewBuilder& shape(std::vector<int64_t> shape);
|
||||
TensorViewBuilder& shape(std::vector<Val*> shape);
|
||||
TensorViewBuilder& shape(const std::vector<int64_t>& shape);
|
||||
|
||||
//! Creates a new TensorView with the specified options
|
||||
TensorView* build() const;
|
||||
|
|
@ -554,7 +582,7 @@ class TORCH_CUDA_CU_API TensorViewBuilder {
|
|||
size_t ndims_ = 0;
|
||||
DataType dtype_ = DataType::Float;
|
||||
std::vector<bool> contiguity_;
|
||||
std::vector<int64_t> shape_;
|
||||
std::vector<Val*> shape_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -338,6 +338,22 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr {
|
|||
//! Fused Matmul operation
|
||||
class TORCH_CUDA_CU_API MmaOp : public Expr {
|
||||
public:
|
||||
// This is a temporary data structure to for the
|
||||
// scheduling specific parameters that we still need
|
||||
// to store on an mma node. Eventually will only be
|
||||
// the mma macro type that will stay on the IR node
|
||||
// after additional cleaning ups.
|
||||
struct OptionsInMma {
|
||||
MmaOptions::MacroType macro = MmaOptions::MacroType::NoMMA;
|
||||
MmaOptions::MmaInputLayout operand_layout = MmaOptions::MmaInputLayout::TT;
|
||||
int accumulator_stride = 0;
|
||||
|
||||
bool operator==(const OptionsInMma& other) const {
|
||||
return macro == other.macro && operand_layout == other.operand_layout &&
|
||||
accumulator_stride == other.accumulator_stride;
|
||||
}
|
||||
};
|
||||
|
||||
MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init);
|
||||
|
||||
MmaOp(
|
||||
|
|
@ -346,7 +362,7 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
|
|||
Val* in_a,
|
||||
Val* in_b,
|
||||
Val* init,
|
||||
MmaOptions options);
|
||||
OptionsInMma options);
|
||||
|
||||
MmaOp(const MmaOp* src, IrCloner* ir_cloner);
|
||||
|
||||
|
|
@ -379,7 +395,15 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
|
|||
}
|
||||
|
||||
void configureOptions(MmaOptions options) {
|
||||
options_ = options;
|
||||
options_ = OptionsInMma();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
options.macro != MmaOptions::MacroType::NoMMA,
|
||||
"Un-configured mma type from options.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
options.accumulator_stride > 0, "Un-configured accumulator stride.");
|
||||
options_->accumulator_stride = options.accumulator_stride;
|
||||
options_->macro = options.macro;
|
||||
options_->operand_layout = options.operand_layout;
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -387,7 +411,7 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
|
|||
Val* const in_a_ = nullptr;
|
||||
Val* const in_b_ = nullptr;
|
||||
Val* const init_ = nullptr;
|
||||
c10::optional<MmaOptions> options_ = c10::nullopt;
|
||||
c10::optional<OptionsInMma> options_ = c10::nullopt;
|
||||
};
|
||||
|
||||
class TORCH_CUDA_CU_API TransposeOp : public Expr {
|
||||
|
|
@ -967,7 +991,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
|
|||
static std::pair<IterDomain*, IterDomain*> swizzle(
|
||||
Swizzle2DType swizzle_type,
|
||||
IterDomain* in_x,
|
||||
IterDomain* in_y);
|
||||
IterDomain* in_y,
|
||||
SwizzleMode swizzle_mode = SwizzleMode::Data);
|
||||
|
||||
bool isMmaSwizzled() const {
|
||||
return is_mma_swizzled_;
|
||||
|
|
@ -1174,7 +1199,11 @@ class TORCH_CUDA_CU_API TensorDomain : public Val {
|
|||
|
||||
//! Applies 2D swizzle on a rectangular tile defined by
|
||||
//! a pair of iterdomains contained in this domain.
|
||||
void swizzle(Swizzle2DType swizzle_type, int x, int y);
|
||||
void swizzle(
|
||||
Swizzle2DType swizzle_type,
|
||||
int x,
|
||||
int y,
|
||||
SwizzleMode swizzle_mode = SwizzleMode::Data);
|
||||
|
||||
// Transform TensorView according to merge and split transformations
|
||||
TensorDomain* view(
|
||||
|
|
@ -1315,7 +1344,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {
|
|||
IterDomain* out_y,
|
||||
IterDomain* in_x,
|
||||
IterDomain* in_y,
|
||||
Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle);
|
||||
Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle,
|
||||
SwizzleMode swizzle_mode = SwizzleMode::Data);
|
||||
|
||||
Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner);
|
||||
|
||||
|
|
@ -1335,10 +1365,14 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {
|
|||
return in_y_;
|
||||
}
|
||||
|
||||
const auto& swizzleType() const {
|
||||
auto swizzleType() const {
|
||||
return swizzle_type_;
|
||||
}
|
||||
|
||||
auto swizzleMode() const {
|
||||
return swizzle_mode_;
|
||||
}
|
||||
|
||||
bool sameAs(const Statement* other) const override;
|
||||
|
||||
private:
|
||||
|
|
@ -1353,7 +1387,50 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {
|
|||
|
||||
// The type of predefined 1-to-1 functions
|
||||
// used for swizzling math.
|
||||
Swizzle2DType swizzle_type_;
|
||||
Swizzle2DType swizzle_type_ = Swizzle2DType::NoSwizzle;
|
||||
|
||||
// Swizzle mode of this swizzle instance.
|
||||
// [Note on swizzle mode]
|
||||
// On the current implementations we support two modes of
|
||||
// swizzle math, namely, data mode and loop mode.
|
||||
// `Data` mode swizzling is a swizzle that will change the
|
||||
// data layout in shared memory, likely in global memory buffers
|
||||
// as well in the future. see also IndexSwizzle in index_compute.cpp.
|
||||
//
|
||||
// Most important use cases are transpose bank conflict removal, and mma
|
||||
// swizzled shared memory layout. Example illustrated in 1D case:
|
||||
//
|
||||
// for (int i = 0; i<I; i++){
|
||||
// # This is a `Data` mode swizzle.
|
||||
// Tshared [swizzled(i)] = Tin[i];
|
||||
// }
|
||||
// # Now Tshared holds swizzled data, i.e. the data layout of
|
||||
// Tshared does not map to Tin with affine relationships.
|
||||
//
|
||||
// for(int i=0;i<I;i++){
|
||||
// Tout = Tshared[swizzled(i)];
|
||||
// }
|
||||
//
|
||||
// `Loop` mode swizzling does not affect the data layout of any buffer
|
||||
// but only permutes the iteration order of serial or parallel loop.
|
||||
// This is useful when we want to designate non-affine mapping of thread
|
||||
// to data or we want to generate non-affine loops.
|
||||
// Exampe illustrated in 1D case:
|
||||
// for (int i = 0; i<I; i++){
|
||||
// # This is a `Loop` mode swizzle
|
||||
// Tshared [swizzled(i)] = Tin[swizzled(i)];
|
||||
// }
|
||||
// # Now Tshared holds normal data, i.e. it still has
|
||||
// the same data layout as if the swizzle wasn't there.
|
||||
//
|
||||
// # Consumers of Tshared does not need to know about the
|
||||
// loop swizzle at previous op if not inlined.
|
||||
// for(int i=0;i<I;i++){
|
||||
// Tout = Tshared[i];
|
||||
// }
|
||||
// TODO: Loop swizzles eventually will be piped through in all mappings
|
||||
// and replay of the fusion IR infrastructure.
|
||||
SwizzleMode swizzle_mode_ = SwizzleMode::Data;
|
||||
};
|
||||
|
||||
//! Integer value which has a special name
|
||||
|
|
|
|||
|
|
@ -599,6 +599,10 @@ void IrPrinter::handle(const kir::BlockSync* node) {
|
|||
}
|
||||
|
||||
void IrPrinter::handle(const kir::CpAsyncWait* node) {
|
||||
indent() << "CPASYNC_WAIT(" << node->keepStages() << ")\n";
|
||||
}
|
||||
|
||||
void IrPrinter::handle(const kir::CpAsyncCommit* node) {
|
||||
indent() << "CPASYNC_WAIT()\n";
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
|
|||
void handle(const kir::BlockSync*) final;
|
||||
void handle(const kir::GridSync*) final;
|
||||
void handle(const kir::CpAsyncWait*) final;
|
||||
void handle(const kir::CpAsyncCommit*) final;
|
||||
void handle(const kir::InitMagicZero*) final;
|
||||
void handle(const kir::UpdateMagicZero*) final;
|
||||
void handle(const kir::AllocateFusedReduction*) final;
|
||||
|
|
|
|||
|
|
@ -630,7 +630,7 @@ MmaOp::MmaOp(
|
|||
Val* in_a,
|
||||
Val* in_b,
|
||||
Val* init,
|
||||
MmaOptions options)
|
||||
OptionsInMma options)
|
||||
: MmaOp(passkey, out, in_a, in_b, init) {
|
||||
options_ = options;
|
||||
}
|
||||
|
|
@ -1293,7 +1293,8 @@ std::pair<IterDomain*, IterDomain*> IterDomain::stridedSplit(int factor) {
|
|||
std::pair<IterDomain*, IterDomain*> IterDomain::swizzle(
|
||||
Swizzle2DType swizzle_type,
|
||||
IterDomain* in_x,
|
||||
IterDomain* in_y) {
|
||||
IterDomain* in_y,
|
||||
SwizzleMode swizzle_mode) {
|
||||
TORCH_CHECK(
|
||||
!in_x->extent()->isZeroInt() && !in_y->extent()->isZeroInt(),
|
||||
"Invalid swizzling of a empty dimension.");
|
||||
|
|
@ -1319,7 +1320,7 @@ std::pair<IterDomain*, IterDomain*> IterDomain::swizzle(
|
|||
IterDomain* out_y = IterDomainBuilder(in_y).build();
|
||||
|
||||
IrBuilder::create<Swizzle2D>(
|
||||
in_x->container(), out_x, out_y, in_x, in_y, swizzle_type);
|
||||
in_x->container(), out_x, out_y, in_x, in_y, swizzle_type, swizzle_mode);
|
||||
|
||||
return std::make_pair(out_x, out_y);
|
||||
}
|
||||
|
|
@ -1790,7 +1791,11 @@ std::vector<IterDomain*> TensorDomain::orderedAs(
|
|||
return reordered_domain;
|
||||
}
|
||||
|
||||
void TensorDomain::swizzle(Swizzle2DType swizzle_type, int x, int y) {
|
||||
void TensorDomain::swizzle(
|
||||
Swizzle2DType swizzle_type,
|
||||
int x,
|
||||
int y,
|
||||
SwizzleMode swizzle_mode) {
|
||||
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain");
|
||||
|
||||
TORCH_CHECK(
|
||||
|
|
@ -1808,7 +1813,7 @@ void TensorDomain::swizzle(Swizzle2DType swizzle_type, int x, int y) {
|
|||
IterDomain* axis_out_y = nullptr;
|
||||
|
||||
std::tie(axis_out_x, axis_out_y) =
|
||||
IterDomain::swizzle(swizzle_type, axis_x, axis_y);
|
||||
IterDomain::swizzle(swizzle_type, axis_x, axis_y, swizzle_mode);
|
||||
|
||||
domain_.erase(domain_.begin() + x);
|
||||
domain_.insert(domain_.begin() + x, axis_out_x);
|
||||
|
|
@ -2039,13 +2044,15 @@ Swizzle2D::Swizzle2D(
|
|||
IterDomain* out_y,
|
||||
IterDomain* in_x,
|
||||
IterDomain* in_y,
|
||||
Swizzle2DType swizzle_type)
|
||||
Swizzle2DType swizzle_type,
|
||||
SwizzleMode swizzle_mode)
|
||||
: Expr(passkey, ExprType::Swizzle2D),
|
||||
out_x_{out_x},
|
||||
out_y_{out_y},
|
||||
in_x_{in_x},
|
||||
in_y_{in_y},
|
||||
swizzle_type_(swizzle_type) {
|
||||
swizzle_type_(swizzle_type),
|
||||
swizzle_mode_(swizzle_mode) {
|
||||
addOutput(out_x);
|
||||
addOutput(out_y);
|
||||
addInput(in_x);
|
||||
|
|
@ -2071,7 +2078,8 @@ Swizzle2D::Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner)
|
|||
out_y_(ir_cloner->clone(src->out_y_)),
|
||||
in_x_(ir_cloner->clone(src->in_x_)),
|
||||
in_y_(ir_cloner->clone(src->in_y_)),
|
||||
swizzle_type_(src->swizzle_type_) {}
|
||||
swizzle_type_(src->swizzle_type_),
|
||||
swizzle_mode_(src->swizzle_mode_) {}
|
||||
|
||||
NamedScalar::NamedScalar(
|
||||
IrBuilderPasskey passkey,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/ir_builder.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
|
||||
#include <set>
|
||||
|
||||
|
|
@ -473,25 +474,23 @@ TensorView* rfactorHelper(
|
|||
TensorView* reduction_tv,
|
||||
const std::vector<int>& axes) {
|
||||
TORCH_INTERNAL_ASSERT(reduction_tv->definition() != nullptr);
|
||||
const bool is_welford = reduction_tv->definition()->isA<WelfordOp>();
|
||||
if (!is_welford) {
|
||||
const bool has_multiple_tvs = reduction_tv->definition()->inputs().size() > 1;
|
||||
if (!has_multiple_tvs) {
|
||||
return reduction_tv->rFactor(axes);
|
||||
}
|
||||
auto welford = reduction_tv->definition()->as<WelfordOp>();
|
||||
auto w_avg = welford->outAvg()->as<TensorView>();
|
||||
auto w_var = welford->outVar()->as<TensorView>();
|
||||
auto w_n = welford->outN()->as<TensorView>();
|
||||
|
||||
auto rtvs =
|
||||
reduction_tv->rFactor(axes, std::vector<TensorView*>{w_avg, w_var, w_n});
|
||||
std::vector<TensorView*> out_tvs;
|
||||
std::transform(
|
||||
reduction_tv->definition()->outputs().begin(),
|
||||
reduction_tv->definition()->outputs().end(),
|
||||
std::back_inserter(out_tvs),
|
||||
[](Val* val) { return val->as<TensorView>(); });
|
||||
|
||||
if (reduction_tv == w_n) {
|
||||
return rtvs.at(2);
|
||||
} else if (reduction_tv == w_var) {
|
||||
return rtvs.at(1);
|
||||
} else {
|
||||
return rtvs.at(0);
|
||||
}
|
||||
auto rf_tvs = reduction_tv->rFactor(axes, out_tvs);
|
||||
|
||||
return rf_tvs.at(std::distance(
|
||||
out_tvs.begin(),
|
||||
std::find(out_tvs.begin(), out_tvs.end(), reduction_tv)));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
@ -652,6 +651,19 @@ std::vector<TensorView*> allTvs(Fusion* fusion) {
|
|||
return uniqueEntries<TensorView>(all_tvs);
|
||||
}
|
||||
|
||||
std::vector<TensorView*> allTvsExcept(
|
||||
Fusion* fusion,
|
||||
const std::unordered_set<TensorView*>& except) {
|
||||
auto all_tvs = allTvs(fusion);
|
||||
std::vector<TensorView*> result;
|
||||
for (auto tv : all_tvs) {
|
||||
if (except.count(tv) == 0) {
|
||||
result.emplace_back(tv);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<Expr*> getReductionOps(Fusion* fusion, bool ignore_trivial) {
|
||||
std::vector<Expr*> red_ops;
|
||||
|
||||
|
|
@ -796,6 +808,145 @@ Val* getReductionInitValOf(TensorView* tv) {
|
|||
return init;
|
||||
}
|
||||
|
||||
// TODO: Should mma be in here? Should we return true if it's a trivial
|
||||
// reduction?
|
||||
bool isReductionOp(const Expr* expr) {
|
||||
// Note that GridReduction inherits ReductionOp
|
||||
return expr->isA<ReductionOp>() || expr->isA<GroupedReductionOp>() ||
|
||||
expr->isA<WelfordOp>() || expr->isA<kir::GridWelford>();
|
||||
}
|
||||
|
||||
bool isReductionTvOp(const Expr* expr) {
|
||||
return ir_utils::isTvOp(expr) && isReductionOp(expr);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct ReplaceValInIndexVal : public OptInDispatch {
|
||||
public:
|
||||
//! Apply replacements to index as specified in
|
||||
//! replacement_map. index is assumed to consist only from Int and
|
||||
//! NamedScalar
|
||||
static Val* replace(
|
||||
Val* index,
|
||||
const std::unordered_map<Val*, Val*>& replacement_map) {
|
||||
ReplaceValInIndexVal replace_index_val(replacement_map);
|
||||
replace_index_val.handle(index);
|
||||
// Return the original index if not replaced
|
||||
if (replace_index_val.is_replaced_) {
|
||||
return replace_index_val.last_visited_val_;
|
||||
} else {
|
||||
return index;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
ReplaceValInIndexVal(const std::unordered_map<Val*, Val*>& replacement_map)
|
||||
: replacement_map_(replacement_map) {}
|
||||
|
||||
using OptOutDispatch::handle;
|
||||
|
||||
void handle(Val* val) override {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
val->isA<Int>() || val->isA<NamedScalar>() || val->isA<kir::IntPair>(),
|
||||
"Invalid Val type: ",
|
||||
val->toString());
|
||||
|
||||
// if val appears in the replacement map, stop traversing and set
|
||||
// the current val with the replacement
|
||||
auto it = replacement_map_.find(val);
|
||||
if (it != replacement_map_.end()) {
|
||||
last_visited_val_ = it->second;
|
||||
is_replaced_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// Recursively traverse its defining expr
|
||||
auto def = val->definition();
|
||||
if (def != nullptr) {
|
||||
switch (def->etype()) {
|
||||
case ExprType::UnaryOp:
|
||||
case ExprType::BinaryOp:
|
||||
case ExprType::Swizzle2DInt:
|
||||
case ExprType::PairSelect:
|
||||
handle(val->definition());
|
||||
break;
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "Unexpected definition: ", def->toString())
|
||||
}
|
||||
// last_visited_val_ is set in the expr handlers
|
||||
} else {
|
||||
last_visited_val_ = val;
|
||||
}
|
||||
}
|
||||
|
||||
// Clone expression after recurisvely replacing inputs
|
||||
void handle(UnaryOp* uop) override {
|
||||
handle(uop->in());
|
||||
auto inp = last_visited_val_;
|
||||
TORCH_INTERNAL_ASSERT(uop->out()->isA<Int>());
|
||||
auto out = IrBuilder::create<Int>(c10::nullopt);
|
||||
IrBuilder::create<UnaryOp>(uop->getUnaryOpType(), out, inp);
|
||||
last_visited_val_ = out;
|
||||
}
|
||||
|
||||
// Clone expression after recurisvely replacing inputs
|
||||
void handle(BinaryOp* bop) override {
|
||||
handle(bop->lhs());
|
||||
auto lhs = last_visited_val_;
|
||||
handle(bop->rhs());
|
||||
auto rhs = last_visited_val_;
|
||||
TORCH_INTERNAL_ASSERT(bop->out()->isA<Int>());
|
||||
auto out = IrBuilder::create<Int>(c10::nullopt);
|
||||
IrBuilder::create<BinaryOp>(bop->getBinaryOpType(), out, lhs, rhs);
|
||||
last_visited_val_ = out;
|
||||
}
|
||||
|
||||
// Clone expression after recurisvely replacing inputs
|
||||
void handle(kir::Swizzle2DInt* swizzle_2d) override {
|
||||
handle(swizzle_2d->inX());
|
||||
auto in_x = last_visited_val_;
|
||||
handle(swizzle_2d->inY());
|
||||
auto in_y = last_visited_val_;
|
||||
auto out = IrBuilder::create<kir::IntPair>();
|
||||
|
||||
// Extents are assumed constant in swizzle so no need to
|
||||
// duplicate their graphs.
|
||||
IrBuilder::create<kir::Swizzle2DInt>(
|
||||
out,
|
||||
in_x,
|
||||
in_y,
|
||||
swizzle_2d->extentX(),
|
||||
swizzle_2d->extentY(),
|
||||
swizzle_2d->swizzleType());
|
||||
last_visited_val_ = out;
|
||||
}
|
||||
|
||||
void handle(kir::PairSelect* pair_select) override {
|
||||
handle(pair_select->in()->asVal());
|
||||
auto in = last_visited_val_;
|
||||
TORCH_INTERNAL_ASSERT(pair_select->out()->isA<Int>());
|
||||
auto out = IrBuilder::create<Int>(c10::nullopt);
|
||||
IrBuilder::create<kir::PairSelect>(
|
||||
out, in->as<kir::IntPair>(), pair_select->selection());
|
||||
last_visited_val_ = out;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::unordered_map<Val*, Val*>& replacement_map_;
|
||||
Val* last_visited_val_ = nullptr;
|
||||
bool is_replaced_ = false;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Val* replaceValInIndexVal(
|
||||
Val* index,
|
||||
const std::unordered_map<Val*, Val*>& replacement_map) {
|
||||
return ReplaceValInIndexVal::replace(index, replacement_map);
|
||||
}
|
||||
|
||||
} // namespace ir_utils
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
|
|
|
|||
|
|
@ -156,6 +156,18 @@ std::vector<int> normalizeOld2New(
|
|||
// Reference is found through direct pointer comparison.
|
||||
Expr* replaceValInExpr(Expr* expr, Val* reference, Val* substitute);
|
||||
|
||||
//! Replace Vals in an index Val as specified by replacement_map while
|
||||
//! cloning the given index Val. The index val is assumed to represent
|
||||
//! a tensor index consisting of Ints and arithmetic expressions.
|
||||
//!
|
||||
//! This is similar to replaceValInExpr but is different as Vals are
|
||||
//! cloned such that no other exprs using the same leaf Vals are not
|
||||
//! modified. TODO: Consider cleaning up the multiple replacement
|
||||
//! routines.
|
||||
Val* replaceValInIndexVal(
|
||||
Val* index,
|
||||
const std::unordered_map<Val*, Val*>& replacement_map);
|
||||
|
||||
// Makes rfactor generic with reduction ops and Welford
|
||||
TORCH_CUDA_CU_API TensorView* rfactorHelper(
|
||||
TensorView* red_tv,
|
||||
|
|
@ -282,6 +294,12 @@ TORCH_CUDA_CU_API std::vector<TensorView*> outputTvsOf(
|
|||
// returns all tensor views in fusion that are used between outputs and inputs.
|
||||
TORCH_CUDA_CU_API std::vector<TensorView*> allTvs(Fusion* fusion);
|
||||
|
||||
// returns all tensor views in fusion that are used between outputs and inputs
|
||||
// except the specified set.
|
||||
TORCH_CUDA_CU_API std::vector<TensorView*> allTvsExcept(
|
||||
Fusion* fusion,
|
||||
const std::unordered_set<TensorView*>& except);
|
||||
|
||||
TORCH_CUDA_CU_API std::vector<Expr*> getReductionOps(
|
||||
Fusion* fusion,
|
||||
bool ignore_trivial = true);
|
||||
|
|
@ -289,6 +307,12 @@ TORCH_CUDA_CU_API std::vector<Expr*> getReductionOps(
|
|||
// Returns the initialization value of tv or nullptr if not initialized.
|
||||
TORCH_CUDA_CU_API Val* getReductionInitValOf(TensorView* tv);
|
||||
|
||||
// Returns if Expr is a reduction op
|
||||
TORCH_CUDA_CU_API bool isReductionOp(const Expr*);
|
||||
|
||||
// Returns if Expr is a reduction op with TensorView or TensorIndex
|
||||
TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);
|
||||
|
||||
template <typename T>
|
||||
std::string toString(const T& nodes) {
|
||||
std::stringstream ss;
|
||||
|
|
|
|||
|
|
@ -327,7 +327,7 @@ std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
|
|||
auto scheduler_entry = schedulers()[group_id].get();
|
||||
|
||||
// Check that the heuristics are matched, in the case of segmented fusion
|
||||
TORCH_INTERNAL_ASSERT(!sg || scheduler_entry->heuristc() == sg->heuristic());
|
||||
TORCH_INTERNAL_ASSERT(!sg || scheduler_entry->heuristic() == sg->heuristic());
|
||||
|
||||
if (!executors_[group_id].compiled()) {
|
||||
FUSER_PERF_SCOPE("FusionKernelRuntime::runKernelWithInput::Compile");
|
||||
|
|
@ -341,32 +341,16 @@ std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
|
|||
options.index_mode = scheduler_entry->indexMode();
|
||||
FusionGuard fg(fusion_to_run.get());
|
||||
scheduler_entry->schedule(fusion_to_run.get());
|
||||
// Load launch params for reduction and normalization kernels
|
||||
if (scheduler_entry->hasReductionParam()) {
|
||||
launch_params = scheduler_entry->reductionParams().lparams;
|
||||
} else {
|
||||
launch_params = scheduler_entry->pointwiseParams().lparams;
|
||||
}
|
||||
launch_params = scheduler_entry->params()->lparams;
|
||||
executors_[group_id].compileFusion(
|
||||
fusion_to_run.get(), inputs, launch_params, options);
|
||||
} else {
|
||||
// Load launch params for reduction and normalization kernels
|
||||
if (scheduler_entry->hasReductionParam()) {
|
||||
launch_params = scheduler_entry->reductionParams().lparams;
|
||||
} else {
|
||||
launch_params = scheduler_entry->pointwiseParams().lparams;
|
||||
}
|
||||
launch_params = scheduler_entry->params()->lparams;
|
||||
}
|
||||
|
||||
if (profiling_) {
|
||||
most_recent_executor_log_.fusion_executor = &executors_[group_id];
|
||||
if (scheduler_entry->hasReductionParam()) {
|
||||
most_recent_executor_log_.reduction_params =
|
||||
scheduler_entry->reductionParams();
|
||||
} else {
|
||||
most_recent_executor_log_.pointwise_params =
|
||||
scheduler_entry->pointwiseParams();
|
||||
}
|
||||
most_recent_executor_log_.params = scheduler_entry->params()->clone();
|
||||
}
|
||||
|
||||
auto& executor = executors_[group_id];
|
||||
|
|
@ -395,11 +379,7 @@ std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
|
|||
}
|
||||
}
|
||||
std::cout << "Compiler log: " << executor.compilerLog() << "\n";
|
||||
if (scheduler_entry->hasReductionParam()) {
|
||||
std::cout << scheduler_entry->reductionParams().toString() << "\n";
|
||||
} else {
|
||||
std::cout << scheduler_entry->pointwiseParams().toString() << "\n";
|
||||
}
|
||||
std::cout << scheduler_entry->params()->toString() << "\n";
|
||||
std::cout << "With arguments: " << executor.lastLaunchParams().toString();
|
||||
std::cout << executor.kernelName() << " " << executor.bytesProcessed()
|
||||
<< " bytes/ " << std::setprecision(3) << executor.kernelTimeMs()
|
||||
|
|
@ -604,13 +584,8 @@ void FusionKernelRuntime::updateHeuristicsLaunchParams(
|
|||
update_heuristics->heuristicsList().size() == scheduler_list_length);
|
||||
for (const auto i : c10::irange(scheduler_list_length)) {
|
||||
auto& schedulerPtr = heuristics_->heuristicsList()[i];
|
||||
if (schedulerPtr->hasReductionParam()) {
|
||||
schedulerPtr->updateLaunchConstraint(
|
||||
update_heuristics->heuristicsList()[i]->reductionParams().lparams);
|
||||
} else {
|
||||
schedulerPtr->updateLaunchConstraint(
|
||||
update_heuristics->heuristicsList()[i]->pointwiseParams().lparams);
|
||||
}
|
||||
schedulerPtr->updateLaunchConstraint(
|
||||
update_heuristics->heuristicsList()[i]->params()->lparams);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@ class SchedulerRuntimeInfo;
|
|||
|
||||
// Utilities for benchmarking and profiling
|
||||
struct ExecutorLog {
|
||||
c10::optional<ReductionParams> reduction_params = c10::nullopt;
|
||||
c10::optional<PointwiseParams> pointwise_params = c10::nullopt;
|
||||
std::shared_ptr<HeuristicParams> params = nullptr;
|
||||
FusionExecutor* fusion_executor = nullptr;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ void ExpressionEvaluator::bind(
|
|||
TORCH_CHECK(
|
||||
value->definition() == nullptr,
|
||||
"Tried to bind to a value that is computed in the kernel IR: ",
|
||||
value->toString(),
|
||||
value->toInlineString(),
|
||||
" with ",
|
||||
concrete_value);
|
||||
known_values_[value] = concrete_value;
|
||||
|
|
|
|||
|
|
@ -93,8 +93,15 @@ GridSync::GridSync(
|
|||
sync_dims_(sync_dims),
|
||||
sync_buffer_(sync_buffer) {}
|
||||
|
||||
CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey)
|
||||
: Expr(passkey, ExprType::CpAsyncWait) {
|
||||
CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages)
|
||||
: Expr(passkey, ExprType::CpAsyncWait), keep_stages_(keep_stages) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
passkey.ir_container_->isA<kir::Kernel>(),
|
||||
"IR type only valid for Kernel container.");
|
||||
}
|
||||
|
||||
CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey)
|
||||
: Expr(passkey, ExprType::CpAsyncCommit) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
passkey.ir_container_->isA<kir::Kernel>(),
|
||||
"IR type only valid for Kernel container.");
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ class Allocate;
|
|||
class BlockSync;
|
||||
class GridSync;
|
||||
class CpAsyncWait;
|
||||
class CpAsyncCommit;
|
||||
class InitMagicZero;
|
||||
class UpdateMagicZero;
|
||||
class ForLoop;
|
||||
|
|
@ -258,11 +259,27 @@ class TORCH_CUDA_CU_API BlockSync final : public Expr {
|
|||
};
|
||||
|
||||
// CpAsyncWait represents wait intrinsics for cp.async
|
||||
// TODO: expand to support different wait modes of the intrinsic
|
||||
// as the analysis passes build out.
|
||||
class TORCH_CUDA_CU_API CpAsyncWait final : public Expr {
|
||||
public:
|
||||
explicit CpAsyncWait(IrBuilderPasskey passkey);
|
||||
explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0);
|
||||
|
||||
//! Returns the remaining number of stages that are not synchronized
|
||||
//! after this op.
|
||||
unsigned int keepStages() const {
|
||||
return keep_stages_;
|
||||
}
|
||||
|
||||
private:
|
||||
//! Number of stage to leave un-sync'ed by this op.
|
||||
unsigned int keep_stages_ = 0;
|
||||
};
|
||||
|
||||
// CpAsyncCommit represents commit intrinsics for cp.async
|
||||
// A commit intrinsic communicates delimiter of transaction groups
|
||||
// to the async load hardware. Example usage see [Cicular buffer].
|
||||
class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr {
|
||||
public:
|
||||
explicit CpAsyncCommit(IrBuilderPasskey passkey);
|
||||
};
|
||||
|
||||
// Synchronize all blocks in device, implies cooperative group launch is
|
||||
|
|
|
|||
|
|
@ -257,6 +257,9 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
|
|||
// Validate mma data format and compatibility if any on the fusion.
|
||||
validateMma(fusion_);
|
||||
|
||||
// Validate swizzle usage on the fusion schedule.
|
||||
validateSwizzle(fusion_);
|
||||
|
||||
// Compute thread predicates. Depends on parallel_dimension_map_
|
||||
thread_pred_map_.build(fusion_);
|
||||
|
||||
|
|
|
|||
|
|
@ -408,7 +408,7 @@ class AllocationInserter : public kir::ExprMutator {
|
|||
|
||||
// Double the allocation size if double-buffered. Record the
|
||||
// original size for indexing.
|
||||
if (info.buffer->isDoubleBuffered()) {
|
||||
if (info.buffer->isDoubleBuffered() || info.buffer->isCircularBuffered()) {
|
||||
Val* original_alloc_size = nullptr;
|
||||
for (auto alloc_dim : alloc_dims) {
|
||||
if (original_alloc_size == nullptr) {
|
||||
|
|
@ -420,7 +420,11 @@ class AllocationInserter : public kir::ExprMutator {
|
|||
}
|
||||
GpuLower::current()->doubleBufferInfo().setOriginalAllocSize(
|
||||
info.buffer, original_alloc_size);
|
||||
alloc_dims.push_back(IrBuilder::create<Int>(2));
|
||||
int double_buffer_stage = 2;
|
||||
if (info.buffer->isCircularBuffered()) {
|
||||
double_buffer_stage = info.buffer->circularBufferDepth();
|
||||
}
|
||||
alloc_dims.push_back(IrBuilder::create<Int>(double_buffer_stage));
|
||||
}
|
||||
|
||||
// Create the allocation node
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ class DoubleBufferFusionInspector : private IterVisitor {
|
|||
using IterVisitor::handle;
|
||||
|
||||
void handle(TensorView* tv) final {
|
||||
if (!tv->isDoubleBuffered()) {
|
||||
if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -190,10 +190,12 @@ class DoubleBufferLoopCloner : public kir::IrVisitor {
|
|||
double_buffer_loop_->iter_domain(), loop_type_);
|
||||
auto start = double_buffer_loop_->start();
|
||||
auto stop = double_buffer_loop_->stop();
|
||||
auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor(
|
||||
double_buffer_loop_->iter_domain());
|
||||
|
||||
if (loop_type_ == DoubleBufferLoopStage::Prolog) {
|
||||
TORCH_INTERNAL_ASSERT(start->isZeroInt());
|
||||
stop = gpu_lower->kernel()->oneVal();
|
||||
stop = SimplifyingIrBuilder::create<Int>(stage_depth - 1);
|
||||
} else if (
|
||||
loop_type_ == DoubleBufferLoopStage::Main &&
|
||||
requireEpilogue(double_buffer_load_exprs_)) {
|
||||
|
|
@ -202,7 +204,8 @@ class DoubleBufferLoopCloner : public kir::IrVisitor {
|
|||
} else if (loop_type_ == DoubleBufferLoopStage::Epilog) {
|
||||
TORCH_INTERNAL_ASSERT(requireEpilogue(double_buffer_load_exprs_));
|
||||
start = IrBuilder::subExpr(
|
||||
double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal());
|
||||
double_buffer_loop_->stop(),
|
||||
SimplifyingIrBuilder::create<Int>(stage_depth - 1));
|
||||
}
|
||||
|
||||
cloned_top_level_loop_ = IrBuilder::create<kir::ForLoop>(
|
||||
|
|
@ -217,6 +220,11 @@ class DoubleBufferLoopCloner : public kir::IrVisitor {
|
|||
loop_type_);
|
||||
|
||||
handle(double_buffer_loop_);
|
||||
|
||||
if (stage_depth > 2) {
|
||||
cloned_top_level_loop_->body().push_back(
|
||||
IrBuilder::create<kir::CpAsyncCommit>());
|
||||
}
|
||||
}
|
||||
|
||||
void handle(kir::ForLoop* fl) final {
|
||||
|
|
@ -314,7 +322,8 @@ class DoubleBufferLoopNestInspector : private kir::IrVisitor {
|
|||
}
|
||||
|
||||
// Ignore init loop
|
||||
if (!out_tv->isDoubleBuffered() || !expr->input(0)->isA<TensorView>()) {
|
||||
if (!(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered()) ||
|
||||
!expr->input(0)->isA<TensorView>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -430,7 +439,11 @@ class DoubleBufferInserter : private kir::ExprMutator {
|
|||
// loop is async copy. We want to wait for the gmem loads to
|
||||
// finish before synchronizing the block.
|
||||
if (std::any_of(loads.begin(), loads.end(), ir_utils::isCpAsyncOp)) {
|
||||
auto cp_async_wait = IrBuilder::create<kir::CpAsyncWait>();
|
||||
auto stage_depth =
|
||||
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
|
||||
double_buffer_loop->iter_domain());
|
||||
auto cp_async_wait =
|
||||
IrBuilder::create<kir::CpAsyncWait>(stage_depth - 2);
|
||||
registerInsertBefore(double_buffer_loop, cp_async_wait);
|
||||
insert_cpasync_wait = true;
|
||||
}
|
||||
|
|
@ -506,7 +519,9 @@ class DoubleBufferInserter : private kir::ExprMutator {
|
|||
// passes. Cleanups suggested in [Double Buffer Sync]
|
||||
// would resolve this dependency on pass ordering.
|
||||
auto end_of_loop_expr = main_loop->body().exprs().back();
|
||||
auto cp_async_wait = IrBuilder::create<kir::CpAsyncWait>();
|
||||
auto stage_depth = GpuLower::current()->doubleBufferInfo().getStageDepthFor(
|
||||
main_loop->iter_domain());
|
||||
auto cp_async_wait = IrBuilder::create<kir::CpAsyncWait>(stage_depth - 2);
|
||||
|
||||
// Check if a sync has been inserted by WAR sync pass.
|
||||
auto block_sync_it = std::find_if(
|
||||
|
|
@ -557,7 +572,9 @@ bool DoubleBufferInfo::isDoubleBufferedIterDomain(IterDomain* id) {
|
|||
|
||||
DoubleBufferInfo::TvInfo& DoubleBufferInfo::getTvInfo(const TensorView* tv) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tv->isDoubleBuffered(), "Not a double-buffered tensor: ", tv->toString());
|
||||
tv->isDoubleBuffered() || tv->isCircularBuffered(),
|
||||
"Not a double-buffered tensor: ",
|
||||
tv->toString());
|
||||
return map_[tv];
|
||||
}
|
||||
|
||||
|
|
@ -565,16 +582,63 @@ void DoubleBufferInfo::setDoubleBufferAxis(
|
|||
const TensorView* tv,
|
||||
IterDomain* axis) {
|
||||
getTvInfo(tv).double_buffer_axis = axis;
|
||||
|
||||
// Also validate the stage consistency with CA map.
|
||||
unsigned int stage_depth = 0;
|
||||
if (tv->isCircularBuffered()) {
|
||||
stage_depth = tv->circularBufferDepth();
|
||||
} else {
|
||||
// Double buffer is essentially
|
||||
// circular buffer with depth 2.
|
||||
stage_depth = 2;
|
||||
}
|
||||
|
||||
// Set and validate the new stage depth.
|
||||
setStageDepth(axis, stage_depth);
|
||||
}
|
||||
|
||||
void DoubleBufferInfo::setStageDepth(IterDomain* id, unsigned int stage_depth) {
|
||||
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
|
||||
id, IdMappingMode::LOOP);
|
||||
|
||||
auto maybe_exisiting_depth_it = stage_depth_.find(concrete_loop_id);
|
||||
if (maybe_exisiting_depth_it == stage_depth_.end()) {
|
||||
stage_depth_[concrete_loop_id] = stage_depth;
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
stage_depth == maybe_exisiting_depth_it->second,
|
||||
"Unsupported multiple depth pipelining, was set to ",
|
||||
maybe_exisiting_depth_it->second,
|
||||
" by ",
|
||||
maybe_exisiting_depth_it->first->toString(),
|
||||
" and then set to ",
|
||||
stage_depth,
|
||||
" by ",
|
||||
concrete_loop_id->toString());
|
||||
}
|
||||
}
|
||||
|
||||
IterDomain* DoubleBufferInfo::getDoubleBufferAxis(const TensorView* tv) {
|
||||
if (!tv->isDoubleBuffered()) {
|
||||
if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return getTvInfo(tv).double_buffer_axis;
|
||||
}
|
||||
|
||||
unsigned int DoubleBufferInfo::getStageDepthFor(
|
||||
IterDomain* double_buffer_axis) {
|
||||
auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
|
||||
double_buffer_axis, IdMappingMode::LOOP);
|
||||
|
||||
auto maybe_depth_it = stage_depth_.find(concrete_id);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
maybe_depth_it != stage_depth_.end(), "Stage depth not found");
|
||||
|
||||
return maybe_depth_it->second;
|
||||
}
|
||||
|
||||
kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop(
|
||||
IterDomain* axis,
|
||||
const std::vector<kir::ForLoop*>& loops,
|
||||
|
|
@ -582,7 +646,8 @@ kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop(
|
|||
auto loop_it = std::find_if(loops.begin(), loops.end(), [&](const auto loop) {
|
||||
return GpuLower::current()->caMap()->areMapped(
|
||||
loop->iter_domain(), axis, IdMappingMode::EXACT) &&
|
||||
(!ignore_prologue || !loop->stop()->isOneInt());
|
||||
(!ignore_prologue ||
|
||||
loop->doubleBufferLoopStage() != DoubleBufferLoopStage::Prolog);
|
||||
});
|
||||
|
||||
if (loop_it != loops.end()) {
|
||||
|
|
@ -612,7 +677,7 @@ void DoubleBufferInfo::setOriginalAllocSize(
|
|||
}
|
||||
|
||||
Val* DoubleBufferInfo::getOriginalAllocSize(const TensorView* tv) {
|
||||
if (!tv->isDoubleBuffered()) {
|
||||
if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,85 @@
|
|||
// - Double the allocation size
|
||||
// - Omit the RAW sync in the Main and Epilogue loops
|
||||
|
||||
// [Cicular buffer] An generalization of double buffering.
|
||||
// On sm80+ hardware there is asynchronous copy infrastructure that
|
||||
// motivates a circular buffering generalization of double buffering.
|
||||
// Almost all analyses previously done for double buffering are exactly
|
||||
// the same with circular buffering, except for the introduction of
|
||||
// new concept: `stage depth`.
|
||||
//
|
||||
// The `stage depth` is defined as the multiplier of extra buffering
|
||||
// space used. In the case of double buffering, the stage depth would
|
||||
// be 2.
|
||||
//
|
||||
// A circular buffered loop structure would look like follows, which
|
||||
// exactly parallels the case of double buffered loop structure, since
|
||||
// it is a exact generalization to the same purpose.
|
||||
//
|
||||
// Here S is the original allocation size as above,
|
||||
// D is the stage depth. With D=2, the below loop structure becomes
|
||||
// exactly the same as the case in double buffering.
|
||||
//
|
||||
// allocate X[S*D] // allocation
|
||||
// for i in 0..D-1: // prolog
|
||||
// for j in ...
|
||||
// if pred:
|
||||
// x[i*S+j] = y[i, j];
|
||||
//
|
||||
// for i in 0..N: // main loop
|
||||
// for j in ...
|
||||
// if pred:
|
||||
// x[((i+D-1)%D)*S+j] = y[i+D-1, j];
|
||||
// for j in ...
|
||||
// .. = x[(i%D)*S+j]
|
||||
//
|
||||
// (Epilog omitted since this only makes sense in using
|
||||
// cp.async, where producer will be in global mem and consumer will
|
||||
// be in shared mem).
|
||||
//
|
||||
// The profitability of this optimization comes from extra tolerance
|
||||
// of global memory pipeline latency, as on the expression `.. = x[(i%D)*S+j]`
|
||||
// we only need to make sure the data for the current iteration is
|
||||
// completed while the remaining D-2 load iterations could still be in progress
|
||||
// and overlap with the computes of the current loop.
|
||||
//
|
||||
// To express this pattern on sm80+ hardware we can group the loads
|
||||
// in each iteration of the circular buffered loop as one "transaction",
|
||||
// and specify how many transactions we want to ensure completion when
|
||||
// we insert the async barriers.
|
||||
//
|
||||
// allocate X[S*D] // allocation
|
||||
// for i in 0..D-1: // prolog
|
||||
// for j in ...
|
||||
// if pred:
|
||||
// x[i*S+j] = y[i, j];
|
||||
// cp.async.commit; // mark the transaction boundary
|
||||
//
|
||||
// # At this point we have D-1 transactions on the fly.
|
||||
// and for the first iteration of the main loop we need
|
||||
// one transaction completed, so we leave D-2 transactions
|
||||
// on the fly, which would be the input to the barrier instruction.
|
||||
//
|
||||
// cp.async.wait D-2 // ensure all but the last D-2 transactions complete.
|
||||
//
|
||||
// for i in 0..N: // main loop
|
||||
// # At this point we always have D-2 transactions on the fly.
|
||||
// and one completed.
|
||||
// for j in ...
|
||||
// if pred:
|
||||
// x[((i+D-1)%D)*S+j] = y[i+D-1, j];
|
||||
// for j in ...
|
||||
// .. = x[(i%D)*S+j]
|
||||
// cp.async.commit; // mark the transaction boundary for the
|
||||
// load issued in this iteration.
|
||||
// # At this point we have D-1 transactions on the fly,
|
||||
// and none completed.
|
||||
// cp.async.wait D-2; // Ensure all but the last D-2 transactions complete.
|
||||
// __syncthreads(); // Need to syncthreads because each thread will only
|
||||
// ensure completion of its own async copies so
|
||||
// would need to sync to this point to ensure
|
||||
// completion of the whole tile.
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
|
|
@ -132,9 +211,23 @@ class TORCH_CUDA_CU_API DoubleBufferInfo {
|
|||
//! as a double buffer loop.
|
||||
bool isDoubleBufferedIterDomain(IterDomain* id);
|
||||
|
||||
//! Get the number of circular buffer stages for the given axis,
|
||||
//! the number of stages will be 2 in the case of double buffer loop.
|
||||
unsigned int getStageDepthFor(IterDomain* circular_buffered_id);
|
||||
|
||||
private:
|
||||
TvInfo& getTvInfo(const TensorView* tv);
|
||||
|
||||
//! Set the number of circular buffer stages for the given
|
||||
//! circular_buffered_id.
|
||||
//! Current code generation only supports one stage depth per loop disjoint
|
||||
//! set,
|
||||
//! so this function will throw an error if trying to set different stage
|
||||
//! numbers to iterdomains that are loop mapped.
|
||||
void setStageDepth(
|
||||
IterDomain* circular_buffered_id,
|
||||
unsigned int stage_depth);
|
||||
|
||||
private:
|
||||
//! Keeps track of information for lowering double buffered tensors
|
||||
std::unordered_map<const TensorView*, TvInfo> map_;
|
||||
|
|
@ -142,6 +235,12 @@ class TORCH_CUDA_CU_API DoubleBufferInfo {
|
|||
//! Keeps track of which concrete loop map is realizing double buffer
|
||||
//! iterdomains.
|
||||
std::unordered_set<const IterDomain*> concrete_double_buffered_loop_id_;
|
||||
|
||||
//! Keeps track of double buffer loop stage depth.
|
||||
//! Currently for each disjoint set of loop mapped iterdomains,
|
||||
//! Only one stage depth is supported, so that the loops can indeed
|
||||
//! shared with the same prolog extent and main loop offset.
|
||||
std::unordered_map<IterDomain*, unsigned int> stage_depth_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -252,8 +252,10 @@ class ExprSegmentationSorter {
|
|||
// Allocate an empty expr group and return it
|
||||
ExprGroup* makeEmptyGroup();
|
||||
|
||||
// Allocate an expr group with the provided expr and return it
|
||||
ExprGroup* makeEmptyGroup(Expr*);
|
||||
// Allocate an expr group with the provided expr and return it. Also requires
|
||||
// information on if this expression is a terminating expression (none of its
|
||||
// outputs are used in other expressions being sorted).
|
||||
ExprGroup* makeEmptyGroup(Expr*, bool terminating_expr);
|
||||
|
||||
// Returns if sg1 and sg2 should be merged together, is called if they can
|
||||
// based on the current status of the DAG.
|
||||
|
|
@ -538,14 +540,19 @@ ExprGroup* ExprSegmentationSorter::makeEmptyGroup() {
|
|||
return groups_.back().get();
|
||||
}
|
||||
|
||||
ExprGroup* ExprSegmentationSorter::makeEmptyGroup(Expr* expr) {
|
||||
ExprGroup* ExprSegmentationSorter::makeEmptyGroup(
|
||||
Expr* expr,
|
||||
bool terminating_expr) {
|
||||
auto group = makeEmptyGroup();
|
||||
group->exprs().push_back(expr);
|
||||
if (ir_utils::isTvOp(expr)) {
|
||||
auto out_tv = expr->outputs()[0]->as<TensorView>();
|
||||
// Grab all id's that are shared with other tensors.
|
||||
for (const auto tv_i : c10::irange(out_tv->getComputeAtPosition())) {
|
||||
group->payload()->ca_domains_.push_back(out_tv->axis(tv_i));
|
||||
// If not connected to consumers, doesn't mater what compute at is set to
|
||||
if (!terminating_expr) {
|
||||
for (const auto tv_i : c10::irange(out_tv->getComputeAtPosition())) {
|
||||
group->payload()->ca_domains_.push_back(out_tv->axis(tv_i));
|
||||
}
|
||||
}
|
||||
for (const auto tv_i : c10::irange(out_tv->getMaxProducerPosition())) {
|
||||
group->payload()->pa_domains_.push_back(out_tv->axis(tv_i));
|
||||
|
|
@ -1219,9 +1226,24 @@ void ExprSegmentationSorter::sort() {
|
|||
// Need this for initialization of the DAG that is processed
|
||||
std::unordered_map<Expr*, ExprGroup*> expr2group;
|
||||
|
||||
auto all_exprs = fusion_->exprs();
|
||||
|
||||
// Figure out all the values used as inputs to the expressions we're sorting
|
||||
// (to find terminating expressions). There could be branches of expressions
|
||||
// not used to produce outputs, so can't simply check val->uses() to figure
|
||||
// out if it's actually used in the expressions we're sorting.
|
||||
std::unordered_set<Val*> used_vals;
|
||||
for (auto expr : all_exprs) {
|
||||
used_vals.insert(expr->inputs().begin(), expr->inputs().end());
|
||||
}
|
||||
|
||||
// Initialize DAG, convert each expr to a segment group
|
||||
for (auto expr : fusion_->exprs()) {
|
||||
auto group = makeEmptyGroup(expr);
|
||||
for (auto expr : all_exprs) {
|
||||
bool is_terminating_expr = std::none_of(
|
||||
expr->outputs().begin(),
|
||||
expr->outputs().end(),
|
||||
[&used_vals](Val* output) { return used_vals.count(output) > 0; });
|
||||
auto group = makeEmptyGroup(expr, is_terminating_expr);
|
||||
expr2group.insert(std::make_pair(expr, group));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -834,6 +834,11 @@ void IndexLowering::handle(const kir::CpAsyncWait* wait) {
|
|||
pushBack(const_cast<kir::CpAsyncWait*>(wait)); // NOLINT
|
||||
}
|
||||
|
||||
void IndexLowering::handle(const kir::CpAsyncCommit* commit) {
|
||||
// TODO(kir): remove the need for const_cast
|
||||
pushBack(const_cast<kir::CpAsyncCommit*>(commit)); // NOLINT
|
||||
}
|
||||
|
||||
void IndexLowering::generate(const std::vector<Expr*>& exprs) {
|
||||
for (auto expr : exprs) {
|
||||
OptOutConstDispatch::handle(expr);
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch {
|
|||
void handle(const kir::BlockSync*) final;
|
||||
void handle(const kir::GridSync*) final;
|
||||
void handle(const kir::CpAsyncWait*) final;
|
||||
void handle(const kir::CpAsyncCommit*) final;
|
||||
|
||||
void generate(const std::vector<Expr*>& exprs);
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_index_compute.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_validation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
|
||||
|
|
@ -24,39 +25,6 @@ IndexFromIdGraph::IndexFromIdGraph(
|
|||
|
||||
namespace {
|
||||
|
||||
void insertMagicZero(
|
||||
const std::vector<kir::ForLoop*>& loops,
|
||||
const std::vector<IterDomain*>& loop_domains,
|
||||
std::unordered_map<IterDomain*, Val*>& concrete_loop_idx_map) {
|
||||
// Find magic zero insertion point
|
||||
IterDomain* magic_zero_loop = nullptr;
|
||||
|
||||
// Search for proper magic zero insertion point,
|
||||
// prefer innermost.
|
||||
for (auto idx : c10::irange(loops.size())) {
|
||||
auto loop = loops[idx];
|
||||
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
|
||||
loop_domains[idx], IdMappingMode::EXACT);
|
||||
auto loop_ind = concrete_loop_idx_map.at(concrete_loop_id);
|
||||
|
||||
// Save the concrete id if this loop id is decided to
|
||||
// be the insertion point by the magic zero util.
|
||||
if (Index::protectWithMagicZero(loop, concrete_loop_id, loop_ind)) {
|
||||
magic_zero_loop = concrete_loop_id;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert magic zero if insertion point found
|
||||
if (magic_zero_loop != nullptr &&
|
||||
concrete_loop_idx_map.count(magic_zero_loop)) {
|
||||
auto& ind = concrete_loop_idx_map.at(magic_zero_loop);
|
||||
if (!ind->isConstScalar()) {
|
||||
ind = SimplifyingIrBuilder::addExpr(
|
||||
ind, GpuLower::current()->kernel()->magicZeroVal());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Maps all producer domains to consumer with broadcast
|
||||
// forwarding. Used to find the allocation position.
|
||||
// TODO: should this be an ir_util ? Didn't seem to be
|
||||
|
|
@ -159,7 +127,7 @@ IndexingParameters getGlobalIndexParameters(
|
|||
index_parameters.concrete_id_to_halo_extent =
|
||||
GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing);
|
||||
|
||||
insertMagicZero(
|
||||
protectNonPredicateIndexWithMagicZero(
|
||||
loops,
|
||||
loop_indexing.loopDomains(),
|
||||
index_parameters.initial_concrete_id_index);
|
||||
|
|
@ -182,10 +150,13 @@ IndexingParameters getGlobalIndexParameters(
|
|||
|
||||
auto concrete_loop_id = ir_utils::caMapExactConcreteId(loop_id);
|
||||
|
||||
auto stage_depth =
|
||||
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
|
||||
loop->iter_domain());
|
||||
index_parameters.initial_concrete_id_index[concrete_loop_id] =
|
||||
SimplifyingIrBuilder::addExpr(
|
||||
index_parameters.initial_concrete_id_index[concrete_loop_id],
|
||||
GpuLower::current()->kernel()->oneVal());
|
||||
SimplifyingIrBuilder::create<Int>(stage_depth - 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -412,9 +383,12 @@ IndexingParameters getPredicateInitialIndexParameters(
|
|||
// be true that that index has been modified to support
|
||||
// unswitch. In that case, it is not necessary to move ahead the
|
||||
// index for double buffering.
|
||||
auto stage_depth =
|
||||
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
|
||||
db_loop->iter_domain());
|
||||
if (cur_index == db_loop->index()) {
|
||||
loop_to_ind_map[db_loop] = SimplifyingIrBuilder::addExpr(
|
||||
cur_index, GpuLower::current()->kernel()->oneVal());
|
||||
cur_index, SimplifyingIrBuilder::create<Int>(stage_depth - 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -428,10 +402,9 @@ IndexingParameters getPredicateInitialIndexParameters(
|
|||
loop_to_ind_map.at(loop);
|
||||
}
|
||||
|
||||
insertMagicZero(
|
||||
loops,
|
||||
loop_indexing.loopDomains(),
|
||||
index_parameters.initial_concrete_id_index);
|
||||
// Note that, unlike non-predicate indexing, magic-zero insertion is
|
||||
// not done at this point but is done individually for each indexed
|
||||
// domain. See Index::getReferenceRootPredicates.
|
||||
|
||||
// Derive the halo extents from the loop indexing result.
|
||||
index_parameters.concrete_id_to_halo_extent =
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/index_compute.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
|
|||
|
|
@ -124,13 +124,15 @@ CommonIndexKey::CommonIndexKey(
|
|||
if (it != concrete_leaf_ids.end()) {
|
||||
// This leaf reference id is used for indexing the consumer id
|
||||
used_loops_.push_back(loop);
|
||||
auto index_it =
|
||||
loop_index_map.find(gpu_lower->caMap()->getConcreteMappedID(
|
||||
loop_domains.at(i), IdMappingMode::EXACT));
|
||||
auto loop_concrete_id = gpu_lower->caMap()->getConcreteMappedID(
|
||||
loop_domains.at(i), IdMappingMode::EXACT);
|
||||
auto index_it = loop_index_map.find(loop_concrete_id);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
index_it != loop_index_map.end(),
|
||||
"Index not found for leaf ID, ",
|
||||
loop_domains.at(i)->toString());
|
||||
loop_domains.at(i)->toString(),
|
||||
", concrete ID: ",
|
||||
loop_concrete_id->toString());
|
||||
loop_index_vals_.push_back(index_it->second);
|
||||
}
|
||||
}
|
||||
|
|
@ -235,6 +237,7 @@ std::pair<Val*, bool> CommonIndexMap::insert(
|
|||
|
||||
const CommonIndexKey key(
|
||||
indexed_consumer_id, consumer_td, ref_td, ref_index_map, loops);
|
||||
|
||||
return tryInsertNewIndex(key, index);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -724,7 +724,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
|
|||
// here, except for the initial load part, which is taken care
|
||||
// separately by DoubleBufferInserter.
|
||||
if (tv->getMemoryType() == MemoryType::Shared &&
|
||||
!tv->isDoubleBuffered()) {
|
||||
!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
|
||||
smem[tv] = expr;
|
||||
|
||||
// only keep track of async writes in smem_async
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@
|
|||
|
||||
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/kernel_ir_dispatch.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_index_compute.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -88,6 +90,133 @@ bool isProtectedWithMagicZero(const Val* val) {
|
|||
return bop->getBinaryOpType() == BinaryOpType::Add && isMagicZero(bop->rhs());
|
||||
}
|
||||
|
||||
bool needsMagicZero(
|
||||
kir::ForLoop* loop,
|
||||
IterDomain* reference_domain,
|
||||
Val* ind) {
|
||||
if (ind->isConstScalar()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ref_dom_simple =
|
||||
reference_domain == nullptr || reference_domain->definition() != nullptr;
|
||||
bool ind_simple =
|
||||
ind == nullptr || (ind->definition() != nullptr && !ind->isZeroInt());
|
||||
return loop->isUnrolled() && (!ref_dom_simple || !ind_simple);
|
||||
}
|
||||
|
||||
void protectNonPredicateIndexWithMagicZero(
|
||||
const std::vector<kir::ForLoop*>& loops,
|
||||
const std::vector<IterDomain*>& loop_domains,
|
||||
std::unordered_map<IterDomain*, Val*>& concrete_loop_idx_map) {
|
||||
// Find magic zero insertion point
|
||||
IterDomain* magic_zero_loop = nullptr;
|
||||
|
||||
// Search for proper magic zero insertion point,
|
||||
// prefer innermost.
|
||||
for (auto idx : c10::irange(loops.size())) {
|
||||
auto loop = loops[idx];
|
||||
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
|
||||
loop_domains[idx], IdMappingMode::EXACT);
|
||||
auto loop_ind = concrete_loop_idx_map.at(concrete_loop_id);
|
||||
|
||||
// Save the concrete id if this loop id is decided to
|
||||
// be the insertion point by the magic zero util.
|
||||
if (needsMagicZero(loop, concrete_loop_id, loop_ind)) {
|
||||
magic_zero_loop = concrete_loop_id;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert magic zero if insertion point found
|
||||
if (magic_zero_loop != nullptr &&
|
||||
concrete_loop_idx_map.count(magic_zero_loop)) {
|
||||
auto& ind = concrete_loop_idx_map.at(magic_zero_loop);
|
||||
ind = SimplifyingIrBuilder::addExpr(
|
||||
ind, GpuLower::current()->kernel()->magicZeroVal());
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
//! Protect loop_index_to_protect appearing in overall_index_val
|
||||
IndexMagicZeroInfo protectIndexByReplacingLoopIndex(
|
||||
IterDomain* loop_id,
|
||||
Val* overall_index_val,
|
||||
Val* loop_index_to_protect) {
|
||||
auto protected_loop_index = SimplifyingIrBuilder::addExpr(
|
||||
loop_index_to_protect, GpuLower::current()->kernel()->magicZeroVal());
|
||||
|
||||
std::unordered_map<Val*, Val*> replacement_map;
|
||||
replacement_map[loop_index_to_protect] = protected_loop_index;
|
||||
|
||||
auto protected_index =
|
||||
ir_utils::replaceValInIndexVal(overall_index_val, replacement_map);
|
||||
|
||||
IndexMagicZeroInfo info;
|
||||
info.index = protected_index;
|
||||
info.original_loop_index = loop_index_to_protect;
|
||||
info.protected_loop_index = protected_loop_index;
|
||||
info.loop_id = loop_id;
|
||||
return info;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
IndexMagicZeroInfo protectPredicateIndexWithMagicZero(
|
||||
Val* index,
|
||||
const IndexFromIdGraph& id_graph,
|
||||
const std::vector<kir::ForLoop*>& loops) {
|
||||
// Gather the loop indices
|
||||
std::unordered_set<Val*> loop_indices;
|
||||
for (auto loop_id : id_graph.resolved_loop_domains) {
|
||||
auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
|
||||
loop_id, IdMappingMode::EXACT);
|
||||
auto index_it = id_graph.initial_concrete_index_map.find(concrete_loop_id);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
index_it != id_graph.initial_concrete_index_map.end(),
|
||||
"Index not found for loop: ",
|
||||
concrete_loop_id->toString());
|
||||
auto loop_index = index_it->second;
|
||||
loop_indices.insert(loop_index);
|
||||
}
|
||||
|
||||
// Figure out which loop indices are used in index
|
||||
const auto vals = DependencyCheck::getAllValsBetween(loop_indices, {index});
|
||||
|
||||
// Traverser from the inner-most loop and apply the magic-zero
|
||||
// prorection if needed
|
||||
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
|
||||
auto loop = loops.at(i);
|
||||
auto loop_id = id_graph.resolved_loop_domains.at(i);
|
||||
TORCH_INTERNAL_ASSERT(GpuLower::current()->caMap()->areMapped(
|
||||
loop_id, loop->iter_domain(), IdMappingMode::PERMISSIVE));
|
||||
IterDomain* concrete_loop_id =
|
||||
GpuLower::current()->caMap()->getConcreteMappedID(
|
||||
loop_id, IdMappingMode::EXACT);
|
||||
auto index_it = id_graph.initial_concrete_index_map.find(concrete_loop_id);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
index_it != id_graph.initial_concrete_index_map.end());
|
||||
auto loop_index = index_it->second;
|
||||
|
||||
const auto is_loop_index_used =
|
||||
std::find(vals.begin(), vals.end(), loop_index) != vals.end();
|
||||
|
||||
if (!is_loop_index_used) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (needsMagicZero(loop, concrete_loop_id, loop_index)) {
|
||||
return protectIndexByReplacingLoopIndex(loop_id, index, loop_index);
|
||||
}
|
||||
}
|
||||
|
||||
// No loop is identified to require protection with magic zero. Just
|
||||
// return the index argument as is
|
||||
IndexMagicZeroInfo not_proteced;
|
||||
not_proteced.index = index;
|
||||
return not_proteced;
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ namespace jit {
|
|||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
struct IndexFromIdGraph;
|
||||
|
||||
//! Insert magic zero definition at the begining of the kernel. Insert magic
|
||||
//! zero update after every (outer most) loop nest with a compile time extent.
|
||||
//!
|
||||
|
|
@ -17,13 +19,59 @@ namespace cuda {
|
|||
std::vector<Expr*> insertMagicZero(const std::vector<Expr*>& exprs);
|
||||
|
||||
//! Check if val is a reference to the magic zero variable
|
||||
bool isMagicZero(const Val* val);
|
||||
TORCH_CUDA_CU_API bool isMagicZero(const Val* val);
|
||||
|
||||
//! Check if val is protected with magic zero.
|
||||
//!
|
||||
//! Specifically, this returns true if val is defined as "x + magic_zero".
|
||||
bool isProtectedWithMagicZero(const Val* val);
|
||||
|
||||
// Determine if we may run into over reuse of predicates or registers in the
|
||||
// compiler. If the loop can be unrolled and the index and domain are not
|
||||
// "simple" we likely want the loop protected.
|
||||
//
|
||||
// Magic zero protection should only be done for global memory and predicates.
|
||||
// We should avoid use on registers. Shared memory does not require it, but
|
||||
// likely wouldn't hurt.
|
||||
bool needsMagicZero(
|
||||
kir::ForLoop* loop,
|
||||
IterDomain* reference_domain = nullptr,
|
||||
Val* ind = nullptr);
|
||||
|
||||
struct IndexMagicZeroInfo {
|
||||
//! Index that may be updated with magic zero
|
||||
Val* index = nullptr;
|
||||
//! Loop index that is protected by magic zero. nullptr if no loop
|
||||
//! is protected
|
||||
Val* original_loop_index = nullptr;
|
||||
//! Protected loop index. nullptr if no loop is protected
|
||||
Val* protected_loop_index = nullptr;
|
||||
//! Protected loop. nullptr if no loop is protected
|
||||
IterDomain* loop_id = nullptr;
|
||||
};
|
||||
|
||||
//! Protect an index val of an IterDomain with magic zero
|
||||
//!
|
||||
//! This should be only used for predicate indexing.
|
||||
//!
|
||||
//! No protection is done if none of the loops is determined to require
|
||||
//! protection by needsMagicZero.
|
||||
IndexMagicZeroInfo protectPredicateIndexWithMagicZero(
|
||||
Val* index,
|
||||
const IndexFromIdGraph& id_graph,
|
||||
const std::vector<kir::ForLoop*>& loops);
|
||||
|
||||
//! Protect an index val of a tensor with magic zero
|
||||
//!
|
||||
//! This should be only used for non-predicate indexing.
|
||||
//!
|
||||
//! No protection is done if none of the loops is determined to require
|
||||
//! protection by needsMagicZero.
|
||||
void protectNonPredicateIndexWithMagicZero(
|
||||
const std::vector<kir::ForLoop*>& loops,
|
||||
const std::vector<IterDomain*>& loop_domains,
|
||||
std::unordered_map<IterDomain*, Val*>& concrete_loop_idx_map);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -137,6 +137,16 @@ void SyncMap::build(Fusion* fusion) {
|
|||
->threadPredMap()
|
||||
.getPredicateInfo(producer)
|
||||
.redundant_types;
|
||||
// Get the parallel types that are inactive in consumer's use chains.
|
||||
auto producer_redundant_use_types = GpuLower::current()
|
||||
->threadPredMap()
|
||||
.getPredicateInfo(producer)
|
||||
.redundant_use_types;
|
||||
|
||||
// In sync info pass we only consider the parallel types in
|
||||
// producer that are redundantly produced but not redundantly consumed.
|
||||
producer_redundant_types =
|
||||
producer_redundant_types & (~producer_redundant_use_types);
|
||||
|
||||
for (const auto producer_i : c10::irange(producer->nDims())) {
|
||||
auto producer_axis = producer->axis(producer_i);
|
||||
|
|
@ -205,25 +215,24 @@ void SyncMap::build(Fusion* fusion) {
|
|||
continue;
|
||||
}
|
||||
|
||||
auto parallel_type_i = getParallelTypeBitMapOffset(parallel_type);
|
||||
|
||||
auto p_id = producer_parallel_ids[parallel_type_i];
|
||||
auto c_id = consumer_parallel_ids[parallel_type_i];
|
||||
|
||||
// If consumer is parallelized with this type but producer is
|
||||
// predicated redundant on this type. This parallel dimension
|
||||
// is a RAW dimension. See test: FusionSeriaSmemWriteParallelRead1/2
|
||||
//
|
||||
// Even if consumer is not parallelized with this type, would still
|
||||
// need a raw sync unless all use chain of the producer end with an
|
||||
// output with the same redundant type.
|
||||
// TODO: need a separate pass to detect the case where no raw sync
|
||||
// is needed in this case, i.e. all use-def chains are redundant.
|
||||
// In the case when the parallel id's are mapped by ca map,
|
||||
// will additionally need to consider if the producer is
|
||||
// a redundant write. The raw dim can be skipped only if
|
||||
// consumer use chains only contain redundant uses.
|
||||
// TODO:
|
||||
// still losing a bit precision here for expr ordering
|
||||
// sensitive cases, but we could wait until that becomes
|
||||
// a perf limiter to fix.
|
||||
if (producer_redundant_types.get(parallel_type)) {
|
||||
raw_dims.set(parallel_type);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto parallel_type_i = getParallelTypeBitMapOffset(parallel_type);
|
||||
|
||||
auto p_id = producer_parallel_ids[parallel_type_i];
|
||||
auto c_id = consumer_parallel_ids[parallel_type_i];
|
||||
|
||||
if (p_id == nullptr && c_id == nullptr) {
|
||||
continue;
|
||||
} else if (p_id != nullptr && c_id != nullptr) {
|
||||
|
|
|
|||
|
|
@ -210,11 +210,11 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
|
|||
|
||||
auto tv_inp = inp->as<TensorView>();
|
||||
|
||||
// Change for welford Op, we want the users of all outputs of welfordOp
|
||||
// to use a single predicate name.
|
||||
// If tv_inp was an output of a multi-output expression, just change it to a
|
||||
// consistent sibling to use a single predicate name.
|
||||
if (auto tv_def = tv_inp->definition()) {
|
||||
if (auto wop = dynamic_cast<WelfordOp*>(tv_def)) {
|
||||
tv_inp = wop->out()->as<TensorView>();
|
||||
if (tv_def->outputs().size() > 1) {
|
||||
tv_inp = ir_utils::getTvOutput(tv_def);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -285,6 +285,184 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
//! A simple backward data flow pass:
|
||||
//! This pass propagates information backward to annotate "redundant use
|
||||
//! chain"'s.
|
||||
//! The reason this is needed is that, say for example, if we have a chain
|
||||
//! of register-to-register ops that begins with a redundant shared mem write
|
||||
//! and ends with an op that non-redundantly uses the result, we'd need to
|
||||
//! insert a sync at the begining of the register-to-register chain.
|
||||
//!
|
||||
//! The same mechanism also applies in the case of a register/sharedmem chain
|
||||
//! that starts and ends with global memory read/write.
|
||||
//!
|
||||
//! The propagation rule is summarized as follows:
|
||||
//!
|
||||
//! Shared TV val:
|
||||
//! Reset all block redundant info to its own redundant write info
|
||||
//! Backpropagate grid redundant info
|
||||
//! Global TV val:
|
||||
//! Reset all redundant info to its own redundant write info
|
||||
//! Local Tv val:
|
||||
//! Backpropagate all redundant info
|
||||
//! Exprs:
|
||||
//! Propagate redundant info backwards from outputs to inputs:
|
||||
//! For each parallel type,
|
||||
//! The parallel type is redundantly used in the expr input
|
||||
//! only if all of the outputs redundantly use the same type.
|
||||
class RedundantUseAnalysis : BackwardVisitor {
|
||||
public:
|
||||
RedundantUseAnalysis(Fusion* fusion, const ThreadPredicateMap& pred_map)
|
||||
: fusion_(fusion), pred_map_(pred_map) {
|
||||
traverseFrom(fusion, fusion->terminatingMathVals());
|
||||
}
|
||||
|
||||
//! Returns a bit map signifying the parallel dimensions
|
||||
//! on which the given tv is redundantly used. On these
|
||||
//! dimensions not all threads/blocks are required to
|
||||
//! hold valid value for their dependent computations.
|
||||
ParallelTypeBitmap getRedundantUseBitMap(const TensorView* tv) {
|
||||
// Since all tv's consumers are visited at this point, we
|
||||
// can aggregate the final redundant use info for this tv.
|
||||
if (fusion_->unordered_uses(tv).empty()) {
|
||||
// Base case, un-used is also not redundantly used
|
||||
return ParallelTypeBitmap();
|
||||
} else {
|
||||
// Aggregate redundant use as a conjunction of all
|
||||
// consumer's redundant consumer info propagated
|
||||
// backward from their consumer chains.
|
||||
ParallelTypeBitmap redundant_use;
|
||||
redundant_use.setAllBID();
|
||||
redundant_use.setAllTID();
|
||||
for (auto expr : fusion_->unordered_uses(tv)) {
|
||||
redundant_use &= redundant_expr_use_map_.at(expr);
|
||||
}
|
||||
|
||||
return redundant_use;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
using BackwardVisitor::handle;
|
||||
|
||||
void handle(TensorView* tv) final {
|
||||
auto redundant_tv_map = pred_map_.getPredicateInfo(tv).redundant_types;
|
||||
|
||||
// Setup the info to propagate backward for the producer tv's and
|
||||
// expressions.
|
||||
ParallelTypeBitmap& redundant_consumer_map =
|
||||
redundant_consumer_parallel_type_map_[tv];
|
||||
|
||||
// Initialize the use map to the redundant pred result
|
||||
redundant_consumer_map = redundant_tv_map;
|
||||
|
||||
if (tv->getMemoryType() == MemoryType::Shared) {
|
||||
backPropagateRedundantUse(
|
||||
redundant_consumer_map,
|
||||
tv,
|
||||
false, // no propagate TID redundant use for shared tv
|
||||
true // propagate BID redundant use
|
||||
);
|
||||
|
||||
} else if (tv->getMemoryType() == MemoryType::Local) {
|
||||
backPropagateRedundantUse(
|
||||
redundant_consumer_map,
|
||||
tv,
|
||||
true, // propagate TID redundant use
|
||||
true // propagate BID redundant use
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void backPropagateRedundantUse(
|
||||
ParallelTypeBitmap& use_map,
|
||||
TensorView* tv,
|
||||
bool propagate_tid,
|
||||
bool propagate_bid) {
|
||||
// Clear the propagated part of the original result
|
||||
if (propagate_bid) {
|
||||
use_map.setAllBID();
|
||||
}
|
||||
if (propagate_tid) {
|
||||
use_map.setAllTID();
|
||||
}
|
||||
|
||||
for (auto expr : fusion_->unordered_uses(tv)) {
|
||||
// Assuming all consumer expressions have been
|
||||
// visited at this point since we are traversing
|
||||
// backward.
|
||||
auto expr_use_map = redundant_expr_use_map_.at(expr);
|
||||
// Clear the part of expression use map that does not
|
||||
// need to be propagated.
|
||||
if (!propagate_bid) {
|
||||
expr_use_map.setAllBID();
|
||||
}
|
||||
if (!propagate_tid) {
|
||||
expr_use_map.setAllTID();
|
||||
}
|
||||
|
||||
// Accumulate expression redundant usage
|
||||
// This implements the `only if all` part in
|
||||
// the discussion above.
|
||||
use_map &= expr_use_map;
|
||||
}
|
||||
}
|
||||
|
||||
void handle(Expr* expr) final {
|
||||
if (ir_utils::isTvOp(expr)) {
|
||||
// Initialize redundant info for current expr
|
||||
c10::optional<ParallelTypeBitmap> maybe_expr_pred_map;
|
||||
|
||||
for (auto consumer_tv :
|
||||
ir_utils::filterByType<TensorView>(expr->outputs())) {
|
||||
auto tv_redundant_bitmap =
|
||||
redundant_consumer_parallel_type_map_.at(consumer_tv);
|
||||
|
||||
if (maybe_expr_pred_map.has_value()) {
|
||||
// Accumulate redundant info of this tv output.
|
||||
maybe_expr_pred_map.value() &= tv_redundant_bitmap;
|
||||
} else {
|
||||
// Copy the tv's redundant info as the first valid case.
|
||||
maybe_expr_pred_map = tv_redundant_bitmap;
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
maybe_expr_pred_map.has_value(), "TV op not having a tv output");
|
||||
redundant_expr_use_map_[expr] = maybe_expr_pred_map.value();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Populated redundant use information on the used tv's
|
||||
// This map provides information on if the given tv does not require
|
||||
// valid data from its producer on any parallel dimensions.
|
||||
// For example:
|
||||
// T1_local = T0_shared[...]
|
||||
// if(tid.x == 0)
|
||||
// T2_shared[...] = T1_local[...]
|
||||
// Then tidx would be redundant consumer parallel type
|
||||
// for T1, as T1 is local tensor, and only threads satisfying
|
||||
// tidx == 0 would need to provide a valid data.
|
||||
// In this case, not all threads would need to read correct data
|
||||
// from T0_shared, which would help remove some sync's.
|
||||
std::unordered_map<const TensorView*, ParallelTypeBitmap>
|
||||
redundant_consumer_parallel_type_map_;
|
||||
|
||||
// Populated redundant use information on the used tv expressions.
|
||||
std::unordered_map<const Expr*, ParallelTypeBitmap> redundant_expr_use_map_;
|
||||
|
||||
// Short cut to the owning fusion of this analysis.
|
||||
Fusion* fusion_ = nullptr;
|
||||
|
||||
// Short cut to the active pred map analysis this pass is running as part of.
|
||||
const ThreadPredicateMap& pred_map_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void ThreadPredicateMap::build(Fusion* fusion) {
|
||||
FUSER_PERF_SCOPE("GpuLower::Lower::ThreadPredicateMap");
|
||||
|
||||
|
|
@ -298,6 +476,15 @@ void ThreadPredicateMap::build(Fusion* fusion) {
|
|||
updateBitSet(expr);
|
||||
}
|
||||
updated_tvs_.clear();
|
||||
populateRedundantUseMap(fusion);
|
||||
}
|
||||
|
||||
void ThreadPredicateMap::populateRedundantUseMap(Fusion* fusion) {
|
||||
RedundantUseAnalysis redundant_use(fusion, *this);
|
||||
for (auto& it : thread_predicates_) {
|
||||
it.second.redundant_use_types =
|
||||
redundant_use.getRedundantUseBitMap(it.first);
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPredicateMap::const_iterator ThreadPredicateMap::find(
|
||||
|
|
@ -399,6 +586,23 @@ ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains(
|
|||
return parallel_broadcast & at(tv).limited_types;
|
||||
}
|
||||
|
||||
ParallelTypeBitmap ThreadPredicateMap::getRedundantConsumerType(
|
||||
Expr* expr) const {
|
||||
c10::optional<ParallelTypeBitmap> result;
|
||||
for (auto out_tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
|
||||
auto out_tv_redundant_map = getPredicateInfo(out_tv).redundant_use_types;
|
||||
if (!result.has_value()) {
|
||||
result = out_tv_redundant_map;
|
||||
} else {
|
||||
result.value() &= out_tv_redundant_map;
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
result.has_value(), "ThreadPredicateMap : TV op assumed");
|
||||
return result.value();
|
||||
}
|
||||
|
||||
void ThreadPredicateMap::markAsUpdated(const TensorView* tv) {
|
||||
updated_tvs_.insert(tv);
|
||||
}
|
||||
|
|
@ -410,6 +614,7 @@ void ThreadPredicateMap::print() const {
|
|||
std::cout << "T" << kv.first->name();
|
||||
std::cout << " {" << kv.second.limited_types.toString() << "}\n";
|
||||
std::cout << "{" << kv.second.redundant_types.toString() << "}\n";
|
||||
std::cout << "{" << kv.second.redundant_use_types.toString() << "}\n";
|
||||
}
|
||||
std::cout << "--------------------------------\n\n";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,9 +48,24 @@ class TORCH_CUDA_CU_API ThreadPredicateMap {
|
|||
ParallelTypeBitmap limited_types;
|
||||
// Parallel types where only one thread/block is enough.
|
||||
ParallelTypeBitmap redundant_types;
|
||||
// Tracking use chain of redundant writes:
|
||||
// [Redundant use chain]
|
||||
// a parallel type is a `redundant_consumer_type` only
|
||||
// if all of its propagation use chains terminate with
|
||||
// a redundant write of this type.
|
||||
// A propagation use chain is currently either a reg-to-reg
|
||||
// chain for a shared mem tv, or a reg/smem-to-reg/smem chain
|
||||
// for a global tv.
|
||||
// This is complementary information to `redundant_types`.
|
||||
// If a tensor view is redundantly written and not redundantly
|
||||
// used by all consumers, see FusionRedundantPredSync3,
|
||||
// a RAW sync will need to be inserted before reading
|
||||
// this redundantly written tensor.
|
||||
ParallelTypeBitmap redundant_use_types;
|
||||
bool operator==(const PredicateInfo& other) const {
|
||||
return limited_types == other.limited_types &&
|
||||
redundant_types == other.redundant_types;
|
||||
redundant_types == other.redundant_types &&
|
||||
redundant_use_types == other.redundant_use_types;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -92,6 +107,9 @@ class TORCH_CUDA_CU_API ThreadPredicateMap {
|
|||
static Bool* getPredicateFromPredicateInfo(
|
||||
const ThreadPredicateMap::PredicateInfo& pred_info);
|
||||
|
||||
//! Get the redundant use types of the given expr, see [Redundant use chain]
|
||||
ParallelTypeBitmap getRedundantConsumerType(Expr* expr) const;
|
||||
|
||||
private:
|
||||
// Update the thread_predicates bitset based on provided Expr
|
||||
void updateBitSet(const Expr*);
|
||||
|
|
@ -111,6 +129,10 @@ class TORCH_CUDA_CU_API ThreadPredicateMap {
|
|||
//! Update a mapping
|
||||
bool update(const TensorView* tv, const PredicateInfo& pred_and_src);
|
||||
|
||||
//! Backward populate redundant use chain info once the redundant
|
||||
//! parallel writes have been identified.
|
||||
void populateRedundantUseMap(Fusion* fusion);
|
||||
|
||||
private:
|
||||
MapType thread_predicates_;
|
||||
//! Keep track of updated tensors that need predicates to be computed
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_trivial_reductions.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
|
||||
#include <unordered_set>
|
||||
|
|
@ -23,29 +24,39 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
|
|||
TORCH_INTERNAL_ASSERT(
|
||||
root_id->definition() == nullptr, "Not root IterDomain: ", root_id);
|
||||
|
||||
if (tv->definition() == nullptr) {
|
||||
auto def = tv->definition();
|
||||
|
||||
if (def == nullptr) {
|
||||
// This is an input tensor, so no rfactor tensor to traverse.
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& inputs = tv->definition()->inputs();
|
||||
|
||||
// Check the reduction expression that produces tv
|
||||
if (inputs.size() != 1 || !inputs[0]->isA<TensorView>() ||
|
||||
(tv->definition()->getExprType() != ExprType::ReductionOp &&
|
||||
tv->definition()->getExprType() != ExprType::WelfordOp)) {
|
||||
// No rfactor producer found
|
||||
if (!ir_utils::isReductionOp(def) || def->isA<MmaOp>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto producer = inputs[0]->as<TensorView>();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
def->inputs().size() == def->outputs().size(),
|
||||
"This logic block assumes number of inputs is the same as number of outputs of reduction ops.");
|
||||
|
||||
if (!producer->hasRFactor()) {
|
||||
// Reduction expr may have multiple inputs, just grab any TV
|
||||
// input. Note that in theory it is possible that a
|
||||
// GroupedReductionOp has rfactor inputs as well as non-rfactor
|
||||
// inputs, so grabbing the one that actually corresponds to tv can
|
||||
// be important. In reality, though, such a GroupedReductionOp
|
||||
// should not happen as we do not group reductions of rfactor and
|
||||
// non-rfactor tensor.
|
||||
auto producer_tv = ir_utils::getTvInput(def);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(producer_tv != nullptr);
|
||||
|
||||
if (!producer_tv->hasRFactor()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto c2p = PairwiseRootDomainMap(producer, tv)
|
||||
.mapConsumerToProducer(tv->domain(), producer->domain());
|
||||
auto c2p = PairwiseRootDomainMap(producer_tv, tv)
|
||||
.mapConsumerToProducer(tv->domain(), producer_tv->domain());
|
||||
|
||||
auto producer_id_it = c2p.find(root_id);
|
||||
if (producer_id_it == c2p.end()) {
|
||||
|
|
@ -55,7 +66,7 @@ bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
|
|||
|
||||
auto producer_root_id = producer_id_it->second;
|
||||
|
||||
return analyzeIfDerivedFromTrivialReduction(producer, producer_root_id);
|
||||
return analyzeIfDerivedFromTrivialReduction(producer_tv, producer_root_id);
|
||||
}
|
||||
|
||||
bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) {
|
||||
|
|
@ -109,11 +120,6 @@ bool TrivialReductionInfo::isDerived(IterDomain* id) const {
|
|||
return domains_.find(id) != domains_.end();
|
||||
}
|
||||
|
||||
bool TrivialReductionInfo::isDerivedFromRoot(IterDomain* id) const {
|
||||
return domains_derived_from_root_.find(id) !=
|
||||
domains_derived_from_root_.end();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -21,9 +21,6 @@ class TORCH_CUDA_CU_API TrivialReductionInfo {
|
|||
|
||||
bool isDerived(IterDomain* id) const;
|
||||
|
||||
// TODO: Not used, cleanup
|
||||
bool isDerivedFromRoot(IterDomain* id) const;
|
||||
|
||||
private:
|
||||
//! IterDomains that are derived only from trivial
|
||||
//! reductons. Included domains are not limited to reduction axes as
|
||||
|
|
|
|||
|
|
@ -183,14 +183,13 @@ TensorView* getTvOutput(const Expr* expr) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
bool isReductionOp(const Expr* expr) {
|
||||
// Note that GridReduction inherits ReductionOp
|
||||
return expr->isA<ReductionOp>() || expr->isA<GroupedReductionOp>() ||
|
||||
expr->isA<WelfordOp>() || expr->isA<kir::GridWelford>();
|
||||
}
|
||||
|
||||
bool isReductionTvOp(const Expr* expr) {
|
||||
return isTvOp(expr) && isReductionOp(expr);
|
||||
TensorView* getTvInput(const Expr* expr) {
|
||||
for (auto inp : expr->inputs()) {
|
||||
if (auto tv = getTv(inp)) {
|
||||
return tv;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool isScalarOp(const Expr* expr) {
|
||||
|
|
@ -483,7 +482,7 @@ BasicAllocInfo getAllocInformation(
|
|||
|
||||
// Allocation of a double buffered tensor is placed outside its
|
||||
// double buffer axis.
|
||||
if (tv->isDoubleBuffered() &&
|
||||
if ((tv->isDoubleBuffered() || tv->isCircularBuffered()) &&
|
||||
tv->axis(info.alloc_pos) ==
|
||||
gpu_lower->doubleBufferInfo().getDoubleBufferAxis(tv)) {
|
||||
outer_alloc_found = true;
|
||||
|
|
|
|||
|
|
@ -79,11 +79,8 @@ TORCH_CUDA_CU_API bool isTvOp(const Expr*);
|
|||
// Returns the first output of Expr that is a TensorView
|
||||
TORCH_CUDA_CU_API TensorView* getTvOutput(const Expr*);
|
||||
|
||||
// Returns if Expr is a reduction op
|
||||
TORCH_CUDA_CU_API bool isReductionOp(const Expr*);
|
||||
|
||||
// Returns if Expr is a reduction op with TensorView or TensorIndex
|
||||
TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*);
|
||||
// Returns the first input of Expr that is a TensorView
|
||||
TORCH_CUDA_CU_API TensorView* getTvInput(const Expr*);
|
||||
|
||||
bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map);
|
||||
|
||||
|
|
|
|||
|
|
@ -899,22 +899,42 @@ void validateMmaTensors(MmaOp* mma) {
|
|||
}
|
||||
|
||||
// Note: this check will be relaxed in a follow up.
|
||||
auto validate_operand_ids = [](const TensorView* tv) {
|
||||
auto validate_operand = [](const TensorView* tv) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tv->getMemoryType() == MemoryType::Local,
|
||||
"Only supporting register input for mma ops, up to sm80 all mma ops have to take register inputs.");
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
std::all_of(
|
||||
tv->domain()->domain().begin() + tv->getComputeAtPosition(),
|
||||
tv->domain()->domain().end(),
|
||||
[](IterDomain* id) {
|
||||
return id->isMmaSwizzled() ||
|
||||
(id->isBroadcast() &&
|
||||
// MMA instructions can only take inputs from registers,
|
||||
// so we always assume mma op inputs are located on
|
||||
// registers.
|
||||
// Currently requiring that serial ids on the right of the
|
||||
// CA axis are constant sized to ensure early detection of
|
||||
// invalid mma schedules.
|
||||
((id->isBroadcast() || id->extent()->isConstInt()) &&
|
||||
id->getParallelType() == ParallelType::Serial);
|
||||
}),
|
||||
"All id's on the right of CA pos needs to be mma-swizzled by WarpMmaSwizzler\n",
|
||||
tv);
|
||||
};
|
||||
|
||||
validate_operand_ids(mma->inA()->as<TensorView>());
|
||||
validate_operand_ids(mma->inB()->as<TensorView>());
|
||||
validate_operand(mma->inA()->as<TensorView>());
|
||||
validate_operand(mma->inB()->as<TensorView>());
|
||||
|
||||
// Additionally validate that mma is not directly taking a double buffered
|
||||
// register input as the double buffer indexing is currently not compatible
|
||||
// with fragment iteration. Would need to require a cache stage in this case.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!mma->inA()->as<TensorView>()->isDoubleBuffered(),
|
||||
"MMA op cannot directly take double buffered register input, put a set stage before.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!mma->inB()->as<TensorView>()->isDoubleBuffered(),
|
||||
"MMA op cannot directly take double buffered register input, put a set stage before.");
|
||||
}
|
||||
|
||||
//! Note and TODO:
|
||||
|
|
@ -1011,6 +1031,7 @@ void validateMma(Fusion* fusion) {
|
|||
validateMinimumArch(7, 0);
|
||||
break;
|
||||
case MmaOptions::MacroType::Turing_16_8_16:
|
||||
case MmaOptions::MacroType::Turing_16_16_16:
|
||||
validateMinimumArch(7, 5);
|
||||
|
||||
// Check that operands come from ldmatrix, can be
|
||||
|
|
@ -1019,6 +1040,7 @@ void validateMma(Fusion* fusion) {
|
|||
validateTuringMmaInput(mma->inB()->as<TensorView>());
|
||||
break;
|
||||
case MmaOptions::MacroType::Ampere_16_8_16:
|
||||
case MmaOptions::MacroType::Ampere_16_16_16:
|
||||
validateMinimumArch(8, 0);
|
||||
|
||||
// Check that operands come from ldmatrix, can be
|
||||
|
|
@ -1037,17 +1059,76 @@ void validateMma(Fusion* fusion) {
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Utility function to validate a loop swizzle:
|
||||
// 1. Throws an error if any output of the swizzle is not in leaf_domain set.
|
||||
// 2. Warns if any output of the swizzle is not the concrete id of the loop
|
||||
// map.
|
||||
// The second case would make the codegen ignore this swizzle, as if it was not
|
||||
// there at all.
|
||||
void validateLoopSwizzle(
|
||||
Expr* swizzle_expr,
|
||||
std::unordered_set<IterDomain*>& leaf_domains) {
|
||||
for (auto out_id :
|
||||
ir_utils::filterByType<IterDomain>(swizzle_expr->outputs())) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
leaf_domains.count(out_id),
|
||||
"Loop swizzle can only be direct producer of leaf domains.");
|
||||
if (GpuLower::current()->caMap()->getConcreteMappedID(
|
||||
out_id, IdMappingMode::LOOP) != out_id) {
|
||||
TORCH_WARN_ONCE("Ignored loop swizzle :", swizzle_expr->toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void validateSwizzle(Fusion* fusion) {
|
||||
auto used_vals = fusion->usedMathVals();
|
||||
for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
|
||||
if (tv->hasSwizzleOp()) {
|
||||
std::unordered_set<IterDomain*> tv_leaf_domain_set(
|
||||
tv->domain()->domain().begin(), tv->domain()->domain().end());
|
||||
|
||||
// Make sure no swizzle op is inlined:
|
||||
auto inlined_swizzles = ir_utils::getAllSwizzlesBetween(
|
||||
tv->getMaybeRFactorDomain(),
|
||||
{tv->domain()->domain().begin(),
|
||||
tv->domain()->domain().begin() + tv->getComputeAtPosition()});
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
inlined_swizzles.empty(), "No support for inlined swizzles");
|
||||
|
||||
auto not_inlined_swizzles = ir_utils::getAllSwizzlesBetween(
|
||||
tv->getMaybeRFactorDomain(),
|
||||
{tv->domain()->domain().begin() + tv->getComputeAtPosition(),
|
||||
tv->domain()->domain().end()});
|
||||
|
||||
// Check inlined swizzles: only loop swizzles can be inlined currently
|
||||
// as inlining data swizzles would require addtional support of unswizzle
|
||||
// operator, which currently doesn't have important use cases.
|
||||
for (auto swizzle_expr : inlined_swizzles) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
swizzle_expr->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Loop,
|
||||
"Only support inlining loop swizzles");
|
||||
validateLoopSwizzle(swizzle_expr, tv_leaf_domain_set);
|
||||
}
|
||||
|
||||
std::unordered_set<Expr*> inlined_swizzle_set(
|
||||
inlined_swizzles.begin(), inlined_swizzles.end());
|
||||
|
||||
// Check not inlined swizzles:
|
||||
// Apply the loop swizzle check when it applies, and
|
||||
// also make sure that the no swizzle is also in inlined_swizzle set.
|
||||
// The latter would mean that one output of the swizzle is inlined while
|
||||
// the other is not. Such case will not be supported.
|
||||
for (auto swizzle_expr : not_inlined_swizzles) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!inlined_swizzle_set.count(swizzle_expr),
|
||||
"Cannot partially inline across swizzle domains.",
|
||||
swizzle_expr->toString());
|
||||
if (swizzle_expr->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Loop) {
|
||||
validateLoopSwizzle(swizzle_expr, tv_leaf_domain_set);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,6 +135,7 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) {
|
|||
if (path_.empty()) {
|
||||
compute_spanning_tree();
|
||||
}
|
||||
propagator->setUp();
|
||||
for (const auto& next_hop : path_) {
|
||||
switch (next_hop.type) {
|
||||
case NextHopType::SIBLING:
|
||||
|
|
@ -148,6 +149,7 @@ void MaxInfoSpanningTree::traverse(Propagator* propagator) {
|
|||
break;
|
||||
}
|
||||
}
|
||||
propagator->tearDown();
|
||||
}
|
||||
|
||||
MaxRootDomainInfoSpanningTree::RootDomainInfo::operator bool() const {
|
||||
|
|
@ -422,6 +424,18 @@ void SpanningTreePrinter::propagateSibling(TensorView* from, TensorView* to) {
|
|||
stream_ << " to: " << to->toString() << std::endl;
|
||||
}
|
||||
|
||||
bool SetSelector::allowC2P(TensorView* from, TensorView* to) {
|
||||
return selected_.count(to) > 0;
|
||||
}
|
||||
|
||||
bool SetSelector::allowP2C(TensorView* from, TensorView* to) {
|
||||
return selected_.count(to) > 0;
|
||||
}
|
||||
|
||||
bool SetSelector::allowSibling(TensorView* from, TensorView* to) {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class TORCH_CUDA_CU_API MaxInfoSpanningTree {
|
|||
|
||||
// This is the interface to implement the actual propagation
|
||||
struct Propagator {
|
||||
virtual void setUp() {}
|
||||
virtual void tearDown() {}
|
||||
virtual void propagateC2P(TensorView* from, TensorView* to) = 0;
|
||||
virtual void propagateP2C(TensorView* from, TensorView* to) = 0;
|
||||
virtual void propagateSibling(TensorView* from, TensorView* to) = 0;
|
||||
|
|
@ -254,6 +256,25 @@ class TORCH_CUDA_CU_API SpanningTreePrinter
|
|||
SpanningTreePrinter(std::ostream& stream = std::cout) : stream_(stream) {}
|
||||
};
|
||||
|
||||
// Simple selector for selecting subgraphs to build spanning trees. The selector
|
||||
// allows propagation only to the given set of selected tensorviews, except for
|
||||
// sibiling propagation, which we should never block.
|
||||
class TORCH_CUDA_CU_API SetSelector : public MaxInfoSpanningTree::Selector {
|
||||
std::unordered_set<TensorView*> selected_;
|
||||
|
||||
public:
|
||||
virtual bool allowC2P(TensorView* from, TensorView* to) override;
|
||||
virtual bool allowP2C(TensorView* from, TensorView* to) override;
|
||||
virtual bool allowSibling(TensorView* from, TensorView* to) override;
|
||||
|
||||
SetSelector(std::unordered_set<TensorView*> selected)
|
||||
: selected_(std::move(selected)) {}
|
||||
|
||||
const std::unordered_set<TensorView*>& selected() const {
|
||||
return selected_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -7,6 +7,16 @@ namespace jit {
|
|||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
MmaOp* MmaOptions::mmaOp() const {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
accumulator_tv != nullptr && accumulator_tv->definition() != nullptr,
|
||||
"Invalid accumulator_tv.");
|
||||
auto mma_op = dynamic_cast<MmaOp*>(accumulator_tv->definition());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
mma_op != nullptr, "accumulator tv not an output of mma op");
|
||||
return mma_op;
|
||||
}
|
||||
|
||||
MmaBuilder::MmaBuilder(
|
||||
MmaOptions::MacroType macro,
|
||||
MatMulTileOptions gemm_tile) {
|
||||
|
|
@ -22,6 +32,10 @@ MmaBuilder::MmaBuilder(
|
|||
case MmaOptions::MacroType::Ampere_16_8_16:
|
||||
option_.accumulator_stride = outer_stride * 2;
|
||||
break;
|
||||
case MmaOptions::MacroType::Ampere_16_16_16:
|
||||
case MmaOptions::MacroType::Turing_16_16_16:
|
||||
option_.accumulator_stride = outer_stride * 4;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "unsupported macro");
|
||||
break;
|
||||
|
|
@ -41,7 +55,7 @@ MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) {
|
|||
// TODO: validate op config
|
||||
MmaOptions MmaBuilder::build() const {
|
||||
TORCH_CHECK(
|
||||
option_.mma_op != nullptr,
|
||||
option_.accumulator_tv != nullptr,
|
||||
"Please configure accumulator tv before using swizzle options.")
|
||||
return option_;
|
||||
}
|
||||
|
|
@ -60,9 +74,10 @@ void MmaBuilder::accumulatorTv(TensorView* tv) {
|
|||
TORCH_CHECK(
|
||||
tv->getMemoryType() == MemoryType::Local, "Mma only outputs to register");
|
||||
TORCH_CHECK(tv->definition(), "Input cannot be accumulator tv");
|
||||
auto mma = dynamic_cast<MmaOp*>(tv->definition());
|
||||
TORCH_CHECK(mma, "Requires mma op output for reduction tv");
|
||||
option_.mma_op = mma;
|
||||
TORCH_CHECK(
|
||||
tv->definition()->isA<MmaOp>(),
|
||||
"Requires mma op output for reduction tv");
|
||||
option_.accumulator_tv = tv;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
@ -73,6 +88,8 @@ LoadStoreOpType getLdMatrixType(MmaOptions options) {
|
|||
switch (options.macro) {
|
||||
case MmaOptions::MacroType::Turing_16_8_16:
|
||||
case MmaOptions::MacroType::Ampere_16_8_16:
|
||||
case MmaOptions::MacroType::Ampere_16_16_16:
|
||||
case MmaOptions::MacroType::Turing_16_16_16:
|
||||
// Turing mma assumes TN as default
|
||||
transpose = (options.operand == MmaOptions::Operand::A &&
|
||||
!isOperandTransposed(options)) ||
|
||||
|
|
@ -98,16 +115,20 @@ bool isVolta(MmaOptions::MacroType macro) {
|
|||
}
|
||||
|
||||
bool isTuring(MmaOptions::MacroType macro) {
|
||||
return macro == MmaOptions::MacroType::Turing_16_8_16;
|
||||
return macro == MmaOptions::MacroType::Turing_16_8_16 ||
|
||||
macro == MmaOptions::MacroType::Turing_16_16_16;
|
||||
}
|
||||
|
||||
bool isAmpere(MmaOptions::MacroType macro) {
|
||||
return macro == MmaOptions::MacroType::Ampere_16_8_16;
|
||||
return macro == MmaOptions::MacroType::Ampere_16_8_16 ||
|
||||
macro == MmaOptions::MacroType::Ampere_16_16_16;
|
||||
}
|
||||
|
||||
int getOutputRegisterSize(MmaOptions::MacroType macro) {
|
||||
switch (macro) {
|
||||
case MmaOptions::MacroType::Volta_16_16_4:
|
||||
case MmaOptions::MacroType::Ampere_16_16_16:
|
||||
case MmaOptions::MacroType::Turing_16_16_16:
|
||||
return 8;
|
||||
break;
|
||||
case MmaOptions::MacroType::Turing_16_8_16:
|
||||
|
|
@ -127,7 +148,9 @@ int getInputARegisterSize(MmaOptions::MacroType macro) {
|
|||
return 4;
|
||||
break;
|
||||
case MmaOptions::MacroType::Turing_16_8_16:
|
||||
case MmaOptions::MacroType::Turing_16_16_16:
|
||||
case MmaOptions::MacroType::Ampere_16_8_16:
|
||||
case MmaOptions::MacroType::Ampere_16_16_16:
|
||||
return 8;
|
||||
break;
|
||||
default:
|
||||
|
|
@ -145,6 +168,9 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) {
|
|||
case MmaOptions::MacroType::Turing_16_8_16:
|
||||
case MmaOptions::MacroType::Ampere_16_8_16:
|
||||
return 4;
|
||||
case MmaOptions::MacroType::Turing_16_16_16:
|
||||
case MmaOptions::MacroType::Ampere_16_16_16:
|
||||
return 8;
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "unknown macro");
|
||||
break;
|
||||
|
|
@ -197,6 +223,10 @@ std::string toString(MmaOptions::MacroType mt) {
|
|||
case MmaOptions::MacroType::Ampere_16_8_16:
|
||||
ss << "M16N8K16";
|
||||
break;
|
||||
case MmaOptions::MacroType::Turing_16_16_16:
|
||||
case MmaOptions::MacroType::Ampere_16_16_16:
|
||||
ss << "M16N16K16";
|
||||
break;
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "undefined mma type");
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ struct GemmTile {
|
|||
GemmTile operator/(const GemmTile& other) {
|
||||
return GemmTile(m / other.m, n / other.n, k / other.k);
|
||||
}
|
||||
|
||||
std::vector<int> toVector() {
|
||||
return {m, n, k};
|
||||
}
|
||||
};
|
||||
|
||||
//! Utility data structure for recording gemm tiles
|
||||
|
|
@ -58,7 +62,9 @@ struct MmaOptions {
|
|||
NoMMA = 0,
|
||||
Volta_16_16_4,
|
||||
Ampere_16_8_16,
|
||||
Ampere_16_16_16,
|
||||
Turing_16_8_16,
|
||||
Turing_16_16_16,
|
||||
Ampere_16_8_8 // place holder for tf32
|
||||
};
|
||||
|
||||
|
|
@ -95,8 +101,18 @@ struct MmaOptions {
|
|||
accumulator_stride == other.accumulator_stride;
|
||||
}
|
||||
|
||||
// To be inferred by mma builder interface.
|
||||
MmaOp* mma_op = nullptr;
|
||||
// The accumulator tensorview register supplied by the
|
||||
// scheduler interface. Each mma builder is responsible
|
||||
// for the parameters of one mma op, so the options struct
|
||||
// would need a pointer to keep track of which mma op it
|
||||
// is describing.
|
||||
// Tracking mma expressions would not be stable as the expression
|
||||
// can get deleted by mutate passes.
|
||||
TensorView* accumulator_tv = nullptr;
|
||||
|
||||
//! Returns the mma op that this options parameter list
|
||||
//! is describing. See comment on accumulator_tv.
|
||||
MmaOp* mmaOp() const;
|
||||
};
|
||||
|
||||
//! User interface for configuring the mma and mma related
|
||||
|
|
|
|||
|
|
@ -477,6 +477,9 @@ void OptOutMutator::mutate(kir::GridSync*) {
|
|||
void OptOutMutator::mutate(kir::CpAsyncWait*) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
|
||||
}
|
||||
void OptOutMutator::mutate(kir::CpAsyncCommit*) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
|
||||
}
|
||||
void OptOutMutator::mutate(kir::InitMagicZero*) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Not implemented yet.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,6 +184,13 @@ TensorView* tanh_backward(TensorView* dy, TensorView* tanh_x) {
|
|||
return dx;
|
||||
}
|
||||
|
||||
TensorView* leaky_relu(TensorView* x, Val* negative_slope) {
|
||||
TORCH_INTERNAL_ASSERT(x != nullptr, "input is invalid.");
|
||||
TORCH_INTERNAL_ASSERT(negative_slope != nullptr, "negative_slope is invalid");
|
||||
auto zero = IrBuilder::create<Double>(x->container(), 0.);
|
||||
return where(ge(x, zero), x, mul(negative_slope, x));
|
||||
}
|
||||
|
||||
TensorView* view_as_real(TensorView* x) {
|
||||
auto input_type = x->getDataType().value();
|
||||
TORCH_CHECK(
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ TORCH_CUDA_CU_API TensorView* gelu_backward(TensorView* dy, TensorView* x);
|
|||
TORCH_CUDA_CU_API TensorView* tanh_gelu(TensorView* x);
|
||||
TORCH_CUDA_CU_API TensorView* tanh_gelu_backward(TensorView* dy, TensorView* x);
|
||||
TORCH_CUDA_CU_API TensorView* tanh_backward(TensorView* dy, TensorView* tanh_x);
|
||||
TORCH_CUDA_CU_API TensorView* leaky_relu(TensorView* x, Val* negative_slope);
|
||||
|
||||
TORCH_CUDA_CU_API TensorView* view_as_real(TensorView* x);
|
||||
|
||||
|
|
|
|||
|
|
@ -529,8 +529,8 @@ ForwardNormResult batch_norm(
|
|||
auto invstd_bcast = broadcast(unbiased_invstd, broadcast_mask);
|
||||
|
||||
// During inference, mean/invstd output are empty tensors
|
||||
mean = TensorViewBuilder().shape({0}).build();
|
||||
invstd = TensorViewBuilder().shape({0}).build();
|
||||
mean = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
|
||||
invstd = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
|
||||
y = mul(x_sub_mean, invstd_bcast);
|
||||
}
|
||||
|
||||
|
|
@ -782,8 +782,8 @@ ForwardNormResult instance_norm(
|
|||
broadcast(unbiased_invstd, channels_only_broadcast_mask);
|
||||
|
||||
// During inference, mean/invstd output are empty tensors
|
||||
mean = TensorViewBuilder().shape({0}).build();
|
||||
invstd = TensorViewBuilder().shape({0}).build();
|
||||
mean = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
|
||||
invstd = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
|
||||
y = mul(x_sub_mean, invstd_bcast);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2909,6 +2909,34 @@ class IrParser {
|
|||
nullptr);
|
||||
}
|
||||
|
||||
{
|
||||
auto ptr_op = getOperatorForLiteral(
|
||||
"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor");
|
||||
REGISTER_PARSE_RULE(
|
||||
ptr_op,
|
||||
{
|
||||
MemoryFormat format;
|
||||
std::list<Val*> list_val;
|
||||
std::tie(format, list_val) = getConsistentValues(
|
||||
c10::nullopt, value_map[node->inputs()[0]->unique()]);
|
||||
auto self = list_val.front()->as<TensorView>();
|
||||
list_val.pop_front();
|
||||
|
||||
Val* negative_slope = value_map[node->inputs()[1]->unique()];
|
||||
|
||||
auto out = leaky_relu(self, negative_slope);
|
||||
value_map.emplace(
|
||||
node->output()->unique(), ValueHolder(out, format));
|
||||
},
|
||||
[](const Node* node) -> bool {
|
||||
if (!isInputNonSizeZeroTensor(node)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
|
||||
{
|
||||
auto ptr_op = getOperatorForLiteral(
|
||||
"aten::gelu(Tensor self, *, str approximate='none') -> Tensor");
|
||||
|
|
|
|||
|
|
@ -528,3 +528,105 @@ __device__ inline int64_t readCycleCounter() {
|
|||
__threadfence();
|
||||
return clock64();
|
||||
}
|
||||
|
||||
__device__ float print_impl(const char* name, float value) {
|
||||
printf(
|
||||
"%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
|
||||
name,
|
||||
value,
|
||||
(int)threadIdx.x,
|
||||
(int)threadIdx.y,
|
||||
(int)threadIdx.z,
|
||||
(int)blockIdx.x,
|
||||
(int)blockIdx.y,
|
||||
(int)blockIdx.z);
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ double print_impl(const char* name, double value) {
|
||||
printf(
|
||||
"%s = %lf @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
|
||||
name,
|
||||
value,
|
||||
(int)threadIdx.x,
|
||||
(int)threadIdx.y,
|
||||
(int)threadIdx.z,
|
||||
(int)blockIdx.x,
|
||||
(int)blockIdx.y,
|
||||
(int)blockIdx.z);
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ int print_impl(const char* name, int value) {
|
||||
printf(
|
||||
"%s = %d @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
|
||||
name,
|
||||
value,
|
||||
(int)threadIdx.x,
|
||||
(int)threadIdx.y,
|
||||
(int)threadIdx.z,
|
||||
(int)blockIdx.x,
|
||||
(int)blockIdx.y,
|
||||
(int)blockIdx.z);
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ int64_t print_impl(const char* name, int64_t value) {
|
||||
printf(
|
||||
"%s = %ld @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
|
||||
name,
|
||||
value,
|
||||
(int)threadIdx.x,
|
||||
(int)threadIdx.y,
|
||||
(int)threadIdx.z,
|
||||
(int)blockIdx.x,
|
||||
(int)blockIdx.y,
|
||||
(int)blockIdx.z);
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ bool print_impl(const char* name, bool value) {
|
||||
printf(
|
||||
"%s = %s @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
|
||||
name,
|
||||
value ? "true" : "false",
|
||||
(int)threadIdx.x,
|
||||
(int)threadIdx.y,
|
||||
(int)threadIdx.z,
|
||||
(int)blockIdx.x,
|
||||
(int)blockIdx.y,
|
||||
(int)blockIdx.z);
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __half print_impl(const char* name, __half value) {
|
||||
printf(
|
||||
"%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
|
||||
name,
|
||||
__half2float(value),
|
||||
(int)threadIdx.x,
|
||||
(int)threadIdx.y,
|
||||
(int)threadIdx.z,
|
||||
(int)blockIdx.x,
|
||||
(int)blockIdx.y,
|
||||
(int)blockIdx.z);
|
||||
return value;
|
||||
}
|
||||
|
||||
#if __CUDACC_VER_MAJOR__ >= 11
|
||||
__device__ __bfloat print_impl(const char* name, __bfloat value) {
|
||||
printf(
|
||||
"%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
|
||||
name,
|
||||
__bfloat2float(value),
|
||||
(int)threadIdx.x,
|
||||
(int)threadIdx.y,
|
||||
(int)threadIdx.z,
|
||||
(int)blockIdx.x,
|
||||
(int)blockIdx.y,
|
||||
(int)blockIdx.z);
|
||||
return value;
|
||||
}
|
||||
#endif
|
||||
|
||||
#define print(...) print_impl(#__VA_ARGS__, (__VA_ARGS__))
|
||||
|
|
|
|||
|
|
@ -157,6 +157,15 @@ DEVICE_INLINE void cpAsyncBarrier() {
|
|||
asm volatile("cp.async.wait_all;");
|
||||
}
|
||||
|
||||
DEVICE_INLINE void cpAsyncCommit() {
|
||||
asm volatile("cp.async.commit_group;");
|
||||
}
|
||||
|
||||
template <int keep_stages>
|
||||
DEVICE_INLINE void cpAsyncPartialBarrier() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(keep_stages));
|
||||
}
|
||||
|
||||
} // namespace Ampere
|
||||
|
||||
#endif // Arch 80
|
||||
|
|
|
|||
|
|
@ -276,6 +276,30 @@ DEVICE_INLINE void M16N8K16TN(
|
|||
_C[acc_stride + 1] = C_data[3];
|
||||
}
|
||||
|
||||
template <int acc_stride>
|
||||
DEVICE_INLINE void initM16N16K16TN(Array<float, 8, 8>* accumulator) {
|
||||
float* _C = reinterpret_cast<float*>(accumulator);
|
||||
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[0]));
|
||||
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[2]));
|
||||
}
|
||||
|
||||
template <int acc_stride = 2>
|
||||
DEVICE_INLINE void M16N16K16TN(
|
||||
Array<float, 8, 8>* C,
|
||||
Array<__half, 8, 8>* A,
|
||||
Array<__half, 8, 8>* B) {
|
||||
float* _C = reinterpret_cast<float*>(C);
|
||||
__half* _B = reinterpret_cast<__half*>(B);
|
||||
M16N8K16TN<acc_stride>(
|
||||
reinterpret_cast<Array<float, 4, 4>*>(&_C[0]),
|
||||
A,
|
||||
reinterpret_cast<Array<__half, 4, 4>*>(&_B[0]));
|
||||
M16N8K16TN<acc_stride>(
|
||||
reinterpret_cast<Array<float, 4, 4>*>(&_C[2]),
|
||||
A,
|
||||
reinterpret_cast<Array<__half, 4, 4>*>(&_B[4]));
|
||||
}
|
||||
|
||||
} // namespace Turing
|
||||
|
||||
#endif // Arch 75
|
||||
|
|
@ -338,6 +362,30 @@ DEVICE_INLINE void M16N8K16TN(
|
|||
_C[acc_stride + 1] = C_data[3];
|
||||
}
|
||||
|
||||
template <int acc_stride>
|
||||
DEVICE_INLINE void initM16N16K16TN(Array<float, 8, 8>* accumulator) {
|
||||
float* _C = reinterpret_cast<float*>(accumulator);
|
||||
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[0]));
|
||||
initM16N8K16TN<acc_stride>(reinterpret_cast<Array<float, 4, 4>*>(&_C[2]));
|
||||
}
|
||||
|
||||
template <int acc_stride = 2>
|
||||
DEVICE_INLINE void M16N16K16TN(
|
||||
Array<float, 8, 8>* C,
|
||||
Array<__half, 8, 8>* A,
|
||||
Array<__half, 8, 8>* B) {
|
||||
float* _C = reinterpret_cast<float*>(C);
|
||||
__half* _B = reinterpret_cast<__half*>(B);
|
||||
M16N8K16TN<acc_stride>(
|
||||
reinterpret_cast<Array<float, 4, 4>*>(&_C[0]),
|
||||
A,
|
||||
reinterpret_cast<Array<__half, 4, 4>*>(&_B[0]));
|
||||
M16N8K16TN<acc_stride>(
|
||||
reinterpret_cast<Array<float, 4, 4>*>(&_C[2]),
|
||||
A,
|
||||
reinterpret_cast<Array<__half, 4, 4>*>(&_B[4]));
|
||||
}
|
||||
|
||||
} // namespace Ampere
|
||||
|
||||
#endif // Arch 80
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -24,6 +25,8 @@ namespace HeuristicCompileTime {
|
|||
|
||||
//! Enum for all possible types of cached entries of compile-time info.
|
||||
enum class CompileTimeEntryType {
|
||||
DOMAIN_MAP,
|
||||
REFERENCE_TENSORS,
|
||||
VECTORIZABLE_INPUTS_AND_OUTPUTS,
|
||||
UNROLLABLE_INPUTS_AND_OUTPUTS,
|
||||
REDUCTION_TVS,
|
||||
|
|
@ -32,6 +35,24 @@ enum class CompileTimeEntryType {
|
|||
BROADCAST_BYTE_MULTIPLES
|
||||
};
|
||||
|
||||
//! Entry type definition class for `DOMAIN_MAP`,
|
||||
//! stores the domain map of a fusion.
|
||||
class DomainMap {
|
||||
public:
|
||||
using DataType = pointwise_utils::DomainMap;
|
||||
static const CompileTimeEntryType EntryType =
|
||||
CompileTimeEntryType::DOMAIN_MAP;
|
||||
};
|
||||
|
||||
//! Entry type definition class for `REFERENCE_TENSORS`,
|
||||
//! stores the the reference TensorViews used to schedule a fusion.
|
||||
class ReferenceTensors {
|
||||
public:
|
||||
using DataType = std::vector<TensorView*>;
|
||||
static const CompileTimeEntryType EntryType =
|
||||
CompileTimeEntryType::REFERENCE_TENSORS;
|
||||
};
|
||||
|
||||
//! Entry type definition class for `VECTORIZABLE_INPUTS_AND_OUTPUTS`,
|
||||
//! stores the vectorizable TensorViews on a fusion's inputs and outputs.
|
||||
class VectorizableInputsAndOutputs {
|
||||
|
|
|
|||
37
torch/csrc/jit/codegen/cuda/scheduler/heuristic.h
Normal file
37
torch/csrc/jit/codegen/cuda/scheduler/heuristic.h
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
class HeuristicParams {
|
||||
public:
|
||||
std::string tag = "";
|
||||
|
||||
LaunchParams lparams;
|
||||
|
||||
virtual std::string toString() const {
|
||||
return "Undefined Heuristic Params";
|
||||
}
|
||||
|
||||
virtual size_t hash() const = 0;
|
||||
|
||||
virtual ~HeuristicParams() = default;
|
||||
|
||||
virtual bool sameAs(const std::shared_ptr<HeuristicParams>& other) const = 0;
|
||||
|
||||
virtual std::shared_ptr<HeuristicParams> clone() const = 0;
|
||||
|
||||
HeuristicParams() = default;
|
||||
HeuristicParams(const std::string& tag) : tag(tag) {}
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
329
torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp
Normal file
329
torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/scheduler/matmul.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
namespace {
|
||||
// Move the broadcast axes to the left on the specified number of inner
|
||||
// dimensions e.g. (when number_of_inner_pos == 3):
|
||||
// [... I0, B, I1] -> [... B, I0, I1]
|
||||
// should probably be only used to order innermost mnk axes.
|
||||
void moveInnerBroadcastLeft(TensorView* tv, int number_of_inner_pos = 3) {
|
||||
TORCH_INTERNAL_ASSERT(tv->nDims() >= number_of_inner_pos);
|
||||
std::vector<int> broadcast_pos;
|
||||
std::vector<int> nonbroadcast_pos;
|
||||
|
||||
for (auto i : c10::irange(number_of_inner_pos)) {
|
||||
auto axis_idx = i - number_of_inner_pos;
|
||||
auto id = tv->axis(axis_idx);
|
||||
if (id->isBroadcast()) {
|
||||
broadcast_pos.push_back(axis_idx);
|
||||
} else {
|
||||
nonbroadcast_pos.push_back(axis_idx);
|
||||
}
|
||||
}
|
||||
|
||||
auto combined_pos_vec = broadcast_pos;
|
||||
combined_pos_vec.insert(
|
||||
combined_pos_vec.end(), nonbroadcast_pos.begin(), nonbroadcast_pos.end());
|
||||
|
||||
std::unordered_map<int, int> order_map;
|
||||
for (auto i : c10::irange(number_of_inner_pos)) {
|
||||
order_map[combined_pos_vec.at(i)] = i - number_of_inner_pos;
|
||||
}
|
||||
|
||||
// Apply ordering.
|
||||
tv->reorder(order_map);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void scheduleMatmul(
|
||||
TensorView* c,
|
||||
TensorView* a,
|
||||
TensorView* b,
|
||||
MatmulParam& params) {
|
||||
// Unpack from params.
|
||||
auto& mma_builder = params.mma_builder;
|
||||
auto& gemm_tile = params.tile_sizes;
|
||||
|
||||
// Including current tensor naming convention for reference,
|
||||
// this is very temporary and will change over time and
|
||||
// in fact the whole body of this function will
|
||||
// eventually be a set of utility functions for different
|
||||
// sections of matmul(fusion) kernels, with
|
||||
// each having its own build out to do.
|
||||
//
|
||||
// Current naming convention:
|
||||
//
|
||||
// operands assumed in global memory : a, b
|
||||
//
|
||||
// registers staging global load : ar, br (short for a/b read)
|
||||
//
|
||||
// shared mem cache of operands : acw_smem, bcw_smem (short for a/b
|
||||
// cache_write smem)
|
||||
//
|
||||
// registers at shared memory load output : acr, bcr (short for a/b cache
|
||||
// read)
|
||||
//
|
||||
// register tensor input to the actual mma op: ab, bb (short for a/b
|
||||
// broadcasted)
|
||||
//
|
||||
// accumulator register: cc (short for c cache)
|
||||
//
|
||||
// result in global memory: c
|
||||
|
||||
// Currently only support a, b, c as fusion inputs/outputs
|
||||
// aka. no prolog and epilog fusion yet.
|
||||
TORCH_CHECK(
|
||||
c->isFusionOutput() && a->isFusionInput() && b->isFusionInput(),
|
||||
"not supporting matmul fusion yet");
|
||||
TORCH_CHECK(c->definition() && c->definition()->isA<MmaOp>());
|
||||
|
||||
mma_builder.configureMma(c);
|
||||
|
||||
// TODO:
|
||||
// Beyond this point, mma_builder really just becomes a populated
|
||||
// list of parameters to describes the mma swizzles that should
|
||||
// be annotated on the tensor domain. Conceptually the mma builder
|
||||
// object should be separated to 2 parts, one as scheduler utility
|
||||
// and the other as matmul heuristic parameters, which we are
|
||||
// starting to build out.
|
||||
|
||||
// Setup register and shared memory stages:
|
||||
// TODO: this section goes to a separate matmul util,
|
||||
// and needs more configurability.
|
||||
|
||||
// Setup accumulator register.
|
||||
auto cc = c->cacheBefore();
|
||||
|
||||
// Get the input to the mma op.
|
||||
auto mma = dynamic_cast<MmaOp*>(cc->definition());
|
||||
TORCH_INTERNAL_ASSERT(mma != nullptr);
|
||||
auto ab = mma->inA()->as<TensorView>();
|
||||
auto bb = mma->inB()->as<TensorView>();
|
||||
|
||||
// Get exact configurations from mma builder.
|
||||
mma_builder.accumulatorTv(cc);
|
||||
auto mma_options = mma_builder.build();
|
||||
|
||||
// Staging register for global memory load
|
||||
TensorView *ar = a, *br = b;
|
||||
|
||||
if (!params.async_gmem_load_operands) {
|
||||
ar = a->cacheAfter();
|
||||
br = b->cacheAfter();
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Significant build out needed here
|
||||
// for more flexibility and data type support.
|
||||
// Shared memory
|
||||
TensorView* acw_smem = nullptr;
|
||||
TensorView* bcw_smem = nullptr;
|
||||
// Shared memory read
|
||||
TensorView* acr = nullptr;
|
||||
TensorView* bcr = nullptr;
|
||||
|
||||
// Different paths because Volta swizzle needs to
|
||||
// involve the broadcast dimensions that are concretized
|
||||
// at mma, while Ampere ones should be done before
|
||||
// the broadcast op to be able to use cp.async.
|
||||
// TODO:
|
||||
// Also a few additional parameters should be introduced
|
||||
// to control this stage of scheduling.
|
||||
if (isVolta(mma_options.macro)) {
|
||||
acw_smem = ab->cacheAfter();
|
||||
bcw_smem = bb->cacheAfter();
|
||||
// Cache again to be able to vectorize.
|
||||
acw_smem = acw_smem->cacheAfter();
|
||||
bcw_smem = bcw_smem->cacheAfter();
|
||||
|
||||
acr = acw_smem->cacheAfter();
|
||||
bcr = bcw_smem->cacheAfter();
|
||||
if (params.double_buffer_options.double_buffer_smem_read) {
|
||||
// Provide another copy op between the double buffered
|
||||
// smem load register and the actual mma ops to avoid
|
||||
// complication in double buffered fragment iteration.
|
||||
ab = acr->cacheAfter();
|
||||
bb = bcr->cacheAfter();
|
||||
} else {
|
||||
ab = acr;
|
||||
bb = bcr;
|
||||
}
|
||||
|
||||
} else {
|
||||
// Use cp.async as requested in scheduler params.
|
||||
c10::optional<LoadStoreOpType> load_op = c10::nullopt;
|
||||
if (params.async_gmem_load_operands) {
|
||||
load_op = LoadStoreOpType::CpAsync;
|
||||
}
|
||||
|
||||
acw_smem = ar->cacheAfter(load_op);
|
||||
bcw_smem = br->cacheAfter(load_op);
|
||||
acr = acw_smem->cacheAfter(
|
||||
mma_builder.operand(MmaOptions::Operand::A).ldMatrix());
|
||||
bcr = bcw_smem->cacheAfter(
|
||||
mma_builder.operand(MmaOptions::Operand::B).ldMatrix());
|
||||
}
|
||||
|
||||
// Make a CTA tile
|
||||
// ------------------------------------------------------------------
|
||||
scheduler_utils::matmul_utils::canonicalizeMmaTvOrdering(cc);
|
||||
// [... M,N,K]
|
||||
scheduler_utils::matmul_utils::makeTile(cc, gemm_tile.cta_tile.toVector());
|
||||
|
||||
// [Mo, No, Ko, Mi, Ni, Ki]
|
||||
// Propagate tiling globally
|
||||
scheduler_utils::transformPropagateToAllFrom(cc, -1);
|
||||
|
||||
// Schedule warp tile
|
||||
scheduler_utils::matmul_utils::scheduleWarpTileWithReduction(cc, gemm_tile);
|
||||
|
||||
// Propagate warp tile to main loop and epilog/output tvs
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::bothWays(
|
||||
cc, -1, {acw_smem, bcw_smem}, {c});
|
||||
|
||||
// Schedule prolog:
|
||||
// TODO: this section goes to a separate matmul util,
|
||||
// and needs more configurability.
|
||||
// ------------------------------------------------------------------
|
||||
scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(acw_smem);
|
||||
// [... M, K]
|
||||
acw_smem->merge(-2);
|
||||
scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
|
||||
acw_smem, gemm_tile, 8, false);
|
||||
|
||||
// [... N, K]
|
||||
scheduler_utils::matmul_utils::orderTiledConcreteIdAsRoot(bcw_smem);
|
||||
bcw_smem->merge(-2);
|
||||
scheduler_utils::matmul_utils::scheduleContiguousVectorLoad(
|
||||
bcw_smem, gemm_tile, 8, false);
|
||||
|
||||
// Propagate prolog tensors
|
||||
// propagate up the DAG, and propagate parallel type.
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
|
||||
acw_smem,
|
||||
-1,
|
||||
{a},
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
|
||||
.propagateParallelType());
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
|
||||
bcw_smem,
|
||||
-1,
|
||||
{b},
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
|
||||
.propagateParallelType());
|
||||
|
||||
// Set computeAt, setup the loop nesting structure on the kernel.
|
||||
// TODO: this section goes to a separate matmul util,
|
||||
// and needs more configurability.
|
||||
// ------------------------------------------------------------------
|
||||
// CTA tile:
|
||||
|
||||
// Swizzle block tiles:
|
||||
c->swizzle(Swizzle2DType::ZShape, 0, 1, SwizzleMode::Loop);
|
||||
|
||||
a->computeAt(c, 2);
|
||||
b->computeAt(c, 2);
|
||||
|
||||
// Prolog:
|
||||
a->computeAt(cc, 3);
|
||||
b->computeAt(cc, 3);
|
||||
|
||||
// Main Loop:
|
||||
acr->computeAt(cc, -6);
|
||||
bcr->computeAt(cc, -6);
|
||||
|
||||
// Add mma swizzle:
|
||||
// TODO: this section goes to a separate matmul util,
|
||||
// and needs more configurability.
|
||||
// ------------------------------------------------------------------
|
||||
if (isTuring(mma_options.macro) || isAmpere(mma_options.macro)) {
|
||||
moveInnerBroadcastLeft(ab);
|
||||
moveInnerBroadcastLeft(bb);
|
||||
}
|
||||
|
||||
ab->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::A).build());
|
||||
bb->applyMmaSwizzle(mma_builder.operand(MmaOptions::Operand::B).build());
|
||||
|
||||
// Propagate mma input swizzle up the DAG
|
||||
// to all the tensors before mma op and after shared mem read.
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
|
||||
ab,
|
||||
-1,
|
||||
{acw_smem},
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
|
||||
.propagateParallelType());
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
|
||||
bb,
|
||||
-1,
|
||||
{bcw_smem},
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
|
||||
.propagateParallelType());
|
||||
|
||||
cc->applyMmaSwizzle(
|
||||
mma_builder.operand(MmaOptions::Operand::Accumulator).build());
|
||||
|
||||
// Set memory type:
|
||||
acw_smem->setMemoryType(MemoryType::Shared);
|
||||
bcw_smem->setMemoryType(MemoryType::Shared);
|
||||
|
||||
// Set parallelization:
|
||||
// TODO: this section goes to a separate matmul util,
|
||||
// and needs more configurability.
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
// Vectorize smem stores/loads:
|
||||
acw_smem->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
bcw_smem->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
|
||||
acr->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
bcr->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
|
||||
// 0 1 2 3 4 5 6 7 8 9 10
|
||||
// [Mo No Ko Mwo Nwo Kw Mw Nw (Mi Ni Ki)]
|
||||
cc->axis(0)->parallelize(ParallelType::BIDx);
|
||||
cc->axis(1)->parallelize(ParallelType::BIDy);
|
||||
cc->axis(3)->parallelize(ParallelType::TIDz);
|
||||
cc->axis(4)->parallelize(ParallelType::TIDy);
|
||||
|
||||
// Propagate mma output swizzle and parallelization down the DAG
|
||||
if (params.double_buffer_options.double_buffer_smem_write) {
|
||||
TORCH_CHECK(
|
||||
params.double_buffer_options.smem_double_buffer_stage > 1,
|
||||
"Invalid buffer stage config")
|
||||
if (params.double_buffer_options.smem_double_buffer_stage > 2) {
|
||||
TORCH_CHECK(
|
||||
params.async_gmem_load_operands,
|
||||
"Circular buffer only supports async load");
|
||||
}
|
||||
|
||||
acw_smem->circularBuffer(
|
||||
params.double_buffer_options.smem_double_buffer_stage);
|
||||
bcw_smem->circularBuffer(
|
||||
params.double_buffer_options.smem_double_buffer_stage);
|
||||
}
|
||||
|
||||
if (params.double_buffer_options.double_buffer_smem_read) {
|
||||
acr->doubleBuffer();
|
||||
bcr->doubleBuffer();
|
||||
}
|
||||
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::forward(
|
||||
cc,
|
||||
-1,
|
||||
{c},
|
||||
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
|
||||
.propagateParallelType()
|
||||
.propagateToBoundary());
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
55
torch/csrc/jit/codegen/cuda/scheduler/matmul.h
Normal file
55
torch/csrc/jit/codegen/cuda/scheduler/matmul.h
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/mma_type.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
//! Starting point for a matmul scheduler parameters:
|
||||
class MatmulParam {
|
||||
public:
|
||||
MatmulParam(MmaBuilder builder) : mma_builder(builder) {}
|
||||
|
||||
struct DoubleBufferOptions {
|
||||
bool double_buffer_smem_write = false;
|
||||
bool double_buffer_smem_read = false;
|
||||
int smem_double_buffer_stage = 2;
|
||||
};
|
||||
|
||||
//! (Ampere+) Use cp.async to load operands.
|
||||
bool async_gmem_load_operands = false;
|
||||
|
||||
//! Specifies the tiling hierarchy on block,
|
||||
//! warp, and instruction levels.
|
||||
MatMulTileOptions tile_sizes;
|
||||
|
||||
//! Parameters for configuring mma ops.
|
||||
MmaBuilder mma_builder;
|
||||
|
||||
//! Specify which tensor we double buffer.
|
||||
DoubleBufferOptions double_buffer_options;
|
||||
};
|
||||
|
||||
//! Prototype auto scheduling function.
|
||||
//! Currently only support a pure matmul with no
|
||||
//! fused prolog or epilog.
|
||||
//!
|
||||
//! TODO:
|
||||
//! - will support a range of fusions in a follow up
|
||||
//! - will formalize scheduling decisions into
|
||||
//! matmul params data structure.
|
||||
TORCH_CUDA_CU_API void scheduleMatmul(
|
||||
TensorView* c_tv,
|
||||
TensorView* a_tv,
|
||||
TensorView* b_tv,
|
||||
MatmulParam& params);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/mma_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -133,6 +134,13 @@ void WarpMmaSwizzler::scheduleMmaWarpOutput(
|
|||
setWarpMapped(tv, 4);
|
||||
}
|
||||
break;
|
||||
case MmaOptions::MacroType::Turing_16_16_16:
|
||||
case MmaOptions::MacroType::Ampere_16_16_16:
|
||||
scheduleTuringM16N16K16MmaWarpOutput(tv, options);
|
||||
if (tv->definition()->isA<MmaOp>()) {
|
||||
setWarpMapped(tv, 4);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "scheduleMmaWarp: unsupported mma option ", toString(macro));
|
||||
|
|
@ -150,6 +158,8 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) {
|
|||
break;
|
||||
case MmaOptions::MacroType::Turing_16_8_16:
|
||||
case MmaOptions::MacroType::Ampere_16_8_16:
|
||||
case MmaOptions::MacroType::Turing_16_16_16:
|
||||
case MmaOptions::MacroType::Ampere_16_16_16:
|
||||
scheduleTuringOperandRead(tv, options);
|
||||
break;
|
||||
default:
|
||||
|
|
@ -245,6 +255,14 @@ std::vector<IterDomain*> getMmaDomains(MmaOp* mma, MmaDimension dimension) {
|
|||
return result;
|
||||
}
|
||||
|
||||
//! Variant of getMmaDomains that returns a set
|
||||
std::unordered_set<IterDomain*> getMmaDomainSet(
|
||||
MmaOp* mma,
|
||||
MmaDimension dimension) {
|
||||
auto mma_domains = getMmaDomains(mma, dimension);
|
||||
return {mma_domains.begin(), mma_domains.end()};
|
||||
}
|
||||
|
||||
// [MMA dimension matching]
|
||||
// Returns all the axes that correspond to the given mma dimension. This is the
|
||||
// first relaxation step on the mma check.
|
||||
|
|
@ -325,9 +343,10 @@ void validateMmaRootInnerMNK(
|
|||
int m,
|
||||
int n,
|
||||
int k) {
|
||||
auto m_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::M);
|
||||
auto n_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::N);
|
||||
auto k_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::K);
|
||||
auto mma = options.mmaOp();
|
||||
auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M);
|
||||
auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N);
|
||||
auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K);
|
||||
|
||||
TORCH_CHECK(
|
||||
!m_dims.empty() && !n_dims.empty() && !k_dims.empty(),
|
||||
|
|
@ -354,8 +373,9 @@ void validateMmaRootInnerMNK(
|
|||
//! swizzles to the right axes.
|
||||
//! This check will be relaxed as we build out the mma usage patterns.
|
||||
void validateMmaRootInnerMN(TensorView* tv, MmaOptions options, int m, int n) {
|
||||
auto m_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::M);
|
||||
auto n_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::N);
|
||||
auto mma = options.mmaOp();
|
||||
auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M);
|
||||
auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N);
|
||||
|
||||
TORCH_CHECK(
|
||||
!m_dims.empty() && !n_dims.empty(),
|
||||
|
|
@ -494,14 +514,17 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) {
|
|||
// Check mma option is supported
|
||||
TORCH_CHECK(
|
||||
options.macro == MmaOptions::MacroType::Ampere_16_8_16 ||
|
||||
options.macro == MmaOptions::MacroType::Turing_16_8_16,
|
||||
options.macro == MmaOptions::MacroType::Ampere_16_16_16 ||
|
||||
options.macro == MmaOptions::MacroType::Turing_16_8_16 ||
|
||||
options.macro == MmaOptions::MacroType::Turing_16_16_16,
|
||||
"scheduleLdMatrix: unknown macro for ldmatrix");
|
||||
|
||||
if (options.operand == MmaOptions::Operand::A) {
|
||||
TORCH_INTERNAL_ASSERT(tv->nDims() >= 2);
|
||||
// validation:
|
||||
auto m_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::M);
|
||||
auto k_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::K);
|
||||
auto mma = options.mmaOp();
|
||||
auto m_dims = getMmaRootDimensions(tv, mma, MmaDimension::M);
|
||||
auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
canValidateIsInnerDim(m_dims.back(), tv->axis(-2), 16),
|
||||
|
|
@ -532,46 +555,84 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) {
|
|||
|
||||
tv->axis(-2)->parallelize(ParallelType::TIDx);
|
||||
} else if (options.operand == MmaOptions::Operand::B) {
|
||||
auto n_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::N);
|
||||
auto k_dims = getMmaRootDimensions(tv, options.mma_op, MmaDimension::K);
|
||||
auto mma = options.mmaOp();
|
||||
auto n_dims = getMmaRootDimensions(tv, mma, MmaDimension::N);
|
||||
auto k_dims = getMmaRootDimensions(tv, mma, MmaDimension::K);
|
||||
|
||||
// validation:
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 8),
|
||||
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
canValidateIsInnerDim(k_dims.back(), tv->axis(-1), 16),
|
||||
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");
|
||||
|
||||
if (transposed) {
|
||||
// [8, 16]
|
||||
tv->split(-2, 4);
|
||||
// Each ldmatrix 4 would be loading an effective 16x16x16 tile, which is 2x
|
||||
// the
|
||||
// size of regular 16x8x16 tile supported by largest mma operation. The
|
||||
// swizzle also needs to be different to take this into account.
|
||||
// TODO:
|
||||
// Using an emulated 16x16x16 mma tile is a temporary step to enable the
|
||||
// widest load possible for scheduler bring up phase.
|
||||
// A unifying step would be needed in a follow up to support all these
|
||||
// swizzles
|
||||
// with a single affine utility.
|
||||
bool use_ldmatrix4 = canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 16);
|
||||
|
||||
// [2i, 4i, 16]
|
||||
tv->reorder({{-1, -2}, {-2, -1}});
|
||||
// [2i, 16, 4i]
|
||||
if (use_ldmatrix4) {
|
||||
// [... N16, K16]
|
||||
tv->split(-2, 8);
|
||||
tv->split(-1, 8);
|
||||
|
||||
tv->merge(-3);
|
||||
// [warp, 4i]
|
||||
} else {
|
||||
//[8, 16]
|
||||
tv->split(-1, 4);
|
||||
tv->split(-2, 2);
|
||||
// -4 -3 -2 -1
|
||||
// [... N2o, N8, K2o, K8]
|
||||
tv->reorder({{-3, -2}, {-2, -3}});
|
||||
// [... N2o, K2o, N8, K8]
|
||||
|
||||
// 0 1 2 3
|
||||
//[8, oo2,oi2,i4]
|
||||
tv->reorder({{-4, -2}, {-2, -4}});
|
||||
|
||||
// 0 1 2 3
|
||||
//[oi2, oo2, 8,i4]
|
||||
if (transposed) {
|
||||
tv->reorder({{-1, -2}, {-2, -1}});
|
||||
}
|
||||
|
||||
tv->merge(-4);
|
||||
tv->merge(-3);
|
||||
// 0 1
|
||||
//[warp, i4]
|
||||
}
|
||||
|
||||
tv->axis(-2)->parallelize(ParallelType::TIDx);
|
||||
// [Warp, K8]
|
||||
tv->axis(-2)->parallelize(ParallelType::TIDx);
|
||||
if (is_immediate_output) {
|
||||
tv->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
}
|
||||
} else {
|
||||
// validation:
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
canValidateIsInnerDim(n_dims.back(), tv->axis(-2), 8),
|
||||
"MMA swizzle: requires instruction tile iterdomains on the innermost side of the tensordomain");
|
||||
|
||||
if (transposed) {
|
||||
// [8, 16]
|
||||
tv->split(-2, 4);
|
||||
|
||||
// [2i, 4i, 16]
|
||||
tv->reorder({{-1, -2}, {-2, -1}});
|
||||
// [2i, 16, 4i]
|
||||
|
||||
tv->merge(-3);
|
||||
// [warp, 4i]
|
||||
} else {
|
||||
//[8, 16]
|
||||
tv->split(-1, 4);
|
||||
tv->split(-2, 2);
|
||||
|
||||
// 0 1 2 3
|
||||
//[8, oo2,oi2,i4]
|
||||
tv->reorder({{-4, -2}, {-2, -4}});
|
||||
|
||||
// 0 1 2 3
|
||||
//[oi2, oo2, 8,i4]
|
||||
|
||||
tv->merge(-4);
|
||||
tv->merge(-3);
|
||||
// 0 1
|
||||
//[warp, i4]
|
||||
}
|
||||
|
||||
tv->axis(-2)->parallelize(ParallelType::TIDx);
|
||||
}
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(false, "unreachable");
|
||||
}
|
||||
|
|
@ -704,6 +765,52 @@ void WarpMmaSwizzler::scheduleTuringM16N8K16MmaWarpOutput(
|
|||
tv->axis(m_pos)->parallelize(ParallelType::TIDx);
|
||||
}
|
||||
|
||||
void WarpMmaSwizzler::scheduleTuringM16N16K16MmaWarpOutput(
|
||||
TensorView* tv,
|
||||
const MmaOptions& options) {
|
||||
// Assume last 2 dims [M16, N8] or [M16, N8, R]
|
||||
// Locate instruction m
|
||||
bool is_reduction = tv->axis(-1)->isReduction();
|
||||
|
||||
// Make sure instruction tile size is correct.
|
||||
if (is_reduction) {
|
||||
validateMmaRootInnerMNK(tv, options, 16, 16, 16);
|
||||
} else {
|
||||
validateMmaRootInnerMN(tv, options, 16, 16);
|
||||
}
|
||||
|
||||
int m_pos = is_reduction ? -3 : -2;
|
||||
// m
|
||||
// [16, 16 (,R)]
|
||||
|
||||
tv->split(m_pos + 1, 8);
|
||||
// m
|
||||
// [16, n2, 8 (,R)]
|
||||
tv->reorder({{m_pos, m_pos - 1}, {m_pos - 1, m_pos}});
|
||||
|
||||
// m
|
||||
// [n2, 16, 8 (,R)]
|
||||
tv->split(m_pos, 8);
|
||||
tv->split(m_pos + 1, 2);
|
||||
|
||||
// m
|
||||
// [2o, 8o, 4i, 2i (,R)]
|
||||
tv->merge(m_pos - 1);
|
||||
|
||||
// m
|
||||
// [2o, Warp, 2i (,R)]
|
||||
TORCH_CHECK(tv->definition() != nullptr);
|
||||
|
||||
if (is_reduction && tv->definition()->isA<MmaOp>()) {
|
||||
// Set instruction loops for mma reduce
|
||||
for (int pos : c10::irange(5)) {
|
||||
tv->axis(-pos - 1)->parallelize(ParallelType::Mma);
|
||||
}
|
||||
}
|
||||
|
||||
tv->axis(m_pos)->parallelize(ParallelType::TIDx);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool isMmaInitLoop(const kir::Scope& loop_body) {
|
||||
|
|
@ -750,6 +857,76 @@ bool isMmaInitLoop(const kir::ForLoop* loop) {
|
|||
|
||||
} // namespace mma_util
|
||||
|
||||
void scheduler_utils::matmul_utils::canonicalizeMmaTvOrdering(TensorView* tv) {
|
||||
std::unordered_set<IterDomain*> root_id_set{
|
||||
tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()};
|
||||
|
||||
auto mma = dynamic_cast<MmaOp*>(tv->definition());
|
||||
TORCH_CHECK(
|
||||
mma != nullptr, "canonicalizeMmaTvOrdering : only support mma op output");
|
||||
|
||||
auto m_id_set = mma_util::getMmaDomainSet(mma, mma_util::MmaDimension::M);
|
||||
auto n_id_set = mma_util::getMmaDomainSet(mma, mma_util::MmaDimension::N);
|
||||
auto k_id_set = mma_util::getMmaDomainSet(mma, mma_util::MmaDimension::K);
|
||||
|
||||
std::vector<int> batch_pos, prev_reduction_pos, m_pos, n_pos, k_pos;
|
||||
|
||||
auto ndims = tv->nDims();
|
||||
|
||||
for (auto idx : c10::irange(ndims)) {
|
||||
auto id = tv->axis(idx);
|
||||
TORCH_CHECK(root_id_set.count(id), id->toString(), " not a root id.");
|
||||
|
||||
// Categorize each original iterdomain position
|
||||
if (m_id_set.count(id)) {
|
||||
m_pos.push_back(idx);
|
||||
} else if (n_id_set.count(id)) {
|
||||
n_pos.push_back(idx);
|
||||
} else if (k_id_set.count(id)) {
|
||||
k_pos.push_back(idx);
|
||||
} else if (id->isReduction()) {
|
||||
prev_reduction_pos.push_back(idx);
|
||||
} else {
|
||||
batch_pos.push_back(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all mma id's, other id's would be either
|
||||
// batch or incoming reduction.
|
||||
|
||||
// Ordering map from old position to new position
|
||||
// that we wil build using the position vectors.
|
||||
std::unordered_map<int, int> order_map;
|
||||
|
||||
// Running position counter keeping track of the
|
||||
// current insert position in order_map.
|
||||
int current_pos = 0;
|
||||
|
||||
// Utility to insert the ordered pos sequences to
|
||||
// the ordering map.
|
||||
auto insert_to_order_map =
|
||||
[&order_map, ¤t_pos](const std::vector<int>& original_pos) {
|
||||
for (auto pos : original_pos) {
|
||||
order_map[pos] = current_pos++;
|
||||
}
|
||||
};
|
||||
|
||||
// Order the categories, while keeping the original
|
||||
// intra-category ordering.
|
||||
insert_to_order_map(batch_pos);
|
||||
insert_to_order_map(prev_reduction_pos);
|
||||
insert_to_order_map(m_pos);
|
||||
insert_to_order_map(n_pos);
|
||||
insert_to_order_map(k_pos);
|
||||
|
||||
// Validate that all of the root ids are covered by
|
||||
// the inserted categories.
|
||||
TORCH_INTERNAL_ASSERT(current_pos == ndims, "Id not completely categorized");
|
||||
|
||||
// Apply the new ordering
|
||||
tv->reorder(order_map);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -115,18 +115,32 @@ class TORCH_CUDA_CU_API WarpMmaSwizzler {
|
|||
MmaOptions options = MmaOptions());
|
||||
|
||||
private:
|
||||
//! Swizzle implementations for Volta mma.
|
||||
//! Operand swizzle implementations for Volta mma.
|
||||
static void scheduleVoltaOperandRead(TensorView* tv, MmaOptions options);
|
||||
|
||||
//! Accumulator swizzle implementations for Volta mma.
|
||||
static void scheduleVoltaM16N16K4Fp32Output(
|
||||
TensorView* tv,
|
||||
const MmaOptions& options);
|
||||
|
||||
//! Swizzle implementations for Turing mma.
|
||||
//! Operand swizzle implementations for Turing and Ampere mma.
|
||||
static void scheduleTuringOperandRead(TensorView* tv, MmaOptions options);
|
||||
|
||||
//! Accumulator swizzle implementation for Turing and Ampere mma.
|
||||
static void scheduleTuringM16N8K16MmaWarpOutput(
|
||||
TensorView* tv,
|
||||
const MmaOptions& options);
|
||||
|
||||
//! Accumulator swizzle implementation for emulated 16x16x16 mma tile
|
||||
//! that enables using ldmatrix.x4.
|
||||
//! Note:
|
||||
//! Keeping both this option and the ldmatrix.x2 variant above for
|
||||
//! now for wider scheduler exploration space. Eventually both of
|
||||
//! these can be unified with a single affine utility.
|
||||
static void scheduleTuringM16N16K16MmaWarpOutput(
|
||||
TensorView* tv,
|
||||
const MmaOptions& options);
|
||||
|
||||
//! Utility to lock the transformed dimensions from further transforms.
|
||||
static void setWarpMapped(TensorView* tv, int number_of_dims);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ int64_t roundUpPow2Or8(const int64_t x) {
|
|||
|
||||
// Copied from reduction scheduler, should generalize. Simply needed to take out
|
||||
// grid reductions.
|
||||
ReductionParams innerPersistentHeuristic(
|
||||
std::shared_ptr<ReductionParams> innerPersistentHeuristic(
|
||||
const int64_t total_reduction_numel,
|
||||
const int64_t total_iteration_numel,
|
||||
const int64_t inner_most_dimension_numel,
|
||||
|
|
@ -86,10 +86,9 @@ ReductionParams innerPersistentHeuristic(
|
|||
const int64_t warp_size_based_on_l1 = std::min(
|
||||
ceilDiv(
|
||||
total_reduction_numel,
|
||||
std::max(
|
||||
l1_cache /
|
||||
(n_tensor_inputs * max_input_dtype_size * active_threads),
|
||||
(int64_t)1)),
|
||||
scheduler_utils::safeDiv(
|
||||
l1_cache,
|
||||
n_tensor_inputs * max_input_dtype_size * active_threads)),
|
||||
(int64_t)16);
|
||||
|
||||
// Take the smaller
|
||||
|
|
@ -105,7 +104,7 @@ ReductionParams innerPersistentHeuristic(
|
|||
// communication is slow so it shouldn't be done for every element in the
|
||||
// reduction.
|
||||
int64_t min_target_iterations =
|
||||
std::max((int64_t)32 / (int64_t)max_input_dtype_size, (int64_t)1);
|
||||
scheduler_utils::safeDiv(32, max_input_dtype_size);
|
||||
|
||||
// Start trying to break parallelization up across threads,
|
||||
// unrolling/iterations, and blocks.
|
||||
|
|
@ -121,8 +120,8 @@ ReductionParams innerPersistentHeuristic(
|
|||
// If we have more than a wave of blocks, put parallelism into unrolling and
|
||||
// target iterations
|
||||
if (target_blocks > device_multiprocessor_count) {
|
||||
auto available_unroll = std::max(
|
||||
n_elems / (warp_size * device_multiprocessor_count), (int64_t)1);
|
||||
auto available_unroll = scheduler_utils::safeDiv(
|
||||
n_elems, warp_size * device_multiprocessor_count);
|
||||
|
||||
// Spread across unrolling and iterations, want a balance of the two so flip
|
||||
// back and forth to alternate adding to them.
|
||||
|
|
@ -140,12 +139,10 @@ ReductionParams innerPersistentHeuristic(
|
|||
target_iterations *= 2;
|
||||
}
|
||||
|
||||
available_unroll = std::max(
|
||||
n_elems /
|
||||
(warp_size * device_multiprocessor_count * target_unroll *
|
||||
target_iterations),
|
||||
(int64_t)1);
|
||||
|
||||
available_unroll = scheduler_utils::safeDiv(
|
||||
n_elems,
|
||||
warp_size * device_multiprocessor_count * target_unroll *
|
||||
target_iterations);
|
||||
flip = !flip;
|
||||
}
|
||||
|
||||
|
|
@ -171,9 +168,8 @@ ReductionParams innerPersistentHeuristic(
|
|||
|
||||
// Compute maximum number of reductions we could do in the same kernel based
|
||||
// on persistent buffer size
|
||||
const int64_t max_multi_reduction_factor = std::max(
|
||||
scheduler_utils::register_file_size / max_persistent_buffer_size,
|
||||
(int64_t)1);
|
||||
const int64_t max_multi_reduction_factor = scheduler_utils::safeDiv(
|
||||
scheduler_utils::register_file_size, max_persistent_buffer_size);
|
||||
|
||||
// To get to target threads:
|
||||
// Prioritize
|
||||
|
|
@ -226,18 +222,18 @@ ReductionParams innerPersistentHeuristic(
|
|||
|
||||
// Put everything else in bdimy for now
|
||||
bdimy = std::min(
|
||||
std::max(warp_size / bdimx, (int64_t)1), max_multi_reduction_factor);
|
||||
scheduler_utils::safeDiv(warp_size, bdimx), max_multi_reduction_factor);
|
||||
|
||||
// If 3D fill the rest of the threads into bdimz
|
||||
bdimz = std::min(
|
||||
std::min(
|
||||
std::max(max_threads_in_block / (bdimx * bdimy), (int64_t)1),
|
||||
scheduler_utils::safeDiv(max_threads_in_block, bdimx * bdimy),
|
||||
outer_reduction_numel),
|
||||
scheduler_utils::z_block_limit);
|
||||
|
||||
// If 3D doesn't fill out the threads, adjust to add to bdimy
|
||||
bdimy = std::min(
|
||||
std::max(max_threads_in_block / (bdimx * bdimz), (int64_t)1),
|
||||
scheduler_utils::safeDiv(max_threads_in_block, bdimx * bdimz),
|
||||
max_multi_reduction_factor);
|
||||
|
||||
// If we don't have a full warp and have an unroll factor, move unroll into
|
||||
|
|
@ -251,14 +247,14 @@ ReductionParams innerPersistentHeuristic(
|
|||
|
||||
// Readjust bdimy and bdimz
|
||||
bdimy = std::min(
|
||||
std::max(warp_size / bdimx, (int64_t)1), max_multi_reduction_factor);
|
||||
scheduler_utils::safeDiv(warp_size, bdimx), max_multi_reduction_factor);
|
||||
|
||||
bdimz = std::min(
|
||||
std::max(max_threads_in_block / (bdimx * bdimy), (int64_t)1),
|
||||
scheduler_utils::safeDiv(max_threads_in_block, bdimx * bdimy),
|
||||
outer_reduction_numel);
|
||||
|
||||
bdimy = std::min(
|
||||
std::max(max_threads_in_block / (bdimx * bdimz), (int64_t)1),
|
||||
scheduler_utils::safeDiv(max_threads_in_block, bdimx * bdimz),
|
||||
max_multi_reduction_factor);
|
||||
}
|
||||
|
||||
|
|
@ -313,14 +309,13 @@ ReductionParams innerPersistentHeuristic(
|
|||
// iteration domain
|
||||
if (inner_reduction_unroll_factor * outer_reduction_unroll_factor <
|
||||
max_unroll &&
|
||||
std::max(max_multi_reduction_factor / bdimy, (int64_t)1) > 2) {
|
||||
scheduler_utils::safeDiv(max_multi_reduction_factor, bdimy) > 2) {
|
||||
// Don't go over a combined inner/outer unroll of max_unroll
|
||||
auto unroll_available = std::min(
|
||||
std::max(
|
||||
max_unroll /
|
||||
(inner_reduction_unroll_factor * outer_reduction_unroll_factor),
|
||||
(int64_t)1),
|
||||
std::max(max_multi_reduction_factor / bdimy, (int64_t)1));
|
||||
scheduler_utils::safeDiv(
|
||||
max_unroll,
|
||||
inner_reduction_unroll_factor * outer_reduction_unroll_factor),
|
||||
scheduler_utils::safeDiv(max_multi_reduction_factor, bdimy));
|
||||
if (unroll_available > 1 && godim > 2 * device_multiprocessor_count) {
|
||||
unroll_available = std::min(
|
||||
unroll_available, ceilDiv(godim, 2 * device_multiprocessor_count));
|
||||
|
|
@ -414,51 +409,52 @@ ReductionParams innerPersistentHeuristic(
|
|||
int64_t gdimy = LaunchParams::UNINITIALIZED_VAL;
|
||||
int64_t gdimz = LaunchParams::UNINITIALIZED_VAL;
|
||||
|
||||
ReductionParams rparams;
|
||||
auto rparams = std::make_shared<ReductionParams>();
|
||||
|
||||
rparams.persistent_kernel = true;
|
||||
rparams.fastest_dim = true;
|
||||
rparams->persistent_kernel = true;
|
||||
rparams->fastest_dim = true;
|
||||
|
||||
// Inner reduction domain
|
||||
rparams.cross_block_inner_reduction = true;
|
||||
rparams.block_dim_inner_reduction = ParallelType::TIDx;
|
||||
rparams.pad_inner_reduction_to_warp = pad_bdimx;
|
||||
rparams.batches_per_block_inner_reduction = batches_per_block_inner_reduction;
|
||||
rparams->cross_block_inner_reduction = true;
|
||||
rparams->block_dim_inner_reduction = ParallelType::TIDx;
|
||||
rparams->pad_inner_reduction_to_warp = pad_bdimx;
|
||||
rparams->batches_per_block_inner_reduction =
|
||||
batches_per_block_inner_reduction;
|
||||
|
||||
// For persistent schedules always have to mark the reduction unrolled
|
||||
// otherwise rfactor can fail
|
||||
rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor;
|
||||
rparams.vectorize_inner_reduction = vectorize;
|
||||
rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor;
|
||||
rparams->vectorize_inner_reduction = vectorize;
|
||||
|
||||
// Iter domain
|
||||
rparams.multiple_reds_per_blk = bdimy > 1;
|
||||
if (rparams.multiple_reds_per_blk) {
|
||||
rparams.block_dim_iter_dom = ParallelType::TIDy;
|
||||
rparams->multiple_reds_per_blk = bdimy > 1;
|
||||
if (rparams->multiple_reds_per_blk) {
|
||||
rparams->block_dim_iter_dom = ParallelType::TIDy;
|
||||
}
|
||||
|
||||
if (godim > 1) {
|
||||
rparams.grid_dim_iter_dom = ParallelType::BIDx;
|
||||
rparams->grid_dim_iter_dom = ParallelType::BIDx;
|
||||
if (godim > scheduler_utils::x_grid_limit) {
|
||||
rparams.split_grid_dim_iter_dom = true;
|
||||
rparams->split_grid_dim_iter_dom = true;
|
||||
gdimx = scheduler_utils::x_grid_limit;
|
||||
}
|
||||
}
|
||||
|
||||
if (iter_unroll_factor > 1) {
|
||||
rparams.unroll_factor_iter_dom = iter_unroll_factor;
|
||||
rparams->unroll_factor_iter_dom = iter_unroll_factor;
|
||||
}
|
||||
|
||||
// Outer reduction domain
|
||||
rparams.schedule_3D = total_reduction_numel != inner_most_dimension_numel;
|
||||
if (rparams.schedule_3D) {
|
||||
rparams.batches_per_block_outer_reduction =
|
||||
rparams->schedule_3D = total_reduction_numel != inner_most_dimension_numel;
|
||||
if (rparams->schedule_3D) {
|
||||
rparams->batches_per_block_outer_reduction =
|
||||
batches_per_block_outer_reduction;
|
||||
rparams.block_dim_outer_reduction = ParallelType::TIDz;
|
||||
rparams.cross_block_outer_reduction = true;
|
||||
rparams.unroll_factor_outer_reduction = outer_reduction_unroll_factor;
|
||||
rparams->block_dim_outer_reduction = ParallelType::TIDz;
|
||||
rparams->cross_block_outer_reduction = true;
|
||||
rparams->unroll_factor_outer_reduction = outer_reduction_unroll_factor;
|
||||
}
|
||||
|
||||
rparams.lparams = LaunchParams(
|
||||
rparams->lparams = LaunchParams(
|
||||
gdimx,
|
||||
gdimy,
|
||||
gdimz,
|
||||
|
|
@ -466,7 +462,7 @@ ReductionParams innerPersistentHeuristic(
|
|||
bdimy,
|
||||
LaunchParams::UNINITIALIZED_VAL);
|
||||
|
||||
rparams.tag = "Inner Persistent Heuristic.\n";
|
||||
rparams->tag = "Inner Persistent Heuristic.\n";
|
||||
|
||||
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
|
||||
std::cerr << "\n===== Reduction Stats ========\n"
|
||||
|
|
@ -483,7 +479,7 @@ ReductionParams innerPersistentHeuristic(
|
|||
<< "\n"
|
||||
<< "block(" << (pad_bdimx ? padded_bdimx : bdimx) << ", " << bdimy
|
||||
<< ", " << bdimz << ")";
|
||||
std::cerr << rparams.toString() << std::endl;
|
||||
std::cerr << rparams->toString() << std::endl;
|
||||
}
|
||||
|
||||
return rparams;
|
||||
|
|
@ -492,7 +488,7 @@ ReductionParams innerPersistentHeuristic(
|
|||
// Copied from reduction scheduler, should generalize. Simply needed to take out
|
||||
// grid reductions.
|
||||
// TODO: Check adding iteration domain unrolling
|
||||
ReductionParams OuterPersistentHeuristic(
|
||||
std::shared_ptr<ReductionParams> outerPersistentHeuristic(
|
||||
const int64_t total_reduction_numel,
|
||||
const int64_t total_iteration_numel,
|
||||
const int64_t n_tensor_inputs,
|
||||
|
|
@ -700,47 +696,47 @@ ReductionParams OuterPersistentHeuristic(
|
|||
|
||||
gdimx = ceilDiv(total_iteration_numel, bdimx);
|
||||
|
||||
ReductionParams rparams;
|
||||
rparams.batches_per_block_inner_reduction = batches_per_block;
|
||||
rparams.persistent_kernel = true;
|
||||
auto rparams = std::make_shared<ReductionParams>();
|
||||
rparams->batches_per_block_inner_reduction = batches_per_block;
|
||||
rparams->persistent_kernel = true;
|
||||
|
||||
rparams.fastest_dim = false;
|
||||
rparams.cross_block_inner_reduction = true;
|
||||
rparams.cross_grid_inner_reduction = false;
|
||||
rparams.multiple_reds_per_blk = bdimx > 1;
|
||||
rparams->fastest_dim = false;
|
||||
rparams->cross_block_inner_reduction = true;
|
||||
rparams->cross_grid_inner_reduction = false;
|
||||
rparams->multiple_reds_per_blk = bdimx > 1;
|
||||
|
||||
if (rparams.multiple_reds_per_blk) {
|
||||
rparams.block_dim_iter_dom = ParallelType::TIDx;
|
||||
if (rparams->multiple_reds_per_blk) {
|
||||
rparams->block_dim_iter_dom = ParallelType::TIDx;
|
||||
}
|
||||
|
||||
rparams.grid_dim_iter_dom = ParallelType::BIDx;
|
||||
rparams.split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit;
|
||||
rparams->grid_dim_iter_dom = ParallelType::BIDx;
|
||||
rparams->split_grid_dim_iter_dom = gdimx > scheduler_utils::x_grid_limit;
|
||||
|
||||
if (rparams.block_dim_iter_dom == ParallelType::TIDx) {
|
||||
rparams.block_dim_inner_reduction = ParallelType::TIDy;
|
||||
if (rparams->block_dim_iter_dom == ParallelType::TIDx) {
|
||||
rparams->block_dim_inner_reduction = ParallelType::TIDy;
|
||||
} else {
|
||||
rparams.block_dim_inner_reduction = ParallelType::TIDx;
|
||||
rparams->block_dim_inner_reduction = ParallelType::TIDx;
|
||||
}
|
||||
|
||||
// Always need to mark inner reduction unroll for rfactor in outer persitent
|
||||
// kernels
|
||||
rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor;
|
||||
rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor;
|
||||
|
||||
rparams.unroll_factor_iter_dom = iter_unroll_factor;
|
||||
rparams->unroll_factor_iter_dom = iter_unroll_factor;
|
||||
|
||||
if (iter_unroll_factor > 1) {
|
||||
rparams.vectorize_iter_dom = vectorize;
|
||||
rparams->vectorize_iter_dom = vectorize;
|
||||
}
|
||||
|
||||
rparams.lparams = LaunchParams(
|
||||
rparams->lparams = LaunchParams(
|
||||
LaunchParams::UNINITIALIZED_VAL,
|
||||
LaunchParams::UNINITIALIZED_VAL,
|
||||
LaunchParams::UNINITIALIZED_VAL,
|
||||
rparams.multiple_reds_per_blk ? bdimx : bdimy,
|
||||
rparams->multiple_reds_per_blk ? bdimx : bdimy,
|
||||
LaunchParams::UNINITIALIZED_VAL,
|
||||
LaunchParams::UNINITIALIZED_VAL);
|
||||
|
||||
rparams.tag = "Outer persistent kernel heuristic.\n";
|
||||
rparams->tag = "Outer persistent kernel heuristic.\n";
|
||||
|
||||
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
|
||||
std::cerr << "\n===== Reduction Stats ========\n"
|
||||
|
|
@ -754,7 +750,7 @@ ReductionParams OuterPersistentHeuristic(
|
|||
<< "max_multi_reduction_factor: " << max_multi_reduction_factor
|
||||
<< "\n"
|
||||
<< "block(" << bdimx << ", " << bdimy << ", 1)" << std::endl;
|
||||
std::cerr << rparams.toString() << std::endl;
|
||||
std::cerr << rparams->toString() << std::endl;
|
||||
}
|
||||
|
||||
return rparams;
|
||||
|
|
@ -762,7 +758,7 @@ ReductionParams OuterPersistentHeuristic(
|
|||
|
||||
} // namespace
|
||||
|
||||
ReductionParams PersistentHeuristic(
|
||||
std::shared_ptr<ReductionParams> persistentHeuristic(
|
||||
const int64_t total_reduction_numel,
|
||||
const int64_t total_iteration_numel,
|
||||
const int64_t inner_most_dimension_numel,
|
||||
|
|
@ -772,7 +768,7 @@ ReductionParams PersistentHeuristic(
|
|||
const int64_t max_persistent_buffer_size,
|
||||
size_t vectorize_factor,
|
||||
bool project_persistent_buffers) {
|
||||
ReductionParams rparams;
|
||||
std::shared_ptr<ReductionParams> rparams;
|
||||
if (fastest_dim_reduction) {
|
||||
rparams = innerPersistentHeuristic(
|
||||
total_reduction_numel,
|
||||
|
|
@ -783,7 +779,7 @@ ReductionParams PersistentHeuristic(
|
|||
max_persistent_buffer_size,
|
||||
vectorize_factor);
|
||||
} else {
|
||||
rparams = OuterPersistentHeuristic(
|
||||
rparams = outerPersistentHeuristic(
|
||||
total_reduction_numel,
|
||||
total_iteration_numel,
|
||||
n_tensor_inputs,
|
||||
|
|
@ -791,11 +787,11 @@ ReductionParams PersistentHeuristic(
|
|||
max_persistent_buffer_size,
|
||||
vectorize_factor);
|
||||
}
|
||||
rparams.project_persistent_buffers = project_persistent_buffers;
|
||||
rparams->project_persistent_buffers = project_persistent_buffers;
|
||||
return rparams;
|
||||
}
|
||||
|
||||
TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
|
||||
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
|
||||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache) {
|
||||
|
|
@ -827,8 +823,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
|
|||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
red_expr->getExprType() != c10::nullopt &&
|
||||
(red_expr->getExprType().value() == ExprType::ReductionOp ||
|
||||
red_expr->getExprType().value() == ExprType::WelfordOp),
|
||||
ir_utils::isReductionOp(red_expr),
|
||||
"TensorView doesn't have a reduction.");
|
||||
|
||||
auto tv_inps = ir_utils::filterByType<TensorView>(fusion->inputs());
|
||||
|
|
@ -939,7 +934,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
|
|||
n_tensor_inputs++;
|
||||
}
|
||||
|
||||
return PersistentHeuristic(
|
||||
return persistentHeuristic(
|
||||
properties.total_reduction_numel,
|
||||
properties.total_iteration_numel,
|
||||
properties.inner_most_dimension_numel,
|
||||
|
|
@ -951,7 +946,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
|
|||
project_persistent_buffers);
|
||||
}
|
||||
|
||||
TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
|
||||
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
|
||||
Fusion* fusion,
|
||||
const at::ArrayRef<c10::IValue>& runtime_inputs,
|
||||
HeuristicSummary* data_cache) {
|
||||
|
|
@ -980,8 +975,9 @@ TORCH_CUDA_CU_API void schedulePersistentKernel(
|
|||
|
||||
bool unroll = rparams.isUnrolled();
|
||||
|
||||
// Cache inputs if unrolled
|
||||
auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll);
|
||||
// Cache inputs even if not unrolled, as otherwise we may not create a
|
||||
// persistent buffer if that persistent buffer would be the input.
|
||||
auto cached_inputs = scheduler_utils::cacheInputs(fusion, true);
|
||||
|
||||
// Cache and fork outputs
|
||||
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll);
|
||||
|
|
|
|||
|
|
@ -18,12 +18,12 @@ namespace cuda {
|
|||
class SchedulerRuntimeInfo;
|
||||
class HeuristicSummary;
|
||||
|
||||
TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
|
||||
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
|
||||
Fusion* fusion,
|
||||
const at::ArrayRef<c10::IValue>& runtime_inputs,
|
||||
HeuristicSummary* data_cache = nullptr);
|
||||
|
||||
TORCH_CUDA_CU_API c10::optional<ReductionParams> getPersistentHeuristics(
|
||||
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
|
||||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache = nullptr);
|
||||
|
|
|
|||
|
|
@ -27,26 +27,28 @@ namespace {
|
|||
// Unused at the moment, commenting for clang tidy
|
||||
constexpr int64_t kThreadX = 128;
|
||||
|
||||
// DomainMap uses the ComputeAtMap to find a reference TensorView
|
||||
// that maps to all iterDomains in the fusion.
|
||||
class DomainMap {
|
||||
class DomainMap : public pointwise_utils::DomainMap {
|
||||
public:
|
||||
DomainMap(Fusion* fusion) : fusion_(fusion), ca_map_(ComputeAtMap(fusion)) {
|
||||
view_tvs_ = scheduler_utils::getViewTVs(fusion);
|
||||
}
|
||||
using pointwise_utils::DomainMap::DomainMap;
|
||||
|
||||
// The pointwise scheduler heuristics requires a minimum number of axes.
|
||||
// The output reference tensor should respect this requirement.
|
||||
TensorView* findReferenceTensorView(int minimum_num_axes = 0) const {
|
||||
TensorView* result = nullptr;
|
||||
int max_dims = -1;
|
||||
for (auto output_tv :
|
||||
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
|
||||
if (isValidReference(output_tv) &&
|
||||
hasMinimumSize(output_tv, minimum_num_axes) &&
|
||||
!output_tv->isFusionInput()) {
|
||||
return output_tv;
|
||||
int n_dims = pointwise_utils::nRootDims(output_tv);
|
||||
if (n_dims > max_dims) {
|
||||
result = output_tv;
|
||||
max_dims = n_dims;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool hasReferenceTensorView(Fusion* fusion) {
|
||||
|
|
@ -71,7 +73,7 @@ class DomainMap {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!areAllMapped(input_tv, output_tv)) {
|
||||
if (!areAllInputIdsMappedToOutput(input_tv, output_tv)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -83,92 +85,11 @@ class DomainMap {
|
|||
TORCH_INTERNAL_ASSERT(tv != nullptr);
|
||||
return (num_axes == 0 || tv->getMaybeRFactorDomain().size() > num_axes);
|
||||
}
|
||||
|
||||
// Determine if all iterDomains are mapped between input and output tvs
|
||||
bool areAllMapped(TensorView* input_tv, TensorView* output_tv) const {
|
||||
// Get concrete IDs for input root or rfactor domain
|
||||
std::unordered_set<IterDomain*> in_concrete_ids;
|
||||
for (auto in_id : input_tv->getMaybeRFactorDomain()) {
|
||||
if (!ca_map_.getConcreteMappedID(in_id, IdMappingMode::EXACT)
|
||||
->isBroadcast() &&
|
||||
!in_id->isReduction()) {
|
||||
in_concrete_ids.insert(
|
||||
ca_map_.getConcreteMappedID(in_id, IdMappingMode::EXACT));
|
||||
}
|
||||
}
|
||||
|
||||
// Erase all input concrete IDs mapped to the output domain
|
||||
// Ignore unresolved broadcast dimensions
|
||||
for (auto out_id : output_tv->getMaybeRFactorDomain()) {
|
||||
if (!out_id->isBroadcast()) {
|
||||
if (!eraseIfMapped(in_concrete_ids, out_id)) {
|
||||
eraseIfMappedThroughView(in_concrete_ids, out_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
return in_concrete_ids.empty();
|
||||
}
|
||||
|
||||
// Erase input concrete ID if it is mapped to output ID
|
||||
bool eraseIfMapped(
|
||||
std::unordered_set<IterDomain*>& in_concrete_ids,
|
||||
IterDomain* out_id) const {
|
||||
auto out_concrete_id =
|
||||
ca_map_.getConcreteMappedID(out_id, IdMappingMode::EXACT);
|
||||
auto in_concrete_id_iter = in_concrete_ids.find(out_concrete_id);
|
||||
bool found_match = in_concrete_id_iter != in_concrete_ids.end();
|
||||
if (found_match) {
|
||||
in_concrete_ids.erase(in_concrete_id_iter);
|
||||
}
|
||||
return found_match;
|
||||
}
|
||||
|
||||
// Check if in_id is mapped to out_id through any view rfactor domain
|
||||
void eraseIfMappedThroughView(
|
||||
std::unordered_set<IterDomain*>& in_concrete_ids,
|
||||
IterDomain* out_id) const {
|
||||
for (auto view : view_tvs_) {
|
||||
// Find any ID in view rfactor domain that is mapped to output ID
|
||||
auto view_rfactor_id = anyMapped(view->getRFactorDomain(), out_id);
|
||||
if (view_rfactor_id == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (view_rfactor_id->isRFactorProduct()) {
|
||||
// Check if input ID is mapped to any input IDs of the view rfactor ID
|
||||
auto root_inputs = InputsOf::outputs(fusion_, {view_rfactor_id});
|
||||
auto filtered_root_ids =
|
||||
ir_utils::filterByType<IterDomain>(root_inputs);
|
||||
for (auto view_root_id : filtered_root_ids) {
|
||||
eraseIfMapped(in_concrete_ids, view_root_id);
|
||||
}
|
||||
} else {
|
||||
// Otherwise, the input ID must map to the view rfactor ID
|
||||
eraseIfMapped(in_concrete_ids, view_rfactor_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find any id in domain that maps with target id
|
||||
IterDomain* anyMapped(
|
||||
const std::vector<IterDomain*> domain,
|
||||
IterDomain* target) const {
|
||||
for (auto id : domain) {
|
||||
if (ca_map_.areMapped(id, target, IdMappingMode::EXACT)) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Fusion* fusion_ = nullptr;
|
||||
ComputeAtMap ca_map_;
|
||||
std::vector<TensorView*> view_tvs_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
c10::optional<PointwiseParams> getPointwiseHeuristics(
|
||||
std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
|
||||
Fusion* fusion,
|
||||
const at::ArrayRef<c10::IValue>& runtime_inputs,
|
||||
HeuristicSummary* data_cache) {
|
||||
|
|
@ -176,7 +97,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
|
|||
return getPointwiseHeuristics(fusion, runtime_info, data_cache);
|
||||
}
|
||||
|
||||
c10::optional<PointwiseParams> getPointwiseHeuristics(
|
||||
std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
|
||||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache) {
|
||||
|
|
@ -187,35 +108,21 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
|
|||
// Incase any buffer is of type DataType::Index
|
||||
DataType index_type = indexModeToDtype(runtime_info.getIndexMode());
|
||||
|
||||
TensorView* largest_out = nullptr;
|
||||
int max_dims = -1;
|
||||
|
||||
auto in_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
|
||||
// Will want to access this with direct indexing later, convert now.
|
||||
std::vector<TensorView*> out_tvs;
|
||||
// Only use valid reference tensors during heuristics analysis
|
||||
DomainMap domain_map(fusion);
|
||||
for (auto out_tv : ir_utils::filterByType<TensorView>(fusion->outputs())) {
|
||||
if (domain_map.isValidReference(out_tv)) {
|
||||
out_tvs.push_back(out_tv);
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!out_tvs.empty(), "No valid reference outputs were found!");
|
||||
|
||||
for (auto out_tv : out_tvs) {
|
||||
int n_dims = 0;
|
||||
for (auto id : out_tv->getMaybeRFactorDomain()) {
|
||||
if (id->isReduction() || id->isBroadcast()) {
|
||||
continue;
|
||||
}
|
||||
n_dims++;
|
||||
}
|
||||
if (n_dims > max_dims) {
|
||||
largest_out = out_tv;
|
||||
max_dims = n_dims;
|
||||
}
|
||||
}
|
||||
auto domain_map_entry =
|
||||
HeuristicSummaryEntry<HeuristicCompileTime::DomainMap>(
|
||||
data_cache,
|
||||
[fusion]() { return std::make_unique<DomainMap>(fusion); });
|
||||
const auto& domain_map = dynamic_cast<DomainMap&>(domain_map_entry.get());
|
||||
|
||||
auto largest_out_entry =
|
||||
HeuristicSummaryEntry<HeuristicCompileTime::ReferenceTensors>(
|
||||
data_cache, [&domain_map]() {
|
||||
std::vector<TensorView*> data{domain_map.findReferenceTensorView()};
|
||||
return std::make_unique<std::vector<TensorView*>>(std::move(data));
|
||||
});
|
||||
TensorView* largest_out = largest_out_entry.get()[0];
|
||||
|
||||
TORCH_INTERNAL_ASSERT(largest_out != nullptr);
|
||||
|
||||
|
|
@ -224,15 +131,12 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
|
|||
|
||||
// TODO: Set to 1?
|
||||
int64_t max_input_dtype_size = 2;
|
||||
size_t n_tensors = 0;
|
||||
|
||||
for (auto inp : in_tvs) {
|
||||
max_input_dtype_size = std::max(
|
||||
max_input_dtype_size,
|
||||
(int64_t)dataTypeSize(inp->getDataType().value(), index_type));
|
||||
n_tensors++;
|
||||
}
|
||||
n_tensors += std::distance(out_tvs.begin(), out_tvs.end());
|
||||
|
||||
auto ref_root = largest_out->getMaybeRFactorDomain();
|
||||
std::vector<int64_t> elem_counts(ref_root.size(), 1);
|
||||
|
|
@ -266,10 +170,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
|
|||
std::vector<scheduler_utils::BroadcastMultiple>>();
|
||||
});
|
||||
broadcast_byte_multiples_entry.get();
|
||||
|
||||
PointwiseParams params;
|
||||
params.tag = "Pointwise heuristics";
|
||||
return params;
|
||||
return std::make_shared<PointwiseParams>("Pointwise heuristics");
|
||||
}
|
||||
|
||||
// Find all vectorizable inputs/outputs
|
||||
|
|
@ -307,8 +208,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
|
|||
max_unroll_factor = 1;
|
||||
}
|
||||
|
||||
PointwiseParams params;
|
||||
params.tag = "Pointwise heuristics";
|
||||
auto params = std::make_shared<PointwiseParams>("Pointwise heuristics");
|
||||
|
||||
/*
|
||||
* 2D pointwise scheduling logic. What is expected is there's some
|
||||
|
|
@ -389,7 +289,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
|
|||
}
|
||||
|
||||
// If there isn't very much parallelism available, just use 1D scheduler
|
||||
if (true || n_elems * 2 > device_multiprocessor_count * kThreadX) {
|
||||
if (n_elems * 2 > device_multiprocessor_count * kThreadX) {
|
||||
int64_t min_total_transfer = std::numeric_limits<int64_t>::max();
|
||||
|
||||
for (const auto break_point_i : c10::irange(ref_root.size())) {
|
||||
|
|
@ -476,7 +376,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
|
|||
// Vectorizing innermost domains
|
||||
|
||||
// Don't try to vectorize if it's not recommended
|
||||
params.unroll_factor = 1;
|
||||
params->unroll_factor = 1;
|
||||
|
||||
// Compute maximum vectorize factor that can be used
|
||||
size_t vectorize_factor = max_unroll_factor;
|
||||
|
|
@ -506,34 +406,33 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
|
|||
}
|
||||
|
||||
if (vectorize_factor == 1) {
|
||||
params.vectorize = false;
|
||||
params.unroll_factor = max_unroll_factor;
|
||||
params->vectorize = false;
|
||||
params->unroll_factor = max_unroll_factor;
|
||||
} else {
|
||||
params.vectorize = true;
|
||||
params.unroll_factor = vectorize_factor;
|
||||
params->vectorize = true;
|
||||
params->unroll_factor = vectorize_factor;
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(right_elem_count > 0 || break_point == 0);
|
||||
TORCH_INTERNAL_ASSERT(!(bdimy > 1 && gdim_right > 1));
|
||||
|
||||
params.break_point = break_point;
|
||||
params.flip_grid_binding = flip_grid_binding;
|
||||
params.split_block = bdimy > 1;
|
||||
params->break_point = break_point;
|
||||
params->flip_grid_binding = flip_grid_binding;
|
||||
params->split_block = bdimy > 1;
|
||||
|
||||
params.lparams.bind(bdimx, ParallelType::TIDx);
|
||||
if (params.split_block) {
|
||||
params.lparams.bind(bdimy, ParallelType::TIDy);
|
||||
params->lparams.bind(bdimx, ParallelType::TIDx);
|
||||
if (params->split_block) {
|
||||
params->lparams.bind(bdimy, ParallelType::TIDy);
|
||||
}
|
||||
if ((flip_grid_binding && gdim_right > 65535) ||
|
||||
(!flip_grid_binding && gdim_left > 65535)) {
|
||||
params.split_grid_y_dim = true;
|
||||
params->split_grid_y_dim = true;
|
||||
}
|
||||
|
||||
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
|
||||
std::cerr << "\n===== Pointwise Stats ========\n"
|
||||
<< "num_elems: " << n_elems << "\n"
|
||||
<< "elem_counts: " << elem_counts << "\n"
|
||||
<< "n_tensor_inputs: " << n_tensors << "\n"
|
||||
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
|
||||
<< "vectorize_factor: " << vectorize_factor << std::endl;
|
||||
std::cerr << "broadcast_byte_multiples: ";
|
||||
|
|
@ -545,7 +444,7 @@ c10::optional<PointwiseParams> getPointwiseHeuristics(
|
|||
<< (right_elem_count > 0 ? n_elems / right_elem_count : 0)
|
||||
<< " RHS elems: " << right_elem_count << std::endl;
|
||||
std::cerr << std::endl;
|
||||
std::cerr << params.toString() << std::endl;
|
||||
std::cerr << params->toString() << std::endl;
|
||||
}
|
||||
|
||||
return params;
|
||||
|
|
@ -558,25 +457,11 @@ LaunchParams schedulePointwise(
|
|||
FUSER_PERF_SCOPE("scheduleFusion");
|
||||
auto params = getPointwiseHeuristics(fusion, runtime_inputs);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
params.has_value(), "Could not schedule pointwise operation.");
|
||||
schedulePointwise(fusion, params.value());
|
||||
return params.value().lparams;
|
||||
params != nullptr, "Could not schedule pointwise operation.");
|
||||
schedulePointwise(fusion, *params);
|
||||
return params->lparams;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Returns number of non-reduction/non-broadcast dims in rfactor domain
|
||||
size_t nRootDims(const TensorView* tv) {
|
||||
auto root_dom = tv->getMaybeRFactorDomain();
|
||||
size_t tv_n_dims = 0;
|
||||
for (auto dim : root_dom) {
|
||||
if (!dim->isReduction() && !dim->isBroadcast()) {
|
||||
tv_n_dims++;
|
||||
}
|
||||
}
|
||||
return tv_n_dims;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool hasReferenceTensorView(Fusion* fusion) {
|
||||
return DomainMap::hasReferenceTensorView(fusion);
|
||||
}
|
||||
|
|
@ -617,11 +502,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
|
|||
|
||||
size_t max_dims = 0;
|
||||
for (auto inp : input_tvs) {
|
||||
max_dims = std::max(nRootDims(inp), max_dims);
|
||||
max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims);
|
||||
}
|
||||
|
||||
for (auto out : output_tvs) {
|
||||
max_dims = std::max(nRootDims(out), max_dims);
|
||||
max_dims = std::max(pointwise_utils::nRootDims(out), max_dims);
|
||||
}
|
||||
|
||||
// If everything is zero dim tensors, just return.
|
||||
|
|
@ -643,10 +528,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
|
|||
int rhs_i = -1;
|
||||
for (int i = (int)reference_tv->nDims(); i > (int)params.break_point; i--) {
|
||||
auto axis_i = i - 1;
|
||||
if (reference_tv->axis(axis_i)->isBroadcast() ||
|
||||
reference_tv->axis(axis_i)->isReduction()) {
|
||||
continue;
|
||||
}
|
||||
if (rhs_i == -1) {
|
||||
rhs_i = axis_i;
|
||||
} else {
|
||||
|
|
@ -663,10 +544,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
|
|||
int lhs_i = -1;
|
||||
for (int i = (int)params.break_point; i > 0; i--) {
|
||||
auto axis_i = i - 1;
|
||||
if (reference_tv->axis(axis_i)->isBroadcast() ||
|
||||
reference_tv->axis(axis_i)->isReduction()) {
|
||||
continue;
|
||||
}
|
||||
if (lhs_i == -1) {
|
||||
lhs_i = axis_i;
|
||||
} else {
|
||||
|
|
@ -676,6 +553,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
|
|||
}
|
||||
|
||||
int64_t unswitch_pos;
|
||||
IterDomain* vectorize_id = nullptr;
|
||||
if (params.break_point) {
|
||||
// 2D parallelization scheme
|
||||
TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0);
|
||||
|
|
@ -691,10 +569,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
|
|||
// [outer, Unswitch | i-remainder, TIDx, Vectorization]
|
||||
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
|
||||
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
|
||||
|
||||
// Aggressively mark with vectorized and cleanup later. That way we
|
||||
// don't have to manually specify parallelization outside the reference.
|
||||
reference_tv->axis(4)->parallelize(ParallelType::Vectorize);
|
||||
// Vectorization are propagated separately
|
||||
vectorize_id = reference_tv->axis(4);
|
||||
|
||||
// [outer, Unswitch | i-remainder, TIDx, Vectorization]
|
||||
// To make consistent with unrolling:
|
||||
|
|
@ -709,6 +585,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
|
|||
reference_tv->reorder({{1, 2}});
|
||||
// [outer, i-remainder, unswitch, unroll, TIDx ]
|
||||
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
|
||||
// Here we do not set axis(3)->parallelize(Unroll) because we do not want
|
||||
// it to be propagated. We manually unroll by splitting the inline
|
||||
// propagation process into two steps:
|
||||
// step 1: inline at the unswitch position for cached inputs and outputs
|
||||
// step 2: inline at the inner most dim for the rest of the graph
|
||||
reference_tv->axis(4)->parallelize(ParallelType::TIDx);
|
||||
|
||||
//[outer | i-remainder, Unswitch, Unroll, TIDx]
|
||||
|
|
@ -795,9 +676,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
|
|||
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
|
||||
reference_tv->axis(1)->parallelize(ParallelType::TIDx);
|
||||
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
|
||||
// Aggressively mark with vectorized and cleanup later. That way we
|
||||
// don't have to manually specify parallelization outside the reference.
|
||||
reference_tv->axis(3)->parallelize(ParallelType::Vectorize);
|
||||
// Vectorization are propagated separately
|
||||
vectorize_id = reference_tv->axis(3);
|
||||
|
||||
//[BIDx, TIDx, Unswitch, Vectorization]
|
||||
// To make consistent with unrolling:
|
||||
|
|
@ -814,6 +694,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
|
|||
// [BIDx, Unswitch, Unroll, TIDx]
|
||||
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
|
||||
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
|
||||
// Here we do not set axis(2)->parallelize(Unroll) because we do not want
|
||||
// it to be propagated. We manually unroll by splitting the inline
|
||||
// propagation process into two steps:
|
||||
// step 1: inline at the unswitch position for cached inputs and outputs
|
||||
// step 2: inline at the inner most dim for the rest of the graph
|
||||
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
|
||||
}
|
||||
unswitch_pos = 2;
|
||||
|
|
@ -822,43 +707,42 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
|
|||
TransformPropagator propagator(reference_tv);
|
||||
MaxRootDomainInfoSpanningTree spanning_tree(reference_tv);
|
||||
spanning_tree.traverse(&propagator);
|
||||
scheduler_utils::parallelizeAllLike(reference_tv, all_tvs);
|
||||
scheduler_utils::parallelizeAllLike(reference_tv);
|
||||
|
||||
if (params.vectorize) {
|
||||
// Grab all tensor views that should be vectorized
|
||||
auto vectorized_tvs =
|
||||
auto inputs_outputs =
|
||||
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true);
|
||||
// Going to move inputs to consumers of inputs, need a copy as we'll modify
|
||||
// the original.
|
||||
{
|
||||
auto vectorized_tvs_copy = vectorized_tvs;
|
||||
for (auto inp : vectorized_tvs_copy) {
|
||||
if (!inp->isFusionInput()) {
|
||||
continue;
|
||||
}
|
||||
vectorized_tvs.erase(
|
||||
std::find(vectorized_tvs.begin(), vectorized_tvs.end(), inp));
|
||||
auto consumer_tvs = ir_utils::consumerTvsOf(inp);
|
||||
vectorized_tvs.insert(
|
||||
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
|
||||
std::vector<TensorView*> vectorized_tvs;
|
||||
bool should_vectorize_reference_tv = false;
|
||||
for (auto tv : inputs_outputs) {
|
||||
if (tv == reference_tv) {
|
||||
should_vectorize_reference_tv = true;
|
||||
}
|
||||
if (!tv->isFusionInput()) {
|
||||
vectorized_tvs.emplace_back(tv);
|
||||
continue;
|
||||
}
|
||||
// move inputs to consumers of inputs
|
||||
auto consumer_tvs = ir_utils::consumerTvsOf(tv);
|
||||
vectorized_tvs.insert(
|
||||
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
|
||||
}
|
||||
// Clear vectorize on tensors that shouldn't have it
|
||||
for (auto tv : all_tvs) {
|
||||
if (std::find(vectorized_tvs.begin(), vectorized_tvs.end(), tv) ==
|
||||
vectorized_tvs.end()) {
|
||||
for (auto id : tv->domain()->domain()) {
|
||||
if (id->getParallelType() == ParallelType::Vectorize) {
|
||||
id->parallelize(ParallelType::Serial);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Aggressively mark with vectorized and cleanup later. That way we
|
||||
// don't have to manually specify parallelization outside the reference.
|
||||
vectorize_id->parallelize(ParallelType::Vectorize);
|
||||
scheduler_utils::parallelizeAllLike(
|
||||
reference_tv, vectorized_tvs, {ParallelType::Vectorize});
|
||||
if (!should_vectorize_reference_tv) {
|
||||
vectorize_id->parallelize(ParallelType::Serial);
|
||||
}
|
||||
}
|
||||
|
||||
// Begin by inlining at the unswitch position for the entire DAG. The cached
|
||||
// inputs, and outputs will keep this inline position, but other tensors will
|
||||
// get a higher position in later inline propagation.
|
||||
// get a higher position in later inline propagation. We need this separate
|
||||
// step because we were not using ParallelType::Unroll, so we have to do
|
||||
// unrolling manually.
|
||||
InlinePropagator inline_unswitch(
|
||||
reference_tv, unswitch_pos, ComputeAtMode::BestEffort);
|
||||
spanning_tree.traverse(&inline_unswitch);
|
||||
|
|
@ -877,10 +761,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
|
|||
InlinePropagator inline_inner_most(
|
||||
reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors);
|
||||
spanning_tree.traverse(&inline_inner_most);
|
||||
|
||||
// Fix max producer position
|
||||
MaxProducerPosUpdater updater;
|
||||
spanning_tree.traverse(&updater);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@ namespace cuda {
|
|||
class SchedulerRuntimeInfo;
|
||||
class HeuristicSummary;
|
||||
|
||||
TORCH_CUDA_CU_API c10::optional<PointwiseParams> getPointwiseHeuristics(
|
||||
TORCH_CUDA_CU_API std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
|
||||
Fusion* fusion,
|
||||
const at::ArrayRef<c10::IValue>& runtime_inputs,
|
||||
HeuristicSummary* data_cache = nullptr);
|
||||
|
||||
TORCH_CUDA_CU_API c10::optional<PointwiseParams> getPointwiseHeuristics(
|
||||
TORCH_CUDA_CU_API std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
|
||||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache = nullptr);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/heuristic.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
|
|
@ -9,11 +9,11 @@ namespace jit {
|
|||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
// Parameters the Reduction Heuristic Generates to describe the optimial
|
||||
// schedule. Warning: equal operator is intended for use in caching the kernel
|
||||
// associated with these reduction parameters. It does not check if the launch
|
||||
// parameters are equivelent!
|
||||
class PointwiseParams {
|
||||
// Parameters of the pointwise heuristic to describe the optimial schedule.
|
||||
// Warning: equal operator is intended for use in caching the kernel associated
|
||||
// with these pointwise parameters. It does not check if the launch parameters
|
||||
// are equivelent!
|
||||
class PointwiseParams : public HeuristicParams {
|
||||
public:
|
||||
// vectorize if true, otherwise unroll
|
||||
bool vectorize = false;
|
||||
|
|
@ -39,12 +39,16 @@ class PointwiseParams {
|
|||
// Unroll or vectorization factor
|
||||
size_t unroll_factor = 1;
|
||||
|
||||
std::string tag = "";
|
||||
|
||||
LaunchParams lparams;
|
||||
using HeuristicParams::HeuristicParams;
|
||||
|
||||
// Warning: Does not check launch parameters!
|
||||
bool operator==(const PointwiseParams& other) const {
|
||||
bool sameAs(
|
||||
const std::shared_ptr<HeuristicParams>& other_base) const override {
|
||||
auto other_casted = std::dynamic_pointer_cast<PointwiseParams>(other_base);
|
||||
if (other_casted == nullptr) {
|
||||
return false;
|
||||
}
|
||||
const PointwiseParams& other = *other_casted;
|
||||
bool attr_equal = other.vectorize == vectorize &&
|
||||
other.break_point == break_point && other.split_block == split_block &&
|
||||
other.split_grid_y_dim == split_grid_y_dim &&
|
||||
|
|
@ -53,7 +57,7 @@ class PointwiseParams {
|
|||
return attr_equal;
|
||||
}
|
||||
|
||||
std::string toString() const {
|
||||
std::string toString() const override {
|
||||
std::stringstream ss;
|
||||
ss << "\n===== Pointwise Parameters ========\n"
|
||||
<< (tag == "" ? "" : "Tag: ") << tag << " Pointwise Characteristics:\n"
|
||||
|
|
@ -82,20 +86,21 @@ class PointwiseParams {
|
|||
ss << "====================================\n";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
// Warning: Hash is not based on launch parameters!
|
||||
class PointwiseParamsHash {
|
||||
public:
|
||||
size_t operator()(const PointwiseParams& pp) const {
|
||||
size_t attr_hash = static_cast<size_t>(pp.vectorize) ^
|
||||
static_cast<size_t>(pp.break_point) << 4 ^
|
||||
static_cast<size_t>(pp.split_block) << 5 ^
|
||||
static_cast<size_t>(pp.split_grid_y_dim) << 6 ^
|
||||
static_cast<size_t>(pp.unroll_factor) << 9 ^
|
||||
static_cast<size_t>(pp.flip_grid_binding) << 10;
|
||||
// Warning: Hash is not based on launch parameters!
|
||||
size_t hash() const override {
|
||||
size_t attr_hash = static_cast<size_t>(vectorize) ^
|
||||
static_cast<size_t>(break_point) << 4 ^
|
||||
static_cast<size_t>(split_block) << 5 ^
|
||||
static_cast<size_t>(split_grid_y_dim) << 6 ^
|
||||
static_cast<size_t>(unroll_factor) << 9 ^
|
||||
static_cast<size_t>(flip_grid_binding) << 10;
|
||||
return attr_hash;
|
||||
}
|
||||
|
||||
std::shared_ptr<HeuristicParams> clone() const override {
|
||||
return std::make_shared<PointwiseParams>(*this);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
101
torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp
Normal file
101
torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.cpp
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace cuda {
|
||||
namespace pointwise_utils {
|
||||
|
||||
DomainMap::DomainMap(Fusion* fusion)
|
||||
: fusion_(fusion), ca_map_(ComputeAtMap(fusion)) {
|
||||
view_tvs_ = scheduler_utils::getViewTVs(fusion);
|
||||
}
|
||||
|
||||
bool DomainMap::areExactMapped(IterDomain* id1, IterDomain* id2) {
|
||||
return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT);
|
||||
}
|
||||
|
||||
// Determine if all IterDomains in input are mapped to output
|
||||
bool DomainMap::areAllInputIdsMappedToOutput(
|
||||
TensorView* input_tv,
|
||||
TensorView* output_tv) const {
|
||||
// Get concrete IDs for input root or rfactor domain
|
||||
std::unordered_set<IterDomain*> in_concrete_ids;
|
||||
for (auto in_id : input_tv->getMaybeRFactorDomain()) {
|
||||
auto concrete = ca_map_.getConcreteMappedID(in_id, IdMappingMode::EXACT);
|
||||
if (!concrete->isBroadcast() && !in_id->isReduction()) {
|
||||
in_concrete_ids.insert(concrete);
|
||||
}
|
||||
}
|
||||
|
||||
// Erase all input concrete IDs mapped to the output domain
|
||||
// Ignore unresolved broadcast dimensions
|
||||
for (auto out_id : output_tv->getMaybeRFactorDomain()) {
|
||||
if (!out_id->isBroadcast()) {
|
||||
if (!eraseIfMapped(in_concrete_ids, out_id)) {
|
||||
eraseIfInputMappedThroughViewToOutput(in_concrete_ids, out_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
return in_concrete_ids.empty();
|
||||
}
|
||||
|
||||
// Erase input concrete ID if it is mapped to output ID
|
||||
bool DomainMap::eraseIfMapped(
|
||||
std::unordered_set<IterDomain*>& in_concrete_ids,
|
||||
IterDomain* out_id) const {
|
||||
auto out_concrete_id =
|
||||
ca_map_.getConcreteMappedID(out_id, IdMappingMode::EXACT);
|
||||
auto in_concrete_id_iter = in_concrete_ids.find(out_concrete_id);
|
||||
bool found_match = in_concrete_id_iter != in_concrete_ids.end();
|
||||
if (found_match) {
|
||||
in_concrete_ids.erase(in_concrete_id_iter);
|
||||
}
|
||||
return found_match;
|
||||
}
|
||||
|
||||
// Check if in_id is mapped to out_id through any view rfactor domain.
|
||||
// Currently this function only allow having one view on the path from input to
|
||||
// output. If there are multiple views, then likely the pointwise scheduler will
|
||||
// reject the fusion because we can not correctly find a reference tensor.
|
||||
void DomainMap::eraseIfInputMappedThroughViewToOutput(
|
||||
std::unordered_set<IterDomain*>& in_concrete_ids,
|
||||
IterDomain* out_id) const {
|
||||
for (auto view : view_tvs_) {
|
||||
// Find any ID in view rfactor domain that is mapped to output ID
|
||||
auto view_rfactor_id = anyMapped(view->getRFactorDomain(), out_id);
|
||||
if (view_rfactor_id == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (view_rfactor_id->isRFactorProduct()) {
|
||||
// Check if input ID is mapped to any input IDs of the view rfactor ID
|
||||
auto root_inputs = InputsOf::outputs(fusion_, {view_rfactor_id});
|
||||
auto filtered_root_ids = ir_utils::filterByType<IterDomain>(root_inputs);
|
||||
for (auto view_root_id : filtered_root_ids) {
|
||||
eraseIfMapped(in_concrete_ids, view_root_id);
|
||||
}
|
||||
} else {
|
||||
// Otherwise, the input ID must map to the view rfactor ID
|
||||
eraseIfMapped(in_concrete_ids, view_rfactor_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find any id in domain that maps with target id
|
||||
IterDomain* DomainMap::anyMapped(
|
||||
const std::vector<IterDomain*>& domain,
|
||||
IterDomain* target) const {
|
||||
for (auto id : domain) {
|
||||
if (ca_map_.areMapped(id, target, IdMappingMode::EXACT)) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace pointwise_utils
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
68
torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h
Normal file
68
torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/compute_at_map.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace fuser {
|
||||
namespace cuda {
|
||||
namespace pointwise_utils {
|
||||
|
||||
// DomainMap uses the ComputeAtMap to find a reference TensorView
|
||||
// that maps to all IterDomains in the fusion.
|
||||
class DomainMap {
|
||||
public:
|
||||
DomainMap(Fusion* fusion);
|
||||
virtual ~DomainMap() = default;
|
||||
|
||||
bool areExactMapped(IterDomain* id1, IterDomain* id2);
|
||||
|
||||
const ComputeAtMap& getComputeAtMap() const {
|
||||
return ca_map_;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Determine if all iterDomains are mapped between input and output tvs
|
||||
bool areAllInputIdsMappedToOutput(TensorView* input_tv, TensorView* output_tv)
|
||||
const;
|
||||
|
||||
// Erase input concrete ID if it is mapped to output ID
|
||||
bool eraseIfMapped(
|
||||
std::unordered_set<IterDomain*>& in_concrete_ids,
|
||||
IterDomain* out_id) const;
|
||||
|
||||
// Check if in_id is mapped to out_id through any view rfactor domain
|
||||
void eraseIfInputMappedThroughViewToOutput(
|
||||
std::unordered_set<IterDomain*>& in_concrete_ids,
|
||||
IterDomain* out_id) const;
|
||||
|
||||
// Find any id in domain that maps with target id
|
||||
IterDomain* anyMapped(
|
||||
const std::vector<IterDomain*>& domain,
|
||||
IterDomain* target) const;
|
||||
|
||||
Fusion* fusion_ = nullptr;
|
||||
ComputeAtMap ca_map_;
|
||||
std::vector<TensorView*> view_tvs_;
|
||||
};
|
||||
|
||||
// Returns number of non-reduction/non-broadcast dims in rfactor domain
|
||||
inline size_t nRootDims(const TensorView* tv) {
|
||||
auto root_dom = tv->getMaybeRFactorDomain();
|
||||
size_t tv_n_dims = 0;
|
||||
for (auto dim : root_dom) {
|
||||
if (!dim->isReduction() && !dim->isBroadcast()) {
|
||||
tv_n_dims++;
|
||||
}
|
||||
}
|
||||
return tv_n_dims;
|
||||
}
|
||||
|
||||
} // namespace pointwise_utils
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -40,11 +40,6 @@ int64_t roundDownPow2OrMultipleOf(const int64_t x, const int64_t multiple) {
|
|||
return std::max(std::max(round_down_multiple, round_down_pow2), (int64_t)1);
|
||||
}
|
||||
|
||||
// Div x by y, but min at 1
|
||||
int64_t safeDiv(const int64_t x, const int64_t y) {
|
||||
return std::max(x / y, (int64_t)1);
|
||||
}
|
||||
|
||||
int64_t clamp(const int64_t val, const int64_t min_val, const int64_t max_val) {
|
||||
return std::min(std::max(val, min_val), max_val);
|
||||
}
|
||||
|
|
@ -54,20 +49,20 @@ int64_t clamp(const int64_t val, const int64_t min_val, const int64_t max_val) {
|
|||
void reduceProductTo(int64_t& z, int64_t& y, int64_t& x, const int64_t max) {
|
||||
TORCH_INTERNAL_ASSERT(max > 1);
|
||||
if (z * y * x > max) {
|
||||
z = safeDiv(z, 2);
|
||||
z = scheduler_utils::safeDiv(z, 2);
|
||||
}
|
||||
if (z * y * x > max) {
|
||||
y = safeDiv(y, 2);
|
||||
y = scheduler_utils::safeDiv(y, 2);
|
||||
}
|
||||
if (z * y * x > max) {
|
||||
x = safeDiv(x, 2);
|
||||
x = scheduler_utils::safeDiv(x, 2);
|
||||
}
|
||||
if (z * y * x > max) {
|
||||
reduceProductTo(x, y, z, max);
|
||||
}
|
||||
}
|
||||
|
||||
ReductionParams innerReductionHeuristic(
|
||||
std::shared_ptr<ReductionParams> innerReductionHeuristic(
|
||||
const int64_t total_reduction_numel,
|
||||
const int64_t total_iteration_numel,
|
||||
const int64_t inner_most_dimension_numel,
|
||||
|
|
@ -382,21 +377,21 @@ ReductionParams innerReductionHeuristic(
|
|||
// require iterating over this entire function.
|
||||
}
|
||||
|
||||
ReductionParams rparams;
|
||||
rparams.fastest_dim = true;
|
||||
rparams.cross_block_inner_reduction = true;
|
||||
rparams.block_dim_inner_reduction = ParallelType::TIDx;
|
||||
rparams.cross_grid_inner_reduction = gridim > 1;
|
||||
rparams.multiple_reds_per_blk = bdimy > 1;
|
||||
auto rparams = std::make_shared<ReductionParams>();
|
||||
rparams->fastest_dim = true;
|
||||
rparams->cross_block_inner_reduction = true;
|
||||
rparams->block_dim_inner_reduction = ParallelType::TIDx;
|
||||
rparams->cross_grid_inner_reduction = gridim > 1;
|
||||
rparams->multiple_reds_per_blk = bdimy > 1;
|
||||
bool pad_bdimx = bdimx > 16 &&
|
||||
bdimx * bdimy <
|
||||
(int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
|
||||
// If barely just covering reduction dim, don't pad to the next warp
|
||||
pad_bdimx = pad_bdimx &&
|
||||
bdimx * inner_reduction_unroll_factor != inner_most_dimension_numel;
|
||||
rparams.pad_inner_reduction_to_warp = pad_bdimx;
|
||||
rparams->pad_inner_reduction_to_warp = pad_bdimx;
|
||||
|
||||
if (rparams.pad_inner_reduction_to_warp) {
|
||||
if (rparams->pad_inner_reduction_to_warp) {
|
||||
// Adjust bdimx based on padding
|
||||
auto min_warp_size =
|
||||
(int64_t)at::cuda::getCurrentDeviceProperties()->warpSize;
|
||||
|
|
@ -405,24 +400,24 @@ ReductionParams innerReductionHeuristic(
|
|||
: bdimx + min_warp_size - bdimx % min_warp_size;
|
||||
}
|
||||
|
||||
rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor;
|
||||
rparams.vectorize_inner_reduction = vectorize;
|
||||
rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor;
|
||||
rparams->vectorize_inner_reduction = vectorize;
|
||||
|
||||
if (rparams.multiple_reds_per_blk) {
|
||||
rparams.block_dim_iter_dom = ParallelType::TIDy;
|
||||
if (rparams->multiple_reds_per_blk) {
|
||||
rparams->block_dim_iter_dom = ParallelType::TIDy;
|
||||
}
|
||||
|
||||
rparams.unroll_factor_iter_dom = iter_unroll_factor;
|
||||
rparams->unroll_factor_iter_dom = iter_unroll_factor;
|
||||
|
||||
rparams.schedule_3D = total_reduction_numel != inner_most_dimension_numel;
|
||||
rparams->schedule_3D = total_reduction_numel != inner_most_dimension_numel;
|
||||
// Outer reduction domain
|
||||
if (rparams.schedule_3D) {
|
||||
rparams.cross_grid_outer_reduction = grodim > 1;
|
||||
if (rparams->schedule_3D) {
|
||||
rparams->cross_grid_outer_reduction = grodim > 1;
|
||||
if (bdimz > 1) {
|
||||
rparams.block_dim_outer_reduction = ParallelType::TIDz;
|
||||
rparams.cross_block_outer_reduction = true;
|
||||
rparams->block_dim_outer_reduction = ParallelType::TIDz;
|
||||
rparams->cross_block_outer_reduction = true;
|
||||
}
|
||||
rparams.unroll_factor_outer_reduction = outer_reduction_unroll_factor;
|
||||
rparams->unroll_factor_outer_reduction = outer_reduction_unroll_factor;
|
||||
}
|
||||
|
||||
int64_t gdimx = LaunchParams::UNINITIALIZED_VAL;
|
||||
|
|
@ -433,38 +428,38 @@ ReductionParams innerReductionHeuristic(
|
|||
// gdimx assigned to grdim. Otherwise it's helpful to pull godim into gdimx in
|
||||
// case it's larger than gdimy can hold, as not doing so can thrash the cache.
|
||||
|
||||
if (rparams.cross_grid_inner_reduction) {
|
||||
rparams.grid_dim_inner_reduction = ParallelType::BIDx;
|
||||
rparams.split_grid_dim_inner_reduction = true;
|
||||
if (rparams->cross_grid_inner_reduction) {
|
||||
rparams->grid_dim_inner_reduction = ParallelType::BIDx;
|
||||
rparams->split_grid_dim_inner_reduction = true;
|
||||
gdimx = std::min(gridim, scheduler_utils::x_grid_limit);
|
||||
|
||||
rparams.grid_dim_iter_dom = ParallelType::BIDy;
|
||||
rparams->grid_dim_iter_dom = ParallelType::BIDy;
|
||||
if (godim > scheduler_utils::y_grid_limit) {
|
||||
rparams.split_grid_dim_iter_dom = true;
|
||||
rparams->split_grid_dim_iter_dom = true;
|
||||
gdimy = std::min(godim, scheduler_utils::y_grid_limit);
|
||||
}
|
||||
|
||||
} else {
|
||||
rparams.grid_dim_iter_dom = ParallelType::BIDx;
|
||||
rparams->grid_dim_iter_dom = ParallelType::BIDx;
|
||||
if (gdimx > scheduler_utils::x_grid_limit) {
|
||||
rparams.split_grid_dim_iter_dom = true;
|
||||
rparams->split_grid_dim_iter_dom = true;
|
||||
gdimx = godim;
|
||||
}
|
||||
}
|
||||
|
||||
if (rparams.cross_grid_outer_reduction) {
|
||||
if (rparams.cross_block_inner_reduction) {
|
||||
rparams.grid_dim_outer_reduction = ParallelType::BIDz;
|
||||
if (rparams->cross_grid_outer_reduction) {
|
||||
if (rparams->cross_block_inner_reduction) {
|
||||
rparams->grid_dim_outer_reduction = ParallelType::BIDz;
|
||||
gdimz = std::min(grodim, scheduler_utils::z_grid_limit);
|
||||
rparams.split_grid_dim_outer_reduction = true;
|
||||
rparams->split_grid_dim_outer_reduction = true;
|
||||
} else {
|
||||
rparams.grid_dim_outer_reduction = ParallelType::BIDy;
|
||||
rparams->grid_dim_outer_reduction = ParallelType::BIDy;
|
||||
gdimy = std::min(grodim, scheduler_utils::y_grid_limit);
|
||||
rparams.split_grid_dim_outer_reduction = true;
|
||||
rparams->split_grid_dim_outer_reduction = true;
|
||||
}
|
||||
}
|
||||
|
||||
rparams.lparams = LaunchParams(
|
||||
rparams->lparams = LaunchParams(
|
||||
gdimx,
|
||||
gdimy,
|
||||
gdimz,
|
||||
|
|
@ -483,20 +478,20 @@ ReductionParams innerReductionHeuristic(
|
|||
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
|
||||
<< "block(" << bdimx << ", " << bdimy << ", " << bdimz << ")"
|
||||
<< std::endl;
|
||||
std::cerr << rparams.toString() << std::endl;
|
||||
std::cerr << rparams->toString() << std::endl;
|
||||
}
|
||||
|
||||
// If 3d, check if it's supported by the scheduler, otherwise force 1D
|
||||
// schedule
|
||||
if (rparams.schedule_3D) {
|
||||
if (rparams.multiple_reds_per_blk &&
|
||||
(rparams.cross_grid_inner_reduction ||
|
||||
rparams.cross_grid_outer_reduction)) {
|
||||
if (rparams->schedule_3D) {
|
||||
if (rparams->multiple_reds_per_blk &&
|
||||
(rparams->cross_grid_inner_reduction ||
|
||||
rparams->cross_grid_outer_reduction)) {
|
||||
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
|
||||
std::cerr << "\n===== UNSUPPORTED REDUCTION HEURISTIC ========\n";
|
||||
std::cerr << rparams.multiple_reds_per_blk << ", "
|
||||
<< (rparams.unroll_factor_inner_reduction > 1) << ", "
|
||||
<< rparams.cross_grid_inner_reduction << std::endl;
|
||||
std::cerr << rparams->multiple_reds_per_blk << ", "
|
||||
<< (rparams->unroll_factor_inner_reduction > 1) << ", "
|
||||
<< rparams->cross_grid_inner_reduction << std::endl;
|
||||
}
|
||||
return innerReductionHeuristic(
|
||||
total_reduction_numel,
|
||||
|
|
@ -511,7 +506,7 @@ ReductionParams innerReductionHeuristic(
|
|||
return rparams;
|
||||
}
|
||||
|
||||
ReductionParams OuterReductionHeuristic(
|
||||
std::shared_ptr<ReductionParams> outerReductionHeuristic(
|
||||
const int64_t total_reduction_numel,
|
||||
const int64_t total_iteration_numel,
|
||||
const int64_t n_tensor_inputs,
|
||||
|
|
@ -684,16 +679,17 @@ ReductionParams OuterReductionHeuristic(
|
|||
bdimx = roundUpPow2OrMultipleOf(bdimx, 8);
|
||||
|
||||
// Fill bdimy with left over threads
|
||||
bdimy =
|
||||
std::min(safeDiv(target_threads_in_block, bdimx), total_reduction_numel);
|
||||
bdimy = std::min(
|
||||
scheduler_utils::safeDiv(target_threads_in_block, bdimx),
|
||||
total_reduction_numel);
|
||||
|
||||
bdimy = roundDownPow2OrMultipleOf(bdimy, 8);
|
||||
|
||||
// Move parallelization into unrolling the reduction dimension if
|
||||
// parallelizing iteration dimension didn't take the available unroll factor.
|
||||
if (iter_unroll_factor < max_unroll && rDimAvail() > 2) {
|
||||
inner_reduction_unroll_factor =
|
||||
std::min(rDimAvail(), safeDiv(max_unroll, iter_unroll_factor));
|
||||
inner_reduction_unroll_factor = std::min(
|
||||
rDimAvail(), scheduler_utils::safeDiv(max_unroll, iter_unroll_factor));
|
||||
|
||||
inner_reduction_unroll_factor =
|
||||
scheduler_utils::lastPow2(inner_reduction_unroll_factor);
|
||||
|
|
@ -731,7 +727,8 @@ ReductionParams OuterReductionHeuristic(
|
|||
// Empiercally found stride shouldn't exceed 256kiB boundaries in a block
|
||||
int64_t kMaxStride = 128 * 1024;
|
||||
|
||||
int64_t max_remainder_size = safeDiv(kMaxStride, bytes_stride_remainder);
|
||||
int64_t max_remainder_size =
|
||||
scheduler_utils::safeDiv(kMaxStride, bytes_stride_remainder);
|
||||
|
||||
int64_t grdim_for_stride = ceilDiv(
|
||||
total_reduction_numel,
|
||||
|
|
@ -771,13 +768,13 @@ ReductionParams OuterReductionHeuristic(
|
|||
// Always disabled for now.
|
||||
// bool flip_grid = gidim > 1 && gidim < 8;
|
||||
const bool flip_grid = false;
|
||||
ReductionParams rparams;
|
||||
auto rparams = std::make_shared<ReductionParams>();
|
||||
// cross grid implies cross block
|
||||
rparams.cross_block_inner_reduction = bdimy > 1 || grdim > 1;
|
||||
rparams.cross_grid_inner_reduction = grdim > 1;
|
||||
if (rparams.cross_grid_inner_reduction) {
|
||||
rparams.split_grid_dim_inner_reduction = true;
|
||||
rparams.grid_dim_inner_reduction =
|
||||
rparams->cross_block_inner_reduction = bdimy > 1 || grdim > 1;
|
||||
rparams->cross_grid_inner_reduction = grdim > 1;
|
||||
if (rparams->cross_grid_inner_reduction) {
|
||||
rparams->split_grid_dim_inner_reduction = true;
|
||||
rparams->grid_dim_inner_reduction =
|
||||
flip_grid ? ParallelType::BIDx : ParallelType::BIDy;
|
||||
if (flip_grid) {
|
||||
gdimx = std::min(grdim, scheduler_utils::x_grid_limit);
|
||||
|
|
@ -785,17 +782,17 @@ ReductionParams OuterReductionHeuristic(
|
|||
gdimy = std::min(grdim, scheduler_utils::y_grid_limit);
|
||||
}
|
||||
}
|
||||
rparams.multiple_reds_per_blk = bdimx > 1 || iter_unroll_factor > 1;
|
||||
rparams->multiple_reds_per_blk = bdimx > 1 || iter_unroll_factor > 1;
|
||||
|
||||
if (rparams.multiple_reds_per_blk) {
|
||||
rparams.block_dim_iter_dom = ParallelType::TIDx;
|
||||
if (rparams->multiple_reds_per_blk) {
|
||||
rparams->block_dim_iter_dom = ParallelType::TIDx;
|
||||
}
|
||||
|
||||
rparams.grid_dim_iter_dom =
|
||||
rparams->grid_dim_iter_dom =
|
||||
flip_grid ? ParallelType::BIDy : ParallelType::BIDx;
|
||||
if (gidim > (flip_grid ? scheduler_utils::y_grid_limit
|
||||
: scheduler_utils::x_grid_limit)) {
|
||||
rparams.split_grid_dim_iter_dom = true;
|
||||
rparams->split_grid_dim_iter_dom = true;
|
||||
if (flip_grid) {
|
||||
gdimy = scheduler_utils::y_grid_limit;
|
||||
} else {
|
||||
|
|
@ -803,29 +800,29 @@ ReductionParams OuterReductionHeuristic(
|
|||
}
|
||||
}
|
||||
|
||||
rparams.flip_grid = flip_grid;
|
||||
rparams->flip_grid = flip_grid;
|
||||
|
||||
if (rparams.cross_block_inner_reduction) {
|
||||
if (rparams.block_dim_iter_dom == ParallelType::TIDx) {
|
||||
rparams.block_dim_inner_reduction = ParallelType::TIDy;
|
||||
if (rparams->cross_block_inner_reduction) {
|
||||
if (rparams->block_dim_iter_dom == ParallelType::TIDx) {
|
||||
rparams->block_dim_inner_reduction = ParallelType::TIDy;
|
||||
} else {
|
||||
rparams.block_dim_inner_reduction = ParallelType::TIDx;
|
||||
rparams->block_dim_inner_reduction = ParallelType::TIDx;
|
||||
}
|
||||
}
|
||||
|
||||
rparams.unroll_factor_inner_reduction = inner_reduction_unroll_factor;
|
||||
rparams->unroll_factor_inner_reduction = inner_reduction_unroll_factor;
|
||||
|
||||
rparams.unroll_factor_iter_dom = iter_unroll_factor;
|
||||
rparams->unroll_factor_iter_dom = iter_unroll_factor;
|
||||
if (iter_unroll_factor > 1) {
|
||||
rparams.vectorize_iter_dom = vectorize;
|
||||
rparams->vectorize_iter_dom = vectorize;
|
||||
}
|
||||
|
||||
rparams.lparams = LaunchParams(
|
||||
rparams->lparams = LaunchParams(
|
||||
gdimx,
|
||||
gdimy,
|
||||
LaunchParams::UNINITIALIZED_VAL,
|
||||
rparams.multiple_reds_per_blk ? bdimx : bdimy,
|
||||
rparams.multiple_reds_per_blk ? bdimy : LaunchParams::UNINITIALIZED_VAL,
|
||||
rparams->multiple_reds_per_blk ? bdimx : bdimy,
|
||||
rparams->multiple_reds_per_blk ? bdimy : LaunchParams::UNINITIALIZED_VAL,
|
||||
LaunchParams::UNINITIALIZED_VAL);
|
||||
|
||||
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
|
||||
|
|
@ -836,14 +833,14 @@ ReductionParams OuterReductionHeuristic(
|
|||
<< "n_tensor_inputs: " << n_tensor_inputs << "\n"
|
||||
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
|
||||
<< "block(" << bdimx << ", " << bdimy << ", 1)" << std::endl;
|
||||
std::cerr << rparams.toString() << std::endl;
|
||||
std::cerr << rparams->toString() << std::endl;
|
||||
}
|
||||
return rparams;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ReductionParams reductionHeuristic(
|
||||
std::shared_ptr<ReductionParams> reductionHeuristic(
|
||||
const int64_t total_reduction_numel,
|
||||
const int64_t total_iteration_numel,
|
||||
const int64_t inner_most_dimension_numel,
|
||||
|
|
@ -861,7 +858,7 @@ ReductionParams reductionHeuristic(
|
|||
vectorize_factor);
|
||||
} else {
|
||||
// 3D schedules not enabled for outer reductions
|
||||
return OuterReductionHeuristic(
|
||||
return outerReductionHeuristic(
|
||||
total_reduction_numel,
|
||||
total_iteration_numel,
|
||||
n_tensor_inputs,
|
||||
|
|
@ -870,7 +867,7 @@ ReductionParams reductionHeuristic(
|
|||
}
|
||||
}
|
||||
|
||||
TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
|
||||
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
|
||||
Fusion* fusion,
|
||||
const at::ArrayRef<c10::IValue>& runtime_inputs,
|
||||
HeuristicSummary* data_cache) {
|
||||
|
|
@ -881,7 +878,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
|
|||
return getReductionHeuristics(fusion, runtime_info, data_cache);
|
||||
}
|
||||
|
||||
TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
|
||||
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
|
||||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache) {
|
||||
|
|
@ -911,8 +908,7 @@ TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
|
|||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
red_expr->getExprType() != c10::nullopt &&
|
||||
(red_expr->getExprType().value() == ExprType::ReductionOp ||
|
||||
red_expr->getExprType().value() == ExprType::WelfordOp),
|
||||
ir_utils::isReductionOp(red_expr),
|
||||
"TensorView doesn't have a reduction.");
|
||||
|
||||
auto properties =
|
||||
|
|
|
|||
|
|
@ -13,12 +13,12 @@ namespace cuda {
|
|||
class SchedulerRuntimeInfo;
|
||||
class HeuristicSummary;
|
||||
|
||||
TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
|
||||
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
|
||||
Fusion* fusion,
|
||||
const at::ArrayRef<c10::IValue>& runtime_inputs,
|
||||
HeuristicSummary* data_cache = nullptr);
|
||||
|
||||
TORCH_CUDA_CU_API c10::optional<ReductionParams> getReductionHeuristics(
|
||||
TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
|
||||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache = nullptr);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/heuristic.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
|
|
@ -9,11 +9,11 @@ namespace jit {
|
|||
namespace fuser {
|
||||
namespace cuda {
|
||||
|
||||
// Parameters the Reduction Heuristic Generates to describe the optimial
|
||||
// schedule. Warning: equal operator is intended for use in caching the kernel
|
||||
// associated with these reduction parameters. It does not check if the launch
|
||||
// parameters are equivelent!
|
||||
class ReductionParams {
|
||||
// Parameters of the reduction heuristic to describe the optimial schedule.
|
||||
// Warning: equal operator is intended for use in caching the kernel associated
|
||||
// with these reduction parameters. It does not check if the launch parameters
|
||||
// are equivelent!
|
||||
class ReductionParams : public HeuristicParams {
|
||||
public:
|
||||
// Reducing inner most dimension?
|
||||
bool fastest_dim = false;
|
||||
|
|
@ -100,18 +100,22 @@ class ReductionParams {
|
|||
// parameters, not used for equivalence/hashing.
|
||||
ParallelType grid_dim_outer_reduction = ParallelType::Serial;
|
||||
|
||||
std::string tag = "";
|
||||
|
||||
LaunchParams lparams;
|
||||
|
||||
bool isUnrolled() const {
|
||||
return unroll_factor_inner_reduction > 1 || unroll_factor_iter_dom > 1 ||
|
||||
unroll_factor_outer_reduction > 1;
|
||||
}
|
||||
|
||||
public:
|
||||
using HeuristicParams::HeuristicParams;
|
||||
|
||||
// Warning: Does not check launch parameters!
|
||||
bool operator==(const ReductionParams& other) const {
|
||||
bool sameAs(
|
||||
const std::shared_ptr<HeuristicParams>& other_base) const override {
|
||||
auto other_casted = std::dynamic_pointer_cast<ReductionParams>(other_base);
|
||||
if (other_casted == nullptr) {
|
||||
return false;
|
||||
}
|
||||
const ReductionParams& other = *other_casted;
|
||||
bool attr_equal = other.fastest_dim == fastest_dim &&
|
||||
other.persistent_kernel == persistent_kernel &&
|
||||
other.project_persistent_buffers == project_persistent_buffers &&
|
||||
|
|
@ -139,7 +143,7 @@ class ReductionParams {
|
|||
return attr_equal;
|
||||
}
|
||||
|
||||
std::string toString() const {
|
||||
std::string toString() const override {
|
||||
std::stringstream ss;
|
||||
ss << "\n===== Reduction Parameters ========\n"
|
||||
<< (tag == "" ? "" : "Tag: ") << tag << "\n"
|
||||
|
|
@ -216,38 +220,37 @@ class ReductionParams {
|
|||
ss << "====================================\n";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
// Warning: Hash is not based on launch parameters!
|
||||
class ReductionParamsHash {
|
||||
public:
|
||||
size_t operator()(const ReductionParams& rp) const {
|
||||
// Warning: Hash is not based on launch parameters!
|
||||
size_t hash() const override {
|
||||
constexpr size_t bits = sizeof(std::size_t) * 8;
|
||||
size_t attr_hash = static_cast<size_t>(rp.fastest_dim) << (bits - 1) ^
|
||||
static_cast<size_t>(rp.persistent_kernel) << (bits - 2) ^
|
||||
static_cast<size_t>(rp.project_persistent_buffers) << (bits - 3) ^
|
||||
static_cast<size_t>(rp.schedule_3D) << (bits - 4) ^
|
||||
static_cast<size_t>(rp.flip_grid) << (bits - 5) ^
|
||||
static_cast<size_t>(rp.cross_block_inner_reduction) << (bits - 6) ^
|
||||
static_cast<size_t>(rp.cross_grid_inner_reduction) << (bits - 7) ^
|
||||
static_cast<size_t>(rp.unroll_factor_inner_reduction) << (bits - 8) ^
|
||||
static_cast<size_t>(rp.vectorize_inner_reduction) << (bits - 9) ^
|
||||
static_cast<size_t>(rp.split_grid_dim_inner_reduction) << (bits - 10) ^
|
||||
static_cast<size_t>(rp.pad_inner_reduction_to_warp) << (bits - 11) ^
|
||||
static_cast<size_t>(rp.batches_per_block_inner_reduction)
|
||||
<< (bits - 12) ^
|
||||
static_cast<size_t>(rp.multiple_reds_per_blk) << (bits - 13) ^
|
||||
static_cast<size_t>(rp.unroll_factor_iter_dom) << (bits - 14) ^
|
||||
static_cast<size_t>(rp.vectorize_iter_dom) << (bits - 15) ^
|
||||
static_cast<size_t>(rp.split_grid_dim_iter_dom) << (bits - 16) ^
|
||||
static_cast<size_t>(rp.cross_block_outer_reduction) << (bits - 17) ^
|
||||
static_cast<size_t>(rp.cross_grid_outer_reduction) << (bits - 18) ^
|
||||
static_cast<size_t>(rp.split_grid_dim_outer_reduction) << (bits - 19) ^
|
||||
static_cast<size_t>(rp.batches_per_block_outer_reduction)
|
||||
<< (bits - 20) ^
|
||||
static_cast<size_t>(rp.unroll_factor_outer_reduction) << (bits - 21);
|
||||
size_t attr_hash = static_cast<size_t>(fastest_dim) << (bits - 1) ^
|
||||
static_cast<size_t>(persistent_kernel) << (bits - 2) ^
|
||||
static_cast<size_t>(project_persistent_buffers) << (bits - 3) ^
|
||||
static_cast<size_t>(schedule_3D) << (bits - 4) ^
|
||||
static_cast<size_t>(flip_grid) << (bits - 5) ^
|
||||
static_cast<size_t>(cross_block_inner_reduction) << (bits - 6) ^
|
||||
static_cast<size_t>(cross_grid_inner_reduction) << (bits - 7) ^
|
||||
static_cast<size_t>(unroll_factor_inner_reduction) << (bits - 8) ^
|
||||
static_cast<size_t>(vectorize_inner_reduction) << (bits - 9) ^
|
||||
static_cast<size_t>(split_grid_dim_inner_reduction) << (bits - 10) ^
|
||||
static_cast<size_t>(pad_inner_reduction_to_warp) << (bits - 11) ^
|
||||
static_cast<size_t>(batches_per_block_inner_reduction) << (bits - 12) ^
|
||||
static_cast<size_t>(multiple_reds_per_blk) << (bits - 13) ^
|
||||
static_cast<size_t>(unroll_factor_iter_dom) << (bits - 14) ^
|
||||
static_cast<size_t>(vectorize_iter_dom) << (bits - 15) ^
|
||||
static_cast<size_t>(split_grid_dim_iter_dom) << (bits - 16) ^
|
||||
static_cast<size_t>(cross_block_outer_reduction) << (bits - 17) ^
|
||||
static_cast<size_t>(cross_grid_outer_reduction) << (bits - 18) ^
|
||||
static_cast<size_t>(split_grid_dim_outer_reduction) << (bits - 19) ^
|
||||
static_cast<size_t>(batches_per_block_outer_reduction) << (bits - 20) ^
|
||||
static_cast<size_t>(unroll_factor_outer_reduction) << (bits - 21);
|
||||
return attr_hash;
|
||||
}
|
||||
|
||||
std::shared_ptr<HeuristicParams> clone() const override {
|
||||
return std::make_shared<ReductionParams>(*this);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.h>
|
||||
|
||||
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/inline_propagator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/maxinfo_propagator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/registry.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
|
||||
|
|
@ -219,13 +221,13 @@ void multiReductionInliner(
|
|||
std::vector<TensorView*> reduction_tvs,
|
||||
std::vector<TensorView*> cached_inputs,
|
||||
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs) {
|
||||
// Propagate transformations before we rfactor the other reductions
|
||||
TransformPropagator propagator(reference_tv);
|
||||
MaxRootDomainInfoSpanningTree(reference_tv).traverse(&propagator);
|
||||
|
||||
// Apply rfactor to all reductions if applicable
|
||||
std::vector<TensorView*> rfactor_tvs;
|
||||
|
||||
// If reduction_tv is rfactored, rfactor all reductions.
|
||||
if (reference_tv != reduction_tv) {
|
||||
// Apply rfactor to all reductions if applicable
|
||||
std::vector<int> rfactor_axes;
|
||||
for (const auto i : c10::irange(reference_tv->nDims())) {
|
||||
if (reference_tv->axis((int)i)->isReduction() &&
|
||||
|
|
@ -236,155 +238,86 @@ void multiReductionInliner(
|
|||
|
||||
for (auto reduction_tv_ : reduction_tvs) {
|
||||
if (reduction_tv_ == reduction_tv) {
|
||||
// The reduction tv
|
||||
rfactor_tvs.push_back(reference_tv);
|
||||
// This should come in already rfactored
|
||||
continue;
|
||||
} else {
|
||||
rfactor_tvs.push_back(
|
||||
ir_utils::rfactorHelper(reduction_tv_, rfactor_axes));
|
||||
ir_utils::rfactorHelper(reduction_tv_, rfactor_axes);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
reduction_tvs.size() == rfactor_tvs.size(),
|
||||
"Expected all reductions to contain rfactor.");
|
||||
}
|
||||
|
||||
// Propagate parallelization
|
||||
scheduler_utils::parallelizeAllLike(reference_tv, ir_utils::allTvs(fusion));
|
||||
|
||||
// Find iter domains that are mapped to a trivial reduction, these should
|
||||
// never be inlined.
|
||||
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
|
||||
scheduler_utils::getTrivialReductionMap(fusion);
|
||||
|
||||
bool unroll = rparams.isUnrolled();
|
||||
|
||||
bool vectorize =
|
||||
rparams.vectorize_inner_reduction || rparams.vectorize_iter_dom;
|
||||
|
||||
// Propagate parallelization except vectorization and unrolling
|
||||
scheduler_utils::parallelizeAllLike(
|
||||
reference_tv,
|
||||
{},
|
||||
allParallelTypesExcept(
|
||||
{ParallelType::Unroll,
|
||||
ParallelType::Vectorize,
|
||||
ParallelType::MisalignedVectorize}));
|
||||
|
||||
if (unroll) {
|
||||
// Inline Input caches to their consumers outside unswitched/vectorization
|
||||
// position Inline consumers of input caches to rfactor tensors
|
||||
|
||||
// Mark which tensor views are actual input caches to leave vectorization on
|
||||
// them
|
||||
std::unordered_set<TensorView*> keep_unrolled;
|
||||
|
||||
std::vector<TensorView*> compute_from;
|
||||
// Find all tensor views that should have unroll or vectorization
|
||||
std::unordered_set<TensorView*> are_unrolled;
|
||||
|
||||
// Grab all tensor views that should be vectorized
|
||||
auto vectorizable_inputs_outputs =
|
||||
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true);
|
||||
|
||||
// Inputs to cache
|
||||
auto vectorizable_expr = [](Expr* e) {
|
||||
return e->isA<UnaryOp>() &&
|
||||
e->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::Set;
|
||||
};
|
||||
|
||||
for (auto cached_input : cached_inputs) {
|
||||
auto consumers_of_input_cache = ir_utils::consumerTvsOf(cached_input);
|
||||
for (auto consumer : consumers_of_input_cache) {
|
||||
auto unswitch_it = std::find_if(
|
||||
consumer->domain()->domain().begin(),
|
||||
consumer->domain()->domain().end(),
|
||||
[&mapped_to_trivial_reduction](IterDomain* id) {
|
||||
return id->getParallelType() == ParallelType::Unswitch ||
|
||||
id->getParallelType() == ParallelType::Unroll ||
|
||||
id->getParallelType() == ParallelType::Vectorize ||
|
||||
id->getParallelType() == ParallelType::MisalignedVectorize ||
|
||||
mapped_to_trivial_reduction.count(id);
|
||||
});
|
||||
auto unswitch_pos = unswitch_it == consumer->domain()->domain().end()
|
||||
? -1
|
||||
: std::distance(consumer->domain()->domain().begin(), unswitch_it) +
|
||||
1;
|
||||
|
||||
cached_input->computeAt(
|
||||
consumer, unswitch_pos, ComputeAtMode::BestEffort);
|
||||
compute_from.push_back(consumer);
|
||||
|
||||
if (vectorize) {
|
||||
auto producer_tvs = ir_utils::producerTvsOf(cached_input);
|
||||
if (producer_tvs.size() == 1 &&
|
||||
std::find(
|
||||
vectorizable_inputs_outputs.begin(),
|
||||
vectorizable_inputs_outputs.end(),
|
||||
producer_tvs[0]) != vectorizable_inputs_outputs.end()) {
|
||||
keep_unrolled.emplace(cached_input);
|
||||
}
|
||||
} else {
|
||||
keep_unrolled.emplace(cached_input);
|
||||
if (vectorize) {
|
||||
auto producer_tvs = ir_utils::producerTvsOf(cached_input);
|
||||
if (producer_tvs.size() == 1 &&
|
||||
vectorizable_expr(cached_input->definition()) &&
|
||||
std::find(
|
||||
vectorizable_inputs_outputs.begin(),
|
||||
vectorizable_inputs_outputs.end(),
|
||||
producer_tvs[0]) != vectorizable_inputs_outputs.end()) {
|
||||
are_unrolled.emplace(cached_input);
|
||||
}
|
||||
} else {
|
||||
are_unrolled.emplace(cached_input);
|
||||
}
|
||||
}
|
||||
|
||||
// Inline output caches into outputs
|
||||
std::vector<TensorView*> compute_to;
|
||||
for (auto cached_output_pair : cached_outputs) {
|
||||
auto cached_output = cached_output_pair.first;
|
||||
auto output = cached_output_pair.second;
|
||||
|
||||
if (vectorize) {
|
||||
if (std::find(
|
||||
if (vectorizable_expr(output->definition()) &&
|
||||
std::find(
|
||||
vectorizable_inputs_outputs.begin(),
|
||||
vectorizable_inputs_outputs.end(),
|
||||
output) != vectorizable_inputs_outputs.end()) {
|
||||
keep_unrolled.emplace(output);
|
||||
are_unrolled.emplace(output);
|
||||
}
|
||||
} else {
|
||||
keep_unrolled.emplace(output);
|
||||
}
|
||||
|
||||
// If an output has multiple consumers don't process compute at structure
|
||||
// here, we want only terminating outputs
|
||||
if (cached_output->uses().size() > 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto pos_it = std::find_if(
|
||||
output->domain()->domain().begin(),
|
||||
output->domain()->domain().end(),
|
||||
[&mapped_to_trivial_reduction](IterDomain* id) {
|
||||
return id->getParallelType() == ParallelType::Unswitch ||
|
||||
id->getParallelType() == ParallelType::Unroll ||
|
||||
id->getParallelType() == ParallelType::Vectorize ||
|
||||
id->getParallelType() == ParallelType::MisalignedVectorize ||
|
||||
mapped_to_trivial_reduction.count(id);
|
||||
});
|
||||
auto pos = pos_it == output->domain()->domain().end()
|
||||
? -1
|
||||
: std::distance(output->domain()->domain().begin(), pos_it) + 1;
|
||||
|
||||
cached_output->computeAt(output, pos, ComputeAtMode::BestEffort);
|
||||
|
||||
compute_to.push_back(cached_output);
|
||||
}
|
||||
|
||||
{
|
||||
// Add inputs to compute_at that weren't unrolled
|
||||
auto processed_inputs = ir_utils::inputTvsOf(compute_from);
|
||||
std::unordered_set<TensorView*> processed_inputs_set{
|
||||
processed_inputs.begin(), processed_inputs.end()};
|
||||
for (auto inp_tv : ir_utils::filterByType<TensorView>(fusion->inputs())) {
|
||||
if (!processed_inputs_set.count(inp_tv)) {
|
||||
compute_from.push_back(inp_tv);
|
||||
}
|
||||
}
|
||||
|
||||
auto processed_outputs = ir_utils::inputTvsOf(compute_to);
|
||||
std::unordered_set<TensorView*> processed_outputs_set{
|
||||
processed_outputs.begin(), processed_outputs.end()};
|
||||
for (auto out_tv :
|
||||
ir_utils::filterByType<TensorView>(fusion->outputs())) {
|
||||
if (!processed_outputs_set.count(out_tv) && out_tv->uses().empty()) {
|
||||
compute_to.push_back(out_tv);
|
||||
}
|
||||
are_unrolled.emplace(output);
|
||||
}
|
||||
}
|
||||
|
||||
// Before compute at-ing the internal structure, remove vectorization
|
||||
// anywhere it doesn't belong. Otherwise it will mess up our inlining. Clear
|
||||
// explicit unroll or vectorization when not for input or output GMEM
|
||||
// transfers.
|
||||
for (auto tv : ir_utils::allTvs(fusion)) {
|
||||
if (!keep_unrolled.count(tv)) {
|
||||
// Propagate vectorization/unrolling to those tensors that need it
|
||||
scheduler_utils::parallelizeAllLike(
|
||||
reference_tv,
|
||||
-1,
|
||||
{are_unrolled.begin(), are_unrolled.end()},
|
||||
{ParallelType::Unroll,
|
||||
ParallelType::Vectorize,
|
||||
ParallelType::MisalignedVectorize});
|
||||
|
||||
std::vector<TensorView*> rfactor_and_reduction_tvs = {
|
||||
reference_tv, reduction_tv};
|
||||
// If reference shouldn't be unrolled, clear that parallel type.
|
||||
for (auto tv : rfactor_and_reduction_tvs) {
|
||||
if (are_unrolled.count(tv) == 0) {
|
||||
for (const auto i : c10::irange(tv->nDims())) {
|
||||
auto id = tv->axis((int)i);
|
||||
if (id->getParallelType() == ParallelType::Unroll ||
|
||||
|
|
@ -395,146 +328,22 @@ void multiReductionInliner(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure not to completely inline if there's trivial reductions in the
|
||||
// fusion
|
||||
auto pos_it = std::find_if(
|
||||
reference_tv->domain()->domain().begin(),
|
||||
reference_tv->domain()->domain().end(),
|
||||
[&mapped_to_trivial_reduction](IterDomain* id) {
|
||||
return mapped_to_trivial_reduction.count(id);
|
||||
});
|
||||
|
||||
auto pos = pos_it == reference_tv->domain()->domain().end()
|
||||
? -1
|
||||
: std::distance(reference_tv->domain()->domain().begin(), pos_it) + 1;
|
||||
|
||||
// Compute at inputs to rfactor dimensions
|
||||
scheduler_utils::computeAtBetween(
|
||||
compute_from, rfactor_tvs, pos, ComputeAtMode::MostInlined);
|
||||
|
||||
// Inline rfactor into reduction
|
||||
if (reference_tv != reduction_tv) {
|
||||
// Compute at rfactor into following reduction, keep outside first
|
||||
// reduction iter domain in the rfactor tensor view
|
||||
for (const auto i : c10::irange(rfactor_tvs.size())) {
|
||||
if (rparams.unroll_factor_iter_dom > 1) {
|
||||
auto rfactor_tv = rfactor_tvs[i];
|
||||
auto rfactor_tv_dom = rfactor_tv->domain()->domain();
|
||||
auto reduction_it = std::find_if(
|
||||
rfactor_tv_dom.begin(), rfactor_tv_dom.end(), [](IterDomain* id) {
|
||||
return id->isReduction();
|
||||
});
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
reduction_it != rfactor_tv_dom.end(),
|
||||
"Expected reduction axis in ",
|
||||
rfactor_tv);
|
||||
auto pos = std::distance(rfactor_tv_dom.begin(), reduction_it);
|
||||
// I would like computeAtMode here to be Standard. However, the
|
||||
// processing of welford rfactors in compute at ends up propating
|
||||
// compute at from reduction_tv->rfactor_tv to all outputs.
|
||||
rfactor_tv->computeWith(
|
||||
reduction_tvs[i], pos, ComputeAtMode::BestEffort);
|
||||
} else {
|
||||
rfactor_tvs[i]->computeWith(
|
||||
reduction_tvs[i], -1, ComputeAtMode::BestEffort);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove anything before a reduction from compute_from
|
||||
{
|
||||
auto producers_of_reductions = DependencyCheck::getAllValsBetween(
|
||||
{fusion->inputs().begin(), fusion->inputs().end()},
|
||||
{reduction_tvs.begin(), reduction_tvs.end()});
|
||||
|
||||
auto producer_tvs_of_reductions =
|
||||
ir_utils::filterByType<TensorView>(producers_of_reductions);
|
||||
compute_from.erase(
|
||||
std::remove_if(
|
||||
compute_from.begin(),
|
||||
compute_from.end(),
|
||||
[&producer_tvs_of_reductions](TensorView* compute_from_tv) {
|
||||
return std::find(
|
||||
producer_tvs_of_reductions.begin(),
|
||||
producer_tvs_of_reductions.end(),
|
||||
compute_from_tv) != producer_tvs_of_reductions.end();
|
||||
}),
|
||||
compute_from.end());
|
||||
}
|
||||
|
||||
// Add reduction tensor views to compute from
|
||||
compute_from.insert(
|
||||
compute_from.end(), reduction_tvs.begin(), reduction_tvs.end());
|
||||
|
||||
// Compute between reductions and output caches
|
||||
scheduler_utils::computeAtBetween(
|
||||
compute_from,
|
||||
compute_to,
|
||||
-1,
|
||||
ComputeAtMode::BestEffort,
|
||||
mapped_to_trivial_reduction);
|
||||
|
||||
} else {
|
||||
// Want to inline, especially backwards based on reduction_tv, otherwise
|
||||
// rfactor tv may not be inlined correctly
|
||||
auto ref_tvs = rfactor_tvs.size() ? rfactor_tvs : reduction_tvs;
|
||||
for (auto red_tv : ref_tvs) {
|
||||
auto pos_it = std::find_if(
|
||||
red_tv->domain()->domain().begin(),
|
||||
red_tv->domain()->domain().end(),
|
||||
[&mapped_to_trivial_reduction](IterDomain* id) {
|
||||
return id->getParallelType() == ParallelType::Unswitch ||
|
||||
id->getParallelType() == ParallelType::Unroll ||
|
||||
id->getParallelType() == ParallelType::Vectorize ||
|
||||
id->getParallelType() == ParallelType::MisalignedVectorize ||
|
||||
mapped_to_trivial_reduction.count(id);
|
||||
});
|
||||
auto pos = pos_it == red_tv->domain()->domain().end()
|
||||
? -1
|
||||
: std::distance(red_tv->domain()->domain().begin(), pos_it) + 1;
|
||||
|
||||
scheduler_utils::computeAtInputs(red_tv, pos, ComputeAtMode::MostInlined);
|
||||
scheduler_utils::computeWithOutputs(
|
||||
red_tv, pos, ComputeAtMode::BestEffort);
|
||||
}
|
||||
// For topologies where there may not be paths to all inputs/outputs from
|
||||
// the reductions, we need to take a similar approach to the unrolled
|
||||
// version and setup of compute at from inputs->outputs avoiding going
|
||||
// through the reduction expressions. This can be done by grabbing inputs
|
||||
// not on path to a reduction, and computeAt-ing with all outputs. This
|
||||
// doesn't guarantee we don't go through a reduction, but with best effort
|
||||
// it should minimize damage if it does.
|
||||
std::vector<TensorView*> compute_to;
|
||||
for (auto out : ir_utils::filterByType<TensorView>(fusion->outputs())) {
|
||||
// only terminating outputs
|
||||
if (out->uses().size() || out->isFusionInput()) {
|
||||
continue;
|
||||
}
|
||||
compute_to.push_back(out);
|
||||
}
|
||||
|
||||
std::vector<TensorView*> compute_from;
|
||||
std::unordered_set<TensorView*> inps_of_reds;
|
||||
{
|
||||
auto inps_of_red_vec = ir_utils::inputTvsOf(ref_tvs);
|
||||
inps_of_reds = std::unordered_set<TensorView*>(
|
||||
inps_of_red_vec.begin(), inps_of_red_vec.end());
|
||||
}
|
||||
for (auto inp : ir_utils::filterByType<TensorView>(fusion->inputs())) {
|
||||
if (inps_of_reds.find(inp) != inps_of_reds.end()) {
|
||||
continue;
|
||||
}
|
||||
compute_from.push_back(inp);
|
||||
}
|
||||
|
||||
scheduler_utils::computeAtBetween(
|
||||
compute_from,
|
||||
compute_to,
|
||||
-1,
|
||||
ComputeAtMode::BestEffort,
|
||||
mapped_to_trivial_reduction);
|
||||
}
|
||||
|
||||
// Find iter domains that are mapped to a trivial reduction, these should
|
||||
// never be inlined.
|
||||
std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
|
||||
scheduler_utils::getTrivialReductionMap(fusion);
|
||||
|
||||
// Inline the schedule
|
||||
InlinePropagator inline_propagator(
|
||||
reference_tv,
|
||||
-1,
|
||||
ComputeAtMode::MostInlined,
|
||||
{},
|
||||
mapped_to_trivial_reduction);
|
||||
|
||||
MaxRootDomainInfoSpanningTree(reference_tv).traverse(&inline_propagator);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ TensorView* scheduleReductionTV(
|
|||
bool has_iter_axis);
|
||||
|
||||
// Inlining function intended for single or multi reduction fusions.
|
||||
void multiReductionInliner(
|
||||
TORCH_CUDA_CU_API void multiReductionInliner(
|
||||
Fusion* fusion,
|
||||
const ReductionParams& rparams,
|
||||
TensorView* reduction_tv,
|
||||
|
|
@ -41,7 +41,7 @@ void multiReductionInliner(
|
|||
// Rfactored axes are reductions bound to grid or blocks. If no axes are bound
|
||||
// to a grid or block dimension it will rfactor the r-unswitch dimension.
|
||||
// Reduction inliner expects an rfactored domain.
|
||||
TensorView* sortAndRFactor(TensorView* reference_tv);
|
||||
TORCH_CUDA_CU_API TensorView* sortAndRFactor(TensorView* reference_tv);
|
||||
|
||||
// Take all projectable persistent buffers, and move them to the inputs.
|
||||
TORCH_CUDA_CU_API void projectPersistentBuffers(Fusion* fusion);
|
||||
|
|
|
|||
|
|
@ -239,8 +239,8 @@ class SchedulerTopologyChecker {
|
|||
static bool hasPostReductionBCast(Fusion* fusion) {
|
||||
auto all_vals = fusion->usedMathVals();
|
||||
for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
|
||||
// Welford can have 2 outputs, so do this on all found reduction tensor
|
||||
// views
|
||||
// Reductions can have multiple outputs, so do this on all found reduction
|
||||
// tensor views
|
||||
if (tv->hasReduction() && !tv->isFusionInput()) {
|
||||
auto tv_chains = tvChains(DependencyCheck::getAllUseChains(tv));
|
||||
// Propagate forward from reduction through all uses of the reduction
|
||||
|
|
@ -301,18 +301,17 @@ class SchedulerTopologyChecker {
|
|||
|
||||
// When checking post reduction vals, we need to make sure
|
||||
// we are really checking paths starting from all outputs
|
||||
// of multi-output reductions, i.e. welford. The reduction_tv
|
||||
// vector is assumed to only have one of them.
|
||||
// of multi-output reductions, i.e. welford/grouped reduction. The
|
||||
// reduction_tv vector is assumed to only have one of them.
|
||||
std::unordered_set<Val*> reduction_tv_set(
|
||||
reduction_tvs.begin(), reduction_tvs.end());
|
||||
|
||||
for (auto red : reduction_tvs) {
|
||||
if (red->definition()) {
|
||||
if (auto wop = dynamic_cast<WelfordOp*>(red->definition())) {
|
||||
for (auto wop_output : wop->outputs()) {
|
||||
if (wop_output->isA<TensorView>()) {
|
||||
reduction_tv_set.insert(wop_output);
|
||||
}
|
||||
if (ir_utils::isReductionOp(red->definition())) {
|
||||
auto outs = red->definition()->outputs();
|
||||
for (auto out_tv : ir_utils::filterByType<TensorView>(outs)) {
|
||||
reduction_tv_set.insert(out_tv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -721,18 +720,7 @@ bool SchedulerEntry::sameAs(const SchedulerEntry* other) {
|
|||
if (index_mode_ != other->index_mode_) {
|
||||
return false;
|
||||
}
|
||||
// Heuristic equal should imply has_reduction_param_ equal,
|
||||
// need to double check if it is the case before removing
|
||||
// the below one.
|
||||
if (has_reduction_param_ != other->has_reduction_param_) {
|
||||
return false;
|
||||
}
|
||||
if (has_reduction_param_) {
|
||||
return rparams_ == other->rparams_;
|
||||
} else {
|
||||
return pparams_ == other->pparams_;
|
||||
}
|
||||
return true;
|
||||
return params_->sameAs(other->params_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
@ -834,7 +822,7 @@ class ReductionScheduler : public SchedulerEntry {
|
|||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache = nullptr)
|
||||
: SchedulerEntry(ScheduleHeuristic::Reduction, true) {
|
||||
: SchedulerEntry(ScheduleHeuristic::Reduction) {
|
||||
computeHeuristics(fusion, runtime_info, data_cache);
|
||||
}
|
||||
|
||||
|
|
@ -964,7 +952,7 @@ class ReductionScheduler : public SchedulerEntry {
|
|||
|
||||
void schedule(Fusion* fusion) override {
|
||||
FUSER_PERF_SCOPE("Schedule Single Reduction");
|
||||
scheduleReduction(fusion, rparams());
|
||||
scheduleReduction(fusion, reductionParams());
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -972,9 +960,8 @@ class ReductionScheduler : public SchedulerEntry {
|
|||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache = nullptr) {
|
||||
auto param = getReductionHeuristics(fusion, runtime_info, data_cache);
|
||||
TORCH_INTERNAL_ASSERT(param.has_value());
|
||||
rparams() = param.value();
|
||||
params_ = getReductionHeuristics(fusion, runtime_info, data_cache);
|
||||
TORCH_INTERNAL_ASSERT(params_ != nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -984,7 +971,7 @@ class PointWiseScheduler : public SchedulerEntry {
|
|||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache = nullptr)
|
||||
: SchedulerEntry(ScheduleHeuristic::PointWise, false) {
|
||||
: SchedulerEntry(ScheduleHeuristic::PointWise) {
|
||||
computeHeuristics(fusion, runtime_info, data_cache);
|
||||
}
|
||||
|
||||
|
|
@ -1000,9 +987,8 @@ class PointWiseScheduler : public SchedulerEntry {
|
|||
|
||||
auto reduction_ops =
|
||||
ir_utils::getReductionOps(fusion, true /* ignore_trivial */);
|
||||
auto welford_ops = ir_utils::filterByType<WelfordOp>(reduction_ops);
|
||||
|
||||
if (!reduction_ops.empty() || !welford_ops.empty()) {
|
||||
if (!reduction_ops.empty()) {
|
||||
scheduler_debug_utils::canScheduleRejectReason(
|
||||
ScheduleHeuristic::PointWise, "no support for reduction ops");
|
||||
return false;
|
||||
|
|
@ -1027,16 +1013,15 @@ class PointWiseScheduler : public SchedulerEntry {
|
|||
|
||||
void schedule(Fusion* fusion) override {
|
||||
FUSER_PERF_SCOPE("Schedule PointWise Fusion");
|
||||
schedulePointwise(fusion, pparams());
|
||||
schedulePointwise(fusion, pointwiseParams());
|
||||
}
|
||||
|
||||
void computeHeuristics(
|
||||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache = nullptr) {
|
||||
auto pparam = getPointwiseHeuristics(fusion, runtime_info, data_cache);
|
||||
TORCH_INTERNAL_ASSERT(pparam.has_value());
|
||||
pparams() = pparam.value();
|
||||
params_ = getPointwiseHeuristics(fusion, runtime_info, data_cache);
|
||||
TORCH_INTERNAL_ASSERT(params_ != nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -1046,13 +1031,13 @@ class PersistentKernelScheduler : public SchedulerEntry {
|
|||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache = nullptr)
|
||||
: SchedulerEntry(ScheduleHeuristic::Persistent, true) {
|
||||
: SchedulerEntry(ScheduleHeuristic::Persistent) {
|
||||
computeHeuristics(fusion, runtime_info, data_cache);
|
||||
}
|
||||
|
||||
void schedule(Fusion* fusion) override {
|
||||
FUSER_PERF_SCOPE("Schedule Persistent Fusion");
|
||||
schedulePersistentKernel(fusion, rparams());
|
||||
schedulePersistentKernel(fusion, reductionParams());
|
||||
}
|
||||
|
||||
static bool canScheduleCompileTime(Fusion* fusion) {
|
||||
|
|
@ -1065,15 +1050,6 @@ class PersistentKernelScheduler : public SchedulerEntry {
|
|||
|
||||
auto reduction_ops =
|
||||
ir_utils::getReductionOps(fusion, false /* ignore_trivial */);
|
||||
auto welford_ops = ir_utils::filterByType<WelfordOp>(reduction_ops);
|
||||
// For persistent schedule we want welford translated to average and
|
||||
// standard deviation reductions.
|
||||
if (welford_ops.begin() != welford_ops.end()) {
|
||||
scheduler_debug_utils::canScheduleRejectReason(
|
||||
ScheduleHeuristic::Persistent,
|
||||
"no support for un-translated welford");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto view_tvs = scheduler_utils::getViewTVs(fusion);
|
||||
if (view_tvs.size() > 0) {
|
||||
|
|
@ -1263,9 +1239,8 @@ class PersistentKernelScheduler : public SchedulerEntry {
|
|||
Fusion* fusion,
|
||||
SchedulerRuntimeInfo& runtime_info,
|
||||
HeuristicSummary* data_cache = nullptr) {
|
||||
auto param = getPersistentHeuristics(fusion, runtime_info, data_cache);
|
||||
TORCH_INTERNAL_ASSERT(param.has_value());
|
||||
rparams() = param.value();
|
||||
params_ = getPersistentHeuristics(fusion, runtime_info, data_cache);
|
||||
TORCH_INTERNAL_ASSERT(params_ != nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -1367,11 +1342,7 @@ c10::optional<ScheduleHeuristic> SchedulerEntry::proposeHeuristics(
|
|||
}
|
||||
|
||||
size_t SchedulerEntryHash::operator()(const SchedulerEntry& se) const {
|
||||
if (se.hasReductionParam()) {
|
||||
return ReductionParamsHash()(se.reductionParams());
|
||||
} else {
|
||||
return PointwiseParamsHash()(se.pointwiseParams());
|
||||
}
|
||||
return se.params()->hash();
|
||||
}
|
||||
|
||||
std::string toString(ScheduleHeuristic sh) {
|
||||
|
|
@ -1444,6 +1415,9 @@ HeuristicSummary::HeuristicSummary(
|
|||
void HeuristicSummary::validate() const {
|
||||
switch (heuristic_) {
|
||||
case ScheduleHeuristic::PointWise: {
|
||||
TORCH_INTERNAL_ASSERT(entry_type_map_.count(EntryType::DOMAIN_MAP));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
entry_type_map_.count(EntryType::REFERENCE_TENSORS));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
entry_type_map_.count(EntryType::VECTORIZABLE_INPUTS_AND_OUTPUTS));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
|
|
@ -1512,6 +1486,8 @@ HeuristicSummaryEntry<EntryClass>::HeuristicSummaryEntry(
|
|||
}
|
||||
|
||||
// Template instantiation for pre-defined cache entries
|
||||
template class HeuristicSummaryEntry<HeuristicCompileTime::DomainMap>;
|
||||
template class HeuristicSummaryEntry<HeuristicCompileTime::ReferenceTensors>;
|
||||
template class HeuristicSummaryEntry<
|
||||
HeuristicCompileTime::VectorizableInputsAndOutputs>;
|
||||
template class HeuristicSummaryEntry<
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/heuristic.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
|
||||
|
|
@ -158,11 +161,7 @@ class TORCH_CUDA_CU_API SchedulerEntry {
|
|||
//! Heuristic comparison
|
||||
bool sameAs(const SchedulerEntry* other);
|
||||
|
||||
bool hasReductionParam() const {
|
||||
return has_reduction_param_;
|
||||
}
|
||||
|
||||
ScheduleHeuristic heuristc() const {
|
||||
ScheduleHeuristic heuristic() const {
|
||||
return heuristc_;
|
||||
}
|
||||
|
||||
|
|
@ -170,51 +169,38 @@ class TORCH_CUDA_CU_API SchedulerEntry {
|
|||
return index_mode_;
|
||||
}
|
||||
|
||||
const std::shared_ptr<HeuristicParams>& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
const ReductionParams& reductionParams() const {
|
||||
auto rparams = std::dynamic_pointer_cast<ReductionParams>(params_);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
has_reduction_param_, "This schedule heuristic is not reduction.");
|
||||
return rparams_;
|
||||
rparams != nullptr, "Heuristic parameter is not a reduction parameter");
|
||||
return *rparams;
|
||||
}
|
||||
|
||||
const PointwiseParams& pointwiseParams() const {
|
||||
auto pparams = std::dynamic_pointer_cast<PointwiseParams>(params_);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!has_reduction_param_, "This schedule heuristic is not pointwise.");
|
||||
return pparams_;
|
||||
pparams != nullptr, "Heuristic parameter is not a pointwise parameter");
|
||||
return *pparams;
|
||||
}
|
||||
|
||||
void updateLaunchConstraint(const LaunchParams& launch_params) {
|
||||
if (hasReductionParam()) {
|
||||
rparams_.lparams = launch_params;
|
||||
} else {
|
||||
pparams_.lparams = launch_params;
|
||||
}
|
||||
params_->lparams = launch_params;
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit SchedulerEntry(ScheduleHeuristic heuristic, bool has_reduction_param)
|
||||
: heuristc_(heuristic), has_reduction_param_(has_reduction_param) {}
|
||||
explicit SchedulerEntry(ScheduleHeuristic heuristic) : heuristc_(heuristic) {}
|
||||
|
||||
ReductionParams& rparams() {
|
||||
return rparams_;
|
||||
}
|
||||
|
||||
PointwiseParams& pparams() {
|
||||
return pparams_;
|
||||
}
|
||||
//! Heuristic parameters if applicable
|
||||
std::shared_ptr<HeuristicParams> params_ = nullptr;
|
||||
|
||||
private:
|
||||
//! What kind of heuristics does this entry have?
|
||||
const ScheduleHeuristic heuristc_;
|
||||
|
||||
//! Has reduction params if true, else has pointwise params
|
||||
const bool has_reduction_param_;
|
||||
|
||||
//! Reduction parameters if applicable
|
||||
ReductionParams rparams_;
|
||||
|
||||
//! Pointwise parameters if applicable
|
||||
PointwiseParams pparams_;
|
||||
|
||||
//! Kernel Index Mode
|
||||
KernelIndexMode index_mode_ = KernelIndexMode::INT64;
|
||||
};
|
||||
|
|
@ -226,10 +212,12 @@ class TORCH_CUDA_CU_API SchedulerEntryHash {
|
|||
};
|
||||
|
||||
//! Debug print function for heuristics
|
||||
std::string toString(ScheduleHeuristic sh);
|
||||
TORCH_CUDA_CU_API std::string toString(ScheduleHeuristic sh);
|
||||
|
||||
//! Debug print function for heuristics
|
||||
std::ostream& operator<<(std::ostream& os, ScheduleHeuristic sh);
|
||||
TORCH_CUDA_CU_API std::ostream& operator<<(
|
||||
std::ostream& os,
|
||||
ScheduleHeuristic sh);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
|
|
|
|||
|
|
@ -188,30 +188,53 @@ size_t mergeNonReduction(
|
|||
|
||||
void parallelizeAllLike(
|
||||
TensorView* reference_tv,
|
||||
const std::vector<TensorView*>& all_tvs) {
|
||||
int64_t pos,
|
||||
std::vector<TensorView*> selected_tvs,
|
||||
const std::unordered_set<ParallelType>& selected_parallel_types,
|
||||
bool propagate_padding) {
|
||||
FusionGuard fg(reference_tv->fusion());
|
||||
|
||||
if (pos < 0) {
|
||||
pos += reference_tv->nDims() + 1;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
pos >= 0 && pos <= reference_tv->nDims(),
|
||||
"parallelizeAllLike called on an position outside valid range.");
|
||||
|
||||
std::unordered_map<IterDomain*, IterDomain*> concrete_to_reference_map;
|
||||
|
||||
auto ca_map = ComputeAtMap(FusionGuard::getCurFusion());
|
||||
|
||||
for (auto id : reference_tv->domain()->domain()) {
|
||||
ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE)
|
||||
->parallelize(id->getParallelType());
|
||||
if (id->hasPaddingToMultipleOfWarp()) {
|
||||
ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE)
|
||||
->padToMultipleOfWarp(id->getMaybeSizeAfterPadding());
|
||||
}
|
||||
const auto& reference_dom = reference_tv->domain()->domain();
|
||||
for (auto it = reference_dom.begin(); it != reference_dom.begin() + pos;
|
||||
it++) {
|
||||
auto ca_id = ca_map.getConcreteMappedID(*it, IdMappingMode::PERMISSIVE);
|
||||
concrete_to_reference_map[ca_id] = *it;
|
||||
}
|
||||
|
||||
for (auto tv : all_tvs) {
|
||||
if (selected_tvs.empty()) {
|
||||
selected_tvs = ir_utils::allTvs(reference_tv->fusion());
|
||||
}
|
||||
for (auto tv : selected_tvs) {
|
||||
if (tv->isFusionInput()) {
|
||||
continue;
|
||||
}
|
||||
for (const auto i : c10::irange(tv->domain()->domain().size())) {
|
||||
auto ca_id =
|
||||
ca_map.getConcreteMappedID(tv->axis(i), IdMappingMode::PERMISSIVE);
|
||||
tv->axis(i)->parallelize(ca_id->getParallelType());
|
||||
if (ca_id->hasPaddingToMultipleOfWarp()) {
|
||||
tv->axis(i)->padToMultipleOfWarp(ca_id->getMaybeSizeAfterPadding());
|
||||
if (concrete_to_reference_map.count(ca_id) > 0) {
|
||||
auto reference_id = concrete_to_reference_map.at(ca_id);
|
||||
auto reference_parallel_type = reference_id->getParallelType();
|
||||
if (selected_parallel_types.empty() ||
|
||||
selected_parallel_types.count(reference_parallel_type)) {
|
||||
tv->axis(i)->parallelize(reference_parallel_type);
|
||||
}
|
||||
if (propagate_padding) {
|
||||
if (reference_id->hasPaddingToMultipleOfWarp()) {
|
||||
tv->axis(i)->padToMultipleOfWarp(
|
||||
reference_id->getMaybeSizeAfterPadding());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1607,7 +1630,8 @@ void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) {
|
|||
void scheduleContiguousVectorLoad(
|
||||
TensorView* tv,
|
||||
MatMulTileOptions tile,
|
||||
int vector_word) {
|
||||
int vector_word,
|
||||
bool vectorize) {
|
||||
auto warp_dims = tile.cta_tile / tile.warp_tile;
|
||||
int num_of_thread = warp_dims.m * warp_dims.n * warp_dims.k * 32;
|
||||
|
||||
|
|
@ -1630,14 +1654,423 @@ void scheduleContiguousVectorLoad(
|
|||
tv->split(-3, warp_dims.k);
|
||||
}
|
||||
|
||||
tv->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
if (vectorize) {
|
||||
tv->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
}
|
||||
|
||||
tv->axis(-2)->parallelize(ParallelType::TIDx);
|
||||
tv->axis(-3)->parallelize(ParallelType::TIDy);
|
||||
tv->axis(-4)->parallelize(ParallelType::TIDz);
|
||||
}
|
||||
|
||||
void makeTile(TensorView* tv, std::vector<int> tile_sizes) {
|
||||
TORCH_CHECK(
|
||||
tv->domain()->domain().size() >= tile_sizes.size(),
|
||||
"Tensor dimension less than tile dimension!");
|
||||
|
||||
// Number of inner dimensions we are tiling.
|
||||
const auto tile_dimension_size = tile_sizes.size();
|
||||
|
||||
// Split the inner dimensions:
|
||||
for (auto idx : c10::irange(tile_dimension_size)) {
|
||||
// Using negative indexing to accomodate potential batching
|
||||
// dimensions on the further left. Eg.:
|
||||
// 0, 1, 2 -> -3,-2,-1
|
||||
// [M, N, K] -> [B0, B1, M, N, K]
|
||||
tv->split(idx - tile_dimension_size, tile_sizes.at(idx));
|
||||
}
|
||||
|
||||
// The transformation happened should look like:
|
||||
// Before After
|
||||
// [..., M, N, K] -> [..., Mo, Mi, No, Ni, Ko, Ki]
|
||||
|
||||
// Re-order the tiles so that all the outer tiles are
|
||||
// on the left of all the inner tiles
|
||||
std::unordered_map<int, int> reorder_map_old_to_new;
|
||||
|
||||
// Number of tiled inner dimensions after we split.
|
||||
const auto split_tile_dimension_size = 2 * tile_dimension_size;
|
||||
for (auto idx : c10::irange(split_tile_dimension_size)) {
|
||||
// We want to reorder as follows:
|
||||
// Before
|
||||
//
|
||||
// [..., Mo, Mi, No, Ni, Ko, Ki] ->
|
||||
// After
|
||||
// vvv group0 vvv vvv group1 vvv
|
||||
// [..., Mo, No, Ko, Mi, Ni, Ki]
|
||||
|
||||
// The index offset within group of current
|
||||
// iterdomain, with grouping specified above.
|
||||
auto index_within_group = idx / 2;
|
||||
|
||||
// The index of the group the current id belongs
|
||||
// to, as specified above.
|
||||
auto group_index = idx % 2;
|
||||
|
||||
// Calculate the actual index after reordering
|
||||
auto index_after_reorder =
|
||||
group_index * tile_dimension_size + index_within_group;
|
||||
|
||||
// Add pair {idx_before, idx_after} to re-order map.
|
||||
reorder_map_old_to_new.insert(std::make_pair(
|
||||
idx - split_tile_dimension_size,
|
||||
index_after_reorder - split_tile_dimension_size));
|
||||
}
|
||||
|
||||
// Apply the re-order map to tensor
|
||||
tv->reorder(reorder_map_old_to_new);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
c10::optional<IterDomain*> getMaybeRootIfInnermostTiled(
|
||||
IterDomain* id,
|
||||
const std::unordered_set<IterDomain*>& maybe_rfactor_id_set) {
|
||||
// Root id defaults to an "innermost id".
|
||||
while (id->definition() && !maybe_rfactor_id_set.count(id)) {
|
||||
if (auto split = dynamic_cast<Split*>(id->definition())) {
|
||||
if (id == split->inner()) {
|
||||
id = split->in();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// Didn't pass the inner most check, return empty.
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TORCH_CUDA_CU_API void orderTiledConcreteIdAsRoot(TensorView* tv) {
|
||||
auto ndims = tv->nDims();
|
||||
|
||||
// Keep track of the left most position where we will
|
||||
// be reordering the axes.
|
||||
auto leftmost_pos = ndims;
|
||||
|
||||
// Pull the root id's of the given tv.
|
||||
std::unordered_set<IterDomain*> maybe_rfactor_id_set{
|
||||
tv->getMaybeRFactorDomain().begin(), tv->getMaybeRFactorDomain().end()};
|
||||
|
||||
// Keep track of leaf positions that is either a reduction
|
||||
// or a broadcast.
|
||||
// Note: Currently don't really see a case where this function
|
||||
// should be called on a reduction output tv, but adding them
|
||||
// here for completeness.
|
||||
std::deque<int> broadcast_or_reduction_pos;
|
||||
|
||||
// Map the root id's to their innermost concrete id's
|
||||
// on the leaf.
|
||||
std::unordered_map<IterDomain*, int> root_id_to_inner_leaf_pos;
|
||||
|
||||
// Try to re-order inner iterdomains from the innermost
|
||||
// position backward. This utility only tries to re-order
|
||||
// inner tiles on the innermost positions, like the resulting
|
||||
// tensor from makeTile utility.
|
||||
// The re-ordering would first try to decide the inner iterdomains
|
||||
// we want to re-order. For this we start from the innermost position
|
||||
// and move back and collect all the iterdomains that we know
|
||||
// are inner tiles of some root domain or broadcast/reduction domains
|
||||
// that won't affect the concrete id layout.
|
||||
// The collection process would stop whenever a iterdomain that is
|
||||
// neither an inner tile nor reduction/broadcast is found, and would
|
||||
// not re-order any iterdomain beyond that point to keep the
|
||||
// outer loop structure unchanged.
|
||||
for (int64_t i = static_cast<int64_t>(ndims) - 1; i >= 0; i--) {
|
||||
auto leaf_id = tv->axis(i);
|
||||
if (leaf_id->isBroadcast() || leaf_id->isReduction()) {
|
||||
// Register this reduction or broadcast axis
|
||||
// to reorder.
|
||||
broadcast_or_reduction_pos.push_front(i);
|
||||
leftmost_pos = i;
|
||||
continue;
|
||||
}
|
||||
auto maybe_root =
|
||||
getMaybeRootIfInnermostTiled(leaf_id, maybe_rfactor_id_set);
|
||||
|
||||
if (maybe_root.has_value()) {
|
||||
// Found an innermost id, add them to the
|
||||
// axes to reorder.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
root_id_to_inner_leaf_pos
|
||||
.insert(std::make_pair(maybe_root.value(), i))
|
||||
.second,
|
||||
"Multiple \"innermost\" id seen for root id :",
|
||||
maybe_root.value()->toString(),
|
||||
" on ",
|
||||
tv->toString(),
|
||||
" very likely an invariant is broken.");
|
||||
leftmost_pos = i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate the ordering:
|
||||
|
||||
// pointer to the current target postion after
|
||||
// repordering
|
||||
int current_pos = leftmost_pos;
|
||||
std::unordered_map<int, int> reorder_map_old_to_new;
|
||||
|
||||
// first place all the broadcast and reduction on the left:
|
||||
for (auto original_broadcast_or_reduction_pos : broadcast_or_reduction_pos) {
|
||||
reorder_map_old_to_new[original_broadcast_or_reduction_pos] = current_pos++;
|
||||
}
|
||||
|
||||
// Next put all the innermost leaf id's, we make sure that
|
||||
// the inner tile ordering follows the corresponding root
|
||||
// domain ordering by iterating on the root domain and
|
||||
// find their corresponding inner tile iterdomains from
|
||||
// the populated root_id_to_inner_leaf_pos.
|
||||
for (auto root_id : tv->getMaybeRFactorDomain()) {
|
||||
auto leaf_id_pos_it = root_id_to_inner_leaf_pos.find(root_id);
|
||||
if (leaf_id_pos_it != root_id_to_inner_leaf_pos.end()) {
|
||||
reorder_map_old_to_new[leaf_id_pos_it->second] = current_pos++;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that we have processed all inner ids or broadcast/reduction
|
||||
// ids we have registered.
|
||||
TORCH_INTERNAL_ASSERT(current_pos == ndims, "Inconsistent ordering logic");
|
||||
|
||||
// Apply the new order:
|
||||
tv->reorder(reorder_map_old_to_new);
|
||||
}
|
||||
|
||||
} // namespace matmul_utils
|
||||
|
||||
//! Propagate current transformations on from_tv to all graphs
|
||||
TORCH_CUDA_CU_API void transformPropagateToAllFrom(
|
||||
TensorView* from_tv,
|
||||
int pos) {
|
||||
TransformPropagator propagator(from_tv, pos);
|
||||
MaxRootDomainInfoSpanningTree(from_tv, nullptr).traverse(&propagator);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
//! Utility enum to signify which direction
|
||||
//! BoundedDirectionalTransformPropagator
|
||||
//! passes will propagate the transforms.
|
||||
enum class PropagateDirection { Backward = 0, Forward };
|
||||
|
||||
//! Returns true if the given tensorview is a fake boundary
|
||||
//! TensorView, see Note [Fake Boundary Tensorview].
|
||||
//! This function assumes and would not check that tv is a boundary
|
||||
//! of the select_tv set.
|
||||
bool isFakeBoundaryTensorview(
|
||||
TensorView* tv,
|
||||
const std::unordered_set<TensorView*>& selected_tv_set,
|
||||
PropagateDirection direction) {
|
||||
if (direction == PropagateDirection::Forward) {
|
||||
// In the case of forward propagation,
|
||||
// a boundary tv is a fake boundary if
|
||||
// it has any consumer tv that's in the selected
|
||||
// set.
|
||||
for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) {
|
||||
if (selected_tv_set.count(consumer_tv)) {
|
||||
// Found a consumer that's in selected tv set.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
// In the case of backward propagation,
|
||||
// a boundary tv is a fake boundary if it has any producer
|
||||
// that is within the selected set.
|
||||
for (auto producer_tv : ir_utils::producerTvsOf(tv)) {
|
||||
if (selected_tv_set.count(producer_tv)) {
|
||||
// Found a producer that's in selected tv set.
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Didn't find any producer/consumer in the selected tv set.
|
||||
// The given tv is not a fake boundary tv.
|
||||
return false;
|
||||
}
|
||||
|
||||
//! Utility function to generate the set of tensorviews to propagate
|
||||
//! transform to by BoundedDirectionalTransformPropagator.
|
||||
std::unordered_set<TensorView*> getDirectionalPropagatePathSet(
|
||||
TensorView* from_tv,
|
||||
std::vector<TensorView*> boundary_tvs,
|
||||
BoundedDirectionalTransformPropagator::Options options,
|
||||
PropagateDirection direction) {
|
||||
// Prepare to collect all candidate tensorviews
|
||||
// within the specified boundary.
|
||||
std::vector<Val*> propagate_candidate;
|
||||
|
||||
// Collect boundary tvs in a set.
|
||||
std::unordered_set<TensorView*> boundary_tv_set(
|
||||
boundary_tvs.begin(), boundary_tvs.end());
|
||||
|
||||
if (direction == PropagateDirection::Forward) {
|
||||
// In the case of forward propagation, collect all tvs
|
||||
// that are consumers of `from_tv` and producers of
|
||||
// boundary tvs.
|
||||
propagate_candidate = DependencyCheck::getAllValsBetween(
|
||||
{from_tv}, {boundary_tvs.begin(), boundary_tvs.end()});
|
||||
} else {
|
||||
// In the case of backward propagation, collect all tvs
|
||||
// that are producers of `from_tv` and consumers of
|
||||
// boundary tvs.
|
||||
propagate_candidate = DependencyCheck::getAllValsBetween(
|
||||
{boundary_tvs.begin(), boundary_tvs.end()}, {from_tv});
|
||||
}
|
||||
|
||||
// Populate initial selected tensorviews in a set.
|
||||
auto propagate_candidate_tv_view =
|
||||
ir_utils::filterByType<TensorView>(propagate_candidate);
|
||||
// Prepare to filter out un-wanted tensorviews according
|
||||
// to the option parameters.
|
||||
std::unordered_set<TensorView*> propagate_path_set{
|
||||
propagate_candidate_tv_view.begin(), propagate_candidate_tv_view.end()};
|
||||
|
||||
// Remove boundary tensorviews if we don't want to transform
|
||||
// tensorviews on the boundary.
|
||||
if (!options.transform_boundary) {
|
||||
// Additional refining step to identify "fake boundary" tensorviews.
|
||||
// We don't want to erase fake boundary tensorviews from the selected
|
||||
// set when we are erasing boundary tvs.
|
||||
//
|
||||
// Note [Fake Boundary Tensorview]
|
||||
// A tensorview, tv0, is defined as fake boundary tv if
|
||||
// 1. Tv0 is on the given boundary set.
|
||||
// 2. There is a path from another boundary tv, Tv1 to from_tv that
|
||||
// goes through Tv0.
|
||||
//
|
||||
// In this case the propagation behavior is not precisely defined.
|
||||
// Our current decision is to treat such tensorview as non-boundary
|
||||
// tv to make sure the propagation paths are not blocked. E.g.:
|
||||
//
|
||||
// T1 = T0
|
||||
// T2 = T1
|
||||
// T3 = T2 + T1
|
||||
// if we propagate with from_tv = {T3}, boundary_tv = {T0, T2},
|
||||
// transform_boundary=false
|
||||
//
|
||||
// Here T2 is a fake boundary and we will still transform T2 as it is
|
||||
// on the path between T3 and T0.
|
||||
|
||||
// Initialize set of fake boundary tvs.
|
||||
std::unordered_set<TensorView*> fake_boundary_set;
|
||||
|
||||
// Populate the set of fake boundary tvs.
|
||||
std::copy_if(
|
||||
boundary_tvs.begin(),
|
||||
boundary_tvs.end(),
|
||||
std::inserter(fake_boundary_set, fake_boundary_set.end()),
|
||||
[&propagate_path_set, direction](TensorView* tv) {
|
||||
return isFakeBoundaryTensorview(tv, propagate_path_set, direction);
|
||||
});
|
||||
|
||||
// Remove boundary tvs from the selected set, keeping fake boundary tvs.
|
||||
for (auto boundary_tv : boundary_tvs) {
|
||||
if (!fake_boundary_set.count(boundary_tv)) {
|
||||
propagate_path_set.erase(boundary_tv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return propagate_path_set;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void BoundedDirectionalTransformPropagator::propagate(
|
||||
TensorView* from_tv,
|
||||
int pos,
|
||||
std::unordered_set<TensorView*> included_tvs,
|
||||
Options options) {
|
||||
// Run transform propagation using the custom selector.
|
||||
SetSelector selector(included_tvs);
|
||||
TransformPropagator propagator(from_tv, pos);
|
||||
MaxRootDomainInfoSpanningTree(from_tv, &selector).traverse(&propagator);
|
||||
|
||||
// Propagate parallel type if requested by option parameters.
|
||||
if (options.propagate_parallel_type) {
|
||||
scheduler_utils::parallelizeAllLike(
|
||||
from_tv,
|
||||
options.parallel_propagation_pos,
|
||||
{included_tvs.begin(), included_tvs.end()},
|
||||
allParallelTypesExcept({ParallelType::Vectorize, ParallelType::Mma}));
|
||||
}
|
||||
}
|
||||
|
||||
void BoundedDirectionalTransformPropagator::backward(
|
||||
TensorView* from,
|
||||
int pos,
|
||||
std::vector<TensorView*> to,
|
||||
c10::optional<Options> options) {
|
||||
if (!options.has_value()) {
|
||||
options = Options();
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!to.empty(),
|
||||
"Propagation needs to be bounded, so no support for empty boundary.");
|
||||
|
||||
// Collect all tvs to included on the backward path as specified
|
||||
// by boundary and options.
|
||||
auto included_tvs = getDirectionalPropagatePathSet(
|
||||
from, to, *options, PropagateDirection::Backward);
|
||||
// Actually run the propagation.
|
||||
propagate(from, pos, included_tvs, *options);
|
||||
}
|
||||
|
||||
void BoundedDirectionalTransformPropagator::forward(
|
||||
TensorView* from,
|
||||
int pos,
|
||||
std::vector<TensorView*> to,
|
||||
c10::optional<Options> options) {
|
||||
if (!options.has_value()) {
|
||||
options = Options();
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!to.empty(),
|
||||
"Propagation needs to be bounded, so no support for empty boundary.")
|
||||
|
||||
// Collect all tvs to included on the forward path as specified
|
||||
// by boundary and options.
|
||||
auto included_tvs = getDirectionalPropagatePathSet(
|
||||
from, to, *options, PropagateDirection::Forward);
|
||||
|
||||
// Actually run the propagation.
|
||||
propagate(from, pos, included_tvs, *options);
|
||||
}
|
||||
|
||||
void BoundedDirectionalTransformPropagator::bothWays(
|
||||
TensorView* from,
|
||||
int pos,
|
||||
std::vector<TensorView*> backward_to,
|
||||
std::vector<TensorView*> forward_to,
|
||||
c10::optional<Options> options) {
|
||||
if (!options.has_value()) {
|
||||
options = Options();
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!backward_to.empty() && !forward_to.empty(),
|
||||
"Propagation needs to be bounded, so no support for empty boundary.")
|
||||
|
||||
// Collect all tvs to included on the backward and forward path as specified
|
||||
// by boundary and options.
|
||||
auto backward_included_tvs = getDirectionalPropagatePathSet(
|
||||
from, backward_to, *options, PropagateDirection::Backward);
|
||||
auto forward_included_tvs = getDirectionalPropagatePathSet(
|
||||
from, forward_to, *options, PropagateDirection::Forward);
|
||||
|
||||
// Combined the included tvs on both paths.
|
||||
auto included_tvs = backward_included_tvs;
|
||||
included_tvs.insert(forward_included_tvs.begin(), forward_included_tvs.end());
|
||||
|
||||
// Run the propagation on the combined set of tvs.
|
||||
propagate(from, pos, included_tvs, *options);
|
||||
}
|
||||
|
||||
// Grab all values and expressions used to make the merged_domain and remove
|
||||
// them from the fusion
|
||||
void cleanUpInnermostMergedDomains(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/maxinfo_propagator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -11,6 +12,7 @@ namespace cuda {
|
|||
|
||||
class SchedulerRuntimeInfo;
|
||||
class ExpressionEvaluator;
|
||||
class HeuristicSummary;
|
||||
|
||||
namespace scheduler_utils {
|
||||
|
||||
|
|
@ -37,6 +39,11 @@ constexpr int64_t lastPow2(int64_t n) {
|
|||
return std::max((int64_t)1, n - (n >> 1));
|
||||
}
|
||||
|
||||
// Div x by y, but min at 1
|
||||
inline int64_t safeDiv(const int64_t x, const int64_t y) {
|
||||
return std::max(x / y, (int64_t)1);
|
||||
}
|
||||
|
||||
// Merge all reduction to the right side and returns total number of
|
||||
// reduction axes. Don't merge is typically used for trivial reductions.
|
||||
size_t mergeReduction(
|
||||
|
|
@ -49,9 +56,32 @@ size_t mergeNonReduction(
|
|||
TensorView* tv,
|
||||
const std::unordered_set<IterDomain*>& dont_merge = {});
|
||||
|
||||
// Propagate the parallelization from the selected dimensions of the reference
|
||||
// tensor to their corresponding dimensions in all selected tensors in the DAG.
|
||||
// Position `pos` means selecting all the dimensions [0, 1, ..., pos - 1]. pos =
|
||||
// -1 means selecting all dimensions. `selected_tvs` are selected tensors in the
|
||||
// DAG. Empty `selected_tvs` means selecting all tensors in the fusion of
|
||||
// `reference_tv`. `selected_parallel_types` are the selected parallel types.
|
||||
// Empty `selected_parallel_types` means selecting all parallel types.
|
||||
TORCH_CUDA_CU_API void parallelizeAllLike(
|
||||
TensorView* reference_tv,
|
||||
const std::vector<TensorView*>& all_tvs);
|
||||
int64_t pos = -1,
|
||||
std::vector<TensorView*> selected_tvs = {},
|
||||
const std::unordered_set<ParallelType>& selected_parallel_types = {},
|
||||
bool propagate_padding = true);
|
||||
|
||||
TORCH_CUDA_CU_API inline void parallelizeAllLike(
|
||||
TensorView* reference_tv,
|
||||
std::vector<TensorView*> selected_tvs,
|
||||
const std::unordered_set<ParallelType>& selected_parallel_types = {},
|
||||
bool propagate_padding = true) {
|
||||
parallelizeAllLike(
|
||||
reference_tv,
|
||||
-1,
|
||||
std::move(selected_tvs),
|
||||
selected_parallel_types,
|
||||
propagate_padding);
|
||||
}
|
||||
|
||||
TORCH_CUDA_CU_API void computeAtInputs(
|
||||
TensorView* consumer,
|
||||
|
|
@ -166,7 +196,6 @@ std::pair<bool, bool> canonicalDimReduction(
|
|||
|
||||
// Return a list of tensor views that are outputs of reduction operations. If
|
||||
// multiple outputs of an expression are found, only include one in the list
|
||||
// (WelfordOp)
|
||||
TORCH_CUDA_CU_API std::vector<TensorView*> getReductionTvs(
|
||||
Fusion* fusion,
|
||||
bool ignore_trivial = true);
|
||||
|
|
@ -179,13 +208,14 @@ void clearMemorySpace(Fusion* fusion);
|
|||
|
||||
// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
|
||||
// return empty vector.
|
||||
std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll);
|
||||
TORCH_CUDA_CU_API std::vector<TensorView*> cacheInputs(
|
||||
Fusion* fusion,
|
||||
bool unroll);
|
||||
|
||||
// Returns the pairs of <cache of each fusion output, corresponding output> for
|
||||
// all outputs.
|
||||
std::vector<std::pair<TensorView*, TensorView*>> cacheAndForkOutputs(
|
||||
Fusion* fusion,
|
||||
bool unroll);
|
||||
TORCH_CUDA_CU_API std::vector<std::pair<TensorView*, TensorView*>>
|
||||
cacheAndForkOutputs(Fusion* fusion, bool unroll);
|
||||
|
||||
// Ignores broadcast and reduction, returns iter domain in root domain that's
|
||||
// "inner most". If this is an rfactored reduction domain, actually check the
|
||||
|
|
@ -289,10 +319,12 @@ namespace matmul_utils {
|
|||
TORCH_CUDA_CU_API void scheduleContiguousVectorLoad(
|
||||
TensorView* tv,
|
||||
MatMulTileOptions tile,
|
||||
int vector_word);
|
||||
int vector_word,
|
||||
bool vectorize = true);
|
||||
|
||||
//! Schedule utility for mma output in matmul main loop:
|
||||
//! Realize the hierarchical tiling based on the given tiling options.
|
||||
//! TODO: rewrite this one with makeTile
|
||||
TORCH_CUDA_CU_API void scheduleWarpTileWithReduction(
|
||||
TensorView* tv,
|
||||
MatMulTileOptions tile);
|
||||
|
|
@ -300,12 +332,142 @@ TORCH_CUDA_CU_API void scheduleWarpTileWithReduction(
|
|||
//! Schedule utility for mma output in matmul main loop:
|
||||
//! Realize the hierarchical tiling based on the given tiling options
|
||||
//! on consumers of mma ops in epilog.
|
||||
//! TODO: remove this one eventually.
|
||||
TORCH_CUDA_CU_API void scheduleWarpTileWithNoReduction(
|
||||
TensorView* tv,
|
||||
MatMulTileOptions tile);
|
||||
|
||||
//! Lower level primitive spliting inner iterdomains into tiles:
|
||||
//! Eg.
|
||||
//! A[B,I0,I1,I2] -> makeTile({1,2,3})
|
||||
//! Gives A[B, I0o, I1o, I2o, I0i(1), I1i(2), I2i(3)]
|
||||
TORCH_CUDA_CU_API void makeTile(TensorView* tv, std::vector<int> tile_sizes);
|
||||
|
||||
//! Order the inner tile dimensions as the original order in
|
||||
//! root domain. Also putting broadcast domains on the left.
|
||||
//! Eg. A[I0o,I1o,B2o,I0i,I1i,B2i] (root domain: I1,B,I0)
|
||||
//! -> A[I0o, I1o, B2o, B2i, I1i, I0i]
|
||||
//! This is used to facilitate data layout swizzling and
|
||||
//! defining vectorized loads.
|
||||
TORCH_CUDA_CU_API void orderTiledConcreteIdAsRoot(TensorView* tv);
|
||||
|
||||
//! Orders the root id ordering of the given tv as
|
||||
//! [Batch, Previous Reduction, M, N, K]
|
||||
//! for easier processing of later scheduling steps.
|
||||
//!
|
||||
//! This matching works on root domain only, and
|
||||
//! will throw if the tv has a leaf iterdomain that is
|
||||
//! not a root id.
|
||||
TORCH_CUDA_CU_API void canonicalizeMmaTvOrdering(TensorView* tv);
|
||||
|
||||
} // namespace matmul_utils
|
||||
|
||||
//! Propagate current transformations on from_tv up to the given
|
||||
//! position, to all tensorviews on the owning fusion that has
|
||||
//! a connection with `from_tv` on the fusion graph.
|
||||
TORCH_CUDA_CU_API void transformPropagateToAllFrom(
|
||||
TensorView* from_tv,
|
||||
int pos);
|
||||
|
||||
//! A type of custom transform propagator that propagates iterdomain
|
||||
//! transforms from a source tv to all tvs that are selected
|
||||
//! using a "direction" and a "boundary".
|
||||
//!
|
||||
//! The propagation model always assumes a `from_tv`, a `direction` and a
|
||||
//! `boundary`.
|
||||
//!
|
||||
//! This propagator will only transform producers and consumers
|
||||
//! of `from_tv`, and all propagation modes **require** a boundary to be
|
||||
//! specified to signify where the propagation should stop.
|
||||
//!
|
||||
//! There are currently three modes of propagation: forward, backward and
|
||||
//! both-way, see comment on the interface functions for details.
|
||||
struct TORCH_CUDA_CU_API BoundedDirectionalTransformPropagator {
|
||||
//! Custom option container for configuring
|
||||
//! the transform propagation actions.
|
||||
//! All option values default to false unless
|
||||
//! the corresponding setter is called.
|
||||
struct Options {
|
||||
//! If true, the transform propagator will
|
||||
//! also propagate parallel types from
|
||||
//! `from_tv` to all selected tvs.
|
||||
bool propagate_parallel_type = false;
|
||||
|
||||
//! If true, the specified boundary tvs
|
||||
//! will also be replayed as `from_tv`.
|
||||
//! If false, they will not be affected
|
||||
//! by the propagation pass.
|
||||
bool transform_boundary = false;
|
||||
|
||||
//! Sets the position boundary in parallel
|
||||
//! type propagation, see comment on
|
||||
//! scheduler_utils::parallelizeAllLike.
|
||||
//! Only used if propagate_parallel_type==true.
|
||||
int parallel_propagation_pos = -1;
|
||||
|
||||
//! Setter for enabling parallel type
|
||||
//! propagation. see comment on the variable.
|
||||
//!
|
||||
//! \param up_to_pos, sets the parallel type
|
||||
//! propagation boundary. see comment on
|
||||
//! scheduler_utils::parallelizeAllLike.
|
||||
Options propagateParallelType(int up_to_pos = -1) {
|
||||
propagate_parallel_type = true;
|
||||
parallel_propagation_pos = up_to_pos;
|
||||
return *this;
|
||||
}
|
||||
|
||||
//! Setter for enabling propagation to
|
||||
//! boundary tvs. see comment on the variable
|
||||
Options propagateToBoundary() {
|
||||
transform_boundary = true;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
//! Replay transforms from tensorview `from`
|
||||
//! to the tensorviews that are consumers
|
||||
//! of boundary tensorviews in `to` and producers of `from`.
|
||||
static void backward(
|
||||
TensorView* from,
|
||||
int pos,
|
||||
std::vector<TensorView*> to,
|
||||
c10::optional<Options> options = c10::nullopt);
|
||||
|
||||
//! Replay transforms from tensorview `from`
|
||||
//! to the tensorviews that are producers
|
||||
//! of boundary tensorviews in `to` and consumers of `from`.
|
||||
static void forward(
|
||||
TensorView* from,
|
||||
int pos,
|
||||
std::vector<TensorView*> to,
|
||||
c10::optional<Options> options = c10::nullopt);
|
||||
|
||||
//! Replay transforms from tensorview `from`
|
||||
//! to all the tensorviews that are consumers
|
||||
//! of tensorviews in `backward_to` and producers
|
||||
//! of tensorviews in `forward_to` while being
|
||||
//! either a producer or a consumer of tensorview `from`.
|
||||
static void bothWays(
|
||||
TensorView* from,
|
||||
int pos,
|
||||
std::vector<TensorView*> backward_to,
|
||||
std::vector<TensorView*> forward_to,
|
||||
c10::optional<Options> options = c10::nullopt);
|
||||
|
||||
private:
|
||||
//! Utility function:
|
||||
//! Will realize the transform propagation to the
|
||||
//! tensorview's in `included_tvs`.
|
||||
//! Assumes that all tvs in included_tvs are either
|
||||
//! a producer or a consumer of from_tv.
|
||||
static void propagate(
|
||||
TensorView* from_tv,
|
||||
int pos,
|
||||
std::unordered_set<TensorView*> included_tvs,
|
||||
Options options);
|
||||
};
|
||||
|
||||
} // namespace scheduler_utils
|
||||
} // namespace cuda
|
||||
} // namespace fuser
|
||||
|
|
|
|||
|
|
@ -211,6 +211,8 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner)
|
|||
memory_type_(src->memory_type_),
|
||||
swizzle_type_(src->swizzle_type_),
|
||||
is_double_buffered_(src->is_double_buffered_),
|
||||
is_circular_buffered_(src->is_circular_buffered_),
|
||||
circular_buffer_stage_(src->circular_buffer_stage_),
|
||||
cpu_scalar_(src->cpu_scalar_),
|
||||
has_swizzle_op_(src->has_swizzle_op_) {
|
||||
for (const auto id : src->axesToSwizzle()) {
|
||||
|
|
@ -571,7 +573,11 @@ TensorView* TensorView::swizzle(
|
|||
return this;
|
||||
}
|
||||
|
||||
TensorView* TensorView::swizzle(Swizzle2DType swizzle_type, int x, int y) {
|
||||
TensorView* TensorView::swizzle(
|
||||
Swizzle2DType swizzle_type,
|
||||
int x,
|
||||
int y,
|
||||
SwizzleMode swizzle_mode) {
|
||||
has_swizzle_op_ = true;
|
||||
if (x < 0) {
|
||||
x += domain()->nDims();
|
||||
|
|
@ -645,7 +651,7 @@ TensorView* TensorView::swizzle(Swizzle2DType swizzle_type, int x, int y) {
|
|||
}
|
||||
}
|
||||
|
||||
domain()->swizzle(swizzle_type, x, y);
|
||||
domain()->swizzle(swizzle_type, x, y, swizzle_mode);
|
||||
|
||||
return this;
|
||||
}
|
||||
|
|
@ -735,11 +741,10 @@ TensorView* TensorView::multiOutputRfactorHelper(
|
|||
!container()->isA<kir::Kernel>(),
|
||||
"Function invalid for kernel container.");
|
||||
// Hack:
|
||||
// Semantically we should always keep the outputs of welfordOp scheduled
|
||||
// the same but the user end cannot guarantee that.
|
||||
// In order to guarantee that the rFactor is defined meaningfully the
|
||||
// scheduling of the output TV that got the rfactor call is force replayed
|
||||
// towards the other two
|
||||
// Semantically we should always keep the outputs of multi reduction ops
|
||||
// scheduled the same but the user end cannot guarantee that. In order to
|
||||
// guarantee that the rFactor is defined meaningfully the scheduling of the
|
||||
// output TV that got the rfactor call is force replayed towards the other two
|
||||
|
||||
if (!sameAs(tv)) {
|
||||
auto root = tv->getRootDomain();
|
||||
|
|
@ -758,7 +763,7 @@ TensorView* TensorView::multiOutputRfactorHelper(
|
|||
std::vector<IterDomain*> new_id;
|
||||
for (auto id : domain()->domain()) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
replay.getReplay().count(id), "Welford Replay Failed");
|
||||
replay.getReplay().count(id), "Multi-output reduction replay failed");
|
||||
new_id.push_back(replay.getReplay().at(id));
|
||||
}
|
||||
|
||||
|
|
@ -795,12 +800,11 @@ std::vector<TensorView*> TensorView::rFactor(
|
|||
TORCH_CHECK(nDims() > 0, "Tried to rFactor a 0-dim TensorView");
|
||||
FusionGuard fg(fusion());
|
||||
TORCH_CHECK(
|
||||
definition() != nullptr &&
|
||||
(definition()->getExprType() == ExprType::GroupedReductionOp ||
|
||||
definition()->getExprType() == ExprType::WelfordOp),
|
||||
"Error rfactoring welford ",
|
||||
definition() != nullptr && ir_utils::isReductionOp(definition()),
|
||||
"Error rfactoring multi-output reduction op ",
|
||||
this,
|
||||
" its definition is either a nullptr or not a GroupedReductionOp or a WelfordOp.");
|
||||
" its definition is either a nullptr or not a GroupedReductionOp or a multi-output reduction op.");
|
||||
|
||||
TORCH_CHECK(
|
||||
!domain()->hasRFactor(), "Cannot call rfactor on the same view twice.");
|
||||
|
||||
|
|
@ -1129,6 +1133,21 @@ void TensorView::doubleBuffer() {
|
|||
is_double_buffered_ = true;
|
||||
}
|
||||
|
||||
void TensorView::circularBuffer(unsigned int stage) {
|
||||
// Early correctness checking. May miss eventual errors as the
|
||||
// checks depend on memory types and parallelization, which may not
|
||||
// be finalized until lowering.
|
||||
TORCH_INTERNAL_ASSERT(stage > 1, "Unsupported stage number");
|
||||
if (stage == 2) {
|
||||
// Re-direct to double buffer interface if stage is 2;
|
||||
doubleBuffer();
|
||||
return;
|
||||
}
|
||||
validateDoubleBufferedTensor(this);
|
||||
is_circular_buffered_ = true;
|
||||
circular_buffer_stage_ = stage;
|
||||
}
|
||||
|
||||
bool TensorView::isEmptyTensor() const {
|
||||
auto& root_domain = getMaybeRFactorDomain();
|
||||
return std::all_of(
|
||||
|
|
@ -1174,7 +1193,29 @@ TensorViewBuilder& TensorViewBuilder::contiguity(std::vector<bool> contiguity) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
TensorViewBuilder& TensorViewBuilder::shape(std::vector<int64_t> shape) {
|
||||
TensorViewBuilder& TensorViewBuilder::shape(const std::vector<int64_t>& shape) {
|
||||
TORCH_CHECK(shape_.empty(), "Attempting to reset shape");
|
||||
if (!shape.empty()) {
|
||||
TORCH_CHECK(ndims_ == 0 || ndims_ == shape.size());
|
||||
ndims_ = shape.size();
|
||||
}
|
||||
shape_.clear();
|
||||
shape_.reserve(shape.size());
|
||||
for (int64_t i : shape) {
|
||||
if (i == -1) {
|
||||
shape_.emplace_back(IrBuilder::create<Int>());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
i >= 0,
|
||||
"Invalid extent value. ",
|
||||
"For a tensor representing a single scalar use ndims = 0 with no sizes set.");
|
||||
shape_.emplace_back(IrBuilder::create<Int>(i));
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
TensorViewBuilder& TensorViewBuilder::shape(std::vector<Val*> shape) {
|
||||
TORCH_CHECK(shape_.empty(), "Attempting to reset shape");
|
||||
if (!shape.empty()) {
|
||||
TORCH_CHECK(ndims_ == 0 || ndims_ == shape.size());
|
||||
|
|
@ -1188,17 +1229,13 @@ TensorView* TensorViewBuilder::build() const {
|
|||
// Build the domain
|
||||
std::vector<IterDomain*> domain(ndims_, nullptr);
|
||||
for (const auto i : c10::irange(ndims_)) {
|
||||
if (shape_.empty() || shape_[i] == -1) {
|
||||
if (shape_.empty()) {
|
||||
domain[i] =
|
||||
IterDomainBuilder(
|
||||
FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create<Int>())
|
||||
.build();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
shape_[i] >= 0,
|
||||
"Invalid extent value. ",
|
||||
"For a tensor representing a single scalar use ndims = 0 with no sizes set.");
|
||||
if (shape_[i] == 1) {
|
||||
if (shape_[i]->isOneInt()) {
|
||||
// If size is known to be 1, assume it needs to be broadcasted.
|
||||
domain[i] = IterDomainBuilder(
|
||||
FusionGuard::getCurFusion()->zeroVal(),
|
||||
|
|
@ -1206,10 +1243,9 @@ TensorView* TensorViewBuilder::build() const {
|
|||
.iter_type(IterType::Broadcast)
|
||||
.build();
|
||||
} else {
|
||||
domain[i] = IterDomainBuilder(
|
||||
FusionGuard::getCurFusion()->zeroVal(),
|
||||
IrBuilder::create<Int>(shape_[i]))
|
||||
.build();
|
||||
domain[i] =
|
||||
IterDomainBuilder(FusionGuard::getCurFusion()->zeroVal(), shape_[i])
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -26,6 +26,7 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/test/test_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
|
||||
|
||||
|
|
@ -46,29 +47,6 @@ using namespace at::indexing;
|
|||
|
||||
namespace {
|
||||
|
||||
// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
|
||||
// but unknown sizes
|
||||
TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
|
||||
return TensorViewBuilder()
|
||||
.ndims(ndims)
|
||||
.dtype(dtype)
|
||||
.contiguity(std::vector<bool>(ndims, true))
|
||||
.build();
|
||||
}
|
||||
|
||||
// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
|
||||
// but unknown sizes
|
||||
TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
|
||||
return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
|
||||
}
|
||||
|
||||
// Make a non-contiguous tensor of compile-time known sizes
|
||||
TensorView* makeConcreteTensor(
|
||||
std::vector<int64_t> shape,
|
||||
DataType dtype = DataType::Float) {
|
||||
return TensorViewBuilder().shape(shape).dtype(dtype).build();
|
||||
}
|
||||
|
||||
class KernelExprVisitor : private kir::IrVisitor {
|
||||
public:
|
||||
static std::vector<Expr*> getAllExprs(const kir::Kernel* kernel) {
|
||||
|
|
@ -144,7 +122,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce1_CUDA) {
|
|||
tv3->axis(0)->parallelize(ParallelType::BIDy);
|
||||
tv3->axis(2)->parallelize(ParallelType::BIDx);
|
||||
tv3->axis(3)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv3);
|
||||
|
||||
// Just to make sure fused_reduction and work buffers are allocated
|
||||
// uniquely
|
||||
|
|
@ -247,7 +225,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce3_CUDA) {
|
|||
|
||||
tv3->axis(1)->parallelize(ParallelType::BIDx);
|
||||
tv3->axis(2)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv3);
|
||||
|
||||
GpuLower gpulw(&fusion);
|
||||
validateNoParallelBroadcastExist(gpulw.kernel());
|
||||
|
|
@ -293,7 +271,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce4_CUDA) {
|
|||
|
||||
tv4->axis(0)->parallelize(ParallelType::BIDx);
|
||||
tv4->axis(1)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv4);
|
||||
|
||||
GpuLower gpulw(&fusion);
|
||||
validateNoParallelBroadcastExist(gpulw.kernel());
|
||||
|
|
@ -352,7 +330,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce5_CUDA) {
|
|||
|
||||
tv4->axis(1)->parallelize(ParallelType::BIDx);
|
||||
tv4->axis(2)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv4);
|
||||
|
||||
tv6->axis(0)->parallelize(ParallelType::BIDy);
|
||||
tv6->axis(1)->parallelize(ParallelType::BIDx);
|
||||
|
|
@ -410,7 +388,7 @@ TEST_F(NVFuserTest, FusionGridAllreduce6_CUDA) {
|
|||
tv1->axis(2)->parallelize(ParallelType::BIDx);
|
||||
tv1->axis(3)->parallelize(ParallelType::TIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv1);
|
||||
|
||||
tv1->axis(4)->parallelize(ParallelType::Vectorize);
|
||||
|
||||
|
|
@ -456,7 +434,7 @@ TEST_F(NVFuserTest, FusionGridAllreduceWelford1_CUDA) {
|
|||
|
||||
tv5->axis(0)->parallelize(ParallelType::BIDx);
|
||||
tv5->axis(1)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv5);
|
||||
|
||||
GpuLower gpulw(&fusion);
|
||||
validateNoParallelBroadcastExist(gpulw.kernel());
|
||||
|
|
@ -506,7 +484,7 @@ TEST_F(NVFuserTest, FusionGridAllreduceWelford2_CUDA) {
|
|||
|
||||
tv3->axis(1)->parallelize(ParallelType::BIDx);
|
||||
tv3->axis(2)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv3);
|
||||
|
||||
// There must be no parallel broadcast
|
||||
GpuLower gpulw(&fusion);
|
||||
|
|
@ -610,7 +588,7 @@ TEST_F(NVFuserTest, FusionFusedReductionBatchnorm_CUDA) {
|
|||
TransformPropagator propagator(tv0);
|
||||
MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator);
|
||||
|
||||
auto tvs_rf = tvs.rFactor({-5, -4, -3, -2, -1});
|
||||
ir_utils::rfactorHelper(tvs.avg, {-5, -4, -3, -2, -1});
|
||||
|
||||
tv0->computeAt(tv29, 2);
|
||||
tv1->computeAt(tv29, 2);
|
||||
|
|
@ -622,7 +600,7 @@ TEST_F(NVFuserTest, FusionFusedReductionBatchnorm_CUDA) {
|
|||
tv29->axis(2)->parallelize(ParallelType::BIDy);
|
||||
tv29->axis(3)->parallelize(ParallelType::TIDz);
|
||||
tv29->axis(4)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv29, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv29);
|
||||
|
||||
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
|
||||
auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
|
||||
|
|
@ -696,7 +674,7 @@ TEST_F(NVFuserTest, FusionGroupedReduction1_CUDA) {
|
|||
|
||||
tv2->axis(0)->parallelize(ParallelType::BIDx);
|
||||
tv2->axis(1)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv2);
|
||||
|
||||
std::vector<int64_t> shape({99, 999});
|
||||
|
||||
|
|
@ -872,7 +850,7 @@ TEST_F(NVFuserTest, FusionGroupedReduction6_CUDA) {
|
|||
tv2->axis(0)->parallelize(ParallelType::BIDx);
|
||||
tv2->axis(1)->parallelize(ParallelType::TIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv2);
|
||||
|
||||
std::vector<int64_t> shape({99, 999});
|
||||
|
||||
|
|
@ -937,7 +915,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionRfactor1_CUDA) {
|
|||
tv1_rf->axis(0)->parallelize(ParallelType::BIDx);
|
||||
tv1_rf->axis(2)->parallelize(ParallelType::TIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(tv1_rf, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv1_rf);
|
||||
|
||||
std::vector<int64_t> shape({12345});
|
||||
|
||||
|
|
@ -982,7 +960,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionRfactor2_CUDA) {
|
|||
tv1_rf->axis(0)->parallelize(ParallelType::BIDx);
|
||||
tv1_rf->axis(2)->parallelize(ParallelType::TIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(tv1_rf, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv1_rf);
|
||||
|
||||
std::vector<int64_t> shape({12345});
|
||||
|
||||
|
|
@ -1028,7 +1006,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionAfterComputeAt_CUDA) {
|
|||
groupReductions({tv2, tv3});
|
||||
|
||||
tv2->axis(1)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv2);
|
||||
|
||||
std::vector<int64_t> shape({3, 1234});
|
||||
|
||||
|
|
@ -1068,7 +1046,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce1_CUDA) {
|
|||
|
||||
tv2->axis(0)->parallelize(ParallelType::BIDx);
|
||||
tv2->axis(1)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv2);
|
||||
|
||||
std::vector<int64_t> shape({999});
|
||||
|
||||
|
|
@ -1117,7 +1095,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce2_CUDA) {
|
|||
tv1->axis(0)->parallelize(ParallelType::BIDy);
|
||||
tv1->axis(1)->parallelize(ParallelType::BIDx);
|
||||
tv1->axis(2)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv1);
|
||||
|
||||
std::vector<int64_t> shape({10, 999});
|
||||
|
||||
|
|
@ -1169,7 +1147,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce3_CUDA) {
|
|||
|
||||
tv1->axis(0)->parallelize(ParallelType::BIDx);
|
||||
tv1->axis(1)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv1);
|
||||
|
||||
std::vector<int64_t> shape({999});
|
||||
|
||||
|
|
@ -1221,7 +1199,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce4_CUDA) {
|
|||
|
||||
reduction_tv->axis(0)->parallelize(ParallelType::BIDx);
|
||||
reduction_tv->axis(1)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(reduction_tv, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(reduction_tv);
|
||||
|
||||
std::vector<int64_t> shape({999});
|
||||
|
||||
|
|
@ -1281,7 +1259,7 @@ TEST_F(NVFuserTest, FusionGroupAllreduce5_CUDA) {
|
|||
|
||||
tv1->axis(0)->parallelize(ParallelType::BIDx);
|
||||
tv1->axis(1)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv1);
|
||||
|
||||
std::vector<int64_t> shape({999});
|
||||
|
||||
|
|
@ -1445,7 +1423,7 @@ TEST_F(NVFuserTest, FusionPersistentBNBackwardAllreduce_CUDA) {
|
|||
}
|
||||
|
||||
// Parallelization
|
||||
scheduler_utils::parallelizeAllLike(grad_input, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(grad_input);
|
||||
input_cache->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
grad_output_cache->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
|
||||
|
|
@ -1559,7 +1537,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionReEntrant1_CUDA) {
|
|||
tv2->axis(2)->parallelize(ParallelType::BIDx);
|
||||
tv2->axis(3)->parallelize(ParallelType::TIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv2);
|
||||
|
||||
std::vector<int64_t> shape({99, 999});
|
||||
|
||||
|
|
@ -1670,7 +1648,7 @@ TEST_F(NVFuserTest, FusionGroupedReductionChannelsLastBatchNormLike_CUDA) {
|
|||
ref->axis(4)->parallelize(ParallelType::Serial);
|
||||
ref->axis(5)->parallelize(ParallelType::Serial);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(ref, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(ref);
|
||||
|
||||
tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
|
|
@ -1799,7 +1777,7 @@ TEST_F(
|
|||
ref->axis(4)->parallelize(ParallelType::Serial);
|
||||
ref->axis(5)->parallelize(ParallelType::Serial);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(ref, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(ref);
|
||||
|
||||
tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize);
|
||||
|
|
@ -1866,7 +1844,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce1_CUDA) {
|
|||
tv1->axis(2)->parallelize(ParallelType::BIDx);
|
||||
tv1->axis(3)->parallelize(ParallelType::TIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv1);
|
||||
|
||||
tv2->axis(4)->parallelize(ParallelType::Group);
|
||||
|
||||
|
|
@ -1944,7 +1922,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce2_CUDA) {
|
|||
tv1->axis(2)->parallelize(ParallelType::BIDx);
|
||||
tv1->axis(3)->parallelize(ParallelType::TIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv1);
|
||||
|
||||
tv2->axis(4)->parallelize(ParallelType::Group);
|
||||
tv2->axis(5)->parallelize(ParallelType::Group);
|
||||
|
|
@ -2030,7 +2008,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduce3_CUDA) {
|
|||
tv1->axis(2)->parallelize(ParallelType::BIDx);
|
||||
tv1->axis(3)->parallelize(ParallelType::TIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv1);
|
||||
|
||||
tv2->axis(4)->parallelize(ParallelType::Group);
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/test/test_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
|
||||
|
||||
|
|
@ -46,48 +47,6 @@ using namespace at::indexing;
|
|||
|
||||
namespace {
|
||||
|
||||
// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
|
||||
// but unknown sizes
|
||||
TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
|
||||
return TensorViewBuilder()
|
||||
.ndims(ndims)
|
||||
.dtype(dtype)
|
||||
.contiguity(std::vector<bool>(ndims, true))
|
||||
.build();
|
||||
}
|
||||
|
||||
// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
|
||||
// but unknown sizes
|
||||
TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
|
||||
return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
|
||||
}
|
||||
|
||||
// Make a non-contiguous tensor of compile-time known sizes
|
||||
TensorView* makeConcreteTensor(
|
||||
std::vector<int64_t> shape,
|
||||
DataType dtype = DataType::Float) {
|
||||
return TensorViewBuilder().shape(shape).dtype(dtype).build();
|
||||
}
|
||||
|
||||
void checkIntValue(
|
||||
ExpressionEvaluator& evaluator,
|
||||
Val* val,
|
||||
Int::ScalarType expected_value) {
|
||||
TORCH_CHECK(val->isAnInt());
|
||||
const auto actual_value = evaluator.evaluate(val);
|
||||
TORCH_CHECK(actual_value.has_value());
|
||||
TORCH_CHECK(actual_value.value() == expected_value);
|
||||
}
|
||||
|
||||
void checkIntValue(
|
||||
kir::ExpressionEvaluator& evaluator,
|
||||
const Val* val,
|
||||
Int::ScalarType expected_value) {
|
||||
const auto actual_value = evaluator.evaluate(val);
|
||||
TORCH_CHECK(actual_value.has_value());
|
||||
TORCH_CHECK(actual_value.value() == expected_value);
|
||||
}
|
||||
|
||||
// Used to signify invalid ranges, i.e., values at offset 0 to
|
||||
// start_offset, and values at offset stop_offset to the end of the
|
||||
// domain.
|
||||
|
|
@ -2175,7 +2134,7 @@ TEST_F(NVFuserTest, FusionHdiff_CUDA) {
|
|||
out->axis(3)->parallelize(ParallelType::TIDy);
|
||||
out->axis(4)->parallelize(ParallelType::TIDx);
|
||||
// Apply the same parallelization to all other tensors
|
||||
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(out);
|
||||
|
||||
// Store intermediate stencil results on smem so that they can be
|
||||
// accessed by threads
|
||||
|
|
@ -2733,7 +2692,7 @@ TEST_F(NVFuserTest, FusionGather6_CUDA) {
|
|||
out->axis(1)->parallelize(ParallelType::BIDx);
|
||||
out->axis(2)->parallelize(ParallelType::TIDy);
|
||||
out->axis(3)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(out);
|
||||
|
||||
const int s1 = 101;
|
||||
const int s2 = 99;
|
||||
|
|
@ -2793,7 +2752,7 @@ TEST_F(NVFuserTest, FusionGather7_CUDA) {
|
|||
out->axis(1)->parallelize(ParallelType::BIDx);
|
||||
out->axis(2)->parallelize(ParallelType::TIDy);
|
||||
out->axis(3)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(out);
|
||||
|
||||
const int s1 = 101;
|
||||
const int s2 = 99;
|
||||
|
|
@ -2894,7 +2853,7 @@ TEST_F(NVFuserTest, FusionGather9_CUDA) {
|
|||
out->axis(1)->parallelize(ParallelType::BIDx);
|
||||
out->axis(2)->parallelize(ParallelType::TIDy);
|
||||
out->axis(3)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(out);
|
||||
|
||||
const int s1 = 101;
|
||||
const int s2 = 99;
|
||||
|
|
@ -3817,7 +3776,7 @@ TEST_F(NVFuserTest, FusionShiftNoPadding1_CUDA) {
|
|||
|
||||
tv5->axis(-1)->parallelize(ParallelType::TIDx);
|
||||
tv5->axis(-2)->parallelize(ParallelType::TIDy);
|
||||
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv5);
|
||||
|
||||
int numel_x = 99;
|
||||
int numel_y = 101;
|
||||
|
|
@ -3873,7 +3832,7 @@ TEST_F(NVFuserTest, FusionShiftNoPadding2_CUDA) {
|
|||
tv3->computeAt(tv5, -1);
|
||||
|
||||
tv5->axis(-1)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv5);
|
||||
|
||||
int numel_x = 99;
|
||||
int numel_y = 101;
|
||||
|
|
@ -3934,7 +3893,7 @@ TEST_F(NVFuserTest, FusionShiftNoPadding3_CUDA) {
|
|||
tv3->computeAt(tv_avg, -1);
|
||||
|
||||
tv_avg->axis(-1)->parallelize(ParallelType::TIDx);
|
||||
scheduler_utils::parallelizeAllLike(tv_avg, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv_avg);
|
||||
|
||||
int numel_x = 99;
|
||||
int numel_y = 101;
|
||||
|
|
@ -4122,7 +4081,7 @@ TEST_F(NVFuserTest, FusionShiftPadding1_CUDA) {
|
|||
|
||||
tv5->axis(-1)->parallelize(ParallelType::TIDx);
|
||||
tv5->axis(-2)->parallelize(ParallelType::TIDy);
|
||||
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv5);
|
||||
|
||||
int numel_x = 99;
|
||||
int numel_y = 101;
|
||||
|
|
@ -5080,7 +5039,7 @@ TEST_F(NVFuserTest, FusionMaxPoolingStrided_CUDA) {
|
|||
max_tensor->axis(4)->parallelize(ParallelType::TIDy);
|
||||
max_tensor->axis(5)->parallelize(ParallelType::TIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(max_tensor, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(max_tensor);
|
||||
|
||||
inp_cache->setMemoryType(MemoryType::Shared);
|
||||
|
||||
|
|
@ -5332,7 +5291,7 @@ TEST_F(NVFuserTest, FusionGather9ptStencilDoubleBuffering_CUDA) {
|
|||
out->axis(2)->parallelize(ParallelType::TIDy);
|
||||
out->axis(0)->parallelize(ParallelType::BIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(out);
|
||||
|
||||
tv0_cache->doubleBuffer();
|
||||
|
||||
|
|
@ -5380,7 +5339,7 @@ TEST_F(NVFuserTest, FusionValidateParallelizeShift_CUDA) {
|
|||
|
||||
tv5->axis(1)->parallelize(ParallelType::TIDx);
|
||||
|
||||
scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
|
||||
scheduler_utils::parallelizeAllLike(tv5);
|
||||
|
||||
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
|
||||
at::Tensor t0 = at::randn({1024 * 32}, options);
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -28,6 +28,7 @@
|
|||
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/test/test_utils.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
|
||||
|
||||
|
|
@ -49,33 +50,6 @@ namespace jit {
|
|||
using namespace torch::jit::fuser::cuda;
|
||||
using namespace at::indexing;
|
||||
|
||||
namespace {
|
||||
|
||||
// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
|
||||
// but unknown sizes
|
||||
TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
|
||||
return TensorViewBuilder()
|
||||
.ndims(ndims)
|
||||
.dtype(dtype)
|
||||
.contiguity(std::vector<bool>(ndims, true))
|
||||
.build();
|
||||
}
|
||||
|
||||
// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
|
||||
// but unknown sizes
|
||||
TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
|
||||
return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
|
||||
}
|
||||
|
||||
// Make a non-contiguous tensor of compile-time known sizes
|
||||
TensorView* makeConcreteTensor(
|
||||
std::vector<int64_t> shape,
|
||||
DataType dtype = DataType::Float) {
|
||||
return TensorViewBuilder().shape(shape).dtype(dtype).build();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(NVFuserTest, FusionViewDtypeSameSizeOutput_CUDA) {
|
||||
Fusion fusion;
|
||||
FusionGuard fg(&fusion);
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user