mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Allow types as node attributes (#26268)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26268 This is necessary to represent operators like isinstance where the type needs to be recorded in the node. This diff does not actually use the attributes for anything yet. One plausible thing to do in the future would be use attributes to fill in the values of type variables for nodes whose schema include type variables rather than rematching them from the arguments. However, this change is not required for isinstance so I have left it for later. Test Plan: Imported from OSS Differential Revision: D17412855 Pulled By: zdevito fbshipit-source-id: 7a2618c8a9f9dfc94858af79afbf433518eda4b3
This commit is contained in:
parent
00e588290b
commit
8a38a53e4d
|
|
@ -7,6 +7,11 @@
|
|||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace c10 {
|
||||
struct Type;
|
||||
using TypePtr = std::shared_ptr<Type>;
|
||||
} // namespace c10
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
|
|
@ -14,10 +19,10 @@ using ::c10::Symbol;
|
|||
|
||||
constexpr int max_tensor_display_size = 10;
|
||||
|
||||
enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs };
|
||||
enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs, ty, tys };
|
||||
static inline const char* toString(AttributeKind kind) {
|
||||
static const char* names[] = {
|
||||
"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs"};
|
||||
"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "ty", "tys"};
|
||||
AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(AttributeKind));
|
||||
return names[int(kind)];
|
||||
}
|
||||
|
|
@ -80,6 +85,9 @@ using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
|
|||
using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
|
||||
using TensorAttr = ScalarAttributeValue<at::Tensor, AttributeKind::t>;
|
||||
using TensorsAttr = VectorAttributeValue<at::Tensor, AttributeKind::ts>;
|
||||
using TypeAttr = ScalarAttributeValue<c10::TypePtr, AttributeKind::ty>;
|
||||
using TypesAttr = VectorAttributeValue<c10::TypePtr, AttributeKind::tys>;
|
||||
|
||||
struct Graph;
|
||||
|
||||
// We special case Graph attributes like this because we want to ensure that
|
||||
|
|
|
|||
|
|
@ -121,6 +121,19 @@ static void printStrList(
|
|||
out << "]";
|
||||
}
|
||||
|
||||
static void printTypeList(
|
||||
std::ostream& out,
|
||||
const std::vector<TypePtr>& items) {
|
||||
out << "[";
|
||||
int i = 0;
|
||||
for (auto& item : items) {
|
||||
if (i++ > 0)
|
||||
out << ", ";
|
||||
out << *item;
|
||||
}
|
||||
out << "]";
|
||||
}
|
||||
|
||||
void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
|
||||
switch (kindOf(name)) {
|
||||
case AttributeKind::f:
|
||||
|
|
@ -175,6 +188,12 @@ void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
|
|||
case AttributeKind::gs:
|
||||
out << "[<Graphs>]";
|
||||
break;
|
||||
case AttributeKind::ty:
|
||||
out << *ty(name);
|
||||
break;
|
||||
case AttributeKind::tys:
|
||||
printTypeList(out, tys(name));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -704,6 +704,8 @@ struct TORCH_API Node {
|
|||
CREATE_ACCESSOR(Ints, is)
|
||||
CREATE_ACCESSOR(Graph, g)
|
||||
CREATE_ACCESSOR(Graphs, gs)
|
||||
CREATE_ACCESSOR(Type, ty)
|
||||
CREATE_ACCESSOR(Types, tys)
|
||||
|
||||
#undef CREATE_ACCESSOR
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,19 @@ bool tensorListEqual(
|
|||
return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
|
||||
}
|
||||
|
||||
bool typeListEqual(
|
||||
const std::vector<TypePtr>& lhs,
|
||||
const std::vector<TypePtr>& rhs) {
|
||||
if (lhs.size() != rhs.size())
|
||||
return false;
|
||||
for (size_t i = 0; i < lhs.size(); ++i) {
|
||||
if (*lhs[i] != *rhs[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check whether two nodes have the same attributes in CSE.
|
||||
// This function may be too conservative for general use.
|
||||
// Do NOT support g/gs attributes.
|
||||
|
|
@ -51,10 +64,10 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
|
|||
if (lhs->kindOf(name) != rhs->kindOf(name))
|
||||
return false;
|
||||
|
||||
#define COMPARE_ATTRIBUTEVALUE(type) \
|
||||
case AttributeKind::type: { \
|
||||
if (lhs->type(name) != rhs->type(name)) \
|
||||
return false; \
|
||||
#define COMPARE_ATTRIBUTEVALUE(selector) \
|
||||
case AttributeKind::selector: { \
|
||||
if (lhs->selector(name) != rhs->selector(name)) \
|
||||
return false; \
|
||||
} break;
|
||||
|
||||
switch (lhs->kindOf(name)) {
|
||||
|
|
@ -74,6 +87,16 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
|
|||
return false;
|
||||
break;
|
||||
}
|
||||
case AttributeKind::ty:
|
||||
if (*lhs->ty(name) != *rhs->ty(name)) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case AttributeKind::tys:
|
||||
if (!typeListEqual(lhs->tys(name), rhs->tys(name))) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case AttributeKind::g:
|
||||
case AttributeKind::gs:
|
||||
return false;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user