mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Codegen changes include: - codegen improvement: i. improved view support on pointwise and transpose scheduler ii. grouped grid welford added for better outer-norm grid persistence in normalization - misc: i. new composite ops added: variance_mean , arange, ii. fixes misaligned address for transpose scheduler iii. refactor on separation of compilation API from execution API to prepare us for async compilation iv. double type support on expression evaluator v. PYTORCH_NVFUSER_DUMP refactor to save PTX and CUBIN Commits that's in this PR from the devel branch: ``` 89330aa23aa804340b2406ab58899d816e3dc3d2 Tensor factories must set the output shape as its input (#1939) b2fd01ea9346712c6d6f623ca6addbc4888d008e arange support (#1933) 56c00fd3922dad7dfc57351ad7d780f0f2f8e4ed Double support on all expression evaluators (#1937) 371f28223e57fe3f6b5e50a0a45177e6a5c0785c Improve trivial reduction merge support (#1931) 1d0c26790e5647920b40d419d26815bbe310b3a6 Test `rand` in a fusion with zero tensor input (#1932) 0dab160fb2177d178eef3148c6a529e0855009e9 Fix softmax bwd sizes. (#1890) ef98f360f6d3e3e1cc662ecb65202d88150f128d Fix a bug (#1936) 63132a0c56508c550084b07fb76a3df865102d00 Propagate permissive mapping information into indexing pass (#1929) b4ac2c88d78078ee4d8b21c4fc51645b5710a282 Map IterationDomains through view operations. (#1919) c0a187a7619d7cf9dc920294e15461791e8d6d4d do not use deprecated functions (#1935) 88de85e758c5e4afb7b6e746573c0d9a53b4cea7 Upstream cherry pick fixes 0811 (#1934) b247dcf7c57dc6ac3f7a799b0a6beb7770536a74 Separate kernel compilation API from kernel execution API (#1914) b34e3b93ee1a8030730c14af3995dd95665af07d Fix `ir_utils::hasBlockSync` + misc fixes in transpose scheduler (#1924) 14a53e6707f43bf760494c238a46386d69830822 Nullary RNGOp (#1892) 3c3c89e638f5172cafb0761f22bacd1fd695eec3 Misc fixes/tuning for transpose scheduler (#1912) 20cf109c8b44d48f61977e35bae94368985144ac Grouped grid welford (#1921) 6cf7eb024c9e53c358cbe56597e117bad56efefd Transpose scheduler small dim sizes better support (#1910) 9341ea9a5bf42f9b14ccad0c94edbc79fc5bb552 Disabled ViewPersistentShmoo sizes that results in NAN (#1922) 057237f66deeea816bb943d802a97c1b7e4414ab Fix CUDA driver error: misaligned address for transpose scheduler (#1918) 3fb3d80339e4f794767a53eb8fdd61e64cf404a2 Add variance_mean function using Welford (#1907) 98febf6aa3b8c6fe4fdfb2864cda9e5d30089262 Remove DisableOption::UnrollWithRng (#1913) ee8ef33a5591b534cf587d347af11e48ba7a15d4 Minor fix for the debug interface of using PTX directly (#1917) 6e8f953351f9dabfd1f991d8431cecb6c2ce684d Add PYTORCH_NVFUSER_DUMP options to save PTX and CUBIN (#1916) 5eefa9a72385f6a4b145680a9dcc52d7e8293763 dopt is only available since nvrtc 11.7 (#1915) 2ec8fc711eafc72451eebf0f5e2a98a38bf3f6ef Kill computeAtBetween (#1911) d0d106a1d9af118d71673173674e875be35d259d Improve view support on pointwise and transpose scheduler (#1906) e71e1ecefe67219846070590bbed54bbc7416b79 Fix name clash of RNG with shared memory (#1904) 3381793a253689abf224febc73fd3fe2a0dbc921 Fix mutator and sameAs for expanded IterDomain (#1902) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D39324552](https://our.internmc.facebook.com/intern/diff/D39324552) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84626 Approved by: https://github.com/malfet
371 lines
12 KiB
C++
371 lines
12 KiB
C++
#pragma once
|
|
|
|
#include <c10/macros/Export.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
|
|
|
#include <unordered_map>
|
|
|
|
// dispatch.h prevents the need from adding manual dispatch in every class that
|
|
// wants to define how to process a series of nodes. dispatch.h provides 4
|
|
// classes that can be inherited providing a means to override functions on a
|
|
// per-node basis. There are currently 4 provided dispatch mechanisms:
|
|
//
|
|
// OptOutDispatch:
|
|
//
|
|
// provides the functions:
|
|
// virtual void handle(ValType* irnode){}
|
|
//
|
|
// This provides a mechanisms to override this handle for particular node
|
|
// types. For example if we only wanted to actually run a function on
|
|
// BinaryOps, we could inherit OptOutDispatch and simply override: void
|
|
// handle(BinaryOp*) { doSomething; } Then we could run through all our
|
|
// Statement* and call OptOutDispatch::handle(statement). When a BinaryOp is
|
|
// encountered our override function will be called. For every other node,
|
|
// nothing will be done.
|
|
//
|
|
// OptInDispatch:
|
|
//
|
|
// This class is similar to OptOutDispatch, however if we encounter a node
|
|
// that we haven't specified an override for in the derived class, an error
|
|
// will be thrown. This is useful if we create a class that is expected to
|
|
// handle any type of node it encounters.
|
|
//
|
|
// OptOutMutator:
|
|
//
|
|
// This class is similar to OptOutDispatch except the functions provided are of
|
|
// type: virtual Statement* mutate(Statement*) this is useful for when we want
|
|
// to have an IR node result from our overloaded functions.
|
|
//
|
|
// OptInMutator:
|
|
//
|
|
// This class is similar to OptInDispatch except the functions provided are of
|
|
// type: virtual Statement* mutate(Statement*) this is useful for when we want
|
|
// to have an IR node result from our overloaded functions.
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
class IrContainer;
|
|
class Fusion;
|
|
|
|
// Hierarchal dispatch functions for handle
|
|
class Statement;
|
|
class Expr;
|
|
class Val;
|
|
|
|
// Vals
|
|
class IterDomain;
|
|
class TensorDomain;
|
|
class TensorView;
|
|
|
|
class Bool;
|
|
class Double;
|
|
class Int;
|
|
class ComplexDouble;
|
|
class NamedScalar;
|
|
|
|
// Exprs
|
|
class ARangeOp;
|
|
class UnaryOp;
|
|
class BinaryOp;
|
|
class TernaryOp;
|
|
class RNGOp;
|
|
class ReductionOp;
|
|
class GroupedReductionOp;
|
|
class WelfordOp;
|
|
class GroupedWelfordOp;
|
|
class LoadStoreOp;
|
|
class MmaOp;
|
|
class BroadcastOp;
|
|
class TransposeOp;
|
|
class ExpandOp;
|
|
class ShiftOp;
|
|
class GatherOp;
|
|
class ViewAsScalar;
|
|
class ViewOp;
|
|
|
|
// Exprs
|
|
class Split;
|
|
class Merge;
|
|
class Swizzle2D;
|
|
|
|
namespace kir {
|
|
class Predicate;
|
|
class TensorIndex;
|
|
class IntPair;
|
|
|
|
class Allocate;
|
|
class BlockSync;
|
|
class GridSync;
|
|
class CpAsyncWait;
|
|
class CpAsyncCommit;
|
|
class ForLoop;
|
|
class IfThenElse;
|
|
class GridReduction;
|
|
class GroupedGridReduction;
|
|
class GridBroadcast;
|
|
class GridWelford;
|
|
class GroupedGridWelford;
|
|
class AllocateFusedReduction;
|
|
class InitMagicZero;
|
|
class UpdateMagicZero;
|
|
class Swizzle2DInt;
|
|
class PairSelect;
|
|
|
|
} // namespace kir
|
|
|
|
// By default, all IR nodes are handled in this dispatch, and will call an empty
|
|
// function on all nodes.
|
|
class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
|
|
protected:
|
|
virtual void unhandled(const Statement*) {}
|
|
|
|
public:
|
|
// Hierarchal dispatch functions for handle
|
|
virtual void handle(const Statement*);
|
|
virtual void handle(const Expr*);
|
|
virtual void handle(const Val*);
|
|
|
|
// Vals
|
|
virtual void handle(const IterDomain* stmt);
|
|
virtual void handle(const TensorDomain* stmt);
|
|
virtual void handle(const TensorView* stmt);
|
|
virtual void handle(const Bool* stmt);
|
|
virtual void handle(const Double* stmt);
|
|
virtual void handle(const Int* stmt);
|
|
virtual void handle(const ComplexDouble* stmt);
|
|
virtual void handle(const NamedScalar* stmt);
|
|
|
|
virtual void handle(const kir::Predicate*);
|
|
virtual void handle(const kir::TensorIndex*);
|
|
virtual void handle(const kir::IntPair*);
|
|
|
|
// Exprs
|
|
virtual void handle(const ARangeOp* stmt);
|
|
virtual void handle(const UnaryOp* stmt);
|
|
virtual void handle(const BinaryOp* stmt);
|
|
virtual void handle(const TernaryOp* stmt);
|
|
virtual void handle(const RNGOp* stmt);
|
|
virtual void handle(const ReductionOp* stmt);
|
|
virtual void handle(const GroupedReductionOp* stmt);
|
|
virtual void handle(const WelfordOp* stmt);
|
|
virtual void handle(const GroupedWelfordOp* stmt);
|
|
virtual void handle(const LoadStoreOp* stmt);
|
|
virtual void handle(const MmaOp* stmt);
|
|
virtual void handle(const BroadcastOp* stmt);
|
|
|
|
virtual void handle(const Split* stmt);
|
|
virtual void handle(const Merge* stmt);
|
|
virtual void handle(const Swizzle2D* stmt);
|
|
virtual void handle(const TransposeOp* stmt);
|
|
virtual void handle(const ExpandOp* stmt);
|
|
virtual void handle(const ShiftOp* stmt);
|
|
virtual void handle(const GatherOp* stmt);
|
|
virtual void handle(const ViewAsScalar* stmt);
|
|
virtual void handle(const ViewOp* stmt);
|
|
|
|
virtual void handle(const kir::Allocate*);
|
|
virtual void handle(const kir::BlockSync*);
|
|
virtual void handle(const kir::GridSync*);
|
|
virtual void handle(const kir::CpAsyncWait*);
|
|
virtual void handle(const kir::CpAsyncCommit*);
|
|
virtual void handle(const kir::InitMagicZero*);
|
|
virtual void handle(const kir::UpdateMagicZero*);
|
|
virtual void handle(const kir::ForLoop*);
|
|
virtual void handle(const kir::IfThenElse*);
|
|
virtual void handle(const kir::GridReduction*);
|
|
virtual void handle(const kir::GroupedGridReduction*);
|
|
virtual void handle(const kir::GridBroadcast*);
|
|
virtual void handle(const kir::GridWelford*);
|
|
virtual void handle(const kir::GroupedGridWelford*);
|
|
virtual void handle(const kir::AllocateFusedReduction*);
|
|
virtual void handle(const kir::Swizzle2DInt*);
|
|
virtual void handle(const kir::PairSelect*);
|
|
};
|
|
|
|
class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
|
|
protected:
|
|
virtual void unhandled(Statement*);
|
|
|
|
public:
|
|
// Hierarchal dispatch functions for handle
|
|
virtual void handle(Statement*);
|
|
virtual void handle(Expr*);
|
|
virtual void handle(Val*);
|
|
|
|
// Vals
|
|
virtual void handle(Bool* stmt);
|
|
virtual void handle(Double* stmt);
|
|
virtual void handle(Int* stmt);
|
|
virtual void handle(ComplexDouble* stmt);
|
|
virtual void handle(NamedScalar* stmt);
|
|
virtual void handle(IterDomain* stmt);
|
|
virtual void handle(TensorDomain* stmt);
|
|
virtual void handle(TensorView* stmt);
|
|
|
|
virtual void handle(kir::Predicate*);
|
|
virtual void handle(kir::TensorIndex*);
|
|
virtual void handle(kir::IntPair*);
|
|
|
|
// Exprs
|
|
virtual void handle(ARangeOp* stmt);
|
|
virtual void handle(UnaryOp* stmt);
|
|
virtual void handle(BinaryOp* stmt);
|
|
virtual void handle(TernaryOp* stmt);
|
|
virtual void handle(RNGOp* stmt);
|
|
virtual void handle(ReductionOp* stmt);
|
|
virtual void handle(GroupedReductionOp* stmt);
|
|
virtual void handle(WelfordOp* stmt);
|
|
virtual void handle(GroupedWelfordOp* stmt);
|
|
virtual void handle(LoadStoreOp* stmt);
|
|
virtual void handle(MmaOp* stmt);
|
|
virtual void handle(BroadcastOp* stmt);
|
|
|
|
virtual void handle(Split* stmt);
|
|
virtual void handle(Merge* stmt);
|
|
virtual void handle(Swizzle2D* stmt);
|
|
virtual void handle(TransposeOp* stmt);
|
|
virtual void handle(ExpandOp* stmt);
|
|
virtual void handle(ShiftOp* stmt);
|
|
virtual void handle(GatherOp* stmt);
|
|
virtual void handle(ViewAsScalar* stmt);
|
|
virtual void handle(ViewOp* stmt);
|
|
|
|
virtual void handle(kir::Allocate* stmt);
|
|
virtual void handle(kir::BlockSync* stmt);
|
|
virtual void handle(kir::GridSync* stmt);
|
|
virtual void handle(kir::CpAsyncWait* stmt);
|
|
virtual void handle(kir::CpAsyncCommit* stmt);
|
|
virtual void handle(kir::InitMagicZero* stmt);
|
|
virtual void handle(kir::UpdateMagicZero* stmt);
|
|
virtual void handle(kir::ForLoop* stmt);
|
|
virtual void handle(kir::IfThenElse* stmt);
|
|
virtual void handle(kir::GridReduction* stmt);
|
|
virtual void handle(kir::GroupedGridReduction* stmt);
|
|
virtual void handle(kir::GridBroadcast* stmt);
|
|
virtual void handle(kir::GridWelford* stmt);
|
|
virtual void handle(kir::GroupedGridWelford* stmt);
|
|
virtual void handle(kir::AllocateFusedReduction* stmt);
|
|
virtual void handle(kir::Swizzle2DInt* stmt);
|
|
virtual void handle(kir::PairSelect* stmt);
|
|
};
|
|
|
|
class TORCH_CUDA_CU_API OptInConstDispatch : public OptOutConstDispatch {
|
|
public:
|
|
using OptOutConstDispatch::handle;
|
|
|
|
protected:
|
|
virtual void unhandled(const Statement* stmt) final;
|
|
};
|
|
|
|
class TORCH_CUDA_CU_API OptInDispatch : public OptOutDispatch {
|
|
public:
|
|
using OptOutDispatch::handle;
|
|
|
|
protected:
|
|
virtual void unhandled(Statement* stmt) final;
|
|
};
|
|
|
|
// Class to perform mutations on Fusion IR. Exprs can simply be redefined, but
|
|
// when mutating values they have to be registered through registerMutation so
|
|
// that exprs can detect there's been a muatation and know to modify all
|
|
// instances of that Val. This means each Val should be mutated "consistently".
|
|
// Otherwise behavior may be difficult to understand as it depends on which
|
|
// order mutate is called in. This class expects user to topologically call the
|
|
// statments of interest so inputs are called and mutated before exprs depending
|
|
// on them.
|
|
//
|
|
// Warning: TensorViews need to be treated carefully. As we don't generally
|
|
// register their mutation when their tensor domains only change. If a TV needs
|
|
// to be swapped out, it needs to be registered as a "proper" mutation like
|
|
// other vals, on top of TensorDomain being updated in the mutated TensorView.
|
|
//
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
|
|
public:
|
|
// Hierarchal dispatch functions for handle
|
|
virtual void mutate(Statement* s);
|
|
virtual void mutate(Expr* e);
|
|
virtual void mutate(Val* v);
|
|
|
|
void registerMutation(Val* val, Val* mutation);
|
|
|
|
Val* maybeMutated(Val* val) {
|
|
if (mutations.find(val) == mutations.end()) {
|
|
return val;
|
|
}
|
|
return mutations.at(val);
|
|
}
|
|
|
|
std::unordered_map<Val*, Val*> mutations;
|
|
|
|
//****Functions below defined in mutator.cpp*****
|
|
|
|
// Vals
|
|
virtual void mutate(Bool*);
|
|
virtual void mutate(Double*);
|
|
virtual void mutate(Int*);
|
|
virtual void mutate(ComplexDouble*);
|
|
virtual void mutate(NamedScalar*);
|
|
virtual void mutate(IterDomain*);
|
|
virtual void mutate(TensorDomain*);
|
|
virtual void mutate(TensorView*);
|
|
|
|
virtual void mutate(kir::Predicate*);
|
|
virtual void mutate(kir::TensorIndex*);
|
|
virtual void mutate(kir::IntPair*);
|
|
|
|
// Exprs
|
|
virtual void mutate(ARangeOp*);
|
|
virtual void mutate(UnaryOp*);
|
|
virtual void mutate(BinaryOp*);
|
|
virtual void mutate(TernaryOp*);
|
|
virtual void mutate(RNGOp*);
|
|
virtual void mutate(ReductionOp*);
|
|
virtual void mutate(GroupedReductionOp*);
|
|
virtual void mutate(WelfordOp*);
|
|
virtual void mutate(GroupedWelfordOp*);
|
|
virtual void mutate(LoadStoreOp*);
|
|
virtual void mutate(MmaOp*);
|
|
virtual void mutate(BroadcastOp*);
|
|
|
|
virtual void mutate(Split*);
|
|
virtual void mutate(Merge*);
|
|
virtual void mutate(Swizzle2D*);
|
|
virtual void mutate(TransposeOp*);
|
|
virtual void mutate(ExpandOp*);
|
|
virtual void mutate(ShiftOp*);
|
|
virtual void mutate(GatherOp*);
|
|
virtual void mutate(ViewAsScalar*);
|
|
virtual void mutate(ViewOp*);
|
|
|
|
virtual void mutate(kir::Allocate*);
|
|
virtual void mutate(kir::BlockSync*);
|
|
virtual void mutate(kir::GridSync*);
|
|
virtual void mutate(kir::CpAsyncWait*);
|
|
virtual void mutate(kir::CpAsyncCommit*);
|
|
virtual void mutate(kir::InitMagicZero*);
|
|
virtual void mutate(kir::UpdateMagicZero*);
|
|
virtual void mutate(kir::ForLoop*);
|
|
virtual void mutate(kir::IfThenElse*);
|
|
virtual void mutate(kir::GridReduction*);
|
|
virtual void mutate(kir::GroupedGridReduction*);
|
|
virtual void mutate(kir::GridBroadcast*);
|
|
virtual void mutate(kir::GridWelford*);
|
|
virtual void mutate(kir::GroupedGridWelford*);
|
|
virtual void mutate(kir::AllocateFusedReduction*);
|
|
virtual void mutate(kir::Swizzle2DInt*);
|
|
virtual void mutate(kir::PairSelect*);
|
|
|
|
protected:
|
|
void removeExpr(IrContainer*, Expr*);
|
|
};
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|