mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
413 lines
13 KiB
C++
413 lines
13 KiB
C++
#include "torch/csrc/jit/fusion_compiler.h"
|
|
#include "torch/csrc/jit/ir.h"
|
|
#include "torch/csrc/jit/DisallowCopy.h"
|
|
#include "torch/csrc/jit/code_template.h"
|
|
#include "torch/csrc/jit/resource_guard.h"
|
|
#include "ATen/ATen.h"
|
|
#include <nvrtc.h>
|
|
#include <cuda.h>
|
|
#include <cuda_runtime.h>
|
|
#include <string>
|
|
#include <algorithm>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
#include <sstream>
|
|
#include <iostream>
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
std::vector<bool> findContiguous(
|
|
at::IntList sizes,
|
|
at::IntList strides) {
|
|
JIT_ASSERT(sizes.size() == strides.size());
|
|
std::vector<bool> cont(sizes.size());
|
|
for(size_t i = 0; i < sizes.size(); ++i) {
|
|
int64_t expected_stride = (i + 1 < sizes.size()) ? sizes[i+1]*strides[i+1] : 1;
|
|
cont[i] = strides[i] == expected_stride;
|
|
}
|
|
return cont;
|
|
}
|
|
|
|
// compress dimensions when the tensor is marked as cont
|
|
// anytime we do a compression, we assert that it is valid for this particular tensor.
|
|
static void compressContiguous(
|
|
at::IntList sizes,
|
|
at::IntList strides,
|
|
const std::vector<bool> & cont,
|
|
uint32_t * c_sizes,
|
|
uint32_t * c_strides) {
|
|
size_t compressed_dims = 0;
|
|
size_t cur = 0;
|
|
size_t ndim = sizes.size();
|
|
while(cur < ndim) {
|
|
size_t total_size = sizes[cur];
|
|
cur++;
|
|
while(cont[cur-1] && cur < ndim) {
|
|
JIT_ASSERT(strides[cur-1] == sizes[cur]*strides[cur]);
|
|
total_size *= sizes[cur];
|
|
cur++;
|
|
}
|
|
// cur starts pointing at the beginning of run to compress
|
|
// cur ends one _after_ the terminating false or end of list.
|
|
// total_size is the size of all dimensions [begin,end)
|
|
// examples:
|
|
// f = not cont.
|
|
// t = cont.
|
|
// x = don't care, including past end of list
|
|
// s = start of cur
|
|
// e = end of cur
|
|
|
|
|
|
// f x x x
|
|
// s e
|
|
|
|
// t f x x
|
|
// s e
|
|
|
|
// t t f x
|
|
// s e
|
|
|
|
c_sizes[compressed_dims] = total_size;
|
|
c_strides[compressed_dims] = strides[cur-1];
|
|
compressed_dims++;
|
|
}
|
|
JIT_ASSERT(!cont.back() || strides.back() == 1);
|
|
}
|
|
|
|
static auto compilation_unit_template = CodeTemplate(R"(
|
|
typedef ${IndexType} IndexType;
|
|
template<typename T, size_t N>
|
|
struct TensorInfo {
|
|
T * data;
|
|
IndexType sizes[N];
|
|
IndexType strides[N];
|
|
};
|
|
|
|
extern "C" __global__
|
|
void ${kernelName}(IndexType totalElements, ${formals}) {
|
|
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
|
|
linearIndex < totalElements;
|
|
linearIndex += gridDim.x * blockDim.x) {
|
|
// Convert `linearIndex` into an offset of tensor:
|
|
${tensorOffsets}
|
|
// calculate the results
|
|
${kernelBody}
|
|
}
|
|
}
|
|
)");
|
|
|
|
// curDimIndex = linearId % sizes[i]; // % sizes[i] is not needed for d == 0, because we already guard for numel outside the index calculation
|
|
// offset += curDimIndex*strides[i]; // *strides[i] is optional if list_is_cont becaause strides.back() == 1
|
|
// linearId /= sizes[i];
|
|
static auto dim_calc = CodeTemplate(R"(
|
|
//printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]);
|
|
size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes};
|
|
${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
|
|
)");
|
|
|
|
static void emitIndexingFor(std::ostream & out, const std::string & tensor, int ndim, bool last_is_cont) {
|
|
TemplateEnv env;
|
|
env.s("tensor",tensor);
|
|
out << format("IndexType ${tensor}_offset = 0;\n",env);
|
|
out << format("IndexType ${tensor}_linearIndex = linearIndex;\n",env);
|
|
for(int d = ndim - 1; d >= 0; --d) {
|
|
env.d("d",d);
|
|
env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]",env) : "");
|
|
env.s("times_stride",(d < ndim - 1 || !last_is_cont) ?
|
|
format("* ${tensor}.strides[${d}]",env) : "");
|
|
out << dim_calc.format(env);
|
|
if(d > 0) {
|
|
out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n",env);
|
|
}
|
|
}
|
|
}
|
|
|
|
static std::ostream& operator<<(std::ostream & out, const TensorDesc & d) {
|
|
out << d.scalar_type << "[";
|
|
for(auto b : d.contiguity)
|
|
out << b << ";";
|
|
out << "]";
|
|
return out;
|
|
}
|
|
|
|
static std::string nodeName(Node * n) {
|
|
return "n"+std::to_string(n->unique());
|
|
}
|
|
|
|
// TODO: we need to support double-precision
|
|
static std::unordered_map<NodeKind,std::string> simple_map_ops = {
|
|
{NodeKind::Sigmoid, "1.f / (1.f + expf(-${0}))"},
|
|
{NodeKind::Tanh, "tanhf(${0})"},
|
|
{NodeKind::Mul, "${0} * ${1}"},
|
|
{NodeKind::Add, "${0} + ${1}"},
|
|
};
|
|
|
|
const char * toCString(at::ScalarType type) {
|
|
switch(type) {
|
|
#define DEFINE_CASE(ctype,name,_) \
|
|
case at::ScalarType::name: return #ctype;
|
|
AT_FORALL_SCALAR_TYPES(DEFINE_CASE)
|
|
#undef DEFINE_CASE
|
|
default:
|
|
throw new std::runtime_error("unknown scalar type");
|
|
}
|
|
}
|
|
|
|
void emitCompilationUnit(std::ostream & out,
|
|
const std::string & name,
|
|
AnnotatedGraph & agraph) {
|
|
Graph& subgraph = *agraph.graph;
|
|
TemplateEnv env;
|
|
env.s("kernelName",name);
|
|
// TODO: handle cases where we need to generate > 2^32 element tensors
|
|
env.s("IndexType","unsigned int"); //avoiding slow header includes to get uint32_t
|
|
|
|
std::stringstream body;
|
|
std::stringstream tensorOffsets;
|
|
std::vector<std::string> formals;
|
|
auto emitFormal = [&](Node * n, const TensorDesc & desc) {
|
|
std::string tensor = "t" + std::to_string(formals.size()); //can't be unique() because Param may be an output
|
|
size_t nDim = desc.nDim();
|
|
emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
|
|
env.s("tensor",tensor);
|
|
env.d("nDim",nDim);
|
|
env.s("scalar_type",toCString(desc.scalar_type));
|
|
formals.push_back(format("TensorInfo<${scalar_type},${nDim}> ${tensor}",env));
|
|
};
|
|
{
|
|
size_t i = 0;
|
|
for(auto p : subgraph.inputs())
|
|
emitFormal(p,agraph.input_desc[i++]);
|
|
}
|
|
{
|
|
size_t i = 0;
|
|
for(auto o : subgraph.outputs())
|
|
emitFormal(o,agraph.output_desc[i++]);
|
|
}
|
|
size_t formal_count = 0;
|
|
for(auto p : subgraph.inputs()) {
|
|
env.s("node",nodeName(p));
|
|
env.d("formal",formal_count++);
|
|
env.s("access",format("t${formal}.data[t${formal}_offset]",env));
|
|
//TODO: actual type propagation rather than relying on auto..
|
|
body << format("auto ${node} = ${access};\n",env);
|
|
}
|
|
for(auto n : subgraph.nodes()) {
|
|
size_t i = 0;
|
|
for(auto in : n->inputs()) {
|
|
env.s(std::to_string(i++),nodeName(in));
|
|
}
|
|
env.s("node",nodeName(n));
|
|
env.s("rhs",format(simple_map_ops.at(n->kind()),env));
|
|
body << format("auto ${node} = ${rhs};\n",env);
|
|
}
|
|
for(auto o : subgraph.outputs()) {
|
|
env.d("formal",formal_count++);
|
|
env.s("access",format("t${formal}.data[t${formal}_offset]",env));
|
|
env.s("node",nodeName(o));
|
|
body << format("${access} = ${node};\n",env);
|
|
}
|
|
env.s("tensorOffsets",tensorOffsets.str());
|
|
env.s("kernelBody",body.str());
|
|
env.v("formals",formals);
|
|
out << compilation_unit_template.format(env);
|
|
}
|
|
|
|
static void nvrtcCheck(nvrtcResult result,const char * file, int line) {
|
|
if(result != NVRTC_SUCCESS) {
|
|
std::stringstream ss;
|
|
ss << file << ":" << line << ": " << nvrtcGetErrorString(result);
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
}
|
|
#define JIT_NVRTC_CHECK(result) \
|
|
nvrtcCheck(result,__FILE__,__LINE__);
|
|
|
|
static void cuCheck(CUresult result, const char * file, int line) {
|
|
if(result != CUDA_SUCCESS) {
|
|
const char * str;
|
|
cuGetErrorString(result, &str);
|
|
std::stringstream ss;
|
|
ss << file << ":" << line << ": " << str;
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
}
|
|
#define JIT_CU_CHECK(result) \
|
|
cuCheck(result,__FILE__,__LINE__);
|
|
|
|
static void cudaCheck(cudaError_t result, const char * file, int line) {
|
|
if(result != cudaSuccess) {
|
|
std::stringstream ss;
|
|
ss << file << ":" << line << ": " << cudaGetErrorString(result);
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
}
|
|
#define JIT_CUDA_CHECK(result) \
|
|
cudaCheck(result,__FILE__,__LINE__);
|
|
|
|
static int ceilDiv(int a, int b) {
|
|
return (a + b - 1) / b;
|
|
}
|
|
|
|
//host-side view of TensorInfo
|
|
//note dims[0], because we need to dynamically allocate the dims
|
|
struct TensorInfo {
|
|
void * data;
|
|
uint32_t sizes_strides[0];
|
|
uint32_t* sizes(size_t nDim) {
|
|
return &sizes_strides[0];
|
|
}
|
|
uint32_t* strides(size_t nDim) {
|
|
return &sizes_strides[nDim];
|
|
}
|
|
};
|
|
|
|
CompiledFusionFunction::CompiledFusionFunction(const std::string & name, AnnotatedGraph & agraph)
|
|
: name(name), input_desc(agraph.input_desc), output_desc(agraph.output_desc) {
|
|
std::stringstream cu;
|
|
emitCompilationUnit(cu, name, agraph);
|
|
compliation_unit = cu.str();
|
|
nvrtcProgram program;
|
|
JIT_NVRTC_CHECK(nvrtcCreateProgram(&program,compliation_unit.c_str(), NULL, 0, nullptr, nullptr));
|
|
cudaDeviceProp deviceProp;
|
|
JIT_CUDA_CHECK(cudaGetDevice(&device));
|
|
JIT_CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, device));
|
|
std::string compute = "--gpu-architecture=compute_" + std::to_string(deviceProp.major) + std::to_string(deviceProp.minor);
|
|
std::vector<const char *> args = {"--std=c++11", compute.c_str()};
|
|
nvrtcResult result = nvrtcCompileProgram(program, args.size(), args.data());
|
|
if(result == NVRTC_ERROR_COMPILATION) {
|
|
size_t logsize;
|
|
nvrtcGetProgramLogSize(program, &logsize);
|
|
std::vector<char> log(logsize);
|
|
nvrtcGetProgramLog(program, log.data());
|
|
cu << log.data();
|
|
throw std::runtime_error(cu.str());
|
|
}
|
|
ResourceGuard holdProgram([&] {
|
|
JIT_NVRTC_CHECK(nvrtcDestroyProgram(&program));
|
|
});
|
|
JIT_NVRTC_CHECK(result);
|
|
size_t ptx_size;
|
|
JIT_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size));
|
|
ptx.resize(ptx_size);
|
|
JIT_NVRTC_CHECK(nvrtcGetPTX(program, ptx.data()));
|
|
|
|
JIT_CU_CHECK(cuModuleLoadData(&module, ptx.data()));
|
|
JIT_CU_CHECK(cuModuleGetFunction(&function, module, name.c_str()));
|
|
|
|
JIT_CU_CHECK(cuOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
&maxBlocks, function, 128, 0));
|
|
JIT_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
|
|
maxBlocks *= prop.multiProcessorCount;
|
|
}
|
|
CompiledFusionFunction::~CompiledFusionFunction() {
|
|
JIT_CU_CHECK(cuModuleUnload(module));
|
|
}
|
|
|
|
void CompiledFusionFunction::launch(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs) {
|
|
JIT_ASSERT(inputs.size() == input_desc.size());
|
|
JIT_ASSERT(outputs.size() == output_desc.size());
|
|
size_t uncompressedDim = input_desc.at(0).contiguity.size();
|
|
uint32_t numel = inputs[0].numel();
|
|
size_t maxPossibleTensorInfoSize = sizeof(TensorInfo) + 2*sizeof(uint32_t)*uncompressedDim;
|
|
size_t maxPossibleBufferSize = maxPossibleTensorInfoSize * (inputs.size() + outputs.size());
|
|
std::vector<char> buffer(maxPossibleBufferSize);
|
|
char * buffer_next = buffer.data();
|
|
std::vector<void*> arguments;
|
|
arguments.reserve(1 + inputs.size() + outputs.size());
|
|
auto addTensorInfo = [&](TensorDesc & desc, const at::Tensor & t) {
|
|
size_t nDim = desc.nDim(); //the compressed dim
|
|
auto ti = reinterpret_cast<TensorInfo*>(buffer_next);
|
|
ti->data = t.data_ptr();
|
|
compressContiguous(t.sizes(), t.strides(), desc.contiguity, ti->sizes(nDim), ti->strides(nDim));
|
|
buffer_next += maxPossibleTensorInfoSize;
|
|
arguments.push_back(ti);
|
|
};
|
|
arguments.push_back(&numel);
|
|
{
|
|
size_t i = 0;
|
|
for(auto & desc : input_desc) {
|
|
addTensorInfo(desc,inputs[i++]);
|
|
}
|
|
}
|
|
{
|
|
size_t i = 0;
|
|
for(auto & desc : output_desc) {
|
|
addTensorInfo(desc,outputs[i++]);
|
|
}
|
|
}
|
|
launch(numel, arguments.data());
|
|
}
|
|
|
|
void CompiledFusionFunction::launch(uint32_t numel, void ** arguments) {
|
|
int numBlocks = std::min(maxBlocks,ceilDiv(numel,blockSize));
|
|
//std::cout << "maxBlocks = " << maxBlocks << " needed blocks: " << ceilDiv(numel,blockSize)
|
|
// << " numblocks = " << numBlocks;
|
|
|
|
JIT_CU_CHECK(cuLaunchKernel(
|
|
function,
|
|
numBlocks, 1, 1,
|
|
blockSize, 1, 1,
|
|
0, nullptr,
|
|
arguments,
|
|
nullptr));
|
|
}
|
|
|
|
|
|
|
|
|
|
FusionCompiler::FusionCompiler() {}
|
|
std::shared_ptr<CompiledFusionFunction> FusionCompiler::getOrCompile(AnnotatedGraph & agraph) {
|
|
std::stringstream key;
|
|
key << *agraph.graph << "\n";
|
|
int device;
|
|
JIT_CUDA_CHECK(cudaGetDevice(&device));
|
|
key << "Device " << device << "\n";
|
|
for(auto & i : agraph.input_desc)
|
|
key << i << "\n";
|
|
for(auto & i : agraph.output_desc)
|
|
key << i << "\n";
|
|
std::string key_ = key.str();
|
|
auto it = cache.find(key_);
|
|
if(it == cache.end()) {
|
|
std::string name = "kernel_" + std::to_string(cache.size());
|
|
auto func = std::make_shared<CompiledFusionFunction>(name,agraph);
|
|
it = cache.emplace(key_,std::move(func)).first;
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
std::shared_ptr<CompiledFusionFunction> FusionCompiler::getOrCompile(Graph & graph) {
|
|
AnnotatedGraph agraph { &graph };
|
|
for(auto & input : graph.inputs()) {
|
|
// graph doesn't record scalar type yet..., assuming float
|
|
TensorType * t = input->type()->cast<TensorType>();
|
|
agraph.input_desc.push_back(TensorDesc(t->scalarType(),t->sizes(),t->strides()));
|
|
}
|
|
for(auto & output : graph.outputs()) {
|
|
TensorType * t = output->type()->cast<TensorType>();
|
|
agraph.output_desc.push_back(TensorDesc(t->scalarType(),t->sizes(),t->strides()));
|
|
}
|
|
return getOrCompile(agraph);
|
|
}
|
|
|
|
void FusionCompiler::debugLaunchGraph(Graph & graph, at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs) {
|
|
AnnotatedGraph agraph { &graph };
|
|
for(auto & i : inputs) {
|
|
agraph.input_desc.emplace_back(i);
|
|
}
|
|
for(auto & i : outputs) {
|
|
agraph.output_desc.emplace_back(i);
|
|
}
|
|
auto func = getOrCompile(agraph);
|
|
func->launch(inputs, outputs);
|
|
}
|
|
|
|
//TODO: thread safety
|
|
FusionCompiler & sharedFusionCompiler() {
|
|
static FusionCompiler compiler;
|
|
return compiler;
|
|
}
|
|
|
|
}}
|