mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Follow up to https://github.com/pytorch/pytorch/issues/68095 This also changes the files from the ATen folder to include c10's `Export.h` instead since they can't ever be exporting `TORCH_PYTHON_API`. cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/69585 Reviewed By: mrshenli Differential Revision: D32958594 Pulled By: albanD fbshipit-source-id: 1ec7ef63764573fa2b486928955e3a1172150061
57 lines
1.7 KiB
C++
57 lines
1.7 KiB
C++
#pragma once
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/core/stack.h>
|
|
#include <torch/csrc/Export.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
constexpr int kCPUDevice = -1;
|
|
|
|
// Assigns a "key" to the given fusion_group that it can use to run its
|
|
// fusion later (via runFusion() below).
|
|
TORCH_API int64_t registerFusion(const Node* fusion_group);
|
|
|
|
// Runs the fusion corresponding to the given key on the inputs
|
|
// found on the stack. Outputs are placed on the same stack.
|
|
// In some cases a fusion cannot be run and a fallback path where
|
|
// PyTorch's interpreter runs the graph instead is attempted.
|
|
TORCH_API void runFusion(const int64_t key, Stack& stack);
|
|
|
|
// True if the respective devices can fuse, false otherwise
|
|
TORCH_API bool canFuseOnCPU();
|
|
TORCH_API bool canFuseOnGPU();
|
|
|
|
// Sets whether fusion on the CPU is allowed (disabled by default due to
|
|
// flakiness)
|
|
TORCH_API void overrideCanFuseOnCPU(bool value);
|
|
|
|
// Sets whether fusion on CPU must use LLVM Codegen and not SimplieIREval
|
|
TORCH_API void overrideMustUseLLVMOnCPU(bool value);
|
|
|
|
// Sets whether fusion on the GPU is allowed (enabled by default)
|
|
TORCH_API void overrideCanFuseOnGPU(bool value);
|
|
|
|
// Treats the given graph as a fusion group and launches it on the
|
|
// specified device with the given inputs.
|
|
// Returns the outputs.
|
|
TORCH_API std::vector<at::Tensor> debugLaunchGraph(
|
|
Graph& graph,
|
|
at::ArrayRef<at::Tensor> inputs);
|
|
|
|
// Treats the given graph as a fusion group and returns the generated code.
|
|
TORCH_API std::string debugGetFusedKernelCode(
|
|
Graph& graph,
|
|
at::ArrayRef<at::Tensor> inputs);
|
|
|
|
TORCH_API size_t nCompiledKernels();
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|