[jit][edge] Migrate to TypeFactory for jit types on mobile (#71516)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71516

Mobile should be able to contruct dynamic types by default.
ghstack-source-id: 147498365

Test Plan:
CI.

**-48KB** binary size reduction for igios BSB.
UMBEX link: https://www.internalfb.com/intern/unigraph/explorer/?jsgq_traversal_spec=%7B%22builds%22%3A[%22bsb%3A422553426218394%5Cu0040base%22%2C%22bsb%3A422553426218394%5Cu0040diff%22]%7D&unigraph_project=UnigraphProjectMbex&is_mbex_redirected

Reviewed By: iseeyuan

Differential Revision: D33673958

fbshipit-source-id: 8600c04ae929283681971aae264d3774188df9cd
(cherry picked from commit 64ebcec09e)
This commit is contained in:
Zhengxu Chen 2022-01-25 22:58:45 -08:00 committed by PyTorch MergeBot
parent e5794974cb
commit fe277b8717
20 changed files with 236 additions and 140 deletions

View File

@ -308,7 +308,7 @@ c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
#else #else
// TODO: caffe2::PThreadPool only provides a data-parallel API. // TODO: caffe2::PThreadPool only provides a data-parallel API.
// Task parallelism is not currently supported. // Task parallelism is not currently supported.
auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get()); auto future = c10::make_intrusive<c10::ivalue::Future>(c10::dynT<NoneType>());
func(); func();
future->markCompleted(); future->markCompleted();
return future; return future;

View File

