pytorch/torch/csrc/jit/codegen/cuda/interface.cpp
jjsjann123 99e0a87bbb [nvFuser] Latency improvements for pointwise + reduction fusion (#45218)
Summary:
A lot of changes are in this update, some highlights:

- Added Doxygen config file
- Split the fusion IR (higher level TE like IR) from kernel IR (lower level CUDA like IR)
- Improved latency with dynamic shape handling for the fusion logic
- Prevent recompilation for pointwise + reduction fusions when not needed
- Improvements to inner dimension reduction performance
- Added input -> kernel + kernel launch parameters cache, added eviction policy
- Added reduction fusions with multiple outputs (still single reduction stage)
- Fixed code generation bugs for symbolic tiled GEMM example
- Added thread predicates to prevent shared memory form being loaded multiple times
- Improved sync threads placements with shared memory and removed read before write race
- Fixes to FP16 reduction fusions where output would come back as FP32

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45218

Reviewed By: ezyang

Differential Revision: D23905183

Pulled By: soumith

fbshipit-source-id: 12f5ad4cbe03e9a25043bccb89e372f8579e2a79
2020-09-24 23:17:20 -07:00

55 lines
1.4 KiB
C++

#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
CudaFuserInterface* getFuserInterface() {
static CudaFuserInterface fuser_interface_;
return &fuser_interface_;
}
void compileFusionGroup(Node* fusion_node) {
TORCH_CHECK(
getFuserInterface()->fn_compile_n_ != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_compile_n_(fusion_node);
}
void runFusionGroup(const Node* fusion_node, Stack& stack) {
TORCH_CHECK(
getFuserInterface()->fn_run_n_s_ != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_run_n_s_(fusion_node, stack);
}
void fuseGraph(std::shared_ptr<Graph>& graph) {
TORCH_CHECK(
getFuserInterface()->fn_fuse_graph != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_fuse_graph(graph);
}
} // namespace cuda
} // namespace fuser
namespace {
RegisterOperators reg({
Operator(
prim::CudaFusionGroup,
[](const Node* node) -> Operation {
return [node](Stack* stack) {
fuser::cuda::runFusionGroup(node, *stack);
};
},
c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
});
}
} // namespace jit
} // namespace torch