#pragma once #include #include #include #include #include #include #include #include namespace torch::jit::tensorexpr { enum CompareSelectOperation { kEQ = 0, kGT, kGE, kLT, kLE, kNE, }; enum CompareSelectBias { kUnbiased, kLikely, kUnlikely, }; inline int getPrecedence(IRNodeType ty) { // Match C++ operator precedence rules, since some pretty-print expressions to // C++. SEE: https://en.cppreference.com/w/cpp/language/operator_precedence switch (ty) { case kPrimitive: return 0; case kCast: case kBitCast: return 2; case kAdd: case kSub: return 6; case kMul: case kDiv: case kMod: return 5; case kMax: case kMin: return 99; case kAnd: return 11; case kOr: return 13; case kLshift: case kRshift: return 7; case kXor: return 12; case kCompareSelect: return 16; default: return 99; } } class TORCH_API Cast : public ExprNode { public: ExprPtr src_value() const { return src_value_; } void set_src_value(ExprPtr src_value) { src_value_ = std::move(src_value); } static ExprHandle make(Dtype dtype, const ExprHandle& src_value) { return ExprHandle(alloc(dtype, src_value.node())); } Cast(Dtype dtype, ExprPtr src_value) : ExprNodeBase(dtype, kCast), src_value_(std::move(src_value)) {} bool isConstant() const override { return src_value_->isConstant(); } private: ExprPtr src_value_; }; template ExprHandle cast(const ExprHandle& src_value) { return Cast::make(Dtype(ToDtype(), src_value.dtype().lanes()), src_value); } // This is a bitwise cast, akin to bitcast in LLVM class TORCH_API BitCast : public ExprNode { public: ExprPtr src_value() const { return src_value_; } void set_src_value(ExprPtr src_value) { src_value_ = std::move(src_value); } static ExprHandle make(Dtype dtype, const ExprHandle& src_value) { return ExprHandle(alloc(dtype, src_value.node())); } BitCast(Dtype dtype, ExprPtr src_value) : ExprNodeBase(dtype, kBitCast), src_value_(std::move(src_value)) { TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size()); } bool isConstant() const override { return src_value_->isConstant(); } private: ExprPtr src_value_; }; template ExprHandle bitcast(const ExprHandle& src_value) { return BitCast::make( Dtype(ToDtype(), src_value.dtype().lanes()), src_value); } // Represent the expression node for binary operators. // A CRTP pattern to share common code among the operators. template class BinaryOpNode : public ExprNode { public: ExprPtr lhs() const { return this->lhs_; } ExprPtr rhs() const { return this->rhs_; } void set_lhs(ExprPtr lhs) { lhs_ = std::move(lhs); } void set_rhs(ExprPtr rhs) { rhs_ = std::move(rhs); } static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) { return ExprHandle(alloc(lhs.node(), rhs.node())); } BinaryOpNode( ExprPtr lhs_v, ExprPtr rhs_v, IRNodeType expr_type, ScalarType ret_type = ScalarType::Undefined) : ExprNode( BinaryOpDtype(lhs_v->dtype(), rhs_v->dtype(), ret_type), expr_type), lhs_(CastIfNeeded(std::move(lhs_v), ExprNode::dtype())), rhs_(CastIfNeeded(std::move(rhs_v), ExprNode::dtype())) {} private: static ExprPtr CastIfNeeded(ExprPtr expr, Dtype dst_dtype) { if (expr->dtype() == dst_dtype) { return expr; } return Cast::make(dst_dtype, ExprHandle(std::move(expr))).node(); } ExprPtr lhs_; ExprPtr rhs_; }; namespace detail { template void bin_op_deducer(BinaryOpNode); bool bin_op_deducer(...); } // namespace detail class TORCH_API Add : public BinaryOpNode { public: Add(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kAdd) {} }; class TORCH_API Sub : public BinaryOpNode { public: Sub(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kSub) {} }; class TORCH_API Mul : public BinaryOpNode { public: Mul(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMul) {} }; class TORCH_API Div : public BinaryOpNode
{ public: Div(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kDiv) {} }; class TORCH_API Mod : public BinaryOpNode { public: Mod(ExprPtr lhs, ExprPtr rhs) : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMod) {} }; template class BitwiseOpNode : public BinaryOpNode { public: BitwiseOpNode(ExprPtr lhs, ExprPtr rhs, IRNodeType type) : BinaryOpNode(std::move(lhs), std::move(rhs), type) {} static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) { if (!lhs.dtype().is_integral()) { throw unsupported_dtype(); } if (lhs.dtype() != rhs.dtype()) { throw malformed_input("lhs/rhs dtype mismatch"); } return BinaryOpNode::make(lhs, rhs); } }; class TORCH_API And : public BitwiseOpNode { public: And(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kAnd) {} }; class TORCH_API Or : public BitwiseOpNode { public: Or(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kOr) {} }; class TORCH_API Xor : public BitwiseOpNode { public: Xor(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kXor) {} }; class TORCH_API Lshift : public BitwiseOpNode { public: Lshift(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kLshift) {} }; class TORCH_API Rshift : public BitwiseOpNode { public: Rshift(ExprPtr lhs, ExprPtr rhs) : BitwiseOpNode(std::move(lhs), std::move(rhs), IRNodeType::kRshift) {} }; // TODO: add TORCH_API // Currently adding it results in a compilation error on Windows class Max : public BinaryOpNode { private: bool propagate_nans_; public: Max(ExprPtr lhs, ExprPtr rhs, bool propagate_nans) : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMax), propagate_nans_(propagate_nans) {} bool propagate_nans() const { return propagate_nans_; } static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete; static ExprHandle make( const ExprHandle& lhs, const ExprHandle& rhs, bool propagate_nans) { return ExprHandle(alloc(lhs.node(), rhs.node(), propagate_nans)); } }; // TODO: add TORCH_API // Currently adding it results in a compilation error on Windows class Min : public BinaryOpNode { private: bool propagate_nans_; public: Min(ExprPtr lhs, ExprPtr rhs, bool propagate_nans) : BinaryOpNode(std::move(lhs), std::move(rhs), IRNodeType::kMin), propagate_nans_(propagate_nans) {} bool propagate_nans() const { return propagate_nans_; } static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete; static ExprHandle make( const ExprHandle& lhs, const ExprHandle& rhs, bool propagate_nans) { return ExprHandle(alloc(lhs.node(), rhs.node(), propagate_nans)); } }; // Encode typed immediate values e.g. IntImm, FloatImm. #define IMM_DECLARE(Type, Name) \ class TORCH_API Name##Imm : public ExprNode { \ public: \ Name##Imm(Type value) \ : ExprNodeBase(k##Name, kPrimitive), value_(value) {} \ bool isConstant() const override { \ return true; \ } \ Type value() const { \ return value_; \ } \ static ExprHandle make(Type value) { \ return ExprHandle(alloc(value)); \ } \ \ private: \ Type value_; \ }; AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE) #undef IMM_DECLARE // Get immediate by ScalarType. template ExprPtr getImmediateByType(ScalarType immType, T initialVal) { switch (immType) { #define TYPE_CASE(Type, Name) \ case ScalarType::Name: \ return alloc(Type(initialVal)); AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE) #undef TYPE_CASE default: throw unsupported_dtype(); } return nullptr; } template ExprPtr getImmediateByType(Dtype dtype, T initialVal) { return getImmediateByType(dtype.scalar_type(), initialVal); } template ExprPtr immLike(const ExprPtr& e, T v) { return getImmediateByType(e->dtype(), v); } template ExprPtr immLike(const ExprHandle& e, T v) { return immLike(e.node(), v); } inline std::optional intValue(const ExprPtr& e) { #define TYPE_CASE(Type, Name) \ if (auto v = to(e)) { \ return v->value(); \ } AT_FORALL_INT_TYPES(TYPE_CASE); #undef TYPE_CASE return std::nullopt; } inline std::optional intValue(const ExprHandle& e) { return intValue(e.node()); } template T immediateAs(const ExprPtr& e) { #define TYPE_CASE(Type, Name) \ if (Name##ImmPtr imm = to(e)) { \ return imm->value(); \ } AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE) #undef TYPE_CASE throw unsupported_dtype(); return 0; } template T immediateAs(const ExprHandle& e) { return immediateAs(e.node()); } template bool immediateEquals(const ExprPtr& e, T val) { #define TYPE_CASE(Type, Name) \ if (Name##ImmPtr imm = to(e)) { \ return imm->value() == val; \ } AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE) #undef TYPE_CASE throw unsupported_dtype(); return false; } TORCH_API bool immediateIsNegative(const ExprPtr& e); TORCH_API bool immediateIsPositive(const ExprPtr& e); TORCH_API bool immediateIsZero(const ExprPtr& e); // Represents a ramp vector node: // [base, base + 1 * stride, ... , base + (lanes - 1) * stride] class TORCH_API Ramp : public ExprNode { public: ExprPtr base() const { return base_; } ExprPtr stride() const { return stride_; } void set_base(ExprPtr base) { base_ = std::move(base); } void set_stride(ExprPtr stride) { stride_ = std::move(stride); } static ExprHandle make( const ExprHandle& base, const ExprHandle& stride, int64_t lanes) { if (stride.dtype() != base.dtype()) { throw malformed_input("Bad stride in Ramp"); } return ExprHandle(alloc(base.node(), stride.node(), lanes)); } int64_t lanes() const { return lanes_; } Ramp(ExprPtr base, ExprPtr stride, int64_t lanes) : ExprNodeBase(Dtype(base->dtype(), lanes)), base_(std::move(base)), stride_(std::move(stride)), lanes_(lanes) {} private: ExprPtr base_; ExprPtr stride_; int64_t lanes_; }; class TORCH_API Load : public ExprNode { public: VarPtr base_handle() const { return buf_->base_handle(); } std::vector indices() const { return indices_; } ExprPtr flat_index() const { TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened."); return indices_[0]; } BufPtr buf() const { return buf_; } void set_buf(BufPtr buf) { buf_ = std::move(buf); } void set_indices(std::vector indices) { indices_ = std::move(indices); } static ExprHandle make( Dtype dtype, const BufHandle& buf, const std::vector& indices); static ExprHandle make( const BufHandle& buf, const std::vector& indices); Load(Dtype dtype, BufPtr base_handle, std::vector indices); Load(const BufPtr& base_handle, const std::vector& indices); private: BufPtr buf_; std::vector indices_; }; class TORCH_API Broadcast : public ExprNode { public: ExprPtr value() const { return value_; } void set_value(ExprPtr value) { value_ = std::move(value); } int64_t lanes() const { return lanes_; } static ExprHandle make(const ExprHandle& value, int64_t lanes) { return ExprHandle(alloc(value.node(), lanes)); } Broadcast(ExprPtr value, int64_t lanes) : ExprNodeBase(Dtype(value->dtype(), lanes)), value_(std::move(value)), lanes_(lanes) {} private: ExprPtr value_; int64_t lanes_; }; class TORCH_API IfThenElse : public ExprNode { public: ExprPtr condition() const { return condition_; } // Lazily evaluated only if condition is true ExprPtr true_value() const { return true_; } // Lazily evaluated only if condition is false ExprPtr false_value() const { return false_; } void set_condition(ExprPtr condition) { condition_ = std::move(condition); } void set_true_value(ExprPtr true_value) { true_ = std::move(true_value); } void set_false_value(ExprPtr false_value) { false_ = std::move(false_value); } static ExprHandle make( const ExprHandle& c, const ExprHandle& t, const ExprHandle& f) { if (!c.dtype().is_integral()) { throw unsupported_dtype(); } if (c.dtype().lanes() != 1) { throw unsupported_dtype(); } if (t.dtype() != f.dtype()) { throw malformed_input("Bad dtype in IfThenElse"); } return ExprHandle(alloc(c.node(), t.node(), f.node())); } IfThenElse(ExprPtr c, ExprPtr t, ExprPtr f) : ExprNodeBase(t->dtype()), condition_(std::move(c)), true_(std::move(t)), false_(std::move(f)) {} private: ExprPtr condition_; ExprPtr true_; ExprPtr false_; }; class TORCH_API CompareSelect : public ExprNode { public: CompareSelectOperation compare_select_op() const { return compare_op_; } ExprPtr lhs() const { return this->lhs_; } ExprPtr rhs() const { return this->rhs_; } ExprPtr ret_val1() const { return this->ret_val1_; } ExprPtr ret_val2() const { return this->ret_val2_; } void set_lhs(ExprPtr lhs) { lhs_ = std::move(lhs); } void set_rhs(ExprPtr rhs) { rhs_ = std::move(rhs); } void set_ret_val1(ExprPtr ret_val1) { ret_val1_ = std::move(ret_val1); } void set_ret_val2(ExprPtr ret_val2) { ret_val2_ = std::move(ret_val2); } CompareSelectBias bias() const { return bias_; } static ExprHandle make( const ExprHandle& lhs, const ExprHandle& rhs, CompareSelectOperation cmp_op, CompareSelectBias bias = kUnbiased) { if (lhs.dtype() != rhs.dtype()) { throw malformed_input("bad dtype in CompareSelect"); } return ExprHandle(alloc( lhs.node(), rhs.node(), IntImm::make(1).node(), IntImm::make(0).node(), cmp_op, bias)); } static ExprHandle make( const ExprHandle& lhs, const ExprHandle& rhs, const ExprHandle& ret_val1, const ExprHandle& ret_val2, CompareSelectOperation cmp_op, CompareSelectBias bias = kUnbiased) { if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) { throw malformed_input("bad dtype in CompareSelect"); } return ExprHandle(alloc( lhs.node(), rhs.node(), ret_val1.node(), ret_val2.node(), cmp_op, bias)); } CompareSelect( ExprPtr lhs, ExprPtr rhs, ExprPtr ret_val1, ExprPtr ret_val2, CompareSelectOperation cmp_op, CompareSelectBias bias = kUnbiased) : ExprNodeBase(ret_val1->dtype()), lhs_(std::move(lhs)), rhs_(std::move(rhs)), ret_val1_(std::move(ret_val1)), ret_val2_(std::move(ret_val2)), compare_op_(cmp_op), bias_(bias) {} CompareSelect( ExprPtr lhs, ExprPtr rhs, CompareSelectOperation cmp_op, CompareSelectBias bias = kUnbiased) : ExprNodeBase(kInt), lhs_(std::move(lhs)), rhs_(std::move(rhs)), ret_val1_(alloc(1)), ret_val2_(alloc(0)), compare_op_(cmp_op), bias_(bias) {} private: ExprPtr lhs_; ExprPtr rhs_; ExprPtr ret_val1_; ExprPtr ret_val2_; CompareSelectOperation compare_op_; CompareSelectBias bias_; }; enum IntrinsicsOp { kSin, kCos, kTan, kAsin, kAcos, kAtan, kAtan2, kSinh, kCosh, kTanh, kSigmoid, kExp, kExpm1, kAbs, kLog, kLog2, kLog10, kLog1p, kErf, kErfc, kSqrt, kRsqrt, kPow, kCeil, kFloor, kRound, kTrunc, kFmod, kRemainder, kLgamma, kFrac, kIsNan, kRand, // We need more discussions on this. Should we consider stateful? kMaxIntrinsicsOp, }; class TORCH_API Intrinsics : public ExprNode { public: static ExprHandle make(IntrinsicsOp op_type, const ExprHandle& v1) { return ExprHandle(alloc(op_type, v1.node())); } static ExprHandle make( IntrinsicsOp op_type, const ExprHandle& v1, const ExprHandle& v2) { return ExprHandle(alloc(op_type, v1.node(), v2.node())); } static ExprHandle make( IntrinsicsOp op_type, const std::vector& params) { std::vector params_nodes(params.size()); for (size_t i = 0; i < params.size(); i++) { params_nodes[i] = params[i].node(); } return ExprHandle(alloc(op_type, params_nodes)); } static ExprHandle make(IntrinsicsOp op_type, Dtype dtype) { return ExprHandle(alloc(op_type, dtype)); } IntrinsicsOp op_type() const { return op_type_; } std::string func_name() const { switch (op_type()) { case kSin: return "sin"; case kCos: return "cos"; case kTan: return "tan"; case kAsin: return "asin"; case kAcos: return "acos"; case kAtan: return "atan"; case kAtan2: return "atan2"; case kSinh: return "sinh"; case kCosh: return "cosh"; case kTanh: return "tanh"; case kSigmoid: return "sigmoid"; case kExp: return "exp"; case kAbs: return "abs"; case kLog: return "log"; case kLog2: return "log2"; case kLog10: return "log10"; case kLog1p: return "log1p"; case kErf: return "erf"; case kSqrt: return "sqrt"; case kRsqrt: return "rsqrt"; case kPow: return "pow"; case kCeil: return "ceil"; case kFloor: return "floor"; case kRound: return "round"; case kTrunc: return "trunc"; case kRand: return "rand"; case kFmod: return "fmod"; case kRemainder: return "remainder"; case kLgamma: return "lgamma"; case kExpm1: return "expm1"; case kErfc: return "erfc"; case kFrac: return "frac"; case kIsNan: return "isnan"; default: throw std::runtime_error( "invalid op_type: " + std::to_string(op_type())); } } Intrinsics(IntrinsicsOp op_type, Dtype dtype) : ExprNodeBase(IntrinsicsDtype(op_type, dtype)), params_({}), op_type_(op_type) { if (OpArgCount(op_type) != 0) { throw malformed_input("bad arg count in Intrinsics"); } } Intrinsics(IntrinsicsOp op_type, ExprPtr v1) : ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype())), params_({std::move(v1)}), op_type_(op_type) { if (OpArgCount(op_type) != 1) { throw malformed_input("bad arg count in Intrinsics"); } } Intrinsics(IntrinsicsOp op_type, ExprPtr v1, ExprPtr v2) : ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype(), v2->dtype())), params_({std::move(v1), std::move(v2)}), op_type_(op_type) { if (OpArgCount(op_type) != 2) { throw malformed_input("bad arg count in Intrinsics"); } } Intrinsics(IntrinsicsOp op_type, const std::vector& params) : ExprNodeBase(IntrinsicsDtype(op_type, params)), params_(params), op_type_(op_type) { if (OpArgCount(op_type) != nparams()) { throw malformed_input("bad arg count in Intrinsics"); } } Intrinsics(IntrinsicsOp op_type, Dtype dtype, std::vector params) : ExprNodeBase(IntrinsicsDtype(op_type, dtype)), params_(std::move(params)), op_type_(op_type) { if (OpArgCount(op_type) != nparams()) { throw malformed_input("bad arg count in Intrinsics"); } } bool isPure() const { return op_type_ != kRand; } size_t nparams() const { return params_.size(); } ExprPtr param(size_t index) const { return params_[index]; } const std::vector& params() const { return params_; } void set_params(std::vector params) { params_ = std::move(params); } static size_t OpArgCount(IntrinsicsOp op_type); private: static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1); static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2); static Dtype IntrinsicsDtype( IntrinsicsOp op_type, const std::vector& params); std::vector params_; IntrinsicsOp op_type_; }; TORCH_API std::vector ExprHandleVectorToExprVector( const std::vector& /*v*/); TORCH_API std::vector ExprVectorToExprHandleVector( const std::vector& /*v*/); TORCH_API std::vector VarHandleVectorToVarVector( const std::vector& /*v*/); TORCH_API std::vector VarVectorToVarHandleVector( const std::vector& /*v*/); TORCH_API ExprPtr flatten_index( const std::vector& dims, const std::vector& indices, const std::vector& strides); } // namespace torch::jit::tensorexpr