#include "torch/csrc/jit/fuser/compiler.h" #include "ATen/ATen.h" #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/type.h" #include "torch/csrc/jit/code_template.h" #include "torch/csrc/jit/assertions.h" #include "torch/csrc/jit/passes/shape_analysis.h" #include "torch/csrc/jit/fuser/interface.h" #include "torch/csrc/jit/fuser/kernel_cache.h" #include "torch/csrc/jit/fuser/codegen.h" #include "torch/csrc/jit/fuser/tensor_desc.h" #if USE_CUDA_FUSER #include "torch/csrc/jit/fuser/cuda/fused_kernel.h" #endif // USE_CUDA_FUSER #if USE_CPU_FUSER #include "torch/csrc/jit/fuser/cpu/fused_kernel.h" #endif // USE_CUDA_FUSER #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { namespace fuser { // Counter for number of kernels compiled, used for debugging and // creating arbitrary kernel names. static std::atomic next_kernel_id{0}; size_t nCompiledKernels() { return next_kernel_id.load(); } // If the given node is used once by a chunk node, returns that node. // Returns nullptr otherwise. static const Node* usedInFusedChunk(const Value* input) { const auto uses = input->uses(); if (uses.size() == 1) { const Node *user = uses[0].user; if (user->kind() == prim::ConstantChunk) { return user; } } return nullptr; } static void setInputChunkDescriptors(KernelSpec& spec) { spec.inputChunks().reserve((spec.graph())->inputs().size()); for (const Value* input : (spec.graph())->inputs()) { if (const Node* chunk = usedInFusedChunk(input)) { spec.inputChunks().emplace_back(chunk->i(attr::chunks), chunk->i(attr::dim)); } else { spec.inputChunks().emplace_back(1, 0); } } } // Run a DFS traversal to find all inputs that affect a given output value static std::vector getInputDependencies(const Value* output) { std::vector queue{output}; std::unordered_set inputs; std::unordered_set seen; while (!queue.empty()) { const Value* val = queue.back(); queue.pop_back(); const Node* producer = val->node(); if (producer->kind() == prim::Param) { inputs.insert(val); continue; } for (const Value* input : producer->inputs()) { if (/*bool inserted = */seen.insert(input).second) { queue.push_back(input); } } } // Convert Value* into offsets into the graph's input list std::vector offsets; offsets.reserve(inputs.size()); for (const Value* input : inputs) { offsets.push_back(input->offset()); } std::sort(offsets.begin(), offsets.end()); return offsets; } static void setInputBroadcastGroups(KernelSpec& spec) { std::unordered_set, torch::hash>> broadcast_groups; for (const Value* output : (spec.graph())->outputs()) { broadcast_groups.insert(getInputDependencies(output)); } std::copy( broadcast_groups.begin() , broadcast_groups.end() , std::back_inserter(spec.inputBroadcastGroups())); } // Performs "upfront" compilation where storage is known but shapes are not. // Currently identifies how to expand all tensors so that all intermediate // tensors are the same shape, simplifying code generation. // Broadcast groups and chunks are identified without shape information // using logical properties of how each works. In particular, tensors // are always expandable to the outputs of pointwise operations they // or their descendants are involved in, which means that in a DAG of // pointwise operations all tensors are expandable to the (single) output. // Note: The logic is slightly complicated by concatenation and chunking. static void upfrontCompilation(KernelSpec& spec) { setInputBroadcastGroups(spec); setInputChunkDescriptors(spec); } int64_t registerFusion(const Node* fusion_group) { // Creates and stores the FusionSpec auto graph = fusion_group->g(attr::Subgraph)->copy(); EraseShapeInformation(*graph); const auto key = store(graph); if (canFuseOnCPU() || canFuseOnGPU()) { const auto maybe_spec = retrieve(key); JIT_ASSERT(maybe_spec); upfrontCompilation(**maybe_spec); } return key; } std::shared_ptr compileKernel( const KernelSpec& spec , const ArgSpec& arg_spec , const std::vector& map_size , const at::Device device) { const std::vector& input_desc = arg_spec.descs(); // Note: this assumes fused kernels only operate on floating point values c10::optional scalar_type; for (const auto& desc : input_desc) { if (isFloatingType(desc.scalar_type)) { scalar_type = desc.scalar_type; break; } } JIT_ASSERT(scalar_type); // Creates output descriptions std::vector output_desc; for (const Value* output : (spec.graph())->outputs()) { std::vector sizes = map_size; if (output->node()->kind() == prim::FusedConcat) { sizes.at(output->node()->i(attr::dim)) *= output->node()->inputs().size(); } auto type = CompleteTensorType::create(*scalar_type, device, sizes); output_desc.emplace_back(std::move(type)); } const std::string name = "kernel_" + std::to_string(next_kernel_id++); const bool use_cuda = device.is_cuda(); std::string code; std::vector chunk_desc; std::vector concat_desc; bool has_random; std::tie(code, chunk_desc, concat_desc, has_random) = generateKernel( name , *(spec.graph()) , input_desc , output_desc , use_cuda); std::shared_ptr fused_kernel; if (use_cuda) { #if USE_CUDA_FUSER fused_kernel = std::make_shared( device.index() , name , code , input_desc , output_desc , chunk_desc , concat_desc , has_random); #else throw std::runtime_error("CUDA Fusion is not supported on this build."); #endif // USE_CUDA_FUSER } else { #if USE_CPU_FUSER fused_kernel = std::make_shared( name , code , input_desc , output_desc , chunk_desc , concat_desc , has_random); #else throw std::runtime_error("CPU Fusion is not supported on this build."); #endif // USE_CPU_FUSER } return fused_kernel; } } // namespace fuser } // namespace jit } // namespace torch