mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: A lot of changes are in this update, some highlights: - Added Doxygen config file - Split the fusion IR (higher level TE like IR) from kernel IR (lower level CUDA like IR) - Improved latency with dynamic shape handling for the fusion logic - Prevent recompilation for pointwise + reduction fusions when not needed - Improvements to inner dimension reduction performance - Added input -> kernel + kernel launch parameters cache, added eviction policy - Added reduction fusions with multiple outputs (still single reduction stage) - Fixed code generation bugs for symbolic tiled GEMM example - Added thread predicates to prevent shared memory form being loaded multiple times - Improved sync threads placements with shared memory and removed read before write race - Fixes to FP16 reduction fusions where output would come back as FP32 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45218 Reviewed By: ezyang Differential Revision: D23905183 Pulled By: soumith fbshipit-source-id: 12f5ad4cbe03e9a25043bccb89e372f8579e2a79
55 lines
1.4 KiB
C++
55 lines
1.4 KiB
C++
|
|
#include <torch/csrc/jit/codegen/cuda/interface.h>
|
|
#include <ATen/core/dispatch/OperatorOptions.h>
|
|
#include <torch/csrc/jit/runtime/custom_operator.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
CudaFuserInterface* getFuserInterface() {
|
|
static CudaFuserInterface fuser_interface_;
|
|
return &fuser_interface_;
|
|
}
|
|
|
|
void compileFusionGroup(Node* fusion_node) {
|
|
TORCH_CHECK(
|
|
getFuserInterface()->fn_compile_n_ != nullptr,
|
|
"Running the CUDA fuser requires a CUDA build.");
|
|
getFuserInterface()->fn_compile_n_(fusion_node);
|
|
}
|
|
|
|
void runFusionGroup(const Node* fusion_node, Stack& stack) {
|
|
TORCH_CHECK(
|
|
getFuserInterface()->fn_run_n_s_ != nullptr,
|
|
"Running the CUDA fuser requires a CUDA build.");
|
|
getFuserInterface()->fn_run_n_s_(fusion_node, stack);
|
|
}
|
|
|
|
void fuseGraph(std::shared_ptr<Graph>& graph) {
|
|
TORCH_CHECK(
|
|
getFuserInterface()->fn_fuse_graph != nullptr,
|
|
"Running the CUDA fuser requires a CUDA build.");
|
|
getFuserInterface()->fn_fuse_graph(graph);
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
|
|
namespace {
|
|
RegisterOperators reg({
|
|
Operator(
|
|
prim::CudaFusionGroup,
|
|
[](const Node* node) -> Operation {
|
|
return [node](Stack* stack) {
|
|
fuser::cuda::runFusionGroup(node, *stack);
|
|
};
|
|
},
|
|
c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
|
|
});
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|