Allow types as node attributes (#26268)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26268

This is necessary to represent operators like isinstance where
the type needs to be recorded in the node.

This diff does not actually use the attributes for anything yet.
One plausible thing to do in the future would be use attributes to
fill in the values of type variables for nodes whose schema include
type variables rather than rematching them from the arguments. However,
this change is not required for isinstance so I have left it for later.

Test Plan: Imported from OSS

Differential Revision: D17412855

Pulled By: zdevito

fbshipit-source-id: 7a2618c8a9f9dfc94858af79afbf433518eda4b3
This commit is contained in:
Zachary DeVito 2019-10-01 16:37:34 -07:00 committed by Facebook Github Bot
parent 00e588290b
commit 8a38a53e4d
4 changed files with 58 additions and 6 deletions

View File

@ -7,6 +7,11 @@
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace c10 {
struct Type;
using TypePtr = std::shared_ptr<Type>;
} // namespace c10
namespace torch {
namespace jit {
@ -14,10 +19,10 @@ using ::c10::Symbol;
constexpr int max_tensor_display_size = 10;
enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs };
enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs, ty, tys };
static inline const char* toString(AttributeKind kind) {
static const char* names[] = {
"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs"};
"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "ty", "tys"};
AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(AttributeKind));
return names[int(kind)];
}
@ -80,6 +85,9 @@ using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
using TensorAttr = ScalarAttributeValue<at::Tensor, AttributeKind::t>;
using TensorsAttr = VectorAttributeValue<at::Tensor, AttributeKind::ts>;
using TypeAttr = ScalarAttributeValue<c10::TypePtr, AttributeKind::ty>;
using TypesAttr = VectorAttributeValue<c10::TypePtr, AttributeKind::tys>;
struct Graph;
// We special case Graph attributes like this because we want to ensure that

View File

@ -121,6 +121,19 @@ static void printStrList(
out << "]";
}
static void printTypeList(
std::ostream& out,
const std::vector<TypePtr>& items) {
out << "[";
int i = 0;
for (auto& item : items) {
if (i++ > 0)
out << ", ";
out << *item;
}
out << "]";
}
void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
switch (kindOf(name)) {
case AttributeKind::f:
@ -175,6 +188,12 @@ void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
case AttributeKind::gs:
out << "[<Graphs>]";
break;
case AttributeKind::ty:
out << *ty(name);
break;
case AttributeKind::tys:
printTypeList(out, tys(name));
break;
}
}

View File

@ -704,6 +704,8 @@ struct TORCH_API Node {
CREATE_ACCESSOR(Ints, is)
CREATE_ACCESSOR(Graph, g)
CREATE_ACCESSOR(Graphs, gs)
CREATE_ACCESSOR(Type, ty)
CREATE_ACCESSOR(Types, tys)
#undef CREATE_ACCESSOR

View File

@ -27,6 +27,19 @@ bool tensorListEqual(
return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
}
bool typeListEqual(
const std::vector<TypePtr>& lhs,
const std::vector<TypePtr>& rhs) {
if (lhs.size() != rhs.size())
return false;
for (size_t i = 0; i < lhs.size(); ++i) {
if (*lhs[i] != *rhs[i]) {
return false;
}
}
return true;
}
// Check whether two nodes have the same attributes in CSE.
// This function may be too conservative for general use.
// Do NOT support g/gs attributes.
@ -51,10 +64,10 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
if (lhs->kindOf(name) != rhs->kindOf(name))
return false;
#define COMPARE_ATTRIBUTEVALUE(type) \
case AttributeKind::type: { \
if (lhs->type(name) != rhs->type(name)) \
return false; \
#define COMPARE_ATTRIBUTEVALUE(selector) \
case AttributeKind::selector: { \
if (lhs->selector(name) != rhs->selector(name)) \
return false; \
} break;
switch (lhs->kindOf(name)) {
@ -74,6 +87,16 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
return false;
break;
}
case AttributeKind::ty:
if (*lhs->ty(name) != *rhs->ty(name)) {
return false;
}
break;
case AttributeKind::tys:
if (!typeListEqual(lhs->tys(name), rhs->tys(name))) {
return false;
}
break;
case AttributeKind::g:
case AttributeKind::gs:
return false;