mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26268 This is necessary to represent operators like isinstance where the type needs to be recorded in the node. This diff does not actually use the attributes for anything yet. One plausible thing to do in the future would be use attributes to fill in the values of type variables for nodes whose schema include type variables rather than rematching them from the arguments. However, this change is not required for isinstance so I have left it for later. Test Plan: Imported from OSS Differential Revision: D17412855 Pulled By: zdevito fbshipit-source-id: 7a2618c8a9f9dfc94858af79afbf433518eda4b3
150 lines
4.2 KiB
C++
150 lines
4.2 KiB
C++
#pragma once
|
|
#include <ATen/ATen.h>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <ATen/core/interned_strings.h>
|
|
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
|
|
namespace c10 {
|
|
struct Type;
|
|
using TypePtr = std::shared_ptr<Type>;
|
|
} // namespace c10
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using ::c10::Symbol;
|
|
|
|
constexpr int max_tensor_display_size = 10;
|
|
|
|
enum class AttributeKind { f, fs, i, is, s, ss, t, ts, g, gs, ty, tys };
|
|
static inline const char* toString(AttributeKind kind) {
|
|
static const char* names[] = {
|
|
"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "ty", "tys"};
|
|
AT_ASSERT(size_t(kind) < sizeof(names) / sizeof(AttributeKind));
|
|
return names[int(kind)];
|
|
}
|
|
|
|
struct AttributeValue {
|
|
AttributeValue(Symbol name) : name(name) {}
|
|
using Ptr = std::unique_ptr<AttributeValue>;
|
|
Symbol name;
|
|
virtual AttributeKind kind() const = 0;
|
|
virtual Ptr clone() const = 0;
|
|
virtual ~AttributeValue() = default;
|
|
};
|
|
|
|
template <typename T, AttributeKind Kind>
|
|
struct ScalarAttributeValue : public AttributeValue {
|
|
using ConstructorType = T;
|
|
using ValueType = T;
|
|
ScalarAttributeValue(Symbol name, ConstructorType value_)
|
|
: AttributeValue(name), value_(std::move(value_)) {}
|
|
ValueType& value() {
|
|
return value_;
|
|
}
|
|
Ptr clone() const override {
|
|
return Ptr(new ScalarAttributeValue(name, value_));
|
|
}
|
|
AttributeKind kind() const override {
|
|
return Kind;
|
|
}
|
|
|
|
private:
|
|
ValueType value_;
|
|
};
|
|
|
|
template <typename T, AttributeKind Kind>
|
|
struct VectorAttributeValue : public AttributeValue {
|
|
using ConstructorType = std::vector<T>;
|
|
using ValueType = std::vector<T>;
|
|
VectorAttributeValue(Symbol name, ConstructorType value_)
|
|
: AttributeValue(name), value_(std::move(value_)) {}
|
|
ValueType& value() {
|
|
return value_;
|
|
}
|
|
AttributeKind kind() const override {
|
|
return Kind;
|
|
}
|
|
std::unique_ptr<AttributeValue> clone() const override {
|
|
auto copy = value_;
|
|
return Ptr(new VectorAttributeValue(name, std::move(copy)));
|
|
}
|
|
|
|
private:
|
|
ValueType value_;
|
|
};
|
|
|
|
using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
|
|
using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
|
|
using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
|
|
using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
|
|
using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
|
|
using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
|
|
using TensorAttr = ScalarAttributeValue<at::Tensor, AttributeKind::t>;
|
|
using TensorsAttr = VectorAttributeValue<at::Tensor, AttributeKind::ts>;
|
|
using TypeAttr = ScalarAttributeValue<c10::TypePtr, AttributeKind::ty>;
|
|
using TypesAttr = VectorAttributeValue<c10::TypePtr, AttributeKind::tys>;
|
|
|
|
struct Graph;
|
|
|
|
// We special case Graph attributes like this because we want to ensure that
|
|
// Graph::copy() is called when we clone() these attributes.
|
|
struct TORCH_API GraphAttr : public AttributeValue {
|
|
using ConstructorType = std::shared_ptr<Graph>;
|
|
using ValueType = std::shared_ptr<Graph>;
|
|
GraphAttr(Symbol name, ConstructorType value_)
|
|
: AttributeValue(name), value_(value_) {}
|
|
ValueType& value() {
|
|
return value_;
|
|
}
|
|
Ptr clone() const override;
|
|
AttributeKind kind() const override {
|
|
return AttributeKind::g;
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<Graph> value_;
|
|
};
|
|
|
|
struct TORCH_API GraphsAttr : public AttributeValue {
|
|
using ConstructorType = std::vector<std::shared_ptr<Graph>>;
|
|
using ValueType = std::vector<std::shared_ptr<Graph>>;
|
|
GraphsAttr(Symbol name, ConstructorType value_)
|
|
: AttributeValue(name), value_(std::move(value_)) {}
|
|
ValueType& value() {
|
|
return value_;
|
|
}
|
|
AttributeKind kind() const override {
|
|
return AttributeKind::gs;
|
|
}
|
|
std::unique_ptr<AttributeValue> clone() const override;
|
|
|
|
private:
|
|
ValueType value_;
|
|
};
|
|
|
|
struct AttributeError : public std::exception {
|
|
AttributeError(Symbol name, bool defined) {
|
|
std::stringstream ss;
|
|
if (!defined) {
|
|
ss << "required keyword attribute '" << name.toUnqualString()
|
|
<< "' is undefined";
|
|
} else {
|
|
ss << "required keyword attribute '" << name.toUnqualString()
|
|
<< "' has the wrong type";
|
|
}
|
|
msg = ss.str();
|
|
}
|
|
const char* what() const noexcept override {
|
|
return msg.c_str();
|
|
}
|
|
|
|
private:
|
|
std::string msg;
|
|
};
|
|
} // namespace jit
|
|
} // namespace torch
|