mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Adds reduction support for the code generator. Reductions are fully supported with split/merge/reorder/rfactor/computeAt/unroll operators. There is also cross thread (intra-block) reduction support. The two remaining pieces missing for reduction support is: - Safety: If cross thread reduction was used, child operators shouldn't be able to bind that thread dim anymore - Cross block reduction: we will want inter-block reduction support to match parity with tensor iterator PR also provides FP16 support for fusions now. We insert casts on FP16 inputs to FP32, and we insert casts to FP16 on FP16 outputs. Also working towards reductions and shape inference for reductions in the fusion pass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/38627 Reviewed By: albanD Differential Revision: D21663196 Pulled By: soumith fbshipit-source-id: 3ff2df563f86c39cd5821ab9c1148149e5172a9e
431 lines
14 KiB
C++
431 lines
14 KiB
C++
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/type.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
|
|
template <typename T>
|
|
T* ptr(T& obj) {
|
|
return &obj;
|
|
}
|
|
|
|
template <typename T>
|
|
T* ptr(T* obj) {
|
|
return obj;
|
|
}
|
|
|
|
/*
|
|
* Generic dispatch for any handler that does not modify the IR directly.
|
|
* For example we may want to walk the graph to construct a topologically sorted
|
|
* set of exprs. This doesn't modify the IR directly. We also use this to print
|
|
* the IR itself.
|
|
* This dispatch is paired with a class that implements the functions:
|
|
* template <typenname node_type>
|
|
* int handler(node_type* node)
|
|
*
|
|
* handler should call:
|
|
* dispatch(this, node_to_dispatch)
|
|
*
|
|
* It could also implement:
|
|
* int handler(Statement* stmt){
|
|
* dispatch(this, stmt);
|
|
* }
|
|
*
|
|
* And therefore dispatch should never call:
|
|
* ptr(mutator)->handle(static_cast<Statement*>(this));
|
|
*/
|
|
|
|
template <typename T>
|
|
void Val::dispatch(T handler, Val* val) {
|
|
switch (*(val->getValType())) {
|
|
case ValType::Scalar:
|
|
switch (*(val->getDataType())) {
|
|
case DataType::Bool:
|
|
ptr(handler)->handle(static_cast<Bool*>(val));
|
|
return;
|
|
case DataType::Float:
|
|
ptr(handler)->handle(static_cast<Float*>(val));
|
|
return;
|
|
case DataType::Half:
|
|
ptr(handler)->handle(static_cast<Half*>(val));
|
|
return;
|
|
case DataType::Int:
|
|
ptr(handler)->handle(static_cast<Int*>(val));
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
case ValType::IterDomain:
|
|
ptr(handler)->handle(static_cast<IterDomain*>(val));
|
|
return;
|
|
case ValType::TensorDomain:
|
|
ptr(handler)->handle(static_cast<TensorDomain*>(val));
|
|
return;
|
|
case ValType::TensorView:
|
|
ptr(handler)->handle(static_cast<TensorView*>(val));
|
|
return;
|
|
case ValType::TensorIndex:
|
|
ptr(handler)->handle(static_cast<TensorIndex*>(val));
|
|
return;
|
|
case ValType::NamedScalar:
|
|
ptr(handler)->handle(static_cast<NamedScalar*>(val));
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
void Expr::dispatch(T handler, Expr* expr) {
|
|
switch (*(expr->getExprType())) {
|
|
case ExprType::Split:
|
|
ptr(handler)->handle(static_cast<Split*>(expr));
|
|
return;
|
|
case ExprType::Merge:
|
|
ptr(handler)->handle(static_cast<Merge*>(expr));
|
|
return;
|
|
case ExprType::Reorder:
|
|
ptr(handler)->handle(static_cast<Reorder*>(expr));
|
|
return;
|
|
case ExprType::UnaryOp:
|
|
ptr(handler)->handle(static_cast<UnaryOp*>(expr));
|
|
return;
|
|
case ExprType::BinaryOp:
|
|
ptr(handler)->handle(static_cast<BinaryOp*>(expr));
|
|
return;
|
|
case ExprType::TernaryOp:
|
|
ptr(handler)->handle(static_cast<TernaryOp*>(expr));
|
|
return;
|
|
case ExprType::ReductionOp:
|
|
ptr(handler)->handle(static_cast<ReductionOp*>(expr));
|
|
return;
|
|
case ExprType::ForLoop:
|
|
ptr(handler)->handle(static_cast<ForLoop*>(expr));
|
|
return;
|
|
case ExprType::IfThenElse:
|
|
ptr(handler)->handle(static_cast<IfThenElse*>(expr));
|
|
return;
|
|
case ExprType::Allocate:
|
|
ptr(handler)->handle(static_cast<Allocate*>(expr));
|
|
return;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void Statement::dispatch(T handler, Statement* stmt) {
|
|
if (stmt->isVal()) {
|
|
ptr(handler)->handle(static_cast<Val*>(stmt));
|
|
} else if (stmt->isExpr()) {
|
|
ptr(handler)->handle(static_cast<Expr*>(stmt));
|
|
} else
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
void Val::constDispatch(T handler, const Val* const val) {
|
|
switch (*(val->getValType())) {
|
|
case ValType::Scalar:
|
|
switch (*(val->getDataType())) {
|
|
case DataType::Bool:
|
|
ptr(handler)->handle(static_cast<const Bool* const>(val));
|
|
return;
|
|
case DataType::Float:
|
|
ptr(handler)->handle(static_cast<const Float*>(val));
|
|
return;
|
|
case DataType::Half:
|
|
ptr(handler)->handle(static_cast<const Half* const>(val));
|
|
return;
|
|
case DataType::Int:
|
|
ptr(handler)->handle(static_cast<const Int*>(val));
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
case ValType::IterDomain:
|
|
ptr(handler)->handle(static_cast<const IterDomain*>(val));
|
|
return;
|
|
case ValType::TensorDomain:
|
|
ptr(handler)->handle(static_cast<const TensorDomain*>(val));
|
|
return;
|
|
case ValType::TensorView:
|
|
ptr(handler)->handle(static_cast<const TensorView*>(val));
|
|
return;
|
|
case ValType::TensorIndex:
|
|
ptr(handler)->handle(static_cast<const TensorIndex*>(val));
|
|
return;
|
|
case ValType::NamedScalar:
|
|
ptr(handler)->handle(static_cast<const NamedScalar*>(val));
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
void Expr::constDispatch(T handler, const Expr* expr) {
|
|
switch (*(expr->getExprType())) {
|
|
case ExprType::Split:
|
|
ptr(handler)->handle(static_cast<const Split*>(expr));
|
|
return;
|
|
case ExprType::Merge:
|
|
ptr(handler)->handle(static_cast<const Merge*>(expr));
|
|
return;
|
|
case ExprType::Reorder:
|
|
ptr(handler)->handle(static_cast<const Reorder*>(expr));
|
|
return;
|
|
case ExprType::UnaryOp:
|
|
ptr(handler)->handle(static_cast<const UnaryOp*>(expr));
|
|
return;
|
|
case ExprType::BinaryOp:
|
|
ptr(handler)->handle(static_cast<const BinaryOp*>(expr));
|
|
return;
|
|
case ExprType::TernaryOp:
|
|
ptr(handler)->handle(static_cast<const TernaryOp* const>(expr));
|
|
return;
|
|
case ExprType::ReductionOp:
|
|
ptr(handler)->handle(static_cast<const ReductionOp* const>(expr));
|
|
return;
|
|
case ExprType::ForLoop:
|
|
ptr(handler)->handle(static_cast<const ForLoop*>(expr));
|
|
return;
|
|
case ExprType::IfThenElse:
|
|
ptr(handler)->handle(static_cast<const IfThenElse*>(expr));
|
|
return;
|
|
case ExprType::Allocate:
|
|
ptr(handler)->handle(static_cast<const Allocate*>(expr));
|
|
return;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void Statement::constDispatch(T handler, const Statement* stmt) {
|
|
if (stmt->isVal()) {
|
|
ptr(handler)->handle(static_cast<const Val*>(stmt));
|
|
} else if (stmt->isExpr()) {
|
|
ptr(handler)->handle(static_cast<const Expr*>(stmt));
|
|
} else
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
|
|
}
|
|
|
|
/*
|
|
* Generic mutatorDispatch for any handler that modifies the IR. This could be
|
|
* a transformation on loop structures, or parallelizing a loop. This
|
|
* mutatorDispatch is paired with a class that implements the functions
|
|
* template <typenname node_type> Statement* mutate(node_type* node) mutate
|
|
* should call (statement* node_to_dispatch)->mutatorDispatch() It could also
|
|
* implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this);
|
|
* }
|
|
* And therefore dispatch should never call:
|
|
* ptr(mutator)->mutate(static_cast<Statement*>(this));
|
|
*/
|
|
template <typename T>
|
|
Statement* Val::mutatorDispatch(T mutator, Val* val) {
|
|
switch (*(val->getValType())) {
|
|
case ValType::Scalar:
|
|
switch (*(val->getDataType())) {
|
|
case DataType::Bool:
|
|
return ptr(mutator)->mutate(static_cast<Bool*>(val));
|
|
case DataType::Float:
|
|
return ptr(mutator)->mutate(static_cast<Float*>(val));
|
|
case DataType::Half:
|
|
return ptr(mutator)->mutate(static_cast<Half*>(val));
|
|
case DataType::Int:
|
|
return ptr(mutator)->mutate(static_cast<Int*>(val));
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
case ValType::IterDomain:
|
|
return ptr(mutator)->mutate(static_cast<IterDomain*>(val));
|
|
case ValType::TensorDomain:
|
|
return ptr(mutator)->mutate(static_cast<TensorDomain*>(val));
|
|
case ValType::TensorView:
|
|
return ptr(mutator)->mutate(static_cast<TensorView*>(val));
|
|
case ValType::TensorIndex:
|
|
return ptr(mutator)->mutate(static_cast<TensorIndex*>(val));
|
|
case ValType::NamedScalar:
|
|
return ptr(mutator)->mutate(static_cast<NamedScalar*>(val));
|
|
default:
|
|
break;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
|
|
switch (*(expr->getExprType())) {
|
|
case ExprType::Split:
|
|
return ptr(mutator)->mutate(static_cast<Split*>(expr));
|
|
case ExprType::Merge:
|
|
return ptr(mutator)->mutate(static_cast<Merge*>(expr));
|
|
case ExprType::Reorder:
|
|
return ptr(mutator)->mutate(static_cast<Reorder*>(expr));
|
|
case ExprType::UnaryOp:
|
|
return ptr(mutator)->mutate(static_cast<UnaryOp*>(expr));
|
|
case ExprType::BinaryOp:
|
|
return ptr(mutator)->mutate(static_cast<BinaryOp*>(expr));
|
|
case ExprType::TernaryOp:
|
|
return ptr(mutator)->mutate(static_cast<TernaryOp*>(expr));
|
|
case ExprType::ReductionOp:
|
|
return ptr(mutator)->mutate(static_cast<ReductionOp*>(expr));
|
|
case ExprType::ForLoop:
|
|
return ptr(mutator)->mutate(static_cast<ForLoop*>(expr));
|
|
case ExprType::IfThenElse:
|
|
return ptr(mutator)->mutate(static_cast<IfThenElse*>(expr));
|
|
case ExprType::Allocate:
|
|
return ptr(mutator)->mutate(static_cast<Allocate*>(expr));
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) {
|
|
if (stmt->isVal()) {
|
|
return ptr(mutator)->mutate(static_cast<Val*>(stmt));
|
|
}
|
|
if (stmt->isExpr()) {
|
|
return ptr(mutator)->mutate(static_cast<Expr*>(stmt));
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
|
|
}
|
|
|
|
/*
|
|
* Handler template instantiations. These should only have to be done on base
|
|
* classes. Actual visitors/mutators should inhereit from these classes and call
|
|
* ->dispatch(this) to avoid needing an explicit instantiation.
|
|
*/
|
|
template void Statement::dispatch(OptOutDispatch, Statement*);
|
|
template void Statement::dispatch(OptOutDispatch*, Statement*);
|
|
template void Val::dispatch(OptOutDispatch, Val*);
|
|
template void Val::dispatch(OptOutDispatch*, Val*);
|
|
template void Expr::dispatch(OptOutDispatch, Expr*);
|
|
template void Expr::dispatch(OptOutDispatch*, Expr*);
|
|
|
|
template void Statement::dispatch(OptInDispatch, Statement*);
|
|
template void Statement::dispatch(OptInDispatch*, Statement*);
|
|
template void Val::dispatch(OptInDispatch, Val*);
|
|
template void Val::dispatch(OptInDispatch*, Val*);
|
|
template void Expr::dispatch(OptInDispatch, Expr*);
|
|
template void Expr::dispatch(OptInDispatch*, Expr*);
|
|
|
|
template void Statement::constDispatch(
|
|
OptOutConstDispatch,
|
|
const Statement* const);
|
|
template void Statement::constDispatch(
|
|
OptOutConstDispatch*,
|
|
const Statement* const);
|
|
template void Val::constDispatch(OptOutConstDispatch, const Val* const);
|
|
template void Val::constDispatch(OptOutConstDispatch*, const Val* const);
|
|
template void Expr::constDispatch(OptOutConstDispatch, const Expr* const);
|
|
template void Expr::constDispatch(OptOutConstDispatch*, const Expr* const);
|
|
|
|
template void Statement::constDispatch(
|
|
OptInConstDispatch,
|
|
const Statement* const);
|
|
template void Statement::constDispatch(
|
|
OptInConstDispatch*,
|
|
const Statement* const);
|
|
template void Val::constDispatch(OptInConstDispatch, const Val* const);
|
|
template void Val::constDispatch(OptInConstDispatch*, const Val* const);
|
|
template void Expr::constDispatch(OptInConstDispatch, const Expr* const);
|
|
template void Expr::constDispatch(OptInConstDispatch*, const Expr* const);
|
|
|
|
template Statement* Statement::mutatorDispatch(OptOutMutator, Statement*);
|
|
template Statement* Statement::mutatorDispatch(OptOutMutator*, Statement*);
|
|
template Statement* Val::mutatorDispatch(OptOutMutator, Val*);
|
|
template Statement* Val::mutatorDispatch(OptOutMutator*, Val*);
|
|
template Statement* Expr::mutatorDispatch(OptOutMutator, Expr*);
|
|
template Statement* Expr::mutatorDispatch(OptOutMutator*, Expr*);
|
|
|
|
template Statement* Statement::mutatorDispatch(OptInMutator, Statement*);
|
|
template Statement* Statement::mutatorDispatch(OptInMutator*, Statement*);
|
|
template Statement* Val::mutatorDispatch(OptInMutator, Val*);
|
|
template Statement* Val::mutatorDispatch(OptInMutator*, Val*);
|
|
template Statement* Expr::mutatorDispatch(OptInMutator, Expr*);
|
|
template Statement* Expr::mutatorDispatch(OptInMutator*, Expr*);
|
|
|
|
void OptOutDispatch::handle(Statement* s) {
|
|
Statement::dispatch(this, s);
|
|
}
|
|
void OptOutDispatch::handle(Expr* e) {
|
|
Expr::dispatch(this, e);
|
|
}
|
|
void OptOutDispatch::handle(Val* v) {
|
|
Val::dispatch(this, v);
|
|
}
|
|
|
|
void OptInDispatch::handle(Statement* s) {
|
|
Statement::dispatch(this, s);
|
|
}
|
|
void OptInDispatch::handle(Expr* e) {
|
|
Expr::dispatch(this, e);
|
|
}
|
|
void OptInDispatch::handle(Val* v) {
|
|
Val::dispatch(this, v);
|
|
}
|
|
|
|
void OptOutConstDispatch::handle(const Statement* const s) {
|
|
Statement::constDispatch(this, s);
|
|
}
|
|
void OptOutConstDispatch::handle(const Expr* const e) {
|
|
Expr::constDispatch(this, e);
|
|
}
|
|
void OptOutConstDispatch::handle(const Val* const v) {
|
|
Val::constDispatch(this, v);
|
|
}
|
|
|
|
void OptInConstDispatch::handle(const Statement* const s) {
|
|
Statement::constDispatch(this, s);
|
|
}
|
|
void OptInConstDispatch::handle(const Expr* const e) {
|
|
Expr::constDispatch(this, e);
|
|
}
|
|
void OptInConstDispatch::handle(const Val* const v) {
|
|
Val::constDispatch(this, v);
|
|
}
|
|
|
|
Statement* OptInMutator::mutate(Statement* s) {
|
|
return Statement::mutatorDispatch(this, s);
|
|
}
|
|
|
|
Statement* OptInMutator::mutate(Expr* e) {
|
|
return Expr::mutatorDispatch(this, e);
|
|
}
|
|
|
|
Statement* OptInMutator::mutate(Val* v) {
|
|
// If value is already mutated, return the mutation
|
|
if (mutations.find(v) != mutations.end())
|
|
return mutations[v];
|
|
return Val::mutatorDispatch(this, v);
|
|
}
|
|
|
|
Statement* OptOutMutator::mutate(Statement* s) {
|
|
return Statement::mutatorDispatch(this, s);
|
|
}
|
|
Statement* OptOutMutator::mutate(Expr* e) {
|
|
return Expr::mutatorDispatch(this, e);
|
|
}
|
|
Statement* OptOutMutator::mutate(Val* v) {
|
|
// If value is already mutated, return the mutation
|
|
if (mutations.find(v) != mutations.end())
|
|
return mutations[v];
|
|
return Val::mutatorDispatch(this, v);
|
|
}
|
|
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|