mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Sub-step of my attempt to split up the torch_cuda library, as it is huge. Please look at https://github.com/pytorch/pytorch/issues/49050 for details on the split and which files are in which target. This PR introduces two new macros for Windows DLL purposes, TORCH_CUDA_CPP_API and TORCH_CUDA_CU_API. Both are defined as TORCH_CUDA_API for the time being. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50627 Reviewed By: mruberry Differential Revision: D25955441 Pulled By: janeyx99 fbshipit-source-id: ff226026833b8fb2fb7c77df6f2d6c824f006869
41 lines
1.3 KiB
C++
41 lines
1.3 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
|
|
/*
|
|
* This file handles compilation and execution of a CudaFusionGroup;
|
|
*
|
|
* A CudaFusionGroup node comes with `attr::Subgraph` containing the computation
|
|
* graph. We compile the graph to generate CUDA function and cache them in a
|
|
* registry. We cache & reuse kernels across nodes sharing identical graph.
|
|
*
|
|
* After compilation, we assign the key to cached kernel as an integer attribute
|
|
* on the node `attr::cache_id`.
|
|
*/
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
// Get fusion_node ready for execution.
|
|
// find or compile `CudaKernel` for graph stored in `attr::Subgraph`
|
|
// this function assigns `attr::cache_id` to `fusion_node`
|
|
TORCH_CUDA_CU_API void compileCudaFusionGroup(Node* fusion_node);
|
|
|
|
// Execute fusion_node.
|
|
// Current protocol is that the function allocates output tensor append them to
|
|
// `stack` after execution.
|
|
// TODO: support shape inferencing. Right now we only handles static shape
|
|
TORCH_CUDA_CU_API void runCudaFusionGroup(
|
|
const Node* fusion_node,
|
|
Stack& stack);
|
|
|
|
TORCH_CUDA_CU_API void CudaFuseGraph(std::shared_ptr<Graph>& graph);
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|