mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: A few more quality of life improvements for NNC's python bindings: - Use standard `torch.dtype`s (rather than `te.Dtype`) - Make names optional (they don't seem to matter) - Make shapes optional - A few implicit conversions to make code cleaner Followup to https://github.com/pytorch/pytorch/issues/59920 Pull Request resolved: https://github.com/pytorch/pytorch/pull/60058 Reviewed By: bertmaher Differential Revision: D29151953 Pulled By: jansel fbshipit-source-id: c8286e329eb4ee3921ca0786e17248cf6a898bd8
380 lines
11 KiB
C++
380 lines
11 KiB
C++
/**
|
|
* This file implements the core classes for Tensor Expressions.
|
|
*
|
|
* The structure of the expressions is inspired by Halide/TVM IR.
|
|
*/
|
|
#pragma once
|
|
|
|
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
|
|
#include <torch/csrc/jit/tensorexpr/mem_arena.h>
|
|
#include <torch/csrc/jit/tensorexpr/types.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace tensorexpr {
|
|
|
|
enum IRNodeType {
|
|
kPrimitive,
|
|
kAdd,
|
|
kSub,
|
|
kMul,
|
|
kDiv,
|
|
kMod,
|
|
kMax,
|
|
kMin,
|
|
kAnd,
|
|
kOr,
|
|
kLshift,
|
|
kRshift,
|
|
kXor,
|
|
kCompareSelect,
|
|
kCast,
|
|
kBitCast,
|
|
kOther,
|
|
};
|
|
|
|
// The common base between all expression node.
|
|
class TORCH_API Expr : public KernelScopedObject {
|
|
public:
|
|
explicit Expr(Dtype dtype, IRNodeType expr_type = kOther)
|
|
: dtype_(dtype), expr_type_(expr_type) {}
|
|
Dtype dtype() const {
|
|
return dtype_;
|
|
}
|
|
virtual void accept(IRVisitor* visitor) const = 0;
|
|
virtual const Expr* accept_mutator(IRMutator* mutator) const = 0;
|
|
|
|
IRNodeType expr_type() const {
|
|
return expr_type_;
|
|
}
|
|
// Is this a fixed (constant) immediate value.
|
|
virtual bool isConstant() const {
|
|
return false;
|
|
}
|
|
|
|
private:
|
|
Dtype dtype_;
|
|
IRNodeType expr_type_;
|
|
};
|
|
|
|
// A CRTP pattern to accept visitors for children class,
|
|
// and dispatch back to the children.
|
|
template <class Op, class Base = Expr>
|
|
class ExprNode : public Base {
|
|
public:
|
|
using ExprNodeBase = ExprNode<Op>;
|
|
void accept(IRVisitor* visitor) const override {
|
|
visitor->visit(static_cast<const Op*>(this));
|
|
}
|
|
const Expr* accept_mutator(IRMutator* mutator) const override;
|
|
// pass the constructor to the base class
|
|
using Base::Base;
|
|
};
|
|
|
|
// A wrapper object to the underlying ExprNode.
|
|
// Also serves the primary way to build and operate on other expressions.
|
|
class TORCH_API ExprHandle {
|
|
public:
|
|
ExprHandle() = default;
|
|
explicit ExprHandle(const Expr* node)
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
|
: base_expr_node_(const_cast<Expr*>(node)) {}
|
|
|
|
Expr* node() {
|
|
return base_expr_node_;
|
|
}
|
|
|
|
const Expr* node() const {
|
|
return base_expr_node_;
|
|
}
|
|
|
|
bool empty() const {
|
|
return base_expr_node_ == nullptr;
|
|
}
|
|
|
|
#define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v);
|
|
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_EXPR_DECLARE);
|
|
#undef IMM_EXPR_DECLARE
|
|
|
|
template <class Op>
|
|
Op* AsNode() {
|
|
return dynamic_cast<Op*>(this->node());
|
|
}
|
|
|
|
template <class Op>
|
|
const Op* AsNode() const {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
|
return const_cast<ExprHandle*>(this)->AsNode<Op>();
|
|
}
|
|
|
|
Dtype dtype() const {
|
|
return node()->dtype();
|
|
}
|
|
|
|
// Handling the math operators.
|
|
ExprHandle operator+(const ExprHandle& other) const;
|
|
ExprHandle operator-(const ExprHandle& other) const;
|
|
ExprHandle operator*(const ExprHandle& other) const;
|
|
ExprHandle operator/(const ExprHandle& other) const;
|
|
ExprHandle operator%(const ExprHandle& other) const;
|
|
ExprHandle operator==(const ExprHandle& other) const;
|
|
ExprHandle operator!=(const ExprHandle& other) const;
|
|
ExprHandle operator>(const ExprHandle& other) const;
|
|
ExprHandle operator>=(const ExprHandle& other) const;
|
|
ExprHandle operator<(const ExprHandle& other) const;
|
|
ExprHandle operator<=(const ExprHandle& other) const;
|
|
ExprHandle operator&(const ExprHandle& other) const;
|
|
ExprHandle operator|(const ExprHandle& other) const;
|
|
ExprHandle operator&&(const ExprHandle& other) const;
|
|
ExprHandle operator||(const ExprHandle& other) const;
|
|
ExprHandle operator^(const ExprHandle& other) const;
|
|
ExprHandle operator<<(const ExprHandle& other) const;
|
|
ExprHandle operator>>(const ExprHandle& other) const;
|
|
|
|
private:
|
|
Expr* base_expr_node_ = nullptr;
|
|
};
|
|
|
|
// The underlying representation node to a Var.
|
|
// Currently, each Var object represents a unique variable, even though the
|
|
// names might be the same. We should consider add a unique_name as well.
|
|
class TORCH_API Var : public ExprNode<Var> {
|
|
public:
|
|
static ExprHandle make(const std::string& name_hint, Dtype dtype) {
|
|
return ExprHandle(new Var(name_hint, dtype));
|
|
}
|
|
static ExprHandle make(Dtype dtype) {
|
|
return ExprHandle(new Var("", dtype));
|
|
}
|
|
|
|
// TODO: unique_name
|
|
const std::string& name_hint() const {
|
|
return name_hint_;
|
|
}
|
|
|
|
void set_name_hint(const std::string& name_hint) {
|
|
name_hint_ = name_hint;
|
|
}
|
|
|
|
Var(std::string name_hint, Dtype dtype)
|
|
: ExprNodeBase(dtype, kPrimitive), name_hint_(std::move(name_hint)) {}
|
|
|
|
private:
|
|
std::string name_hint_;
|
|
};
|
|
|
|
class TORCH_API Buf : public ExprNode<Buf> {
|
|
public:
|
|
static ExprHandle make(
|
|
const std::string& name_hint,
|
|
const std::vector<ExprHandle>& dims,
|
|
Dtype dtype);
|
|
static ExprHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
|
|
|
|
// TODO: unique_name
|
|
const Var* base_handle() const {
|
|
return base_handle_;
|
|
}
|
|
void set_base_handle(Var* base_handle) {
|
|
base_handle_ = base_handle;
|
|
}
|
|
|
|
const std::string& name_hint() const {
|
|
return base_handle_->name_hint();
|
|
}
|
|
void set_name_hint(const std::string& name_hint) {
|
|
base_handle_->set_name_hint(name_hint);
|
|
}
|
|
|
|
Buf(const std::string& name_hint,
|
|
const std::vector<const Expr*>& dims,
|
|
Dtype dtype,
|
|
const Expr* initializer = nullptr)
|
|
: Buf(new Var(name_hint, kHandle), dims, dtype, initializer) {}
|
|
|
|
Buf(Var* var,
|
|
std::vector<const Expr*> dims,
|
|
Dtype dtype,
|
|
const Expr* initializer = nullptr)
|
|
: ExprNodeBase(dtype, kPrimitive),
|
|
base_handle_(var),
|
|
dims_(std::move(dims)),
|
|
initializer_(initializer) {
|
|
TORCH_CHECK(var);
|
|
}
|
|
|
|
size_t ndim() const {
|
|
return dims_.size();
|
|
}
|
|
const Expr* dim(size_t index) const {
|
|
if (index >= ndim()) {
|
|
throw out_of_range_index();
|
|
}
|
|
return dims_[index];
|
|
}
|
|
std::vector<const Expr*> dims() const {
|
|
return dims_;
|
|
}
|
|
void set_dims(std::vector<const Expr*> dims) {
|
|
dims_ = dims;
|
|
};
|
|
|
|
const Expr* initializer() const {
|
|
return initializer_;
|
|
};
|
|
|
|
bool hasConstantDims() const {
|
|
for (auto d : dims_) {
|
|
if (!d->isConstant()) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
Var* base_handle_;
|
|
std::vector<const Expr*> dims_;
|
|
const Expr* initializer_;
|
|
};
|
|
|
|
class TORCH_API BufHandle : public ExprHandle {
|
|
public:
|
|
BufHandle(
|
|
const std::string& name_hint,
|
|
const std::vector<ExprHandle>& dims,
|
|
Dtype dtype)
|
|
: ExprHandle(Buf::make(name_hint, dims, dtype)) {}
|
|
|
|
BufHandle(const std::vector<ExprHandle>& dims, Dtype dtype)
|
|
: ExprHandle(Buf::make("_", dims, dtype)) {}
|
|
|
|
explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {}
|
|
|
|
explicit BufHandle(const Buf* node) : ExprHandle(node) {}
|
|
const Buf* node() const {
|
|
return static_cast<const Buf*>(ExprHandle::node());
|
|
}
|
|
Buf* node() {
|
|
return static_cast<Buf*>(ExprHandle::node());
|
|
}
|
|
|
|
template <typename... Ts>
|
|
inline ExprHandle load(const Ts&... ts) const;
|
|
|
|
template <typename T>
|
|
inline ExprHandle load(const std::vector<T>& args) const;
|
|
|
|
bool operator==(const BufHandle& other) const {
|
|
return this->node() == other.node();
|
|
}
|
|
bool operator!=(const BufHandle& other) const {
|
|
return !(*this == other);
|
|
}
|
|
|
|
const std::string& name_hint() const {
|
|
return this->node()->name_hint();
|
|
}
|
|
|
|
bool empty() const {
|
|
return (this->node() == nullptr);
|
|
}
|
|
|
|
size_t ndim() const {
|
|
return node()->ndim();
|
|
}
|
|
|
|
std::vector<ExprHandle> dims() const;
|
|
|
|
ExprHandle dim(size_t index) const {
|
|
return ExprHandle(node()->dim(index));
|
|
}
|
|
};
|
|
|
|
// An expression to construct the underlying variable node.
|
|
// Note: do not store any info here, since it is often possible to slice this
|
|
// object. For example: VarHandle x('x'); ExprHandle x2 = x;
|
|
class TORCH_API VarHandle : public ExprHandle {
|
|
public:
|
|
VarHandle() : ExprHandle(nullptr) {}
|
|
explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {}
|
|
VarHandle(const std::string& name_hint, Dtype dtype)
|
|
: ExprHandle(Var::make(name_hint, dtype)) {}
|
|
explicit VarHandle(const Var* node) : ExprHandle(node) {}
|
|
const Var* node() const {
|
|
return static_cast<const Var*>(ExprHandle::node());
|
|
}
|
|
bool operator==(const VarHandle& other) const {
|
|
return this->node() == other.node();
|
|
}
|
|
bool operator!=(const VarHandle& other) const {
|
|
return !(*this == other);
|
|
}
|
|
|
|
const std::string& name_hint() const {
|
|
return this->node()->name_hint();
|
|
}
|
|
bool empty() const {
|
|
return (this->node() == nullptr);
|
|
}
|
|
};
|
|
|
|
template <class Op, class Base>
|
|
const Expr* ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) const {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
|
ExprNode* this_mutable = const_cast<ExprNode*>(this);
|
|
return mutator->mutate(static_cast<Op*>(this_mutable));
|
|
}
|
|
|
|
inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) {
|
|
return expr1.AsNode<Expr>() == expr2.AsNode<Expr>();
|
|
}
|
|
|
|
TORCH_API ExprHandle sin(const ExprHandle& v);
|
|
TORCH_API ExprHandle cos(const ExprHandle& v);
|
|
TORCH_API ExprHandle tan(const ExprHandle& v);
|
|
TORCH_API ExprHandle asin(const ExprHandle& v);
|
|
TORCH_API ExprHandle acos(const ExprHandle& v);
|
|
TORCH_API ExprHandle atan(const ExprHandle& v);
|
|
TORCH_API ExprHandle sinh(const ExprHandle& v);
|
|
TORCH_API ExprHandle cosh(const ExprHandle& v);
|
|
TORCH_API ExprHandle tanh(const ExprHandle& v);
|
|
TORCH_API ExprHandle sigmoid(const ExprHandle& v);
|
|
TORCH_API ExprHandle exp(const ExprHandle& v);
|
|
TORCH_API ExprHandle expm1(const ExprHandle& v);
|
|
TORCH_API ExprHandle abs(const ExprHandle& v);
|
|
TORCH_API ExprHandle log(const ExprHandle& v);
|
|
TORCH_API ExprHandle fast_tanh(const ExprHandle& v);
|
|
TORCH_API ExprHandle fast_sigmoid(const ExprHandle& v);
|
|
TORCH_API ExprHandle fast_log(const ExprHandle& v);
|
|
TORCH_API ExprHandle log_vml(const ExprHandle& v);
|
|
TORCH_API ExprHandle log2(const ExprHandle& v);
|
|
TORCH_API ExprHandle log10(const ExprHandle& v);
|
|
TORCH_API ExprHandle log1p(const ExprHandle& v);
|
|
TORCH_API ExprHandle erf(const ExprHandle& v);
|
|
TORCH_API ExprHandle erfc(const ExprHandle& v);
|
|
TORCH_API ExprHandle sqrt(const ExprHandle& v);
|
|
TORCH_API ExprHandle rsqrt(const ExprHandle& v);
|
|
TORCH_API ExprHandle ceil(const ExprHandle& v);
|
|
TORCH_API ExprHandle floor(const ExprHandle& v);
|
|
TORCH_API ExprHandle round(const ExprHandle& v);
|
|
TORCH_API ExprHandle trunc(const ExprHandle& v);
|
|
TORCH_API ExprHandle frac(const ExprHandle& v);
|
|
TORCH_API ExprHandle lgamma(const ExprHandle& v);
|
|
TORCH_API ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2);
|
|
TORCH_API ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2);
|
|
TORCH_API ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2);
|
|
TORCH_API ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2);
|
|
TORCH_API ExprHandle isnan(const ExprHandle& v1);
|
|
TORCH_API ExprHandle Relu(const ExprHandle& v1);
|
|
|
|
TORCH_API ExprHandle
|
|
ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f);
|
|
|
|
TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes);
|
|
|
|
} // namespace tensorexpr
|
|
} // namespace jit
|
|
} // namespace torch
|