[LT] Store OpKind for each IR subclass in a static field

Summary: Currently OpKind is stored as an object field called op_ for each IR
node, and one usage of op_ is to avoid dynamic_cast in NodeCast when we
need to downcast a base-node pointer into a concrete sub-node pointer.
As a result, we need to construct and pass in an op when downcasting
nodes, and this becomes quite anonnying when we start to implement the
trie-based IR node reusing. More importantly, the op for each subclass
should be unique for that subclass and thus making it a const static field
is a more logical design.

In this PR, we still keep the object-level op_ for easier XLA adoption. As
furture work, we can come back to remove op_, make the op() method
virtual, and get rid of OpKind in all the node constructors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76711

Approved by: https://github.com/wconstab, https://github.com/JackCaoG
This commit is contained in:
Bin Bao 2022-05-06 15:04:59 +00:00 committed by PyTorch MergeBot
parent 8b6a78f39f
commit ac37ddc795
56 changed files with 212 additions and 63 deletions

View File

@ -123,6 +123,7 @@ libtorch_cpp_generated_sources = [
"torch/csrc/autograd/generated/Functions.cpp",
"torch/csrc/autograd/generated/variable_factories.h",
"torch/csrc/lazy/generated/LazyIr.h",
"torch/csrc/lazy/generated/LazyIr.cpp",
"torch/csrc/lazy/generated/LazyNativeFunctions.h",
"torch/csrc/lazy/generated/LazyNativeFunctions.cpp",
"torch/csrc/lazy/generated/RegisterAutogradLazy.cpp",
@ -1914,6 +1915,7 @@ test_suite(
for path in [
"aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
"aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
"aten/src/ATen/templates/LazyIr.cpp",
"aten/src/ATen/templates/LazyIr.h",
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
"aten/src/ATen/native/native_functions.yaml",

View File

@ -0,0 +1,8 @@
// ${generated_comment}
${includes}
${namespace_prologue}
${opkind_definitions}
${namespace_epilogue}

View File

@ -27,6 +27,7 @@ def define_targets(rules):
srcs = [
":DispatchKeyNativeFunctions.cpp",
":DispatchKeyNativeFunctions.h",
":LazyIr.cpp",
":LazyIr.h",
":RegisterDispatchKey.cpp",
":native_functions.yaml",
@ -111,6 +112,7 @@ _GENERATED_CPP = [
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
"torch/csrc/autograd/generated/python_variable_methods.cpp",
"torch/csrc/lazy/generated/LazyIr.cpp",
"torch/csrc/lazy/generated/LazyNativeFunctions.cpp",
"torch/csrc/lazy/generated/RegisterAutogradLazy.cpp",
"torch/csrc/lazy/generated/RegisterLazy.cpp",

View File

@ -352,6 +352,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
)
if(BUILD_LAZY_TS_BACKEND)
list(APPEND GENERATED_CXX_TORCH
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyIr.cpp"
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNativeFunctions.cpp"
"${TORCH_SRC_DIR}/csrc/lazy/generated/RegisterAutogradLazy.cpp"
"${TORCH_SRC_DIR}/csrc/lazy/generated/RegisterLazy.cpp"
@ -432,6 +433,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h"
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
"${TOOLS_PATH}/autograd/templates/VariableType.h"
"${TOOLS_PATH}/autograd/templates/VariableType.cpp"

View File

@ -17,6 +17,8 @@ namespace lazy {
class TestLeafNode : public Node {
public:
static const OpKind class_op_kind;
explicit TestLeafNode(size_t param)
: Node(OpKind(), /* num_outputs */ 1),
hash_(Hash(param)),
@ -38,6 +40,8 @@ class TestLeafNode : public Node {
size_t param_;
};
const OpKind TestLeafNode::class_op_kind = OpKind();
TEST(IrTest, BasicTest) {
NodePtr node1 = MakeNode<TestLeafNode>(1);
NodePtr node2 = MakeNode<TestLeafNode>(2);
@ -45,7 +49,7 @@ TEST(IrTest, BasicTest) {
EXPECT_EQ(node1->num_outputs(), 1);
const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get(), OpKind());
const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get());
EXPECT_TRUE(leafptr != nullptr);
}
@ -102,7 +106,7 @@ TEST(IrTest, TsNodeTest) {
EXPECT_EQ(node1->num_outputs(), 1);
const TsNode* leafptr = NodeCast<TsNode>(node1.get(), OpKind(at::aten::view));
const TsNode* leafptr = dynamic_cast<const TsNode*>(node1.get());
EXPECT_TRUE(leafptr != nullptr);
}

View File

@ -13,8 +13,10 @@ namespace lazy {
class TrieCacheNode : public Node {
public:
static const OpKind class_op_kind;
explicit TrieCacheNode(size_t id)
: Node(OpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
: Node(class_op_kind, /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
~TrieCacheNode() override = default;
bool Equal(size_t id) const {
@ -36,6 +38,8 @@ class TrieCacheNode : public Node {
hash_t hash_;
};
const OpKind TrieCacheNode::class_op_kind = OpKind();
TEST(TrieCacheTest, TestSinglePath) {
FLAGS_torch_lazy_reuse_ir = true;
TrieCache::Get()->Clear();
@ -45,9 +49,9 @@ TEST(TrieCacheTest, TestSinglePath) {
NodePtr c = MakeNode<TrieCacheNode>(2);
TrieCache::Get()->ResetCurrent(); // MarkStep
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
}
@ -67,20 +71,20 @@ TEST(TrieCacheTest, TestTwoPaths) {
NodePtr c = MakeNode<TrieCacheNode>(2);
TrieCache::Get()->ResetCurrent(); // MarkStep
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3);
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(3);
EXPECT_NE(d.get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3).get(), d.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(3).get(), d.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
}

View File

@ -11,6 +11,7 @@
# This is duplicated in caffe2/CMakeLists.txt for now and not yet used in buck
GENERATED_LAZY_TS_CPP = [
"lazy/generated/LazyIr.cpp",
"lazy/generated/LazyNativeFunctions.cpp",
"lazy/generated/RegisterAutogradLazy.cpp",
"lazy/generated/RegisterLazy.cpp",
@ -425,6 +426,7 @@ lazy_tensor_ts_sources = [
"torch/csrc/lazy/ts_backend/ops/expand.cpp",
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
"torch/csrc/lazy/ts_backend/ops/scalar.cpp",
"torch/csrc/lazy/ts_backend/ops/to_copy.cpp",
"torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp",
"torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp",
"torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp",

View File

@ -175,6 +175,8 @@ inline std::ostream& operator<<(std::ostream& stream, const Node& node) {
return stream;
}
// Note: Keep this version of NodeCast for smooth PyTorch/XLA migration, and
// clean up once the migration is done.
template <typename T>
const T* NodeCast(const Node* node, OpKind op) {
if (op != node->op()) {
@ -187,6 +189,18 @@ const T* NodeCast(const Node* node, OpKind op) {
#endif
}
template <typename T>
const T* NodeCast(const Node* node) {
if (T::class_op_kind != node->op()) {
return nullptr;
}
#ifdef NDEBUG
return static_cast<const T*>(node);
#else
return &dynamic_cast<const T&>(*node);
#endif
}
// Represents a specific output produced by a node. Since the output of a node
// can be composed by multiple outputs, the node+index coordinates fully qualify

View File

@ -15,9 +15,9 @@ namespace torch {
namespace lazy {
template <typename T, typename... Args>
NodePtr ReuseNode(OpKind op, Args&&... args) {
NodePtr ReuseNode(Args&&... args) {
if (FLAGS_torch_lazy_reuse_ir) {
return LookupNodeFromTrieCache<T>(op, std::forward<Args>(args)...);
return LookupNodeFromTrieCache<T>(std::forward<Args>(args)...);
}
return nullptr;
}
@ -27,7 +27,7 @@ template <typename T, typename... Args>
NodePtr MakeNode(Args&&... args) {
NodePtr node = std::make_shared<T>(std::forward<Args>(args)...);
if (FLAGS_torch_lazy_reuse_ir) {
// If ir caching is enabled, we need to record all new nodes
// If ir caching is enabled, we need to record all new nodes
TrieCache::Get()->Insert(node);
}
return node;
@ -35,8 +35,8 @@ NodePtr MakeNode(Args&&... args) {
// op is passed in for a more efficient node casting, see the implementation of NodeCast
template <typename T, typename... Args>
NodePtr ReuseOrMakeNode(OpKind op, Args&&... args) {
NodePtr node = ReuseNode<T>(op, std::forward<Args>(args)...);
NodePtr ReuseOrMakeNode(Args&&... args) {
NodePtr node = ReuseNode<T>(std::forward<Args>(args)...);
if (!node) {
node = MakeNode<T>(std::forward<Args>(args)...);
}

View File

@ -52,11 +52,11 @@ class TORCH_API TrieCache {
};
template <typename T, typename... Args>
NodePtr LookupNodeFromTrieCache(OpKind op, Args&&... args) {
NodePtr LookupNodeFromTrieCache(Args&&... args) {
auto& successors = TrieCache::Get()->Current()->successors;
for (auto it = successors.begin(); it != successors.end(); it++) {
NodePtr ir_node = (*it)->ir_node;
const T* concrete_node = NodeCast<T>(ir_node.get(), op);
const T* concrete_node = NodeCast<T>(ir_node.get());
if (concrete_node && concrete_node->Equal(std::forward<Args>(args)...)) {
TORCH_LAZY_COUNTER("IrNodeReused::" + std::string(typeid(T).name()), 1);
TrieCache::Get()->SetCurrent(it);

View File

@ -4,6 +4,9 @@
namespace torch {
namespace lazy {
const OpKind TSNativeBatchNormBackward::class_op_kind(at::aten::native_batch_norm_backward);
const OpKind TSNativeBatchNormForward::class_op_kind(at::aten::native_batch_norm);
TSNativeBatchNormBackward::TSNativeBatchNormBackward(
const torch::lazy::Value& grad_out, const torch::lazy::Value& input,
const torch::lazy::Value& weight, const torch::lazy::Value& running_mean,

View File

@ -8,6 +8,8 @@ namespace lazy {
// Node for the backward batch norm operator.
class TSNativeBatchNormBackward : public torch::lazy::TsNode {
public:
static const OpKind class_op_kind;
TSNativeBatchNormBackward(const torch::lazy::Value& grad_out, const torch::lazy::Value& input,
const torch::lazy::Value& weight, const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var, const torch::lazy::Value& save_mean,
@ -35,6 +37,8 @@ class TSNativeBatchNormBackward : public torch::lazy::TsNode {
class TSNativeBatchNormForward : public torch::lazy::TsNode {
public:
static const OpKind class_op_kind;
TSNativeBatchNormForward(const torch::lazy::Value& input, const torch::lazy::Value& weight,
const torch::lazy::Value& bias, const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var, bool training,

View File

@ -15,6 +15,9 @@ Shape NodeOutputShape(const Value& input, c10::ScalarType type) {
}
} // namespace
const OpKind Cast::class_op_kind(ltc_cast);
Cast::Cast(
const Value& input,
at::ScalarType dtype,

View File

@ -9,6 +9,8 @@ namespace lazy {
class TORCH_API Cast : public TsNode {
public:
static const OpKind class_op_kind;
Cast(
const Value& input,
at::ScalarType dtype,

View File

@ -7,6 +7,8 @@
namespace torch {
namespace lazy {
const OpKind DeviceData::class_op_kind(ltc_device_data);
DeviceData::DeviceData(std::shared_ptr<BackendData> data)
: TsNode(
ltc_device_data,
@ -22,7 +24,7 @@ std::string DeviceData::ToString() const {
}
const DeviceData* DeviceData::Cast(const Node* node) {
return NodeCast<DeviceData>(node, ltc_device_data);
return NodeCast<DeviceData>(node);
}
} // namespace lazy

View File

@ -8,6 +8,8 @@ namespace lazy {
class TORCH_API DeviceData : public TsNode {
public:
static const OpKind class_op_kind;
explicit DeviceData(std::shared_ptr<BackendData> data);
std::string ToString() const override;

View File

@ -3,6 +3,8 @@
namespace torch {
namespace lazy {
const OpKind Expand::class_op_kind(at::aten::expand);
Expand::Expand(
const Value& input,
std::vector<int64_t> size,

View File

@ -9,6 +9,8 @@ namespace lazy {
class TORCH_API Expand : public TsNode {
public:
static const OpKind class_op_kind;
Expand(const Value& input, std::vector<int64_t> size, bool is_scalar_expand);
std::string ToString() const override;

View File

@ -4,6 +4,8 @@
namespace torch {
namespace lazy {
const OpKind Normal::class_op_kind(c10::Symbol::fromQualString("aten::normal_"));
Normal::Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector<torch::lazy::Shape>&& shapes)
: torch::lazy::TsNode(torch::lazy::OpKind(c10::Symbol::fromQualString("aten::normal_")),
{self}, std::move(shapes),

View File

@ -7,6 +7,8 @@ namespace lazy {
class Normal : public torch::lazy::TsNode {
public:
static const OpKind class_op_kind;
Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector<torch::lazy::Shape>&& shapes);
std::string ToString() const override;

View File

@ -10,6 +10,8 @@ namespace lazy {
using at::operator<<;
const OpKind Scalar::class_op_kind(at::prim::Constant);
Scalar::Scalar(const at::Scalar& value, Shape shape)
: TsNode(
OpKind(at::prim::Constant),

View File

@ -12,6 +12,8 @@ namespace lazy {
// computation graph.
class TORCH_API Scalar : public TsNode {
public:
static const OpKind class_op_kind;
Scalar(const at::Scalar& value, Shape shape);
Scalar(const at::Scalar& value, c10::ScalarType type);

View File

@ -0,0 +1,9 @@
#include <torch/csrc/lazy/ts_backend/ops/to_copy.h>
namespace torch {
namespace lazy {
const OpKind ToCopy::class_op_kind(at::aten::_to_copy);
} // namespace lazy
} // namespace torch

View File

@ -12,6 +12,8 @@ namespace lazy {
// the aten/eager fallback necessitating directly implementing the right to(device) behavior
class ToCopy : public torch::lazy::TsNode {
public:
static const OpKind class_op_kind;
ToCopy(const torch::lazy::Value& self, const c10::optional<at::ScalarType>& dtype, const c10::optional<at::Layout>& layout, const c10::optional<at::Device>& device, const c10::optional<bool>& pin_memory, const bool& non_blocking, const c10::optional<at::MemoryFormat>& memory_format, std::vector<torch::lazy::Shape>&& shapes)
: torch::lazy::TsNode(torch::lazy::OpKind(at::aten::_to_copy),
{self}, std::move(shapes),
@ -85,5 +87,6 @@ class ToCopy : public torch::lazy::TsNode {
bool non_blocking;
c10::optional<at::MemoryFormat> memory_format;
};
} // namespace lazy
} // namespace torch

View File

@ -31,7 +31,7 @@ TSLoweringContext::TSLoweringContext(
void TSLoweringContext::AssignOutputOp(
const Output& output,
torch::jit::Value* op) {
auto ts_node = NodeCast<TsNode>(output.node, output.node->op());
const TsNode* ts_node = static_cast<const TsNode*>(output.node);
std::string stack_trace = ts_node->getPythonStacktrace();
if (!stack_trace.empty()) {
op->node()->s_(c10::Symbol::attr("source"), stack_trace);

View File

@ -77,6 +77,8 @@ TSOpVector TsNode::Lower(std::shared_ptr<torch::jit::GraphFunction> function,
return {};
}
const OpKind TensorList::class_op_kind(tensor_list_opkind);
TensorList::TensorList(OpList values)
: TsNode(/*op=*/tensor_list_opkind,
/*operands=*/values,

View File

@ -64,6 +64,8 @@ const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and then implement it as NotImplemented for
// TensorList, also fixing the assertion that would fail.
struct TORCH_API TensorList : public TsNode {
static const OpKind class_op_kind;
TensorList() = delete;
TensorList(OpList values);

View File

@ -55,85 +55,67 @@ class TSNodeLowering : public TSNodeLoweringInterface {
// classes
TSOpVector LowerNonCodegenOps(const torch::lazy::Node* node) {
if (node->op().op == at::aten::as_strided) {
return LowerAsStrided(torch::lazy::NodeCast<torch::lazy::AsStrided>(
node, torch::lazy::OpKind(at::aten::as_strided)));
return LowerAsStrided(torch::lazy::NodeCast<torch::lazy::AsStrided>(node));
}
if (node->op() == *torch::lazy::ltc_as_strided_view_update) {
return LowerAsStridedViewUpdate(
torch::lazy::NodeCast<torch::lazy::AsStridedViewUpdate>(
node, *torch::lazy::ltc_as_strided_view_update));
torch::lazy::NodeCast<torch::lazy::AsStridedViewUpdate>(node));
}
if (node->op() == *torch::lazy::ltc_cast) {
return LowerCast(torch::lazy::NodeCast<torch::lazy::Cast>(
node, *torch::lazy::ltc_cast));
return LowerCast(torch::lazy::NodeCast<torch::lazy::Cast>(node));
}
if (node->op() == *torch::lazy::ltc_select_view_update) {
return LowerSelectViewUpdate(
torch::lazy::NodeCast<torch::lazy::SelectViewUpdate>(
node, *torch::lazy::ltc_select_view_update));
torch::lazy::NodeCast<torch::lazy::SelectViewUpdate>(node));
}
if (node->op() == *torch::lazy::ltc_narrow_view_update) {
return LowerNarrowViewUpdate(
torch::lazy::NodeCast<torch::lazy::NarrowViewUpdate>(
node, *torch::lazy::ltc_narrow_view_update));
torch::lazy::NodeCast<torch::lazy::NarrowViewUpdate>(node));
}
if (node->op().op == at::prim::Constant) {
return LowerScalar(torch::lazy::NodeCast<torch::lazy::Scalar>(
node, torch::lazy::OpKind(at::prim::Constant)));
return LowerScalar(torch::lazy::NodeCast<torch::lazy::Scalar>(node));
}
if (node->op().op == at::aten::native_batch_norm) {
return LowerBatchNorm(
torch::lazy::NodeCast<TSNativeBatchNormForward>(
node, torch::lazy::OpKind(at::aten::native_batch_norm)));
torch::lazy::NodeCast<TSNativeBatchNormForward>(node));
}
if (node->op().op == at::aten::native_batch_norm_backward) {
return LowerBatchNormBackward(
torch::lazy::NodeCast<TSNativeBatchNormBackward>(
node, torch::lazy::OpKind(at::aten::native_batch_norm_backward)));
torch::lazy::NodeCast<TSNativeBatchNormBackward>(node));
}
if (node->op().op == at::aten::expand) {
return LowerExpand(
torch::lazy::NodeCast<torch::lazy::Expand>(
node, torch::lazy::OpKind(at::aten::expand)));
torch::lazy::NodeCast<torch::lazy::Expand>(node));
}
if (node->op().op == at::aten::narrow) {
return LowerNarrow(torch::lazy::NodeCast<torch::lazy::Narrow>(
node, torch::lazy::OpKind(at::aten::narrow)));
return LowerNarrow(torch::lazy::NodeCast<torch::lazy::Narrow>(node));
}
if (node->op().op == at::aten::permute) {
return LowerPermute(torch::lazy::NodeCast<torch::lazy::Permute>(
node, torch::lazy::OpKind(at::aten::permute)));
return LowerPermute(torch::lazy::NodeCast<torch::lazy::Permute>(node));
}
if (node->op().op == at::aten::select) {
return LowerSelect(torch::lazy::NodeCast<torch::lazy::Select>(
node, torch::lazy::OpKind(at::aten::select)));
return LowerSelect(torch::lazy::NodeCast<torch::lazy::Select>(node));
}
if (node->op().op == at::aten::squeeze) {
return LowerSqueeze(
torch::lazy::NodeCast<Squeeze>(
node, torch::lazy::OpKind(at::aten::squeeze)));
torch::lazy::NodeCast<Squeeze>(node));
}
if (node->op().op == at::aten::unsqueeze) {
return LowerUnsqueeze(
torch::lazy::NodeCast<Unsqueeze>(
node, torch::lazy::OpKind(at::aten::unsqueeze)));
torch::lazy::NodeCast<Unsqueeze>(node));
}
if (node->op().op == at::aten::view) {
return LowerView(torch::lazy::NodeCast<torch::lazy::View>(
node, torch::lazy::OpKind(at::aten::view)));
return LowerView(torch::lazy::NodeCast<torch::lazy::View>(node));
}
if (node->op().op == at::aten::diagonal) {
return LowerDiagonal(torch::lazy::NodeCast<torch::lazy::Diagonal>(
node, torch::lazy::OpKind(at::aten::diagonal)));
return LowerDiagonal(torch::lazy::NodeCast<torch::lazy::Diagonal>(node));
}
if (node->op() == *torch::lazy::ltc_diagonal_view_update) {
return LowerDiagonalViewUpdate(torch::lazy::NodeCast<torch::lazy::DiagonalViewUpdate>(
node, *torch::lazy::ltc_diagonal_view_update));
return LowerDiagonalViewUpdate(torch::lazy::NodeCast<torch::lazy::DiagonalViewUpdate>(node));
}
if (node->op() == *torch::lazy::ltc_device_data) {
const torch::lazy::DeviceData* device_data_node =
torch::lazy::NodeCast<torch::lazy::DeviceData>(
node, *torch::lazy::ltc_device_data);
torch::lazy::NodeCast<torch::lazy::DeviceData>(node);
auto infoptr = device_data_node->data()->info();
auto deviceDataInfoPtr = (torch::lazy::LazyGraphExecutor::DeviceDataInfo*) infoptr;
if (GRAPH_DUMP_ENABLED) {

View File

@ -8,6 +8,8 @@
namespace torch {
namespace lazy {
const OpKind AsStrided::class_op_kind(at::aten::as_strided);
AsStrided::AsStrided(
const Value& input,
std::vector<int64_t> size,

View File

@ -9,6 +9,8 @@ namespace lazy {
class TORCH_API AsStrided : public TsNode {
public:
static const OpKind class_op_kind;
AsStrided(
const Value& input,
std::vector<int64_t> size,

View File

@ -7,6 +7,8 @@
namespace torch {
namespace lazy {
const OpKind AsStridedViewUpdate::class_op_kind(ltc_as_strided_view_update);
AsStridedViewUpdate::AsStridedViewUpdate(
const Value& target,
const Value& input,

View File

@ -9,6 +9,8 @@ namespace lazy {
class TORCH_API AsStridedViewUpdate : public TsNode {
public:
static const OpKind class_op_kind;
AsStridedViewUpdate(
const Value& target,
const Value& input,

View File

@ -7,6 +7,8 @@
namespace torch {
namespace lazy {
const OpKind Diagonal::class_op_kind(at::aten::diagonal);
Diagonal::Diagonal(
const Value& input,
int64_t offset,

View File

@ -7,6 +7,8 @@ namespace lazy {
class TORCH_API Diagonal : public TsNode {
public:
static const OpKind class_op_kind;
Diagonal(const Value& input, int64_t offset, int64_t dim1, int64_t dim2);
std::string ToString() const override;

View File

@ -5,6 +5,8 @@
namespace torch {
namespace lazy {
const OpKind DiagonalViewUpdate::class_op_kind(ltc_diagonal_view_update);
DiagonalViewUpdate::DiagonalViewUpdate(
const Value& target,
const Value& input,

View File

@ -7,6 +7,8 @@ namespace lazy {
class TORCH_API DiagonalViewUpdate : public TsNode {
public:
static const OpKind class_op_kind;
DiagonalViewUpdate(
const Value& target,
const Value& input,

View File

@ -5,6 +5,8 @@
namespace torch {
namespace lazy {
const OpKind Narrow::class_op_kind(at::aten::narrow);
Narrow::Narrow(
const Value& input,
c10::ArrayRef<int64_t> base_indices,

View File

@ -7,6 +7,8 @@ namespace lazy {
class TORCH_API Narrow : public TsNode {
public:
static const OpKind class_op_kind;
Narrow(
const Value& input,
c10::ArrayRef<int64_t> base_indices,

View File

@ -5,6 +5,8 @@
namespace torch {
namespace lazy {
const OpKind NarrowViewUpdate::class_op_kind(ltc_narrow_view_update);
NarrowViewUpdate::NarrowViewUpdate(
const Value& input,
const Value& source,

View File

@ -7,6 +7,8 @@ namespace lazy {
class TORCH_API NarrowViewUpdate : public TsNode {
public:
static const OpKind class_op_kind;
NarrowViewUpdate(
const Value& input,
const Value& source,

View File

@ -6,6 +6,8 @@
namespace torch {
namespace lazy {
const OpKind Permute::class_op_kind(at::aten::permute);
Permute::Permute(const Value& input, std::vector<int64_t> dims)
: TsNode(
OpKind(at::aten::permute),

View File

@ -7,6 +7,8 @@ namespace lazy {
class TORCH_API Permute : public TsNode {
public:
static const OpKind class_op_kind;
Permute(const Value& input, std::vector<int64_t> dims);
std::string ToString() const override;

View File

@ -4,13 +4,14 @@ namespace torch {
namespace lazy {
namespace {
Shape NodeOutputShape(const Value& input, c10::ArrayRef<int64_t> size) {
return Shape(input.shape().scalar_type(), size);
}
} // namespace
const OpKind Resize::class_op_kind(at::aten::resize);
Resize::Resize(const Value& input, std::vector<int64_t> size)
: TsNode(
OpKind(at::aten::resize),

View File

@ -7,6 +7,8 @@ namespace lazy {
class TORCH_API Resize : public TsNode {
public:
static const OpKind class_op_kind;
Resize(const Value& input, std::vector<int64_t> size);
std::string ToString() const override;

View File

@ -6,6 +6,8 @@
namespace torch {
namespace lazy {
const OpKind Select::class_op_kind(at::aten::select);
Select::Select(
const Value& input,
int64_t dim,

View File

@ -7,6 +7,8 @@ namespace lazy {
class TORCH_API Select : public TsNode {
public:
static const OpKind class_op_kind;
Select(
const Value& input,
int64_t dim,

View File

@ -7,6 +7,8 @@
namespace torch {
namespace lazy {
const OpKind SelectViewUpdate::class_op_kind(ltc_select_view_update);
SelectViewUpdate::SelectViewUpdate(
const Value& target,
const Value& source,

View File

@ -7,6 +7,8 @@ namespace lazy {
class TORCH_API SelectViewUpdate : public TsNode {
public:
static const OpKind class_op_kind;
SelectViewUpdate(
const Value& target,
const Value& source,

View File

@ -6,6 +6,8 @@
namespace torch {
namespace lazy {
const OpKind Squeeze::class_op_kind(at::aten::squeeze);
Squeeze::Squeeze(const torch::lazy::Value& input, int dim)
: torch::lazy::TsNode(torch::lazy::OpKind(at::aten::squeeze), {input},
/*num_outputs=*/1, torch::lazy::MHash(dim)),

View File

@ -7,6 +7,8 @@ namespace lazy {
class TORCH_API Squeeze : public TsNode {
public:
static const OpKind class_op_kind;
// Squeeze out the specified dimension index, -1 for all trivial dimensions.
Squeeze(const torch::lazy::Value& input, int dim);

View File

@ -5,6 +5,8 @@
namespace torch {
namespace lazy {
const OpKind Unsqueeze::class_op_kind(at::aten::unsqueeze);
Unsqueeze::Unsqueeze(const torch::lazy::Value& input, int dim)
: torch::lazy::TsNode(
torch::lazy::OpKind(at::aten::unsqueeze),

View File

@ -7,6 +7,8 @@ namespace lazy {
class TORCH_API Unsqueeze : public TsNode {
public:
static const OpKind class_op_kind;
Unsqueeze(const torch::lazy::Value& input, int dim);
std::string ToString() const override;

View File

@ -6,7 +6,6 @@ namespace torch {
namespace lazy {
namespace {
Shape NodeOutputShape(const Value& input, c10::ArrayRef<int64_t> output_sizes) {
const Shape& input_shape = input.shape();
const auto complete_output_sizes =
@ -16,6 +15,8 @@ Shape NodeOutputShape(const Value& input, c10::ArrayRef<int64_t> output_sizes) {
} // namespace
const OpKind View::class_op_kind(at::aten::view);
View::View(const Value& input, std::vector<int64_t> output_size)
: TsNode(
OpKind(at::aten::view),

View File

@ -9,6 +9,8 @@ namespace lazy {
class TORCH_API View : public TsNode {
public:
static const OpKind class_op_kind;
View(const Value& input, std::vector<int64_t> output_size);
std::string ToString() const override;

View File

@ -119,6 +119,16 @@ class GenLazyIR(ABC):
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
return self.gen(f)
@method_with_native_function
def gen_opkind_definition(
self, f: Union[NativeFunctionsGroup, NativeFunction]
) -> List[str]:
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
schema = LazyIrSchema(func)
return [
f"const OpKind {schema.node_name}::class_op_kind{{{aten_symbol(schema)}}};"
]
# there is no lowering functionality generated unless this IR base class is subclassed and
# implemented as a backend-specific node
def lowering_function(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
@ -202,6 +212,8 @@ class GenLazyIR(ABC):
f"""\
class {schema.node_name} : public {self.node_base} {{
public:
static const OpKind class_op_kind;
{schema.node_name}({node_ctor_args}, std::vector<Shape>&& shapes)
: {self.node_base_ctor_call(schema)}{comma_if_scalar_initializers}
{scalar_initializers}
@ -221,7 +233,6 @@ class {schema.node_name} : public {self.node_base} {{
{scalar_decls}
{has_optional_decls}
}};
""",
@ -394,6 +405,7 @@ class GenLazyNativeFuncDefinition:
schema = LazyIrSchema(func.func)
return [
f"""\
{sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
{self.force_eager_fallback(func, schema)}
{self.metrics(func, schema)}

View File

@ -524,6 +524,29 @@ def run_gen_lazy_tensor(
"namespace_epilogue": ns_helper.epilogue,
},
)
# Generate OpKind definitions for IR node classes
fm.write_with_template(
"LazyIr.cpp",
"LazyIr.cpp",
lambda: {
"includes": [
f"#include <{path}>"
for path in [
f"{output_dir}/LazyIr.h",
]
],
"opkind_definitions": list(
concat_map_codegen(
lazy_ir_generator(
backend_indices[backend_key], node_base
).gen_opkind_definition,
grouped_native_functions,
)
),
"namespace_prologue": ns_helper.prologue,
"namespace_epilogue": ns_helper.epilogue,
},
)
if __name__ == "__main__":