pytorch/torch/csrc/jit/codegen/fuser/codegen.cpp
jjsjann123 7b419e8513 [NVFuser] Upstream push 1026 (#87779)
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

* codegen improvement:
    i. allow non-root trivial reductions, allow empty/no-op fusion
    ii. fixes vectorization checks and size calculation
    iii. bank conflict handle improvement
    iv. enables transpose scheduler

* misc:
    i. CI tests failure fixes
    ii. cpp tests file clean up
    iii. trivial forwarding supports added in codegen runtime
    iv. added factory methods support in codegen

Commits that's in this PR from the devel branch:

```
7117a7e37ebec372d9e802fdfb8abb7786960f4a patching nvfuser conv cudnn test numerics mismatch (#2048)
65af1a4e7013f070df1ba33701f2d524de79d096 Inserting sync for redundant parallel types is already done at the (#2023)
6ac74d181689c8f135f60bfc1ec139d88941c98c Fix sync map (#2047)
f5bca333355e2c0033523f3402de5b8aac602c00 Bank conflict checker improvements (#2032)
d2ca7e3fd203537946be3f7b435303c60fa7f51e Minor update on cp.async code generation. (#1901)
d36cf61f5570c9c992a748126287c4e7432228e0 Test file cleanup (#2040)
0b8e83f49c2ea9f04a4aad5061c1e7f4268474c6 Allow non-root trivial reductions (#2037)
a2dfe40b27cd3f5c04207596f0a1818fbd5e5439 Fix vectorize size calculation (#2035)
e040676a317fe34ea5875276270c7be88f6eaa56 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025)
197221b847ad5eb347d7ec1cf2706733aacbf97c removing ci workflow (#2034)
40e2703d00795526e7855860aa00b9ab7160755f Reduction rand like patch (#2031)
bc772661cbdb3b711d8e9854ae9b8b7052e3e4a3 Add utility for checking bank conflict of shared memory (#2029)
ddd1cf7695f3fb172a0e4bcb8e4004573617a037 Add back FusionReductionWithTrivialReduction_CUDA (#2030)
fbd97e5ef15fa0f7573800e6fbb5743463fd9e57 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024)
bca20c1dfb8aa8d881fc7973e7579ce82bc6a894 Cleanup trivial reduction workarounds (#2006)
e4b65850eee1d70084105bb6e1f290651adde23e Trivial forwarding (#1995)
1a0e355b5027ed0df501989194ee8f2be3fdd37a Fix contiguity analysis of predicates to match updated contiguity. (#1991)
a4effa6a5f7066647519dc56e854f4c8a2efd2a7 Enable output allocation cache (#2010)
35440b7953ed8da164a5fb28f87d7fd760ac5e00 Patching bn inference (#2016)
0f9f0b4060dc8ca18dc65779cfd7e0776b6b38e8 Add matmul benchmark (#2007)
45045cd05ea268f510587321dbcc8d7c2977cdab Enable tests previously disabled due to an aliasing bug (#2005)
967aa77d2c8e360c7c01587522eec1c1d377c87e Contiguous indexing for View operations (#1990)
a43cb20f48943595894e345865bc1eabf58a5b48 Make inlining even more modular (#2004)
dc458358c0ac91dfaf4e6655a9b3fc206fc0c897 Test util cleanup (#2003)
3ca21ebe4d213f0070ffdfa4ae5d7f6cb0b8e870 More strict validation (#2000)
a7a7d573310c4707a9f381831d3114210461af01 Fix build problem (#1999)
fc235b064e27921fa9d6dbb9dc7055e5bae1c222 Just fixes comments (#1998)
482386c0509fee6edb2964c5ae72074791f3e43a cleanup (#1997)
4cbe0db6558a82c3097d281eec9c85ad2ea0893a Improve divisible split detection (#1970)
42ccc52bdc18bab0330f4b93ed1399164e2980c9 Minor build fix. (#1996)
fcf8c091f72d46f3055975a35afd06263324ede6 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989)
15f2f6dba8cbf408ec93c344767c1862c30f7ecc Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988)
8f1c7f52679a3ad6acfd419d28a2f4be4a7d89e2 Minor cleanup lower_unroll.cpp (#1994)
1d9858c80319ca7f0037db7de5f04e47f540d76c Minor cleanup (#1992)
f262d9cab59f41c669f53799c6d4a6b9fc4267eb Add support for uniform RNG (#1986)
eb1dad10c73f855eb1ecb20a8b1f7b6edb0c9ea3 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987)
634820c5e3586c0fe44132c51179b3155be18072 Add support for some empty fusion (#1981)
eabe8d844ad765ee4973faa4821d451ef71b83c3 Segment self mapping fusions (#1954)
e96aacfd9cf9b3c6d08f120282762489bdf540c8 Enable Transpose operation (#1882)
425dce2777420248e9f08893765b5402644f4161 Add a null scheduler that helps segmenting away no-op schedules (#1835)
306d4a68f127dd1b854b749855e48ba23444ba60 Fix canScheduleCompileTime check of transpose scheduler (#1969)
b1bd32cc1b2ae7bbd44701477bddbcfa6642a9be Minor fix (#1967)
bd93578143c1763c1e00ba613a017f8130a6b989 Enable transpose scheduler (#1927)
b7a206e93b4ac823c791c87f12859cf7af264a4c Move scheduler vectorize utilities into their own file (#1959)
d9420e4ca090489bf210e68e9912bb059b895baf View scheduling (#1928)
c668e13aea0cf21d40f95b48e0163b812712cdf2 Upstream push ci fixes (#1965)
c40202bb40ce955955bb97b12762ef3b6b612997 Fix dump effective bandwidth (#1962)
93505bcbb90a7849bd67090fe5708d867e8909e4 WAR on index mapping when exact and permissive maps differ (#1960)
45e95fd1d3c773ee9b2a21d79624c279d269da9f Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930)
a3ecb339442131f87842eb56955e4f17c544e99f Improve the comments at the beginning of index_compute.h (#1946)
f7bc3417cc2923a635042cc6cc361b2f344248d6 Remove unused variables (#1955)
df3393adbb5cb0309d091f358cfa98706bd4d313 Some cleanup (#1957)
7d1d7c8724ab5a226fad0f5a80feeac04975a496 TVDomainGuard factory (#1953)
357ba224c0fb41ed3e4e8594d95599c973f4a0ca Fill allocation with nan on tests (#1956)
8eafc54685d406f5ac527bcbacc475fda4492d7a Fix detection of unmappable root domains (#1952)
90a51f282601ba8ebd4c84b9334efd7762a234bc Some indexing cleanups, Add eye support (#1940)
ddc01e4e16428aec92f9c84d698f959b6436a971 Exclude unsupported data types (#1951)
992e17c0688fe690c51b50e81a75803621b7e6aa test the groups the same order as they are merged (#1949)
208262b75d1fed0597a0329d61d57bc8bcd7ff14 Move detection of self mapping IDs to IterDomainGraph from (#1941)
ac4de38c6ee53b366e85fdfe408c3642d32b57df Merge pull request #1945 from csarofeen/master_merge_0828
631094891a96f715d8c9925fb73d41013ca7f2e3 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943)
aab10bce4541204c46b91ff0f0ed9878aec1bfc4 Merge remote-tracking branch 'upstream/viable/strict' into HEAD
4c254c063bb55887b45677e3812357556a7aa80d Fix arange when step is negative (#1942)
89330aa23aa804340b2406ab58899d816e3dc3d2 Tensor factories must set the output shape as its input (#1939)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87779
Approved by: https://github.com/davidberard98
2022-11-04 20:04:34 +00:00

710 lines
24 KiB
C++

#include <torch/csrc/jit/codegen/fuser/codegen.h>
#include <ATen/ATen.h>
#include <ATen/code_template.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/codegen/fuser/compiler.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/codegen/fuser/tensor_info.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/codegen/fuser/cpu/resource_strings.h>
#include <torch/csrc/jit/codegen/fuser/cuda/resource_strings.h>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <tuple>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
// Template for computing the offset into the tensor to access a value
static auto dim_calc = at::jit::CodeTemplate(R"(
//printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]);
size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes};
${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
)");
static std::string valueName(const Value* n) {
return "n" + c10::to_string(n->unique());
}
static std::string scalarValue(const int64_t v) {
return c10::to_string(v);
}
static std::string scalarValue(const bool v) {
return c10::to_string(v);
}
// Note: The NAN, NEG_INFINITY and POS_INFINITY strings map to device-specific
// implementations of these special values. These macros are found in the
// resource strings for each device.
static std::string scalarValue(const double v) {
std::ostringstream out;
if (std::isnan(v)) {
out << "NAN";
} else if (std::isinf(v)) {
if (v < 0) {
out << "NEG_INFINITY";
} else {
out << "POS_INFINITY";
}
} else {
out << std::setprecision(16) << v;
}
return out.str();
}
// Note: Half is special-cased to avoid returning at::Half
static const char* scalarTypeName(const at::ScalarType type) {
if (type == at::ScalarType::Half) {
return "half";
}
if (type == at::ScalarType::BFloat16) {
return "__nv_bfloat16";
}
switch (type) {
#define DEFINE_CASE(ctype, name) \
case at::ScalarType::name: \
return #ctype;
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
#undef DEFINE_CASE
default:
throw std::runtime_error("unknown scalar type");
}
}
static const char* calcScalarTypeName(const at::ScalarType type) {
if (type == at::ScalarType::Half) {
return "float";
}
if (type == at::ScalarType::BFloat16) {
return "float";
}
return scalarTypeName(type);
}
static std::string variableType(const c10::Type& t) {
if (t.kind() == TypeKind::IntType) {
return "int64_t";
} else if (t.kind() == TypeKind::FloatType) {
return "double";
} else if (t.kind() == TypeKind::BoolType) {
return "bool";
} else if (auto scalar_type = t.expectRef<TensorType>().scalarType()) {
return calcScalarTypeName(*scalar_type);
}
// something went wrong with the type analysis during shape propagation
throw std::runtime_error(
"unknown scalar type during JIT fusion code generation");
}
static std::string typeCastedValueName(
const c10::Type& t,
const at::ScalarType outtype,
const std::string& vn) {
if (t.kind() == TypeKind::IntType || t.kind() == TypeKind::BoolType) {
if (!isIntegralType(outtype, /*includeBool=*/false)) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
} else if (t.kind() == TypeKind::FloatType) {
// We don't guard this on anything because in our type system for scalars,
// there is not a distinction between `float` and `double`, however there
// *is* a distinction in tensor scalar types. We conservatively insert a
// cast here, which may end up being a no-op if the tensor's scalar type
// is `double`.
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
} else if (t.kind() == TypeKind::NoneType) {
// Support None value for optional arguments like memory format
return vn;
} else if (auto scalar_type = t.expectRef<TensorType>().scalarType()) {
if (*scalar_type != outtype) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
}
// something went wrong with the type analysis during shape propagation
throw std::runtime_error(
"unknown scalar type during JIT fusion code generation");
}
// Writes RHS of special handling "simple mappable" ops
static std::string encodeSpecialRHS(const Node* n, at::jit::TemplateEnv& env) {
// special case for clamp fusion on missing min/max inputs
// Note: It may seem unusual to have the bounds as the first case below,
// this is so that if min or max is NaN, they are "ignored"
// and when the input is NaN, the output is, too
if (n->kind() == aten::clamp) {
const auto min = n->input(1);
const auto max = n->input(2);
env.s("0", valueName(n->input(0)));
if (!min->node()->mustBeNone() && !max->node()->mustBeNone()) {
env.s("1", valueName(min));
env.s("2", valueName(max));
return format("(${0} < ${1} ? ${1} : (${0} > ${2}? ${2} : ${0}))", env);
} else if (min->node()->mustBeNone()) {
env.s("1", valueName(max));
return format("(${0} > ${1} ? ${1} : ${0})", env);
} else if (max->node()->mustBeNone()) {
env.s("1", valueName(min));
return format("(${0} < ${1} ? ${1} : ${0})", env);
} else {
throw std::runtime_error(
"At least one of 'min' or 'max' must not be None");
}
} else {
throw std::runtime_error("Cannot encode RHS of the node, op not supported");
}
}
// This struct specifies a template for dispatching specific aten:: operators.
// The current variants of RHS code selection we support are for double and
// float output values. For example, an aten::log operator which is assigned
// to a float value would emit logf(), whereas an aten::log operator which is
// assigned to a double would emit log().
struct RHSTemplate {
// Common case: float and double dispatch are identical
RHSTemplate(const char* for_float)
: for_float(for_float), for_double(for_float) {}
RHSTemplate(const char* for_float, const char* for_double)
: for_float(for_float), for_double(for_double) {}
const char* for_float;
const char* for_double;
};
// Writes "simple mappable" ops
static std::string encodeRHS(const Node* n) {
static std::unordered_map<NodeKind, RHSTemplate> simple_map_ops = {
// unary
{aten::_cast_Float, "static_cast<float>(${0})"},
{aten::abs, "fabs(${0})"},
{aten::sigmoid, {"1.f / (1.f + expf(-${0}))", "1. / (1. + exp(-${0}))"}},
{aten::relu, "${0} < 0 ? 0.f : ${0} "},
{aten::threshold,
"${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
{aten::log, {"logf(${0})", "log(${0})"}},
{aten::log10, {"log10f(${0})", "log10(${0})"}},
{aten::log1p, {"log1pf(${0})", "log1p(${0})"}},
{aten::log2, {"log2f(${0})", "log2(${0})"}},
{aten::lgamma, {"lgammaf(${0})", "lgamma(${0})"}},
{aten::exp, {"expf(${0})", "exp(${0})"}},
{aten::expm1, {"expm1f(${0})", "expm1(${0})"}},
{aten::erf, {"erff(${0})", "erf(${0})"}},
{aten::erfc, {"erfcf(${0})", "erfc(${0})"}},
{aten::cos, {"cosf(${0})", "cos(${0})"}},
{aten::acos, {"acosf(${0})", "acos(${0})"}},
{aten::cosh, {"coshf(${0})", "cosh(${0})"}},
{aten::sin, {"sinf(${0})", "sin(${0})"}},
{aten::asin, {"asinf(${0})", "asin(${0})"}},
{aten::sinh, {"sinhf(${0})", "sinh(${0})"}},
{aten::tan, {"tanf(${0})", "tan(${0})"}},
{aten::atan, {"atanf(${0})", "atan(${0})"}},
{aten::tanh, {"tanhf(${0})", "tanh(${0})"}},
{aten::sqrt, {"sqrtf(${0})", "sqrt(${0})"}},
{aten::rsqrt, {"rsqrtf(${0})", "rsqrt(${0})"}},
{aten::ceil, {"ceilf(${0})", "ceil(${0})"}},
{aten::floor, {"floorf(${0})", "floor(${0})"}},
{aten::round, {"roundf(${0})", "round(${0})"}},
{aten::trunc, {"truncf(${0})", "trunc(${0})"}},
{aten::frac, {"${0} - truncf(${0})", "${0} - trunc(${0})"}},
{aten::reciprocal, {"1.f/(${0})", "1./(${0})"}},
{aten::neg, "-${0}"},
// simple binary
{aten::atan2, "atan2(${0}, ${1})"},
{aten::min,
"isnan(${0}) ? ${0} : (isnan(${1}) ? ${1} : (${0} < ${1} ? ${0} : ${1}))"},
{aten::max,
"isnan(${0}) ? ${0} : (isnan(${1}) ? ${1} : (${0} < ${1} ? ${1} : ${0}))"},
// binary with other
// TODO: some of these ops will not get generated because
// we only work on float inputs/outputs, but they are here to record
// that they are valid mappable ops once we handle more type
{aten::__and__, "${0} && ${1}"},
{aten::__lshift__, "${0} << ${1}"},
{aten::__or__, "${0} || ${1}"},
{aten::__rshift__, "${0} >> ${1}"},
{aten::__xor__, "${0} ^ ${1}"},
{aten::addcmul, "${0} + ${3} * ${1} * ${2}"},
{aten::div, "${0} / ${1}"},
{aten::eq, "${0_nocast} == ${1_nocast}"},
{aten::fmod, "fmodf(${0}, ${1})"},
{aten::ge, "(${0_nocast} >= ${1_nocast})"},
{aten::gt, "${0_nocast} > ${1_nocast}"},
{aten::le, "(${0_nocast} <= ${1_nocast})"},
{aten::lt, "${0_nocast} < ${1_nocast}"},
{aten::lerp, "${0} + ${2} * (${1} - ${0})"},
{aten::type_as, "(${0})"},
{aten::mul, "${0} * ${1}"},
{aten::ne, "${0_nocast} != ${1_nocast}"},
{aten::remainder, "fmod((${1} + fmod(${0}, ${1})), ${1})"},
{aten::pow, {"powf(${0}, ${1})", "pow(${0}, ${1})"}},
// alpha
{aten::add, "${0} + ${2}*${1}"},
{aten::sub, "(${0} - ${2}*${1})"},
{aten::rand_like, "uniform(rnd())"},
// where
{aten::where, "(${0} ? ${1} : ${2})"},
};
at::jit::TemplateEnv env;
if (simple_map_ops.find(n->kind()) == simple_map_ops.end()) {
return encodeSpecialRHS(n, env);
} else {
size_t i = 0;
auto outtype = n->output()->type()->expectRef<TensorType>().scalarType();
TORCH_INTERNAL_ASSERT(outtype);
for (auto in : n->inputs()) {
// PyTorch converts (scalar) argument types to result before applying the
// operator e.g. 1.4-torch.tensor(3) = -2
env.s(
c10::to_string(i),
typeCastedValueName(*in->type(), *outtype, valueName(in)));
// Uncasted operands only used for comparison operators
env.s(c10::to_string(i) + "_nocast", valueName(in));
i++;
}
const auto& templ = simple_map_ops.at(n->kind());
const char* str = nullptr;
if (*outtype == at::kFloat) {
str = templ.for_float;
} else {
str = templ.for_double;
}
AT_ASSERT(str);
return format(str, env);
}
}
static void emitIndexingFor(
std::ostream& out,
const std::string& tensor,
const int ndim,
const bool last_is_cont) {
at::jit::TemplateEnv env;
env.s("tensor", tensor);
out << format("IndexType ${tensor}_offset = 0;\n", env);
out << format("IndexType ${tensor}_linearIndex = linearIndex;\n", env);
for (int d = ndim - 1; d >= 0; --d) {
env.d("d", d);
env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]", env) : "");
env.s(
"times_stride",
(d < ndim - 1 || !last_is_cont)
? format("* ${tensor}.strides[${d}]", env)
: "");
out << dim_calc.format(env);
if (d > 0) {
out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env);
}
}
}
static void emitCheckFor(
std::ostream& out,
const std::string& tensor,
const int ndim,
const TensorDesc& desc) {
at::jit::TemplateEnv env;
env.s("tensor", tensor);
env.s("scalar_type", scalarTypeName(desc.scalar_type));
// allocate buffer to load 4
out << format("${scalar_type} ${tensor}_buf[4];\n", env);
// check if last dim is contiguous
if (!desc.lastIsContiguous()) {
out << "flag_vec4 = false;\n";
return;
}
// disable on dtype > 4 bytes for performance
if (at::elementSize(desc.scalar_type) > 4) {
out << "flag_vec4 = false;\n";
return;
}
// last dim size multiple of 4, other dim stride multiple of 4
for (int d = ndim - 1; d >= 0; --d) {
env.d("d", d);
if (d == ndim - 1) {
// last dim stride already checked above at compile time
out << format(
"if(${tensor}.sizes[${d}] % 4 != 0) flag_vec4 = false;\n", env);
} else {
out << format(
"if(${tensor}.strides[${d}] % 4 != 0) flag_vec4 = false;\n", env);
}
}
// pointer aligned
out << format(
"if(((uint64_t) ${tensor}.data) % (4 * sizeof(${scalar_type})) != 0) flag_vec4 = false;\n",
env);
}
// TODO: handle cases where we need to generate > 2^32 element tensors
std::string generateKernel(
const std::string& name,
const Graph& graph,
const std::vector<std::pair<const Value*, const c10::optional<TensorDesc>>>&
inputs,
const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
const bool use_cuda) {
at::jit::TemplateEnv env;
env.s("kernelName", name);
env.s(
"IndexType",
"unsigned int"); // Note: not uint32_t to avoid including cstdint
std::stringstream tensorChecks;
std::stringstream body;
std::stringstream body_vec4;
std::stringstream load;
std::stringstream store;
std::stringstream tensorOffsets;
std::vector<std::string> formals;
std::vector<std::string> argument_loads;
// Lambda for writing arguments
auto emitFormal = [&](const Value* n, const TensorDesc& desc) {
env.d(
"formal_index",
formals.size() +
1); // + 1 because the first argument is the linearIndex
std::string tensor =
"t" +
c10::to_string(
formals.size()); // can't be unique() because Param may be an output
const auto nDim = desc.nDim();
emitCheckFor(tensorChecks, tensor, nDim, desc);
emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
env.s("tensor", tensor);
env.d("nDim", nDim);
env.s("scalar_type", scalarTypeName(desc.scalar_type));
formals.push_back(
format("const TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
argument_loads.push_back(format(
"*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
env));
};
auto emitScalarFormal = [&](const Value* n) {
env.d(
"formal_index",
formals.size() +
1); // + 1 because the first argument is the linearIndex
std::string scalar =
"s" +
c10::to_string(
formals.size()); // can't be unique() because Param may be an output
env.d(
"formal_index",
formals.size() +
1); // + 1 because the first argument is the linearIndex
env.s("scalar", scalar);
env.s("scalar_type", variableType(*n->type()));
formals.push_back(format("${scalar_type} ${scalar}", env));
argument_loads.push_back(
format("*static_cast<${scalar_type}*>(args[${formal_index}])", env));
};
// Writes input parameters
for (const auto& input : inputs) {
if (input.second.has_value()) {
emitFormal(input.first, *input.second);
} else {
emitScalarFormal(input.first);
}
}
// Writes output parameters
for (const auto& output : outputs) {
emitFormal(output.first, output.second);
}
// Acquires input values
bool has_half_tensor = false;
bool has_bfloat_tensor = false;
size_t formal_count = 0;
for (const auto& input : inputs) {
auto p = input.first;
env.s("node", valueName(p));
env.d("formal", formal_count++);
// Acquires and converts (if needed) inputs
// Note: conversion from half is only supported for CUDA kernels.
// The conversion immediately converts fp16 inputs to float.
// Access for other types is common to CUDA and CPU kernels.
if (input.second.has_value()) {
const auto is_half = input.second.has_value() &&
((*input.second).scalar_type == at::ScalarType::Half);
const auto is_bfloat = input.second.has_value() &&
((*input.second).scalar_type == at::ScalarType::BFloat16);
const auto is_bool = input.second.has_value() &&
((*input.second).scalar_type == at::ScalarType::Bool);
if (is_half) {
AT_ASSERT(use_cuda);
env.s(
"access",
format("__half2float(t${formal}.data[t${formal}_offset])", env));
env.s("access_vec4", format("__half2float(t${formal}_buf[i])", env));
has_half_tensor = true;
} else if (is_bfloat) {
AT_ASSERT(use_cuda);
env.s(
"access",
format(
"__bfloat162float(t${formal}.data[t${formal}_offset])", env));
env.s(
"access_vec4", format("__bfloat162float(t${formal}_buf[i])", env));
has_bfloat_tensor = true;
} else if (use_cuda) {
// No __ldg overload for bool
if (is_bool) {
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
} else {
env.s(
"access",
format("__ldg(&t${formal}.data[t${formal}_offset])", env));
}
env.s("access_vec4", format("t${formal}_buf[i]", env));
} else {
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
env.s("access_vec4", format("t${formal}_buf[i]", env));
}
env.s("lhs_type", calcScalarTypeName(input.second->scalar_type));
// load input in vectorized code path
auto ele_size = at::elementSize((*input.second).scalar_type);
if (ele_size == 1) {
env.s(
"load4",
format(
"*(reinterpret_cast<float*>(t${formal}_buf)) = *(reinterpret_cast<float*>(t${formal}.data + t${formal}_offset))",
env));
} else if (ele_size == 2) {
env.s(
"load4",
format(
"*(reinterpret_cast<float2*>(t${formal}_buf)) = *(reinterpret_cast<float2*>(t${formal}.data + t${formal}_offset))",
env));
} else if (ele_size == 4) {
env.s(
"load4",
format(
"*(reinterpret_cast<float4*>(t${formal}_buf)) = *(reinterpret_cast<float4*>(t${formal}.data + t${formal}_offset))",
env));
} else {
env.s(
"load4",
format(
"for(int i = 0; i<4; i++) t${formal}_buf[i] = t${formal}.data[t${formal}_offset + i]",
env));
}
load << format("${load4};\n", env);
} else {
env.s("access", format("s${formal}", env));
env.s("access_vec4", format("s${formal}", env));
env.s("lhs_type", variableType(*input.first->type()));
}
body << format("${lhs_type} ${node} = ${access};\n", env);
body_vec4 << format("${lhs_type} ${node} = ${access_vec4};\n", env);
}
bool has_random = false;
// Generates code for intermediate nodes
// Note: Concat and Chunk are implicitly generated
// Note: Random number generation is only supported for CUDA kernels.
// Note: Constant None node is ignored and we will handle it in the
// places where the constant None node is used
// Note: No need to iterate over reference as n is a pointer
for (const auto n : graph.nodes()) {
static_assert(std::is_pointer<decltype(n)>::value, "n must be a pointer");
// Note: FusedConcat nodes work by narrowing the output Tensors before the
// kernel runs
if (n->kind() == prim::FusedConcat)
continue;
if (n->kind() == prim::ConstantChunk)
continue;
if (n->mustBeNone())
continue;
if (n->kind() == aten::rand_like) {
AT_ASSERT(use_cuda);
has_random = true;
}
// Always emit double for prim::Constant. This will be narrowed later based
// on either:
// - Tensor-Scalar operator type rules
// - Math function rules
if (n->kind() == prim::Constant) {
const auto val = toIValue(n->output()).value();
std::string rhs;
if (val.isDouble()) {
rhs = scalarValue(val.toDouble());
} else if (val.isBool()) {
rhs = scalarValue(val.toBool());
} else {
AT_ASSERT(val.isInt());
rhs = scalarValue(val.toInt());
}
env.s("node", valueName(n->output()));
env.s("rhs", rhs);
env.s("lhs_type", variableType(*n->output()->type()));
} else {
env.s("node", valueName(n->output()));
env.s("rhs", encodeRHS(n));
env.s("lhs_type", variableType(*n->output()->type()));
}
body << format("${lhs_type} ${node} = ${rhs};\n", env);
body_vec4 << format("${lhs_type} ${node} = ${rhs};\n", env);
}
// Generates writes to output tensors
for (const auto& output : outputs) {
env.d("formal", formal_count++);
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
env.s("access_vec4", format("t${formal}_buf[i]", env));
env.s("node", valueName(output.first));
// Acquires and converts (if needed) outputs
// Note: conversion to half is only supported for CUDA kernels.
const auto is_half = (output.second.scalar_type == at::ScalarType::Half);
const auto is_bfloat =
(output.second.scalar_type == at::ScalarType::BFloat16);
if (is_half) {
AT_ASSERT(use_cuda);
body << format("${access} = __float2half(${node});\n", env);
body_vec4 << format("${access_vec4} = __float2half(${node});\n", env);
has_half_tensor = true;
} else if (is_bfloat) {
AT_ASSERT(use_cuda);
body << format("${access} = __float2bfloat16(${node});\n", env);
body_vec4 << format("${access_vec4} = __float2bfloat16(${node});\n", env);
has_bfloat_tensor = true;
} else {
body << format("${access} = ${node};\n", env);
body_vec4 << format("${access_vec4} = ${node};\n", env);
}
// store output in vectorized code path
auto ele_size = at::elementSize(output.second.scalar_type);
if (ele_size == 1) {
env.s(
"store4",
format(
"*(reinterpret_cast<float*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float*>(t${formal}_buf))",
env));
} else if (ele_size == 2) {
env.s(
"store4",
format(
"*(reinterpret_cast<float2*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float2*>(t${formal}_buf))",
env));
} else if (ele_size == 4) {
env.s(
"store4",
format(
"*(reinterpret_cast<float4*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float4*>(t${formal}_buf))",
env));
} else {
env.s(
"store4",
format(
"for(int i = 0; i<4; i++) t${formal}.data[t${formal}_offset + i] = t${formal}_buf[i]",
env));
}
store << format("${store4};\n", env);
}
// Includes headers
// Note: CUDA kernels support halfs and random generation, CPU kernels do not
if (has_half_tensor) {
env.s("HalfHeader", cuda::half_support_literal);
} else {
env.s("HalfHeader", "");
}
if (has_bfloat_tensor) {
env.s("BFloat16Header", cuda::bfloat16_support_literal);
} else {
env.s("BFloat16Header", "");
}
if (has_random) {
env.s("RandHeader", cuda::rand_support_literal);
env.s("RandParam", cuda::rand_param);
env.s("RandInit", cuda::rand_init);
} else {
env.s("RandHeader", "");
env.s("RandParam", "");
env.s("RandInit", "");
}
// HIP headers must be included until precompiled header feature is available
// clang-format off
#if defined(USE_ROCM)
#if ROCM_VERSION < 40200
if (use_cuda && has_half_tensor) {
env.s("RuntimeHeader", R"(
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
)");
} else if (use_cuda) {
env.s("RuntimeHeader", R"(
#include <hip/hip_runtime.h>
)");
}
#else
// Still need the key defined, but empty.
env.s("RuntimeHeader", R"()");
#endif
#endif
// clang-format on
// Instantiates the CUDA or CPU-specific templates
env.s("tensorOffsets", tensorOffsets.str());
env.s("tensorChecks", tensorChecks.str());
env.s("kernelBody", body.str());
env.s("kernelBody_vec4", body_vec4.str());
env.s("kernelLoad", load.str());
env.s("kernelStore", store.str());
env.v("formals", formals);
env.v("argument_loads", argument_loads);
std::string code_string;
if (use_cuda) {
env.s("type_declarations", cuda::type_declarations_template.format(env));
code_string = cuda::cuda_compilation_unit_template.format(env);
} else {
env.s("type_declarations", cpu::type_declarations_template.format(env));
code_string = cpu::cpu_compilation_unit_template.format(env);
}
if (debugFuser()) {
std::cerr << "fusion code:" << code_string << std::endl;
}
return code_string;
}
} // namespace fuser
} // namespace jit
} // namespace torch