mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Have basic reduction fusion working, and have improved code generator to approach performance of eager mode reductions. Coming soon will be pointwise-reduction fusions in a way that should prevent the possibility of hitting regressions. Also working on performant softmax kernels in the code generator which may be our next fusion target. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40864 Reviewed By: ngimel Differential Revision: D22392877 Pulled By: soumith fbshipit-source-id: 457448a807d628b1035f6d90bc0abe8a87bf8447
722 lines
23 KiB
C++
722 lines
23 KiB
C++
#include <ATen/CUDAGeneratorImpl.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
|
#include <c10/core/ScalarType.h>
|
|
#include <c10/cuda/CUDACachingAllocator.h>
|
|
#include <c10/util/ArrayRef.h>
|
|
|
|
#include <torch/csrc/jit/codegen/cuda/kernel.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel_arg.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel_resource_strings.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
|
#include <torch/csrc/jit/codegen/cuda/parser.h>
|
|
|
|
#include <torch/csrc/jit/resource_guard.h>
|
|
#include <fstream>
|
|
#include <iostream>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
constexpr auto kCgNamespace = "CudaCodeGen";
|
|
constexpr auto kKernelName = "kernel";
|
|
|
|
namespace {
|
|
|
|
// See NOTE [ USE OF NVRTC AND DRIVER API ]
|
|
static const at::cuda::NVRTC& nvrtc() {
|
|
return at::globalContext().getNVRTC();
|
|
}
|
|
|
|
static int ceilDiv(const int a, const int b) {
|
|
return (a + b - 1) / b;
|
|
}
|
|
|
|
// Go through a tensor, and grab it's sizes/strides potentially broadcasted
|
|
struct ExtractSizeStride {
|
|
std::vector<int64_t> sizes;
|
|
std::vector<int64_t> strides;
|
|
|
|
explicit ExtractSizeStride(
|
|
const at::Tensor& val,
|
|
c10::optional<at::IntArrayRef> broadcasted_size = c10::nullopt) {
|
|
if (broadcasted_size) {
|
|
// [Note - broadcast support in integration]
|
|
// PyTorch follows numpy broadcasting rule.
|
|
// (https://numpy.org/doc/stable/user/basics.broadcasting.html)
|
|
//
|
|
// So in case where the rank of two operators differ, we align them on
|
|
// the higher dimensions, hence the offset o_dim-b_dim to the index here.
|
|
int b_dim = (int)broadcasted_size->size();
|
|
int o_dim = (int)val.dim();
|
|
TORCH_CHECK(b_dim >= o_dim);
|
|
for (int i = 0; i < b_dim; i++) {
|
|
sizes.push_back(broadcasted_size->at(i));
|
|
int index = i + o_dim - b_dim;
|
|
if (index < 0) {
|
|
strides.push_back(0);
|
|
} else if (val.sizes()[index] == sizes[i]) {
|
|
strides.push_back(val.strides()[index]);
|
|
} else {
|
|
TORCH_CHECK(
|
|
val.sizes()[index] == 1,
|
|
"Not compatible dimension size for broadcast");
|
|
strides.push_back(0);
|
|
}
|
|
}
|
|
} else {
|
|
auto o_dim = val.dim();
|
|
for (decltype(val.dim()) i{0}; i < o_dim; i++) {
|
|
sizes.push_back(val.sizes()[i]);
|
|
strides.push_back(val.strides()[i]);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
struct KernelArgumentHolder {
|
|
private:
|
|
std::vector<ArgAbstract*> arguments;
|
|
std::vector<void*> void_ptrs;
|
|
bool changed = true;
|
|
|
|
public:
|
|
virtual ~KernelArgumentHolder() {
|
|
for (auto arg : arguments)
|
|
delete arg;
|
|
}
|
|
|
|
// Push a tensor to the arguments
|
|
void push(
|
|
const at::Tensor& val,
|
|
c10::optional<at::IntArrayRef> broadcasted_size = c10::nullopt) {
|
|
changed = true;
|
|
ExtractSizeStride ess(val, std::move(broadcasted_size));
|
|
int nDims = ess.sizes.size();
|
|
|
|
c10::ScalarType dtype = val.scalar_type();
|
|
TensorArgAbstract* tensor_arg = getTensorArg(dtype, nDims);
|
|
tensor_arg->setPointer(val.data_ptr());
|
|
for (int i = 0; i < nDims; i++) {
|
|
tensor_arg->setSize(i, ess.sizes[i]);
|
|
tensor_arg->setStride(i, ess.strides[i]);
|
|
}
|
|
arguments.push_back(tensor_arg);
|
|
}
|
|
|
|
// Push a scalar or integer to the arguments
|
|
void push(const IValue& val) {
|
|
changed = true;
|
|
TORCH_INTERNAL_ASSERT(
|
|
val.isScalar(),
|
|
"Tried to push an arg to run in a fused kernel, expected a scalar but got, ",
|
|
val);
|
|
switch (val.toScalar().type()) {
|
|
case (c10::ScalarType::Double):
|
|
arguments.push_back(new FloatArg((float)val.toDouble()));
|
|
return;
|
|
case (c10::ScalarType::Long):
|
|
arguments.push_back(new IntArg((int)val.toInt()));
|
|
return;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
" Tried to create argument to send to a fused kernel, but got an unexpected type.");
|
|
}
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
" Tried to create argument to send to a fused kernel, but got a non-scalar type.");
|
|
}
|
|
|
|
void push(const uint64_t& val) {
|
|
arguments.push_back(new ULongArg(val));
|
|
}
|
|
|
|
// Create buffer, flatten arguments into it, align by 8 Bytes, return pointers
|
|
// in the buffer
|
|
void** getBuffer() {
|
|
if (changed) {
|
|
void_ptrs = std::vector<void*>(arguments.size(), nullptr);
|
|
for (decltype(arguments.size()) i{0}; i < arguments.size(); i++)
|
|
void_ptrs[i] = static_cast<void*>(arguments[i]->arg());
|
|
changed = false;
|
|
}
|
|
return void_ptrs.data();
|
|
}
|
|
};
|
|
|
|
std::pair<std::string, std::string> codeGeneration(Fusion* fusion) {
|
|
std::stringstream str_stream;
|
|
str_stream << "namespace " << kCgNamespace << " {\n"
|
|
<< code_template_tensor_struct << "\n"
|
|
<< code_fp16_support << "\n"
|
|
<< code_random_number_gen << "\n"
|
|
<< code_helper_funcs << "\n"
|
|
<< code_template_block_reduction << "\n"
|
|
<< code_template_grid_reduction << "\n";
|
|
std::stringstream cdg;
|
|
GPULower gpulw(fusion);
|
|
gpulw.printKernel(str_stream, kKernelName);
|
|
str_stream << "\n} // namespace";
|
|
|
|
std::string func_name = std::string(kCgNamespace) + "::" + kKernelName;
|
|
return std::make_pair(func_name, str_stream.str());
|
|
}
|
|
|
|
bool validateKernelArgTensor(
|
|
const at::Tensor& arg,
|
|
const Val* param,
|
|
int device_index,
|
|
std::stringstream& msg) {
|
|
// Arg is a tensor. Param must be a tensor too.
|
|
if (*param->getValType() != ValType::TensorView) {
|
|
msg << "Argument is a tensor, but the parameter is not.";
|
|
return false;
|
|
}
|
|
|
|
// Check the rank of the tensors.
|
|
size_t arg_dim = arg.dim();
|
|
// Note: This requires current Fusion to be active.
|
|
size_t param_dim = TensorDomain::noReductions(
|
|
static_cast<const TensorView*>(param)->getRootDomain())
|
|
.size();
|
|
// see [Note - broadcast support in integration]
|
|
// Because of broadcasting support handled in integration, we relax the rank
|
|
// check as necessary.
|
|
if (arg_dim > param_dim) {
|
|
msg << "Argument tensor's rank is " << arg_dim << ", but the parameter is "
|
|
<< param_dim;
|
|
return false;
|
|
}
|
|
|
|
if (arg.device().index() != device_index) {
|
|
msg << "Argument is on device that is not compiled for";
|
|
return false;
|
|
}
|
|
// Check element type
|
|
at::ScalarType arg_data_type = arg.scalar_type();
|
|
DataType param_data_type = *param->getDataType();
|
|
bool match = false;
|
|
switch (arg_data_type) {
|
|
case at::ScalarType::Half:
|
|
match = param_data_type == DataType::Half;
|
|
break;
|
|
case at::ScalarType::Float:
|
|
match = param_data_type == DataType::Float;
|
|
break;
|
|
case at::ScalarType::Bool:
|
|
match = param_data_type == DataType::Bool;
|
|
break;
|
|
default:
|
|
msg << "Argument element type, " << arg_data_type
|
|
<< ", is not supported.";
|
|
return false;
|
|
}
|
|
if (!match)
|
|
msg << "Argument element type is " << arg_data_type
|
|
<< ", but the parameter is " << param_data_type;
|
|
return match;
|
|
}
|
|
|
|
bool validateKernelArgScalar(
|
|
const c10::TypePtr& arg_type,
|
|
const Val* param,
|
|
std::stringstream& msg) {
|
|
if (!param->isScalar()) {
|
|
msg << "Argument is a scalar, but the parameter is not.";
|
|
return false;
|
|
}
|
|
DataType param_type = *param->getDataType();
|
|
bool match = false;
|
|
switch (arg_type->kind()) {
|
|
case c10::TypeKind::IntType:
|
|
match = param_type == DataType::Int;
|
|
break;
|
|
case c10::TypeKind::FloatType:
|
|
match = param_type == DataType::Float;
|
|
break;
|
|
case c10::TypeKind::BoolType:
|
|
match = param_type == DataType::Bool;
|
|
break;
|
|
default:
|
|
match = false;
|
|
}
|
|
if (!match) {
|
|
msg << "Argument type is " << *arg_type << ", but the parameter is "
|
|
<< param_type;
|
|
}
|
|
return match;
|
|
}
|
|
|
|
bool validateKernelArg(
|
|
const c10::IValue& arg,
|
|
const Val* param,
|
|
int device_index,
|
|
std::stringstream& msg) {
|
|
if (arg.type()->kind() != c10::TypeKind::TensorType) {
|
|
return validateKernelArgScalar(arg.type(), param, msg);
|
|
} else {
|
|
return validateKernelArgTensor(arg.toTensor(), param, device_index, msg);
|
|
}
|
|
}
|
|
|
|
void validateKernelArgs(
|
|
const CudaKernel& entry,
|
|
const at::ArrayRef<IValue>& inputs,
|
|
const std::vector<at::Tensor>& outputs) {
|
|
// This is necessary as we were traversing the fusion graph later in the check
|
|
FusionGuard fg(&entry);
|
|
// Check inputs
|
|
TORCH_INTERNAL_ASSERT(
|
|
inputs.size() == entry.fusion_->inputs().size(),
|
|
"Wrong number of kernel inputs.");
|
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
|
const IValue& arg = inputs[i];
|
|
const Val* param = entry.fusion_->inputs()[i];
|
|
std::stringstream msg;
|
|
TORCH_INTERNAL_ASSERT(
|
|
validateKernelArg(arg, param, entry.device_, msg),
|
|
"Input argument at position ",
|
|
i,
|
|
" is invalid; ",
|
|
msg.str());
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
entry.fusion_->outputs().size() != 0,
|
|
"Kernel should have at least one output tensor.");
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
outputs.size() == entry.fusion_->outputs().size(),
|
|
"Wrong number of kernel outputs.");
|
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
|
const at::Tensor& arg = outputs[i];
|
|
const Val* param = entry.fusion_->outputs()[i];
|
|
std::stringstream msg;
|
|
TORCH_INTERNAL_ASSERT(
|
|
validateKernelArgTensor(arg, param, entry.device_, msg),
|
|
"Output argument at position ",
|
|
i,
|
|
" is invalid; ",
|
|
msg.str());
|
|
}
|
|
}
|
|
|
|
size_t size(const dim3& d) {
|
|
return (size_t)d.x * (size_t)d.y * (size_t)d.z;
|
|
}
|
|
|
|
dim3 dimensionOfReductionBlock(
|
|
const dim3& block_dim,
|
|
bool x_thread,
|
|
bool y_thread,
|
|
bool z_thread) {
|
|
return dim3{x_thread ? block_dim.x : 1,
|
|
y_thread ? block_dim.y : 1,
|
|
z_thread ? block_dim.z : 1};
|
|
}
|
|
|
|
int sizeOfReductionBlock(
|
|
const dim3& block_dim,
|
|
bool x_thread,
|
|
bool y_thread,
|
|
bool z_thread) {
|
|
return size(
|
|
dimensionOfReductionBlock(block_dim, x_thread, y_thread, z_thread));
|
|
}
|
|
|
|
// Returns the total number of reduction segments.
|
|
size_t numberOfReductionSegments(
|
|
const dim3& grid_dim,
|
|
bool x_block,
|
|
bool y_block,
|
|
bool z_block) {
|
|
return (x_block ? 1 : grid_dim.x) * (y_block ? 1 : grid_dim.y) *
|
|
(z_block ? 1 : grid_dim.z);
|
|
}
|
|
|
|
std::array<size_t, 2> gridReductionTempBufferSizes(CudaKernel* entry) {
|
|
size_t buffer_size = 0;
|
|
size_t sync_flag_size = 0;
|
|
for (auto expr : entry->fusion_->exprs(true)) {
|
|
if (expr->getExprType() != ExprType::ReductionOp)
|
|
continue;
|
|
ReductionOp* rop = static_cast<ReductionOp*>(expr);
|
|
auto domains = rop->getParallelReductionDomains();
|
|
bool x_block = domains.find(ParallelType::BIDx) != domains.end();
|
|
bool y_block = domains.find(ParallelType::BIDy) != domains.end();
|
|
bool z_block = domains.find(ParallelType::BIDz) != domains.end();
|
|
// No buffer needed unless it's a grid reduction
|
|
if (!x_block && !y_block && !z_block)
|
|
continue;
|
|
// Assumption here is that reduction along the block-parallel
|
|
// domains is done prior to this grid reduction, so those domains
|
|
// do not need to participate in the grid reductions
|
|
bool x_thread = domains.find(ParallelType::TIDx) == domains.end();
|
|
bool y_thread = domains.find(ParallelType::TIDy) == domains.end();
|
|
bool z_thread = domains.find(ParallelType::TIDz) == domains.end();
|
|
auto rb_size =
|
|
sizeOfReductionBlock(entry->block_, x_thread, y_thread, z_thread);
|
|
auto num_blocks = size(entry->grid_);
|
|
auto element_size = dataTypeSize(*(rop->out()->getDataType()));
|
|
auto required_temp_buffer_size = num_blocks * rb_size * element_size;
|
|
buffer_size = std::max(buffer_size, required_temp_buffer_size);
|
|
auto flag_size = sizeof(unsigned) *
|
|
numberOfReductionSegments(entry->grid_, x_block, y_block, z_block);
|
|
sync_flag_size = std::max(sync_flag_size, flag_size);
|
|
}
|
|
return {{buffer_size, sync_flag_size}};
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool NaivePWKernelArgsReq::matchKernelSize(const at::ArrayRef<IValue> inputs) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
inputs.size() == dims_.size(),
|
|
"wrong number of inputs feed to generated kernel!");
|
|
for (size_t i = 0; i < dims_.size(); i++) {
|
|
if (inputs[i].isTensor()) {
|
|
if (inputs[i].toTensor().dim() != dims_[i]) {
|
|
return false;
|
|
}
|
|
} else {
|
|
if (dims_[i] != -1) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void compileKernel(CudaKernel* entry) {
|
|
// generating cuda code;
|
|
std::string code;
|
|
std::string func_name;
|
|
std::tie(func_name, code) = codeGeneration(entry->fusion_.get());
|
|
|
|
static int32_t compiled_kernel_id = 0;
|
|
// We increment the id here instead of at the end of the function to avoid
|
|
// error during jit-compilation that would make debug message confusing.
|
|
compiled_kernel_id++;
|
|
const char* debug_env = getenv("PYTORCH_CUDA_FUSER_DEBUG");
|
|
if (debug_env && atoi(debug_env)) {
|
|
std::cout << "\n==== codegen output for kernel: " << compiled_kernel_id
|
|
<< " ====" << std::endl
|
|
<< code << std::endl
|
|
<< "====================================" << std::endl;
|
|
}
|
|
|
|
// vvv NVRTC COMPILATION vvv
|
|
|
|
// lazily construct context if non-existing yet;
|
|
CUcontext pctx = nullptr;
|
|
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
|
|
if (!pctx) {
|
|
std::unique_lock<std::mutex> cudaFreeMutexLock(
|
|
*(c10::cuda::CUDACachingAllocator::getFreeMutex()));
|
|
cudaFree(nullptr);
|
|
}
|
|
|
|
// set device for the operation;
|
|
at::cuda::set_device(entry->device_);
|
|
entry->has_random_ = entry->fusion_->hasRNG();
|
|
|
|
const auto prop = at::cuda::getCurrentDeviceProperties();
|
|
int nvrtc_major, nvrtc_minor;
|
|
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));
|
|
|
|
// Short-circuits if NVRTC version too low
|
|
TORCH_INTERNAL_ASSERT(nvrtc_major >= 6);
|
|
// Major and minor is determined by device properties and
|
|
// possibly "downcompiled" to a lower (compatible) compute architecture
|
|
// based on the NVRTC version
|
|
int major, minor;
|
|
major = prop->major;
|
|
minor = prop->minor;
|
|
nvrtcProgram program;
|
|
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
|
|
&program, code.c_str(), nullptr, 0, nullptr, nullptr));
|
|
ResourceGuard holdProgram(
|
|
[&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
|
|
|
|
const std::string compute = "--gpu-architecture=compute_" +
|
|
std::to_string(major) + std::to_string(minor);
|
|
const std::vector<const char*> args = {
|
|
"--std=c++14", compute.c_str(), "-default-device"};
|
|
|
|
nvrtc().nvrtcAddNameExpression(program, func_name.c_str());
|
|
const auto result =
|
|
nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
|
|
if (result != NVRTC_SUCCESS) {
|
|
size_t logsize;
|
|
nvrtc().nvrtcGetProgramLogSize(program, &logsize);
|
|
std::vector<char> log(logsize);
|
|
nvrtc().nvrtcGetProgramLog(program, log.data());
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, code.c_str(), "\nCUDA NVRTC compile error: ", log.data());
|
|
}
|
|
const char* lowered_kernel_name;
|
|
nvrtc().nvrtcGetLoweredName(program, func_name.c_str(), &lowered_kernel_name);
|
|
|
|
AT_CUDA_NVRTC_CHECK(result);
|
|
size_t ptx_size;
|
|
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size));
|
|
std::vector<char> ptx;
|
|
ptx.resize(ptx_size);
|
|
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx.data()));
|
|
|
|
// TODO: We do go through different code path, should investigate whether this
|
|
// has an impact on generated binary.
|
|
const char* prefix_env = getenv("PYTORCH_CUDA_FUSER_CUBIN");
|
|
if (prefix_env) {
|
|
// Output ptx file
|
|
std::stringstream ptx_file_name;
|
|
ptx_file_name << prefix_env << "_" << compiled_kernel_id << ".ptx";
|
|
std::ofstream myPtxFile(ptx_file_name.str().c_str(), std::ios::out);
|
|
if (myPtxFile.is_open()) {
|
|
myPtxFile.write(ptx.data(), ptx.size());
|
|
myPtxFile.close();
|
|
}
|
|
|
|
CUlinkState linkState;
|
|
|
|
AT_CUDA_DRIVER_CHECK(nvrtc().cuLinkCreate(0, nullptr, nullptr, &linkState));
|
|
AT_CUDA_DRIVER_CHECK(nvrtc().cuLinkAddData(
|
|
linkState,
|
|
CU_JIT_INPUT_PTX,
|
|
ptx.data(),
|
|
ptx_size,
|
|
"compiling PTX",
|
|
0,
|
|
nullptr,
|
|
nullptr));
|
|
size_t cubinSize;
|
|
void* cubin;
|
|
AT_CUDA_DRIVER_CHECK(nvrtc().cuLinkComplete(linkState, &cubin, &cubinSize));
|
|
|
|
// Output binary file
|
|
std::stringstream cubin_file_name;
|
|
cubin_file_name << prefix_env << "_" << compiled_kernel_id << ".cubin";
|
|
std::ofstream myCubinFile(
|
|
cubin_file_name.str().c_str(), std::ios::out | std::ios::binary);
|
|
if (myCubinFile.is_open()) {
|
|
myCubinFile.write(static_cast<const char*>(cubin), cubinSize);
|
|
myCubinFile.close();
|
|
}
|
|
|
|
// load compiled cubin
|
|
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&(entry->module_), cubin));
|
|
} else {
|
|
// load ptx directly
|
|
AT_CUDA_DRIVER_CHECK(
|
|
nvrtc().cuModuleLoadData(&(entry->module_), ptx.data()));
|
|
}
|
|
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleGetFunction(
|
|
&(entry->function_), entry->module_, lowered_kernel_name));
|
|
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION < 305
|
|
// HIP function signature is not compatible yet
|
|
uint32_t max_blocks;
|
|
AT_CUDA_DRIVER_CHECK(nvrtc().hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
&max_blocks, entry->function_, 128, 0));
|
|
entry->max_blocks_ = max_blocks;
|
|
#else
|
|
AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
&entry->max_blocks_, entry->function_, 128, 0));
|
|
#endif
|
|
entry->max_blocks_ *= prop->multiProcessorCount;
|
|
}
|
|
|
|
void runKernel(
|
|
CudaKernel* entry,
|
|
const at::ArrayRef<IValue> inputs,
|
|
const std::vector<at::Tensor>& outputs,
|
|
const std::vector<int64_t>& broadcasted_shape) {
|
|
validateKernelArgs(*entry, inputs, outputs);
|
|
|
|
const auto prior_device = at::cuda::current_device();
|
|
at::cuda::set_device(entry->device_);
|
|
auto stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
// TODO: Proper API to establish reasonable launch configurations;
|
|
// Naive launch config;
|
|
const size_t numel = outputs[0].numel();
|
|
|
|
int blocks = 1;
|
|
int thread_x = 1;
|
|
int thread_y = 1;
|
|
if (!entry->reduction_axes_.empty()) {
|
|
// TODO: MAJOR HACK! Expr evaluation makes launch configuration much easier
|
|
blocks = numel;
|
|
// Translated to `fcd_reduction`
|
|
if (entry->reduction_axes_.back() ==
|
|
outputs[0].dim() + ((int)entry->reduction_axes_.size()) - 1) {
|
|
thread_x = kFcdReductionThreadX;
|
|
thread_y = 1;
|
|
} else {
|
|
thread_x = kNonFcdReductionThreadX;
|
|
thread_y = kNonFcdReductionThreadY;
|
|
}
|
|
} else {
|
|
// TODO: we can't randomly clap down this until we got striding.
|
|
blocks = ceilDiv(numel, kPwThreadX * entry->unroll_factor_);
|
|
thread_x = kPwThreadX;
|
|
thread_y = 1;
|
|
}
|
|
const auto nBlocks = blocks;
|
|
const auto nThreadx = thread_x;
|
|
const auto nThready = thread_y;
|
|
|
|
KernelArgumentHolder kernel_args;
|
|
|
|
// Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O
|
|
// allocated here from the subgraph could be, and very likely are, different
|
|
// from I/O expected by the generated CUDA kernel.
|
|
for (auto& input : inputs) {
|
|
if (input.isTensor()) {
|
|
kernel_args.push(input.toTensor(), broadcasted_shape);
|
|
} else {
|
|
kernel_args.push(input);
|
|
}
|
|
}
|
|
|
|
for (auto& output : outputs) {
|
|
kernel_args.push(output);
|
|
}
|
|
|
|
// TODO: this probably won't work for us.
|
|
if (entry->has_random_) {
|
|
std::pair<uint64_t, uint64_t> philox_engine_inputs;
|
|
const auto rand_offset = 4 * (std::ceil(numel / (4.0 * 128 * nBlocks)) + 1);
|
|
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
|
|
{
|
|
// See Note [Acquire lock when using random generators]
|
|
std::lock_guard<std::mutex> lock(gen.mutex());
|
|
philox_engine_inputs =
|
|
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
|
|
rand_offset);
|
|
}
|
|
kernel_args.push(philox_engine_inputs.first);
|
|
kernel_args.push(philox_engine_inputs.second);
|
|
}
|
|
|
|
// launch kernel;
|
|
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
|
|
entry->function_,
|
|
nBlocks,
|
|
1,
|
|
1,
|
|
nThreadx,
|
|
nThready,
|
|
1,
|
|
0,
|
|
stream,
|
|
kernel_args.getBuffer(),
|
|
nullptr));
|
|
|
|
// Resets device (see at::DeviceGuard notes above)
|
|
at::cuda::set_device(prior_device);
|
|
}
|
|
|
|
// WARNING:
|
|
// This function is here for testing purposes only
|
|
void runTestKernel(
|
|
CudaKernel* entry,
|
|
const at::ArrayRef<IValue> inputs,
|
|
const std::vector<at::Tensor>& outputs) {
|
|
validateKernelArgs(*entry, inputs, outputs);
|
|
|
|
const auto prior_device = at::cuda::current_device();
|
|
at::cuda::set_device(entry->device_);
|
|
auto stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
// TODO: Proper API to establish reasonable launch configurations;
|
|
// Naive launch config;
|
|
TORCH_INTERNAL_ASSERT(!outputs.empty(), "No outputs set for test kernel.");
|
|
size_t numel = outputs[0].numel();
|
|
|
|
// TODO: we can't randomly clap down this until we got striding.
|
|
const auto nBlocks = ceilDiv(numel, 128 * entry->unroll_factor_);
|
|
|
|
KernelArgumentHolder kernel_args;
|
|
|
|
auto exprs = entry->fusion_->exprs(true);
|
|
|
|
// Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O
|
|
// allocated here from the subgraph could be, and very likely are, different
|
|
// from I/O expected by the generated CUDA kernel.
|
|
for (auto& input : inputs) {
|
|
if (input.isTensor()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
input.toTensor().device().index() == entry->device_,
|
|
"input to kernel on device that is not compiled for");
|
|
TORCH_INTERNAL_ASSERT(
|
|
!entry->fusion_->outputs().empty(),
|
|
"No output found for this kernel, aborting.");
|
|
kernel_args.push(input.toTensor());
|
|
} else {
|
|
kernel_args.push(input);
|
|
}
|
|
}
|
|
|
|
for (auto& output : outputs) {
|
|
kernel_args.push(output);
|
|
}
|
|
|
|
// TODO: this probably won't work for us.
|
|
if (entry->has_random_) {
|
|
std::pair<uint64_t, uint64_t> philox_engine_inputs;
|
|
const auto rand_offset = 4 * (std::ceil(numel / (4.0 * 128 * nBlocks)) + 1);
|
|
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
|
|
{
|
|
// See Note [Acquire lock when using random generators]
|
|
std::lock_guard<std::mutex> lock(gen.mutex());
|
|
philox_engine_inputs =
|
|
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
|
|
rand_offset);
|
|
}
|
|
kernel_args.push(philox_engine_inputs.first);
|
|
kernel_args.push(philox_engine_inputs.second);
|
|
}
|
|
|
|
// When the kernel has global reductions, the kernel needs two
|
|
// additional temporary buffers, one for intermediate results and
|
|
// another for synchronization among thread blocks.
|
|
if (entry->fusion_->hasGridReduction()) {
|
|
auto temp_buf_type = at::kFloat;
|
|
auto temp_buf_sizes = gridReductionTempBufferSizes(entry);
|
|
auto options =
|
|
at::TensorOptions().dtype(temp_buf_type).device(at::kCUDA, 0);
|
|
at::Tensor reduction_work_buffer = at::empty(
|
|
{(long)(temp_buf_sizes[0] / c10::elementSize(temp_buf_type))}, options);
|
|
kernel_args.push(reduction_work_buffer);
|
|
at::Tensor sync_flags = at::zeros(
|
|
{(long)(temp_buf_sizes[1] / c10::elementSize(temp_buf_type))}, options);
|
|
kernel_args.push(sync_flags);
|
|
}
|
|
|
|
// launch kernel;
|
|
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
|
|
entry->function_,
|
|
entry->grid_.x,
|
|
entry->grid_.y,
|
|
entry->grid_.z,
|
|
entry->block_.x,
|
|
entry->block_.y,
|
|
entry->block_.z,
|
|
0,
|
|
stream,
|
|
kernel_args.getBuffer(),
|
|
nullptr));
|
|
|
|
// Resets device (see at::DeviceGuard notes above)
|
|
at::cuda::set_device(prior_device);
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|