pytorch/torch/csrc/jit/fuser/compiler.cpp
Adam Paszke 31b3d81714 Broadcast prim::FusedConcat inputs independently when checking kernels (#14503)
Summary:
Fixes #14483.

cc zou3519 mruberry
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14503

Differential Revision: D13256343

Pulled By: zou3519

fbshipit-source-id: 1c68a23f425be067a742bada7ee8cdfab7fc3fa2
2018-11-29 13:05:00 -08:00

228 lines
6.9 KiB
C++

#include "torch/csrc/jit/fuser/compiler.h"
#include "ATen/ATen.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/type.h"
#include "torch/csrc/jit/code_template.h"
#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/jit/passes/shape_analysis.h"
#include "torch/csrc/jit/fuser/interface.h"
#include "torch/csrc/jit/fuser/kernel_cache.h"
#include "torch/csrc/jit/fuser/codegen.h"
#include "torch/csrc/jit/fuser/tensor_desc.h"
#if USE_CUDA_FUSER
#include "torch/csrc/jit/fuser/cuda/fused_kernel.h"
#endif // USE_CUDA_FUSER
#if USE_CPU_FUSER
#include "torch/csrc/jit/fuser/cpu/fused_kernel.h"
#endif // USE_CUDA_FUSER
#include <iostream>
#include <memory>
#include <unordered_set>
#include <utility>
#include <string>
#include <atomic>
#include <sstream>
#include <stdexcept>
#include <tuple>
namespace torch { namespace jit { namespace fuser {
// Counter for number of kernels compiled, used for debugging and
// creating arbitrary kernel names.
static std::atomic<size_t> next_kernel_id{0};
static int debug_fusion{-1};
size_t nCompiledKernels() { return next_kernel_id.load(); }
int debugFuser() {
if (debug_fusion < 0) {
const char* debug_env = getenv("PYTORCH_FUSION_DEBUG");
debug_fusion = debug_env ? atoi(debug_env) : 0;
}
return debug_fusion;
}
// If the given node is used once by a chunk node, returns that node.
// Returns nullptr otherwise.
static const Node* usedInFusedChunk(const Value* input) {
const auto uses = input->uses();
if (uses.size() == 1) {
const Node *user = uses[0].user;
if (user->kind() == prim::ConstantChunk) {
return user;
}
}
return nullptr;
}
static void setInputChunkDescriptors(KernelSpec& spec) {
spec.inputChunks().reserve((spec.graph())->inputs().size());
for (const Value* input : (spec.graph())->inputs()) {
if (const Node* chunk = usedInFusedChunk(input)) {
spec.inputChunks().emplace_back(chunk->i(attr::chunks), chunk->i(attr::dim));
} else {
spec.inputChunks().emplace_back(1, 0);
}
}
}
// Run a DFS traversal to find all inputs that affect a given output value
static std::vector<int64_t> getInputDependencies(const Value* output) {
std::vector<const Value*> queue{output};
std::unordered_set<const Value*> inputs;
std::unordered_set<const Value*> seen;
while (!queue.empty()) {
const Value* val = queue.back(); queue.pop_back();
const Node* producer = val->node();
if (producer->kind() == prim::Param) {
inputs.insert(val);
continue;
}
for (const Value* input : producer->inputs()) {
if (/*bool inserted = */seen.insert(input).second) {
queue.push_back(input);
}
}
}
// Convert Value* into offsets into the graph's input list
std::vector<int64_t> offsets;
offsets.reserve(inputs.size());
for (const Value* input : inputs) {
offsets.push_back(input->offset());
}
std::sort(offsets.begin(), offsets.end());
return offsets;
}
static void setInputBroadcastGroups(KernelSpec& spec) {
std::unordered_set<std::vector<int64_t>, torch::hash<std::vector<int64_t>>> broadcast_groups;
for (const Value* output : (spec.graph())->outputs()) {
if (output->node()->kind() == prim::FusedConcat) {
for (const Value* concat_input : output->node()->inputs()) {
broadcast_groups.insert(getInputDependencies(concat_input));
}
} else {
broadcast_groups.insert(getInputDependencies(output));
}
}
std::copy(
broadcast_groups.begin()
, broadcast_groups.end()
, std::back_inserter(spec.inputBroadcastGroups()));
}
// Performs "upfront" compilation where storage is known but shapes are not.
// Currently identifies how to expand all tensors so that all intermediate
// tensors are the same shape, simplifying code generation.
// Broadcast groups and chunks are identified without shape information
// using logical properties of how each works. In particular, tensors
// are always expandable to the outputs of pointwise operations they
// or their descendants are involved in, which means that in a DAG of
// pointwise operations all tensors are expandable to the (single) output.
// Note: The logic is slightly complicated by concatenation and chunking.
static void upfrontCompilation(KernelSpec& spec) {
setInputBroadcastGroups(spec);
setInputChunkDescriptors(spec);
}
int64_t registerFusion(const Node* fusion_group) {
// Creates and stores the FusionSpec
auto graph = fusion_group->g(attr::Subgraph)->copy();
EraseShapeInformation(graph);
const auto key = store(graph);
if (canFuseOnCPU() || canFuseOnGPU()) {
const auto maybe_spec = retrieve(key);
JIT_ASSERT(maybe_spec);
upfrontCompilation(**maybe_spec);
}
return key;
}
std::shared_ptr<FusedKernel> compileKernel(
const KernelSpec& spec
, const ArgSpec& arg_spec
, const std::vector<int64_t>& map_size
, const at::Device device) {
const std::vector<TensorDesc>& input_desc = arg_spec.descs();
auto graph = spec.graph()->copy();
c10::optional<at::ScalarType> scalar_type;
for (size_t i = 0; i < input_desc.size(); i++) {
const auto& desc = input_desc[i];
graph->inputs()[i]->setType(TensorType::create(desc.scalar_type, device, desc.nDim())); // TODO: nDim is bad, as it is collapsed
}
PropagateInputShapes(graph);
// Creates output descriptions
std::vector<TensorDesc> output_desc;
for (const Value* output : graph->outputs()) {
std::vector<int64_t> sizes = map_size;
if (output->node()->kind() == prim::FusedConcat) {
sizes.at(output->node()->i(attr::dim)) *= output->node()->inputs().size();
}
auto scalar_type = output->type()->expect<c10::TensorType const>()->scalarType();
auto type = CompleteTensorType::create(scalar_type, device, sizes);
output_desc.emplace_back(std::move(type));
}
const std::string name = "kernel_" + std::to_string(next_kernel_id++);
const bool use_cuda = device.is_cuda();
std::string code;
std::vector<PartitionDesc> chunk_desc;
std::vector<PartitionDesc> concat_desc;
bool has_random;
std::tie(code, chunk_desc, concat_desc, has_random)
= generateKernel(
name
, *graph
, input_desc
, output_desc
, use_cuda);
std::shared_ptr<FusedKernel> fused_kernel;
if (use_cuda) {
#if USE_CUDA_FUSER
fused_kernel = std::make_shared<cuda::FusedKernelCUDA>(
device.index()
, name
, code
, input_desc
, output_desc
, chunk_desc
, concat_desc
, has_random);
#else
throw std::runtime_error("CUDA Fusion is not supported on this build.");
#endif // USE_CUDA_FUSER
} else {
#if USE_CPU_FUSER
fused_kernel = std::make_shared<cpu::FusedKernelCPU>(
name
, code
, input_desc
, output_desc
, chunk_desc
, concat_desc
, has_random);
#else
throw std::runtime_error("CPU Fusion is not supported on this build.");
#endif // USE_CPU_FUSER
}
return fused_kernel;
}
} // namespace fuser
} // namespace jit
} // namespace torch