mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[jit][edge] Migrate to TypeFactory for jit types on mobile (#71516)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71516
Mobile should be able to contruct dynamic types by default.
ghstack-source-id: 147498365
Test Plan:
CI.
**-48KB** binary size reduction for igios BSB.
UMBEX link: https://www.internalfb.com/intern/unigraph/explorer/?jsgq_traversal_spec=%7B%22builds%22%3A[%22bsb%3A422553426218394%5Cu0040base%22%2C%22bsb%3A422553426218394%5Cu0040diff%22]%7D&unigraph_project=UnigraphProjectMbex&is_mbex_redirected
Reviewed By: iseeyuan
Differential Revision: D33673958
fbshipit-source-id: 8600c04ae929283681971aae264d3774188df9cd
(cherry picked from commit 64ebcec09e)
This commit is contained in:
parent
e5794974cb
commit
fe277b8717
|
|
@ -308,7 +308,7 @@ c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
|
||||||
#else
|
#else
|
||||||
// TODO: caffe2::PThreadPool only provides a data-parallel API.
|
// TODO: caffe2::PThreadPool only provides a data-parallel API.
|
||||||
// Task parallelism is not currently supported.
|
// Task parallelism is not currently supported.
|
||||||
auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get());
|
auto future = c10::make_intrusive<c10::ivalue::Future>(c10::dynT<NoneType>());
|
||||||
func();
|
func();
|
||||||
future->markCompleted();
|
future->markCompleted();
|
||||||
return future;
|
return future;
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
#include <ATen/core/jit_type.h>
|
#include <ATen/core/jit_type.h>
|
||||||
#include <ATen/core/function_schema.h>
|
#include <ATen/core/function_schema.h>
|
||||||
#include <ATen/core/functional.h>
|
#include <ATen/core/functional.h>
|
||||||
|
#include <ATen/core/type_factory.h>
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
@ -102,7 +103,7 @@ class_base::class_base(
|
||||||
{
|
{
|
||||||
detail::checkValidIdent(namespaceName, "Namespace name");
|
detail::checkValidIdent(namespaceName, "Namespace name");
|
||||||
detail::checkValidIdent(className, "Class name");
|
detail::checkValidIdent(className, "Class name");
|
||||||
classTypePtr->addAttribute("capsule", at::CapsuleType::get());
|
classTypePtr->addAttribute("capsule", c10::TypeFactory::get<c10::CapsuleType>());
|
||||||
c10::getCustomClassTypeMap().insert(
|
c10::getCustomClassTypeMap().insert(
|
||||||
{std::type_index(intrusivePtrClassTypeid), classTypePtr});
|
{std::type_index(intrusivePtrClassTypeid), classTypePtr});
|
||||||
c10::getCustomClassTypeMap().insert(
|
c10::getCustomClassTypeMap().insert(
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include <ATen/core/class_type.h>
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
#include <ATen/core/jit_type.h>
|
#include <ATen/core/jit_type.h>
|
||||||
#include <ATen/core/type_factory.h>
|
#include <ATen/core/type_factory.h>
|
||||||
|
|
@ -198,6 +199,11 @@ TypePtr DynamicType::containedType(size_t i) const {
|
||||||
return arguments_.elems.at(i).ty;
|
return arguments_.elems.at(i).ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t DynamicType::containedTypeSize() const {
|
||||||
|
TORCH_INTERNAL_ASSERT(tag_ != Tag::Class);
|
||||||
|
return arguments_.elems.size();
|
||||||
|
}
|
||||||
|
|
||||||
TypeKind DynamicType::dynamicKind() const {
|
TypeKind DynamicType::dynamicKind() const {
|
||||||
switch (tag_) {
|
switch (tag_) {
|
||||||
#define CASE_TYPE(T, _, __) \
|
#define CASE_TYPE(T, _, __) \
|
||||||
|
|
@ -271,6 +277,16 @@ TypePtr DynamicType::fallback() const {
|
||||||
return VarType::create(*name_);
|
return VarType::create(*name_);
|
||||||
case Tag::AnyClass:
|
case Tag::AnyClass:
|
||||||
return AnyClassType::get();
|
return AnyClassType::get();
|
||||||
|
case Tag::QScheme:
|
||||||
|
return QSchemeType::get();
|
||||||
|
case Tag::Quantizer:
|
||||||
|
return QuantizerType::get();
|
||||||
|
case Tag::AnyEnum:
|
||||||
|
return AnyEnumType::get();
|
||||||
|
case Tag::RRef:
|
||||||
|
return RRefType::create(arguments_.elems[0].ty->fallback());
|
||||||
|
case Tag::Future:
|
||||||
|
return FutureType::create(arguments_.elems[0].ty->fallback());
|
||||||
case Tag::Any:
|
case Tag::Any:
|
||||||
return AnyType::get();
|
return AnyType::get();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,6 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
#include <ATen/core/class_type.h>
|
|
||||||
#include <ATen/core/ivalue.h>
|
|
||||||
#include <ATen/core/jit_type_base.h>
|
#include <ATen/core/jit_type_base.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
|
|
||||||
|
|
@ -53,8 +51,17 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
|
||||||
_(Storage, DYNAMIC_TYPE_BIT(16), 1) \
|
_(Storage, DYNAMIC_TYPE_BIT(16), 1) \
|
||||||
_(Var, DYNAMIC_TYPE_BIT(17), 0) \
|
_(Var, DYNAMIC_TYPE_BIT(17), 0) \
|
||||||
_(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1) \
|
_(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1) \
|
||||||
|
_(QScheme, DYNAMIC_TYPE_BIT(18), 1) \
|
||||||
|
_(Quantizer, DYNAMIC_TYPE_BIT(19), 1) \
|
||||||
|
_(AnyEnum, DYNAMIC_TYPE_BIT(20), 1) \
|
||||||
|
_(RRef, DYNAMIC_TYPE_BIT(21), 0) \
|
||||||
|
_(Future, DYNAMIC_TYPE_BIT(22), 0) \
|
||||||
_(Any, 0xffffffff, 1)
|
_(Any, 0xffffffff, 1)
|
||||||
|
|
||||||
|
#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;
|
||||||
|
FORALL_DYNAMIC_TYPES(FORWARD_DECL_TYPE)
|
||||||
|
#undef FORWARD_DECL_TYPE
|
||||||
|
|
||||||
class DynamicType;
|
class DynamicType;
|
||||||
using DynamicTypePtr = std::shared_ptr<DynamicType>;
|
using DynamicTypePtr = std::shared_ptr<DynamicType>;
|
||||||
|
|
||||||
|
|
@ -142,6 +149,7 @@ class DynamicType : public SharedType {
|
||||||
explicit DynamicType(Tag, c10::string_view, Arguments);
|
explicit DynamicType(Tag, c10::string_view, Arguments);
|
||||||
|
|
||||||
TypePtr containedType(size_t) const override;
|
TypePtr containedType(size_t) const override;
|
||||||
|
size_t containedTypeSize() const override;
|
||||||
Tag tag() const {
|
Tag tag() const {
|
||||||
return tag_;
|
return tag_;
|
||||||
}
|
}
|
||||||
|
|
@ -154,6 +162,9 @@ class DynamicType : public SharedType {
|
||||||
TypeKind dynamicKind() const;
|
TypeKind dynamicKind() const;
|
||||||
|
|
||||||
// Should be used only on the server side to restore static type information.
|
// Should be used only on the server side to restore static type information.
|
||||||
|
#ifndef C10_MOBILE
|
||||||
|
TORCH_API
|
||||||
|
#endif
|
||||||
TypePtr fallback() const;
|
TypePtr fallback() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -188,7 +199,7 @@ class DynamicType : public SharedType {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct DynamicTypeTrait {
|
struct DynamicTypeTrait {
|
||||||
static auto tagValue() {
|
C10_NOINLINE static auto tagValue() {
|
||||||
TORCH_CHECK(false);
|
TORCH_CHECK(false);
|
||||||
return DynamicType::Tag::Any;
|
return DynamicType::Tag::Any;
|
||||||
}
|
}
|
||||||
|
|
@ -201,7 +212,7 @@ C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag);
|
||||||
#define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE) \
|
#define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE) \
|
||||||
template <> \
|
template <> \
|
||||||
struct TORCH_API DynamicTypeTrait<NAME##Type> { \
|
struct TORCH_API DynamicTypeTrait<NAME##Type> { \
|
||||||
static auto tagValue() { \
|
C10_ERASE static auto tagValue() { \
|
||||||
return DynamicType::Tag::NAME; \
|
return DynamicType::Tag::NAME; \
|
||||||
} \
|
} \
|
||||||
static constexpr bool isBaseType = IS_BASE_TYPE; \
|
static constexpr bool isBaseType = IS_BASE_TYPE; \
|
||||||
|
|
@ -214,19 +225,4 @@ C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag);
|
||||||
FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE)
|
FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE)
|
||||||
#undef DYNAMIC_TYPE_TAG_VALUE
|
#undef DYNAMIC_TYPE_TAG_VALUE
|
||||||
|
|
||||||
template <>
|
|
||||||
struct IValue::TagType<c10::DynamicType> {
|
|
||||||
static DynamicType::Ptr get(const c10::IValue& v);
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace ivalue {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct TORCH_API TupleTypeFactory<c10::DynamicType> {
|
|
||||||
static DynamicTypePtr create(std::vector<TypePtr> elemTypes);
|
|
||||||
static DynamicTypePtr fallback(const Type&);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace ivalue
|
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -390,7 +390,7 @@ struct FunctionSchema {
|
||||||
|
|
||||||
// Check that inputs have the correct types and appends any missing default
|
// Check that inputs have the correct types and appends any missing default
|
||||||
// values.
|
// values.
|
||||||
template <typename T = c10::Type>
|
template <typename T = c10::PlatformType>
|
||||||
void checkAndNormalizeInputs(
|
void checkAndNormalizeInputs(
|
||||||
std::vector<IValue>& inputs,
|
std::vector<IValue>& inputs,
|
||||||
const std::unordered_map<std::string, IValue>& kwargs =
|
const std::unordered_map<std::string, IValue>& kwargs =
|
||||||
|
|
|
||||||
|
|
@ -293,7 +293,7 @@ inline void FunctionSchema::checkArg(
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
formatTypeMismatchMsg(
|
formatTypeMismatchMsg(
|
||||||
argument, value.type()->repr_str(), pos));
|
argument, value.type<T>()->repr_str(), pos));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <ATen/core/function.h>
|
#include <ATen/core/function.h>
|
||||||
#include <ATen/core/jit_type.h>
|
#include <ATen/core/jit_type.h>
|
||||||
#include <ATen/core/stack.h>
|
#include <ATen/core/stack.h>
|
||||||
|
#include <ATen/core/type_factory.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <c10/util/StringUtil.h>
|
#include <c10/util/StringUtil.h>
|
||||||
#include <c10/util/hash.h>
|
#include <c10/util/hash.h>
|
||||||
|
|
@ -403,6 +404,39 @@ bool IValue::is(const IValue& rhs) const {
|
||||||
return lhs == rhs;
|
return lhs == rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline bool IValue::isListOf() const {
|
||||||
|
// note: avoids calling type() to avoid extra referencing counting for the returned type.
|
||||||
|
if (!isList()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
|
||||||
|
if (ty->kind() == T::Kind) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return *ty == *TypeFactory::get<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IValue::isDoubleList() const {
|
||||||
|
return isListOf<c10::FloatType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IValue::isComplexDoubleList() const {
|
||||||
|
return isListOf<c10::ComplexType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IValue::isTensorList() const {
|
||||||
|
return isListOf<c10::TensorType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IValue::isIntList() const {
|
||||||
|
return isListOf<c10::IntType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IValue::isBoolList() const {
|
||||||
|
return isListOf<c10::BoolType>();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using IValueFormatter = std::function<void(std::ostream&, const IValue&)>;
|
using IValueFormatter = std::function<void(std::ostream&, const IValue&)>;
|
||||||
|
|
@ -430,7 +464,7 @@ std::ostream& printMaybeAnnotatedList(
|
||||||
std::ostream& out,
|
std::ostream& out,
|
||||||
const IValue& the_list,
|
const IValue& the_list,
|
||||||
IValueFormatter formatter) {
|
IValueFormatter formatter) {
|
||||||
auto list_elem_type = the_list.type()->expectRef<ListType>().getElementType();
|
auto list_elem_type = the_list.type()->containedType(0);
|
||||||
if (the_list.toListRef().size() == 0 ||
|
if (the_list.toListRef().size() == 0 ||
|
||||||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
|
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
|
||||||
out << "annotate(" << the_list.type()->annotation_str() << ", ";
|
out << "annotate(" << the_list.type()->annotation_str() << ", ";
|
||||||
|
|
@ -925,7 +959,7 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(IValue::HashAliasedI
|
||||||
auto cu = type_.cu_;
|
auto cu = type_.cu_;
|
||||||
auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes());
|
auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes());
|
||||||
for (const auto i : c10::irange(slots_.size())) {
|
for (const auto i : c10::irange(slots_.size())) {
|
||||||
if (slots_[i].type() == c10::CapsuleType::get()) {
|
if (*slots_[i].type() == *c10::TypeFactory::get<CapsuleType>()) {
|
||||||
// If we've gotten here, it means that we have *not* copied this
|
// If we've gotten here, it means that we have *not* copied this
|
||||||
// class via __getstate__ and __setstate__. That fact and the
|
// class via __getstate__ and __setstate__. That fact and the
|
||||||
// fact that we have a Capsule attribute mean that this is a
|
// fact that we have a Capsule attribute mean that this is a
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <ATen/core/custom_class.h>
|
#include <ATen/core/custom_class.h>
|
||||||
#include <ATen/core/ivalue_to.h>
|
#include <ATen/core/ivalue_to.h>
|
||||||
#include <ATen/core/jit_type_base.h>
|
#include <ATen/core/jit_type_base.h>
|
||||||
|
#include <ATen/core/type_factory.h>
|
||||||
#include <c10/util/C++17.h>
|
#include <c10/util/C++17.h>
|
||||||
#include <c10/util/MaybeOwned.h>
|
#include <c10/util/MaybeOwned.h>
|
||||||
#include <c10/util/intrusive_ptr.h>
|
#include <c10/util/intrusive_ptr.h>
|
||||||
|
|
@ -895,8 +896,8 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T = c10::Type>
|
template <typename T = c10::PlatformType>
|
||||||
typename T::Ptr type() const;
|
TypePtr type() const;
|
||||||
|
|
||||||
// Detect aliased tensors.
|
// Detect aliased tensors.
|
||||||
struct HashAliasedIValue {
|
struct HashAliasedIValue {
|
||||||
|
|
|
||||||
|
|
@ -586,6 +586,12 @@ struct TORCH_API TupleTypeFactory<TupleType> {
|
||||||
static TupleTypePtr fallback(const Type& type);
|
static TupleTypePtr fallback(const Type& type);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TORCH_API TupleTypeFactory<c10::DynamicType> {
|
||||||
|
static DynamicTypePtr create(std::vector<TypePtr> elemTypes);
|
||||||
|
static DynamicTypePtr fallback(const Type&);
|
||||||
|
};
|
||||||
|
|
||||||
struct TORCH_API Tuple : c10::intrusive_ptr_target {
|
struct TORCH_API Tuple : c10::intrusive_ptr_target {
|
||||||
private:
|
private:
|
||||||
TupleElements elements_;
|
TupleElements elements_;
|
||||||
|
|
@ -1915,39 +1921,6 @@ inline ivalue::Tuple& IValue::toTupleRef() const {
|
||||||
payload.u.as_intrusive_ptr);
|
payload.u.as_intrusive_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline bool IValue::isListOf() const {
|
|
||||||
// note: avoids calling type() to avoid extra referencing counting for the returned type.
|
|
||||||
if (!isList()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
|
|
||||||
if (ty->kind() == T::Kind) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return *ty == *T::get();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool IValue::isDoubleList() const {
|
|
||||||
return isListOf<c10::FloatType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool IValue::isComplexDoubleList() const {
|
|
||||||
return isListOf<c10::ComplexType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool IValue::isTensorList() const {
|
|
||||||
return isListOf<c10::TensorType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool IValue::isIntList() const {
|
|
||||||
return isListOf<c10::IntType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool IValue::isBoolList() const {
|
|
||||||
return isListOf<c10::BoolType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
|
inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
|
||||||
: tag(Tag::Tuple), is_intrusive_ptr(true) {
|
: tag(Tag::Tuple), is_intrusive_ptr(true) {
|
||||||
payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
|
payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
|
||||||
|
|
@ -2285,8 +2258,13 @@ struct IValue::TagType<c10::Type> {
|
||||||
static TORCH_API c10::TypePtr get(const IValue&);
|
static TORCH_API c10::TypePtr get(const IValue&);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct IValue::TagType<c10::DynamicType> {
|
||||||
|
static TORCH_API c10::TypePtr get(const IValue&);
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
typename T::Ptr IValue::type() const {
|
TypePtr IValue::type() const {
|
||||||
return IValue::TagType<T>::get(*this);
|
return IValue::TagType<T>::get(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#include <ATen/core/TensorBody.h>
|
#include <ATen/core/TensorBody.h>
|
||||||
#include <ATen/core/functional.h>
|
#include <ATen/core/functional.h>
|
||||||
#include <ATen/core/symbol.h>
|
#include <ATen/core/symbol.h>
|
||||||
|
#include <ATen/core/type_factory.h>
|
||||||
#include <ATen/core/qualified_name.h>
|
#include <ATen/core/qualified_name.h>
|
||||||
#include <c10/util/TypeList.h>
|
#include <c10/util/TypeList.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
|
|
@ -1730,7 +1731,8 @@ struct getTypePtr_<c10::QScheme> final {
|
||||||
template <>
|
template <>
|
||||||
struct getTypePtr_<at::Generator> final {
|
struct getTypePtr_<at::Generator> final {
|
||||||
static decltype(auto) call() {
|
static decltype(auto) call() {
|
||||||
return OptionalType::create(GeneratorType::get());
|
return TypeFactory::create<OptionalType>(
|
||||||
|
TypeFactory::get<GeneratorType>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
|
|
@ -1798,7 +1800,8 @@ struct getTypePtr_<c10::Dict<K, V>> final {
|
||||||
template <class T>
|
template <class T>
|
||||||
struct getTypePtr_<at::optional<T>> final {
|
struct getTypePtr_<at::optional<T>> final {
|
||||||
static const auto& call() {
|
static const auto& call() {
|
||||||
static auto type = OptionalType::create(getTypePtr_<T>::call());
|
static auto type = TypeFactory::create<OptionalType>(
|
||||||
|
getTypePtr_<T>::call());
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -558,6 +558,9 @@ struct TORCH_API Type {
|
||||||
virtual TypePtr containedType(size_t i) const {
|
virtual TypePtr containedType(size_t i) const {
|
||||||
return containedTypes().at(i);
|
return containedTypes().at(i);
|
||||||
}
|
}
|
||||||
|
virtual size_t containedTypeSize() const {
|
||||||
|
return containedTypes().size();
|
||||||
|
}
|
||||||
// create a new version of this type, replacing its contained types with
|
// create a new version of this type, replacing its contained types with
|
||||||
// contained_types
|
// contained_types
|
||||||
TypePtr withContained(std::vector<TypePtr> contained_types);
|
TypePtr withContained(std::vector<TypePtr> contained_types);
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
#include <ATen/core/type_factory.h>
|
#include <ATen/core/type_factory.h>
|
||||||
|
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
// Dtype constraints are not constrained in compilation. Therefore, we map
|
// Dtype constraints are not constrained in compilation. Therefore, we map
|
||||||
|
|
@ -56,4 +58,11 @@ const std::unordered_map<std::string, c10::TypePtr>& DefaultTypeFactory::
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c10::TypePtr DefaultTypeFactory::createNamedTuple(
|
||||||
|
const std::string& name,
|
||||||
|
const std::vector<c10::string_view>& fields,
|
||||||
|
const std::vector<c10::TypePtr>& types) {
|
||||||
|
return c10::TupleType::createNamed(name, fields, types);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,19 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/core/dynamic_type.h>
|
|
||||||
#include <ATen/core/jit_type.h>
|
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <ATen/core/dynamic_type.h>
|
||||||
|
#include <ATen/core/jit_type_base.h>
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
struct TORCH_API DynamicTypeFactory {
|
template <typename T>
|
||||||
|
struct TORCH_API TypeFactoryBase {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TORCH_API TypeFactoryBase<c10::DynamicType> {
|
||||||
template <typename T, typename... Args>
|
template <typename T, typename... Args>
|
||||||
static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) {
|
static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) {
|
||||||
return std::make_shared<c10::DynamicType>(
|
return std::make_shared<c10::DynamicType>(
|
||||||
|
|
@ -29,26 +36,40 @@ struct TORCH_API DynamicTypeFactory {
|
||||||
name,
|
name,
|
||||||
c10::DynamicType::Arguments(fields, types));
|
c10::DynamicType::Arguments(fields, types));
|
||||||
}
|
}
|
||||||
|
template <typename T>
|
||||||
|
C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) {
|
||||||
|
return std::make_shared<c10::DynamicType>(
|
||||||
|
c10::DynamicTypeTrait<T>::tagValue(),
|
||||||
|
name,
|
||||||
|
c10::DynamicType::Arguments{});
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
C10_ERASE static c10::DynamicTypePtr get() {
|
||||||
|
return DynamicTypeTrait<T>::getBaseType();
|
||||||
|
}
|
||||||
static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
|
static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using DynamicTypeFactory = TypeFactoryBase<c10::DynamicType>;
|
||||||
|
|
||||||
// Helper functions for constructing DynamicTypes inline.
|
// Helper functions for constructing DynamicTypes inline.
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
std::enable_if_t<DynamicTypeTrait<T>::isBaseType, int> = 0>
|
std::enable_if_t<DynamicTypeTrait<T>::isBaseType, int> = 0>
|
||||||
DynamicTypePtr dynT() {
|
C10_ERASE DynamicTypePtr dynT() {
|
||||||
return DynamicTypeTrait<T>::getBaseType();
|
return DynamicTypeFactory::get<T>();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
typename... Args,
|
typename... Args,
|
||||||
std::enable_if_t<!DynamicTypeTrait<T>::isBaseType, int> = 0>
|
std::enable_if_t<!DynamicTypeTrait<T>::isBaseType, int> = 0>
|
||||||
DynamicTypePtr dynT(Args&&... args) {
|
C10_ERASE DynamicTypePtr dynT(Args&&... args) {
|
||||||
return DynamicTypeFactory::create<T>(std::forward<Args>(args)...);
|
return DynamicTypeFactory::create<T>(std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TORCH_API DefaultTypeFactory {
|
template <>
|
||||||
|
struct TORCH_API TypeFactoryBase<c10::Type> {
|
||||||
template <typename T, typename... Args>
|
template <typename T, typename... Args>
|
||||||
static c10::TypePtr create(TypePtr ty, Args&&... args) {
|
static c10::TypePtr create(TypePtr ty, Args&&... args) {
|
||||||
return T::create(std::move(ty), std::forward<Args>(args)...);
|
return T::create(std::move(ty), std::forward<Args>(args)...);
|
||||||
|
|
@ -60,18 +81,28 @@ struct TORCH_API DefaultTypeFactory {
|
||||||
static c10::TypePtr createNamedTuple(
|
static c10::TypePtr createNamedTuple(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::vector<c10::string_view>& fields,
|
const std::vector<c10::string_view>& fields,
|
||||||
const std::vector<c10::TypePtr>& types) {
|
const std::vector<c10::TypePtr>& types);
|
||||||
return c10::TupleType::createNamed(name, fields, types);
|
template <typename T>
|
||||||
|
C10_ERASE static c10::TypePtr createNamed(const std::string& name) {
|
||||||
|
return T::create(name);
|
||||||
}
|
}
|
||||||
static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
|
static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
|
||||||
|
template <typename T>
|
||||||
|
C10_ERASE static c10::TypePtr get() {
|
||||||
|
return T::get();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
using TypeFactory =
|
using DefaultTypeFactory = TypeFactoryBase<c10::Type>;
|
||||||
|
|
||||||
|
using PlatformType =
|
||||||
#ifdef C10_MOBILE
|
#ifdef C10_MOBILE
|
||||||
DynamicTypeFactory
|
c10::DynamicType
|
||||||
#else
|
#else
|
||||||
DefaultTypeFactory
|
c10::Type
|
||||||
#endif
|
#endif
|
||||||
;
|
;
|
||||||
|
|
||||||
|
using TypeFactory = TypeFactoryBase<PlatformType>;
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -225,6 +225,16 @@ using namespace c10::hip;
|
||||||
#define C10_ALWAYS_INLINE inline
|
#define C10_ALWAYS_INLINE inline
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#define C10_ATTR_VISIBILITY_HIDDEN
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden")))
|
||||||
|
#else
|
||||||
|
#define C10_ATTR_VISIBILITY_HIDDEN
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN
|
||||||
|
|
||||||
// C10_FALLTHROUGH - Annotate fallthrough to the next case in a switch.
|
// C10_FALLTHROUGH - Annotate fallthrough to the next case in a switch.
|
||||||
#if C10_HAS_CPP_ATTRIBUTE(fallthrough)
|
#if C10_HAS_CPP_ATTRIBUTE(fallthrough)
|
||||||
#define C10_FALLTHROUGH [[fallthrough]]
|
#define C10_FALLTHROUGH [[fallthrough]]
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@ namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
static inline TypePtr unwrapOptional(TypePtr opt_type) {
|
static inline TypePtr unwrapOptional(TypePtr opt_type) {
|
||||||
|
if (auto dyn = opt_type->castRaw<c10::DynamicType>()) {
|
||||||
|
return unwrapOptional(dyn->fallback());
|
||||||
|
}
|
||||||
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
|
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
|
||||||
return unwrap_list_type->getElementType();
|
return unwrap_list_type->getElementType();
|
||||||
}
|
}
|
||||||
|
|
@ -282,12 +285,17 @@ static bool varargsCanBeUsedAsList(
|
||||||
bool is_last_argument = arg_index + 1 == schema.arguments().size() ||
|
bool is_last_argument = arg_index + 1 == schema.arguments().size() ||
|
||||||
schema.arguments()[arg_index + 1].kwarg_only();
|
schema.arguments()[arg_index + 1].kwarg_only();
|
||||||
|
|
||||||
|
auto arg_type = arg.type();
|
||||||
|
if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
|
||||||
|
arg_type = dyn->fallback();
|
||||||
|
}
|
||||||
|
|
||||||
// The formal must be a list
|
// The formal must be a list
|
||||||
bool argument_is_list = arg.type()->kind() == TypeKind::ListType;
|
bool argument_is_list = arg_type->kind() == TypeKind::ListType;
|
||||||
|
|
||||||
// matching varargs of typevar list nyi
|
// matching varargs of typevar list nyi
|
||||||
bool typevar_list = argument_is_list &&
|
bool typevar_list = argument_is_list &&
|
||||||
arg.type()->castRaw<ListType>()->getElementType()->cast<VarType>();
|
arg_type->castRaw<ListType>()->getElementType()->cast<VarType>();
|
||||||
|
|
||||||
// it must not be a broadcasting list like int[3],
|
// it must not be a broadcasting list like int[3],
|
||||||
// otherwise a single int is a valid input
|
// otherwise a single int is a valid input
|
||||||
|
|
|
||||||
|
|
@ -41,32 +41,33 @@ namespace jit {
|
||||||
|
|
||||||
TypePtr SchemaTypeParser::parseBaseType() {
|
TypePtr SchemaTypeParser::parseBaseType() {
|
||||||
static std::unordered_map<std::string, TypePtr> type_map = {
|
static std::unordered_map<std::string, TypePtr> type_map = {
|
||||||
{"Generator", GeneratorType::get()},
|
{"Generator", c10::TypeFactory::get<GeneratorType>()},
|
||||||
{"Dimname", StringType::get()},
|
{"Dimname", c10::TypeFactory::get<StringType>()},
|
||||||
{"ScalarType", IntType::get()},
|
{"ScalarType", c10::TypeFactory::get<IntType>()},
|
||||||
{"Layout", IntType::get()},
|
{"Layout", c10::TypeFactory::get<IntType>()},
|
||||||
{"MemoryFormat", IntType::get()},
|
{"MemoryFormat", c10::TypeFactory::get<IntType>()},
|
||||||
{"Storage", StorageType::get()},
|
{"Storage", c10::TypeFactory::get<StorageType>()},
|
||||||
{"QScheme", QSchemeType::get()},
|
{"QScheme", c10::TypeFactory::get<QSchemeType>()},
|
||||||
{"Quantizer", QuantizerType::get()},
|
{"Quantizer", c10::TypeFactory::get<QuantizerType>()},
|
||||||
{"ConstQuantizerPtr",
|
{"ConstQuantizerPtr",
|
||||||
IntType::get()}, // TODO This type should be removed from the schema
|
c10::TypeFactory::get<IntType>()}, // TODO This type should be removed
|
||||||
// parser, it should use the custom class mechanism
|
// from the schema parser, it should
|
||||||
// instead. @jerryzh
|
// use the custom class mechanism
|
||||||
{"Device", DeviceObjType::get()},
|
// instead. @jerryzh
|
||||||
{"Stream", StreamObjType::get()},
|
{"Device", c10::TypeFactory::get<DeviceObjType>()},
|
||||||
{"Scalar", NumberType::get()},
|
{"Stream", c10::TypeFactory::get<StreamObjType>()},
|
||||||
{"str", StringType::get()},
|
{"Scalar", c10::TypeFactory::get<NumberType>()},
|
||||||
{"float", FloatType::get()},
|
{"str", c10::TypeFactory::get<StringType>()},
|
||||||
{"complex", ComplexType::get()},
|
{"float", c10::TypeFactory::get<FloatType>()},
|
||||||
{"int", IntType::get()},
|
{"complex", c10::TypeFactory::get<ComplexType>()},
|
||||||
{"bool", BoolType::get()},
|
{"int", c10::TypeFactory::get<IntType>()},
|
||||||
{"None", NoneType::get()},
|
{"bool", c10::TypeFactory::get<BoolType>()},
|
||||||
{"NoneType", NoneType::get()},
|
{"None", c10::TypeFactory::get<NoneType>()},
|
||||||
{"Capsule", CapsuleType::get()},
|
{"NoneType", c10::TypeFactory::get<NoneType>()},
|
||||||
{"Any", at::AnyType::get()},
|
{"Capsule", c10::TypeFactory::get<CapsuleType>()},
|
||||||
{"AnyClassType", at::AnyClassType::get()},
|
{"Any", c10::TypeFactory::get<c10::AnyType>()},
|
||||||
{"AnyEnumType", at::AnyEnumType::get()},
|
{"AnyClassType", c10::TypeFactory::get<c10::AnyClassType>()},
|
||||||
|
{"AnyEnumType", c10::TypeFactory::get<c10::AnyEnumType>()},
|
||||||
};
|
};
|
||||||
auto tok = L.cur();
|
auto tok = L.cur();
|
||||||
if (!L.nextIf(TK_NONE) && !L.nextIf(TK_NONE_TYPE)) {
|
if (!L.nextIf(TK_NONE) && !L.nextIf(TK_NONE_TYPE)) {
|
||||||
|
|
@ -79,7 +80,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
|
||||||
if (text.size() > 0 && islower(text[0])) {
|
if (text.size() > 0 && islower(text[0])) {
|
||||||
// lower case identifiers that are not otherwise valid types
|
// lower case identifiers that are not otherwise valid types
|
||||||
// are treated as type variables
|
// are treated as type variables
|
||||||
return VarType::create(text);
|
return c10::TypeFactory::createNamed<VarType>(text);
|
||||||
}
|
}
|
||||||
throw ErrorReport(tok.range) << "unknown type specifier";
|
throw ErrorReport(tok.range) << "unknown type specifier";
|
||||||
}
|
}
|
||||||
|
|
@ -313,7 +314,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||||
alias_info->addContainedType(std::move(*r.second));
|
alias_info->addContainedType(std::move(*r.second));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
value = TupleType::create(std::move(types));
|
value = c10::TypeFactory::create<TupleType>(std::move(types));
|
||||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
|
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
|
||||||
L.next(); // Future
|
L.next(); // Future
|
||||||
L.expect('(');
|
L.expect('(');
|
||||||
|
|
@ -321,7 +322,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||||
auto subtype = std::move(p.first);
|
auto subtype = std::move(p.first);
|
||||||
auto subalias = std::move(p.second);
|
auto subalias = std::move(p.second);
|
||||||
L.expect(')');
|
L.expect(')');
|
||||||
value = FutureType::create(subtype);
|
value = c10::TypeFactory::create<FutureType>(subtype);
|
||||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
|
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
|
||||||
L.next(); // RRef
|
L.next(); // RRef
|
||||||
L.expect('(');
|
L.expect('(');
|
||||||
|
|
@ -329,10 +330,10 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||||
auto subtype = std::move(p.first);
|
auto subtype = std::move(p.first);
|
||||||
auto subalias = std::move(p.second);
|
auto subalias = std::move(p.second);
|
||||||
L.expect(')');
|
L.expect(')');
|
||||||
value = RRefType::create(subtype);
|
value = c10::TypeFactory::create<RRefType>(subtype);
|
||||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
|
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
|
||||||
L.next();
|
L.next();
|
||||||
value = TensorType::get();
|
value = c10::TypeFactory::get<TensorType>();
|
||||||
alias_info = parseAliasAnnotation();
|
alias_info = parseAliasAnnotation();
|
||||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
|
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
|
||||||
L.next();
|
L.next();
|
||||||
|
|
@ -342,7 +343,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||||
auto value_type = parseType().first;
|
auto value_type = parseType().first;
|
||||||
L.expect(')');
|
L.expect(')');
|
||||||
alias_info = parseAliasAnnotation();
|
alias_info = parseAliasAnnotation();
|
||||||
value = DictType::create(key_type, value_type);
|
value = c10::TypeFactory::create<DictType>(key_type, value_type);
|
||||||
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
|
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
|
||||||
L.next();
|
L.next();
|
||||||
L.expect('(');
|
L.expect('(');
|
||||||
|
|
@ -395,7 +396,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||||
if (L.cur().kind == '[' && L.lookahead().kind == ']') {
|
if (L.cur().kind == '[' && L.lookahead().kind == ']') {
|
||||||
L.next(); // [
|
L.next(); // [
|
||||||
L.next(); // ]
|
L.next(); // ]
|
||||||
value = ListType::create(value);
|
value = c10::TypeFactory::create<ListType>(value);
|
||||||
auto container = parseAliasAnnotation();
|
auto container = parseAliasAnnotation();
|
||||||
if (container && alias_info) {
|
if (container && alias_info) {
|
||||||
container->addContainedType(std::move(*alias_info));
|
container->addContainedType(std::move(*alias_info));
|
||||||
|
|
|
||||||
|
|
@ -1485,6 +1485,9 @@ inline Value::Value(Node* node_, size_t offset_)
|
||||||
|
|
||||||
inline Value* Value::setType(TypePtr type) {
|
inline Value* Value::setType(TypePtr type) {
|
||||||
AT_ASSERT(type);
|
AT_ASSERT(type);
|
||||||
|
if (auto dyn = type->castRaw<c10::DynamicType>()) {
|
||||||
|
type = dyn->fallback();
|
||||||
|
}
|
||||||
type_ = std::move(type);
|
type_ = std::move(type);
|
||||||
for (Use& use : uses_) {
|
for (Use& use : uses_) {
|
||||||
use.user->op_ = nullptr;
|
use.user->op_ = nullptr;
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,11 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||||
static const std::string torch_prefix("__torch__");
|
static const std::string torch_prefix("__torch__");
|
||||||
static const std::string class_prefix("__torch__.torch.classes");
|
static const std::string class_prefix("__torch__.torch.classes");
|
||||||
|
|
||||||
for (const TypePtr& t : mobile_code.types_) {
|
for (const TypePtr& ty : mobile_code.types_) {
|
||||||
|
auto t = ty;
|
||||||
|
if (auto dyn = t->castRaw<c10::DynamicType>()) {
|
||||||
|
t = dyn->fallback();
|
||||||
|
}
|
||||||
std::string type_str = t->annotation_str();
|
std::string type_str = t->annotation_str();
|
||||||
if (t->kind() == TypeKind::TupleType) {
|
if (t->kind() == TypeKind::TupleType) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
|
|
@ -216,9 +220,13 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||||
arg.type()->annotation_str(type_printer) => mangled unique name of the
|
arg.type()->annotation_str(type_printer) => mangled unique name of the
|
||||||
module/submodule
|
module/submodule
|
||||||
*/
|
*/
|
||||||
|
auto arg_type = arg.type();
|
||||||
|
if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
|
||||||
|
arg_type = dyn->fallback();
|
||||||
|
}
|
||||||
argTables.emplace_back(Table({
|
argTables.emplace_back(Table({
|
||||||
{"name", arg.name()},
|
{"name", arg.name()},
|
||||||
{"type", arg.type()->annotation_str(type_printer)},
|
{"type", arg_type->annotation_str(type_printer)},
|
||||||
{"default_value", arg.default_value()},
|
{"default_value", arg.default_value()},
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -575,7 +575,11 @@ void Pickler::endTypeTag(const IValue& ivalue) {
|
||||||
|
|
||||||
// Push the dict type
|
// Push the dict type
|
||||||
TORCH_INTERNAL_ASSERT(ivalue.type());
|
TORCH_INTERNAL_ASSERT(ivalue.type());
|
||||||
pushString(ivalue.type()->annotation_str());
|
auto type = ivalue.type();
|
||||||
|
if (auto dyn = type->castRaw<c10::DynamicType>()) {
|
||||||
|
type = dyn->fallback();
|
||||||
|
}
|
||||||
|
pushString(type->annotation_str());
|
||||||
|
|
||||||
// Pop the dict and type into a tuple
|
// Pop the dict and type into a tuple
|
||||||
push<PickleOpCode>(PickleOpCode::TUPLE2);
|
push<PickleOpCode>(PickleOpCode::TUPLE2);
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
|
||||||
// of the contained objects and cannot restore the tags.
|
// of the contained objects and cannot restore the tags.
|
||||||
void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||||
struct Work {
|
struct Work {
|
||||||
TypePtr static_type;
|
TypePtr type;
|
||||||
IValue value;
|
IValue value;
|
||||||
};
|
};
|
||||||
std::vector<Work> to_process = {{type_tag, root}};
|
std::vector<Work> to_process = {{type_tag, root}};
|
||||||
|
|
@ -53,7 +53,11 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||||
}
|
}
|
||||||
scanned.emplace_hint(it, key);
|
scanned.emplace_hint(it, key);
|
||||||
}
|
}
|
||||||
switch (w.static_type->kind()) {
|
auto kind = w.type->kind();
|
||||||
|
if (auto dyn = w.type->castRaw<c10::DynamicType>()) {
|
||||||
|
kind = dyn->dynamicKind();
|
||||||
|
}
|
||||||
|
switch (kind) {
|
||||||
case TensorType::Kind:
|
case TensorType::Kind:
|
||||||
case StorageType::Kind:
|
case StorageType::Kind:
|
||||||
case NumberType::Kind:
|
case NumberType::Kind:
|
||||||
|
|
@ -83,52 +87,37 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||||
// no op, there is nothing to tag
|
// no op, there is nothing to tag
|
||||||
break;
|
break;
|
||||||
case DynamicType::Kind:
|
case DynamicType::Kind:
|
||||||
|
case UnionType::Kind:
|
||||||
case EnumType::Kind:
|
case EnumType::Kind:
|
||||||
// TODO(gmagogsfm): Implement serialization/deserialization of Enum.
|
// TODO(gmagogsfm): Implement serialization/deserialization of Enum.
|
||||||
TORCH_INTERNAL_ASSERT(false);
|
TORCH_INTERNAL_ASSERT(false);
|
||||||
case TupleType::Kind: {
|
case TupleType::Kind: {
|
||||||
auto t = w.value.toTuple();
|
auto t = w.value.toTuple();
|
||||||
auto ttype = w.static_type->expect<TupleType>();
|
for (size_t i = 0; i < w.type->containedTypeSize(); ++i) {
|
||||||
for (size_t i = 0; i < ttype->containedTypes().size(); ++i) {
|
Work elem = {w.type->containedType(i), t->elements().at(i)};
|
||||||
Work elem = {ttype->containedTypes().at(i), t->elements().at(i)};
|
|
||||||
to_process.emplace_back(std::move(elem));
|
to_process.emplace_back(std::move(elem));
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case FutureType::Kind: {
|
case FutureType::Kind: {
|
||||||
auto f = w.value.toFuture();
|
auto f = w.value.toFuture();
|
||||||
auto t = w.static_type->expect<FutureType>();
|
|
||||||
if (f->completed()) {
|
if (f->completed()) {
|
||||||
Work elem = {t->getElementType(), f->value()};
|
Work elem = {w.type->containedType(0), f->value()};
|
||||||
to_process.emplace_back(std::move(elem));
|
to_process.emplace_back(std::move(elem));
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case OptionalType::Kind: {
|
case OptionalType::Kind: {
|
||||||
if (!w.value.isNone()) {
|
if (!w.value.isNone()) {
|
||||||
auto t = w.static_type->expect<OptionalType>();
|
Work elem = {w.type->containedType(0), w.value};
|
||||||
Work elem = {t->getElementType(), w.value};
|
|
||||||
to_process.emplace_back(std::move(elem));
|
to_process.emplace_back(std::move(elem));
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case UnionType::Kind: {
|
|
||||||
auto t = w.static_type->expect<UnionType>();
|
|
||||||
if (t->containedTypes().size() == 2 &&
|
|
||||||
t->canHoldType(*NoneType::get())) {
|
|
||||||
if (!w.value.isNone()) {
|
|
||||||
auto inner = t->containedTypes()[0] != NoneType::get()
|
|
||||||
? t->containedTypes()[0]
|
|
||||||
: t->containedTypes()[1];
|
|
||||||
Work elem = {inner, w.value};
|
|
||||||
to_process.emplace_back(std::move(elem));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
case ListType::Kind: {
|
case ListType::Kind: {
|
||||||
// specialized lists do not need their type refined, so we can exit
|
// specialized lists do not need their type refined, so we can exit
|
||||||
// early here
|
// early here
|
||||||
if (!w.value.isList()) {
|
if (!w.value.isList()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
auto elem_type = w.static_type->castRaw<ListType>()->getElementType();
|
auto elem_type = w.type->containedType(0);
|
||||||
auto lst = w.value.toList();
|
auto lst = w.value.toList();
|
||||||
lst.unsafeSetElementType(elem_type);
|
lst.unsafeSetElementType(elem_type);
|
||||||
for (const IValue item : lst) {
|
for (const IValue item : lst) {
|
||||||
|
|
@ -137,13 +126,14 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case DictType::Kind: {
|
case DictType::Kind: {
|
||||||
auto dt = w.static_type->cast<DictType>();
|
|
||||||
auto d = w.value.toGenericDict();
|
auto d = w.value.toGenericDict();
|
||||||
d.unsafeSetKeyType(dt->getKeyType());
|
auto keyType = w.type->containedType(0);
|
||||||
d.unsafeSetValueType(dt->getValueType());
|
auto valType = w.type->containedType(1);
|
||||||
|
d.unsafeSetKeyType(keyType);
|
||||||
|
d.unsafeSetValueType(valType);
|
||||||
for (const auto& item : d) {
|
for (const auto& item : d) {
|
||||||
Work kelem = {dt->getKeyType(), item.key()};
|
Work kelem = {keyType, item.key()};
|
||||||
Work velem = {dt->getValueType(), item.value()};
|
Work velem = {valType, item.value()};
|
||||||
to_process.emplace_back(std::move(kelem));
|
to_process.emplace_back(std::move(kelem));
|
||||||
to_process.emplace_back(std::move(velem));
|
to_process.emplace_back(std::move(velem));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user