mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Follows #133067 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133399 Approved by: https://github.com/Skylion007
55 lines
1.7 KiB
C++
55 lines
1.7 KiB
C++
#pragma once
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/core/stack.h>
|
|
#include <torch/csrc/Export.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
namespace torch::jit {
|
|
|
|
constexpr int kCPUDevice = -1;
|
|
|
|
// Assigns a "key" to the given fusion_group that it can use to run its
|
|
// fusion later (via runFusion() below).
|
|
TORCH_API int64_t registerFusion(const Node* fusion_group);
|
|
|
|
// Runs the fusion corresponding to the given key on the inputs
|
|
// found on the stack. Outputs are placed on the same stack.
|
|
// In some cases a fusion cannot be run and a fallback path where
|
|
// PyTorch's interpreter runs the graph instead is attempted.
|
|
TORCH_API void runFusion(const int64_t key, Stack& stack);
|
|
|
|
// True if the respective devices can fuse, false otherwise
|
|
TORCH_API bool canFuseOnCPU();
|
|
TORCH_API bool canFuseOnGPU();
|
|
|
|
// Sets whether fusion on the CPU is allowed (disabled by default due to
|
|
// flakiness)
|
|
TORCH_API void overrideCanFuseOnCPU(bool value);
|
|
|
|
// Sets whether fusion on CPU must use LLVM Codegen and not SimplieIREval
|
|
TORCH_API void overrideMustUseLLVMOnCPU(bool value);
|
|
|
|
// Sets whether fusion on the GPU is allowed (enabled by default)
|
|
TORCH_API void overrideCanFuseOnGPU(bool value);
|
|
|
|
// Treats the given graph as a fusion group and launches it on the
|
|
// specified device with the given inputs.
|
|
// Returns the outputs.
|
|
TORCH_API std::vector<at::Tensor> debugLaunchGraph(
|
|
Graph& graph,
|
|
at::ArrayRef<at::Tensor> inputs);
|
|
|
|
// Treats the given graph as a fusion group and returns the generated code.
|
|
TORCH_API std::string debugGetFusedKernelCode(
|
|
Graph& graph,
|
|
at::ArrayRef<at::Tensor> inputs);
|
|
|
|
TORCH_API size_t nCompiledKernels();
|
|
|
|
} // namespace torch::jit
|