mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3faed4c64477e5f5c678f21f33415d0195 Merge remote-tracking branch 'csarofeen/devel' into HEAD 16173732ecfafc4797e93c2449cfb778015a6c7a Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb7796bdcf055eb61d600b7b5c9df292950290 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6de62061d30781de50ef1862bbfb1615173 Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5bba3bc158d41ccbefa0ee2c5ceea7aedb Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522454aa715ef164c88a73fb8bdddc706805 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa219293a59e4166e258d76289fe13633ca Fix most inlined propagator for mismatched dims (#1875) 501f4aa270bf4dd47b0d2f4860bc6f23ebc32a38 Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d690f923047a85b5229a787118708f810741 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a61c87cd998e88ddd79a496548171c31e0 Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7a66b098f04c9d95a2d34ab2bceee151b3 fragment iteration to support fully unrolled mma ops (#1823) a48270a18dc2d3accc2626758d14d5858ae55032 Merge all dims in pointwise scheduler (#1872) 172fb3673fb4aaf4c1e889922a4fc5c06cbd59f7 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a5ac2fcf57a177bf36b0f26c61a4e252a4 Allow trivial reduction to be merged (#1871) 440102bcda6eb1dcd42d5fa5aeab9d6b049956bc Symmetric API for BestEffortReplay (#1870) d1caf330c08ea8002f7133ca655bbd5b28c4eb98 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda50be38eac96c00ba781340ac199d5a136 Remove some welford specific logic. (#1864) 51589d36be5a101d06e641fe0400b39028b7cb81 Some cleanups on tests and heuristics params (#1866) a6b3e70da5dee51dbc246347228ea21384e46ac3 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9b5e562d6f0caba5e7319e83e5df64104f Add nullptr checks to IrBuilder (#1861) 1cd9451d7493f631c2837ba07c1ea93a74e83a15 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9b8c454f557ab9fcf5b1c3cef9b9e136d0 Add leaky_relu operation (#1852) e842a9bab5e9f7289b7ce33ee37a682b22373f49 Minor cleanup in pointwise scheduler (#1858) 9ee850ca2f7f51dd5269bffb1255e485f809282d Fix stringstream usage (#1857) 20a36c1e4f28c4ff9837e56784be2686d17435f3 Improve nsight compute support (#1855) 405910308301097297b55c34d560aab6a360e897 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bfe8fdfacdbfdcfba9a624cdf900fe044d4 Misc cleanup (#1853) 5cc64943dc381a568223140bce0f22163c01e29f Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f0207e3a89fe90fd5cd3ffc575dfd766ba00 Cleanup normalization scheduler (#1845) db89c6591a2f21130599a93675e0615e55564e41 Type inference patch (#1848) 102fe93a4605ca465cda26ebaee4ba1af2026901 Add debug dump for InlinePropagator (#1847) b7a4d93d375a6e2ddef483763c93ffddc62ec452 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b256056d0e02877361b814ae6af32ca15f Upstream ci build fixes (#1842) 0b83645915029d67f9345aa4649b8c6f62b0061b Fix vectorization bug introduced in #1831 (#1840) 63630f1ae091180e541932a9d9dc598e0a9902dd Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a963c01d97ba34b1a7d2f106e78a13fd6651 Fix transpose benchmark dtype (#1839) 2c9a6c02312d5bf4f83cde653b847b4f85849432 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83067 Approved by: https://github.com/davidberard98
283 lines
9.2 KiB
C++
283 lines
9.2 KiB
C++
#pragma once
|
|
|
|
#include <ATen/core/ivalue.h>
|
|
#include <c10/macros/Export.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_container.h>
|
|
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
|
|
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
//! Usage: FusionGuard and Fusion are required user interfaces for any operation
|
|
//! underlying the code generator. In order to create values, expressions, and
|
|
//! generate code a Fusion instance must be active. It is the responsibility of
|
|
//! the user to create a Fusion instance and register it with the fusion guard.
|
|
//! The simplest example of this is:
|
|
//!
|
|
//! Fusion fusion;
|
|
//! FusionGuard fg(&fusion);
|
|
//!
|
|
//! Once a fusion is active all values and operations will be registered with
|
|
//! it.
|
|
//!
|
|
//! FusionGuard and Fusion are critical to the lifetime model of the IR system.
|
|
//! FusionGuard is a convenient way to set what base container instance holds
|
|
//! the defined IR. Statements that are defined are registered through the
|
|
//! FusionGuard with a particular Fusion. FusionGuard provides convenient
|
|
//! methods to access the active fusion so it doesn't need to be passed around
|
|
//! constantly. Any IR node derived classes from Statement must register with
|
|
//! Fusion to avoid memory leaks.
|
|
//!
|
|
//! Fusion is generally thought of as a translated fusion group from the JIT. It
|
|
//! is likely a single kernel, although, we don't have to stick to this in the
|
|
//! future and could in theory generate multiple kernels with an executor to run
|
|
//! them.
|
|
//!
|
|
//! Fusion also allows users to set input/output values that will allow us to
|
|
//! figure out how to hook up runtime data to and from the JIT as well as
|
|
//! provide us mechanisms for dependency analysis and DCE including safety
|
|
//! checks.
|
|
|
|
class Fusion;
|
|
class TensorView;
|
|
class WelfordResult;
|
|
|
|
class SegmentCandidateFinder;
|
|
class SegmentedFusion;
|
|
|
|
//! Fusion Guard is our "context manager". It holds the actrive fusion and
|
|
//! allows it to be accessed anywhere through FusionGuard::getCurFusion()
|
|
class TORCH_CUDA_CU_API FusionGuard {
|
|
public:
|
|
Fusion* prev_fusion;
|
|
|
|
//! Set the active fusion so it can be manipulated.
|
|
explicit FusionGuard(Fusion* fusion);
|
|
|
|
~FusionGuard();
|
|
|
|
static Fusion* getCurFusion();
|
|
static void setCurFusion(Fusion* fusion);
|
|
};
|
|
|
|
//! Fusion is mutable but unique. Nodes cannot be copied in any way from one
|
|
//! Fusion to another. If anything like that is desired, it would require
|
|
//! duplicating all associated values and exprs. Fusion is considered to be SSA,
|
|
//! though this could also change in the future if there is a good reason to do
|
|
//! so.
|
|
//!
|
|
//! The Fusion owns the whole IR graph (Vals and Exprs)
|
|
//!
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
class TORCH_CUDA_CU_API Fusion : public IrContainer {
|
|
typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;
|
|
|
|
public:
|
|
Fusion() = default;
|
|
|
|
Fusion(const Fusion& other);
|
|
Fusion(Fusion&& other) noexcept;
|
|
|
|
Fusion& operator=(const Fusion& other);
|
|
Fusion& operator=(Fusion&& other) noexcept;
|
|
|
|
~Fusion();
|
|
|
|
friend void swap(Fusion& a, Fusion& b) noexcept;
|
|
|
|
void clear() noexcept;
|
|
|
|
//! Break dependency chains associated with Expr, remove references to expr
|
|
//! delete expr
|
|
void removeExpr(Expr* expr) override;
|
|
|
|
//! Completely remove val from the fusion, break all dependencies associated
|
|
//! with it
|
|
void removeVal(Val* val) override;
|
|
|
|
//! Register input as an input of the fusion
|
|
void addInput(Val* input);
|
|
|
|
//! Register output as an output of the fusion
|
|
void addOutput(Val* output);
|
|
|
|
//! Deregister input as an input of the fusion
|
|
void removeInput(Val* input);
|
|
|
|
//! Deregister output as an output of the fusion
|
|
void removeOutput(Val* output);
|
|
|
|
//! Replace output with another value
|
|
void replaceOutput(Val* output, Val* replacement);
|
|
|
|
//! Assert that all leaves found from outputs are registered as an input
|
|
void validateInputs();
|
|
|
|
//! Print this fusion to the console
|
|
void print();
|
|
|
|
//! Print Arith exprs
|
|
//! \param from_outputs_only Only print exprs reachable from outputs
|
|
void printMath(bool from_outputs_only = true);
|
|
|
|
//! Print transformations used in fusion (can be very verbose)
|
|
void printTransforms();
|
|
|
|
//! Lower the fusion and print a kernel
|
|
void printKernel(DataType index_type = DataType::Int);
|
|
|
|
//! Return a list of topologically sorted expressions. This only includes
|
|
//! exprs required to genereate registered outputs.
|
|
std::vector<Expr*> exprs();
|
|
|
|
//! Return a vector of fusion inputs that feed this Val
|
|
std::vector<Val*> inputsOf(Val* val);
|
|
|
|
//! Return all Vals in math expressions that cannot be eliminated.
|
|
//!
|
|
//! It is generally equivalent to vals that are used to generate
|
|
//! outputs, however, when a multi-output expression exists, and only
|
|
//! some of the outputs are used, the remaining unused outputs are
|
|
//! also included as they must show up in the final code.
|
|
std::vector<Val*> usedMathVals();
|
|
|
|
//! Returns all vals that are produced by used math expressions and
|
|
//! also do not have further consumers.
|
|
//!
|
|
//! In the case of an active multi-output expressions, the returned vector
|
|
//! will include the expression outputs that did not lead to an fusion
|
|
//! output.
|
|
std::vector<Val*> terminatingMathVals();
|
|
|
|
//! Return all Exprs that use val
|
|
std::unordered_set<Expr*> unordered_uses(const Val* val) const;
|
|
|
|
//! Return the Expr that produces val
|
|
Expr* definition(const Val* val) const;
|
|
|
|
//! Indicate to kernel to set itself up to generate random numbers
|
|
bool isStochastic();
|
|
|
|
//! Run fusion segmentation algorithm to create a segmented fusion
|
|
std::unique_ptr<SegmentedFusion> segment(
|
|
const at::ArrayRef<at::IValue>& inputs);
|
|
|
|
const auto& inputs() const {
|
|
return inputs_;
|
|
}
|
|
|
|
const auto& outputs() const {
|
|
return outputs_;
|
|
}
|
|
|
|
std::vector<Val*> getTerminatingOutputs();
|
|
|
|
// Aliasing output to input value, this is a WAR to allow inplace update on
|
|
// input tensor.
|
|
// Note: this is not always safe and should be used with extra caution.
|
|
// Currently the only place it's used is in the running stats update for batch
|
|
// normalization.
|
|
// TODO: alias should be made aware to segmentation, so we'll always include
|
|
// the input tensor to the section where output is produced.
|
|
void aliasOutputToInput(Val* output, Val* input);
|
|
Val* getOutputAlias(Val* output);
|
|
std::unordered_set<int> getOutputAliasIndices() const;
|
|
std::vector<std::pair<int, int>> getInputAliasIndices() const;
|
|
|
|
// mark input at index to be permuted by permutation
|
|
void setPermutationOnInput(int index, std::vector<int64_t> permutation) {
|
|
permuted_input_map_.insert({index, permutation});
|
|
}
|
|
|
|
// mark output at index to be restored by permutation
|
|
void setPermutationOnOutput(int index, std::vector<int64_t> permutation) {
|
|
permuted_output_map_.insert({index, permutation});
|
|
}
|
|
|
|
// return a map of indices to permutation, which indicates all input tensors
|
|
// that needs to be permuted
|
|
const PermutationMap& getPermutationInputMap() const {
|
|
return permuted_input_map_;
|
|
}
|
|
|
|
// return a map of indices to permutation, which indicates all output tensors
|
|
// that needs to be permuted
|
|
const PermutationMap& getPermutationOutputMap() const {
|
|
return permuted_output_map_;
|
|
}
|
|
|
|
bool isTVUseInfoValid() {
|
|
return all_tv_uses_valid_;
|
|
}
|
|
|
|
bool isUpdatingTVUseInfo() {
|
|
return is_during_update_uses_;
|
|
}
|
|
|
|
const auto& ioAlias() const {
|
|
return io_alias_;
|
|
}
|
|
|
|
protected:
|
|
friend SegmentCandidateFinder;
|
|
friend SegmentedFusion;
|
|
friend class TranslateApplicableWelford;
|
|
friend Val;
|
|
|
|
static IrCloner copy(const Fusion* from, Fusion* to);
|
|
|
|
//! Register the Val with this fusion
|
|
virtual void registerVal(Val* val) override;
|
|
|
|
//! Register expr with this fusion.
|
|
//! When we register an expression, we want to update the dependency tracking
|
|
//! of Vals. If this container is a not a Kernel, it will remove previous
|
|
//! definitions of outputs and register this Expr as the definition. Otherwise
|
|
//! will update definition if not previously set, but will not remove old
|
|
//! definitions.
|
|
virtual void registerExpr(Expr* expr) override;
|
|
|
|
//! Clear Expr's from TV uses that are not required to produce outputs from
|
|
//! inputs. Only other place this is used (other than Fusion) is in
|
|
//! Val::uses()
|
|
void resetTvUses();
|
|
|
|
private:
|
|
// Determine if the two values are compatible for aliasing
|
|
// Same DataType, ValType, and number of dimensions
|
|
bool isAliasCompatible(Val* left, Val* right);
|
|
|
|
private:
|
|
// Fusion inputs and outputs
|
|
std::vector<Val*> inputs_;
|
|
std::vector<Val*> outputs_;
|
|
|
|
// io alias pointing from output to input
|
|
std::unordered_map<Val*, Val*> io_alias_;
|
|
|
|
// See Note [ Permutation support in nvfuser ]
|
|
// map from indices of input tensor to permutation
|
|
PermutationMap permuted_input_map_;
|
|
// map from indices of output tensor to permutation
|
|
PermutationMap permuted_output_map_;
|
|
|
|
// Records if the current use data in the IR nodes are valid
|
|
// the states are either all valid or all invalid
|
|
bool all_tv_uses_valid_ = false;
|
|
bool is_during_update_uses_ = false;
|
|
};
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|