mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. double support in expression evaluator - bug fixes: 1. dropout fix - rework RNG to support broadcasted dropout (Fixes #82784) 2. expand fix - Patch expand+reduction, expand+view, rework view analysis and guard - scheduler: 1. manual transpose schedule example 2. WIP transpose scheduler Commits that's in this PR from the devel branch: ``` b7435afcd22c917713c2f41a7237bc26e1183f14 Transpose scheduler, step 1 (#1854) 8a45dbf72034684eb8e18b1835b533e90b68f184 Add an example on how to manually schedule transpose (#1889) 83dbf56a9554b2efbd5416461d938fff477b0b27 Patch dropout fix (#1898) 69d3519a532250719b1aa8341b50e067b181b42d Expand+Reduction, Expand+View support, rework View analysis and guards (#1883) 15091c488e96343bdc49e3990acbf238a3b3da51 Rework RNG to correctly support broadcasted dropout (#1888) aafe2d048aaac596e503596a41303423619f3954 Make ExpressionEvaluator support Double (#1885) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38657074](https://our.internmc.facebook.com/intern/diff/D38657074) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83239 Approved by: https://github.com/davidberard98
87 lines
2.7 KiB
C++
87 lines
2.7 KiB
C++
#pragma once
|
|
|
|
#include <c10/macros/Export.h>
|
|
#include <torch/csrc/jit/codegen/cuda/transform_view.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/passes/pass_manager.h>
|
|
#include <torch/csrc/jit/runtime/profiling_record.h>
|
|
|
|
/*
|
|
* This file contains APIs for cuda fuser;
|
|
*
|
|
* We use an empty static struct to hold the function pointers, which are
|
|
* registered separately. This is to support cpu-only compilation.
|
|
* Registration is done in torch/csrc/jit/codegen/cuda/register_interface.cpp
|
|
*/
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
TORCH_API std::atomic<bool>& getCudaFusionGuardMode();
|
|
|
|
TORCH_API bool getSingletonFusion();
|
|
TORCH_API bool setSingletonFusion(bool value);
|
|
TORCH_API bool getHorizontalFusion();
|
|
TORCH_API bool setHorizontalFusion(bool value);
|
|
|
|
// dummy struct to allow API registration
|
|
struct CudaFuserInterface {
|
|
void (*fn_compile_n)(Node*) = nullptr;
|
|
void (*fn_run_n_s)(const Node*, Stack&) = nullptr;
|
|
void (*fn_fuse_graph)(std::shared_ptr<Graph>&) = nullptr;
|
|
bool (*fn_can_fuse_n)(const Node*) = nullptr;
|
|
void (*fn_insert_profile_inodes)(ProfilingRecord* pr) = nullptr;
|
|
bool (*fn_profile_n)(const Node*) = nullptr;
|
|
bool (*fn_skip_n)(const std::string&, bool flip) = nullptr;
|
|
AnalyzeViewConstraint (*fn_analyze_view)(
|
|
const std::vector<int64_t>& original_sizes,
|
|
const std::vector<int64_t>& new_sizes) = nullptr;
|
|
};
|
|
|
|
// Get interface, this is used by registration and user facing API internally
|
|
TORCH_API CudaFuserInterface* getFuserInterface();
|
|
|
|
TORCH_API void compileFusionGroup(Node* fusion_node);
|
|
TORCH_API void runFusionGroup(const Node* fusion_node, Stack& stack);
|
|
TORCH_API void fuseGraph(std::shared_ptr<Graph>&);
|
|
TORCH_API bool canFuseNode(const Node* node);
|
|
TORCH_API void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr);
|
|
TORCH_API bool profileNode(const Node* node);
|
|
|
|
TORCH_API bool skipNode(const std::string& symbol_str, bool flip = true);
|
|
|
|
TORCH_API AnalyzeViewConstraint getViewConstraint(
|
|
const std::vector<int64_t>& original_sizes,
|
|
const std::vector<int64_t>& new_sizes);
|
|
|
|
TORCH_API bool complyWith(
|
|
const at::Tensor& tensor,
|
|
const c10::TensorTypePtr& guard_tensor_type);
|
|
|
|
TORCH_API bool isEnabled();
|
|
TORCH_API bool setEnabled(bool is_enabled);
|
|
TORCH_API bool canBeEnabled();
|
|
|
|
struct TORCH_API NVFuserPassManager : public PassManager<NVFuserPassManager> {
|
|
static bool registerPass(bool enabled) {
|
|
bool old_value = PassManager::isRegistered();
|
|
if (enabled) {
|
|
PassManager::registerPass(fuseGraph);
|
|
} else {
|
|
PassManager::clearPass();
|
|
}
|
|
return old_value;
|
|
}
|
|
|
|
static bool isRegistered() {
|
|
return PassManager::isRegistered();
|
|
}
|
|
};
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|