More robust check of whether a class is defined in torch (#64083)

Summary:
This would prevent bugs for classes that
1) Is defined in a module that happens to start with `torch`, say `torchvision`
2) Is defined in torch but with an import alias like `import torch as th`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64083

Reviewed By: soulitzer

Differential Revision: D30598369

Pulled By: gmagogsfm

fbshipit-source-id: 9d3a7135737b2339c9bd32195e4e69a9c07549d4
This commit is contained in:
gmagogsfm 2021-08-27 08:49:54 -07:00 committed by Facebook GitHub Bot
parent f2c47cf4db
commit ad8eddbd80

View File

@ -1,6 +1,10 @@
import torch
import inspect
import typing
import pathlib
import sys
from typing import Optional, Iterable, List, Dict
from collections import defaultdict
from types import CodeType
@ -15,6 +19,18 @@ try:
except ImportError:
_IS_MONKEYTYPE_INSTALLED = False
# Checks whether a class is defind in `torch.*` modules
def is_torch_native_class(cls):
if not hasattr(cls, '__module__'):
return False
parent_modules = cls.__module__.split('.')
if not parent_modules:
return False
root_module = sys.modules.get(parent_modules[0])
return root_module is torch
def get_type(type):
"""
Helper function which converts the given type to a torchScript acceptable format.
@ -28,7 +44,7 @@ def get_type(type):
# typing.List is not accepted by TorchScript.
type_to_string = str(type)
return type_to_string.replace(type.__module__ + '.', '')
elif type.__module__.startswith('torch'):
elif is_torch_native_class(type):
# If the type is a subtype of torch module, then TorchScript expects a fully qualified name
# for the type which is obtained by combining the module name and type name.
return type.__module__ + '.' + type.__name__