mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: We've got quite a few things going on, preparing a push back to upstream so we don't get too desynced. - Major refactor of transform replay. It is now far more robust and fixes bugs discovered in reductions. Preparing for extension to explicit broadcast ops which will be the last major memory pattern for op coverage. Broadcast ops will allow us to express up to and potentially beyond norms and gemms. - Initial runtime expression evaluator. This allows us to evaluate expressions at runtime. Will be useful for determining our grid/block layout at runtime, so we don't have to manually compute them according to the code we're trying to generate. - Moving to int64 and double for scalar representations to match PyTorch JIT. - Improvements in codegen interface where we return Tensor like object instead of parent class Val. - Add `addcmul` and `lerp` ops - General updates, fixes, test additions, test inprovements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/39579 Differential Revision: D21974001 Pulled By: soumith fbshipit-source-id: 7f7ccc91593466e948f3ce90f8f9b7fbc5c28de2
431 lines
14 KiB
C++
431 lines
14 KiB
C++
#include <torch/csrc/jit/codegen/cuda/fusion.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/type.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/dispatch.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
|
|
template <typename T>
|
|
T* ptr(T& obj) {
|
|
return &obj;
|
|
}
|
|
|
|
template <typename T>
|
|
T* ptr(T* obj) {
|
|
return obj;
|
|
}
|
|
|
|
/*
|
|
* Generic dispatch for any handler that does not modify the IR directly.
|
|
* For example we may want to walk the graph to construct a topologically sorted
|
|
* set of exprs. This doesn't modify the IR directly. We also use this to print
|
|
* the IR itself.
|
|
* This dispatch is paired with a class that implements the functions:
|
|
* template <typenname node_type>
|
|
* int handler(node_type* node)
|
|
*
|
|
* handler should call:
|
|
* dispatch(this, node_to_dispatch)
|
|
*
|
|
* It could also implement:
|
|
* int handler(Statement* stmt){
|
|
* dispatch(this, stmt);
|
|
* }
|
|
*
|
|
* And therefore dispatch should never call:
|
|
* ptr(mutator)->handle(static_cast<Statement*>(this));
|
|
*/
|
|
|
|
template <typename T>
|
|
void Val::dispatch(T handler, Val* val) {
|
|
switch (*(val->getValType())) {
|
|
case ValType::Scalar:
|
|
switch (*(val->getDataType())) {
|
|
case DataType::Bool:
|
|
ptr(handler)->handle(static_cast<Bool*>(val));
|
|
return;
|
|
case DataType::Float:
|
|
ptr(handler)->handle(static_cast<Float*>(val));
|
|
return;
|
|
case DataType::Half:
|
|
ptr(handler)->handle(static_cast<Half*>(val));
|
|
return;
|
|
case DataType::Int:
|
|
ptr(handler)->handle(static_cast<Int*>(val));
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
case ValType::IterDomain:
|
|
ptr(handler)->handle(static_cast<IterDomain*>(val));
|
|
return;
|
|
case ValType::TensorDomain:
|
|
ptr(handler)->handle(static_cast<TensorDomain*>(val));
|
|
return;
|
|
case ValType::TensorView:
|
|
ptr(handler)->handle(static_cast<TensorView*>(val));
|
|
return;
|
|
case ValType::TensorIndex:
|
|
ptr(handler)->handle(static_cast<TensorIndex*>(val));
|
|
return;
|
|
case ValType::NamedScalar:
|
|
ptr(handler)->handle(static_cast<NamedScalar*>(val));
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
void Expr::dispatch(T handler, Expr* expr) {
|
|
switch (*(expr->getExprType())) {
|
|
case ExprType::Split:
|
|
ptr(handler)->handle(static_cast<Split*>(expr));
|
|
return;
|
|
case ExprType::Merge:
|
|
ptr(handler)->handle(static_cast<Merge*>(expr));
|
|
return;
|
|
case ExprType::UnaryOp:
|
|
ptr(handler)->handle(static_cast<UnaryOp*>(expr));
|
|
return;
|
|
case ExprType::BinaryOp:
|
|
ptr(handler)->handle(static_cast<BinaryOp*>(expr));
|
|
return;
|
|
case ExprType::TernaryOp:
|
|
ptr(handler)->handle(static_cast<TernaryOp*>(expr));
|
|
return;
|
|
case ExprType::ReductionOp:
|
|
ptr(handler)->handle(static_cast<ReductionOp*>(expr));
|
|
return;
|
|
case ExprType::BroadcastOp:
|
|
ptr(handler)->handle(static_cast<BroadcastOp*>(expr));
|
|
return;
|
|
case ExprType::ForLoop:
|
|
ptr(handler)->handle(static_cast<ForLoop*>(expr));
|
|
return;
|
|
case ExprType::IfThenElse:
|
|
ptr(handler)->handle(static_cast<IfThenElse*>(expr));
|
|
return;
|
|
case ExprType::Allocate:
|
|
ptr(handler)->handle(static_cast<Allocate*>(expr));
|
|
return;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void Statement::dispatch(T handler, Statement* stmt) {
|
|
if (stmt->isVal()) {
|
|
ptr(handler)->handle(static_cast<Val*>(stmt));
|
|
} else if (stmt->isExpr()) {
|
|
ptr(handler)->handle(static_cast<Expr*>(stmt));
|
|
} else
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
void Val::constDispatch(T handler, const Val* val) {
|
|
switch (*(val->getValType())) {
|
|
case ValType::Scalar:
|
|
switch (*(val->getDataType())) {
|
|
case DataType::Bool:
|
|
ptr(handler)->handle(static_cast<const Bool*>(val));
|
|
return;
|
|
case DataType::Float:
|
|
ptr(handler)->handle(static_cast<const Float*>(val));
|
|
return;
|
|
case DataType::Half:
|
|
ptr(handler)->handle(static_cast<const Half*>(val));
|
|
return;
|
|
case DataType::Int:
|
|
ptr(handler)->handle(static_cast<const Int*>(val));
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
case ValType::IterDomain:
|
|
ptr(handler)->handle(static_cast<const IterDomain*>(val));
|
|
return;
|
|
case ValType::TensorDomain:
|
|
ptr(handler)->handle(static_cast<const TensorDomain*>(val));
|
|
return;
|
|
case ValType::TensorView:
|
|
ptr(handler)->handle(static_cast<const TensorView*>(val));
|
|
return;
|
|
case ValType::TensorIndex:
|
|
ptr(handler)->handle(static_cast<const TensorIndex*>(val));
|
|
return;
|
|
case ValType::NamedScalar:
|
|
ptr(handler)->handle(static_cast<const NamedScalar*>(val));
|
|
return;
|
|
default:
|
|
break;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
void Expr::constDispatch(T handler, const Expr* expr) {
|
|
switch (*(expr->getExprType())) {
|
|
case ExprType::Split:
|
|
ptr(handler)->handle(static_cast<const Split*>(expr));
|
|
return;
|
|
case ExprType::Merge:
|
|
ptr(handler)->handle(static_cast<const Merge*>(expr));
|
|
return;
|
|
case ExprType::UnaryOp:
|
|
ptr(handler)->handle(static_cast<const UnaryOp*>(expr));
|
|
return;
|
|
case ExprType::BinaryOp:
|
|
ptr(handler)->handle(static_cast<const BinaryOp*>(expr));
|
|
return;
|
|
case ExprType::TernaryOp:
|
|
ptr(handler)->handle(static_cast<const TernaryOp*>(expr));
|
|
return;
|
|
case ExprType::ReductionOp:
|
|
ptr(handler)->handle(static_cast<const ReductionOp*>(expr));
|
|
return;
|
|
case ExprType::BroadcastOp:
|
|
ptr(handler)->handle(static_cast<const BroadcastOp*>(expr));
|
|
return;
|
|
case ExprType::ForLoop:
|
|
ptr(handler)->handle(static_cast<const ForLoop*>(expr));
|
|
return;
|
|
case ExprType::IfThenElse:
|
|
ptr(handler)->handle(static_cast<const IfThenElse*>(expr));
|
|
return;
|
|
case ExprType::Allocate:
|
|
ptr(handler)->handle(static_cast<const Allocate*>(expr));
|
|
return;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void Statement::constDispatch(T handler, const Statement* stmt) {
|
|
if (stmt->isVal()) {
|
|
ptr(handler)->handle(static_cast<const Val*>(stmt));
|
|
} else if (stmt->isExpr()) {
|
|
ptr(handler)->handle(static_cast<const Expr*>(stmt));
|
|
} else
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
|
|
}
|
|
|
|
/*
|
|
* Generic mutatorDispatch for any handler that modifies the IR. This could be
|
|
* a transformation on loop structures, or parallelizing a loop. This
|
|
* mutatorDispatch is paired with a class that implements the functions
|
|
* template <typenname node_type> Statement* mutate(node_type* node) mutate
|
|
* should call (statement* node_to_dispatch)->mutatorDispatch() It could also
|
|
* implement Statement* mutate(Statement* stmt){ stmt->mutatorDispatch(this);
|
|
* }
|
|
* And therefore dispatch should never call:
|
|
* ptr(mutator)->mutate(static_cast<Statement*>(this));
|
|
*/
|
|
template <typename T>
|
|
Statement* Val::mutatorDispatch(T mutator, Val* val) {
|
|
switch (*(val->getValType())) {
|
|
case ValType::Scalar:
|
|
switch (*(val->getDataType())) {
|
|
case DataType::Bool:
|
|
return ptr(mutator)->mutate(static_cast<Bool*>(val));
|
|
case DataType::Float:
|
|
return ptr(mutator)->mutate(static_cast<Float*>(val));
|
|
case DataType::Half:
|
|
return ptr(mutator)->mutate(static_cast<Half*>(val));
|
|
case DataType::Int:
|
|
return ptr(mutator)->mutate(static_cast<Int*>(val));
|
|
default:
|
|
break;
|
|
}
|
|
break;
|
|
case ValType::IterDomain:
|
|
return ptr(mutator)->mutate(static_cast<IterDomain*>(val));
|
|
case ValType::TensorDomain:
|
|
return ptr(mutator)->mutate(static_cast<TensorDomain*>(val));
|
|
case ValType::TensorView:
|
|
return ptr(mutator)->mutate(static_cast<TensorView*>(val));
|
|
case ValType::TensorIndex:
|
|
return ptr(mutator)->mutate(static_cast<TensorIndex*>(val));
|
|
case ValType::NamedScalar:
|
|
return ptr(mutator)->mutate(static_cast<NamedScalar*>(val));
|
|
default:
|
|
break;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown valtype in dispatch!");
|
|
}
|
|
|
|
template <typename T>
|
|
Statement* Expr::mutatorDispatch(T mutator, Expr* expr) {
|
|
switch (*(expr->getExprType())) {
|
|
case ExprType::Split:
|
|
return ptr(mutator)->mutate(static_cast<Split*>(expr));
|
|
case ExprType::Merge:
|
|
return ptr(mutator)->mutate(static_cast<Merge*>(expr));
|
|
case ExprType::UnaryOp:
|
|
return ptr(mutator)->mutate(static_cast<UnaryOp*>(expr));
|
|
case ExprType::BinaryOp:
|
|
return ptr(mutator)->mutate(static_cast<BinaryOp*>(expr));
|
|
case ExprType::TernaryOp:
|
|
return ptr(mutator)->mutate(static_cast<TernaryOp*>(expr));
|
|
case ExprType::ReductionOp:
|
|
return ptr(mutator)->mutate(static_cast<ReductionOp*>(expr));
|
|
case ExprType::BroadcastOp:
|
|
return ptr(mutator)->mutate(static_cast<BroadcastOp*>(expr));
|
|
case ExprType::ForLoop:
|
|
return ptr(mutator)->mutate(static_cast<ForLoop*>(expr));
|
|
case ExprType::IfThenElse:
|
|
return ptr(mutator)->mutate(static_cast<IfThenElse*>(expr));
|
|
case ExprType::Allocate:
|
|
return ptr(mutator)->mutate(static_cast<Allocate*>(expr));
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
Statement* Statement::mutatorDispatch(T mutator, Statement* stmt) {
|
|
if (stmt->isVal()) {
|
|
return ptr(mutator)->mutate(static_cast<Val*>(stmt));
|
|
}
|
|
if (stmt->isExpr()) {
|
|
return ptr(mutator)->mutate(static_cast<Expr*>(stmt));
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown stmttype in dispatch!");
|
|
}
|
|
|
|
/*
|
|
* Handler template instantiations. These should only have to be done on base
|
|
* classes. Actual visitors/mutators should inhereit from these classes and call
|
|
* ->dispatch(this) to avoid needing an explicit instantiation.
|
|
*/
|
|
template void Statement::dispatch(OptOutDispatch, Statement*);
|
|
template void Statement::dispatch(OptOutDispatch*, Statement*);
|
|
template void Val::dispatch(OptOutDispatch, Val*);
|
|
template void Val::dispatch(OptOutDispatch*, Val*);
|
|
template void Expr::dispatch(OptOutDispatch, Expr*);
|
|
template void Expr::dispatch(OptOutDispatch*, Expr*);
|
|
|
|
template void Statement::dispatch(OptInDispatch, Statement*);
|
|
template void Statement::dispatch(OptInDispatch*, Statement*);
|
|
template void Val::dispatch(OptInDispatch, Val*);
|
|
template void Val::dispatch(OptInDispatch*, Val*);
|
|
template void Expr::dispatch(OptInDispatch, Expr*);
|
|
template void Expr::dispatch(OptInDispatch*, Expr*);
|
|
|
|
template void Statement::constDispatch(
|
|
OptOutConstDispatch,
|
|
const Statement* const);
|
|
template void Statement::constDispatch(
|
|
OptOutConstDispatch*,
|
|
const Statement* const);
|
|
template void Val::constDispatch(OptOutConstDispatch, const Val* const);
|
|
template void Val::constDispatch(OptOutConstDispatch*, const Val* const);
|
|
template void Expr::constDispatch(OptOutConstDispatch, const Expr* const);
|
|
template void Expr::constDispatch(OptOutConstDispatch*, const Expr* const);
|
|
|
|
template void Statement::constDispatch(
|
|
OptInConstDispatch,
|
|
const Statement* const);
|
|
template void Statement::constDispatch(
|
|
OptInConstDispatch*,
|
|
const Statement* const);
|
|
template void Val::constDispatch(OptInConstDispatch, const Val* const);
|
|
template void Val::constDispatch(OptInConstDispatch*, const Val* const);
|
|
template void Expr::constDispatch(OptInConstDispatch, const Expr* const);
|
|
template void Expr::constDispatch(OptInConstDispatch*, const Expr* const);
|
|
|
|
template Statement* Statement::mutatorDispatch(OptOutMutator, Statement*);
|
|
template Statement* Statement::mutatorDispatch(OptOutMutator*, Statement*);
|
|
template Statement* Val::mutatorDispatch(OptOutMutator, Val*);
|
|
template Statement* Val::mutatorDispatch(OptOutMutator*, Val*);
|
|
template Statement* Expr::mutatorDispatch(OptOutMutator, Expr*);
|
|
template Statement* Expr::mutatorDispatch(OptOutMutator*, Expr*);
|
|
|
|
template Statement* Statement::mutatorDispatch(OptInMutator, Statement*);
|
|
template Statement* Statement::mutatorDispatch(OptInMutator*, Statement*);
|
|
template Statement* Val::mutatorDispatch(OptInMutator, Val*);
|
|
template Statement* Val::mutatorDispatch(OptInMutator*, Val*);
|
|
template Statement* Expr::mutatorDispatch(OptInMutator, Expr*);
|
|
template Statement* Expr::mutatorDispatch(OptInMutator*, Expr*);
|
|
|
|
void OptOutDispatch::handle(Statement* s) {
|
|
Statement::dispatch(this, s);
|
|
}
|
|
void OptOutDispatch::handle(Expr* e) {
|
|
Expr::dispatch(this, e);
|
|
}
|
|
void OptOutDispatch::handle(Val* v) {
|
|
Val::dispatch(this, v);
|
|
}
|
|
|
|
void OptInDispatch::handle(Statement* s) {
|
|
Statement::dispatch(this, s);
|
|
}
|
|
void OptInDispatch::handle(Expr* e) {
|
|
Expr::dispatch(this, e);
|
|
}
|
|
void OptInDispatch::handle(Val* v) {
|
|
Val::dispatch(this, v);
|
|
}
|
|
|
|
void OptOutConstDispatch::handle(const Statement* const s) {
|
|
Statement::constDispatch(this, s);
|
|
}
|
|
void OptOutConstDispatch::handle(const Expr* const e) {
|
|
Expr::constDispatch(this, e);
|
|
}
|
|
void OptOutConstDispatch::handle(const Val* const v) {
|
|
Val::constDispatch(this, v);
|
|
}
|
|
|
|
void OptInConstDispatch::handle(const Statement* const s) {
|
|
Statement::constDispatch(this, s);
|
|
}
|
|
void OptInConstDispatch::handle(const Expr* const e) {
|
|
Expr::constDispatch(this, e);
|
|
}
|
|
void OptInConstDispatch::handle(const Val* const v) {
|
|
Val::constDispatch(this, v);
|
|
}
|
|
|
|
Statement* OptInMutator::mutate(Statement* s) {
|
|
return Statement::mutatorDispatch(this, s);
|
|
}
|
|
|
|
Statement* OptInMutator::mutate(Expr* e) {
|
|
return Expr::mutatorDispatch(this, e);
|
|
}
|
|
|
|
Statement* OptInMutator::mutate(Val* v) {
|
|
// If value is already mutated, return the mutation
|
|
if (mutations.find(v) != mutations.end())
|
|
return mutations[v];
|
|
return Val::mutatorDispatch(this, v);
|
|
}
|
|
|
|
Statement* OptOutMutator::mutate(Statement* s) {
|
|
return Statement::mutatorDispatch(this, s);
|
|
}
|
|
Statement* OptOutMutator::mutate(Expr* e) {
|
|
return Expr::mutatorDispatch(this, e);
|
|
}
|
|
Statement* OptOutMutator::mutate(Val* v) {
|
|
// If value is already mutated, return the mutation
|
|
if (mutations.find(v) != mutations.end())
|
|
return mutations[v];
|
|
return Val::mutatorDispatch(this, v);
|
|
}
|
|
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|