@ -3,6 +3,7 @@
#include <ATen/core/jit_type.h> #include <ATen/core/jit_type.h>
#include <ATen/core/function_schema.h> #include <ATen/core/function_schema.h>
#include <ATen/core/functional.h> #include <ATen/core/functional.h>
#include <ATen/core/type_factory.h>
#include <atomic> #include <atomic>
#include <unordered_map> #include <unordered_map>
@ -102,7 +103,7 @@ class_base::class_base(
{ {
detail::checkValidIdent(namespaceName, "Namespace name"); detail::checkValidIdent(namespaceName, "Namespace name");
detail::checkValidIdent(className, "Class name"); detail::checkValidIdent(className, "Class name");
classTypePtr->addAttribute("capsule", at::CapsuleType::get()); classTypePtr->addAttribute("capsule", c10::TypeFactory::get<c10::CapsuleType>());
c10::getCustomClassTypeMap().insert( c10::getCustomClassTypeMap().insert(
{std::type_index(intrusivePtrClassTypeid), classTypePtr}); {std::type_index(intrusivePtrClassTypeid), classTypePtr});
c10::getCustomClassTypeMap().insert( c10::getCustomClassTypeMap().insert(

View File

@ -2,6 +2,7 @@
#include <string> #include <string>
#include <ATen/core/class_type.h>
#include <ATen/core/ivalue.h> #include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h> #include <ATen/core/jit_type.h>
#include <ATen/core/type_factory.h> #include <ATen/core/type_factory.h>
@ -198,6 +199,11 @@ TypePtr DynamicType::containedType(size_t i) const {
return arguments_.elems.at(i).ty; return arguments_.elems.at(i).ty;
} }
size_t DynamicType::containedTypeSize() const {
TORCH_INTERNAL_ASSERT(tag_ != Tag::Class);
return arguments_.elems.size();
}
TypeKind DynamicType::dynamicKind() const { TypeKind DynamicType::dynamicKind() const {
switch (tag_) { switch (tag_) {
#define CASE_TYPE(T, _, __) \ #define CASE_TYPE(T, _, __) \
@ -271,6 +277,16 @@ TypePtr DynamicType::fallback() const {
return VarType::create(*name_); return VarType::create(*name_);
case Tag::AnyClass: case Tag::AnyClass:
return AnyClassType::get(); return AnyClassType::get();
case Tag::QScheme:
return QSchemeType::get();
case Tag::Quantizer:
return QuantizerType::get();
case Tag::AnyEnum:
return AnyEnumType::get();
case Tag::RRef:
return RRefType::create(arguments_.elems[0].ty->fallback());
case Tag::Future:
return FutureType::create(arguments_.elems[0].ty->fallback());
case Tag::Any: case Tag::Any:
return AnyType::get(); return AnyType::get();
} }

View File

@ -3,8 +3,6 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <ATen/core/class_type.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type_base.h> #include <ATen/core/jit_type_base.h>
#include <c10/util/Optional.h> #include <c10/util/Optional.h>
@ -53,8 +51,17 @@ constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
_(Storage, DYNAMIC_TYPE_BIT(16), 1) \ _(Storage, DYNAMIC_TYPE_BIT(16), 1) \
_(Var, DYNAMIC_TYPE_BIT(17), 0) \ _(Var, DYNAMIC_TYPE_BIT(17), 0) \
_(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1) \ _(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1) \
_(QScheme, DYNAMIC_TYPE_BIT(18), 1) \
_(Quantizer, DYNAMIC_TYPE_BIT(19), 1) \
_(AnyEnum, DYNAMIC_TYPE_BIT(20), 1) \
_(RRef, DYNAMIC_TYPE_BIT(21), 0) \
_(Future, DYNAMIC_TYPE_BIT(22), 0) \
_(Any, 0xffffffff, 1) _(Any, 0xffffffff, 1)
#define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;
FORALL_DYNAMIC_TYPES(FORWARD_DECL_TYPE)
#undef FORWARD_DECL_TYPE
class DynamicType; class DynamicType;
using DynamicTypePtr = std::shared_ptr<DynamicType>; using DynamicTypePtr = std::shared_ptr<DynamicType>;
@ -142,6 +149,7 @@ class DynamicType : public SharedType {
explicit DynamicType(Tag, c10::string_view, Arguments); explicit DynamicType(Tag, c10::string_view, Arguments);
TypePtr containedType(size_t) const override; TypePtr containedType(size_t) const override;
size_t containedTypeSize() const override;
Tag tag() const { Tag tag() const {
return tag_; return tag_;
} }
@ -154,6 +162,9 @@ class DynamicType : public SharedType {
TypeKind dynamicKind() const; TypeKind dynamicKind() const;
// Should be used only on the server side to restore static type information. // Should be used only on the server side to restore static type information.
#ifndef C10_MOBILE
TORCH_API
#endif
TypePtr fallback() const; TypePtr fallback() const;
private: private:
@ -188,7 +199,7 @@ class DynamicType : public SharedType {
template <typename T> template <typename T>
struct DynamicTypeTrait { struct DynamicTypeTrait {
static auto tagValue() { C10_NOINLINE static auto tagValue() {
TORCH_CHECK(false); TORCH_CHECK(false);
return DynamicType::Tag::Any; return DynamicType::Tag::Any;
} }
@ -201,7 +212,7 @@ C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag);
#define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE) \ #define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE) \
template <> \ template <> \
struct TORCH_API DynamicTypeTrait<NAME##Type> { \ struct TORCH_API DynamicTypeTrait<NAME##Type> { \
static auto tagValue() { \ C10_ERASE static auto tagValue() { \
return DynamicType::Tag::NAME; \ return DynamicType::Tag::NAME; \
} \ } \
static constexpr bool isBaseType = IS_BASE_TYPE; \ static constexpr bool isBaseType = IS_BASE_TYPE; \
@ -214,19 +225,4 @@ C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag);
FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE) FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE)
#undef DYNAMIC_TYPE_TAG_VALUE #undef DYNAMIC_TYPE_TAG_VALUE
template <>
struct IValue::TagType<c10::DynamicType> {
static DynamicType::Ptr get(const c10::IValue& v);
};
namespace ivalue {
template <>
struct TORCH_API TupleTypeFactory<c10::DynamicType> {
static DynamicTypePtr create(std::vector<TypePtr> elemTypes);
static DynamicTypePtr fallback(const Type&);
};
} // namespace ivalue
} // namespace c10 } // namespace c10

View File

@ -390,7 +390,7 @@ struct FunctionSchema {
// Check that inputs have the correct types and appends any missing default // Check that inputs have the correct types and appends any missing default
// values. // values.
template <typename T = c10::Type> template <typename T = c10::PlatformType>
void checkAndNormalizeInputs( void checkAndNormalizeInputs(
std::vector<IValue>& inputs, std::vector<IValue>& inputs,
const std::unordered_map<std::string, IValue>& kwargs = const std::unordered_map<std::string, IValue>& kwargs =

View File

@ -293,7 +293,7 @@ inline void FunctionSchema::checkArg(
TORCH_CHECK( TORCH_CHECK(
false, false,
formatTypeMismatchMsg( formatTypeMismatchMsg(
argument, value.type()->repr_str(), pos)); argument, value.type<T>()->repr_str(), pos));
} }
} }

View File

@ -6,6 +6,7 @@
#include <ATen/core/function.h> #include <ATen/core/function.h>
#include <ATen/core/jit_type.h> #include <ATen/core/jit_type.h>
#include <ATen/core/stack.h> #include <ATen/core/stack.h>
#include <ATen/core/type_factory.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <c10/util/StringUtil.h> #include <c10/util/StringUtil.h>
#include <c10/util/hash.h> #include <c10/util/hash.h>
@ -403,6 +404,39 @@ bool IValue::is(const IValue& rhs) const {
return lhs == rhs; return lhs == rhs;
} }
template <typename T>
inline bool IValue::isListOf() const {
// note: avoids calling type() to avoid extra referencing counting for the returned type.
if (!isList()) {
return false;
}
const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
if (ty->kind() == T::Kind) {
return true;
}
return *ty == *TypeFactory::get<T>();
}
bool IValue::isDoubleList() const {
return isListOf<c10::FloatType>();
}
bool IValue::isComplexDoubleList() const {
return isListOf<c10::ComplexType>();
}
bool IValue::isTensorList() const {
return isListOf<c10::TensorType>();
}
bool IValue::isIntList() const {
return isListOf<c10::IntType>();
}
bool IValue::isBoolList() const {
return isListOf<c10::BoolType>();
}
namespace { namespace {
using IValueFormatter = std::function<void(std::ostream&, const IValue&)>; using IValueFormatter = std::function<void(std::ostream&, const IValue&)>;
@ -430,7 +464,7 @@ std::ostream& printMaybeAnnotatedList(
std::ostream& out, std::ostream& out,
const IValue& the_list, const IValue& the_list,
IValueFormatter formatter) { IValueFormatter formatter) {
auto list_elem_type = the_list.type()->expectRef<ListType>().getElementType(); auto list_elem_type = the_list.type()->containedType(0);
if (the_list.toListRef().size() == 0 || if (the_list.toListRef().size() == 0 ||
!elementTypeCanBeInferredFromMembers(list_elem_type)) { !elementTypeCanBeInferredFromMembers(list_elem_type)) {
out << "annotate(" << the_list.type()->annotation_str() << ", "; out << "annotate(" << the_list.type()->annotation_str() << ", ";
@ -925,7 +959,7 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(IValue::HashAliasedI
auto cu = type_.cu_; auto cu = type_.cu_;
auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes()); auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes());
for (const auto i : c10::irange(slots_.size())) { for (const auto i : c10::irange(slots_.size())) {
if (slots_[i].type() == c10::CapsuleType::get()) { if (*slots_[i].type() == *c10::TypeFactory::get<CapsuleType>()) {
// If we've gotten here, it means that we have *not* copied this // If we've gotten here, it means that we have *not* copied this
// class via __getstate__ and __setstate__. That fact and the // class via __getstate__ and __setstate__. That fact and the
// fact that we have a Capsule attribute mean that this is a // fact that we have a Capsule attribute mean that this is a

View File

@ -6,6 +6,7 @@
#include <ATen/core/custom_class.h> #include <ATen/core/custom_class.h>
#include <ATen/core/ivalue_to.h> #include <ATen/core/ivalue_to.h>
#include <ATen/core/jit_type_base.h> #include <ATen/core/jit_type_base.h>
#include <ATen/core/type_factory.h>
#include <c10/util/C++17.h> #include <c10/util/C++17.h>
#include <c10/util/MaybeOwned.h> #include <c10/util/MaybeOwned.h>
#include <c10/util/intrusive_ptr.h> #include <c10/util/intrusive_ptr.h>
@ -895,8 +896,8 @@ public:
} }
} }
template <typename T = c10::Type> template <typename T = c10::PlatformType>
typename T::Ptr type() const; TypePtr type() const;
// Detect aliased tensors. // Detect aliased tensors.
struct HashAliasedIValue { struct HashAliasedIValue {

View File

@ -586,6 +586,12 @@ struct TORCH_API TupleTypeFactory<TupleType> {
static TupleTypePtr fallback(const Type& type); static TupleTypePtr fallback(const Type& type);
}; };
template <>
struct TORCH_API TupleTypeFactory<c10::DynamicType> {
static DynamicTypePtr create(std::vector<TypePtr> elemTypes);
static DynamicTypePtr fallback(const Type&);
};
struct TORCH_API Tuple : c10::intrusive_ptr_target { struct TORCH_API Tuple : c10::intrusive_ptr_target {
private: private:
TupleElements elements_; TupleElements elements_;
@ -1915,39 +1921,6 @@ inline ivalue::Tuple& IValue::toTupleRef() const {
payload.u.as_intrusive_ptr); payload.u.as_intrusive_ptr);
} }
template <typename T>
inline bool IValue::isListOf() const {
// note: avoids calling type() to avoid extra referencing counting for the returned type.
if (!isList()) {
return false;
}
const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
if (ty->kind() == T::Kind) {
return true;
}
return *ty == *T::get();
}
inline bool IValue::isDoubleList() const {
return isListOf<c10::FloatType>();
}
inline bool IValue::isComplexDoubleList() const {
return isListOf<c10::ComplexType>();
}
inline bool IValue::isTensorList() const {
return isListOf<c10::TensorType>();
}
inline bool IValue::isIntList() const {
return isListOf<c10::IntType>();
}
inline bool IValue::isBoolList() const {
return isListOf<c10::BoolType>();
}
inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v) inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
: tag(Tag::Tuple), is_intrusive_ptr(true) { : tag(Tag::Tuple), is_intrusive_ptr(true) {
payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
@ -2285,8 +2258,13 @@ struct IValue::TagType<c10::Type> {
static TORCH_API c10::TypePtr get(const IValue&); static TORCH_API c10::TypePtr get(const IValue&);
}; };
template <>
struct IValue::TagType<c10::DynamicType> {
static TORCH_API c10::TypePtr get(const IValue&);
};
template <typename T> template <typename T>
typename T::Ptr IValue::type() const { TypePtr IValue::type() const {
return IValue::TagType<T>::get(*this); return IValue::TagType<T>::get(*this);
} }

View File

@ -5,6 +5,7 @@
#include <ATen/core/TensorBody.h> #include <ATen/core/TensorBody.h>
#include <ATen/core/functional.h> #include <ATen/core/functional.h>
#include <ATen/core/symbol.h> #include <ATen/core/symbol.h>
#include <ATen/core/type_factory.h>
#include <ATen/core/qualified_name.h> #include <ATen/core/qualified_name.h>
#include <c10/util/TypeList.h> #include <c10/util/TypeList.h>
#include <c10/util/Optional.h> #include <c10/util/Optional.h>
@ -1730,7 +1731,8 @@ struct getTypePtr_<c10::QScheme> final {
template <> template <>
struct getTypePtr_<at::Generator> final { struct getTypePtr_<at::Generator> final {
static decltype(auto) call() { static decltype(auto) call() {
return OptionalType::create(GeneratorType::get()); return TypeFactory::create<OptionalType>(
TypeFactory::get<GeneratorType>());
} }
}; };
template <> template <>
@ -1798,7 +1800,8 @@ struct getTypePtr_<c10::Dict<K, V>> final {
template <class T> template <class T>
struct getTypePtr_<at::optional<T>> final { struct getTypePtr_<at::optional<T>> final {
static const auto& call() { static const auto& call() {
static auto type = OptionalType::create(getTypePtr_<T>::call()); static auto type = TypeFactory::create<OptionalType>(
getTypePtr_<T>::call());
return type; return type;
} }
}; };

View File

@ -558,6 +558,9 @@ struct TORCH_API Type {
virtual TypePtr containedType(size_t i) const { virtual TypePtr containedType(size_t i) const {
return containedTypes().at(i); return containedTypes().at(i);
} }
virtual size_t containedTypeSize() const {
return containedTypes().size();
}
// create a new version of this type, replacing its contained types with // create a new version of this type, replacing its contained types with
// contained_types // contained_types
TypePtr withContained(std::vector<TypePtr> contained_types); TypePtr withContained(std::vector<TypePtr> contained_types);

View File

@ -1,5 +1,7 @@
#include <ATen/core/type_factory.h> #include <ATen/core/type_factory.h>
#include <ATen/core/jit_type.h>
namespace c10 { namespace c10 {
// Dtype constraints are not constrained in compilation. Therefore, we map // Dtype constraints are not constrained in compilation. Therefore, we map
@ -56,4 +58,11 @@ const std::unordered_map<std::string, c10::TypePtr>& DefaultTypeFactory::
return map; return map;
} }
c10::TypePtr DefaultTypeFactory::createNamedTuple(
const std::string& name,
const std::vector<c10::string_view>& fields,
const std::vector<c10::TypePtr>& types) {
return c10::TupleType::createNamed(name, fields, types);
}
} // namespace c10 } // namespace c10

View File

@ -1,12 +1,19 @@
#pragma once #pragma once
#include <ATen/core/dynamic_type.h>
#include <ATen/core/jit_type.h>
#include <type_traits> #include <type_traits>
#include <unordered_map>
#include <ATen/core/dynamic_type.h>
#include <ATen/core/jit_type_base.h>
#include <c10/macros/Macros.h>
namespace c10 { namespace c10 {
struct TORCH_API DynamicTypeFactory { template <typename T>
struct TORCH_API TypeFactoryBase {};
template <>
struct TORCH_API TypeFactoryBase<c10::DynamicType> {
template <typename T, typename... Args> template <typename T, typename... Args>
static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) { static c10::DynamicTypePtr create(TypePtr ty, Args&&... args) {
return std::make_shared<c10::DynamicType>( return std::make_shared<c10::DynamicType>(
@ -29,26 +36,40 @@ struct TORCH_API DynamicTypeFactory {
name, name,
c10::DynamicType::Arguments(fields, types)); c10::DynamicType::Arguments(fields, types));
} }
template <typename T>
C10_ERASE static c10::DynamicTypePtr createNamed(const std::string& name) {
return std::make_shared<c10::DynamicType>(
c10::DynamicTypeTrait<T>::tagValue(),
name,
c10::DynamicType::Arguments{});
}
template <typename T>
C10_ERASE static c10::DynamicTypePtr get() {
return DynamicTypeTrait<T>::getBaseType();
}
static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes(); static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
}; };
using DynamicTypeFactory = TypeFactoryBase<c10::DynamicType>;
// Helper functions for constructing DynamicTypes inline. // Helper functions for constructing DynamicTypes inline.
template < template <
typename T, typename T,
std::enable_if_t<DynamicTypeTrait<T>::isBaseType, int> = 0> std::enable_if_t<DynamicTypeTrait<T>::isBaseType, int> = 0>
DynamicTypePtr dynT() { C10_ERASE DynamicTypePtr dynT() {
return DynamicTypeTrait<T>::getBaseType(); return DynamicTypeFactory::get<T>();
} }
template < template <
typename T, typename T,
typename... Args, typename... Args,
std::enable_if_t<!DynamicTypeTrait<T>::isBaseType, int> = 0> std::enable_if_t<!DynamicTypeTrait<T>::isBaseType, int> = 0>
DynamicTypePtr dynT(Args&&... args) { C10_ERASE DynamicTypePtr dynT(Args&&... args) {
return DynamicTypeFactory::create<T>(std::forward<Args>(args)...); return DynamicTypeFactory::create<T>(std::forward<Args>(args)...);
} }
struct TORCH_API DefaultTypeFactory { template <>
struct TORCH_API TypeFactoryBase<c10::Type> {
template <typename T, typename... Args> template <typename T, typename... Args>
static c10::TypePtr create(TypePtr ty, Args&&... args) { static c10::TypePtr create(TypePtr ty, Args&&... args) {
return T::create(std::move(ty), std::forward<Args>(args)...); return T::create(std::move(ty), std::forward<Args>(args)...);
@ -60,18 +81,28 @@ struct TORCH_API DefaultTypeFactory {
static c10::TypePtr createNamedTuple( static c10::TypePtr createNamedTuple(
const std::string& name, const std::string& name,
const std::vector<c10::string_view>& fields, const std::vector<c10::string_view>& fields,
const std::vector<c10::TypePtr>& types) { const std::vector<c10::TypePtr>& types);
return c10::TupleType::createNamed(name, fields, types); template <typename T>
C10_ERASE static c10::TypePtr createNamed(const std::string& name) {
return T::create(name);
} }
static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes(); static const std::unordered_map<std::string, c10::TypePtr>& basePythonTypes();
template <typename T>
C10_ERASE static c10::TypePtr get() {
return T::get();
}
}; };
using TypeFactory = using DefaultTypeFactory = TypeFactoryBase<c10::Type>;
using PlatformType =
#ifdef C10_MOBILE #ifdef C10_MOBILE
DynamicTypeFactory c10::DynamicType
#else #else
DefaultTypeFactory c10::Type
#endif #endif
; ;
using TypeFactory = TypeFactoryBase<PlatformType>;
} // namespace c10 } // namespace c10

View File

@ -225,6 +225,16 @@ using namespace c10::hip;
#define C10_ALWAYS_INLINE inline #define C10_ALWAYS_INLINE inline
#endif #endif
#if defined(_MSC_VER)
#define C10_ATTR_VISIBILITY_HIDDEN
#elif defined(__GNUC__)
#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden")))
#else
#define C10_ATTR_VISIBILITY_HIDDEN
#endif
#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN
// C10_FALLTHROUGH - Annotate fallthrough to the next case in a switch. // C10_FALLTHROUGH - Annotate fallthrough to the next case in a switch.
#if C10_HAS_CPP_ATTRIBUTE(fallthrough) #if C10_HAS_CPP_ATTRIBUTE(fallthrough)
#define C10_FALLTHROUGH [[fallthrough]] #define C10_FALLTHROUGH [[fallthrough]]

View File

@ -16,6 +16,9 @@ namespace torch {
namespace jit { namespace jit {
static inline TypePtr unwrapOptional(TypePtr opt_type) { static inline TypePtr unwrapOptional(TypePtr opt_type) {
if (auto dyn = opt_type->castRaw<c10::DynamicType>()) {
return unwrapOptional(dyn->fallback());
}
if (auto unwrap_list_type = opt_type->cast<OptionalType>()) { if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
return unwrap_list_type->getElementType(); return unwrap_list_type->getElementType();
} }
@ -282,12 +285,17 @@ static bool varargsCanBeUsedAsList(
bool is_last_argument = arg_index + 1 == schema.arguments().size() || bool is_last_argument = arg_index + 1 == schema.arguments().size() ||
schema.arguments()[arg_index + 1].kwarg_only(); schema.arguments()[arg_index + 1].kwarg_only();
auto arg_type = arg.type();
if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
arg_type = dyn->fallback();
}
// The formal must be a list // The formal must be a list
bool argument_is_list = arg.type()->kind() == TypeKind::ListType; bool argument_is_list = arg_type->kind() == TypeKind::ListType;
// matching varargs of typevar list nyi // matching varargs of typevar list nyi
bool typevar_list = argument_is_list && bool typevar_list = argument_is_list &&
arg.type()->castRaw<ListType>()->getElementType()->cast<VarType>(); arg_type->castRaw<ListType>()->getElementType()->cast<VarType>();
// it must not be a broadcasting list like int[3], // it must not be a broadcasting list like int[3],
// otherwise a single int is a valid input // otherwise a single int is a valid input

View File

@ -41,32 +41,33 @@ namespace jit {
TypePtr SchemaTypeParser::parseBaseType() { TypePtr SchemaTypeParser::parseBaseType() {
static std::unordered_map<std::string, TypePtr> type_map = { static std::unordered_map<std::string, TypePtr> type_map = {
{"Generator", GeneratorType::get()}, {"Generator", c10::TypeFactory::get<GeneratorType>()},
{"Dimname", StringType::get()}, {"Dimname", c10::TypeFactory::get<StringType>()},
{"ScalarType", IntType::get()}, {"ScalarType", c10::TypeFactory::get<IntType>()},
{"Layout", IntType::get()}, {"Layout", c10::TypeFactory::get<IntType>()},
{"MemoryFormat", IntType::get()}, {"MemoryFormat", c10::TypeFactory::get<IntType>()},
{"Storage", StorageType::get()}, {"Storage", c10::TypeFactory::get<StorageType>()},
{"QScheme", QSchemeType::get()}, {"QScheme", c10::TypeFactory::get<QSchemeType>()},
{"Quantizer", QuantizerType::get()}, {"Quantizer", c10::TypeFactory::get<QuantizerType>()},
{"ConstQuantizerPtr", {"ConstQuantizerPtr",
IntType::get()}, // TODO This type should be removed from the schema c10::TypeFactory::get<IntType>()}, // TODO This type should be removed
// parser, it should use the custom class mechanism // from the schema parser, it should
// use the custom class mechanism
// instead. @jerryzh // instead. @jerryzh
{"Device", DeviceObjType::get()}, {"Device", c10::TypeFactory::get<DeviceObjType>()},
{"Stream", StreamObjType::get()}, {"Stream", c10::TypeFactory::get<StreamObjType>()},
{"Scalar", NumberType::get()}, {"Scalar", c10::TypeFactory::get<NumberType>()},
{"str", StringType::get()}, {"str", c10::TypeFactory::get<StringType>()},
{"float", FloatType::get()}, {"float", c10::TypeFactory::get<FloatType>()},
{"complex", ComplexType::get()}, {"complex", c10::TypeFactory::get<ComplexType>()},
{"int", IntType::get()}, {"int", c10::TypeFactory::get<IntType>()},
{"bool", BoolType::get()}, {"bool", c10::TypeFactory::get<BoolType>()},
{"None", NoneType::get()}, {"None", c10::TypeFactory::get<NoneType>()},
{"NoneType", NoneType::get()}, {"NoneType", c10::TypeFactory::get<NoneType>()},
{"Capsule", CapsuleType::get()}, {"Capsule", c10::TypeFactory::get<CapsuleType>()},
{"Any", at::AnyType::get()}, {"Any", c10::TypeFactory::get<c10::AnyType>()},
{"AnyClassType", at::AnyClassType::get()}, {"AnyClassType", c10::TypeFactory::get<c10::AnyClassType>()},
{"AnyEnumType", at::AnyEnumType::get()}, {"AnyEnumType", c10::TypeFactory::get<c10::AnyEnumType>()},
}; };
auto tok = L.cur(); auto tok = L.cur();
if (!L.nextIf(TK_NONE) && !L.nextIf(TK_NONE_TYPE)) { if (!L.nextIf(TK_NONE) && !L.nextIf(TK_NONE_TYPE)) {
@ -79,7 +80,7 @@ TypePtr SchemaTypeParser::parseBaseType() {
if (text.size() > 0 && islower(text[0])) { if (text.size() > 0 && islower(text[0])) {
// lower case identifiers that are not otherwise valid types // lower case identifiers that are not otherwise valid types
// are treated as type variables // are treated as type variables
return VarType::create(text); return c10::TypeFactory::createNamed<VarType>(text);
} }
throw ErrorReport(tok.range) << "unknown type specifier"; throw ErrorReport(tok.range) << "unknown type specifier";
} }
@ -313,7 +314,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
alias_info->addContainedType(std::move(*r.second)); alias_info->addContainedType(std::move(*r.second));
} }
}); });
value = TupleType::create(std::move(types)); value = c10::TypeFactory::create<TupleType>(std::move(types));
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") { } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Future") {
L.next(); // Future L.next(); // Future
L.expect('('); L.expect('(');
@ -321,7 +322,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
auto subtype = std::move(p.first); auto subtype = std::move(p.first);
auto subalias = std::move(p.second); auto subalias = std::move(p.second);
L.expect(')'); L.expect(')');
value = FutureType::create(subtype); value = c10::TypeFactory::create<FutureType>(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") { } else if (L.cur().kind == TK_IDENT && L.cur().text() == "RRef") {
L.next(); // RRef L.next(); // RRef
L.expect('('); L.expect('(');
@ -329,10 +330,10 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
auto subtype = std::move(p.first); auto subtype = std::move(p.first);
auto subalias = std::move(p.second); auto subalias = std::move(p.second);
L.expect(')'); L.expect(')');
value = RRefType::create(subtype); value = c10::TypeFactory::create<RRefType>(subtype);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") { } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Tensor") {
L.next(); L.next();
value = TensorType::get(); value = c10::TypeFactory::get<TensorType>();
alias_info = parseAliasAnnotation(); alias_info = parseAliasAnnotation();
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") { } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Dict") {
L.next(); L.next();
@ -342,7 +343,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
auto value_type = parseType().first; auto value_type = parseType().first;
L.expect(')'); L.expect(')');
alias_info = parseAliasAnnotation(); alias_info = parseAliasAnnotation();
value = DictType::create(key_type, value_type); value = c10::TypeFactory::create<DictType>(key_type, value_type);
} else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") { } else if (L.cur().kind == TK_IDENT && L.cur().text() == "Union") {
L.next(); L.next();
L.expect('('); L.expect('(');
@ -395,7 +396,7 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
if (L.cur().kind == '[' && L.lookahead().kind == ']') { if (L.cur().kind == '[' && L.lookahead().kind == ']') {
L.next(); // [ L.next(); // [
L.next(); // ] L.next(); // ]
value = ListType::create(value); value = c10::TypeFactory::create<ListType>(value);
auto container = parseAliasAnnotation(); auto container = parseAliasAnnotation();
if (container && alias_info) { if (container && alias_info) {
container->addContainedType(std::move(*alias_info)); container->addContainedType(std::move(*alias_info));

View File

@ -1485,6 +1485,9 @@ inline Value::Value(Node* node_, size_t offset_)
inline Value* Value::setType(TypePtr type) { inline Value* Value::setType(TypePtr type) {
AT_ASSERT(type); AT_ASSERT(type);
if (auto dyn = type->castRaw<c10::DynamicType>()) {
type = dyn->fallback();
}
type_ = std::move(type); type_ = std::move(type);
for (Use& use : uses_) { for (Use& use : uses_) {
use.user->op_ = nullptr; use.user->op_ = nullptr;

View File

@ -108,7 +108,11 @@ std::pair<IValue, IValue> getFunctionTuple(
static const std::string torch_prefix("__torch__"); static const std::string torch_prefix("__torch__");
static const std::string class_prefix("__torch__.torch.classes"); static const std::string class_prefix("__torch__.torch.classes");
for (const TypePtr& t : mobile_code.types_) { for (const TypePtr& ty : mobile_code.types_) {
auto t = ty;
if (auto dyn = t->castRaw<c10::DynamicType>()) {
t = dyn->fallback();
}
std::string type_str = t->annotation_str(); std::string type_str = t->annotation_str();
if (t->kind() == TypeKind::TupleType) { if (t->kind() == TypeKind::TupleType) {
TORCH_CHECK( TORCH_CHECK(
@ -216,9 +220,13 @@ std::pair<IValue, IValue> getFunctionTuple(
arg.type()->annotation_str(type_printer) => mangled unique name of the arg.type()->annotation_str(type_printer) => mangled unique name of the
module/submodule module/submodule
*/ */
auto arg_type = arg.type();
if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
arg_type = dyn->fallback();
}
argTables.emplace_back(Table({ argTables.emplace_back(Table({
{"name", arg.name()}, {"name", arg.name()},
{"type", arg.type()->annotation_str(type_printer)}, {"type", arg_type->annotation_str(type_printer)},
{"default_value", arg.default_value()}, {"default_value", arg.default_value()},
})); }));
} }

View File

@ -575,7 +575,11 @@ void Pickler::endTypeTag(const IValue& ivalue) {
// Push the dict type // Push the dict type
TORCH_INTERNAL_ASSERT(ivalue.type()); TORCH_INTERNAL_ASSERT(ivalue.type());
pushString(ivalue.type()->annotation_str()); auto type = ivalue.type();
if (auto dyn = type->castRaw<c10::DynamicType>()) {
type = dyn->fallback();
}
pushString(type->annotation_str());
// Pop the dict and type into a tuple // Pop the dict and type into a tuple
push<PickleOpCode>(PickleOpCode::TUPLE2); push<PickleOpCode>(PickleOpCode::TUPLE2);

View File

@ -34,7 +34,7 @@ static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
// of the contained objects and cannot restore the tags. // of the contained objects and cannot restore the tags.
void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
struct Work { struct Work {
TypePtr static_type; TypePtr type;
IValue value; IValue value;
}; };
std::vector<Work> to_process = {{type_tag, root}}; std::vector<Work> to_process = {{type_tag, root}};
@ -53,7 +53,11 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
} }
scanned.emplace_hint(it, key); scanned.emplace_hint(it, key);
} }
switch (w.static_type->kind()) { auto kind = w.type->kind();
if (auto dyn = w.type->castRaw<c10::DynamicType>()) {
kind = dyn->dynamicKind();
}
switch (kind) {
case TensorType::Kind: case TensorType::Kind:
case StorageType::Kind: case StorageType::Kind:
case NumberType::Kind: case NumberType::Kind:
@ -83,52 +87,37 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
// no op, there is nothing to tag // no op, there is nothing to tag
break; break;
case DynamicType::Kind: case DynamicType::Kind:
case UnionType::Kind:
case EnumType::Kind: case EnumType::Kind:
// TODO(gmagogsfm): Implement serialization/deserialization of Enum. // TODO(gmagogsfm): Implement serialization/deserialization of Enum.
TORCH_INTERNAL_ASSERT(false); TORCH_INTERNAL_ASSERT(false);
case TupleType::Kind: { case TupleType::Kind: {
auto t = w.value.toTuple(); auto t = w.value.toTuple();
auto ttype = w.static_type->expect<TupleType>(); for (size_t i = 0; i < w.type->containedTypeSize(); ++i) {
for (size_t i = 0; i < ttype->containedTypes().size(); ++i) { Work elem = {w.type->containedType(i), t->elements().at(i)};
Work elem = {ttype->containedTypes().at(i), t->elements().at(i)};
to_process.emplace_back(std::move(elem)); to_process.emplace_back(std::move(elem));
} }
} break; } break;
case FutureType::Kind: { case FutureType::Kind: {
auto f = w.value.toFuture(); auto f = w.value.toFuture();
auto t = w.static_type->expect<FutureType>();
if (f->completed()) { if (f->completed()) {
Work elem = {t->getElementType(), f->value()}; Work elem = {w.type->containedType(0), f->value()};
to_process.emplace_back(std::move(elem)); to_process.emplace_back(std::move(elem));
} }
} break; } break;
case OptionalType::Kind: { case OptionalType::Kind: {
if (!w.value.isNone()) { if (!w.value.isNone()) {
auto t = w.static_type->expect<OptionalType>(); Work elem = {w.type->containedType(0), w.value};
Work elem = {t->getElementType(), w.value};
to_process.emplace_back(std::move(elem)); to_process.emplace_back(std::move(elem));
} }
} break; } break;
case UnionType::Kind: {
auto t = w.static_type->expect<UnionType>();
if (t->containedTypes().size() == 2 &&
t->canHoldType(*NoneType::get())) {
if (!w.value.isNone()) {
auto inner = t->containedTypes()[0] != NoneType::get()
? t->containedTypes()[0]
: t->containedTypes()[1];
Work elem = {inner, w.value};
to_process.emplace_back(std::move(elem));
}
}
} break;
case ListType::Kind: { case ListType::Kind: {
// specialized lists do not need their type refined, so we can exit // specialized lists do not need their type refined, so we can exit
// early here // early here
if (!w.value.isList()) { if (!w.value.isList()) {
break; break;
} }
auto elem_type = w.static_type->castRaw<ListType>()->getElementType(); auto elem_type = w.type->containedType(0);
auto lst = w.value.toList(); auto lst = w.value.toList();
lst.unsafeSetElementType(elem_type); lst.unsafeSetElementType(elem_type);
for (const IValue item : lst) { for (const IValue item : lst) {
@ -137,13 +126,14 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
} }
} break; } break;
case DictType::Kind: { case DictType::Kind: {
auto dt = w.static_type->cast<DictType>();
auto d = w.value.toGenericDict(); auto d = w.value.toGenericDict();
d.unsafeSetKeyType(dt->getKeyType()); auto keyType = w.type->containedType(0);
d.unsafeSetValueType(dt->getValueType()); auto valType = w.type->containedType(1);
d.unsafeSetKeyType(keyType);
d.unsafeSetValueType(valType);
for (const auto& item : d) { for (const auto& item : d) {
Work kelem = {dt->getKeyType(), item.key()}; Work kelem = {keyType, item.key()};
Work velem = {dt->getValueType(), item.value()}; Work velem = {valType, item.value()};
to_process.emplace_back(std::move(kelem)); to_process.emplace_back(std::move(kelem));
to_process.emplace_back(std::move(velem)); to_process.emplace_back(std::move(velem));
} }