mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD 16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875) 501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823) a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872) 172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871) 440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870) d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864) 51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866) a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861) 1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852) e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858) 9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857) 20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855) 405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853) 5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845) db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848) 102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847) b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842) 0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840) 63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839) 2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067 Approved by: https://github.com/davidberard98
355 lines
11 KiB
C++
355 lines
11 KiB
C++
#pragma once
|
|
|
|
#include <c10/macros/Export.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
|
|
|
#include <unordered_map>
|
|
|
|
// dispatch.h prevents the need from adding manual dispatch in every class that
|
|
// wants to define how to process a series of nodes. dispatch.h provides 4
|
|
// classes that can be inherited providing a means to override functions on a
|
|
// per-node basis. There are currently 4 provided dispatch mechanisms:
|
|
//
|
|
// OptOutDispatch:
|
|
//
|
|
// provides the functions:
|
|
// virtual void handle(ValType* irnode){}
|
|
//
|
|
// This provides a mechanisms to override this handle for particular node
|
|
// types. For example if we only wanted to actually run a function on
|
|
// BinaryOps, we could inherit OptOutDispatch and simply override: void
|
|
// handle(BinaryOp*) { doSomething; } Then we could run through all our
|
|
// Statement* and call OptOutDispatch::handle(statement). When a BinaryOp is
|
|
// encountered our override function will be called. For every other node,
|
|
// nothing will be done.
|
|
//
|
|
// OptInDispatch:
|
|
//
|
|
// This class is similar to OptOutDispatch, however if we encounter a node
|
|
// that we haven't specified an override for in the derived class, an error
|
|
// will be thrown. This is useful if we create a class that is expected to
|
|
// handle any type of node it encounters.
|
|
//
|
|
// OptOutMutator:
|
|
//
|
|
// This class is similar to OptOutDispatch except the functions provided are of
|
|
// type: virtual Statement* mutate(Statement*) this is useful for when we want
|
|
// to have an IR node result from our overloaded functions.
|
|
//
|
|
// OptInMutator:
|
|
//
|
|
// This class is similar to OptInDispatch except the functions provided are of
|
|
// type: virtual Statement* mutate(Statement*) this is useful for when we want
|
|
// to have an IR node result from our overloaded functions.
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
class IrContainer;
|
|
class Fusion;
|
|
|
|
// Hierarchal dispatch functions for handle
|
|
class Statement;
|
|
class Expr;
|
|
class Val;
|
|
|
|
// Vals
|
|
class IterDomain;
|
|
class TensorDomain;
|
|
class TensorView;
|
|
|
|
class Bool;
|
|
class Double;
|
|
class Int;
|
|
class ComplexDouble;
|
|
class NamedScalar;
|
|
|
|
// Exprs
|
|
class UnaryOp;
|
|
class BinaryOp;
|
|
class TernaryOp;
|
|
class ReductionOp;
|
|
class GroupedReductionOp;
|
|
class WelfordOp;
|
|
class LoadStoreOp;
|
|
class MmaOp;
|
|
class BroadcastOp;
|
|
class TransposeOp;
|
|
class ExpandOp;
|
|
class ShiftOp;
|
|
class GatherOp;
|
|
class ViewAsScalar;
|
|
class ViewOp;
|
|
|
|
// Exprs
|
|
class Split;
|
|
class Merge;
|
|
class Swizzle2D;
|
|
|
|
namespace kir {
|
|
class Predicate;
|
|
class TensorIndex;
|
|
class IntPair;
|
|
|
|
class Allocate;
|
|
class BlockSync;
|
|
class GridSync;
|
|
class CpAsyncWait;
|
|
class CpAsyncCommit;
|
|
class ForLoop;
|
|
class IfThenElse;
|
|
class GridReduction;
|
|
class GroupedGridReduction;
|
|
class GridBroadcast;
|
|
class GridWelford;
|
|
class AllocateFusedReduction;
|
|
class InitMagicZero;
|
|
class UpdateMagicZero;
|
|
class Swizzle2DInt;
|
|
class PairSelect;
|
|
|
|
} // namespace kir
|
|
|
|
// By default, all IR nodes are handled in this dispatch, and will call an empty
|
|
// function on all nodes.
|
|
class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
|
|
protected:
|
|
virtual void unhandled(const Statement*) {}
|
|
|
|
public:
|
|
// Hierarchal dispatch functions for handle
|
|
virtual void handle(const Statement*);
|
|
virtual void handle(const Expr*);
|
|
virtual void handle(const Val*);
|
|
|
|
// Vals
|
|
virtual void handle(const IterDomain* stmt);
|
|
virtual void handle(const TensorDomain* stmt);
|
|
virtual void handle(const TensorView* stmt);
|
|
virtual void handle(const Bool* stmt);
|
|
virtual void handle(const Double* stmt);
|
|
virtual void handle(const Int* stmt);
|
|
virtual void handle(const ComplexDouble* stmt);
|
|
virtual void handle(const NamedScalar* stmt);
|
|
|
|
virtual void handle(const kir::Predicate*);
|
|
virtual void handle(const kir::TensorIndex*);
|
|
virtual void handle(const kir::IntPair*);
|
|
|
|
// Exprs
|
|
virtual void handle(const UnaryOp* stmt);
|
|
virtual void handle(const BinaryOp* stmt);
|
|
virtual void handle(const TernaryOp* stmt);
|
|
virtual void handle(const ReductionOp* stmt);
|
|
virtual void handle(const GroupedReductionOp* stmt);
|
|
virtual void handle(const WelfordOp* stmt);
|
|
virtual void handle(const LoadStoreOp* stmt);
|
|
virtual void handle(const MmaOp* stmt);
|
|
virtual void handle(const BroadcastOp* stmt);
|
|
|
|
virtual void handle(const Split* stmt);
|
|
virtual void handle(const Merge* stmt);
|
|
virtual void handle(const Swizzle2D* stmt);
|
|
virtual void handle(const TransposeOp* stmt);
|
|
virtual void handle(const ExpandOp* stmt);
|
|
virtual void handle(const ShiftOp* stmt);
|
|
virtual void handle(const GatherOp* stmt);
|
|
virtual void handle(const ViewAsScalar* stmt);
|
|
virtual void handle(const ViewOp* stmt);
|
|
|
|
virtual void handle(const kir::Allocate*);
|
|
virtual void handle(const kir::BlockSync*);
|
|
virtual void handle(const kir::GridSync*);
|
|
virtual void handle(const kir::CpAsyncWait*);
|
|
virtual void handle(const kir::CpAsyncCommit*);
|
|
virtual void handle(const kir::InitMagicZero*);
|
|
virtual void handle(const kir::UpdateMagicZero*);
|
|
virtual void handle(const kir::ForLoop*);
|
|
virtual void handle(const kir::IfThenElse*);
|
|
virtual void handle(const kir::GridReduction*);
|
|
virtual void handle(const kir::GroupedGridReduction*);
|
|
virtual void handle(const kir::GridBroadcast*);
|
|
virtual void handle(const kir::GridWelford*);
|
|
virtual void handle(const kir::AllocateFusedReduction*);
|
|
virtual void handle(const kir::Swizzle2DInt*);
|
|
virtual void handle(const kir::PairSelect*);
|
|
};
|
|
|
|
class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
|
|
protected:
|
|
virtual void unhandled(Statement*);
|
|
|
|
public:
|
|
// Hierarchal dispatch functions for handle
|
|
virtual void handle(Statement*);
|
|
virtual void handle(Expr*);
|
|
virtual void handle(Val*);
|
|
|
|
// Vals
|
|
virtual void handle(Bool* stmt);
|
|
virtual void handle(Double* stmt);
|
|
virtual void handle(Int* stmt);
|
|
virtual void handle(ComplexDouble* stmt);
|
|
virtual void handle(NamedScalar* stmt);
|
|
virtual void handle(IterDomain* stmt);
|
|
virtual void handle(TensorDomain* stmt);
|
|
virtual void handle(TensorView* stmt);
|
|
|
|
virtual void handle(kir::Predicate*);
|
|
virtual void handle(kir::TensorIndex*);
|
|
virtual void handle(kir::IntPair*);
|
|
|
|
// Exprs
|
|
virtual void handle(UnaryOp* stmt);
|
|
virtual void handle(BinaryOp* stmt);
|
|
virtual void handle(TernaryOp* stmt);
|
|
virtual void handle(ReductionOp* stmt);
|
|
virtual void handle(GroupedReductionOp* stmt);
|
|
virtual void handle(WelfordOp* stmt);
|
|
virtual void handle(LoadStoreOp* stmt);
|
|
virtual void handle(MmaOp* stmt);
|
|
virtual void handle(BroadcastOp* stmt);
|
|
|
|
virtual void handle(Split* stmt);
|
|
virtual void handle(Merge* stmt);
|
|
virtual void handle(Swizzle2D* stmt);
|
|
virtual void handle(TransposeOp* stmt);
|
|
virtual void handle(ExpandOp* stmt);
|
|
virtual void handle(ShiftOp* stmt);
|
|
virtual void handle(GatherOp* stmt);
|
|
virtual void handle(ViewAsScalar* stmt);
|
|
virtual void handle(ViewOp* stmt);
|
|
|
|
virtual void handle(kir::Allocate* stmt);
|
|
virtual void handle(kir::BlockSync* stmt);
|
|
virtual void handle(kir::GridSync* stmt);
|
|
virtual void handle(kir::CpAsyncWait* stmt);
|
|
virtual void handle(kir::CpAsyncCommit* stmt);
|
|
virtual void handle(kir::InitMagicZero* stmt);
|
|
virtual void handle(kir::UpdateMagicZero* stmt);
|
|
virtual void handle(kir::ForLoop* stmt);
|
|
virtual void handle(kir::IfThenElse* stmt);
|
|
virtual void handle(kir::GridReduction* stmt);
|
|
virtual void handle(kir::GroupedGridReduction* stmt);
|
|
virtual void handle(kir::GridBroadcast* stmt);
|
|
virtual void handle(kir::GridWelford* stmt);
|
|
virtual void handle(kir::AllocateFusedReduction* stmt);
|
|
virtual void handle(kir::Swizzle2DInt* stmt);
|
|
virtual void handle(kir::PairSelect* stmt);
|
|
};
|
|
|
|
class TORCH_CUDA_CU_API OptInConstDispatch : public OptOutConstDispatch {
|
|
public:
|
|
using OptOutConstDispatch::handle;
|
|
|
|
protected:
|
|
virtual void unhandled(const Statement* stmt) final;
|
|
};
|
|
|
|
class TORCH_CUDA_CU_API OptInDispatch : public OptOutDispatch {
|
|
public:
|
|
using OptOutDispatch::handle;
|
|
|
|
protected:
|
|
virtual void unhandled(Statement* stmt) final;
|
|
};
|
|
|
|
// Class to perform mutations on Fusion IR. Exprs can simply be redefined, but
|
|
// when mutating values they have to be registered through registerMutation so
|
|
// that exprs can detect there's been a muatation and know to modify all
|
|
// instances of that Val. This means each Val should be mutated "consistently".
|
|
// Otherwise behavior may be difficult to understand as it depends on which
|
|
// order mutate is called in. This class expects user to topologically call the
|
|
// statments of interest so inputs are called and mutated before exprs depending
|
|
// on them.
|
|
//
|
|
// Warning: TensorViews need to be treated carefully. As we don't generally
|
|
// register their mutation when their tensor domains only change. If a TV needs
|
|
// to be swapped out, it needs to be registered as a "proper" mutation like
|
|
// other vals, on top of TensorDomain being updated in the mutated TensorView.
|
|
//
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
|
|
public:
|
|
// Hierarchal dispatch functions for handle
|
|
virtual void mutate(Statement* s);
|
|
virtual void mutate(Expr* e);
|
|
virtual void mutate(Val* v);
|
|
|
|
void registerMutation(Val* val, Val* mutation);
|
|
|
|
Val* maybeMutated(Val* val) {
|
|
if (mutations.find(val) == mutations.end()) {
|
|
return val;
|
|
}
|
|
return mutations.at(val);
|
|
}
|
|
|
|
std::unordered_map<Val*, Val*> mutations;
|
|
|
|
//****Functions below defined in mutator.cpp*****
|
|
|
|
// Vals
|
|
virtual void mutate(Bool*);
|
|
virtual void mutate(Double*);
|
|
virtual void mutate(Int*);
|
|
virtual void mutate(ComplexDouble*);
|
|
virtual void mutate(NamedScalar*);
|
|
virtual void mutate(IterDomain*);
|
|
virtual void mutate(TensorDomain*);
|
|
virtual void mutate(TensorView*);
|
|
|
|
virtual void mutate(kir::Predicate*);
|
|
virtual void mutate(kir::TensorIndex*);
|
|
virtual void mutate(kir::IntPair*);
|
|
|
|
// Exprs
|
|
virtual void mutate(UnaryOp*);
|
|
virtual void mutate(BinaryOp*);
|
|
virtual void mutate(TernaryOp*);
|
|
virtual void mutate(ReductionOp*);
|
|
virtual void mutate(GroupedReductionOp*);
|
|
virtual void mutate(WelfordOp*);
|
|
virtual void mutate(LoadStoreOp*);
|
|
virtual void mutate(MmaOp*);
|
|
virtual void mutate(BroadcastOp*);
|
|
|
|
virtual void mutate(Split*);
|
|
virtual void mutate(Merge*);
|
|
virtual void mutate(Swizzle2D*);
|
|
virtual void mutate(TransposeOp*);
|
|
virtual void mutate(ExpandOp*);
|
|
virtual void mutate(ShiftOp*);
|
|
virtual void mutate(GatherOp*);
|
|
virtual void mutate(ViewAsScalar*);
|
|
virtual void mutate(ViewOp*);
|
|
|
|
virtual void mutate(kir::Allocate*);
|
|
virtual void mutate(kir::BlockSync*);
|
|
virtual void mutate(kir::GridSync*);
|
|
virtual void mutate(kir::CpAsyncWait*);
|
|
virtual void mutate(kir::CpAsyncCommit*);
|
|
virtual void mutate(kir::InitMagicZero*);
|
|
virtual void mutate(kir::UpdateMagicZero*);
|
|
virtual void mutate(kir::ForLoop*);
|
|
virtual void mutate(kir::IfThenElse*);
|
|
virtual void mutate(kir::GridReduction*);
|
|
virtual void mutate(kir::GroupedGridReduction*);
|
|
virtual void mutate(kir::GridBroadcast*);
|
|
virtual void mutate(kir::GridWelford*);
|
|
virtual void mutate(kir::AllocateFusedReduction*);
|
|
virtual void mutate(kir::Swizzle2DInt*);
|
|
virtual void mutate(kir::PairSelect*);
|
|
|
|
protected:
|
|
void removeExpr(IrContainer*, Expr*);
|
|
};
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|