pytorch/torch/csrc/jit/codegen/cuda/interface.cpp
Sebastian Messmer 53af9df557 Unify boxed function signature between jit and c10 (#37034)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37034

c10 takes a Stack* in boxed functions while JIT took Stack&.
c10 doesn't return anything while JIT returns an int which is always zero.

This changes JIT to follow the c10 behavior.
ghstack-source-id: 106834069

Test Plan: unit tests

Differential Revision: D20567950

fbshipit-source-id: 1a7aea291023afc52ae706957e9a5ca576fbb53b
2020-06-29 19:24:26 -07:00

54 lines
1.4 KiB
C++

#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
CudaFuserInterface* getFuserInterface() {
static CudaFuserInterface fuser_interface_;
return &fuser_interface_;
}
void compileFusionGroup(Node* fusion_node) {
TORCH_CHECK(
getFuserInterface()->fn_compile_n_ != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_compile_n_(fusion_node);
}
void runFusionGroup(const Node* fusion_node, Stack& stack) {
TORCH_CHECK(
getFuserInterface()->fn_run_n_s_ != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_run_n_s_(fusion_node, stack);
}
void fuseGraph(std::shared_ptr<Graph>& graph) {
TORCH_CHECK(
getFuserInterface()->fn_fuse_graph != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_fuse_graph(graph);
}
} // namespace cuda
} // namespace fuser
namespace {
RegisterOperators reg({
Operator(
prim::CudaFusionGroup,
[](const Node* node) -> Operation {
return [node](Stack* stack) {
fuser::cuda::runFusionGroup(node, *stack);
};
},
c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
});
}
} // namespace jit
} // namespace torch