mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: **Summary:** This PR contains the infrastructure of a new CUDA fuser. This CUDA fuser is based on many of the same principles of TensorExpressions and Halide, however the implementation is ground up. The fusion pass itself is similar to the default CUDA fuser, however, it has undergone some refactoring and is using the new code generation infrastructure. For those who are interested in how the code generation in this PR works, I would recommend reviewing _test/cpp/jit/test_gpu_fusion.cpp_ as well as the long comment section at the beginning of _torch/csrc/jit/codegen/cuda/transform_replay.h_ One of the largest differences between our approach and that of TVM/Halide, is the concept of "TensorView". TensorView from a high level should be thought of similarly to how we think of working with Tensors in PyTorch. It's an N-D object which can undergo transformations that change its dimensionality. Dimensionality changes are done through the operations split/merge/reorder/computeAt. These transformations are similar to split/fuse/reorder/compute_at of TVM, they modify how a tensor is iterated over to generate GPU code. Interestingly, in our scheme these transformations are applied to tensors and only impact how that tensor is generated. **Warning:** This PR is purposefully not feature complete with the current fuser. We wanted to separate out the infrastructure from the fusion capabilities. Once in, smaller incremental PRs will be submitted to expand capabilities of the fuser. **Short term goals:** Parity with current CUDA fuser (including performance): - Dynamic shapes (no recompilation) - Implicit handling of braodcast (broadcasted tensors are treated as tensors of the braodcasted size in the generated code) - Dropout **Mid-term goals:** - Transposes fused with pointwise operations where transpose involves only 2 axes (across the fused operation). - 1-D reductions fused with pointwise operations Pull Request resolved: https://github.com/pytorch/pytorch/pull/34785 Reviewed By: ZolotukhinM Differential Revision: D20650977 Pulled By: soumith fbshipit-source-id: ee39c95a880e1b9822e874ed4cc180971572bf63
290 lines
9.0 KiB
C++
290 lines
9.0 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
/*
|
|
* dispatch.h prevents the need from adding manual dispatch in every class that
|
|
* wants to define how to process a series of nodes. dispatch.h provides 4
|
|
* classes that can be inherited providing a means to override functions on a
|
|
* per-node basis. There are currently 4 provided dispatch mechanisms:
|
|
*
|
|
* OptOutDispatch:
|
|
*
|
|
* provides the functions:
|
|
* virtual void handle(ValType* irnode){}
|
|
*
|
|
* This provides a mechanisms to override this handle for particular node
|
|
* types. For example if we only wanted to actually run a function on
|
|
* BinaryOps, we could inherit OptOutDispatch and simply override: void
|
|
* handle(BinaryOp*) { doSomething; } Then we could run through all our
|
|
* Statement* and call OptOutDispatch::handle(statement). When a BinaryOp is
|
|
* encountered our override function will be called. For every other node,
|
|
* nothing will be done.
|
|
*
|
|
* OptInDispatch:
|
|
*
|
|
* This class is similar to OptOutDispatch, however if we encounter a node
|
|
* that we haven't specified an override for in the derived class, an error
|
|
* will be thrown. This is useful if we create a class that is expected to
|
|
* handle any type of node it encounters.
|
|
*
|
|
* OptOutMutator:
|
|
*
|
|
* This class is similar to OptOutDispatch except the functions provided are of
|
|
* type: virtual Statement* mutate(Statement*) this is useful for when we want
|
|
* to have an IR node result from our overloaded functions.
|
|
*
|
|
* OptInMutator:
|
|
*
|
|
* This class is similar to OptInDispatch except the functions provided are of
|
|
* type: virtual Statement* mutate(Statement*) this is useful for when we want
|
|
* to have an IR node result from our overloaded functions.
|
|
*/
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
|
|
struct Fusion;
|
|
|
|
// Hierarchal dispatch functions for handle
|
|
struct Statement;
|
|
struct Expr;
|
|
struct Val;
|
|
|
|
// Vals
|
|
struct IterDomain;
|
|
struct TensorDomain;
|
|
struct TensorView;
|
|
struct Float;
|
|
struct Int;
|
|
|
|
// Exprs
|
|
struct Split;
|
|
struct Merge;
|
|
struct Reorder;
|
|
struct UnaryOp;
|
|
struct BinaryOp;
|
|
|
|
/*
|
|
* By default, all IR nodes are handled in this dispatch, and will call an empty
|
|
* function on all nodes.
|
|
*/
|
|
struct TORCH_CUDA_API OptOutDispatch {
|
|
virtual ~OptOutDispatch() = default;
|
|
OptOutDispatch() = default;
|
|
|
|
OptOutDispatch(const OptOutDispatch& other) = default;
|
|
OptOutDispatch& operator=(const OptOutDispatch& other) = default;
|
|
|
|
OptOutDispatch(OptOutDispatch&& other) = default;
|
|
OptOutDispatch& operator=(OptOutDispatch&& other) = default;
|
|
|
|
// Hierarchal dispatch functions for handle
|
|
virtual void handle(Statement*);
|
|
virtual void handle(Expr*);
|
|
virtual void handle(Val*);
|
|
|
|
// Vals
|
|
virtual void handle(IterDomain*) {}
|
|
virtual void handle(TensorDomain*) {}
|
|
virtual void handle(TensorView*) {}
|
|
virtual void handle(Float*) {}
|
|
virtual void handle(Int*) {}
|
|
|
|
// Exprs
|
|
virtual void handle(Split*) {}
|
|
virtual void handle(Merge*) {}
|
|
virtual void handle(Reorder*) {}
|
|
virtual void handle(UnaryOp*) {}
|
|
virtual void handle(BinaryOp*) {}
|
|
};
|
|
|
|
struct TORCH_CUDA_API OptInConstDispatch {
|
|
virtual ~OptInConstDispatch() = default;
|
|
OptInConstDispatch() = default;
|
|
|
|
OptInConstDispatch(const OptInConstDispatch& other) = default;
|
|
OptInConstDispatch& operator=(const OptInConstDispatch& other) = default;
|
|
|
|
OptInConstDispatch(OptInConstDispatch&& other) = default;
|
|
OptInConstDispatch& operator=(OptInConstDispatch&& other) = default;
|
|
|
|
// Hierarchal dispatch functions for handle
|
|
virtual void handle(const Statement* const);
|
|
virtual void handle(const Expr* const);
|
|
virtual void handle(const Val* const);
|
|
|
|
// Vals
|
|
virtual void handle(const IterDomain* const) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IterDomain.");
|
|
}
|
|
virtual void handle(const TensorDomain* const) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorDomain.");
|
|
}
|
|
virtual void handle(const TensorView* const) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorView.");
|
|
}
|
|
virtual void handle(const Float* const) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float.");
|
|
}
|
|
virtual void handle(const Int* const) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int.");
|
|
}
|
|
|
|
// Exprs
|
|
virtual void handle(const Split* const) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Split.");
|
|
}
|
|
virtual void handle(const Merge* const) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge.");
|
|
}
|
|
virtual void handle(const Reorder* const) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Reorder.");
|
|
}
|
|
virtual void handle(const UnaryOp* const) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp.");
|
|
}
|
|
virtual void handle(const BinaryOp* const) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp.");
|
|
}
|
|
};
|
|
|
|
struct TORCH_CUDA_API OptInDispatch {
|
|
virtual ~OptInDispatch() = default;
|
|
OptInDispatch() = default;
|
|
|
|
OptInDispatch(const OptInDispatch& other) = default;
|
|
OptInDispatch& operator=(const OptInDispatch& other) = default;
|
|
|
|
OptInDispatch(OptInDispatch&& other) = default;
|
|
OptInDispatch& operator=(OptInDispatch&& other) = default;
|
|
|
|
// Hierarchal dispatch functions for handle
|
|
virtual void handle(Statement* s);
|
|
virtual void handle(Expr* e);
|
|
virtual void handle(Val* v);
|
|
|
|
// Vals
|
|
virtual void handle(IterDomain*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IterDomain.");
|
|
}
|
|
virtual void handle(TensorDomain*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorDomain.");
|
|
}
|
|
virtual void handle(TensorView*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorView.");
|
|
}
|
|
virtual void handle(Float*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float.");
|
|
}
|
|
virtual void handle(Int*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int.");
|
|
}
|
|
|
|
// Exprs
|
|
virtual void handle(Split*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Split.");
|
|
}
|
|
virtual void handle(Merge*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge.");
|
|
}
|
|
virtual void handle(Reorder*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Reorder.");
|
|
}
|
|
virtual void handle(UnaryOp*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp.");
|
|
}
|
|
virtual void handle(BinaryOp*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp.");
|
|
}
|
|
};
|
|
|
|
struct TORCH_CUDA_API OptOutMutator {
|
|
virtual ~OptOutMutator() = default;
|
|
OptOutMutator() = default;
|
|
|
|
OptOutMutator(const OptOutMutator& other) = default;
|
|
OptOutMutator& operator=(const OptOutMutator& other) = default;
|
|
|
|
OptOutMutator(OptOutMutator&& other) = default;
|
|
OptOutMutator& operator=(OptOutMutator&& other) = default;
|
|
|
|
virtual void mutate(Fusion* fusion);
|
|
|
|
// Hierarchal dispatch functions for handle
|
|
virtual Statement* mutate(Statement* s);
|
|
virtual Statement* mutate(Expr* e);
|
|
virtual Statement* mutate(Val* v);
|
|
|
|
//****Functions below defined in mutator.cpp*****///
|
|
// Vals
|
|
virtual Statement* mutate(IterDomain*);
|
|
virtual Statement* mutate(TensorDomain*);
|
|
virtual Statement* mutate(TensorView*);
|
|
virtual Statement* mutate(Float*);
|
|
virtual Statement* mutate(Int*);
|
|
|
|
// Exprs
|
|
virtual Statement* mutate(Split*);
|
|
virtual Statement* mutate(Merge*);
|
|
virtual Statement* mutate(Reorder*);
|
|
virtual Statement* mutate(UnaryOp*);
|
|
virtual Statement* mutate(BinaryOp*);
|
|
};
|
|
|
|
struct TORCH_CUDA_API OptInMutator {
|
|
virtual ~OptInMutator() = default;
|
|
OptInMutator() = default;
|
|
|
|
OptInMutator(const OptInMutator& other) = default;
|
|
OptInMutator& operator=(const OptInMutator& other) = default;
|
|
|
|
OptInMutator(OptInMutator&& other) = default;
|
|
OptInMutator& operator=(OptInMutator&& other) = default;
|
|
|
|
// Hierarchal dispatch functions for mutate
|
|
virtual Statement* mutate(Statement*);
|
|
virtual Statement* mutate(Expr*);
|
|
virtual Statement* mutate(Val*);
|
|
|
|
// Vals
|
|
virtual Statement* mutate(IterDomain*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for IterDomain.");
|
|
}
|
|
virtual Statement* mutate(TensorDomain*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TensorDomain.");
|
|
}
|
|
virtual Statement* mutate(TensorView*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for TensorView.");
|
|
}
|
|
virtual Statement* mutate(Float*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Float.");
|
|
}
|
|
virtual Statement* mutate(Int*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Int.");
|
|
}
|
|
|
|
// Exprs
|
|
virtual Statement* mutate(Split*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Split.");
|
|
}
|
|
virtual Statement* mutate(Merge*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Merge.");
|
|
}
|
|
virtual Statement* mutate(Reorder*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Reorder.");
|
|
}
|
|
virtual Statement* mutate(UnaryOp*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for UnaryOp.");
|
|
}
|
|
virtual Statement* mutate(BinaryOp*) {
|
|
TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for BinaryOp.");
|
|
}
|
|
};
|
|
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|