#include #include #include #include namespace torch::jit::tensorexpr { static Dtype ChooseDtype(const Dtype& buffer_dtype, const Dtype& index_dtype) { return Dtype(buffer_dtype, index_dtype.lanes()); } static Dtype dtypeOfIndices(const std::vector& indices) { if (indices.empty()) { // Return something so we can handle scalar buffers. return kInt; } return indices.at(0)->dtype(); } static void castIndicesToInts(std::vector& indices) { // Cast all indices to either Int or Long auto index_dtype = ScalarType::Int; for (auto& index : indices) { if (index->dtype().scalar_type() == ScalarType::Long) { // If any of the indexes is Long, cast all of them to Long index_dtype = ScalarType::Long; break; } } for (auto& index : indices) { const Dtype& dt = index->dtype(); if (c10::isIntegralType(dt.scalar_type(), true) && dt.scalar_type() != index_dtype) { index = alloc(Dtype(index_dtype, dt.lanes()), index); } } } Load::Load(Dtype dtype, BufPtr buf, std::vector indices) : ExprNodeBase(dtype), buf_(std::move(buf)), indices_(std::move(indices)) { castIndicesToInts(indices_); } Load::Load(const BufPtr& buf, const std::vector& indices) : Load(ChooseDtype(buf->dtype(), dtypeOfIndices(indices)), buf, indices) {} ExprHandle Load::make( Dtype dtype, const BufHandle& buf, const std::vector& indices) { return ExprHandle( alloc(dtype, buf.node(), ExprHandleVectorToExprVector(indices))); } ExprHandle Load::make( const BufHandle& buf, const std::vector& indices) { return Load::make(buf.dtype(), buf, indices); } Store::Store(BufPtr buf, std::vector indices, ExprPtr value) : buf_(std::move(buf)), indices_(std::move(indices)), value_(std::move(value)) { castIndicesToInts(indices_); } StorePtr Store::make( const BufHandle& buf, const std::vector& indices, const ExprHandle& value) { return alloc( buf.node(), ExprHandleVectorToExprVector(indices), value.node()); } StorePtr BufHandle::store( const std::vector& args, const ExprHandle& value) const { return Store::make(*this, args, value); } ExprPtr flatten_index( const std::vector& dims, const std::vector& indices, const std::vector& strides) { // Handle already flattened indices first if (indices.size() == 1) { return indices[0]; } size_t ndim = dims.size(); if (ndim != indices.size()) { throw malformed_input("dimensions mismatch in flatten_index"); } if (ndim != strides.size()) { throw malformed_input("strides mismatch in flatten_index"); } if (ndim == 0) { return alloc(0); } ExprPtr total_index = immLike(indices[0], 0); for (const auto i : c10::irange(ndim)) { total_index = alloc(total_index, alloc(indices[i], strides[i])); } return total_index; } Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1) { if (op_type == kIsNan) { return dt1.cloneWithScalarType(ScalarType::Int); } // TODO: check the op_type and make a real decision return dt1; } Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2) { // TODO: check the op_type and make a real decision return dt1; } Dtype Intrinsics::IntrinsicsDtype( IntrinsicsOp op_type, const std::vector& params) { // TODO: check the op_type and make a real decision // Doesn't this fail with kRand? if (params.empty()) { throw malformed_input("invalid params in Intrinsics"); } else if (params.size() == 1) { return IntrinsicsDtype(op_type, params[0]->dtype()); } else if (params.size() == 2) { return IntrinsicsDtype(op_type, params[0]->dtype(), params[1]->dtype()); } return params[0]->dtype(); } size_t Intrinsics::OpArgCount(IntrinsicsOp op_type) { switch (op_type) { case kSin: case kCos: case kTan: case kAsin: case kAcos: case kAtan: case kSinh: case kCosh: case kTanh: case kSigmoid: case kExp: case kExpm1: case kAbs: case kLog: case kLog2: case kLog10: case kLog1p: case kErf: case kErfc: case kSqrt: case kRsqrt: case kCeil: case kFloor: case kRound: case kTrunc: case kFrac: case kLgamma: case kIsNan: return 1; case kRand: return 0; case kAtan2: case kFmod: case kPow: case kRemainder: return 2; default: throw std::runtime_error("invalid op_type: " + std::to_string(op_type)); } } ExternalCallPtr ExternalCall::make( BufHandle buf, const std::string& func_name, const std::vector& buf_args, const std::vector& args) { std::vector buf_arg_nodes; buf_arg_nodes.reserve(buf_args.size()); for (const BufHandle& buf_arg : buf_args) { buf_arg_nodes.push_back(buf_arg.node()); } return alloc( buf.node(), func_name, buf_arg_nodes, ExprHandleVectorToExprVector(args)); } ExternalCallWithAllocPtr ExternalCallWithAlloc::make( const std::string& func_name, const std::vector& buf_out_args, const std::vector& buf_args, const std::vector& args) { std::vector buf_out_arg_nodes; buf_out_arg_nodes.reserve(buf_out_args.size()); for (const BufHandle& buf_out_arg : buf_out_args) { buf_out_arg_nodes.push_back(buf_out_arg.node()); } std::vector buf_arg_nodes; buf_arg_nodes.reserve(buf_args.size()); for (const BufHandle& buf_arg : buf_args) { buf_arg_nodes.push_back(buf_arg.node()); } return alloc( func_name, buf_out_arg_nodes, buf_arg_nodes, ExprHandleVectorToExprVector(args)); } FreeExtPtr FreeExt::make(const std::vector& bufs) { std::vector buf_nodes; buf_nodes.reserve(bufs.size()); for (const BufHandle& buf : bufs) { buf_nodes.push_back(buf.node()); } return alloc(buf_nodes); } std::vector ExprHandleVectorToExprVector( const std::vector& v) { std::vector result(v.size()); for (const auto i : c10::irange(v.size())) { result[i] = v[i].node(); } return result; } std::vector ExprVectorToExprHandleVector( const std::vector& v) { std::vector result(v.size()); for (const auto i : c10::irange(v.size())) { result[i] = ExprHandle(v[i]); } return result; } std::vector VarHandleVectorToVarVector( const std::vector& v) { std::vector result(v.size()); for (const auto i : c10::irange(v.size())) { result[i] = v[i].node(); } return result; } std::vector VarVectorToVarHandleVector( const std::vector& v) { std::vector result(v.size()); for (const auto i : c10::irange(v.size())) { result[i] = VarHandle(v[i]); } return result; } bool immediateIsNegative(const ExprPtr& e) { #define TYPE_CASE(Type, Name) \ if (Name##ImmPtr imm = to(e)) { \ return imm->value() < 0; \ } AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE); #undef TYPE_CASE return false; } bool immediateIsPositive(const ExprPtr& e) { #define TYPE_CASE(Type, Name) \ if (Name##ImmPtr imm = to(e)) { \ return imm->value() > 0; \ } AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE) #undef TYPE_CASE return false; } bool immediateIsZero(const ExprPtr& e) { #define TYPE_CASE(Type, Name) \ if (Name##ImmPtr imm = to(e)) { \ return imm->value() == 0; \ } AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE) #undef TYPE_CASE return false; } } // namespace torch::jit::tensorexpr