#include #include #include #include #include #include #include #include using namespace torch::jit; using namespace torch::jit::tensorexpr; namespace torch { namespace jit { namespace tensorexpr { static int te_cuda_pointwise_loop_levels = -1; static int te_cuda_pointwise_block_count = -1; static int te_cuda_pointwise_block_size = -1; static bool fallback_allowed = false; static bool te_generate_block_code = false; static bool te_must_use_llvm_on_cpu = false; bool setFallbackAllowed(bool value) { bool old_value = fallback_allowed; fallback_allowed = value; return old_value; } bool fallbackAllowed() { static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK"); if (!enable_c_str) { return fallback_allowed; } if (std::string(enable_c_str) == "0") { return false; } return true; } bool fallbackEnforced() { static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK"); if (!enable_c_str) { return fallback_allowed; } if (std::string(enable_c_str) == "2") { return true; } return false; } int& getTECudaPointwiseLoopLevels() { return te_cuda_pointwise_loop_levels; } int& getTECudaPointwiseBlockCount() { return te_cuda_pointwise_block_count; } int& getTECudaPointwiseBlockSize() { return te_cuda_pointwise_block_size; } // TODO: Remove this global var // Ideally Block code gen should be decided // based on device type in tensor. bool& getTEGenerateBlockCode() { return te_generate_block_code; } bool& getTEMustUseLLVMOnCPU() { return te_must_use_llvm_on_cpu; } c10::optional pickDeviceType( const at::ArrayRef& inputs) { c10::optional device = c10::nullopt; for (auto const& input : inputs) { auto tt = input->type()->cast(); if (tt && tt->device()) { if (device && *device != *tt->device()) { return c10::nullopt; } device = *tt->device(); } } return device; } } // namespace tensorexpr } // namespace jit } // namespace torch static at::ScalarType tensorType(Tensor* t) { return static_cast(t->body()->dtype().scalar_type()); } static std::vector computeIndicesToBroadcast( const std::vector& outputAxes, const std::vector& inputSizes) { if (outputAxes.size() < inputSizes.size()) { throw malformed_input("Cannot broadcast to a lower rank tensor"); } std::vector bcast; auto axisIt = outputAxes.rbegin(); auto sizeIt = inputSizes.rbegin(); while (sizeIt != inputSizes.rend()) { auto const& size = sizeIt->AsNode(); if (size && size->value() == 1) { bcast.emplace_back(0); } else { bcast.emplace_back(*axisIt); } ++axisIt; ++sizeIt; } std::reverse(bcast.begin(), bcast.end()); return bcast; } ExprHandle TensorExprKernel::broadcast( Tensor* t, const std::vector& axes) { return t->call(computeIndicesToBroadcast( axes, ExprVectorToExprHandleVector(t->buf()->dims()))); } ExprHandle TensorExprKernel::chunk( Tensor* t, size_t chunkIdx, int64_t dim, int64_t chunks, const std::vector& axes) { if (dim < 0) { dim = axes.size() + dim; } auto sizes = bufferSizes(t); size_t step = sizes[dim] / chunks; std::vector indices; for (size_t i = 0; i < axes.size(); ++i) { if (i == dim) { indices.push_back(axes[i] + IntImm::make((int)chunkIdx * (int)step)); } else { indices.push_back(axes[i]); } } return t->call(indices); } ExprHandle promoteToDtype(ExprHandle e, ScalarType dt) { if (e.dtype().scalar_type() == dt) { return e; } switch (dt) { // NOLINTNEXTLINE #define TYPE_CASE(Type, Name) \ case ScalarType::Name: \ e = cast(e); \ break; AT_FORALL_SCALAR_TYPES_AND2(Half, Bool, TYPE_CASE); #undef TYPE_CASE default: throw unsupported_dtype(); } return e; } ExprHandle TensorExprKernel::tensorOrConstant( const torch::jit::Value* v, const std::vector& axes) { auto ti = tensors_.find(v->unique()); if (ti != tensors_.end()) { return broadcast(ti->second, axes); } return constant(v); } std::vector TensorExprKernel::sizesFromVaryingShape( const c10::VaryingShape& shape) { std::vector dims; for (size_t i = 0; i < *shape.size(); i++) { dims.push_back(IntImm::make(*shape[i])); } return dims; } std::vector TensorExprKernel::dimsFromSizes( const std::vector& sizes) { std::vector dimArgs; for (size_t idx = 0; idx < sizes.size(); idx++) { dimArgs.emplace_back(DimArg(sizes[idx], "i" + c10::to_string(idx))); } return dimArgs; } std::vector TensorExprKernel::sizesForValue( const torch::jit::Value* v) { if (known_sizes_.count(v)) { return known_sizes_.at(v); } // If the shape is present in the type info, just extract it from here. No // need to infer it. if (v->type()->kind() == TypeKind::TensorType) { auto tt = v->type()->cast(); if (tt->isComplete()) { return sizesFromVaryingShape(tt->sizes()); } } if (v->type()->isSubtypeOf(FloatType::get()) || v->type()->isSubtypeOf(IntType::get())) { return {1}; } if (v->type()->isSubtypeOf(NoneType::get())) { return {}; } known_sizes_[v] = inferSizesForValue(v); return known_sizes_.at(v); } std::vector TensorExprKernel::inferSizesForValue( const torch::jit::Value* v) { switch (v->node()->kind()) { case aten::_cast_Float: case aten::sigmoid: case aten::reciprocal: case aten::neg: case aten::relu: case aten::log: case aten::log10: case aten::log1p: case aten::log2: case aten::exp: case aten::expm1: case aten::erf: case aten::erfc: case aten::cos: case aten::sin: case aten::tan: case aten::rand_like: case aten::acos: case aten::asin: case aten::cosh: case aten::sinh: case aten::atan: case aten::tanh: case aten::sqrt: case aten::rsqrt: case aten::abs: case aten::ceil: case aten::floor: case aten::round: case aten::trunc: case aten::frac: case aten::lgamma: case aten::type_as: return sizesForValue(v->node()->input(0)); case aten::sub: case aten::add: case aten::mul: case aten::div: case aten::__and__: case aten::__or__: case aten::__xor__: case aten::__lshift__: case aten::__rshift__: case aten::eq: case aten::ne: case aten::ge: case aten::gt: case aten::le: case aten::lt: case aten::min: case aten::max: case aten::pow: case aten::fmod: case aten::remainder: case aten::atan2: { std::vector> shapes; for (size_t idx = 0; idx < 2; idx++) { torch::jit::Value* inp = v->node()->input(idx); shapes.push_back(sizesForValue(inp)); } return broadcastShapes(shapes); } case aten::lerp: case aten::clamp: case aten::threshold: case aten::where: { std::vector> shapes; for (size_t idx = 0; idx < 3; idx++) { torch::jit::Value* inp = v->node()->input(idx); shapes.push_back(sizesForValue(inp)); } return broadcastShapes(shapes); } case aten::addcmul: { std::vector> shapes; for (size_t idx = 0; idx < 4; idx++) { torch::jit::Value* inp = v->node()->input(idx); shapes.push_back(sizesForValue(inp)); } return broadcastShapes(shapes); } case prim::ConstantChunk: { auto shape = sizesForValue(v->node()->input()); int dim = v->node()->i(attr::dim); int chunks = v->node()->i(attr::chunks); shape[dim] = IRSimplifier::simplify(shape[dim] / chunks); return shape; } case aten::unsqueeze: { auto const& n = v->node(); auto shape = sizesForValue(n->input(0)); int64_t dim = constant(n->input(1)).AsNode()->value(); // From the documentation // (https://pytorch.org/docs/master/generated/torch.unsqueeze.html): // // A dim value within the range [-input.dim() - 1, input.dim() + 1) can be // used. Negative dim will correspond to unsqueeze() applied at dim = dim // + input.dim() + 1. if (dim < 0) { dim = dim + shape.size() + 1; } if (dim < 0 || dim > shape.size()) { throw std::runtime_error("Invalid 'dim' input in aten::unsqueeze"); } shape.insert(shape.begin() + dim, ExprHandle(1)); return shape; } case aten::cat: { // In JIT IR, aten::cat usually appears with the following nodes around // it: // %dim : int = prim::Constant[value=0]() // %inputs : Tensor[] = prim::ListConstruct(%a, %b, ...) // %cat_output : Tensor = aten::cat(%inputs, %dim) // Shapes of the input tensors could only differ at the dimension %dim. // The sizes of the output tensor on that dimension is a sum of the // corresponding sizes of the input tensors, the other dimension have the // same sizes. // Negative dim will correspond to dim = dim + input.dim(). auto const& n = v->node(); auto inputs = n->input(0)->node()->inputs(); if (inputs.size() == 0) { throw std::runtime_error("Empty input list is passed to aten::cat"); } TORCH_INTERNAL_ASSERT(n->input(1)->node()->kind() == prim::Constant); int64_t dim = n->input(1)->node()->i(attr::value); auto shape = sizesForValue(inputs[0]); if (dim < 0) { dim += shape.size(); } if (dim < 0 || dim > shape.size()) { throw std::runtime_error("Invalid 'dim' input in aten::cat"); } ExprHandle concat_dim_size = 0; for (auto input : inputs) { concat_dim_size = concat_dim_size + sizesForValue(input)[dim]; } concat_dim_size = IRSimplifier::simplify(concat_dim_size); shape[dim] = concat_dim_size; return shape; } case aten::softmax: case aten::log_softmax: // Output of softmax / log_softmax has the same shape as input 0. return sizesForValue(v->node()->input(0)); case aten::slice: throw std::runtime_error( "Shape info is not implemented for this kind of node"); default: { GRAPH_DEBUG("Can't infer sizes for the node: ", *v->node()); GRAPH_DEBUG("Full fusion group graph:\n", *v->node()->owningGraph()); throw std::runtime_error("Unhandled node kind"); } } } ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) { if (v->node()->kind() == prim::Constant) { const auto val = toIValue(v).value(); if (val.isDouble()) { return FloatImm::make(static_cast(val.toDouble())); } else if (val.isInt()) { return IntImm::make(val.toInt()); } else if (val.isBool()) { return BoolImm::make(val.toBool()); } else if (val.isNone()) { // This is just a placeholder so we don't throw. None-handling // is operator-specific and should be handled properly in // the operator-specific lowering code. return IntImm::make(0); } else { throw unsupported_dtype(); } } if (!scalars_.count(v->unique())) { throw malformed_input("no scalar in Constant"); } return scalars_.at(v->unique()); } ExprHandle promoteIntegerToFloat(const ExprHandle& e) { auto scalarType = static_cast(e.dtype().scalar_type()); if (!c10::isIntegralType(scalarType, /*includeBool*/ true)) { return e; } auto defaultType = static_cast( c10::typeMetaToScalarType(c10::get_default_dtype())); return Cast::make(Dtype(defaultType, e.dtype().lanes()), e); } void TensorExprKernel::promoteInputs(std::vector& inputs) { if (inputs.empty()) { return; } // Find the highest type among the inputs. ScalarType highType = inputs[0].dtype().scalar_type(); for (const auto input : inputs) { highType = promoteTypes(highType, input.dtype().scalar_type()); } for (ExprHandle& e : inputs) { e = promoteToDtype(e, highType); } } ExprHandle TensorExprKernel::demoteOutput( const ExprHandle& e, const torch::jit::Value* v) { if (v->type()->kind() != TypeKind::TensorType) { return e; } if (!v->isCompleteTensor()) { return e; } auto tt = *v->type()->cast()->scalarType(); if (tt == static_cast(e.dtype().scalar_type())) { return e; } switch (tt) { // NOLINTNEXTLINE #define TYPE_CASE(Type, Name) \ case at::ScalarType::Name: \ return cast(e); AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); #undef TYPE_CASE case at::ScalarType::Bool: return cast(e); default: throw unsupported_dtype(); } return e; } static bool isOne(ExprHandle e) { auto const& n = e.AsNode(); if (!n) { return false; } return n->value() == 1; } std::vector TensorExprKernel::broadcastShapes( std::vector> shapes) { size_t n = shapes.size(); if (n == 1) { return shapes[0]; } auto res1 = broadcastShapes(shapes[n - 2], shapes[n - 1]); shapes[n - 2] = res1; shapes.pop_back(); auto res2 = broadcastShapes(shapes); return res2; } std::vector TensorExprKernel::broadcastShapes( const std::vector& a, const std::vector& b) { auto at = a.rbegin(); auto bt = b.rbegin(); std::vector ret; while (at != a.rend() || bt != b.rend()) { if (at == a.rend()) { hasBroadcast_ = true; ret.push_back(*bt++); continue; } if (bt == b.rend()) { hasBroadcast_ = true; ret.push_back(*at++); continue; } // TODO: if neither *at nor *bt is 1, ensure they are identical // expressions. Nb: `==` doesn't work since that simply produces a new // ExprHandle. ExprHandle dim = *at; if (isOne(*at)) { if (!isOne(*bt)) { dim = *bt; hasBroadcast_ = true; } } ret.push_back(dim); at++; bt++; } std::reverse(ret.begin(), ret.end()); return ret; } std::vector TensorExprKernel::valueShape( const torch::jit::Value* v) { auto it = tensors_.find(v->unique()); if (it == tensors_.end()) { return {}; } return ExprVectorToExprHandleVector(it->second->dims()); } Tensor* TensorExprKernel::computeOneOperand( const std::string& name, const torch::jit::Value* v, const std::function& innerExpr) { auto const& n = v->node(); auto const& shape = valueShape(n->inputs()[0]); return Compute( name, c10::fmap(shape), [this, v, innerExpr](const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { tensorOrConstant(n->inputs()[0], indices)}; promoteInputs(inputs); ExprHandle compute = innerExpr(inputs[0]); return demoteOutput(compute, n->output()); }); } Tensor* TensorExprKernel::computeTwoOperand( const std::string& name, const torch::jit::Value* v, const std::function& innerExpr) { auto const& n = v->node(); auto const& shape = broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1])); return Compute( name, c10::fmap(shape), [this, v, innerExpr](const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { tensorOrConstant(n->inputs()[0], indices), tensorOrConstant(n->inputs()[1], indices), }; promoteInputs(inputs); ExprHandle compute = innerExpr(inputs[0], inputs[1]); return demoteOutput(compute, n->output()); }); } Tensor* TensorExprKernel::computeTwoOperandWithAlpha( const std::string& name, const torch::jit::Value* v, const std::function& innerExpr) { auto const& n = v->node(); auto const& shape = broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1])); return Compute( name, c10::fmap(shape), [this, v, innerExpr](const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { tensorOrConstant(n->inputs()[0], indices), tensorOrConstant(n->inputs()[1], indices), tensorOrConstant(n->inputs()[2], indices), }; promoteInputs(inputs); ExprHandle compute = innerExpr(inputs[0], inputs[2] * inputs[1]); return demoteOutput(compute, n->output()); }); } Tensor* TensorExprKernel::computeConditionWithTwoOperand( const std::string& name, const torch::jit::Value* v, const std::function< ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& innerExpr) { auto const& n = v->node(); std::vector> shapes; for (size_t idx = 0; idx < 2; idx++) { torch::jit::Value* inp = n->input(idx); shapes.push_back(sizesForValue(inp)); } auto const& shape = broadcastShapes(shapes); return Compute( name, c10::fmap(shape), [this, v, innerExpr](const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { tensorOrConstant(n->inputs()[1], indices), tensorOrConstant(n->inputs()[2], indices), }; promoteInputs(inputs); // First expr is the condition, which we don't promote inputs.emplace( inputs.begin(), tensorOrConstant(n->inputs()[0], indices)); ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2]); return demoteOutput(compute, n->output()); }); } Tensor* TensorExprKernel::computeThreeOperand( const std::string& name, const torch::jit::Value* v, const std::function< ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& innerExpr) { auto const& n = v->node(); std::vector> shapes; for (size_t idx = 0; idx < 3; idx++) { torch::jit::Value* inp = n->input(idx); shapes.push_back(sizesForValue(inp)); } auto const& shape = broadcastShapes(shapes); return Compute( name, c10::fmap(shape), [this, v, innerExpr](const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { tensorOrConstant(n->inputs()[0], indices), tensorOrConstant(n->inputs()[1], indices), tensorOrConstant(n->inputs()[2], indices), }; promoteInputs(inputs); ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2]); return demoteOutput(compute, n->output()); }); } Tensor* TensorExprKernel::computeFourOperand( const std::string& name, const torch::jit::Value* v, const std::function& innerExpr) { auto const& n = v->node(); std::vector> shapes; for (size_t idx = 0; idx < 4; idx++) { torch::jit::Value* inp = n->input(idx); shapes.push_back(sizesForValue(inp)); } auto const& shape = broadcastShapes(shapes); return Compute( name, c10::fmap(shape), [this, v, innerExpr](const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { tensorOrConstant(n->inputs()[0], indices), tensorOrConstant(n->inputs()[1], indices), tensorOrConstant(n->inputs()[2], indices), tensorOrConstant(n->inputs()[3], indices), }; promoteInputs(inputs); ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2], inputs[3]); return demoteOutput(compute, n->output()); }); } namespace { // Convert boolean to integer, if needed. ExprHandle boolToInteger(const ExprHandle& x) { return x.dtype().scalar_type() == ScalarType::Bool ? cast(x) : x; } } // namespace c10::optional findDtypeForValue(const torch::jit::Value* v) { if (v->type()->kind() == TypeKind::TensorType) { auto tt = v->type()->cast(); if (tt->scalarType()) { return static_cast(*tt->scalarType()); } } return c10::nullopt; } Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { switch (v->node()->kind()) { case aten::add: { auto add_lambda = [](const ExprHandle& lhs, const ExprHandle& rhs) { return boolToInteger(lhs) + boolToInteger(rhs); }; TORCH_INTERNAL_ASSERT( v->node()->inputs().size() == 2 || v->node()->inputs().size() == 3); return (v->node()->inputs().size() > 2) ? computeTwoOperandWithAlpha("aten_add", v, add_lambda) : computeTwoOperand("aten_add", v, add_lambda); } break; case aten::_cast_Float: { return computeOneOperand("aten_cast_float", v, [](const ExprHandle& a) { return cast(a); }); } break; case aten::sub: { auto sub_lambda = [](const ExprHandle& lhs, const ExprHandle& rhs) { // NB: sub isn't supported on boolean, no need to promote to integer. return lhs - rhs; }; TORCH_INTERNAL_ASSERT( v->node()->inputs().size() == 2 || v->node()->inputs().size() == 3); return (v->node()->inputs().size() > 2) ? computeTwoOperandWithAlpha("aten_sub", v, sub_lambda) : computeTwoOperand("aten_sub", v, sub_lambda); } break; case aten::mul: { return computeTwoOperand( "aten_mul", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return boolToInteger(lhs) * boolToInteger(rhs); }); } break; case aten::div: { return computeTwoOperand( "aten_div", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return boolToInteger(lhs) / boolToInteger(rhs); }); } break; case aten::__and__: { return computeTwoOperand( "aten_and", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return boolToInteger(lhs) & boolToInteger(rhs); }); } break; case aten::__or__: { return computeTwoOperand( "aten_or", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return boolToInteger(lhs) | boolToInteger(rhs); }); } break; case aten::__xor__: { return computeTwoOperand( "aten_xor", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return boolToInteger(lhs) ^ boolToInteger(rhs); }); } break; case aten::__lshift__: { return computeTwoOperand( "aten_lshift", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return lhs << rhs; }); } break; case aten::__rshift__: { return computeTwoOperand( "aten_rshift", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return lhs >> rhs; }); } break; case aten::addcmul: { return computeFourOperand( "aten_addcmul", v, [](const ExprHandle& a0, const ExprHandle& a1, const ExprHandle& a2, const ExprHandle& a3) { return a0 + a3 * a1 * a2; }); } break; case aten::eq: { return computeTwoOperand( "aten_eq", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return cast(lhs == rhs); }); } break; case aten::ne: { return computeTwoOperand( "aten_ne", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return cast(lhs != rhs); }); } break; case aten::ge: { return computeTwoOperand( "aten_ge", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return cast(lhs >= rhs); }); } break; case aten::gt: { return computeTwoOperand( "aten_gt", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return cast(lhs > rhs); }); } break; case aten::le: { return computeTwoOperand( "aten_le", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return cast(lhs <= rhs); }); } break; case aten::lt: { return computeTwoOperand( "aten_lt", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return cast(lhs < rhs); }); } break; case aten::min: { return computeTwoOperand( "aten_min", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return Min::make(boolToInteger(lhs), boolToInteger(rhs), false); }); } break; case aten::max: { return computeTwoOperand( "aten_max", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return Max::make(boolToInteger(lhs), boolToInteger(rhs), false); }); } break; case aten::clamp: { bool noMin = false; bool noMax = false; if (v->node()->input(1)->node()->kind() == prim::Constant) { const auto val = toIValue(v->node()->input(1)).value(); if (val.isNone()) { noMin = true; } } if (v->node()->input(2)->node()->kind() == prim::Constant) { const auto val = toIValue(v->node()->input(2)).value(); if (val.isNone()) { noMax = true; } } return computeThreeOperand( "aten_clamp", v, [noMin, noMax]( const ExprHandle& in, const ExprHandle& min, const ExprHandle& max) { if (noMin && noMax) { return in; } else if (noMin) { return CompareSelect::make(in, max, max, in, kGT); } else if (noMax) { return CompareSelect::make(in, min, min, in, kLT); } else { return CompareSelect::make( in, min, min, CompareSelect::make(in, max, max, in, kGT), kLT); } }); } break; case aten::sigmoid: { return computeOneOperand( "aten_sigmoid", v, [](const ExprHandle& a) { return sigmoid(a); }); } break; case aten::reciprocal: { return computeOneOperand("aten_reciprocal", v, [](const ExprHandle& a) { return ExprHandle(1.0f) / a; }); } break; case aten::neg: { return computeOneOperand("aten_neg", v, [](const ExprHandle& a) { return ExprHandle(-0) - a; }); } break; case aten::relu: { return computeOneOperand("aten_relu", v, [](const ExprHandle& a) { auto zero = Cast::make(a.dtype(), 0); return ifThenElse(CompareSelect::make(a, zero, kLT), zero, a); }); } break; case aten::log: { return computeOneOperand("aten_log", v, [](const ExprHandle& a) { return log(promoteIntegerToFloat(a)); }); } break; case aten::log10: { return computeOneOperand("aten_log10", v, [](const ExprHandle& a) { return log10(promoteIntegerToFloat(a)); }); } break; case aten::log1p: { return computeOneOperand( "aten_log1p", v, [](const ExprHandle& a) { return log1p(a); }); } break; case aten::log2: { return computeOneOperand("aten_log2", v, [](const ExprHandle& a) { return log2(promoteIntegerToFloat(a)); }); } break; case aten::exp: { return computeOneOperand( "aten_exp", v, [](const ExprHandle& a) { return exp(a); }); } break; case aten::expm1: { return computeOneOperand( "aten_expm1", v, [](const ExprHandle& a) { return expm1(a); }); } break; case aten::erf: { return computeOneOperand( "aten_erf", v, [](const ExprHandle& a) { return erf(a); }); } break; case aten::erfc: { return computeOneOperand( "aten_erfc", v, [](const ExprHandle& a) { return erfc(a); }); } break; case aten::cos: { return computeOneOperand("aten_cos", v, [](const ExprHandle& a) { return cos(promoteIntegerToFloat(a)); }); } break; case aten::sin: { return computeOneOperand("aten_sin", v, [](const ExprHandle& a) { return sin(promoteIntegerToFloat(a)); }); } break; case aten::tan: { return computeOneOperand("aten_tan", v, [](const ExprHandle& a) { return tan(promoteIntegerToFloat(a)); }); } break; case aten::type_as: { auto const& n = v->node(); Tensor* rhs = tensors_.at(n->inputs()[1]->unique()); auto dtype = rhs->body()->dtype(); return computeOneOperand( "aten_type_as", v, [dtype](const ExprHandle& lhs) { return Cast::make(dtype, lhs); }); } break; case aten::rand_like: { hasRandom_ = true; return computeOneOperand("aten_rand_like", v, [](const ExprHandle& a) { return Intrinsics::make(IntrinsicsOp::kRand, a.dtype()); }); } break; case aten::pow: { return computeTwoOperand( "aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { if (!rhs.node()->isConstant()) { return pow(lhs, rhs); } double val = immediateAs(IRSimplifier::simplify(rhs.node())); if (val == 1.0f) { return lhs; } else if (val == 2.0f) { // NOLINT return lhs * lhs; } else if (val == 3.0f) { // NOLINT return (lhs * lhs) * lhs; } else if (val == 4.0f) { // NOLINT ExprHandle tmp = lhs * lhs; return tmp * tmp; } else if (val == 0.5f) { // NOLINT return sqrt(lhs); } else if (val == 0.0f) { return ExprHandle(1.0f); } else if (val == -0.5f) { // NOLINT return rsqrt(lhs); } else if (val == -1.0f) { return ExprHandle(1.0f) / lhs; } else if (val == -2.0f) { // NOLINT return ExprHandle(1.0f) / (lhs * lhs); } return pow(lhs, rhs); }); } break; case aten::fmod: { return computeTwoOperand( "aten_fmod", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return fmod(lhs, rhs); }); } break; case aten::lerp: { return computeThreeOperand( "aten_lerp", v, [](const ExprHandle& a, const ExprHandle& end, const ExprHandle& weight) { return a + weight * (end - a); }); } break; case aten::remainder: { auto imodImpl = [](const ExprHandle& lhs, const ExprHandle& rhs) { return Mod::make(lhs, rhs); }; auto fmodImpl = [](const ExprHandle& lhs, const ExprHandle& rhs) { return fmod((rhs + fmod(lhs, rhs)), rhs); }; { auto const& n = v->node(); auto const& shape = broadcastShapes( valueShape(n->inputs()[0]), valueShape(n->inputs()[1])); return Compute( "aten_remainder", c10::fmap(shape), [&](const std::vector& axes) { auto const& n = v->node(); std::vector indices(axes.begin(), axes.end()); std::vector inputs = { tensorOrConstant(n->inputs()[0], indices), tensorOrConstant(n->inputs()[1], indices), }; promoteInputs(inputs); bool allInt = true; for (auto& e : inputs) { if (e.dtype().is_floating_point()) { allInt = false; break; } } if (allInt) { return demoteOutput( imodImpl(inputs[0], inputs[1]), n->output()); } else { return demoteOutput( fmodImpl(inputs[0], inputs[1]), n->output()); } }); } } break; case aten::acos: { return computeOneOperand("aten_acos", v, [](const ExprHandle& a) { return acos(promoteIntegerToFloat(a)); }); } break; case aten::asin: { return computeOneOperand( "aten_asin", v, [](const ExprHandle& a) { return asin(a); }); } break; case aten::cosh: { return computeOneOperand( "aten_cosh", v, [](const ExprHandle& a) { return cosh(a); }); } break; case aten::sinh: { return computeOneOperand( "aten_sinh", v, [](const ExprHandle& a) { return sinh(a); }); } break; case aten::atan: { return computeOneOperand("aten_atan", v, [](const ExprHandle& a) { return atan(promoteIntegerToFloat(a)); }); } break; case aten::atan2: { return computeTwoOperand( "aten_atan2", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { return atan2(lhs, rhs); }); } break; case aten::tanh: { return computeOneOperand("aten_tanh", v, [](const ExprHandle& a) { return tanh(promoteIntegerToFloat(a)); }); } break; case aten::sqrt: { return computeOneOperand("aten_sqrt", v, [](const ExprHandle& a) { return sqrt(promoteIntegerToFloat(a)); }); } break; case aten::rsqrt: { return computeOneOperand( "aten_rsqrt", v, [](const ExprHandle& a) { return rsqrt(a); }); } break; case aten::abs: { return computeOneOperand( "aten_abs", v, [](const ExprHandle& a) { return fabs(a); }); } break; case aten::ceil: { return computeOneOperand( "aten_ceil", v, [](const ExprHandle& a) { return ceil(a); }); } break; case aten::floor: { return computeOneOperand( "aten_floor", v, [](const ExprHandle& a) { return floor(a); }); } break; case aten::round: { return computeOneOperand( "aten_round", v, [](const ExprHandle& a) { return round(a); }); } break; case aten::trunc: { return computeOneOperand( "aten_trunc", v, [](const ExprHandle& a) { return trunc(a); }); } break; case aten::threshold: { return computeThreeOperand( "aten_threshold", v, [](const ExprHandle& a, const ExprHandle& threshold, const ExprHandle& value) { return ifThenElse(CompareSelect::make(a, threshold, kLE), value, a); }); } break; case aten::where: { return computeConditionWithTwoOperand( "aten_where", v, [](const ExprHandle& a0, const ExprHandle& a1, const ExprHandle& a2) { return ifThenElse(a0, a1, a2); }); } break; case aten::frac: { return computeOneOperand( "aten_frac", v, [](const ExprHandle& a) { return a - floor(a); }); } break; case aten::lgamma: { return computeOneOperand( "aten_lgamma", v, [](const ExprHandle& a) { return lgamma(a); }); } break; case prim::ConstantChunk: { return Compute( "prim_constantchunk", dimsFromSizes(sizesForValue(v)), [this, v](const std::vector& axes) { auto const& n = v->node(); int64_t dim = n->i(attr::dim); int64_t chunks = n->i(attr::chunks); std::vector indices(axes.begin(), axes.end()); return chunk( tensors_.at(n->inputs()[0]->unique()), v->offset(), dim, chunks, indices); }); } case aten::cat: { return Compute( "aten_cat", dimsFromSizes(sizesForValue(v)), [this, v](const std::vector& axes) { auto const& n = v->node(); auto inputs = n->inputs()[0]->node()->inputs(); if (inputs.size() == 0) { throw std::runtime_error( "Empty input list is passed to aten::cat"); } // Some of the inputs can be empty tensors, we need to skip them // when we construct the expression, but we need to take them into // account in dtype promotion. std::vector nonempty_inputs; for (auto input : inputs) { if (input->type()->kind() == TypeKind::TensorType) { auto tt = input->type()->cast(); if (tt->isComplete() && tt->sizes().size() && tt->sizes()[0] && *tt->sizes()[0]) { nonempty_inputs.push_back(input); } } } // When all inputs are empty tensors, the tensor we create for this // computation would contain no elements, so it doesn't really // matter what we return here, so just return 0. if (!nonempty_inputs.size()) { return ExprHandle(0); } int64_t dim = n->inputs()[1]->node()->i(attr::value); if (dim < 0) { dim += axes.size(); } if (dim < 0 || dim >= axes.size()) { throw std::runtime_error("invalid 'dim' value in aten::cat"); } // Promote input types. // Note that we need to consider all inputs, including empty - they // also affect the resultant dtype. auto maybe_dtype = findDtypeForValue(inputs[0]); TORCH_INTERNAL_ASSERT( maybe_dtype, "Cannot find dtype for one of aten::cat inputs"); ScalarType highType = *maybe_dtype; for (const auto input : inputs) { auto maybe_dtype = findDtypeForValue(input); TORCH_INTERNAL_ASSERT( maybe_dtype, "Cannot find dtype for one of aten::cat inputs"); highType = promoteTypes(highType, *maybe_dtype); } // Now we know the final dtype, we know what inputs are non-empty, // and we know that there is at least one such an input. With all // that we construct a tensor expression performing the // concatenation. // The expression we build here is a cascading if-then-else that // essentially represents: // // inp1[i, j, k] if 0 < i < l1, // out[i,j,k] = inp2[i, j-l1, k] if l1 =< i < l1 + l2, // ... // inpN[i, j-l_N_1, k] if l1+l2+...l_N_1 < i // where l_i is the corresponding size of the i-th input. std::vector newAxes(axes.begin(), axes.end()); ExprHandle load = promoteToDtype( tensorOrConstant(nonempty_inputs[0], newAxes), highType); size_t offset = bufferSizes(tensors_.at(nonempty_inputs[0]->unique()))[dim]; newAxes[dim] = newAxes[dim] - IntImm::make(offset); for (size_t ii = 1; ii < nonempty_inputs.size(); ++ii) { auto input = nonempty_inputs[ii]; load = ifThenElse( CompareSelect::make(axes[dim], IntImm::make(offset), kLT), load, promoteToDtype(tensorOrConstant(input, newAxes), highType)); offset += bufferSizes(tensors_.at(input->unique()))[dim]; newAxes[dim] = axes[dim] - IntImm::make(offset); } return load; }); } case aten::slice: { return Compute( "aten_slice", dimsFromSizes(sizesForValue(v)), [this, v](const std::vector& axes) { auto const& n = v->node(); int dim = constant(n->inputs()[1]).AsNode()->value(); ExprHandle start = constant(n->inputs()[2]); ExprHandle stride = constant(n->inputs()[4]); std::vector newAxes(axes.begin(), axes.end()); newAxes[dim] = stride * newAxes[dim] + start; return tensorOrConstant(n->inputs()[0], newAxes); }); } case aten::unsqueeze: { return Compute( "aten_unsqueeze", dimsFromSizes(sizesForValue(v)), [this, v](const std::vector& axes) { auto const& n = v->node(); int64_t dim = constant(n->inputs()[1]).AsNode()->value(); if (dim < 0) { if (axes.size() == 0) { throw malformed_input("axes are zero handling unsqueeze"); } dim += axes.size(); } // To construct an expression for an 'unsqueezed' tensor we need to // drop the DIM-th axis, i.e. // unsqueezed_v[i,j,k,l] = v[i,j,l] # dim = 2 - drop index 'k' // 0 1 2 3 std::vector indices; int64_t i = 0; for (auto a : axes) { if (i++ != dim) { indices.emplace_back(ExprHandle(a.node())); } } return tensorOrConstant(n->inputs()[0], indices); }); } case aten::sum: { return computeSum(v); } case aten::softmax: { return computeSoftmax(v, false); } case aten::log_softmax: { return computeSoftmax(v, true); } default: { throw std::runtime_error("Unhandled node kind"); } } } Stmt* TensorExprKernel::generateStmt(BackendType backendType) { torch::jit::tensorexpr::LoopNest l(tensorOutputs_); GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n"); bool hasReduction = NodeFinder::find(l.root_stmt()).size() != 0; l.inlineIntermediateBufs(); if (backendType == kCudaCodeGen) { for (auto tensor : tensorOutputs_) { std::vector loops = l.getLoopStmtsFor(tensor); TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty"); For* flattened = nullptr; LoopNest::flatten(loops, &flattened); assert(flattened); int loopLevels = getTECudaPointwiseLoopLevels(); const int kDefaultLoopLevels = 2; loopLevels = (loopLevels > 0) ? loopLevels : kDefaultLoopLevels; int blockCount = getTECudaPointwiseBlockCount(); int blockSize = getTECudaPointwiseBlockSize(); if (loopLevels == 2) { For* outer; For* inner; const int kDefaultBlockSize = 512; if (blockSize < 0) { blockSize = kDefaultBlockSize; } l.splitWithMask(flattened, blockSize, &outer, &inner); l.setGPUBlockIndex(outer, 0); l.setGPUThreadIndex(inner, 0); } else if (loopLevels == 3) { For* outer; For* inner; For* inner1; For* inner2; // TODO: change the number of microprocessors const int kDefaultBlockCount = 1280; const int kDefaultBlockSize = 256; blockCount = (blockCount > 0) ? blockCount : kDefaultBlockCount; blockSize = (blockSize > 0) ? blockSize : kDefaultBlockSize; l.splitWithMask(flattened, blockCount * blockSize, &outer, &inner); l.splitWithMask(inner, blockSize, &inner1, &inner2); l.setGPUBlockIndex(inner1, 0); l.setGPUThreadIndex(inner2, 0); } else { throw std::runtime_error( "Invalid loop-level: " + c10::to_string(loopLevels)); } } } if (backendType == kBlockCodeGen) { auto block_analysis = std::make_unique(); for (auto tensor : tensorOutputs_) { const int default_fp16_blocksize = 16; const int default_uint8_blocksize = 32; int blockSize = default_fp16_blocksize; // We only handle looplevels == 2 for now // Run Block analysis to get multi dim buffer info auto root_stmt = l.root_stmt(); root_stmt->accept(block_analysis.get()); if (tensor->buf()->dtype().scalar_type() == ScalarType::Byte) { blockSize = default_uint8_blocksize; } std::vector loops = l.getLoopStmtsFor(tensor); TORCH_INTERNAL_ASSERT(!loops.empty(), "loops should not be empty"); For* flattened = nullptr; LoopNest::flatten(loops, &flattened); assert(flattened); For* outer = nullptr; For* inner = nullptr; l.splitWithMask(flattened, blockSize, &outer, &inner); l.setGPUBlockIndex(outer, 0); l.setGPUThreadIndex(inner, 0); l.setBufferMap(outer, block_analysis->getBufferMap()); } } l.prepareForCodegen(); if (backendType == kLLVMCodeGen && !hasReduction) { l.vectorizeInnerLoops(); } Stmt* stmt = l.root_stmt(); // Arithmetic Simplification. stmt = IRSimplifier::simplify(stmt); GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), "\n"); return stmt; } std::string TensorExprKernel::getCodeGenName(BackendType backendType) { switch (backendType) { case kCudaCodeGen: return "cuda_codegen"; case kLLVMCodeGen: return "llvm_codegen"; case kSimpleIREval: return "simple_ir_eval"; case kBlockCodeGen: return "block_codegen"; default: throw std::runtime_error( "invalid backend type: " + c10::to_string(static_cast(backendType))); } } std::vector TensorExprKernel::prepareBufferArgs() { std::vector params; for (auto const& arg : kernelArgs_) { params.push_back(arg.buffer()); for (auto const& size : arg.sizes()) { params.emplace_back(size.var); } for (auto const& stride : arg.strides()) { params.emplace_back(stride.var); } } for (auto& o : tensorOutputs_) { params.emplace_back(o); } return params; } template static bool isValidPrimProperty(const c10::optional& a, T b) { return !a.has_value() || *a == b; } TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice( at::Device device) { BackendType backendType = BackendType::kUninitialized; if (device.type() == at::kCUDA) { backendType = kCudaCodeGen; } else if (device.type() == at::kCPU && getTEGenerateBlockCode()) { backendType = kBlockCodeGen; } else if (device.type() == at::kCPU) { #ifdef TORCH_ENABLE_LLVM backendType = kLLVMCodeGen; #else backendType = kSimpleIREval; #endif if (getTEMustUseLLVMOnCPU() && backendType == kSimpleIREval) { throw std::runtime_error("LLVM Backend not found"); } } else { throw std::runtime_error("Invalid device type"); } return backendType; } void TensorExprKernel::bindInput(const torch::jit::Value* input) { auto const& t = input->type(); switch (t->kind()) { case TypeKind::TensorType: { auto tt = input->type()->cast(); Placeholder inBuffer( "t" + input->debugName(), ToDtype(static_cast(*tt->scalarType())), {0}); std::vector inputTensorDims; for (size_t i = 0; i < *tt->sizes().size(); i++) { auto const size = *tt->sizes()[i]; inputTensorDims.emplace_back( DimArg(IntImm::make(size), "i" + c10::to_string(i))); } auto const strides = tt->strides(); tensors_.emplace( input->unique(), Compute( "input" + c10::to_string(tensors_.size() + 1), inputTensorDims, [&](const std::vector& axes) { ExprHandle idx = 0; for (size_t i = 0; i < axes.size(); i++) { idx = idx + axes[i] * IntImm::make(*strides[i]); } return inBuffer.load(idx); })); kernelArgs_.emplace_back( inBuffer, std::vector(), std::vector()); break; } case TypeKind::FloatType: { VarHandle v("v" + input->debugName(), kFloat); kernelArgs_.emplace_back(v); scalars_.emplace(input->unique(), v); break; } case TypeKind::BoolType: { VarHandle v("v" + input->debugName(), kBool); kernelArgs_.emplace_back(v); scalars_.emplace(input->unique(), v); break; } case TypeKind::IntType: { VarHandle v("v" + input->debugName(), kInt); kernelArgs_.emplace_back(v); scalars_.emplace(input->unique(), v); break; } default: { throw unsupported_dtype(); break; } } } namespace { // Remove all indices from axes positions. std::vector squeezeIndices( const ParameterList& indices, const std::vector& axes) { std::vector indices_squeezed; for (size_t dim = 0; dim < indices.size(); ++dim) { if (!std::count(axes.begin(), axes.end(), dim)) { indices_squeezed.push_back(indices[dim]); } } return indices_squeezed; } } // namespace Tensor* TensorExprKernel::computeSum(const torch::jit::Value* v) { auto reduction_info = getReductionInfo(v->node()); return Reduce( "sum", reduction_info.outputDims, Sum(), [&](ParameterList& indices) { const auto& axes = reduction_info.axes; // "Squeeze" out indices inserted when keepdim is set. auto indices_squeezed = reduction_info.keepdim ? squeezeIndices(indices, axes) : indices; TORCH_INTERNAL_ASSERT(axes.size() <= indices_squeezed.size()); // Move innermost indices into axes positions: // 1. Fill the outermost indices first. // 2. Insert the innermost indices into the correct axis position, // displacing the outermost indices as needed. std::vector indices_exprs; size_t i = 0; for (; i < indices_squeezed.size() - axes.size(); ++i) { indices_exprs.push_back(indices_squeezed[i]); } for (auto axis : axes) { indices_exprs.insert( indices_exprs.begin() + axis, indices_squeezed[i]); ++i; } auto indexed = tensorOrConstant(v->node()->input(0), indices_exprs); if (reduction_info.dtype) { return Cast::make(*reduction_info.dtype, indexed); } else { return indexed; } }, reduction_info.reductionDims); } Tensor* TensorExprKernel::computeSoftmax( const torch::jit::Value* v, bool log_softmax) { // Softmax is computed as follows: // softmax(vi) = exp(vi) / sum(exp(vi)) // // In order to avoid overflow issues due to exp of a large number, we // subtract the max of that dim before computing exp. // softmax(vi) = exp(vi - max(vi)) / sum(exp(vi - max(vi))) // // This is implemented as 4 loopnests: // - First loop computes the max over the softmax dim. // - Second loop computes exp for every element in v after subtracting // the max of the softmax dim it belongs to. // - Third loop computes the sum over the softmax dim. // - Final loop computes softmax for every element in v. // LogSoftmax is computed as follows: // log_softmax(vi) = log(softmax(vi)) // = vi - log(sum(exp(vi))) // // Using the same max trick as above: // log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi)))) // // This is implemented as 5 loopnests: // - First loop computes the max over the softmax dim. // - Second loop computes exp for every element in v after subtracting // the max of the softmax dim it belongs to. // - Third loop computes the sum over the softmax dim. // - Fourth loop computes log for every element in the sum. // - Final loop computes the log_softmax for every element in v. TORCH_INTERNAL_ASSERT(v->node()->inputs().size() == 3); auto output_dims = dimsFromSizes(sizesForValue(v)); // We do not handle None for dims (input 1) because that is supposed to // be deprecated. TORCH_INTERNAL_ASSERT(v->node()->input(1)->node()->kind() == prim::Constant); size_t softmax_dim = v->node()->input(1)->node()->i(attr::value); TORCH_INTERNAL_ASSERT(softmax_dim < output_dims.size()); std::vector non_softmax_dims; for (size_t i = 0; i < output_dims.size(); ++i) { if (i != softmax_dim) { non_softmax_dims.push_back(output_dims[i]); } } // Softmax implementation includes two reductions, one to find the max and // the other to calculate the sum along the softmax dim. These reductions // will have the softmax dimension as the inner most loop. So, the innermost // index in the indices will refer to the softmax dimension. // Update the indices by moving the softmax dimension index to the // appropriate position. auto move_softmax_dim_index_to_pos = [&](const ParameterList& indices) { std::vector new_indices; for (auto ind : indices) { new_indices.push_back(ind); } for (size_t i = softmax_dim; i < indices.size() - 1; ++i) { new_indices[i + 1] = indices[i]; } new_indices[softmax_dim] = indices[indices.size() - 1]; return new_indices; }; // Remove the index corresponding to the softmax dimension. auto remove_softmax_dim_index = [&](const ParameterList& indices) { std::vector new_indices; for (size_t i = 0; i < indices.size(); ++i) { if (i != softmax_dim) { new_indices.push_back(indices[i]); } } return new_indices; }; auto convert_indices_to_expr_handle = [&](const ParameterList& indices) { std::vector new_indices(indices.size()); for (size_t i = 0; i < indices.size(); ++i) { new_indices[i] = indices[i]; } return new_indices; }; c10::optional dtype = ToDtype(ScalarType::None); auto maybe_dtype = v->node()->get(attr::dtype); if (maybe_dtype && !maybe_dtype->isNone()) { dtype = ToDtype(static_cast(maybe_dtype->toInt())); } auto max = Reduce( "aten_softmax_max", non_softmax_dims, Maximum(dtype.value()), [&](ParameterList& indices) { return tensorOrConstant( v->node()->inputs()[0], move_softmax_dim_index_to_pos(indices)); }, {output_dims[softmax_dim]}); auto e = Compute("aten_softmax_exp", output_dims, [&](ParameterList& indices) { auto inp = tensorOrConstant( v->node()->inputs()[0], convert_indices_to_expr_handle(indices)); return exp(inp - max->call(remove_softmax_dim_index(indices))); }); auto sum = Reduce( "aten_softmax_sum", non_softmax_dims, Sum(), [&](ParameterList& indices) { return e->call(move_softmax_dim_index_to_pos(indices)); }, {output_dims[softmax_dim]}); if (!log_softmax) { return Compute("aten_softmax", output_dims, [&](ParameterList& indices) { return e->call(indices) / sum->call(remove_softmax_dim_index(indices)); }); } auto log_sum = Compute( "aten_softmax_log_sum", non_softmax_dims, [&](ParameterList& indices) { return log(sum->call(indices)); }); return Compute("aten_log_softmax", output_dims, [&](ParameterList& indices) { auto inp = tensorOrConstant( v->node()->inputs()[0], convert_indices_to_expr_handle(indices)); auto non_softmax_indices = remove_softmax_dim_index(indices); return inp - max->call(non_softmax_indices) - log_sum->call(non_softmax_indices); }); } TensorExprKernel::ReductionInfo TensorExprKernel::getReductionInfo( const torch::jit::Node* node) { std::vector axes; bool keepdim = false; // aten::sum takes the input tensor named self. auto sizes = sizesForValue(node->namedInput(attr::self)); const auto inputs = node->inputs(); int rank = sizes.size(); if (inputs.size() > 2) { auto nodeAxes = getReductionAxes(node); // Canonicalize axes: wrap around, sort and make unique. for (auto axis : nodeAxes) { axes.push_back(at::maybe_wrap_dim(axis, rank)); } std::sort(axes.begin(), axes.end()); axes.erase(std::unique(axes.begin(), axes.end()), axes.end()); keepdim = node->get(attr::keepdim)->toBool(); } else { axes.resize(sizes.size()); std::iota(axes.begin(), axes.end(), 0); } // Axes go into reduction dimensions. std::vector reductionDims; reductionDims.reserve(sizes.size()); for (size_t axis : axes) { reductionDims.emplace_back(sizes[axis]); } auto allDims = dimsFromSizes(sizes); std::vector outputDims; // Output dimensions are the complement of axes. When keepdim is set, a // one-sized dimension is inserted for each axis. for (size_t dim = 0; dim < allDims.size(); ++dim) { if (!std::count(axes.begin(), axes.end(), dim)) { outputDims.emplace_back(sizes[dim]); } else if (keepdim) { outputDims.emplace_back(1); } } c10::optional dtype; auto dtypeValue = node->get(attr::dtype); if (!dtypeValue->isNone()) { auto scalarType = static_cast(dtypeValue->toInt()); dtype = ToDtype(scalarType); } return {reductionDims, outputDims, axes, keepdim, dtype}; } std::vector TensorExprKernel::getReductionAxes( const torch::jit::Node* node) { std::vector axes; auto axesNode = node->namedInput(attr::dim)->node(); // There are two possible representations for reduction axes: // 1. A prim::ListConstruct of integer constants. // 2. A prim::Constant list of integer ival's. // We need to handle both of them. if (axesNode->kind() == prim::ListConstruct) { for (auto axisNode : axesNode->inputs()) { axes.push_back(constant(axisNode).AsNode()->value()); } return axes; } TORCH_INTERNAL_ASSERT(axesNode->kind() == prim::Constant); TORCH_INTERNAL_ASSERT(axesNode->kindOf(attr::value) == AttributeKind::ival); const auto& genericList = axesNode->ival(attr::value).toList(); for (const IValue axisNode : genericList) { axes.push_back(axisNode.toInt()); } return axes; } void TensorExprKernel::compile() { KernelScope kernelScope(&kernelArena_); GRAPH_DUMP("TensorExprKernel graph:", graph_); // Bind inputs to buffers. nInputs_ = graph_->inputs().size(); for (auto const& input : graph_->inputs()) { bindInput(input); inputTypes_.push_back(input->type()); } // Bind nodes to tensor compute expressions. for (auto const& n : graph_->nodes()) { if (n->kind() == prim::Constant || n->kind() == prim::ListConstruct) { continue; } else { for (auto const& output : n->outputs()) { if (output->hasUses()) { tensors_.emplace(output->unique(), computeValue(output)); } } } if (hasRandom_ && hasBroadcast_) { throw std::runtime_error( "Cannot support broadcast and random within one kernel"); } } // Move output operands from `tensors_` to `tensorOutputs_` for (const auto& output : graph_->outputs()) { if (!tensors_.count(output->unique())) { throw malformed_input("cannot find output Tensor"); } tensorOutputs_.emplace_back(tensors_.at(output->unique())); tensors_.erase(output->unique()); } device_ = *pickDeviceType(graph_->inputs()); BackendType backendType = inferBackendTypeFromDevice(device_); Stmt* stmt = generateStmt(backendType); // Set up formal params (inputs, then outputs) for kernel. std::vector params = prepareBufferArgs(); // Generate code. codegen_ = CreateCodeGen( getCodeGenName(backendType), stmt, params, device_, SubgraphUtils::generateNameForGraph(graph_)); } TensorExprKernel::TensorExprKernel(const std::shared_ptr& subgraph) : graph_(subgraph), code_(subgraph, "") { if (!fallbackAllowed()) { compile(); return; } try { compile(); } catch (...) { fallback_ = true; } } void TensorExprKernel::run(Stack& stack) { if (fallbackEnforced()) { fallback(stack); return; } if (!fallbackAllowed()) { runKernel(stack); return; } if (fallback_) { fallback(stack); return; } try { runKernel(stack); } catch (...) { fallback_ = true; fallback(stack); } } std::vector TensorExprKernel::prepareRunArgs( const at::ArrayRef& inputs, std::vector& outputs) { std::map varToSize; std::vector runArgs; for (size_t i = 0; i < inputs.size(); i++) { auto const& input = inputs[i]; if (input.isInt()) { runArgs.emplace_back((int32_t)input.toInt()); } else if (input.isDouble()) { runArgs.emplace_back((float)input.toDouble()); } else if (input.isTensor()) { auto const& tensor = input.toTensor(); runArgs.emplace_back(tensor.data_ptr()); for (auto const& size : kernelArgs_[i].sizes()) { int32_t s = tensor.sizes()[size.idx]; runArgs.emplace_back(s); varToSize[size.var.node()] = s; } for (auto const& stride : kernelArgs_[i].strides()) { int32_t s = tensor.strides()[stride.idx]; runArgs.emplace_back(s); } } } for (auto& o : tensorOutputs_) { std::vector tensorSize; for (const Expr* dim : o->dims()) { auto it = varToSize.find(dim); if (it != varToSize.end()) { tensorSize.push_back(it->second); } else { const IntImm* s = dynamic_cast(dim); if (!s) { throw malformed_input("output expected Int", dim); } tensorSize.push_back(s->value()); } } outputs.push_back(at::empty( tensorSize, c10::TensorOptions(tensorType(o)).device(device_))); runArgs.emplace_back(outputs.back().data_ptr()); } return runArgs; } Stmt* TensorExprKernel::getCodeGenStmt() { return codegen_->stmt(); } void TensorExprKernel::runKernel(Stack& stack) { KernelScope kernelScope(&kernelArena_); // Set up arguments (inputs, then outputs) for kernel call. auto inputs = last(stack, nInputs_); std::vector outputs; std::vector runArgs = prepareRunArgs(inputs, outputs); // Call the kernel. codegen_->call(runArgs); // Update the stack. drop(stack, nInputs_); for (auto& o : outputs) { push_one(stack, std::move(o)); } }