mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD 16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875) 501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823) a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872) 172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871) 440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870) d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864) 51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866) a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861) 1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852) e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858) 9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857) 20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855) 405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853) 5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845) db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848) 102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847) b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842) 0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840) 63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839) 2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067 Approved by: https://github.com/davidberard98
954 lines
28 KiB
C++
954 lines
28 KiB
C++
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/type.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
template <typename T>
|
|
T* ptr(T& obj) {
|
|
return &obj;
|
|
}
|
|
|
|
template <typename T>
|
|
T* ptr(T* obj) {
|
|
return obj;
|
|
}
|
|
|
|
/*
|
|
* Generic dispatch for any handler that does not modify the IR directly.
|
|
* For example we may want to walk the graph to construct a topologically sorted
|
|
* set of exprs. This doesn't modify the IR directly. We also use this to print
|
|
* the IR itself.
|
|
* This dispatch is paired with a class that implements the functions:
|
|
* template <typenname node_type>
|
|
* int handler(node_type* node)
|
|
*
|
|
* handler should call:
|
|
* dispatch(this, node_to_dispatch)
|
|
*
|
|
* It could also implement:
|
|
* int handler(Statement* stmt){
|
|
* dispatch(this, stmt);
|
|
* }
|
|
*
|
|
* And therefore dispatch should never call:
|
|
* ptr(mutator)->mutate(this->as<Statement>());
|
|
*/
|
|
|
|
template <typename T>
|
|
void Val::dispatch(T handler, Val* val) {
|
|
switch (*(val->getValType())) {
|
|
case ValType::Scalar:
|
|
switch (*(val->getDataType())) {
|
|
case DataType::Bool:
|
|
ptr(handler)->handle(val->as<Bool>());
|
|
return;
|
|
case DataType::Double:
|
|
ptr(handler)->handle(val->as<Double>());
|
|
return;
|
|
case DataType::Int:
|
|
case DataType::Int32:
|
|
// Dispatch to Int even with Int32 as we don't have Int32 IR
|
|
// node.
|
|
ptr(handler)->handle(val->as<Int>());
|
|
return;
|
|
case DataType::ComplexDouble:
|
|
ptr(handler)->handle(val->as<ComplexDouble>());
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
case ValType::NamedScalar:
|
|
ptr(handler)->handle(val->as<NamedScalar>());
|
|
return;
|
|
|
|
case ValType::IterDomain:
|
|
ptr(handler)->handle(val->as<IterDomain>());
|
|
return;
|
|
case ValType::TensorDomain:
|
|
ptr(handler)->handle(val->as<TensorDomain>());
|
|
return;
|
|
case ValType::TensorView:
|
|
ptr(handler)->handle(val->as<TensorView>());
|
|
return;
|
|
case ValType::Predicate:
|
|
ptr(handler)->handle(val->as<kir::Predicate>());
|
|
return;
|
|
case ValType::TensorIndex:
|
|
ptr(handler)->handle(val->as<kir::TensorIndex>());
|
|
return;
|
|
case ValType::IntPair:
|
|
ptr(handler)->handle(val->as<kir::IntPair>());
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
void Expr::dispatch(T handler, Expr* expr) {
|
|
switch (*(expr->getExprType())) {
|
|
case ExprType::UnaryOp:
|
|
ptr(handler)->handle(expr->as<UnaryOp>());
|
|
return;
|
|
case ExprType::BinaryOp:
|
|
ptr(handler)->handle(expr->as<BinaryOp>());
|
|
return;
|
|
case ExprType::TernaryOp:
|
|
ptr(handler)->handle(expr->as<TernaryOp>());
|
|
return;
|
|
case ExprType::ReductionOp:
|
|
ptr(handler)->handle(expr->as<ReductionOp>());
|
|
return;
|
|
case ExprType::GroupedReductionOp:
|
|
ptr(handler)->handle(expr->as<GroupedReductionOp>());
|
|
return;
|
|
case ExprType::WelfordOp:
|
|
ptr(handler)->handle(expr->as<WelfordOp>());
|
|
return;
|
|
case ExprType::LoadStoreOp:
|
|
ptr(handler)->handle(expr->as<LoadStoreOp>());
|
|
return;
|
|
case ExprType::MmaOp:
|
|
ptr(handler)->handle(expr->as<MmaOp>());
|
|
return;
|
|
case ExprType::BroadcastOp:
|
|
ptr(handler)->handle(expr->as<BroadcastOp>());
|
|
return;
|
|
|
|
case ExprType::Split:
|
|
ptr(handler)->handle(expr->as<Split>());
|
|
return;
|
|
case ExprType::Merge:
|
|
ptr(handler)->handle(expr->as<Merge>());
|
|
return;
|
|
case ExprType::Swizzle2D:
|
|
ptr(handler)->handle(expr->as<Swizzle2D>());
|
|
return;
|
|
case ExprType::TransposeOp:
|
|
ptr(handler)->handle(expr->as<TransposeOp>());
|
|
return;
|
|
case ExprType::ExpandOp:
|
|
ptr(handler)->handle(expr->as<ExpandOp>());
|
|
return;
|
|
case ExprType::ShiftOp:
|
|
ptr(handler)->handle(expr->as<ShiftOp>());
|
|
return;
|
|
case ExprType::GatherOp:
|
|
ptr(handler)->handle(expr->as<GatherOp>());
|
|
return;
|
|
case ExprType::ViewAsScalar:
|
|
ptr(handler)->handle(expr->as<ViewAsScalar>());
|
|
return;
|
|
case ExprType::ViewOp:
|
|
ptr(handler)->handle(expr->as<ViewOp>());
|
|
return;
|
|
|
|
case ExprType::Allocate:
|
|
ptr(handler)->handle(expr->as<kir::Allocate>());
|
|
return;
|
|
case ExprType::BlockSync:
|
|
ptr(handler)->handle(expr->as<kir::BlockSync>());
|
|
return;
|
|
case ExprType::GridSync:
|
|
ptr(handler)->handle(expr->as<kir::GridSync>());
|
|
return;
|
|
case ExprType::CpAsyncWait:
|
|
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
|
|
return;
|
|
case ExprType::CpAsyncCommit:
|
|
ptr(handler)->handle(expr->as<kir::CpAsyncCommit>());
|
|
return;
|
|
case ExprType::InitMagicZero:
|
|
ptr(handler)->handle(expr->as<kir::InitMagicZero>());
|
|
return;
|
|
case ExprType::UpdateMagicZero:
|
|
ptr(handler)->handle(expr->as<kir::UpdateMagicZero>());
|
|
return;
|
|
case ExprType::ForLoop:
|
|
ptr(handler)->handle(expr->as<kir::ForLoop>());
|
|
return;
|
|
case ExprType::IfThenElse:
|
|
ptr(handler)->handle(expr->as<kir::IfThenElse>());
|
|
return;
|
|
case ExprType::GridReduction:
|
|
ptr(handler)->handle(expr->as<kir::GridReduction>());
|
|
return;
|
|
case ExprType::GroupedGridReduction:
|
|
ptr(handler)->handle(expr->as<kir::GroupedGridReduction>());
|
|
return;
|
|
case ExprType::GridBroadcast:
|
|
ptr(handler)->handle(expr->as<kir::GridBroadcast>());
|
|
return;
|
|
case ExprType::GridWelford:
|
|
ptr(handler)->handle(expr->as<kir::GridWelford>());
|
|
return;
|
|
case ExprType::AllocateFusedReduction:
|
|
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
|
|
return;
|
|
case ExprType::Swizzle2DInt:
|
|
ptr(handler)->handle(expr->as<kir::Swizzle2DInt>());
|
|
return;
|
|
case ExprType::PairSelect:
|
|
ptr(handler)->handle(expr->as<kir::PairSelect>());
|
|
return;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void Statement::dispatch(T handler, Statement* stmt) {
|
|
if (stmt->isVal()) {
|
|
ptr(handler)->handle(stmt->as<Val>());
|
|
} else if (stmt->isExpr()) {
|
|
ptr(handler)->handle(stmt->as<Expr>());
|
|
} else
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
void Val::constDispatch(T handler, const Val* val) {
|
|
switch (*(val->getValType())) {
|
|
case ValType::Scalar:
|
|
switch (*(val->getDataType())) {
|
|
case DataType::Bool:
|
|
ptr(handler)->handle(val->as<Bool>());
|
|
return;
|
|
case DataType::Double:
|
|
ptr(handler)->handle(val->as<Double>());
|
|
return;
|
|
case DataType::Int:
|
|
case DataType::Int32:
|
|
// Dispatch to Int even with Int32 as we don't have Int32 IR
|
|
// node.
|
|
ptr(handler)->handle(val->as<Int>());
|
|
return;
|
|
case DataType::ComplexDouble:
|
|
ptr(handler)->handle(val->as<ComplexDouble>());
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
case ValType::NamedScalar:
|
|
ptr(handler)->handle(val->as<NamedScalar>());
|
|
return;
|
|
|
|
case ValType::IterDomain:
|
|
ptr(handler)->handle(val->as<IterDomain>());
|
|
return;
|
|
case ValType::TensorDomain:
|
|
ptr(handler)->handle(val->as<TensorDomain>());
|
|
return;
|
|
case ValType::TensorView:
|
|
ptr(handler)->handle(val->as<TensorView>());
|
|
return;
|
|
case ValType::Predicate:
|
|
ptr(handler)->handle(val->as<kir::Predicate>());
|
|
return;
|
|
case ValType::TensorIndex:
|
|
ptr(handler)->handle(val->as<kir::TensorIndex>());
|
|
return;
|
|
case ValType::IntPair:
|
|
ptr(handler)->handle(val->as<kir::IntPair>());
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
void Expr::constDispatch(T handler, const Expr* expr) {
|
|
switch (*(expr->getExprType())) {
|
|
case ExprType::UnaryOp:
|
|
ptr(handler)->handle(expr->as<UnaryOp>());
|
|
return;
|
|
case ExprType::BinaryOp:
|
|
ptr(handler)->handle(expr->as<BinaryOp>());
|
|
return;
|
|
case ExprType::TernaryOp:
|
|
ptr(handler)->handle(expr->as<TernaryOp>());
|
|
return;
|
|
case ExprType::ReductionOp:
|
|
ptr(handler)->handle(expr->as<ReductionOp>());
|
|
return;
|
|
case ExprType::GroupedReductionOp:
|
|
ptr(handler)->handle(expr->as<GroupedReductionOp>());
|
|
return;
|
|
case ExprType::WelfordOp:
|
|
ptr(handler)->handle(expr->as<WelfordOp>());
|
|
return;
|
|
case ExprType::LoadStoreOp:
|
|
ptr(handler)->handle(expr->as<LoadStoreOp>());
|
|
return;
|
|
case ExprType::MmaOp:
|
|
ptr(handler)->handle(expr->as<MmaOp>());
|
|
return;
|
|
case ExprType::BroadcastOp:
|
|
ptr(handler)->handle(expr->as<BroadcastOp>());
|
|
return;
|
|
|
|
case ExprType::Split:
|
|
ptr(handler)->handle(expr->as<Split>());
|
|
return;
|
|
case ExprType::Merge:
|
|
ptr(handler)->handle(expr->as<Merge>());
|
|
return;
|
|
case ExprType::Swizzle2D:
|
|
ptr(handler)->handle(expr->as<Swizzle2D>());
|
|
return;
|
|
case ExprType::TransposeOp:
|
|
ptr(handler)->handle(expr->as<TransposeOp>());
|
|
return;
|
|
case ExprType::ExpandOp:
|
|
ptr(handler)->handle(expr->as<ExpandOp>());
|
|
return;
|
|
case ExprType::ShiftOp:
|
|
ptr(handler)->handle(expr->as<ShiftOp>());
|
|
return;
|
|
case ExprType::GatherOp:
|
|
ptr(handler)->handle(expr->as<GatherOp>());
|
|
return;
|
|
case ExprType::ViewAsScalar:
|
|
ptr(handler)->handle(expr->as<ViewAsScalar>());
|
|
return;
|
|
case ExprType::ViewOp:
|
|
ptr(handler)->handle(expr->as<ViewOp>());
|
|
return;
|
|
|
|
case ExprType::Allocate:
|
|
ptr(handler)->handle(expr->as<kir::Allocate>());
|
|
return;
|
|
case ExprType::BlockSync:
|
|
ptr(handler)->handle(expr->as<kir::BlockSync>());
|
|
return;
|
|
case ExprType::GridSync:
|
|
ptr(handler)->handle(expr->as<kir::GridSync>());
|
|
return;
|
|
case ExprType::CpAsyncWait:
|
|
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
|
|
return;
|
|
case ExprType::CpAsyncCommit:
|
|
ptr(handler)->handle(expr->as<kir::CpAsyncCommit>());
|
|
return;
|
|
case ExprType::InitMagicZero:
|
|
ptr(handler)->handle(expr->as<kir::InitMagicZero>());
|
|
return;
|
|
case ExprType::UpdateMagicZero:
|
|
ptr(handler)->handle(expr->as<kir::UpdateMagicZero>());
|
|
return;
|
|
case ExprType::ForLoop:
|
|
ptr(handler)->handle(expr->as<kir::ForLoop>());
|
|
return;
|
|
case ExprType::IfThenElse:
|
|
ptr(handler)->handle(expr->as<kir::IfThenElse>());
|
|
return;
|
|
case ExprType::GridReduction:
|
|
ptr(handler)->handle(expr->as<kir::GridReduction>());
|
|
return;
|
|
case ExprType::GroupedGridReduction:
|
|
ptr(handler)->handle(expr->as<kir::GroupedGridReduction>());
|
|
return;
|
|
case ExprType::GridBroadcast:
|
|
ptr(handler)->handle(expr->as<kir::GridBroadcast>());
|
|
return;
|
|
case ExprType::GridWelford:
|
|
ptr(handler)->handle(expr->as<kir::GridWelford>());
|
|
return;
|
|
case ExprType::AllocateFusedReduction:
|
|
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
|
|
return;
|
|
case ExprType::Swizzle2DInt:
|
|
ptr(handler)->handle(expr->as<kir::Swizzle2DInt>());
|
|
return;
|
|
case ExprType::PairSelect:
|
|
ptr(handler)->handle(expr->as<kir::PairSelect>());
|
|
return;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void Statement::constDispatch(T handler, const Statement* stmt) {
|
|
if (stmt->isVal()) {
|
|
ptr(handler)->handle(stmt->as<Val>());
|
|
} else if (stmt->isExpr()) {
|
|
ptr(handler)->handle(stmt->as<Expr>());
|
|
} else
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
|
|
}
|
|
|
|
/*
|
|
* Generic mutatorDispatch for any handler that modifies the IR. This could be
|
|
* a transformation on loop structures, or parallelizing a loop. This
|
|
* mutatorDispatch is paired with a class that implements the functions
|
|
* template <typenname node_type> Statement* mutate(node_type* node) mutate
|
|
* should call (statement* node_to_dispatch)->mutatorDispatch() It could also
|
|
* implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this);
|
|
* }
|
|
* And therefore dispatch should never call:
|
|
* ptr(mutator)->mutate(this->as<Statement>());
|
|
*/
|
|
template <typename T>
|
|
void Val::mutatorDispatch(T mutator, Val* val) {
|
|
switch (*(val->getValType())) {
|
|
case ValType::Scalar:
|
|
switch (*(val->getDataType())) {
|
|
case DataType::Bool:
|
|
ptr(mutator)->mutate(val->as<Bool>());
|
|
return;
|
|
case DataType::Double:
|
|
ptr(mutator)->mutate(val->as<Double>());
|
|
return;
|
|
case DataType::Int:
|
|
ptr(mutator)->mutate(val->as<Int>());
|
|
return;
|
|
case DataType::ComplexDouble:
|
|
ptr(mutator)->mutate(val->as<ComplexDouble>());
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
case ValType::NamedScalar:
|
|
ptr(mutator)->mutate(val->as<NamedScalar>());
|
|
return;
|
|
|
|
case ValType::IterDomain:
|
|
ptr(mutator)->mutate(val->as<IterDomain>());
|
|
return;
|
|
case ValType::TensorDomain:
|
|
ptr(mutator)->mutate(val->as<TensorDomain>());
|
|
return;
|
|
case ValType::TensorView:
|
|
ptr(mutator)->mutate(val->as<TensorView>());
|
|
return;
|
|
case ValType::Predicate:
|
|
ptr(mutator)->mutate(val->as<kir::Predicate>());
|
|
return;
|
|
case ValType::TensorIndex:
|
|
ptr(mutator)->mutate(val->as<kir::TensorIndex>());
|
|
return;
|
|
case ValType::IntPair:
|
|
ptr(mutator)->mutate(val->as<kir::IntPair>());
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
void Expr::mutatorDispatch(T mutator, Expr* expr) {
|
|
switch (*(expr->getExprType())) {
|
|
case ExprType::UnaryOp:
|
|
ptr(mutator)->mutate(expr->as<UnaryOp>());
|
|
return;
|
|
case ExprType::BinaryOp:
|
|
ptr(mutator)->mutate(expr->as<BinaryOp>());
|
|
return;
|
|
case ExprType::TernaryOp:
|
|
ptr(mutator)->mutate(expr->as<TernaryOp>());
|
|
return;
|
|
case ExprType::ReductionOp:
|
|
ptr(mutator)->mutate(expr->as<ReductionOp>());
|
|
return;
|
|
case ExprType::GroupedReductionOp:
|
|
ptr(mutator)->mutate(expr->as<GroupedReductionOp>());
|
|
return;
|
|
case ExprType::WelfordOp:
|
|
ptr(mutator)->mutate(expr->as<WelfordOp>());
|
|
return;
|
|
case ExprType::LoadStoreOp:
|
|
ptr(mutator)->mutate(expr->as<LoadStoreOp>());
|
|
return;
|
|
case ExprType::MmaOp:
|
|
ptr(mutator)->mutate(expr->as<MmaOp>());
|
|
return;
|
|
case ExprType::BroadcastOp:
|
|
ptr(mutator)->mutate(expr->as<BroadcastOp>());
|
|
return;
|
|
|
|
case ExprType::Split:
|
|
ptr(mutator)->mutate(expr->as<Split>());
|
|
return;
|
|
case ExprType::Merge:
|
|
ptr(mutator)->mutate(expr->as<Merge>());
|
|
return;
|
|
case ExprType::Swizzle2D:
|
|
ptr(mutator)->mutate(expr->as<Swizzle2D>());
|
|
return;
|
|
case ExprType::TransposeOp:
|
|
ptr(mutator)->mutate(expr->as<TransposeOp>());
|
|
return;
|
|
case ExprType::ExpandOp:
|
|
ptr(mutator)->mutate(expr->as<ExpandOp>());
|
|
return;
|
|
case ExprType::ShiftOp:
|
|
ptr(mutator)->mutate(expr->as<ShiftOp>());
|
|
return;
|
|
case ExprType::GatherOp:
|
|
ptr(mutator)->mutate(expr->as<GatherOp>());
|
|
return;
|
|
case ExprType::ViewAsScalar:
|
|
ptr(mutator)->mutate(expr->as<ViewAsScalar>());
|
|
return;
|
|
case ExprType::ViewOp:
|
|
ptr(mutator)->mutate(expr->as<ViewOp>());
|
|
return;
|
|
|
|
case ExprType::Allocate:
|
|
ptr(mutator)->mutate(expr->as<kir::Allocate>());
|
|
return;
|
|
case ExprType::BlockSync:
|
|
ptr(mutator)->mutate(expr->as<kir::BlockSync>());
|
|
return;
|
|
case ExprType::GridSync:
|
|
ptr(mutator)->mutate(expr->as<kir::GridSync>());
|
|
return;
|
|
case ExprType::CpAsyncWait:
|
|
ptr(mutator)->mutate(expr->as<kir::CpAsyncWait>());
|
|
return;
|
|
case ExprType::CpAsyncCommit:
|
|
ptr(mutator)->mutate(expr->as<kir::CpAsyncCommit>());
|
|
return;
|
|
case ExprType::InitMagicZero:
|
|
ptr(mutator)->mutate(expr->as<kir::InitMagicZero>());
|
|
return;
|
|
case ExprType::UpdateMagicZero:
|
|
ptr(mutator)->mutate(expr->as<kir::UpdateMagicZero>());
|
|
return;
|
|
case ExprType::ForLoop:
|
|
ptr(mutator)->mutate(expr->as<kir::ForLoop>());
|
|
return;
|
|
case ExprType::IfThenElse:
|
|
ptr(mutator)->mutate(expr->as<kir::IfThenElse>());
|
|
return;
|
|
case ExprType::GridReduction:
|
|
ptr(mutator)->mutate(expr->as<kir::GridReduction>());
|
|
return;
|
|
case ExprType::GroupedGridReduction:
|
|
ptr(mutator)->mutate(expr->as<kir::GroupedGridReduction>());
|
|
return;
|
|
case ExprType::GridBroadcast:
|
|
ptr(mutator)->mutate(expr->as<kir::GridBroadcast>());
|
|
return;
|
|
case ExprType::GridWelford:
|
|
ptr(mutator)->mutate(expr->as<kir::GridWelford>());
|
|
return;
|
|
case ExprType::AllocateFusedReduction:
|
|
ptr(mutator)->mutate(expr->as<kir::AllocateFusedReduction>());
|
|
return;
|
|
case ExprType::Swizzle2DInt:
|
|
ptr(mutator)->mutate(expr->as<kir::Swizzle2DInt>());
|
|
return;
|
|
case ExprType::PairSelect:
|
|
ptr(mutator)->mutate(expr->as<kir::PairSelect>());
|
|
return;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void Statement::mutatorDispatch(T mutator, Statement* stmt) {
|
|
if (stmt->isVal()) {
|
|
ptr(mutator)->mutate(stmt->as<Val>());
|
|
return;
|
|
}
|
|
if (stmt->isExpr()) {
|
|
ptr(mutator)->mutate(stmt->as<Expr>());
|
|
return;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
|
|
}
|
|
|
|
/*
|
|
* Handler template instantiations. These should only have to be done on base
|
|
* classes. Actual visitors/mutators should inhereit from these classes and call
|
|
* ->dispatch(this) to avoid needing an explicit instantiation.
|
|
*/
|
|
template void Statement::dispatch(OptOutDispatch&, Statement*);
|
|
template void Statement::dispatch(OptOutDispatch*, Statement*);
|
|
template void Val::dispatch(OptOutDispatch&, Val*);
|
|
template void Val::dispatch(OptOutDispatch*, Val*);
|
|
template void Expr::dispatch(OptOutDispatch&, Expr*);
|
|
template void Expr::dispatch(OptOutDispatch*, Expr*);
|
|
|
|
template void Statement::dispatch(OptInDispatch, Statement*);
|
|
template void Statement::dispatch(OptInDispatch*, Statement*);
|
|
template void Val::dispatch(OptInDispatch, Val*);
|
|
template void Val::dispatch(OptInDispatch*, Val*);
|
|
template void Expr::dispatch(OptInDispatch, Expr*);
|
|
template void Expr::dispatch(OptInDispatch*, Expr*);
|
|
|
|
template void Statement::constDispatch(OptOutConstDispatch&, const Statement*);
|
|
template void Statement::constDispatch(OptOutConstDispatch*, const Statement*);
|
|
template void Val::constDispatch(OptOutConstDispatch&, const Val*);
|
|
template void Val::constDispatch(OptOutConstDispatch*, const Val*);
|
|
template void Expr::constDispatch(OptOutConstDispatch&, const Expr*);
|
|
template void Expr::constDispatch(OptOutConstDispatch*, const Expr*);
|
|
|
|
template void Statement::constDispatch(OptInConstDispatch&, const Statement*);
|
|
template void Statement::constDispatch(OptInConstDispatch*, const Statement*);
|
|
template void Val::constDispatch(OptInConstDispatch&, const Val*);
|
|
template void Val::constDispatch(OptInConstDispatch*, const Val*);
|
|
template void Expr::constDispatch(OptInConstDispatch&, const Expr*);
|
|
template void Expr::constDispatch(OptInConstDispatch*, const Expr*);
|
|
|
|
template void Statement::mutatorDispatch(OptOutMutator&, Statement*);
|
|
template void Statement::mutatorDispatch(OptOutMutator*, Statement*);
|
|
template void Val::mutatorDispatch(OptOutMutator&, Val*);
|
|
template void Val::mutatorDispatch(OptOutMutator*, Val*);
|
|
template void Expr::mutatorDispatch(OptOutMutator&, Expr*);
|
|
template void Expr::mutatorDispatch(OptOutMutator*, Expr*);
|
|
|
|
void OptOutDispatch::handle(Statement* s) {
|
|
Statement::dispatch(this, s);
|
|
}
|
|
|
|
void OptOutDispatch::handle(Expr* e) {
|
|
Expr::dispatch(this, e);
|
|
}
|
|
|
|
void OptOutDispatch::handle(Val* v) {
|
|
Val::dispatch(this, v);
|
|
}
|
|
|
|
void OptOutConstDispatch::handle(const Statement* s) {
|
|
Statement::constDispatch(this, s);
|
|
}
|
|
|
|
void OptOutConstDispatch::handle(const Expr* e) {
|
|
Expr::constDispatch(this, e);
|
|
}
|
|
|
|
void OptOutConstDispatch::handle(const Val* v) {
|
|
Val::constDispatch(this, v);
|
|
}
|
|
|
|
void OptInConstDispatch::unhandled(const Statement* stmt) {
|
|
if (stmt->isExpr()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Handle not overriden for ", stmt->getExprType().value(), ".");
|
|
} else if (stmt->isVal()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Handle not overriden for ", stmt->getValType().value(), ".");
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(false, "Unrecognized statement type.");
|
|
}
|
|
}
|
|
|
|
void OptInDispatch::unhandled(Statement* stmt) {
|
|
if (stmt->isExpr()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Handle not overriden for ", stmt->getExprType().value(), ".");
|
|
} else if (stmt->isVal()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Handle not overriden for ", stmt->getValType().value(), ".");
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(false, "Unrecognized statement type.");
|
|
}
|
|
}
|
|
|
|
// Vals
|
|
void OptOutConstDispatch::handle(const Bool* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const Double* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const Int* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const ComplexDouble* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const NamedScalar* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const IterDomain* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const TensorDomain* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const TensorView* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
|
|
void OptOutConstDispatch::handle(const kir::Predicate* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::TensorIndex* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::IntPair* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
|
|
// Exprs
|
|
void OptOutConstDispatch::handle(const UnaryOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const BinaryOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const TernaryOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const ReductionOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const GroupedReductionOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const WelfordOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const LoadStoreOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const MmaOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const BroadcastOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
|
|
void OptOutConstDispatch::handle(const Split* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const Merge* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const Swizzle2D* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const TransposeOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const ExpandOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const ShiftOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const GatherOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const ViewAsScalar* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const ViewOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
|
|
void OptOutConstDispatch::handle(const kir::Allocate* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::BlockSync* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::GridSync* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::CpAsyncWait* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::CpAsyncCommit* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::InitMagicZero* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::UpdateMagicZero* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::ForLoop* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::IfThenElse* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::GridReduction* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::GroupedGridReduction* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::GridBroadcast* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::GridWelford* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::Swizzle2DInt* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutConstDispatch::handle(const kir::PairSelect* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
|
|
void OptOutDispatch::unhandled(Statement*) {}
|
|
|
|
// Vals
|
|
void OptOutDispatch::handle(Bool* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(Double* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(Int* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(ComplexDouble* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(NamedScalar* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(IterDomain* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(TensorDomain* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(TensorView* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
|
|
void OptOutDispatch::handle(kir::Predicate* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::TensorIndex* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::IntPair* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
|
|
// Exprs
|
|
void OptOutDispatch::handle(UnaryOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(BinaryOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(TernaryOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(ReductionOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(GroupedReductionOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(WelfordOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(LoadStoreOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(MmaOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(BroadcastOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
|
|
void OptOutDispatch::handle(Split* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(Merge* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(Swizzle2D* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(TransposeOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(ExpandOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(ShiftOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(GatherOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(ViewAsScalar* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(ViewOp* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
|
|
void OptOutDispatch::handle(kir::Allocate* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::BlockSync* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::GridSync* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::CpAsyncWait* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::CpAsyncCommit* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::InitMagicZero* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::UpdateMagicZero* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::ForLoop* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::IfThenElse* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::GridReduction* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::GroupedGridReduction* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::GridBroadcast* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::GridWelford* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::Swizzle2DInt* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
void OptOutDispatch::handle(kir::PairSelect* stmt) {
|
|
unhandled(stmt);
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|