pytorch/torch/csrc/jit/codegen/cuda/fusion.h
jjsjann123 df741c589f [NVFuser] Upstream push 0809 (#83067)
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. removes un-necessary sync from redundant thread compute analysis
  2. symmetric API for BestEffortReplay
  3. support merge on trivial reductions
  4. Ampere async copy improvements
- bug fixes:
  1. vectorization bug fixes
  2. type inference patch : fixes upstream #81725
  3. segmenter bug fix with deterministic iteration ordering
- parser update
  1. added leaky_relu
- scheduler
  1. normalization scheduler clean up.
  2. simplifies matmul scheduling with new transform propagator
  3. merge all dimensions in PW scheduler
  4. various gemm related improvements
- debuggability
  1. nsight compute support
  2. debug dump for InlinePropagator
  3. Add `UnaryOpType::Print`

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD
16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884)
7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803
3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD
01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878)
0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881)
7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875)
501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826)
d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827)
e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824)
9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823)
a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872)
172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868)
a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871)
440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870)
d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867)
1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864)
51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866)
a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering.  (#1865)
1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861)
1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator.  (#1817)
bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852)
e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858)
9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857)
20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855)
405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822)
01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853)
5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846)
92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845)
db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848)
102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847)
b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687)
942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842)
0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840)
63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825)
9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839)
2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067
Approved by: https://github.com/davidberard98
2022-08-10 21:02:56 +00:00

283 lines
9.2 KiB
C++

