mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Follow up to https://github.com/pytorch/pytorch/issues/68095 This also changes the files from the ATen folder to include c10's `Export.h` instead since they can't ever be exporting `TORCH_PYTHON_API`. cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/69585 Reviewed By: mrshenli Differential Revision: D32958594 Pulled By: albanD fbshipit-source-id: 1ec7ef63764573fa2b486928955e3a1172150061
33 lines
859 B
C++
33 lines
859 B
C++
#pragma once
|
|
|
|
#include <torch/csrc/Export.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
/*
|
|
* API for query node-compatibility in CudaCodeGen
|
|
*
|
|
* It is used in the optimization passes, where the graph is traversed and parts
|
|
* that could be handled by CudaCodegen is partitioned and stuffed in
|
|
* `attr::Subgraph` of `prim::CudaFusionGroup`.
|
|
*
|
|
* Logic right now is very simple. On top of device placement, we consider a
|
|
* `Node` compatible when we have a parsing rule for it in our parser.
|
|
*/
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
TORCH_CUDA_CU_API bool isFusibleCudaFusionGroup(const Node* node);
|
|
|
|
// consider if `node` could be fused into `fusion`
|
|
TORCH_CUDA_CU_API bool isFusibleCudaFusionGroup(
|
|
const Node* fusion,
|
|
const Node* node);
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|