pytorch/torch/csrc/jit/tensorexpr/expr.h
Jason Ansel 85517a2b70 [TensorExpr] More python binding cleanups (#60058)
Summary:
A few more quality of life improvements for NNC's python bindings:
- Use standard `torch.dtype`s (rather than `te.Dtype`)
- Make names optional (they don't seem to matter)
- Make shapes optional
- A few implicit conversions to make code cleaner

Followup to https://github.com/pytorch/pytorch/issues/59920

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60058

Reviewed By: bertmaher

Differential Revision: D29151953

Pulled By: jansel

fbshipit-source-id: c8286e329eb4ee3921ca0786e17248cf6a898bd8
2021-06-16 20:06:08 -07:00

380 lines
11 KiB
C++

/**
* This file implements the core classes for Tensor Expressions.
*
* The structure of the expressions is inspired by Halide/TVM IR.
*/
#pragma once
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/mem_arena.h>
#include <torch/csrc/jit/tensorexpr/types.h>
namespace torch {
namespace jit {
namespace tensorexpr {
enum IRNodeType {
kPrimitive,
kAdd,
kSub,
kMul,
kDiv,
kMod,
kMax,
kMin,
kAnd,
kOr,
kLshift,
kRshift,
kXor,
kCompareSelect,
kCast,
kBitCast,
kOther,
};
// The common base between all expression node.
class TORCH_API Expr : public KernelScopedObject {
public:
explicit Expr(Dtype dtype, IRNodeType expr_type = kOther)
: dtype_(dtype), expr_type_(expr_type) {}
Dtype dtype() const {
return dtype_;
}
virtual void accept(IRVisitor* visitor) const = 0;
virtual const Expr* accept_mutator(IRMutator* mutator) const = 0;
IRNodeType expr_type() const {
return expr_type_;
}
// Is this a fixed (constant) immediate value.
virtual bool isConstant() const {
return false;
}
private:
Dtype dtype_;
IRNodeType expr_type_;
};
// A CRTP pattern to accept visitors for children class,
// and dispatch back to the children.
template <class Op, class Base = Expr>
class ExprNode : public Base {
public:
using ExprNodeBase = ExprNode<Op>;
void accept(IRVisitor* visitor) const override {
visitor->visit(static_cast<const Op*>(this));
}
const Expr* accept_mutator(IRMutator* mutator) const override;
// pass the constructor to the base class
using Base::Base;
};
// A wrapper object to the underlying ExprNode.
// Also serves the primary way to build and operate on other expressions.
class TORCH_API ExprHandle {
public:
ExprHandle() = default;
explicit ExprHandle(const Expr* node)
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
: base_expr_node_(const_cast<Expr*>(node)) {}
Expr* node() {
return base_expr_node_;
}
const Expr* node() const {
return base_expr_node_;
}
bool empty() const {
return base_expr_node_ == nullptr;
}
#define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v);
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_EXPR_DECLARE);
#undef IMM_EXPR_DECLARE
template <class Op>
Op* AsNode() {
return dynamic_cast<Op*>(this->node());
}
template <class Op>
const Op* AsNode() const {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return const_cast<ExprHandle*>(this)->AsNode<Op>();
}
Dtype dtype() const {
return node()->dtype();
}
// Handling the math operators.
ExprHandle operator+(const ExprHandle& other) const;
ExprHandle operator-(const ExprHandle& other) const;
ExprHandle operator*(const ExprHandle& other) const;
ExprHandle operator/(const ExprHandle& other) const;
ExprHandle operator%(const ExprHandle& other) const;
ExprHandle operator==(const ExprHandle& other) const;
ExprHandle operator!=(const ExprHandle& other) const;
ExprHandle operator>(const ExprHandle& other) const;
ExprHandle operator>=(const ExprHandle& other) const;
ExprHandle operator<(const ExprHandle& other) const;
ExprHandle operator<=(const ExprHandle& other) const;
ExprHandle operator&(const ExprHandle& other) const;
ExprHandle operator|(const ExprHandle& other) const;
ExprHandle operator&&(const ExprHandle& other) const;
ExprHandle operator||(const ExprHandle& other) const;
ExprHandle operator^(const ExprHandle& other) const;
ExprHandle operator<<(const ExprHandle& other) const;
ExprHandle operator>>(const ExprHandle& other) const;
private:
Expr* base_expr_node_ = nullptr;
};
// The underlying representation node to a Var.
// Currently, each Var object represents a unique variable, even though the
// names might be the same. We should consider add a unique_name as well.
class TORCH_API Var : public ExprNode<Var> {
public:
static ExprHandle make(const std::string& name_hint, Dtype dtype) {
return ExprHandle(new Var(name_hint, dtype));
}
static ExprHandle make(Dtype dtype) {
return ExprHandle(new Var("", dtype));
}
// TODO: unique_name
const std::string& name_hint() const {
return name_hint_;
}
void set_name_hint(const std::string& name_hint) {
name_hint_ = name_hint;
}
Var(std::string name_hint, Dtype dtype)
: ExprNodeBase(dtype, kPrimitive), name_hint_(std::move(name_hint)) {}
private:
std::string name_hint_;
};
class TORCH_API Buf : public ExprNode<Buf> {
public:
static ExprHandle make(
const std::string& name_hint,
const std::vector<ExprHandle>& dims,
Dtype dtype);
static ExprHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
// TODO: unique_name
const Var* base_handle() const {
return base_handle_;
}
void set_base_handle(Var* base_handle) {
base_handle_ = base_handle;
}
const std::string& name_hint() const {
return base_handle_->name_hint();
}
void set_name_hint(const std::string& name_hint) {
base_handle_->set_name_hint(name_hint);
}
Buf(const std::string& name_hint,
const std::vector<const Expr*>& dims,
Dtype dtype,
const Expr* initializer = nullptr)
: Buf(new Var(name_hint, kHandle), dims, dtype, initializer) {}
Buf(Var* var,
std::vector<const Expr*> dims,
Dtype dtype,
const Expr* initializer = nullptr)
: ExprNodeBase(dtype, kPrimitive),
base_handle_(var),
dims_(std::move(dims)),
initializer_(initializer) {
TORCH_CHECK(var);
}
size_t ndim() const {
return dims_.size();
}
const Expr* dim(size_t index) const {
if (index >= ndim()) {
throw out_of_range_index();
}
return dims_[index];
}
std::vector<const Expr*> dims() const {
return dims_;
}
void set_dims(std::vector<const Expr*> dims) {
dims_ = dims;
};
const Expr* initializer() const {
return initializer_;
};
bool hasConstantDims() const {
for (auto d : dims_) {
if (!d->isConstant()) {
return false;
}
}
return true;
}
private:
Var* base_handle_;
std::vector<const Expr*> dims_;
const Expr* initializer_;
};
class TORCH_API BufHandle : public ExprHandle {
public:
BufHandle(
const std::string& name_hint,
const std::vector<ExprHandle>& dims,
Dtype dtype)
: ExprHandle(Buf::make(name_hint, dims, dtype)) {}
BufHandle(const std::vector<ExprHandle>& dims, Dtype dtype)
: ExprHandle(Buf::make("_", dims, dtype)) {}
explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {}
explicit BufHandle(const Buf* node) : ExprHandle(node) {}
const Buf* node() const {
return static_cast<const Buf*>(ExprHandle::node());
}
Buf* node() {
return static_cast<Buf*>(ExprHandle::node());
}
template <typename... Ts>
inline ExprHandle load(const Ts&... ts) const;
template <typename T>
inline ExprHandle load(const std::vector<T>& args) const;
bool operator==(const BufHandle& other) const {
return this->node() == other.node();
}
bool operator!=(const BufHandle& other) const {
return !(*this == other);
}
const std::string& name_hint() const {
return this->node()->name_hint();
}
bool empty() const {
return (this->node() == nullptr);
}
size_t ndim() const {
return node()->ndim();
}
std::vector<ExprHandle> dims() const;
ExprHandle dim(size_t index) const {
return ExprHandle(node()->dim(index));
}
};
// An expression to construct the underlying variable node.
// Note: do not store any info here, since it is often possible to slice this
// object. For example: VarHandle x('x'); ExprHandle x2 = x;
class TORCH_API VarHandle : public ExprHandle {
public:
VarHandle() : ExprHandle(nullptr) {}
explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {}
VarHandle(const std::string& name_hint, Dtype dtype)
: ExprHandle(Var::make(name_hint, dtype)) {}
explicit VarHandle(const Var* node) : ExprHandle(node) {}
const Var* node() const {
return static_cast<const Var*>(ExprHandle::node());
}
bool operator==(const VarHandle& other) const {
return this->node() == other.node();
}
bool operator!=(const VarHandle& other) const {
return !(*this == other);
}
const std::string& name_hint() const {
return this->node()->name_hint();
}
bool empty() const {
return (this->node() == nullptr);
}
};
template <class Op, class Base>
const Expr* ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) const {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
ExprNode* this_mutable = const_cast<ExprNode*>(this);
return mutator->mutate(static_cast<Op*>(this_mutable));
}
inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) {
return expr1.AsNode<Expr>() == expr2.AsNode<Expr>();
}
TORCH_API ExprHandle sin(const ExprHandle& v);
TORCH_API ExprHandle cos(const ExprHandle& v);
TORCH_API ExprHandle tan(const ExprHandle& v);
TORCH_API ExprHandle asin(const ExprHandle& v);
TORCH_API ExprHandle acos(const ExprHandle& v);
TORCH_API ExprHandle atan(const ExprHandle& v);
TORCH_API ExprHandle sinh(const ExprHandle& v);
TORCH_API ExprHandle cosh(const ExprHandle& v);
TORCH_API ExprHandle tanh(const ExprHandle& v);
TORCH_API ExprHandle sigmoid(const ExprHandle& v);
TORCH_API ExprHandle exp(const ExprHandle& v);
TORCH_API ExprHandle expm1(const ExprHandle& v);
TORCH_API ExprHandle abs(const ExprHandle& v);
TORCH_API ExprHandle log(const ExprHandle& v);
TORCH_API ExprHandle fast_tanh(const ExprHandle& v);
TORCH_API ExprHandle fast_sigmoid(const ExprHandle& v);
TORCH_API ExprHandle fast_log(const ExprHandle& v);
TORCH_API ExprHandle log_vml(const ExprHandle& v);
TORCH_API ExprHandle log2(const ExprHandle& v);
TORCH_API ExprHandle log10(const ExprHandle& v);
TORCH_API ExprHandle log1p(const ExprHandle& v);
TORCH_API ExprHandle erf(const ExprHandle& v);
TORCH_API ExprHandle erfc(const ExprHandle& v);
TORCH_API ExprHandle sqrt(const ExprHandle& v);
TORCH_API ExprHandle rsqrt(const ExprHandle& v);
TORCH_API ExprHandle ceil(const ExprHandle& v);
TORCH_API ExprHandle floor(const ExprHandle& v);
TORCH_API ExprHandle round(const ExprHandle& v);
TORCH_API ExprHandle trunc(const ExprHandle& v);
TORCH_API ExprHandle frac(const ExprHandle& v);
TORCH_API ExprHandle lgamma(const ExprHandle& v);
TORCH_API ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle isnan(const ExprHandle& v1);
TORCH_API ExprHandle Relu(const ExprHandle& v1);
TORCH_API ExprHandle
ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f);
TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes);
} // namespace tensorexpr
} // namespace jit
} // namespace torch