#pragma once
#include <ATen/core/ivalue.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_container.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
//! Usage: FusionGuard and Fusion are required user interfaces for any operation
//! underlying the code generator. In order to create values, expressions, and
//! generate code a Fusion instance must be active. It is the responsibility of
//! the user to create a Fusion instance and register it with the fusion guard.
//! The simplest example of this is:
//!
//! Fusion fusion;
//! FusionGuard fg(&fusion);
//!
//! Once a fusion is active all values and operations will be registered with
//! it.
//!
//! FusionGuard and Fusion are critical to the lifetime model of the IR system.
//! FusionGuard is a convenient way to set what base container instance holds
//! the defined IR. Statements that are defined are registered through the
//! FusionGuard with a particular Fusion. FusionGuard provides convenient
//! methods to access the active fusion so it doesn't need to be passed around
//! constantly. Any IR node derived classes from Statement must register with
//! Fusion to avoid memory leaks.
//!
//! Fusion is generally thought of as a translated fusion group from the JIT. It
//! is likely a single kernel, although, we don't have to stick to this in the
//! future and could in theory generate multiple kernels with an executor to run
//! them.
//!
//! Fusion also allows users to set input/output values that will allow us to
//! figure out how to hook up runtime data to and from the JIT as well as
//! provide us mechanisms for dependency analysis and DCE including safety
//! checks.
class Fusion;
class TensorView;
class WelfordResult;
class SegmentCandidateFinder;
class SegmentedFusion;
//! Fusion Guard is our "context manager". It holds the actrive fusion and
//! allows it to be accessed anywhere through FusionGuard::getCurFusion()
class TORCH_CUDA_CU_API FusionGuard {
public:
Fusion* prev_fusion;
//! Set the active fusion so it can be manipulated.
explicit FusionGuard(Fusion* fusion);
~FusionGuard();
static Fusion* getCurFusion();
static void setCurFusion(Fusion* fusion);
};
//! Fusion is mutable but unique. Nodes cannot be copied in any way from one
//! Fusion to another. If anything like that is desired, it would require
//! duplicating all associated values and exprs. Fusion is considered to be SSA,
//! though this could also change in the future if there is a good reason to do
//! so.
//!
//! The Fusion owns the whole IR graph (Vals and Exprs)
//!
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_CUDA_CU_API Fusion : public IrContainer {
typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;
public:
Fusion() = default;
Fusion(const Fusion& other);
Fusion(Fusion&& other) noexcept;
Fusion& operator=(const Fusion& other);
Fusion& operator=(Fusion&& other) noexcept;
~Fusion();
friend void swap(Fusion& a, Fusion& b) noexcept;
void clear() noexcept;
//! Break dependency chains associated with Expr, remove references to expr
//! delete expr
void removeExpr(Expr* expr) override;
//! Completely remove val from the fusion, break all dependencies associated
//! with it
void removeVal(Val* val) override;
//! Register input as an input of the fusion
void addInput(Val* input);
//! Register output as an output of the fusion
void addOutput(Val* output);
//! Deregister input as an input of the fusion
void removeInput(Val* input);
//! Deregister output as an output of the fusion
void removeOutput(Val* output);
//! Replace output with another value
void replaceOutput(Val* output, Val* replacement);
//! Assert that all leaves found from outputs are registered as an input
void validateInputs();
//! Print this fusion to the console
void print();
//! Print Arith exprs
//! \param from_outputs_only Only print exprs reachable from outputs
void printMath(bool from_outputs_only = true);
//! Print transformations used in fusion (can be very verbose)
void printTransforms();
//! Lower the fusion and print a kernel
void printKernel(DataType index_type = DataType::Int);
//! Return a list of topologically sorted expressions. This only includes
//! exprs required to genereate registered outputs.
std::vector<Expr*> exprs();
//! Return a vector of fusion inputs that feed this Val
std::vector<Val*> inputsOf(Val* val);
//! Return all Vals in math expressions that cannot be eliminated.
//!
//! It is generally equivalent to vals that are used to generate
//! outputs, however, when a multi-output expression exists, and only
//! some of the outputs are used, the remaining unused outputs are
//! also included as they must show up in the final code.
std::vector<Val*> usedMathVals();
//! Returns all vals that are produced by used math expressions and
//! also do not have further consumers.
//!
//! In the case of an active multi-output expressions, the returned vector
//! will include the expression outputs that did not lead to an fusion
//! output.
std::vector<Val*> terminatingMathVals();
//! Return all Exprs that use val
std::unordered_set<Expr*> unordered_uses(const Val* val) const;
//! Return the Expr that produces val
Expr* definition(const Val* val) const;
//! Indicate to kernel to set itself up to generate random numbers
bool isStochastic();
//! Run fusion segmentation algorithm to create a segmented fusion
std::unique_ptr<SegmentedFusion> segment(
const at::ArrayRef<at::IValue>& inputs);
const auto& inputs() const {
return inputs_;
}
const auto& outputs() const {
return outputs_;
}
std::vector<Val*> getTerminatingOutputs();
// Aliasing output to input value, this is a WAR to allow inplace update on
// input tensor.
// Note: this is not always safe and should be used with extra caution.
// Currently the only place it's used is in the running stats update for batch
// normalization.
// TODO: alias should be made aware to segmentation, so we'll always include
// the input tensor to the section where output is produced.
void aliasOutputToInput(Val* output, Val* input);
Val* getOutputAlias(Val* output);
std::unordered_set<int> getOutputAliasIndices() const;
std::vector<std::pair<int, int>> getInputAliasIndices() const;
// mark input at index to be permuted by permutation
void setPermutationOnInput(int index, std::vector<int64_t> permutation) {
permuted_input_map_.insert({index, permutation});
}
// mark output at index to be restored by permutation
void setPermutationOnOutput(int index, std::vector<int64_t> permutation) {
permuted_output_map_.insert({index, permutation});
}
// return a map of indices to permutation, which indicates all input tensors
// that needs to be permuted
const PermutationMap& getPermutationInputMap() const {
return permuted_input_map_;
}
// return a map of indices to permutation, which indicates all output tensors
// that needs to be permuted
const PermutationMap& getPermutationOutputMap() const {
return permuted_output_map_;
}
bool isTVUseInfoValid() {
return all_tv_uses_valid_;
}
bool isUpdatingTVUseInfo() {
return is_during_update_uses_;
}
const auto& ioAlias() const {
return io_alias_;
}
protected:
friend SegmentCandidateFinder;
friend SegmentedFusion;
friend class TranslateApplicableWelford;
friend Val;
static IrCloner copy(const Fusion* from, Fusion* to);
//! Register the Val with this fusion
virtual void registerVal(Val* val) override;
//! Register expr with this fusion.
//! When we register an expression, we want to update the dependency tracking
//! of Vals. If this container is a not a Kernel, it will remove previous
//! definitions of outputs and register this Expr as the definition. Otherwise
//! will update definition if not previously set, but will not remove old
//! definitions.
virtual void registerExpr(Expr* expr) override;
//! Clear Expr's from TV uses that are not required to produce outputs from
//! inputs. Only other place this is used (other than Fusion) is in
//! Val::uses()
void resetTvUses();
private:
// Determine if the two values are compatible for aliasing
// Same DataType, ValType, and number of dimensions
bool isAliasCompatible(Val* left, Val* right);
private:
// Fusion inputs and outputs
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;
// io alias pointing from output to input
std::unordered_map<Val*, Val*> io_alias_;
// See Note [ Permutation support in nvfuser ]
// map from indices of input tensor to permutation
PermutationMap permuted_input_map_;
// map from indices of output tensor to permutation
PermutationMap permuted_output_map_;
// Records if the current use data in the IR nodes are valid
// the states are either all valid or all invalid
bool all_tv_uses_valid_ = false;
bool is_during_update_uses_ = false;
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch