pytorch/torch/csrc/jit/attributes.h
Zachary DeVito 8a38a53e4d Allow types as node attributes (#26268)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26268

This is necessary to represent operators like isinstance where
the type needs to be recorded in the node.

This diff does not actually use the attributes for anything yet.
One plausible thing to do in the future would be use attributes to
fill in the values of type variables for nodes whose schema include
type variables rather than rematching them from the arguments. However,
this change is not required for isinstance so I have left it for later.

Test Plan: Imported from OSS

Differential Revision: D17412855

Pulled By: zdevito

fbshipit-source-id: 7a2618c8a9f9dfc94858af79afbf433518eda4b3
2019-10-01 16:46:55 -07:00

150 lines
4.2 KiB
C++

#pragma once
#include <ATen/ATen.h>
#include <string>
#include <vector>
#include <ATen/core/interned_strings.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace c10 {
struct Type;
using TypePtr = std::shared_ptr<Type>;
} // namespace c10
namespace torch {
namespace jit {
using ::c10::Symbol;
constexpr int max_tensor_display_size = 10;
enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs, ty, tys };
static inline const char* toString(AttributeKind kind) {
static const char* names[] = {
"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "ty", "tys"};
AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(AttributeKind));
return names[int(kind)];
}
struct AttributeValue {
AttributeValue(Symbol name) : name(name) {}
using Ptr = std::unique_ptr<AttributeValue>;
Symbol name;
virtual AttributeKind kind() const = 0;
virtual Ptr clone() const = 0;
virtual ~AttributeValue() = default;
};
template <typename T, AttributeKind Kind>
struct ScalarAttributeValue : public AttributeValue {
using ConstructorType = T;
using ValueType = T;
ScalarAttributeValue(Symbol name, ConstructorType value_)
: AttributeValue(name), value_(std::move(value_)) {}
ValueType& value() {
return value_;
}
Ptr clone() const override {
return Ptr(new ScalarAttributeValue(name, value_));
}
AttributeKind kind() const override {
return Kind;
}
private:
ValueType value_;
};
template <typename T, AttributeKind Kind>
struct VectorAttributeValue : public AttributeValue {
using ConstructorType = std::vector<T>;
using ValueType = std::vector<T>;
VectorAttributeValue(Symbol name, ConstructorType value_)
: AttributeValue(name), value_(std::move(value_)) {}
ValueType& value() {
return value_;
}
AttributeKind kind() const override {
return Kind;
}
std::unique_ptr<AttributeValue> clone() const override {
auto copy = value_;
return Ptr(new VectorAttributeValue(name, std::move(copy)));
}
private:
ValueType value_;
};
using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
using TensorAttr = ScalarAttributeValue<at::Tensor, AttributeKind::t>;
using TensorsAttr = VectorAttributeValue<at::Tensor, AttributeKind::ts>;
using TypeAttr = ScalarAttributeValue<c10::TypePtr, AttributeKind::ty>;
using TypesAttr = VectorAttributeValue<c10::TypePtr, AttributeKind::tys>;
struct Graph;
// We special case Graph attributes like this because we want to ensure that
// Graph::copy() is called when we clone() these attributes.
struct TORCH_API GraphAttr : public AttributeValue {
using ConstructorType = std::shared_ptr<Graph>;
using ValueType = std::shared_ptr<Graph>;
GraphAttr(Symbol name, ConstructorType value_)
: AttributeValue(name), value_(value_) {}
ValueType& value() {
return value_;
}
Ptr clone() const override;
AttributeKind kind() const override {
return AttributeKind::g;
}
private:
std::shared_ptr<Graph> value_;
};
struct TORCH_API GraphsAttr : public AttributeValue {
using ConstructorType = std::vector<std::shared_ptr<Graph>>;
using ValueType = std::vector<std::shared_ptr<Graph>>;
GraphsAttr(Symbol name, ConstructorType value_)
: AttributeValue(name), value_(std::move(value_)) {}
ValueType& value() {
return value_;
}
AttributeKind kind() const override {
return AttributeKind::gs;
}
std::unique_ptr<AttributeValue> clone() const override;
private:
ValueType value_;
};
struct AttributeError : public std::exception {
AttributeError(Symbol name, bool defined) {
std::stringstream ss;
if (!defined) {
ss << "required keyword attribute '" << name.toUnqualString()
<< "' is undefined";
} else {
ss << "required keyword attribute '" << name.toUnqualString()
<< "' has the wrong type";
}
msg = ss.str();
}
const char* what() const noexcept override {
return msg.c_str();
}
private:
std::string msg;
};
} // namespace jit
} // namespace torch