#pragma once #include #include #include #include #include #include #include namespace torch::jit { constexpr int kCPUDevice = -1; // Assigns a "key" to the given fusion_group that it can use to run its // fusion later (via runFusion() below). TORCH_API int64_t registerFusion(const Node* fusion_group); // Runs the fusion corresponding to the given key on the inputs // found on the stack. Outputs are placed on the same stack. // In some cases a fusion cannot be run and a fallback path where // PyTorch's interpreter runs the graph instead is attempted. TORCH_API void runFusion(const int64_t key, Stack& stack); // True if the respective devices can fuse, false otherwise TORCH_API bool canFuseOnCPU(); TORCH_API bool canFuseOnGPU(); // Sets whether fusion on the CPU is allowed (disabled by default due to // flakiness) TORCH_API void overrideCanFuseOnCPU(bool value); // Sets whether fusion on CPU must use LLVM Codegen and not SimplieIREval TORCH_API void overrideMustUseLLVMOnCPU(bool value); // Sets whether fusion on the GPU is allowed (enabled by default) TORCH_API void overrideCanFuseOnGPU(bool value); // Treats the given graph as a fusion group and launches it on the // specified device with the given inputs. // Returns the outputs. TORCH_API std::vector debugLaunchGraph( Graph& graph, at::ArrayRef inputs); // Treats the given graph as a fusion group and returns the generated code. TORCH_API std::string debugGetFusedKernelCode( Graph& graph, at::ArrayRef inputs); TORCH_API size_t nCompiledKernels(); } // namespace torch::jit