pytorch/torch/csrc/jit/codegen/cuda/interface.cpp
Soumith Chintala d9dd353a00 fix clang-format (#35884)
Summary:
breakage introduced in PR that I landed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35884

Differential Revision: D20817603

Pulled By: soumith

fbshipit-source-id: b0729bed81549d4c8e6a889c380baa19c73ef127
2020-04-02 12:12:27 -07:00

53 lines
1.4 KiB
C++

#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
CudaFuserInterface* getFuserInterface() {
static CudaFuserInterface fuser_interface_;
return &fuser_interface_;
}
void compileFusionGroup(Node* fusion_node) {
TORCH_CHECK(
getFuserInterface()->fn_compile_n_ != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_compile_n_(fusion_node);
}
void runFusionGroup(const Node* fusion_node, Stack& stack) {
TORCH_CHECK(
getFuserInterface()->fn_run_n_s_ != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_run_n_s_(fusion_node, stack);
}
void fuseGraph(std::shared_ptr<Graph>& graph) {
TORCH_CHECK(
getFuserInterface()->fn_fuse_graph != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_fuse_graph(graph);
}
} // namespace cuda
} // namespace fuser
RegisterOperators reg({
Operator(
prim::CudaFusionGroup,
[](const Node* node) -> Operation {
return [node](Stack& stack) {
fuser::cuda::runFusionGroup(node, stack);
return 0;
};
},
c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
});
} // namespace jit
} // namespace torch