pytorch/torch/csrc/jit/codegen/cuda/fusion.h
jjsjann123 0e582fbfcc [NVFuser] Upstream push 0907 (#84626)
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

- codegen improvement:
i. improved view support on pointwise and transpose scheduler
ii. grouped grid welford added for better outer-norm grid persistence in normalization

- misc:
i. new composite ops added: variance_mean , arange,
ii. fixes misaligned address for transpose scheduler
iii. refactor on separation of compilation API from execution API to prepare us for async compilation
iv. double type support on expression evaluator
v. PYTORCH_NVFUSER_DUMP refactor to save PTX and CUBIN

Commits that's in this PR from the devel branch:
```
89330aa23aa804340b2406ab58899d816e3dc3d2 Tensor factories must set the output shape as its input (#1939)
b2fd01ea9346712c6d6f623ca6addbc4888d008e arange support (#1933)
56c00fd3922dad7dfc57351ad7d780f0f2f8e4ed Double support on all expression evaluators (#1937)
371f28223e57fe3f6b5e50a0a45177e6a5c0785c Improve trivial reduction merge support (#1931)
1d0c26790e5647920b40d419d26815bbe310b3a6 Test `rand` in a fusion with zero tensor input (#1932)
0dab160fb2177d178eef3148c6a529e0855009e9 Fix softmax bwd sizes. (#1890)
ef98f360f6d3e3e1cc662ecb65202d88150f128d Fix a bug (#1936)
63132a0c56508c550084b07fb76a3df865102d00 Propagate permissive mapping information into indexing pass (#1929)
b4ac2c88d78078ee4d8b21c4fc51645b5710a282 Map IterationDomains through view operations. (#1919)
c0a187a7619d7cf9dc920294e15461791e8d6d4d do not use deprecated functions (#1935)
88de85e758c5e4afb7b6e746573c0d9a53b4cea7 Upstream cherry pick fixes 0811 (#1934)
b247dcf7c57dc6ac3f7a799b0a6beb7770536a74 Separate kernel compilation API from kernel execution API (#1914)
b34e3b93ee1a8030730c14af3995dd95665af07d Fix `ir_utils::hasBlockSync` + misc fixes in transpose scheduler (#1924)
14a53e6707f43bf760494c238a46386d69830822 Nullary RNGOp (#1892)
3c3c89e638f5172cafb0761f22bacd1fd695eec3 Misc fixes/tuning for transpose scheduler (#1912)
20cf109c8b44d48f61977e35bae94368985144ac Grouped grid welford (#1921)
6cf7eb024c9e53c358cbe56597e117bad56efefd Transpose scheduler small dim sizes better support (#1910)
9341ea9a5bf42f9b14ccad0c94edbc79fc5bb552 Disabled ViewPersistentShmoo sizes that results in NAN (#1922)
057237f66deeea816bb943d802a97c1b7e4414ab Fix CUDA driver error: misaligned address for transpose scheduler  (#1918)
3fb3d80339e4f794767a53eb8fdd61e64cf404a2 Add variance_mean function using Welford (#1907)
98febf6aa3b8c6fe4fdfb2864cda9e5d30089262 Remove DisableOption::UnrollWithRng (#1913)
ee8ef33a5591b534cf587d347af11e48ba7a15d4 Minor fix for the debug interface of using PTX directly (#1917)
6e8f953351f9dabfd1f991d8431cecb6c2ce684d Add PYTORCH_NVFUSER_DUMP options to save PTX and CUBIN (#1916)
5eefa9a72385f6a4b145680a9dcc52d7e8293763 dopt is only available since nvrtc 11.7 (#1915)
2ec8fc711eafc72451eebf0f5e2a98a38bf3f6ef Kill computeAtBetween (#1911)
d0d106a1d9af118d71673173674e875be35d259d Improve view support on pointwise and transpose scheduler (#1906)
e71e1ecefe67219846070590bbed54bbc7416b79 Fix name clash of RNG with shared memory (#1904)
3381793a253689abf224febc73fd3fe2a0dbc921 Fix mutator and sameAs for expanded IterDomain (#1902)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D39324552](https://our.internmc.facebook.com/intern/diff/D39324552)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84626
Approved by: https://github.com/malfet
2022-09-23 20:29:48 +00:00

285 lines
9.3 KiB
C++

#pragma once
#include <ATen/core/ivalue.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_container.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
//! Usage: FusionGuard and Fusion are required user interfaces for any operation
//! underlying the code generator. In order to create values, expressions, and
//! generate code a Fusion instance must be active. It is the responsibility of
//! the user to create a Fusion instance and register it with the fusion guard.
//! The simplest example of this is:
//!
//! Fusion fusion;
//! FusionGuard fg(&fusion);
//!
//! Once a fusion is active all values and operations will be registered with
//! it.
//!
//! FusionGuard and Fusion are critical to the lifetime model of the IR system.
//! FusionGuard is a convenient way to set what base container instance holds
//! the defined IR. Statements that are defined are registered through the
//! FusionGuard with a particular Fusion. FusionGuard provides convenient
//! methods to access the active fusion so it doesn't need to be passed around
//! constantly. Any IR node derived classes from Statement must register with
//! Fusion to avoid memory leaks.
//!
//! Fusion is generally thought of as a translated fusion group from the JIT. It
//! is likely a single kernel, although, we don't have to stick to this in the
//! future and could in theory generate multiple kernels with an executor to run
//! them.
//!
//! Fusion also allows users to set input/output values that will allow us to
//! figure out how to hook up runtime data to and from the JIT as well as
//! provide us mechanisms for dependency analysis and DCE including safety
//! checks.
class Fusion;
class TensorView;
class WelfordResult;
class SegmentCandidateFinder;
class SegmentedFusion;
class KernelArgumentHolder;
//! Fusion Guard is our "context manager". It holds the actrive fusion and
//! allows it to be accessed anywhere through FusionGuard::getCurFusion()
class TORCH_CUDA_CU_API FusionGuard {
public:
Fusion* prev_fusion;
//! Set the active fusion so it can be manipulated.
explicit FusionGuard(Fusion* fusion);
~FusionGuard();
static Fusion* getCurFusion();
static void setCurFusion(Fusion* fusion);
};
//! Fusion is mutable but unique. Nodes cannot be copied in any way from one
//! Fusion to another. If anything like that is desired, it would require
//! duplicating all associated values and exprs. Fusion is considered to be SSA,
//! though this could also change in the future if there is a good reason to do
//! so.
//!
//! The Fusion owns the whole IR graph (Vals and Exprs)
//!
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_CUDA_CU_API Fusion : public IrContainer {
typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;
public:
Fusion() = default;
Fusion(const Fusion& other);
Fusion(Fusion&& other) noexcept;
Fusion& operator=(const Fusion& other);
Fusion& operator=(Fusion&& other) noexcept;
~Fusion();
friend void swap(Fusion& a, Fusion& b) noexcept;
void clear() noexcept;
//! Break dependency chains associated with Expr, remove references to expr
//! delete expr
void removeExpr(Expr* expr) override;
//! Completely remove val from the fusion, break all dependencies associated
//! with it
void removeVal(Val* val) override;
//! Register input as an input of the fusion
void addInput(Val* input);
//! Register output as an output of the fusion
void addOutput(Val* output);
//! Deregister input as an input of the fusion
void removeInput(Val* input);
//! Deregister output as an output of the fusion
void removeOutput(Val* output);
//! Replace output with another value
void replaceOutput(Val* output, Val* replacement);
//! Assert that all leaves found from outputs are registered as an input
void validateInputs();
//! Print this fusion to the console
void print();
//! Print Arith exprs
//! \param from_outputs_only Only print exprs reachable from outputs
void printMath(bool from_outputs_only = true);
//! Print transformations used in fusion (can be very verbose)
void printTransforms();
//! Lower the fusion and print a kernel
void printKernel(DataType index_type = DataType::Int);
//! Return a list of topologically sorted expressions. This only includes
//! exprs required to genereate registered outputs.
std::vector<Expr*> exprs();
//! Return a vector of fusion inputs that feed this Val
std::vector<Val*> inputsOf(Val* val);
//! Return all Vals in math expressions that cannot be eliminated.
//!
//! It is generally equivalent to vals that are used to generate
//! outputs, however, when a multi-output expression exists, and only
//! some of the outputs are used, the remaining unused outputs are
//! also included as they must show up in the final code.
std::vector<Val*> usedMathVals();
//! Returns all vals that are produced by used math expressions and
//! also do not have further consumers.
//!
//! In the case of an active multi-output expressions, the returned vector
//! will include the expression outputs that did not lead to an fusion
//! output.
std::vector<Val*> terminatingMathVals();
//! Return all Exprs that use val
std::unordered_set<Expr*> unordered_uses(const Val* val) const;
//! Return the Expr that produces val
Expr* definition(const Val* val) const;
//! Indicate to kernel to set itself up to generate random numbers
bool isStochastic();
//! Run fusion segmentation algorithm to create a segmented fusion
std::unique_ptr<SegmentedFusion> segment(const KernelArgumentHolder& args);
const auto& inputs() const {
return inputs_;
}
std::vector<Val*> inputsAndCreated();
const auto& outputs() const {
return outputs_;
}
std::vector<Val*> getTerminatingOutputs() const;
// Aliasing output to input value, this is a WAR to allow inplace update on
// input tensor.
// Note: this is not always safe and should be used with extra caution.
// Currently the only place it's used is in the running stats update for batch
// normalization.
// TODO: alias should be made aware to segmentation, so we'll always include
// the input tensor to the section where output is produced.
void aliasOutputToInput(Val* output, Val* input);
Val* getOutputAlias(Val* output);
std::unordered_set<int> getOutputAliasIndices() const;
std::vector<std::pair<int, int>> getInputAliasIndices() const;
// mark input at index to be permuted by permutation
void setPermutationOnInput(int index, std::vector<int64_t> permutation) {
permuted_input_map_.insert({index, permutation});
}
// mark output at index to be restored by permutation
void setPermutationOnOutput(int index, std::vector<int64_t> permutation) {
permuted_output_map_.insert({index, permutation});
}
// return a map of indices to permutation, which indicates all input tensors
// that needs to be permuted
const PermutationMap& getPermutationInputMap() const {
return permuted_input_map_;
}
// return a map of indices to permutation, which indicates all output tensors
// that needs to be permuted
const PermutationMap& getPermutationOutputMap() const {
return permuted_output_map_;
}
bool isTVUseInfoValid() {
return all_tv_uses_valid_;
}
bool isUpdatingTVUseInfo() {
return is_during_update_uses_;
}
const auto& ioAlias() const {
return io_alias_;
}
protected:
friend SegmentCandidateFinder;
friend SegmentedFusion;
friend class TranslateApplicableWelford;
friend Val;
static IrCloner copy(const Fusion* from, Fusion* to);
//! Register the Val with this fusion
virtual void registerVal(Val* val) override;
//! Register expr with this fusion.
//! When we register an expression, we want to update the dependency tracking
//! of Vals. If this container is a not a Kernel, it will remove previous
//! definitions of outputs and register this Expr as the definition. Otherwise
//! will update definition if not previously set, but will not remove old
//! definitions.
virtual void registerExpr(Expr* expr) override;
//! Clear Expr's from TV uses that are not required to produce outputs from
//! inputs. Only other place this is used (other than Fusion) is in
//! Val::uses()
void resetTvUses();
private:
// Determine if the two values are compatible for aliasing
// Same DataType, ValType, and number of dimensions
bool isAliasCompatible(Val* left, Val* right);
private:
// Fusion inputs and outputs
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;
// io alias pointing from output to input
std::unordered_map<Val*, Val*> io_alias_;
// See Note [ Permutation support in nvfuser ]
// map from indices of input tensor to permutation
PermutationMap permuted_input_map_;
// map from indices of output tensor to permutation
PermutationMap permuted_output_map_;
// Records if the current use data in the IR nodes are valid
// the states are either all valid or all invalid
bool all_tv_uses_valid_ = false;
bool is_during_update_uses_ = false;
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch