pytorch/torch/csrc/jit/codegen/cuda/manager.cpp
Christian Sarofeen 6d24f8fe21 Infrastructure for a new CUDA Fuser (#34785)
Summary:
**Summary:** This PR contains the infrastructure of a new CUDA fuser. This CUDA fuser is based on many of the same principles of TensorExpressions and Halide, however the implementation is ground up. The fusion pass itself is similar to the default CUDA fuser, however, it has undergone some refactoring and is using the new code generation infrastructure. For those who are interested in how the code generation in this PR works, I would recommend reviewing _test/cpp/jit/test_gpu_fusion.cpp_ as well as the long comment section at the beginning of _torch/csrc/jit/codegen/cuda/transform_replay.h_  One of the largest differences between our approach and that of TVM/Halide, is the concept of "TensorView". TensorView from a high level should be thought of similarly to how we think of working with Tensors in PyTorch. It's an N-D object which can undergo transformations that change its dimensionality. Dimensionality changes are done through the operations split/merge/reorder/computeAt. These transformations are similar to split/fuse/reorder/compute_at of TVM, they modify how a tensor is iterated over to generate GPU code. Interestingly, in our scheme these transformations are applied to tensors and only impact how that tensor is generated.

**Warning:** This PR is purposefully not feature complete with the current fuser. We wanted to separate out the infrastructure from the fusion capabilities. Once in, smaller incremental PRs will be submitted to expand capabilities of the fuser.

**Short term goals:**

Parity with current CUDA fuser (including performance):
- Dynamic shapes (no recompilation)
- Implicit handling of braodcast (broadcasted tensors are treated as tensors of the braodcasted size in the generated code)
- Dropout

**Mid-term goals:**

- Transposes fused with pointwise operations where transpose involves only 2 axes (across the fused operation).
- 1-D reductions fused with pointwise operations
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34785

Reviewed By: ZolotukhinM

Differential Revision: D20650977

Pulled By: soumith

fbshipit-source-id: ee39c95a880e1b9822e874ed4cc180971572bf63
2020-04-02 09:22:42 -07:00

165 lines
5.0 KiB
C++

#include <torch/csrc/jit/codegen/cuda/manager.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/parser.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/tensor.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <unordered_map>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace {
// CudaFusionManager holds compiled `CudaKernel` and handles all interfacing
// including compilation and execution.
//
// We cache two maps here:
// a. string of graph -> kernel_id
// b. kernel_id -> CudaKernel
//
// This allows CudaKernel reuse across nodes;
class CudaFusionManager {
public:
static CudaFusionManager& getManager() {
static CudaFusionManager cuda_fusion_manager_;
return cuda_fusion_manager_;
};
// TODO: I'm assuming we have stride information in `graph->toString`
// We need to make sure stride information is in the final string, as we
// want to AVOID kernel reuse between different fusion_node, unless they
// have identical contiguity information! (So identical stride + shape
// is even more restricting in a good way)
int32_t registerOrGetCacheId(std::shared_ptr<Graph>& graph) {
std::lock_guard<std::mutex> guard(mutex_);
// prepare graph for lowering;
Canonicalize(graph, false);
// EraseShapeInformation(graph);
auto repr = graph->toString(false);
// create new graph_cache_ entry;
if (graph_cache_.count(repr) == 0) {
int32_t kernel_id = getNextUniqueID();
graph_cache_[repr] = kernel_id;
Fusion fusion;
// lower torch::jit::Graph to torch::jit::fuser::cuda::fusion
parseJitIR(graph, fusion);
// default constructor via accessing empty key;
compileKernel(fusion, kernel_cache_[kernel_id]);
return kernel_id;
} else {
return graph_cache_[repr];
}
};
void runFusionNode(
int32_t kernel_id,
const at::ArrayRef<IValue> inputs,
std::vector<at::Tensor> outputs) {
TORCH_CHECK(
kernel_cache_.count(kernel_id) != 0, "kernel id not recognized");
CudaKernel& cuda_kernel_entry = kernel_cache_[kernel_id];
runKernel(cuda_kernel_entry, inputs, outputs);
}
private:
std::mutex mutex_;
void runCudaKernel(
int32_t key,
const std::vector<int>& contiguity_tag,
const c10::Device){};
int32_t getNextUniqueID() {
return next_unique_id_++;
};
std::unordered_map<std::string, int32_t> graph_cache_;
std::unordered_map<int64_t, CudaKernel> kernel_cache_;
int32_t next_unique_id_ = 0;
};
} // namespace
void compileCudaFusionGroup(Node* fusion_node) {
TORCH_CHECK(
fusion_node->kind() == prim::CudaFusionGroup,
"Only prim::CudaFusionGroup can be compiled");
if (fusion_node->hasAttribute(attr::cache_id)) {
TORCH_WARN("Double registration of CudaFusionGroup on CudaFusionManager");
}
int32_t fusion_cache_id =
CudaFusionManager::getManager().registerOrGetCacheId(
fusion_node->g(attr::Subgraph));
fusion_node->i_(attr::cache_id, fusion_cache_id);
}
void runCudaFusionGroup(const Node* const fusion_node, Stack& stack) {
TORCH_CHECK(
fusion_node->kind() == prim::CudaFusionGroup,
"prim::CudaFusionGroup expected");
// TODO: should we support runtime compilation with updated dynamic shape;
// shape inference would be needed so we can allocate output;
TORCH_CHECK(
fusion_node->hasAttribute(attr::cache_id),
"node prim::CudaFusionGroup has not been compiled yet");
int32_t kernel_id = fusion_node->i(attr::cache_id);
// Currently we just construct I/O tensors for static graph;
const std::shared_ptr<Graph> graph = fusion_node->g(attr::Subgraph);
const auto nInputs = graph->inputs().size();
at::ArrayRef<IValue> inputs = last(stack, nInputs);
// we need to construct outputs;
std::vector<at::Tensor> outputs;
for (const auto* const output : graph->outputs()) {
auto type = output->type()->expect<TensorType>();
// Expect output to be tensor;
TORCH_CHECK(
type && type->isComplete(),
"Complete TensorType for output is expected.");
const auto device = *(type->device());
const auto scalar_type = *(type->scalarType());
auto options = at::TensorOptions()
.dtype(scalar_type)
.layout(at::kStrided)
.device(device)
.requires_grad(type->requires_grad());
// TODO: We should infer output shape from `inputs`
const auto sizes = extractSizes(type);
const auto strides = extractStrides(type);
auto tensor = at::empty_strided(sizes, strides, options);
outputs.push_back(tensor);
}
CudaFusionManager::getManager().runFusionNode(kernel_id, inputs, outputs);
drop(stack, inputs.size());
stack.insert(
stack.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch