mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76485 Adds an environment variable `PYTORCH_JIT_ENABLE_NVFUSER` for controlling whether or not nvfuser is enabled. This required changing the PassManager behavior to support the case where nvfuser gets enabled by default when PYTORCH_JIT_ENABLE_NVFUSER=1. Previously the solution for turning nvfuser on or off was to use the PassManager to register or un-register the pass. That works fine if the pass starts of _disabled_, but causes issues once we try to enable the pass by default. The main issue with enabling by default is with the validation check to see whether NVFuser can be turned on. The check relies on at::globalContext().hasCUDA(), which requires CUDAHooks to be registered before hasCUDA() wil work correctly. At static initialization time it's difficult to ensure that CUDAHooks will be registered _before_ we attempt to register the nvfuser pass. In OSS it worked fine, but in internal builds it would fail on ROCm builds. To fix this, we switch the control of NVFuser enablement to a check in the pass. i.e. previously, we enabled/disabled nvfuser by registering or de-registering the pass in pass manager; now, the pass is always registered in pass manager, and enablement is done by a check within the nvfuser pass. Remaining TODO: Connect this with NNC so that in cases where NNC is available but not NVFuser (i.e. on AMD gpus), NNC can be turned on automatically. Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D35982618 Pulled By: davidberard98 fbshipit-source-id: fd5b76bc0b8c8716c96fdc04bebfb15026a7ef60 (cherry picked from commit ff14603ff5ac8d9b6c749c4f111f4a8be8023b7f)
78 lines
2.3 KiB
C++
78 lines
2.3 KiB
C++
#pragma once
|
|
|
|
#include <c10/macros/Export.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/passes/pass_manager.h>
|
|
#include <torch/csrc/jit/runtime/profiling_record.h>
|
|
|
|
/*
|
|
* This file contains APIs for cuda fuser;
|
|
*
|
|
* We use an empty static struct to hold the function pointers, which are
|
|
* registered separately. This is to support cpu-only compilation.
|
|
* Registration is done in torch/csrc/jit/codegen/cuda/register_interface.cpp
|
|
*/
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
TORCH_API std::atomic<bool>& getCudaFusionGuardMode();
|
|
|
|
TORCH_API bool getSingletonFusion();
|
|
TORCH_API bool setSingletonFusion(bool value);
|
|
TORCH_API bool getHorizontalFusion();
|
|
TORCH_API bool setHorizontalFusion(bool value);
|
|
|
|
// dummy struct to allow API registration
|
|
struct CudaFuserInterface {
|
|
void (*fn_compile_n)(Node*) = nullptr;
|
|
void (*fn_run_n_s)(const Node*, Stack&) = nullptr;
|
|
void (*fn_fuse_graph)(std::shared_ptr<Graph>&) = nullptr;
|
|
bool (*fn_can_fuse_n)(const Node*) = nullptr;
|
|
void (*fn_insert_profile_inodes)(ProfilingRecord* pr) = nullptr;
|
|
bool (*fn_profile_n)(const Node*) = nullptr;
|
|
bool (*fn_skip_n)(const std::string&, bool flip) = nullptr;
|
|
};
|
|
|
|
// Get interface, this is used by registration and user facing API internally
|
|
TORCH_API CudaFuserInterface* getFuserInterface();
|
|
|
|
TORCH_API void compileFusionGroup(Node* fusion_node);
|
|
TORCH_API void runFusionGroup(const Node* fusion_node, Stack& stack);
|
|
TORCH_API void fuseGraph(std::shared_ptr<Graph>&);
|
|
TORCH_API bool canFuseNode(const Node* node);
|
|
TORCH_API void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr);
|
|
TORCH_API bool profileNode(const Node* node);
|
|
|
|
TORCH_API bool skipNode(const std::string& symbol_str, bool flip = true);
|
|
|
|
TORCH_API bool complyWith(
|
|
const at::Tensor& tensor,
|
|
const c10::TensorTypePtr& guard_tensor_type);
|
|
|
|
TORCH_API bool isEnabled();
|
|
TORCH_API bool setEnabled(bool is_enabled);
|
|
|
|
struct TORCH_API NVFuserPassManager : public PassManager<NVFuserPassManager> {
|
|
static bool registerPass(bool enabled) {
|
|
bool old_value = PassManager::isRegistered();
|
|
if (enabled) {
|
|
PassManager::registerPass(fuseGraph);
|
|
} else {
|
|
PassManager::clearPass();
|
|
}
|
|
return old_value;
|
|
}
|
|
|
|
static bool isRegistered() {
|
|
return PassManager::isRegistered();
|
|
}
|
|
};
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|