mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[LT] Store OpKind for each IR subclass in a static field
Summary: Currently OpKind is stored as an object field called op_ for each IR node, and one usage of op_ is to avoid dynamic_cast in NodeCast when we need to downcast a base-node pointer into a concrete sub-node pointer. As a result, we need to construct and pass in an op when downcasting nodes, and this becomes quite anonnying when we start to implement the trie-based IR node reusing. More importantly, the op for each subclass should be unique for that subclass and thus making it a const static field is a more logical design. In this PR, we still keep the object-level op_ for easier XLA adoption. As furture work, we can come back to remove op_, make the op() method virtual, and get rid of OpKind in all the node constructors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76711 Approved by: https://github.com/wconstab, https://github.com/JackCaoG
This commit is contained in:
parent
8b6a78f39f
commit
ac37ddc795
|
|
@ -123,6 +123,7 @@ libtorch_cpp_generated_sources = [
|
|||
"torch/csrc/autograd/generated/Functions.cpp",
|
||||
"torch/csrc/autograd/generated/variable_factories.h",
|
||||
"torch/csrc/lazy/generated/LazyIr.h",
|
||||
"torch/csrc/lazy/generated/LazyIr.cpp",
|
||||
"torch/csrc/lazy/generated/LazyNativeFunctions.h",
|
||||
"torch/csrc/lazy/generated/LazyNativeFunctions.cpp",
|
||||
"torch/csrc/lazy/generated/RegisterAutogradLazy.cpp",
|
||||
|
|
@ -1914,6 +1915,7 @@ test_suite(
|
|||
for path in [
|
||||
"aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
|
||||
"aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
|
||||
"aten/src/ATen/templates/LazyIr.cpp",
|
||||
"aten/src/ATen/templates/LazyIr.h",
|
||||
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
|
|
|
|||
8
aten/src/ATen/templates/LazyIr.cpp
Normal file
8
aten/src/ATen/templates/LazyIr.cpp
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
// ${generated_comment}
|
||||
${includes}
|
||||
|
||||
${namespace_prologue}
|
||||
|
||||
${opkind_definitions}
|
||||
|
||||
${namespace_epilogue}
|
||||
|
|
@ -27,6 +27,7 @@ def define_targets(rules):
|
|||
srcs = [
|
||||
":DispatchKeyNativeFunctions.cpp",
|
||||
":DispatchKeyNativeFunctions.h",
|
||||
":LazyIr.cpp",
|
||||
":LazyIr.h",
|
||||
":RegisterDispatchKey.cpp",
|
||||
":native_functions.yaml",
|
||||
|
|
@ -111,6 +112,7 @@ _GENERATED_CPP = [
|
|||
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
|
||||
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
|
||||
"torch/csrc/autograd/generated/python_variable_methods.cpp",
|
||||
"torch/csrc/lazy/generated/LazyIr.cpp",
|
||||
"torch/csrc/lazy/generated/LazyNativeFunctions.cpp",
|
||||
"torch/csrc/lazy/generated/RegisterAutogradLazy.cpp",
|
||||
"torch/csrc/lazy/generated/RegisterLazy.cpp",
|
||||
|
|
|
|||
|
|
@ -352,6 +352,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
)
|
||||
if(BUILD_LAZY_TS_BACKEND)
|
||||
list(APPEND GENERATED_CXX_TORCH
|
||||
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyIr.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNativeFunctions.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/lazy/generated/RegisterAutogradLazy.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/lazy/generated/RegisterLazy.cpp"
|
||||
|
|
@ -432,6 +433,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.cpp"
|
||||
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/VariableType.h"
|
||||
"${TOOLS_PATH}/autograd/templates/VariableType.cpp"
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ namespace lazy {
|
|||
|
||||
class TestLeafNode : public Node {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
explicit TestLeafNode(size_t param)
|
||||
: Node(OpKind(), /* num_outputs */ 1),
|
||||
hash_(Hash(param)),
|
||||
|
|
@ -38,6 +40,8 @@ class TestLeafNode : public Node {
|
|||
size_t param_;
|
||||
};
|
||||
|
||||
const OpKind TestLeafNode::class_op_kind = OpKind();
|
||||
|
||||
TEST(IrTest, BasicTest) {
|
||||
NodePtr node1 = MakeNode<TestLeafNode>(1);
|
||||
NodePtr node2 = MakeNode<TestLeafNode>(2);
|
||||
|
|
@ -45,7 +49,7 @@ TEST(IrTest, BasicTest) {
|
|||
|
||||
EXPECT_EQ(node1->num_outputs(), 1);
|
||||
|
||||
const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get(), OpKind());
|
||||
const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get());
|
||||
EXPECT_TRUE(leafptr != nullptr);
|
||||
}
|
||||
|
||||
|
|
@ -102,7 +106,7 @@ TEST(IrTest, TsNodeTest) {
|
|||
|
||||
EXPECT_EQ(node1->num_outputs(), 1);
|
||||
|
||||
const TsNode* leafptr = NodeCast<TsNode>(node1.get(), OpKind(at::aten::view));
|
||||
const TsNode* leafptr = dynamic_cast<const TsNode*>(node1.get());
|
||||
EXPECT_TRUE(leafptr != nullptr);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,8 +13,10 @@ namespace lazy {
|
|||
|
||||
class TrieCacheNode : public Node {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
explicit TrieCacheNode(size_t id)
|
||||
: Node(OpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
|
||||
: Node(class_op_kind, /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
|
||||
~TrieCacheNode() override = default;
|
||||
|
||||
bool Equal(size_t id) const {
|
||||
|
|
@ -36,6 +38,8 @@ class TrieCacheNode : public Node {
|
|||
hash_t hash_;
|
||||
};
|
||||
|
||||
const OpKind TrieCacheNode::class_op_kind = OpKind();
|
||||
|
||||
TEST(TrieCacheTest, TestSinglePath) {
|
||||
FLAGS_torch_lazy_reuse_ir = true;
|
||||
TrieCache::Get()->Clear();
|
||||
|
|
@ -45,9 +49,9 @@ TEST(TrieCacheTest, TestSinglePath) {
|
|||
NodePtr c = MakeNode<TrieCacheNode>(2);
|
||||
TrieCache::Get()->ResetCurrent(); // MarkStep
|
||||
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
|
||||
TrieCache::Get()->ResetCurrent(); // MarkStep
|
||||
}
|
||||
|
||||
|
|
@ -67,20 +71,20 @@ TEST(TrieCacheTest, TestTwoPaths) {
|
|||
NodePtr c = MakeNode<TrieCacheNode>(2);
|
||||
TrieCache::Get()->ResetCurrent(); // MarkStep
|
||||
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
|
||||
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3);
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
|
||||
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(3);
|
||||
EXPECT_NE(d.get(), c.get());
|
||||
TrieCache::Get()->ResetCurrent(); // MarkStep
|
||||
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3).get(), d.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(3).get(), d.get());
|
||||
TrieCache::Get()->ResetCurrent(); // MarkStep
|
||||
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
|
||||
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
|
||||
TrieCache::Get()->ResetCurrent(); // MarkStep
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
# This is duplicated in caffe2/CMakeLists.txt for now and not yet used in buck
|
||||
GENERATED_LAZY_TS_CPP = [
|
||||
"lazy/generated/LazyIr.cpp",
|
||||
"lazy/generated/LazyNativeFunctions.cpp",
|
||||
"lazy/generated/RegisterAutogradLazy.cpp",
|
||||
"lazy/generated/RegisterLazy.cpp",
|
||||
|
|
@ -425,6 +426,7 @@ lazy_tensor_ts_sources = [
|
|||
"torch/csrc/lazy/ts_backend/ops/expand.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/scalar.cpp",
|
||||
"torch/csrc/lazy/ts_backend/ops/to_copy.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp",
|
||||
"torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp",
|
||||
|
|
|
|||
|
|
@ -175,6 +175,8 @@ inline std::ostream& operator<<(std::ostream& stream, const Node& node) {
|
|||
return stream;
|
||||
}
|
||||
|
||||
// Note: Keep this version of NodeCast for smooth PyTorch/XLA migration, and
|
||||
// clean up once the migration is done.
|
||||
template <typename T>
|
||||
const T* NodeCast(const Node* node, OpKind op) {
|
||||
if (op != node->op()) {
|
||||
|
|
@ -187,6 +189,18 @@ const T* NodeCast(const Node* node, OpKind op) {
|
|||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* NodeCast(const Node* node) {
|
||||
if (T::class_op_kind != node->op()) {
|
||||
return nullptr;
|
||||
}
|
||||
#ifdef NDEBUG
|
||||
return static_cast<const T*>(node);
|
||||
#else
|
||||
return &dynamic_cast<const T&>(*node);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// Represents a specific output produced by a node. Since the output of a node
|
||||
// can be composed by multiple outputs, the node+index coordinates fully qualify
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ namespace torch {
|
|||
namespace lazy {
|
||||
|
||||
template <typename T, typename... Args>
|
||||
NodePtr ReuseNode(OpKind op, Args&&... args) {
|
||||
NodePtr ReuseNode(Args&&... args) {
|
||||
if (FLAGS_torch_lazy_reuse_ir) {
|
||||
return LookupNodeFromTrieCache<T>(op, std::forward<Args>(args)...);
|
||||
return LookupNodeFromTrieCache<T>(std::forward<Args>(args)...);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -27,7 +27,7 @@ template <typename T, typename... Args>
|
|||
NodePtr MakeNode(Args&&... args) {
|
||||
NodePtr node = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
if (FLAGS_torch_lazy_reuse_ir) {
|
||||
// If ir caching is enabled, we need to record all new nodes
|
||||
// If ir caching is enabled, we need to record all new nodes
|
||||
TrieCache::Get()->Insert(node);
|
||||
}
|
||||
return node;
|
||||
|
|
@ -35,8 +35,8 @@ NodePtr MakeNode(Args&&... args) {
|
|||
|
||||
// op is passed in for a more efficient node casting, see the implementation of NodeCast
|
||||
template <typename T, typename... Args>
|
||||
NodePtr ReuseOrMakeNode(OpKind op, Args&&... args) {
|
||||
NodePtr node = ReuseNode<T>(op, std::forward<Args>(args)...);
|
||||
NodePtr ReuseOrMakeNode(Args&&... args) {
|
||||
NodePtr node = ReuseNode<T>(std::forward<Args>(args)...);
|
||||
if (!node) {
|
||||
node = MakeNode<T>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,11 +52,11 @@ class TORCH_API TrieCache {
|
|||
};
|
||||
|
||||
template <typename T, typename... Args>
|
||||
NodePtr LookupNodeFromTrieCache(OpKind op, Args&&... args) {
|
||||
NodePtr LookupNodeFromTrieCache(Args&&... args) {
|
||||
auto& successors = TrieCache::Get()->Current()->successors;
|
||||
for (auto it = successors.begin(); it != successors.end(); it++) {
|
||||
NodePtr ir_node = (*it)->ir_node;
|
||||
const T* concrete_node = NodeCast<T>(ir_node.get(), op);
|
||||
const T* concrete_node = NodeCast<T>(ir_node.get());
|
||||
if (concrete_node && concrete_node->Equal(std::forward<Args>(args)...)) {
|
||||
TORCH_LAZY_COUNTER("IrNodeReused::" + std::string(typeid(T).name()), 1);
|
||||
TrieCache::Get()->SetCurrent(it);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,9 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind TSNativeBatchNormBackward::class_op_kind(at::aten::native_batch_norm_backward);
|
||||
const OpKind TSNativeBatchNormForward::class_op_kind(at::aten::native_batch_norm);
|
||||
|
||||
TSNativeBatchNormBackward::TSNativeBatchNormBackward(
|
||||
const torch::lazy::Value& grad_out, const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight, const torch::lazy::Value& running_mean,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ namespace lazy {
|
|||
// Node for the backward batch norm operator.
|
||||
class TSNativeBatchNormBackward : public torch::lazy::TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
TSNativeBatchNormBackward(const torch::lazy::Value& grad_out, const torch::lazy::Value& input,
|
||||
const torch::lazy::Value& weight, const torch::lazy::Value& running_mean,
|
||||
const torch::lazy::Value& running_var, const torch::lazy::Value& save_mean,
|
||||
|
|
@ -35,6 +37,8 @@ class TSNativeBatchNormBackward : public torch::lazy::TsNode {
|
|||
|
||||
class TSNativeBatchNormForward : public torch::lazy::TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
TSNativeBatchNormForward(const torch::lazy::Value& input, const torch::lazy::Value& weight,
|
||||
const torch::lazy::Value& bias, const torch::lazy::Value& running_mean,
|
||||
const torch::lazy::Value& running_var, bool training,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ Shape NodeOutputShape(const Value& input, c10::ScalarType type) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const OpKind Cast::class_op_kind(ltc_cast);
|
||||
|
||||
Cast::Cast(
|
||||
const Value& input,
|
||||
at::ScalarType dtype,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API Cast : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
Cast(
|
||||
const Value& input,
|
||||
at::ScalarType dtype,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind DeviceData::class_op_kind(ltc_device_data);
|
||||
|
||||
DeviceData::DeviceData(std::shared_ptr<BackendData> data)
|
||||
: TsNode(
|
||||
ltc_device_data,
|
||||
|
|
@ -22,7 +24,7 @@ std::string DeviceData::ToString() const {
|
|||
}
|
||||
|
||||
const DeviceData* DeviceData::Cast(const Node* node) {
|
||||
return NodeCast<DeviceData>(node, ltc_device_data);
|
||||
return NodeCast<DeviceData>(node);
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API DeviceData : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
explicit DeviceData(std::shared_ptr<BackendData> data);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind Expand::class_op_kind(at::aten::expand);
|
||||
|
||||
Expand::Expand(
|
||||
const Value& input,
|
||||
std::vector<int64_t> size,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API Expand : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
Expand(const Value& input, std::vector<int64_t> size, bool is_scalar_expand);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind Normal::class_op_kind(c10::Symbol::fromQualString("aten::normal_"));
|
||||
|
||||
Normal::Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector<torch::lazy::Shape>&& shapes)
|
||||
: torch::lazy::TsNode(torch::lazy::OpKind(c10::Symbol::fromQualString("aten::normal_")),
|
||||
{self}, std::move(shapes),
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class Normal : public torch::lazy::TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector<torch::lazy::Shape>&& shapes);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ namespace lazy {
|
|||
|
||||
using at::operator<<;
|
||||
|
||||
const OpKind Scalar::class_op_kind(at::prim::Constant);
|
||||
|
||||
Scalar::Scalar(const at::Scalar& value, Shape shape)
|
||||
: TsNode(
|
||||
OpKind(at::prim::Constant),
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ namespace lazy {
|
|||
// computation graph.
|
||||
class TORCH_API Scalar : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
Scalar(const at::Scalar& value, Shape shape);
|
||||
Scalar(const at::Scalar& value, c10::ScalarType type);
|
||||
|
||||
|
|
|
|||
9
torch/csrc/lazy/ts_backend/ops/to_copy.cpp
Normal file
9
torch/csrc/lazy/ts_backend/ops/to_copy.cpp
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
#include <torch/csrc/lazy/ts_backend/ops/to_copy.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind ToCopy::class_op_kind(at::aten::_to_copy);
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
@ -12,6 +12,8 @@ namespace lazy {
|
|||
// the aten/eager fallback necessitating directly implementing the right to(device) behavior
|
||||
class ToCopy : public torch::lazy::TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
ToCopy(const torch::lazy::Value& self, const c10::optional<at::ScalarType>& dtype, const c10::optional<at::Layout>& layout, const c10::optional<at::Device>& device, const c10::optional<bool>& pin_memory, const bool& non_blocking, const c10::optional<at::MemoryFormat>& memory_format, std::vector<torch::lazy::Shape>&& shapes)
|
||||
: torch::lazy::TsNode(torch::lazy::OpKind(at::aten::_to_copy),
|
||||
{self}, std::move(shapes),
|
||||
|
|
@ -85,5 +87,6 @@ class ToCopy : public torch::lazy::TsNode {
|
|||
bool non_blocking;
|
||||
c10::optional<at::MemoryFormat> memory_format;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ TSLoweringContext::TSLoweringContext(
|
|||
void TSLoweringContext::AssignOutputOp(
|
||||
const Output& output,
|
||||
torch::jit::Value* op) {
|
||||
auto ts_node = NodeCast<TsNode>(output.node, output.node->op());
|
||||
const TsNode* ts_node = static_cast<const TsNode*>(output.node);
|
||||
std::string stack_trace = ts_node->getPythonStacktrace();
|
||||
if (!stack_trace.empty()) {
|
||||
op->node()->s_(c10::Symbol::attr("source"), stack_trace);
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ TSOpVector TsNode::Lower(std::shared_ptr<torch::jit::GraphFunction> function,
|
|||
return {};
|
||||
}
|
||||
|
||||
const OpKind TensorList::class_op_kind(tensor_list_opkind);
|
||||
|
||||
TensorList::TensorList(OpList values)
|
||||
: TsNode(/*op=*/tensor_list_opkind,
|
||||
/*operands=*/values,
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list");
|
|||
// TODO(whc) once Shape() API is moved to Node base, also make it virtual, and then implement it as NotImplemented for
|
||||
// TensorList, also fixing the assertion that would fail.
|
||||
struct TORCH_API TensorList : public TsNode {
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
TensorList() = delete;
|
||||
TensorList(OpList values);
|
||||
|
||||
|
|
|
|||
|
|
@ -55,85 +55,67 @@ class TSNodeLowering : public TSNodeLoweringInterface {
|
|||
// classes
|
||||
TSOpVector LowerNonCodegenOps(const torch::lazy::Node* node) {
|
||||
if (node->op().op == at::aten::as_strided) {
|
||||
return LowerAsStrided(torch::lazy::NodeCast<torch::lazy::AsStrided>(
|
||||
node, torch::lazy::OpKind(at::aten::as_strided)));
|
||||
return LowerAsStrided(torch::lazy::NodeCast<torch::lazy::AsStrided>(node));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_as_strided_view_update) {
|
||||
return LowerAsStridedViewUpdate(
|
||||
torch::lazy::NodeCast<torch::lazy::AsStridedViewUpdate>(
|
||||
node, *torch::lazy::ltc_as_strided_view_update));
|
||||
torch::lazy::NodeCast<torch::lazy::AsStridedViewUpdate>(node));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_cast) {
|
||||
return LowerCast(torch::lazy::NodeCast<torch::lazy::Cast>(
|
||||
node, *torch::lazy::ltc_cast));
|
||||
return LowerCast(torch::lazy::NodeCast<torch::lazy::Cast>(node));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_select_view_update) {
|
||||
return LowerSelectViewUpdate(
|
||||
torch::lazy::NodeCast<torch::lazy::SelectViewUpdate>(
|
||||
node, *torch::lazy::ltc_select_view_update));
|
||||
torch::lazy::NodeCast<torch::lazy::SelectViewUpdate>(node));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_narrow_view_update) {
|
||||
return LowerNarrowViewUpdate(
|
||||
torch::lazy::NodeCast<torch::lazy::NarrowViewUpdate>(
|
||||
node, *torch::lazy::ltc_narrow_view_update));
|
||||
torch::lazy::NodeCast<torch::lazy::NarrowViewUpdate>(node));
|
||||
}
|
||||
if (node->op().op == at::prim::Constant) {
|
||||
return LowerScalar(torch::lazy::NodeCast<torch::lazy::Scalar>(
|
||||
node, torch::lazy::OpKind(at::prim::Constant)));
|
||||
return LowerScalar(torch::lazy::NodeCast<torch::lazy::Scalar>(node));
|
||||
}
|
||||
if (node->op().op == at::aten::native_batch_norm) {
|
||||
return LowerBatchNorm(
|
||||
torch::lazy::NodeCast<TSNativeBatchNormForward>(
|
||||
node, torch::lazy::OpKind(at::aten::native_batch_norm)));
|
||||
torch::lazy::NodeCast<TSNativeBatchNormForward>(node));
|
||||
}
|
||||
if (node->op().op == at::aten::native_batch_norm_backward) {
|
||||
return LowerBatchNormBackward(
|
||||
torch::lazy::NodeCast<TSNativeBatchNormBackward>(
|
||||
node, torch::lazy::OpKind(at::aten::native_batch_norm_backward)));
|
||||
torch::lazy::NodeCast<TSNativeBatchNormBackward>(node));
|
||||
}
|
||||
if (node->op().op == at::aten::expand) {
|
||||
return LowerExpand(
|
||||
torch::lazy::NodeCast<torch::lazy::Expand>(
|
||||
node, torch::lazy::OpKind(at::aten::expand)));
|
||||
torch::lazy::NodeCast<torch::lazy::Expand>(node));
|
||||
}
|
||||
if (node->op().op == at::aten::narrow) {
|
||||
return LowerNarrow(torch::lazy::NodeCast<torch::lazy::Narrow>(
|
||||
node, torch::lazy::OpKind(at::aten::narrow)));
|
||||
return LowerNarrow(torch::lazy::NodeCast<torch::lazy::Narrow>(node));
|
||||
}
|
||||
if (node->op().op == at::aten::permute) {
|
||||
return LowerPermute(torch::lazy::NodeCast<torch::lazy::Permute>(
|
||||
node, torch::lazy::OpKind(at::aten::permute)));
|
||||
return LowerPermute(torch::lazy::NodeCast<torch::lazy::Permute>(node));
|
||||
}
|
||||
if (node->op().op == at::aten::select) {
|
||||
return LowerSelect(torch::lazy::NodeCast<torch::lazy::Select>(
|
||||
node, torch::lazy::OpKind(at::aten::select)));
|
||||
return LowerSelect(torch::lazy::NodeCast<torch::lazy::Select>(node));
|
||||
}
|
||||
if (node->op().op == at::aten::squeeze) {
|
||||
return LowerSqueeze(
|
||||
torch::lazy::NodeCast<Squeeze>(
|
||||
node, torch::lazy::OpKind(at::aten::squeeze)));
|
||||
torch::lazy::NodeCast<Squeeze>(node));
|
||||
}
|
||||
if (node->op().op == at::aten::unsqueeze) {
|
||||
return LowerUnsqueeze(
|
||||
torch::lazy::NodeCast<Unsqueeze>(
|
||||
node, torch::lazy::OpKind(at::aten::unsqueeze)));
|
||||
torch::lazy::NodeCast<Unsqueeze>(node));
|
||||
}
|
||||
if (node->op().op == at::aten::view) {
|
||||
return LowerView(torch::lazy::NodeCast<torch::lazy::View>(
|
||||
node, torch::lazy::OpKind(at::aten::view)));
|
||||
return LowerView(torch::lazy::NodeCast<torch::lazy::View>(node));
|
||||
}
|
||||
if (node->op().op == at::aten::diagonal) {
|
||||
return LowerDiagonal(torch::lazy::NodeCast<torch::lazy::Diagonal>(
|
||||
node, torch::lazy::OpKind(at::aten::diagonal)));
|
||||
return LowerDiagonal(torch::lazy::NodeCast<torch::lazy::Diagonal>(node));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_diagonal_view_update) {
|
||||
return LowerDiagonalViewUpdate(torch::lazy::NodeCast<torch::lazy::DiagonalViewUpdate>(
|
||||
node, *torch::lazy::ltc_diagonal_view_update));
|
||||
return LowerDiagonalViewUpdate(torch::lazy::NodeCast<torch::lazy::DiagonalViewUpdate>(node));
|
||||
}
|
||||
if (node->op() == *torch::lazy::ltc_device_data) {
|
||||
const torch::lazy::DeviceData* device_data_node =
|
||||
torch::lazy::NodeCast<torch::lazy::DeviceData>(
|
||||
node, *torch::lazy::ltc_device_data);
|
||||
torch::lazy::NodeCast<torch::lazy::DeviceData>(node);
|
||||
auto infoptr = device_data_node->data()->info();
|
||||
auto deviceDataInfoPtr = (torch::lazy::LazyGraphExecutor::DeviceDataInfo*) infoptr;
|
||||
if (GRAPH_DUMP_ENABLED) {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind AsStrided::class_op_kind(at::aten::as_strided);
|
||||
|
||||
AsStrided::AsStrided(
|
||||
const Value& input,
|
||||
std::vector<int64_t> size,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API AsStrided : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
AsStrided(
|
||||
const Value& input,
|
||||
std::vector<int64_t> size,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind AsStridedViewUpdate::class_op_kind(ltc_as_strided_view_update);
|
||||
|
||||
AsStridedViewUpdate::AsStridedViewUpdate(
|
||||
const Value& target,
|
||||
const Value& input,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API AsStridedViewUpdate : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
AsStridedViewUpdate(
|
||||
const Value& target,
|
||||
const Value& input,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind Diagonal::class_op_kind(at::aten::diagonal);
|
||||
|
||||
Diagonal::Diagonal(
|
||||
const Value& input,
|
||||
int64_t offset,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API Diagonal : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
Diagonal(const Value& input, int64_t offset, int64_t dim1, int64_t dim2);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind DiagonalViewUpdate::class_op_kind(ltc_diagonal_view_update);
|
||||
|
||||
DiagonalViewUpdate::DiagonalViewUpdate(
|
||||
const Value& target,
|
||||
const Value& input,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API DiagonalViewUpdate : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
DiagonalViewUpdate(
|
||||
const Value& target,
|
||||
const Value& input,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind Narrow::class_op_kind(at::aten::narrow);
|
||||
|
||||
Narrow::Narrow(
|
||||
const Value& input,
|
||||
c10::ArrayRef<int64_t> base_indices,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API Narrow : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
Narrow(
|
||||
const Value& input,
|
||||
c10::ArrayRef<int64_t> base_indices,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind NarrowViewUpdate::class_op_kind(ltc_narrow_view_update);
|
||||
|
||||
NarrowViewUpdate::NarrowViewUpdate(
|
||||
const Value& input,
|
||||
const Value& source,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API NarrowViewUpdate : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
NarrowViewUpdate(
|
||||
const Value& input,
|
||||
const Value& source,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind Permute::class_op_kind(at::aten::permute);
|
||||
|
||||
Permute::Permute(const Value& input, std::vector<int64_t> dims)
|
||||
: TsNode(
|
||||
OpKind(at::aten::permute),
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API Permute : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
Permute(const Value& input, std::vector<int64_t> dims);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@ namespace torch {
|
|||
namespace lazy {
|
||||
|
||||
namespace {
|
||||
|
||||
Shape NodeOutputShape(const Value& input, c10::ArrayRef<int64_t> size) {
|
||||
return Shape(input.shape().scalar_type(), size);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const OpKind Resize::class_op_kind(at::aten::resize);
|
||||
|
||||
Resize::Resize(const Value& input, std::vector<int64_t> size)
|
||||
: TsNode(
|
||||
OpKind(at::aten::resize),
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API Resize : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
Resize(const Value& input, std::vector<int64_t> size);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind Select::class_op_kind(at::aten::select);
|
||||
|
||||
Select::Select(
|
||||
const Value& input,
|
||||
int64_t dim,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API Select : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
Select(
|
||||
const Value& input,
|
||||
int64_t dim,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind SelectViewUpdate::class_op_kind(ltc_select_view_update);
|
||||
|
||||
SelectViewUpdate::SelectViewUpdate(
|
||||
const Value& target,
|
||||
const Value& source,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API SelectViewUpdate : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
SelectViewUpdate(
|
||||
const Value& target,
|
||||
const Value& source,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind Squeeze::class_op_kind(at::aten::squeeze);
|
||||
|
||||
Squeeze::Squeeze(const torch::lazy::Value& input, int dim)
|
||||
: torch::lazy::TsNode(torch::lazy::OpKind(at::aten::squeeze), {input},
|
||||
/*num_outputs=*/1, torch::lazy::MHash(dim)),
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API Squeeze : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
// Squeeze out the specified dimension index, -1 for all trivial dimensions.
|
||||
Squeeze(const torch::lazy::Value& input, int dim);
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
const OpKind Unsqueeze::class_op_kind(at::aten::unsqueeze);
|
||||
|
||||
Unsqueeze::Unsqueeze(const torch::lazy::Value& input, int dim)
|
||||
: torch::lazy::TsNode(
|
||||
torch::lazy::OpKind(at::aten::unsqueeze),
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API Unsqueeze : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
Unsqueeze(const torch::lazy::Value& input, int dim);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ namespace torch {
|
|||
namespace lazy {
|
||||
|
||||
namespace {
|
||||
|
||||
Shape NodeOutputShape(const Value& input, c10::ArrayRef<int64_t> output_sizes) {
|
||||
const Shape& input_shape = input.shape();
|
||||
const auto complete_output_sizes =
|
||||
|
|
@ -16,6 +15,8 @@ Shape NodeOutputShape(const Value& input, c10::ArrayRef<int64_t> output_sizes) {
|
|||
|
||||
} // namespace
|
||||
|
||||
const OpKind View::class_op_kind(at::aten::view);
|
||||
|
||||
View::View(const Value& input, std::vector<int64_t> output_size)
|
||||
: TsNode(
|
||||
OpKind(at::aten::view),
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ namespace lazy {
|
|||
|
||||
class TORCH_API View : public TsNode {
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
View(const Value& input, std::vector<int64_t> output_size);
|
||||
|
||||
std::string ToString() const override;
|
||||
|
|
|
|||
|
|
@ -119,6 +119,16 @@ class GenLazyIR(ABC):
|
|||
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||
return self.gen(f)
|
||||
|
||||
@method_with_native_function
|
||||
def gen_opkind_definition(
|
||||
self, f: Union[NativeFunctionsGroup, NativeFunction]
|
||||
) -> List[str]:
|
||||
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||
schema = LazyIrSchema(func)
|
||||
return [
|
||||
f"const OpKind {schema.node_name}::class_op_kind{{{aten_symbol(schema)}}};"
|
||||
]
|
||||
|
||||
# there is no lowering functionality generated unless this IR base class is subclassed and
|
||||
# implemented as a backend-specific node
|
||||
def lowering_function(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
|
||||
|
|
@ -202,6 +212,8 @@ class GenLazyIR(ABC):
|
|||
f"""\
|
||||
class {schema.node_name} : public {self.node_base} {{
|
||||
public:
|
||||
static const OpKind class_op_kind;
|
||||
|
||||
{schema.node_name}({node_ctor_args}, std::vector<Shape>&& shapes)
|
||||
: {self.node_base_ctor_call(schema)}{comma_if_scalar_initializers}
|
||||
{scalar_initializers}
|
||||
|
|
@ -221,7 +233,6 @@ class {schema.node_name} : public {self.node_base} {{
|
|||
|
||||
{scalar_decls}
|
||||
{has_optional_decls}
|
||||
|
||||
}};
|
||||
|
||||
""",
|
||||
|
|
@ -394,6 +405,7 @@ class GenLazyNativeFuncDefinition:
|
|||
schema = LazyIrSchema(func.func)
|
||||
return [
|
||||
f"""\
|
||||
|
||||
{sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
|
||||
{self.force_eager_fallback(func, schema)}
|
||||
{self.metrics(func, schema)}
|
||||
|
|
|
|||
|
|
@ -524,6 +524,29 @@ def run_gen_lazy_tensor(
|
|||
"namespace_epilogue": ns_helper.epilogue,
|
||||
},
|
||||
)
|
||||
# Generate OpKind definitions for IR node classes
|
||||
fm.write_with_template(
|
||||
"LazyIr.cpp",
|
||||
"LazyIr.cpp",
|
||||
lambda: {
|
||||
"includes": [
|
||||
f"#include <{path}>"
|
||||
for path in [
|
||||
f"{output_dir}/LazyIr.h",
|
||||
]
|
||||
],
|
||||
"opkind_definitions": list(
|
||||
concat_map_codegen(
|
||||
lazy_ir_generator(
|
||||
backend_indices[backend_key], node_base
|
||||
).gen_opkind_definition,
|
||||
grouped_native_functions,
|
||||
)
|
||||
),
|
||||
"namespace_prologue": ns_helper.prologue,
|
||||
"namespace_epilogue": ns_helper.epilogue,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user