mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48213 it was completely broken unless rhs was a constant. Test Plan: new unit test in test_jit_fuser_te.py Reviewed By: eellison Differential Revision: D25071639 fbshipit-source-id: ef1010a9fd551db646b83adfaa961648a5c388ae
2021 lines
62 KiB
C++
2021 lines
62 KiB
C++
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
|
|
|
#include <c10/util/string_utils.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
|
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
|
|
using namespace torch::jit;
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace tensorexpr {
|
|
|
|
static int te_cuda_pointwise_loop_levels = -1;
|
|
static int te_cuda_pointwise_block_count = -1;
|
|
static int te_cuda_pointwise_block_size = -1;
|
|
static bool fallback_allowed = false;
|
|
static bool te_generate_block_code = false;
|
|
static bool te_must_use_llvm_on_cpu = false;
|
|
|
|
bool setFallbackAllowed(bool value) {
|
|
bool old_value = fallback_allowed;
|
|
fallback_allowed = value;
|
|
return old_value;
|
|
}
|
|
|
|
bool fallbackAllowed() {
|
|
static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK");
|
|
if (!enable_c_str) {
|
|
return fallback_allowed;
|
|
}
|
|
if (std::string(enable_c_str) == "0") {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool fallbackEnforced() {
|
|
static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK");
|
|
if (!enable_c_str) {
|
|
return fallback_allowed;
|
|
}
|
|
if (std::string(enable_c_str) == "2") {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
int& getTECudaPointwiseLoopLevels() {
|
|
return te_cuda_pointwise_loop_levels;
|
|
}
|
|
|
|
int& getTECudaPointwiseBlockCount() {
|
|
return te_cuda_pointwise_block_count;
|
|
}
|
|
|
|
int& getTECudaPointwiseBlockSize() {
|
|
return te_cuda_pointwise_block_size;
|
|
}
|
|
|
|
// TODO: Remove this global var
|
|
// Ideally Block code gen should be decided
|
|
// based on device type in tensor.
|
|
bool& getTEGenerateBlockCode() {
|
|
return te_generate_block_code;
|
|
}
|
|
|
|
bool& getTEMustUseLLVMOnCPU() {
|
|
return te_must_use_llvm_on_cpu;
|
|
}
|
|
|
|
c10::optional<at::Device> pickDeviceType(
|
|
const at::ArrayRef<torch::jit::Value*>& inputs) {
|
|
c10::optional<at::Device> device = c10::nullopt;
|
|
for (auto const& input : inputs) {
|
|
auto tt = input->type()->cast<TensorType>();
|
|
if (tt && tt->device()) {
|
|
if (device && *device != *tt->device()) {
|
|
return c10::nullopt;
|
|
}
|
|
device = *tt->device();
|
|
}
|
|
}
|
|
return device;
|
|
}
|
|
|
|
} // namespace tensorexpr
|
|
} // namespace jit
|
|
} // namespace torch
|
|
|
|
static at::ScalarType tensorType(Tensor* t) {
|
|
return static_cast<at::ScalarType>(t->body()->dtype().scalar_type());
|
|
}
|
|
|
|
static std::vector<ExprHandle> computeIndicesToBroadcast(
|
|
const std::vector<ExprHandle>& outputAxes,
|
|
const std::vector<ExprHandle>& inputSizes) {
|
|
if (outputAxes.size() < inputSizes.size()) {
|
|
throw malformed_input("Cannot broadcast to a lower rank tensor");
|
|
}
|
|
std::vector<ExprHandle> bcast;
|
|
auto axisIt = outputAxes.rbegin();
|
|
auto sizeIt = inputSizes.rbegin();
|
|
while (sizeIt != inputSizes.rend()) {
|
|
auto const& size = sizeIt->AsNode<IntImm>();
|
|
if (size && size->value() == 1) {
|
|
bcast.emplace_back(0);
|
|
} else {
|
|
bcast.emplace_back(*axisIt);
|
|
}
|
|
++axisIt;
|
|
++sizeIt;
|
|
}
|
|
std::reverse(bcast.begin(), bcast.end());
|
|
return bcast;
|
|
}
|
|
|
|
ExprHandle TensorExprKernel::broadcast(
|
|
Tensor* t,
|
|
const std::vector<ExprHandle>& axes) {
|
|
return t->call(computeIndicesToBroadcast(
|
|
axes, ExprVectorToExprHandleVector(t->buf()->dims())));
|
|
}
|
|
|
|
ExprHandle TensorExprKernel::chunk(
|
|
Tensor* t,
|
|
size_t chunkIdx,
|
|
int64_t dim,
|
|
int64_t chunks,
|
|
const std::vector<ExprHandle>& axes) {
|
|
if (dim < 0) {
|
|
dim = axes.size() + dim;
|
|
}
|
|
auto sizes = bufferSizes(t);
|
|
size_t step = sizes[dim] / chunks;
|
|
|
|
std::vector<ExprHandle> indices;
|
|
for (size_t i = 0; i < axes.size(); ++i) {
|
|
if (i == dim) {
|
|
indices.push_back(axes[i] + IntImm::make((int)chunkIdx * (int)step));
|
|
} else {
|
|
indices.push_back(axes[i]);
|
|
}
|
|
}
|
|
|
|
return t->call(indices);
|
|
}
|
|
|
|
ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) {
|
|
if (e.dtype().scalar_type() == dt) {
|
|
return e;
|
|
}
|
|
|
|
switch (dt) {
|
|
// NOLINTNEXTLINE
|
|
#define TYPE_CASE(Type, Name) \
|
|
case ScalarType::Name: \
|
|
e = cast<Type>(e); \
|
|
break;
|
|
AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE);
|
|
#undef TYPE_CASE
|
|
default:
|
|
throw unsupported_dtype();
|
|
}
|
|
return e;
|
|
}
|
|
|
|
ExprHandle TensorExprKernel::tensorOrConstant(
|
|
const torch::jit::Value* v,
|
|
const std::vector<ExprHandle>& axes) {
|
|
auto ti = tensors_.find(v->unique());
|
|
if (ti != tensors_.end()) {
|
|
return broadcast(ti->second, axes);
|
|
}
|
|
return constant(v);
|
|
}
|
|
|
|
std::vector<ExprHandle> TensorExprKernel::sizesFromVaryingShape(
|
|
const c10::VaryingShape<int64_t>& shape) {
|
|
std::vector<ExprHandle> dims;
|
|
for (size_t i = 0; i < *shape.size(); i++) {
|
|
dims.push_back(IntImm::make(*shape[i]));
|
|
}
|
|
return dims;
|
|
}
|
|
|
|
std::vector<DimArg> TensorExprKernel::dimsFromSizes(
|
|
const std::vector<ExprHandle>& sizes) {
|
|
std::vector<DimArg> dimArgs;
|
|
for (size_t idx = 0; idx < sizes.size(); idx++) {
|
|
dimArgs.emplace_back(DimArg(sizes[idx], "i" + c10::to_string(idx)));
|
|
}
|
|
return dimArgs;
|
|
}
|
|
|
|
std::vector<ExprHandle> TensorExprKernel::sizesForValue(
|
|
const torch::jit::Value* v) {
|
|
if (known_sizes_.count(v)) {
|
|
return known_sizes_.at(v);
|
|
}
|
|
|
|
// If the shape is present in the type info, just extract it from here. No
|
|
// need to infer it.
|
|
if (v->type()->kind() == TypeKind::TensorType) {
|
|
auto tt = v->type()->cast<TensorType>();
|
|
if (tt->isComplete()) {
|
|
return sizesFromVaryingShape(tt->sizes());
|
|
}
|
|
}
|
|
|
|
if (v->type()->isSubtypeOf(FloatType::get()) ||
|
|
v->type()->isSubtypeOf(IntType::get())) {
|
|
return {1};
|
|
}
|
|
if (v->type()->isSubtypeOf(NoneType::get())) {
|
|
return {};
|
|
}
|
|
|
|
known_sizes_[v] = inferSizesForValue(v);
|
|
return known_sizes_.at(v);
|
|
}
|
|
|
|
std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
|
|
const torch::jit::Value* v) {
|
|
switch (v->node()->kind()) {
|
|
case aten::_cast_Float:
|
|
case aten::sigmoid:
|
|
case aten::reciprocal:
|
|
case aten::neg:
|
|
case aten::relu:
|
|
case aten::log:
|
|
case aten::log10:
|
|
case aten::log1p:
|
|
case aten::log2:
|
|
case aten::exp:
|
|
case aten::expm1:
|
|
case aten::erf:
|
|
case aten::erfc:
|
|
case aten::cos:
|
|
case aten::sin:
|
|
case aten::tan:
|
|
case aten::rand_like:
|
|
case aten::acos:
|
|
case aten::asin:
|
|
case aten::cosh:
|
|
case aten::sinh:
|
|
case aten::atan:
|
|
case aten::tanh:
|
|
case aten::sqrt:
|
|
case aten::rsqrt:
|
|
case aten::abs:
|
|
case aten::ceil:
|
|
case aten::floor:
|
|
case aten::round:
|
|
case aten::trunc:
|
|
case aten::frac:
|
|
case aten::lgamma:
|
|
case aten::type_as:
|
|
return sizesForValue(v->node()->input(0));
|
|
|
|
case aten::sub:
|
|
case aten::add:
|
|
case aten::mul:
|
|
case aten::div:
|
|
case aten::__and__:
|
|
case aten::__or__:
|
|
case aten::__xor__:
|
|
case aten::__lshift__:
|
|
case aten::__rshift__:
|
|
case aten::eq:
|
|
case aten::ne:
|
|
case aten::ge:
|
|
case aten::gt:
|
|
case aten::le:
|
|
case aten::lt:
|
|
case aten::min:
|
|
case aten::max:
|
|
case aten::pow:
|
|
case aten::fmod:
|
|
case aten::remainder:
|
|
case aten::atan2: {
|
|
std::vector<std::vector<ExprHandle>> shapes;
|
|
for (size_t idx = 0; idx < 2; idx++) {
|
|
torch::jit::Value* inp = v->node()->input(idx);
|
|
shapes.push_back(sizesForValue(inp));
|
|
}
|
|
return broadcastShapes(shapes);
|
|
}
|
|
|
|
case aten::lerp:
|
|
case aten::clamp:
|
|
case aten::threshold:
|
|
case aten::where: {
|
|
std::vector<std::vector<ExprHandle>> shapes;
|
|
for (size_t idx = 0; idx < 3; idx++) {
|
|
torch::jit::Value* inp = v->node()->input(idx);
|
|
shapes.push_back(sizesForValue(inp));
|
|
}
|
|
return broadcastShapes(shapes);
|
|
}
|
|
|
|
case aten::addcmul: {
|
|
std::vector<std::vector<ExprHandle>> shapes;
|
|
for (size_t idx = 0; idx < 4; idx++) {
|
|
torch::jit::Value* inp = v->node()->input(idx);
|
|
shapes.push_back(sizesForValue(inp));
|
|
}
|
|
return broadcastShapes(shapes);
|
|
}
|
|
|
|
case prim::ConstantChunk: {
|
|
auto shape = sizesForValue(v->node()->input());
|
|
int dim = v->node()->i(attr::dim);
|
|
int chunks = v->node()->i(attr::chunks);
|
|
shape[dim] = IRSimplifier::simplify(shape[dim] / chunks);
|
|
return shape;
|
|
}
|
|
|
|
case aten::unsqueeze: {
|
|
auto const& n = v->node();
|
|
auto shape = sizesForValue(n->input(0));
|
|
|
|
int64_t dim = constant(n->input(1)).AsNode<IntImm>()->value();
|
|
// From the documentation
|
|
// (https://pytorch.org/docs/master/generated/torch.unsqueeze.html):
|
|
//
|
|
// A dim value within the range [-input.dim() - 1, input.dim() + 1) can be
|
|
// used. Negative dim will correspond to unsqueeze() applied at dim = dim
|
|
// + input.dim() + 1.
|
|
if (dim < 0) {
|
|
dim = dim + shape.size() + 1;
|
|
}
|
|
if (dim < 0 || dim > shape.size()) {
|
|
throw std::runtime_error("Invalid 'dim' input in aten::unsqueeze");
|
|
}
|
|
|
|
shape.insert(shape.begin() + dim, ExprHandle(1));
|
|
return shape;
|
|
}
|
|
|
|
case aten::cat: {
|
|
// In JIT IR, aten::cat usually appears with the following nodes around
|
|
// it:
|
|
// %dim : int = prim::Constant[value=0]()
|
|
// %inputs : Tensor[] = prim::ListConstruct(%a, %b, ...)
|
|
// %cat_output : Tensor = aten::cat(%inputs, %dim)
|
|
// Shapes of the input tensors could only differ at the dimension %dim.
|
|
// The sizes of the output tensor on that dimension is a sum of the
|
|
// corresponding sizes of the input tensors, the other dimension have the
|
|
// same sizes.
|
|
// Negative dim will correspond to dim = dim + input.dim().
|
|
auto const& n = v->node();
|
|
auto inputs = n->input(0)->node()->inputs();
|
|
if (inputs.size() == 0) {
|
|
throw std::runtime_error("Empty input list is passed to aten::cat");
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(n->input(1)->node()->kind() == prim::Constant);
|
|
int64_t dim = n->input(1)->node()->i(attr::value);
|
|
auto shape = sizesForValue(inputs[0]);
|
|
if (dim < 0) {
|
|
dim += shape.size();
|
|
}
|
|
if (dim < 0 || dim > shape.size()) {
|
|
throw std::runtime_error("Invalid 'dim' input in aten::cat");
|
|
}
|
|
|
|
ExprHandle concat_dim_size = 0;
|
|
for (auto input : inputs) {
|
|
concat_dim_size = concat_dim_size + sizesForValue(input)[dim];
|
|
}
|
|
concat_dim_size = IRSimplifier::simplify(concat_dim_size);
|
|
shape[dim] = concat_dim_size;
|
|
return shape;
|
|
}
|
|
|
|
case aten::softmax:
|
|
case aten::log_softmax:
|
|
// Output of softmax / log_softmax has the same shape as input 0.
|
|
return sizesForValue(v->node()->input(0));
|
|
|
|
case aten::slice:
|
|
throw std::runtime_error(
|
|
"Shape info is not implemented for this kind of node");
|
|
|
|
default: {
|
|
GRAPH_DEBUG("Can't infer sizes for the node: ", *v->node());
|
|
GRAPH_DEBUG("Full fusion group graph:\n", *v->node()->owningGraph());
|
|
throw std::runtime_error("Unhandled node kind");
|
|
}
|
|
}
|
|
}
|
|
|
|
ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) {
|
|
if (v->node()->kind() == prim::Constant) {
|
|
const auto val = toIValue(v).value();
|
|
if (val.isDouble()) {
|
|
return FloatImm::make(static_cast<float>(val.toDouble()));
|
|
} else if (val.isInt()) {
|
|
return IntImm::make(val.toInt());
|
|
} else if (val.isBool()) {
|
|
return BoolImm::make(val.toBool());
|
|
} else if (val.isNone()) {
|
|
// This is just a placeholder so we don't throw. None-handling
|
|
// is operator-specific and should be handled properly in
|
|
// the operator-specific lowering code.
|
|
return IntImm::make(0);
|
|
} else {
|
|
throw unsupported_dtype();
|
|
}
|
|
}
|
|
|
|
if (!scalars_.count(v->unique())) {
|
|
throw malformed_input("no scalar in Constant");
|
|
}
|
|
|
|
return scalars_.at(v->unique());
|
|
}
|
|
|
|
ExprHandle promoteIntegerToFloat(const ExprHandle& e) {
|
|
auto scalarType = static_cast<c10::ScalarType>(e.dtype().scalar_type());
|
|
if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) {
|
|
return e;
|
|
}
|
|
auto defaultType = static_cast<tensorexpr::ScalarType>(
|
|
c10::typeMetaToScalarType(c10::get_default_dtype()));
|
|
return Cast::make(Dtype(defaultType, e.dtype().lanes()), e);
|
|
}
|
|
|
|
void TensorExprKernel::promoteInputs(std::vector<ExprHandle>& inputs) {
|
|
if (inputs.empty()) {
|
|
return;
|
|
}
|
|
|
|
// Find the highest type among the inputs.
|
|
ScalarType highType = inputs[0].dtype().scalar_type();
|
|
for (const auto input : inputs) {
|
|
highType = promoteTypes(highType, input.dtype().scalar_type());
|
|
}
|
|
|
|
for (ExprHandle& e : inputs) {
|
|
e = promoteToDtype(e, highType);
|
|
}
|
|
}
|
|
|
|
ExprHandle TensorExprKernel::demoteOutput(
|
|
const ExprHandle& e,
|
|
const torch::jit::Value* v) {
|
|
if (v->type()->kind() != TypeKind::TensorType) {
|
|
return e;
|
|
}
|
|
|
|
if (!v->isCompleteTensor()) {
|
|
return e;
|
|
}
|
|
|
|
auto tt = *v->type()->cast<TensorType>()->scalarType();
|
|
|
|
if (tt == static_cast<at::ScalarType>(e.dtype().scalar_type())) {
|
|
return e;
|
|
}
|
|
|
|
switch (tt) {
|
|
// NOLINTNEXTLINE
|
|
#define TYPE_CASE(Type, Name) \
|
|
case at::ScalarType::Name: \
|
|
return cast<Type>(e);
|
|
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
|
#undef TYPE_CASE
|
|
case at::ScalarType::Bool:
|
|
return cast<bool>(e);
|
|
default:
|
|
throw unsupported_dtype();
|
|
}
|
|
|
|
return e;
|
|
}
|
|
|
|
static bool isOne(ExprHandle e) {
|
|
auto const& n = e.AsNode<IntImm>();
|
|
if (!n) {
|
|
return false;
|
|
}
|
|
return n->value() == 1;
|
|
}
|
|
|
|
std::vector<ExprHandle> TensorExprKernel::broadcastShapes(
|
|
std::vector<std::vector<ExprHandle>> shapes) {
|
|
size_t n = shapes.size();
|
|
if (n == 1) {
|
|
return shapes[0];
|
|
}
|
|
auto res1 = broadcastShapes(shapes[n - 2], shapes[n - 1]);
|
|
shapes[n - 2] = res1;
|
|
shapes.pop_back();
|
|
auto res2 = broadcastShapes(shapes);
|
|
return res2;
|
|
}
|
|
|
|
std::vector<ExprHandle> TensorExprKernel::broadcastShapes(
|
|
const std::vector<ExprHandle>& a,
|
|
const std::vector<ExprHandle>& b) {
|
|
auto at = a.rbegin();
|
|
auto bt = b.rbegin();
|
|
std::vector<ExprHandle> ret;
|
|
while (at != a.rend() || bt != b.rend()) {
|
|
if (at == a.rend()) {
|
|
hasBroadcast_ = true;
|
|
ret.push_back(*bt++);
|
|
continue;
|
|
}
|
|
if (bt == b.rend()) {
|
|
hasBroadcast_ = true;
|
|
ret.push_back(*at++);
|
|
continue;
|
|
}
|
|
// TODO: if neither *at nor *bt is 1, ensure they are identical
|
|
// expressions. Nb: `==` doesn't work since that simply produces a new
|
|
// ExprHandle.
|
|
ExprHandle dim = *at;
|
|
if (isOne(*at)) {
|
|
if (!isOne(*bt)) {
|
|
dim = *bt;
|
|
hasBroadcast_ = true;
|
|
}
|
|
}
|
|
ret.push_back(dim);
|
|
at++;
|
|
bt++;
|
|
}
|
|
std::reverse(ret.begin(), ret.end());
|
|
return ret;
|
|
}
|
|
|
|
std::vector<ExprHandle> TensorExprKernel::valueShape(
|
|
const torch::jit::Value* v) {
|
|
auto it = tensors_.find(v->unique());
|
|
if (it == tensors_.end()) {
|
|
return {};
|
|
}
|
|
return ExprVectorToExprHandleVector(it->second->dims());
|
|
}
|
|
|
|
Tensor* TensorExprKernel::computeOneOperand(
|
|
const std::string& name,
|
|
const torch::jit::Value* v,
|
|
const std::function<ExprHandle(const ExprHandle&)>& innerExpr) {
|
|
auto const& n = v->node();
|
|
auto const& shape = valueShape(n->inputs()[0]);
|
|
return Compute(
|
|
name,
|
|
c10::fmap<DimArg>(shape),
|
|
[this, v, innerExpr](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
|
std::vector<ExprHandle> inputs = {
|
|
tensorOrConstant(n->inputs()[0], indices)};
|
|
|
|
promoteInputs(inputs);
|
|
ExprHandle compute = innerExpr(inputs[0]);
|
|
return demoteOutput(compute, n->output());
|
|
});
|
|
}
|
|
|
|
Tensor* TensorExprKernel::computeTwoOperand(
|
|
const std::string& name,
|
|
const torch::jit::Value* v,
|
|
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
|
innerExpr) {
|
|
auto const& n = v->node();
|
|
auto const& shape =
|
|
broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1]));
|
|
return Compute(
|
|
name,
|
|
c10::fmap<DimArg>(shape),
|
|
[this, v, innerExpr](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
|
std::vector<ExprHandle> inputs = {
|
|
tensorOrConstant(n->inputs()[0], indices),
|
|
tensorOrConstant(n->inputs()[1], indices),
|
|
};
|
|
|
|
promoteInputs(inputs);
|
|
ExprHandle compute = innerExpr(inputs[0], inputs[1]);
|
|
return demoteOutput(compute, n->output());
|
|
});
|
|
}
|
|
|
|
Tensor* TensorExprKernel::computeTwoOperandWithAlpha(
|
|
const std::string& name,
|
|
const torch::jit::Value* v,
|
|
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
|
innerExpr) {
|
|
auto const& n = v->node();
|
|
auto const& shape =
|
|
broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1]));
|
|
return Compute(
|
|
name,
|
|
c10::fmap<DimArg>(shape),
|
|
[this, v, innerExpr](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
|
std::vector<ExprHandle> inputs = {
|
|
tensorOrConstant(n->inputs()[0], indices),
|
|
tensorOrConstant(n->inputs()[1], indices),
|
|
tensorOrConstant(n->inputs()[2], indices),
|
|
};
|
|
|
|
promoteInputs(inputs);
|
|
ExprHandle compute = innerExpr(inputs[0], inputs[2] * inputs[1]);
|
|
return demoteOutput(compute, n->output());
|
|
});
|
|
}
|
|
|
|
Tensor* TensorExprKernel::computeConditionWithTwoOperand(
|
|
const std::string& name,
|
|
const torch::jit::Value* v,
|
|
const std::function<
|
|
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
|
|
innerExpr) {
|
|
auto const& n = v->node();
|
|
std::vector<std::vector<ExprHandle>> shapes;
|
|
for (size_t idx = 0; idx < 2; idx++) {
|
|
torch::jit::Value* inp = n->input(idx);
|
|
shapes.push_back(sizesForValue(inp));
|
|
}
|
|
auto const& shape = broadcastShapes(shapes);
|
|
return Compute(
|
|
name,
|
|
c10::fmap<DimArg>(shape),
|
|
[this, v, innerExpr](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
|
std::vector<ExprHandle> inputs = {
|
|
tensorOrConstant(n->inputs()[1], indices),
|
|
tensorOrConstant(n->inputs()[2], indices),
|
|
};
|
|
|
|
promoteInputs(inputs);
|
|
// First expr is the condition, which we don't promote
|
|
inputs.emplace(
|
|
inputs.begin(), tensorOrConstant(n->inputs()[0], indices));
|
|
ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2]);
|
|
return demoteOutput(compute, n->output());
|
|
});
|
|
}
|
|
|
|
Tensor* TensorExprKernel::computeThreeOperand(
|
|
const std::string& name,
|
|
const torch::jit::Value* v,
|
|
const std::function<
|
|
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
|
|
innerExpr) {
|
|
auto const& n = v->node();
|
|
std::vector<std::vector<ExprHandle>> shapes;
|
|
for (size_t idx = 0; idx < 3; idx++) {
|
|
torch::jit::Value* inp = n->input(idx);
|
|
shapes.push_back(sizesForValue(inp));
|
|
}
|
|
auto const& shape = broadcastShapes(shapes);
|
|
return Compute(
|
|
name,
|
|
c10::fmap<DimArg>(shape),
|
|
[this, v, innerExpr](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
|
std::vector<ExprHandle> inputs = {
|
|
tensorOrConstant(n->inputs()[0], indices),
|
|
tensorOrConstant(n->inputs()[1], indices),
|
|
tensorOrConstant(n->inputs()[2], indices),
|
|
};
|
|
|
|
promoteInputs(inputs);
|
|
ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2]);
|
|
return demoteOutput(compute, n->output());
|
|
});
|
|
}
|
|
|
|
Tensor* TensorExprKernel::computeFourOperand(
|
|
const std::string& name,
|
|
const torch::jit::Value* v,
|
|
const std::function<ExprHandle(
|
|
const ExprHandle&,
|
|
const ExprHandle&,
|
|
const ExprHandle&,
|
|
const ExprHandle&)>& innerExpr) {
|
|
auto const& n = v->node();
|
|
std::vector<std::vector<ExprHandle>> shapes;
|
|
for (size_t idx = 0; idx < 4; idx++) {
|
|
torch::jit::Value* inp = n->input(idx);
|
|
shapes.push_back(sizesForValue(inp));
|
|
}
|
|
auto const& shape = broadcastShapes(shapes);
|
|
return Compute(
|
|
name,
|
|
c10::fmap<DimArg>(shape),
|
|
[this, v, innerExpr](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
|
std::vector<ExprHandle> inputs = {
|
|
tensorOrConstant(n->inputs()[0], indices),
|
|
tensorOrConstant(n->inputs()[1], indices),
|
|
tensorOrConstant(n->inputs()[2], indices),
|
|
tensorOrConstant(n->inputs()[3], indices),
|
|
};
|
|
|
|
promoteInputs(inputs);
|
|
ExprHandle compute =
|
|
innerExpr(inputs[0], inputs[1], inputs[2], inputs[3]);
|
|
return demoteOutput(compute, n->output());
|
|
});
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Convert boolean to integer, if needed.
|
|
ExprHandle boolToInteger(const ExprHandle& x) {
|
|
return x.dtype().scalar_type() == ScalarType::Bool ? cast<int>(x) : x;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
c10::optional<ScalarType> findDtypeForValue(const torch::jit::Value* v) {
|
|
if (v->type()->kind() == TypeKind::TensorType) {
|
|
auto tt = v->type()->cast<TensorType>();
|
|
if (tt->scalarType()) {
|
|
return static_cast<ScalarType>(*tt->scalarType());
|
|
}
|
|
}
|
|
return c10::nullopt;
|
|
}
|
|
|
|
Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
|
switch (v->node()->kind()) {
|
|
case aten::add: {
|
|
auto add_lambda = [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return boolToInteger(lhs) + boolToInteger(rhs);
|
|
};
|
|
TORCH_INTERNAL_ASSERT(
|
|
v->node()->inputs().size() == 2 || v->node()->inputs().size() == 3);
|
|
return (v->node()->inputs().size() > 2)
|
|
? computeTwoOperandWithAlpha("aten_add", v, add_lambda)
|
|
: computeTwoOperand("aten_add", v, add_lambda);
|
|
} break;
|
|
|
|
case aten::_cast_Float: {
|
|
return computeOneOperand("aten_cast_float", v, [](const ExprHandle& a) {
|
|
return cast<float>(a);
|
|
});
|
|
} break;
|
|
|
|
case aten::sub: {
|
|
auto sub_lambda = [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
// NB: sub isn't supported on boolean, no need to promote to integer.
|
|
return lhs - rhs;
|
|
};
|
|
TORCH_INTERNAL_ASSERT(
|
|
v->node()->inputs().size() == 2 || v->node()->inputs().size() == 3);
|
|
return (v->node()->inputs().size() > 2)
|
|
? computeTwoOperandWithAlpha("aten_sub", v, sub_lambda)
|
|
: computeTwoOperand("aten_sub", v, sub_lambda);
|
|
} break;
|
|
|
|
case aten::mul: {
|
|
return computeTwoOperand(
|
|
"aten_mul", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return boolToInteger(lhs) * boolToInteger(rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::div: {
|
|
return computeTwoOperand(
|
|
"aten_div", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return boolToInteger(lhs) / boolToInteger(rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::__and__: {
|
|
return computeTwoOperand(
|
|
"aten_and", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return boolToInteger(lhs) & boolToInteger(rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::__or__: {
|
|
return computeTwoOperand(
|
|
"aten_or", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return boolToInteger(lhs) | boolToInteger(rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::__xor__: {
|
|
return computeTwoOperand(
|
|
"aten_xor", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return boolToInteger(lhs) ^ boolToInteger(rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::__lshift__: {
|
|
return computeTwoOperand(
|
|
"aten_lshift", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return lhs << rhs;
|
|
});
|
|
} break;
|
|
|
|
case aten::__rshift__: {
|
|
return computeTwoOperand(
|
|
"aten_rshift", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return lhs >> rhs;
|
|
});
|
|
} break;
|
|
|
|
case aten::addcmul: {
|
|
return computeFourOperand(
|
|
"aten_addcmul",
|
|
v,
|
|
[](const ExprHandle& a0,
|
|
const ExprHandle& a1,
|
|
const ExprHandle& a2,
|
|
const ExprHandle& a3) { return a0 + a3 * a1 * a2; });
|
|
} break;
|
|
|
|
case aten::eq: {
|
|
return computeTwoOperand(
|
|
"aten_eq", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return cast<bool>(lhs == rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::ne: {
|
|
return computeTwoOperand(
|
|
"aten_ne", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return cast<bool>(lhs != rhs);
|
|
});
|
|
} break;
|
|
case aten::ge: {
|
|
return computeTwoOperand(
|
|
"aten_ge", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return cast<bool>(lhs >= rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::gt: {
|
|
return computeTwoOperand(
|
|
"aten_gt", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return cast<bool>(lhs > rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::le: {
|
|
return computeTwoOperand(
|
|
"aten_le", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return cast<bool>(lhs <= rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::lt: {
|
|
return computeTwoOperand(
|
|
"aten_lt", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return cast<bool>(lhs < rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::min: {
|
|
return computeTwoOperand(
|
|
"aten_min", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return Min::make(boolToInteger(lhs), boolToInteger(rhs), false);
|
|
});
|
|
} break;
|
|
|
|
case aten::max: {
|
|
return computeTwoOperand(
|
|
"aten_max", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return Max::make(boolToInteger(lhs), boolToInteger(rhs), false);
|
|
});
|
|
} break;
|
|
|
|
case aten::clamp: {
|
|
bool noMin = false;
|
|
bool noMax = false;
|
|
if (v->node()->input(1)->node()->kind() == prim::Constant) {
|
|
const auto val = toIValue(v->node()->input(1)).value();
|
|
if (val.isNone()) {
|
|
noMin = true;
|
|
}
|
|
}
|
|
|
|
if (v->node()->input(2)->node()->kind() == prim::Constant) {
|
|
const auto val = toIValue(v->node()->input(2)).value();
|
|
if (val.isNone()) {
|
|
noMax = true;
|
|
}
|
|
}
|
|
|
|
return computeThreeOperand(
|
|
"aten_clamp",
|
|
v,
|
|
[noMin, noMax](
|
|
const ExprHandle& in,
|
|
const ExprHandle& min,
|
|
const ExprHandle& max) {
|
|
if (noMin && noMax) {
|
|
return in;
|
|
} else if (noMin) {
|
|
return CompareSelect::make(in, max, max, in, kGT);
|
|
} else if (noMax) {
|
|
return CompareSelect::make(in, min, min, in, kLT);
|
|
} else {
|
|
return CompareSelect::make(
|
|
in,
|
|
min,
|
|
min,
|
|
CompareSelect::make(in, max, max, in, kGT),
|
|
kLT);
|
|
}
|
|
});
|
|
} break;
|
|
|
|
case aten::sigmoid: {
|
|
return computeOneOperand(
|
|
"aten_sigmoid", v, [](const ExprHandle& a) { return sigmoid(a); });
|
|
} break;
|
|
|
|
case aten::reciprocal: {
|
|
return computeOneOperand("aten_reciprocal", v, [](const ExprHandle& a) {
|
|
return ExprHandle(1.0f) / a;
|
|
});
|
|
} break;
|
|
|
|
case aten::neg: {
|
|
return computeOneOperand("aten_neg", v, [](const ExprHandle& a) {
|
|
return ExprHandle(-0) - a;
|
|
});
|
|
} break;
|
|
|
|
case aten::relu: {
|
|
return computeOneOperand("aten_relu", v, [](const ExprHandle& a) {
|
|
auto zero = Cast::make(a.dtype(), 0);
|
|
return ifThenElse(CompareSelect::make(a, zero, kLT), zero, a);
|
|
});
|
|
} break;
|
|
|
|
case aten::log: {
|
|
return computeOneOperand("aten_log", v, [](const ExprHandle& a) {
|
|
return log(promoteIntegerToFloat(a));
|
|
});
|
|
} break;
|
|
|
|
case aten::log10: {
|
|
return computeOneOperand("aten_log10", v, [](const ExprHandle& a) {
|
|
return log10(promoteIntegerToFloat(a));
|
|
});
|
|
} break;
|
|
|
|
case aten::log1p: {
|
|
return computeOneOperand(
|
|
"aten_log1p", v, [](const ExprHandle& a) { return log1p(a); });
|
|
} break;
|
|
|
|
case aten::log2: {
|
|
return computeOneOperand("aten_log2", v, [](const ExprHandle& a) {
|
|
return log2(promoteIntegerToFloat(a));
|
|
});
|
|
} break;
|
|
|
|
case aten::exp: {
|
|
return computeOneOperand(
|
|
"aten_exp", v, [](const ExprHandle& a) { return exp(a); });
|
|
} break;
|
|
|
|
case aten::expm1: {
|
|
return computeOneOperand(
|
|
"aten_expm1", v, [](const ExprHandle& a) { return expm1(a); });
|
|
} break;
|
|
|
|
case aten::erf: {
|
|
return computeOneOperand(
|
|
"aten_erf", v, [](const ExprHandle& a) { return erf(a); });
|
|
} break;
|
|
|
|
case aten::erfc: {
|
|
return computeOneOperand(
|
|
"aten_erfc", v, [](const ExprHandle& a) { return erfc(a); });
|
|
} break;
|
|
|
|
case aten::cos: {
|
|
return computeOneOperand("aten_cos", v, [](const ExprHandle& a) {
|
|
return cos(promoteIntegerToFloat(a));
|
|
});
|
|
} break;
|
|
|
|
case aten::sin: {
|
|
return computeOneOperand("aten_sin", v, [](const ExprHandle& a) {
|
|
return sin(promoteIntegerToFloat(a));
|
|
});
|
|
} break;
|
|
|
|
case aten::tan: {
|
|
return computeOneOperand("aten_tan", v, [](const ExprHandle& a) {
|
|
return tan(promoteIntegerToFloat(a));
|
|
});
|
|
} break;
|
|
|
|
case aten::type_as: {
|
|
auto const& n = v->node();
|
|
Tensor* rhs = tensors_.at(n->inputs()[1]->unique());
|
|
auto dtype = rhs->body()->dtype();
|
|
return computeOneOperand(
|
|
"aten_type_as", v, [dtype](const ExprHandle& lhs) {
|
|
return Cast::make(dtype, lhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::rand_like: {
|
|
hasRandom_ = true;
|
|
return computeOneOperand("aten_rand_like", v, [](const ExprHandle& a) {
|
|
return Intrinsics::make(IntrinsicsOp::kRand, a.dtype());
|
|
});
|
|
} break;
|
|
|
|
case aten::pow: {
|
|
return computeTwoOperand(
|
|
"aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
if (!rhs.node()->isConstant()) {
|
|
return pow(lhs, rhs);
|
|
}
|
|
double val =
|
|
immediateAs<double>(IRSimplifier::simplify(rhs.node()));
|
|
|
|
if (val == 1.0f) {
|
|
return lhs;
|
|
} else if (val == 2.0f) { // NOLINT
|
|
return lhs * lhs;
|
|
} else if (val == 3.0f) { // NOLINT
|
|
return (lhs * lhs) * lhs;
|
|
} else if (val == 4.0f) { // NOLINT
|
|
ExprHandle tmp = lhs * lhs;
|
|
return tmp * tmp;
|
|
} else if (val == 0.5f) { // NOLINT
|
|
return sqrt(lhs);
|
|
} else if (val == 0.0f) {
|
|
return ExprHandle(1.0f);
|
|
} else if (val == -0.5f) { // NOLINT
|
|
return rsqrt(lhs);
|
|
} else if (val == -1.0f) {
|
|
return ExprHandle(1.0f) / lhs;
|
|
} else if (val == -2.0f) { // NOLINT
|
|
return ExprHandle(1.0f) / (lhs * lhs);
|
|
}
|
|
return pow(lhs, rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::fmod: {
|
|
return computeTwoOperand(
|
|
"aten_fmod", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return fmod(lhs, rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::lerp: {
|
|
return computeThreeOperand(
|
|
"aten_lerp",
|
|
v,
|
|
[](const ExprHandle& a,
|
|
const ExprHandle& end,
|
|
const ExprHandle& weight) { return a + weight * (end - a); });
|
|
} break;
|
|
case aten::remainder: {
|
|
auto imodImpl = [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return Mod::make(lhs, rhs);
|
|
};
|
|
auto fmodImpl = [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return fmod((rhs + fmod(lhs, rhs)), rhs);
|
|
};
|
|
{
|
|
auto const& n = v->node();
|
|
auto const& shape = broadcastShapes(
|
|
valueShape(n->inputs()[0]), valueShape(n->inputs()[1]));
|
|
return Compute(
|
|
"aten_remainder",
|
|
c10::fmap<DimArg>(shape),
|
|
[&](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
|
std::vector<ExprHandle> inputs = {
|
|
tensorOrConstant(n->inputs()[0], indices),
|
|
tensorOrConstant(n->inputs()[1], indices),
|
|
};
|
|
|
|
promoteInputs(inputs);
|
|
bool allInt = true;
|
|
for (auto& e : inputs) {
|
|
if (e.dtype().is_floating_point()) {
|
|
allInt = false;
|
|
break;
|
|
}
|
|
}
|
|
if (allInt) {
|
|
return demoteOutput(
|
|
imodImpl(inputs[0], inputs[1]), n->output());
|
|
} else {
|
|
return demoteOutput(
|
|
fmodImpl(inputs[0], inputs[1]), n->output());
|
|
}
|
|
});
|
|
}
|
|
|
|
} break;
|
|
|
|
case aten::acos: {
|
|
return computeOneOperand("aten_acos", v, [](const ExprHandle& a) {
|
|
return acos(promoteIntegerToFloat(a));
|
|
});
|
|
} break;
|
|
|
|
case aten::asin: {
|
|
return computeOneOperand(
|
|
"aten_asin", v, [](const ExprHandle& a) { return asin(a); });
|
|
} break;
|
|
|
|
case aten::cosh: {
|
|
return computeOneOperand(
|
|
"aten_cosh", v, [](const ExprHandle& a) { return cosh(a); });
|
|
} break;
|
|
|
|
case aten::sinh: {
|
|
return computeOneOperand(
|
|
"aten_sinh", v, [](const ExprHandle& a) { return sinh(a); });
|
|
} break;
|
|
|
|
case aten::atan: {
|
|
return computeOneOperand("aten_atan", v, [](const ExprHandle& a) {
|
|
return atan(promoteIntegerToFloat(a));
|
|
});
|
|
} break;
|
|
|
|
case aten::atan2: {
|
|
return computeTwoOperand(
|
|
"aten_atan2", v, [](const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return atan2(lhs, rhs);
|
|
});
|
|
} break;
|
|
|
|
case aten::tanh: {
|
|
return computeOneOperand("aten_tanh", v, [](const ExprHandle& a) {
|
|
return tanh(promoteIntegerToFloat(a));
|
|
});
|
|
} break;
|
|
|
|
case aten::sqrt: {
|
|
return computeOneOperand("aten_sqrt", v, [](const ExprHandle& a) {
|
|
return sqrt(promoteIntegerToFloat(a));
|
|
});
|
|
} break;
|
|
|
|
case aten::rsqrt: {
|
|
return computeOneOperand(
|
|
"aten_rsqrt", v, [](const ExprHandle& a) { return rsqrt(a); });
|
|
} break;
|
|
|
|
case aten::abs: {
|
|
return computeOneOperand(
|
|
"aten_abs", v, [](const ExprHandle& a) { return fabs(a); });
|
|
} break;
|
|
|
|
case aten::ceil: {
|
|
return computeOneOperand(
|
|
"aten_ceil", v, [](const ExprHandle& a) { return ceil(a); });
|
|
} break;
|
|
|
|
case aten::floor: {
|
|
return computeOneOperand(
|
|
"aten_floor", v, [](const ExprHandle& a) { return floor(a); });
|
|
} break;
|
|
|
|
case aten::round: {
|
|
return computeOneOperand(
|
|
"aten_round", v, [](const ExprHandle& a) { return round(a); });
|
|
} break;
|
|
|
|
case aten::trunc: {
|
|
return computeOneOperand(
|
|
"aten_trunc", v, [](const ExprHandle& a) { return trunc(a); });
|
|
} break;
|
|
|
|
case aten::threshold: {
|
|
return computeThreeOperand(
|
|
"aten_threshold",
|
|
v,
|
|
[](const ExprHandle& a,
|
|
const ExprHandle& threshold,
|
|
const ExprHandle& value) {
|
|
return ifThenElse(CompareSelect::make(a, threshold, kLE), value, a);
|
|
});
|
|
} break;
|
|
|
|
case aten::where: {
|
|
return computeConditionWithTwoOperand(
|
|
"aten_where",
|
|
v,
|
|
[](const ExprHandle& a0, const ExprHandle& a1, const ExprHandle& a2) {
|
|
return ifThenElse(a0, a1, a2);
|
|
});
|
|
} break;
|
|
|
|
case aten::frac: {
|
|
return computeOneOperand(
|
|
"aten_frac", v, [](const ExprHandle& a) { return a - floor(a); });
|
|
} break;
|
|
|
|
case aten::lgamma: {
|
|
return computeOneOperand(
|
|
"aten_lgamma", v, [](const ExprHandle& a) { return lgamma(a); });
|
|
} break;
|
|
|
|
case prim::ConstantChunk: {
|
|
return Compute(
|
|
"prim_constantchunk",
|
|
dimsFromSizes(sizesForValue(v)),
|
|
[this, v](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
int64_t dim = n->i(attr::dim);
|
|
int64_t chunks = n->i(attr::chunks);
|
|
std::vector<ExprHandle> indices(axes.begin(), axes.end());
|
|
return chunk(
|
|
tensors_.at(n->inputs()[0]->unique()),
|
|
v->offset(),
|
|
dim,
|
|
chunks,
|
|
indices);
|
|
});
|
|
}
|
|
|
|
case aten::cat: {
|
|
return Compute(
|
|
"aten_cat",
|
|
dimsFromSizes(sizesForValue(v)),
|
|
[this, v](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
auto inputs = n->inputs()[0]->node()->inputs();
|
|
if (inputs.size() == 0) {
|
|
throw std::runtime_error(
|
|
"Empty input list is passed to aten::cat");
|
|
}
|
|
|
|
// Some of the inputs can be empty tensors, we need to skip them
|
|
// when we construct the expression, but we need to take them into
|
|
// account in dtype promotion.
|
|
std::vector<const torch::jit::Value*> nonempty_inputs;
|
|
for (auto input : inputs) {
|
|
if (input->type()->kind() == TypeKind::TensorType) {
|
|
auto tt = input->type()->cast<TensorType>();
|
|
if (tt->isComplete() && tt->sizes().size() && tt->sizes()[0] &&
|
|
*tt->sizes()[0]) {
|
|
nonempty_inputs.push_back(input);
|
|
}
|
|
}
|
|
}
|
|
|
|
// When all inputs are empty tensors, the tensor we create for this
|
|
// computation would contain no elements, so it doesn't really
|
|
// matter what we return here, so just return 0.
|
|
if (!nonempty_inputs.size()) {
|
|
return ExprHandle(0);
|
|
}
|
|
|
|
int64_t dim = n->inputs()[1]->node()->i(attr::value);
|
|
if (dim < 0) {
|
|
dim += axes.size();
|
|
}
|
|
|
|
if (dim < 0 || dim >= axes.size()) {
|
|
throw std::runtime_error("invalid 'dim' value in aten::cat");
|
|
}
|
|
|
|
// Promote input types.
|
|
// Note that we need to consider all inputs, including empty - they
|
|
// also affect the resultant dtype.
|
|
auto maybe_dtype = findDtypeForValue(inputs[0]);
|
|
TORCH_INTERNAL_ASSERT(
|
|
maybe_dtype, "Cannot find dtype for one of aten::cat inputs");
|
|
ScalarType highType = *maybe_dtype;
|
|
for (const auto input : inputs) {
|
|
auto maybe_dtype = findDtypeForValue(input);
|
|
TORCH_INTERNAL_ASSERT(
|
|
maybe_dtype, "Cannot find dtype for one of aten::cat inputs");
|
|
highType = promoteTypes(highType, *maybe_dtype);
|
|
}
|
|
|
|
// Now we know the final dtype, we know what inputs are non-empty,
|
|
// and we know that there is at least one such an input. With all
|
|
// that we construct a tensor expression performing the
|
|
// concatenation.
|
|
// The expression we build here is a cascading if-then-else that
|
|
// essentially represents:
|
|
//
|
|
// inp1[i, j, k] if 0 < i < l1,
|
|
// out[i,j,k] = inp2[i, j-l1, k] if l1 =< i < l1 + l2,
|
|
// ...
|
|
// inpN[i, j-l_N_1, k] if l1+l2+...l_N_1 < i
|
|
// where l_i is the corresponding size of the i-th input.
|
|
std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
|
|
ExprHandle load = promoteToDtype(
|
|
tensorOrConstant(nonempty_inputs[0], newAxes), highType);
|
|
size_t offset =
|
|
bufferSizes(tensors_.at(nonempty_inputs[0]->unique()))[dim];
|
|
newAxes[dim] = newAxes[dim] - IntImm::make(offset);
|
|
|
|
for (size_t ii = 1; ii < nonempty_inputs.size(); ++ii) {
|
|
auto input = nonempty_inputs[ii];
|
|
load = ifThenElse(
|
|
CompareSelect::make(axes[dim], IntImm::make(offset), kLT),
|
|
load,
|
|
promoteToDtype(tensorOrConstant(input, newAxes), highType));
|
|
|
|
offset += bufferSizes(tensors_.at(input->unique()))[dim];
|
|
newAxes[dim] = axes[dim] - IntImm::make(offset);
|
|
}
|
|
|
|
return load;
|
|
});
|
|
}
|
|
case aten::slice: {
|
|
return Compute(
|
|
"aten_slice",
|
|
dimsFromSizes(sizesForValue(v)),
|
|
[this, v](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
int dim = constant(n->inputs()[1]).AsNode<IntImm>()->value();
|
|
ExprHandle start = constant(n->inputs()[2]);
|
|
ExprHandle stride = constant(n->inputs()[4]);
|
|
|
|
std::vector<ExprHandle> newAxes(axes.begin(), axes.end());
|
|
newAxes[dim] = stride * newAxes[dim] + start;
|
|
return tensorOrConstant(n->inputs()[0], newAxes);
|
|
});
|
|
}
|
|
|
|
case aten::unsqueeze: {
|
|
return Compute(
|
|
"aten_unsqueeze",
|
|
dimsFromSizes(sizesForValue(v)),
|
|
[this, v](const std::vector<VarHandle>& axes) {
|
|
auto const& n = v->node();
|
|
int64_t dim = constant(n->inputs()[1]).AsNode<IntImm>()->value();
|
|
if (dim < 0) {
|
|
if (axes.size() == 0) {
|
|
throw malformed_input("axes are zero handling unsqueeze");
|
|
}
|
|
dim += axes.size();
|
|
}
|
|
// To construct an expression for an 'unsqueezed' tensor we need to
|
|
// drop the DIM-th axis, i.e.
|
|
// unsqueezed_v[i,j,k,l] = v[i,j,l] # dim = 2 - drop index 'k'
|
|
// 0 1 2 3
|
|
std::vector<ExprHandle> indices;
|
|
int64_t i = 0;
|
|
for (auto a : axes) {
|
|
if (i++ != dim) {
|
|
indices.emplace_back(ExprHandle(a.node()));
|
|
}
|
|
}
|
|
|
|
return tensorOrConstant(n->inputs()[0], indices);
|
|
});
|
|
}
|
|
|
|
case aten::sum: {
|
|
return computeSum(v);
|
|
}
|
|
|
|
case aten::softmax: {
|
|
return computeSoftmax(v, false);
|
|
}
|
|
|
|
case aten::log_softmax: {
|
|
return computeSoftmax(v, true);
|
|
}
|
|
|
|
default: {
|
|
throw std::runtime_error("Unhandled node kind");
|
|
}
|
|
}
|
|
}
|
|
|
|
Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
|
|
torch::jit::tensorexpr::LoopNest l(tensorOutputs_);
|
|
GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
|
|
|
|
bool hasReduction = NodeFinder<ReduceOp>::find(l.root_stmt()).size() != 0;
|
|
|
|
l.inlineIntermediateBufs();
|
|
|
|
if (backendType == kCudaCodeGen) {
|
|
for (auto tensor : tensorOutputs_) {
|
|
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
|
|
TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty");
|
|
For* flattened = nullptr;
|
|
LoopNest::flatten(loops, &flattened);
|
|
assert(flattened);
|
|
|
|
int loopLevels = getTECudaPointwiseLoopLevels();
|
|
const int kDefaultLoopLevels = 2;
|
|
loopLevels = (loopLevels > 0) ? loopLevels : kDefaultLoopLevels;
|
|
int blockCount = getTECudaPointwiseBlockCount();
|
|
int blockSize = getTECudaPointwiseBlockSize();
|
|
|
|
if (loopLevels == 2) {
|
|
For* outer;
|
|
For* inner;
|
|
const int kDefaultBlockSize = 512;
|
|
if (blockSize < 0) {
|
|
blockSize = kDefaultBlockSize;
|
|
}
|
|
l.splitWithMask(flattened, blockSize, &outer, &inner);
|
|
l.setGPUBlockIndex(outer, 0);
|
|
l.setGPUThreadIndex(inner, 0);
|
|
} else if (loopLevels == 3) {
|
|
For* outer;
|
|
For* inner;
|
|
For* inner1;
|
|
For* inner2;
|
|
// TODO: change the number of microprocessors
|
|
const int kDefaultBlockCount = 1280;
|
|
const int kDefaultBlockSize = 256;
|
|
blockCount = (blockCount > 0) ? blockCount : kDefaultBlockCount;
|
|
blockSize = (blockSize > 0) ? blockSize : kDefaultBlockSize;
|
|
l.splitWithMask(flattened, blockCount * blockSize, &outer, &inner);
|
|
l.splitWithMask(inner, blockSize, &inner1, &inner2);
|
|
l.setGPUBlockIndex(inner1, 0);
|
|
l.setGPUThreadIndex(inner2, 0);
|
|
} else {
|
|
throw std::runtime_error(
|
|
"Invalid loop-level: " + c10::to_string(loopLevels));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (backendType == kBlockCodeGen) {
|
|
auto block_analysis = std::make_unique<CreateBufferMap>();
|
|
for (auto tensor : tensorOutputs_) {
|
|
const int default_fp16_blocksize = 16;
|
|
const int default_uint8_blocksize = 32;
|
|
int blockSize = default_fp16_blocksize;
|
|
// We only handle looplevels == 2 for now
|
|
// Run Block analysis to get multi dim buffer info
|
|
auto root_stmt = l.root_stmt();
|
|
root_stmt->accept(block_analysis.get());
|
|
|
|
if (tensor->buf()->dtype().scalar_type() == ScalarType::Byte) {
|
|
blockSize = default_uint8_blocksize;
|
|
}
|
|
|
|
std::vector<For*> loops = l.getLoopStmtsFor(tensor);
|
|
TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty");
|
|
For* flattened = nullptr;
|
|
LoopNest::flatten(loops, &flattened);
|
|
assert(flattened);
|
|
|
|
For* outer = nullptr;
|
|
For* inner = nullptr;
|
|
l.splitWithMask(flattened, blockSize, &outer, &inner);
|
|
l.setGPUBlockIndex(outer, 0);
|
|
l.setGPUThreadIndex(inner, 0);
|
|
l.setBufferMap(outer, block_analysis->getBufferMap());
|
|
}
|
|
}
|
|
|
|
l.prepareForCodegen();
|
|
|
|
if (backendType == kLLVMCodeGen && !hasReduction) {
|
|
l.vectorizeInnerLoops();
|
|
}
|
|
|
|
Stmt* stmt = l.root_stmt();
|
|
// Arithmetic Simplification.
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), "\n");
|
|
return stmt;
|
|
}
|
|
|
|
std::string TensorExprKernel::getCodeGenName(BackendType backendType) {
|
|
switch (backendType) {
|
|
case kCudaCodeGen:
|
|
return "cuda_codegen";
|
|
case kLLVMCodeGen:
|
|
return "llvm_codegen";
|
|
case kSimpleIREval:
|
|
return "simple_ir_eval";
|
|
case kBlockCodeGen:
|
|
return "block_codegen";
|
|
default:
|
|
throw std::runtime_error(
|
|
"invalid backend type: " +
|
|
c10::to_string(static_cast<int>(backendType)));
|
|
}
|
|
}
|
|
|
|
std::vector<CodeGen::BufferArg> TensorExprKernel::prepareBufferArgs() {
|
|
std::vector<CodeGen::BufferArg> params;
|
|
for (auto const& arg : kernelArgs_) {
|
|
params.push_back(arg.buffer());
|
|
for (auto const& size : arg.sizes()) {
|
|
params.emplace_back(size.var);
|
|
}
|
|
for (auto const& stride : arg.strides()) {
|
|
params.emplace_back(stride.var);
|
|
}
|
|
}
|
|
for (auto& o : tensorOutputs_) {
|
|
params.emplace_back(o);
|
|
}
|
|
return params;
|
|
}
|
|
|
|
template <typename T>
|
|
static bool isValidPrimProperty(const c10::optional<T>& a, T b) {
|
|
return !a.has_value() || *a == b;
|
|
}
|
|
|
|
TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice(
|
|
at::Device device) {
|
|
BackendType backendType = BackendType::kUninitialized;
|
|
if (device.type() == at::kCUDA) {
|
|
backendType = kCudaCodeGen;
|
|
} else if (device.type() == at::kCPU && getTEGenerateBlockCode()) {
|
|
backendType = kBlockCodeGen;
|
|
} else if (device.type() == at::kCPU) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
backendType = kLLVMCodeGen;
|
|
#else
|
|
backendType = kSimpleIREval;
|
|
#endif
|
|
if (getTEMustUseLLVMOnCPU() && backendType == kSimpleIREval) {
|
|
throw std::runtime_error("LLVM Backend not found");
|
|
}
|
|
} else {
|
|
throw std::runtime_error("Invalid device type");
|
|
}
|
|
return backendType;
|
|
}
|
|
|
|
void TensorExprKernel::bindInput(const torch::jit::Value* input) {
|
|
auto const& t = input->type();
|
|
switch (t->kind()) {
|
|
case TypeKind::TensorType: {
|
|
auto tt = input->type()->cast<TensorType>();
|
|
Placeholder inBuffer(
|
|
"t" + input->debugName(),
|
|
ToDtype(static_cast<ScalarType>(*tt->scalarType())),
|
|
{0});
|
|
std::vector<DimArg> inputTensorDims;
|
|
for (size_t i = 0; i < *tt->sizes().size(); i++) {
|
|
auto const size = *tt->sizes()[i];
|
|
inputTensorDims.emplace_back(
|
|
DimArg(IntImm::make(size), "i" + c10::to_string(i)));
|
|
}
|
|
auto const strides = tt->strides();
|
|
tensors_.emplace(
|
|
input->unique(),
|
|
Compute(
|
|
"input" + c10::to_string(tensors_.size() + 1),
|
|
inputTensorDims,
|
|
[&](const std::vector<VarHandle>& axes) {
|
|
ExprHandle idx = 0;
|
|
for (size_t i = 0; i < axes.size(); i++) {
|
|
idx = idx + axes[i] * IntImm::make(*strides[i]);
|
|
}
|
|
return inBuffer.load(idx);
|
|
}));
|
|
kernelArgs_.emplace_back(
|
|
inBuffer, std::vector<ShapeArg>(), std::vector<ShapeArg>());
|
|
break;
|
|
}
|
|
case TypeKind::FloatType: {
|
|
VarHandle v("v" + input->debugName(), kFloat);
|
|
kernelArgs_.emplace_back(v);
|
|
scalars_.emplace(input->unique(), v);
|
|
break;
|
|
}
|
|
case TypeKind::BoolType: {
|
|
VarHandle v("v" + input->debugName(), kBool);
|
|
kernelArgs_.emplace_back(v);
|
|
scalars_.emplace(input->unique(), v);
|
|
break;
|
|
}
|
|
case TypeKind::IntType: {
|
|
VarHandle v("v" + input->debugName(), kInt);
|
|
kernelArgs_.emplace_back(v);
|
|
scalars_.emplace(input->unique(), v);
|
|
break;
|
|
}
|
|
default: {
|
|
throw unsupported_dtype();
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Remove all indices from axes positions.
|
|
std::vector<VarHandle> squeezeIndices(
|
|
const ParameterList& indices,
|
|
const std::vector<size_t>& axes) {
|
|
std::vector<VarHandle> indices_squeezed;
|
|
for (size_t dim = 0; dim < indices.size(); ++dim) {
|
|
if (!std::count(axes.begin(), axes.end(), dim)) {
|
|
indices_squeezed.push_back(indices[dim]);
|
|
}
|
|
}
|
|
return indices_squeezed;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
Tensor* TensorExprKernel::computeSum(const torch::jit::Value* v) {
|
|
auto reduction_info = getReductionInfo(v->node());
|
|
return Reduce(
|
|
"sum",
|
|
reduction_info.outputDims,
|
|
Sum(),
|
|
[&](ParameterList& indices) {
|
|
const auto& axes = reduction_info.axes;
|
|
// "Squeeze" out indices inserted when keepdim is set.
|
|
auto indices_squeezed =
|
|
reduction_info.keepdim ? squeezeIndices(indices, axes) : indices;
|
|
TORCH_INTERNAL_ASSERT(axes.size() <= indices_squeezed.size());
|
|
// Move innermost indices into axes positions:
|
|
// 1. Fill the outermost indices first.
|
|
// 2. Insert the innermost indices into the correct axis position,
|
|
// displacing the outermost indices as needed.
|
|
std::vector<ExprHandle> indices_exprs;
|
|
size_t i = 0;
|
|
for (; i < indices_squeezed.size() - axes.size(); ++i) {
|
|
indices_exprs.push_back(indices_squeezed[i]);
|
|
}
|
|
for (auto axis : axes) {
|
|
indices_exprs.insert(
|
|
indices_exprs.begin() + axis, indices_squeezed[i]);
|
|
++i;
|
|
}
|
|
auto indexed = tensorOrConstant(v->node()->input(0), indices_exprs);
|
|
if (reduction_info.dtype) {
|
|
return Cast::make(*reduction_info.dtype, indexed);
|
|
} else {
|
|
return indexed;
|
|
}
|
|
},
|
|
reduction_info.reductionDims);
|
|
}
|
|
|
|
Tensor* TensorExprKernel::computeSoftmax(
|
|
const torch::jit::Value* v,
|
|
bool log_softmax) {
|
|
// Softmax is computed as follows:
|
|
// softmax(vi) = exp(vi) / sum(exp(vi))
|
|
//
|
|
// In order to avoid overflow issues due to exp of a large number, we
|
|
// subtract the max of that dim before computing exp.
|
|
// softmax(vi) = exp(vi - max(vi)) / sum(exp(vi - max(vi)))
|
|
//
|
|
// This is implemented as 4 loopnests:
|
|
// - First loop computes the max over the softmax dim.
|
|
// - Second loop computes exp for every element in v after subtracting
|
|
// the max of the softmax dim it belongs to.
|
|
// - Third loop computes the sum over the softmax dim.
|
|
// - Final loop computes softmax for every element in v.
|
|
|
|
// LogSoftmax is computed as follows:
|
|
// log_softmax(vi) = log(softmax(vi))
|
|
// = vi - log(sum(exp(vi)))
|
|
//
|
|
// Using the same max trick as above:
|
|
// log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi))))
|
|
//
|
|
// This is implemented as 5 loopnests:
|
|
// - First loop computes the max over the softmax dim.
|
|
// - Second loop computes exp for every element in v after subtracting
|
|
// the max of the softmax dim it belongs to.
|
|
// - Third loop computes the sum over the softmax dim.
|
|
// - Fourth loop computes log for every element in the sum.
|
|
// - Final loop computes the log_softmax for every element in v.
|
|
|
|
TORCH_INTERNAL_ASSERT(v->node()->inputs().size() == 3);
|
|
auto output_dims = dimsFromSizes(sizesForValue(v));
|
|
|
|
// We do not handle None for dims (input 1) because that is supposed to
|
|
// be deprecated.
|
|
TORCH_INTERNAL_ASSERT(v->node()->input(1)->node()->kind() == prim::Constant);
|
|
size_t softmax_dim = v->node()->input(1)->node()->i(attr::value);
|
|
TORCH_INTERNAL_ASSERT(softmax_dim < output_dims.size());
|
|
|
|
std::vector<DimArg> non_softmax_dims;
|
|
for (size_t i = 0; i < output_dims.size(); ++i) {
|
|
if (i != softmax_dim) {
|
|
non_softmax_dims.push_back(output_dims[i]);
|
|
}
|
|
}
|
|
|
|
// Softmax implementation includes two reductions, one to find the max and
|
|
// the other to calculate the sum along the softmax dim. These reductions
|
|
// will have the softmax dimension as the inner most loop. So, the innermost
|
|
// index in the indices will refer to the softmax dimension.
|
|
|
|
// Update the indices by moving the softmax dimension index to the
|
|
// appropriate position.
|
|
auto move_softmax_dim_index_to_pos = [&](const ParameterList& indices) {
|
|
std::vector<ExprHandle> new_indices;
|
|
for (auto ind : indices) {
|
|
new_indices.push_back(ind);
|
|
}
|
|
for (size_t i = softmax_dim; i < indices.size() - 1; ++i) {
|
|
new_indices[i + 1] = indices[i];
|
|
}
|
|
new_indices[softmax_dim] = indices[indices.size() - 1];
|
|
return new_indices;
|
|
};
|
|
|
|
// Remove the index corresponding to the softmax dimension.
|
|
auto remove_softmax_dim_index = [&](const ParameterList& indices) {
|
|
std::vector<ExprHandle> new_indices;
|
|
for (size_t i = 0; i < indices.size(); ++i) {
|
|
if (i != softmax_dim) {
|
|
new_indices.push_back(indices[i]);
|
|
}
|
|
}
|
|
return new_indices;
|
|
};
|
|
|
|
auto convert_indices_to_expr_handle = [&](const ParameterList& indices) {
|
|
std::vector<ExprHandle> new_indices(indices.size());
|
|
for (size_t i = 0; i < indices.size(); ++i) {
|
|
new_indices[i] = indices[i];
|
|
}
|
|
return new_indices;
|
|
};
|
|
|
|
c10::optional<Dtype> dtype = ToDtype(ScalarType::None);
|
|
auto maybe_dtype = v->node()->get(attr::dtype);
|
|
if (maybe_dtype && !maybe_dtype->isNone()) {
|
|
dtype = ToDtype(static_cast<ScalarType>(maybe_dtype->toInt()));
|
|
}
|
|
|
|
auto max = Reduce(
|
|
"aten_softmax_max",
|
|
non_softmax_dims,
|
|
Maximum(dtype.value()),
|
|
[&](ParameterList& indices) {
|
|
return tensorOrConstant(
|
|
v->node()->inputs()[0], move_softmax_dim_index_to_pos(indices));
|
|
},
|
|
{output_dims[softmax_dim]});
|
|
auto e =
|
|
Compute("aten_softmax_exp", output_dims, [&](ParameterList& indices) {
|
|
auto inp = tensorOrConstant(
|
|
v->node()->inputs()[0], convert_indices_to_expr_handle(indices));
|
|
return exp(inp - max->call(remove_softmax_dim_index(indices)));
|
|
});
|
|
auto sum = Reduce(
|
|
"aten_softmax_sum",
|
|
non_softmax_dims,
|
|
Sum(),
|
|
[&](ParameterList& indices) {
|
|
return e->call(move_softmax_dim_index_to_pos(indices));
|
|
},
|
|
{output_dims[softmax_dim]});
|
|
if (!log_softmax) {
|
|
return Compute("aten_softmax", output_dims, [&](ParameterList& indices) {
|
|
return e->call(indices) / sum->call(remove_softmax_dim_index(indices));
|
|
});
|
|
}
|
|
|
|
auto log_sum = Compute(
|
|
"aten_softmax_log_sum", non_softmax_dims, [&](ParameterList& indices) {
|
|
return log(sum->call(indices));
|
|
});
|
|
return Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) {
|
|
auto inp = tensorOrConstant(
|
|
v->node()->inputs()[0], convert_indices_to_expr_handle(indices));
|
|
auto non_softmax_indices = remove_softmax_dim_index(indices);
|
|
return inp - max->call(non_softmax_indices) -
|
|
log_sum->call(non_softmax_indices);
|
|
});
|
|
}
|
|
|
|
TensorExprKernel::ReductionInfo TensorExprKernel::getReductionInfo(
|
|
const torch::jit::Node* node) {
|
|
std::vector<size_t> axes;
|
|
bool keepdim = false;
|
|
// aten::sum takes the input tensor named self.
|
|
auto sizes = sizesForValue(node->namedInput(attr::self));
|
|
const auto inputs = node->inputs();
|
|
int rank = sizes.size();
|
|
if (inputs.size() > 2) {
|
|
auto nodeAxes = getReductionAxes(node);
|
|
// Canonicalize axes: wrap around, sort and make unique.
|
|
for (auto axis : nodeAxes) {
|
|
axes.push_back(at::maybe_wrap_dim(axis, rank));
|
|
}
|
|
std::sort(axes.begin(), axes.end());
|
|
axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
|
|
keepdim = node->get(attr::keepdim)->toBool();
|
|
} else {
|
|
axes.resize(sizes.size());
|
|
std::iota(axes.begin(), axes.end(), 0);
|
|
}
|
|
// Axes go into reduction dimensions.
|
|
std::vector<DimArg> reductionDims;
|
|
reductionDims.reserve(sizes.size());
|
|
for (size_t axis : axes) {
|
|
reductionDims.emplace_back(sizes[axis]);
|
|
}
|
|
auto allDims = dimsFromSizes(sizes);
|
|
std::vector<DimArg> outputDims;
|
|
// Output dimensions are the complement of axes. When keepdim is set, a
|
|
// one-sized dimension is inserted for each axis.
|
|
for (size_t dim = 0; dim < allDims.size(); ++dim) {
|
|
if (!std::count(axes.begin(), axes.end(), dim)) {
|
|
outputDims.emplace_back(sizes[dim]);
|
|
} else if (keepdim) {
|
|
outputDims.emplace_back(1);
|
|
}
|
|
}
|
|
c10::optional<Dtype> dtype;
|
|
auto dtypeValue = node->get(attr::dtype);
|
|
if (!dtypeValue->isNone()) {
|
|
auto scalarType = static_cast<ScalarType>(dtypeValue->toInt());
|
|
dtype = ToDtype(scalarType);
|
|
}
|
|
return {reductionDims, outputDims, axes, keepdim, dtype};
|
|
}
|
|
|
|
std::vector<int64_t> TensorExprKernel::getReductionAxes(
|
|
const torch::jit::Node* node) {
|
|
std::vector<int64_t> axes;
|
|
auto axesNode = node->namedInput(attr::dim)->node();
|
|
// There are two possible representations for reduction axes:
|
|
// 1. A prim::ListConstruct of integer constants.
|
|
// 2. A prim::Constant list of integer ival's.
|
|
// We need to handle both of them.
|
|
if (axesNode->kind() == prim::ListConstruct) {
|
|
for (auto axisNode : axesNode->inputs()) {
|
|
axes.push_back(constant(axisNode).AsNode<IntImm>()->value());
|
|
}
|
|
return axes;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(axesNode->kind() == prim::Constant);
|
|
TORCH_INTERNAL_ASSERT(axesNode->kindOf(attr::value) == AttributeKind::ival);
|
|
const auto& genericList = axesNode->ival(attr::value).toList();
|
|
for (const IValue axisNode : genericList) {
|
|
axes.push_back(axisNode.toInt());
|
|
}
|
|
return axes;
|
|
}
|
|
|
|
void TensorExprKernel::compile() {
|
|
KernelScope kernelScope(&kernelArena_);
|
|
GRAPH_DUMP("TensorExprKernel graph:", graph_);
|
|
// Bind inputs to buffers.
|
|
nInputs_ = graph_->inputs().size();
|
|
for (auto const& input : graph_->inputs()) {
|
|
bindInput(input);
|
|
inputTypes_.push_back(input->type());
|
|
}
|
|
|
|
// Bind nodes to tensor compute expressions.
|
|
for (auto const& n : graph_->nodes()) {
|
|
if (n->kind() == prim::Constant || n->kind() == prim::ListConstruct) {
|
|
continue;
|
|
} else {
|
|
for (auto const& output : n->outputs()) {
|
|
if (output->hasUses()) {
|
|
tensors_.emplace(output->unique(), computeValue(output));
|
|
}
|
|
}
|
|
}
|
|
if (hasRandom_ && hasBroadcast_) {
|
|
throw std::runtime_error(
|
|
"Cannot support broadcast and random within one kernel");
|
|
}
|
|
}
|
|
|
|
// Move output operands from `tensors_` to `tensorOutputs_`
|
|
for (const auto& output : graph_->outputs()) {
|
|
if (!tensors_.count(output->unique())) {
|
|
throw malformed_input("cannot find output Tensor");
|
|
}
|
|
tensorOutputs_.emplace_back(tensors_.at(output->unique()));
|
|
tensors_.erase(output->unique());
|
|
}
|
|
|
|
device_ = *pickDeviceType(graph_->inputs());
|
|
BackendType backendType = inferBackendTypeFromDevice(device_);
|
|
Stmt* stmt = generateStmt(backendType);
|
|
// Set up formal params (inputs, then outputs) for kernel.
|
|
std::vector<CodeGen::BufferArg> params = prepareBufferArgs();
|
|
|
|
// Generate code.
|
|
codegen_ = CreateCodeGen(
|
|
getCodeGenName(backendType),
|
|
stmt,
|
|
params,
|
|
device_,
|
|
SubgraphUtils::generateNameForGraph(graph_));
|
|
}
|
|
|
|
TensorExprKernel::TensorExprKernel(const std::shared_ptr<Graph>& subgraph)
|
|
: graph_(subgraph), code_(subgraph, "") {
|
|
if (!fallbackAllowed()) {
|
|
compile();
|
|
return;
|
|
}
|
|
|
|
try {
|
|
compile();
|
|
} catch (...) {
|
|
fallback_ = true;
|
|
}
|
|
}
|
|
|
|
void TensorExprKernel::run(Stack& stack) {
|
|
if (fallbackEnforced()) {
|
|
fallback(stack);
|
|
return;
|
|
}
|
|
if (!fallbackAllowed()) {
|
|
runKernel(stack);
|
|
return;
|
|
}
|
|
|
|
if (fallback_) {
|
|
fallback(stack);
|
|
return;
|
|
}
|
|
try {
|
|
runKernel(stack);
|
|
} catch (...) {
|
|
fallback_ = true;
|
|
fallback(stack);
|
|
}
|
|
}
|
|
|
|
std::vector<CodeGen::CallArg> TensorExprKernel::prepareRunArgs(
|
|
const at::ArrayRef<IValue>& inputs,
|
|
std::vector<at::Tensor>& outputs) {
|
|
std::map<const Expr*, int32_t> varToSize;
|
|
|
|
std::vector<CodeGen::CallArg> runArgs;
|
|
for (size_t i = 0; i < inputs.size(); i++) {
|
|
auto const& input = inputs[i];
|
|
if (input.isInt()) {
|
|
runArgs.emplace_back((int32_t)input.toInt());
|
|
} else if (input.isDouble()) {
|
|
runArgs.emplace_back((float)input.toDouble());
|
|
} else if (input.isTensor()) {
|
|
auto const& tensor = input.toTensor();
|
|
runArgs.emplace_back(tensor.data_ptr());
|
|
for (auto const& size : kernelArgs_[i].sizes()) {
|
|
int32_t s = tensor.sizes()[size.idx];
|
|
runArgs.emplace_back(s);
|
|
varToSize[size.var.node()] = s;
|
|
}
|
|
for (auto const& stride : kernelArgs_[i].strides()) {
|
|
int32_t s = tensor.strides()[stride.idx];
|
|
runArgs.emplace_back(s);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto& o : tensorOutputs_) {
|
|
std::vector<int64_t> tensorSize;
|
|
for (const Expr* dim : o->dims()) {
|
|
auto it = varToSize.find(dim);
|
|
if (it != varToSize.end()) {
|
|
tensorSize.push_back(it->second);
|
|
} else {
|
|
const IntImm* s = dynamic_cast<const IntImm*>(dim);
|
|
if (!s) {
|
|
throw malformed_input("output expected Int", dim);
|
|
}
|
|
tensorSize.push_back(s->value());
|
|
}
|
|
}
|
|
|
|
outputs.push_back(at::empty(
|
|
tensorSize, c10::TensorOptions(tensorType(o)).device(device_)));
|
|
runArgs.emplace_back(outputs.back().data_ptr());
|
|
}
|
|
return runArgs;
|
|
}
|
|
|
|
Stmt* TensorExprKernel::getCodeGenStmt() {
|
|
return codegen_->stmt();
|
|
}
|
|
|
|
void TensorExprKernel::runKernel(Stack& stack) {
|
|
KernelScope kernelScope(&kernelArena_);
|
|
|
|
// Set up arguments (inputs, then outputs) for kernel call.
|
|
auto inputs = last(stack, nInputs_);
|
|
std::vector<at::Tensor> outputs;
|
|
|
|
std::vector<CodeGen::CallArg> runArgs = prepareRunArgs(inputs, outputs);
|
|
|
|
// Call the kernel.
|
|
codegen_->call(runArgs);
|
|
|
|
// Update the stack.
|
|
drop(stack, nInputs_);
|
|
for (auto& o : outputs) {
|
|
push_one(stack, std::move(o));
|
|
}
|
|
}
|